summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.rst131
-rw-r--r--MANIFEST.in7
-rw-r--r--PKG-INFO12
-rw-r--r--README.rst7
-rwxr-xr-xbuild-deb.sh12
-rw-r--r--doc/source/Tutorials/array_widget.rst4
-rw-r--r--doc/source/Tutorials/fit.rst6
-rw-r--r--doc/source/Tutorials/fitconfig.rst4
-rw-r--r--doc/source/Tutorials/writing_NXdata.rst3
-rw-r--r--doc/source/applications/view.rst4
-rw-r--r--doc/source/ext/snapshotqt_directive.py2
-rw-r--r--doc/source/index.rst11
-rw-r--r--doc/source/install.rst23
-rw-r--r--doc/source/modules/gui/icons.rst388
-rw-r--r--doc/source/modules/gui/plot/getting_started.rst6
-rw-r--r--doc/source/modules/gui/widgets/printpreview.rst4
-rw-r--r--doc/source/modules/io/fioh5.rst35
-rw-r--r--doc/source/modules/io/index.rst1
-rw-r--r--doc/source/virtualenv.rst2
-rw-r--r--examples/colormapDialog.py8
-rw-r--r--examples/compareImages.py4
-rw-r--r--examples/compositeline.py2
-rw-r--r--examples/customDataView.py2
-rw-r--r--examples/customHdf5TreeModel.py2
-rw-r--r--examples/dropZones.py4
-rw-r--r--examples/exampleBaseline.py4
-rwxr-xr-xexamples/fftPlotAction.py2
-rw-r--r--examples/fileDialog.py12
-rw-r--r--examples/findContours.py4
-rwxr-xr-xexamples/hdf5widget.py11
-rw-r--r--examples/icons.py4
-rw-r--r--examples/imageStack.py4
-rwxr-xr-xexamples/imageview.py87
-rw-r--r--examples/periodicTable.py5
-rw-r--r--examples/plot3dContextMenu.py6
-rw-r--r--examples/plot3dSceneWindow.py4
-rw-r--r--examples/plot3dUpdateScatterFromThread.py2
-rw-r--r--examples/plotClearAction.py4
-rw-r--r--examples/plotContextMenu.py4
-rw-r--r--examples/plotCurveLegendWidget.py6
-rw-r--r--examples/plotInteractiveImageROI.py4
-rwxr-xr-xexamples/plotItemsSelector.py4
-rw-r--r--examples/plotLimits.py4
-rw-r--r--examples/plotProfile.py2
-rw-r--r--examples/plotROIStats.py6
-rw-r--r--examples/plotStats.py4
-rw-r--r--examples/plotUpdateCurveFromThread.py4
-rw-r--r--examples/plotUpdateImageFromThread.py5
-rw-r--r--examples/plotWidget.py6
-rwxr-xr-xexamples/printPreview.py4
-rw-r--r--examples/scatterMask.py2
-rwxr-xr-xexamples/scatterview.py2
-rwxr-xr-xexamples/shiftPlotAction.py4
-rwxr-xr-xexamples/simplewidget.py2
-rw-r--r--examples/stackView.py2
-rw-r--r--examples/syncPlotLocation.py2
-rw-r--r--examples/syncaxis.py2
-rw-r--r--examples/viewer3DVolume.py2
-rw-r--r--package/debian10/control3
-rw-r--r--package/debian11/control3
-rw-r--r--package/windows/README.rst24
-rw-r--r--package/windows/create-installer.iss.template92
-rw-r--r--package/windows/pyinstaller-silx-view.spec55
-rw-r--r--package/windows/pyinstaller.spec162
-rw-r--r--pyproject.toml1
-rw-r--r--qtdesigner_plugins/plot1dplugin.py6
-rw-r--r--qtdesigner_plugins/plot2dplugin.py6
-rw-r--r--qtdesigner_plugins/plotwidgetplugin.py6
-rw-r--r--qtdesigner_plugins/plotwindowplugin.py6
-rw-r--r--requirements-dev.txt8
-rw-r--r--requirements.txt5
-rwxr-xr-xrun_tests.py350
-rw-r--r--setup.py51
-rw-r--r--silx.egg-info/PKG-INFO12
-rw-r--r--silx.egg-info/SOURCES.txt2230
-rw-r--r--silx.egg-info/requires.txt6
-rw-r--r--silx/__init__.py61
-rw-r--r--silx/app/convert.py525
-rw-r--r--silx/app/test/__init__.py39
-rw-r--r--silx/app/test/test_convert.py167
-rw-r--r--silx/app/test_.py159
-rw-r--r--silx/app/view/About.py257
-rw-r--r--silx/app/view/ApplicationContext.py194
-rw-r--r--silx/app/view/CustomNxdataWidget.py1008
-rw-r--r--silx/app/view/Viewer.py971
-rw-r--r--silx/app/view/main.py171
-rw-r--r--silx/app/view/test/__init__.py41
-rw-r--r--silx/app/view/test/test_launcher.py151
-rw-r--r--silx/app/view/test/test_view.py394
-rw-r--r--silx/gui/_glutils/FramebufferTexture.py165
-rw-r--r--silx/gui/_glutils/OpenGLWidget.py423
-rw-r--r--silx/gui/_glutils/font.py163
-rw-r--r--silx/gui/_glutils/utils.py121
-rwxr-xr-xsilx/gui/colors.py1326
-rw-r--r--silx/gui/console.py202
-rw-r--r--silx/gui/data/ArrayTableModel.py670
-rw-r--r--silx/gui/data/ArrayTableWidget.py492
-rw-r--r--silx/gui/data/Hdf5TableView.py646
-rw-r--r--silx/gui/data/HexaTableView.py286
-rw-r--r--silx/gui/data/NXdataWidgets.py1081
-rw-r--r--silx/gui/data/RecordTableView.py447
-rw-r--r--silx/gui/data/TextFormatter.py395
-rw-r--r--silx/gui/data/test/__init__.py45
-rw-r--r--silx/gui/data/test/test_arraywidget.py329
-rw-r--r--silx/gui/data/test/test_dataviewer.py314
-rw-r--r--silx/gui/data/test/test_numpyaxesselector.py161
-rw-r--r--silx/gui/data/test/test_textformatter.py212
-rw-r--r--silx/gui/dialog/AbstractDataFileDialog.py1742
-rw-r--r--silx/gui/dialog/ColormapDialog.py1771
-rw-r--r--silx/gui/dialog/DataFileDialog.py340
-rw-r--r--silx/gui/dialog/DatasetDialog.py122
-rw-r--r--silx/gui/dialog/GroupDialog.py230
-rw-r--r--silx/gui/dialog/ImageFileDialog.py354
-rw-r--r--silx/gui/dialog/SafeFileSystemModel.py804
-rw-r--r--silx/gui/dialog/test/__init__.py49
-rw-r--r--silx/gui/dialog/test/test_colormapdialog.py453
-rw-r--r--silx/gui/dialog/test/test_datafiledialog.py939
-rw-r--r--silx/gui/dialog/test/test_imagefiledialog.py784
-rw-r--r--silx/gui/dialog/utils.py106
-rw-r--r--silx/gui/fit/BackgroundWidget.py534
-rw-r--r--silx/gui/fit/FitConfig.py543
-rw-r--r--silx/gui/fit/FitWidget.py739
-rw-r--r--silx/gui/fit/FitWidgets.py559
-rw-r--r--silx/gui/fit/Parameters.py882
-rw-r--r--silx/gui/fit/test/__init__.py43
-rw-r--r--silx/gui/fit/test/testBackgroundWidget.py83
-rw-r--r--silx/gui/fit/test/testFitConfig.py95
-rw-r--r--silx/gui/fit/test/testFitWidget.py135
-rw-r--r--silx/gui/hdf5/Hdf5Formatter.py241
-rw-r--r--silx/gui/hdf5/Hdf5HeaderView.py195
-rw-r--r--silx/gui/hdf5/Hdf5TreeModel.py778
-rw-r--r--silx/gui/hdf5/Hdf5TreeView.py271
-rw-r--r--silx/gui/hdf5/_utils.py461
-rw-r--r--silx/gui/hdf5/test/__init__.py39
-rwxr-xr-xsilx/gui/hdf5/test/test_hdf5.py1140
-rw-r--r--silx/gui/plot/AlphaSlider.py300
-rw-r--r--silx/gui/plot/ColorBar.py881
-rw-r--r--silx/gui/plot/CompareImages.py1249
-rw-r--r--silx/gui/plot/ComplexImageView.py518
-rw-r--r--silx/gui/plot/CurvesROIWidget.py1584
-rw-r--r--silx/gui/plot/ImageStack.py636
-rw-r--r--silx/gui/plot/ImageView.py854
-rw-r--r--silx/gui/plot/ItemsSelectionDialog.py286
-rwxr-xr-xsilx/gui/plot/LegendSelector.py1036
-rw-r--r--silx/gui/plot/MaskToolsWidget.py919
-rw-r--r--silx/gui/plot/PlotInteraction.py1748
-rwxr-xr-xsilx/gui/plot/PlotWidget.py3621
-rw-r--r--silx/gui/plot/PlotWindow.py994
-rw-r--r--silx/gui/plot/PrintPreviewToolButton.py392
-rw-r--r--silx/gui/plot/ROIStatsWidget.py780
-rw-r--r--silx/gui/plot/ScatterMaskToolsWidget.py621
-rw-r--r--silx/gui/plot/ScatterView.py405
-rw-r--r--silx/gui/plot/StackView.py1254
-rw-r--r--silx/gui/plot/StatsWidget.py1661
-rw-r--r--silx/gui/plot/_utils/__init__.py93
-rw-r--r--silx/gui/plot/_utils/panzoom.py292
-rw-r--r--silx/gui/plot/_utils/test/__init__.py43
-rw-r--r--silx/gui/plot/_utils/test/test_dtime_ticklayout.py93
-rw-r--r--silx/gui/plot/_utils/test/test_ticklayout.py92
-rw-r--r--silx/gui/plot/actions/fit.py403
-rw-r--r--silx/gui/plot/actions/histogram.py392
-rw-r--r--silx/gui/plot/actions/io.py818
-rwxr-xr-xsilx/gui/plot/backends/BackendBase.py578
-rwxr-xr-xsilx/gui/plot/backends/BackendMatplotlib.py1544
-rwxr-xr-xsilx/gui/plot/backends/BackendOpenGL.py1420
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotCurve.py1375
-rw-r--r--silx/gui/plot/backends/glutils/GLPlotFrame.py1219
-rw-r--r--silx/gui/plot/items/__init__.py52
-rw-r--r--silx/gui/plot/items/axis.py569
-rw-r--r--silx/gui/plot/items/core.py1734
-rw-r--r--silx/gui/plot/items/curve.py326
-rw-r--r--silx/gui/plot/items/image.py617
-rw-r--r--silx/gui/plot/items/scatter.py973
-rw-r--r--silx/gui/plot/items/shape.py288
-rw-r--r--silx/gui/plot/test/__init__.py92
-rw-r--r--silx/gui/plot/test/testAlphaSlider.py218
-rw-r--r--silx/gui/plot/test/testColorBar.py354
-rw-r--r--silx/gui/plot/test/testCompareImages.py117
-rw-r--r--silx/gui/plot/test/testComplexImageView.py95
-rw-r--r--silx/gui/plot/test/testCurvesROIWidget.py469
-rw-r--r--silx/gui/plot/test/testImageStack.py197
-rw-r--r--silx/gui/plot/test/testImageView.py136
-rw-r--r--silx/gui/plot/test/testInteraction.py89
-rw-r--r--silx/gui/plot/test/testItem.py340
-rw-r--r--silx/gui/plot/test/testLegendSelector.py142
-rw-r--r--silx/gui/plot/test/testLimitConstraints.py125
-rw-r--r--silx/gui/plot/test/testMaskToolsWidget.py316
-rw-r--r--silx/gui/plot/test/testPixelIntensityHistoAction.py157
-rw-r--r--silx/gui/plot/test/testPlotInteraction.py172
-rwxr-xr-xsilx/gui/plot/test/testPlotWidget.py2072
-rw-r--r--silx/gui/plot/test/testPlotWidgetNoBackend.py631
-rw-r--r--silx/gui/plot/test/testPlotWindow.py185
-rw-r--r--silx/gui/plot/test/testRoiStatsWidget.py290
-rw-r--r--silx/gui/plot/test/testSaveAction.py143
-rw-r--r--silx/gui/plot/test/testScatterMaskToolsWidget.py318
-rw-r--r--silx/gui/plot/test/testScatterView.py134
-rw-r--r--silx/gui/plot/test/testStackView.py261
-rw-r--r--silx/gui/plot/test/testStats.py1058
-rw-r--r--silx/gui/plot/test/testUtilsAxis.py214
-rw-r--r--silx/gui/plot/test/utils.py94
-rw-r--r--silx/gui/plot/tools/PositionInfo.py376
-rw-r--r--silx/gui/plot/tools/profile/manager.py1076
-rw-r--r--silx/gui/plot/tools/profile/rois.py1156
-rw-r--r--silx/gui/plot/tools/roi.py1417
-rw-r--r--silx/gui/plot/tools/test/__init__.py52
-rw-r--r--silx/gui/plot/tools/test/testCurveLegendsWidget.py125
-rw-r--r--silx/gui/plot/tools/test/testProfile.py673
-rw-r--r--silx/gui/plot/tools/test/testROI.py694
-rw-r--r--silx/gui/plot/tools/test/testScatterProfileToolBar.py196
-rw-r--r--silx/gui/plot/tools/test/testTools.py147
-rw-r--r--silx/gui/plot/utils/axis.py403
-rw-r--r--silx/gui/plot3d/ParamTreeView.py546
-rw-r--r--silx/gui/plot3d/Plot3DWidget.py460
-rw-r--r--silx/gui/plot3d/SFViewParamTree.py1817
-rw-r--r--silx/gui/plot3d/_model/items.py1760
-rw-r--r--silx/gui/plot3d/actions/io.py336
-rw-r--r--silx/gui/plot3d/actions/mode.py178
-rw-r--r--silx/gui/plot3d/items/core.py779
-rw-r--r--silx/gui/plot3d/items/image.py425
-rw-r--r--silx/gui/plot3d/scene/test/__init__.py43
-rw-r--r--silx/gui/plot3d/scene/test/test_transform.py91
-rw-r--r--silx/gui/plot3d/scene/test/test_utils.py275
-rw-r--r--silx/gui/plot3d/scene/window.py430
-rw-r--r--silx/gui/plot3d/test/__init__.py75
-rw-r--r--silx/gui/plot3d/test/testGL.py84
-rw-r--r--silx/gui/plot3d/test/testScalarFieldView.py139
-rw-r--r--silx/gui/plot3d/test/testSceneWidget.py84
-rw-r--r--silx/gui/plot3d/test/testSceneWidgetPicking.py326
-rw-r--r--silx/gui/plot3d/test/testSceneWindow.py245
-rw-r--r--silx/gui/plot3d/test/testStatsWidget.py216
-rw-r--r--silx/gui/plot3d/tools/GroupPropertiesWidget.py202
-rw-r--r--silx/gui/plot3d/tools/PositionInfoWidget.py219
-rw-r--r--silx/gui/plot3d/tools/test/__init__.py41
-rw-r--r--silx/gui/plot3d/tools/test/testPositionInfoWidget.py101
-rw-r--r--silx/gui/qt/__init__.py60
-rw-r--r--silx/gui/qt/_macosx.py68
-rw-r--r--silx/gui/qt/_pyside_dynamic.py239
-rw-r--r--silx/gui/qt/_pyside_missing.py274
-rw-r--r--silx/gui/qt/_qt.py289
-rw-r--r--silx/gui/qt/_utils.py71
-rw-r--r--silx/gui/qt/inspect.py87
-rw-r--r--silx/gui/test/__init__.py113
-rwxr-xr-xsilx/gui/test/test_colors.py619
-rw-r--r--silx/gui/test/test_console.py91
-rw-r--r--silx/gui/test/test_icons.py158
-rw-r--r--silx/gui/test/test_qt.py201
-rw-r--r--silx/gui/utils/glutils/__init__.py199
-rw-r--r--silx/gui/utils/image.py143
-rw-r--r--silx/gui/utils/matplotlib.py71
-rwxr-xr-xsilx/gui/utils/test/__init__.py56
-rw-r--r--silx/gui/utils/test/test.py76
-rw-r--r--silx/gui/utils/test/test_async.py138
-rw-r--r--silx/gui/utils/test/test_glutils.py66
-rw-r--r--silx/gui/utils/test/test_image.py90
-rwxr-xr-xsilx/gui/utils/test/test_qtutils.py75
-rw-r--r--silx/gui/utils/test/test_testutils.py55
-rw-r--r--silx/gui/utils/testutils.py518
-rw-r--r--silx/gui/widgets/ElidedLabel.py137
-rw-r--r--silx/gui/widgets/FloatEdit.py65
-rw-r--r--silx/gui/widgets/PeriodicTable.py831
-rw-r--r--silx/gui/widgets/PrintGeometryDialog.py222
-rw-r--r--silx/gui/widgets/PrintPreview.py728
-rw-r--r--silx/gui/widgets/RangeSlider.py765
-rw-r--r--silx/gui/widgets/TableWidget.py626
-rw-r--r--silx/gui/widgets/UrlSelectionTable.py172
-rw-r--r--silx/gui/widgets/WaitingPushButton.py245
-rw-r--r--silx/gui/widgets/test/__init__.py59
-rw-r--r--silx/gui/widgets/test/test_boxlayoutdockwidget.py83
-rw-r--r--silx/gui/widgets/test/test_elidedlabel.py111
-rw-r--r--silx/gui/widgets/test/test_flowlayout.py77
-rw-r--r--silx/gui/widgets/test/test_framebrowser.py73
-rw-r--r--silx/gui/widgets/test/test_hierarchicaltableview.py117
-rw-r--r--silx/gui/widgets/test/test_legendiconwidget.py74
-rw-r--r--silx/gui/widgets/test/test_periodictable.py163
-rw-r--r--silx/gui/widgets/test/test_printpreview.py74
-rw-r--r--silx/gui/widgets/test/test_rangeslider.py114
-rw-r--r--silx/gui/widgets/test/test_tablewidget.py61
-rw-r--r--silx/gui/widgets/test/test_threadpoolpushbutton.py135
-rw-r--r--silx/image/marchingsquares/setup.py51
-rw-r--r--silx/image/marchingsquares/test/__init__.py40
-rw-r--r--silx/image/marchingsquares/test/test_funcapi.py99
-rw-r--r--silx/image/marchingsquares/test/test_mergeimpl.py272
-rw-r--r--silx/image/test/__init__.py48
-rw-r--r--silx/image/test/test_bb.py86
-rw-r--r--silx/image/test/test_bilinear.py178
-rw-r--r--silx/image/test/test_medianfilter.py76
-rw-r--r--silx/image/test/test_shapes.py366
-rw-r--r--silx/image/test/test_tomography.py66
-rw-r--r--silx/io/commonh5.py1083
-rw-r--r--silx/io/convert.py343
-rw-r--r--silx/io/dictdump.py842
-rwxr-xr-xsilx/io/fabioh5.py1051
-rw-r--r--silx/io/h5py_utils.py317
-rw-r--r--silx/io/nxdata/_utils.py184
-rw-r--r--silx/io/nxdata/parse.py997
-rw-r--r--silx/io/nxdata/write.py203
-rw-r--r--silx/io/specfile.pyx1268
-rw-r--r--silx/io/spech5.py883
-rw-r--r--silx/io/test/__init__.py61
-rw-r--r--silx/io/test/test_commonh5.py295
-rw-r--r--silx/io/test/test_dictdump.py1025
-rwxr-xr-xsilx/io/test/test_fabioh5.py629
-rw-r--r--silx/io/test/test_h5py_utils.py397
-rw-r--r--silx/io/test/test_nxdata.py579
-rw-r--r--silx/io/test/test_octaveh5.py165
-rw-r--r--silx/io/test/test_rawh5.py96
-rw-r--r--silx/io/test/test_specfile.py433
-rw-r--r--silx/io/test/test_specfilewrapper.py206
-rw-r--r--silx/io/test/test_spech5.py881
-rw-r--r--silx/io/test/test_spectoh5.py194
-rw-r--r--silx/io/test/test_url.py228
-rw-r--r--silx/io/test/test_utils.py888
-rw-r--r--silx/io/url.py390
-rw-r--r--silx/io/utils.py1142
-rw-r--r--silx/math/colormap.pyx559
-rw-r--r--silx/math/fft/test/__init__.py25
-rw-r--r--silx/math/fft/test/test_fft.py270
-rw-r--r--silx/math/fit/fitmanager.py1087
-rw-r--r--silx/math/fit/fittheories.py1374
-rw-r--r--silx/math/fit/test/__init__.py46
-rw-r--r--silx/math/fit/test/test_bgtheories.py169
-rw-r--r--silx/math/fit/test/test_filters.py137
-rw-r--r--silx/math/fit/test/test_fit.py387
-rw-r--r--silx/math/fit/test/test_fitmanager.py513
-rw-r--r--silx/math/fit/test/test_functions.py272
-rw-r--r--silx/math/fit/test/test_peaks.py146
-rw-r--r--silx/math/medianfilter/test/__init__.py36
-rw-r--r--silx/math/medianfilter/test/benchmark.py122
-rw-r--r--silx/math/medianfilter/test/test_medianfilter.py740
-rw-r--r--silx/math/setup.py99
-rw-r--r--silx/math/test/__init__.py58
-rw-r--r--silx/math/test/benchmark_combo.py203
-rw-r--r--silx/math/test/test_HistogramndLut_nominal.py587
-rw-r--r--silx/math/test/test_calibration.py158
-rw-r--r--silx/math/test/test_colormap.py266
-rw-r--r--silx/math/test/test_combo.py218
-rw-r--r--silx/math/test/test_histogramnd_error.py535
-rw-r--r--silx/math/test/test_histogramnd_nominal.py949
-rw-r--r--silx/math/test/test_histogramnd_vs_np.py848
-rw-r--r--silx/math/test/test_interpolate.py136
-rw-r--r--silx/math/test/test_marchingcubes.py188
-rw-r--r--silx/opencl/codec/test/__init__.py37
-rw-r--r--silx/opencl/codec/test/test_byte_offset.py315
-rw-r--r--silx/opencl/common.py689
-rw-r--r--silx/opencl/test/__init__.py68
-rw-r--r--silx/opencl/test/test_addition.py154
-rw-r--r--silx/opencl/test/test_array_utils.py161
-rw-r--r--silx/opencl/test/test_backprojection.py231
-rw-r--r--silx/opencl/test/test_convolution.py265
-rw-r--r--silx/opencl/test/test_doubleword.py258
-rw-r--r--silx/opencl/test/test_image.py137
-rw-r--r--silx/opencl/test/test_kahan.py269
-rw-r--r--silx/opencl/test/test_linalg.py216
-rw-r--r--silx/opencl/test/test_medfilt.py175
-rw-r--r--silx/opencl/test/test_projection.py131
-rw-r--r--silx/opencl/test/test_sparse.py203
-rw-r--r--silx/opencl/test/test_stats.py116
-rw-r--r--silx/resources/gui/icons/compare-align-auto.svg4
-rw-r--r--silx/resources/gui/icons/compare-align-center.svg4
-rw-r--r--silx/resources/gui/icons/compare-align-origin.svg4
-rw-r--r--silx/resources/gui/icons/compare-align-stretch.svg4
-rw-r--r--silx/resources/gui/icons/math-peak-search.svg2
-rw-r--r--silx/resources/gui/icons/remove.svg2
-rw-r--r--silx/resources/gui/icons/zoom-back.svg2
-rw-r--r--silx/resources/gui/icons/zoom-in.svg2
-rw-r--r--silx/resources/gui/icons/zoom-original.svg2
-rw-r--r--silx/resources/gui/icons/zoom-out.svg2
-rw-r--r--silx/resources/gui/icons/zoom.svg2
-rw-r--r--silx/setup.py54
-rw-r--r--silx/sx/_plot.py623
-rw-r--r--silx/test/__init__.py104
-rw-r--r--silx/test/test_resources.py200
-rw-r--r--silx/test/test_sx.py292
-rw-r--r--silx/test/test_version.py49
-rw-r--r--silx/test/utils.py204
-rw-r--r--silx/third_party/setup.py49
-rw-r--r--silx/utils/ExternalResources.py320
-rw-r--r--silx/utils/_have_openmp.pxd49
-rw-r--r--silx/utils/array_like.py596
-rw-r--r--silx/utils/debug.py103
-rw-r--r--silx/utils/html.py60
-rw-r--r--silx/utils/proxy.py241
-rwxr-xr-xsilx/utils/test/__init__.py59
-rw-r--r--silx/utils/test/test_array_like.py445
-rw-r--r--silx/utils/test/test_debug.py99
-rw-r--r--silx/utils/test/test_deprecation.py107
-rw-r--r--silx/utils/test/test_enum.py96
-rw-r--r--silx/utils/test/test_external_resources.py99
-rw-r--r--silx/utils/test/test_html.py61
-rw-r--r--silx/utils/test/test_launcher.py204
-rw-r--r--silx/utils/test/test_number.py186
-rw-r--r--silx/utils/test/test_proxy.py344
-rw-r--r--silx/utils/test/test_retry.py179
-rwxr-xr-xsilx/utils/test/test_testutils.py105
-rw-r--r--silx/utils/test/test_weakref.py330
-rwxr-xr-xsilx/utils/testutils.py333
-rw-r--r--src/silx/__init__.py58
-rw-r--r--src/silx/__main__.py (renamed from silx/__main__.py)0
-rw-r--r--src/silx/_config.py (renamed from silx/_config.py)0
-rw-r--r--src/silx/_version.py120
-rw-r--r--src/silx/app/__init__.py (renamed from silx/app/__init__.py)0
-rw-r--r--src/silx/app/convert.py548
-rw-r--r--src/silx/app/setup.py (renamed from silx/app/setup.py)0
-rw-r--r--src/silx/app/test/__init__.py24
-rw-r--r--src/silx/app/test/test_convert.py156
-rw-r--r--src/silx/app/test_.py45
-rw-r--r--src/silx/app/view/About.py258
-rw-r--r--src/silx/app/view/ApplicationContext.py195
-rw-r--r--src/silx/app/view/CustomNxdataWidget.py1002
-rw-r--r--src/silx/app/view/DataPanel.py (renamed from silx/app/view/DataPanel.py)0
-rw-r--r--src/silx/app/view/Viewer.py962
-rw-r--r--src/silx/app/view/__init__.py (renamed from silx/app/view/__init__.py)0
-rw-r--r--src/silx/app/view/main.py186
-rw-r--r--src/silx/app/view/setup.py (renamed from silx/app/view/setup.py)0
-rw-r--r--src/silx/app/view/test/__init__.py24
-rw-r--r--src/silx/app/view/test/test_launcher.py140
-rw-r--r--src/silx/app/view/test/test_view.py388
-rw-r--r--src/silx/app/view/utils.py (renamed from silx/app/view/utils.py)0
-rw-r--r--src/silx/conftest.py130
-rw-r--r--src/silx/gui/__init__.py (renamed from silx/gui/__init__.py)0
-rw-r--r--src/silx/gui/_glutils/Context.py (renamed from silx/gui/_glutils/Context.py)0
-rw-r--r--src/silx/gui/_glutils/FramebufferTexture.py168
-rw-r--r--src/silx/gui/_glutils/OpenGLWidget.py422
-rw-r--r--src/silx/gui/_glutils/Program.py (renamed from silx/gui/_glutils/Program.py)0
-rw-r--r--src/silx/gui/_glutils/Texture.py (renamed from silx/gui/_glutils/Texture.py)0
-rw-r--r--src/silx/gui/_glutils/VertexBuffer.py (renamed from silx/gui/_glutils/VertexBuffer.py)0
-rw-r--r--src/silx/gui/_glutils/__init__.py (renamed from silx/gui/_glutils/__init__.py)0
-rw-r--r--src/silx/gui/_glutils/font.py156
-rw-r--r--src/silx/gui/_glutils/gl.py (renamed from silx/gui/_glutils/gl.py)0
-rw-r--r--src/silx/gui/_glutils/utils.py123
-rwxr-xr-xsrc/silx/gui/colors.py1036
-rw-r--r--src/silx/gui/conftest.py5
-rw-r--r--src/silx/gui/console.py202
-rw-r--r--src/silx/gui/data/ArrayTableModel.py650
-rw-r--r--src/silx/gui/data/ArrayTableWidget.py492
-rw-r--r--src/silx/gui/data/DataViewer.py (renamed from silx/gui/data/DataViewer.py)0
-rw-r--r--src/silx/gui/data/DataViewerFrame.py (renamed from silx/gui/data/DataViewerFrame.py)0
-rw-r--r--src/silx/gui/data/DataViewerSelector.py (renamed from silx/gui/data/DataViewerSelector.py)0
-rw-r--r--src/silx/gui/data/DataViews.py (renamed from silx/gui/data/DataViews.py)0
-rw-r--r--src/silx/gui/data/Hdf5TableView.py634
-rw-r--r--src/silx/gui/data/HexaTableView.py272
-rw-r--r--src/silx/gui/data/NXdataWidgets.py1086
-rw-r--r--src/silx/gui/data/NumpyAxesSelector.py (renamed from silx/gui/data/NumpyAxesSelector.py)0
-rw-r--r--src/silx/gui/data/RecordTableView.py439
-rw-r--r--src/silx/gui/data/TextFormatter.py386
-rw-r--r--src/silx/gui/data/_RecordPlot.py (renamed from silx/gui/data/_RecordPlot.py)0
-rw-r--r--src/silx/gui/data/_VolumeWindow.py (renamed from silx/gui/data/_VolumeWindow.py)0
-rw-r--r--src/silx/gui/data/__init__.py (renamed from silx/gui/data/__init__.py)0
-rw-r--r--src/silx/gui/data/setup.py (renamed from silx/gui/data/setup.py)0
-rw-r--r--src/silx/gui/data/test/__init__.py24
-rw-r--r--src/silx/gui/data/test/test_arraywidget.py316
-rw-r--r--src/silx/gui/data/test/test_dataviewer.py304
-rw-r--r--src/silx/gui/data/test/test_numpyaxesselector.py150
-rw-r--r--src/silx/gui/data/test/test_textformatter.py199
-rw-r--r--src/silx/gui/dialog/AbstractDataFileDialog.py1731
-rw-r--r--src/silx/gui/dialog/ColormapDialog.py1775
-rw-r--r--src/silx/gui/dialog/DataFileDialog.py340
-rw-r--r--src/silx/gui/dialog/DatasetDialog.py122
-rw-r--r--src/silx/gui/dialog/FileTypeComboBox.py (renamed from silx/gui/dialog/FileTypeComboBox.py)0
-rw-r--r--src/silx/gui/dialog/GroupDialog.py230
-rw-r--r--src/silx/gui/dialog/ImageFileDialog.py354
-rw-r--r--src/silx/gui/dialog/SafeFileIconProvider.py (renamed from silx/gui/dialog/SafeFileIconProvider.py)0
-rw-r--r--src/silx/gui/dialog/SafeFileSystemModel.py802
-rw-r--r--src/silx/gui/dialog/__init__.py (renamed from silx/gui/dialog/__init__.py)0
-rw-r--r--src/silx/gui/dialog/setup.py (renamed from silx/gui/dialog/setup.py)0
-rw-r--r--src/silx/gui/dialog/test/__init__.py24
-rw-r--r--src/silx/gui/dialog/test/test_colormapdialog.py395
-rw-r--r--src/silx/gui/dialog/test/test_datafiledialog.py924
-rw-r--r--src/silx/gui/dialog/test/test_imagefiledialog.py772
-rw-r--r--src/silx/gui/dialog/utils.py99
-rw-r--r--src/silx/gui/fit/BackgroundWidget.py534
-rw-r--r--src/silx/gui/fit/FitConfig.py543
-rw-r--r--src/silx/gui/fit/FitWidget.py751
-rw-r--r--src/silx/gui/fit/FitWidgets.py555
-rw-r--r--src/silx/gui/fit/Parameters.py882
-rw-r--r--src/silx/gui/fit/__init__.py (renamed from silx/gui/fit/__init__.py)0
-rw-r--r--src/silx/gui/fit/setup.py (renamed from silx/gui/fit/setup.py)0
-rw-r--r--src/silx/gui/fit/test/__init__.py24
-rw-r--r--src/silx/gui/fit/test/testBackgroundWidget.py72
-rw-r--r--src/silx/gui/fit/test/testFitConfig.py84
-rw-r--r--src/silx/gui/fit/test/testFitWidget.py124
-rw-r--r--src/silx/gui/hdf5/Hdf5Formatter.py240
-rw-r--r--src/silx/gui/hdf5/Hdf5HeaderView.py184
-rwxr-xr-xsrc/silx/gui/hdf5/Hdf5Item.py (renamed from silx/gui/hdf5/Hdf5Item.py)0
-rw-r--r--src/silx/gui/hdf5/Hdf5LoadingItem.py (renamed from silx/gui/hdf5/Hdf5LoadingItem.py)0
-rw-r--r--src/silx/gui/hdf5/Hdf5Node.py (renamed from silx/gui/hdf5/Hdf5Node.py)0
-rw-r--r--src/silx/gui/hdf5/Hdf5TreeModel.py742
-rw-r--r--src/silx/gui/hdf5/Hdf5TreeView.py269
-rw-r--r--src/silx/gui/hdf5/NexusSortFilterProxyModel.py (renamed from silx/gui/hdf5/NexusSortFilterProxyModel.py)0
-rw-r--r--src/silx/gui/hdf5/__init__.py (renamed from silx/gui/hdf5/__init__.py)0
-rw-r--r--src/silx/gui/hdf5/_utils.py461
-rw-r--r--src/silx/gui/hdf5/setup.py (renamed from silx/gui/hdf5/setup.py)0
-rw-r--r--src/silx/gui/hdf5/test/__init__.py24
-rwxr-xr-xsrc/silx/gui/hdf5/test/test_hdf5.py1092
-rw-r--r--src/silx/gui/icons.py (renamed from silx/gui/icons.py)0
-rw-r--r--src/silx/gui/plot/AlphaSlider.py300
-rw-r--r--src/silx/gui/plot/ColorBar.py883
-rw-r--r--src/silx/gui/plot/Colormap.py (renamed from silx/gui/plot/Colormap.py)0
-rw-r--r--src/silx/gui/plot/ColormapDialog.py (renamed from silx/gui/plot/ColormapDialog.py)0
-rw-r--r--src/silx/gui/plot/Colors.py (renamed from silx/gui/plot/Colors.py)0
-rw-r--r--src/silx/gui/plot/CompareImages.py1259
-rw-r--r--src/silx/gui/plot/ComplexImageView.py518
-rw-r--r--src/silx/gui/plot/CurvesROIWidget.py1581
-rw-r--r--src/silx/gui/plot/ImageStack.py640
-rw-r--r--src/silx/gui/plot/ImageView.py1057
-rw-r--r--src/silx/gui/plot/Interaction.py (renamed from silx/gui/plot/Interaction.py)0
-rw-r--r--src/silx/gui/plot/ItemsSelectionDialog.py286
-rwxr-xr-xsrc/silx/gui/plot/LegendSelector.py1039
-rw-r--r--src/silx/gui/plot/LimitsHistory.py (renamed from silx/gui/plot/LimitsHistory.py)0
-rw-r--r--src/silx/gui/plot/MaskToolsWidget.py919
-rw-r--r--src/silx/gui/plot/PlotActions.py (renamed from silx/gui/plot/PlotActions.py)0
-rw-r--r--src/silx/gui/plot/PlotEvents.py (renamed from silx/gui/plot/PlotEvents.py)0
-rw-r--r--src/silx/gui/plot/PlotInteraction.py1746
-rw-r--r--src/silx/gui/plot/PlotToolButtons.py (renamed from silx/gui/plot/PlotToolButtons.py)0
-rw-r--r--src/silx/gui/plot/PlotTools.py (renamed from silx/gui/plot/PlotTools.py)0
-rwxr-xr-xsrc/silx/gui/plot/PlotWidget.py3628
-rw-r--r--src/silx/gui/plot/PlotWindow.py993
-rw-r--r--src/silx/gui/plot/PrintPreviewToolButton.py388
-rw-r--r--src/silx/gui/plot/Profile.py (renamed from silx/gui/plot/Profile.py)0
-rw-r--r--src/silx/gui/plot/ProfileMainWindow.py (renamed from silx/gui/plot/ProfileMainWindow.py)0
-rw-r--r--src/silx/gui/plot/ROIStatsWidget.py780
-rw-r--r--src/silx/gui/plot/ScatterMaskToolsWidget.py621
-rw-r--r--src/silx/gui/plot/ScatterView.py404
-rw-r--r--src/silx/gui/plot/StackView.py1254
-rw-r--r--src/silx/gui/plot/StatsWidget.py1658
-rw-r--r--src/silx/gui/plot/_BaseMaskToolsWidget.py (renamed from silx/gui/plot/_BaseMaskToolsWidget.py)0
-rw-r--r--src/silx/gui/plot/__init__.py (renamed from silx/gui/plot/__init__.py)0
-rw-r--r--src/silx/gui/plot/_utils/__init__.py92
-rw-r--r--src/silx/gui/plot/_utils/delaunay.py (renamed from silx/gui/plot/_utils/delaunay.py)0
-rw-r--r--src/silx/gui/plot/_utils/dtime_ticklayout.py (renamed from silx/gui/plot/_utils/dtime_ticklayout.py)0
-rw-r--r--src/silx/gui/plot/_utils/panzoom.py325
-rw-r--r--src/silx/gui/plot/_utils/setup.py (renamed from silx/gui/plot/_utils/setup.py)0
-rw-r--r--src/silx/gui/plot/_utils/test/__init__.py24
-rw-r--r--src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py79
-rw-r--r--src/silx/gui/plot/_utils/test/test_ticklayout.py81
-rw-r--r--src/silx/gui/plot/_utils/ticklayout.py (renamed from silx/gui/plot/_utils/ticklayout.py)0
-rw-r--r--src/silx/gui/plot/actions/PlotAction.py (renamed from silx/gui/plot/actions/PlotAction.py)0
-rw-r--r--src/silx/gui/plot/actions/PlotToolAction.py (renamed from silx/gui/plot/actions/PlotToolAction.py)0
-rw-r--r--src/silx/gui/plot/actions/__init__.py (renamed from silx/gui/plot/actions/__init__.py)0
-rwxr-xr-xsrc/silx/gui/plot/actions/control.py (renamed from silx/gui/plot/actions/control.py)0
-rw-r--r--src/silx/gui/plot/actions/fit.py485
-rw-r--r--src/silx/gui/plot/actions/histogram.py542
-rw-r--r--src/silx/gui/plot/actions/io.py819
-rw-r--r--src/silx/gui/plot/actions/medfilt.py (renamed from silx/gui/plot/actions/medfilt.py)0
-rw-r--r--src/silx/gui/plot/actions/mode.py (renamed from silx/gui/plot/actions/mode.py)0
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendBase.py568
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendMatplotlib.py1557
-rwxr-xr-xsrc/silx/gui/plot/backends/BackendOpenGL.py1420
-rw-r--r--src/silx/gui/plot/backends/__init__.py (renamed from silx/gui/plot/backends/__init__.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotCurve.py1380
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotFrame.py1210
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotImage.py (renamed from silx/gui/plot/backends/glutils/GLPlotImage.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotItem.py (renamed from silx/gui/plot/backends/glutils/GLPlotItem.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/GLPlotTriangles.py (renamed from silx/gui/plot/backends/glutils/GLPlotTriangles.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/GLSupport.py (renamed from silx/gui/plot/backends/glutils/GLSupport.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/GLText.py (renamed from silx/gui/plot/backends/glutils/GLText.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/GLTexture.py (renamed from silx/gui/plot/backends/glutils/GLTexture.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/PlotImageFile.py (renamed from silx/gui/plot/backends/glutils/PlotImageFile.py)0
-rw-r--r--src/silx/gui/plot/backends/glutils/__init__.py (renamed from silx/gui/plot/backends/glutils/__init__.py)0
-rw-r--r--src/silx/gui/plot/items/__init__.py53
-rw-r--r--src/silx/gui/plot/items/_arc_roi.py (renamed from silx/gui/plot/items/_arc_roi.py)0
-rw-r--r--src/silx/gui/plot/items/_pick.py (renamed from silx/gui/plot/items/_pick.py)0
-rw-r--r--src/silx/gui/plot/items/_roi_base.py (renamed from silx/gui/plot/items/_roi_base.py)0
-rw-r--r--src/silx/gui/plot/items/axis.py560
-rw-r--r--src/silx/gui/plot/items/complex.py (renamed from silx/gui/plot/items/complex.py)0
-rw-r--r--src/silx/gui/plot/items/core.py1733
-rw-r--r--src/silx/gui/plot/items/curve.py325
-rw-r--r--src/silx/gui/plot/items/histogram.py (renamed from silx/gui/plot/items/histogram.py)0
-rw-r--r--src/silx/gui/plot/items/image.py641
-rw-r--r--src/silx/gui/plot/items/image_aggregated.py229
-rwxr-xr-xsrc/silx/gui/plot/items/marker.py (renamed from silx/gui/plot/items/marker.py)0
-rw-r--r--src/silx/gui/plot/items/roi.py (renamed from silx/gui/plot/items/roi.py)0
-rw-r--r--src/silx/gui/plot/items/scatter.py1002
-rw-r--r--src/silx/gui/plot/items/shape.py287
-rw-r--r--src/silx/gui/plot/matplotlib/Colormap.py (renamed from silx/gui/plot/matplotlib/Colormap.py)0
-rw-r--r--src/silx/gui/plot/matplotlib/__init__.py (renamed from silx/gui/plot/matplotlib/__init__.py)0
-rw-r--r--src/silx/gui/plot/setup.py (renamed from silx/gui/plot/setup.py)0
-rw-r--r--src/silx/gui/plot/stats/__init__.py (renamed from silx/gui/plot/stats/__init__.py)0
-rw-r--r--src/silx/gui/plot/stats/stats.py (renamed from silx/gui/plot/stats/stats.py)0
-rw-r--r--src/silx/gui/plot/stats/statshandler.py (renamed from silx/gui/plot/stats/statshandler.py)0
-rw-r--r--src/silx/gui/plot/test/__init__.py24
-rw-r--r--src/silx/gui/plot/test/testAlphaSlider.py204
-rw-r--r--src/silx/gui/plot/test/testColorBar.py340
-rw-r--r--src/silx/gui/plot/test/testCompareImages.py106
-rw-r--r--src/silx/gui/plot/test/testComplexImageView.py84
-rw-r--r--src/silx/gui/plot/test/testCurvesROIWidget.py465
-rw-r--r--src/silx/gui/plot/test/testImageStack.py186
-rw-r--r--src/silx/gui/plot/test/testImageView.py194
-rw-r--r--src/silx/gui/plot/test/testInteraction.py78
-rw-r--r--src/silx/gui/plot/test/testItem.py360
-rw-r--r--src/silx/gui/plot/test/testLegendSelector.py130
-rw-r--r--src/silx/gui/plot/test/testLimitConstraints.py114
-rw-r--r--src/silx/gui/plot/test/testMaskToolsWidget.py306
-rw-r--r--src/silx/gui/plot/test/testPixelIntensityHistoAction.py145
-rw-r--r--src/silx/gui/plot/test/testPlotActions.py110
-rw-r--r--src/silx/gui/plot/test/testPlotInteraction.py160
-rwxr-xr-xsrc/silx/gui/plot/test/testPlotWidget.py2113
-rw-r--r--src/silx/gui/plot/test/testPlotWidgetNoBackend.py618
-rw-r--r--src/silx/gui/plot/test/testPlotWindow.py174
-rw-r--r--src/silx/gui/plot/test/testRoiStatsWidget.py277
-rw-r--r--src/silx/gui/plot/test/testSaveAction.py132
-rw-r--r--src/silx/gui/plot/test/testScatterMaskToolsWidget.py306
-rw-r--r--src/silx/gui/plot/test/testScatterView.py123
-rw-r--r--src/silx/gui/plot/test/testStackView.py248
-rw-r--r--src/silx/gui/plot/test/testStats.py1047
-rw-r--r--src/silx/gui/plot/test/testUtilsAxis.py203
-rw-r--r--src/silx/gui/plot/test/utils.py93
-rw-r--r--src/silx/gui/plot/tools/CurveLegendsWidget.py (renamed from silx/gui/plot/tools/CurveLegendsWidget.py)0
-rw-r--r--src/silx/gui/plot/tools/LimitsToolBar.py (renamed from silx/gui/plot/tools/LimitsToolBar.py)0
-rw-r--r--src/silx/gui/plot/tools/PositionInfo.py373
-rw-r--r--src/silx/gui/plot/tools/RadarView.py (renamed from silx/gui/plot/tools/RadarView.py)0
-rw-r--r--src/silx/gui/plot/tools/__init__.py (renamed from silx/gui/plot/tools/__init__.py)0
-rw-r--r--src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py (renamed from silx/gui/plot/tools/profile/ScatterProfileToolBar.py)0
-rw-r--r--src/silx/gui/plot/tools/profile/__init__.py (renamed from silx/gui/plot/tools/profile/__init__.py)0
-rw-r--r--src/silx/gui/plot/tools/profile/core.py (renamed from silx/gui/plot/tools/profile/core.py)0
-rw-r--r--src/silx/gui/plot/tools/profile/editors.py (renamed from silx/gui/plot/tools/profile/editors.py)0
-rw-r--r--src/silx/gui/plot/tools/profile/manager.py1079
-rw-r--r--src/silx/gui/plot/tools/profile/rois.py1156
-rw-r--r--src/silx/gui/plot/tools/profile/toolbar.py (renamed from silx/gui/plot/tools/profile/toolbar.py)0
-rw-r--r--src/silx/gui/plot/tools/roi.py1417
-rw-r--r--src/silx/gui/plot/tools/test/__init__.py24
-rw-r--r--src/silx/gui/plot/tools/test/testCurveLegendsWidget.py113
-rw-r--r--src/silx/gui/plot/tools/test/testProfile.py654
-rw-r--r--src/silx/gui/plot/tools/test/testROI.py682
-rw-r--r--src/silx/gui/plot/tools/test/testScatterProfileToolBar.py184
-rw-r--r--src/silx/gui/plot/tools/test/testTools.py135
-rw-r--r--src/silx/gui/plot/tools/toolbars.py (renamed from silx/gui/plot/tools/toolbars.py)0
-rw-r--r--src/silx/gui/plot/utils/__init__.py (renamed from silx/gui/plot/utils/__init__.py)0
-rw-r--r--src/silx/gui/plot/utils/axis.py398
-rw-r--r--src/silx/gui/plot/utils/intersections.py (renamed from silx/gui/plot/utils/intersections.py)0
-rw-r--r--src/silx/gui/plot3d/ParamTreeView.py522
-rw-r--r--src/silx/gui/plot3d/Plot3DWidget.py463
-rw-r--r--src/silx/gui/plot3d/Plot3DWindow.py (renamed from silx/gui/plot3d/Plot3DWindow.py)0
-rw-r--r--src/silx/gui/plot3d/SFViewParamTree.py1814
-rw-r--r--src/silx/gui/plot3d/ScalarFieldView.py (renamed from silx/gui/plot3d/ScalarFieldView.py)0
-rw-r--r--src/silx/gui/plot3d/SceneWidget.py (renamed from silx/gui/plot3d/SceneWidget.py)0
-rw-r--r--src/silx/gui/plot3d/SceneWindow.py (renamed from silx/gui/plot3d/SceneWindow.py)0
-rw-r--r--src/silx/gui/plot3d/__init__.py (renamed from silx/gui/plot3d/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/_model/__init__.py (renamed from silx/gui/plot3d/_model/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/_model/core.py (renamed from silx/gui/plot3d/_model/core.py)0
-rw-r--r--src/silx/gui/plot3d/_model/items.py1759
-rw-r--r--src/silx/gui/plot3d/_model/model.py (renamed from silx/gui/plot3d/_model/model.py)0
-rw-r--r--src/silx/gui/plot3d/actions/Plot3DAction.py (renamed from silx/gui/plot3d/actions/Plot3DAction.py)0
-rw-r--r--src/silx/gui/plot3d/actions/__init__.py (renamed from silx/gui/plot3d/actions/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/actions/io.py337
-rw-r--r--src/silx/gui/plot3d/actions/mode.py178
-rw-r--r--src/silx/gui/plot3d/actions/viewpoint.py (renamed from silx/gui/plot3d/actions/viewpoint.py)0
-rw-r--r--src/silx/gui/plot3d/conftest.py5
-rw-r--r--src/silx/gui/plot3d/items/__init__.py (renamed from silx/gui/plot3d/items/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/items/_pick.py (renamed from silx/gui/plot3d/items/_pick.py)0
-rw-r--r--src/silx/gui/plot3d/items/clipplane.py (renamed from silx/gui/plot3d/items/clipplane.py)0
-rw-r--r--src/silx/gui/plot3d/items/core.py778
-rw-r--r--src/silx/gui/plot3d/items/image.py425
-rw-r--r--src/silx/gui/plot3d/items/mesh.py (renamed from silx/gui/plot3d/items/mesh.py)0
-rw-r--r--src/silx/gui/plot3d/items/mixins.py (renamed from silx/gui/plot3d/items/mixins.py)0
-rw-r--r--src/silx/gui/plot3d/items/scatter.py (renamed from silx/gui/plot3d/items/scatter.py)0
-rw-r--r--src/silx/gui/plot3d/items/volume.py (renamed from silx/gui/plot3d/items/volume.py)0
-rw-r--r--src/silx/gui/plot3d/scene/__init__.py (renamed from silx/gui/plot3d/scene/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/scene/axes.py (renamed from silx/gui/plot3d/scene/axes.py)0
-rw-r--r--src/silx/gui/plot3d/scene/camera.py (renamed from silx/gui/plot3d/scene/camera.py)0
-rw-r--r--src/silx/gui/plot3d/scene/core.py (renamed from silx/gui/plot3d/scene/core.py)0
-rw-r--r--src/silx/gui/plot3d/scene/cutplane.py (renamed from silx/gui/plot3d/scene/cutplane.py)0
-rw-r--r--src/silx/gui/plot3d/scene/event.py (renamed from silx/gui/plot3d/scene/event.py)0
-rw-r--r--src/silx/gui/plot3d/scene/function.py (renamed from silx/gui/plot3d/scene/function.py)0
-rw-r--r--src/silx/gui/plot3d/scene/interaction.py (renamed from silx/gui/plot3d/scene/interaction.py)0
-rw-r--r--src/silx/gui/plot3d/scene/primitives.py (renamed from silx/gui/plot3d/scene/primitives.py)0
-rw-r--r--src/silx/gui/plot3d/scene/test/__init__.py24
-rw-r--r--src/silx/gui/plot3d/scene/test/test_transform.py80
-rw-r--r--src/silx/gui/plot3d/scene/test/test_utils.py258
-rw-r--r--src/silx/gui/plot3d/scene/text.py (renamed from silx/gui/plot3d/scene/text.py)0
-rw-r--r--src/silx/gui/plot3d/scene/transform.py (renamed from silx/gui/plot3d/scene/transform.py)0
-rw-r--r--src/silx/gui/plot3d/scene/utils.py (renamed from silx/gui/plot3d/scene/utils.py)0
-rw-r--r--src/silx/gui/plot3d/scene/viewport.py (renamed from silx/gui/plot3d/scene/viewport.py)0
-rw-r--r--src/silx/gui/plot3d/scene/window.py433
-rw-r--r--src/silx/gui/plot3d/setup.py (renamed from silx/gui/plot3d/setup.py)0
-rw-r--r--src/silx/gui/plot3d/test/__init__.py25
-rw-r--r--src/silx/gui/plot3d/test/testGL.py73
-rw-r--r--src/silx/gui/plot3d/test/testScalarFieldView.py128
-rw-r--r--src/silx/gui/plot3d/test/testSceneWidget.py72
-rw-r--r--src/silx/gui/plot3d/test/testSceneWidgetPicking.py314
-rw-r--r--src/silx/gui/plot3d/test/testSceneWindow.py233
-rw-r--r--src/silx/gui/plot3d/test/testStatsWidget.py201
-rw-r--r--src/silx/gui/plot3d/tools/GroupPropertiesWidget.py202
-rw-r--r--src/silx/gui/plot3d/tools/PositionInfoWidget.py225
-rw-r--r--src/silx/gui/plot3d/tools/ViewpointTools.py (renamed from silx/gui/plot3d/tools/ViewpointTools.py)0
-rw-r--r--src/silx/gui/plot3d/tools/__init__.py (renamed from silx/gui/plot3d/tools/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/tools/test/__init__.py25
-rw-r--r--src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py89
-rw-r--r--src/silx/gui/plot3d/tools/toolbars.py (renamed from silx/gui/plot3d/tools/toolbars.py)0
-rw-r--r--src/silx/gui/plot3d/utils/__init__.py (renamed from silx/gui/plot3d/utils/__init__.py)0
-rw-r--r--src/silx/gui/plot3d/utils/mng.py (renamed from silx/gui/plot3d/utils/mng.py)0
-rw-r--r--src/silx/gui/printer.py (renamed from silx/gui/printer.py)0
-rw-r--r--src/silx/gui/qt/__init__.py54
-rw-r--r--src/silx/gui/qt/_pyside_dynamic.py235
-rw-r--r--src/silx/gui/qt/_qt.py232
-rw-r--r--src/silx/gui/qt/_utils.py68
-rw-r--r--src/silx/gui/qt/inspect.py75
-rw-r--r--src/silx/gui/setup.py (renamed from silx/gui/setup.py)0
-rw-r--r--src/silx/gui/test/__init__.py24
-rwxr-xr-xsrc/silx/gui/test/test_colors.py603
-rw-r--r--src/silx/gui/test/test_console.py75
-rw-r--r--src/silx/gui/test/test_icons.py144
-rw-r--r--src/silx/gui/test/test_qt.py212
-rw-r--r--src/silx/gui/test/utils.py (renamed from silx/gui/test/utils.py)0
-rwxr-xr-xsrc/silx/gui/utils/__init__.py (renamed from silx/gui/utils/__init__.py)0
-rw-r--r--src/silx/gui/utils/concurrent.py (renamed from silx/gui/utils/concurrent.py)0
-rw-r--r--src/silx/gui/utils/glutils/__init__.py199
-rw-r--r--src/silx/gui/utils/image.py143
-rw-r--r--src/silx/gui/utils/matplotlib.py65
-rw-r--r--src/silx/gui/utils/projecturl.py (renamed from silx/gui/utils/projecturl.py)0
-rwxr-xr-xsrc/silx/gui/utils/qtutils.py (renamed from silx/gui/utils/qtutils.py)0
-rw-r--r--src/silx/gui/utils/signal.py (renamed from silx/gui/utils/signal.py)0
-rwxr-xr-xsrc/silx/gui/utils/test/__init__.py25
-rw-r--r--src/silx/gui/utils/test/test.py63
-rw-r--r--src/silx/gui/utils/test/test_async.py127
-rw-r--r--src/silx/gui/utils/test/test_glutils.py55
-rw-r--r--src/silx/gui/utils/test/test_image.py79
-rwxr-xr-xsrc/silx/gui/utils/test/test_qtutils.py65
-rw-r--r--src/silx/gui/utils/test/test_testutils.py44
-rw-r--r--src/silx/gui/utils/testutils.py508
-rw-r--r--src/silx/gui/widgets/BoxLayoutDockWidget.py (renamed from silx/gui/widgets/BoxLayoutDockWidget.py)0
-rw-r--r--src/silx/gui/widgets/ColormapNameComboBox.py (renamed from silx/gui/widgets/ColormapNameComboBox.py)0
-rw-r--r--src/silx/gui/widgets/ElidedLabel.py140
-rw-r--r--src/silx/gui/widgets/FloatEdit.py71
-rw-r--r--src/silx/gui/widgets/FlowLayout.py (renamed from silx/gui/widgets/FlowLayout.py)0
-rw-r--r--src/silx/gui/widgets/FrameBrowser.py (renamed from silx/gui/widgets/FrameBrowser.py)0
-rw-r--r--src/silx/gui/widgets/HierarchicalTableView.py (renamed from silx/gui/widgets/HierarchicalTableView.py)0
-rwxr-xr-xsrc/silx/gui/widgets/LegendIconWidget.py (renamed from silx/gui/widgets/LegendIconWidget.py)0
-rw-r--r--src/silx/gui/widgets/MedianFilterDialog.py (renamed from silx/gui/widgets/MedianFilterDialog.py)0
-rw-r--r--src/silx/gui/widgets/MultiModeAction.py (renamed from silx/gui/widgets/MultiModeAction.py)0
-rw-r--r--src/silx/gui/widgets/PeriodicTable.py831
-rw-r--r--src/silx/gui/widgets/PrintGeometryDialog.py222
-rw-r--r--src/silx/gui/widgets/PrintPreview.py697
-rw-r--r--src/silx/gui/widgets/RangeSlider.py776
-rw-r--r--src/silx/gui/widgets/TableWidget.py626
-rw-r--r--src/silx/gui/widgets/ThreadPoolPushButton.py (renamed from silx/gui/widgets/ThreadPoolPushButton.py)0
-rw-r--r--src/silx/gui/widgets/UrlSelectionTable.py169
-rw-r--r--src/silx/gui/widgets/WaitingPushButton.py241
-rw-r--r--src/silx/gui/widgets/__init__.py (renamed from silx/gui/widgets/__init__.py)0
-rw-r--r--src/silx/gui/widgets/setup.py (renamed from silx/gui/widgets/setup.py)0
-rw-r--r--src/silx/gui/widgets/test/__init__.py24
-rw-r--r--src/silx/gui/widgets/test/test_boxlayoutdockwidget.py72
-rw-r--r--src/silx/gui/widgets/test/test_elidedlabel.py100
-rw-r--r--src/silx/gui/widgets/test/test_flowlayout.py66
-rw-r--r--src/silx/gui/widgets/test/test_framebrowser.py62
-rw-r--r--src/silx/gui/widgets/test/test_hierarchicaltableview.py103
-rw-r--r--src/silx/gui/widgets/test/test_legendiconwidget.py63
-rw-r--r--src/silx/gui/widgets/test/test_periodictable.py148
-rw-r--r--src/silx/gui/widgets/test/test_printpreview.py63
-rw-r--r--src/silx/gui/widgets/test/test_rangeslider.py103
-rw-r--r--src/silx/gui/widgets/test/test_tablewidget.py50
-rw-r--r--src/silx/gui/widgets/test/test_threadpoolpushbutton.py124
-rw-r--r--src/silx/image/__init__.py (renamed from silx/image/__init__.py)0
-rw-r--r--src/silx/image/_boundingbox.py (renamed from silx/image/_boundingbox.py)0
-rw-r--r--src/silx/image/backprojection.py (renamed from silx/image/backprojection.py)0
-rw-r--r--src/silx/image/bilinear.pyx (renamed from silx/image/bilinear.pyx)0
-rw-r--r--src/silx/image/marchingsquares/__init__.py (renamed from silx/image/marchingsquares/__init__.py)0
-rw-r--r--src/silx/image/marchingsquares/_mergeimpl.pyx (renamed from silx/image/marchingsquares/_mergeimpl.pyx)0
-rw-r--r--src/silx/image/marchingsquares/_skimage.py (renamed from silx/image/marchingsquares/_skimage.py)0
-rw-r--r--src/silx/image/marchingsquares/include/patterns.h (renamed from silx/image/marchingsquares/include/patterns.h)0
-rw-r--r--src/silx/image/marchingsquares/setup.py51
-rw-r--r--src/silx/image/marchingsquares/test/__init__.py24
-rw-r--r--src/silx/image/marchingsquares/test/test_funcapi.py92
-rw-r--r--src/silx/image/marchingsquares/test/test_mergeimpl.py264
-rw-r--r--src/silx/image/medianfilter.py (renamed from silx/image/medianfilter.py)0
-rw-r--r--src/silx/image/phantomgenerator.py (renamed from silx/image/phantomgenerator.py)0
-rw-r--r--src/silx/image/projection.py (renamed from silx/image/projection.py)0
-rw-r--r--src/silx/image/reconstruction.py (renamed from silx/image/reconstruction.py)0
-rw-r--r--src/silx/image/setup.py (renamed from silx/image/setup.py)0
-rw-r--r--src/silx/image/shapes.pyx (renamed from silx/image/shapes.pyx)0
-rw-r--r--src/silx/image/sift.py (renamed from silx/image/sift.py)0
-rw-r--r--src/silx/image/test/__init__.py24
-rw-r--r--src/silx/image/test/test_bb.py74
-rw-r--r--src/silx/image/test/test_bilinear.py167
-rw-r--r--src/silx/image/test/test_medianfilter.py64
-rw-r--r--src/silx/image/test/test_shapes.py354
-rw-r--r--src/silx/image/test/test_tomography.py54
-rw-r--r--src/silx/image/tomography.py (renamed from silx/image/tomography.py)0
-rw-r--r--src/silx/image/utils.py (renamed from silx/image/utils.py)0
-rw-r--r--src/silx/io/__init__.py (renamed from silx/io/__init__.py)0
-rw-r--r--src/silx/io/commonh5.py1061
-rw-r--r--src/silx/io/configdict.py (renamed from silx/io/configdict.py)0
-rw-r--r--src/silx/io/convert.py335
-rw-r--r--src/silx/io/dictdump.py843
-rwxr-xr-xsrc/silx/io/fabioh5.py1050
-rw-r--r--src/silx/io/fioh5.py490
-rw-r--r--src/silx/io/h5py_utils.py440
-rw-r--r--src/silx/io/nxdata/__init__.py (renamed from silx/io/nxdata/__init__.py)0
-rw-r--r--src/silx/io/nxdata/_utils.py183
-rw-r--r--src/silx/io/nxdata/parse.py1004
-rw-r--r--src/silx/io/nxdata/write.py202
-rw-r--r--src/silx/io/octaveh5.py (renamed from silx/io/octaveh5.py)0
-rw-r--r--src/silx/io/rawh5.py (renamed from silx/io/rawh5.py)0
-rw-r--r--src/silx/io/setup.py (renamed from silx/io/setup.py)0
-rw-r--r--src/silx/io/specfile.pyx1268
-rw-r--r--src/silx/io/specfile/include/Lists.h (renamed from silx/io/specfile/include/Lists.h)0
-rw-r--r--src/silx/io/specfile/include/SpecFile.h (renamed from silx/io/specfile/include/SpecFile.h)0
-rw-r--r--src/silx/io/specfile/include/SpecFileCython.h (renamed from silx/io/specfile/include/SpecFileCython.h)0
-rw-r--r--src/silx/io/specfile/include/SpecFileP.h (renamed from silx/io/specfile/include/SpecFileP.h)0
-rw-r--r--src/silx/io/specfile/include/locale_management.h (renamed from silx/io/specfile/include/locale_management.h)0
-rw-r--r--src/silx/io/specfile/src/locale_management.c (renamed from silx/io/specfile/src/locale_management.c)0
-rw-r--r--src/silx/io/specfile/src/sfdata.c (renamed from silx/io/specfile/src/sfdata.c)0
-rw-r--r--src/silx/io/specfile/src/sfheader.c (renamed from silx/io/specfile/src/sfheader.c)0
-rw-r--r--src/silx/io/specfile/src/sfindex.c (renamed from silx/io/specfile/src/sfindex.c)0
-rw-r--r--src/silx/io/specfile/src/sfinit.c (renamed from silx/io/specfile/src/sfinit.c)0
-rw-r--r--src/silx/io/specfile/src/sflabel.c (renamed from silx/io/specfile/src/sflabel.c)0
-rw-r--r--src/silx/io/specfile/src/sflists.c (renamed from silx/io/specfile/src/sflists.c)0
-rw-r--r--src/silx/io/specfile/src/sfmca.c (renamed from silx/io/specfile/src/sfmca.c)0
-rw-r--r--src/silx/io/specfile/src/sftools.c (renamed from silx/io/specfile/src/sftools.c)0
-rw-r--r--src/silx/io/specfile/src/sfwrite.c (renamed from silx/io/specfile/src/sfwrite.c)0
-rw-r--r--src/silx/io/specfile_wrapper.pxd (renamed from silx/io/specfile_wrapper.pxd)0
-rw-r--r--src/silx/io/specfilewrapper.py (renamed from silx/io/specfilewrapper.py)0
-rw-r--r--src/silx/io/spech5.py907
-rw-r--r--src/silx/io/spectoh5.py (renamed from silx/io/spectoh5.py)0
-rw-r--r--src/silx/io/test/__init__.py23
-rw-r--r--src/silx/io/test/test_commonh5.py285
-rw-r--r--src/silx/io/test/test_dictdump.py1009
-rwxr-xr-xsrc/silx/io/test/test_fabioh5.py615
-rw-r--r--src/silx/io/test/test_fioh5.py299
-rw-r--r--src/silx/io/test/test_h5py_utils.py451
-rw-r--r--src/silx/io/test/test_nxdata.py563
-rw-r--r--src/silx/io/test/test_octaveh5.py156
-rw-r--r--src/silx/io/test/test_rawh5.py85
-rw-r--r--src/silx/io/test/test_specfile.py420
-rw-r--r--src/silx/io/test/test_specfilewrapper.py195
-rw-r--r--src/silx/io/test/test_spech5.py929
-rw-r--r--src/silx/io/test/test_spectoh5.py183
-rw-r--r--src/silx/io/test/test_url.py217
-rw-r--r--src/silx/io/test/test_utils.py923
-rw-r--r--src/silx/io/test/test_write_to_h5.py118
-rw-r--r--src/silx/io/url.py388
-rw-r--r--src/silx/io/utils.py1185
-rw-r--r--src/silx/math/__init__.py (renamed from silx/math/__init__.py)0
-rw-r--r--src/silx/math/_colormap.pyx571
-rw-r--r--src/silx/math/calibration.py (renamed from silx/math/calibration.py)0
-rw-r--r--src/silx/math/chistogramnd.pyx (renamed from silx/math/chistogramnd.pyx)0
-rw-r--r--src/silx/math/chistogramnd_lut.pyx (renamed from silx/math/chistogramnd_lut.pyx)0
-rw-r--r--src/silx/math/colormap.py450
-rw-r--r--src/silx/math/combo.pyx (renamed from silx/math/combo.pyx)0
-rw-r--r--src/silx/math/fft/__init__.py (renamed from silx/math/fft/__init__.py)0
-rw-r--r--src/silx/math/fft/basefft.py (renamed from silx/math/fft/basefft.py)0
-rw-r--r--src/silx/math/fft/clfft.py (renamed from silx/math/fft/clfft.py)0
-rw-r--r--src/silx/math/fft/cufft.py (renamed from silx/math/fft/cufft.py)0
-rw-r--r--src/silx/math/fft/fft.py (renamed from silx/math/fft/fft.py)0
-rw-r--r--src/silx/math/fft/fftw.py (renamed from silx/math/fft/fftw.py)0
-rw-r--r--src/silx/math/fft/npfft.py (renamed from silx/math/fft/npfft.py)0
-rw-r--r--src/silx/math/fft/setup.py (renamed from silx/math/fft/setup.py)0
-rw-r--r--src/silx/math/fft/test/__init__.py23
-rw-r--r--src/silx/math/fft/test/test_fft.py257
-rw-r--r--src/silx/math/fit/__init__.py (renamed from silx/math/fit/__init__.py)0
-rw-r--r--src/silx/math/fit/bgtheories.py (renamed from silx/math/fit/bgtheories.py)0
-rw-r--r--src/silx/math/fit/filters.pyx (renamed from silx/math/fit/filters.pyx)0
-rw-r--r--src/silx/math/fit/filters/include/filters.h (renamed from silx/math/fit/filters/include/filters.h)0
-rw-r--r--src/silx/math/fit/filters/src/smoothnd.c (renamed from silx/math/fit/filters/src/smoothnd.c)0
-rw-r--r--src/silx/math/fit/filters/src/snip1d.c (renamed from silx/math/fit/filters/src/snip1d.c)0
-rw-r--r--src/silx/math/fit/filters/src/snip2d.c (renamed from silx/math/fit/filters/src/snip2d.c)0
-rw-r--r--src/silx/math/fit/filters/src/snip3d.c (renamed from silx/math/fit/filters/src/snip3d.c)0
-rw-r--r--src/silx/math/fit/filters/src/strip.c (renamed from silx/math/fit/filters/src/strip.c)0
-rw-r--r--src/silx/math/fit/filters_wrapper.pxd (renamed from silx/math/fit/filters_wrapper.pxd)0
-rw-r--r--src/silx/math/fit/fitmanager.py1087
-rw-r--r--src/silx/math/fit/fittheories.py1374
-rw-r--r--src/silx/math/fit/fittheory.py (renamed from silx/math/fit/fittheory.py)0
-rw-r--r--src/silx/math/fit/functions.pyx (renamed from silx/math/fit/functions.pyx)0
-rw-r--r--src/silx/math/fit/functions/include/functions.h (renamed from silx/math/fit/functions/include/functions.h)0
-rw-r--r--src/silx/math/fit/functions/src/funs.c (renamed from silx/math/fit/functions/src/funs.c)0
-rw-r--r--src/silx/math/fit/functions_wrapper.pxd (renamed from silx/math/fit/functions_wrapper.pxd)0
-rw-r--r--src/silx/math/fit/leastsq.py (renamed from silx/math/fit/leastsq.py)0
-rw-r--r--src/silx/math/fit/peaks.pyx (renamed from silx/math/fit/peaks.pyx)0
-rw-r--r--src/silx/math/fit/peaks/include/peaks.h (renamed from silx/math/fit/peaks/include/peaks.h)0
-rw-r--r--src/silx/math/fit/peaks/src/peaks.c (renamed from silx/math/fit/peaks/src/peaks.c)0
-rw-r--r--src/silx/math/fit/peaks_wrapper.pxd (renamed from silx/math/fit/peaks_wrapper.pxd)0
-rw-r--r--src/silx/math/fit/setup.py (renamed from silx/math/fit/setup.py)0
-rw-r--r--src/silx/math/fit/test/__init__.py23
-rw-r--r--src/silx/math/fit/test/test_bgtheories.py154
-rw-r--r--src/silx/math/fit/test/test_filters.py122
-rw-r--r--src/silx/math/fit/test/test_fit.py373
-rw-r--r--src/silx/math/fit/test/test_fitmanager.py498
-rw-r--r--src/silx/math/fit/test/test_functions.py259
-rw-r--r--src/silx/math/fit/test/test_peaks.py132
-rw-r--r--src/silx/math/histogram.py (renamed from silx/math/histogram.py)0
-rw-r--r--src/silx/math/histogramnd/include/histogramnd_c.h (renamed from silx/math/histogramnd/include/histogramnd_c.h)0
-rw-r--r--src/silx/math/histogramnd/include/msvc/stdint.h (renamed from silx/math/histogramnd/include/msvc/stdint.h)0
-rw-r--r--src/silx/math/histogramnd/include/templates.h (renamed from silx/math/histogramnd/include/templates.h)0
-rw-r--r--src/silx/math/histogramnd/src/histogramnd_c.c (renamed from silx/math/histogramnd/src/histogramnd_c.c)0
-rw-r--r--src/silx/math/histogramnd/src/histogramnd_template.c (renamed from silx/math/histogramnd/src/histogramnd_template.c)0
-rw-r--r--src/silx/math/histogramnd_c.pxd (renamed from silx/math/histogramnd_c.pxd)0
-rw-r--r--src/silx/math/include/math_compatibility.h (renamed from silx/math/include/math_compatibility.h)0
-rw-r--r--src/silx/math/interpolate.pyx (renamed from silx/math/interpolate.pyx)0
-rw-r--r--src/silx/math/marchingcubes.pyx (renamed from silx/math/marchingcubes.pyx)0
-rw-r--r--src/silx/math/marchingcubes/mc.hpp (renamed from silx/math/marchingcubes/mc.hpp)0
-rw-r--r--src/silx/math/marchingcubes/mc_lut.cpp (renamed from silx/math/marchingcubes/mc_lut.cpp)0
-rw-r--r--src/silx/math/math_compatibility.pxd (renamed from silx/math/math_compatibility.pxd)0
-rw-r--r--src/silx/math/mc.pxd (renamed from silx/math/mc.pxd)0
-rw-r--r--src/silx/math/medianfilter/__init__.py (renamed from silx/math/medianfilter/__init__.py)0
-rw-r--r--src/silx/math/medianfilter/include/median_filter.hpp (renamed from silx/math/medianfilter/include/median_filter.hpp)0
-rw-r--r--src/silx/math/medianfilter/median_filter.pxd (renamed from silx/math/medianfilter/median_filter.pxd)0
-rw-r--r--src/silx/math/medianfilter/medianfilter.pyx (renamed from silx/math/medianfilter/medianfilter.pyx)0
-rw-r--r--src/silx/math/medianfilter/setup.py (renamed from silx/math/medianfilter/setup.py)0
-rw-r--r--src/silx/math/medianfilter/test/__init__.py23
-rw-r--r--src/silx/math/medianfilter/test/benchmark.py122
-rw-r--r--src/silx/math/medianfilter/test/test_medianfilter.py722
-rw-r--r--src/silx/math/setup.py99
-rw-r--r--src/silx/math/test/__init__.py23
-rw-r--r--src/silx/math/test/benchmark_combo.py192
-rw-r--r--src/silx/math/test/histo_benchmarks.py (renamed from silx/math/test/histo_benchmarks.py)0
-rw-r--r--src/silx/math/test/test_HistogramndLut_nominal.py571
-rw-r--r--src/silx/math/test/test_calibration.py145
-rw-r--r--src/silx/math/test/test_colormap.py269
-rw-r--r--src/silx/math/test/test_combo.py207
-rw-r--r--src/silx/math/test/test_histogramnd_error.py519
-rw-r--r--src/silx/math/test/test_histogramnd_nominal.py937
-rw-r--r--src/silx/math/test/test_histogramnd_vs_np.py826
-rw-r--r--src/silx/math/test/test_interpolate.py125
-rw-r--r--src/silx/math/test/test_marchingcubes.py174
-rw-r--r--src/silx/opencl/__init__.py (renamed from silx/opencl/__init__.py)0
-rw-r--r--src/silx/opencl/backprojection.py (renamed from silx/opencl/backprojection.py)0
-rw-r--r--src/silx/opencl/codec/__init__.py (renamed from silx/opencl/codec/__init__.py)0
-rw-r--r--src/silx/opencl/codec/byte_offset.py (renamed from silx/opencl/codec/byte_offset.py)0
-rw-r--r--src/silx/opencl/codec/setup.py (renamed from silx/opencl/codec/setup.py)0
-rw-r--r--src/silx/opencl/codec/test/__init__.py23
-rw-r--r--src/silx/opencl/codec/test/test_byte_offset.py303
-rw-r--r--src/silx/opencl/common.py692
-rw-r--r--src/silx/opencl/conftest.py5
-rw-r--r--src/silx/opencl/convolution.py (renamed from silx/opencl/convolution.py)0
-rw-r--r--src/silx/opencl/image.py (renamed from silx/opencl/image.py)0
-rw-r--r--src/silx/opencl/linalg.py (renamed from silx/opencl/linalg.py)0
-rw-r--r--src/silx/opencl/medfilt.py (renamed from silx/opencl/medfilt.py)0
-rw-r--r--src/silx/opencl/processing.py (renamed from silx/opencl/processing.py)0
-rw-r--r--src/silx/opencl/projection.py (renamed from silx/opencl/projection.py)0
-rw-r--r--src/silx/opencl/reconstruction.py (renamed from silx/opencl/reconstruction.py)0
-rw-r--r--src/silx/opencl/setup.py (renamed from silx/opencl/setup.py)0
-rw-r--r--src/silx/opencl/sinofilter.py (renamed from silx/opencl/sinofilter.py)0
-rw-r--r--src/silx/opencl/sparse.py (renamed from silx/opencl/sparse.py)0
-rw-r--r--src/silx/opencl/statistics.py (renamed from silx/opencl/statistics.py)0
-rw-r--r--src/silx/opencl/test/__init__.py23
-rw-r--r--src/silx/opencl/test/test_addition.py140
-rw-r--r--src/silx/opencl/test/test_array_utils.py152
-rw-r--r--src/silx/opencl/test/test_backprojection.py217
-rw-r--r--src/silx/opencl/test/test_convolution.py280
-rw-r--r--src/silx/opencl/test/test_doubleword.py244
-rw-r--r--src/silx/opencl/test/test_image.py125
-rw-r--r--src/silx/opencl/test/test_kahan.py254
-rw-r--r--src/silx/opencl/test/test_linalg.py204
-rw-r--r--src/silx/opencl/test/test_medfilt.py162
-rw-r--r--src/silx/opencl/test/test_projection.py121
-rw-r--r--src/silx/opencl/test/test_sparse.py188
-rw-r--r--src/silx/opencl/test/test_stats.py106
-rw-r--r--src/silx/opencl/utils.py (renamed from silx/opencl/utils.py)0
-rw-r--r--src/silx/resources/__init__.py (renamed from silx/resources/__init__.py)0
-rw-r--r--src/silx/resources/gui/colormaps/cividis.npy (renamed from silx/resources/gui/colormaps/cividis.npy)bin3200 -> 3200 bytes
-rw-r--r--src/silx/resources/gui/colormaps/inferno.npy (renamed from silx/resources/gui/colormaps/inferno.npy)bin3152 -> 3152 bytes
-rw-r--r--src/silx/resources/gui/colormaps/magma.npy (renamed from silx/resources/gui/colormaps/magma.npy)bin3152 -> 3152 bytes
-rw-r--r--src/silx/resources/gui/colormaps/plasma.npy (renamed from silx/resources/gui/colormaps/plasma.npy)bin3152 -> 3152 bytes
-rw-r--r--src/silx/resources/gui/colormaps/viridis.npy (renamed from silx/resources/gui/colormaps/viridis.npy)bin3152 -> 3152 bytes
-rw-r--r--src/silx/resources/gui/icons/3d-plane-normal-x.png (renamed from silx/resources/gui/icons/3d-plane-normal-x.png)bin743 -> 743 bytes
-rw-r--r--src/silx/resources/gui/icons/3d-plane-normal-x.svg (renamed from silx/resources/gui/icons/3d-plane-normal-x.svg)0
-rw-r--r--src/silx/resources/gui/icons/3d-plane-normal-y.png (renamed from silx/resources/gui/icons/3d-plane-normal-y.png)bin791 -> 791 bytes
-rw-r--r--src/silx/resources/gui/icons/3d-plane-normal-y.svg (renamed from silx/resources/gui/icons/3d-plane-normal-y.svg)0
-rw-r--r--src/silx/resources/gui/icons/3d-plane-normal-z.png (renamed from silx/resources/gui/icons/3d-plane-normal-z.png)bin681 -> 681 bytes
-rw-r--r--src/silx/resources/gui/icons/3d-plane-normal-z.svg (renamed from silx/resources/gui/icons/3d-plane-normal-z.svg)0
-rw-r--r--src/silx/resources/gui/icons/3d-plane-pan.png (renamed from silx/resources/gui/icons/3d-plane-pan.png)bin1428 -> 1428 bytes
-rw-r--r--src/silx/resources/gui/icons/3d-plane-pan.svg (renamed from silx/resources/gui/icons/3d-plane-pan.svg)0
-rw-r--r--src/silx/resources/gui/icons/3d-plane.png (renamed from silx/resources/gui/icons/3d-plane.png)bin1134 -> 1134 bytes
-rw-r--r--src/silx/resources/gui/icons/3d-plane.svg (renamed from silx/resources/gui/icons/3d-plane.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-range-horizontal.png (renamed from silx/resources/gui/icons/add-range-horizontal.png)bin560 -> 560 bytes
-rw-r--r--src/silx/resources/gui/icons/add-range-horizontal.svg (renamed from silx/resources/gui/icons/add-range-horizontal.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-arc.png (renamed from silx/resources/gui/icons/add-shape-arc.png)bin1164 -> 1164 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-arc.svg (renamed from silx/resources/gui/icons/add-shape-arc.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-circle.png (renamed from silx/resources/gui/icons/add-shape-circle.png)bin1238 -> 1238 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-circle.svg (renamed from silx/resources/gui/icons/add-shape-circle.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-cross.png (renamed from silx/resources/gui/icons/add-shape-cross.png)bin501 -> 501 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-cross.svg (renamed from silx/resources/gui/icons/add-shape-cross.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-diagonal.png (renamed from silx/resources/gui/icons/add-shape-diagonal.png)bin626 -> 626 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-diagonal.svg (renamed from silx/resources/gui/icons/add-shape-diagonal.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-ellipse.png (renamed from silx/resources/gui/icons/add-shape-ellipse.png)bin1180 -> 1180 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-ellipse.svg (renamed from silx/resources/gui/icons/add-shape-ellipse.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-horizontal.png (renamed from silx/resources/gui/icons/add-shape-horizontal.png)bin408 -> 408 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-horizontal.svg (renamed from silx/resources/gui/icons/add-shape-horizontal.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-point.png (renamed from silx/resources/gui/icons/add-shape-point.png)bin482 -> 482 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-point.svg (renamed from silx/resources/gui/icons/add-shape-point.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-polygon.png (renamed from silx/resources/gui/icons/add-shape-polygon.png)bin1217 -> 1217 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-polygon.svg (renamed from silx/resources/gui/icons/add-shape-polygon.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-rectangle.png (renamed from silx/resources/gui/icons/add-shape-rectangle.png)bin463 -> 463 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-rectangle.svg (renamed from silx/resources/gui/icons/add-shape-rectangle.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-unknown.png (renamed from silx/resources/gui/icons/add-shape-unknown.png)bin1506 -> 1506 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-unknown.svg (renamed from silx/resources/gui/icons/add-shape-unknown.svg)0
-rw-r--r--src/silx/resources/gui/icons/add-shape-vertical.png (renamed from silx/resources/gui/icons/add-shape-vertical.png)bin422 -> 422 bytes
-rw-r--r--src/silx/resources/gui/icons/add-shape-vertical.svg (renamed from silx/resources/gui/icons/add-shape-vertical.svg)0
-rw-r--r--src/silx/resources/gui/icons/add.png (renamed from silx/resources/gui/icons/add.png)bin470 -> 470 bytes
-rw-r--r--src/silx/resources/gui/icons/add.svg (renamed from silx/resources/gui/icons/add.svg)0
-rw-r--r--src/silx/resources/gui/icons/aggregation-mode.pngbin0 -> 826 bytes
-rw-r--r--src/silx/resources/gui/icons/aggregation-mode.svg4
-rw-r--r--src/silx/resources/gui/icons/arrow-keys.png (renamed from silx/resources/gui/icons/arrow-keys.png)bin669 -> 669 bytes
-rw-r--r--src/silx/resources/gui/icons/arrow-keys.svg (renamed from silx/resources/gui/icons/arrow-keys.svg)0
-rw-r--r--src/silx/resources/gui/icons/axis.png (renamed from silx/resources/gui/icons/axis.png)bin1740 -> 1740 bytes
-rw-r--r--src/silx/resources/gui/icons/axis.svg (renamed from silx/resources/gui/icons/axis.svg)0
-rw-r--r--src/silx/resources/gui/icons/backend-opengl.png (renamed from silx/resources/gui/icons/backend-opengl.png)bin1582 -> 1582 bytes
-rw-r--r--src/silx/resources/gui/icons/backend-opengl.svg (renamed from silx/resources/gui/icons/backend-opengl.svg)0
-rw-r--r--src/silx/resources/gui/icons/camera.png (renamed from silx/resources/gui/icons/camera.png)bin348 -> 348 bytes
-rw-r--r--src/silx/resources/gui/icons/camera.svg (renamed from silx/resources/gui/icons/camera.svg)0
-rw-r--r--src/silx/resources/gui/icons/clipboard.png (renamed from silx/resources/gui/icons/clipboard.png)bin736 -> 736 bytes
-rw-r--r--src/silx/resources/gui/icons/clipboard.svg (renamed from silx/resources/gui/icons/clipboard.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/close.png (renamed from silx/resources/gui/icons/close.png)bin2243 -> 2243 bytes
-rw-r--r--src/silx/resources/gui/icons/close.svg (renamed from silx/resources/gui/icons/close.svg)0
-rw-r--r--src/silx/resources/gui/icons/colorbar.png (renamed from silx/resources/gui/icons/colorbar.png)bin657 -> 657 bytes
-rw-r--r--src/silx/resources/gui/icons/colorbar.svg (renamed from silx/resources/gui/icons/colorbar.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-histogram.png (renamed from silx/resources/gui/icons/colormap-histogram.png)bin641 -> 641 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-histogram.svg (renamed from silx/resources/gui/icons/colormap-histogram.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-none.png (renamed from silx/resources/gui/icons/colormap-none.png)bin232 -> 232 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-none.svg (renamed from silx/resources/gui/icons/colormap-none.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-arcsinh.png (renamed from silx/resources/gui/icons/colormap-norm-arcsinh.png)bin648 -> 648 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-arcsinh.svg (renamed from silx/resources/gui/icons/colormap-norm-arcsinh.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-gamma.png (renamed from silx/resources/gui/icons/colormap-norm-gamma.png)bin994 -> 994 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-gamma.svg (renamed from silx/resources/gui/icons/colormap-norm-gamma.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-linear.png (renamed from silx/resources/gui/icons/colormap-norm-linear.png)bin675 -> 675 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-linear.svg (renamed from silx/resources/gui/icons/colormap-norm-linear.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-log.png (renamed from silx/resources/gui/icons/colormap-norm-log.png)bin512 -> 512 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-log.svg (renamed from silx/resources/gui/icons/colormap-norm-log.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-sqrt.png (renamed from silx/resources/gui/icons/colormap-norm-sqrt.png)bin569 -> 569 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-norm-sqrt.svg (renamed from silx/resources/gui/icons/colormap-norm-sqrt.svg)0
-rw-r--r--src/silx/resources/gui/icons/colormap-range.png (renamed from silx/resources/gui/icons/colormap-range.png)bin284 -> 284 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap-range.svg (renamed from silx/resources/gui/icons/colormap-range.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/colormap.png (renamed from silx/resources/gui/icons/colormap.png)bin1583 -> 1583 bytes
-rw-r--r--src/silx/resources/gui/icons/colormap.svg (renamed from silx/resources/gui/icons/colormap.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-align-auto.png (renamed from silx/resources/gui/icons/compare-align-auto.png)bin1446 -> 1446 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-align-auto.svg4
-rw-r--r--src/silx/resources/gui/icons/compare-align-center.png (renamed from silx/resources/gui/icons/compare-align-center.png)bin716 -> 716 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-align-center.svg4
-rw-r--r--src/silx/resources/gui/icons/compare-align-origin.png (renamed from silx/resources/gui/icons/compare-align-origin.png)bin728 -> 728 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-align-origin.svg4
-rw-r--r--src/silx/resources/gui/icons/compare-align-stretch.png (renamed from silx/resources/gui/icons/compare-align-stretch.png)bin903 -> 903 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-align-stretch.svg4
-rw-r--r--src/silx/resources/gui/icons/compare-keypoints.png (renamed from silx/resources/gui/icons/compare-keypoints.png)bin616 -> 616 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-keypoints.svg (renamed from silx/resources/gui/icons/compare-keypoints.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-a-minus-b.png (renamed from silx/resources/gui/icons/compare-mode-a-minus-b.png)bin3862 -> 3862 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-a-minus-b.svg (renamed from silx/resources/gui/icons/compare-mode-a-minus-b.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-a.png (renamed from silx/resources/gui/icons/compare-mode-a.png)bin803 -> 803 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-a.svg (renamed from silx/resources/gui/icons/compare-mode-a.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-b.png (renamed from silx/resources/gui/icons/compare-mode-b.png)bin740 -> 740 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-b.svg (renamed from silx/resources/gui/icons/compare-mode-b.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-hline.png (renamed from silx/resources/gui/icons/compare-mode-hline.png)bin902 -> 902 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-hline.svg (renamed from silx/resources/gui/icons/compare-mode-hline.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-rb-channel.png (renamed from silx/resources/gui/icons/compare-mode-rb-channel.png)bin1269 -> 1269 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-rb-channel.svg (renamed from silx/resources/gui/icons/compare-mode-rb-channel.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-rbneg-channel.png (renamed from silx/resources/gui/icons/compare-mode-rbneg-channel.png)bin1260 -> 1260 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-rbneg-channel.svg (renamed from silx/resources/gui/icons/compare-mode-rbneg-channel.svg)0
-rw-r--r--src/silx/resources/gui/icons/compare-mode-vline.png (renamed from silx/resources/gui/icons/compare-mode-vline.png)bin1079 -> 1079 bytes
-rw-r--r--src/silx/resources/gui/icons/compare-mode-vline.svg (renamed from silx/resources/gui/icons/compare-mode-vline.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/crop.png (renamed from silx/resources/gui/icons/crop.png)bin642 -> 642 bytes
-rw-r--r--src/silx/resources/gui/icons/crop.svg (renamed from silx/resources/gui/icons/crop.svg)0
-rw-r--r--src/silx/resources/gui/icons/crosshair.png (renamed from silx/resources/gui/icons/crosshair.png)bin1196 -> 1196 bytes
-rw-r--r--src/silx/resources/gui/icons/crosshair.svg (renamed from silx/resources/gui/icons/crosshair.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-back.png (renamed from silx/resources/gui/icons/cube-back.png)bin737 -> 737 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-back.svg (renamed from silx/resources/gui/icons/cube-back.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-bottom.png (renamed from silx/resources/gui/icons/cube-bottom.png)bin833 -> 833 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-bottom.svg (renamed from silx/resources/gui/icons/cube-bottom.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-front.png (renamed from silx/resources/gui/icons/cube-front.png)bin708 -> 708 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-front.svg (renamed from silx/resources/gui/icons/cube-front.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-left.png (renamed from silx/resources/gui/icons/cube-left.png)bin712 -> 712 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-left.svg (renamed from silx/resources/gui/icons/cube-left.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-right.png (renamed from silx/resources/gui/icons/cube-right.png)bin701 -> 701 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-right.svg (renamed from silx/resources/gui/icons/cube-right.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-rotate.png (renamed from silx/resources/gui/icons/cube-rotate.png)bin955 -> 955 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-rotate.svg (renamed from silx/resources/gui/icons/cube-rotate.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube-top.png (renamed from silx/resources/gui/icons/cube-top.png)bin767 -> 767 bytes
-rw-r--r--src/silx/resources/gui/icons/cube-top.svg (renamed from silx/resources/gui/icons/cube-top.svg)0
-rw-r--r--src/silx/resources/gui/icons/cube.png (renamed from silx/resources/gui/icons/cube.png)bin953 -> 953 bytes
-rw-r--r--src/silx/resources/gui/icons/cube.svg (renamed from silx/resources/gui/icons/cube.svg)0
-rw-r--r--src/silx/resources/gui/icons/description-description.png (renamed from silx/resources/gui/icons/description-description.png)bin756 -> 756 bytes
-rw-r--r--src/silx/resources/gui/icons/description-description.svg (renamed from silx/resources/gui/icons/description-description.svg)0
-rw-r--r--src/silx/resources/gui/icons/description-error.png (renamed from silx/resources/gui/icons/description-error.png)bin952 -> 952 bytes
-rw-r--r--src/silx/resources/gui/icons/description-error.svg (renamed from silx/resources/gui/icons/description-error.svg)0
-rw-r--r--src/silx/resources/gui/icons/description-name.png (renamed from silx/resources/gui/icons/description-name.png)bin822 -> 822 bytes
-rw-r--r--src/silx/resources/gui/icons/description-name.svg (renamed from silx/resources/gui/icons/description-name.svg)0
-rw-r--r--src/silx/resources/gui/icons/description-program.png (renamed from silx/resources/gui/icons/description-program.png)bin767 -> 767 bytes
-rw-r--r--src/silx/resources/gui/icons/description-program.svg (renamed from silx/resources/gui/icons/description-program.svg)0
-rw-r--r--src/silx/resources/gui/icons/description-title.png (renamed from silx/resources/gui/icons/description-title.png)bin707 -> 707 bytes
-rw-r--r--src/silx/resources/gui/icons/description-title.svg (renamed from silx/resources/gui/icons/description-title.svg)0
-rw-r--r--src/silx/resources/gui/icons/description-value.png (renamed from silx/resources/gui/icons/description-value.png)bin833 -> 833 bytes
-rw-r--r--src/silx/resources/gui/icons/description-value.svg (renamed from silx/resources/gui/icons/description-value.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/document-open.png (renamed from silx/resources/gui/icons/document-open.png)bin2676 -> 2676 bytes
-rw-r--r--src/silx/resources/gui/icons/document-open.svg (renamed from silx/resources/gui/icons/document-open.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/document-print.png (renamed from silx/resources/gui/icons/document-print.png)bin702 -> 702 bytes
-rw-r--r--src/silx/resources/gui/icons/document-print.svg (renamed from silx/resources/gui/icons/document-print.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/document-save.png (renamed from silx/resources/gui/icons/document-save.png)bin535 -> 535 bytes
-rw-r--r--src/silx/resources/gui/icons/document-save.svg (renamed from silx/resources/gui/icons/document-save.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/draw-brush.png (renamed from silx/resources/gui/icons/draw-brush.png)bin1466 -> 1466 bytes
-rw-r--r--src/silx/resources/gui/icons/draw-brush.svg (renamed from silx/resources/gui/icons/draw-brush.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/draw-pencil.png (renamed from silx/resources/gui/icons/draw-pencil.png)bin1055 -> 1055 bytes
-rw-r--r--src/silx/resources/gui/icons/draw-pencil.svg (renamed from silx/resources/gui/icons/draw-pencil.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/draw-rubber.png (renamed from silx/resources/gui/icons/draw-rubber.png)bin1154 -> 1154 bytes
-rw-r--r--src/silx/resources/gui/icons/draw-rubber.svg (renamed from silx/resources/gui/icons/draw-rubber.svg)0
-rw-r--r--src/silx/resources/gui/icons/edit-copy.png (renamed from silx/resources/gui/icons/edit-copy.png)bin2191 -> 2191 bytes
-rw-r--r--src/silx/resources/gui/icons/edit-copy.svg (renamed from silx/resources/gui/icons/edit-copy.svg)0
-rw-r--r--src/silx/resources/gui/icons/eye.png (renamed from silx/resources/gui/icons/eye.png)bin755 -> 755 bytes
-rw-r--r--src/silx/resources/gui/icons/eye.svg (renamed from silx/resources/gui/icons/eye.svg)0
-rw-r--r--src/silx/resources/gui/icons/first.png (renamed from silx/resources/gui/icons/first.png)bin1177 -> 1177 bytes
-rw-r--r--src/silx/resources/gui/icons/first.svg (renamed from silx/resources/gui/icons/first.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/folder.png (renamed from silx/resources/gui/icons/folder.png)bin2583 -> 2583 bytes
-rw-r--r--src/silx/resources/gui/icons/folder.svg (renamed from silx/resources/gui/icons/folder.svg)0
-rw-r--r--src/silx/resources/gui/icons/image-mask.png (renamed from silx/resources/gui/icons/image-mask.png)bin852 -> 852 bytes
-rw-r--r--src/silx/resources/gui/icons/image-mask.svg (renamed from silx/resources/gui/icons/image-mask.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/image-select-add.png (renamed from silx/resources/gui/icons/image-select-add.png)bin2531 -> 2531 bytes
-rw-r--r--src/silx/resources/gui/icons/image-select-add.svg (renamed from silx/resources/gui/icons/image-select-add.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/image-select-box.png (renamed from silx/resources/gui/icons/image-select-box.png)bin3036 -> 3036 bytes
-rw-r--r--src/silx/resources/gui/icons/image-select-box.svg (renamed from silx/resources/gui/icons/image-select-box.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/image-select-brush.png (renamed from silx/resources/gui/icons/image-select-brush.png)bin3300 -> 3300 bytes
-rw-r--r--src/silx/resources/gui/icons/image-select-brush.svg (renamed from silx/resources/gui/icons/image-select-brush.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/image-select-erase-rubber.png (renamed from silx/resources/gui/icons/image-select-erase-rubber.png)bin1638 -> 1638 bytes
-rw-r--r--src/silx/resources/gui/icons/image-select-erase-rubber.svg (renamed from silx/resources/gui/icons/image-select-erase-rubber.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/image-select-erase.png (renamed from silx/resources/gui/icons/image-select-erase.png)bin2286 -> 2286 bytes
-rw-r--r--src/silx/resources/gui/icons/image-select-erase.svg (renamed from silx/resources/gui/icons/image-select-erase.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/image.png (renamed from silx/resources/gui/icons/image.png)bin2572 -> 2572 bytes
-rw-r--r--src/silx/resources/gui/icons/image.svg (renamed from silx/resources/gui/icons/image.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-0dim.png (renamed from silx/resources/gui/icons/item-0dim.png)bin305 -> 305 bytes
-rw-r--r--src/silx/resources/gui/icons/item-0dim.svg (renamed from silx/resources/gui/icons/item-0dim.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-1dim.png (renamed from silx/resources/gui/icons/item-1dim.png)bin674 -> 674 bytes
-rw-r--r--src/silx/resources/gui/icons/item-1dim.svg (renamed from silx/resources/gui/icons/item-1dim.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-2dim.png (renamed from silx/resources/gui/icons/item-2dim.png)bin233 -> 233 bytes
-rw-r--r--src/silx/resources/gui/icons/item-2dim.svg (renamed from silx/resources/gui/icons/item-2dim.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-3dim.png (renamed from silx/resources/gui/icons/item-3dim.png)bin582 -> 582 bytes
-rw-r--r--src/silx/resources/gui/icons/item-3dim.svg (renamed from silx/resources/gui/icons/item-3dim.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-ndim.png (renamed from silx/resources/gui/icons/item-ndim.png)bin947 -> 947 bytes
-rw-r--r--src/silx/resources/gui/icons/item-ndim.svg (renamed from silx/resources/gui/icons/item-ndim.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-none.png (renamed from silx/resources/gui/icons/item-none.png)bin637 -> 637 bytes
-rw-r--r--src/silx/resources/gui/icons/item-none.svg (renamed from silx/resources/gui/icons/item-none.svg)0
-rw-r--r--src/silx/resources/gui/icons/item-object.png (renamed from silx/resources/gui/icons/item-object.png)bin836 -> 836 bytes
-rw-r--r--src/silx/resources/gui/icons/item-object.svg (renamed from silx/resources/gui/icons/item-object.svg)0
-rw-r--r--src/silx/resources/gui/icons/last.png (renamed from silx/resources/gui/icons/last.png)bin1111 -> 1111 bytes
-rw-r--r--src/silx/resources/gui/icons/last.svg (renamed from silx/resources/gui/icons/last.svg)0
-rw-r--r--src/silx/resources/gui/icons/layer-nx.png (renamed from silx/resources/gui/icons/layer-nx.png)bin459 -> 459 bytes
-rw-r--r--src/silx/resources/gui/icons/layer-nx.svg (renamed from silx/resources/gui/icons/layer-nx.svg)0
-rw-r--r--src/silx/resources/gui/icons/mask-clear-all.png (renamed from silx/resources/gui/icons/mask-clear-all.png)bin1383 -> 1383 bytes
-rw-r--r--src/silx/resources/gui/icons/mask-clear-all.svg (renamed from silx/resources/gui/icons/mask-clear-all.svg)0
-rw-r--r--src/silx/resources/gui/icons/mask-clear.png (renamed from silx/resources/gui/icons/mask-clear.png)bin1086 -> 1086 bytes
-rw-r--r--src/silx/resources/gui/icons/mask-clear.svg (renamed from silx/resources/gui/icons/mask-clear.svg)0
-rw-r--r--src/silx/resources/gui/icons/mask-invert.png (renamed from silx/resources/gui/icons/mask-invert.png)bin717 -> 717 bytes
-rw-r--r--src/silx/resources/gui/icons/mask-invert.svg (renamed from silx/resources/gui/icons/mask-invert.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-amplitude.png (renamed from silx/resources/gui/icons/math-amplitude.png)bin526 -> 526 bytes
-rw-r--r--src/silx/resources/gui/icons/math-amplitude.svg (renamed from silx/resources/gui/icons/math-amplitude.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-average.png (renamed from silx/resources/gui/icons/math-average.png)bin571 -> 571 bytes
-rw-r--r--src/silx/resources/gui/icons/math-average.svg (renamed from silx/resources/gui/icons/math-average.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-derive.png (renamed from silx/resources/gui/icons/math-derive.png)bin593 -> 593 bytes
-rw-r--r--src/silx/resources/gui/icons/math-derive.svg (renamed from silx/resources/gui/icons/math-derive.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-energy.png (renamed from silx/resources/gui/icons/math-energy.png)bin645 -> 645 bytes
-rw-r--r--src/silx/resources/gui/icons/math-energy.svg (renamed from silx/resources/gui/icons/math-energy.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-fit.png (renamed from silx/resources/gui/icons/math-fit.png)bin768 -> 768 bytes
-rw-r--r--src/silx/resources/gui/icons/math-fit.svg (renamed from silx/resources/gui/icons/math-fit.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-imaginary.png (renamed from silx/resources/gui/icons/math-imaginary.png)bin630 -> 630 bytes
-rw-r--r--src/silx/resources/gui/icons/math-imaginary.svg (renamed from silx/resources/gui/icons/math-imaginary.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-mean.png (renamed from silx/resources/gui/icons/math-mean.png)bin1487 -> 1487 bytes
-rw-r--r--src/silx/resources/gui/icons/math-mean.svg (renamed from silx/resources/gui/icons/math-mean.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-normalize.png (renamed from silx/resources/gui/icons/math-normalize.png)bin653 -> 653 bytes
-rw-r--r--src/silx/resources/gui/icons/math-normalize.svg (renamed from silx/resources/gui/icons/math-normalize.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-peak-reset.png (renamed from silx/resources/gui/icons/math-peak-reset.png)bin1420 -> 1420 bytes
-rw-r--r--src/silx/resources/gui/icons/math-peak-reset.svg (renamed from silx/resources/gui/icons/math-peak-reset.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-peak-search.png (renamed from silx/resources/gui/icons/math-peak-search.png)bin2163 -> 2163 bytes
-rw-r--r--src/silx/resources/gui/icons/math-peak-search.svg2
-rwxr-xr-xsrc/silx/resources/gui/icons/math-peak.png (renamed from silx/resources/gui/icons/math-peak.png)bin829 -> 829 bytes
-rw-r--r--src/silx/resources/gui/icons/math-peak.svg (renamed from silx/resources/gui/icons/math-peak.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-phase-color-log.png (renamed from silx/resources/gui/icons/math-phase-color-log.png)bin2256 -> 2256 bytes
-rw-r--r--src/silx/resources/gui/icons/math-phase-color-log.svg (renamed from silx/resources/gui/icons/math-phase-color-log.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-phase-color.png (renamed from silx/resources/gui/icons/math-phase-color.png)bin2127 -> 2127 bytes
-rw-r--r--src/silx/resources/gui/icons/math-phase-color.svg (renamed from silx/resources/gui/icons/math-phase-color.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-phase.png (renamed from silx/resources/gui/icons/math-phase.png)bin1868 -> 1868 bytes
-rw-r--r--src/silx/resources/gui/icons/math-phase.svg (renamed from silx/resources/gui/icons/math-phase.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-real.png (renamed from silx/resources/gui/icons/math-real.png)bin749 -> 749 bytes
-rw-r--r--src/silx/resources/gui/icons/math-real.svg (renamed from silx/resources/gui/icons/math-real.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-sigma.png (renamed from silx/resources/gui/icons/math-sigma.png)bin744 -> 744 bytes
-rw-r--r--src/silx/resources/gui/icons/math-sigma.svg (renamed from silx/resources/gui/icons/math-sigma.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-smooth.png (renamed from silx/resources/gui/icons/math-smooth.png)bin1243 -> 1243 bytes
-rw-r--r--src/silx/resources/gui/icons/math-smooth.svg (renamed from silx/resources/gui/icons/math-smooth.svg)0
-rw-r--r--src/silx/resources/gui/icons/math-square-amplitude.png (renamed from silx/resources/gui/icons/math-square-amplitude.png)bin592 -> 592 bytes
-rw-r--r--src/silx/resources/gui/icons/math-square-amplitude.svg (renamed from silx/resources/gui/icons/math-square-amplitude.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-substract.png (renamed from silx/resources/gui/icons/math-substract.png)bin845 -> 845 bytes
-rw-r--r--src/silx/resources/gui/icons/math-substract.svg (renamed from silx/resources/gui/icons/math-substract.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-swap-sign.png (renamed from silx/resources/gui/icons/math-swap-sign.png)bin1007 -> 1007 bytes
-rw-r--r--src/silx/resources/gui/icons/math-swap-sign.svg (renamed from silx/resources/gui/icons/math-swap-sign.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/math-ymin-to-zero.png (renamed from silx/resources/gui/icons/math-ymin-to-zero.png)bin666 -> 666 bytes
-rw-r--r--src/silx/resources/gui/icons/math-ymin-to-zero.svg (renamed from silx/resources/gui/icons/math-ymin-to-zero.svg)0
-rw-r--r--src/silx/resources/gui/icons/median-filter.png (renamed from silx/resources/gui/icons/median-filter.png)bin694 -> 694 bytes
-rw-r--r--src/silx/resources/gui/icons/median-filter.svg (renamed from silx/resources/gui/icons/median-filter.svg)0
-rw-r--r--src/silx/resources/gui/icons/next.png (renamed from silx/resources/gui/icons/next.png)bin1092 -> 1092 bytes
-rw-r--r--src/silx/resources/gui/icons/next.svg (renamed from silx/resources/gui/icons/next.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/normal.png (renamed from silx/resources/gui/icons/normal.png)bin1264 -> 1264 bytes
-rw-r--r--src/silx/resources/gui/icons/normal.svg (renamed from silx/resources/gui/icons/normal.svg)0
-rw-r--r--src/silx/resources/gui/icons/nxdata-axis-add.png (renamed from silx/resources/gui/icons/nxdata-axis-add.png)bin686 -> 686 bytes
-rw-r--r--src/silx/resources/gui/icons/nxdata-axis-add.svg (renamed from silx/resources/gui/icons/nxdata-axis-add.svg)0
-rw-r--r--src/silx/resources/gui/icons/nxdata-axis-remove.png (renamed from silx/resources/gui/icons/nxdata-axis-remove.png)bin967 -> 967 bytes
-rw-r--r--src/silx/resources/gui/icons/nxdata-axis-remove.svg (renamed from silx/resources/gui/icons/nxdata-axis-remove.svg)0
-rw-r--r--src/silx/resources/gui/icons/nxdata-create.png (renamed from silx/resources/gui/icons/nxdata-create.png)bin867 -> 867 bytes
-rw-r--r--src/silx/resources/gui/icons/nxdata-create.svg (renamed from silx/resources/gui/icons/nxdata-create.svg)0
-rw-r--r--src/silx/resources/gui/icons/nxdata-remove.png (renamed from silx/resources/gui/icons/nxdata-remove.png)bin1265 -> 1265 bytes
-rw-r--r--src/silx/resources/gui/icons/nxdata-remove.svg (renamed from silx/resources/gui/icons/nxdata-remove.svg)0
-rw-r--r--src/silx/resources/gui/icons/pan.png (renamed from silx/resources/gui/icons/pan.png)bin526 -> 526 bytes
-rw-r--r--src/silx/resources/gui/icons/pan.svg (renamed from silx/resources/gui/icons/pan.svg)0
-rw-r--r--src/silx/resources/gui/icons/pixel-intensities.png (renamed from silx/resources/gui/icons/pixel-intensities.png)bin654 -> 654 bytes
-rw-r--r--src/silx/resources/gui/icons/pixel-intensities.svg (renamed from silx/resources/gui/icons/pixel-intensities.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-grid.png (renamed from silx/resources/gui/icons/plot-grid.png)bin446 -> 446 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-grid.svg (renamed from silx/resources/gui/icons/plot-grid.svg)0
-rw-r--r--src/silx/resources/gui/icons/plot-roi-above.png (renamed from silx/resources/gui/icons/plot-roi-above.png)bin999 -> 999 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-roi-above.svg (renamed from silx/resources/gui/icons/plot-roi-above.svg)0
-rw-r--r--src/silx/resources/gui/icons/plot-roi-below.png (renamed from silx/resources/gui/icons/plot-roi-below.png)bin988 -> 988 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-roi-below.svg (renamed from silx/resources/gui/icons/plot-roi-below.svg)0
-rw-r--r--src/silx/resources/gui/icons/plot-roi-between.png (renamed from silx/resources/gui/icons/plot-roi-between.png)bin1021 -> 1021 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-roi-between.svg (renamed from silx/resources/gui/icons/plot-roi-between.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-roi-reset.png (renamed from silx/resources/gui/icons/plot-roi-reset.png)bin2063 -> 2063 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-roi-reset.svg (renamed from silx/resources/gui/icons/plot-roi-reset.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-roi.png (renamed from silx/resources/gui/icons/plot-roi.png)bin903 -> 903 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-roi.svg (renamed from silx/resources/gui/icons/plot-roi.svg)0
-rw-r--r--src/silx/resources/gui/icons/plot-symbols.png (renamed from silx/resources/gui/icons/plot-symbols.png)bin672 -> 672 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-symbols.svg (renamed from silx/resources/gui/icons/plot-symbols.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-toggle-points.png (renamed from silx/resources/gui/icons/plot-toggle-points.png)bin484 -> 484 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-toggle-points.svg (renamed from silx/resources/gui/icons/plot-toggle-points.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-widget.png (renamed from silx/resources/gui/icons/plot-widget.png)bin1093 -> 1093 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-widget.svg (renamed from silx/resources/gui/icons/plot-widget.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-window-image.png (renamed from silx/resources/gui/icons/plot-window-image.png)bin1188 -> 1188 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-window-image.svg (renamed from silx/resources/gui/icons/plot-window-image.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-window.png (renamed from silx/resources/gui/icons/plot-window.png)bin955 -> 955 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-window.svg (renamed from silx/resources/gui/icons/plot-window.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-xauto.png (renamed from silx/resources/gui/icons/plot-xauto.png)bin626 -> 626 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-xauto.svg (renamed from silx/resources/gui/icons/plot-xauto.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-xlog.png (renamed from silx/resources/gui/icons/plot-xlog.png)bin679 -> 679 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-xlog.svg (renamed from silx/resources/gui/icons/plot-xlog.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-yauto.png (renamed from silx/resources/gui/icons/plot-yauto.png)bin676 -> 676 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-yauto.svg (renamed from silx/resources/gui/icons/plot-yauto.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-ydown.png (renamed from silx/resources/gui/icons/plot-ydown.png)bin701 -> 701 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-ydown.svg (renamed from silx/resources/gui/icons/plot-ydown.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-ylog.png (renamed from silx/resources/gui/icons/plot-ylog.png)bin772 -> 772 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-ylog.svg (renamed from silx/resources/gui/icons/plot-ylog.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/plot-yup.png (renamed from silx/resources/gui/icons/plot-yup.png)bin667 -> 667 bytes
-rw-r--r--src/silx/resources/gui/icons/plot-yup.svg (renamed from silx/resources/gui/icons/plot-yup.svg)0
-rw-r--r--src/silx/resources/gui/icons/pointing-hand.png (renamed from silx/resources/gui/icons/pointing-hand.png)bin680 -> 680 bytes
-rw-r--r--src/silx/resources/gui/icons/pointing-hand.svg (renamed from silx/resources/gui/icons/pointing-hand.svg)0
-rw-r--r--src/silx/resources/gui/icons/previous.png (renamed from silx/resources/gui/icons/previous.png)bin1151 -> 1151 bytes
-rw-r--r--src/silx/resources/gui/icons/previous.svg (renamed from silx/resources/gui/icons/previous.svg)0
-rw-r--r--src/silx/resources/gui/icons/process-working.mng (renamed from silx/resources/gui/icons/process-working.mng)bin15966 -> 15966 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/00.png (renamed from silx/resources/gui/icons/process-working/00.png)bin778 -> 778 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/01.png (renamed from silx/resources/gui/icons/process-working/01.png)bin789 -> 789 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/02.png (renamed from silx/resources/gui/icons/process-working/02.png)bin785 -> 785 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/03.png (renamed from silx/resources/gui/icons/process-working/03.png)bin785 -> 785 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/04.png (renamed from silx/resources/gui/icons/process-working/04.png)bin766 -> 766 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/05.png (renamed from silx/resources/gui/icons/process-working/05.png)bin777 -> 777 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/06.png (renamed from silx/resources/gui/icons/process-working/06.png)bin784 -> 784 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/07.png (renamed from silx/resources/gui/icons/process-working/07.png)bin783 -> 783 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/08.png (renamed from silx/resources/gui/icons/process-working/08.png)bin762 -> 762 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/09.png (renamed from silx/resources/gui/icons/process-working/09.png)bin781 -> 781 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/10.png (renamed from silx/resources/gui/icons/process-working/10.png)bin771 -> 771 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/11.png (renamed from silx/resources/gui/icons/process-working/11.png)bin768 -> 768 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/12.png (renamed from silx/resources/gui/icons/process-working/12.png)bin759 -> 759 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/13.png (renamed from silx/resources/gui/icons/process-working/13.png)bin767 -> 767 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/14.png (renamed from silx/resources/gui/icons/process-working/14.png)bin778 -> 778 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/15.png (renamed from silx/resources/gui/icons/process-working/15.png)bin760 -> 760 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/16.png (renamed from silx/resources/gui/icons/process-working/16.png)bin754 -> 754 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/17.png (renamed from silx/resources/gui/icons/process-working/17.png)bin782 -> 782 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/18.png (renamed from silx/resources/gui/icons/process-working/18.png)bin775 -> 775 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/19.png (renamed from silx/resources/gui/icons/process-working/19.png)bin764 -> 764 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/20.png (renamed from silx/resources/gui/icons/process-working/20.png)bin764 -> 764 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/21.png (renamed from silx/resources/gui/icons/process-working/21.png)bin772 -> 772 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/22.png (renamed from silx/resources/gui/icons/process-working/22.png)bin769 -> 769 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/23.png (renamed from silx/resources/gui/icons/process-working/23.png)bin773 -> 773 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/24.png (renamed from silx/resources/gui/icons/process-working/24.png)bin757 -> 757 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/25.png (renamed from silx/resources/gui/icons/process-working/25.png)bin759 -> 759 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/26.png (renamed from silx/resources/gui/icons/process-working/26.png)bin774 -> 774 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/27.png (renamed from silx/resources/gui/icons/process-working/27.png)bin766 -> 766 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/28.png (renamed from silx/resources/gui/icons/process-working/28.png)bin760 -> 760 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/29.png (renamed from silx/resources/gui/icons/process-working/29.png)bin777 -> 777 bytes
-rw-r--r--src/silx/resources/gui/icons/process-working/30.png (renamed from silx/resources/gui/icons/process-working/30.png)bin775 -> 775 bytes
-rw-r--r--src/silx/resources/gui/icons/profile-clear.png (renamed from silx/resources/gui/icons/profile-clear.png)bin917 -> 917 bytes
-rw-r--r--src/silx/resources/gui/icons/profile-clear.svg (renamed from silx/resources/gui/icons/profile-clear.svg)0
-rw-r--r--src/silx/resources/gui/icons/profile1D.png (renamed from silx/resources/gui/icons/profile1D.png)bin347 -> 347 bytes
-rw-r--r--src/silx/resources/gui/icons/profile1D.svg (renamed from silx/resources/gui/icons/profile1D.svg)0
-rw-r--r--src/silx/resources/gui/icons/profile2D.png (renamed from silx/resources/gui/icons/profile2D.png)bin1403 -> 1403 bytes
-rw-r--r--src/silx/resources/gui/icons/profile2D.svg (renamed from silx/resources/gui/icons/profile2D.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/remove.png (renamed from silx/resources/gui/icons/remove.png)bin680 -> 680 bytes
-rw-r--r--src/silx/resources/gui/icons/remove.svg2
-rw-r--r--src/silx/resources/gui/icons/rm.png (renamed from silx/resources/gui/icons/rm.png)bin348 -> 348 bytes
-rw-r--r--src/silx/resources/gui/icons/rm.svg (renamed from silx/resources/gui/icons/rm.svg)0
-rw-r--r--src/silx/resources/gui/icons/rotate-3d.png (renamed from silx/resources/gui/icons/rotate-3d.png)bin760 -> 760 bytes
-rw-r--r--src/silx/resources/gui/icons/rotate-3d.svg (renamed from silx/resources/gui/icons/rotate-3d.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/rudder.png (renamed from silx/resources/gui/icons/rudder.png)bin877 -> 877 bytes
-rw-r--r--src/silx/resources/gui/icons/rudder.svg (renamed from silx/resources/gui/icons/rudder.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/selected.png (renamed from silx/resources/gui/icons/selected.png)bin1411 -> 1411 bytes
-rw-r--r--src/silx/resources/gui/icons/selected.svg (renamed from silx/resources/gui/icons/selected.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-circle-solid.png (renamed from silx/resources/gui/icons/shape-circle-solid.png)bin562 -> 562 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-circle-solid.svg (renamed from silx/resources/gui/icons/shape-circle-solid.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-circle.png (renamed from silx/resources/gui/icons/shape-circle.png)bin722 -> 722 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-circle.svg (renamed from silx/resources/gui/icons/shape-circle.svg)0
-rw-r--r--src/silx/resources/gui/icons/shape-cross.png (renamed from silx/resources/gui/icons/shape-cross.png)bin356 -> 356 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-cross.svg (renamed from silx/resources/gui/icons/shape-cross.svg)0
-rw-r--r--src/silx/resources/gui/icons/shape-diagonal-directed.png (renamed from silx/resources/gui/icons/shape-diagonal-directed.png)bin542 -> 542 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-diagonal-directed.svg (renamed from silx/resources/gui/icons/shape-diagonal-directed.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-diagonal.png (renamed from silx/resources/gui/icons/shape-diagonal.png)bin461 -> 461 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-diagonal.svg (renamed from silx/resources/gui/icons/shape-diagonal.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-ellipse-solid.png (renamed from silx/resources/gui/icons/shape-ellipse-solid.png)bin541 -> 541 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-ellipse-solid.svg (renamed from silx/resources/gui/icons/shape-ellipse-solid.svg)0
-rw-r--r--src/silx/resources/gui/icons/shape-ellipse.png (renamed from silx/resources/gui/icons/shape-ellipse.png)bin643 -> 643 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-ellipse.svg (renamed from silx/resources/gui/icons/shape-ellipse.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-horizontal.png (renamed from silx/resources/gui/icons/shape-horizontal.png)bin301 -> 301 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-horizontal.svg (renamed from silx/resources/gui/icons/shape-horizontal.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-polygon.png (renamed from silx/resources/gui/icons/shape-polygon.png)bin819 -> 819 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-polygon.svg (renamed from silx/resources/gui/icons/shape-polygon.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-rectangle.png (renamed from silx/resources/gui/icons/shape-rectangle.png)bin337 -> 337 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-rectangle.svg (renamed from silx/resources/gui/icons/shape-rectangle.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-square.png (renamed from silx/resources/gui/icons/shape-square.png)bin417 -> 417 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-square.svg (renamed from silx/resources/gui/icons/shape-square.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/shape-vertical.png (renamed from silx/resources/gui/icons/shape-vertical.png)bin294 -> 294 bytes
-rw-r--r--src/silx/resources/gui/icons/shape-vertical.svg (renamed from silx/resources/gui/icons/shape-vertical.svg)0
-rw-r--r--src/silx/resources/gui/icons/side-histograms.pngbin0 -> 518 bytes
-rw-r--r--src/silx/resources/gui/icons/side-histograms.svg2
-rwxr-xr-xsrc/silx/resources/gui/icons/silx.png (renamed from silx/resources/gui/icons/silx.png)bin2048 -> 2048 bytes
-rw-r--r--src/silx/resources/gui/icons/silx.svg (renamed from silx/resources/gui/icons/silx.svg)0
-rw-r--r--src/silx/resources/gui/icons/slice-cross.png (renamed from silx/resources/gui/icons/slice-cross.png)bin1057 -> 1057 bytes
-rw-r--r--src/silx/resources/gui/icons/slice-cross.svg (renamed from silx/resources/gui/icons/slice-cross.svg)0
-rw-r--r--src/silx/resources/gui/icons/slice-horizontal.png (renamed from silx/resources/gui/icons/slice-horizontal.png)bin967 -> 967 bytes
-rw-r--r--src/silx/resources/gui/icons/slice-horizontal.svg (renamed from silx/resources/gui/icons/slice-horizontal.svg)0
-rw-r--r--src/silx/resources/gui/icons/slice-vertical.png (renamed from silx/resources/gui/icons/slice-vertical.png)bin1023 -> 1023 bytes
-rw-r--r--src/silx/resources/gui/icons/slice-vertical.svg (renamed from silx/resources/gui/icons/slice-vertical.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/sliders-off.png (renamed from silx/resources/gui/icons/sliders-off.png)bin1111 -> 1111 bytes
-rw-r--r--src/silx/resources/gui/icons/sliders-off.svg (renamed from silx/resources/gui/icons/sliders-off.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/sliders-on.png (renamed from silx/resources/gui/icons/sliders-on.png)bin691 -> 691 bytes
-rw-r--r--src/silx/resources/gui/icons/sliders-on.svg (renamed from silx/resources/gui/icons/sliders-on.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/spec.png (renamed from silx/resources/gui/icons/spec.png)bin1044 -> 1044 bytes
-rw-r--r--src/silx/resources/gui/icons/spec.svg (renamed from silx/resources/gui/icons/spec.svg)0
-rw-r--r--src/silx/resources/gui/icons/stats-active-items.png (renamed from silx/resources/gui/icons/stats-active-items.png)bin1521 -> 1521 bytes
-rw-r--r--src/silx/resources/gui/icons/stats-active-items.svg (renamed from silx/resources/gui/icons/stats-active-items.svg)0
-rw-r--r--src/silx/resources/gui/icons/stats-visible-data.png (renamed from silx/resources/gui/icons/stats-visible-data.png)bin662 -> 662 bytes
-rw-r--r--src/silx/resources/gui/icons/stats-visible-data.svg (renamed from silx/resources/gui/icons/stats-visible-data.svg)0
-rw-r--r--src/silx/resources/gui/icons/stats-whole-data.png (renamed from silx/resources/gui/icons/stats-whole-data.png)bin923 -> 923 bytes
-rw-r--r--src/silx/resources/gui/icons/stats-whole-data.svg (renamed from silx/resources/gui/icons/stats-whole-data.svg)0
-rw-r--r--src/silx/resources/gui/icons/stats-whole-items.png (renamed from silx/resources/gui/icons/stats-whole-items.png)bin1333 -> 1333 bytes
-rw-r--r--src/silx/resources/gui/icons/stats-whole-items.svg (renamed from silx/resources/gui/icons/stats-whole-items.svg)0
-rw-r--r--src/silx/resources/gui/icons/tree-collapse-all.png (renamed from silx/resources/gui/icons/tree-collapse-all.png)bin508 -> 508 bytes
-rw-r--r--src/silx/resources/gui/icons/tree-collapse-all.svg (renamed from silx/resources/gui/icons/tree-collapse-all.svg)0
-rw-r--r--src/silx/resources/gui/icons/tree-expand-all.png (renamed from silx/resources/gui/icons/tree-expand-all.png)bin602 -> 602 bytes
-rw-r--r--src/silx/resources/gui/icons/tree-expand-all.svg (renamed from silx/resources/gui/icons/tree-expand-all.svg)0
-rw-r--r--src/silx/resources/gui/icons/tree-sort.png (renamed from silx/resources/gui/icons/tree-sort.png)bin655 -> 655 bytes
-rw-r--r--src/silx/resources/gui/icons/tree-sort.svg (renamed from silx/resources/gui/icons/tree-sort.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-1d.png (renamed from silx/resources/gui/icons/view-1d.png)bin881 -> 881 bytes
-rw-r--r--src/silx/resources/gui/icons/view-1d.svg (renamed from silx/resources/gui/icons/view-1d.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-2d-stack.png (renamed from silx/resources/gui/icons/view-2d-stack.png)bin710 -> 710 bytes
-rw-r--r--src/silx/resources/gui/icons/view-2d-stack.svg (renamed from silx/resources/gui/icons/view-2d-stack.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-2d.png (renamed from silx/resources/gui/icons/view-2d.png)bin304 -> 304 bytes
-rw-r--r--src/silx/resources/gui/icons/view-2d.svg (renamed from silx/resources/gui/icons/view-2d.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-3d.png (renamed from silx/resources/gui/icons/view-3d.png)bin1073 -> 1073 bytes
-rw-r--r--src/silx/resources/gui/icons/view-3d.svg (renamed from silx/resources/gui/icons/view-3d.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/view-fullscreen.png (renamed from silx/resources/gui/icons/view-fullscreen.png)bin1829 -> 1829 bytes
-rw-r--r--src/silx/resources/gui/icons/view-fullscreen.svg (renamed from silx/resources/gui/icons/view-fullscreen.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-hdf5.png (renamed from silx/resources/gui/icons/view-hdf5.png)bin1347 -> 1347 bytes
-rw-r--r--src/silx/resources/gui/icons/view-hdf5.svg (renamed from silx/resources/gui/icons/view-hdf5.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-nexus.png (renamed from silx/resources/gui/icons/view-nexus.png)bin1332 -> 1332 bytes
-rw-r--r--src/silx/resources/gui/icons/view-nexus.svg (renamed from silx/resources/gui/icons/view-nexus.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/view-nofullscreen.png (renamed from silx/resources/gui/icons/view-nofullscreen.png)bin1799 -> 1799 bytes
-rw-r--r--src/silx/resources/gui/icons/view-nofullscreen.svg (renamed from silx/resources/gui/icons/view-nofullscreen.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-raw.png (renamed from silx/resources/gui/icons/view-raw.png)bin641 -> 641 bytes
-rw-r--r--src/silx/resources/gui/icons/view-raw.svg (renamed from silx/resources/gui/icons/view-raw.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/view-refresh.png (renamed from silx/resources/gui/icons/view-refresh.png)bin1184 -> 1184 bytes
-rw-r--r--src/silx/resources/gui/icons/view-refresh.svg (renamed from silx/resources/gui/icons/view-refresh.svg)0
-rw-r--r--src/silx/resources/gui/icons/view-text.png (renamed from silx/resources/gui/icons/view-text.png)bin872 -> 872 bytes
-rw-r--r--src/silx/resources/gui/icons/view-text.svg (renamed from silx/resources/gui/icons/view-text.svg)0
-rwxr-xr-xsrc/silx/resources/gui/icons/window-new.png (renamed from silx/resources/gui/icons/window-new.png)bin698 -> 698 bytes
-rw-r--r--src/silx/resources/gui/icons/window-new.svg (renamed from silx/resources/gui/icons/window-new.svg)0
-rw-r--r--src/silx/resources/gui/icons/zoom-back.png (renamed from silx/resources/gui/icons/zoom-back.png)bin1432 -> 1432 bytes
-rw-r--r--src/silx/resources/gui/icons/zoom-back.svg2
-rwxr-xr-xsrc/silx/resources/gui/icons/zoom-in.png (renamed from silx/resources/gui/icons/zoom-in.png)bin1612 -> 1612 bytes
-rw-r--r--src/silx/resources/gui/icons/zoom-in.svg2
-rwxr-xr-xsrc/silx/resources/gui/icons/zoom-original.png (renamed from silx/resources/gui/icons/zoom-original.png)bin1518 -> 1518 bytes
-rw-r--r--src/silx/resources/gui/icons/zoom-original.svg2
-rwxr-xr-xsrc/silx/resources/gui/icons/zoom-out.png (renamed from silx/resources/gui/icons/zoom-out.png)bin1567 -> 1567 bytes
-rw-r--r--src/silx/resources/gui/icons/zoom-out.svg2
-rwxr-xr-xsrc/silx/resources/gui/icons/zoom.png (renamed from silx/resources/gui/icons/zoom.png)bin1448 -> 1448 bytes
-rw-r--r--src/silx/resources/gui/icons/zoom.svg2
-rw-r--r--src/silx/resources/gui/logo/silx.png (renamed from silx/resources/gui/logo/silx.png)bin21257 -> 21257 bytes
-rw-r--r--src/silx/resources/gui/logo/silx.svg (renamed from silx/resources/gui/logo/silx.svg)0
-rw-r--r--src/silx/resources/opencl/addition.cl (renamed from silx/resources/opencl/addition.cl)0
-rw-r--r--src/silx/resources/opencl/array_utils.cl (renamed from silx/resources/opencl/array_utils.cl)0
-rw-r--r--src/silx/resources/opencl/backproj.cl (renamed from silx/resources/opencl/backproj.cl)0
-rw-r--r--src/silx/resources/opencl/backproj_helper.cl (renamed from silx/resources/opencl/backproj_helper.cl)0
-rw-r--r--src/silx/resources/opencl/bitonic.cl (renamed from silx/resources/opencl/bitonic.cl)0
-rw-r--r--src/silx/resources/opencl/codec/byte_offset.cl (renamed from silx/resources/opencl/codec/byte_offset.cl)0
-rw-r--r--src/silx/resources/opencl/convolution.cl (renamed from silx/resources/opencl/convolution.cl)0
-rw-r--r--src/silx/resources/opencl/convolution_textures.cl (renamed from silx/resources/opencl/convolution_textures.cl)0
-rw-r--r--src/silx/resources/opencl/doubleword.cl (renamed from silx/resources/opencl/doubleword.cl)0
-rw-r--r--src/silx/resources/opencl/image/cast.cl (renamed from silx/resources/opencl/image/cast.cl)0
-rw-r--r--src/silx/resources/opencl/image/histogram.cl (renamed from silx/resources/opencl/image/histogram.cl)0
-rw-r--r--src/silx/resources/opencl/image/map.cl (renamed from silx/resources/opencl/image/map.cl)0
-rw-r--r--src/silx/resources/opencl/image/max_min.cl (renamed from silx/resources/opencl/image/max_min.cl)0
-rw-r--r--src/silx/resources/opencl/kahan.cl (renamed from silx/resources/opencl/kahan.cl)0
-rw-r--r--src/silx/resources/opencl/linalg.cl (renamed from silx/resources/opencl/linalg.cl)0
-rw-r--r--src/silx/resources/opencl/medfilt.cl (renamed from silx/resources/opencl/medfilt.cl)0
-rw-r--r--src/silx/resources/opencl/preprocess.cl (renamed from silx/resources/opencl/preprocess.cl)0
-rw-r--r--src/silx/resources/opencl/proj.cl (renamed from silx/resources/opencl/proj.cl)0
-rw-r--r--src/silx/resources/opencl/sparse.cl (renamed from silx/resources/opencl/sparse.cl)0
-rw-r--r--src/silx/resources/opencl/statistics.cl (renamed from silx/resources/opencl/statistics.cl)0
-rw-r--r--src/silx/setup.py54
-rw-r--r--src/silx/sx/__init__.py (renamed from silx/sx/__init__.py)0
-rw-r--r--src/silx/sx/_plot.py625
-rw-r--r--src/silx/sx/_plot3d.py (renamed from silx/sx/_plot3d.py)0
-rw-r--r--src/silx/test/__init__.py53
-rw-r--r--src/silx/test/test_resources.py187
-rw-r--r--src/silx/test/test_sx.py265
-rw-r--r--src/silx/test/test_version.py38
-rw-r--r--src/silx/test/utils.py198
-rw-r--r--src/silx/third_party/EdfFile.py (renamed from silx/third_party/EdfFile.py)0
-rw-r--r--src/silx/third_party/TiffIO.py (renamed from silx/third_party/TiffIO.py)0
-rw-r--r--src/silx/third_party/__init__.py (renamed from silx/third_party/__init__.py)0
-rw-r--r--src/silx/third_party/scipy_spatial.py (renamed from silx/third_party/scipy_spatial.py)0
-rw-r--r--src/silx/third_party/setup.py49
-rw-r--r--src/silx/utils/ExternalResources.py321
-rw-r--r--src/silx/utils/__init__.py (renamed from silx/utils/__init__.py)0
-rw-r--r--src/silx/utils/_have_openmp.pxd49
-rw-r--r--src/silx/utils/array_like.py595
-rw-r--r--src/silx/utils/debug.py100
-rw-r--r--src/silx/utils/deprecation.py (renamed from silx/utils/deprecation.py)0
-rw-r--r--src/silx/utils/enum.py (renamed from silx/utils/enum.py)0
-rw-r--r--src/silx/utils/exceptions.py (renamed from silx/utils/exceptions.py)0
-rw-r--r--src/silx/utils/files.py (renamed from silx/utils/files.py)0
-rw-r--r--src/silx/utils/html.py37
-rw-r--r--src/silx/utils/include/silx_store_openmp.h (renamed from silx/utils/include/silx_store_openmp.h)0
-rw-r--r--src/silx/utils/launcher.py (renamed from silx/utils/launcher.py)0
-rwxr-xr-xsrc/silx/utils/number.py (renamed from silx/utils/number.py)0
-rw-r--r--src/silx/utils/property.py (renamed from silx/utils/property.py)0
-rw-r--r--src/silx/utils/proxy.py208
-rw-r--r--src/silx/utils/retry.py (renamed from silx/utils/retry.py)0
-rw-r--r--src/silx/utils/setup.py (renamed from silx/utils/setup.py)0
-rwxr-xr-xsrc/silx/utils/test/__init__.py24
-rw-r--r--src/silx/utils/test/test_array_like.py430
-rw-r--r--src/silx/utils/test/test_debug.py88
-rw-r--r--src/silx/utils/test/test_deprecation.py96
-rw-r--r--src/silx/utils/test/test_enum.py85
-rw-r--r--src/silx/utils/test/test_external_resources.py89
-rw-r--r--src/silx/utils/test/test_launcher.py191
-rw-r--r--src/silx/utils/test/test_launcher_command.py (renamed from silx/utils/test/test_launcher_command.py)0
-rw-r--r--src/silx/utils/test/test_number.py175
-rw-r--r--src/silx/utils/test/test_proxy.py330
-rw-r--r--src/silx/utils/test/test_retry.py169
-rwxr-xr-xsrc/silx/utils/test/test_testutils.py94
-rw-r--r--src/silx/utils/test/test_weakref.py315
-rwxr-xr-xsrc/silx/utils/testutils.py351
-rw-r--r--src/silx/utils/weakref.py (renamed from silx/utils/weakref.py)0
-rw-r--r--version.py120
1441 files changed, 127646 insertions, 128285 deletions
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 3556fbb..4e99a93 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,6 +1,134 @@
Change Log
==========
+1.0.0: 2021/11/XX
+-----------------
+
+This the first version of `silx` supporting `PySide6` (for `Qt6`) and using `pytest` to run the tests.
+
+* `silx view`:
+
+ * Added Windows installer generation (PR #3548)
+ * Updated 'About' dialog (#3547, #3475)
+ * Fixed: Keep curve legend selection with changing dimensions (PR #3529)
+ * Fixed: Increase max number of opened file at start-up (PR #3545)
+
+* `silx.gui`:
+
+ * Added PySide6 support (PR #3486, #3528, #3479, #3542, #3549, #3478, #3481):
+ * Removed support of PyQt4 / Pyside (PR #3423, #3424, #3480, #3482)
+ * `silx.gui.colors`:
+
+ * Fixed duplicated logs when colormap vmin/vmax are not valid (PR #3471)
+
+ * `silx.gui.plot`:
+
+ * `silx.gui.plot.actions`:
+
+ * `silx.gui.plot.actions.fit`:
+
+ * Updated behaviour of fitted item auto update (PR #3532)
+
+ * `silx.gui.plot.actions.histogram`:
+
+ * Enhanced: Allow user to change histogram nbins and range (PR #3514, #3514)
+ * Updated `PixelIntensitiesHistoAction` to use `PlotWidget.selection` (PR #3408)
+ * Fixed issue when the whole image is masked (PR #3544)
+ * Fixed error on macOS 11 with 3D display in `silx view` (PR #3544)
+
+ * `silx.gui.plot.CompareImages`:
+
+ * Fixed `colormap`: avoid forcing vmin and vmax when not in 'HORIZONTAL_LINE' or 'VERTICAL_LINE' mode (PR #3510)
+
+ * `silx.gui.plot.items`:
+
+ * Added 'image_aggregated.ImageDataAggregated': item allowing to aggregate image data before display (PR #3503)
+ * Fixed `ArcROI.setGeometry` (fix #3492)
+
+ * `silx.gui.plot.ImageStack`:
+
+ * Enhanced management of the `animation thread` (PR #3440, PR #3441)
+
+ * `silx.gui.plot.ImageView`:
+
+ * Added action to show/hide the side histogram (PR #3488)
+ * Added 'resetzoom' parameter to 'ImageView.setImage' (PR #3488)
+ * Added empty array support to 'ImageView.setImage' (PR #3530)
+ * Added aggregation mode action (PR #3536)
+ * Added support of RGB and RGBA images (PR #3487)
+ * Updated 'imageview' example with a '--live' option (PR #3488)
+ * Fixed profile window, added `setProfileWindowBehavior` method (PR #3457)
+ * Fixed issue with profile window size (PR #3455)
+
+ * `silx.gui.plot.PlotWidget`:
+
+ * Fixed update of `Scatter` item binned statistics visualization (PR #3452)
+ * Fixed OpenGL backend memory leak (PR #3453)
+ * Enhanced: Optimized scatter when rendered as regular grid with the OpenGL backend (PR #3447)
+ * Enhanced axis limits management by the OpenGL backend (PR #3504)
+ * Enhanced control of repaint (PR #3449)
+ * Enhanced text label background rendering with OpenGL backend (PR #3565)
+
+ * `silx.gui.plot.PlotWindow`:
+
+ * Fixed returned action from 'getKeepDataAspectRatioAction' (PR #3500)
+
+ * `silx.gui.plot3d`:
+
+ * Fixed picking on highdpi screen (PR #3550)
+ * Fixed issue in parameter tree (PR #3550)
+
+* `silx.io`:
+
+ * Added read support for FIO files (PR #3539) thanks to tifuchs contribution
+ * `silx.io.dictdump`:
+
+ * Fixed missing conversion of the key (PR #3505) thanks to rnwatanabe contribution
+ * Extract update modes list to a constant global variable (PR #3460) thanks to jpcbertoldo
+
+ * `silx.io.convert`:
+
+ * Enhanced `write_to_h5`: `infile` parameter can now also be a HDF5 file as input (PR #3511)
+
+ * `silx.io.h5py_utils`:
+
+ * Added support of `locking` argument from the h5py.File when possible (PR #3554)
+ * Added log a critical message for unsupported versions of libhdf5 (PR #3533)
+
+ * `silx.io.spech5`:
+
+ * Enhanced: Improve robustness (PR #3507, #3463)
+
+ * `silx.io.url`:
+
+ * Fixed `is_absolute` in the case the `file_path()` returns None (PR #3437)
+
+ * `silx.io.utils`:
+
+ * Added 'silx.io.utils.visitall': provides a visitor of all items including links that works for both `commonh5` and `h5py` (PR #3511)
+
+* `silx.math`:
+
+ * `silx.math.colormap`:
+
+ * Added `apply_colormap` function (PR #3525)
+ * Enhanced `cmap` error messages (PR #3522)
+
+* `silx.opencl`:
+
+ * Added description of compute capabilities for Ampere generation GPU from Nvidia (PR #3535)
+ * Added doubleword OpenCL library (PR #3466, PR #3472)
+
+* Miscellaneous:
+
+ * Enhanced: Setup the project to use `pytest` (PR #3431, #3516, #3526)
+ * Enhanced: Minor test clean up (PR #3515, #3508)
+ * Updated project structure: move `silx` sources in `src/silx` (PR #3412)
+ * Fixed 'run_test.py --qt-binding' option (PR #3527)
+ * Fixed support of numpy 1.21rc1 (PR ##3476)
+ * Removed `six` dependency (PR #3483)
+
+
0.15.2: 2021/06/21
------------------
@@ -31,7 +159,6 @@ Minor release:
* Fixed profile window default behavior (PR #3458)
* Added `setProfileWindowBehavior` method (PR #3458)
-
0.15.0: 2021/03/18
------------------
@@ -101,7 +228,6 @@ Main new features are the `silx.io.h5py_utils` module which provides `h5py` conc
* Fixed debian packaging (PR #3362)
* Fixed `silx test` application on Windows (PR #3411)
-
0.14.1: 2021/04/30
------------------
@@ -110,7 +236,6 @@ This is a bug-fix version of silx.
* silx.gui.plot: Fixed `PlotWidget` OpenGL backend memory leak (PR #3445)
* silx.gui.utils.glutils: Fixed `isOpenGLAvailable` (PR #3356)
-
0.14.0: 2020/12/11
------------------
diff --git a/MANIFEST.in b/MANIFEST.in
index da024c2..5ee9a8d 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -4,18 +4,17 @@ include copyright
include LICENSE
include MANIFEST.in
include run_tests.py
-include version.py
include stdeb.cfg
include build-deb.sh
include requirements.txt
include requirements-dev.txt
include pyproject.toml
-recursive-include silx *.pyx *.pxd *.pxi
-recursive-include silx *.h *.c *.hpp *.cpp
+recursive-include src/silx *.pyx *.pxd *.pxi
+recursive-include src/silx *.h *.c *.hpp *.cpp
recursive-include doc/source *.py *.rst *.png *.ico *.ipynb
global-exclude .ipynb_checkpoints/*
recursive-include qtdesigner_plugins *.py *.rst
-recursive-include silx/resources *
+recursive-include src/silx/resources *
recursive-include examples *
recursive-include package *
diff --git a/PKG-INFO b/PKG-INFO
index 04ec406..43179ac 100644
--- a/PKG-INFO
+++ b/PKG-INFO
@@ -1,13 +1,13 @@
Metadata-Version: 2.1
Name: silx
-Version: 0.15.2
+Version: 1.0.0
Summary: Software library for X-ray data analysis
Home-page: http://www.silx.org/
Author: data analysis unit
Author-email: silx@esrf.fr
License: UNKNOWN
Platform: UNKNOWN
-Classifier: Development Status :: 4 - Beta
+Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Console
Classifier: Environment :: MacOS X
Classifier: Environment :: Win32 (MS Windows)
@@ -26,12 +26,16 @@ Classifier: Topic :: Scientific/Engineering :: Physics
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.5
Provides-Extra: full
+Provides-Extra: test
License-File: LICENSE
silx toolkit
============
+.. |silxView| image:: http://www.silx.org/doc/silx/img/silx-view-v1-0.gif
+ :height: 480px
+
The purpose of the *silx* project is to provide a collection of Python packages to support the
development of data assessment, reduction and analysis applications at synchrotron
radiation facilities.
@@ -59,8 +63,12 @@ The current version features:
* a set of applications:
* a unified viewer (*silx view filename*) for HDF5, SPEC and image file formats
+
+ |silxView|
+
* a unified converter to HDF5 format (*silx convert filename*)
+
Installation
------------
diff --git a/README.rst b/README.rst
index 37915ef..5b8a3ca 100644
--- a/README.rst
+++ b/README.rst
@@ -2,6 +2,9 @@
silx toolkit
============
+.. |silxView| image:: http://www.silx.org/doc/silx/img/silx-view-v1-0.gif
+ :height: 480px
+
The purpose of the *silx* project is to provide a collection of Python packages to support the
development of data assessment, reduction and analysis applications at synchrotron
radiation facilities.
@@ -29,8 +32,12 @@ The current version features:
* a set of applications:
* a unified viewer (*silx view filename*) for HDF5, SPEC and image file formats
+
+ |silxView|
+
* a unified converter to HDF5 format (*silx convert filename*)
+
Installation
------------
diff --git a/build-deb.sh b/build-deb.sh
index 25718f3..23b0a86 100755
--- a/build-deb.sh
+++ b/build-deb.sh
@@ -3,7 +3,7 @@
# Project: Silx
# https://github.com/silx-kit/silx
#
-# Copyright (C) 2015-2020 European Synchrotron Radiation Facility, Grenoble, France
+# Copyright (C) 2015-2021 European Synchrotron Radiation Facility, Grenoble, France
#
# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
@@ -29,9 +29,6 @@
project=silx
source_project=silx
-version=$(python3 -c"import version; print(version.version)")
-strictversion=$(python3 -c"import version; print(version.strictversion)")
-debianversion=$(python3 -c"import version; print(version.debianversion)")
deb_name=$(echo "$source_project" | tr '[:upper:]' '[:lower:]')
@@ -66,6 +63,13 @@ project_directory="`( cd \"$project_directory\" && pwd )`" # absolutized
dist_directory=${project_directory}/dist/${target_system}
build_directory=${project_directory}/build/${target_system}
+# Get version info
+cd ${project_directory}/src/${project}
+version=$(python3 -c"import _version; print(_version.version)")
+strictversion=$(python3 -c"import _version; print(_version.strictversion)")
+debianversion=$(python3 -c"import _version; print(_version.debianversion)")
+cd ${project_directory}
+
if [ -d /usr/lib/ccache ];
then
export PATH=/usr/lib/ccache:$PATH
diff --git a/doc/source/Tutorials/array_widget.rst b/doc/source/Tutorials/array_widget.rst
index c6a32e3..b0bc890 100644
--- a/doc/source/Tutorials/array_widget.rst
+++ b/doc/source/Tutorials/array_widget.rst
@@ -23,7 +23,7 @@ Let's look at a simple usage example:
w = ArrayTableWidget()
w.setArrayData(array, labels=True)
w.show()
- app.exec_()
+ app.exec()
.. |imgArray0| image:: img/arraywidget3D_0.png
@@ -242,6 +242,6 @@ of RGB colors.
fgcolors=fcolors)
atw.show()
- app.exec_()
+ app.exec()
diff --git a/doc/source/Tutorials/fit.rst b/doc/source/Tutorials/fit.rst
index d9671f4..c0eafc5 100644
--- a/doc/source/Tutorials/fit.rst
+++ b/doc/source/Tutorials/fit.rst
@@ -450,7 +450,7 @@ The following example illustrates the strip background removal process:
app = qt.QApplication([])
plot(x, y, x, actual_bg, x, strip_bg)
plot(x, y, x, (y - strip_bg))
- app.exec_()
+ app.exec()
.. |imgStrip1| image:: img/stripbg_plot1.png
:height: 300px
@@ -545,7 +545,7 @@ Simple usage
w.setData(x=x, y=y)
w.show()
- a.exec_()
+ a.exec()
.. |imgFitWidget1| image:: img/fitwidget1.png
:width: 300px
@@ -638,7 +638,7 @@ The :class:`FitWidget` can be initialised with a non-standard
fw = FitWidget(fitmngr=myfitmngr)
fw.show()
- a.exec_()
+ a.exec()
In our previous example, we didn't load a customised :class:`FitManager`,
therefore, the fit widget automatically initialised the default fit manager and
diff --git a/doc/source/Tutorials/fitconfig.rst b/doc/source/Tutorials/fitconfig.rst
index 225ef8f..0d7538c 100644
--- a/doc/source/Tutorials/fitconfig.rst
+++ b/doc/source/Tutorials/fitconfig.rst
@@ -52,7 +52,7 @@ dialog by FitWidget:
- :meth:`show`: should cause the widget to become visible to the
user)
- - :meth:`exec_`: should run while the user is interacting with the
+ - :meth:`exec`: should run while the user is interacting with the
widget, interrupting the rest of the program. It should
typically end (*return*) when the user clicks an *OK*
or a *Cancel* button.
@@ -175,7 +175,7 @@ used by our fit function to scale the *y* values.
fw.associateConfigDialog("scaled linear", CustomConfigWidget())
fw.show()
- app.exec_()
+ app.exec()
.. |img0| image:: img/custom_config_scale1.0.png
:height: 300px
diff --git a/doc/source/Tutorials/writing_NXdata.rst b/doc/source/Tutorials/writing_NXdata.rst
index 1c65199..4d87e3d 100644
--- a/doc/source/Tutorials/writing_NXdata.rst
+++ b/doc/source/Tutorials/writing_NXdata.rst
@@ -154,8 +154,7 @@ a *frame number*.
.. note::
- This additional attribute is not mentionned in the official NXdata
- specification.
+ This attribute is documented in the official NeXus `description <https://manual.nexusformat.org/nxdl_desc.html>`_
Writing NXdata with h5py
diff --git a/doc/source/applications/view.rst b/doc/source/applications/view.rst
index d4145c2..747a121 100644
--- a/doc/source/applications/view.rst
+++ b/doc/source/applications/view.rst
@@ -1,7 +1,11 @@
+.. _silx view:
silx view
=========
+.. figure:: http://www.silx.org/doc/silx/img/silx-view-v1-0.gif
+ :align: center
+
Purpose
-------
diff --git a/doc/source/ext/snapshotqt_directive.py b/doc/source/ext/snapshotqt_directive.py
index 582b934..84b3ac6 100644
--- a/doc/source/ext/snapshotqt_directive.py
+++ b/doc/source/ext/snapshotqt_directive.py
@@ -152,7 +152,7 @@ else:
_towrite = _line.lstrip(' ')
if not _towrite.startswith(':'):
_file.write(_towrite + '\n')
- _file.write("app.exec_()")
+ _file.write("app.exec()")
self.content = []
if script is not None:
_logger.warning('Cannot specify a script if source code (content) is given.'
diff --git a/doc/source/index.rst b/doc/source/index.rst
index 1c4ad72..027bd6f 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -1,6 +1,9 @@
silx |version|
==============
+.. |silxView| image:: http://www.silx.org/doc/silx/img/silx-view-v1-0.gif
+ :height: 80px
+
The silx project aims to provide a collection of Python packages to support the
development of data assessment, reduction and analysis at synchrotron radiation
facilities.
@@ -28,9 +31,13 @@ The current version (v\ |version|) caters for:
* a set of applications:
- * a unified viewer (*silx view filename*) for HDF5, SPEC and image file formats
- * a unified converter to HDF5 format (*silx convert filename*)
+ * a unified viewer (:ref:`silx view` *filename*) for HDF5, SPEC and image file formats
+
+ |silxView|
+
+ * a unified converter to HDF5 format (*silx convert filename*)
+
.. toctree::
:hidden:
diff --git a/doc/source/install.rst b/doc/source/install.rst
index b0d6b4f..0841c2a 100644
--- a/doc/source/install.rst
+++ b/doc/source/install.rst
@@ -7,15 +7,15 @@ programming language.
This table summarizes the support matrix of silx:
-+------------+--------------+---------------------+
-| System | Python vers. | Qt and its bindings |
-+------------+--------------+---------------------+
-| `Windows`_ | 3.6-3.9 | PyQt5.6+, PySide2 |
-+------------+--------------+---------------------+
-| `MacOS`_ | 3.6-3.9 | PyQt5.6+, PySide2 |
-+------------+--------------+---------------------+
-| `Linux`_ | 3.6-3.9 | PyQt5.3+, PySide2 |
-+------------+--------------+---------------------+
++------------+--------------+----------------------------+
+| System | Python vers. | Qt and its bindings |
++------------+--------------+----------------------------+
+| `Windows`_ | 3.6-3.9 | PyQt5.6+, PySide2, PySide6 |
++------------+--------------+----------------------------+
+| `MacOS`_ | 3.6-3.9 | PyQt5.6+, PySide2, PySide6 |
++------------+--------------+----------------------------+
+| `Linux`_ | 3.6-3.9 | PyQt5.3+, PySide2, PySide6 |
++------------+--------------+----------------------------+
For the description of *silx* dependencies, see the Dependencies_ section.
@@ -66,7 +66,8 @@ The mandatory dependencies are:
The GUI widgets depend on the following extra packages:
* A Qt binding: either `PyQt5 <https://riverbankcomputing.com/software/pyqt/intro>`_,
- or `PySide2 <https://wiki.qt.io/Qt_for_Python>`_
+ `PySide2 <https://pypi.org/project/PySide2/>`_, or
+ `PySide6 <https://pypi.org/project/PySide6/>`_
* `matplotlib <http://matplotlib.org/>`_
* `PyOpenGL <http://pyopengl.sourceforge.net/>`_
* `qt_console <https://pypi.org/project/qtconsole>`_
@@ -245,7 +246,7 @@ installed using:
.. code-block:: bash
- pip install -r https://github.com/silx-kit/silx/raw/0.8/requirements-dev.txt
+ pip install -r https://github.com/silx-kit/silx/raw/master/requirements-dev.txt
Building from source
diff --git a/doc/source/modules/gui/icons.rst b/doc/source/modules/gui/icons.rst
index 67235c2..cb8e044 100644
--- a/doc/source/modules/gui/icons.rst
+++ b/doc/source/modules/gui/icons.rst
@@ -55,6 +55,8 @@ Available icons
- add-shape-vertical
* - |add|
- add
+ * - |aggregation-mode|
+ - aggregation-mode
* - |arrow-keys|
- arrow-keys
* - |axis|
@@ -339,6 +341,8 @@ Available icons
- shape-square
* - |shape-vertical|
- shape-vertical
+ * - |side-histograms|
+ - side-histograms
* - |silx|
- silx
* - |slice-cross|
@@ -402,194 +406,196 @@ Available icons
* - |zoom|
- zoom
-.. |3d-plane-normal-x| image:: ../../../../silx/resources/gui/icons/3d-plane-normal-x.png
-.. |3d-plane-normal-y| image:: ../../../../silx/resources/gui/icons/3d-plane-normal-y.png
-.. |3d-plane-normal-z| image:: ../../../../silx/resources/gui/icons/3d-plane-normal-z.png
-.. |3d-plane-pan| image:: ../../../../silx/resources/gui/icons/3d-plane-pan.png
-.. |3d-plane| image:: ../../../../silx/resources/gui/icons/3d-plane.png
-.. |add-range-horizontal| image:: ../../../../silx/resources/gui/icons/add-range-horizontal.png
-.. |add-shape-arc| image:: ../../../../silx/resources/gui/icons/add-shape-arc.png
-.. |add-shape-circle| image:: ../../../../silx/resources/gui/icons/add-shape-circle.png
-.. |add-shape-cross| image:: ../../../../silx/resources/gui/icons/add-shape-cross.png
-.. |add-shape-diagonal| image:: ../../../../silx/resources/gui/icons/add-shape-diagonal.png
-.. |add-shape-ellipse| image:: ../../../../silx/resources/gui/icons/add-shape-ellipse.png
-.. |add-shape-horizontal| image:: ../../../../silx/resources/gui/icons/add-shape-horizontal.png
-.. |add-shape-point| image:: ../../../../silx/resources/gui/icons/add-shape-point.png
-.. |add-shape-polygon| image:: ../../../../silx/resources/gui/icons/add-shape-polygon.png
-.. |add-shape-rectangle| image:: ../../../../silx/resources/gui/icons/add-shape-rectangle.png
-.. |add-shape-unknown| image:: ../../../../silx/resources/gui/icons/add-shape-unknown.png
-.. |add-shape-vertical| image:: ../../../../silx/resources/gui/icons/add-shape-vertical.png
-.. |add| image:: ../../../../silx/resources/gui/icons/add.png
-.. |arrow-keys| image:: ../../../../silx/resources/gui/icons/arrow-keys.png
-.. |axis| image:: ../../../../silx/resources/gui/icons/axis.png
-.. |backend-opengl| image:: ../../../../silx/resources/gui/icons/backend-opengl.png
-.. |camera| image:: ../../../../silx/resources/gui/icons/camera.png
-.. |clipboard| image:: ../../../../silx/resources/gui/icons/clipboard.png
-.. |close| image:: ../../../../silx/resources/gui/icons/close.png
-.. |colorbar| image:: ../../../../silx/resources/gui/icons/colorbar.png
-.. |colormap-histogram| image:: ../../../../silx/resources/gui/icons/colormap-histogram.png
-.. |colormap-none| image:: ../../../../silx/resources/gui/icons/colormap-none.png
-.. |colormap-norm-arcsinh| image:: ../../../../silx/resources/gui/icons/colormap-norm-arcsinh.png
-.. |colormap-norm-gamma| image:: ../../../../silx/resources/gui/icons/colormap-norm-gamma.png
-.. |colormap-norm-linear| image:: ../../../../silx/resources/gui/icons/colormap-norm-linear.png
-.. |colormap-norm-log| image:: ../../../../silx/resources/gui/icons/colormap-norm-log.png
-.. |colormap-norm-sqrt| image:: ../../../../silx/resources/gui/icons/colormap-norm-sqrt.png
-.. |colormap-range| image:: ../../../../silx/resources/gui/icons/colormap-range.png
-.. |colormap| image:: ../../../../silx/resources/gui/icons/colormap.png
-.. |compare-align-auto| image:: ../../../../silx/resources/gui/icons/compare-align-auto.png
-.. |compare-align-center| image:: ../../../../silx/resources/gui/icons/compare-align-center.png
-.. |compare-align-origin| image:: ../../../../silx/resources/gui/icons/compare-align-origin.png
-.. |compare-align-stretch| image:: ../../../../silx/resources/gui/icons/compare-align-stretch.png
-.. |compare-keypoints| image:: ../../../../silx/resources/gui/icons/compare-keypoints.png
-.. |compare-mode-a-minus-b| image:: ../../../../silx/resources/gui/icons/compare-mode-a-minus-b.png
-.. |compare-mode-a| image:: ../../../../silx/resources/gui/icons/compare-mode-a.png
-.. |compare-mode-b| image:: ../../../../silx/resources/gui/icons/compare-mode-b.png
-.. |compare-mode-hline| image:: ../../../../silx/resources/gui/icons/compare-mode-hline.png
-.. |compare-mode-rb-channel| image:: ../../../../silx/resources/gui/icons/compare-mode-rb-channel.png
-.. |compare-mode-rbneg-channel| image:: ../../../../silx/resources/gui/icons/compare-mode-rbneg-channel.png
-.. |compare-mode-vline| image:: ../../../../silx/resources/gui/icons/compare-mode-vline.png
-.. |crop| image:: ../../../../silx/resources/gui/icons/crop.png
-.. |crosshair| image:: ../../../../silx/resources/gui/icons/crosshair.png
-.. |cube-back| image:: ../../../../silx/resources/gui/icons/cube-back.png
-.. |cube-bottom| image:: ../../../../silx/resources/gui/icons/cube-bottom.png
-.. |cube-front| image:: ../../../../silx/resources/gui/icons/cube-front.png
-.. |cube-left| image:: ../../../../silx/resources/gui/icons/cube-left.png
-.. |cube-right| image:: ../../../../silx/resources/gui/icons/cube-right.png
-.. |cube-rotate| image:: ../../../../silx/resources/gui/icons/cube-rotate.png
-.. |cube-top| image:: ../../../../silx/resources/gui/icons/cube-top.png
-.. |cube| image:: ../../../../silx/resources/gui/icons/cube.png
-.. |description-description| image:: ../../../../silx/resources/gui/icons/description-description.png
-.. |description-error| image:: ../../../../silx/resources/gui/icons/description-error.png
-.. |description-name| image:: ../../../../silx/resources/gui/icons/description-name.png
-.. |description-program| image:: ../../../../silx/resources/gui/icons/description-program.png
-.. |description-title| image:: ../../../../silx/resources/gui/icons/description-title.png
-.. |description-value| image:: ../../../../silx/resources/gui/icons/description-value.png
-.. |document-open| image:: ../../../../silx/resources/gui/icons/document-open.png
-.. |document-print| image:: ../../../../silx/resources/gui/icons/document-print.png
-.. |document-save| image:: ../../../../silx/resources/gui/icons/document-save.png
-.. |draw-brush| image:: ../../../../silx/resources/gui/icons/draw-brush.png
-.. |draw-pencil| image:: ../../../../silx/resources/gui/icons/draw-pencil.png
-.. |draw-rubber| image:: ../../../../silx/resources/gui/icons/draw-rubber.png
-.. |edit-copy| image:: ../../../../silx/resources/gui/icons/edit-copy.png
-.. |eye| image:: ../../../../silx/resources/gui/icons/eye.png
-.. |first| image:: ../../../../silx/resources/gui/icons/first.png
-.. |folder| image:: ../../../../silx/resources/gui/icons/folder.png
-.. |image-mask| image:: ../../../../silx/resources/gui/icons/image-mask.png
-.. |image-select-add| image:: ../../../../silx/resources/gui/icons/image-select-add.png
-.. |image-select-box| image:: ../../../../silx/resources/gui/icons/image-select-box.png
-.. |image-select-brush| image:: ../../../../silx/resources/gui/icons/image-select-brush.png
-.. |image-select-erase-rubber| image:: ../../../../silx/resources/gui/icons/image-select-erase-rubber.png
-.. |image-select-erase| image:: ../../../../silx/resources/gui/icons/image-select-erase.png
-.. |image| image:: ../../../../silx/resources/gui/icons/image.png
-.. |item-0dim| image:: ../../../../silx/resources/gui/icons/item-0dim.png
-.. |item-1dim| image:: ../../../../silx/resources/gui/icons/item-1dim.png
-.. |item-2dim| image:: ../../../../silx/resources/gui/icons/item-2dim.png
-.. |item-3dim| image:: ../../../../silx/resources/gui/icons/item-3dim.png
-.. |item-ndim| image:: ../../../../silx/resources/gui/icons/item-ndim.png
-.. |item-none| image:: ../../../../silx/resources/gui/icons/item-none.png
-.. |item-object| image:: ../../../../silx/resources/gui/icons/item-object.png
-.. |last| image:: ../../../../silx/resources/gui/icons/last.png
-.. |layer-nx| image:: ../../../../silx/resources/gui/icons/layer-nx.png
-.. |mask-clear-all| image:: ../../../../silx/resources/gui/icons/mask-clear-all.png
-.. |mask-clear| image:: ../../../../silx/resources/gui/icons/mask-clear.png
-.. |mask-invert| image:: ../../../../silx/resources/gui/icons/mask-invert.png
-.. |math-amplitude| image:: ../../../../silx/resources/gui/icons/math-amplitude.png
-.. |math-average| image:: ../../../../silx/resources/gui/icons/math-average.png
-.. |math-derive| image:: ../../../../silx/resources/gui/icons/math-derive.png
-.. |math-energy| image:: ../../../../silx/resources/gui/icons/math-energy.png
-.. |math-fit| image:: ../../../../silx/resources/gui/icons/math-fit.png
-.. |math-imaginary| image:: ../../../../silx/resources/gui/icons/math-imaginary.png
-.. |math-mean| image:: ../../../../silx/resources/gui/icons/math-mean.png
-.. |math-normalize| image:: ../../../../silx/resources/gui/icons/math-normalize.png
-.. |math-peak-reset| image:: ../../../../silx/resources/gui/icons/math-peak-reset.png
-.. |math-peak-search| image:: ../../../../silx/resources/gui/icons/math-peak-search.png
-.. |math-peak| image:: ../../../../silx/resources/gui/icons/math-peak.png
-.. |math-phase-color-log| image:: ../../../../silx/resources/gui/icons/math-phase-color-log.png
-.. |math-phase-color| image:: ../../../../silx/resources/gui/icons/math-phase-color.png
-.. |math-phase| image:: ../../../../silx/resources/gui/icons/math-phase.png
-.. |math-real| image:: ../../../../silx/resources/gui/icons/math-real.png
-.. |math-sigma| image:: ../../../../silx/resources/gui/icons/math-sigma.png
-.. |math-smooth| image:: ../../../../silx/resources/gui/icons/math-smooth.png
-.. |math-square-amplitude| image:: ../../../../silx/resources/gui/icons/math-square-amplitude.png
-.. |math-substract| image:: ../../../../silx/resources/gui/icons/math-substract.png
-.. |math-swap-sign| image:: ../../../../silx/resources/gui/icons/math-swap-sign.png
-.. |math-ymin-to-zero| image:: ../../../../silx/resources/gui/icons/math-ymin-to-zero.png
-.. |median-filter| image:: ../../../../silx/resources/gui/icons/median-filter.png
-.. |next| image:: ../../../../silx/resources/gui/icons/next.png
-.. |normal| image:: ../../../../silx/resources/gui/icons/normal.png
-.. |nxdata-axis-add| image:: ../../../../silx/resources/gui/icons/nxdata-axis-add.png
-.. |nxdata-axis-remove| image:: ../../../../silx/resources/gui/icons/nxdata-axis-remove.png
-.. |nxdata-create| image:: ../../../../silx/resources/gui/icons/nxdata-create.png
-.. |nxdata-remove| image:: ../../../../silx/resources/gui/icons/nxdata-remove.png
-.. |pan| image:: ../../../../silx/resources/gui/icons/pan.png
-.. |pixel-intensities| image:: ../../../../silx/resources/gui/icons/pixel-intensities.png
-.. |plot-grid| image:: ../../../../silx/resources/gui/icons/plot-grid.png
-.. |plot-roi-above| image:: ../../../../silx/resources/gui/icons/plot-roi-above.png
-.. |plot-roi-below| image:: ../../../../silx/resources/gui/icons/plot-roi-below.png
-.. |plot-roi-between| image:: ../../../../silx/resources/gui/icons/plot-roi-between.png
-.. |plot-roi-reset| image:: ../../../../silx/resources/gui/icons/plot-roi-reset.png
-.. |plot-roi| image:: ../../../../silx/resources/gui/icons/plot-roi.png
-.. |plot-symbols| image:: ../../../../silx/resources/gui/icons/plot-symbols.png
-.. |plot-toggle-points| image:: ../../../../silx/resources/gui/icons/plot-toggle-points.png
-.. |plot-widget| image:: ../../../../silx/resources/gui/icons/plot-widget.png
-.. |plot-window-image| image:: ../../../../silx/resources/gui/icons/plot-window-image.png
-.. |plot-window| image:: ../../../../silx/resources/gui/icons/plot-window.png
-.. |plot-xauto| image:: ../../../../silx/resources/gui/icons/plot-xauto.png
-.. |plot-xlog| image:: ../../../../silx/resources/gui/icons/plot-xlog.png
-.. |plot-yauto| image:: ../../../../silx/resources/gui/icons/plot-yauto.png
-.. |plot-ydown| image:: ../../../../silx/resources/gui/icons/plot-ydown.png
-.. |plot-ylog| image:: ../../../../silx/resources/gui/icons/plot-ylog.png
-.. |plot-yup| image:: ../../../../silx/resources/gui/icons/plot-yup.png
-.. |pointing-hand| image:: ../../../../silx/resources/gui/icons/pointing-hand.png
-.. |previous| image:: ../../../../silx/resources/gui/icons/previous.png
-.. |profile-clear| image:: ../../../../silx/resources/gui/icons/profile-clear.png
-.. |profile1D| image:: ../../../../silx/resources/gui/icons/profile1D.png
-.. |profile2D| image:: ../../../../silx/resources/gui/icons/profile2D.png
-.. |remove| image:: ../../../../silx/resources/gui/icons/remove.png
-.. |rm| image:: ../../../../silx/resources/gui/icons/rm.png
-.. |rotate-3d| image:: ../../../../silx/resources/gui/icons/rotate-3d.png
-.. |rudder| image:: ../../../../silx/resources/gui/icons/rudder.png
-.. |selected| image:: ../../../../silx/resources/gui/icons/selected.png
-.. |shape-circle-solid| image:: ../../../../silx/resources/gui/icons/shape-circle-solid.png
-.. |shape-circle| image:: ../../../../silx/resources/gui/icons/shape-circle.png
-.. |shape-cross| image:: ../../../../silx/resources/gui/icons/shape-cross.png
-.. |shape-diagonal-directed| image:: ../../../../silx/resources/gui/icons/shape-diagonal-directed.png
-.. |shape-diagonal| image:: ../../../../silx/resources/gui/icons/shape-diagonal.png
-.. |shape-ellipse-solid| image:: ../../../../silx/resources/gui/icons/shape-ellipse-solid.png
-.. |shape-ellipse| image:: ../../../../silx/resources/gui/icons/shape-ellipse.png
-.. |shape-horizontal| image:: ../../../../silx/resources/gui/icons/shape-horizontal.png
-.. |shape-polygon| image:: ../../../../silx/resources/gui/icons/shape-polygon.png
-.. |shape-rectangle| image:: ../../../../silx/resources/gui/icons/shape-rectangle.png
-.. |shape-square| image:: ../../../../silx/resources/gui/icons/shape-square.png
-.. |shape-vertical| image:: ../../../../silx/resources/gui/icons/shape-vertical.png
-.. |silx| image:: ../../../../silx/resources/gui/icons/silx.png
-.. |slice-cross| image:: ../../../../silx/resources/gui/icons/slice-cross.png
-.. |slice-horizontal| image:: ../../../../silx/resources/gui/icons/slice-horizontal.png
-.. |slice-vertical| image:: ../../../../silx/resources/gui/icons/slice-vertical.png
-.. |sliders-off| image:: ../../../../silx/resources/gui/icons/sliders-off.png
-.. |sliders-on| image:: ../../../../silx/resources/gui/icons/sliders-on.png
-.. |spec| image:: ../../../../silx/resources/gui/icons/spec.png
-.. |stats-active-items| image:: ../../../../silx/resources/gui/icons/stats-active-items.png
-.. |stats-visible-data| image:: ../../../../silx/resources/gui/icons/stats-visible-data.png
-.. |stats-whole-data| image:: ../../../../silx/resources/gui/icons/stats-whole-data.png
-.. |stats-whole-items| image:: ../../../../silx/resources/gui/icons/stats-whole-items.png
-.. |tree-collapse-all| image:: ../../../../silx/resources/gui/icons/tree-collapse-all.png
-.. |tree-expand-all| image:: ../../../../silx/resources/gui/icons/tree-expand-all.png
-.. |tree-sort| image:: ../../../../silx/resources/gui/icons/tree-sort.png
-.. |view-1d| image:: ../../../../silx/resources/gui/icons/view-1d.png
-.. |view-2d-stack| image:: ../../../../silx/resources/gui/icons/view-2d-stack.png
-.. |view-2d| image:: ../../../../silx/resources/gui/icons/view-2d.png
-.. |view-3d| image:: ../../../../silx/resources/gui/icons/view-3d.png
-.. |view-fullscreen| image:: ../../../../silx/resources/gui/icons/view-fullscreen.png
-.. |view-hdf5| image:: ../../../../silx/resources/gui/icons/view-hdf5.png
-.. |view-nexus| image:: ../../../../silx/resources/gui/icons/view-nexus.png
-.. |view-nofullscreen| image:: ../../../../silx/resources/gui/icons/view-nofullscreen.png
-.. |view-raw| image:: ../../../../silx/resources/gui/icons/view-raw.png
-.. |view-refresh| image:: ../../../../silx/resources/gui/icons/view-refresh.png
-.. |view-text| image:: ../../../../silx/resources/gui/icons/view-text.png
-.. |window-new| image:: ../../../../silx/resources/gui/icons/window-new.png
-.. |zoom-back| image:: ../../../../silx/resources/gui/icons/zoom-back.png
-.. |zoom-in| image:: ../../../../silx/resources/gui/icons/zoom-in.png
-.. |zoom-original| image:: ../../../../silx/resources/gui/icons/zoom-original.png
-.. |zoom-out| image:: ../../../../silx/resources/gui/icons/zoom-out.png
-.. |zoom| image:: ../../../../silx/resources/gui/icons/zoom.png
+.. |3d-plane-normal-x| image:: ../../../../src/silx/resources/gui/icons/3d-plane-normal-x.png
+.. |3d-plane-normal-y| image:: ../../../../src/silx/resources/gui/icons/3d-plane-normal-y.png
+.. |3d-plane-normal-z| image:: ../../../../src/silx/resources/gui/icons/3d-plane-normal-z.png
+.. |3d-plane-pan| image:: ../../../../src/silx/resources/gui/icons/3d-plane-pan.png
+.. |3d-plane| image:: ../../../../src/silx/resources/gui/icons/3d-plane.png
+.. |add-range-horizontal| image:: ../../../../src/silx/resources/gui/icons/add-range-horizontal.png
+.. |add-shape-arc| image:: ../../../../src/silx/resources/gui/icons/add-shape-arc.png
+.. |add-shape-circle| image:: ../../../../src/silx/resources/gui/icons/add-shape-circle.png
+.. |add-shape-cross| image:: ../../../../src/silx/resources/gui/icons/add-shape-cross.png
+.. |add-shape-diagonal| image:: ../../../../src/silx/resources/gui/icons/add-shape-diagonal.png
+.. |add-shape-ellipse| image:: ../../../../src/silx/resources/gui/icons/add-shape-ellipse.png
+.. |add-shape-horizontal| image:: ../../../../src/silx/resources/gui/icons/add-shape-horizontal.png
+.. |add-shape-point| image:: ../../../../src/silx/resources/gui/icons/add-shape-point.png
+.. |add-shape-polygon| image:: ../../../../src/silx/resources/gui/icons/add-shape-polygon.png
+.. |add-shape-rectangle| image:: ../../../../src/silx/resources/gui/icons/add-shape-rectangle.png
+.. |add-shape-unknown| image:: ../../../../src/silx/resources/gui/icons/add-shape-unknown.png
+.. |add-shape-vertical| image:: ../../../../src/silx/resources/gui/icons/add-shape-vertical.png
+.. |add| image:: ../../../../src/silx/resources/gui/icons/add.png
+.. |aggregation-mode| image:: ../../../../src/silx/resources/gui/icons/aggregation-mode.png
+.. |arrow-keys| image:: ../../../../src/silx/resources/gui/icons/arrow-keys.png
+.. |axis| image:: ../../../../src/silx/resources/gui/icons/axis.png
+.. |backend-opengl| image:: ../../../../src/silx/resources/gui/icons/backend-opengl.png
+.. |camera| image:: ../../../../src/silx/resources/gui/icons/camera.png
+.. |clipboard| image:: ../../../../src/silx/resources/gui/icons/clipboard.png
+.. |close| image:: ../../../../src/silx/resources/gui/icons/close.png
+.. |colorbar| image:: ../../../../src/silx/resources/gui/icons/colorbar.png
+.. |colormap-histogram| image:: ../../../../src/silx/resources/gui/icons/colormap-histogram.png
+.. |colormap-none| image:: ../../../../src/silx/resources/gui/icons/colormap-none.png
+.. |colormap-norm-arcsinh| image:: ../../../../src/silx/resources/gui/icons/colormap-norm-arcsinh.png
+.. |colormap-norm-gamma| image:: ../../../../src/silx/resources/gui/icons/colormap-norm-gamma.png
+.. |colormap-norm-linear| image:: ../../../../src/silx/resources/gui/icons/colormap-norm-linear.png
+.. |colormap-norm-log| image:: ../../../../src/silx/resources/gui/icons/colormap-norm-log.png
+.. |colormap-norm-sqrt| image:: ../../../../src/silx/resources/gui/icons/colormap-norm-sqrt.png
+.. |colormap-range| image:: ../../../../src/silx/resources/gui/icons/colormap-range.png
+.. |colormap| image:: ../../../../src/silx/resources/gui/icons/colormap.png
+.. |compare-align-auto| image:: ../../../../src/silx/resources/gui/icons/compare-align-auto.png
+.. |compare-align-center| image:: ../../../../src/silx/resources/gui/icons/compare-align-center.png
+.. |compare-align-origin| image:: ../../../../src/silx/resources/gui/icons/compare-align-origin.png
+.. |compare-align-stretch| image:: ../../../../src/silx/resources/gui/icons/compare-align-stretch.png
+.. |compare-keypoints| image:: ../../../../src/silx/resources/gui/icons/compare-keypoints.png
+.. |compare-mode-a-minus-b| image:: ../../../../src/silx/resources/gui/icons/compare-mode-a-minus-b.png
+.. |compare-mode-a| image:: ../../../../src/silx/resources/gui/icons/compare-mode-a.png
+.. |compare-mode-b| image:: ../../../../src/silx/resources/gui/icons/compare-mode-b.png
+.. |compare-mode-hline| image:: ../../../../src/silx/resources/gui/icons/compare-mode-hline.png
+.. |compare-mode-rb-channel| image:: ../../../../src/silx/resources/gui/icons/compare-mode-rb-channel.png
+.. |compare-mode-rbneg-channel| image:: ../../../../src/silx/resources/gui/icons/compare-mode-rbneg-channel.png
+.. |compare-mode-vline| image:: ../../../../src/silx/resources/gui/icons/compare-mode-vline.png
+.. |crop| image:: ../../../../src/silx/resources/gui/icons/crop.png
+.. |crosshair| image:: ../../../../src/silx/resources/gui/icons/crosshair.png
+.. |cube-back| image:: ../../../../src/silx/resources/gui/icons/cube-back.png
+.. |cube-bottom| image:: ../../../../src/silx/resources/gui/icons/cube-bottom.png
+.. |cube-front| image:: ../../../../src/silx/resources/gui/icons/cube-front.png
+.. |cube-left| image:: ../../../../src/silx/resources/gui/icons/cube-left.png
+.. |cube-right| image:: ../../../../src/silx/resources/gui/icons/cube-right.png
+.. |cube-rotate| image:: ../../../../src/silx/resources/gui/icons/cube-rotate.png
+.. |cube-top| image:: ../../../../src/silx/resources/gui/icons/cube-top.png
+.. |cube| image:: ../../../../src/silx/resources/gui/icons/cube.png
+.. |description-description| image:: ../../../../src/silx/resources/gui/icons/description-description.png
+.. |description-error| image:: ../../../../src/silx/resources/gui/icons/description-error.png
+.. |description-name| image:: ../../../../src/silx/resources/gui/icons/description-name.png
+.. |description-program| image:: ../../../../src/silx/resources/gui/icons/description-program.png
+.. |description-title| image:: ../../../../src/silx/resources/gui/icons/description-title.png
+.. |description-value| image:: ../../../../src/silx/resources/gui/icons/description-value.png
+.. |document-open| image:: ../../../../src/silx/resources/gui/icons/document-open.png
+.. |document-print| image:: ../../../../src/silx/resources/gui/icons/document-print.png
+.. |document-save| image:: ../../../../src/silx/resources/gui/icons/document-save.png
+.. |draw-brush| image:: ../../../../src/silx/resources/gui/icons/draw-brush.png
+.. |draw-pencil| image:: ../../../../src/silx/resources/gui/icons/draw-pencil.png
+.. |draw-rubber| image:: ../../../../src/silx/resources/gui/icons/draw-rubber.png
+.. |edit-copy| image:: ../../../../src/silx/resources/gui/icons/edit-copy.png
+.. |eye| image:: ../../../../src/silx/resources/gui/icons/eye.png
+.. |first| image:: ../../../../src/silx/resources/gui/icons/first.png
+.. |folder| image:: ../../../../src/silx/resources/gui/icons/folder.png
+.. |image-mask| image:: ../../../../src/silx/resources/gui/icons/image-mask.png
+.. |image-select-add| image:: ../../../../src/silx/resources/gui/icons/image-select-add.png
+.. |image-select-box| image:: ../../../../src/silx/resources/gui/icons/image-select-box.png
+.. |image-select-brush| image:: ../../../../src/silx/resources/gui/icons/image-select-brush.png
+.. |image-select-erase-rubber| image:: ../../../../src/silx/resources/gui/icons/image-select-erase-rubber.png
+.. |image-select-erase| image:: ../../../../src/silx/resources/gui/icons/image-select-erase.png
+.. |image| image:: ../../../../src/silx/resources/gui/icons/image.png
+.. |item-0dim| image:: ../../../../src/silx/resources/gui/icons/item-0dim.png
+.. |item-1dim| image:: ../../../../src/silx/resources/gui/icons/item-1dim.png
+.. |item-2dim| image:: ../../../../src/silx/resources/gui/icons/item-2dim.png
+.. |item-3dim| image:: ../../../../src/silx/resources/gui/icons/item-3dim.png
+.. |item-ndim| image:: ../../../../src/silx/resources/gui/icons/item-ndim.png
+.. |item-none| image:: ../../../../src/silx/resources/gui/icons/item-none.png
+.. |item-object| image:: ../../../../src/silx/resources/gui/icons/item-object.png
+.. |last| image:: ../../../../src/silx/resources/gui/icons/last.png
+.. |layer-nx| image:: ../../../../src/silx/resources/gui/icons/layer-nx.png
+.. |mask-clear-all| image:: ../../../../src/silx/resources/gui/icons/mask-clear-all.png
+.. |mask-clear| image:: ../../../../src/silx/resources/gui/icons/mask-clear.png
+.. |mask-invert| image:: ../../../../src/silx/resources/gui/icons/mask-invert.png
+.. |math-amplitude| image:: ../../../../src/silx/resources/gui/icons/math-amplitude.png
+.. |math-average| image:: ../../../../src/silx/resources/gui/icons/math-average.png
+.. |math-derive| image:: ../../../../src/silx/resources/gui/icons/math-derive.png
+.. |math-energy| image:: ../../../../src/silx/resources/gui/icons/math-energy.png
+.. |math-fit| image:: ../../../../src/silx/resources/gui/icons/math-fit.png
+.. |math-imaginary| image:: ../../../../src/silx/resources/gui/icons/math-imaginary.png
+.. |math-mean| image:: ../../../../src/silx/resources/gui/icons/math-mean.png
+.. |math-normalize| image:: ../../../../src/silx/resources/gui/icons/math-normalize.png
+.. |math-peak-reset| image:: ../../../../src/silx/resources/gui/icons/math-peak-reset.png
+.. |math-peak-search| image:: ../../../../src/silx/resources/gui/icons/math-peak-search.png
+.. |math-peak| image:: ../../../../src/silx/resources/gui/icons/math-peak.png
+.. |math-phase-color-log| image:: ../../../../src/silx/resources/gui/icons/math-phase-color-log.png
+.. |math-phase-color| image:: ../../../../src/silx/resources/gui/icons/math-phase-color.png
+.. |math-phase| image:: ../../../../src/silx/resources/gui/icons/math-phase.png
+.. |math-real| image:: ../../../../src/silx/resources/gui/icons/math-real.png
+.. |math-sigma| image:: ../../../../src/silx/resources/gui/icons/math-sigma.png
+.. |math-smooth| image:: ../../../../src/silx/resources/gui/icons/math-smooth.png
+.. |math-square-amplitude| image:: ../../../../src/silx/resources/gui/icons/math-square-amplitude.png
+.. |math-substract| image:: ../../../../src/silx/resources/gui/icons/math-substract.png
+.. |math-swap-sign| image:: ../../../../src/silx/resources/gui/icons/math-swap-sign.png
+.. |math-ymin-to-zero| image:: ../../../../src/silx/resources/gui/icons/math-ymin-to-zero.png
+.. |median-filter| image:: ../../../../src/silx/resources/gui/icons/median-filter.png
+.. |next| image:: ../../../../src/silx/resources/gui/icons/next.png
+.. |normal| image:: ../../../../src/silx/resources/gui/icons/normal.png
+.. |nxdata-axis-add| image:: ../../../../src/silx/resources/gui/icons/nxdata-axis-add.png
+.. |nxdata-axis-remove| image:: ../../../../src/silx/resources/gui/icons/nxdata-axis-remove.png
+.. |nxdata-create| image:: ../../../../src/silx/resources/gui/icons/nxdata-create.png
+.. |nxdata-remove| image:: ../../../../src/silx/resources/gui/icons/nxdata-remove.png
+.. |pan| image:: ../../../../src/silx/resources/gui/icons/pan.png
+.. |pixel-intensities| image:: ../../../../src/silx/resources/gui/icons/pixel-intensities.png
+.. |plot-grid| image:: ../../../../src/silx/resources/gui/icons/plot-grid.png
+.. |plot-roi-above| image:: ../../../../src/silx/resources/gui/icons/plot-roi-above.png
+.. |plot-roi-below| image:: ../../../../src/silx/resources/gui/icons/plot-roi-below.png
+.. |plot-roi-between| image:: ../../../../src/silx/resources/gui/icons/plot-roi-between.png
+.. |plot-roi-reset| image:: ../../../../src/silx/resources/gui/icons/plot-roi-reset.png
+.. |plot-roi| image:: ../../../../src/silx/resources/gui/icons/plot-roi.png
+.. |plot-symbols| image:: ../../../../src/silx/resources/gui/icons/plot-symbols.png
+.. |plot-toggle-points| image:: ../../../../src/silx/resources/gui/icons/plot-toggle-points.png
+.. |plot-widget| image:: ../../../../src/silx/resources/gui/icons/plot-widget.png
+.. |plot-window-image| image:: ../../../../src/silx/resources/gui/icons/plot-window-image.png
+.. |plot-window| image:: ../../../../src/silx/resources/gui/icons/plot-window.png
+.. |plot-xauto| image:: ../../../../src/silx/resources/gui/icons/plot-xauto.png
+.. |plot-xlog| image:: ../../../../src/silx/resources/gui/icons/plot-xlog.png
+.. |plot-yauto| image:: ../../../../src/silx/resources/gui/icons/plot-yauto.png
+.. |plot-ydown| image:: ../../../../src/silx/resources/gui/icons/plot-ydown.png
+.. |plot-ylog| image:: ../../../../src/silx/resources/gui/icons/plot-ylog.png
+.. |plot-yup| image:: ../../../../src/silx/resources/gui/icons/plot-yup.png
+.. |pointing-hand| image:: ../../../../src/silx/resources/gui/icons/pointing-hand.png
+.. |previous| image:: ../../../../src/silx/resources/gui/icons/previous.png
+.. |profile-clear| image:: ../../../../src/silx/resources/gui/icons/profile-clear.png
+.. |profile1D| image:: ../../../../src/silx/resources/gui/icons/profile1D.png
+.. |profile2D| image:: ../../../../src/silx/resources/gui/icons/profile2D.png
+.. |remove| image:: ../../../../src/silx/resources/gui/icons/remove.png
+.. |rm| image:: ../../../../src/silx/resources/gui/icons/rm.png
+.. |rotate-3d| image:: ../../../../src/silx/resources/gui/icons/rotate-3d.png
+.. |rudder| image:: ../../../../src/silx/resources/gui/icons/rudder.png
+.. |selected| image:: ../../../../src/silx/resources/gui/icons/selected.png
+.. |shape-circle-solid| image:: ../../../../src/silx/resources/gui/icons/shape-circle-solid.png
+.. |shape-circle| image:: ../../../../src/silx/resources/gui/icons/shape-circle.png
+.. |shape-cross| image:: ../../../../src/silx/resources/gui/icons/shape-cross.png
+.. |shape-diagonal-directed| image:: ../../../../src/silx/resources/gui/icons/shape-diagonal-directed.png
+.. |shape-diagonal| image:: ../../../../src/silx/resources/gui/icons/shape-diagonal.png
+.. |shape-ellipse-solid| image:: ../../../../src/silx/resources/gui/icons/shape-ellipse-solid.png
+.. |shape-ellipse| image:: ../../../../src/silx/resources/gui/icons/shape-ellipse.png
+.. |shape-horizontal| image:: ../../../../src/silx/resources/gui/icons/shape-horizontal.png
+.. |shape-polygon| image:: ../../../../src/silx/resources/gui/icons/shape-polygon.png
+.. |shape-rectangle| image:: ../../../../src/silx/resources/gui/icons/shape-rectangle.png
+.. |shape-square| image:: ../../../../src/silx/resources/gui/icons/shape-square.png
+.. |shape-vertical| image:: ../../../../src/silx/resources/gui/icons/shape-vertical.png
+.. |side-histograms| image:: ../../../../src/silx/resources/gui/icons/side-histograms.png
+.. |silx| image:: ../../../../src/silx/resources/gui/icons/silx.png
+.. |slice-cross| image:: ../../../../src/silx/resources/gui/icons/slice-cross.png
+.. |slice-horizontal| image:: ../../../../src/silx/resources/gui/icons/slice-horizontal.png
+.. |slice-vertical| image:: ../../../../src/silx/resources/gui/icons/slice-vertical.png
+.. |sliders-off| image:: ../../../../src/silx/resources/gui/icons/sliders-off.png
+.. |sliders-on| image:: ../../../../src/silx/resources/gui/icons/sliders-on.png
+.. |spec| image:: ../../../../src/silx/resources/gui/icons/spec.png
+.. |stats-active-items| image:: ../../../../src/silx/resources/gui/icons/stats-active-items.png
+.. |stats-visible-data| image:: ../../../../src/silx/resources/gui/icons/stats-visible-data.png
+.. |stats-whole-data| image:: ../../../../src/silx/resources/gui/icons/stats-whole-data.png
+.. |stats-whole-items| image:: ../../../../src/silx/resources/gui/icons/stats-whole-items.png
+.. |tree-collapse-all| image:: ../../../../src/silx/resources/gui/icons/tree-collapse-all.png
+.. |tree-expand-all| image:: ../../../../src/silx/resources/gui/icons/tree-expand-all.png
+.. |tree-sort| image:: ../../../../src/silx/resources/gui/icons/tree-sort.png
+.. |view-1d| image:: ../../../../src/silx/resources/gui/icons/view-1d.png
+.. |view-2d-stack| image:: ../../../../src/silx/resources/gui/icons/view-2d-stack.png
+.. |view-2d| image:: ../../../../src/silx/resources/gui/icons/view-2d.png
+.. |view-3d| image:: ../../../../src/silx/resources/gui/icons/view-3d.png
+.. |view-fullscreen| image:: ../../../../src/silx/resources/gui/icons/view-fullscreen.png
+.. |view-hdf5| image:: ../../../../src/silx/resources/gui/icons/view-hdf5.png
+.. |view-nexus| image:: ../../../../src/silx/resources/gui/icons/view-nexus.png
+.. |view-nofullscreen| image:: ../../../../src/silx/resources/gui/icons/view-nofullscreen.png
+.. |view-raw| image:: ../../../../src/silx/resources/gui/icons/view-raw.png
+.. |view-refresh| image:: ../../../../src/silx/resources/gui/icons/view-refresh.png
+.. |view-text| image:: ../../../../src/silx/resources/gui/icons/view-text.png
+.. |window-new| image:: ../../../../src/silx/resources/gui/icons/window-new.png
+.. |zoom-back| image:: ../../../../src/silx/resources/gui/icons/zoom-back.png
+.. |zoom-in| image:: ../../../../src/silx/resources/gui/icons/zoom-in.png
+.. |zoom-original| image:: ../../../../src/silx/resources/gui/icons/zoom-original.png
+.. |zoom-out| image:: ../../../../src/silx/resources/gui/icons/zoom-out.png
+.. |zoom| image:: ../../../../src/silx/resources/gui/icons/zoom.png
diff --git a/doc/source/modules/gui/plot/getting_started.rst b/doc/source/modules/gui/plot/getting_started.rst
index c105395..1c29f23 100644
--- a/doc/source/modules/gui/plot/getting_started.rst
+++ b/doc/source/modules/gui/plot/getting_started.rst
@@ -20,7 +20,7 @@ For a complete description of the API, see :mod:`silx.gui.plot`.
Use :mod:`silx.gui.plot` from (I)Python console
-----------------------------------------------
-We recommend to use (I)Python >=3.5 and PyQt5.
+We recommend to use (I)Python >=3.6 and PyQt5.
From a Python or IPython interpreter, the simplest way is to import the :mod:`silx.sx` module:
@@ -87,9 +87,9 @@ A Qt GUI script must have a QApplication initialised before creating widgets:
if __name__ == '__main__':
[...]
- qapp.exec_()
+ qapp.exec()
-Unless a Qt binding has already been loaded, :mod:`silx.gui.qt` uses one of the supported Qt bindings (PyQt5, PyQt4, PySide2).
+Unless a Qt binding has already been loaded, :mod:`silx.gui.qt` uses one of the supported Qt bindings (PyQt5, PySide2, PySide6).
If you prefer to choose the Qt binding yourself, import it before importing
a module from :mod:`silx.gui`:
diff --git a/doc/source/modules/gui/widgets/printpreview.rst b/doc/source/modules/gui/widgets/printpreview.rst
index d0b7999..bff2381 100644
--- a/doc/source/modules/gui/widgets/printpreview.rst
+++ b/doc/source/modules/gui/widgets/printpreview.rst
@@ -56,5 +56,5 @@ Example
commentPosition="CENTER")
w.addImage(qt.QImage(filename), comment=comment, commentPosition="LEFT")
- w.exec_()
- a.exec_()
+ w.exec()
+ a.exec()
diff --git a/doc/source/modules/io/fioh5.rst b/doc/source/modules/io/fioh5.rst
new file mode 100644
index 0000000..c901878
--- /dev/null
+++ b/doc/source/modules/io/fioh5.rst
@@ -0,0 +1,35 @@
+
+.. currentmodule:: silx.io
+
+:mod:`fioh5`: h5py-like API to FIO file
+----------------------------------------
+
+.. automodule:: silx.io.fioh5
+
+
+Classes
++++++++
+
+- :class:`FioH5`
+- :class:`FioFile`
+
+.. autoclass:: FioH5
+ :members:
+ :show-inheritance:
+ :undoc-members:
+ :inherited-members: name, basename, attrs, h5py_class, parent,
+ get, keys, values, items,
+ :special-members: __getitem__, __len__, __contains__, __enter__, __exit__, __iter__
+ :exclude-members: add_node
+
+.. autoclass:: FioFile
+
+.. autoclass:: silx.io.commonh5.Group
+ :show-inheritance:
+ :undoc-members:
+ :members: name, basename, file, attrs, h5py_class, parent,
+ get, keys, values, items, visit, visititems
+ :special-members: __getitem__, __len__, __contains__, __iter__
+ :exclude-members: add_node
+
+.. autofunction:: is_fiofile \ No newline at end of file
diff --git a/doc/source/modules/io/index.rst b/doc/source/modules/io/index.rst
index 581f763..a511bef 100644
--- a/doc/source/modules/io/index.rst
+++ b/doc/source/modules/io/index.rst
@@ -16,6 +16,7 @@
specfile.rst
specfilewrapper.rst
spech5.rst
+ fioh5.rst
url.rst
utils.rst
h5py_utils.rst
diff --git a/doc/source/virtualenv.rst b/doc/source/virtualenv.rst
index ccdd9b6..280c031 100644
--- a/doc/source/virtualenv.rst
+++ b/doc/source/virtualenv.rst
@@ -132,7 +132,7 @@ To test *silx*, open an interactive python console:
python
-If you don't have PyQt5 or PySide2, run:
+If you don't have PyQt5, PySide2 or PySide6, run:
.. code-block:: bash
diff --git a/examples/colormapDialog.py b/examples/colormapDialog.py
index c9e7c35..d389327 100644
--- a/examples/colormapDialog.py
+++ b/examples/colormapDialog.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -140,7 +140,7 @@ class ColormapDialogExample(qt.QMainWindow):
button = qt.QPushButton("Negative to positive")
button.clicked.connect(self.setDataFromNegToPos)
layout.addWidget(button)
- button = qt.QPushButton("Only non finite values")
+ button = qt.QPushButton("With non finite values")
button.clicked.connect(self.setDataWithNonFinite)
layout.addWidget(button)
@@ -259,7 +259,7 @@ class ColormapDialogExample(qt.QMainWindow):
if scipy is not None:
from scipy import ndimage
data = ndimage.gaussian_filter(data, sigma=20)
- data = numpy.random.poisson(data)
+ data = numpy.random.poisson(data).astype(numpy.float32)
data[10] = float("nan")
data[50] = float("+inf")
data[100] = float("-inf")
@@ -275,7 +275,7 @@ def main():
example = ColormapDialogExample()
example.show()
- app.exec_()
+ app.exec()
if __name__ == '__main__':
diff --git a/examples/compareImages.py b/examples/compareImages.py
index 623216a..3408a72 100644
--- a/examples/compareImages.py
+++ b/examples/compareImages.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -208,4 +208,4 @@ if __name__ == "__main__":
window.setFiles(options.files)
window.setVisible(True)
- app.exec_()
+ app.exec()
diff --git a/examples/compositeline.py b/examples/compositeline.py
index 892ecf3..72398e6 100644
--- a/examples/compositeline.py
+++ b/examples/compositeline.py
@@ -72,7 +72,7 @@ def main(argv=None):
addLine([70, 0], [70, 100], "_", "_", "l6", "black")
mainWindow.setVisible(True)
- return app.exec_()
+ return app.exec()
if __name__ == "__main__":
diff --git a/examples/customDataView.py b/examples/customDataView.py
index 33662e8..e02e577 100644
--- a/examples/customDataView.py
+++ b/examples/customDataView.py
@@ -96,7 +96,7 @@ def main():
widget.addView(MyColorView(widget))
widget.setData(Color.GREEN)
widget.show()
- result = app.exec_()
+ result = app.exec()
# remove ending warnings relative to QTimer
app.deleteLater()
sys.exit(result)
diff --git a/examples/customHdf5TreeModel.py b/examples/customHdf5TreeModel.py
index fde76c5..ffc0220 100644
--- a/examples/customHdf5TreeModel.py
+++ b/examples/customHdf5TreeModel.py
@@ -279,7 +279,7 @@ def main(filenames):
sys.excepthook = qt.exceptionHandler
window = Hdf5TreeViewExample(filenames)
window.show()
- result = app.exec_()
+ result = app.exec()
# remove ending warnings relative to QTimer
app.deleteLater()
sys.exit(result)
diff --git a/examples/dropZones.py b/examples/dropZones.py
index 68d0a57..6593bbb 100644
--- a/examples/dropZones.py
+++ b/examples/dropZones.py
@@ -140,7 +140,7 @@ class DragLabel(qt.QLabel):
self._url.path().encode(encoding='utf-8'))
drag = qt.QDrag(self)
drag.setMimeData(mimeData)
- dropAction = drag.exec_()
+ dropAction = drag.exec()
class DragAndDropExample(qt.QMainWindow):
@@ -178,7 +178,7 @@ def main():
silx.io.url.DataUrl(file_path=filename, data_path='/curve', scheme="silx")))
example.setWindowTitle("Drag&Drop URLs sample code")
example.show()
- app.exec_()
+ app.exec()
if __name__ == "__main__":
diff --git a/examples/exampleBaseline.py b/examples/exampleBaseline.py
index edd0fc3..b53b412 100644
--- a/examples/exampleBaseline.py
+++ b/examples/exampleBaseline.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -157,7 +157,7 @@ def main(argv):
plot_log = get_plot_log(backend=options.backend)
plot_log.show()
- qapp.exec_()
+ qapp.exec()
if __name__ == '__main__':
diff --git a/examples/fftPlotAction.py b/examples/fftPlotAction.py
index bdea08d..f7c819f 100755
--- a/examples/fftPlotAction.py
+++ b/examples/fftPlotAction.py
@@ -190,5 +190,5 @@ plotwin.getYAxis().setLabel("amplitude")
plotwin.getXAxis().setLabel("time")
plotwin.show()
-app.exec_()
+app.exec()
sys.excepthook = sys.__excepthook__
diff --git a/examples/fileDialog.py b/examples/fileDialog.py
index 40191bb..fa11ed5 100644
--- a/examples/fileDialog.py
+++ b/examples/fileDialog.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -216,7 +216,7 @@ class DialogExample(qt.QMainWindow):
dialog = self.createDialog()
# Execute the dialog as modal
- result = dialog.exec_()
+ result = dialog.exec()
self.printResult(dialog, result)
def openDialogStoredState(self):
@@ -226,7 +226,7 @@ class DialogExample(qt.QMainWindow):
dialog.restoreState(self.__state[dialog.__class__])
# Execute the dialog as modal
- result = dialog.exec_()
+ result = dialog.exec()
self.__state[dialog.__class__] = dialog.saveState()
self.printResult(dialog, result)
@@ -237,7 +237,7 @@ class DialogExample(qt.QMainWindow):
dialog.setDirectory(path)
# Execute the dialog as modal
- result = dialog.exec_()
+ result = dialog.exec()
self.printResult(dialog, result)
def openDialogAtComputer(self):
@@ -247,7 +247,7 @@ class DialogExample(qt.QMainWindow):
dialog.setDirectory(path)
# Execute the dialog as modal
- result = dialog.exec_()
+ result = dialog.exec()
self.printResult(dialog, result)
@@ -255,7 +255,7 @@ def main():
app = qt.QApplication([])
example = DialogExample()
example.show()
- app.exec_()
+ app.exec()
if __name__ == "__main__":
diff --git a/examples/findContours.py b/examples/findContours.py
index 6199ba6..acf5199 100644
--- a/examples/findContours.py
+++ b/examples/findContours.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -694,7 +694,7 @@ def main():
window = FindContours()
window.generateIsland()
window.show()
- result = app.exec_()
+ result = app.exec()
# remove ending warnings relative to QTimer
app.deleteLater()
return result
diff --git a/examples/hdf5widget.py b/examples/hdf5widget.py
index 0d45b8f..82ce27d 100755
--- a/examples/hdf5widget.py
+++ b/examples/hdf5widget.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -25,12 +25,12 @@
# ###########################################################################*/
"""Qt Hdf5 widget examples"""
+import html
import logging
import sys
import tempfile
import numpy
-import six
logging.basicConfig()
_logger = logging.getLogger("hdf5widget")
@@ -46,7 +46,6 @@ except ImportError:
import h5py
import silx.gui.hdf5
-import silx.utils.html
from silx.gui import qt
from silx.gui.data.DataViewerFrame import DataViewerFrame
from silx.gui.widgets.ThreadPoolPushButton import ThreadPoolPushButton
@@ -59,7 +58,7 @@ _file_cache = {}
def str_attrs(str_list):
"""Return a numpy array of unicode strings"""
- text_dtype = h5py.special_dtype(vlen=six.text_type)
+ text_dtype = h5py.special_dtype(vlen=str)
return numpy.array(str_list, dtype=text_dtype)
@@ -573,7 +572,7 @@ class Hdf5TreeViewExample(qt.QMainWindow):
"""Called to log event in widget
"""
def formatKey(name, value):
- name, value = silx.utils.html.escape(str(name)), silx.utils.html.escape(str(value))
+ name, value = html.escape(str(name)), html.escape(str(value))
return "<li><b>%s</b>: %s</li>" % (name, value)
text = "<html>"
@@ -791,7 +790,7 @@ def main(filenames):
sys.excepthook = qt.exceptionHandler
window = Hdf5TreeViewExample(filenames)
window.show()
- result = app.exec_()
+ result = app.exec()
# remove ending warnings relative to QTimer
app.deleteLater()
sys.exit(result)
diff --git a/examples/icons.py b/examples/icons.py
index ae77630..ff8410d 100644
--- a/examples/icons.py
+++ b/examples/icons.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -177,4 +177,4 @@ if __name__ == "__main__":
app = qt.QApplication([])
window = IconPreview()
window.setVisible(True)
- app.exec_()
+ app.exec()
diff --git a/examples/imageStack.py b/examples/imageStack.py
index 0437a6e..4c211b5 100644
--- a/examples/imageStack.py
+++ b/examples/imageStack.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -132,7 +132,7 @@ def main():
urls = create_datasets(folder=dataset_folder)
widget.setUrls(urls=urls)
widget.show()
- qapp.exec_()
+ qapp.exec()
widget.close()
shutil.rmtree(dataset_folder)
diff --git a/examples/imageview.py b/examples/imageview.py
index 81741b1..40b5dff 100755
--- a/examples/imageview.py
+++ b/examples/imageview.py
@@ -40,14 +40,76 @@ __license__ = "MIT"
__date__ = "08/11/2018"
import logging
+import numpy
+import time
+import threading
+
+from silx.gui.utils import concurrent
from silx.gui.plot.ImageView import ImageViewMainWindow
from silx.gui import qt
-import numpy
logging.basicConfig()
logger = logging.getLogger(__name__)
+Nx = 150
+Ny = 50
+
+
+class UpdateThread(threading.Thread):
+ """Thread updating the image of a :class:`~sil.gui.plot.Plot2D`
+
+ :param plot2d: The Plot2D to update."""
+
+ def __init__(self, imageview):
+ self.imageview = imageview
+ self.running = False
+ self.future_result = None
+ super(UpdateThread, self).__init__()
+
+ def start(self):
+ """Start the update thread"""
+ self.running = True
+ super(UpdateThread, self).start()
+
+ def run(self, pos={'x0': 0, 'y0': 0}):
+ """Method implementing thread loop that updates the plot
+
+ It produces an image every 10 ms or so, and
+ either updates the plot or skip the image
+ """
+ while self.running:
+ time.sleep(0.01)
+
+ # Create image
+ # width of peak
+ sigma_x = 0.15
+ sigma_y = 0.25
+ # x and y positions
+ x = numpy.linspace(-1.5, 1.5, Nx)
+ y = numpy.linspace(-1.0, 1.0, Ny)
+ xv, yv = numpy.meshgrid(x, y)
+ signal = numpy.exp(- ((xv - pos['x0']) ** 2 / sigma_x ** 2
+ + (yv - pos['y0']) ** 2 / sigma_y ** 2))
+ # add noise
+ signal += 0.3 * numpy.random.random(size=signal.shape)
+ # random walk of center of peak ('drift')
+ pos['x0'] += 0.05 * (numpy.random.random() - 0.5)
+ pos['y0'] += 0.05 * (numpy.random.random() - 0.5)
+
+ # If previous frame was not added to the plot yet, skip this one
+ if self.future_result is None or self.future_result.done():
+ # plot the data asynchronously, and
+ # keep a reference to the `future` object
+ self.future_result = concurrent.submitToQtMainThread(
+ self.imageview.setImage, signal, resetzoom=False)
+
+ def stop(self):
+ """Stop the update thread"""
+ self.running = False
+ self.join(2)
+
+
def main(argv=None):
"""Display an image from a file in an :class:`ImageView` widget.
@@ -82,12 +144,17 @@ def main(argv=None):
parser.add_argument(
'filename', nargs='?',
help='EDF filename of the image to open')
+ parser.add_argument(
+ '--live', action='store_true',
+ help='Live update of a generated image')
args = parser.parse_args(args=argv)
# Open the input file
- if not args.filename:
+ edfFile = None
+ if args.live:
+ data = None
+ elif not args.filename:
logger.warning('No image file provided, displaying dummy data')
- edfFile = None
size = 512
xx, yy = numpy.ogrid[-size:size, -size:size]
data = numpy.cos(xx / (size//5)) + numpy.cos(yy / (size//5))
@@ -118,9 +185,10 @@ def main(argv=None):
colormap = mainWindow.getDefaultColormap()
colormap.setNormalization(colormap.LOGARITHM)
- mainWindow.setImage(data,
- origin=args.origin,
- scale=args.scale)
+ if data is not None:
+ mainWindow.setImage(data,
+ origin=args.origin,
+ scale=args.scale)
if edfFile is not None and nbFrames > 1:
# Add a toolbar for multi-frame EDF support
@@ -144,7 +212,12 @@ def main(argv=None):
mainWindow.show()
mainWindow.setFocus(qt.Qt.OtherFocusReason)
- return app.exec_()
+ if args.live:
+ # Start updating the plot
+ updateThread = UpdateThread(mainWindow)
+ updateThread.start()
+
+ return app.exec()
if __name__ == "__main__":
diff --git a/examples/periodicTable.py b/examples/periodicTable.py
index e329ef7..fc3985f 100644
--- a/examples/periodicTable.py
+++ b/examples/periodicTable.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -77,5 +77,4 @@ w.addTab(pl, "PeriodicList")
w.addTab(comboContainer, "PeriodicCombo")
w.show()
-a.exec_()
-
+a.exec()
diff --git a/examples/plot3dContextMenu.py b/examples/plot3dContextMenu.py
index d33bb8f..0802b29 100644
--- a/examples/plot3dContextMenu.py
+++ b/examples/plot3dContextMenu.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -86,7 +86,7 @@ class ScalarFieldViewWithContextMenu(ScalarFieldView):
# The position received as argument is relative to Plot3DWidget
# and needs to be converted.
globalPosition = self.getPlot3DWidget().mapToGlobal(pos)
- menu.exec_(globalPosition)
+ menu.exec(globalPosition)
# Start Qt QApplication
@@ -109,4 +109,4 @@ window.setData(data)
window.addIsosurface(0.2, '#FF0000FF')
window.show()
-app.exec_()
+app.exec()
diff --git a/examples/plot3dSceneWindow.py b/examples/plot3dSceneWindow.py
index 1b2f808..436b121 100644
--- a/examples/plot3dSceneWindow.py
+++ b/examples/plot3dSceneWindow.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -200,4 +200,4 @@ window.show()
sys.excepthook = qt.exceptionHandler
# Run Qt event loop
-qapp.exec_()
+qapp.exec()
diff --git a/examples/plot3dUpdateScatterFromThread.py b/examples/plot3dUpdateScatterFromThread.py
index 9c2213f..a02fec6 100644
--- a/examples/plot3dUpdateScatterFromThread.py
+++ b/examples/plot3dUpdateScatterFromThread.py
@@ -167,7 +167,7 @@ def main():
updateThread = UpdateScatterThread(scatter)
updateThread.start() # Start updating the plot
- app.exec_()
+ app.exec()
updateThread.stop() # Stop updating the plot
diff --git a/examples/plotClearAction.py b/examples/plotClearAction.py
index e1130cb..6f1823a 100644
--- a/examples/plotClearAction.py
+++ b/examples/plotClearAction.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -72,4 +72,4 @@ if __name__ == '__main__':
plot.addCurve((0, 1, 2, 3, 4), (0, 1, 1.5, 1, 0)) # Add a curve to the plot
plot.show() # Show the plot widget
- app.exec_() # Start Qt application
+ app.exec() # Start Qt application
diff --git a/examples/plotContextMenu.py b/examples/plotContextMenu.py
index 5f02f5f..bd1ad87 100644
--- a/examples/plotContextMenu.py
+++ b/examples/plotContextMenu.py
@@ -87,7 +87,7 @@ class PlotWidgetWithContextMenu(PlotWidget):
# plot area, and thus needs to be converted.
plotArea = self.getWidgetHandle()
globalPosition = plotArea.mapToGlobal(pos)
- menu.exec_(globalPosition)
+ menu.exec(globalPosition)
# Start the QApplication
@@ -100,4 +100,4 @@ plot.addCurve(x, numpy.sin(x), legend='sin')
# Show the widget and start the application
plot.show()
-app.exec_()
+app.exec()
diff --git a/examples/plotCurveLegendWidget.py b/examples/plotCurveLegendWidget.py
index 97c516a..98ba30b 100644
--- a/examples/plotCurveLegendWidget.py
+++ b/examples/plotCurveLegendWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -112,7 +112,7 @@ class MyCurveLegendsWidget(CurveLegendsWidget):
functools.partial(self._switchCurveVisibility, curve))
globalPosition = self.mapToGlobal(pos)
- menu.exec_(globalPosition)
+ menu.exec(globalPosition)
# First create the QApplication
@@ -151,4 +151,4 @@ window.addDockWidget(qt.Qt.RightDockWidgetArea, dock)
window.setAttribute(qt.Qt.WA_DeleteOnClose)
window.show()
-app.exec_()
+app.exec()
diff --git a/examples/plotInteractiveImageROI.py b/examples/plotInteractiveImageROI.py
index 7254b7e..298f7af 100644
--- a/examples/plotInteractiveImageROI.py
+++ b/examples/plotInteractiveImageROI.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -153,6 +153,6 @@ plot.addTabbedDockWidget(dock)
# Show the widget and start the application
plot.show()
-result = app.exec_()
+result = app.exec()
app.deleteLater()
sys.exit(result)
diff --git a/examples/plotItemsSelector.py b/examples/plotItemsSelector.py
index 177489f..d7493ae 100755
--- a/examples/plotItemsSelector.py
+++ b/examples/plotItemsSelector.py
@@ -46,11 +46,11 @@ pw.show()
isd = ItemsSelectionDialog(plot=pw)
isd.setItemsSelectionMode(qt.QTableWidget.ExtendedSelection)
-result = isd.exec_()
+result = isd.exec()
if result:
for item in isd.getSelectedItems():
print(item.getName(), type(item))
else:
print("Selection cancelled")
-app.exec_()
+app.exec()
diff --git a/examples/plotLimits.py b/examples/plotLimits.py
index c7cc7f5..75440f4 100644
--- a/examples/plotLimits.py
+++ b/examples/plotLimits.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -90,4 +90,4 @@ if __name__ == "__main__":
app = qt.QApplication([])
window = ConstrainedViewPlot()
window.setVisible(True)
- app.exec_()
+ app.exec()
diff --git a/examples/plotProfile.py b/examples/plotProfile.py
index 931f9b4..40e199a 100644
--- a/examples/plotProfile.py
+++ b/examples/plotProfile.py
@@ -201,7 +201,7 @@ def main():
app = qt.QApplication([])
widget = Example()
widget.show()
- app.exec_()
+ app.exec()
if __name__ == "__main__":
main()
diff --git a/examples/plotROIStats.py b/examples/plotROIStats.py
index 3caff7e..e713592 100644
--- a/examples/plotROIStats.py
+++ b/examples/plotROIStats.py
@@ -248,7 +248,7 @@ def example_curve(mode):
window.setUpdateMode(mode)
window.show()
- app.exec_()
+ app.exec()
def example_image(mode):
@@ -278,7 +278,7 @@ def example_image(mode):
window.setUpdateMode(mode)
window.show()
- app.exec_()
+ app.exec()
updateThread.stop() # Stop updating the plot
@@ -314,7 +314,7 @@ def example_curve_image(mode):
updateThread.start() # Start updating the plot
window.show()
- app.exec_()
+ app.exec()
updateThread.stop() # Stop updating the plot
diff --git a/examples/plotStats.py b/examples/plotStats.py
index 030caf8..433088f 100644
--- a/examples/plotStats.py
+++ b/examples/plotStats.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -161,7 +161,7 @@ def main(argv):
plot.getStatsWidget().parent().setVisible(True)
plot.show()
- app.exec_()
+ app.exec()
updateThread.stop() # Stop updating the plot
diff --git a/examples/plotUpdateCurveFromThread.py b/examples/plotUpdateCurveFromThread.py
index a36e5ee..27dbf9b 100644
--- a/examples/plotUpdateCurveFromThread.py
+++ b/examples/plotUpdateCurveFromThread.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -95,7 +95,7 @@ def main():
updateThread = UpdateThread(plot1d)
updateThread.start() # Start updating the plot
- app.exec_()
+ app.exec()
updateThread.stop() # Stop updating the plot
diff --git a/examples/plotUpdateImageFromThread.py b/examples/plotUpdateImageFromThread.py
index 5850263..de23d3f 100644
--- a/examples/plotUpdateImageFromThread.py
+++ b/examples/plotUpdateImageFromThread.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -116,6 +116,7 @@ def main():
# Create a Plot2D, set its limits and display it
plot2d = Plot2D()
+ plot2d.getIntensityHistogramAction().setVisible(True)
plot2d.setLimits(0, Nx, 0, Ny)
plot2d.getDefaultColormap().setVRange(0., 1.5)
plot2d.show()
@@ -124,7 +125,7 @@ def main():
updateThread = UpdateThread(plot2d)
updateThread.start() # Start updating the plot
- app.exec_()
+ app.exec()
updateThread.stop() # Stop updating the plot
diff --git a/examples/plotWidget.py b/examples/plotWidget.py
index af64afb..5d1f4b6 100644
--- a/examples/plotWidget.py
+++ b/examples/plotWidget.py
@@ -1,7 +1,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -64,7 +64,7 @@ class MyPlotWindow(qt.QMainWindow):
# Make ColorBarWidget background white by changing its palette
colorBar.setAutoFillBackground(True)
palette = colorBar.palette()
- palette.setColor(qt.QPalette.Background, qt.Qt.white)
+ palette.setColor(qt.QPalette.Window, qt.Qt.white)
palette.setColor(qt.QPalette.Window, qt.Qt.white)
colorBar.setPalette(palette)
@@ -178,7 +178,7 @@ def main():
window.setAttribute(qt.Qt.WA_DeleteOnClose)
window.show()
window.showImage()
- app.exec_()
+ app.exec()
if __name__ == '__main__':
diff --git a/examples/printPreview.py b/examples/printPreview.py
index 6de8209..7fe5480 100755
--- a/examples/printPreview.py
+++ b/examples/printPreview.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -96,4 +96,4 @@ pw3.show()
pw3.addCurve(x, numpy.cos(x * 2 * numpy.pi / 1000))
-app.exec_()
+app.exec()
diff --git a/examples/scatterMask.py b/examples/scatterMask.py
index 7a407ad..839fa3a 100644
--- a/examples/scatterMask.py
+++ b/examples/scatterMask.py
@@ -150,4 +150,4 @@ if __name__ == "__main__":
msw.setScatter(x, y, v=v)
msw.setBackgroundImage(bg_img)
msw.show()
- app.exec_()
+ app.exec()
diff --git a/examples/scatterview.py b/examples/scatterview.py
index cab32c0..5df11be 100755
--- a/examples/scatterview.py
+++ b/examples/scatterview.py
@@ -91,7 +91,7 @@ def main(argv=None):
mainWindow.show()
mainWindow.setFocus(qt.Qt.OtherFocusReason)
- return app.exec_()
+ return app.exec()
if __name__ == "__main__":
diff --git a/examples/shiftPlotAction.py b/examples/shiftPlotAction.py
index f272cda..6c5f8cf 100755
--- a/examples/shiftPlotAction.py
+++ b/examples/shiftPlotAction.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -110,4 +110,4 @@ plotwin.addCurve(x, y1, legend="triangle shaped curve")
plotwin.addCurve(x, y2, legend="oblique line")
plotwin.show()
-app.exec_()
+app.exec()
diff --git a/examples/simplewidget.py b/examples/simplewidget.py
index 0e9038e..0a1c336 100755
--- a/examples/simplewidget.py
+++ b/examples/simplewidget.py
@@ -219,7 +219,7 @@ def main():
sys.excepthook = qt.exceptionHandler
window = SimpleWidgetExample()
window.show()
- result = app.exec_()
+ result = app.exec()
# remove ending warnings relative to QTimer
app.deleteLater()
sys.excepthook = sys.__excepthook__
diff --git a/examples/stackView.py b/examples/stackView.py
index a4b6e8c..f857140 100644
--- a/examples/stackView.py
+++ b/examples/stackView.py
@@ -61,4 +61,4 @@ maskToolsWidget.setItemMaskUpdated(True)
sv.show()
-app.exec_()
+app.exec()
diff --git a/examples/syncPlotLocation.py b/examples/syncPlotLocation.py
index 55332bc..83c1ade 100644
--- a/examples/syncPlotLocation.py
+++ b/examples/syncPlotLocation.py
@@ -102,4 +102,4 @@ if __name__ == "__main__":
window = SyncPlot()
window.setAttribute(qt.Qt.WA_DeleteOnClose, True)
window.setVisible(True)
- app.exec_()
+ app.exec()
diff --git a/examples/syncaxis.py b/examples/syncaxis.py
index 02505c9..2976231 100644
--- a/examples/syncaxis.py
+++ b/examples/syncaxis.py
@@ -99,4 +99,4 @@ if __name__ == "__main__":
window = SyncPlot()
window.setAttribute(qt.Qt.WA_DeleteOnClose, True)
window.setVisible(True)
- app.exec_()
+ app.exec()
diff --git a/examples/viewer3DVolume.py b/examples/viewer3DVolume.py
index 2193402..5b86199 100644
--- a/examples/viewer3DVolume.py
+++ b/examples/viewer3DVolume.py
@@ -203,4 +203,4 @@ else:
window.addIsosurface(default_isolevel, '#FF0000FF')
window.show()
-app.exec_()
+app.exec()
diff --git a/package/debian10/control b/package/debian10/control
index d724e69..25cf293 100644
--- a/package/debian10/control
+++ b/package/debian10/control
@@ -16,7 +16,6 @@ Build-Depends: cython3 (>= 0.23.2),
python3-all-dev,
python3-dateutil,
python3-qtconsole,
- python3-six,
python3-fabio,
python3-h5py,
python3-mako,
@@ -29,6 +28,8 @@ Build-Depends: cython3 (>= 0.23.2),
python3-pyqt5,
python3-pyqt5.qtopengl,
python3-pyqt5.qtsvg,
+ python3-pytest,
+ python3-pytest-xvfb,
python3-scipy,
python3-setuptools,
python3-sphinx,
diff --git a/package/debian11/control b/package/debian11/control
index 5e387fc..775753d 100644
--- a/package/debian11/control
+++ b/package/debian11/control
@@ -34,11 +34,12 @@ Build-Depends: cython3 (>= 0.23.2),
python3-pyqt5.qtopengl-dbg,
python3-pyqt5.qtsvg,
python3-pyqt5.qtsvg-dbg,
+ python3-pytest,
+ python3-pytest-xvfb,
python3-qtconsole,
python3-scipy,
python3-scipy-dbg,
python3-setuptools,
- python3-six,
python3-sphinx,
python3-sphinxcontrib.programoutput,
xauth,
diff --git a/package/windows/README.rst b/package/windows/README.rst
index 97c1d54..cbf6fa3 100644
--- a/package/windows/README.rst
+++ b/package/windows/README.rst
@@ -4,20 +4,13 @@ Generate silx fat binary for Windows
Pre-requisites
--------------
-- PyInstaller must be installed.
- It is best to use the development version to use package specific hooks up-to-date.
- Run either::
-
- pip install -r requirements-dev.txt
-
- or::
-
- pip install https://github.com/pyinstaller/pyinstaller/archive/develop.zip
-
+- `PyInstaller <https://pyinstaller.readthedocs.io/>`_ must be installed.
+ Run: ``pip install -r requirements-dev.txt``
+- `InnoSetup <https://jrsoftware.org/isinfo.php>`_ must be installed and in the ``PATH``.
- silx and all its dependencies must be INSTALLED::
pip install silx[full]
-
+
or from the source directory::
pip install .[full]
@@ -28,10 +21,5 @@ Procedure
- Go to the ``package/windows`` folder in the source directory
- Run ``pyinstaller pyinstaller.spec``.
- This generates a fat binary in ``package/windows/dist/silx/`` for the generic launcher ``silx.exe``.
-- Run ``pyinstaller pyinstaller-silx-view.spec``.
- This generates a fat binary in ``package/windows/dist/silx-view/`` for the silx view command ``silx-view.exe``.
-- Copy ``silx-view.exe`` and ``silx-view.exe.manifest`` to ``package/windows/dist/silx/``.
- This is a hack until PyInstaller supports multiple executables (see https://github.com/pyinstaller/pyinstaller/issues/1527).
-- Zip ``package\windows\dist\silx`` to make the application available as a single zip file.
-
+ This will generates the fat binary in ``package/windows/dist/``.
+ It will then run InnoSetup to create the installer in ``package/windows/artifacts/``.
diff --git a/package/windows/create-installer.iss.template b/package/windows/create-installer.iss.template
new file mode 100644
index 0000000..ffb6af4
--- /dev/null
+++ b/package/windows/create-installer.iss.template
@@ -0,0 +1,92 @@
+[Setup]
+AppId={{A694A78C-B4D1-4983-8BC6-A84D30341EB4}
+AppName=silx view
+AppVersion=#Version
+AppVerName=silx-view-#Version
+AppPublisher=ESRF
+AppPublisherURL=https://www.silx.org
+AppSupportURL=https://github.com/silx-kit/silx
+AppUpdatesURL=https://github.com/silx-kit/silx/releases
+DefaultDirName={autopf}\silx
+DefaultGroupName=silx
+LicenseFile=..\..\LICENSE
+OutputDir=artifacts
+OutputBaseFilename=silx-#Version-x64
+Compression=lzma
+SolidCompression=yes
+ArchitecturesAllowed=x64
+ArchitecturesInstallIn64BitMode=x64
+WizardStyle=modern
+
+[Languages]
+Name: "english"; MessagesFile: "compiler:Default.isl"
+
+[Files]
+Source: "dist\silx\*"; DestDir: "{app}"; Flags: ignoreversion recursesubdirs
+Source: "silx.ico"; DestDir: "{app}"
+
+[Icons]
+Name: "{group}\silx"; Filename: "{app}\silx-view.exe"; IconFilename: "{app}\silx.ico"
+Name: "{group}\Uninstall"; Filename: "{uninstallexe}"
+
+// Code from https://stackoverflow.com/questions/2000296/inno-setup-how-to-automatically-uninstall-previous-installed-version/2099805#209980
+[Code]
+
+/////////////////////////////////////////////////////////////////////
+function GetUninstallString(): String;
+var
+ sUnInstPath: String;
+ sUnInstallString: String;
+begin
+ sUnInstPath := ExpandConstant('Software\Microsoft\Windows\CurrentVersion\Uninstall\{#emit SetupSetting("AppId")}_is1');
+ sUnInstallString := '';
+ if not RegQueryStringValue(HKLM, sUnInstPath, 'UninstallString', sUnInstallString) then
+ RegQueryStringValue(HKCU, sUnInstPath, 'UninstallString', sUnInstallString);
+ Result := sUnInstallString;
+end;
+
+
+/////////////////////////////////////////////////////////////////////
+function IsUpgrade(): Boolean;
+begin
+ Result := (GetUninstallString() <> '');
+end;
+
+
+/////////////////////////////////////////////////////////////////////
+function UnInstallOldVersion(): Integer;
+var
+ sUnInstallString: String;
+ iResultCode: Integer;
+begin
+ // Return Values:
+ // 1 - uninstall string is empty
+ // 2 - error executing the UnInstallString
+ // 3 - successfully executed the UnInstallString
+
+ // default return value
+ Result := 0;
+
+ // get the uninstall string of the old app
+ sUnInstallString := GetUninstallString();
+ if sUnInstallString <> '' then begin
+ sUnInstallString := RemoveQuotes(sUnInstallString);
+ if Exec(sUnInstallString, '/VERYSILENT /NORESTART /SUPPRESSMSGBOXES','', SW_HIDE, ewWaitUntilTerminated, iResultCode) then
+ Result := 3
+ else
+ Result := 2;
+ end else
+ Result := 1;
+end;
+
+/////////////////////////////////////////////////////////////////////
+procedure CurStepChanged(CurStep: TSetupStep);
+begin
+ if (CurStep=ssInstall) then
+ begin
+ if (IsUpgrade()) then
+ begin
+ UnInstallOldVersion();
+ end;
+ end;
+end;
diff --git a/package/windows/pyinstaller-silx-view.spec b/package/windows/pyinstaller-silx-view.spec
deleted file mode 100644
index cf01fd1..0000000
--- a/package/windows/pyinstaller-silx-view.spec
+++ /dev/null
@@ -1,55 +0,0 @@
-# -*- mode: python -*-
-import os.path
-from PyInstaller.utils.hooks import collect_data_files, collect_submodules
-
-datas = []
-
-PROJECT_PATH = os.path.abspath(os.path.join(SPECPATH, "..", ".."))
-datas.append((os.path.join(PROJECT_PATH, "README.rst"), "."))
-datas.append((os.path.join(PROJECT_PATH, "LICENSE"), "."))
-datas.append((os.path.join(PROJECT_PATH, "copyright"), "."))
-datas += collect_data_files("silx.resources")
-
-
-hiddenimports = collect_submodules('fabio')
-
-
-block_cipher = None
-
-
-a = Analysis(['bootstrap-silx-view.py'],
- pathex=[],
- binaries=[],
- datas=datas,
- hiddenimports=hiddenimports,
- hookspath=[],
- runtime_hooks=[],
- excludes=[],
- win_no_prefer_redirects=False,
- win_private_assemblies=False,
- cipher=block_cipher,
- noarchive=False)
-
-pyz = PYZ(a.pure,
- a.zipped_data,
- cipher=block_cipher)
-
-exe = EXE(pyz,
- a.scripts,
- [],
- exclude_binaries=True,
- name='silx-view',
- debug=False,
- bootloader_ignore_signals=False,
- strip=False,
- upx=False,
- console=False,
- icon='silx.ico')
-
-coll = COLLECT(exe,
- a.binaries,
- a.zipfiles,
- a.datas,
- strip=False,
- upx=False,
- name='silx-view')
diff --git a/package/windows/pyinstaller.spec b/package/windows/pyinstaller.spec
index 548e41a..59b66c1 100644
--- a/package/windows/pyinstaller.spec
+++ b/package/windows/pyinstaller.spec
@@ -1,5 +1,9 @@
# -*- mode: python -*-
import os.path
+from pathlib import Path
+import shutil
+import subprocess
+
from PyInstaller.utils.hooks import collect_data_files, collect_submodules
datas = []
@@ -9,47 +13,131 @@ datas.append((os.path.join(PROJECT_PATH, "README.rst"), "."))
datas.append((os.path.join(PROJECT_PATH, "LICENSE"), "."))
datas.append((os.path.join(PROJECT_PATH, "copyright"), "."))
datas += collect_data_files("silx.resources")
+datas += collect_data_files("hdf5plugin")
-hiddenimports = collect_submodules('fabio')
+hiddenimports = []
+hiddenimports += collect_submodules('fabio')
+hiddenimports += collect_submodules('hdf5plugin')
block_cipher = None
-a = Analysis(['bootstrap.py'],
- pathex=[],
- binaries=[],
- datas=datas,
- hiddenimports=hiddenimports,
- hookspath=[],
- runtime_hooks=[],
- excludes=[],
- win_no_prefer_redirects=False,
- win_private_assemblies=False,
- cipher=block_cipher,
- noarchive=False)
-
-pyz = PYZ(a.pure,
- a.zipped_data,
- cipher=block_cipher)
-
-exe = EXE(pyz,
- a.scripts,
- [],
- exclude_binaries=True,
- name='silx',
- debug=False,
- bootloader_ignore_signals=False,
- strip=False,
- upx=False,
- console=True,
- icon='silx.ico')
-
-coll = COLLECT(exe,
- a.binaries,
- a.zipfiles,
- a.datas,
- strip=False,
- upx=False,
- name='silx')
+silx_a = Analysis(
+ ['bootstrap.py'],
+ pathex=[],
+ binaries=[],
+ datas=datas,
+ hiddenimports=hiddenimports,
+ hookspath=[],
+ runtime_hooks=[],
+ excludes=[],
+ win_no_prefer_redirects=False,
+ win_private_assemblies=False,
+ cipher=block_cipher,
+ noarchive=False)
+
+silx_view_a = Analysis(
+ ['bootstrap-silx-view.py'],
+ pathex=[],
+ binaries=[],
+ datas=datas,
+ hiddenimports=hiddenimports,
+ hookspath=[],
+ runtime_hooks=[],
+ excludes=[],
+ win_no_prefer_redirects=False,
+ win_private_assemblies=False,
+ cipher=block_cipher,
+ noarchive=False)
+
+MERGE(
+ (silx_a, 'silx', os.path.join('silx', 'silx')),
+ (silx_view_a, 'silx-view', os.path.join('silx-view', 'silx-view')),
+)
+
+
+silx_pyz = PYZ(
+ silx_a.pure,
+ silx_a.zipped_data,
+ cipher=block_cipher)
+
+silx_exe = EXE(
+ silx_pyz,
+ silx_a.scripts,
+ silx_a.dependencies,
+ [],
+ exclude_binaries=True,
+ name='silx',
+ debug=False,
+ bootloader_ignore_signals=False,
+ strip=False,
+ upx=False,
+ console=True,
+ icon='silx.ico')
+
+silx_coll = COLLECT(
+ silx_exe,
+ silx_a.binaries,
+ silx_a.zipfiles,
+ silx_a.datas,
+ strip=False,
+ upx=False,
+ name='silx')
+
+
+silx_view_pyz = PYZ(
+ silx_view_a.pure,
+ silx_view_a.zipped_data,
+ cipher=block_cipher)
+
+silx_view_exe = EXE(
+ silx_view_pyz,
+ silx_view_a.scripts,
+ silx_view_a.dependencies,
+ [],
+ exclude_binaries=True,
+ name='silx-view',
+ debug=False,
+ bootloader_ignore_signals=False,
+ strip=False,
+ upx=False,
+ console=False,
+ icon='silx.ico')
+
+silx_view_coll = COLLECT(
+ silx_view_exe,
+ silx_view_a.binaries,
+ silx_view_a.zipfiles,
+ silx_view_a.datas,
+ strip=False,
+ upx=False,
+ name='silx-view')
+
+
+# Fix MERGE by copying produced silx-view.exe file
+def move_silx_view_exe():
+ dist = Path(SPECPATH) / 'dist'
+ shutil.copy2(
+ src=dist / 'silx-view' / 'silx-view.exe',
+ dst=dist / 'silx',
+ )
+ shutil.rmtree(dist / 'silx-view')
+
+move_silx_view_exe()
+
+# Run innosetup
+def innosetup():
+ from silx import version
+
+ config_name = "create-installer.iss"
+ with open(config_name + '.template') as f:
+ content = f.read().replace("#Version", version)
+ with open(config_name, "w") as f:
+ f.write(content)
+
+ subprocess.call(["iscc", os.path.join(SPECPATH, config_name)])
+ os.remove(config_name)
+
+innosetup()
diff --git a/pyproject.toml b/pyproject.toml
index c80dee7..c92bf42 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,3 +5,4 @@ requires = [
"numpy>=1.12",
"Cython>=0.21.1"
]
+build-backend = "setuptools.build_meta" \ No newline at end of file
diff --git a/qtdesigner_plugins/plot1dplugin.py b/qtdesigner_plugins/plot1dplugin.py
index 86982af..8bfdc50 100644
--- a/qtdesigner_plugins/plot1dplugin.py
+++ b/qtdesigner_plugins/plot1dplugin.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,9 +32,7 @@ __date__ = "30/05/2016"
from silx.gui import icons, qt
-if qt.BINDING == 'PyQt4':
- from PyQt4 import QtDesigner
-elif qt.BINDING == 'PyQt5':
+if qt.BINDING == 'PyQt5':
from PyQt5 import QtDesigner
else:
raise RuntimeError("Unsupport Qt BINDING: %s" % qt.BINDING)
diff --git a/qtdesigner_plugins/plot2dplugin.py b/qtdesigner_plugins/plot2dplugin.py
index 1a07510..31c7557 100644
--- a/qtdesigner_plugins/plot2dplugin.py
+++ b/qtdesigner_plugins/plot2dplugin.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,9 +32,7 @@ __date__ = "30/05/2016"
from silx.gui import icons, qt
-if qt.BINDING == 'PyQt4':
- from PyQt4 import QtDesigner
-elif qt.BINDING == 'PyQt5':
+if qt.BINDING == 'PyQt5':
from PyQt5 import QtDesigner
else:
raise RuntimeError("Unsupport Qt BINDING: %s" % qt.BINDING)
diff --git a/qtdesigner_plugins/plotwidgetplugin.py b/qtdesigner_plugins/plotwidgetplugin.py
index 6cc97a5..5c92ebe 100644
--- a/qtdesigner_plugins/plotwidgetplugin.py
+++ b/qtdesigner_plugins/plotwidgetplugin.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2016 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,9 +32,7 @@ __date__ = "30/05/2016"
from silx.gui import qt, icons
-if qt.BINDING == 'PyQt4':
- from PyQt4 import QtDesigner
-elif qt.BINDING == 'PyQt5':
+if qt.BINDING == 'PyQt5':
from PyQt5 import QtDesigner
else:
raise RuntimeError("Unsupport Qt BINDING: %s" % qt.BINDING)
diff --git a/qtdesigner_plugins/plotwindowplugin.py b/qtdesigner_plugins/plotwindowplugin.py
index b666399..89e6c47 100644
--- a/qtdesigner_plugins/plotwindowplugin.py
+++ b/qtdesigner_plugins/plotwindowplugin.py
@@ -2,7 +2,7 @@
# coding: utf-8
# /*##########################################################################
#
-# Copyright (c) 2015-2016 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,9 +32,7 @@ __date__ = "30/05/2016"
from silx.gui import icons, qt
-if qt.BINDING == 'PyQt4':
- from PyQt4 import QtDesigner
-elif qt.BINDING == 'PyQt5':
+if qt.BINDING == 'PyQt5':
from PyQt5 import QtDesigner
else:
raise RuntimeError("Unsupport Qt BINDING: %s" % qt.BINDING)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index dac7fad..b5eea71 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,11 +4,15 @@
-r requirements.txt
wheel # To build wheels
Sphinx # To build the documentation in doc/
-lxml # For test coverage in run_test.py
-coverage # For test coverage in run_test.py
pillow # For loading images in documentation generation
nbsphinx # For converting ipynb in documentation
pandoc # For documentation Qt snapshot updates
+pytest # For testing
+pytest-xvfb # For GUI testing
+pytest-cov # For coverage
+
+hdf5plugin # For HDF5 compression filters handling
+
# Use dev version of PyInstaller to keep hooks up-to-date
https://github.com/pyinstaller/pyinstaller/archive/develop.zip; sys_platform == "win32"
diff --git a/requirements.txt b/requirements.txt
index fb5690d..06f57e1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@
--trusted-host www.silx.org
--find-links http://www.silx.org/pub/wheelhouse/
---only-binary numpy,h5py,scipy,PySide2,PyQt5
+--only-binary numpy,h5py,scipy,PySide2,PyQt5,PySide6
# Required dependencies (from setup.py setup_requires and install_requires)
numpy >= 1.12
@@ -11,7 +11,6 @@ setuptools
Cython >= 0.21.1
h5py
fabio >= 0.9
-six
# Extra dependencies (from setup.py extra_requires 'full' target)
pyopencl; platform_machine in "i386, x86_64, AMD64" # For silx.opencl
@@ -22,4 +21,4 @@ PyOpenGL # For silx.gui.plot3d
python-dateutil # For silx.gui.plot
scipy # For silx.math.fit demo, silx.image.sift demo, silx.image.sift.test
Pillow # For silx.opencl.image.test
-PyQt5 # or PySide2 # For silx.gui
+PyQt5 # PySide2, PySide6 # For silx.gui
diff --git a/run_tests.py b/run_tests.py
index 5d3155a..bc8efe8 100755
--- a/run_tests.py
+++ b/run_tests.py
@@ -2,7 +2,7 @@
# coding: utf8
# /*##########################################################################
#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -40,43 +40,9 @@ import logging
import os
import subprocess
import sys
-import time
-import unittest
-import collections
-from argparse import ArgumentParser
+import importlib
-class StreamHandlerUnittestReady(logging.StreamHandler):
- """The unittest class TestResult redefine sys.stdout/err to capture
- stdout/err from tests and to display them only when a test fail.
- This class allow to use unittest stdout-capture by using the last sys.stdout
- and not a cached one.
- """
-
- def emit(self, record):
- """
- :type record: logging.LogRecord
- """
- self.stream = sys.stderr
- super(StreamHandlerUnittestReady, self).emit(record)
-
- def flush(self):
- pass
-
-
-def createBasicHandler():
- """Create the handler using the basic configuration"""
- hdlr = StreamHandlerUnittestReady()
- fs = logging.BASIC_FORMAT
- dfs = None
- fmt = logging.Formatter(fs, dfs)
- hdlr.setFormatter(fmt)
- return hdlr
-
-
-# Use an handler compatible with unittests, else use_buffer is not working
-logging.root.addHandler(createBasicHandler())
-
# Capture all default warnings
logging.captureWarnings(True)
import warnings
@@ -87,25 +53,6 @@ logger.setLevel(logging.WARNING)
logger.info("Python %s %s", sys.version, tuple.__itemsize__ * 8)
-try:
- import resource
-except ImportError:
- resource = None
- logger.warning("resource module missing")
-
-try:
- import importlib
- importer = importlib.import_module
-except ImportError:
-
- def importer(name):
- module = __import__(name)
- # returns the leaf module, instead of the root module
- subnames = name.split(".")
- subnames.pop(0)
- for subname in subnames:
- module = getattr(module, subname)
- return module
try:
import numpy
@@ -136,123 +83,6 @@ def get_project_name(root_dir):
return name.split()[-1].decode('ascii')
-class TextTestResultWithSkipList(unittest.TextTestResult):
- """Override default TextTestResult to display list of skipped tests at the
- end
- """
-
- def printErrors(self):
- unittest.TextTestResult.printErrors(self)
- # Print skipped tests at the end
- self.printGroupedList("SKIPPED", self.skipped)
-
- def printGroupedList(self, flavour, errors):
- grouped = collections.OrderedDict()
-
- for test, err in errors:
- if err in grouped:
- grouped[err] = grouped[err] + [test]
- else:
- grouped[err] = [test]
-
- for err, tests in grouped.items():
- self.stream.writeln(self.separator1)
- for test in tests:
- self.stream.writeln("%s: %s" % (flavour, self.getDescription(test)))
- self.stream.writeln(self.separator2)
- self.stream.writeln("%s" % err)
-
-
-class ProfileTextTestResult(unittest.TextTestRunner.resultclass):
-
- def __init__(self, *arg, **kwarg):
- unittest.TextTestRunner.resultclass.__init__(self, *arg, **kwarg)
- self.logger = logging.getLogger("memProf")
- self.logger.setLevel(min(logging.INFO, logging.root.level))
- self.logger.handlers.append(logging.FileHandler("profile.log"))
-
- def startTest(self, test):
- unittest.TextTestRunner.resultclass.startTest(self, test)
- if resource:
- self.__mem_start = \
- resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
- self.__time_start = time.time()
-
- def stopTest(self, test):
- unittest.TextTestRunner.resultclass.stopTest(self, test)
- # see issue 311. For other platform, get size of ru_maxrss in "man getrusage"
- if sys.platform == "darwin":
- ratio = 1e-6
- else:
- ratio = 1e-3
- if resource:
- memusage = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss -
- self.__mem_start) * ratio
- else:
- memusage = 0
- self.logger.info("Time: %.3fs \t RAM: %.3f Mb\t%s",
- time.time() - self.__time_start,
- memusage, test.id())
-
-
-def report_rst(cov, package, version="0.0.0", base=""):
- """
- Generate a report of test coverage in RST (for Sphinx inclusion)
-
- :param cov: test coverage instance
- :param str package: Name of the package
- :param str base: base directory of modules to include in the report
- :return: RST string
- """
- import tempfile
- fd, fn = tempfile.mkstemp(suffix=".xml")
- os.close(fd)
- cov.xml_report(outfile=fn)
-
- from lxml import etree
- xml = etree.parse(fn)
- classes = xml.xpath("//class")
-
- line0 = "Test coverage report for %s" % package
- res = [line0, "=" * len(line0), ""]
- res.append("Measured on *%s* version %s, %s" %
- (package, version, time.strftime("%d/%m/%Y")))
- res += ["",
- ".. csv-table:: Test suite coverage",
- ' :header: "Name", "Stmts", "Exec", "Cover"',
- ' :widths: 35, 8, 8, 8',
- '']
- tot_sum_lines = 0
- tot_sum_hits = 0
-
- for cl in classes:
- name = cl.get("name")
- fname = cl.get("filename")
- if not os.path.abspath(fname).startswith(base):
- continue
- lines = cl.find("lines").getchildren()
- hits = [int(i.get("hits")) for i in lines]
-
- sum_hits = sum(hits)
- sum_lines = len(lines)
-
- cover = 100.0 * sum_hits / sum_lines if sum_lines else 0
-
- if base:
- name = os.path.relpath(fname, base)
-
- res.append(' "%s", "%s", "%s", "%.1f %%"' %
- (name, sum_lines, sum_hits, cover))
- tot_sum_lines += sum_lines
- tot_sum_hits += sum_hits
- res.append("")
- res.append(' "%s total", "%s", "%s", "%.1f %%"' %
- (package, tot_sum_lines, tot_sum_hits,
- 100.0 * tot_sum_hits / tot_sum_lines if tot_sum_lines else 0))
- res.append("")
- return os.linesep.join(res)
-
-
def is_debug_python():
"""Returns true if the Python interpreter is in debug mode."""
try:
@@ -304,14 +134,9 @@ def build_project(name, root_dir):
def import_project_module(project_name, project_dir):
"""Import project module, from the system of from the project directory"""
- # Prevent importing from source directory
- if (os.path.dirname(os.path.abspath(__file__)) == os.path.abspath(sys.path[0])):
- removed_from_sys_path = sys.path.pop(0)
- logger.info("Patched sys.path, removed: '%s'", removed_from_sys_path)
-
if "--installed" in sys.argv:
try:
- module = importer(project_name)
+ module = importlib.import_module(project_name)
except Exception:
logger.error("Cannot run tests on installed version: %s not installed or raising error.",
project_name)
@@ -322,25 +147,13 @@ def import_project_module(project_name, project_dir):
logging.error("Built project is not available !!! investigate")
sys.path.insert(0, build_dir)
logger.warning("Patched sys.path, added: '%s'", build_dir)
- module = importer(project_name)
+ module = importlib.import_module(project_name)
return module
-def get_test_options(project_module):
- """Returns the test options if available, else None"""
- module_name = project_module.__name__ + '.test.utils'
- logger.info('Import %s', module_name)
- try:
- test_utils = importer(module_name)
- except ImportError:
- logger.warning("No module named '%s'. No test options available.", module_name)
- return None
-
- test_options = getattr(test_utils, "test_options", None)
- return test_options
-
-
if __name__ == "__main__": # Needed for multiprocessing support on Windows
+ import pytest
+
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_NAME = get_project_name(PROJECT_DIR)
logger.info("Project name: %s", PROJECT_NAME)
@@ -349,143 +162,20 @@ if __name__ == "__main__": # Needed for multiprocessing support on Windows
PROJECT_VERSION = getattr(project_module, 'version', '')
PROJECT_PATH = project_module.__path__[0]
- test_options = get_test_options(project_module)
- """Contains extra configuration for the tests."""
-
- epilog = """Environment variables:
- WITH_QT_TEST=False to disable graphical tests
- SILX_OPENCL=False to disable OpenCL tests
- SILX_TEST_LOW_MEM=True to disable tests taking large amount of memory
- GPU=False to disable the use of a GPU with OpenCL test
- WITH_GL_TEST=False to disable tests using OpenGL
- """
- parser = ArgumentParser(description='Run the tests.',
- epilog=epilog)
-
- parser.add_argument("--installed",
- action="store_true", dest="installed", default=False,
- help=("Test the installed version instead of" +
- "building from the source"))
- parser.add_argument("-c", "--coverage", dest="coverage",
- action="store_true", default=False,
- help=("Report code coverage" +
- "(requires 'coverage' and 'lxml' module)"))
- parser.add_argument("-m", "--memprofile", dest="memprofile",
- action="store_true", default=False,
- help="Report memory profiling")
- parser.add_argument("-v", "--verbose", default=0,
- action="count", dest="verbose",
- help="Increase verbosity. Option -v prints additional " +
- "INFO messages. Use -vv for full verbosity, " +
- "including debug messages and test help strings.")
- parser.add_argument("--qt-binding", dest="qt_binding", default=None,
- help="Force using a Qt binding, from 'PyQt4', 'PyQt5', or 'PySide'")
- if test_options is not None:
- test_options.add_parser_argument(parser)
-
- default_test_name = "%s.test.suite" % PROJECT_NAME
- parser.add_argument("test_name", nargs='*',
- default=(default_test_name,),
- help="Test names to run (Default: %s)" % default_test_name)
- options = parser.parse_args()
- sys.argv = [sys.argv[0]]
-
- test_verbosity = 1
- use_buffer = True
- if options.verbose == 1:
- logging.root.setLevel(logging.INFO)
- logger.info("Set log level: INFO")
- test_verbosity = 2
- use_buffer = False
- elif options.verbose > 1:
- logging.root.setLevel(logging.DEBUG)
- logger.info("Set log level: DEBUG")
- test_verbosity = 2
- use_buffer = False
-
- if options.coverage:
- logger.info("Running test-coverage")
- import coverage
- omits = ["*test*", "*third_party*", "*/setup.py",
- # temporary test modules (silx.math.fit.test.test_fitmanager)
- "*customfun.py", ]
- try:
- cov = coverage.Coverage(omit=omits)
- except AttributeError:
- cov = coverage.coverage(omit=omits)
- cov.start()
-
- if options.qt_binding:
- binding = options.qt_binding.lower()
- if binding == "pyqt4":
- logger.info("Force using PyQt4")
- if sys.version < "3.0.0":
- try:
- import sip
- sip.setapi("QString", 2)
- sip.setapi("QVariant", 2)
- except Exception:
- logger.warning("Cannot set sip API")
- import PyQt4.QtCore # noqa
- elif binding == "pyqt5":
- logger.info("Force using PyQt5")
- import PyQt5.QtCore # noqa
- elif binding == "pyside":
- logger.info("Force using PySide")
- import PySide.QtCore # noqa
- elif binding == "pyside2":
- logger.info("Force using PySide2")
- import PySide2.QtCore # noqa
- else:
- raise ValueError("Qt binding '%s' is unknown" % options.qt_binding)
-
- # Run the tests
- runnerArgs = {}
- runnerArgs["verbosity"] = test_verbosity
- runnerArgs["buffer"] = use_buffer
- if options.memprofile:
- runnerArgs["resultclass"] = ProfileTextTestResult
- else:
- runnerArgs["resultclass"] = TextTestResultWithSkipList
- runner = unittest.TextTestRunner(**runnerArgs)
-
- logger.warning("Test %s %s from %s",
- PROJECT_NAME, PROJECT_VERSION, PROJECT_PATH)
-
- test_module_name = PROJECT_NAME + '.test'
- logger.info('Import %s', test_module_name)
- test_module = importer(test_module_name)
- test_suite = unittest.TestSuite()
+ def normalize_option(option):
+ option_parts = option.split(os.path.sep)
+ if option_parts == ["src", "silx"]:
+ return PROJECT_PATH
+ if option_parts[:2] == ["src", "silx"]:
+ return os.path.join(PROJECT_PATH, *option_parts[2:])
+ return option
- if test_options is not None:
- # Configure the test options according to the command lines and the the environment
- test_options.configure(options)
- else:
- logger.warning("No test options available.")
-
- if not options.test_name:
- # Do not use test loader to avoid cryptic exception
- # when an error occur during import
- project_test_suite = getattr(test_module, 'suite')
- test_suite.addTest(project_test_suite())
- else:
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromNames(options.test_name))
-
- # Display the result when using CTRL-C
- unittest.installHandler()
-
- result = runner.run(test_suite)
-
- if result.wasSuccessful():
- exit_status = 0
- else:
- exit_status = 1
+ args = [normalize_option(p) for p in sys.argv[1:] if p != "--installed"]
- if options.coverage:
- cov.stop()
- cov.save()
- with open("coverage.rst", "w") as fn:
- fn.write(report_rst(cov, PROJECT_NAME, PROJECT_VERSION, PROJECT_PATH))
+ # Run test on PROJECT_PATH if nothing is specified
+ without_options = [a for a in args if not a.startswith("-")]
+ if len(without_options) == 0:
+ args += [PROJECT_PATH]
- sys.exit(exit_status)
+ argv = ["--rootdir", PROJECT_PATH] + args
+ sys.exit(pytest.main(argv))
diff --git a/setup.py b/setup.py
index 771374c..045b9a0 100644
--- a/setup.py
+++ b/setup.py
@@ -47,7 +47,6 @@ from distutils.command.clean import clean as Clean
from distutils.command.build import build as _build
try:
from setuptools import Command
- from setuptools.command.build_py import build_py as _build_py
from setuptools.command.sdist import sdist
try:
from Cython.Build import build_ext
@@ -60,7 +59,6 @@ except ImportError:
from numpy.distutils.core import Command
except ImportError:
from distutils.core import Command
- from distutils.command.build_py import build_py as _build_py
from distutils.command.sdist import sdist
try:
from Cython.Build import build_ext
@@ -88,13 +86,14 @@ export LC_ALL=en_US.utf-8
""")
-def get_version():
- """Returns current version number from version.py file"""
- dirname = os.path.dirname(os.path.abspath(__file__))
+def get_version(debian=False):
+ """Returns current version number from _version.py file"""
+ dirname = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "src", PROJECT)
sys.path.insert(0, dirname)
- import version
+ import _version
sys.path = sys.path[1:]
- return version.strictversion
+ return _version.debianversion if debian else _version.strictversion
def get_readme():
@@ -106,7 +105,7 @@ def get_readme():
return long_description
-classifiers = ["Development Status :: 4 - Beta",
+classifiers = ["Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Environment :: MacOS X",
"Environment :: Win32 (MS Windows)",
@@ -125,22 +124,6 @@ classifiers = ["Development Status :: 4 - Beta",
"Topic :: Software Development :: Libraries :: Python Modules",
]
-# ########## #
-# version.py #
-# ########## #
-
-
-class build_py(_build_py):
- """
- Enhanced build_py which copies version.py to <PROJECT>._version.py
- """
-
- def find_package_modules(self, package, package_dir):
- modules = _build_py.find_package_modules(self, package, package_dir)
- if package == PROJECT:
- modules.append((PROJECT, '_version', 'version.py'))
- return modules
-
########
# Test #
########
@@ -476,7 +459,9 @@ def configuration(parent_package='', top_path=None):
assume_default_configuration=True,
delegate_options_to_subpackages=True,
quiet=True)
- config.add_subpackage(PROJECT)
+ config.add_subpackage(
+ PROJECT, subpackage_path=os.path.join(
+ os.path.abspath(os.path.dirname(__file__)), 'src', PROJECT))
return config
# ############## #
@@ -778,8 +763,7 @@ class sdist_debian(sdist):
@staticmethod
def get_debian_name():
- import version
- name = "%s_%s" % (PROJECT, version.debianversion)
+ name = "%s_%s" % (PROJECT, get_version(debian=True))
return name
def prune_file_list(self):
@@ -846,8 +830,6 @@ def get_project_configuration(dry_run):
# for io support
"h5py",
"fabio>=0.9",
- # Python 2/3 compatibility
- "six",
]
# Add Python 2.7 backports
@@ -858,8 +840,6 @@ def get_project_configuration(dry_run):
install_requires.append("enum34")
install_requires.append("futures")
- setup_requires = ["setuptools", "numpy>=1.12", "Cython>=0.21.1"]
-
# extras requirements: target 'full' to install all dependencies at once
full_requires = [
# opencl
@@ -872,11 +852,18 @@ def get_project_configuration(dry_run):
'python-dateutil',
'PyQt5',
# extra
+ 'hdf5plugin',
'scipy',
'Pillow']
+ test_requires = [
+ "pytest",
+ "pytest-xvfb"
+ ]
+
extras_require = {
'full': full_requires,
+ 'test': test_requires,
}
# Here for packaging purpose only
@@ -910,7 +897,6 @@ def get_project_configuration(dry_run):
cmdclass = dict(
build=Build,
- build_py=build_py,
test=PyTest,
build_screenshots=BuildDocAndGenerateScreenshotCommand,
build_doc=BuildDocCommand,
@@ -940,7 +926,6 @@ def get_project_configuration(dry_run):
description="Software library for X-ray data analysis",
long_description=get_readme(),
install_requires=install_requires,
- setup_requires=setup_requires,
extras_require=extras_require,
cmdclass=cmdclass,
package_data=package_data,
diff --git a/silx.egg-info/PKG-INFO b/silx.egg-info/PKG-INFO
index 04ec406..43179ac 100644
--- a/silx.egg-info/PKG-INFO
+++ b/silx.egg-info/PKG-INFO
@@ -1,13 +1,13 @@
Metadata-Version: 2.1
Name: silx
-Version: 0.15.2
+Version: 1.0.0
Summary: Software library for X-ray data analysis
Home-page: http://www.silx.org/
Author: data analysis unit
Author-email: silx@esrf.fr
License: UNKNOWN
Platform: UNKNOWN
-Classifier: Development Status :: 4 - Beta
+Classifier: Development Status :: 5 - Production/Stable
Classifier: Environment :: Console
Classifier: Environment :: MacOS X
Classifier: Environment :: Win32 (MS Windows)
@@ -26,12 +26,16 @@ Classifier: Topic :: Scientific/Engineering :: Physics
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.5
Provides-Extra: full
+Provides-Extra: test
License-File: LICENSE
silx toolkit
============
+.. |silxView| image:: http://www.silx.org/doc/silx/img/silx-view-v1-0.gif
+ :height: 480px
+
The purpose of the *silx* project is to provide a collection of Python packages to support the
development of data assessment, reduction and analysis applications at synchrotron
radiation facilities.
@@ -59,8 +63,12 @@ The current version features:
* a set of applications:
* a unified viewer (*silx view filename*) for HDF5, SPEC and image file formats
+
+ |silxView|
+
* a unified converter to HDF5 format (*silx convert filename*)
+
Installation
------------
diff --git a/silx.egg-info/SOURCES.txt b/silx.egg-info/SOURCES.txt
index c6acc15..99516d4 100644
--- a/silx.egg-info/SOURCES.txt
+++ b/silx.egg-info/SOURCES.txt
@@ -10,7 +10,6 @@ requirements.txt
run_tests.py
setup.py
stdeb.cfg
-version.py
doc/source/changelog.rst
doc/source/conf.py
doc/source/index.rst
@@ -237,6 +236,7 @@ doc/source/modules/image/sift.rst
doc/source/modules/io/configdict.rst
doc/source/modules/io/convert.rst
doc/source/modules/io/dictdump.rst
+doc/source/modules/io/fioh5.rst
doc/source/modules/io/h5py_utils.rst
doc/source/modules/io/index.rst
doc/source/modules/io/nxdata.rst
@@ -399,7 +399,7 @@ package/desktop/silx.xml
package/windows/README.rst
package/windows/bootstrap-silx-view.py
package/windows/bootstrap.py
-package/windows/pyinstaller-silx-view.spec
+package/windows/create-installer.iss.template
package/windows/pyinstaller.spec
package/windows/silx.ico
qtdesigner_plugins/README.rst
@@ -407,10 +407,6 @@ qtdesigner_plugins/plot1dplugin.py
qtdesigner_plugins/plot2dplugin.py
qtdesigner_plugins/plotwidgetplugin.py
qtdesigner_plugins/plotwindowplugin.py
-silx/__init__.py
-silx/__main__.py
-silx/_config.py
-silx/setup.py
silx.egg-info/PKG-INFO
silx.egg-info/SOURCES.txt
silx.egg-info/dependency_links.txt
@@ -418,1106 +414,1122 @@ silx.egg-info/entry_points.txt
silx.egg-info/not-zip-safe
silx.egg-info/requires.txt
silx.egg-info/top_level.txt
-silx/app/__init__.py
-silx/app/convert.py
-silx/app/setup.py
-silx/app/test_.py
-silx/app/test/__init__.py
-silx/app/test/test_convert.py
-silx/app/view/About.py
-silx/app/view/ApplicationContext.py
-silx/app/view/CustomNxdataWidget.py
-silx/app/view/DataPanel.py
-silx/app/view/Viewer.py
-silx/app/view/__init__.py
-silx/app/view/main.py
-silx/app/view/setup.py
-silx/app/view/utils.py
-silx/app/view/test/__init__.py
-silx/app/view/test/test_launcher.py
-silx/app/view/test/test_view.py
-silx/gui/__init__.py
-silx/gui/colors.py
-silx/gui/console.py
-silx/gui/icons.py
-silx/gui/printer.py
-silx/gui/setup.py
-silx/gui/_glutils/Context.py
-silx/gui/_glutils/FramebufferTexture.py
-silx/gui/_glutils/OpenGLWidget.py
-silx/gui/_glutils/Program.py
-silx/gui/_glutils/Texture.py
-silx/gui/_glutils/VertexBuffer.py
-silx/gui/_glutils/__init__.py
-silx/gui/_glutils/font.py
-silx/gui/_glutils/gl.py
-silx/gui/_glutils/utils.py
-silx/gui/data/ArrayTableModel.py
-silx/gui/data/ArrayTableWidget.py
-silx/gui/data/DataViewer.py
-silx/gui/data/DataViewerFrame.py
-silx/gui/data/DataViewerSelector.py
-silx/gui/data/DataViews.py
-silx/gui/data/Hdf5TableView.py
-silx/gui/data/HexaTableView.py
-silx/gui/data/NXdataWidgets.py
-silx/gui/data/NumpyAxesSelector.py
-silx/gui/data/RecordTableView.py
-silx/gui/data/TextFormatter.py
-silx/gui/data/_RecordPlot.py
-silx/gui/data/_VolumeWindow.py
-silx/gui/data/__init__.py
-silx/gui/data/setup.py
-silx/gui/data/test/__init__.py
-silx/gui/data/test/test_arraywidget.py
-silx/gui/data/test/test_dataviewer.py
-silx/gui/data/test/test_numpyaxesselector.py
-silx/gui/data/test/test_textformatter.py
-silx/gui/dialog/AbstractDataFileDialog.py
-silx/gui/dialog/ColormapDialog.py
-silx/gui/dialog/DataFileDialog.py
-silx/gui/dialog/DatasetDialog.py
-silx/gui/dialog/FileTypeComboBox.py
-silx/gui/dialog/GroupDialog.py
-silx/gui/dialog/ImageFileDialog.py
-silx/gui/dialog/SafeFileIconProvider.py
-silx/gui/dialog/SafeFileSystemModel.py
-silx/gui/dialog/__init__.py
-silx/gui/dialog/setup.py
-silx/gui/dialog/utils.py
-silx/gui/dialog/test/__init__.py
-silx/gui/dialog/test/test_colormapdialog.py
-silx/gui/dialog/test/test_datafiledialog.py
-silx/gui/dialog/test/test_imagefiledialog.py
-silx/gui/fit/BackgroundWidget.py
-silx/gui/fit/FitConfig.py
-silx/gui/fit/FitWidget.py
-silx/gui/fit/FitWidgets.py
-silx/gui/fit/Parameters.py
-silx/gui/fit/__init__.py
-silx/gui/fit/setup.py
-silx/gui/fit/test/__init__.py
-silx/gui/fit/test/testBackgroundWidget.py
-silx/gui/fit/test/testFitConfig.py
-silx/gui/fit/test/testFitWidget.py
-silx/gui/hdf5/Hdf5Formatter.py
-silx/gui/hdf5/Hdf5HeaderView.py
-silx/gui/hdf5/Hdf5Item.py
-silx/gui/hdf5/Hdf5LoadingItem.py
-silx/gui/hdf5/Hdf5Node.py
-silx/gui/hdf5/Hdf5TreeModel.py
-silx/gui/hdf5/Hdf5TreeView.py
-silx/gui/hdf5/NexusSortFilterProxyModel.py
-silx/gui/hdf5/__init__.py
-silx/gui/hdf5/_utils.py
-silx/gui/hdf5/setup.py
-silx/gui/hdf5/test/__init__.py
-silx/gui/hdf5/test/test_hdf5.py
-silx/gui/plot/AlphaSlider.py
-silx/gui/plot/ColorBar.py
-silx/gui/plot/Colormap.py
-silx/gui/plot/ColormapDialog.py
-silx/gui/plot/Colors.py
-silx/gui/plot/CompareImages.py
-silx/gui/plot/ComplexImageView.py
-silx/gui/plot/CurvesROIWidget.py
-silx/gui/plot/ImageStack.py
-silx/gui/plot/ImageView.py
-silx/gui/plot/Interaction.py
-silx/gui/plot/ItemsSelectionDialog.py
-silx/gui/plot/LegendSelector.py
-silx/gui/plot/LimitsHistory.py
-silx/gui/plot/MaskToolsWidget.py
-silx/gui/plot/PlotActions.py
-silx/gui/plot/PlotEvents.py
-silx/gui/plot/PlotInteraction.py
-silx/gui/plot/PlotToolButtons.py
-silx/gui/plot/PlotTools.py
-silx/gui/plot/PlotWidget.py
-silx/gui/plot/PlotWindow.py
-silx/gui/plot/PrintPreviewToolButton.py
-silx/gui/plot/Profile.py
-silx/gui/plot/ProfileMainWindow.py
-silx/gui/plot/ROIStatsWidget.py
-silx/gui/plot/ScatterMaskToolsWidget.py
-silx/gui/plot/ScatterView.py
-silx/gui/plot/StackView.py
-silx/gui/plot/StatsWidget.py
-silx/gui/plot/_BaseMaskToolsWidget.py
-silx/gui/plot/__init__.py
-silx/gui/plot/setup.py
-silx/gui/plot/_utils/__init__.py
-silx/gui/plot/_utils/delaunay.py
-silx/gui/plot/_utils/dtime_ticklayout.py
-silx/gui/plot/_utils/panzoom.py
-silx/gui/plot/_utils/setup.py
-silx/gui/plot/_utils/ticklayout.py
-silx/gui/plot/_utils/test/__init__.py
-silx/gui/plot/_utils/test/test_dtime_ticklayout.py
-silx/gui/plot/_utils/test/test_ticklayout.py
-silx/gui/plot/actions/PlotAction.py
-silx/gui/plot/actions/PlotToolAction.py
-silx/gui/plot/actions/__init__.py
-silx/gui/plot/actions/control.py
-silx/gui/plot/actions/fit.py
-silx/gui/plot/actions/histogram.py
-silx/gui/plot/actions/io.py
-silx/gui/plot/actions/medfilt.py
-silx/gui/plot/actions/mode.py
-silx/gui/plot/backends/BackendBase.py
-silx/gui/plot/backends/BackendMatplotlib.py
-silx/gui/plot/backends/BackendOpenGL.py
-silx/gui/plot/backends/__init__.py
-silx/gui/plot/backends/glutils/GLPlotCurve.py
-silx/gui/plot/backends/glutils/GLPlotFrame.py
-silx/gui/plot/backends/glutils/GLPlotImage.py
-silx/gui/plot/backends/glutils/GLPlotItem.py
-silx/gui/plot/backends/glutils/GLPlotTriangles.py
-silx/gui/plot/backends/glutils/GLSupport.py
-silx/gui/plot/backends/glutils/GLText.py
-silx/gui/plot/backends/glutils/GLTexture.py
-silx/gui/plot/backends/glutils/PlotImageFile.py
-silx/gui/plot/backends/glutils/__init__.py
-silx/gui/plot/items/__init__.py
-silx/gui/plot/items/_arc_roi.py
-silx/gui/plot/items/_pick.py
-silx/gui/plot/items/_roi_base.py
-silx/gui/plot/items/axis.py
-silx/gui/plot/items/complex.py
-silx/gui/plot/items/core.py
-silx/gui/plot/items/curve.py
-silx/gui/plot/items/histogram.py
-silx/gui/plot/items/image.py
-silx/gui/plot/items/marker.py
-silx/gui/plot/items/roi.py
-silx/gui/plot/items/scatter.py
-silx/gui/plot/items/shape.py
-silx/gui/plot/matplotlib/Colormap.py
-silx/gui/plot/matplotlib/__init__.py
-silx/gui/plot/stats/__init__.py
-silx/gui/plot/stats/stats.py
-silx/gui/plot/stats/statshandler.py
-silx/gui/plot/test/__init__.py
-silx/gui/plot/test/testAlphaSlider.py
-silx/gui/plot/test/testColorBar.py
-silx/gui/plot/test/testCompareImages.py
-silx/gui/plot/test/testComplexImageView.py
-silx/gui/plot/test/testCurvesROIWidget.py
-silx/gui/plot/test/testImageStack.py
-silx/gui/plot/test/testImageView.py
-silx/gui/plot/test/testInteraction.py
-silx/gui/plot/test/testItem.py
-silx/gui/plot/test/testLegendSelector.py
-silx/gui/plot/test/testLimitConstraints.py
-silx/gui/plot/test/testMaskToolsWidget.py
-silx/gui/plot/test/testPixelIntensityHistoAction.py
-silx/gui/plot/test/testPlotInteraction.py
-silx/gui/plot/test/testPlotWidget.py
-silx/gui/plot/test/testPlotWidgetNoBackend.py
-silx/gui/plot/test/testPlotWindow.py
-silx/gui/plot/test/testRoiStatsWidget.py
-silx/gui/plot/test/testSaveAction.py
-silx/gui/plot/test/testScatterMaskToolsWidget.py
-silx/gui/plot/test/testScatterView.py
-silx/gui/plot/test/testStackView.py
-silx/gui/plot/test/testStats.py
-silx/gui/plot/test/testUtilsAxis.py
-silx/gui/plot/test/utils.py
-silx/gui/plot/tools/CurveLegendsWidget.py
-silx/gui/plot/tools/LimitsToolBar.py
-silx/gui/plot/tools/PositionInfo.py
-silx/gui/plot/tools/RadarView.py
-silx/gui/plot/tools/__init__.py
-silx/gui/plot/tools/roi.py
-silx/gui/plot/tools/toolbars.py
-silx/gui/plot/tools/profile/ScatterProfileToolBar.py
-silx/gui/plot/tools/profile/__init__.py
-silx/gui/plot/tools/profile/core.py
-silx/gui/plot/tools/profile/editors.py
-silx/gui/plot/tools/profile/manager.py
-silx/gui/plot/tools/profile/rois.py
-silx/gui/plot/tools/profile/toolbar.py
-silx/gui/plot/tools/test/__init__.py
-silx/gui/plot/tools/test/testCurveLegendsWidget.py
-silx/gui/plot/tools/test/testProfile.py
-silx/gui/plot/tools/test/testROI.py
-silx/gui/plot/tools/test/testScatterProfileToolBar.py
-silx/gui/plot/tools/test/testTools.py
-silx/gui/plot/utils/__init__.py
-silx/gui/plot/utils/axis.py
-silx/gui/plot/utils/intersections.py
-silx/gui/plot3d/ParamTreeView.py
-silx/gui/plot3d/Plot3DWidget.py
-silx/gui/plot3d/Plot3DWindow.py
-silx/gui/plot3d/SFViewParamTree.py
-silx/gui/plot3d/ScalarFieldView.py
-silx/gui/plot3d/SceneWidget.py
-silx/gui/plot3d/SceneWindow.py
-silx/gui/plot3d/__init__.py
-silx/gui/plot3d/setup.py
-silx/gui/plot3d/_model/__init__.py
-silx/gui/plot3d/_model/core.py
-silx/gui/plot3d/_model/items.py
-silx/gui/plot3d/_model/model.py
-silx/gui/plot3d/actions/Plot3DAction.py
-silx/gui/plot3d/actions/__init__.py
-silx/gui/plot3d/actions/io.py
-silx/gui/plot3d/actions/mode.py
-silx/gui/plot3d/actions/viewpoint.py
-silx/gui/plot3d/items/__init__.py
-silx/gui/plot3d/items/_pick.py
-silx/gui/plot3d/items/clipplane.py
-silx/gui/plot3d/items/core.py
-silx/gui/plot3d/items/image.py
-silx/gui/plot3d/items/mesh.py
-silx/gui/plot3d/items/mixins.py
-silx/gui/plot3d/items/scatter.py
-silx/gui/plot3d/items/volume.py
-silx/gui/plot3d/scene/__init__.py
-silx/gui/plot3d/scene/axes.py
-silx/gui/plot3d/scene/camera.py
-silx/gui/plot3d/scene/core.py
-silx/gui/plot3d/scene/cutplane.py
-silx/gui/plot3d/scene/event.py
-silx/gui/plot3d/scene/function.py
-silx/gui/plot3d/scene/interaction.py
-silx/gui/plot3d/scene/primitives.py
-silx/gui/plot3d/scene/text.py
-silx/gui/plot3d/scene/transform.py
-silx/gui/plot3d/scene/utils.py
-silx/gui/plot3d/scene/viewport.py
-silx/gui/plot3d/scene/window.py
-silx/gui/plot3d/scene/test/__init__.py
-silx/gui/plot3d/scene/test/test_transform.py
-silx/gui/plot3d/scene/test/test_utils.py
-silx/gui/plot3d/test/__init__.py
-silx/gui/plot3d/test/testGL.py
-silx/gui/plot3d/test/testScalarFieldView.py
-silx/gui/plot3d/test/testSceneWidget.py
-silx/gui/plot3d/test/testSceneWidgetPicking.py
-silx/gui/plot3d/test/testSceneWindow.py
-silx/gui/plot3d/test/testStatsWidget.py
-silx/gui/plot3d/tools/GroupPropertiesWidget.py
-silx/gui/plot3d/tools/PositionInfoWidget.py
-silx/gui/plot3d/tools/ViewpointTools.py
-silx/gui/plot3d/tools/__init__.py
-silx/gui/plot3d/tools/toolbars.py
-silx/gui/plot3d/tools/test/__init__.py
-silx/gui/plot3d/tools/test/testPositionInfoWidget.py
-silx/gui/plot3d/utils/__init__.py
-silx/gui/plot3d/utils/mng.py
-silx/gui/qt/__init__.py
-silx/gui/qt/_macosx.py
-silx/gui/qt/_pyside_dynamic.py
-silx/gui/qt/_pyside_missing.py
-silx/gui/qt/_qt.py
-silx/gui/qt/_utils.py
-silx/gui/qt/inspect.py
-silx/gui/test/__init__.py
-silx/gui/test/test_colors.py
-silx/gui/test/test_console.py
-silx/gui/test/test_icons.py
-silx/gui/test/test_qt.py
-silx/gui/test/utils.py
-silx/gui/utils/__init__.py
-silx/gui/utils/concurrent.py
-silx/gui/utils/image.py
-silx/gui/utils/matplotlib.py
-silx/gui/utils/projecturl.py
-silx/gui/utils/qtutils.py
-silx/gui/utils/signal.py
-silx/gui/utils/testutils.py
-silx/gui/utils/glutils/__init__.py
-silx/gui/utils/test/__init__.py
-silx/gui/utils/test/test.py
-silx/gui/utils/test/test_async.py
-silx/gui/utils/test/test_glutils.py
-silx/gui/utils/test/test_image.py
-silx/gui/utils/test/test_qtutils.py
-silx/gui/utils/test/test_testutils.py
-silx/gui/widgets/BoxLayoutDockWidget.py
-silx/gui/widgets/ColormapNameComboBox.py
-silx/gui/widgets/ElidedLabel.py
-silx/gui/widgets/FloatEdit.py
-silx/gui/widgets/FlowLayout.py
-silx/gui/widgets/FrameBrowser.py
-silx/gui/widgets/HierarchicalTableView.py
-silx/gui/widgets/LegendIconWidget.py
-silx/gui/widgets/MedianFilterDialog.py
-silx/gui/widgets/MultiModeAction.py
-silx/gui/widgets/PeriodicTable.py
-silx/gui/widgets/PrintGeometryDialog.py
-silx/gui/widgets/PrintPreview.py
-silx/gui/widgets/RangeSlider.py
-silx/gui/widgets/TableWidget.py
-silx/gui/widgets/ThreadPoolPushButton.py
-silx/gui/widgets/UrlSelectionTable.py
-silx/gui/widgets/WaitingPushButton.py
-silx/gui/widgets/__init__.py
-silx/gui/widgets/setup.py
-silx/gui/widgets/test/__init__.py
-silx/gui/widgets/test/test_boxlayoutdockwidget.py
-silx/gui/widgets/test/test_elidedlabel.py
-silx/gui/widgets/test/test_flowlayout.py
-silx/gui/widgets/test/test_framebrowser.py
-silx/gui/widgets/test/test_hierarchicaltableview.py
-silx/gui/widgets/test/test_legendiconwidget.py
-silx/gui/widgets/test/test_periodictable.py
-silx/gui/widgets/test/test_printpreview.py
-silx/gui/widgets/test/test_rangeslider.py
-silx/gui/widgets/test/test_tablewidget.py
-silx/gui/widgets/test/test_threadpoolpushbutton.py
-silx/image/__init__.py
-silx/image/_boundingbox.py
-silx/image/backprojection.py
-silx/image/bilinear.pyx
-silx/image/medianfilter.py
-silx/image/phantomgenerator.py
-silx/image/projection.py
-silx/image/reconstruction.py
-silx/image/setup.py
-silx/image/shapes.pyx
-silx/image/sift.py
-silx/image/tomography.py
-silx/image/utils.py
-silx/image/marchingsquares/__init__.py
-silx/image/marchingsquares/_mergeimpl.pyx
-silx/image/marchingsquares/_skimage.py
-silx/image/marchingsquares/setup.py
-silx/image/marchingsquares/include/patterns.h
-silx/image/marchingsquares/test/__init__.py
-silx/image/marchingsquares/test/test_funcapi.py
-silx/image/marchingsquares/test/test_mergeimpl.py
-silx/image/test/__init__.py
-silx/image/test/test_bb.py
-silx/image/test/test_bilinear.py
-silx/image/test/test_medianfilter.py
-silx/image/test/test_shapes.py
-silx/image/test/test_tomography.py
-silx/io/__init__.py
-silx/io/commonh5.py
-silx/io/configdict.py
-silx/io/convert.py
-silx/io/dictdump.py
-silx/io/fabioh5.py
-silx/io/h5py_utils.py
-silx/io/octaveh5.py
-silx/io/rawh5.py
-silx/io/setup.py
-silx/io/specfile.pyx
-silx/io/specfile_wrapper.pxd
-silx/io/specfilewrapper.py
-silx/io/spech5.py
-silx/io/spectoh5.py
-silx/io/url.py
-silx/io/utils.py
-silx/io/nxdata/__init__.py
-silx/io/nxdata/_utils.py
-silx/io/nxdata/parse.py
-silx/io/nxdata/write.py
-silx/io/specfile/include/Lists.h
-silx/io/specfile/include/SpecFile.h
-silx/io/specfile/include/SpecFileCython.h
-silx/io/specfile/include/SpecFileP.h
-silx/io/specfile/include/locale_management.h
-silx/io/specfile/src/locale_management.c
-silx/io/specfile/src/sfdata.c
-silx/io/specfile/src/sfheader.c
-silx/io/specfile/src/sfindex.c
-silx/io/specfile/src/sfinit.c
-silx/io/specfile/src/sflabel.c
-silx/io/specfile/src/sflists.c
-silx/io/specfile/src/sfmca.c
-silx/io/specfile/src/sftools.c
-silx/io/specfile/src/sfwrite.c
-silx/io/test/__init__.py
-silx/io/test/test_commonh5.py
-silx/io/test/test_dictdump.py
-silx/io/test/test_fabioh5.py
-silx/io/test/test_h5py_utils.py
-silx/io/test/test_nxdata.py
-silx/io/test/test_octaveh5.py
-silx/io/test/test_rawh5.py
-silx/io/test/test_specfile.py
-silx/io/test/test_specfilewrapper.py
-silx/io/test/test_spech5.py
-silx/io/test/test_spectoh5.py
-silx/io/test/test_url.py
-silx/io/test/test_utils.py
-silx/math/__init__.py
-silx/math/calibration.py
-silx/math/chistogramnd.pyx
-silx/math/chistogramnd_lut.pyx
-silx/math/colormap.pyx
-silx/math/combo.pyx
-silx/math/histogram.py
-silx/math/histogramnd_c.pxd
-silx/math/interpolate.pyx
-silx/math/marchingcubes.pyx
-silx/math/math_compatibility.pxd
-silx/math/mc.pxd
-silx/math/setup.py
-silx/math/fft/__init__.py
-silx/math/fft/basefft.py
-silx/math/fft/clfft.py
-silx/math/fft/cufft.py
-silx/math/fft/fft.py
-silx/math/fft/fftw.py
-silx/math/fft/npfft.py
-silx/math/fft/setup.py
-silx/math/fft/test/__init__.py
-silx/math/fft/test/test_fft.py
-silx/math/fit/__init__.py
-silx/math/fit/bgtheories.py
-silx/math/fit/filters.pyx
-silx/math/fit/filters_wrapper.pxd
-silx/math/fit/fitmanager.py
-silx/math/fit/fittheories.py
-silx/math/fit/fittheory.py
-silx/math/fit/functions.pyx
-silx/math/fit/functions_wrapper.pxd
-silx/math/fit/leastsq.py
-silx/math/fit/peaks.pyx
-silx/math/fit/peaks_wrapper.pxd
-silx/math/fit/setup.py
-silx/math/fit/filters/include/filters.h
-silx/math/fit/filters/src/smoothnd.c
-silx/math/fit/filters/src/snip1d.c
-silx/math/fit/filters/src/snip2d.c
-silx/math/fit/filters/src/snip3d.c
-silx/math/fit/filters/src/strip.c
-silx/math/fit/functions/include/functions.h
-silx/math/fit/functions/src/funs.c
-silx/math/fit/peaks/include/peaks.h
-silx/math/fit/peaks/src/peaks.c
-silx/math/fit/test/__init__.py
-silx/math/fit/test/test_bgtheories.py
-silx/math/fit/test/test_filters.py
-silx/math/fit/test/test_fit.py
-silx/math/fit/test/test_fitmanager.py
-silx/math/fit/test/test_functions.py
-silx/math/fit/test/test_peaks.py
-silx/math/histogramnd/include/histogramnd_c.h
-silx/math/histogramnd/include/templates.h
-silx/math/histogramnd/include/msvc/stdint.h
-silx/math/histogramnd/src/histogramnd_c.c
-silx/math/histogramnd/src/histogramnd_template.c
-silx/math/include/math_compatibility.h
-silx/math/marchingcubes/mc.hpp
-silx/math/marchingcubes/mc_lut.cpp
-silx/math/medianfilter/__init__.py
-silx/math/medianfilter/median_filter.pxd
-silx/math/medianfilter/medianfilter.pyx
-silx/math/medianfilter/setup.py
-silx/math/medianfilter/include/median_filter.hpp
-silx/math/medianfilter/test/__init__.py
-silx/math/medianfilter/test/benchmark.py
-silx/math/medianfilter/test/test_medianfilter.py
-silx/math/test/__init__.py
-silx/math/test/benchmark_combo.py
-silx/math/test/histo_benchmarks.py
-silx/math/test/test_HistogramndLut_nominal.py
-silx/math/test/test_calibration.py
-silx/math/test/test_colormap.py
-silx/math/test/test_combo.py
-silx/math/test/test_histogramnd_error.py
-silx/math/test/test_histogramnd_nominal.py
-silx/math/test/test_histogramnd_vs_np.py
-silx/math/test/test_interpolate.py
-silx/math/test/test_marchingcubes.py
-silx/opencl/__init__.py
-silx/opencl/backprojection.py
-silx/opencl/common.py
-silx/opencl/convolution.py
-silx/opencl/image.py
-silx/opencl/linalg.py
-silx/opencl/medfilt.py
-silx/opencl/processing.py
-silx/opencl/projection.py
-silx/opencl/reconstruction.py
-silx/opencl/setup.py
-silx/opencl/sinofilter.py
-silx/opencl/sparse.py
-silx/opencl/statistics.py
-silx/opencl/utils.py
-silx/opencl/codec/__init__.py
-silx/opencl/codec/byte_offset.py
-silx/opencl/codec/setup.py
-silx/opencl/codec/test/__init__.py
-silx/opencl/codec/test/test_byte_offset.py
-silx/opencl/sift/__init__.py
-silx/opencl/sift/alignment.py
-silx/opencl/sift/match.py
-silx/opencl/sift/param.py
-silx/opencl/sift/plan.py
-silx/opencl/sift/setup.py
-silx/opencl/sift/sift.py
-silx/opencl/sift/utils.py
-silx/opencl/sift/test/__init__.py
-silx/opencl/sift/test/test_algebra.py
-silx/opencl/sift/test/test_align.py
-silx/opencl/sift/test/test_convol.py
-silx/opencl/sift/test/test_gaussian.py
-silx/opencl/sift/test/test_image.py
-silx/opencl/sift/test/test_image_functions.py
-silx/opencl/sift/test/test_image_setup.py
-silx/opencl/sift/test/test_keypoints.py
-silx/opencl/sift/test/test_matching.py
-silx/opencl/sift/test/test_preproc.py
-silx/opencl/sift/test/test_reductions.py
-silx/opencl/sift/test/test_transform.py
-silx/opencl/test/__init__.py
-silx/opencl/test/test_addition.py
-silx/opencl/test/test_array_utils.py
-silx/opencl/test/test_backprojection.py
-silx/opencl/test/test_convolution.py
-silx/opencl/test/test_doubleword.py
-silx/opencl/test/test_image.py
-silx/opencl/test/test_kahan.py
-silx/opencl/test/test_linalg.py
-silx/opencl/test/test_medfilt.py
-silx/opencl/test/test_projection.py
-silx/opencl/test/test_sparse.py
-silx/opencl/test/test_stats.py
-silx/resources/__init__.py
-silx/resources/gui/colormaps/cividis.npy
-silx/resources/gui/colormaps/inferno.npy
-silx/resources/gui/colormaps/magma.npy
-silx/resources/gui/colormaps/plasma.npy
-silx/resources/gui/colormaps/viridis.npy
-silx/resources/gui/icons/3d-plane-normal-x.png
-silx/resources/gui/icons/3d-plane-normal-x.svg
-silx/resources/gui/icons/3d-plane-normal-y.png
-silx/resources/gui/icons/3d-plane-normal-y.svg
-silx/resources/gui/icons/3d-plane-normal-z.png
-silx/resources/gui/icons/3d-plane-normal-z.svg
-silx/resources/gui/icons/3d-plane-pan.png
-silx/resources/gui/icons/3d-plane-pan.svg
-silx/resources/gui/icons/3d-plane.png
-silx/resources/gui/icons/3d-plane.svg
-silx/resources/gui/icons/add-range-horizontal.png
-silx/resources/gui/icons/add-range-horizontal.svg
-silx/resources/gui/icons/add-shape-arc.png
-silx/resources/gui/icons/add-shape-arc.svg
-silx/resources/gui/icons/add-shape-circle.png
-silx/resources/gui/icons/add-shape-circle.svg
-silx/resources/gui/icons/add-shape-cross.png
-silx/resources/gui/icons/add-shape-cross.svg
-silx/resources/gui/icons/add-shape-diagonal.png
-silx/resources/gui/icons/add-shape-diagonal.svg
-silx/resources/gui/icons/add-shape-ellipse.png
-silx/resources/gui/icons/add-shape-ellipse.svg
-silx/resources/gui/icons/add-shape-horizontal.png
-silx/resources/gui/icons/add-shape-horizontal.svg
-silx/resources/gui/icons/add-shape-point.png
-silx/resources/gui/icons/add-shape-point.svg
-silx/resources/gui/icons/add-shape-polygon.png
-silx/resources/gui/icons/add-shape-polygon.svg
-silx/resources/gui/icons/add-shape-rectangle.png
-silx/resources/gui/icons/add-shape-rectangle.svg
-silx/resources/gui/icons/add-shape-unknown.png
-silx/resources/gui/icons/add-shape-unknown.svg
-silx/resources/gui/icons/add-shape-vertical.png
-silx/resources/gui/icons/add-shape-vertical.svg
-silx/resources/gui/icons/add.png
-silx/resources/gui/icons/add.svg
-silx/resources/gui/icons/arrow-keys.png
-silx/resources/gui/icons/arrow-keys.svg
-silx/resources/gui/icons/axis.png
-silx/resources/gui/icons/axis.svg
-silx/resources/gui/icons/backend-opengl.png
-silx/resources/gui/icons/backend-opengl.svg
-silx/resources/gui/icons/camera.png
-silx/resources/gui/icons/camera.svg
-silx/resources/gui/icons/clipboard.png
-silx/resources/gui/icons/clipboard.svg
-silx/resources/gui/icons/close.png
-silx/resources/gui/icons/close.svg
-silx/resources/gui/icons/colorbar.png
-silx/resources/gui/icons/colorbar.svg
-silx/resources/gui/icons/colormap-histogram.png
-silx/resources/gui/icons/colormap-histogram.svg
-silx/resources/gui/icons/colormap-none.png
-silx/resources/gui/icons/colormap-none.svg
-silx/resources/gui/icons/colormap-norm-arcsinh.png
-silx/resources/gui/icons/colormap-norm-arcsinh.svg
-silx/resources/gui/icons/colormap-norm-gamma.png
-silx/resources/gui/icons/colormap-norm-gamma.svg
-silx/resources/gui/icons/colormap-norm-linear.png
-silx/resources/gui/icons/colormap-norm-linear.svg
-silx/resources/gui/icons/colormap-norm-log.png
-silx/resources/gui/icons/colormap-norm-log.svg
-silx/resources/gui/icons/colormap-norm-sqrt.png
-silx/resources/gui/icons/colormap-norm-sqrt.svg
-silx/resources/gui/icons/colormap-range.png
-silx/resources/gui/icons/colormap-range.svg
-silx/resources/gui/icons/colormap.png
-silx/resources/gui/icons/colormap.svg
-silx/resources/gui/icons/compare-align-auto.png
-silx/resources/gui/icons/compare-align-auto.svg
-silx/resources/gui/icons/compare-align-center.png
-silx/resources/gui/icons/compare-align-center.svg
-silx/resources/gui/icons/compare-align-origin.png
-silx/resources/gui/icons/compare-align-origin.svg
-silx/resources/gui/icons/compare-align-stretch.png
-silx/resources/gui/icons/compare-align-stretch.svg
-silx/resources/gui/icons/compare-keypoints.png
-silx/resources/gui/icons/compare-keypoints.svg
-silx/resources/gui/icons/compare-mode-a-minus-b.png
-silx/resources/gui/icons/compare-mode-a-minus-b.svg
-silx/resources/gui/icons/compare-mode-a.png
-silx/resources/gui/icons/compare-mode-a.svg
-silx/resources/gui/icons/compare-mode-b.png
-silx/resources/gui/icons/compare-mode-b.svg
-silx/resources/gui/icons/compare-mode-hline.png
-silx/resources/gui/icons/compare-mode-hline.svg
-silx/resources/gui/icons/compare-mode-rb-channel.png
-silx/resources/gui/icons/compare-mode-rb-channel.svg
-silx/resources/gui/icons/compare-mode-rbneg-channel.png
-silx/resources/gui/icons/compare-mode-rbneg-channel.svg
-silx/resources/gui/icons/compare-mode-vline.png
-silx/resources/gui/icons/compare-mode-vline.svg
-silx/resources/gui/icons/crop.png
-silx/resources/gui/icons/crop.svg
-silx/resources/gui/icons/crosshair.png
-silx/resources/gui/icons/crosshair.svg
-silx/resources/gui/icons/cube-back.png
-silx/resources/gui/icons/cube-back.svg
-silx/resources/gui/icons/cube-bottom.png
-silx/resources/gui/icons/cube-bottom.svg
-silx/resources/gui/icons/cube-front.png
-silx/resources/gui/icons/cube-front.svg
-silx/resources/gui/icons/cube-left.png
-silx/resources/gui/icons/cube-left.svg
-silx/resources/gui/icons/cube-right.png
-silx/resources/gui/icons/cube-right.svg
-silx/resources/gui/icons/cube-rotate.png
-silx/resources/gui/icons/cube-rotate.svg
-silx/resources/gui/icons/cube-top.png
-silx/resources/gui/icons/cube-top.svg
-silx/resources/gui/icons/cube.png
-silx/resources/gui/icons/cube.svg
-silx/resources/gui/icons/description-description.png
-silx/resources/gui/icons/description-description.svg
-silx/resources/gui/icons/description-error.png
-silx/resources/gui/icons/description-error.svg
-silx/resources/gui/icons/description-name.png
-silx/resources/gui/icons/description-name.svg
-silx/resources/gui/icons/description-program.png
-silx/resources/gui/icons/description-program.svg
-silx/resources/gui/icons/description-title.png
-silx/resources/gui/icons/description-title.svg
-silx/resources/gui/icons/description-value.png
-silx/resources/gui/icons/description-value.svg
-silx/resources/gui/icons/document-open.png
-silx/resources/gui/icons/document-open.svg
-silx/resources/gui/icons/document-print.png
-silx/resources/gui/icons/document-print.svg
-silx/resources/gui/icons/document-save.png
-silx/resources/gui/icons/document-save.svg
-silx/resources/gui/icons/draw-brush.png
-silx/resources/gui/icons/draw-brush.svg
-silx/resources/gui/icons/draw-pencil.png
-silx/resources/gui/icons/draw-pencil.svg
-silx/resources/gui/icons/draw-rubber.png
-silx/resources/gui/icons/draw-rubber.svg
-silx/resources/gui/icons/edit-copy.png
-silx/resources/gui/icons/edit-copy.svg
-silx/resources/gui/icons/eye.png
-silx/resources/gui/icons/eye.svg
-silx/resources/gui/icons/first.png
-silx/resources/gui/icons/first.svg
-silx/resources/gui/icons/folder.png
-silx/resources/gui/icons/folder.svg
-silx/resources/gui/icons/image-mask.png
-silx/resources/gui/icons/image-mask.svg
-silx/resources/gui/icons/image-select-add.png
-silx/resources/gui/icons/image-select-add.svg
-silx/resources/gui/icons/image-select-box.png
-silx/resources/gui/icons/image-select-box.svg
-silx/resources/gui/icons/image-select-brush.png
-silx/resources/gui/icons/image-select-brush.svg
-silx/resources/gui/icons/image-select-erase-rubber.png
-silx/resources/gui/icons/image-select-erase-rubber.svg
-silx/resources/gui/icons/image-select-erase.png
-silx/resources/gui/icons/image-select-erase.svg
-silx/resources/gui/icons/image.png
-silx/resources/gui/icons/image.svg
-silx/resources/gui/icons/item-0dim.png
-silx/resources/gui/icons/item-0dim.svg
-silx/resources/gui/icons/item-1dim.png
-silx/resources/gui/icons/item-1dim.svg
-silx/resources/gui/icons/item-2dim.png
-silx/resources/gui/icons/item-2dim.svg
-silx/resources/gui/icons/item-3dim.png
-silx/resources/gui/icons/item-3dim.svg
-silx/resources/gui/icons/item-ndim.png
-silx/resources/gui/icons/item-ndim.svg
-silx/resources/gui/icons/item-none.png
-silx/resources/gui/icons/item-none.svg
-silx/resources/gui/icons/item-object.png
-silx/resources/gui/icons/item-object.svg
-silx/resources/gui/icons/last.png
-silx/resources/gui/icons/last.svg
-silx/resources/gui/icons/layer-nx.png
-silx/resources/gui/icons/layer-nx.svg
-silx/resources/gui/icons/mask-clear-all.png
-silx/resources/gui/icons/mask-clear-all.svg
-silx/resources/gui/icons/mask-clear.png
-silx/resources/gui/icons/mask-clear.svg
-silx/resources/gui/icons/mask-invert.png
-silx/resources/gui/icons/mask-invert.svg
-silx/resources/gui/icons/math-amplitude.png
-silx/resources/gui/icons/math-amplitude.svg
-silx/resources/gui/icons/math-average.png
-silx/resources/gui/icons/math-average.svg
-silx/resources/gui/icons/math-derive.png
-silx/resources/gui/icons/math-derive.svg
-silx/resources/gui/icons/math-energy.png
-silx/resources/gui/icons/math-energy.svg
-silx/resources/gui/icons/math-fit.png
-silx/resources/gui/icons/math-fit.svg
-silx/resources/gui/icons/math-imaginary.png
-silx/resources/gui/icons/math-imaginary.svg
-silx/resources/gui/icons/math-mean.png
-silx/resources/gui/icons/math-mean.svg
-silx/resources/gui/icons/math-normalize.png
-silx/resources/gui/icons/math-normalize.svg
-silx/resources/gui/icons/math-peak-reset.png
-silx/resources/gui/icons/math-peak-reset.svg
-silx/resources/gui/icons/math-peak-search.png
-silx/resources/gui/icons/math-peak-search.svg
-silx/resources/gui/icons/math-peak.png
-silx/resources/gui/icons/math-peak.svg
-silx/resources/gui/icons/math-phase-color-log.png
-silx/resources/gui/icons/math-phase-color-log.svg
-silx/resources/gui/icons/math-phase-color.png
-silx/resources/gui/icons/math-phase-color.svg
-silx/resources/gui/icons/math-phase.png
-silx/resources/gui/icons/math-phase.svg
-silx/resources/gui/icons/math-real.png
-silx/resources/gui/icons/math-real.svg
-silx/resources/gui/icons/math-sigma.png
-silx/resources/gui/icons/math-sigma.svg
-silx/resources/gui/icons/math-smooth.png
-silx/resources/gui/icons/math-smooth.svg
-silx/resources/gui/icons/math-square-amplitude.png
-silx/resources/gui/icons/math-square-amplitude.svg
-silx/resources/gui/icons/math-substract.png
-silx/resources/gui/icons/math-substract.svg
-silx/resources/gui/icons/math-swap-sign.png
-silx/resources/gui/icons/math-swap-sign.svg
-silx/resources/gui/icons/math-ymin-to-zero.png
-silx/resources/gui/icons/math-ymin-to-zero.svg
-silx/resources/gui/icons/median-filter.png
-silx/resources/gui/icons/median-filter.svg
-silx/resources/gui/icons/next.png
-silx/resources/gui/icons/next.svg
-silx/resources/gui/icons/normal.png
-silx/resources/gui/icons/normal.svg
-silx/resources/gui/icons/nxdata-axis-add.png
-silx/resources/gui/icons/nxdata-axis-add.svg
-silx/resources/gui/icons/nxdata-axis-remove.png
-silx/resources/gui/icons/nxdata-axis-remove.svg
-silx/resources/gui/icons/nxdata-create.png
-silx/resources/gui/icons/nxdata-create.svg
-silx/resources/gui/icons/nxdata-remove.png
-silx/resources/gui/icons/nxdata-remove.svg
-silx/resources/gui/icons/pan.png
-silx/resources/gui/icons/pan.svg
-silx/resources/gui/icons/pixel-intensities.png
-silx/resources/gui/icons/pixel-intensities.svg
-silx/resources/gui/icons/plot-grid.png
-silx/resources/gui/icons/plot-grid.svg
-silx/resources/gui/icons/plot-roi-above.png
-silx/resources/gui/icons/plot-roi-above.svg
-silx/resources/gui/icons/plot-roi-below.png
-silx/resources/gui/icons/plot-roi-below.svg
-silx/resources/gui/icons/plot-roi-between.png
-silx/resources/gui/icons/plot-roi-between.svg
-silx/resources/gui/icons/plot-roi-reset.png
-silx/resources/gui/icons/plot-roi-reset.svg
-silx/resources/gui/icons/plot-roi.png
-silx/resources/gui/icons/plot-roi.svg
-silx/resources/gui/icons/plot-symbols.png
-silx/resources/gui/icons/plot-symbols.svg
-silx/resources/gui/icons/plot-toggle-points.png
-silx/resources/gui/icons/plot-toggle-points.svg
-silx/resources/gui/icons/plot-widget.png
-silx/resources/gui/icons/plot-widget.svg
-silx/resources/gui/icons/plot-window-image.png
-silx/resources/gui/icons/plot-window-image.svg
-silx/resources/gui/icons/plot-window.png
-silx/resources/gui/icons/plot-window.svg
-silx/resources/gui/icons/plot-xauto.png
-silx/resources/gui/icons/plot-xauto.svg
-silx/resources/gui/icons/plot-xlog.png
-silx/resources/gui/icons/plot-xlog.svg
-silx/resources/gui/icons/plot-yauto.png
-silx/resources/gui/icons/plot-yauto.svg
-silx/resources/gui/icons/plot-ydown.png
-silx/resources/gui/icons/plot-ydown.svg
-silx/resources/gui/icons/plot-ylog.png
-silx/resources/gui/icons/plot-ylog.svg
-silx/resources/gui/icons/plot-yup.png
-silx/resources/gui/icons/plot-yup.svg
-silx/resources/gui/icons/pointing-hand.png
-silx/resources/gui/icons/pointing-hand.svg
-silx/resources/gui/icons/previous.png
-silx/resources/gui/icons/previous.svg
-silx/resources/gui/icons/process-working.mng
-silx/resources/gui/icons/profile-clear.png
-silx/resources/gui/icons/profile-clear.svg
-silx/resources/gui/icons/profile1D.png
-silx/resources/gui/icons/profile1D.svg
-silx/resources/gui/icons/profile2D.png
-silx/resources/gui/icons/profile2D.svg
-silx/resources/gui/icons/remove.png
-silx/resources/gui/icons/remove.svg
-silx/resources/gui/icons/rm.png
-silx/resources/gui/icons/rm.svg
-silx/resources/gui/icons/rotate-3d.png
-silx/resources/gui/icons/rotate-3d.svg
-silx/resources/gui/icons/rudder.png
-silx/resources/gui/icons/rudder.svg
-silx/resources/gui/icons/selected.png
-silx/resources/gui/icons/selected.svg
-silx/resources/gui/icons/shape-circle-solid.png
-silx/resources/gui/icons/shape-circle-solid.svg
-silx/resources/gui/icons/shape-circle.png
-silx/resources/gui/icons/shape-circle.svg
-silx/resources/gui/icons/shape-cross.png
-silx/resources/gui/icons/shape-cross.svg
-silx/resources/gui/icons/shape-diagonal-directed.png
-silx/resources/gui/icons/shape-diagonal-directed.svg
-silx/resources/gui/icons/shape-diagonal.png
-silx/resources/gui/icons/shape-diagonal.svg
-silx/resources/gui/icons/shape-ellipse-solid.png
-silx/resources/gui/icons/shape-ellipse-solid.svg
-silx/resources/gui/icons/shape-ellipse.png
-silx/resources/gui/icons/shape-ellipse.svg
-silx/resources/gui/icons/shape-horizontal.png
-silx/resources/gui/icons/shape-horizontal.svg
-silx/resources/gui/icons/shape-polygon.png
-silx/resources/gui/icons/shape-polygon.svg
-silx/resources/gui/icons/shape-rectangle.png
-silx/resources/gui/icons/shape-rectangle.svg
-silx/resources/gui/icons/shape-square.png
-silx/resources/gui/icons/shape-square.svg
-silx/resources/gui/icons/shape-vertical.png
-silx/resources/gui/icons/shape-vertical.svg
-silx/resources/gui/icons/silx.png
-silx/resources/gui/icons/silx.svg
-silx/resources/gui/icons/slice-cross.png
-silx/resources/gui/icons/slice-cross.svg
-silx/resources/gui/icons/slice-horizontal.png
-silx/resources/gui/icons/slice-horizontal.svg
-silx/resources/gui/icons/slice-vertical.png
-silx/resources/gui/icons/slice-vertical.svg
-silx/resources/gui/icons/sliders-off.png
-silx/resources/gui/icons/sliders-off.svg
-silx/resources/gui/icons/sliders-on.png
-silx/resources/gui/icons/sliders-on.svg
-silx/resources/gui/icons/spec.png
-silx/resources/gui/icons/spec.svg
-silx/resources/gui/icons/stats-active-items.png
-silx/resources/gui/icons/stats-active-items.svg
-silx/resources/gui/icons/stats-visible-data.png
-silx/resources/gui/icons/stats-visible-data.svg
-silx/resources/gui/icons/stats-whole-data.png
-silx/resources/gui/icons/stats-whole-data.svg
-silx/resources/gui/icons/stats-whole-items.png
-silx/resources/gui/icons/stats-whole-items.svg
-silx/resources/gui/icons/tree-collapse-all.png
-silx/resources/gui/icons/tree-collapse-all.svg
-silx/resources/gui/icons/tree-expand-all.png
-silx/resources/gui/icons/tree-expand-all.svg
-silx/resources/gui/icons/tree-sort.png
-silx/resources/gui/icons/tree-sort.svg
-silx/resources/gui/icons/view-1d.png
-silx/resources/gui/icons/view-1d.svg
-silx/resources/gui/icons/view-2d-stack.png
-silx/resources/gui/icons/view-2d-stack.svg
-silx/resources/gui/icons/view-2d.png
-silx/resources/gui/icons/view-2d.svg
-silx/resources/gui/icons/view-3d.png
-silx/resources/gui/icons/view-3d.svg
-silx/resources/gui/icons/view-fullscreen.png
-silx/resources/gui/icons/view-fullscreen.svg
-silx/resources/gui/icons/view-hdf5.png
-silx/resources/gui/icons/view-hdf5.svg
-silx/resources/gui/icons/view-nexus.png
-silx/resources/gui/icons/view-nexus.svg
-silx/resources/gui/icons/view-nofullscreen.png
-silx/resources/gui/icons/view-nofullscreen.svg
-silx/resources/gui/icons/view-raw.png
-silx/resources/gui/icons/view-raw.svg
-silx/resources/gui/icons/view-refresh.png
-silx/resources/gui/icons/view-refresh.svg
-silx/resources/gui/icons/view-text.png
-silx/resources/gui/icons/view-text.svg
-silx/resources/gui/icons/window-new.png
-silx/resources/gui/icons/window-new.svg
-silx/resources/gui/icons/zoom-back.png
-silx/resources/gui/icons/zoom-back.svg
-silx/resources/gui/icons/zoom-in.png
-silx/resources/gui/icons/zoom-in.svg
-silx/resources/gui/icons/zoom-original.png
-silx/resources/gui/icons/zoom-original.svg
-silx/resources/gui/icons/zoom-out.png
-silx/resources/gui/icons/zoom-out.svg
-silx/resources/gui/icons/zoom.png
-silx/resources/gui/icons/zoom.svg
-silx/resources/gui/icons/process-working/00.png
-silx/resources/gui/icons/process-working/01.png
-silx/resources/gui/icons/process-working/02.png
-silx/resources/gui/icons/process-working/03.png
-silx/resources/gui/icons/process-working/04.png
-silx/resources/gui/icons/process-working/05.png
-silx/resources/gui/icons/process-working/06.png
-silx/resources/gui/icons/process-working/07.png
-silx/resources/gui/icons/process-working/08.png
-silx/resources/gui/icons/process-working/09.png
-silx/resources/gui/icons/process-working/10.png
-silx/resources/gui/icons/process-working/11.png
-silx/resources/gui/icons/process-working/12.png
-silx/resources/gui/icons/process-working/13.png
-silx/resources/gui/icons/process-working/14.png
-silx/resources/gui/icons/process-working/15.png
-silx/resources/gui/icons/process-working/16.png
-silx/resources/gui/icons/process-working/17.png
-silx/resources/gui/icons/process-working/18.png
-silx/resources/gui/icons/process-working/19.png
-silx/resources/gui/icons/process-working/20.png
-silx/resources/gui/icons/process-working/21.png
-silx/resources/gui/icons/process-working/22.png
-silx/resources/gui/icons/process-working/23.png
-silx/resources/gui/icons/process-working/24.png
-silx/resources/gui/icons/process-working/25.png
-silx/resources/gui/icons/process-working/26.png
-silx/resources/gui/icons/process-working/27.png
-silx/resources/gui/icons/process-working/28.png
-silx/resources/gui/icons/process-working/29.png
-silx/resources/gui/icons/process-working/30.png
-silx/resources/gui/logo/silx.png
-silx/resources/gui/logo/silx.svg
-silx/resources/opencl/addition.cl
-silx/resources/opencl/array_utils.cl
-silx/resources/opencl/backproj.cl
-silx/resources/opencl/backproj_helper.cl
-silx/resources/opencl/bitonic.cl
-silx/resources/opencl/convolution.cl
-silx/resources/opencl/convolution_textures.cl
-silx/resources/opencl/doubleword.cl
-silx/resources/opencl/kahan.cl
-silx/resources/opencl/linalg.cl
-silx/resources/opencl/medfilt.cl
-silx/resources/opencl/preprocess.cl
-silx/resources/opencl/proj.cl
-silx/resources/opencl/sparse.cl
-silx/resources/opencl/statistics.cl
-silx/resources/opencl/codec/byte_offset.cl
-silx/resources/opencl/image/cast.cl
-silx/resources/opencl/image/histogram.cl
-silx/resources/opencl/image/map.cl
-silx/resources/opencl/image/max_min.cl
-silx/resources/opencl/sift/addition.cl
-silx/resources/opencl/sift/algebra.cl
-silx/resources/opencl/sift/convolution.cl
-silx/resources/opencl/sift/descriptor_cpu.cl
-silx/resources/opencl/sift/descriptor_gpu1.cl
-silx/resources/opencl/sift/descriptor_gpu2.cl
-silx/resources/opencl/sift/gaussian.cl
-silx/resources/opencl/sift/image.cl
-silx/resources/opencl/sift/interpolation.cl
-silx/resources/opencl/sift/matching_cpu.cl
-silx/resources/opencl/sift/matching_gpu.cl
-silx/resources/opencl/sift/memset.cl
-silx/resources/opencl/sift/orientation_cpu.cl
-silx/resources/opencl/sift/orientation_gpu.cl
-silx/resources/opencl/sift/preprocess.cl
-silx/resources/opencl/sift/reductions.cl
-silx/resources/opencl/sift/sift.cl
-silx/resources/opencl/sift/transform.cl
-silx/sx/__init__.py
-silx/sx/_plot.py
-silx/sx/_plot3d.py
-silx/test/__init__.py
-silx/test/test_resources.py
-silx/test/test_sx.py
-silx/test/test_version.py
-silx/test/utils.py
-silx/third_party/EdfFile.py
-silx/third_party/TiffIO.py
-silx/third_party/__init__.py
-silx/third_party/scipy_spatial.py
-silx/third_party/setup.py
-silx/third_party/_local/__init__.py
-silx/third_party/_local/scipy_spatial/__init__.py
-silx/third_party/_local/scipy_spatial/qhull.pxd
-silx/third_party/_local/scipy_spatial/qhull.pyx
-silx/third_party/_local/scipy_spatial/qhull_misc.h
-silx/third_party/_local/scipy_spatial/setlist.pxd
-silx/third_party/_local/scipy_spatial/setup.py
-silx/third_party/_local/scipy_spatial/qhull/COPYING.txt
-silx/third_party/_local/scipy_spatial/qhull/src/geom2_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/geom_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/geom_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/global_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/io_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/io_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/libqhull_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/libqhull_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/mem_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/mem_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/merge_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/merge_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/poly2_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/poly_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/poly_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/qhull_ra.h
-silx/third_party/_local/scipy_spatial/qhull/src/qset_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/qset_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/random_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/random_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/rboxlib_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/stat_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/stat_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/user_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/user_r.h
-silx/third_party/_local/scipy_spatial/qhull/src/usermem_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/userprintf_r.c
-silx/third_party/_local/scipy_spatial/qhull/src/userprintf_rbox_r.c
-silx/utils/ExternalResources.py
-silx/utils/__init__.py
-silx/utils/_have_openmp.pxd
-silx/utils/array_like.py
-silx/utils/debug.py
-silx/utils/deprecation.py
-silx/utils/enum.py
-silx/utils/exceptions.py
-silx/utils/files.py
-silx/utils/html.py
-silx/utils/launcher.py
-silx/utils/number.py
-silx/utils/property.py
-silx/utils/proxy.py
-silx/utils/retry.py
-silx/utils/setup.py
-silx/utils/testutils.py
-silx/utils/weakref.py
-silx/utils/include/silx_store_openmp.h
-silx/utils/test/__init__.py
-silx/utils/test/test_array_like.py
-silx/utils/test/test_debug.py
-silx/utils/test/test_deprecation.py
-silx/utils/test/test_enum.py
-silx/utils/test/test_external_resources.py
-silx/utils/test/test_html.py
-silx/utils/test/test_launcher.py
-silx/utils/test/test_launcher_command.py
-silx/utils/test/test_number.py
-silx/utils/test/test_proxy.py
-silx/utils/test/test_retry.py
-silx/utils/test/test_testutils.py
-silx/utils/test/test_weakref.py \ No newline at end of file
+src/silx/__init__.py
+src/silx/__main__.py
+src/silx/_config.py
+src/silx/_version.py
+src/silx/conftest.py
+src/silx/setup.py
+src/silx/app/__init__.py
+src/silx/app/convert.py
+src/silx/app/setup.py
+src/silx/app/test_.py
+src/silx/app/test/__init__.py
+src/silx/app/test/test_convert.py
+src/silx/app/view/About.py
+src/silx/app/view/ApplicationContext.py
+src/silx/app/view/CustomNxdataWidget.py
+src/silx/app/view/DataPanel.py
+src/silx/app/view/Viewer.py
+src/silx/app/view/__init__.py
+src/silx/app/view/main.py
+src/silx/app/view/setup.py
+src/silx/app/view/utils.py
+src/silx/app/view/test/__init__.py
+src/silx/app/view/test/test_launcher.py
+src/silx/app/view/test/test_view.py
+src/silx/gui/__init__.py
+src/silx/gui/colors.py
+src/silx/gui/conftest.py
+src/silx/gui/console.py
+src/silx/gui/icons.py
+src/silx/gui/printer.py
+src/silx/gui/setup.py
+src/silx/gui/_glutils/Context.py
+src/silx/gui/_glutils/FramebufferTexture.py
+src/silx/gui/_glutils/OpenGLWidget.py
+src/silx/gui/_glutils/Program.py
+src/silx/gui/_glutils/Texture.py
+src/silx/gui/_glutils/VertexBuffer.py
+src/silx/gui/_glutils/__init__.py
+src/silx/gui/_glutils/font.py
+src/silx/gui/_glutils/gl.py
+src/silx/gui/_glutils/utils.py
+src/silx/gui/data/ArrayTableModel.py
+src/silx/gui/data/ArrayTableWidget.py
+src/silx/gui/data/DataViewer.py
+src/silx/gui/data/DataViewerFrame.py
+src/silx/gui/data/DataViewerSelector.py
+src/silx/gui/data/DataViews.py
+src/silx/gui/data/Hdf5TableView.py
+src/silx/gui/data/HexaTableView.py
+src/silx/gui/data/NXdataWidgets.py
+src/silx/gui/data/NumpyAxesSelector.py
+src/silx/gui/data/RecordTableView.py
+src/silx/gui/data/TextFormatter.py
+src/silx/gui/data/_RecordPlot.py
+src/silx/gui/data/_VolumeWindow.py
+src/silx/gui/data/__init__.py
+src/silx/gui/data/setup.py
+src/silx/gui/data/test/__init__.py
+src/silx/gui/data/test/test_arraywidget.py
+src/silx/gui/data/test/test_dataviewer.py
+src/silx/gui/data/test/test_numpyaxesselector.py
+src/silx/gui/data/test/test_textformatter.py
+src/silx/gui/dialog/AbstractDataFileDialog.py
+src/silx/gui/dialog/ColormapDialog.py
+src/silx/gui/dialog/DataFileDialog.py
+src/silx/gui/dialog/DatasetDialog.py
+src/silx/gui/dialog/FileTypeComboBox.py
+src/silx/gui/dialog/GroupDialog.py
+src/silx/gui/dialog/ImageFileDialog.py
+src/silx/gui/dialog/SafeFileIconProvider.py
+src/silx/gui/dialog/SafeFileSystemModel.py
+src/silx/gui/dialog/__init__.py
+src/silx/gui/dialog/setup.py
+src/silx/gui/dialog/utils.py
+src/silx/gui/dialog/test/__init__.py
+src/silx/gui/dialog/test/test_colormapdialog.py
+src/silx/gui/dialog/test/test_datafiledialog.py
+src/silx/gui/dialog/test/test_imagefiledialog.py
+src/silx/gui/fit/BackgroundWidget.py
+src/silx/gui/fit/FitConfig.py
+src/silx/gui/fit/FitWidget.py
+src/silx/gui/fit/FitWidgets.py
+src/silx/gui/fit/Parameters.py
+src/silx/gui/fit/__init__.py
+src/silx/gui/fit/setup.py
+src/silx/gui/fit/test/__init__.py
+src/silx/gui/fit/test/testBackgroundWidget.py
+src/silx/gui/fit/test/testFitConfig.py
+src/silx/gui/fit/test/testFitWidget.py
+src/silx/gui/hdf5/Hdf5Formatter.py
+src/silx/gui/hdf5/Hdf5HeaderView.py
+src/silx/gui/hdf5/Hdf5Item.py
+src/silx/gui/hdf5/Hdf5LoadingItem.py
+src/silx/gui/hdf5/Hdf5Node.py
+src/silx/gui/hdf5/Hdf5TreeModel.py
+src/silx/gui/hdf5/Hdf5TreeView.py
+src/silx/gui/hdf5/NexusSortFilterProxyModel.py
+src/silx/gui/hdf5/__init__.py
+src/silx/gui/hdf5/_utils.py
+src/silx/gui/hdf5/setup.py
+src/silx/gui/hdf5/test/__init__.py
+src/silx/gui/hdf5/test/test_hdf5.py
+src/silx/gui/plot/AlphaSlider.py
+src/silx/gui/plot/ColorBar.py
+src/silx/gui/plot/Colormap.py
+src/silx/gui/plot/ColormapDialog.py
+src/silx/gui/plot/Colors.py
+src/silx/gui/plot/CompareImages.py
+src/silx/gui/plot/ComplexImageView.py
+src/silx/gui/plot/CurvesROIWidget.py
+src/silx/gui/plot/ImageStack.py
+src/silx/gui/plot/ImageView.py
+src/silx/gui/plot/Interaction.py
+src/silx/gui/plot/ItemsSelectionDialog.py
+src/silx/gui/plot/LegendSelector.py
+src/silx/gui/plot/LimitsHistory.py
+src/silx/gui/plot/MaskToolsWidget.py
+src/silx/gui/plot/PlotActions.py
+src/silx/gui/plot/PlotEvents.py
+src/silx/gui/plot/PlotInteraction.py
+src/silx/gui/plot/PlotToolButtons.py
+src/silx/gui/plot/PlotTools.py
+src/silx/gui/plot/PlotWidget.py
+src/silx/gui/plot/PlotWindow.py
+src/silx/gui/plot/PrintPreviewToolButton.py
+src/silx/gui/plot/Profile.py
+src/silx/gui/plot/ProfileMainWindow.py
+src/silx/gui/plot/ROIStatsWidget.py
+src/silx/gui/plot/ScatterMaskToolsWidget.py
+src/silx/gui/plot/ScatterView.py
+src/silx/gui/plot/StackView.py
+src/silx/gui/plot/StatsWidget.py
+src/silx/gui/plot/_BaseMaskToolsWidget.py
+src/silx/gui/plot/__init__.py
+src/silx/gui/plot/setup.py
+src/silx/gui/plot/_utils/__init__.py
+src/silx/gui/plot/_utils/delaunay.py
+src/silx/gui/plot/_utils/dtime_ticklayout.py
+src/silx/gui/plot/_utils/panzoom.py
+src/silx/gui/plot/_utils/setup.py
+src/silx/gui/plot/_utils/ticklayout.py
+src/silx/gui/plot/_utils/test/__init__.py
+src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
+src/silx/gui/plot/_utils/test/test_ticklayout.py
+src/silx/gui/plot/actions/PlotAction.py
+src/silx/gui/plot/actions/PlotToolAction.py
+src/silx/gui/plot/actions/__init__.py
+src/silx/gui/plot/actions/control.py
+src/silx/gui/plot/actions/fit.py
+src/silx/gui/plot/actions/histogram.py
+src/silx/gui/plot/actions/io.py
+src/silx/gui/plot/actions/medfilt.py
+src/silx/gui/plot/actions/mode.py
+src/silx/gui/plot/backends/BackendBase.py
+src/silx/gui/plot/backends/BackendMatplotlib.py
+src/silx/gui/plot/backends/BackendOpenGL.py
+src/silx/gui/plot/backends/__init__.py
+src/silx/gui/plot/backends/glutils/GLPlotCurve.py
+src/silx/gui/plot/backends/glutils/GLPlotFrame.py
+src/silx/gui/plot/backends/glutils/GLPlotImage.py
+src/silx/gui/plot/backends/glutils/GLPlotItem.py
+src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
+src/silx/gui/plot/backends/glutils/GLSupport.py
+src/silx/gui/plot/backends/glutils/GLText.py
+src/silx/gui/plot/backends/glutils/GLTexture.py
+src/silx/gui/plot/backends/glutils/PlotImageFile.py
+src/silx/gui/plot/backends/glutils/__init__.py
+src/silx/gui/plot/items/__init__.py
+src/silx/gui/plot/items/_arc_roi.py
+src/silx/gui/plot/items/_pick.py
+src/silx/gui/plot/items/_roi_base.py
+src/silx/gui/plot/items/axis.py
+src/silx/gui/plot/items/complex.py
+src/silx/gui/plot/items/core.py
+src/silx/gui/plot/items/curve.py
+src/silx/gui/plot/items/histogram.py
+src/silx/gui/plot/items/image.py
+src/silx/gui/plot/items/image_aggregated.py
+src/silx/gui/plot/items/marker.py
+src/silx/gui/plot/items/roi.py
+src/silx/gui/plot/items/scatter.py
+src/silx/gui/plot/items/shape.py
+src/silx/gui/plot/matplotlib/Colormap.py
+src/silx/gui/plot/matplotlib/__init__.py
+src/silx/gui/plot/stats/__init__.py
+src/silx/gui/plot/stats/stats.py
+src/silx/gui/plot/stats/statshandler.py
+src/silx/gui/plot/test/__init__.py
+src/silx/gui/plot/test/testAlphaSlider.py
+src/silx/gui/plot/test/testColorBar.py
+src/silx/gui/plot/test/testCompareImages.py
+src/silx/gui/plot/test/testComplexImageView.py
+src/silx/gui/plot/test/testCurvesROIWidget.py
+src/silx/gui/plot/test/testImageStack.py
+src/silx/gui/plot/test/testImageView.py
+src/silx/gui/plot/test/testInteraction.py
+src/silx/gui/plot/test/testItem.py
+src/silx/gui/plot/test/testLegendSelector.py
+src/silx/gui/plot/test/testLimitConstraints.py
+src/silx/gui/plot/test/testMaskToolsWidget.py
+src/silx/gui/plot/test/testPixelIntensityHistoAction.py
+src/silx/gui/plot/test/testPlotActions.py
+src/silx/gui/plot/test/testPlotInteraction.py
+src/silx/gui/plot/test/testPlotWidget.py
+src/silx/gui/plot/test/testPlotWidgetNoBackend.py
+src/silx/gui/plot/test/testPlotWindow.py
+src/silx/gui/plot/test/testRoiStatsWidget.py
+src/silx/gui/plot/test/testSaveAction.py
+src/silx/gui/plot/test/testScatterMaskToolsWidget.py
+src/silx/gui/plot/test/testScatterView.py
+src/silx/gui/plot/test/testStackView.py
+src/silx/gui/plot/test/testStats.py
+src/silx/gui/plot/test/testUtilsAxis.py
+src/silx/gui/plot/test/utils.py
+src/silx/gui/plot/tools/CurveLegendsWidget.py
+src/silx/gui/plot/tools/LimitsToolBar.py
+src/silx/gui/plot/tools/PositionInfo.py
+src/silx/gui/plot/tools/RadarView.py
+src/silx/gui/plot/tools/__init__.py
+src/silx/gui/plot/tools/roi.py
+src/silx/gui/plot/tools/toolbars.py
+src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
+src/silx/gui/plot/tools/profile/__init__.py
+src/silx/gui/plot/tools/profile/core.py
+src/silx/gui/plot/tools/profile/editors.py
+src/silx/gui/plot/tools/profile/manager.py
+src/silx/gui/plot/tools/profile/rois.py
+src/silx/gui/plot/tools/profile/toolbar.py
+src/silx/gui/plot/tools/test/__init__.py
+src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
+src/silx/gui/plot/tools/test/testProfile.py
+src/silx/gui/plot/tools/test/testROI.py
+src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
+src/silx/gui/plot/tools/test/testTools.py
+src/silx/gui/plot/utils/__init__.py
+src/silx/gui/plot/utils/axis.py
+src/silx/gui/plot/utils/intersections.py
+src/silx/gui/plot3d/ParamTreeView.py
+src/silx/gui/plot3d/Plot3DWidget.py
+src/silx/gui/plot3d/Plot3DWindow.py
+src/silx/gui/plot3d/SFViewParamTree.py
+src/silx/gui/plot3d/ScalarFieldView.py
+src/silx/gui/plot3d/SceneWidget.py
+src/silx/gui/plot3d/SceneWindow.py
+src/silx/gui/plot3d/__init__.py
+src/silx/gui/plot3d/conftest.py
+src/silx/gui/plot3d/setup.py
+src/silx/gui/plot3d/_model/__init__.py
+src/silx/gui/plot3d/_model/core.py
+src/silx/gui/plot3d/_model/items.py
+src/silx/gui/plot3d/_model/model.py
+src/silx/gui/plot3d/actions/Plot3DAction.py
+src/silx/gui/plot3d/actions/__init__.py
+src/silx/gui/plot3d/actions/io.py
+src/silx/gui/plot3d/actions/mode.py
+src/silx/gui/plot3d/actions/viewpoint.py
+src/silx/gui/plot3d/items/__init__.py
+src/silx/gui/plot3d/items/_pick.py
+src/silx/gui/plot3d/items/clipplane.py
+src/silx/gui/plot3d/items/core.py
+src/silx/gui/plot3d/items/image.py
+src/silx/gui/plot3d/items/mesh.py
+src/silx/gui/plot3d/items/mixins.py
+src/silx/gui/plot3d/items/scatter.py
+src/silx/gui/plot3d/items/volume.py
+src/silx/gui/plot3d/scene/__init__.py
+src/silx/gui/plot3d/scene/axes.py
+src/silx/gui/plot3d/scene/camera.py
+src/silx/gui/plot3d/scene/core.py
+src/silx/gui/plot3d/scene/cutplane.py
+src/silx/gui/plot3d/scene/event.py
+src/silx/gui/plot3d/scene/function.py
+src/silx/gui/plot3d/scene/interaction.py
+src/silx/gui/plot3d/scene/primitives.py
+src/silx/gui/plot3d/scene/text.py
+src/silx/gui/plot3d/scene/transform.py
+src/silx/gui/plot3d/scene/utils.py
+src/silx/gui/plot3d/scene/viewport.py
+src/silx/gui/plot3d/scene/window.py
+src/silx/gui/plot3d/scene/test/__init__.py
+src/silx/gui/plot3d/scene/test/test_transform.py
+src/silx/gui/plot3d/scene/test/test_utils.py
+src/silx/gui/plot3d/test/__init__.py
+src/silx/gui/plot3d/test/testGL.py
+src/silx/gui/plot3d/test/testScalarFieldView.py
+src/silx/gui/plot3d/test/testSceneWidget.py
+src/silx/gui/plot3d/test/testSceneWidgetPicking.py
+src/silx/gui/plot3d/test/testSceneWindow.py
+src/silx/gui/plot3d/test/testStatsWidget.py
+src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
+src/silx/gui/plot3d/tools/PositionInfoWidget.py
+src/silx/gui/plot3d/tools/ViewpointTools.py
+src/silx/gui/plot3d/tools/__init__.py
+src/silx/gui/plot3d/tools/toolbars.py
+src/silx/gui/plot3d/tools/test/__init__.py
+src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
+src/silx/gui/plot3d/utils/__init__.py
+src/silx/gui/plot3d/utils/mng.py
+src/silx/gui/qt/__init__.py
+src/silx/gui/qt/_pyside_dynamic.py
+src/silx/gui/qt/_qt.py
+src/silx/gui/qt/_utils.py
+src/silx/gui/qt/inspect.py
+src/silx/gui/test/__init__.py
+src/silx/gui/test/test_colors.py
+src/silx/gui/test/test_console.py
+src/silx/gui/test/test_icons.py
+src/silx/gui/test/test_qt.py
+src/silx/gui/test/utils.py
+src/silx/gui/utils/__init__.py
+src/silx/gui/utils/concurrent.py
+src/silx/gui/utils/image.py
+src/silx/gui/utils/matplotlib.py
+src/silx/gui/utils/projecturl.py
+src/silx/gui/utils/qtutils.py
+src/silx/gui/utils/signal.py
+src/silx/gui/utils/testutils.py
+src/silx/gui/utils/glutils/__init__.py
+src/silx/gui/utils/test/__init__.py
+src/silx/gui/utils/test/test.py
+src/silx/gui/utils/test/test_async.py
+src/silx/gui/utils/test/test_glutils.py
+src/silx/gui/utils/test/test_image.py
+src/silx/gui/utils/test/test_qtutils.py
+src/silx/gui/utils/test/test_testutils.py
+src/silx/gui/widgets/BoxLayoutDockWidget.py
+src/silx/gui/widgets/ColormapNameComboBox.py
+src/silx/gui/widgets/ElidedLabel.py
+src/silx/gui/widgets/FloatEdit.py
+src/silx/gui/widgets/FlowLayout.py
+src/silx/gui/widgets/FrameBrowser.py
+src/silx/gui/widgets/HierarchicalTableView.py
+src/silx/gui/widgets/LegendIconWidget.py
+src/silx/gui/widgets/MedianFilterDialog.py
+src/silx/gui/widgets/MultiModeAction.py
+src/silx/gui/widgets/PeriodicTable.py
+src/silx/gui/widgets/PrintGeometryDialog.py
+src/silx/gui/widgets/PrintPreview.py
+src/silx/gui/widgets/RangeSlider.py
+src/silx/gui/widgets/TableWidget.py
+src/silx/gui/widgets/ThreadPoolPushButton.py
+src/silx/gui/widgets/UrlSelectionTable.py
+src/silx/gui/widgets/WaitingPushButton.py
+src/silx/gui/widgets/__init__.py
+src/silx/gui/widgets/setup.py
+src/silx/gui/widgets/test/__init__.py
+src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
+src/silx/gui/widgets/test/test_elidedlabel.py
+src/silx/gui/widgets/test/test_flowlayout.py
+src/silx/gui/widgets/test/test_framebrowser.py
+src/silx/gui/widgets/test/test_hierarchicaltableview.py
+src/silx/gui/widgets/test/test_legendiconwidget.py
+src/silx/gui/widgets/test/test_periodictable.py
+src/silx/gui/widgets/test/test_printpreview.py
+src/silx/gui/widgets/test/test_rangeslider.py
+src/silx/gui/widgets/test/test_tablewidget.py
+src/silx/gui/widgets/test/test_threadpoolpushbutton.py
+src/silx/image/__init__.py
+src/silx/image/_boundingbox.py
+src/silx/image/backprojection.py
+src/silx/image/bilinear.pyx
+src/silx/image/medianfilter.py
+src/silx/image/phantomgenerator.py
+src/silx/image/projection.py
+src/silx/image/reconstruction.py
+src/silx/image/setup.py
+src/silx/image/shapes.pyx
+src/silx/image/sift.py
+src/silx/image/tomography.py
+src/silx/image/utils.py
+src/silx/image/marchingsquares/__init__.py
+src/silx/image/marchingsquares/_mergeimpl.pyx
+src/silx/image/marchingsquares/_skimage.py
+src/silx/image/marchingsquares/setup.py
+src/silx/image/marchingsquares/include/patterns.h
+src/silx/image/marchingsquares/test/__init__.py
+src/silx/image/marchingsquares/test/test_funcapi.py
+src/silx/image/marchingsquares/test/test_mergeimpl.py
+src/silx/image/test/__init__.py
+src/silx/image/test/test_bb.py
+src/silx/image/test/test_bilinear.py
+src/silx/image/test/test_medianfilter.py
+src/silx/image/test/test_shapes.py
+src/silx/image/test/test_tomography.py
+src/silx/io/__init__.py
+src/silx/io/commonh5.py
+src/silx/io/configdict.py
+src/silx/io/convert.py
+src/silx/io/dictdump.py
+src/silx/io/fabioh5.py
+src/silx/io/fioh5.py
+src/silx/io/h5py_utils.py
+src/silx/io/octaveh5.py
+src/silx/io/rawh5.py
+src/silx/io/setup.py
+src/silx/io/specfile.pyx
+src/silx/io/specfile_wrapper.pxd
+src/silx/io/specfilewrapper.py
+src/silx/io/spech5.py
+src/silx/io/spectoh5.py
+src/silx/io/url.py
+src/silx/io/utils.py
+src/silx/io/nxdata/__init__.py
+src/silx/io/nxdata/_utils.py
+src/silx/io/nxdata/parse.py
+src/silx/io/nxdata/write.py
+src/silx/io/specfile/include/Lists.h
+src/silx/io/specfile/include/SpecFile.h
+src/silx/io/specfile/include/SpecFileCython.h
+src/silx/io/specfile/include/SpecFileP.h
+src/silx/io/specfile/include/locale_management.h
+src/silx/io/specfile/src/locale_management.c
+src/silx/io/specfile/src/sfdata.c
+src/silx/io/specfile/src/sfheader.c
+src/silx/io/specfile/src/sfindex.c
+src/silx/io/specfile/src/sfinit.c
+src/silx/io/specfile/src/sflabel.c
+src/silx/io/specfile/src/sflists.c
+src/silx/io/specfile/src/sfmca.c
+src/silx/io/specfile/src/sftools.c
+src/silx/io/specfile/src/sfwrite.c
+src/silx/io/test/__init__.py
+src/silx/io/test/test_commonh5.py
+src/silx/io/test/test_dictdump.py
+src/silx/io/test/test_fabioh5.py
+src/silx/io/test/test_fioh5.py
+src/silx/io/test/test_h5py_utils.py
+src/silx/io/test/test_nxdata.py
+src/silx/io/test/test_octaveh5.py
+src/silx/io/test/test_rawh5.py
+src/silx/io/test/test_specfile.py
+src/silx/io/test/test_specfilewrapper.py
+src/silx/io/test/test_spech5.py
+src/silx/io/test/test_spectoh5.py
+src/silx/io/test/test_url.py
+src/silx/io/test/test_utils.py
+src/silx/io/test/test_write_to_h5.py
+src/silx/math/__init__.py
+src/silx/math/_colormap.pyx
+src/silx/math/calibration.py
+src/silx/math/chistogramnd.pyx
+src/silx/math/chistogramnd_lut.pyx
+src/silx/math/colormap.py
+src/silx/math/combo.pyx
+src/silx/math/histogram.py
+src/silx/math/histogramnd_c.pxd
+src/silx/math/interpolate.pyx
+src/silx/math/marchingcubes.pyx
+src/silx/math/math_compatibility.pxd
+src/silx/math/mc.pxd
+src/silx/math/setup.py
+src/silx/math/fft/__init__.py
+src/silx/math/fft/basefft.py
+src/silx/math/fft/clfft.py
+src/silx/math/fft/cufft.py
+src/silx/math/fft/fft.py
+src/silx/math/fft/fftw.py
+src/silx/math/fft/npfft.py
+src/silx/math/fft/setup.py
+src/silx/math/fft/test/__init__.py
+src/silx/math/fft/test/test_fft.py
+src/silx/math/fit/__init__.py
+src/silx/math/fit/bgtheories.py
+src/silx/math/fit/filters.pyx
+src/silx/math/fit/filters_wrapper.pxd
+src/silx/math/fit/fitmanager.py
+src/silx/math/fit/fittheories.py
+src/silx/math/fit/fittheory.py
+src/silx/math/fit/functions.pyx
+src/silx/math/fit/functions_wrapper.pxd
+src/silx/math/fit/leastsq.py
+src/silx/math/fit/peaks.pyx
+src/silx/math/fit/peaks_wrapper.pxd
+src/silx/math/fit/setup.py
+src/silx/math/fit/filters/include/filters.h
+src/silx/math/fit/filters/src/smoothnd.c
+src/silx/math/fit/filters/src/snip1d.c
+src/silx/math/fit/filters/src/snip2d.c
+src/silx/math/fit/filters/src/snip3d.c
+src/silx/math/fit/filters/src/strip.c
+src/silx/math/fit/functions/include/functions.h
+src/silx/math/fit/functions/src/funs.c
+src/silx/math/fit/peaks/include/peaks.h
+src/silx/math/fit/peaks/src/peaks.c
+src/silx/math/fit/test/__init__.py
+src/silx/math/fit/test/test_bgtheories.py
+src/silx/math/fit/test/test_filters.py
+src/silx/math/fit/test/test_fit.py
+src/silx/math/fit/test/test_fitmanager.py
+src/silx/math/fit/test/test_functions.py
+src/silx/math/fit/test/test_peaks.py
+src/silx/math/histogramnd/include/histogramnd_c.h
+src/silx/math/histogramnd/include/templates.h
+src/silx/math/histogramnd/include/msvc/stdint.h
+src/silx/math/histogramnd/src/histogramnd_c.c
+src/silx/math/histogramnd/src/histogramnd_template.c
+src/silx/math/include/math_compatibility.h
+src/silx/math/marchingcubes/mc.hpp
+src/silx/math/marchingcubes/mc_lut.cpp
+src/silx/math/medianfilter/__init__.py
+src/silx/math/medianfilter/median_filter.pxd
+src/silx/math/medianfilter/medianfilter.pyx
+src/silx/math/medianfilter/setup.py
+src/silx/math/medianfilter/include/median_filter.hpp
+src/silx/math/medianfilter/test/__init__.py
+src/silx/math/medianfilter/test/benchmark.py
+src/silx/math/medianfilter/test/test_medianfilter.py
+src/silx/math/test/__init__.py
+src/silx/math/test/benchmark_combo.py
+src/silx/math/test/histo_benchmarks.py
+src/silx/math/test/test_HistogramndLut_nominal.py
+src/silx/math/test/test_calibration.py
+src/silx/math/test/test_colormap.py
+src/silx/math/test/test_combo.py
+src/silx/math/test/test_histogramnd_error.py
+src/silx/math/test/test_histogramnd_nominal.py
+src/silx/math/test/test_histogramnd_vs_np.py
+src/silx/math/test/test_interpolate.py
+src/silx/math/test/test_marchingcubes.py
+src/silx/opencl/__init__.py
+src/silx/opencl/backprojection.py
+src/silx/opencl/common.py
+src/silx/opencl/conftest.py
+src/silx/opencl/convolution.py
+src/silx/opencl/image.py
+src/silx/opencl/linalg.py
+src/silx/opencl/medfilt.py
+src/silx/opencl/processing.py
+src/silx/opencl/projection.py
+src/silx/opencl/reconstruction.py
+src/silx/opencl/setup.py
+src/silx/opencl/sinofilter.py
+src/silx/opencl/sparse.py
+src/silx/opencl/statistics.py
+src/silx/opencl/utils.py
+src/silx/opencl/codec/__init__.py
+src/silx/opencl/codec/byte_offset.py
+src/silx/opencl/codec/setup.py
+src/silx/opencl/codec/test/__init__.py
+src/silx/opencl/codec/test/test_byte_offset.py
+src/silx/opencl/sift/__init__.py
+src/silx/opencl/sift/alignment.py
+src/silx/opencl/sift/match.py
+src/silx/opencl/sift/param.py
+src/silx/opencl/sift/plan.py
+src/silx/opencl/sift/setup.py
+src/silx/opencl/sift/sift.py
+src/silx/opencl/sift/utils.py
+src/silx/opencl/sift/test/__init__.py
+src/silx/opencl/sift/test/test_algebra.py
+src/silx/opencl/sift/test/test_align.py
+src/silx/opencl/sift/test/test_convol.py
+src/silx/opencl/sift/test/test_gaussian.py
+src/silx/opencl/sift/test/test_image.py
+src/silx/opencl/sift/test/test_image_functions.py
+src/silx/opencl/sift/test/test_image_setup.py
+src/silx/opencl/sift/test/test_keypoints.py
+src/silx/opencl/sift/test/test_matching.py
+src/silx/opencl/sift/test/test_preproc.py
+src/silx/opencl/sift/test/test_reductions.py
+src/silx/opencl/sift/test/test_transform.py
+src/silx/opencl/test/__init__.py
+src/silx/opencl/test/test_addition.py
+src/silx/opencl/test/test_array_utils.py
+src/silx/opencl/test/test_backprojection.py
+src/silx/opencl/test/test_convolution.py
+src/silx/opencl/test/test_doubleword.py
+src/silx/opencl/test/test_image.py
+src/silx/opencl/test/test_kahan.py
+src/silx/opencl/test/test_linalg.py
+src/silx/opencl/test/test_medfilt.py
+src/silx/opencl/test/test_projection.py
+src/silx/opencl/test/test_sparse.py
+src/silx/opencl/test/test_stats.py
+src/silx/resources/__init__.py
+src/silx/resources/gui/colormaps/cividis.npy
+src/silx/resources/gui/colormaps/inferno.npy
+src/silx/resources/gui/colormaps/magma.npy
+src/silx/resources/gui/colormaps/plasma.npy
+src/silx/resources/gui/colormaps/viridis.npy
+src/silx/resources/gui/icons/3d-plane-normal-x.png
+src/silx/resources/gui/icons/3d-plane-normal-x.svg
+src/silx/resources/gui/icons/3d-plane-normal-y.png
+src/silx/resources/gui/icons/3d-plane-normal-y.svg
+src/silx/resources/gui/icons/3d-plane-normal-z.png
+src/silx/resources/gui/icons/3d-plane-normal-z.svg
+src/silx/resources/gui/icons/3d-plane-pan.png
+src/silx/resources/gui/icons/3d-plane-pan.svg
+src/silx/resources/gui/icons/3d-plane.png
+src/silx/resources/gui/icons/3d-plane.svg
+src/silx/resources/gui/icons/add-range-horizontal.png
+src/silx/resources/gui/icons/add-range-horizontal.svg
+src/silx/resources/gui/icons/add-shape-arc.png
+src/silx/resources/gui/icons/add-shape-arc.svg
+src/silx/resources/gui/icons/add-shape-circle.png
+src/silx/resources/gui/icons/add-shape-circle.svg
+src/silx/resources/gui/icons/add-shape-cross.png
+src/silx/resources/gui/icons/add-shape-cross.svg
+src/silx/resources/gui/icons/add-shape-diagonal.png
+src/silx/resources/gui/icons/add-shape-diagonal.svg
+src/silx/resources/gui/icons/add-shape-ellipse.png
+src/silx/resources/gui/icons/add-shape-ellipse.svg
+src/silx/resources/gui/icons/add-shape-horizontal.png
+src/silx/resources/gui/icons/add-shape-horizontal.svg
+src/silx/resources/gui/icons/add-shape-point.png
+src/silx/resources/gui/icons/add-shape-point.svg
+src/silx/resources/gui/icons/add-shape-polygon.png
+src/silx/resources/gui/icons/add-shape-polygon.svg
+src/silx/resources/gui/icons/add-shape-rectangle.png
+src/silx/resources/gui/icons/add-shape-rectangle.svg
+src/silx/resources/gui/icons/add-shape-unknown.png
+src/silx/resources/gui/icons/add-shape-unknown.svg
+src/silx/resources/gui/icons/add-shape-vertical.png
+src/silx/resources/gui/icons/add-shape-vertical.svg
+src/silx/resources/gui/icons/add.png
+src/silx/resources/gui/icons/add.svg
+src/silx/resources/gui/icons/aggregation-mode.png
+src/silx/resources/gui/icons/aggregation-mode.svg
+src/silx/resources/gui/icons/arrow-keys.png
+src/silx/resources/gui/icons/arrow-keys.svg
+src/silx/resources/gui/icons/axis.png
+src/silx/resources/gui/icons/axis.svg
+src/silx/resources/gui/icons/backend-opengl.png
+src/silx/resources/gui/icons/backend-opengl.svg
+src/silx/resources/gui/icons/camera.png
+src/silx/resources/gui/icons/camera.svg
+src/silx/resources/gui/icons/clipboard.png
+src/silx/resources/gui/icons/clipboard.svg
+src/silx/resources/gui/icons/close.png
+src/silx/resources/gui/icons/close.svg
+src/silx/resources/gui/icons/colorbar.png
+src/silx/resources/gui/icons/colorbar.svg
+src/silx/resources/gui/icons/colormap-histogram.png
+src/silx/resources/gui/icons/colormap-histogram.svg
+src/silx/resources/gui/icons/colormap-none.png
+src/silx/resources/gui/icons/colormap-none.svg
+src/silx/resources/gui/icons/colormap-norm-arcsinh.png
+src/silx/resources/gui/icons/colormap-norm-arcsinh.svg
+src/silx/resources/gui/icons/colormap-norm-gamma.png
+src/silx/resources/gui/icons/colormap-norm-gamma.svg
+src/silx/resources/gui/icons/colormap-norm-linear.png
+src/silx/resources/gui/icons/colormap-norm-linear.svg
+src/silx/resources/gui/icons/colormap-norm-log.png
+src/silx/resources/gui/icons/colormap-norm-log.svg
+src/silx/resources/gui/icons/colormap-norm-sqrt.png
+src/silx/resources/gui/icons/colormap-norm-sqrt.svg
+src/silx/resources/gui/icons/colormap-range.png
+src/silx/resources/gui/icons/colormap-range.svg
+src/silx/resources/gui/icons/colormap.png
+src/silx/resources/gui/icons/colormap.svg
+src/silx/resources/gui/icons/compare-align-auto.png
+src/silx/resources/gui/icons/compare-align-auto.svg
+src/silx/resources/gui/icons/compare-align-center.png
+src/silx/resources/gui/icons/compare-align-center.svg
+src/silx/resources/gui/icons/compare-align-origin.png
+src/silx/resources/gui/icons/compare-align-origin.svg
+src/silx/resources/gui/icons/compare-align-stretch.png
+src/silx/resources/gui/icons/compare-align-stretch.svg
+src/silx/resources/gui/icons/compare-keypoints.png
+src/silx/resources/gui/icons/compare-keypoints.svg
+src/silx/resources/gui/icons/compare-mode-a-minus-b.png
+src/silx/resources/gui/icons/compare-mode-a-minus-b.svg
+src/silx/resources/gui/icons/compare-mode-a.png
+src/silx/resources/gui/icons/compare-mode-a.svg
+src/silx/resources/gui/icons/compare-mode-b.png
+src/silx/resources/gui/icons/compare-mode-b.svg
+src/silx/resources/gui/icons/compare-mode-hline.png
+src/silx/resources/gui/icons/compare-mode-hline.svg
+src/silx/resources/gui/icons/compare-mode-rb-channel.png
+src/silx/resources/gui/icons/compare-mode-rb-channel.svg
+src/silx/resources/gui/icons/compare-mode-rbneg-channel.png
+src/silx/resources/gui/icons/compare-mode-rbneg-channel.svg
+src/silx/resources/gui/icons/compare-mode-vline.png
+src/silx/resources/gui/icons/compare-mode-vline.svg
+src/silx/resources/gui/icons/crop.png
+src/silx/resources/gui/icons/crop.svg
+src/silx/resources/gui/icons/crosshair.png
+src/silx/resources/gui/icons/crosshair.svg
+src/silx/resources/gui/icons/cube-back.png
+src/silx/resources/gui/icons/cube-back.svg
+src/silx/resources/gui/icons/cube-bottom.png
+src/silx/resources/gui/icons/cube-bottom.svg
+src/silx/resources/gui/icons/cube-front.png
+src/silx/resources/gui/icons/cube-front.svg
+src/silx/resources/gui/icons/cube-left.png
+src/silx/resources/gui/icons/cube-left.svg
+src/silx/resources/gui/icons/cube-right.png
+src/silx/resources/gui/icons/cube-right.svg
+src/silx/resources/gui/icons/cube-rotate.png
+src/silx/resources/gui/icons/cube-rotate.svg
+src/silx/resources/gui/icons/cube-top.png
+src/silx/resources/gui/icons/cube-top.svg
+src/silx/resources/gui/icons/cube.png
+src/silx/resources/gui/icons/cube.svg
+src/silx/resources/gui/icons/description-description.png
+src/silx/resources/gui/icons/description-description.svg
+src/silx/resources/gui/icons/description-error.png
+src/silx/resources/gui/icons/description-error.svg
+src/silx/resources/gui/icons/description-name.png
+src/silx/resources/gui/icons/description-name.svg
+src/silx/resources/gui/icons/description-program.png
+src/silx/resources/gui/icons/description-program.svg
+src/silx/resources/gui/icons/description-title.png
+src/silx/resources/gui/icons/description-title.svg
+src/silx/resources/gui/icons/description-value.png
+src/silx/resources/gui/icons/description-value.svg
+src/silx/resources/gui/icons/document-open.png
+src/silx/resources/gui/icons/document-open.svg
+src/silx/resources/gui/icons/document-print.png
+src/silx/resources/gui/icons/document-print.svg
+src/silx/resources/gui/icons/document-save.png
+src/silx/resources/gui/icons/document-save.svg
+src/silx/resources/gui/icons/draw-brush.png
+src/silx/resources/gui/icons/draw-brush.svg
+src/silx/resources/gui/icons/draw-pencil.png
+src/silx/resources/gui/icons/draw-pencil.svg
+src/silx/resources/gui/icons/draw-rubber.png
+src/silx/resources/gui/icons/draw-rubber.svg
+src/silx/resources/gui/icons/edit-copy.png
+src/silx/resources/gui/icons/edit-copy.svg
+src/silx/resources/gui/icons/eye.png
+src/silx/resources/gui/icons/eye.svg
+src/silx/resources/gui/icons/first.png
+src/silx/resources/gui/icons/first.svg
+src/silx/resources/gui/icons/folder.png
+src/silx/resources/gui/icons/folder.svg
+src/silx/resources/gui/icons/image-mask.png
+src/silx/resources/gui/icons/image-mask.svg
+src/silx/resources/gui/icons/image-select-add.png
+src/silx/resources/gui/icons/image-select-add.svg
+src/silx/resources/gui/icons/image-select-box.png
+src/silx/resources/gui/icons/image-select-box.svg
+src/silx/resources/gui/icons/image-select-brush.png
+src/silx/resources/gui/icons/image-select-brush.svg
+src/silx/resources/gui/icons/image-select-erase-rubber.png
+src/silx/resources/gui/icons/image-select-erase-rubber.svg
+src/silx/resources/gui/icons/image-select-erase.png
+src/silx/resources/gui/icons/image-select-erase.svg
+src/silx/resources/gui/icons/image.png
+src/silx/resources/gui/icons/image.svg
+src/silx/resources/gui/icons/item-0dim.png
+src/silx/resources/gui/icons/item-0dim.svg
+src/silx/resources/gui/icons/item-1dim.png
+src/silx/resources/gui/icons/item-1dim.svg
+src/silx/resources/gui/icons/item-2dim.png
+src/silx/resources/gui/icons/item-2dim.svg
+src/silx/resources/gui/icons/item-3dim.png
+src/silx/resources/gui/icons/item-3dim.svg
+src/silx/resources/gui/icons/item-ndim.png
+src/silx/resources/gui/icons/item-ndim.svg
+src/silx/resources/gui/icons/item-none.png
+src/silx/resources/gui/icons/item-none.svg
+src/silx/resources/gui/icons/item-object.png
+src/silx/resources/gui/icons/item-object.svg
+src/silx/resources/gui/icons/last.png
+src/silx/resources/gui/icons/last.svg
+src/silx/resources/gui/icons/layer-nx.png
+src/silx/resources/gui/icons/layer-nx.svg
+src/silx/resources/gui/icons/mask-clear-all.png
+src/silx/resources/gui/icons/mask-clear-all.svg
+src/silx/resources/gui/icons/mask-clear.png
+src/silx/resources/gui/icons/mask-clear.svg
+src/silx/resources/gui/icons/mask-invert.png
+src/silx/resources/gui/icons/mask-invert.svg
+src/silx/resources/gui/icons/math-amplitude.png
+src/silx/resources/gui/icons/math-amplitude.svg
+src/silx/resources/gui/icons/math-average.png
+src/silx/resources/gui/icons/math-average.svg
+src/silx/resources/gui/icons/math-derive.png
+src/silx/resources/gui/icons/math-derive.svg
+src/silx/resources/gui/icons/math-energy.png
+src/silx/resources/gui/icons/math-energy.svg
+src/silx/resources/gui/icons/math-fit.png
+src/silx/resources/gui/icons/math-fit.svg
+src/silx/resources/gui/icons/math-imaginary.png
+src/silx/resources/gui/icons/math-imaginary.svg
+src/silx/resources/gui/icons/math-mean.png
+src/silx/resources/gui/icons/math-mean.svg
+src/silx/resources/gui/icons/math-normalize.png
+src/silx/resources/gui/icons/math-normalize.svg
+src/silx/resources/gui/icons/math-peak-reset.png
+src/silx/resources/gui/icons/math-peak-reset.svg
+src/silx/resources/gui/icons/math-peak-search.png
+src/silx/resources/gui/icons/math-peak-search.svg
+src/silx/resources/gui/icons/math-peak.png
+src/silx/resources/gui/icons/math-peak.svg
+src/silx/resources/gui/icons/math-phase-color-log.png
+src/silx/resources/gui/icons/math-phase-color-log.svg
+src/silx/resources/gui/icons/math-phase-color.png
+src/silx/resources/gui/icons/math-phase-color.svg
+src/silx/resources/gui/icons/math-phase.png
+src/silx/resources/gui/icons/math-phase.svg
+src/silx/resources/gui/icons/math-real.png
+src/silx/resources/gui/icons/math-real.svg
+src/silx/resources/gui/icons/math-sigma.png
+src/silx/resources/gui/icons/math-sigma.svg
+src/silx/resources/gui/icons/math-smooth.png
+src/silx/resources/gui/icons/math-smooth.svg
+src/silx/resources/gui/icons/math-square-amplitude.png
+src/silx/resources/gui/icons/math-square-amplitude.svg
+src/silx/resources/gui/icons/math-substract.png
+src/silx/resources/gui/icons/math-substract.svg
+src/silx/resources/gui/icons/math-swap-sign.png
+src/silx/resources/gui/icons/math-swap-sign.svg
+src/silx/resources/gui/icons/math-ymin-to-zero.png
+src/silx/resources/gui/icons/math-ymin-to-zero.svg
+src/silx/resources/gui/icons/median-filter.png
+src/silx/resources/gui/icons/median-filter.svg
+src/silx/resources/gui/icons/next.png
+src/silx/resources/gui/icons/next.svg
+src/silx/resources/gui/icons/normal.png
+src/silx/resources/gui/icons/normal.svg
+src/silx/resources/gui/icons/nxdata-axis-add.png
+src/silx/resources/gui/icons/nxdata-axis-add.svg
+src/silx/resources/gui/icons/nxdata-axis-remove.png
+src/silx/resources/gui/icons/nxdata-axis-remove.svg
+src/silx/resources/gui/icons/nxdata-create.png
+src/silx/resources/gui/icons/nxdata-create.svg
+src/silx/resources/gui/icons/nxdata-remove.png
+src/silx/resources/gui/icons/nxdata-remove.svg
+src/silx/resources/gui/icons/pan.png
+src/silx/resources/gui/icons/pan.svg
+src/silx/resources/gui/icons/pixel-intensities.png
+src/silx/resources/gui/icons/pixel-intensities.svg
+src/silx/resources/gui/icons/plot-grid.png
+src/silx/resources/gui/icons/plot-grid.svg
+src/silx/resources/gui/icons/plot-roi-above.png
+src/silx/resources/gui/icons/plot-roi-above.svg
+src/silx/resources/gui/icons/plot-roi-below.png
+src/silx/resources/gui/icons/plot-roi-below.svg
+src/silx/resources/gui/icons/plot-roi-between.png
+src/silx/resources/gui/icons/plot-roi-between.svg
+src/silx/resources/gui/icons/plot-roi-reset.png
+src/silx/resources/gui/icons/plot-roi-reset.svg
+src/silx/resources/gui/icons/plot-roi.png
+src/silx/resources/gui/icons/plot-roi.svg
+src/silx/resources/gui/icons/plot-symbols.png
+src/silx/resources/gui/icons/plot-symbols.svg
+src/silx/resources/gui/icons/plot-toggle-points.png
+src/silx/resources/gui/icons/plot-toggle-points.svg
+src/silx/resources/gui/icons/plot-widget.png
+src/silx/resources/gui/icons/plot-widget.svg
+src/silx/resources/gui/icons/plot-window-image.png
+src/silx/resources/gui/icons/plot-window-image.svg
+src/silx/resources/gui/icons/plot-window.png
+src/silx/resources/gui/icons/plot-window.svg
+src/silx/resources/gui/icons/plot-xauto.png
+src/silx/resources/gui/icons/plot-xauto.svg
+src/silx/resources/gui/icons/plot-xlog.png
+src/silx/resources/gui/icons/plot-xlog.svg
+src/silx/resources/gui/icons/plot-yauto.png
+src/silx/resources/gui/icons/plot-yauto.svg
+src/silx/resources/gui/icons/plot-ydown.png
+src/silx/resources/gui/icons/plot-ydown.svg
+src/silx/resources/gui/icons/plot-ylog.png
+src/silx/resources/gui/icons/plot-ylog.svg
+src/silx/resources/gui/icons/plot-yup.png
+src/silx/resources/gui/icons/plot-yup.svg
+src/silx/resources/gui/icons/pointing-hand.png
+src/silx/resources/gui/icons/pointing-hand.svg
+src/silx/resources/gui/icons/previous.png
+src/silx/resources/gui/icons/previous.svg
+src/silx/resources/gui/icons/process-working.mng
+src/silx/resources/gui/icons/profile-clear.png
+src/silx/resources/gui/icons/profile-clear.svg
+src/silx/resources/gui/icons/profile1D.png
+src/silx/resources/gui/icons/profile1D.svg
+src/silx/resources/gui/icons/profile2D.png
+src/silx/resources/gui/icons/profile2D.svg
+src/silx/resources/gui/icons/remove.png
+src/silx/resources/gui/icons/remove.svg
+src/silx/resources/gui/icons/rm.png
+src/silx/resources/gui/icons/rm.svg
+src/silx/resources/gui/icons/rotate-3d.png
+src/silx/resources/gui/icons/rotate-3d.svg
+src/silx/resources/gui/icons/rudder.png
+src/silx/resources/gui/icons/rudder.svg
+src/silx/resources/gui/icons/selected.png
+src/silx/resources/gui/icons/selected.svg
+src/silx/resources/gui/icons/shape-circle-solid.png
+src/silx/resources/gui/icons/shape-circle-solid.svg
+src/silx/resources/gui/icons/shape-circle.png
+src/silx/resources/gui/icons/shape-circle.svg
+src/silx/resources/gui/icons/shape-cross.png
+src/silx/resources/gui/icons/shape-cross.svg
+src/silx/resources/gui/icons/shape-diagonal-directed.png
+src/silx/resources/gui/icons/shape-diagonal-directed.svg
+src/silx/resources/gui/icons/shape-diagonal.png
+src/silx/resources/gui/icons/shape-diagonal.svg
+src/silx/resources/gui/icons/shape-ellipse-solid.png
+src/silx/resources/gui/icons/shape-ellipse-solid.svg
+src/silx/resources/gui/icons/shape-ellipse.png
+src/silx/resources/gui/icons/shape-ellipse.svg
+src/silx/resources/gui/icons/shape-horizontal.png
+src/silx/resources/gui/icons/shape-horizontal.svg
+src/silx/resources/gui/icons/shape-polygon.png
+src/silx/resources/gui/icons/shape-polygon.svg
+src/silx/resources/gui/icons/shape-rectangle.png
+src/silx/resources/gui/icons/shape-rectangle.svg
+src/silx/resources/gui/icons/shape-square.png
+src/silx/resources/gui/icons/shape-square.svg
+src/silx/resources/gui/icons/shape-vertical.png
+src/silx/resources/gui/icons/shape-vertical.svg
+src/silx/resources/gui/icons/side-histograms.png
+src/silx/resources/gui/icons/side-histograms.svg
+src/silx/resources/gui/icons/silx.png
+src/silx/resources/gui/icons/silx.svg
+src/silx/resources/gui/icons/slice-cross.png
+src/silx/resources/gui/icons/slice-cross.svg
+src/silx/resources/gui/icons/slice-horizontal.png
+src/silx/resources/gui/icons/slice-horizontal.svg
+src/silx/resources/gui/icons/slice-vertical.png
+src/silx/resources/gui/icons/slice-vertical.svg
+src/silx/resources/gui/icons/sliders-off.png
+src/silx/resources/gui/icons/sliders-off.svg
+src/silx/resources/gui/icons/sliders-on.png
+src/silx/resources/gui/icons/sliders-on.svg
+src/silx/resources/gui/icons/spec.png
+src/silx/resources/gui/icons/spec.svg
+src/silx/resources/gui/icons/stats-active-items.png
+src/silx/resources/gui/icons/stats-active-items.svg
+src/silx/resources/gui/icons/stats-visible-data.png
+src/silx/resources/gui/icons/stats-visible-data.svg
+src/silx/resources/gui/icons/stats-whole-data.png
+src/silx/resources/gui/icons/stats-whole-data.svg
+src/silx/resources/gui/icons/stats-whole-items.png
+src/silx/resources/gui/icons/stats-whole-items.svg
+src/silx/resources/gui/icons/tree-collapse-all.png
+src/silx/resources/gui/icons/tree-collapse-all.svg
+src/silx/resources/gui/icons/tree-expand-all.png
+src/silx/resources/gui/icons/tree-expand-all.svg
+src/silx/resources/gui/icons/tree-sort.png
+src/silx/resources/gui/icons/tree-sort.svg
+src/silx/resources/gui/icons/view-1d.png
+src/silx/resources/gui/icons/view-1d.svg
+src/silx/resources/gui/icons/view-2d-stack.png
+src/silx/resources/gui/icons/view-2d-stack.svg
+src/silx/resources/gui/icons/view-2d.png
+src/silx/resources/gui/icons/view-2d.svg
+src/silx/resources/gui/icons/view-3d.png
+src/silx/resources/gui/icons/view-3d.svg
+src/silx/resources/gui/icons/view-fullscreen.png
+src/silx/resources/gui/icons/view-fullscreen.svg
+src/silx/resources/gui/icons/view-hdf5.png
+src/silx/resources/gui/icons/view-hdf5.svg
+src/silx/resources/gui/icons/view-nexus.png
+src/silx/resources/gui/icons/view-nexus.svg
+src/silx/resources/gui/icons/view-nofullscreen.png
+src/silx/resources/gui/icons/view-nofullscreen.svg
+src/silx/resources/gui/icons/view-raw.png
+src/silx/resources/gui/icons/view-raw.svg
+src/silx/resources/gui/icons/view-refresh.png
+src/silx/resources/gui/icons/view-refresh.svg
+src/silx/resources/gui/icons/view-text.png
+src/silx/resources/gui/icons/view-text.svg
+src/silx/resources/gui/icons/window-new.png
+src/silx/resources/gui/icons/window-new.svg
+src/silx/resources/gui/icons/zoom-back.png
+src/silx/resources/gui/icons/zoom-back.svg
+src/silx/resources/gui/icons/zoom-in.png
+src/silx/resources/gui/icons/zoom-in.svg
+src/silx/resources/gui/icons/zoom-original.png
+src/silx/resources/gui/icons/zoom-original.svg
+src/silx/resources/gui/icons/zoom-out.png
+src/silx/resources/gui/icons/zoom-out.svg
+src/silx/resources/gui/icons/zoom.png
+src/silx/resources/gui/icons/zoom.svg
+src/silx/resources/gui/icons/process-working/00.png
+src/silx/resources/gui/icons/process-working/01.png
+src/silx/resources/gui/icons/process-working/02.png
+src/silx/resources/gui/icons/process-working/03.png
+src/silx/resources/gui/icons/process-working/04.png
+src/silx/resources/gui/icons/process-working/05.png
+src/silx/resources/gui/icons/process-working/06.png
+src/silx/resources/gui/icons/process-working/07.png
+src/silx/resources/gui/icons/process-working/08.png
+src/silx/resources/gui/icons/process-working/09.png
+src/silx/resources/gui/icons/process-working/10.png
+src/silx/resources/gui/icons/process-working/11.png
+src/silx/resources/gui/icons/process-working/12.png
+src/silx/resources/gui/icons/process-working/13.png
+src/silx/resources/gui/icons/process-working/14.png
+src/silx/resources/gui/icons/process-working/15.png
+src/silx/resources/gui/icons/process-working/16.png
+src/silx/resources/gui/icons/process-working/17.png
+src/silx/resources/gui/icons/process-working/18.png
+src/silx/resources/gui/icons/process-working/19.png
+src/silx/resources/gui/icons/process-working/20.png
+src/silx/resources/gui/icons/process-working/21.png
+src/silx/resources/gui/icons/process-working/22.png
+src/silx/resources/gui/icons/process-working/23.png
+src/silx/resources/gui/icons/process-working/24.png
+src/silx/resources/gui/icons/process-working/25.png
+src/silx/resources/gui/icons/process-working/26.png
+src/silx/resources/gui/icons/process-working/27.png
+src/silx/resources/gui/icons/process-working/28.png
+src/silx/resources/gui/icons/process-working/29.png
+src/silx/resources/gui/icons/process-working/30.png
+src/silx/resources/gui/logo/silx.png
+src/silx/resources/gui/logo/silx.svg
+src/silx/resources/opencl/addition.cl
+src/silx/resources/opencl/array_utils.cl
+src/silx/resources/opencl/backproj.cl
+src/silx/resources/opencl/backproj_helper.cl
+src/silx/resources/opencl/bitonic.cl
+src/silx/resources/opencl/convolution.cl
+src/silx/resources/opencl/convolution_textures.cl
+src/silx/resources/opencl/doubleword.cl
+src/silx/resources/opencl/kahan.cl
+src/silx/resources/opencl/linalg.cl
+src/silx/resources/opencl/medfilt.cl
+src/silx/resources/opencl/preprocess.cl
+src/silx/resources/opencl/proj.cl
+src/silx/resources/opencl/sparse.cl
+src/silx/resources/opencl/statistics.cl
+src/silx/resources/opencl/codec/byte_offset.cl
+src/silx/resources/opencl/image/cast.cl
+src/silx/resources/opencl/image/histogram.cl
+src/silx/resources/opencl/image/map.cl
+src/silx/resources/opencl/image/max_min.cl
+src/silx/resources/opencl/sift/addition.cl
+src/silx/resources/opencl/sift/algebra.cl
+src/silx/resources/opencl/sift/convolution.cl
+src/silx/resources/opencl/sift/descriptor_cpu.cl
+src/silx/resources/opencl/sift/descriptor_gpu1.cl
+src/silx/resources/opencl/sift/descriptor_gpu2.cl
+src/silx/resources/opencl/sift/gaussian.cl
+src/silx/resources/opencl/sift/image.cl
+src/silx/resources/opencl/sift/interpolation.cl
+src/silx/resources/opencl/sift/matching_cpu.cl
+src/silx/resources/opencl/sift/matching_gpu.cl
+src/silx/resources/opencl/sift/memset.cl
+src/silx/resources/opencl/sift/orientation_cpu.cl
+src/silx/resources/opencl/sift/orientation_gpu.cl
+src/silx/resources/opencl/sift/preprocess.cl
+src/silx/resources/opencl/sift/reductions.cl
+src/silx/resources/opencl/sift/sift.cl
+src/silx/resources/opencl/sift/transform.cl
+src/silx/sx/__init__.py
+src/silx/sx/_plot.py
+src/silx/sx/_plot3d.py
+src/silx/test/__init__.py
+src/silx/test/test_resources.py
+src/silx/test/test_sx.py
+src/silx/test/test_version.py
+src/silx/test/utils.py
+src/silx/third_party/EdfFile.py
+src/silx/third_party/TiffIO.py
+src/silx/third_party/__init__.py
+src/silx/third_party/scipy_spatial.py
+src/silx/third_party/setup.py
+src/silx/third_party/_local/__init__.py
+src/silx/third_party/_local/scipy_spatial/__init__.py
+src/silx/third_party/_local/scipy_spatial/qhull.pxd
+src/silx/third_party/_local/scipy_spatial/qhull.pyx
+src/silx/third_party/_local/scipy_spatial/qhull_misc.h
+src/silx/third_party/_local/scipy_spatial/setlist.pxd
+src/silx/third_party/_local/scipy_spatial/setup.py
+src/silx/third_party/_local/scipy_spatial/qhull/COPYING.txt
+src/silx/third_party/_local/scipy_spatial/qhull/src/geom2_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/geom_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/geom_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/global_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/io_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/io_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/libqhull_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/libqhull_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/mem_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/mem_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/merge_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/merge_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/poly2_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/poly_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/poly_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/qhull_ra.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/qset_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/qset_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/random_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/random_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/rboxlib_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/stat_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/stat_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/user_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/user_r.h
+src/silx/third_party/_local/scipy_spatial/qhull/src/usermem_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/userprintf_r.c
+src/silx/third_party/_local/scipy_spatial/qhull/src/userprintf_rbox_r.c
+src/silx/utils/ExternalResources.py
+src/silx/utils/__init__.py
+src/silx/utils/_have_openmp.pxd
+src/silx/utils/array_like.py
+src/silx/utils/debug.py
+src/silx/utils/deprecation.py
+src/silx/utils/enum.py
+src/silx/utils/exceptions.py
+src/silx/utils/files.py
+src/silx/utils/html.py
+src/silx/utils/launcher.py
+src/silx/utils/number.py
+src/silx/utils/property.py
+src/silx/utils/proxy.py
+src/silx/utils/retry.py
+src/silx/utils/setup.py
+src/silx/utils/testutils.py
+src/silx/utils/weakref.py
+src/silx/utils/include/silx_store_openmp.h
+src/silx/utils/test/__init__.py
+src/silx/utils/test/test_array_like.py
+src/silx/utils/test/test_debug.py
+src/silx/utils/test/test_deprecation.py
+src/silx/utils/test/test_enum.py
+src/silx/utils/test/test_external_resources.py
+src/silx/utils/test/test_launcher.py
+src/silx/utils/test/test_launcher_command.py
+src/silx/utils/test/test_number.py
+src/silx/utils/test/test_proxy.py
+src/silx/utils/test/test_retry.py
+src/silx/utils/test/test_testutils.py
+src/silx/utils/test/test_weakref.py \ No newline at end of file
diff --git a/silx.egg-info/requires.txt b/silx.egg-info/requires.txt
index 4ed2690..11d2418 100644
--- a/silx.egg-info/requires.txt
+++ b/silx.egg-info/requires.txt
@@ -2,7 +2,6 @@ numpy>=1.12.0
setuptools
h5py
fabio>=0.9
-six
[full]
pyopencl
@@ -12,5 +11,10 @@ matplotlib>=1.2.0
PyOpenGL
python-dateutil
PyQt5
+hdf5plugin
scipy
Pillow
+
+[test]
+pytest
+pytest-xvfb
diff --git a/silx/__init__.py b/silx/__init__.py
deleted file mode 100644
index 2892572..0000000
--- a/silx/__init__.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""The silx package contains the following main sub-packages:
-
-- silx.gui: Qt widgets for data visualization and data file browsing
-- silx.image: Some processing functions for 2D images
-- silx.io: Reading and writing data files (HDF5/NeXus, SPEC, ...)
-- silx.math: Some processing functions for 1D, 2D, 3D, nD arrays
-- silx.opencl: OpenCL-based data processing
-- silx.sx: High-level silx functions suited for (I)Python console.
-- silx.utils: Miscellaneous convenient functions
-
-See silx documentation: http://www.silx.org/doc/silx/latest/
-"""
-
-from __future__ import absolute_import, print_function, division
-
-__authors__ = ["Jérôme Kieffer"]
-__license__ = "MIT"
-__date__ = "26/04/2018"
-
-import os as _os
-import logging as _logging
-from ._config import Config as _Config
-
-config = _Config()
-"""Global configuration shared with the whole library"""
-
-# Attach a do nothing logging handler for silx
-_logging.getLogger(__name__).addHandler(_logging.NullHandler())
-
-
-project = _os.path.basename(_os.path.dirname(_os.path.abspath(__file__)))
-
-try:
- from ._version import __date__ as date # noqa
- from ._version import version, version_info, hexversion, strictversion # noqa
-except ImportError:
- raise RuntimeError("Do NOT use %s from its sources: build it and use the built version" % project)
diff --git a/silx/app/convert.py b/silx/app/convert.py
deleted file mode 100644
index 7e601ce..0000000
--- a/silx/app/convert.py
+++ /dev/null
@@ -1,525 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Convert silx supported data files into HDF5 files"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/02/2019"
-
-import ast
-import os
-import argparse
-from glob import glob
-import logging
-import re
-import time
-import numpy
-import six
-
-import silx.io
-from silx.io.specfile import is_specfile
-from silx.io import fabioh5
-
-_logger = logging.getLogger(__name__)
-"""Module logger"""
-
-
-def c_format_string_to_re(pattern_string):
- """
-
- :param pattern_string: C style format string with integer patterns
- (e.g. "%d", "%04d").
- Not supported: fixed length padded with whitespaces (e.g "%4d", "%-4d")
- :return: Equivalent regular expression (e.g. "\\d+", "\\d{4}")
- """
- # escape dots and backslashes
- pattern_string = pattern_string.replace("\\", "\\\\")
- pattern_string = pattern_string.replace(".", r"\.")
-
- # %d
- pattern_string = pattern_string.replace("%d", r"([-+]?\d+)")
-
- # %0nd
- for sub_pattern in re.findall(r"%0\d+d", pattern_string):
- n = int(re.search(r"%0(\d+)d", sub_pattern).group(1))
- if n == 1:
- re_sub_pattern = r"([+-]?\d)"
- else:
- re_sub_pattern = r"([\d+-]\d{%d})" % (n - 1)
- pattern_string = pattern_string.replace(sub_pattern, re_sub_pattern, 1)
-
- return pattern_string
-
-
-def drop_indices_before_begin(filenames, regex, begin):
- """
-
- :param List[str] filenames: list of filenames
- :param str regex: Regexp used to find indices in a filename
- :param str begin: Comma separated list of begin indices
- :return: List of filenames with only indices >= begin
- """
- begin_indices = list(map(int, begin.split(",")))
- output_filenames = []
- for fname in filenames:
- m = re.match(regex, fname)
- file_indices = list(map(int, m.groups()))
- if len(file_indices) != len(begin_indices):
- raise IOError("Number of indices found in filename "
- "does not match number of parsed end indices.")
- good_indices = True
- for i, fidx in enumerate(file_indices):
- if fidx < begin_indices[i]:
- good_indices = False
- if good_indices:
- output_filenames.append(fname)
- return output_filenames
-
-
-def drop_indices_after_end(filenames, regex, end):
- """
-
- :param List[str] filenames: list of filenames
- :param str regex: Regexp used to find indices in a filename
- :param str end: Comma separated list of end indices
- :return: List of filenames with only indices <= end
- """
- end_indices = list(map(int, end.split(",")))
- output_filenames = []
- for fname in filenames:
- m = re.match(regex, fname)
- file_indices = list(map(int, m.groups()))
- if len(file_indices) != len(end_indices):
- raise IOError("Number of indices found in filename "
- "does not match number of parsed end indices.")
- good_indices = True
- for i, fidx in enumerate(file_indices):
- if fidx > end_indices[i]:
- good_indices = False
- if good_indices:
- output_filenames.append(fname)
- return output_filenames
-
-
-def are_files_missing_in_series(filenames, regex):
- """Return True if any file is missing in a list of filenames
- that are supposed to follow a pattern.
-
- :param List[str] filenames: list of filenames
- :param str regex: Regexp used to find indices in a filename
- :return: boolean
- :raises AssertionError: if a filename does not match the regexp
- """
- previous_indices = None
- for fname in filenames:
- m = re.match(regex, fname)
- assert m is not None, \
- "regex %s does not match filename %s" % (fname, regex)
- new_indices = list(map(int, m.groups()))
- if previous_indices is not None:
- for old_idx, new_idx in zip(previous_indices, new_indices):
- if (new_idx - old_idx) > 1:
- _logger.error("Index increment > 1 in file series: "
- "previous idx %d, next idx %d",
- old_idx, new_idx)
- return True
- previous_indices = new_indices
- return False
-
-
-def are_all_specfile(filenames):
- """Return True if all files in a list are SPEC files.
- :param List[str] filenames: list of filenames
- """
- for fname in filenames:
- if not is_specfile(fname):
- return False
- return True
-
-
-def contains_specfile(filenames):
- """Return True if any file in a list are SPEC files.
- :param List[str] filenames: list of filenames
- """
- for fname in filenames:
- if is_specfile(fname):
- return True
- return False
-
-
-def main(argv):
- """
- Main function to launch the converter as an application
-
- :param argv: Command line arguments
- :returns: exit status
- """
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument(
- 'input_files',
- nargs="*",
- help='Input files (EDF, TIFF, SPEC...). When specifying multiple '
- 'files, you cannot specify both fabio images and SPEC files. '
- 'Multiple SPEC files will simply be concatenated, with one '
- 'entry per scan. Multiple image files will be merged into '
- 'a single entry with a stack of images.')
- # input_files and --filepattern are mutually exclusive
- parser.add_argument(
- '--file-pattern',
- help='File name pattern for loading a series of indexed image files '
- '(toto_%%04d.edf). This argument is incompatible with argument '
- 'input_files. If an output URI with a HDF5 path is provided, '
- 'only the content of the NXdetector group will be copied there. '
- 'If no HDF5 path, or just "/", is given, a complete NXdata '
- 'structure will be created.')
- parser.add_argument(
- '-o', '--output-uri',
- default=time.strftime("%Y%m%d-%H%M%S") + '.h5',
- help='Output file name (HDF5). An URI can be provided to write'
- ' the data into a specific group in the output file: '
- '/path/to/file::/path/to/group. '
- 'If not provided, the filename defaults to a timestamp:'
- ' YYYYmmdd-HHMMSS.h5')
- parser.add_argument(
- '-m', '--mode',
- default="w-",
- help='Write mode: "r+" (read/write, file must exist), '
- '"w" (write, existing file is lost), '
- '"w-" (write, fail if file exists) or '
- '"a" (read/write if exists, create otherwise)')
- parser.add_argument(
- '--begin',
- help='First file index, or first file indices to be considered. '
- 'This argument only makes sense when used together with '
- '--file-pattern. Provide as many start indices as there '
- 'are indices in the file pattern, separated by commas. '
- 'Examples: "--filepattern toto_%%d.edf --begin 100", '
- ' "--filepattern toto_%%d_%%04d_%%02d.edf --begin 100,2000,5".')
- parser.add_argument(
- '--end',
- help='Last file index, or last file indices to be considered. '
- 'The same rules as with argument --begin apply. '
- 'Example: "--filepattern toto_%%d_%%d.edf --end 199,1999"')
- parser.add_argument(
- '--add-root-group',
- action="store_true",
- help='This option causes each input file to be written to a '
- 'specific root group with the same name as the file. When '
- 'merging multiple input files, this can help preventing conflicts'
- ' when datasets have the same name (see --overwrite-data). '
- 'This option is ignored when using --file-pattern.')
- parser.add_argument(
- '--overwrite-data',
- action="store_true",
- help='If the output path exists and an input dataset has the same'
- ' name as an existing output dataset, overwrite the output '
- 'dataset (in modes "r+" or "a").')
- parser.add_argument(
- '--min-size',
- type=int,
- default=500,
- help='Minimum number of elements required to be in a dataset to '
- 'apply compression or chunking (default 500).')
- parser.add_argument(
- '--chunks',
- nargs="?",
- const="auto",
- help='Chunk shape. Provide an argument that evaluates as a python '
- 'tuple (e.g. "(1024, 768)"). If this option is provided without '
- 'specifying an argument, the h5py library will guess a chunk for '
- 'you. Note that if you specify an explicit chunking shape, it '
- 'will be applied identically to all datasets with a large enough '
- 'size (see --min-size). ')
- parser.add_argument(
- '--compression',
- nargs="?",
- const="gzip",
- help='Compression filter. By default, the datasets in the output '
- 'file are not compressed. If this option is specified without '
- 'argument, the GZIP compression is used. Additional compression '
- 'filters may be available, depending on your HDF5 installation.')
-
- def check_gzip_compression_opts(value):
- ivalue = int(value)
- if ivalue < 0 or ivalue > 9:
- raise argparse.ArgumentTypeError(
- "--compression-opts must be an int from 0 to 9")
- return ivalue
-
- parser.add_argument(
- '--compression-opts',
- type=check_gzip_compression_opts,
- help='Compression options. For "gzip", this may be an integer from '
- '0 to 9, with a default of 4. This is only supported for GZIP.')
- parser.add_argument(
- '--shuffle',
- action="store_true",
- help='Enables the byte shuffle filter. This may improve the compression '
- 'ratio for block oriented compressors like GZIP or LZF.')
- parser.add_argument(
- '--fletcher32',
- action="store_true",
- help='Adds a checksum to each chunk to detect data corruption.')
- parser.add_argument(
- '--debug',
- action="store_true",
- default=False,
- help='Set logging system in debug mode')
-
- options = parser.parse_args(argv[1:])
-
- if options.debug:
- logging.root.setLevel(logging.DEBUG)
-
- # Import after parsing --debug
- try:
- # it should be loaded before h5py
- import hdf5plugin # noqa
- except ImportError:
- _logger.debug("Backtrace", exc_info=True)
- hdf5plugin = None
-
- import h5py
-
- try:
- from silx.io.convert import write_to_h5
- except ImportError:
- _logger.debug("Backtrace", exc_info=True)
- write_to_h5 = None
-
- if hdf5plugin is None:
- message = "Module 'hdf5plugin' is not installed. It supports additional hdf5"\
- + " compressions. You can install it using \"pip install hdf5plugin\"."
- _logger.debug(message)
-
- # Process input arguments (mutually exclusive arguments)
- if bool(options.input_files) == bool(options.file_pattern is not None):
- if not options.input_files:
- message = "You must specify either input files (at least one), "
- message += "or a file pattern."
- else:
- message = "You cannot specify input files and a file pattern"
- message += " at the same time."
- _logger.error(message)
- return -1
- elif options.input_files:
- # some shells (windows) don't interpret wildcard characters (*, ?, [])
- old_input_list = list(options.input_files)
- options.input_files = []
- for fname in old_input_list:
- globbed_files = glob(fname)
- if not globbed_files:
- # no files found, keep the name as it is, to raise an error later
- options.input_files += [fname]
- else:
- # glob does not sort files, but the bash shell does
- options.input_files += sorted(globbed_files)
- else:
- # File series
- dirname = os.path.dirname(options.file_pattern)
- file_pattern_re = c_format_string_to_re(options.file_pattern) + "$"
- files_in_dir = glob(os.path.join(dirname, "*"))
- _logger.debug("""
- Processing file_pattern
- dirname: %s
- file_pattern_re: %s
- files_in_dir: %s
- """, dirname, file_pattern_re, files_in_dir)
-
- options.input_files = sorted(list(filter(lambda name: re.match(file_pattern_re, name),
- files_in_dir)))
- _logger.debug("options.input_files: %s", options.input_files)
-
- if options.begin is not None:
- options.input_files = drop_indices_before_begin(options.input_files,
- file_pattern_re,
- options.begin)
- _logger.debug("options.input_files after applying --begin: %s",
- options.input_files)
-
- if options.end is not None:
- options.input_files = drop_indices_after_end(options.input_files,
- file_pattern_re,
- options.end)
- _logger.debug("options.input_files after applying --end: %s",
- options.input_files)
-
- if are_files_missing_in_series(options.input_files,
- file_pattern_re):
- _logger.error("File missing in the file series. Aborting.")
- return -1
-
- if not options.input_files:
- _logger.error("No file matching --file-pattern found.")
- return -1
-
- # Test that the output path is writeable
- if "::" in options.output_uri:
- output_name, hdf5_path = options.output_uri.split("::")
- else:
- output_name, hdf5_path = options.output_uri, "/"
-
- if os.path.isfile(output_name):
- if options.mode == "w-":
- _logger.error("Output file %s exists and mode is 'w-' (default)."
- " Aborting. To append data to an existing file, "
- "use 'a' or 'r+'.",
- output_name)
- return -1
- elif not os.access(output_name, os.W_OK):
- _logger.error("Output file %s exists and is not writeable.",
- output_name)
- return -1
- elif options.mode == "w":
- _logger.info("Output file %s exists and mode is 'w'. "
- "Overwriting existing file.", output_name)
- elif options.mode in ["a", "r+"]:
- _logger.info("Appending data to existing file %s.",
- output_name)
- else:
- if options.mode == "r+":
- _logger.error("Output file %s does not exist and mode is 'r+'"
- " (append, file must exist). Aborting.",
- output_name)
- return -1
- else:
- _logger.info("Creating new output file %s.",
- output_name)
-
- # Test that all input files exist and are readable
- bad_input = False
- for fname in options.input_files:
- if not os.access(fname, os.R_OK):
- _logger.error("Cannot read input file %s.",
- fname)
- bad_input = True
- if bad_input:
- _logger.error("Aborting.")
- return -1
-
- # create_dataset special args
- create_dataset_args = {}
- if options.chunks is not None:
- if options.chunks.lower() in ["auto", "true"]:
- create_dataset_args["chunks"] = True
- else:
- try:
- chunks = ast.literal_eval(options.chunks)
- except (ValueError, SyntaxError):
- _logger.error("Invalid --chunks argument %s", options.chunks)
- return -1
- if not isinstance(chunks, (tuple, list)):
- _logger.error("--chunks argument str does not evaluate to a tuple")
- return -1
- else:
- nitems = numpy.prod(chunks)
- nbytes = nitems * 8
- if nbytes > 10**6:
- _logger.warning("Requested chunk size might be larger than"
- " the default 1MB chunk cache, for float64"
- " data. This can dramatically affect I/O "
- "performances.")
- create_dataset_args["chunks"] = chunks
-
- if options.compression is not None:
- try:
- compression = int(options.compression)
- except ValueError:
- compression = options.compression
- create_dataset_args["compression"] = compression
-
- if options.compression_opts is not None:
- create_dataset_args["compression_opts"] = options.compression_opts
-
- if options.shuffle:
- create_dataset_args["shuffle"] = True
-
- if options.fletcher32:
- create_dataset_args["fletcher32"] = True
-
- if (len(options.input_files) > 1 and
- not contains_specfile(options.input_files) and
- not options.add_root_group) or options.file_pattern is not None:
- # File series -> stack of images
- input_group = fabioh5.File(file_series=options.input_files)
- if hdf5_path != "/":
- # we want to append only data and headers to an existing file
- input_group = input_group["/scan_0/instrument/detector_0"]
- with h5py.File(output_name, mode=options.mode) as h5f:
- write_to_h5(input_group, h5f,
- h5path=hdf5_path,
- overwrite_data=options.overwrite_data,
- create_dataset_args=create_dataset_args,
- min_size=options.min_size)
-
- elif len(options.input_files) == 1 or \
- are_all_specfile(options.input_files) or\
- options.add_root_group:
- # single file, or spec files
- h5paths_and_groups = []
- for input_name in options.input_files:
- hdf5_path_for_file = hdf5_path
- if options.add_root_group:
- hdf5_path_for_file = hdf5_path.rstrip("/") + "/" + os.path.basename(input_name)
- try:
- h5paths_and_groups.append((hdf5_path_for_file,
- silx.io.open(input_name)))
- except IOError:
- _logger.error("Cannot read file %s. If this is a file format "
- "supported by the fabio library, you can try to"
- " install fabio (`pip install fabio`)."
- " Aborting conversion.",
- input_name)
- return -1
-
- with h5py.File(output_name, mode=options.mode) as h5f:
- for hdf5_path_for_file, input_group in h5paths_and_groups:
- write_to_h5(input_group, h5f,
- h5path=hdf5_path_for_file,
- overwrite_data=options.overwrite_data,
- create_dataset_args=create_dataset_args,
- min_size=options.min_size)
-
- else:
- # multiple file, SPEC and fabio images mixed
- _logger.error("Multiple files with incompatible formats specified. "
- "You can provide multiple SPEC files or multiple image "
- "files, but not both.")
- return -1
-
- with h5py.File(output_name, mode="r+") as h5f:
- # append "silx convert" to the creator attribute, for NeXus files
- previous_creator = h5f.attrs.get("creator", u"")
- creator = "silx convert (v%s)" % silx.version
- # only if it not already there
- if creator not in previous_creator:
- if not previous_creator:
- new_creator = creator
- else:
- new_creator = previous_creator + "; " + creator
- h5f.attrs["creator"] = numpy.array(
- new_creator,
- dtype=h5py.special_dtype(vlen=six.text_type))
-
- return 0
diff --git a/silx/app/test/__init__.py b/silx/app/test/__init__.py
deleted file mode 100644
index 7c91134..0000000
--- a/silx/app/test/__init__.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "06/06/2018"
-
-import unittest
-
-from ..view import test as test_view
-from . import test_convert
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_view.suite())
- test_suite.addTest(test_convert.suite())
- return test_suite
diff --git a/silx/app/test/test_convert.py b/silx/app/test/test_convert.py
deleted file mode 100644
index 857f30c..0000000
--- a/silx/app/test/test_convert.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module testing silx.app.convert"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import os
-import sys
-import tempfile
-import unittest
-import io
-import gc
-import h5py
-
-import silx
-from .. import convert
-from silx.utils import testutils
-from silx.io.utils import h5py_read_dataset
-
-
-# content of a spec file
-sftext = """#F /tmp/sf.dat
-#E 1455180875
-#D Thu Feb 11 09:54:35 2016
-#C imaging User = opid17
-#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
-#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
-#o0 pshg mrtu mrtd
-#o2 ss1vo ss1ho ss1vg
-
-#J0 Seconds IA ion.mono Current
-#J1 xbpmc2 idgap1 Inorm
-
-#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
-#D Thu Feb 11 09:55:20 2016
-#T 0.2 (Seconds)
-#P0 180.005 -0.66875 0.87125
-#P1 14.74255 16.197579 12.238283
-#N 4
-#L MRTSlit UP second column 3rd_col
--1.23 5.89 8
-8.478100E+01 5 1.56
-3.14 2.73 -3.14
-1.2 2.3 3.4
-
-#S 1 aaaaaa
-#D Thu Feb 11 10:00:32 2016
-#@MCADEV 1
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#N 3
-#L uno duo
-1 2
-@A 0 1 2
-@A 10 9 8
-3 4
-@A 3.1 4 5
-@A 7 6 5
-5 6
-@A 6 7.7 8
-@A 4 3 2
-"""
-
-
-class TestConvertCommand(unittest.TestCase):
- """Test command line parsing"""
-
- def testHelp(self):
- # option -h must cause a `raise SystemExit` or a `return 0`
- try:
- result = convert.main(["convert", "--help"])
- except SystemExit as e:
- result = e.args[0]
- self.assertEqual(result, 0)
-
- def testWrongOption(self):
- # presence of a wrong option must cause a SystemExit or a return
- # with a non-zero status
- try:
- result = convert.main(["convert", "--foo"])
- except SystemExit as e:
- result = e.args[0]
- self.assertNotEqual(result, 0)
-
- @testutils.test_logging(convert._logger.name, error=3)
- # one error log per missing file + one "Aborted" error log
- def testWrongFiles(self):
- result = convert.main(["convert", "foo.spec", "bar.edf"])
- self.assertNotEqual(result, 0)
-
- def testFile(self):
- # create a writable temp directory
- tempdir = tempfile.mkdtemp()
-
- # write a temporary SPEC file
- specname = os.path.join(tempdir, "input.dat")
- with io.open(specname, "wb") as fd:
- if sys.version_info < (3, ):
- fd.write(sftext)
- else:
- fd.write(bytes(sftext, 'ascii'))
-
- # convert it
- h5name = os.path.join(tempdir, "output.h5")
- assert not os.path.isfile(h5name)
- command_list = ["convert", "-m", "w",
- specname, "-o", h5name]
- result = convert.main(command_list)
-
- self.assertEqual(result, 0)
- self.assertTrue(os.path.isfile(h5name))
-
- with h5py.File(h5name, "r") as h5f:
- title12 = h5py_read_dataset(h5f["/1.2/title"])
- if sys.version_info < (3, ):
- title12 = title12.encode("utf-8")
- self.assertEqual(title12,
- "aaaaaa")
-
- creator = h5f.attrs.get("creator")
- self.assertIsNotNone(creator, "No creator attribute in NXroot group")
- if sys.version_info < (3, ):
- creator = creator.encode("utf-8")
- self.assertIn("silx convert (v%s)" % silx.version, creator)
-
- # delete input file
- gc.collect() # necessary to free spec file on Windows
- os.unlink(specname)
- os.unlink(h5name)
- os.rmdir(tempdir)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loader(TestConvertCommand))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/app/test_.py b/silx/app/test_.py
deleted file mode 100644
index a8e58bf..0000000
--- a/silx/app/test_.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Launch unittests of the library"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "12/01/2018"
-
-import sys
-import argparse
-import logging
-import unittest
-
-
-class StreamHandlerUnittestReady(logging.StreamHandler):
- """The unittest class TestResult redefine sys.stdout/err to capture
- stdout/err from tests and to display them only when a test fail.
-
- This class allow to use unittest stdout-capture by using the last sys.stdout
- and not a cached one.
- """
-
- def emit(self, record):
- """
- :type record: logging.LogRecord
- """
- self.stream = sys.stderr
- super(StreamHandlerUnittestReady, self).emit(record)
-
- def flush(self):
- pass
-
-
-def createBasicHandler():
- """Create the handler using the basic configuration"""
- hdlr = StreamHandlerUnittestReady()
- fs = logging.BASIC_FORMAT
- dfs = None
- fmt = logging.Formatter(fs, dfs)
- hdlr.setFormatter(fmt)
- return hdlr
-
-
-# Use an handler compatible with unittests, else use_buffer is not working
-for h in logging.root.handlers:
- logging.root.removeHandler(h)
-logging.root.addHandler(createBasicHandler())
-logging.captureWarnings(True)
-
-_logger = logging.getLogger(__name__)
-"""Module logger"""
-
-
-class TextTestResultWithSkipList(unittest.TextTestResult):
- """Override default TextTestResult to display list of skipped tests at the
- end
- """
-
- def printErrors(self):
- unittest.TextTestResult.printErrors(self)
- # Print skipped tests at the end
- self.printErrorList("SKIPPED", self.skipped)
-
-
-def main(argv):
- """
- Main function to launch the unittests as an application
-
- :param argv: Command line arguments
- :returns: exit status
- """
- from silx.test import utils
-
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument("-v", "--verbose", default=0,
- action="count", dest="verbose",
- help="Increase verbosity. Option -v prints additional " +
- "INFO messages. Use -vv for full verbosity, " +
- "including debug messages and test help strings.")
- parser.add_argument("--qt-binding", dest="qt_binding", default=None,
- help="Force using a Qt binding: 'PyQt5' or 'PySide2'")
- utils.test_options.add_parser_argument(parser)
-
- options = parser.parse_args(argv[1:])
-
- test_verbosity = 1
- use_buffer = True
- if options.verbose == 1:
- logging.root.setLevel(logging.INFO)
- _logger.info("Set log level: INFO")
- test_verbosity = 2
- use_buffer = False
- elif options.verbose > 1:
- logging.root.setLevel(logging.DEBUG)
- _logger.info("Set log level: DEBUG")
- test_verbosity = 2
- use_buffer = False
-
- if options.qt_binding:
- binding = options.qt_binding.lower()
- if binding == "pyqt4":
- _logger.info("Force using PyQt4")
- import PyQt4.QtCore # noqa
- elif binding == "pyqt5":
- _logger.info("Force using PyQt5")
- import PyQt5.QtCore # noqa
- elif binding == "pyside":
- _logger.info("Force using PySide")
- import PySide.QtCore # noqa
- elif binding == "pyside2":
- _logger.info("Force using PySide2")
- import PySide2.QtCore # noqa
- else:
- raise ValueError("Qt binding '%s' is unknown" % options.qt_binding)
-
- # Configure test options
- utils.test_options.configure(options)
-
- # Run the tests
- runnerArgs = {}
- runnerArgs["verbosity"] = test_verbosity
- runnerArgs["buffer"] = use_buffer
- runner = unittest.TextTestRunner(**runnerArgs)
- runner.resultclass = TextTestResultWithSkipList
-
- # Display the result when using CTRL-C
- unittest.installHandler()
-
- import silx.test
- test_suite = unittest.TestSuite()
- test_suite.addTest(silx.test.suite())
- result = runner.run(test_suite)
-
- if result.wasSuccessful():
- exit_status = 0
- else:
- exit_status = 1
- return exit_status
diff --git a/silx/app/view/About.py b/silx/app/view/About.py
deleted file mode 100644
index a2b430f..0000000
--- a/silx/app/view/About.py
+++ /dev/null
@@ -1,257 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""About box for Silx viewer"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "05/07/2018"
-
-import os
-import sys
-
-from silx.gui import qt
-from silx.gui import icons
-
-_LICENSE_TEMPLATE = """<p align="center">
-<b>Copyright (C) {year} European Synchrotron Radiation Facility</b>
-</p>
-
-<p align="justify">
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-</p>
-
-<p align="justify">
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-</p>
-
-<p align="justify">
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
-</p>
-"""
-
-
-class About(qt.QDialog):
- """
- Util dialog to display an common about box for all the silx GUIs.
- """
-
- def __init__(self, parent=None):
- """
- :param files_: List of HDF5 or Spec files (pathes or
- :class:`silx.io.spech5.SpecH5` or :class:`h5py.File`
- instances)
- """
- super(About, self).__init__(parent)
- self.__createLayout()
- self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
- self.setModal(True)
- self.setApplicationName(None)
-
- def __createLayout(self):
- layout = qt.QVBoxLayout(self)
- layout.setContentsMargins(24, 15, 24, 20)
- layout.setSpacing(8)
-
- self.__label = qt.QLabel(self)
- self.__label.setWordWrap(True)
- flags = self.__label.textInteractionFlags()
- flags = flags | qt.Qt.TextSelectableByKeyboard
- flags = flags | qt.Qt.TextSelectableByMouse
- self.__label.setTextInteractionFlags(flags)
- self.__label.setOpenExternalLinks(True)
- self.__label.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Preferred)
-
- licenseButton = qt.QPushButton(self)
- licenseButton.setText("License...")
- licenseButton.clicked.connect(self.__displayLicense)
- licenseButton.setAutoDefault(False)
-
- self.__options = qt.QDialogButtonBox()
- self.__options.addButton(licenseButton, qt.QDialogButtonBox.ActionRole)
- okButton = self.__options.addButton(qt.QDialogButtonBox.Ok)
- okButton.setDefault(True)
- okButton.clicked.connect(self.accept)
-
- layout.addWidget(self.__label)
- layout.addWidget(self.__options)
- layout.setStretch(0, 100)
- layout.setStretch(1, 0)
-
- def getHtmlLicense(self):
- """Returns the text license in HTML format.
-
- :rtype: str
- """
- from silx._version import __date__ as date
- year = date.split("/")[2]
- info = dict(
- year=year
- )
- textLicense = _LICENSE_TEMPLATE.format(**info)
- return textLicense
-
- def __displayLicense(self):
- """Displays the license used by silx."""
- text = self.getHtmlLicense()
- licenseDialog = qt.QMessageBox(self)
- licenseDialog.setWindowTitle("License")
- licenseDialog.setText(text)
- licenseDialog.exec_()
-
- def setApplicationName(self, name):
- self.__applicationName = name
- if name is None:
- self.setWindowTitle("About")
- else:
- self.setWindowTitle("About %s" % name)
- self.__updateText()
-
- @staticmethod
- def __formatOptionalLibraries(name, isAvailable):
- """Utils to format availability of features"""
- if isAvailable:
- template = '<b>%s</b> is <font color="green">loaded</font>'
- else:
- template = '<b>%s</b> is <font color="red">not loaded</font>'
- return template % name
-
- @staticmethod
- def __formatOptionalFilters(name, isAvailable):
- """Utils to format availability of features"""
- if isAvailable:
- template = '<b>%s</b> is <font color="green">available</font>'
- else:
- template = '<b>%s</b> is <font color="red">not available</font>'
- return template % name
-
- def __updateText(self):
- """Update the content of the dialog according to the settings."""
- import silx._version
-
- message = """<table>
- <tr><td width="50%" align="center" valign="middle">
- <img src="{silx_image_path}" width="100" />
- </td><td width="50%" align="center" valign="middle">
- <b>{application_name}</b>
- <br />
- <br />{silx_version}
- <br />
- <br /><a href="{project_url}">Upstream project on GitHub</a>
- </td></tr>
- </table>
- <dl>
- <dt><b>Silx version</b></dt><dd>{silx_version}</dd>
- <dt><b>Qt version</b></dt><dd>{qt_version}</dd>
- <dt><b>Qt binding</b></dt><dd>{qt_binding}</dd>
- <dt><b>Python version</b></dt><dd>{python_version}</dd>
- <dt><b>Optional libraries</b></dt><dd>{optional_lib}</dd>
- </dl>
- <p>
- Copyright (C) <a href="{esrf_url}">European Synchrotron Radiation Facility</a>
- </p>
- """
-
- optionals = []
- optionals.append(self.__formatOptionalLibraries("H5py", "h5py" in sys.modules))
- optionals.append(self.__formatOptionalLibraries("FabIO", "fabio" in sys.modules))
-
- try:
- import h5py.version
- if h5py.version.hdf5_version_tuple >= (1, 10, 2):
- # Previous versions only return True if the filter was first used
- # to decode a dataset
- import h5py.h5z
- FILTER_LZ4 = 32004
- FILTER_BITSHUFFLE = 32008
- filters = [
- ("HDF5 LZ4 filter", FILTER_LZ4),
- ("HDF5 Bitshuffle filter", FILTER_BITSHUFFLE),
- ]
- for name, filterId in filters:
- isAvailable = h5py.h5z.filter_avail(filterId)
- optionals.append(self.__formatOptionalFilters(name, isAvailable))
- else:
- optionals.append(self.__formatOptionalLibraries("hdf5plugin", "hdf5plugin" in sys.modules))
- except ImportError:
- pass
-
- # Access to the logo in SVG or PNG
- logo = icons.getQFile("silx:" + os.path.join("gui", "logo", "silx"))
-
- info = dict(
- application_name=self.__applicationName,
- esrf_url="http://www.esrf.eu",
- project_url="https://github.com/silx-kit/silx",
- silx_version=silx._version.version,
- qt_binding=qt.BINDING,
- qt_version=qt.qVersion(),
- python_version=sys.version.replace("\n", "<br />"),
- optional_lib="<br />".join(optionals),
- silx_image_path=logo.fileName()
- )
-
- self.__label.setText(message.format(**info))
- self.__updateSize()
-
- def __updateSize(self):
- """Force the size to a QMessageBox like size."""
- screenSize = qt.QApplication.desktop().availableGeometry(qt.QCursor.pos()).size()
- hardLimit = min(screenSize.width() - 480, 1000)
- if screenSize.width() <= 1024:
- hardLimit = screenSize.width()
- softLimit = min(screenSize.width() / 2, 420)
-
- layoutMinimumSize = self.layout().totalMinimumSize()
- width = layoutMinimumSize.width()
- if width > softLimit:
- width = softLimit
- if width > hardLimit:
- width = hardLimit
-
- height = layoutMinimumSize.height()
- self.setFixedSize(width, height)
-
- @staticmethod
- def about(parent, applicationName):
- """Displays a silx about box with title and text text.
-
- :param qt.QWidget parent: The parent widget
- :param str title: The title of the dialog
- :param str applicationName: The content of the dialog
- """
- dialog = About(parent)
- dialog.setApplicationName(applicationName)
- dialog.exec_()
diff --git a/silx/app/view/ApplicationContext.py b/silx/app/view/ApplicationContext.py
deleted file mode 100644
index 8693848..0000000
--- a/silx/app/view/ApplicationContext.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Browse a data file with a GUI"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "23/05/2018"
-
-import weakref
-import logging
-
-import silx
-from silx.gui.data.DataViews import DataViewHooks
-from silx.gui.colors import Colormap
-from silx.gui.dialog.ColormapDialog import ColormapDialog
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ApplicationContext(DataViewHooks):
- """
- Store the conmtext of the application
-
- It overwrites the DataViewHooks to custom the use of the DataViewer for
- the silx view application.
-
- - Create a single colormap shared with all the views
- - Create a single colormap dialog shared with all the views
- """
-
- def __init__(self, parent, settings=None):
- self.__parent = weakref.ref(parent)
- self.__defaultColormap = None
- self.__defaultColormapDialog = None
- self.__settings = settings
- self.__recentFiles = []
-
- def getSettings(self):
- """Returns actual application settings.
-
- :rtype: qt.QSettings
- """
- return self.__settings
-
- def restoreLibrarySettings(self):
- """Restore the library settings, which must be done early"""
- settings = self.__settings
- if settings is None:
- return
- settings.beginGroup("library")
- plotBackend = settings.value("plot.backend", "")
- plotImageYAxisOrientation = settings.value("plot-image.y-axis-orientation", "")
- settings.endGroup()
-
- if plotBackend != "":
- silx.config.DEFAULT_PLOT_BACKEND = plotBackend
- if plotImageYAxisOrientation != "":
- silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = plotImageYAxisOrientation
-
- def restoreSettings(self):
- """Restore the settings of all the application"""
- settings = self.__settings
- if settings is None:
- return
- parent = self.__parent()
- parent.restoreSettings(settings)
-
- settings.beginGroup("colormap")
- byteArray = settings.value("default", None)
- if byteArray is not None:
- try:
- colormap = Colormap()
- colormap.restoreState(byteArray)
- self.__defaultColormap = colormap
- except Exception:
- _logger.debug("Backtrace", exc_info=True)
- settings.endGroup()
-
- self.__recentFiles = []
- settings.beginGroup("recent-files")
- for index in range(1, 10 + 1):
- if not settings.contains("path%d" % index):
- break
- filePath = settings.value("path%d" % index)
- self.__recentFiles.append(filePath)
- settings.endGroup()
-
- def saveSettings(self):
- """Save the settings of all the application"""
- settings = self.__settings
- if settings is None:
- return
- parent = self.__parent()
- parent.saveSettings(settings)
-
- if self.__defaultColormap is not None:
- settings.beginGroup("colormap")
- settings.setValue("default", self.__defaultColormap.saveState())
- settings.endGroup()
-
- settings.beginGroup("library")
- settings.setValue("plot.backend", silx.config.DEFAULT_PLOT_BACKEND)
- settings.setValue("plot-image.y-axis-orientation", silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION)
- settings.endGroup()
-
- settings.beginGroup("recent-files")
- for index in range(0, 11):
- key = "path%d" % (index + 1)
- if index < len(self.__recentFiles):
- filePath = self.__recentFiles[index]
- settings.setValue(key, filePath)
- else:
- settings.remove(key)
- settings.endGroup()
-
- def getRecentFiles(self):
- """Returns the list of recently opened files.
-
- The list is limited to the last 10 entries. The newest file path is
- in first.
-
- :rtype: List[str]
- """
- return self.__recentFiles
-
- def pushRecentFile(self, filePath):
- """Push a new recent file to the list.
-
- If the file is duplicated in the list, all duplications are removed
- before inserting the new filePath.
-
- If the list becan bigger than 10 items, oldest paths are removed.
-
- :param filePath: File path to push
- """
- # Remove old occurencies
- self.__recentFiles[:] = (f for f in self.__recentFiles if f != filePath)
- self.__recentFiles.insert(0, filePath)
- while len(self.__recentFiles) > 10:
- self.__recentFiles.pop()
-
- def clearRencentFiles(self):
- """Clear the history of the rencent files.
- """
- self.__recentFiles[:] = []
-
- def getColormap(self, view):
- """Returns a default colormap.
-
- Override from DataViewHooks
-
- :rtype: Colormap
- """
- if self.__defaultColormap is None:
- self.__defaultColormap = Colormap(name="viridis")
- return self.__defaultColormap
-
- def getColormapDialog(self, view):
- """Returns a shared color dialog as default for all the views.
-
- Override from DataViewHooks
-
- :rtype: ColorDialog
- """
- if self.__defaultColormapDialog is None:
- parent = self.__parent()
- if parent is None:
- return None
- dialog = ColormapDialog(parent=parent)
- dialog.setModal(False)
- self.__defaultColormapDialog = dialog
- return self.__defaultColormapDialog
diff --git a/silx/app/view/CustomNxdataWidget.py b/silx/app/view/CustomNxdataWidget.py
deleted file mode 100644
index 72c9940..0000000
--- a/silx/app/view/CustomNxdataWidget.py
+++ /dev/null
@@ -1,1008 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-"""Widget to custom NXdata groups"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "15/06/2018"
-
-import logging
-import numpy
-import weakref
-
-from silx.gui import qt
-from silx.io import commonh5
-import silx.io.nxdata
-from silx.gui.hdf5._utils import Hdf5DatasetMimeData
-from silx.gui.data.TextFormatter import TextFormatter
-from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
-from silx.gui import icons
-
-
-_logger = logging.getLogger(__name__)
-_formatter = TextFormatter()
-_hdf5Formatter = Hdf5Formatter(textFormatter=_formatter)
-
-
-class _RowItems(qt.QStandardItem):
- """Define the list of items used for a specific row."""
-
- def type(self):
- return qt.QStandardItem.UserType + 1
-
- def getRowItems(self):
- """Returns the list of items used for a specific row.
-
- The first item should be this class.
-
- :rtype: List[qt.QStandardItem]
- """
- raise NotImplementedError()
-
-
-class _DatasetItemRow(_RowItems):
- """Define a row which can contain a dataset."""
-
- def __init__(self, label="", dataset=None):
- """Constructor"""
- super(_DatasetItemRow, self).__init__(label)
- self.setEditable(False)
- self.setDropEnabled(False)
- self.setDragEnabled(False)
-
- self.__name = qt.QStandardItem()
- self.__name.setEditable(False)
- self.__name.setDropEnabled(True)
-
- self.__type = qt.QStandardItem()
- self.__type.setEditable(False)
- self.__type.setDropEnabled(False)
- self.__type.setDragEnabled(False)
-
- self.__shape = qt.QStandardItem()
- self.__shape.setEditable(False)
- self.__shape.setDropEnabled(False)
- self.__shape.setDragEnabled(False)
-
- self.setDataset(dataset)
-
- def getDefaultFormatter(self):
- """Get the formatter used to display dataset informations.
-
- :rtype: Hdf5Formatter
- """
- return _hdf5Formatter
-
- def setDataset(self, dataset):
- """Set the dataset stored in this item.
-
- :param Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset] dataset:
- The dataset to store.
- """
- self.__dataset = dataset
- if self.__dataset is not None:
- name = self.__dataset.name
-
- if silx.io.is_dataset(dataset):
- type_ = self.getDefaultFormatter().humanReadableType(dataset)
- shape = self.getDefaultFormatter().humanReadableShape(dataset)
-
- if dataset.shape is None:
- icon_name = "item-none"
- elif len(dataset.shape) < 4:
- icon_name = "item-%ddim" % len(dataset.shape)
- else:
- icon_name = "item-ndim"
- icon = icons.getQIcon(icon_name)
- else:
- type_ = ""
- shape = ""
- icon = qt.QIcon()
- else:
- name = ""
- type_ = ""
- shape = ""
- icon = qt.QIcon()
-
- self.__icon = icon
- self.__name.setText(name)
- self.__name.setDragEnabled(self.__dataset is not None)
- self.__name.setIcon(self.__icon)
- self.__type.setText(type_)
- self.__shape.setText(shape)
-
- parent = self.parent()
- if parent is not None:
- self.parent()._datasetUpdated()
-
- def getDataset(self):
- """Returns the dataset stored within the item."""
- return self.__dataset
-
- def getRowItems(self):
- """Returns the list of items used for a specific row.
-
- The first item should be this class.
-
- :rtype: List[qt.QStandardItem]
- """
- return [self, self.__name, self.__type, self.__shape]
-
-
-class _DatasetAxisItemRow(_DatasetItemRow):
- """Define a row describing an axis."""
-
- def __init__(self):
- """Constructor"""
- super(_DatasetAxisItemRow, self).__init__()
-
- def setAxisId(self, axisId):
- """Set the id of the axis (the first axis is 0)
-
- :param int axisId: Identifier of this axis.
- """
- self.__axisId = axisId
- label = "Axis %d" % (axisId + 1)
- self.setText(label)
-
- def getAxisId(self):
- """Returns the identifier of this axis.
-
- :rtype: int
- """
- return self.__axisId
-
-
-class _NxDataItem(qt.QStandardItem):
- """
- Define a custom NXdata.
- """
-
- def __init__(self):
- """Constructor"""
- qt.QStandardItem.__init__(self)
- self.__error = None
- self.__title = None
- self.__axes = []
- self.__virtual = None
-
- item = _DatasetItemRow("Signal", None)
- self.appendRow(item.getRowItems())
- self.__signal = item
-
- self.setEditable(False)
- self.setDragEnabled(False)
- self.setDropEnabled(False)
- self.__setError(None)
-
- def getRowItems(self):
- """Returns the list of items used for a specific row.
-
- The first item should be this class.
-
- :rtype: List[qt.QStandardItem]
- """
- row = [self]
- for _ in range(3):
- item = qt.QStandardItem("")
- item.setEditable(False)
- item.setDragEnabled(False)
- item.setDropEnabled(False)
- row.append(item)
- return row
-
- def _datasetUpdated(self):
- """Called when the NXdata contained of the item have changed.
-
- It invalidates the NXdata stored and send an event `sigNxdataUpdated`.
- """
- self.__virtual = None
- self.__setError(None)
- model = self.model()
- if model is not None:
- model.sigNxdataUpdated.emit(self.index())
-
- def createVirtualGroup(self):
- """Returns a new virtual Group using a NeXus NXdata structure to store
- data
-
- :rtype: silx.io.commonh5.Group
- """
- name = ""
- if self.__title is not None:
- name = self.__title
- virtual = commonh5.Group(name)
- virtual.attrs["NX_class"] = "NXdata"
-
- if self.__title is not None:
- virtual.attrs["title"] = self.__title
-
- if self.__signal is not None:
- signal = self.__signal.getDataset()
- if signal is not None:
- # Could be done using a link instead of a copy
- node = commonh5.DatasetProxy("signal", target=signal)
- virtual.attrs["signal"] = "signal"
- virtual.add_node(node)
-
- axesAttr = []
- for i, axis in enumerate(self.__axes):
- if axis is None:
- name = "."
- else:
- axis = axis.getDataset()
- if axis is None:
- name = "."
- else:
- name = "axis%d" % i
- node = commonh5.DatasetProxy(name, target=axis)
- virtual.add_node(node)
- axesAttr.append(name)
-
- if axesAttr != []:
- virtual.attrs["axes"] = numpy.array(axesAttr)
-
- validator = silx.io.nxdata.NXdata(virtual)
- if not validator.is_valid:
- message = "<html>"
- message += "This NXdata is not consistant"
- message += "<ul>"
- for issue in validator.issues:
- message += "<li>%s</li>" % issue
- message += "</ul>"
- message += "</html>"
- self.__setError(message)
- else:
- self.__setError(None)
- return virtual
-
- def isValid(self):
- """Returns true if the stored NXdata is valid
-
- :rtype: bool
- """
- return self.__error is None
-
- def getVirtualGroup(self):
- """Returns a cached virtual Group using a NeXus NXdata structure to
- store data.
-
- If the stored NXdata was invalidated, :meth:`createVirtualGroup` is
- internally called to update the cache.
-
- :rtype: silx.io.commonh5.Group
- """
- if self.__virtual is None:
- self.__virtual = self.createVirtualGroup()
- return self.__virtual
-
- def getTitle(self):
- """Returns the title of the NXdata
-
- :rtype: str
- """
- return self.text()
-
- def setTitle(self, title):
- """Set the title of the NXdata
-
- :param str title: The title of this NXdata
- """
- self.setText(title)
-
- def __setError(self, error):
- """Set the error message in case of the current state of the stored
- NXdata is not valid.
-
- :param str error: Message to display
- """
- self.__error = error
- style = qt.QApplication.style()
- if error is None:
- message = ""
- icon = style.standardIcon(qt.QStyle.SP_DirLinkIcon)
- else:
- message = error
- icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
- self.setIcon(icon)
- self.setToolTip(message)
-
- def getError(self):
- """Returns the error message in case the NXdata is not valid.
-
- :rtype: str"""
- return self.__error
-
- def setSignalDataset(self, dataset):
- """Set the dataset to use as signal with this NXdata.
-
- :param Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset] dataset:
- The dataset to use as signal.
- """
-
- self.__signal.setDataset(dataset)
- self._datasetUpdated()
-
- def getSignalDataset(self):
- """Returns the dataset used as signal.
-
- :rtype: Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset]
- """
- return self.__signal.getDataset()
-
- def setAxesDatasets(self, datasets):
- """Set all the available dataset used as axes.
-
- Axes will be created or removed from the GUI in order to provide the
- same amount of requested axes.
-
- A `None` element is an axes with no dataset.
-
- :param List[Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset,None]] datasets:
- List of dataset to use as axes.
- """
- for i, dataset in enumerate(datasets):
- if i < len(self.__axes):
- mustAppend = False
- item = self.__axes[i]
- else:
- mustAppend = True
- item = _DatasetAxisItemRow()
- item.setAxisId(i)
- item.setDataset(dataset)
- if mustAppend:
- self.__axes.append(item)
- self.appendRow(item.getRowItems())
-
- # Clean up extra axis
- for i in range(len(datasets), len(self.__axes)):
- item = self.__axes.pop(len(datasets))
- self.removeRow(item.row())
-
- self._datasetUpdated()
-
- def getAxesDatasets(self):
- """Returns available axes as dataset.
-
- A `None` element is an axes with no dataset.
-
- :rtype: List[Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset,None]]
- """
- datasets = []
- for axis in self.__axes:
- datasets.append(axis.getDataset())
- return datasets
-
-
-class _Model(qt.QStandardItemModel):
- """Model storing a list of custom NXdata items.
-
- Supports drag and drop of datasets.
- """
-
- sigNxdataUpdated = qt.Signal(qt.QModelIndex)
- """Emitted when stored NXdata was edited"""
-
- def __init__(self, parent=None):
- """Constructor"""
- qt.QStandardItemModel.__init__(self, parent)
- root = self.invisibleRootItem()
- root.setDropEnabled(True)
- root.setDragEnabled(False)
-
- def supportedDropActions(self):
- """Inherited method to redefine supported drop actions."""
- return qt.Qt.CopyAction | qt.Qt.MoveAction
-
- def mimeTypes(self):
- """Inherited method to redefine draggable mime types."""
- return [Hdf5DatasetMimeData.MIME_TYPE]
-
- def mimeData(self, indexes):
- """
- Returns an object that contains serialized items of data corresponding
- to the list of indexes specified.
-
- :param List[qt.QModelIndex] indexes: List of indexes
- :rtype: qt.QMimeData
- """
- if len(indexes) > 1:
- return None
- if len(indexes) == 0:
- return None
-
- qindex = indexes[0]
- qindex = self.index(qindex.row(), 0, parent=qindex.parent())
- item = self.itemFromIndex(qindex)
- if isinstance(item, _DatasetItemRow):
- dataset = item.getDataset()
- if dataset is None:
- return None
- else:
- mimeData = Hdf5DatasetMimeData(dataset=item.getDataset())
- else:
- mimeData = None
- return mimeData
-
- def dropMimeData(self, mimedata, action, row, column, parentIndex):
- """Inherited method to handle a drop operation to this model."""
- if action == qt.Qt.IgnoreAction:
- return True
-
- if mimedata.hasFormat(Hdf5DatasetMimeData.MIME_TYPE):
- if row != -1 or column != -1:
- # It is not a drop on a specific item
- return False
- item = self.itemFromIndex(parentIndex)
- if item is None or item is self.invisibleRootItem():
- # Drop at the end
- dataset = mimedata.dataset()
- if silx.io.is_dataset(dataset):
- self.createFromSignal(dataset)
- elif silx.io.is_group(dataset):
- nxdata = dataset
- try:
- self.createFromNxdata(nxdata)
- except ValueError:
- _logger.error("Error while dropping a group as an NXdata")
- _logger.debug("Backtrace", exc_info=True)
- return False
- else:
- _logger.error("Dropping a wrong object")
- return False
- else:
- item = item.parent().child(item.row(), 0)
- if not isinstance(item, _DatasetItemRow):
- # Dropped at a bad place
- return False
- dataset = mimedata.dataset()
- if silx.io.is_dataset(dataset):
- item.setDataset(dataset)
- else:
- _logger.error("Dropping a wrong object")
- return False
- return True
-
- return False
-
- def __getNxdataByTitle(self, title):
- """Returns an NXdata item by its title, else None.
-
- :rtype: Union[_NxDataItem,None]
- """
- for row in range(self.rowCount()):
- qindex = self.index(row, 0)
- item = self.itemFromIndex(qindex)
- if item.getTitle() == title:
- return item
- return None
-
- def findFreeNxdataTitle(self):
- """Returns an NXdata title which is not yet used.
-
- :rtype: str
- """
- for i in range(self.rowCount() + 1):
- name = "NXData #%d" % (i + 1)
- group = self.__getNxdataByTitle(name)
- if group is None:
- break
- return name
-
- def createNewNxdata(self, name=None):
- """Create a new NXdata item.
-
- :param Union[str,None] name: A title for the new NXdata
- """
- item = _NxDataItem()
- if name is None:
- name = self.findFreeNxdataTitle()
- item.setTitle(name)
- self.appendRow(item.getRowItems())
-
- def createFromSignal(self, dataset):
- """Create a new NXdata item from a signal dataset.
-
- This signal will also define an amount of axes according to its number
- of dimensions.
-
- :param Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset] dataset:
- A dataset uses as signal.
- """
-
- item = _NxDataItem()
- name = self.findFreeNxdataTitle()
- item.setTitle(name)
- item.setSignalDataset(dataset)
- item.setAxesDatasets([None] * len(dataset.shape))
- self.appendRow(item.getRowItems())
-
- def createFromNxdata(self, nxdata):
- """Create a new custom NXdata item from an existing NXdata group.
-
- If the NXdata is not valid, nothing is created, and an exception is
- returned.
-
- :param Union[h5py.Group,silx.io.commonh5.Group] nxdata: An h5py group
- following the NXData specification.
- :raise ValueError:If `nxdata` is not valid.
- """
- validator = silx.io.nxdata.NXdata(nxdata)
- if validator.is_valid:
- item = _NxDataItem()
- title = validator.title
- if title in [None or ""]:
- title = self.findFreeNxdataTitle()
- item.setTitle(title)
- item.setSignalDataset(validator.signal)
- item.setAxesDatasets(validator.axes)
- self.appendRow(item.getRowItems())
- else:
- raise ValueError("Not a valid NXdata")
-
- def removeNxdataItem(self, item):
- """Remove an NXdata item from this model.
-
- :param _NxDataItem item: An item
- """
- if isinstance(item, _NxDataItem):
- parent = item.parent()
- assert(parent is None)
- model = item.model()
- model.removeRow(item.row())
- else:
- _logger.error("Unexpected item")
-
- def appendAxisToNxdataItem(self, item):
- """Append a new axes to this item (or the NXdata item own by this item).
-
- :param Union[_NxDataItem,qt.QStandardItem] item: An item
- """
- if item is not None and not isinstance(item, _NxDataItem):
- item = item.parent()
- nxdataItem = item
- if isinstance(item, _NxDataItem):
- datasets = nxdataItem.getAxesDatasets()
- datasets.append(None)
- nxdataItem.setAxesDatasets(datasets)
- else:
- _logger.error("Unexpected item")
-
- def removeAxisItem(self, item):
- """Remove an axis item from this model.
-
- :param _DatasetAxisItemRow item: An axis item
- """
- if isinstance(item, _DatasetAxisItemRow):
- axisId = item.getAxisId()
- nxdataItem = item.parent()
- datasets = nxdataItem.getAxesDatasets()
- del datasets[axisId]
- nxdataItem.setAxesDatasets(datasets)
- else:
- _logger.error("Unexpected item")
-
-
-class CustomNxDataToolBar(qt.QToolBar):
- """A specialised toolbar to manage custom NXdata model and items."""
-
- def __init__(self, parent=None):
- """Constructor"""
- super(CustomNxDataToolBar, self).__init__(parent=parent)
- self.__nxdataWidget = None
- self.__initContent()
- # Initialize action state
- self.__currentSelectionChanged(qt.QModelIndex(), qt.QModelIndex())
-
- def __initContent(self):
- """Create all expected actions and set the content of this toolbar."""
- action = qt.QAction("Create a new custom NXdata", self)
- action.setIcon(icons.getQIcon("nxdata-create"))
- action.triggered.connect(self.__createNewNxdata)
- self.addAction(action)
- self.__addNxDataAction = action
-
- action = qt.QAction("Remove the selected NXdata", self)
- action.setIcon(icons.getQIcon("nxdata-remove"))
- action.triggered.connect(self.__removeSelectedNxdata)
- self.addAction(action)
- self.__removeNxDataAction = action
-
- self.addSeparator()
-
- action = qt.QAction("Create a new axis to the selected NXdata", self)
- action.setIcon(icons.getQIcon("nxdata-axis-add"))
- action.triggered.connect(self.__appendNewAxisToSelectedNxdata)
- self.addAction(action)
- self.__addNxDataAxisAction = action
-
- action = qt.QAction("Remove the selected NXdata axis", self)
- action.setIcon(icons.getQIcon("nxdata-axis-remove"))
- action.triggered.connect(self.__removeSelectedAxis)
- self.addAction(action)
- self.__removeNxDataAxisAction = action
-
- def __getSelectedItem(self):
- """Get the selected item from the linked CustomNxdataWidget.
-
- :rtype: qt.QStandardItem
- """
- selectionModel = self.__nxdataWidget.selectionModel()
- index = selectionModel.currentIndex()
- if not index.isValid():
- return
- model = self.__nxdataWidget.model()
- index = model.index(index.row(), 0, index.parent())
- item = model.itemFromIndex(index)
- return item
-
- def __createNewNxdata(self):
- """Create a new NXdata item to the linked CustomNxdataWidget."""
- if self.__nxdataWidget is None:
- return
- model = self.__nxdataWidget.model()
- model.createNewNxdata()
-
- def __removeSelectedNxdata(self):
- """Remove the NXdata item currently selected in the linked
- CustomNxdataWidget."""
- if self.__nxdataWidget is None:
- return
- model = self.__nxdataWidget.model()
- item = self.__getSelectedItem()
- model.removeNxdataItem(item)
-
- def __appendNewAxisToSelectedNxdata(self):
- """Append a new axis to the NXdata item currently selected in the
- linked CustomNxdataWidget."""
- if self.__nxdataWidget is None:
- return
- model = self.__nxdataWidget.model()
- item = self.__getSelectedItem()
- model.appendAxisToNxdataItem(item)
-
- def __removeSelectedAxis(self):
- """Remove the axis item currently selected in the linked
- CustomNxdataWidget."""
- if self.__nxdataWidget is None:
- return
- model = self.__nxdataWidget.model()
- item = self.__getSelectedItem()
- model.removeAxisItem(item)
-
- def setCustomNxDataWidget(self, widget):
- """Set the linked CustomNxdataWidget to this toolbar."""
- assert(isinstance(widget, CustomNxdataWidget))
- if self.__nxdataWidget is not None:
- selectionModel = self.__nxdataWidget.selectionModel()
- selectionModel.currentChanged.disconnect(self.__currentSelectionChanged)
- self.__nxdataWidget = widget
- if self.__nxdataWidget is not None:
- selectionModel = self.__nxdataWidget.selectionModel()
- selectionModel.currentChanged.connect(self.__currentSelectionChanged)
-
- def __currentSelectionChanged(self, current, previous):
- """Update the actions according to the linked CustomNxdataWidget
- item selection"""
- if not current.isValid():
- item = None
- else:
- model = self.__nxdataWidget.model()
- index = model.index(current.row(), 0, current.parent())
- item = model.itemFromIndex(index)
- self.__removeNxDataAction.setEnabled(isinstance(item, _NxDataItem))
- self.__removeNxDataAxisAction.setEnabled(isinstance(item, _DatasetAxisItemRow))
- self.__addNxDataAxisAction.setEnabled(isinstance(item, _NxDataItem) or isinstance(item, _DatasetItemRow))
-
-
-class _HashDropZones(qt.QStyledItemDelegate):
- """Delegate item displaying a drop zone when the item do not contains
- dataset."""
-
- def __init__(self, parent=None):
- """Constructor"""
- super(_HashDropZones, self).__init__(parent)
- pen = qt.QPen()
- pen.setColor(qt.QColor("#D0D0D0"))
- pen.setStyle(qt.Qt.DotLine)
- pen.setWidth(2)
- self.__dropPen = pen
-
- def paint(self, painter, option, index):
- """
- Paint the item
-
- :param qt.QPainter painter: A painter
- :param qt.QStyleOptionViewItem option: Options of the item to paint
- :param qt.QModelIndex index: Index of the item to paint
- """
- displayDropZone = False
- if index.isValid():
- model = index.model()
- rowIndex = model.index(index.row(), 0, index.parent())
- rowItem = model.itemFromIndex(rowIndex)
- if isinstance(rowItem, _DatasetItemRow):
- displayDropZone = rowItem.getDataset() is None
-
- if displayDropZone:
- painter.save()
-
- # Draw background if selected
- if option.state & qt.QStyle.State_Selected:
- colorGroup = qt.QPalette.Inactive
- if option.state & qt.QStyle.State_Active:
- colorGroup = qt.QPalette.Active
- if not option.state & qt.QStyle.State_Enabled:
- colorGroup = qt.QPalette.Disabled
- brush = option.palette.brush(colorGroup, qt.QPalette.Highlight)
- painter.fillRect(option.rect, brush)
-
- painter.setPen(self.__dropPen)
- painter.drawRect(option.rect.adjusted(3, 3, -3, -3))
- painter.restore()
- else:
- qt.QStyledItemDelegate.paint(self, painter, option, index)
-
-
-class CustomNxdataWidget(qt.QTreeView):
- """Widget providing a table displaying and allowing to custom virtual
- NXdata."""
-
- sigNxdataItemUpdated = qt.Signal(qt.QStandardItem)
- """Emitted when the NXdata from an NXdata item was edited"""
-
- sigNxdataItemRemoved = qt.Signal(qt.QStandardItem)
- """Emitted when an NXdata item was removed"""
-
- def __init__(self, parent=None):
- """Constructor"""
- qt.QTreeView.__init__(self, parent=None)
- self.__model = _Model(self)
- self.__model.setColumnCount(4)
- self.__model.setHorizontalHeaderLabels(["Name", "Dataset", "Type", "Shape"])
- self.setModel(self.__model)
-
- self.setItemDelegateForColumn(1, _HashDropZones(self))
-
- self.__model.sigNxdataUpdated.connect(self.__nxdataUpdate)
- self.__model.rowsAboutToBeRemoved.connect(self.__rowsAboutToBeRemoved)
- self.__model.rowsAboutToBeInserted.connect(self.__rowsAboutToBeInserted)
-
- header = self.header()
- if qt.qVersion() < "5.0":
- setResizeMode = header.setResizeMode
- else:
- setResizeMode = header.setSectionResizeMode
- setResizeMode(0, qt.QHeaderView.ResizeToContents)
- setResizeMode(1, qt.QHeaderView.Stretch)
- setResizeMode(2, qt.QHeaderView.ResizeToContents)
- setResizeMode(3, qt.QHeaderView.ResizeToContents)
-
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
- self.setDropIndicatorShown(True)
- self.setDragDropOverwriteMode(True)
- self.setDragEnabled(True)
- self.viewport().setAcceptDrops(True)
-
- self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
- self.customContextMenuRequested[qt.QPoint].connect(self.__executeContextMenu)
-
- def __rowsAboutToBeInserted(self, parentIndex, start, end):
- if qt.qVersion()[0:2] == "5.":
- # FIXME: workaround for https://github.com/silx-kit/silx/issues/1919
- # Uses of ResizeToContents looks to break nice update of cells with Qt5
- # This patch make the view blinking
- self.repaint()
-
- def __rowsAboutToBeRemoved(self, parentIndex, start, end):
- """Called when an item was removed from the model."""
- items = []
- model = self.model()
- for index in range(start, end):
- qindex = model.index(index, 0, parent=parentIndex)
- item = self.__model.itemFromIndex(qindex)
- if isinstance(item, _NxDataItem):
- items.append(item)
- for item in items:
- self.sigNxdataItemRemoved.emit(item)
-
- if qt.qVersion()[0:2] == "5.":
- # FIXME: workaround for https://github.com/silx-kit/silx/issues/1919
- # Uses of ResizeToContents looks to break nice update of cells with Qt5
- # This patch make the view blinking
- self.repaint()
-
- def __nxdataUpdate(self, index):
- """Called when a virtual NXdata was updated from the model."""
- model = self.model()
- item = model.itemFromIndex(index)
- self.sigNxdataItemUpdated.emit(item)
-
- def createDefaultContextMenu(self, index):
- """Create a default context menu at this position.
-
- :param qt.QModelIndex index: Index of the item
- """
- index = self.__model.index(index.row(), 0, parent=index.parent())
- item = self.__model.itemFromIndex(index)
-
- menu = qt.QMenu()
-
- weakself = weakref.proxy(self)
-
- if isinstance(item, _NxDataItem):
- action = qt.QAction("Add a new axis", menu)
- action.triggered.connect(lambda: weakself.model().appendAxisToNxdataItem(item))
- action.setIcon(icons.getQIcon("nxdata-axis-add"))
- action.setIconVisibleInMenu(True)
- menu.addAction(action)
- menu.addSeparator()
- action = qt.QAction("Remove this NXdata", menu)
- action.triggered.connect(lambda: weakself.model().removeNxdataItem(item))
- action.setIcon(icons.getQIcon("remove"))
- action.setIconVisibleInMenu(True)
- menu.addAction(action)
- else:
- if isinstance(item, _DatasetItemRow):
- if item.getDataset() is not None:
- action = qt.QAction("Remove this dataset", menu)
- action.triggered.connect(lambda: item.setDataset(None))
- menu.addAction(action)
-
- if isinstance(item, _DatasetAxisItemRow):
- menu.addSeparator()
- action = qt.QAction("Remove this axis", menu)
- action.triggered.connect(lambda: weakself.model().removeAxisItem(item))
- action.setIcon(icons.getQIcon("remove"))
- action.setIconVisibleInMenu(True)
- menu.addAction(action)
-
- return menu
-
- def __executeContextMenu(self, point):
- """Execute the context menu at this position."""
- index = self.indexAt(point)
- menu = self.createDefaultContextMenu(index)
- if menu is None or menu.isEmpty():
- return
- menu.exec_(qt.QCursor.pos())
-
- def removeDatasetsFrom(self, root):
- """
- Remove all datasets provided by this root
-
- :param root: The root file of datasets to remove
- """
- for row in range(self.__model.rowCount()):
- qindex = self.__model.index(row, 0)
- item = self.model().itemFromIndex(qindex)
-
- edited = False
- datasets = item.getAxesDatasets()
- for i, dataset in enumerate(datasets):
- if dataset is not None:
- # That's an approximation, IS can't be used as h5py generates
- # To objects for each requests to a node
- if dataset.file.filename == root.file.filename:
- datasets[i] = None
- edited = True
- if edited:
- item.setAxesDatasets(datasets)
-
- dataset = item.getSignalDataset()
- if dataset is not None:
- # That's an approximation, IS can't be used as h5py generates
- # To objects for each requests to a node
- if dataset.file.filename == root.file.filename:
- item.setSignalDataset(None)
-
- def replaceDatasetsFrom(self, removedRoot, loadedRoot):
- """
- Replace any dataset from any NXdata items using the same dataset name
- from another root.
-
- Usually used when a file was synchronized.
-
- :param removedRoot: The h5py root file which is replaced
- (which have to be removed)
- :param loadedRoot: The new h5py root file which have to be used
- instread.
- """
- for row in range(self.__model.rowCount()):
- qindex = self.__model.index(row, 0)
- item = self.model().itemFromIndex(qindex)
-
- edited = False
- datasets = item.getAxesDatasets()
- for i, dataset in enumerate(datasets):
- newDataset = self.__replaceDatasetRoot(dataset, removedRoot, loadedRoot)
- if dataset is not newDataset:
- datasets[i] = newDataset
- edited = True
- if edited:
- item.setAxesDatasets(datasets)
-
- dataset = item.getSignalDataset()
- newDataset = self.__replaceDatasetRoot(dataset, removedRoot, loadedRoot)
- if dataset is not newDataset:
- item.setSignalDataset(newDataset)
-
- def __replaceDatasetRoot(self, dataset, fromRoot, toRoot):
- """
- Replace the dataset by the same dataset name from another root.
- """
- if dataset is None:
- return None
-
- if dataset.file is None:
- # Not from the expected root
- return dataset
-
- # That's an approximation, IS can't be used as h5py generates
- # To objects for each requests to a node
- if dataset.file.filename == fromRoot.file.filename:
- # Try to find the same dataset name
- try:
- return toRoot[dataset.name]
- except Exception:
- _logger.debug("Backtrace", exc_info=True)
- return None
- else:
- # Not from the expected root
- return dataset
-
- def selectedItems(self):
- """Returns the list of selected items containing NXdata
-
- :rtype: List[qt.QStandardItem]
- """
- result = []
- for qindex in self.selectedIndexes():
- if qindex.column() != 0:
- continue
- if not qindex.isValid():
- continue
- item = self.__model.itemFromIndex(qindex)
- if not isinstance(item, _NxDataItem):
- continue
- result.append(item)
- return result
-
- def selectedNxdata(self):
- """Returns the list of selected NXdata
-
- :rtype: List[silx.io.commonh5.Group]
- """
- result = []
- for qindex in self.selectedIndexes():
- if qindex.column() != 0:
- continue
- if not qindex.isValid():
- continue
- item = self.__model.itemFromIndex(qindex)
- if not isinstance(item, _NxDataItem):
- continue
- result.append(item.getVirtualGroup())
- return result
diff --git a/silx/app/view/Viewer.py b/silx/app/view/Viewer.py
deleted file mode 100644
index dd4d075..0000000
--- a/silx/app/view/Viewer.py
+++ /dev/null
@@ -1,971 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Browse a data file with a GUI"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "15/01/2019"
-
-
-import os
-import collections
-import logging
-import functools
-
-import silx.io.nxdata
-from silx.gui import qt
-from silx.gui import icons
-import silx.gui.hdf5
-from .ApplicationContext import ApplicationContext
-from .CustomNxdataWidget import CustomNxdataWidget
-from .CustomNxdataWidget import CustomNxDataToolBar
-from . import utils
-from silx.gui.utils import projecturl
-from .DataPanel import DataPanel
-
-
-_logger = logging.getLogger(__name__)
-
-
-class Viewer(qt.QMainWindow):
- """
- This window allows to browse a data file like images or HDF5 and it's
- content.
- """
-
- def __init__(self, parent=None, settings=None):
- """
- Constructor
- """
-
- qt.QMainWindow.__init__(self, parent)
- self.setWindowTitle("Silx viewer")
-
- silxIcon = icons.getQIcon("silx")
- self.setWindowIcon(silxIcon)
-
- self.__context = self.createApplicationContext(settings)
- self.__context.restoreLibrarySettings()
-
- self.__dialogState = None
- self.__customNxDataItem = None
- self.__treeview = silx.gui.hdf5.Hdf5TreeView(self)
- self.__treeview.setExpandsOnDoubleClick(False)
- """Silx HDF5 TreeView"""
-
- rightPanel = qt.QSplitter(self)
- rightPanel.setOrientation(qt.Qt.Vertical)
- self.__splitter2 = rightPanel
-
- self.__displayIt = None
- self.__treeWindow = self.__createTreeWindow(self.__treeview)
-
- # Custom the model to be able to manage the life cycle of the files
- treeModel = silx.gui.hdf5.Hdf5TreeModel(self.__treeview, ownFiles=False)
- treeModel.sigH5pyObjectLoaded.connect(self.__h5FileLoaded)
- treeModel.sigH5pyObjectRemoved.connect(self.__h5FileRemoved)
- treeModel.sigH5pyObjectSynchronized.connect(self.__h5FileSynchonized)
- treeModel.setDatasetDragEnabled(True)
- self.__treeModelSorted = silx.gui.hdf5.NexusSortFilterProxyModel(self.__treeview)
- self.__treeModelSorted.setSourceModel(treeModel)
- self.__treeModelSorted.sort(0, qt.Qt.AscendingOrder)
- self.__treeModelSorted.setSortCaseSensitivity(qt.Qt.CaseInsensitive)
-
- self.__treeview.setModel(self.__treeModelSorted)
- rightPanel.addWidget(self.__treeWindow)
-
- self.__customNxdata = CustomNxdataWidget(self)
- self.__customNxdata.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- # optimise the rendering
- self.__customNxdata.setUniformRowHeights(True)
- self.__customNxdata.setIconSize(qt.QSize(16, 16))
- self.__customNxdata.setExpandsOnDoubleClick(False)
-
- self.__customNxdataWindow = self.__createCustomNxdataWindow(self.__customNxdata)
- self.__customNxdataWindow.setVisible(False)
- rightPanel.addWidget(self.__customNxdataWindow)
-
- rightPanel.setStretchFactor(1, 1)
- rightPanel.setCollapsible(0, False)
- rightPanel.setCollapsible(1, False)
-
- self.__dataPanel = DataPanel(self, self.__context)
-
- spliter = qt.QSplitter(self)
- spliter.addWidget(rightPanel)
- spliter.addWidget(self.__dataPanel)
- spliter.setStretchFactor(1, 1)
- spliter.setCollapsible(0, False)
- spliter.setCollapsible(1, False)
- self.__splitter = spliter
-
- main_panel = qt.QWidget(self)
- layout = qt.QVBoxLayout()
- layout.addWidget(spliter)
- layout.setStretchFactor(spliter, 1)
- main_panel.setLayout(layout)
-
- self.setCentralWidget(main_panel)
-
- self.__treeview.activated.connect(self.displaySelectedData)
- self.__customNxdata.activated.connect(self.displaySelectedCustomData)
- self.__customNxdata.sigNxdataItemRemoved.connect(self.__customNxdataRemoved)
- self.__customNxdata.sigNxdataItemUpdated.connect(self.__customNxdataUpdated)
- self.__treeview.addContextMenuCallback(self.customContextMenu)
-
- treeModel = self.__treeview.findHdf5TreeModel()
- columns = list(treeModel.COLUMN_IDS)
- columns.remove(treeModel.VALUE_COLUMN)
- columns.remove(treeModel.NODE_COLUMN)
- columns.remove(treeModel.DESCRIPTION_COLUMN)
- columns.insert(1, treeModel.DESCRIPTION_COLUMN)
- self.__treeview.header().setSections(columns)
-
- self._iconUpward = icons.getQIcon('plot-yup')
- self._iconDownward = icons.getQIcon('plot-ydown')
-
- self.createActions()
- self.createMenus()
- self.__context.restoreSettings()
-
- def createApplicationContext(self, settings):
- return ApplicationContext(self, settings)
-
- def __createTreeWindow(self, treeView):
- toolbar = qt.QToolBar(self)
- toolbar.setIconSize(qt.QSize(16, 16))
- toolbar.setStyleSheet("QToolBar { border: 0px }")
-
- action = qt.QAction(toolbar)
- action.setIcon(icons.getQIcon("view-refresh"))
- action.setText("Refresh")
- action.setToolTip("Refresh all selected items")
- action.triggered.connect(self.__refreshSelected)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_F5))
- toolbar.addAction(action)
- treeView.addAction(action)
- self.__refreshAction = action
-
- # Another shortcut for refresh
- action = qt.QAction(toolbar)
- action.setShortcut(qt.QKeySequence(qt.Qt.ControlModifier + qt.Qt.Key_R))
- treeView.addAction(action)
- action.triggered.connect(self.__refreshSelected)
-
- action = qt.QAction(toolbar)
- # action.setIcon(icons.getQIcon("view-refresh"))
- action.setText("Close")
- action.setToolTip("Close selected item")
- action.triggered.connect(self.__removeSelected)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_Delete))
- treeView.addAction(action)
- self.__closeAction = action
-
- toolbar.addSeparator()
-
- action = qt.QAction(toolbar)
- action.setIcon(icons.getQIcon("tree-expand-all"))
- action.setText("Expand all")
- action.setToolTip("Expand all selected items")
- action.triggered.connect(self.__expandAllSelected)
- action.setShortcut(qt.QKeySequence(qt.Qt.ControlModifier + qt.Qt.Key_Plus))
- toolbar.addAction(action)
- treeView.addAction(action)
- self.__expandAllAction = action
-
- action = qt.QAction(toolbar)
- action.setIcon(icons.getQIcon("tree-collapse-all"))
- action.setText("Collapse all")
- action.setToolTip("Collapse all selected items")
- action.triggered.connect(self.__collapseAllSelected)
- action.setShortcut(qt.QKeySequence(qt.Qt.ControlModifier + qt.Qt.Key_Minus))
- toolbar.addAction(action)
- treeView.addAction(action)
- self.__collapseAllAction = action
-
- action = qt.QAction("&Sort file content", toolbar)
- action.setIcon(icons.getQIcon("tree-sort"))
- action.setToolTip("Toggle sorting of file content")
- action.setCheckable(True)
- action.setChecked(True)
- action.triggered.connect(self.setContentSorted)
- toolbar.addAction(action)
- treeView.addAction(action)
- self._sortContentAction = action
-
- widget = qt.QWidget(self)
- layout = qt.QVBoxLayout(widget)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
- layout.addWidget(toolbar)
- layout.addWidget(treeView)
- return widget
-
- def __removeSelected(self):
- """Close selected items"""
- qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
-
- selection = self.__treeview.selectionModel()
- indexes = selection.selectedIndexes()
- selectedItems = []
- model = self.__treeview.model()
- h5files = set([])
- while len(indexes) > 0:
- index = indexes.pop(0)
- if index.column() != 0:
- continue
- h5 = model.data(index, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- rootIndex = index
- # Reach the root of the tree
- while rootIndex.parent().isValid():
- rootIndex = rootIndex.parent()
- rootRow = rootIndex.row()
- relativePath = self.__getRelativePath(model, rootIndex, index)
- selectedItems.append((rootRow, relativePath))
- h5files.add(h5.file)
-
- if len(h5files) != 0:
- model = self.__treeview.findHdf5TreeModel()
- for h5 in h5files:
- row = model.h5pyObjectRow(h5)
- model.removeH5pyObject(h5)
-
- qt.QApplication.restoreOverrideCursor()
-
- def __refreshSelected(self):
- """Refresh all selected items
- """
- qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
-
- selection = self.__treeview.selectionModel()
- indexes = selection.selectedIndexes()
- selectedItems = []
- model = self.__treeview.model()
- h5files = set([])
- while len(indexes) > 0:
- index = indexes.pop(0)
- if index.column() != 0:
- continue
- h5 = model.data(index, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- rootIndex = index
- # Reach the root of the tree
- while rootIndex.parent().isValid():
- rootIndex = rootIndex.parent()
- rootRow = rootIndex.row()
- relativePath = self.__getRelativePath(model, rootIndex, index)
- selectedItems.append((rootRow, relativePath))
- h5files.add(h5.file)
-
- if len(h5files) == 0:
- qt.QApplication.restoreOverrideCursor()
- return
-
- model = self.__treeview.findHdf5TreeModel()
- for h5 in h5files:
- self.__synchronizeH5pyObject(h5)
-
- model = self.__treeview.model()
- itemSelection = qt.QItemSelection()
- for rootRow, relativePath in selectedItems:
- rootIndex = model.index(rootRow, 0, qt.QModelIndex())
- index = self.__indexFromPath(model, rootIndex, relativePath)
- if index is None:
- continue
- indexEnd = model.index(index.row(), model.columnCount() - 1, index.parent())
- itemSelection.select(index, indexEnd)
- selection.select(itemSelection, qt.QItemSelectionModel.ClearAndSelect)
-
- qt.QApplication.restoreOverrideCursor()
-
- def __synchronizeH5pyObject(self, h5):
- model = self.__treeview.findHdf5TreeModel()
- # This is buggy right now while h5py do not allow to close a file
- # while references are still used.
- # FIXME: The architecture have to be reworked to support this feature.
- # model.synchronizeH5pyObject(h5)
-
- filename = h5.filename
- row = model.h5pyObjectRow(h5)
- index = self.__treeview.model().index(row, 0, qt.QModelIndex())
- paths = self.__getPathFromExpandedNodes(self.__treeview, index)
- model.removeH5pyObject(h5)
- model.insertFile(filename, row)
- index = self.__treeview.model().index(row, 0, qt.QModelIndex())
- self.__expandNodesFromPaths(self.__treeview, index, paths)
-
- def __getRelativePath(self, model, rootIndex, index):
- """Returns a relative path from an index to his rootIndex.
-
- If the path is empty the index is also the rootIndex.
- """
- path = ""
- while index.isValid():
- if index == rootIndex:
- return path
- name = model.data(index)
- if path == "":
- path = name
- else:
- path = name + "/" + path
- index = index.parent()
-
- # index is not a children of rootIndex
- raise ValueError("index is not a children of the rootIndex")
-
- def __getPathFromExpandedNodes(self, view, rootIndex):
- """Return relative path from the root index of the extended nodes"""
- model = view.model()
- rootPath = None
- paths = []
- indexes = [rootIndex]
- while len(indexes):
- index = indexes.pop(0)
- if not view.isExpanded(index):
- continue
-
- node = model.data(index, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
- path = node._getCanonicalName()
- if rootPath is None:
- rootPath = path
- path = path[len(rootPath):]
- paths.append(path)
-
- for child in range(model.rowCount(index)):
- childIndex = model.index(child, 0, index)
- indexes.append(childIndex)
- return paths
-
- def __indexFromPath(self, model, rootIndex, path):
- elements = path.split("/")
- if elements[0] == "":
- elements.pop(0)
- index = rootIndex
- while len(elements) != 0:
- element = elements.pop(0)
- found = False
- for child in range(model.rowCount(index)):
- childIndex = model.index(child, 0, index)
- name = model.data(childIndex)
- if element == name:
- index = childIndex
- found = True
- break
- if not found:
- return None
- return index
-
- def __expandNodesFromPaths(self, view, rootIndex, paths):
- model = view.model()
- for path in paths:
- index = self.__indexFromPath(model, rootIndex, path)
- if index is not None:
- view.setExpanded(index, True)
-
- def __expandAllSelected(self):
- """Expand all selected items of the tree.
-
- The depth is fixed to avoid infinite loop with recurssive links.
- """
- qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
-
- selection = self.__treeview.selectionModel()
- indexes = selection.selectedIndexes()
- model = self.__treeview.model()
- while len(indexes) > 0:
- index = indexes.pop(0)
- if isinstance(index, tuple):
- index, depth = index
- else:
- depth = 0
- if index.column() != 0:
- continue
-
- if depth > 10:
- # Avoid infinite loop with recursive links
- break
-
- if model.hasChildren(index):
- self.__treeview.setExpanded(index, True)
- for row in range(model.rowCount(index)):
- childIndex = model.index(row, 0, index)
- indexes.append((childIndex, depth + 1))
- qt.QApplication.restoreOverrideCursor()
-
- def __collapseAllSelected(self):
- """Collapse all selected items of the tree.
-
- The depth is fixed to avoid infinite loop with recurssive links.
- """
- selection = self.__treeview.selectionModel()
- indexes = selection.selectedIndexes()
- model = self.__treeview.model()
- while len(indexes) > 0:
- index = indexes.pop(0)
- if isinstance(index, tuple):
- index, depth = index
- else:
- depth = 0
- if index.column() != 0:
- continue
-
- if depth > 10:
- # Avoid infinite loop with recursive links
- break
-
- if model.hasChildren(index):
- self.__treeview.setExpanded(index, False)
- for row in range(model.rowCount(index)):
- childIndex = model.index(row, 0, index)
- indexes.append((childIndex, depth + 1))
-
- def __createCustomNxdataWindow(self, customNxdataWidget):
- toolbar = CustomNxDataToolBar(self)
- toolbar.setCustomNxDataWidget(customNxdataWidget)
- toolbar.setIconSize(qt.QSize(16, 16))
- toolbar.setStyleSheet("QToolBar { border: 0px }")
-
- widget = qt.QWidget(self)
- layout = qt.QVBoxLayout(widget)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
- layout.addWidget(toolbar)
- layout.addWidget(customNxdataWidget)
- return widget
-
- def __h5FileLoaded(self, loadedH5):
- self.__context.pushRecentFile(loadedH5.file.filename)
- if loadedH5.file.filename == self.__displayIt:
- self.__displayIt = None
- self.displayData(loadedH5)
-
- def __h5FileRemoved(self, removedH5):
- self.__dataPanel.removeDatasetsFrom(removedH5)
- self.__customNxdata.removeDatasetsFrom(removedH5)
- removedH5.close()
-
- def __h5FileSynchonized(self, removedH5, loadedH5):
- self.__dataPanel.replaceDatasetsFrom(removedH5, loadedH5)
- self.__customNxdata.replaceDatasetsFrom(removedH5, loadedH5)
- removedH5.close()
-
- def closeEvent(self, event):
- self.__context.saveSettings()
-
- # Clean up as much as possible Python objects
- self.displayData(None)
- customModel = self.__customNxdata.model()
- customModel.clear()
- hdf5Model = self.__treeview.findHdf5TreeModel()
- hdf5Model.clear()
-
- def saveSettings(self, settings):
- """Save the window settings to this settings object
-
- :param qt.QSettings settings: Initialized settings
- """
- isFullScreen = bool(self.windowState() & qt.Qt.WindowFullScreen)
- if isFullScreen:
- # show in normal to catch the normal geometry
- self.showNormal()
-
- settings.beginGroup("mainwindow")
- settings.setValue("size", self.size())
- settings.setValue("pos", self.pos())
- settings.setValue("full-screen", isFullScreen)
- settings.endGroup()
-
- settings.beginGroup("mainlayout")
- settings.setValue("spliter", self.__splitter.sizes())
- settings.setValue("spliter2", self.__splitter2.sizes())
- isVisible = self.__customNxdataWindow.isVisible()
- settings.setValue("custom-nxdata-window-visible", isVisible)
- settings.endGroup()
-
- settings.beginGroup("content")
- isSorted = self._sortContentAction.isChecked()
- settings.setValue("is-sorted", isSorted)
- settings.endGroup()
-
- if isFullScreen:
- self.showFullScreen()
-
- def restoreSettings(self, settings):
- """Restore the window settings using this settings object
-
- :param qt.QSettings settings: Initialized settings
- """
- settings.beginGroup("mainwindow")
- size = settings.value("size", qt.QSize(640, 480))
- pos = settings.value("pos", qt.QPoint())
- isFullScreen = settings.value("full-screen", False)
- try:
- if not isinstance(isFullScreen, bool):
- isFullScreen = utils.stringToBool(isFullScreen)
- except ValueError:
- isFullScreen = False
- settings.endGroup()
-
- settings.beginGroup("mainlayout")
- try:
- data = settings.value("spliter")
- data = [int(d) for d in data]
- self.__splitter.setSizes(data)
- except Exception:
- _logger.debug("Backtrace", exc_info=True)
- try:
- data = settings.value("spliter2")
- data = [int(d) for d in data]
- self.__splitter2.setSizes(data)
- except Exception:
- _logger.debug("Backtrace", exc_info=True)
- isVisible = settings.value("custom-nxdata-window-visible", False)
- try:
- if not isinstance(isVisible, bool):
- isVisible = utils.stringToBool(isVisible)
- except ValueError:
- isVisible = False
- self.__customNxdataWindow.setVisible(isVisible)
- self._displayCustomNxdataWindow.setChecked(isVisible)
-
- settings.endGroup()
-
- settings.beginGroup("content")
- isSorted = settings.value("is-sorted", True)
- try:
- if not isinstance(isSorted, bool):
- isSorted = utils.stringToBool(isSorted)
- except ValueError:
- isSorted = True
- self.setContentSorted(isSorted)
- settings.endGroup()
-
- if not pos.isNull():
- self.move(pos)
- if not size.isNull():
- self.resize(size)
- if isFullScreen:
- self.showFullScreen()
-
- def createActions(self):
- action = qt.QAction("E&xit", self)
- action.setShortcuts(qt.QKeySequence.Quit)
- action.setStatusTip("Exit the application")
- action.triggered.connect(self.close)
- self._exitAction = action
-
- action = qt.QAction("&Open...", self)
- action.setStatusTip("Open a file")
- action.triggered.connect(self.open)
- self._openAction = action
-
- action = qt.QAction("Open Recent", self)
- action.setStatusTip("Open a recently openned file")
- action.triggered.connect(self.open)
- self._openRecentAction = action
-
- action = qt.QAction("Close All", self)
- action.setStatusTip("Close all opened files")
- action.triggered.connect(self.closeAll)
- self._closeAllAction = action
-
- action = qt.QAction("&About", self)
- action.setStatusTip("Show the application's About box")
- action.triggered.connect(self.about)
- self._aboutAction = action
-
- action = qt.QAction("&Documentation", self)
- action.setStatusTip("Show the Silx library's documentation")
- action.triggered.connect(self.showDocumentation)
- self._documentationAction = action
-
- # Plot backend
-
- action = qt.QAction("Plot rendering backend", self)
- action.setStatusTip("Select plot rendering backend")
- self._plotBackendSelection = action
-
- menu = qt.QMenu()
- action.setMenu(menu)
- group = qt.QActionGroup(self)
- group.setExclusive(True)
-
- action = qt.QAction("matplotlib", self)
- action.setStatusTip("Plot will be rendered using matplotlib")
- action.setCheckable(True)
- action.triggered.connect(self.__forceMatplotlibBackend)
- group.addAction(action)
- menu.addAction(action)
- self._usePlotWithMatplotlib = action
-
- action = qt.QAction("OpenGL", self)
- action.setStatusTip("Plot will be rendered using OpenGL")
- action.setCheckable(True)
- action.triggered.connect(self.__forceOpenglBackend)
- group.addAction(action)
- menu.addAction(action)
- self._usePlotWithOpengl = action
-
- # Plot image orientation
-
- action = qt.QAction("Default plot image y-axis orientation", self)
- action.setStatusTip("Select the default y-axis orientation used by plot displaying images")
- self._plotImageOrientation = action
-
- menu = qt.QMenu()
- action.setMenu(menu)
- group = qt.QActionGroup(self)
- group.setExclusive(True)
-
- action = qt.QAction("Downward, origin on top", self)
- action.setIcon(self._iconDownward)
- action.setStatusTip("Plot images will use a downward Y-axis orientation")
- action.setCheckable(True)
- action.triggered.connect(self.__forcePlotImageDownward)
- group.addAction(action)
- menu.addAction(action)
- self._useYAxisOrientationDownward = action
-
- action = qt.QAction("Upward, origin on bottom", self)
- action.setIcon(self._iconUpward)
- action.setStatusTip("Plot images will use a upward Y-axis orientation")
- action.setCheckable(True)
- action.triggered.connect(self.__forcePlotImageUpward)
- group.addAction(action)
- menu.addAction(action)
- self._useYAxisOrientationUpward = action
-
- # Windows
-
- action = qt.QAction("Show custom NXdata selector", self)
- action.setStatusTip("Show a widget which allow to create plot by selecting data and axes")
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_F6))
- action.toggled.connect(self.__toggleCustomNxdataWindow)
- self._displayCustomNxdataWindow = action
-
- def __toggleCustomNxdataWindow(self):
- isVisible = self._displayCustomNxdataWindow.isChecked()
- self.__customNxdataWindow.setVisible(isVisible)
-
- def __updateFileMenu(self):
- files = self.__context.getRecentFiles()
- self._openRecentAction.setEnabled(len(files) != 0)
- menu = None
- if len(files) != 0:
- menu = qt.QMenu()
- for filePath in files:
- baseName = os.path.basename(filePath)
- action = qt.QAction(baseName, self)
- action.setToolTip(filePath)
- action.triggered.connect(functools.partial(self.__openRecentFile, filePath))
- menu.addAction(action)
- menu.addSeparator()
- baseName = os.path.basename(filePath)
- action = qt.QAction("Clear history", self)
- action.setToolTip("Clear the history of the recent files")
- action.triggered.connect(self.__clearRecentFile)
- menu.addAction(action)
- self._openRecentAction.setMenu(menu)
-
- def __clearRecentFile(self):
- self.__context.clearRencentFiles()
-
- def __openRecentFile(self, filePath):
- self.appendFile(filePath)
-
- def __updateOptionMenu(self):
- """Update the state of the checked options as it is based on global
- environment values."""
-
- # plot backend
-
- action = self._plotBackendSelection
- title = action.text().split(": ", 1)[0]
- action.setText("%s: %s" % (title, silx.config.DEFAULT_PLOT_BACKEND))
-
- action = self._usePlotWithMatplotlib
- action.setChecked(silx.config.DEFAULT_PLOT_BACKEND in ["matplotlib", "mpl"])
- title = action.text().split(" (", 1)[0]
- if not action.isChecked():
- title += " (applied after application restart)"
- action.setText(title)
-
- action = self._usePlotWithOpengl
- action.setChecked(silx.config.DEFAULT_PLOT_BACKEND in ["opengl", "gl"])
- title = action.text().split(" (", 1)[0]
- if not action.isChecked():
- title += " (applied after application restart)"
- action.setText(title)
-
- # plot orientation
-
- action = self._plotImageOrientation
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
- action.setIcon(self._iconDownward)
- else:
- action.setIcon(self._iconUpward)
- action.setIconVisibleInMenu(True)
-
- action = self._useYAxisOrientationDownward
- action.setChecked(silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward")
- title = action.text().split(" (", 1)[0]
- if not action.isChecked():
- title += " (applied after application restart)"
- action.setText(title)
-
- action = self._useYAxisOrientationUpward
- action.setChecked(silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION != "downward")
- title = action.text().split(" (", 1)[0]
- if not action.isChecked():
- title += " (applied after application restart)"
- action.setText(title)
-
- def createMenus(self):
- fileMenu = self.menuBar().addMenu("&File")
- fileMenu.addAction(self._openAction)
- fileMenu.addAction(self._openRecentAction)
- fileMenu.addAction(self._closeAllAction)
- fileMenu.addSeparator()
- fileMenu.addAction(self._exitAction)
- fileMenu.aboutToShow.connect(self.__updateFileMenu)
-
- optionMenu = self.menuBar().addMenu("&Options")
- optionMenu.addAction(self._plotImageOrientation)
- optionMenu.addAction(self._plotBackendSelection)
- optionMenu.aboutToShow.connect(self.__updateOptionMenu)
-
- viewMenu = self.menuBar().addMenu("&Views")
- viewMenu.addAction(self._displayCustomNxdataWindow)
-
- helpMenu = self.menuBar().addMenu("&Help")
- helpMenu.addAction(self._aboutAction)
- helpMenu.addAction(self._documentationAction)
-
- def open(self):
- dialog = self.createFileDialog()
- if self.__dialogState is None:
- currentDirectory = os.getcwd()
- dialog.setDirectory(currentDirectory)
- else:
- dialog.restoreState(self.__dialogState)
-
- result = dialog.exec_()
- if not result:
- return
-
- self.__dialogState = dialog.saveState()
-
- filenames = dialog.selectedFiles()
- for filename in filenames:
- self.appendFile(filename)
-
- def closeAll(self):
- """Close all currently opened files"""
- model = self.__treeview.findHdf5TreeModel()
- model.clear()
-
- def createFileDialog(self):
- dialog = qt.QFileDialog(self)
- dialog.setWindowTitle("Open")
- dialog.setModal(True)
-
- # NOTE: hdf5plugin have to be loaded before
- extensions = collections.OrderedDict()
- for description, ext in silx.io.supported_extensions().items():
- extensions[description] = " ".join(sorted(list(ext)))
-
- # Add extensions supported by fabio
- extensions["NeXus layout from EDF files"] = "*.edf"
- extensions["NeXus layout from TIFF image files"] = "*.tif *.tiff"
- extensions["NeXus layout from CBF files"] = "*.cbf"
- extensions["NeXus layout from MarCCD image files"] = "*.mccd"
-
- all_supported_extensions = set()
- for name, exts in extensions.items():
- exts = exts.split(" ")
- all_supported_extensions.update(exts)
- all_supported_extensions = sorted(list(all_supported_extensions))
-
- filters = []
- filters.append("All supported files (%s)" % " ".join(all_supported_extensions))
- for name, extension in extensions.items():
- filters.append("%s (%s)" % (name, extension))
- filters.append("All files (*)")
-
- dialog.setNameFilters(filters)
- dialog.setFileMode(qt.QFileDialog.ExistingFiles)
- return dialog
-
- def about(self):
- from .About import About
- About.about(self, "Silx viewer")
-
- def showDocumentation(self):
- subpath = "index.html"
- url = projecturl.getDocumentationUrl(subpath)
- qt.QDesktopServices.openUrl(qt.QUrl(url))
-
- def setContentSorted(self, sort):
- """Set whether file content should be sorted or not.
-
- :param bool sort:
- """
- sort = bool(sort)
- if sort != self.isContentSorted():
-
- # save expanded nodes
- pathss = []
- root = qt.QModelIndex()
- model = self.__treeview.model()
- for i in range(model.rowCount(root)):
- index = model.index(i, 0, root)
- paths = self.__getPathFromExpandedNodes(self.__treeview, index)
- pathss.append(paths)
-
- self.__treeview.setModel(
- self.__treeModelSorted if sort else self.__treeModelSorted.sourceModel())
- self._sortContentAction.setChecked(self.isContentSorted())
-
- # restore expanded nodes
- model = self.__treeview.model()
- for i in range(model.rowCount(root)):
- index = model.index(i, 0, root)
- paths = pathss.pop(0)
- self.__expandNodesFromPaths(self.__treeview, index, paths)
-
- def isContentSorted(self):
- """Returns whether the file content is sorted or not.
-
- :rtype: bool
- """
- return self.__treeview.model() is self.__treeModelSorted
-
- def __forcePlotImageDownward(self):
- silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = "downward"
-
- def __forcePlotImageUpward(self):
- silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = "upward"
-
- def __forceMatplotlibBackend(self):
- silx.config.DEFAULT_PLOT_BACKEND = "matplotlib"
-
- def __forceOpenglBackend(self):
- silx.config.DEFAULT_PLOT_BACKEND = "opengl"
-
- def appendFile(self, filename):
- if self.__displayIt is None:
- # Store the file to display it (loading could be async)
- self.__displayIt = filename
- self.__treeview.findHdf5TreeModel().appendFile(filename)
-
- def displaySelectedData(self):
- """Called to update the dataviewer with the selected data.
- """
- selected = list(self.__treeview.selectedH5Nodes(ignoreBrokenLinks=False))
- if len(selected) == 1:
- # Update the viewer for a single selection
- data = selected[0]
- self.__dataPanel.setData(data)
- else:
- _logger.debug("Too many data selected")
-
- def displayData(self, data):
- """Called to update the dataviewer with a secific data.
- """
- self.__dataPanel.setData(data)
-
- def displaySelectedCustomData(self):
- selected = list(self.__customNxdata.selectedItems())
- if len(selected) == 1:
- # Update the viewer for a single selection
- item = selected[0]
- self.__dataPanel.setCustomDataItem(item)
- else:
- _logger.debug("Too many items selected")
-
- def __customNxdataRemoved(self, item):
- if self.__dataPanel.getCustomNxdataItem() is item:
- self.__dataPanel.setCustomDataItem(None)
-
- def __customNxdataUpdated(self, item):
- if self.__dataPanel.getCustomNxdataItem() is item:
- self.__dataPanel.setCustomDataItem(item)
-
- def __makeSureCustomNxDataWindowIsVisible(self):
- if not self.__customNxdataWindow.isVisible():
- self.__customNxdataWindow.setVisible(True)
- self._displayCustomNxdataWindow.setChecked(True)
-
- def useAsNewCustomSignal(self, h5dataset):
- self.__makeSureCustomNxDataWindowIsVisible()
- model = self.__customNxdata.model()
- model.createFromSignal(h5dataset)
-
- def useAsNewCustomNxdata(self, h5nxdata):
- self.__makeSureCustomNxDataWindowIsVisible()
- model = self.__customNxdata.model()
- model.createFromNxdata(h5nxdata)
-
- def customContextMenu(self, event):
- """Called to populate the context menu
-
- :param silx.gui.hdf5.Hdf5ContextMenuEvent event: Event
- containing expected information to populate the context menu
- """
- selectedObjects = event.source().selectedH5Nodes(ignoreBrokenLinks=False)
- menu = event.menu()
-
- if not menu.isEmpty():
- menu.addSeparator()
-
- for obj in selectedObjects:
- h5 = obj.h5py_object
-
- name = obj.name
- if name.startswith("/"):
- name = name[1:]
- if name == "":
- name = "the root"
-
- action = qt.QAction("Show %s" % name, event.source())
- action.triggered.connect(lambda: self.displayData(h5))
- menu.addAction(action)
-
- if silx.io.is_dataset(h5):
- action = qt.QAction("Use as a new custom signal", event.source())
- action.triggered.connect(lambda: self.useAsNewCustomSignal(h5))
- menu.addAction(action)
-
- if silx.io.is_group(h5) and silx.io.nxdata.is_valid_nxdata(h5):
- action = qt.QAction("Use as a new custom NXdata", event.source())
- action.triggered.connect(lambda: self.useAsNewCustomNxdata(h5))
- menu.addAction(action)
-
- if silx.io.is_file(h5):
- action = qt.QAction("Close %s" % obj.local_filename, event.source())
- action.triggered.connect(lambda: self.__treeview.findHdf5TreeModel().removeH5pyObject(h5))
- menu.addAction(action)
- action = qt.QAction("Synchronize %s" % obj.local_filename, event.source())
- action.triggered.connect(lambda: self.__synchronizeH5pyObject(h5))
- menu.addAction(action)
diff --git a/silx/app/view/main.py b/silx/app/view/main.py
deleted file mode 100644
index a1369c1..0000000
--- a/silx/app/view/main.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Module containing launcher of the `silx view` application"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/01/2019"
-
-import argparse
-import logging
-import os
-import signal
-import sys
-
-
-_logger = logging.getLogger(__name__)
-"""Module logger"""
-
-
-def createParser():
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument(
- 'files',
- nargs=argparse.ZERO_OR_MORE,
- help='Data file to show (h5 file, edf files, spec files)')
- parser.add_argument(
- '--debug',
- dest="debug",
- action="store_true",
- default=False,
- help='Set logging system in debug mode')
- parser.add_argument(
- '--use-opengl-plot',
- dest="use_opengl_plot",
- action="store_true",
- default=False,
- help='Use OpenGL for plots (instead of matplotlib)')
- parser.add_argument(
- '-f', '--fresh',
- dest="fresh_preferences",
- action="store_true",
- default=False,
- help='Start the application using new fresh user preferences')
- parser.add_argument(
- '--hdf5-file-locking',
- dest="hdf5_file_locking",
- action="store_true",
- default=False,
- help='Start the application with HDF5 file locking enabled (it is disabled by default)')
- return parser
-
-
-def createWindow(parent, settings):
- from .Viewer import Viewer
- window = Viewer(parent=None, settings=settings)
- return window
-
-
-def mainQt(options):
- """Part of the main depending on Qt"""
- if options.debug:
- logging.root.setLevel(logging.DEBUG)
-
- #
- # Import most of the things here to be sure to use the right logging level
- #
-
- # This needs to be done prior to load HDF5
- hdf5_file_locking = 'TRUE' if options.hdf5_file_locking else 'FALSE'
- _logger.info('Set HDF5_USE_FILE_LOCKING=%s', hdf5_file_locking)
- os.environ['HDF5_USE_FILE_LOCKING'] = hdf5_file_locking
-
- try:
- # it should be loaded before h5py
- import hdf5plugin # noqa
- except ImportError:
- _logger.debug("Backtrace", exc_info=True)
-
- import h5py
-
- import silx
- import silx.utils.files
- from silx.gui import qt
- # Make sure matplotlib is configured
- # Needed for Debian 8: compatibility between Qt4/Qt5 and old matplotlib
- import silx.gui.utils.matplotlib # noqa
-
- app = qt.QApplication([])
- qt.QLocale.setDefault(qt.QLocale.c())
-
- def sigintHandler(*args):
- """Handler for the SIGINT signal."""
- qt.QApplication.quit()
-
- signal.signal(signal.SIGINT, sigintHandler)
- sys.excepthook = qt.exceptionHandler
-
- timer = qt.QTimer()
- timer.start(500)
- # Application have to wake up Python interpreter, else SIGINT is not
- # catched
- timer.timeout.connect(lambda: None)
-
- settings = qt.QSettings(qt.QSettings.IniFormat,
- qt.QSettings.UserScope,
- "silx",
- "silx-view",
- None)
- if options.fresh_preferences:
- settings.clear()
-
- window = createWindow(parent=None, settings=settings)
- window.setAttribute(qt.Qt.WA_DeleteOnClose, True)
-
- if options.use_opengl_plot:
- # It have to be done after the settings (after the Viewer creation)
- silx.config.DEFAULT_PLOT_BACKEND = "opengl"
-
- # NOTE: under Windows, cmd does not convert `*.tif` into existing files
- options.files = silx.utils.files.expand_filenames(options.files)
-
- for filename in options.files:
- # TODO: Would be nice to add a process widget and a cancel button
- try:
- window.appendFile(filename)
- except IOError as e:
- _logger.error(e.args[0])
- _logger.debug("Backtrace", exc_info=True)
-
- window.show()
- result = app.exec_()
- # remove ending warnings relative to QTimer
- app.deleteLater()
- return result
-
-
-def main(argv):
- """
- Main function to launch the viewer as an application
-
- :param argv: Command line arguments
- :returns: exit status
- """
- parser = createParser()
- options = parser.parse_args(argv[1:])
- mainQt(options)
-
-
-if __name__ == '__main__':
- main(sys.argv)
diff --git a/silx/app/view/test/__init__.py b/silx/app/view/test/__init__.py
deleted file mode 100644
index 8e64948..0000000
--- a/silx/app/view/test/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "07/06/2018"
-
-import unittest
-
-from silx.test.utils import test_options
-
-
-def suite():
- test_suite = unittest.TestSuite()
- if test_options.WITH_QT_TEST:
- from . import test_launcher
- from . import test_view
- test_suite.addTest(test_view.suite())
- test_suite.addTest(test_launcher.suite())
- return test_suite
diff --git a/silx/app/view/test/test_launcher.py b/silx/app/view/test/test_launcher.py
deleted file mode 100644
index 5f03de9..0000000
--- a/silx/app/view/test/test_launcher.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module testing silx.app.view"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "07/06/2018"
-
-
-import os
-import shutil
-import sys
-import tempfile
-import unittest
-import logging
-import subprocess
-
-from silx.test.utils import test_options
-from .. import main
-from silx import __main__ as silx_main
-
-_logger = logging.getLogger(__name__)
-
-
-@unittest.skipUnless(test_options.WITH_QT_TEST, test_options.WITH_QT_TEST_REASON)
-class TestLauncher(unittest.TestCase):
- """Test command line parsing"""
-
- def testHelp(self):
- # option -h must cause a raise SystemExit or a return 0
- try:
- parser = main.createParser()
- parser.parse_args(["view", "--help"])
- result = 0
- except SystemExit as e:
- result = e.args[0]
- self.assertEqual(result, 0)
-
- def testWrongOption(self):
- try:
- parser = main.createParser()
- parser.parse_args(["view", "--foo"])
- self.fail()
- except SystemExit as e:
- result = e.args[0]
- self.assertNotEqual(result, 0)
-
- def testWrongFile(self):
- try:
- parser = main.createParser()
- result = parser.parse_args(["view", "__file.not.found__"])
- result = 0
- except SystemExit as e:
- result = e.args[0]
- self.assertEqual(result, 0)
-
- def executeAsScript(self, filename, *args):
- """Execute a command line.
-
- Log output as debug in case of bad return code.
- """
- env = self.createTestEnv()
-
- with tempfile.TemporaryDirectory() as tmpdir:
- # Copy file to temporary dir to avoid import from current dir.
- script = os.path.join(tmpdir, 'launcher.py')
- shutil.copyfile(filename, script)
- command_line = [sys.executable, script] + list(args)
-
- _logger.info("Execute: %s", " ".join(command_line))
- p = subprocess.Popen(command_line,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=env)
- out, err = p.communicate()
- _logger.info("Return code: %d", p.returncode)
- try:
- out = out.decode('utf-8')
- except UnicodeError:
- pass
- try:
- err = err.decode('utf-8')
- except UnicodeError:
- pass
-
- if p.returncode != 0:
- _logger.info("stdout:")
- _logger.info("%s", out)
- _logger.info("stderr:")
- _logger.info("%s", err)
- else:
- _logger.debug("stdout:")
- _logger.debug("%s", out)
- _logger.debug("stderr:")
- _logger.debug("%s", err)
- self.assertEqual(p.returncode, 0)
-
- def createTestEnv(self):
- """
- Returns an associated environment with a working project.
- """
- env = dict((str(k), str(v)) for k, v in os.environ.items())
- env["PYTHONPATH"] = os.pathsep.join(sys.path)
- return env
-
- def testExecuteViewHelp(self):
- """Test if the main module is well connected.
-
- Uses subprocess to avoid to parasite the current environment.
- """
- self.executeAsScript(main.__file__, "--help")
-
- def testExecuteSilxViewHelp(self):
- """Test if the main module is well connected.
-
- Uses subprocess to avoid to parasite the current environment.
- """
- self.executeAsScript(silx_main.__file__, "view", "--help")
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loader(TestLauncher))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/app/view/test/test_view.py b/silx/app/view/test/test_view.py
deleted file mode 100644
index 7ea5a2c..0000000
--- a/silx/app/view/test/test_view.py
+++ /dev/null
@@ -1,394 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module testing silx.app.view"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "07/06/2018"
-
-
-import unittest
-import weakref
-import numpy
-import tempfile
-import shutil
-import os.path
-import h5py
-
-from silx.gui import qt
-from silx.app.view.Viewer import Viewer
-from silx.app.view.About import About
-from silx.app.view.DataPanel import DataPanel
-from silx.app.view.CustomNxdataWidget import CustomNxdataWidget
-from silx.gui.hdf5._utils import Hdf5DatasetMimeData
-from silx.gui.utils.testutils import TestCaseQt
-from silx.io import commonh5
-
-_tmpDirectory = None
-
-
-def setUpModule():
- global _tmpDirectory
- _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
-
- # create h5 data
- filename = _tmpDirectory + "/data.h5"
- f = h5py.File(filename, "w")
- g = f.create_group("arrays")
- g.create_dataset("scalar", data=10)
- g.create_dataset("integers", data=numpy.array([10, 20, 30]))
- f.close()
-
- # create h5 data
- filename = _tmpDirectory + "/data2.h5"
- f = h5py.File(filename, "w")
- g = f.create_group("arrays")
- g.create_dataset("scalar", data=20)
- g.create_dataset("integers", data=numpy.array([10, 20, 30]))
- f.close()
-
-
-def tearDownModule():
- global _tmpDirectory
- shutil.rmtree(_tmpDirectory)
- _tmpDirectory = None
-
-
-class TestViewer(TestCaseQt):
- """Test for Viewer class"""
-
- def testConstruct(self):
- widget = Viewer()
- self.qWaitForWindowExposed(widget)
-
- def testDestroy(self):
- widget = Viewer()
- ref = weakref.ref(widget)
- widget = None
- self.qWaitForDestroy(ref)
-
-
-class TestAbout(TestCaseQt):
- """Test for About box class"""
-
- def testConstruct(self):
- widget = About()
- self.qWaitForWindowExposed(widget)
-
- def testLicense(self):
- widget = About()
- widget.getHtmlLicense()
- self.qWaitForWindowExposed(widget)
-
- def testDestroy(self):
- widget = About()
- ref = weakref.ref(widget)
- widget = None
- self.qWaitForDestroy(ref)
-
-
-class TestDataPanel(TestCaseQt):
-
- def testConstruct(self):
- widget = DataPanel()
- self.qWaitForWindowExposed(widget)
-
- def testDestroy(self):
- widget = DataPanel()
- ref = weakref.ref(widget)
- widget = None
- self.qWaitForDestroy(ref)
-
- def testHeaderLabelPaintEvent(self):
- widget = DataPanel()
- data = numpy.array([1, 2, 3, 4, 5])
- widget.setData(data)
- # Expected to execute HeaderLabel.paintEvent
- widget.setVisible(True)
- self.qWaitForWindowExposed(widget)
-
- def testData(self):
- widget = DataPanel()
- data = numpy.array([1, 2, 3, 4, 5])
- widget.setData(data)
- self.assertIs(widget.getData(), data)
- self.assertIs(widget.getCustomNxdataItem(), None)
-
- def testDataNone(self):
- widget = DataPanel()
- widget.setData(None)
- self.assertIs(widget.getData(), None)
- self.assertIs(widget.getCustomNxdataItem(), None)
-
- def testCustomDataItem(self):
- class CustomDataItemMock(object):
- def getVirtualGroup(self):
- return None
-
- def text(self):
- return ""
-
- data = CustomDataItemMock()
- widget = DataPanel()
- widget.setCustomDataItem(data)
- self.assertIs(widget.getData(), None)
- self.assertIs(widget.getCustomNxdataItem(), data)
-
- def testCustomDataItemNone(self):
- data = None
- widget = DataPanel()
- widget.setCustomDataItem(data)
- self.assertIs(widget.getData(), None)
- self.assertIs(widget.getCustomNxdataItem(), data)
-
- def testRemoveDatasetsFrom(self):
- f = h5py.File(os.path.join(_tmpDirectory, "data.h5"), mode='r')
- try:
- widget = DataPanel()
- widget.setData(f["arrays/scalar"])
- widget.removeDatasetsFrom(f)
- self.assertIs(widget.getData(), None)
- finally:
- widget.setData(None)
- f.close()
-
- def testReplaceDatasetsFrom(self):
- f = h5py.File(os.path.join(_tmpDirectory, "data.h5"), mode='r')
- f2 = h5py.File(os.path.join(_tmpDirectory, "data2.h5"), mode='r')
- try:
- widget = DataPanel()
- widget.setData(f["arrays/scalar"])
- self.assertEqual(widget.getData()[()], 10)
- widget.replaceDatasetsFrom(f, f2)
- self.assertEqual(widget.getData()[()], 20)
- finally:
- widget.setData(None)
- f.close()
- f2.close()
-
-
-class TestCustomNxdataWidget(TestCaseQt):
-
- def testConstruct(self):
- widget = CustomNxdataWidget()
- self.qWaitForWindowExposed(widget)
-
- def testDestroy(self):
- widget = CustomNxdataWidget()
- ref = weakref.ref(widget)
- widget = None
- self.qWaitForDestroy(ref)
-
- def testCreateNxdata(self):
- widget = CustomNxdataWidget()
- model = widget.model()
- model.createNewNxdata()
- model.createNewNxdata("Foo")
- widget.setVisible(True)
- self.qWaitForWindowExposed(widget)
-
- def testCreateNxdataFromDataset(self):
- widget = CustomNxdataWidget()
- model = widget.model()
- signal = commonh5.Dataset("foo", data=numpy.array([[[5]]]))
- model.createFromSignal(signal)
- widget.setVisible(True)
- self.qWaitForWindowExposed(widget)
-
- def testCreateNxdataFromNxdata(self):
- widget = CustomNxdataWidget()
- model = widget.model()
- data = numpy.array([[[5]]])
- nxdata = commonh5.Group("foo")
- nxdata.attrs["NX_class"] = "NXdata"
- nxdata.attrs["signal"] = "signal"
- nxdata.create_dataset("signal", data=data)
- model.createFromNxdata(nxdata)
- widget.setVisible(True)
- self.qWaitForWindowExposed(widget)
-
- def testCreateBadNxdata(self):
- widget = CustomNxdataWidget()
- model = widget.model()
- signal = commonh5.Dataset("foo", data=numpy.array([[[5]]]))
- model.createFromSignal(signal)
- axis = commonh5.Dataset("foo", data=numpy.array([[[5]]]))
- nxdataIndex = model.index(0, 0)
- item = model.itemFromIndex(nxdataIndex)
- item.setAxesDatasets([axis])
- nxdata = item.getVirtualGroup()
- self.assertIsNotNone(nxdata)
- self.assertFalse(item.isValid())
-
- def testRemoveDatasetsFrom(self):
- f = h5py.File(os.path.join(_tmpDirectory, "data.h5"), mode='r')
- try:
- widget = CustomNxdataWidget()
- model = widget.model()
- dataset = f["arrays/integers"]
- model.createFromSignal(dataset)
- widget.removeDatasetsFrom(f)
- finally:
- model.clear()
- f.close()
-
- def testReplaceDatasetsFrom(self):
- f = h5py.File(os.path.join(_tmpDirectory, "data.h5"), mode='r')
- f2 = h5py.File(os.path.join(_tmpDirectory, "data2.h5"), mode='r')
- try:
- widget = CustomNxdataWidget()
- model = widget.model()
- dataset = f["arrays/integers"]
- model.createFromSignal(dataset)
- widget.replaceDatasetsFrom(f, f2)
- finally:
- model.clear()
- f.close()
- f2.close()
-
-
-class TestCustomNxdataWidgetInteraction(TestCaseQt):
- """Test CustomNxdataWidget with user interaction"""
-
- def setUp(self):
- TestCaseQt.setUp(self)
-
- self.widget = CustomNxdataWidget()
- self.model = self.widget.model()
- data = numpy.array([[[5]]])
- dataset = commonh5.Dataset("foo", data=data)
- self.model.createFromSignal(dataset)
- self.selectionModel = self.widget.selectionModel()
-
- def tearDown(self):
- self.selectionModel = None
- self.model.clear()
- self.model = None
- self.widget = None
- TestCaseQt.tearDown(self)
-
- def testSelectedNxdata(self):
- index = self.model.index(0, 0)
- self.selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
- nxdata = self.widget.selectedNxdata()
- self.assertEqual(len(nxdata), 1)
- self.assertIsNot(nxdata[0], None)
-
- def testSelectedItems(self):
- index = self.model.index(0, 0)
- self.selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
- items = self.widget.selectedItems()
- self.assertEqual(len(items), 1)
- self.assertIsNot(items[0], None)
- self.assertIsInstance(items[0], qt.QStandardItem)
-
- def testRowsAboutToBeRemoved(self):
- self.model.removeRow(0)
- self.qWaitForWindowExposed(self.widget)
-
- def testPaintItems(self):
- self.widget.expandAll()
- self.widget.setVisible(True)
- self.qWaitForWindowExposed(self.widget)
-
- def testCreateDefaultContextMenu(self):
- nxDataIndex = self.model.index(0, 0)
- menu = self.widget.createDefaultContextMenu(nxDataIndex)
- self.assertIsNot(menu, None)
- self.assertIsInstance(menu, qt.QMenu)
-
- signalIndex = self.model.index(0, 0, nxDataIndex)
- menu = self.widget.createDefaultContextMenu(signalIndex)
- self.assertIsNot(menu, None)
- self.assertIsInstance(menu, qt.QMenu)
-
- axesIndex = self.model.index(1, 0, nxDataIndex)
- menu = self.widget.createDefaultContextMenu(axesIndex)
- self.assertIsNot(menu, None)
- self.assertIsInstance(menu, qt.QMenu)
-
- def testDropNewDataset(self):
- dataset = commonh5.Dataset("foo", numpy.array([1, 2, 3, 4]))
- mimedata = Hdf5DatasetMimeData(dataset=dataset)
- self.model.dropMimeData(mimedata, qt.Qt.CopyAction, -1, -1, qt.QModelIndex())
- self.assertEqual(self.model.rowCount(qt.QModelIndex()), 2)
-
- def testDropNewNxdata(self):
- data = numpy.array([[[5]]])
- nxdata = commonh5.Group("foo")
- nxdata.attrs["NX_class"] = "NXdata"
- nxdata.attrs["signal"] = "signal"
- nxdata.create_dataset("signal", data=data)
- mimedata = Hdf5DatasetMimeData(dataset=nxdata)
- self.model.dropMimeData(mimedata, qt.Qt.CopyAction, -1, -1, qt.QModelIndex())
- self.assertEqual(self.model.rowCount(qt.QModelIndex()), 2)
-
- def testDropAxisDataset(self):
- dataset = commonh5.Dataset("foo", numpy.array([1, 2, 3, 4]))
- mimedata = Hdf5DatasetMimeData(dataset=dataset)
- nxDataIndex = self.model.index(0, 0)
- axesIndex = self.model.index(1, 0, nxDataIndex)
- self.model.dropMimeData(mimedata, qt.Qt.CopyAction, -1, -1, axesIndex)
- self.assertEqual(self.model.rowCount(qt.QModelIndex()), 1)
- item = self.model.itemFromIndex(axesIndex)
- self.assertIsNot(item.getDataset(), None)
-
- def testMimeData(self):
- nxDataIndex = self.model.index(0, 0)
- signalIndex = self.model.index(0, 0, nxDataIndex)
- mimeData = self.model.mimeData([signalIndex])
- self.assertIsNot(mimeData, None)
- self.assertIsInstance(mimeData, qt.QMimeData)
-
- def testRemoveNxdataItem(self):
- nxdataIndex = self.model.index(0, 0)
- item = self.model.itemFromIndex(nxdataIndex)
- self.model.removeNxdataItem(item)
-
- def testAppendAxisToNxdataItem(self):
- nxdataIndex = self.model.index(0, 0)
- item = self.model.itemFromIndex(nxdataIndex)
- self.model.appendAxisToNxdataItem(item)
-
- def testRemoveAxisItem(self):
- nxdataIndex = self.model.index(0, 0)
- axesIndex = self.model.index(1, 0, nxdataIndex)
- item = self.model.itemFromIndex(axesIndex)
- self.model.removeAxisItem(item)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loader(TestViewer))
- test_suite.addTest(loader(TestAbout))
- test_suite.addTest(loader(TestDataPanel))
- test_suite.addTest(loader(TestCustomNxdataWidget))
- test_suite.addTest(loader(TestCustomNxdataWidgetInteraction))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/_glutils/FramebufferTexture.py b/silx/gui/_glutils/FramebufferTexture.py
deleted file mode 100644
index e065030..0000000
--- a/silx/gui/_glutils/FramebufferTexture.py
+++ /dev/null
@@ -1,165 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Association of a texture and a framebuffer object for off-screen rendering.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "25/07/2016"
-
-
-import logging
-
-from . import gl
-from .Texture import Texture
-
-
-_logger = logging.getLogger(__name__)
-
-
-class FramebufferTexture(object):
- """Framebuffer with a texture.
-
- Aimed at off-screen rendering to texture.
-
- :param internalFormat: OpenGL texture internal format
- :param shape: Shape (height, width) of the framebuffer and texture
- :type shape: 2-tuple of int
- :param stencilFormat: Stencil renderbuffer format
- :param depthFormat: Depth renderbuffer format
- :param kwargs: Extra arguments for :class:`Texture` constructor
- """
-
- _PACKED_FORMAT = gl.GL_DEPTH24_STENCIL8, gl.GL_DEPTH_STENCIL
-
- def __init__(self,
- internalFormat,
- shape,
- stencilFormat=gl.GL_DEPTH24_STENCIL8,
- depthFormat=gl.GL_DEPTH24_STENCIL8,
- **kwargs):
-
- self._texture = Texture(internalFormat, shape=shape, **kwargs)
- self._texture.prepare()
-
- self._previousFramebuffer = 0 # Used by with statement
-
- self._name = gl.glGenFramebuffers(1)
-
- with self: # Bind FBO
- # Attachments
- gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER,
- gl.GL_COLOR_ATTACHMENT0,
- gl.GL_TEXTURE_2D,
- self._texture.name,
- 0)
-
- height, width = self._texture.shape
-
- if stencilFormat is not None:
- self._stencilId = gl.glGenRenderbuffers(1)
- gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._stencilId)
- gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
- stencilFormat,
- width, height)
- gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
- gl.GL_STENCIL_ATTACHMENT,
- gl.GL_RENDERBUFFER,
- self._stencilId)
- else:
- self._stencilId = None
-
- if depthFormat is not None:
- if self._stencilId and depthFormat in self._PACKED_FORMAT:
- self._depthId = self._stencilId
- else:
- self._depthId = gl.glGenRenderbuffers(1)
- gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._depthId)
- gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
- depthFormat,
- width, height)
- gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
- gl.GL_DEPTH_ATTACHMENT,
- gl.GL_RENDERBUFFER,
- self._depthId)
- else:
- self._depthId = None
-
- assert (gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) ==
- gl.GL_FRAMEBUFFER_COMPLETE)
-
- @property
- def shape(self):
- """Shape of the framebuffer (height, width)"""
- return self._texture.shape
-
- @property
- def texture(self):
- """The texture this framebuffer is rendering to.
-
- The life-cycle of the texture is managed by this object"""
- return self._texture
-
- @property
- def name(self):
- """OpenGL name of the framebuffer"""
- if self._name is not None:
- return self._name
- else:
- raise RuntimeError("No OpenGL framebuffer resource, \
- discard has already been called")
-
- def bind(self):
- """Bind this framebuffer for rendering"""
- gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.name)
-
- # with statement
-
- def __enter__(self):
- self._previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
- self.bind()
-
- def __exit__(self, exctype, excvalue, traceback):
- gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._previousFramebuffer)
- self._previousFramebuffer = None
-
- def discard(self):
- """Delete associated OpenGL resources including texture"""
- if self._name is not None:
- gl.glDeleteFramebuffers(self._name)
- self._name = None
-
- if self._stencilId is not None:
- gl.glDeleteRenderbuffers(self._stencilId)
- if self._stencilId == self._depthId:
- self._depthId = None
- self._stencilId = None
- if self._depthId is not None:
- gl.glDeleteRenderbuffers(self._depthId)
- self._depthId = None
-
- self._texture.discard() # Also discard the texture
- else:
- _logger.warning("Discard has already been called")
diff --git a/silx/gui/_glutils/OpenGLWidget.py b/silx/gui/_glutils/OpenGLWidget.py
deleted file mode 100644
index 5e3fcb8..0000000
--- a/silx/gui/_glutils/OpenGLWidget.py
+++ /dev/null
@@ -1,423 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides a compatibility layer for OpenGL widget.
-
-It provides a compatibility layer for Qt OpenGL widget used in silx
-across Qt<=5.3 QtOpenGL.QGLWidget and QOpenGLWidget.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "22/11/2019"
-
-
-import logging
-import sys
-
-from .. import qt
-from ..utils.glutils import isOpenGLAvailable
-from .._glutils import gl
-
-
-_logger = logging.getLogger(__name__)
-
-
-if not hasattr(qt, 'QOpenGLWidget') and not hasattr(qt, 'QGLWidget'):
- OpenGLWidget = None
-
-else:
- if hasattr(qt, 'QOpenGLWidget'): # PyQt>=5.4
- _logger.info('Using QOpenGLWidget')
- _BaseOpenGLWidget = qt.QOpenGLWidget
-
- else:
- _logger.info('Using QGLWidget')
- _BaseOpenGLWidget = qt.QGLWidget
-
- class _OpenGLWidget(_BaseOpenGLWidget):
- """Wrapper over QOpenGLWidget and QGLWidget"""
-
- sigOpenGLContextError = qt.Signal(str)
- """Signal emitted when an OpenGL context error is detected at runtime.
-
- It provides the error reason as a str.
- """
-
- def __init__(self, parent,
- alphaBufferSize=0,
- depthBufferSize=24,
- stencilBufferSize=8,
- version=(2, 0),
- f=qt.Qt.WindowFlags()):
- # True if using QGLWidget, False if using QOpenGLWidget
- self.__legacy = not hasattr(qt, 'QOpenGLWidget')
-
- self.__devicePixelRatio = 1.0
- self.__requestedOpenGLVersion = int(version[0]), int(version[1])
- self.__isValid = False
-
- if self.__legacy: # QGLWidget
- format_ = qt.QGLFormat()
- format_.setAlphaBufferSize(alphaBufferSize)
- format_.setAlpha(alphaBufferSize != 0)
- format_.setDepthBufferSize(depthBufferSize)
- format_.setDepth(depthBufferSize != 0)
- format_.setStencilBufferSize(stencilBufferSize)
- format_.setStencil(stencilBufferSize != 0)
- format_.setVersion(*self.__requestedOpenGLVersion)
- format_.setDoubleBuffer(True)
-
- super(_OpenGLWidget, self).__init__(format_, parent, None, f)
-
- else: # QOpenGLWidget
- super(_OpenGLWidget, self).__init__(parent, f)
-
- format_ = qt.QSurfaceFormat()
- format_.setAlphaBufferSize(alphaBufferSize)
- format_.setDepthBufferSize(depthBufferSize)
- format_.setStencilBufferSize(stencilBufferSize)
- format_.setVersion(*self.__requestedOpenGLVersion)
- format_.setSwapBehavior(qt.QSurfaceFormat.DoubleBuffer)
- self.setFormat(format_)
-
- # Enable receiving mouse move events when no buttons are pressed
- self.setMouseTracking(True)
-
- def getDevicePixelRatio(self):
- """Returns the ratio device-independent / device pixel size
-
- It should be either 1.0 or 2.0.
-
- :return: Scale factor between screen and Qt units
- :rtype: float
- """
- return self.__devicePixelRatio
-
- def getRequestedOpenGLVersion(self):
- """Returns the requested OpenGL version.
-
- :return: (major, minor)
- :rtype: 2-tuple of int"""
- return self.__requestedOpenGLVersion
-
- def getOpenGLVersion(self):
- """Returns the available OpenGL version.
-
- :return: (major, minor)
- :rtype: 2-tuple of int"""
- if self.__legacy: # QGLWidget
- supportedVersion = 0, 0
-
- # Go through all OpenGL version flags checking support
- flags = self.format().openGLVersionFlags()
- for version in ((1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
- (2, 0), (2, 1),
- (3, 0), (3, 1), (3, 2), (3, 3),
- (4, 0)):
- versionFlag = getattr(qt.QGLFormat,
- 'OpenGL_Version_%d_%d' % version)
- if not versionFlag & flags:
- break
- supportedVersion = version
- return supportedVersion
-
- else: # QOpenGLWidget
- return self.format().version()
-
- # QOpenGLWidget methods
-
- def isValid(self):
- """Returns True if OpenGL is available.
-
- This adds extra checks to Qt isValid method.
-
- :rtype: bool
- """
- return self.__isValid and super(_OpenGLWidget, self).isValid()
-
- def defaultFramebufferObject(self):
- """Returns the framebuffer object handle.
-
- See :meth:`QOpenGLWidget.defaultFramebufferObject`
- """
- if self.__legacy: # QGLWidget
- return 0
- else: # QOpenGLWidget
- return super(_OpenGLWidget, self).defaultFramebufferObject()
-
- # *GL overridden methods
-
- def initializeGL(self):
- parent = self.parent()
- if parent is None:
- _logger.error('_OpenGLWidget has no parent')
- return
-
- # Check OpenGL version
- if self.getOpenGLVersion() >= self.getRequestedOpenGLVersion():
- try:
- gl.glGetError() # clear any previous error (if any)
- version = gl.glGetString(gl.GL_VERSION)
- except:
- version = None
-
- if version:
- self.__isValid = True
- else:
- errMsg = 'OpenGL not available'
- if sys.platform.startswith('linux'):
- errMsg += ': If connected remotely, ' \
- 'GLX forwarding might be disabled.'
- _logger.error(errMsg)
- self.sigOpenGLContextError.emit(errMsg)
- self.__isValid = False
-
- else:
- errMsg = 'OpenGL %d.%d not available' % \
- self.getRequestedOpenGLVersion()
- _logger.error('OpenGL widget disabled: %s', errMsg)
- self.sigOpenGLContextError.emit(errMsg)
- self.__isValid = False
-
- if self.isValid():
- parent.initializeGL()
-
- def paintGL(self):
- parent = self.parent()
- if parent is None:
- _logger.error('_OpenGLWidget has no parent')
- return
-
- if qt.BINDING in ('PyQt5', 'PySide2'):
- devicePixelRatio = self.window().windowHandle().devicePixelRatio()
-
- if devicePixelRatio != self.getDevicePixelRatio():
- # Update devicePixelRatio and call resizeOpenGL
- # as resizeGL is not always called.
- self.__devicePixelRatio = devicePixelRatio
- self.makeCurrent()
- parent.resizeGL(self.width(), self.height())
-
- if self.isValid():
- parent.paintGL()
-
- def resizeGL(self, width, height):
- parent = self.parent()
- if parent is None:
- _logger.error('_OpenGLWidget has no parent')
- return
-
- if self.isValid():
- # Call parent resizeGL with device-independent pixel unit
- # This works over both QGLWidget and QOpenGLWidget
- parent.resizeGL(self.width(), self.height())
-
-
-class OpenGLWidget(qt.QWidget):
- """OpenGL widget wrapper over QGLWidget and QOpenGLWidget
-
- This wrapper API implements a subset of QOpenGLWidget API.
- The constructor takes a different set of arguments.
- Methods returning object like :meth:`context` returns either
- QGL* or QOpenGL* objects.
-
- :param parent: Parent widget see :class:`QWidget`
- :param int alphaBufferSize:
- Size in bits of the alpha channel (default: 0).
- Set to 0 to disable alpha channel.
- :param int depthBufferSize:
- Size in bits of the depth buffer (default: 24).
- Set to 0 to disable depth buffer.
- :param int stencilBufferSize:
- Size in bits of the stencil buffer (default: 8).
- Set to 0 to disable stencil buffer
- :param version: Requested OpenGL version (default: (2, 0)).
- :type version: 2-tuple of int
- :param f: see :class:`QWidget`
- """
-
- def __init__(self, parent=None,
- alphaBufferSize=0,
- depthBufferSize=24,
- stencilBufferSize=8,
- version=(2, 0),
- f=qt.Qt.WindowFlags()):
- super(OpenGLWidget, self).__init__(parent, f)
-
- layout = qt.QHBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- self.setLayout(layout)
-
- self.__context = None
-
- _check = isOpenGLAvailable(version=version, runtimeCheck=False)
- if _OpenGLWidget is None or not _check:
- _logger.error('OpenGL-based widget disabled: %s', _check.error)
- self.__openGLWidget = None
- label = self._createErrorQLabel(_check.error)
- self.layout().addWidget(label)
-
- else:
- self.__openGLWidget = _OpenGLWidget(
- parent=self,
- alphaBufferSize=alphaBufferSize,
- depthBufferSize=depthBufferSize,
- stencilBufferSize=stencilBufferSize,
- version=version,
- f=f)
- # Async connection need, otherwise issue when hiding OpenGL
- # widget while doing the rendering..
- self.__openGLWidget.sigOpenGLContextError.connect(
- self._handleOpenGLInitError, qt.Qt.QueuedConnection)
- self.layout().addWidget(self.__openGLWidget)
-
- @staticmethod
- def _createErrorQLabel(error):
- """Create QLabel displaying error message in place of OpenGL widget
-
- :param str error: The error message to display"""
- label = qt.QLabel()
- label.setText('OpenGL-based widget disabled:\n%s' % error)
- label.setAlignment(qt.Qt.AlignCenter)
- label.setWordWrap(True)
- return label
-
- def _handleOpenGLInitError(self, error):
- """Handle runtime errors in OpenGL widget"""
- if self.__openGLWidget is not None:
- self.__openGLWidget.setVisible(False)
- self.__openGLWidget.setParent(None)
- self.__openGLWidget = None
-
- label = self._createErrorQLabel(error)
- self.layout().addWidget(label)
-
- # Additional API
-
- def getDevicePixelRatio(self):
- """Returns the ratio device-independent / device pixel size
-
- It should be either 1.0 or 2.0.
-
- :return: Scale factor between screen and Qt units
- :rtype: float
- """
- if self.__openGLWidget is None:
- return 1.
- else:
- return self.__openGLWidget.getDevicePixelRatio()
-
- def getDotsPerInch(self):
- """Returns current screen resolution as device pixels per inch.
-
- :rtype: float
- """
- screen = self.window().windowHandle().screen()
- if screen is not None:
- # TODO check if this is correct on different OS/screen
- # OK on macOS10.12/qt5.13.2
- dpi = screen.physicalDotsPerInch() * self.getDevicePixelRatio()
- else: # Fallback
- dpi = 96. * self.getDevicePixelRatio()
- return dpi
-
- def getOpenGLVersion(self):
- """Returns the available OpenGL version.
-
- :return: (major, minor)
- :rtype: 2-tuple of int"""
- if self.__openGLWidget is None:
- return 0, 0
- else:
- return self.__openGLWidget.getOpenGLVersion()
-
- # QOpenGLWidget API
-
- def isValid(self):
- """Returns True if OpenGL with the requested version is available.
-
- :rtype: bool
- """
- if self.__openGLWidget is None:
- return False
- else:
- return self.__openGLWidget.isValid()
-
- def context(self):
- """Return Qt OpenGL context object or None.
-
- See :meth:`QOpenGLWidget.context` and :meth:`QGLWidget.context`
- """
- if self.__openGLWidget is None:
- return None
- else:
- # Keep a reference on QOpenGLContext to make
- # else PyQt5 keeps creating a new one.
- self.__context = self.__openGLWidget.context()
- return self.__context
-
- def defaultFramebufferObject(self):
- """Returns the framebuffer object handle.
-
- See :meth:`QOpenGLWidget.defaultFramebufferObject`
- """
- if self.__openGLWidget is None:
- return 0
- else:
- return self.__openGLWidget.defaultFramebufferObject()
-
- def makeCurrent(self):
- """Make the underlying OpenGL widget's context current.
-
- See :meth:`QOpenGLWidget.makeCurrent`
- """
- if self.__openGLWidget is not None:
- self.__openGLWidget.makeCurrent()
-
- def update(self):
- """Async update of the OpenGL widget.
-
- See :meth:`QOpenGLWidget.update`
- """
- if self.__openGLWidget is not None:
- self.__openGLWidget.update()
-
- # QOpenGLWidget API to override
-
- def initializeGL(self):
- """Override to implement OpenGL initialization."""
- pass
-
- def paintGL(self):
- """Override to implement OpenGL rendering."""
- pass
-
- def resizeGL(self, width, height):
- """Override to implement resize of OpenGL framebuffer.
-
- :param int width: Width in device-independent pixels
- :param int height: Height in device-independent pixels
- """
- pass
diff --git a/silx/gui/_glutils/font.py b/silx/gui/_glutils/font.py
deleted file mode 100644
index 6a4c489..0000000
--- a/silx/gui/_glutils/font.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Text rasterisation feature leveraging Qt font and text layout support."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "13/10/2016"
-
-
-import logging
-import numpy
-
-from ..utils.image import convertQImageToArray
-from .. import qt
-
-_logger = logging.getLogger(__name__)
-
-
-def getDefaultFontFamily():
- """Returns the default font family of the application"""
- return qt.QApplication.instance().font().family()
-
-
-# Font weights
-ULTRA_LIGHT = 0
-"""Lightest characters: Minimum font weight"""
-
-LIGHT = 25
-"""Light characters"""
-
-NORMAL = 50
-"""Normal characters"""
-
-SEMI_BOLD = 63
-"""Between normal and bold characters"""
-
-BOLD = 74
-"""Thicker characters"""
-
-BLACK = 87
-"""Really thick characters"""
-
-ULTRA_BLACK = 99
-"""Thickest characters: Maximum font weight"""
-
-
-def rasterText(text, font,
- size=-1,
- weight=-1,
- italic=False,
- devicePixelRatio=1.0):
- """Raster text using Qt.
-
- It supports multiple lines.
-
- :param str text: The text to raster
- :param font: Font name or QFont to use
- :type font: str or :class:`QFont`
- :param int size:
- Font size in points
- Used only if font is given as name.
- :param int weight:
- Font weight in [0, 99], see QFont.Weight.
- Used only if font is given as name.
- :param bool italic:
- True for italic font (default: False).
- Used only if font is given as name.
- :param float devicePixelRatio:
- The current ratio between device and device-independent pixel
- (default: 1.0)
- :return: Corresponding image in gray scale and baseline offset from top
- :rtype: (HxW numpy.ndarray of uint8, int)
- """
- if not text:
- _logger.info("Trying to raster empty text, replaced by white space")
- text = ' ' # Replace empty text by white space to produce an image
-
- if (devicePixelRatio != 1.0 and
- not hasattr(qt.QImage, 'setDevicePixelRatio')): # Qt 4
- _logger.error('devicePixelRatio not supported')
- devicePixelRatio = 1.0
-
- if not isinstance(font, qt.QFont):
- font = qt.QFont(font, size, weight, italic)
-
- # get text size
- image = qt.QImage(1, 1, qt.QImage.Format_RGB888)
- painter = qt.QPainter()
- painter.begin(image)
- painter.setPen(qt.Qt.white)
- painter.setFont(font)
- bounds = painter.boundingRect(
- qt.QRect(0, 0, 4096, 4096), qt.Qt.TextExpandTabs, text)
- painter.end()
-
- metrics = qt.QFontMetrics(font)
-
- # This does not provide the correct text bbox on macOS
- # size = metrics.size(qt.Qt.TextExpandTabs, text)
- # bounds = metrics.boundingRect(
- # qt.QRect(0, 0, size.width(), size.height()),
- # qt.Qt.TextExpandTabs,
- # text)
-
- # Add extra border and handle devicePixelRatio
- width = bounds.width() * devicePixelRatio + 2
- # align line size to 32 bits to ease conversion to numpy array
- width = 4 * ((width + 3) // 4)
- image = qt.QImage(int(width),
- int(bounds.height() * devicePixelRatio + 2),
- qt.QImage.Format_RGB888)
- if (devicePixelRatio != 1.0 and
- hasattr(image, 'setDevicePixelRatio')): # Qt 5
- image.setDevicePixelRatio(devicePixelRatio)
-
- # TODO if Qt5 use Format_Grayscale8 instead
- image.fill(0)
-
- # Raster text
- painter = qt.QPainter()
- painter.begin(image)
- painter.setPen(qt.Qt.white)
- painter.setFont(font)
- painter.drawText(bounds, qt.Qt.TextExpandTabs, text)
- painter.end()
-
- array = convertQImageToArray(image)
-
- # RGB to R
- array = numpy.ascontiguousarray(array[:, :, 0])
-
- # Remove leading and trailing empty columns but one on each side
- column_cumsum = numpy.cumsum(numpy.sum(array, axis=0))
- array = array[:, column_cumsum.argmin():column_cumsum.argmax() + 2]
-
- # Remove leading and trailing empty rows but one on each side
- row_cumsum = numpy.cumsum(numpy.sum(array, axis=1))
- min_row = row_cumsum.argmin()
- array = array[min_row:row_cumsum.argmax() + 2, :]
-
- return array, metrics.ascent() - min_row
diff --git a/silx/gui/_glutils/utils.py b/silx/gui/_glutils/utils.py
deleted file mode 100644
index d5627ef..0000000
--- a/silx/gui/_glutils/utils.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2014-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides conversion functions between OpenGL and numpy types.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "10/01/2017"
-
-import numpy
-
-from OpenGL.constants import BYTE_SIZES as _BYTE_SIZES
-from OpenGL.constants import ARRAY_TO_GL_TYPE_MAPPING as _ARRAY_TO_GL_TYPE_MAPPING
-
-
-def sizeofGLType(type_):
- """Returns the size in bytes of an element of type `type_`"""
- return _BYTE_SIZES[type_]
-
-
-def isSupportedGLType(type_):
- """Test if a numpy type or dtype can be converted to a GL type."""
- return numpy.dtype(type_).char in _ARRAY_TO_GL_TYPE_MAPPING
-
-
-def numpyToGLType(type_):
- """Returns the GL type corresponding the provided numpy type or dtype."""
- return _ARRAY_TO_GL_TYPE_MAPPING[numpy.dtype(type_).char]
-
-
-def segmentTrianglesIntersection(segment, triangles):
- """Check for segment/triangles intersection.
-
- This is based on signed tetrahedron volume comparison.
-
- See A. Kensler, A., Shirley, P.
- Optimizing Ray-Triangle Intersection via Automated Search.
- Symposium on Interactive Ray Tracing, vol. 0, p33-38 (2006)
-
- :param numpy.ndarray segment:
- Segment end points as a 2x3 array of coordinates
- :param numpy.ndarray triangles:
- Nx3x3 array of triangles
- :return: (triangle indices, segment parameter, barycentric coord)
- Indices of intersected triangles, "depth" along the segment
- of the intersection point and barycentric coordinates of intersection
- point in the triangle.
- :rtype: List[numpy.ndarray]
- """
- # TODO triangles from vertices + indices
- # TODO early rejection? e.g., check segment bbox vs triangle bbox
- segment = numpy.asarray(segment)
- assert segment.ndim == 2
- assert segment.shape == (2, 3)
-
- triangles = numpy.asarray(triangles)
- assert triangles.ndim == 3
- assert triangles.shape[1] == 3
-
- # Test line/triangles intersection
- d = segment[1] - segment[0]
- t0s0 = segment[0] - triangles[:, 0, :]
- edge01 = triangles[:, 1, :] - triangles[:, 0, :]
- edge02 = triangles[:, 2, :] - triangles[:, 0, :]
-
- dCrossEdge02 = numpy.cross(d, edge02)
- t0s0CrossEdge01 = numpy.cross(t0s0, edge01)
- volume = numpy.sum(dCrossEdge02 * edge01, axis=1)
- del edge01
- subVolumes = numpy.empty((len(triangles), 3), dtype=triangles.dtype)
- subVolumes[:, 1] = numpy.sum(dCrossEdge02 * t0s0, axis=1)
- del dCrossEdge02
- subVolumes[:, 2] = numpy.sum(t0s0CrossEdge01 * d, axis=1)
- subVolumes[:, 0] = volume - subVolumes[:, 1] - subVolumes[:, 2]
- intersect = numpy.logical_or(
- numpy.all(subVolumes >= 0., axis=1), # All positive
- numpy.all(subVolumes <= 0., axis=1)) # All negative
- intersect = numpy.where(intersect)[0] # Indices of intersected triangles
-
- # Get barycentric coordinates
- barycentric = subVolumes[intersect] / volume[intersect].reshape(-1, 1)
- del subVolumes
-
- # Test segment/triangles intersection
- volAlpha = numpy.sum(t0s0CrossEdge01[intersect] * edge02[intersect], axis=1)
- t = volAlpha / volume[intersect] # segment parameter of intersected triangles
- del t0s0CrossEdge01
- del edge02
- del volAlpha
- del volume
-
- inSegmentMask = numpy.logical_and(t >= 0., t <= 1.)
- intersect = intersect[inSegmentMask]
- t = t[inSegmentMask]
- barycentric = barycentric[inSegmentMask]
-
- # Sort intersecting triangles by t
- indices = numpy.argsort(t)
- return intersect[indices], t[indices], barycentric[indices]
diff --git a/silx/gui/colors.py b/silx/gui/colors.py
deleted file mode 100755
index db837b5..0000000
--- a/silx/gui/colors.py
+++ /dev/null
@@ -1,1326 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides API to manage colors.
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent", "H.Payno"]
-__license__ = "MIT"
-__date__ = "29/01/2019"
-
-import numpy
-import logging
-import collections
-import warnings
-
-from silx.gui import qt
-from silx.gui.utils import blockSignals
-from silx.math.combo import min_max
-from silx.math import colormap as _colormap
-from silx.utils.exceptions import NotEditableError
-from silx.utils import deprecation
-from silx.resources import resource_filename as _resource_filename
-
-
-_logger = logging.getLogger(__name__)
-
-try:
- import silx.gui.utils.matplotlib # noqa Initalize matplotlib
- from matplotlib import cm as _matplotlib_cm
- from matplotlib.pyplot import colormaps as _matplotlib_colormaps
-except ImportError:
- _logger.info("matplotlib not available, only embedded colormaps available")
- _matplotlib_cm = None
- _matplotlib_colormaps = None
-
-
-_COLORDICT = {}
-"""Dictionary of common colors."""
-
-_COLORDICT['b'] = _COLORDICT['blue'] = '#0000ff'
-_COLORDICT['r'] = _COLORDICT['red'] = '#ff0000'
-_COLORDICT['g'] = _COLORDICT['green'] = '#00ff00'
-_COLORDICT['k'] = _COLORDICT['black'] = '#000000'
-_COLORDICT['w'] = _COLORDICT['white'] = '#ffffff'
-_COLORDICT['pink'] = '#ff66ff'
-_COLORDICT['brown'] = '#a52a2a'
-_COLORDICT['orange'] = '#ff9900'
-_COLORDICT['violet'] = '#6600ff'
-_COLORDICT['gray'] = _COLORDICT['grey'] = '#a0a0a4'
-# _COLORDICT['darkGray'] = _COLORDICT['darkGrey'] = '#808080'
-# _COLORDICT['lightGray'] = _COLORDICT['lightGrey'] = '#c0c0c0'
-_COLORDICT['y'] = _COLORDICT['yellow'] = '#ffff00'
-_COLORDICT['m'] = _COLORDICT['magenta'] = '#ff00ff'
-_COLORDICT['c'] = _COLORDICT['cyan'] = '#00ffff'
-_COLORDICT['darkBlue'] = '#000080'
-_COLORDICT['darkRed'] = '#800000'
-_COLORDICT['darkGreen'] = '#008000'
-_COLORDICT['darkBrown'] = '#660000'
-_COLORDICT['darkCyan'] = '#008080'
-_COLORDICT['darkYellow'] = '#808000'
-_COLORDICT['darkMagenta'] = '#800080'
-_COLORDICT['transparent'] = '#00000000'
-
-
-# FIXME: It could be nice to expose a functional API instead of that attribute
-COLORDICT = _COLORDICT
-
-
-_LUT_DESCRIPTION = collections.namedtuple("_LUT_DESCRIPTION", ["source", "cursor_color", "preferred"])
-"""Description of a LUT for internal purpose."""
-
-
-_AVAILABLE_LUTS = collections.OrderedDict([
- ('gray', _LUT_DESCRIPTION('builtin', 'pink', True)),
- ('reversed gray', _LUT_DESCRIPTION('builtin', 'pink', True)),
- ('red', _LUT_DESCRIPTION('builtin', 'green', True)),
- ('green', _LUT_DESCRIPTION('builtin', 'pink', True)),
- ('blue', _LUT_DESCRIPTION('builtin', 'yellow', True)),
- ('viridis', _LUT_DESCRIPTION('resource', 'pink', True)),
- ('cividis', _LUT_DESCRIPTION('resource', 'pink', True)),
- ('magma', _LUT_DESCRIPTION('resource', 'green', True)),
- ('inferno', _LUT_DESCRIPTION('resource', 'green', True)),
- ('plasma', _LUT_DESCRIPTION('resource', 'green', True)),
- ('temperature', _LUT_DESCRIPTION('builtin', 'pink', True)),
- ('jet', _LUT_DESCRIPTION('matplotlib', 'pink', True)),
- ('hsv', _LUT_DESCRIPTION('matplotlib', 'black', True)),
-])
-"""Description for internal porpose of all the default LUT provided by the library."""
-
-
-DEFAULT_MIN_LIN = 0
-"""Default min value if in linear normalization"""
-DEFAULT_MAX_LIN = 1
-"""Default max value if in linear normalization"""
-
-
-def rgba(color, colorDict=None):
- """Convert color code '#RRGGBB' and '#RRGGBBAA' to a tuple (R, G, B, A)
- of floats.
-
- It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
- QColor as color argument.
-
- :param str color: The color to convert
- :param dict colorDict: A dictionary of color name conversion to color code
- :returns: RGBA colors as floats in [0., 1.]
- :rtype: tuple
- """
- if colorDict is None:
- colorDict = _COLORDICT
-
- if hasattr(color, 'getRgbF'): # QColor support
- color = color.getRgbF()
-
- values = numpy.asarray(color).ravel()
-
- if values.dtype.kind in 'iuf': # integer or float
- # Color is an array
- assert len(values) in (3, 4)
-
- # Convert from integers in [0, 255] to float in [0, 1]
- if values.dtype.kind in 'iu':
- values = values / 255.
-
- # Clip to [0, 1]
- values[values < 0.] = 0.
- values[values > 1.] = 1.
-
- if len(values) == 3:
- return values[0], values[1], values[2], 1.
- else:
- return tuple(values)
-
- # We assume color is a string
- if not color.startswith('#'):
- color = colorDict[color]
-
- assert len(color) in (7, 9) and color[0] == '#'
- r = int(color[1:3], 16) / 255.
- g = int(color[3:5], 16) / 255.
- b = int(color[5:7], 16) / 255.
- a = int(color[7:9], 16) / 255. if len(color) == 9 else 1.
- return r, g, b, a
-
-
-def greyed(color, colorDict=None):
- """Convert color code '#RRGGBB' and '#RRGGBBAA' to a grey color
- (R, G, B, A).
-
- It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
- QColor as color argument.
-
- :param str color: The color to convert
- :param dict colorDict: A dictionary of color name conversion to color code
- :returns: RGBA colors as floats in [0., 1.]
- :rtype: tuple
- """
- r, g, b, a = rgba(color=color, colorDict=colorDict)
- g = 0.21 * r + 0.72 * g + 0.07 * b
- return g, g, g, a
-
-
-def asQColor(color):
- """Convert color code '#RRGGBB' and '#RRGGBBAA' to a `qt.QColor`.
-
- It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
- QColor as color argument.
-
- :param str color: The color to convert
- :rtype: qt.QColor
- """
- color = rgba(color)
- return qt.QColor.fromRgbF(*color)
-
-
-def cursorColorForColormap(colormapName):
- """Get a color suitable for overlay over a colormap.
-
- :param str colormapName: The name of the colormap.
- :return: Name of the color.
- :rtype: str
- """
- description = _AVAILABLE_LUTS.get(colormapName, None)
- if description is not None:
- color = description.cursor_color
- if color is not None:
- return color
- return 'black'
-
-
-# Colormap loader
-
-_COLORMAP_CACHE = {}
-"""Cache already used colormaps as name: color LUT"""
-
-
-def _arrayToRgba8888(colors):
- """Convert colors from a numpy array using float (0..1) int or uint
- (0..255) to uint8 RGBA.
-
- :param numpy.ndarray colors: Array of float int or uint colors to convert
- :return: colors as uint8
- :rtype: numpy.ndarray
- """
- assert len(colors.shape) == 2
- assert colors.shape[1] in (3, 4)
-
- if colors.dtype == numpy.uint8:
- pass
- elif colors.dtype.kind == 'f':
- # Each bin is [N, N+1[ except the last one: [255, 256]
- colors = numpy.clip(colors.astype(numpy.float64) * 256, 0., 255.)
- colors = colors.astype(numpy.uint8)
- elif colors.dtype.kind in 'iu':
- colors = numpy.clip(colors, 0, 255)
- colors = colors.astype(numpy.uint8)
-
- if colors.shape[1] == 3:
- tmp = numpy.empty((len(colors), 4), dtype=numpy.uint8)
- tmp[:, 0:3] = colors
- tmp[:, 3] = 255
- colors = tmp
-
- return colors
-
-
-def _createColormapLut(name):
- """Returns the color LUT corresponding to a colormap name
-
- :param str name: Name of the colormap to load
- :returns: Corresponding table of colors
- :rtype: numpy.ndarray
- :raise ValueError: If no colormap corresponds to name
- """
- description = _AVAILABLE_LUTS.get(name)
- use_mpl = False
- if description is not None:
- if description.source == "builtin":
- # Build colormap LUT
- lut = numpy.zeros((256, 4), dtype=numpy.uint8)
- lut[:, 3] = 255
-
- if name == 'gray':
- lut[:, :3] = numpy.arange(256, dtype=numpy.uint8).reshape(-1, 1)
- elif name == 'reversed gray':
- lut[:, :3] = numpy.arange(255, -1, -1, dtype=numpy.uint8).reshape(-1, 1)
- elif name == 'red':
- lut[:, 0] = numpy.arange(256, dtype=numpy.uint8)
- elif name == 'green':
- lut[:, 1] = numpy.arange(256, dtype=numpy.uint8)
- elif name == 'blue':
- lut[:, 2] = numpy.arange(256, dtype=numpy.uint8)
- elif name == 'temperature':
- # Red
- lut[128:192, 0] = numpy.arange(2, 255, 4, dtype=numpy.uint8)
- lut[192:, 0] = 255
- # Green
- lut[:64, 1] = numpy.arange(0, 255, 4, dtype=numpy.uint8)
- lut[64:192, 1] = 255
- lut[192:, 1] = numpy.arange(252, -1, -4, dtype=numpy.uint8)
- # Blue
- lut[:64, 2] = 255
- lut[64:128, 2] = numpy.arange(254, 0, -4, dtype=numpy.uint8)
- else:
- raise RuntimeError("Built-in colormap not implemented")
- return lut
-
- elif description.source == "resource":
- # Load colormap LUT
- colors = numpy.load(_resource_filename("gui/colormaps/%s.npy" % name))
- # Convert to uint8 and add alpha channel
- lut = _arrayToRgba8888(colors)
- return lut
-
- elif description.source == "matplotlib":
- use_mpl = True
-
- else:
- raise RuntimeError("Internal LUT source '%s' unsupported" % description.source)
-
- # Here it expect a matplotlib LUTs
-
- if use_mpl:
- # matplotlib is mandatory
- if _matplotlib_cm is None:
- raise ValueError("The colormap '%s' expect matplotlib, but matplotlib is not installed" % name)
-
- if _matplotlib_cm is not None: # Try to load with matplotlib
- colormap = _matplotlib_cm.get_cmap(name)
- lut = colormap(numpy.linspace(0, 1, colormap.N, endpoint=True))
- lut = _arrayToRgba8888(lut)
- return lut
-
- raise ValueError("Unknown colormap '%s'" % name)
-
-
-def _getColormap(name):
- """Returns the color LUT corresponding to a colormap name
-
- :param str name: Name of the colormap to load
- :returns: Corresponding table of colors
- :rtype: numpy.ndarray
- :raise ValueError: If no colormap corresponds to name
- """
- name = str(name)
- if name not in _COLORMAP_CACHE:
- lut = _createColormapLut(name)
- _COLORMAP_CACHE[name] = lut
- return _COLORMAP_CACHE[name]
-
-
-# Normalizations
-
-class _NormalizationMixIn:
- """Colormap normalization mix-in class"""
-
- DEFAULT_RANGE = 0, 1
- """Fallback for (vmin, vmax)"""
-
- def isValid(self, value):
- """Check if a value is in the valid range for this normalization.
-
- Override in subclass.
-
- :param Union[float,numpy.ndarray] value:
- :rtype: Union[bool,numpy.ndarray]
- """
- if isinstance(value, collections.abc.Iterable):
- return numpy.ones_like(value, dtype=numpy.bool_)
- else:
- return True
-
- def autoscale(self, data, mode):
- """Returns range for given data and autoscale mode.
-
- :param Union[None,numpy.ndarray] data:
- :param str mode: Autoscale mode, see :class:`Colormap`
- :returns: Range as (min, max)
- :rtype: Tuple[float,float]
- """
- data = None if data is None else numpy.array(data, copy=False)
- if data is None or data.size == 0:
- return self.DEFAULT_RANGE
-
- if mode == Colormap.MINMAX:
- vmin, vmax = self.autoscaleMinMax(data)
- elif mode == Colormap.STDDEV3:
- dmin, dmax = self.autoscaleMinMax(data)
- stdmin, stdmax = self.autoscaleMean3Std(data)
- if dmin is None:
- vmin = stdmin
- elif stdmin is None:
- vmin = dmin
- else:
- vmin = max(dmin, stdmin)
-
- if dmax is None:
- vmax = stdmax
- elif stdmax is None:
- vmax = dmax
- else:
- vmax = min(dmax, stdmax)
-
- else:
- raise ValueError('Unsupported mode: %s' % mode)
-
- # Check returned range and handle fallbacks
- if vmin is None or not numpy.isfinite(vmin):
- vmin = self.DEFAULT_RANGE[0]
- if vmax is None or not numpy.isfinite(vmax):
- vmax = self.DEFAULT_RANGE[1]
- if vmax < vmin:
- vmax = vmin
- return float(vmin), float(vmax)
-
- def autoscaleMinMax(self, data):
- """Autoscale using min/max
-
- :param numpy.ndarray data:
- :returns: (vmin, vmax)
- :rtype: Tuple[float,float]
- """
- data = data[self.isValid(data)]
- if data.size == 0:
- return None, None
- result = min_max(data, min_positive=False, finite=True)
- return result.minimum, result.maximum
-
- def autoscaleMean3Std(self, data):
- """Autoscale using mean+/-3std
-
- This implementation only works for normalization that do NOT
- use the data range.
- Override this method for normalization using the range.
-
- :param numpy.ndarray data:
- :returns: (vmin, vmax)
- :rtype: Tuple[float,float]
- """
- # Use [0, 1] as data range for normalization not using range
- normdata = self.apply(data, 0., 1.)
- if normdata.dtype.kind == 'f': # Replaces inf by NaN
- normdata[numpy.isfinite(normdata) == False] = numpy.nan
- if normdata.size == 0: # Fallback
- return None, None
-
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
- # Ignore nanmean "Mean of empty slice" warning and
- # nanstd "Degrees of freedom <= 0 for slice" warning
- mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata)
-
- return self.revert(mean - 3 * std, 0., 1.), self.revert(mean + 3 * std, 0., 1.)
-
-
-class _LinearNormalizationMixIn(_NormalizationMixIn):
- """Colormap normalization mix-in class specific to autoscale taken from initial range"""
-
- def autoscaleMean3Std(self, data):
- """Autoscale using mean+/-3std
-
- Do the autoscale on the data itself, not the normalized data.
-
- :param numpy.ndarray data:
- :returns: (vmin, vmax)
- :rtype: Tuple[float,float]
- """
- if data.dtype.kind == 'f': # Replaces inf by NaN
- data = numpy.array(data, copy=True) # Work on a copy
- data[numpy.isfinite(data) == False] = numpy.nan
- if data.size == 0: # Fallback
- return None, None
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
- # Ignore nanmean "Mean of empty slice" warning and
- # nanstd "Degrees of freedom <= 0 for slice" warning
- mean, std = numpy.nanmean(data), numpy.nanstd(data)
- return mean - 3 * std, mean + 3 * std
-
-
-class _LinearNormalization(_colormap.LinearNormalization, _LinearNormalizationMixIn):
- """Linear normalization"""
- def __init__(self):
- _colormap.LinearNormalization.__init__(self)
- _LinearNormalizationMixIn.__init__(self)
-
-
-class _LogarithmicNormalization(_colormap.LogarithmicNormalization, _NormalizationMixIn):
- """Logarithm normalization"""
-
- DEFAULT_RANGE = 1, 10
-
- def __init__(self):
- _colormap.LogarithmicNormalization.__init__(self)
- _NormalizationMixIn.__init__(self)
-
- def isValid(self, value):
- return value > 0.
-
- def autoscaleMinMax(self, data):
- result = min_max(data, min_positive=True, finite=True)
- return result.min_positive, result.maximum
-
-
-class _SqrtNormalization(_colormap.SqrtNormalization, _NormalizationMixIn):
- """Square root normalization"""
-
- DEFAULT_RANGE = 0, 1
-
- def __init__(self):
- _colormap.SqrtNormalization.__init__(self)
- _NormalizationMixIn.__init__(self)
-
- def isValid(self, value):
- return value >= 0.
-
-
-class _GammaNormalization(_colormap.PowerNormalization, _LinearNormalizationMixIn):
- """Gamma correction normalization:
-
- Linear normalization to [0, 1] followed by power normalization.
-
- :param gamma: Gamma correction factor
- """
- def __init__(self, gamma):
- _colormap.PowerNormalization.__init__(self, gamma)
- _LinearNormalizationMixIn.__init__(self)
-
-
-class _ArcsinhNormalization(_colormap.ArcsinhNormalization, _NormalizationMixIn):
- """Inverse hyperbolic sine normalization"""
-
- def __init__(self):
- _colormap.ArcsinhNormalization.__init__(self)
- _NormalizationMixIn.__init__(self)
-
-
-class Colormap(qt.QObject):
- """Description of a colormap
-
- If no `name` nor `colors` are provided, a default gray LUT is used.
-
- :param str name: Name of the colormap
- :param tuple colors: optional, custom colormap.
- Nx3 or Nx4 numpy array of RGB(A) colors,
- either uint8 or float in [0, 1].
- If 'name' is None, then this array is used as the colormap.
- :param str normalization: Normalization: 'linear' (default) or 'log'
- :param vmin: Lower bound of the colormap or None for autoscale (default)
- :type vmin: Union[None, float]
- :param vmax: Upper bounds of the colormap or None for autoscale (default)
- :type vmax: Union[None, float]
- """
-
- LINEAR = 'linear'
- """constant for linear normalization"""
-
- LOGARITHM = 'log'
- """constant for logarithmic normalization"""
-
- SQRT = 'sqrt'
- """constant for square root normalization"""
-
- GAMMA = 'gamma'
- """Constant for gamma correction normalization"""
-
- ARCSINH = 'arcsinh'
- """constant for inverse hyperbolic sine normalization"""
-
- _BASIC_NORMALIZATIONS = {
- LINEAR: _LinearNormalization(),
- LOGARITHM: _LogarithmicNormalization(),
- SQRT: _SqrtNormalization(),
- ARCSINH: _ArcsinhNormalization(),
- }
- """Normalizations without parameters"""
-
- NORMALIZATIONS = LINEAR, LOGARITHM, SQRT, GAMMA, ARCSINH
- """Tuple of managed normalizations"""
-
- MINMAX = 'minmax'
- """constant for autoscale using min/max data range"""
-
- STDDEV3 = 'stddev3'
- """constant for autoscale using mean +/- 3*std(data)
- with a clamp on min/max of the data"""
-
- AUTOSCALE_MODES = (MINMAX, STDDEV3)
- """Tuple of managed auto scale algorithms"""
-
- sigChanged = qt.Signal()
- """Signal emitted when the colormap has changed."""
-
- _DEFAULT_NAN_COLOR = 255, 255, 255, 0
-
- def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None, autoscaleMode=MINMAX):
- qt.QObject.__init__(self)
- self._editable = True
- self.__gamma = 2.0
- # Default NaN color: fully transparent white
- self.__nanColor = numpy.array(self._DEFAULT_NAN_COLOR, dtype=numpy.uint8)
-
- assert normalization in Colormap.NORMALIZATIONS
- assert autoscaleMode in Colormap.AUTOSCALE_MODES
-
- if normalization is Colormap.LOGARITHM:
- if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0):
- m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale."
- m += ' Autoscale will be performed.'
- m = m % (vmin, vmax)
- _logger.warning(m)
- vmin = None
- vmax = None
-
- self._name = None
- self._colors = None
-
- if colors is not None and name is not None:
- deprecation.deprecated_warning("Argument",
- name="silx.gui.plot.Colors",
- reason="name and colors can't be used at the same time",
- since_version="0.10.0",
- skip_backtrace_count=1)
-
- colors = None
-
- if name is not None:
- self.setName(name) # And resets colormap LUT
- elif colors is not None:
- self.setColormapLUT(colors)
- else:
- # Default colormap is grey
- self.setName("gray")
-
- self._normalization = str(normalization)
- self._autoscaleMode = str(autoscaleMode)
- self._vmin = float(vmin) if vmin is not None else None
- self._vmax = float(vmax) if vmax is not None else None
-
- def setFromColormap(self, other):
- """Set this colormap using information from the `other` colormap.
-
- :param ~silx.gui.colors.Colormap other: Colormap to use as reference.
- """
- if not self.isEditable():
- raise NotEditableError('Colormap is not editable')
- if self == other:
- return
- with blockSignals(self):
- name = other.getName()
- if name is not None:
- self.setName(name)
- else:
- self.setColormapLUT(other.getColormapLUT())
- self.setNaNColor(other.getNaNColor())
- self.setNormalization(other.getNormalization())
- self.setGammaNormalizationParameter(
- other.getGammaNormalizationParameter())
- self.setAutoscaleMode(other.getAutoscaleMode())
- self.setVRange(*other.getVRange())
- self.setEditable(other.isEditable())
- self.sigChanged.emit()
-
- def getNColors(self, nbColors=None):
- """Returns N colors computed by sampling the colormap regularly.
-
- :param nbColors:
- The number of colors in the returned array or None for the default value.
- The default value is the size of the colormap LUT.
- :type nbColors: int or None
- :return: 2D array of uint8 of shape (nbColors, 4)
- :rtype: numpy.ndarray
- """
- # Handle default value for nbColors
- if nbColors is None:
- return numpy.array(self._colors, copy=True)
- else:
- nbColors = int(nbColors)
- colormap = self.copy()
- colormap.setNormalization(Colormap.LINEAR)
- colormap.setVRange(vmin=0, vmax=nbColors - 1)
- colors = colormap.applyToData(
- numpy.arange(nbColors, dtype=numpy.int32))
- return colors
-
- def getName(self):
- """Return the name of the colormap
- :rtype: str
- """
- return self._name
-
- def setName(self, name):
- """Set the name of the colormap to use.
-
- :param str name: The name of the colormap.
- At least the following names are supported: 'gray',
- 'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet',
- 'viridis', 'magma', 'inferno', 'plasma'.
- """
- name = str(name)
- if self._name == name:
- return
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- if name not in self.getSupportedColormaps():
- raise ValueError("Colormap name '%s' is not supported" % name)
- self._name = name
- self._colors = _getColormap(self._name)
- self.sigChanged.emit()
-
- def getColormapLUT(self, copy=True):
- """Return the list of colors for the colormap or None if not set.
-
- This returns None if the colormap was set with :meth:`setName`.
- Use :meth:`getNColors` to get the colormap LUT for any colormap.
-
- :param bool copy: If true a copy of the numpy array is provided
- :return: the list of colors for the colormap or None if not set
- :rtype: numpy.ndarray or None
- """
- if self._name is None:
- return numpy.array(self._colors, copy=copy)
- else:
- return None
-
- def setColormapLUT(self, colors):
- """Set the colors of the colormap.
-
- :param numpy.ndarray colors: the colors of the LUT.
- If float, it is converted from [0, 1] to uint8 range.
- Otherwise it is casted to uint8.
-
- .. warning: this will set the value of name to None
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- assert colors is not None
-
- colors = numpy.array(colors, copy=False)
- if colors.shape == ():
- raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors))
- assert len(colors) != 0
- assert colors.ndim >= 2
- colors.shape = -1, colors.shape[-1]
- self._colors = _arrayToRgba8888(colors)
- self._name = None
- self.sigChanged.emit()
-
- def getNaNColor(self):
- """Returns the color to use for Not-A-Number floating point value.
-
- :rtype: QColor
- """
- return qt.QColor(*self.__nanColor)
-
- def setNaNColor(self, color):
- """Set the color to use for Not-A-Number floating point value.
-
- :param color: RGB(A) color to use for NaN values
- :type color: QColor, str, tuple of uint8 or float in [0., 1.]
- """
- color = (numpy.array(rgba(color)) * 255).astype(numpy.uint8)
- if not numpy.array_equal(self.__nanColor, color):
- self.__nanColor = color
- self.sigChanged.emit()
-
- def getNormalization(self):
- """Return the normalization of the colormap.
-
- See :meth:`setNormalization` for returned values.
-
- :return: the normalization of the colormap
- :rtype: str
- """
- return self._normalization
-
- def setNormalization(self, norm):
- """Set the colormap normalization.
-
- Accepted normalizations: 'log', 'linear', 'sqrt'
-
- :param str norm: the norm to set
- """
- assert norm in self.NORMALIZATIONS
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- self._normalization = str(norm)
- self.sigChanged.emit()
-
- def setGammaNormalizationParameter(self, gamma: float) -> None:
- """Set the gamma correction parameter.
-
- Only used for gamma correction normalization.
-
- :param float gamma:
- :raise ValueError: If gamma is not valid
- """
- if gamma < 0. or not numpy.isfinite(gamma):
- raise ValueError("Gamma value not supported")
- if gamma != self.__gamma:
- self.__gamma = gamma
- self.sigChanged.emit()
-
- def getGammaNormalizationParameter(self) -> float:
- """Returns the gamma correction parameter value.
-
- :rtype: float
- """
- return self.__gamma
-
- def getAutoscaleMode(self):
- """Return the autoscale mode of the colormap ('minmax' or 'stddev3')
-
- :rtype: str
- """
- return self._autoscaleMode
-
- def setAutoscaleMode(self, mode):
- """Set the autoscale mode: either 'minmax' or 'stddev3'
-
- :param str mode: the mode to set
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- assert mode in self.AUTOSCALE_MODES
- if mode != self._autoscaleMode:
- self._autoscaleMode = mode
- self.sigChanged.emit()
-
- def isAutoscale(self):
- """Return True if both min and max are in autoscale mode"""
- return self._vmin is None and self._vmax is None
-
- def getVMin(self):
- """Return the lower bound of the colormap
-
- :return: the lower bound of the colormap
- :rtype: float or None
- """
- return self._vmin
-
- def setVMin(self, vmin):
- """Set the minimal value of the colormap
-
- :param float vmin: Lower bound of the colormap or None for autoscale
- (default)
- value)
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- if vmin is not None:
- if self._vmax is not None and vmin > self._vmax:
- err = "Can't set vmin because vmin >= vmax. " \
- "vmin = %s, vmax = %s" % (vmin, self._vmax)
- raise ValueError(err)
-
- self._vmin = vmin
- self.sigChanged.emit()
-
- def getVMax(self):
- """Return the upper bounds of the colormap or None
-
- :return: the upper bounds of the colormap or None
- :rtype: float or None
- """
- return self._vmax
-
- def setVMax(self, vmax):
- """Set the maximal value of the colormap
-
- :param float vmax: Upper bounds of the colormap or None for autoscale
- (default)
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- if vmax is not None:
- if self._vmin is not None and vmax < self._vmin:
- err = "Can't set vmax because vmax <= vmin. " \
- "vmin = %s, vmax = %s" % (self._vmin, vmax)
- raise ValueError(err)
-
- self._vmax = vmax
- self.sigChanged.emit()
-
- def isEditable(self):
- """ Return if the colormap is editable or not
-
- :return: editable state of the colormap
- :rtype: bool
- """
- return self._editable
-
- def setEditable(self, editable):
- """
- Set the editable state of the colormap
-
- :param bool editable: is the colormap editable
- """
- assert type(editable) is bool
- self._editable = editable
- self.sigChanged.emit()
-
- def _getNormalizer(self):
- """Returns normalizer object"""
- normalization = self.getNormalization()
- if normalization == self.GAMMA:
- return _GammaNormalization(self.getGammaNormalizationParameter())
- else:
- return self._BASIC_NORMALIZATIONS[normalization]
-
- def _computeAutoscaleRange(self, data):
- """Compute the data range which will be used in autoscale mode.
-
- :param numpy.ndarray data: The data for which to compute the range
- :return: (vmin, vmax) range
- """
- return self._getNormalizer().autoscale(
- data, mode=self.getAutoscaleMode())
-
- def getColormapRange(self, data=None):
- """Return (vmin, vmax) the range of the colormap for the given data or item.
-
- :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixIn] data:
- The data or item to use for autoscale bounds.
- :return: (vmin, vmax) corresponding to the colormap applied to data if provided.
- :rtype: tuple
- """
- vmin = self._vmin
- vmax = self._vmax
- assert vmin is None or vmax is None or vmin <= vmax # TODO handle this in setters
-
- normalizer = self._getNormalizer()
-
- # Handle invalid bounds as autoscale
- if vmin is not None and not normalizer.isValid(vmin):
- _logger.info(
- 'Invalid vmin, switching to autoscale for lower bound')
- vmin = None
- if vmax is not None and not normalizer.isValid(vmax):
- _logger.info(
- 'Invalid vmax, switching to autoscale for upper bound')
- vmax = None
-
- if vmin is None or vmax is None: # Handle autoscale
- from .plot.items.core import ColormapMixIn # avoid cyclic import
- if isinstance(data, ColormapMixIn):
- min_, max_ = data._getColormapAutoscaleRange(self)
- # Make sure min_, max_ are not None
- min_ = normalizer.DEFAULT_RANGE[0] if min_ is None else min_
- max_ = normalizer.DEFAULT_RANGE[1] if max_ is None else max_
- else:
- min_, max_ = normalizer.autoscale(
- data, mode=self.getAutoscaleMode())
-
- if vmin is None: # Set vmin respecting provided vmax
- vmin = min_ if vmax is None else min(min_, vmax)
-
- if vmax is None:
- vmax = max(max_, vmin) # Handle max_ <= 0 for log scale
-
- return vmin, vmax
-
- def getVRange(self):
- """Get the bounds of the colormap
-
- :rtype: Tuple(Union[float,None],Union[float,None])
- :returns: A tuple of 2 values for min and max. Or None instead of float
- for autoscale
- """
- return self.getVMin(), self.getVMax()
-
- def setVRange(self, vmin, vmax):
- """Set the bounds of the colormap
-
- :param vmin: Lower bound of the colormap or None for autoscale
- (default)
- :param vmax: Upper bounds of the colormap or None for autoscale
- (default)
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- if vmin is not None and vmax is not None:
- if vmin > vmax:
- err = "Can't set vmin and vmax because vmin >= vmax " \
- "vmin = %s, vmax = %s" % (vmin, vmax)
- raise ValueError(err)
-
- if self._vmin == vmin and self._vmax == vmax:
- return
-
- self._vmin = vmin
- self._vmax = vmax
- self.sigChanged.emit()
-
- def __getitem__(self, item):
- if item == 'autoscale':
- return self.isAutoscale()
- elif item == 'name':
- return self.getName()
- elif item == 'normalization':
- return self.getNormalization()
- elif item == 'vmin':
- return self.getVMin()
- elif item == 'vmax':
- return self.getVMax()
- elif item == 'colors':
- return self.getColormapLUT()
- elif item == 'autoscaleMode':
- return self.getAutoscaleMode()
- else:
- raise KeyError(item)
-
- def _toDict(self):
- """Return the equivalent colormap as a dictionary
- (old colormap representation)
-
- :return: the representation of the Colormap as a dictionary
- :rtype: dict
- """
- return {
- 'name': self._name,
- 'colors': self.getColormapLUT(),
- 'vmin': self._vmin,
- 'vmax': self._vmax,
- 'autoscale': self.isAutoscale(),
- 'normalization': self.getNormalization(),
- 'autoscaleMode': self.getAutoscaleMode(),
- }
-
- def _setFromDict(self, dic):
- """Set values to the colormap from a dictionary
-
- :param dict dic: the colormap as a dictionary
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- name = dic['name'] if 'name' in dic else None
- colors = dic['colors'] if 'colors' in dic else None
- if name is not None and colors is not None:
- if isinstance(colors, int):
- # Filter out argument which was supported but never used
- _logger.info("Unused 'colors' from colormap dictionary filterer.")
- colors = None
- vmin = dic['vmin'] if 'vmin' in dic else None
- vmax = dic['vmax'] if 'vmax' in dic else None
- if 'normalization' in dic:
- normalization = dic['normalization']
- else:
- warn = 'Normalization not given in the dictionary, '
- warn += 'set by default to ' + Colormap.LINEAR
- _logger.warning(warn)
- normalization = Colormap.LINEAR
-
- if name is None and colors is None:
- err = 'The colormap should have a name defined or a tuple of colors'
- raise ValueError(err)
- if normalization not in Colormap.NORMALIZATIONS:
- err = 'Given normalization is not recognized (%s)' % normalization
- raise ValueError(err)
-
- autoscaleMode = dic.get('autoscaleMode', Colormap.MINMAX)
- if autoscaleMode not in Colormap.AUTOSCALE_MODES:
- err = 'Given autoscale mode is not recognized (%s)' % autoscaleMode
- raise ValueError(err)
-
- # If autoscale, then set boundaries to None
- if dic.get('autoscale', False):
- vmin, vmax = None, None
-
- if name is not None:
- self.setName(name)
- else:
- self.setColormapLUT(colors)
- self._vmin = vmin
- self._vmax = vmax
- self._autoscale = True if (vmin is None and vmax is None) else False
- self._normalization = normalization
- self._autoscaleMode = autoscaleMode
-
- self.sigChanged.emit()
-
- @staticmethod
- def _fromDict(dic):
- colormap = Colormap()
- colormap._setFromDict(dic)
- return colormap
-
- def copy(self):
- """Return a copy of the Colormap.
-
- :rtype: silx.gui.colors.Colormap
- """
- colormap = Colormap(name=self._name,
- colors=self.getColormapLUT(),
- vmin=self._vmin,
- vmax=self._vmax,
- normalization=self.getNormalization(),
- autoscaleMode=self.getAutoscaleMode())
- colormap.setNaNColor(self.getNaNColor())
- colormap.setGammaNormalizationParameter(
- self.getGammaNormalizationParameter())
- colormap.setEditable(self.isEditable())
- return colormap
-
- def applyToData(self, data, reference=None):
- """Apply the colormap to the data
-
- :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn] data:
- The data to convert or the item for which to apply the colormap.
- :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn,None] reference:
- The data or item to use as reference to compute autoscale
- """
- if reference is None:
- reference = data
- vmin, vmax = self.getColormapRange(reference)
-
- if hasattr(data, "getColormappedData"): # Use item's data
- data = data.getColormappedData(copy=False)
-
- return _colormap.cmap(
- data,
- self._colors,
- vmin,
- vmax,
- self._getNormalizer(),
- self.__nanColor)
-
- @staticmethod
- def getSupportedColormaps():
- """Get the supported colormap names as a tuple of str.
-
- The list should at least contain and start by:
-
- ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
- 'viridis', 'magma', 'inferno', 'plasma')
-
- :rtype: tuple
- """
- colormaps = set()
- if _matplotlib_colormaps is not None:
- colormaps.update(_matplotlib_colormaps())
- colormaps.update(_AVAILABLE_LUTS.keys())
-
- colormaps = tuple(cmap for cmap in sorted(colormaps)
- if cmap not in _AVAILABLE_LUTS.keys())
-
- return tuple(_AVAILABLE_LUTS.keys()) + colormaps
-
- def __str__(self):
- return str(self._toDict())
-
- def __eq__(self, other):
- """Compare colormap values and not pointers"""
- if other is None:
- return False
- if not isinstance(other, Colormap):
- return False
- if self.getNormalization() != other.getNormalization():
- return False
- if self.getNormalization() == self.GAMMA:
- delta = self.getGammaNormalizationParameter() - other.getGammaNormalizationParameter()
- if abs(delta) > 0.001:
- return False
- return (self.getName() == other.getName() and
- self.getAutoscaleMode() == other.getAutoscaleMode() and
- self.getVMin() == other.getVMin() and
- self.getVMax() == other.getVMax() and
- numpy.array_equal(self.getColormapLUT(), other.getColormapLUT())
- )
-
- _SERIAL_VERSION = 3
-
- def restoreState(self, byteArray):
- """
- Read the colormap state from a QByteArray.
-
- :param qt.QByteArray byteArray: Stream containing the state
- :return: True if the restoration sussseed
- :rtype: bool
- """
- if self.isEditable() is False:
- raise NotEditableError('Colormap is not editable')
- stream = qt.QDataStream(byteArray, qt.QIODevice.ReadOnly)
-
- className = stream.readQString()
- if className != self.__class__.__name__:
- _logger.warning("Classname mismatch. Found %s." % className)
- return False
-
- version = stream.readUInt32()
- if version not in numpy.arange(1, self._SERIAL_VERSION+1):
- _logger.warning("Serial version mismatch. Found %d." % version)
- return False
-
- name = stream.readQString()
- isNull = stream.readBool()
- if not isNull:
- vmin = stream.readQVariant()
- else:
- vmin = None
- isNull = stream.readBool()
- if not isNull:
- vmax = stream.readQVariant()
- else:
- vmax = None
-
- normalization = stream.readQString()
- if normalization == Colormap.GAMMA:
- gamma = stream.readFloat()
- else:
- gamma = None
-
- if version == 1:
- autoscaleMode = Colormap.MINMAX
- else:
- autoscaleMode = stream.readQString()
-
- if version <= 2:
- nanColor = self._DEFAULT_NAN_COLOR
- else:
- nanColor = stream.readInt32(), stream.readInt32(), stream.readInt32(), stream.readInt32()
-
- # emit change event only once
- old = self.blockSignals(True)
- try:
- self.setName(name)
- self.setNormalization(normalization)
- self.setAutoscaleMode(autoscaleMode)
- self.setVRange(vmin, vmax)
- if gamma is not None:
- self.setGammaNormalizationParameter(gamma)
- self.setNaNColor(nanColor)
- finally:
- self.blockSignals(old)
- self.sigChanged.emit()
- return True
-
- def saveState(self):
- """
- Save state of the colomap into a QDataStream.
-
- :rtype: qt.QByteArray
- """
- data = qt.QByteArray()
- stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
-
- stream.writeQString(self.__class__.__name__)
- stream.writeUInt32(self._SERIAL_VERSION)
- stream.writeQString(self.getName())
- stream.writeBool(self.getVMin() is None)
- if self.getVMin() is not None:
- stream.writeQVariant(self.getVMin())
- stream.writeBool(self.getVMax() is None)
- if self.getVMax() is not None:
- stream.writeQVariant(self.getVMax())
- stream.writeQString(self.getNormalization())
- if self.getNormalization() == Colormap.GAMMA:
- stream.writeFloat(self.getGammaNormalizationParameter())
- stream.writeQString(self.getAutoscaleMode())
- nanColor = self.getNaNColor()
- stream.writeInt32(nanColor.red())
- stream.writeInt32(nanColor.green())
- stream.writeInt32(nanColor.blue())
- stream.writeInt32(nanColor.alpha())
-
- return data
-
-
-_PREFERRED_COLORMAPS = None
-"""
-Tuple of preferred colormap names accessed with :meth:`preferredColormaps`.
-"""
-
-
-def preferredColormaps():
- """Returns the name of the preferred colormaps.
-
- This list is used by widgets allowing to change the colormap
- like the :class:`ColormapDialog` as a subset of colormap choices.
-
- :rtype: tuple of str
- """
- global _PREFERRED_COLORMAPS
- if _PREFERRED_COLORMAPS is None:
- # Initialize preferred colormaps
- default_preferred = []
- for name, info in _AVAILABLE_LUTS.items():
- if (info.preferred and
- (info.source != 'matplotlib' or _matplotlib_cm is not None)):
- default_preferred.append(name)
- setPreferredColormaps(default_preferred)
- return tuple(_PREFERRED_COLORMAPS)
-
-
-def setPreferredColormaps(colormaps):
- """Set the list of preferred colormap names.
-
- Warning: If a colormap name is not available
- it will be removed from the list.
-
- :param colormaps: Not empty list of colormap names
- :type colormaps: iterable of str
- :raise ValueError: if the list of available preferred colormaps is empty.
- """
- supportedColormaps = Colormap.getSupportedColormaps()
- colormaps = [cmap for cmap in colormaps if cmap in supportedColormaps]
- if len(colormaps) == 0:
- raise ValueError("Cannot set preferred colormaps to an empty list")
-
- global _PREFERRED_COLORMAPS
- _PREFERRED_COLORMAPS = colormaps
-
-
-def registerLUT(name, colors, cursor_color='black', preferred=True):
- """Register a custom LUT to be used with `Colormap` objects.
-
- It can override existing LUT names.
-
- :param str name: Name of the LUT as defined to configure colormaps
- :param numpy.ndarray colors: The custom LUT to register.
- Nx3 or Nx4 numpy array of RGB(A) colors,
- either uint8 or float in [0, 1].
- :param bool preferred: If true, this LUT will be displayed as part of the
- preferred colormaps in dialogs.
- :param str cursor_color: Color used to display overlay over images using
- colormap with this LUT.
- """
- description = _LUT_DESCRIPTION('user', cursor_color, preferred=preferred)
- colors = _arrayToRgba8888(colors)
- _AVAILABLE_LUTS[name] = description
-
- if preferred:
- # Invalidate the preferred cache
- global _PREFERRED_COLORMAPS
- if _PREFERRED_COLORMAPS is not None:
- if name not in _PREFERRED_COLORMAPS:
- _PREFERRED_COLORMAPS.append(name)
- else:
- # The cache is not yet loaded, it's fine
- pass
-
- # Register the cache as the LUT was already loaded
- _COLORMAP_CACHE[name] = colors
diff --git a/silx/gui/console.py b/silx/gui/console.py
deleted file mode 100644
index 5dc6336..0000000
--- a/silx/gui/console.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides an IPython console widget.
-
-You can push variables - any python object - to the
-console's interactive namespace. This provides users with an advanced way
-of interacting with your program. For instance, if your program has a
-:class:`PlotWidget` or a :class:`PlotWindow`, you can push a reference to
-these widgets to allow your users to add curves, save data to files… by using
-the widgets' methods from the console.
-
-.. note::
-
- This module has a dependency on
- `qtconsole <https://pypi.org/project/qtconsole/>`_.
- An ``ImportError`` will be raised if it is
- imported while the dependencies are not satisfied.
-
-Basic usage example::
-
- from silx.gui import qt
- from silx.gui.console import IPythonWidget
-
- app = qt.QApplication([])
-
- hello_button = qt.QPushButton("Hello World!", None)
- hello_button.show()
-
- console = IPythonWidget()
- console.show()
- console.pushVariables({"the_button": hello_button})
-
- app.exec_()
-
-This program will display a console widget and a push button in two separate
-windows. You will be able to interact with the button from the console,
-for example change its text::
-
- >>> the_button.setText("Spam spam")
-
-An IPython interactive console is a powerful tool that enables you to work
-with data and plot it.
-See `this tutorial <https://plot.ly/python/ipython-notebook-tutorial/>`_
-for more information on some of the rich features of IPython.
-"""
-__authors__ = ["Tim Rae", "V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/05/2016"
-
-import logging
-
-from . import qt
-
-_logger = logging.getLogger(__name__)
-
-
-# This widget cannot be used inside an interactive IPython shell.
-# It would raise MultipleInstanceError("Multiple incompatible subclass
-# instances of InProcessInteractiveShell are being created").
-try:
- __IPYTHON__
-except NameError:
- pass # Not in IPython
-else:
- msg = "Module " + __name__ + " cannot be used within an IPython shell"
- raise ImportError(msg)
-
-try:
- from qtconsole.rich_jupyter_widget import RichJupyterWidget as \
- _RichJupyterWidget
-except ImportError:
- try:
- from qtconsole.rich_ipython_widget import RichJupyterWidget as \
- _RichJupyterWidget
- except ImportError:
- from qtconsole.rich_ipython_widget import RichIPythonWidget as \
- _RichJupyterWidget
-
-from qtconsole.inprocess import QtInProcessKernelManager
-
-try:
- from ipykernel import version_info as _ipykernel_version_info
-except ImportError:
- _ipykernel_version_info = None
-
-
-class IPythonWidget(_RichJupyterWidget):
- """Live IPython console widget.
-
- .. image:: img/IPythonWidget.png
-
- :param custom_banner: Custom welcome message to be printed at the top of
- the console.
- """
-
- def __init__(self, parent=None, custom_banner=None, *args, **kwargs):
- if parent is not None:
- kwargs["parent"] = parent
- super(IPythonWidget, self).__init__(*args, **kwargs)
- if custom_banner is not None:
- self.banner = custom_banner
- self.setWindowTitle(self.banner)
- self.kernel_manager = kernel_manager = QtInProcessKernelManager()
- kernel_manager.start_kernel()
-
- # Monkey-patch to workaround issue:
- # https://github.com/ipython/ipykernel/issues/370
- if (_ipykernel_version_info is not None and
- _ipykernel_version_info[0] > 4 and
- _ipykernel_version_info[:3] <= (5, 1, 0)):
- def _abort_queues(*args, **kwargs):
- pass
- kernel_manager.kernel._abort_queues = _abort_queues
-
- self.kernel_client = kernel_client = self._kernel_manager.client()
- kernel_client.start_channels()
-
- def stop():
- kernel_client.stop_channels()
- kernel_manager.shutdown_kernel()
- self.exit_requested.connect(stop)
-
- def sizeHint(self):
- """Return a reasonable default size for usage in :class:`PlotWindow`"""
- return qt.QSize(500, 300)
-
- def pushVariables(self, variable_dict):
- """ Given a dictionary containing name / value pairs, push those
- variables to the IPython console widget.
-
- :param variable_dict: Dictionary of variables to be pushed to the
- console's interactive namespace (```{variable_name: object, …}```)
- """
- self.kernel_manager.kernel.shell.push(variable_dict)
-
-
-class IPythonDockWidget(qt.QDockWidget):
- """Dock Widget including a :class:`IPythonWidget` inside
- a vertical layout.
-
- .. image:: img/IPythonDockWidget.png
-
- :param available_vars: Dictionary of variables to be pushed to the
- console's interactive namespace: ``{"variable_name": object, …}``
- :param custom_banner: Custom welcome message to be printed at the top of
- the console
- :param title: Dock widget title
- :param parent: Parent :class:`qt.QMainWindow` containing this
- :class:`qt.QDockWidget`
- """
- def __init__(self, parent=None, available_vars=None, custom_banner=None,
- title="Console"):
- super(IPythonDockWidget, self).__init__(title, parent)
-
- self.ipyconsole = IPythonWidget(custom_banner=custom_banner)
-
- self.layout().setContentsMargins(0, 0, 0, 0)
- self.setWidget(self.ipyconsole)
-
- if available_vars is not None:
- self.ipyconsole.pushVariables(available_vars)
-
- def showEvent(self, event):
- """Make sure this widget is raised when it is shown
- (when it is first created as a tab in PlotWindow or when it is shown
- again after hiding).
- """
- self.raise_()
-
-
-def main():
- """Run a Qt app with an IPython console"""
- app = qt.QApplication([])
- widget = IPythonDockWidget()
- widget.show()
- app.exec_()
-
-
-if __name__ == '__main__':
- main()
diff --git a/silx/gui/data/ArrayTableModel.py b/silx/gui/data/ArrayTableModel.py
deleted file mode 100644
index b7bd9c4..0000000
--- a/silx/gui/data/ArrayTableModel.py
+++ /dev/null
@@ -1,670 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module defines a data model for displaying and editing arrays of any
-number of dimensions in a table view.
-"""
-from __future__ import division
-import numpy
-import logging
-from silx.gui import qt
-from silx.gui.data.TextFormatter import TextFormatter
-
-__authors__ = ["V.A. Sole"]
-__license__ = "MIT"
-__date__ = "27/09/2017"
-
-
-_logger = logging.getLogger(__name__)
-
-
-def _is_array(data):
- """Return True if object implements all necessary attributes to be used
- as a numpy array.
-
- :param object data: Array-like object (numpy array, h5py dataset...)
- :return: boolean
- """
- # add more required attribute if necessary
- for attr in ("shape", "dtype"):
- if not hasattr(data, attr):
- return False
- return True
-
-
-class ArrayTableModel(qt.QAbstractTableModel):
- """This data model provides access to 2D slices in a N-dimensional
- array.
-
- A slice for a 3-D array is characterized by a perspective (the number of
- the axis orthogonal to the slice) and an index at which the slice
- intersects the orthogonal axis.
-
- In the n-D case, only slices parallel to the last two axes are handled. A
- slice is therefore characterized by a list of indices locating the
- slice on all the :math:`n - 2` orthogonal axes.
-
- :param parent: Parent QObject
- :param data: Numpy array, or object implementing a similar interface
- (e.g. h5py dataset)
- :param str fmt: Format string for representing numerical values.
- Default is ``"%g"``.
- :param sequence[int] perspective: See documentation
- of :meth:`setPerspective`.
- """
-
- MAX_NUMBER_OF_SECTIONS = 10e6
- """Maximum number of displayed rows and columns"""
-
- def __init__(self, parent=None, data=None, perspective=None):
- qt.QAbstractTableModel.__init__(self, parent)
-
- self._array = None
- """n-dimensional numpy array"""
-
- self._bgcolors = None
- """(n+1)-dimensional numpy array containing RGB(A) color data
- for the background color
- """
-
- self._fgcolors = None
- """(n+1)-dimensional numpy array containing RGB(A) color data
- for the foreground color
- """
-
- self._formatter = None
- """Formatter for text representation of data"""
-
- formatter = TextFormatter(self)
- formatter.setUseQuoteForText(False)
- self.setFormatter(formatter)
-
- self._index = None
- """This attribute stores the slice index, as a list of indices
- where the frame intersects orthogonal axis."""
-
- self._perspective = None
- """Sequence of dimensions orthogonal to the frame to be viewed.
- For an array with ``n`` dimensions, this is a sequence of ``n-2``
- integers. the first dimension is numbered ``0``.
- By default, the data frames use the last two dimensions as their axes
- and therefore the perspective is a sequence of the first ``n-2``
- dimensions.
- For example, for a 5-D array, the default perspective is ``(0, 1, 2)``
- and the default frames axes are ``(3, 4)``."""
-
- # set _data and _perspective
- self.setArrayData(data, perspective=perspective)
-
- def _getRowDim(self):
- """The row axis is the first axis parallel to the frames
- (lowest dimension number)
-
- Return None for 0-D (scalar) or 1-D arrays
- """
- n_dimensions = len(self._array.shape)
- if n_dimensions < 2:
- # scalar or 1D array: no row index
- return None
- # take all dimensions and remove the orthogonal ones
- frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
- # sanity check
- assert len(frame_axes) == 2
- return min(frame_axes)
-
- def _getColumnDim(self):
- """The column axis is the second (highest dimension) axis parallel
- to the frames
-
- Return None for 0-D (scalar)
- """
- n_dimensions = len(self._array.shape)
- if n_dimensions < 1:
- # scalar: no column index
- return None
- frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
- # sanity check
- assert (len(frame_axes) == 2) if n_dimensions > 1 else (len(frame_axes) == 1)
- return max(frame_axes)
-
- def _getIndexTuple(self, table_row, table_col):
- """Return the n-dimensional index of a value in the original array,
- based on its row and column indices in the table view
-
- :param table_row: Row index (0-based) of a table cell
- :param table_col: Column index (0-based) of a table cell
- :return: Tuple of indices of the element in the numpy array
- """
- row_dim = self._getRowDim()
- col_dim = self._getColumnDim()
-
- # get indices on all orthogonal axes
- selection = list(self._index)
- # insert indices on parallel axes
- if row_dim is not None:
- selection.insert(row_dim, table_row)
- if col_dim is not None:
- selection.insert(col_dim, table_col)
- return tuple(selection)
-
- # Methods to be implemented to subclass QAbstractTableModel
- def rowCount(self, parent_idx=None):
- """QAbstractTableModel method
- Return number of rows to be displayed in table"""
- row_dim = self._getRowDim()
- if row_dim is None:
- # 0-D and 1-D arrays
- return 1
- return min(self._array.shape[row_dim], self.MAX_NUMBER_OF_SECTIONS)
-
- def columnCount(self, parent_idx=None):
- """QAbstractTableModel method
- Return number of columns to be displayed in table"""
- col_dim = self._getColumnDim()
- if col_dim is None:
- # 0-D array
- return 1
- return min(self._array.shape[col_dim], self.MAX_NUMBER_OF_SECTIONS)
-
- def __isClipped(self, orientation=qt.Qt.Vertical) -> bool:
- """Returns whether or not array is clipped in a given orientation"""
- if orientation == qt.Qt.Vertical:
- dim = self._getRowDim()
- else:
- dim = self._getColumnDim()
- return (dim is not None and
- self._array.shape[dim] > self.MAX_NUMBER_OF_SECTIONS)
-
- def __isClippedIndex(self, index) -> bool:
- """Returns whether or not index's cell represents clipped data."""
- if not index.isValid():
- return False
- if index.row() == self.MAX_NUMBER_OF_SECTIONS - 2:
- return self.__isClipped(qt.Qt.Vertical)
- if index.column() == self.MAX_NUMBER_OF_SECTIONS - 2:
- return self.__isClipped(qt.Qt.Horizontal)
- return False
-
- def __clippedData(self, role=qt.Qt.DisplayRole):
- """Return data for cells representing clipped data"""
- if role == qt.Qt.DisplayRole:
- return "..."
- elif role == qt.Qt.ToolTipRole:
- return "Dataset is too large: display is clipped"
- else:
- return None
-
- def data(self, index, role=qt.Qt.DisplayRole):
- """QAbstractTableModel method to access data values
- in the format ready to be displayed"""
- if index.isValid():
- if self.__isClippedIndex(index): # Special displayed for clipped data
- return self.__clippedData(role)
-
- row, column = index.row(), index.column()
-
- # When clipped, display last data of the array in last column of the table
- if (self.__isClipped(qt.Qt.Vertical) and
- row == self.MAX_NUMBER_OF_SECTIONS - 1):
- row = self._array.shape[self._getRowDim()] - 1
- if (self.__isClipped(qt.Qt.Horizontal) and
- column == self.MAX_NUMBER_OF_SECTIONS - 1):
- column = self._array.shape[self._getColumnDim()] - 1
-
- selection = self._getIndexTuple(row, column)
-
- if role == qt.Qt.DisplayRole:
- return self._formatter.toString(self._array[selection], self._array.dtype)
-
- if role == qt.Qt.BackgroundRole and self._bgcolors is not None:
- r, g, b = self._bgcolors[selection][0:3]
- if self._bgcolors.shape[-1] == 3:
- return qt.QColor(r, g, b)
- if self._bgcolors.shape[-1] == 4:
- a = self._bgcolors[selection][3]
- return qt.QColor(r, g, b, a)
-
- if role == qt.Qt.ForegroundRole:
- if self._fgcolors is not None:
- r, g, b = self._fgcolors[selection][0:3]
- if self._fgcolors.shape[-1] == 3:
- return qt.QColor(r, g, b)
- if self._fgcolors.shape[-1] == 4:
- a = self._fgcolors[selection][3]
- return qt.QColor(r, g, b, a)
-
- # no fg color given, use black or white
- # based on luminosity threshold
- elif self._bgcolors is not None:
- r, g, b = self._bgcolors[selection][0:3]
- lum = 0.21 * r + 0.72 * g + 0.07 * b
- if lum < 128:
- return qt.QColor(qt.Qt.white)
- else:
- return qt.QColor(qt.Qt.black)
-
- def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
- """QAbstractTableModel method
- Return the 0-based row or column index, for display in the
- horizontal and vertical headers"""
- if self.__isClipped(orientation): # Header is clipped
- if section == self.MAX_NUMBER_OF_SECTIONS - 2:
- # Represent clipped data
- return self.__clippedData(role)
-
- elif section == self.MAX_NUMBER_OF_SECTIONS - 1:
- # Display last index from data not table
- if role == qt.Qt.DisplayRole:
- if orientation == qt.Qt.Vertical:
- dim = self._getRowDim()
- else:
- dim = self._getColumnDim()
- return str(self._array.shape[dim] - 1)
- else:
- return None
-
- if role == qt.Qt.DisplayRole:
- return "%d" % section
- return None
-
- def flags(self, index):
- """QAbstractTableModel method to inform the view whether data
- is editable or not."""
- if not self._editable or self.__isClippedIndex(index):
- return qt.QAbstractTableModel.flags(self, index)
- return qt.QAbstractTableModel.flags(self, index) | qt.Qt.ItemIsEditable
-
- def setData(self, index, value, role=None):
- """QAbstractTableModel method to handle editing data.
- Cast the new value into the same format as the array before editing
- the array value."""
- if index.isValid() and role == qt.Qt.EditRole:
- try:
- # cast value to same type as array
- v = numpy.array(value, dtype=self._array.dtype).item()
- except ValueError:
- return False
-
- selection = self._getIndexTuple(index.row(),
- index.column())
- self._array[selection] = v
- self.dataChanged.emit(index, index)
- return True
- else:
- return False
-
- # Public methods
- def setArrayData(self, data, copy=True,
- perspective=None, editable=False):
- """Set the data array and the viewing perspective.
-
- You can set ``copy=False`` if you need more performances, when dealing
- with a large numpy array. In this case, a simple reference to the data
- is used to access the data, rather than a copy of the array.
-
- .. warning::
-
- Any change to the data model will affect your original data
- array, when using a reference rather than a copy..
-
- :param data: n-dimensional numpy array, or any object that can be
- converted to a numpy array using ``numpy.array(data)`` (e.g.
- a nested sequence).
- :param bool copy: If *True* (default), a copy of the array is stored
- and the original array is not modified if the table is edited.
- If *False*, then the behavior depends on the data type:
- if possible (if the original array is a proper numpy array)
- a reference to the original array is used.
- :param perspective: See documentation of :meth:`setPerspective`.
- If None, the default perspective is the list of the first ``n-2``
- dimensions, to view frames parallel to the last two axes.
- :param bool editable: Flag to enable editing data. Default *False*.
- """
- if qt.qVersion() > "4.6":
- self.beginResetModel()
- else:
- self.reset()
-
- if data is None:
- # empty array
- self._array = numpy.array([])
- elif copy:
- # copy requested (default)
- self._array = numpy.array(data, copy=True)
- if hasattr(data, "dtype"):
- # Avoid to lose the monkey-patched h5py dtype
- self._array.dtype = data.dtype
- elif not _is_array(data):
- raise TypeError("data is not a proper array. Try setting" +
- " copy=True to convert it into a numpy array" +
- " (this will cause the data to be copied!)")
- # # copy not requested, but necessary
- # _logger.warning(
- # "data is not an array-like object. " +
- # "Data must be copied.")
- # self._array = numpy.array(data, copy=True)
- else:
- # Copy explicitly disabled & data implements required attributes.
- # We can use a reference.
- self._array = data
-
- # reset colors to None if new data shape is inconsistent
- valid_color_shapes = (self._array.shape + (3,),
- self._array.shape + (4,))
- if self._bgcolors is not None:
- if self._bgcolors.shape not in valid_color_shapes:
- self._bgcolors = None
- if self._fgcolors is not None:
- if self._fgcolors.shape not in valid_color_shapes:
- self._fgcolors = None
-
- self.setEditable(editable)
-
- self._index = [0 for _i in range((len(self._array.shape) - 2))]
- self._perspective = tuple(perspective) if perspective is not None else\
- tuple(range(0, len(self._array.shape) - 2))
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
-
- def setArrayColors(self, bgcolors=None, fgcolors=None):
- """Set the colors for all table cells by passing an array
- of RGB or RGBA values (integers between 0 and 255).
-
- The shape of the colors array must be consistent with the data shape.
-
- If the data array is n-dimensional, the colors array must be
- (n+1)-dimensional, with the first n-dimensions identical to the data
- array dimensions, and the last dimension length-3 (RGB) or
- length-4 (RGBA).
-
- :param bgcolors: RGB or RGBA colors array, defining the background color
- for each cell in the table.
- :param fgcolors: RGB or RGBA colors array, defining the foreground color
- (text color) for each cell in the table.
- """
- # array must be RGB or RGBA
- valid_shapes = (self._array.shape + (3,), self._array.shape + (4,))
- errmsg = "Inconsistent shape for color array, should be %s or %s" % valid_shapes
-
- if bgcolors is not None:
- if not _is_array(bgcolors):
- bgcolors = numpy.array(bgcolors)
- assert bgcolors.shape in valid_shapes, errmsg
-
- self._bgcolors = bgcolors
-
- if fgcolors is not None:
- if not _is_array(fgcolors):
- fgcolors = numpy.array(fgcolors)
- assert fgcolors.shape in valid_shapes, errmsg
-
- self._fgcolors = fgcolors
-
- def setEditable(self, editable):
- """Set flags to make the data editable.
-
- .. warning::
-
- If the data is a reference to a h5py dataset open in read-only
- mode, setting *editable=True* will fail and print a warning.
-
- .. warning::
-
- Making the data editable means that the underlying data structure
- in this data model will be modified.
- If the data is a reference to a public object (open with
- ``copy=False``), this could have side effects. If it is a
- reference to an HDF5 dataset, this means the file will be
- modified.
-
- :param bool editable: Flag to enable editing data.
- :return: True if setting desired flag succeeded, False if it failed.
- """
- self._editable = editable
- if hasattr(self._array, "file"):
- if hasattr(self._array.file, "mode"):
- if editable and self._array.file.mode == "r":
- _logger.warning(
- "Data is a HDF5 dataset open in read-only " +
- "mode. Editing must be disabled.")
- self._editable = False
- return False
- return True
-
- def getData(self, copy=True):
- """Return a copy of the data array, or a reference to it
- if *copy=False* is passed as parameter.
-
- In case the shape was modified, to convert 0-D or 1-D data
- into 2-D data, the original shape is restored in the returned data.
-
- :param bool copy: If *True* (default), return a copy of the data. If
- *False*, return a reference.
- :return: numpy array of data, or reference to original data object
- if *copy=False*
- """
- data = self._array if not copy else numpy.array(self._array, copy=True)
- return data
-
- def setFrameIndex(self, index):
- """Set the active slice index.
-
- This method is only relevant to arrays with at least 3 dimensions.
-
- :param index: Index of the active slice in the array.
- In the general n-D case, this is a sequence of :math:`n - 2`
- indices where the slice intersects the respective orthogonal axes.
- :raise IndexError: If any index in the index sequence is out of bound
- on its respective axis.
- """
- shape = self._array.shape
- if len(shape) < 3:
- # index is ignored
- return
-
- if qt.qVersion() > "4.6":
- self.beginResetModel()
- else:
- self.reset()
-
- if len(shape) == 3:
- len_ = shape[self._perspective[0]]
- # accept integers as index in the case of 3-D arrays
- if not hasattr(index, "__len__"):
- self._index = [index]
- else:
- self._index = index
- if not 0 <= self._index[0] < len_:
- raise ValueError("Index must be a positive integer " +
- "lower than %d" % len_)
- else:
- # general n-D case
- for i_, idx in enumerate(index):
- if not 0 <= idx < shape[self._perspective[i_]]:
- raise IndexError("Invalid index %d " % idx +
- "not in range 0-%d" % (shape[i_] - 1))
- self._index = index
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
-
- def setFormatter(self, formatter):
- """Set the formatter object to be used to display data from the model
-
- :param TextFormatter formatter: Formatter to use
- """
- if formatter is self._formatter:
- return
-
- if qt.qVersion() > "4.6":
- self.beginResetModel()
-
- if self._formatter is not None:
- self._formatter.formatChanged.disconnect(self.__formatChanged)
-
- self._formatter = formatter
- if self._formatter is not None:
- self._formatter.formatChanged.connect(self.__formatChanged)
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
- else:
- self.reset()
-
- def getFormatter(self):
- """Returns the text formatter used.
-
- :rtype: TextFormatter
- """
- return self._formatter
-
- def __formatChanged(self):
- """Called when the format changed.
- """
- self.reset()
-
- def setPerspective(self, perspective):
- """Set the perspective by defining a sequence listing all axes
- orthogonal to the frame or 2-D slice to be visualized.
-
- Alternatively, you can use :meth:`setFrameAxes` for the complementary
- approach of specifying the two axes parallel to the frame.
-
- In the 1-D or 2-D case, this parameter is irrelevant.
-
- In the 3-D case, if the unit vectors describing
- your axes are :math:`\vec{x}, \vec{y}, \vec{z}`, a perspective of 0
- means you slices are parallel to :math:`\vec{y}\vec{z}`, 1 means they
- are parallel to :math:`\vec{x}\vec{z}` and 2 means they
- are parallel to :math:`\vec{x}\vec{y}`.
-
- In the n-D case, this parameter is a sequence of :math:`n-2` axes
- numbers.
- For instance if you want to display 2-D frames whose axes are the
- second and third dimensions of a 5-D array, set the perspective to
- ``(0, 3, 4)``.
-
- :param perspective: Sequence of dimensions/axes orthogonal to the
- frames.
- :raise: IndexError if any value in perspective is higher than the
- number of dimensions minus one (first dimension is 0), or
- if the number of values is different from the number of dimensions
- minus two.
- """
- n_dimensions = len(self._array.shape)
- if n_dimensions < 3:
- _logger.warning(
- "perspective is not relevant for 1D and 2D arrays")
- return
-
- if not hasattr(perspective, "__len__"):
- # we can tolerate an integer for 3-D array
- if n_dimensions == 3:
- perspective = [perspective]
- else:
- raise ValueError("perspective must be a sequence of integers")
-
- # ensure unicity of dimensions in perspective
- perspective = tuple(set(perspective))
-
- if len(perspective) != n_dimensions - 2 or\
- min(perspective) < 0 or max(perspective) >= n_dimensions:
- raise IndexError(
- "Invalid perspective " + str(perspective) +
- " for %d-D array " % n_dimensions +
- "with shape " + str(self._array.shape))
-
- if qt.qVersion() > "4.6":
- self.beginResetModel()
- else:
- self.reset()
-
- self._perspective = perspective
-
- # reset index
- self._index = [0 for _i in range(n_dimensions - 2)]
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
-
- def setFrameAxes(self, row_axis, col_axis):
- """Set the perspective by specifying the two axes parallel to the frame
- to be visualised.
-
- The complementary approach of defining the orthogonal axes can be used
- with :meth:`setPerspective`.
-
- :param int row_axis: Index (0-based) of the first dimension used as a frame
- axis
- :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
- axis
- :raise: IndexError if axes are invalid
- """
- if row_axis > col_axis:
- _logger.warning("The dimension of the row axis must be lower " +
- "than the dimension of the column axis. Swapping.")
- row_axis, col_axis = min(row_axis, col_axis), max(row_axis, col_axis)
-
- n_dimensions = len(self._array.shape)
- if n_dimensions < 3:
- _logger.warning(
- "Frame axes cannot be changed for 1D and 2D arrays")
- return
-
- perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
-
- if len(perspective) != n_dimensions - 2 or\
- min(perspective) < 0 or max(perspective) >= n_dimensions:
- raise IndexError(
- "Invalid perspective " + str(perspective) +
- " for %d-D array " % n_dimensions +
- "with shape " + str(self._array.shape))
-
- if qt.qVersion() > "4.6":
- self.beginResetModel()
- else:
- self.reset()
-
- self._perspective = perspective
- # reset index
- self._index = [0 for _i in range(n_dimensions - 2)]
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
-
-
-if __name__ == "__main__":
- app = qt.QApplication([])
- w = qt.QTableView()
- d = numpy.random.normal(0, 1, (5, 1000, 1000))
- for i in range(5):
- d[i, :, :] += i * 10
- m = ArrayTableModel(data=d)
- w.setModel(m)
- m.setFrameIndex(3)
- # m.setArrayData(numpy.ones((100,)))
- w.show()
- app.exec_()
diff --git a/silx/gui/data/ArrayTableWidget.py b/silx/gui/data/ArrayTableWidget.py
deleted file mode 100644
index cb8e915..0000000
--- a/silx/gui/data/ArrayTableWidget.py
+++ /dev/null
@@ -1,492 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module defines a widget designed to display data arrays with any
-number of dimensions as 2D frames (images, slices) in a table view.
-The dimensions not displayed in the table can be browsed using improved
-sliders.
-
-The widget uses a TableView that relies on a custom abstract item
-model: :class:`silx.gui.data.ArrayTableModel`.
-"""
-from __future__ import division
-import sys
-
-from silx.gui import qt
-from silx.gui.widgets.TableWidget import TableView
-from .ArrayTableModel import ArrayTableModel
-from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/01/2017"
-
-
-class AxesSelector(qt.QWidget):
- """Widget with two combo-boxes to select two dimensions among
- all possible dimensions of an n-dimensional array.
-
- The first combobox contains values from :math:`0` to :math:`n-2`.
-
- The choices in the 2nd CB depend on the value selected in the first one.
- If the value selected in the first CB is :math:`m`, the second one lets you
- select values from :math:`m+1` to :math:`n-1`.
-
- The two axes can be used to select the row axis and the column axis t
- display a slice of the array data in a table view.
- """
- sigDimensionsChanged = qt.Signal(int, int)
- """Signal emitted whenever one of the comboboxes is changed.
- The signal carries the two selected dimensions."""
-
- def __init__(self, parent=None, n=None):
- qt.QWidget.__init__(self, parent)
- self.layout = qt.QHBoxLayout(self)
- self.layout.setContentsMargins(0, 2, 0, 2)
- self.layout.setSpacing(10)
-
- self.rowsCB = qt.QComboBox(self)
- self.columnsCB = qt.QComboBox(self)
-
- self.layout.addWidget(qt.QLabel("Rows dimension", self))
- self.layout.addWidget(self.rowsCB)
- self.layout.addWidget(qt.QLabel(" ", self))
- self.layout.addWidget(qt.QLabel("Columns dimension", self))
- self.layout.addWidget(self.columnsCB)
- self.layout.addStretch(1)
-
- self._slotsAreConnected = False
- if n is not None:
- self.setNDimensions(n)
-
- def setNDimensions(self, n):
- """Initialize combo-boxes depending on number of dimensions of array.
- Initially, the rows dimension is the second-to-last one, and the
- columns dimension is the last one.
-
- Link the CBs together. MAke them emit a signal when their value is
- changed.
-
- :param int n: Number of dimensions of array
- """
- # remember the number of dimensions and the rows dimension
- self.n = n
- self._rowsDim = n - 2
-
- # ensure slots are disconnected before (re)initializing widget
- if self._slotsAreConnected:
- self.rowsCB.currentIndexChanged.disconnect(self._rowDimChanged)
- self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
-
- self._clear()
- self.rowsCB.addItems([str(i) for i in range(n - 1)])
- self.rowsCB.setCurrentIndex(n - 2)
- if n >= 1:
- self.columnsCB.addItem(str(n - 1))
- self.columnsCB.setCurrentIndex(0)
-
- # reconnect slots
- self.rowsCB.currentIndexChanged.connect(self._rowDimChanged)
- self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
- self._slotsAreConnected = True
-
- # emit new dimensions
- if n > 2:
- self.sigDimensionsChanged.emit(n - 2, n - 1)
-
- def setDimensions(self, row_dim, col_dim):
- """Set the rows and columns dimensions.
-
- The rows dimension must be lower than the columns dimension.
-
- :param int row_dim: Rows dimension
- :param int col_dim: Columns dimension
- """
- if row_dim >= col_dim:
- raise IndexError("Row dimension must be lower than column dimension")
- if not (0 <= row_dim < self.n - 1):
- raise IndexError("Row dimension must be between 0 and %d" % (self.n - 2))
- if not (row_dim < col_dim <= self.n - 1):
- raise IndexError("Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1))
-
- # set the rows dimension; this triggers an update of columnsCB
- self.rowsCB.setCurrentIndex(row_dim)
- # columnsCB first item is "row_dim + 1". So index of "col_dim" is
- # col_dim - (row_dim + 1)
- self.columnsCB.setCurrentIndex(col_dim - row_dim - 1)
-
- def getDimensions(self):
- """Return a 2-tuple of the rows dimension and the columns dimension.
-
- :return: 2-tuple of axes numbers (row_dimension, col_dimension)
- """
- return self._getRowDim(), self._getColDim()
-
- def _clear(self):
- """Empty the combo-boxes"""
- self.rowsCB.clear()
- self.columnsCB.clear()
-
- def _getRowDim(self):
- """Get rows dimension, selected in :attr:`rowsCB`
- """
- # rows combobox contains elements "0", ..."n-2",
- # so the selected dim is always equal to the index
- return self.rowsCB.currentIndex()
-
- def _getColDim(self):
- """Get columns dimension, selected in :attr:`columnsCB`"""
- # columns combobox contains elements "row_dim+1", "row_dim+2", ..., "n-1"
- # so the selected dim is equal to row_dim + 1 + index
- return self._rowsDim + 1 + self.columnsCB.currentIndex()
-
- def _rowDimChanged(self):
- """Update columns combobox when the rows dimension is changed.
-
- Emit :attr:`sigDimensionsChanged`"""
- old_col_dim = self._getColDim()
- new_row_dim = self._getRowDim()
-
- # clear cols CB
- self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
- self.columnsCB.clear()
- # refill cols CB
- for i in range(new_row_dim + 1, self.n):
- self.columnsCB.addItem(str(i))
-
- # keep previous col dimension if possible
- new_col_cb_idx = old_col_dim - (new_row_dim + 1)
- if new_col_cb_idx < 0:
- # if row_dim is now greater than the previous col_dim,
- # we select a new col_dim = row_dim + 1 (first element in cols CB)
- new_col_cb_idx = 0
- self.columnsCB.setCurrentIndex(new_col_cb_idx)
-
- # reconnect slot
- self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
-
- self._rowsDim = new_row_dim
-
- self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
-
- def _colDimChanged(self):
- """Emit :attr:`sigDimensionsChanged`"""
- self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
-
-
-def _get_shape(array_like):
- """Return shape of an array like object.
-
- In case the object is a nested sequence (list of lists, tuples...),
- the size of each dimension is assumed to be uniform, and is deduced from
- the length of the first sequence.
-
- :param array_like: Array like object: numpy array, hdf5 dataset,
- multi-dimensional sequence
- :return: Shape of array, as a tuple of integers
- """
- if hasattr(array_like, "shape"):
- return array_like.shape
-
- shape = []
- subsequence = array_like
- while hasattr(subsequence, "__len__"):
- shape.append(len(subsequence))
- subsequence = subsequence[0]
-
- return tuple(shape)
-
-
-class ArrayTableWidget(qt.QWidget):
- """This widget is designed to display data of 2D frames (images, slices)
- in a table view. The widget can load any n-dimensional array, and display
- any 2-D frame/slice in the array.
-
- The index of the dimensions orthogonal to the displayed frame can be set
- interactively using a browser widget (sliders, buttons and text entries).
-
- To set the data, use :meth:`setArrayData`.
- To select the perspective, use :meth:`setPerspective` or
- use :meth:`setFrameAxes`.
- To select the frame, use :meth:`setFrameIndex`.
-
- .. image:: img/ArrayTableWidget.png
- """
- def __init__(self, parent=None):
- """
-
- :param parent: parent QWidget
- :param labels: list of labels for each dimension of the array
- """
- qt.QWidget.__init__(self, parent)
- self.mainLayout = qt.QVBoxLayout(self)
- self.mainLayout.setContentsMargins(0, 0, 0, 0)
- self.mainLayout.setSpacing(0)
-
- self.browserContainer = qt.QWidget(self)
- self.browserLayout = qt.QGridLayout(self.browserContainer)
- self.browserLayout.setContentsMargins(0, 0, 0, 0)
- self.browserLayout.setSpacing(0)
-
- self._dimensionLabelsText = []
- """List of text labels sorted in the increasing order of the dimension
- they apply to."""
- self._browserLabels = []
- """List of QLabel widgets."""
- self._browserWidgets = []
- """List of HorizontalSliderWithBrowser widgets."""
-
- self.axesSelector = AxesSelector(self)
-
- self.view = TableView(self)
-
- self.mainLayout.addWidget(self.browserContainer)
- self.mainLayout.addWidget(self.axesSelector)
- self.mainLayout.addWidget(self.view)
-
- self.model = ArrayTableModel(self)
- self.view.setModel(self.model)
-
- def setArrayData(self, data, labels=None, copy=True, editable=False):
- """Set the data array. Update frame browsers and labels.
-
- :param data: Numpy array or similar object (e.g. nested sequence,
- h5py dataset...)
- :param labels: list of labels for each dimension of the array, or
- boolean ``True`` to use default labels ("dimension 0",
- "dimension 1", ...). `None` to disable labels (default).
- :param bool copy: If *True*, store a copy of *data* in the model. If
- *False*, store a reference to *data* if possible (only possible if
- *data* is a proper numpy array or an object that implements the
- same methods).
- :param bool editable: Flag to enable editing data. Default is *False*
- """
- self._data_shape = _get_shape(data)
-
- n_widgets = len(self._browserWidgets)
- n_dimensions = len(self._data_shape)
-
- # Reset text of labels
- self._dimensionLabelsText = []
- for i in range(n_dimensions):
- if labels in [True, 1]:
- label_text = "Dimension %d" % i
- elif labels is None or i >= len(labels):
- label_text = ""
- else:
- label_text = labels[i]
- self._dimensionLabelsText.append(label_text)
-
- # not enough widgets, create new ones (we need n_dim - 2)
- for i in range(n_widgets, n_dimensions - 2):
- browser = HorizontalSliderWithBrowser(self.browserContainer)
- self.browserLayout.addWidget(browser, i, 1)
- self._browserWidgets.append(browser)
- browser.valueChanged.connect(self._browserSlot)
- browser.setEnabled(False)
- browser.hide()
-
- label = qt.QLabel(self.browserContainer)
- self._browserLabels.append(label)
- self.browserLayout.addWidget(label, i, 0)
- label.hide()
-
- n_widgets = len(self._browserWidgets)
- for i in range(n_widgets):
- label = self._browserLabels[i]
- browser = self._browserWidgets[i]
-
- if (i + 2) < n_dimensions:
- label.setText(self._dimensionLabelsText[i])
- browser.setRange(0, self._data_shape[i] - 1)
- browser.setEnabled(True)
- browser.show()
- if labels is not None:
- label.show()
- else:
- label.hide()
- else:
- browser.setEnabled(False)
- browser.hide()
- label.hide()
-
- # set model
- self.model.setArrayData(data, copy=copy, editable=editable)
- # some linux distributions need this call
- self.view.setModel(self.model)
- if editable:
- self.view.enableCut()
- self.view.enablePaste()
-
- # initialize & connect axesSelector
- self.axesSelector.setNDimensions(n_dimensions)
- self.axesSelector.sigDimensionsChanged.connect(self.setFrameAxes)
-
- def setArrayColors(self, bgcolors=None, fgcolors=None):
- """Set the colors for all table cells by passing an array
- of RGB or RGBA values (integers between 0 and 255).
-
- The shape of the colors array must be consistent with the data shape.
-
- If the data array is n-dimensional, the colors array must be
- (n+1)-dimensional, with the first n-dimensions identical to the data
- array dimensions, and the last dimension length-3 (RGB) or
- length-4 (RGBA).
-
- :param bgcolors: RGB or RGBA colors array, defining the background color
- for each cell in the table.
- :param fgcolors: RGB or RGBA colors array, defining the foreground color
- (text color) for each cell in the table.
- """
- self.model.setArrayColors(bgcolors, fgcolors)
-
- def displayAxesSelector(self, isVisible):
- """Allow to display or hide the axes selector.
-
- :param bool isVisible: True to display the axes selector.
- """
- self.axesSelector.setVisible(isVisible)
-
- def setFrameIndex(self, index):
- """Set the active slice/image index in the n-dimensional array.
-
- A frame is a 2D array extracted from an array. This frame is
- necessarily parallel to 2 axes, and orthogonal to all other axes.
-
- The index of a frame is a sequence of indices along the orthogonal
- axes, where the frame intersects the respective axis. The indices
- are listed in the same order as the corresponding dimensions of the
- data array.
-
- For example, it the data array has 5 dimensions, and we are
- considering frames whose parallel axes are the 2nd and 4th dimensions
- of the array, the frame index will be a sequence of length 3
- corresponding to the indices where the frame intersects the 1st, 3rd
- and 5th axes.
-
- :param index: Sequence of indices defining the active data slice in
- a n-dimensional array. The sequence length is :math:`n-2`
- :raise: IndexError if any index in the index sequence is out of bound
- on its respective axis.
- """
- self.model.setFrameIndex(index)
-
- def _resetBrowsers(self, perspective):
- """Adjust limits for browsers based on the perspective and the
- size of the corresponding dimensions. Reset the index to 0.
- Update the dimension in the labels.
-
- :param perspective: Sequence of axes/dimensions numbers (0-based)
- defining the axes orthogonal to the frame.
- """
- # for 3D arrays we can accept an int rather than a 1-tuple
- if not hasattr(perspective, "__len__"):
- perspective = [perspective]
-
- # perspective must be sorted
- perspective = sorted(perspective)
-
- n_dimensions = len(self._data_shape)
- for i in range(n_dimensions - 2):
- browser = self._browserWidgets[i]
- label = self._browserLabels[i]
- browser.setRange(0, self._data_shape[perspective[i]] - 1)
- browser.setValue(0)
- label.setText(self._dimensionLabelsText[perspective[i]])
-
- def setPerspective(self, perspective):
- """Set the *perspective* by specifying which axes are orthogonal
- to the frame.
-
- For the opposite approach (defining parallel axes), use
- :meth:`setFrameAxes` instead.
-
- :param perspective: Sequence of unique axes numbers (0-based) defining
- the orthogonal axes. For a n-dimensional array, the sequence
- length is :math:`n-2`. The order is of the sequence is not taken
- into account (the dimensions are displayed in increasing order
- in the widget).
- """
- self.model.setPerspective(perspective)
- self._resetBrowsers(perspective)
-
- def setFrameAxes(self, row_axis, col_axis):
- """Set the *perspective* by specifying which axes are parallel
- to the frame.
-
- For the opposite approach (defining orthogonal axes), use
- :meth:`setPerspective` instead.
-
- :param int row_axis: Index (0-based) of the first dimension used as a frame
- axis
- :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
- axis
- """
- self.model.setFrameAxes(row_axis, col_axis)
- n_dimensions = len(self._data_shape)
- perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
- self._resetBrowsers(perspective)
-
- def _browserSlot(self, value):
- index = []
- for browser in self._browserWidgets:
- if browser.isEnabled():
- index.append(browser.value())
- self.setFrameIndex(index)
- self.view.reset()
-
- def getData(self, copy=True):
- """Return a copy of the data array, or a reference to it if
- *copy=False* is passed as parameter.
-
- :param bool copy: If *True* (default), return a copy of the data. If
- *False*, return a reference.
- :return: Numpy array of data, or reference to original data object
- if *copy=False*
- """
- return self.model.getData(copy=copy)
-
-
-def main():
- import numpy
- a = qt.QApplication([])
- d = numpy.random.normal(0, 1, (4, 5, 1000, 1000))
- for j in range(4):
- for i in range(5):
- d[j, i, :, :] += i + 10 * j
- w = ArrayTableWidget()
- if "2" in sys.argv:
- print("sending a single image")
- w.setArrayData(d[0, 0])
- elif "3" in sys.argv:
- print("sending 5 images")
- w.setArrayData(d[0])
- else:
- print("sending 4 * 5 images ")
- w.setArrayData(d, labels=True)
- w.show()
- a.exec_()
-
-if __name__ == "__main__":
- main()
diff --git a/silx/gui/data/Hdf5TableView.py b/silx/gui/data/Hdf5TableView.py
deleted file mode 100644
index 7749326..0000000
--- a/silx/gui/data/Hdf5TableView.py
+++ /dev/null
@@ -1,646 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module define model and widget to display 1D slices from numpy
-array using compound data types or hdf5 databases.
-"""
-from __future__ import division
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "12/02/2019"
-
-import collections
-import functools
-import os.path
-import logging
-import h5py
-import numpy
-
-from silx.gui import qt
-import silx.io
-from .TextFormatter import TextFormatter
-import silx.gui.hdf5
-from silx.gui.widgets import HierarchicalTableView
-from ..hdf5.Hdf5Formatter import Hdf5Formatter
-from ..hdf5._utils import htmlFromDict
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _CellData(object):
- """Store a table item
- """
- def __init__(self, value=None, isHeader=False, span=None, tooltip=None):
- """
- Constructor
-
- :param str value: Label of this property
- :param bool isHeader: True if the cell is an header
- :param tuple span: Tuple of row, column span
- """
- self.__value = value
- self.__isHeader = isHeader
- self.__span = span
- self.__tooltip = tooltip
-
- def isHeader(self):
- """Returns true if the property is a sub-header title.
-
- :rtype: bool
- """
- return self.__isHeader
-
- def value(self):
- """Returns the value of the item.
- """
- return self.__value
-
- def span(self):
- """Returns the span size of the cell.
-
- :rtype: tuple
- """
- return self.__span
-
- def tooltip(self):
- """Returns the tooltip of the item.
-
- :rtype: tuple
- """
- return self.__tooltip
-
- def invalidateValue(self):
- self.__value = None
-
- def invalidateToolTip(self):
- self.__tooltip = None
-
- def data(self, role):
- return None
-
-
-class _TableData(object):
- """Modelize a table with header, row and column span.
-
- It is mostly defined as a row based table.
- """
-
- def __init__(self, columnCount):
- """Constructor.
-
- :param int columnCount: Define the number of column of the table
- """
- self.__colCount = columnCount
- self.__data = []
-
- def rowCount(self):
- """Returns the number of rows.
-
- :rtype: int
- """
- return len(self.__data)
-
- def columnCount(self):
- """Returns the number of columns.
-
- :rtype: int
- """
- return self.__colCount
-
- def clear(self):
- """Remove all the cells of the table"""
- self.__data = []
-
- def cellAt(self, row, column):
- """Returns the cell at the row column location. Else None if there is
- nothing.
-
- :rtype: _CellData
- """
- if row < 0:
- return None
- if column < 0:
- return None
- if row >= len(self.__data):
- return None
- cells = self.__data[row]
- if column >= len(cells):
- return None
- return cells[column]
-
- def addHeaderRow(self, headerLabel):
- """Append the table with header on the full row.
-
- :param str headerLabel: label of the header.
- """
- item = _CellData(value=headerLabel, isHeader=True, span=(1, self.__colCount))
- self.__data.append([item])
-
- def addHeaderValueRow(self, headerLabel, value, tooltip=None):
- """Append the table with a row using the first column as an header and
- other cells as a single cell for the value.
-
- :param str headerLabel: label of the header.
- :param object value: value to store.
- """
- header = _CellData(value=headerLabel, isHeader=True)
- value = _CellData(value=value, span=(1, self.__colCount), tooltip=tooltip)
- self.__data.append([header, value])
-
- def addRow(self, *args):
- """Append the table with a row using arguments for each cells
-
- :param list[object] args: List of cell values for the row
- """
- row = []
- for value in args:
- if not isinstance(value, _CellData):
- value = _CellData(value=value)
- row.append(value)
- self.__data.append(row)
-
-
-class _CellFilterAvailableData(_CellData):
- """Cell rendering for availability of a filter"""
-
- _states = {
- True: ("Available", qt.QColor(0x000000), None, None),
- False: ("Not available", qt.QColor(0xFFFFFF), qt.QColor(0xFF0000),
- "You have to install this filter on your system to be able to read this dataset"),
- "na": ("n.a.", qt.QColor(0x000000), None,
- "This version of h5py/hdf5 is not able to display the information"),
- }
-
- def __init__(self, filterId):
- if h5py.version.hdf5_version_tuple >= (1, 10, 2):
- # Previous versions only returns True if the filter was first used
- # to decode a dataset
- self.__availability = h5py.h5z.filter_avail(filterId)
- else:
- self.__availability = "na"
- _CellData.__init__(self)
-
- def value(self):
- state = self._states[self.__availability]
- return state[0]
-
- def tooltip(self):
- state = self._states[self.__availability]
- return state[3]
-
- def data(self, role=qt.Qt.DisplayRole):
- state = self._states[self.__availability]
- if role == qt.Qt.TextColorRole:
- return state[1]
- elif role == qt.Qt.BackgroundColorRole:
- return state[2]
- else:
- return None
-
-
-class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
- """This data model provides access to HDF5 node content (File, Group,
- Dataset). Main info, like name, file, attributes... are displayed
- """
-
- def __init__(self, parent=None, data=None):
- """
- Constructor
-
- :param qt.QObject parent: Parent object
- :param object data: An h5py-like object (file, group or dataset)
- """
- super(Hdf5TableModel, self).__init__(parent)
-
- self.__obj = None
- self.__data = _TableData(columnCount=5)
- self.__formatter = None
- self.__hdf5Formatter = Hdf5Formatter(self)
- formatter = TextFormatter(self)
- self.setFormatter(formatter)
- self.setObject(data)
-
- def rowCount(self, parent_idx=None):
- """Returns number of rows to be displayed in table"""
- return self.__data.rowCount()
-
- def columnCount(self, parent_idx=None):
- """Returns number of columns to be displayed in table"""
- return self.__data.columnCount()
-
- def data(self, index, role=qt.Qt.DisplayRole):
- """QAbstractTableModel method to access data values
- in the format ready to be displayed"""
- if not index.isValid():
- return None
-
- cell = self.__data.cellAt(index.row(), index.column())
- if cell is None:
- return None
-
- if role == self.SpanRole:
- return cell.span()
- elif role == self.IsHeaderRole:
- return cell.isHeader()
- elif role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
- value = cell.value()
- if callable(value):
- try:
- value = value(self.__obj)
- except Exception:
- cell.invalidateValue()
- raise
- return value
- elif role == qt.Qt.ToolTipRole:
- value = cell.tooltip()
- if callable(value):
- try:
- value = value(self.__obj)
- except Exception:
- cell.invalidateToolTip()
- raise
- return value
- else:
- return cell.data(role)
- return None
-
- def isSupportedObject(self, h5pyObject):
- """
- Returns true if the provided object can be modelized using this model.
- """
- isSupported = False
- isSupported = isSupported or silx.io.is_group(h5pyObject)
- isSupported = isSupported or silx.io.is_dataset(h5pyObject)
- isSupported = isSupported or isinstance(h5pyObject, silx.gui.hdf5.H5Node)
- return isSupported
-
- def setObject(self, h5pyObject):
- """Set the h5py-like object exposed by the model
-
- :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`,
- a `h5py.File`, a `h5py.Group`. It also can be a,
- `silx.gui.hdf5.H5Node` which is needed to display some local path
- information.
- """
- if qt.qVersion() > "4.6":
- self.beginResetModel()
-
- if h5pyObject is None or self.isSupportedObject(h5pyObject):
- self.__obj = h5pyObject
- else:
- _logger.warning("Object class %s unsupported. Object ignored.", type(h5pyObject))
- self.__initProperties()
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
- else:
- self.reset()
-
- def __formatHdf5Type(self, dataset):
- """Format the HDF5 type"""
- return self.__hdf5Formatter.humanReadableHdf5Type(dataset)
-
- def __attributeTooltip(self, attribute):
- attributeDict = collections.OrderedDict()
- if hasattr(attribute, "shape"):
- attributeDict["Shape"] = self.__hdf5Formatter.humanReadableShape(attribute)
- attributeDict["Data type"] = self.__hdf5Formatter.humanReadableType(attribute, full=True)
- html = htmlFromDict(attributeDict, title="HDF5 Attribute")
- return html
-
- def __formatDType(self, dataset):
- """Format the numpy dtype"""
- return self.__hdf5Formatter.humanReadableType(dataset, full=True)
-
- def __formatShape(self, dataset):
- """Format the shape"""
- if dataset.shape is None or len(dataset.shape) <= 1:
- return self.__hdf5Formatter.humanReadableShape(dataset)
- size = dataset.size
- shape = self.__hdf5Formatter.humanReadableShape(dataset)
- return u"%s = %s" % (shape, size)
-
- def __formatChunks(self, dataset):
- """Format the shape"""
- chunks = dataset.chunks
- if chunks is None:
- return ""
- shape = " \u00D7 ".join([str(i) for i in chunks])
- sizes = numpy.product(chunks)
- text = "%s = %s" % (shape, sizes)
- return text
-
- def __initProperties(self):
- """Initialize the list of available properties according to the defined
- h5py-like object."""
- self.__data.clear()
- if self.__obj is None:
- return
-
- obj = self.__obj
-
- hdf5obj = obj
- if isinstance(obj, silx.gui.hdf5.H5Node):
- hdf5obj = obj.h5py_object
-
- if silx.io.is_file(hdf5obj):
- objectType = "File"
- elif silx.io.is_group(hdf5obj):
- objectType = "Group"
- elif silx.io.is_dataset(hdf5obj):
- objectType = "Dataset"
- else:
- objectType = obj.__class__.__name__
- self.__data.addHeaderRow(headerLabel="HDF5 %s" % objectType)
-
- SEPARATOR = "::"
-
- self.__data.addHeaderRow(headerLabel="Path info")
- showPhysicalLocation = True
- if isinstance(obj, silx.gui.hdf5.H5Node):
- # helpful informations if the object come from an HDF5 tree
- self.__data.addHeaderValueRow("Basename", lambda x: x.local_basename)
- self.__data.addHeaderValueRow("Name", lambda x: x.local_name)
- local = lambda x: x.local_filename + SEPARATOR + x.local_name
- self.__data.addHeaderValueRow("Local", local)
- else:
- # it's a real H5py object
- self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name))
- self.__data.addHeaderValueRow("Name", lambda x: x.name)
- if obj.file is not None:
- self.__data.addHeaderValueRow("File", lambda x: x.file.filename)
- if hasattr(obj, "path"):
- # That's a link
- if hasattr(obj, "filename"):
- # External link
- link = lambda x: x.filename + SEPARATOR + x.path
- else:
- # Soft link
- link = lambda x: x.path
- self.__data.addHeaderValueRow("Link", link)
- showPhysicalLocation = False
-
- # External data (nothing to do with external links)
- nExtSources = 0
- firstExtSource = None
- extType = None
- if silx.io.is_dataset(hdf5obj):
- if hasattr(hdf5obj, "is_virtual"):
- if hdf5obj.is_virtual:
- extSources = hdf5obj.virtual_sources()
- if extSources:
- firstExtSource = extSources[0].file_name + SEPARATOR + extSources[0].dset_name
- extType = "Virtual"
- nExtSources = len(extSources)
- if hasattr(hdf5obj, "external"):
- extSources = hdf5obj.external
- if extSources:
- firstExtSource = extSources[0][0]
- extType = "Raw"
- nExtSources = len(extSources)
-
- if showPhysicalLocation:
- def _physical_location(x):
- if isinstance(obj, silx.gui.hdf5.H5Node):
- return x.physical_filename + SEPARATOR + x.physical_name
- elif silx.io.is_file(obj):
- return x.filename + SEPARATOR + x.name
- elif obj.file is not None:
- return x.file.filename + SEPARATOR + x.name
- else:
- # Guess it is a virtual node
- return "No physical location"
-
- self.__data.addHeaderValueRow("Physical", _physical_location)
-
- if extType:
- def _first_source(x):
- # Absolute path
- if os.path.isabs(firstExtSource):
- return firstExtSource
-
- # Relative path with respect to the file directory
- if isinstance(obj, silx.gui.hdf5.H5Node):
- filename = x.physical_filename
- elif silx.io.is_file(obj):
- filename = x.filename
- elif obj.file is not None:
- filename = x.file.filename
- else:
- return firstExtSource
-
- if firstExtSource[0] == ".":
- firstExtSource.pop(0)
- return os.path.join(os.path.dirname(filename), firstExtSource)
-
- self.__data.addHeaderRow(headerLabel="External sources")
- self.__data.addHeaderValueRow("Type", extType)
- self.__data.addHeaderValueRow("Count", str(nExtSources))
- self.__data.addHeaderValueRow("First", _first_source)
-
- if hasattr(obj, "dtype"):
-
- self.__data.addHeaderRow(headerLabel="Data info")
-
- if hasattr(obj, "id") and hasattr(obj.id, "get_type"):
- # display the HDF5 type
- self.__data.addHeaderValueRow("HDF5 type", self.__formatHdf5Type)
- self.__data.addHeaderValueRow("dtype", self.__formatDType)
- if hasattr(obj, "shape"):
- self.__data.addHeaderValueRow("shape", self.__formatShape)
- if hasattr(obj, "chunks") and obj.chunks is not None:
- self.__data.addHeaderValueRow("chunks", self.__formatChunks)
-
- # relative to compression
- # h5py expose compression, compression_opts but are not initialized
- # for external plugins, then we use id
- # h5py also expose fletcher32 and shuffle attributes, but it is also
- # part of the filters
- if hasattr(obj, "shape") and hasattr(obj, "id"):
- if hasattr(obj.id, "get_create_plist"):
- dcpl = obj.id.get_create_plist()
- if dcpl.get_nfilters() > 0:
- self.__data.addHeaderRow(headerLabel="Compression info")
- pos = _CellData(value="Position", isHeader=True)
- hdf5id = _CellData(value="HDF5 ID", isHeader=True)
- name = _CellData(value="Name", isHeader=True)
- options = _CellData(value="Options", isHeader=True)
- availability = _CellData(value="", isHeader=True)
- self.__data.addRow(pos, hdf5id, name, options, availability)
- for index in range(dcpl.get_nfilters()):
- filterId, name, options = self.__getFilterInfo(obj, index)
- pos = _CellData(value=str(index))
- hdf5id = _CellData(value=str(filterId))
- name = _CellData(value=name)
- options = _CellData(value=options)
- availability = _CellFilterAvailableData(filterId=filterId)
- self.__data.addRow(pos, hdf5id, name, options, availability)
-
- if hasattr(obj, "attrs"):
- if len(obj.attrs) > 0:
- self.__data.addHeaderRow(headerLabel="Attributes")
- for key in sorted(obj.attrs.keys()):
- callback = lambda key, x: self.__formatter.toString(x.attrs[key])
- callbackTooltip = lambda key, x: self.__attributeTooltip(x.attrs[key])
- self.__data.addHeaderValueRow(headerLabel=key,
- value=functools.partial(callback, key),
- tooltip=functools.partial(callbackTooltip, key))
-
- def __getFilterInfo(self, dataset, filterIndex):
- """Get a tuple of readable info from dataset filters
-
- :param h5py.Dataset dataset: A h5py dataset
- :param int filterId:
- """
- try:
- dcpl = dataset.id.get_create_plist()
- info = dcpl.get_filter(filterIndex)
- filterId, _flags, cdValues, name = info
- name = self.__formatter.toString(name)
- options = " ".join([self.__formatter.toString(i) for i in cdValues])
- return (filterId, name, options)
- except Exception:
- _logger.debug("Backtrace", exc_info=True)
- return (None, None, None)
-
- def object(self):
- """Returns the internal object modelized.
-
- :rtype: An h5py-like object
- """
- return self.__obj
-
- def setFormatter(self, formatter):
- """Set the formatter object to be used to display data from the model
-
- :param TextFormatter formatter: Formatter to use
- """
- if formatter is self.__formatter:
- return
-
- self.__hdf5Formatter.setTextFormatter(formatter)
-
- if qt.qVersion() > "4.6":
- self.beginResetModel()
-
- if self.__formatter is not None:
- self.__formatter.formatChanged.disconnect(self.__formatChanged)
-
- self.__formatter = formatter
- if self.__formatter is not None:
- self.__formatter.formatChanged.connect(self.__formatChanged)
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
- else:
- self.reset()
-
- def getFormatter(self):
- """Returns the text formatter used.
-
- :rtype: TextFormatter
- """
- return self.__formatter
-
- def __formatChanged(self):
- """Called when the format changed.
- """
- self.reset()
-
-
-class Hdf5TableItemDelegate(HierarchicalTableView.HierarchicalItemDelegate):
- """Item delegate the :class:`Hdf5TableView` with read-only text editor"""
-
- def createEditor(self, parent, option, index):
- """See :meth:`QStyledItemDelegate.createEditor`"""
- editor = super().createEditor(parent, option, index)
- if isinstance(editor, qt.QLineEdit):
- editor.setReadOnly(True)
- editor.deselect()
- editor.textChanged.connect(self.__textChanged, qt.Qt.QueuedConnection)
- self.installEventFilter(editor)
- return editor
-
- def __textChanged(self, text):
- sender = self.sender()
- if sender is not None:
- sender.deselect()
-
- def eventFilter(self, watched, event):
- eventType = event.type()
- if eventType == qt.QEvent.FocusIn:
- watched.selectAll()
- qt.QTimer.singleShot(0, watched.selectAll)
- elif eventType == qt.QEvent.FocusOut:
- watched.deselect()
- return super().eventFilter(watched, event)
-
-
-class Hdf5TableView(HierarchicalTableView.HierarchicalTableView):
- """A widget to display metadata about a HDF5 node using a table."""
-
- def __init__(self, parent=None):
- super(Hdf5TableView, self).__init__(parent)
- self.setModel(Hdf5TableModel(self))
- self.setItemDelegate(Hdf5TableItemDelegate(self))
- self.setSelectionMode(qt.QAbstractItemView.NoSelection)
-
- def isSupportedData(self, data):
- """
- Returns true if the provided object can be modelized using this model.
- """
- return self.model().isSupportedObject(data)
-
- def setData(self, data):
- """Set the h5py-like object exposed by the model
-
- :param data: A h5py-like object. It can be a `h5py.Dataset`,
- a `h5py.File`, a `h5py.Group`. It also can be a,
- `silx.gui.hdf5.H5Node` which is needed to display some local path
- information.
- """
- model = self.model()
-
- model.setObject(data)
- header = self.horizontalHeader()
- if qt.qVersion() < "5.0":
- setResizeMode = header.setResizeMode
- else:
- setResizeMode = header.setSectionResizeMode
- setResizeMode(0, qt.QHeaderView.Fixed)
- setResizeMode(1, qt.QHeaderView.ResizeToContents)
- setResizeMode(2, qt.QHeaderView.Stretch)
- setResizeMode(3, qt.QHeaderView.ResizeToContents)
- setResizeMode(4, qt.QHeaderView.ResizeToContents)
- header.setStretchLastSection(False)
-
- for row in range(model.rowCount()):
- for column in range(model.columnCount()):
- index = model.index(row, column)
- if (index.isValid() and index.data(
- HierarchicalTableView.HierarchicalTableModel.IsHeaderRole) is False):
- self.openPersistentEditor(index)
diff --git a/silx/gui/data/HexaTableView.py b/silx/gui/data/HexaTableView.py
deleted file mode 100644
index 1617f0a..0000000
--- a/silx/gui/data/HexaTableView.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module defines model and widget to display raw data using an
-hexadecimal viewer.
-"""
-from __future__ import division
-
-import collections
-
-import numpy
-import six
-
-from silx.gui import qt
-import silx.io.utils
-from silx.gui.widgets.TableWidget import CopySelectedCellsAction
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "23/05/2018"
-
-
-class _VoidConnector(object):
- """Byte connector to a numpy.void data.
-
- It uses a cache of 32 x 1KB and a direct read access API from HDF5.
- """
-
- def __init__(self, data):
- self.__cache = collections.OrderedDict()
- self.__len = data.itemsize
- self.__data = data
-
- def __getBuffer(self, bufferId):
- if bufferId not in self.__cache:
- pos = bufferId << 10
- data = self.__data
- if hasattr(data, "tobytes"):
- data = data.tobytes()[pos:pos + 1024]
- else:
- # Old fashion
- data = data.data[pos:pos + 1024]
-
- self.__cache[bufferId] = data
- if len(self.__cache) > 32:
- self.__cache.popitem()
- else:
- data = self.__cache[bufferId]
- return data
-
- def __getitem__(self, pos):
- """Returns the value of the byte at the given position.
-
- :param uint pos: Position of the byte
- :rtype: int
- """
- bufferId = pos >> 10
- bufferPos = pos & 0b1111111111
- data = self.__getBuffer(bufferId)
- value = data[bufferPos]
- if six.PY2:
- return ord(value)
- else:
- return value
-
- def __len__(self):
- """
- Returns the number of available bytes.
-
- :rtype: uint
- """
- return self.__len
-
-
-class HexaTableModel(qt.QAbstractTableModel):
- """This data model provides access to a numpy void data.
-
- Bytes are displayed one by one as a hexadecimal viewer.
-
- The 16th first columns display bytes as hexadecimal, the last column
- displays the same data as ASCII.
-
- :param qt.QObject parent: Parent object
- :param data: A numpy array or a h5py dataset
- """
- def __init__(self, parent=None, data=None):
- qt.QAbstractTableModel.__init__(self, parent)
-
- self.__data = None
- self.__connector = None
- self.setArrayData(data)
-
- if hasattr(qt.QFontDatabase, "systemFont"):
- self.__font = qt.QFontDatabase.systemFont(qt.QFontDatabase.FixedFont)
- else:
- self.__font = qt.QFont("Monospace")
- self.__font.setStyleHint(qt.QFont.TypeWriter)
- self.__palette = qt.QPalette()
-
- def rowCount(self, parent_idx=None):
- """Returns number of rows to be displayed in table"""
- if self.__connector is None:
- return 0
- return ((len(self.__connector) - 1) >> 4) + 1
-
- def columnCount(self, parent_idx=None):
- """Returns number of columns to be displayed in table"""
- return 0x10 + 1
-
- def data(self, index, role=qt.Qt.DisplayRole):
- """QAbstractTableModel method to access data values
- in the format ready to be displayed"""
- if not index.isValid():
- return None
-
- if self.__connector is None:
- return None
-
- row = index.row()
- column = index.column()
-
- if role == qt.Qt.DisplayRole:
- if column == 0x10:
- start = (row << 4)
- text = ""
- for i in range(0x10):
- pos = start + i
- if pos >= len(self.__connector):
- break
- value = self.__connector[pos]
- if value > 0x20 and value < 0x7F:
- text += chr(value)
- else:
- text += "."
- return text
- else:
- pos = (row << 4) + column
- if pos < len(self.__connector):
- value = self.__connector[pos]
- return "%02X" % value
- else:
- return ""
- elif role == qt.Qt.FontRole:
- return self.__font
-
- elif role == qt.Qt.BackgroundColorRole:
- pos = (row << 4) + column
- if column != 0x10 and pos >= len(self.__connector):
- return self.__palette.color(qt.QPalette.Disabled, qt.QPalette.Background)
- else:
- return None
-
- return None
-
- def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
- """Returns the 0-based row or column index, for display in the
- horizontal and vertical headers"""
- if section == -1:
- # PyQt4 send -1 when there is columns but no rows
- return None
-
- if role == qt.Qt.DisplayRole:
- if orientation == qt.Qt.Vertical:
- return "%02X" % (section << 4)
- if orientation == qt.Qt.Horizontal:
- if section == 0x10:
- return "ASCII"
- else:
- return "%02X" % section
- elif role == qt.Qt.FontRole:
- return self.__font
- elif role == qt.Qt.TextAlignmentRole:
- if orientation == qt.Qt.Vertical:
- return qt.Qt.AlignRight
- if orientation == qt.Qt.Horizontal:
- if section == 0x10:
- return qt.Qt.AlignLeft
- else:
- return qt.Qt.AlignCenter
- return None
-
- def flags(self, index):
- """QAbstractTableModel method to inform the view whether data
- is editable or not.
- """
- row = index.row()
- column = index.column()
- pos = (row << 4) + column
- if column != 0x10 and pos >= len(self.__connector):
- return qt.Qt.NoItemFlags
- return qt.QAbstractTableModel.flags(self, index)
-
- def setArrayData(self, data):
- """Set the data array.
-
- :param data: A numpy object or a dataset.
- """
- if qt.qVersion() > "4.6":
- self.beginResetModel()
-
- self.__connector = None
- self.__data = data
- if self.__data is not None:
- if silx.io.utils.is_dataset(self.__data):
- data = data[()]
- elif isinstance(self.__data, numpy.ndarray):
- data = data[()]
- self.__connector = _VoidConnector(data)
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
- else:
- self.reset()
-
- def arrayData(self):
- """Returns the internal data.
-
- :rtype: numpy.ndarray of h5py.Dataset
- """
- return self.__data
-
-
-class HexaTableView(qt.QTableView):
- """TableView using HexaTableModel as default model.
-
- It customs the column size to provide a better layout.
- """
- def __init__(self, parent=None):
- """
- Constructor
-
- :param qt.QWidget parent: parent QWidget
- """
- qt.QTableView.__init__(self, parent)
-
- model = HexaTableModel(self)
- self.setModel(model)
- self._copyAction = CopySelectedCellsAction(self)
- self.addAction(self._copyAction)
-
- def copy(self):
- self._copyAction.trigger()
-
- def setArrayData(self, data):
- """Set the data array.
-
- :param data: A numpy object or a dataset.
- """
- self.model().setArrayData(data)
- self.__fixHeader()
-
- def __fixHeader(self):
- """Update the view according to the state of the auto-resize"""
- header = self.horizontalHeader()
- if qt.qVersion() < "5.0":
- setResizeMode = header.setResizeMode
- else:
- setResizeMode = header.setSectionResizeMode
-
- header.setDefaultSectionSize(30)
- header.setStretchLastSection(True)
- for i in range(0x10):
- setResizeMode(i, qt.QHeaderView.Fixed)
- setResizeMode(0x10, qt.QHeaderView.Stretch)
diff --git a/silx/gui/data/NXdataWidgets.py b/silx/gui/data/NXdataWidgets.py
deleted file mode 100644
index be7d0e3..0000000
--- a/silx/gui/data/NXdataWidgets.py
+++ /dev/null
@@ -1,1081 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module defines widgets used by _NXdataView.
-"""
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "12/11/2018"
-
-import logging
-import numpy
-
-from silx.gui import qt
-from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
-from silx.gui.plot import Plot1D, Plot2D, StackView, ScatterView
-from silx.gui.plot.ComplexImageView import ComplexImageView
-from silx.gui.colors import Colormap
-from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
-
-from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ArrayCurvePlot(qt.QWidget):
- """
- Widget for plotting a curve from a multi-dimensional signal array
- and a 1D axis array.
-
- The signal array can have an arbitrary number of dimensions, the only
- limitation being that the last dimension must have the same length as
- the axis array.
-
- The widget provides sliders to select indices on the first (n - 1)
- dimensions of the signal array, and buttons to add/replace selected
- curves to the plot.
-
- This widget also handles simple 2D or 3D scatter plots (third dimension
- displayed as colour of points).
- """
- def __init__(self, parent=None):
- """
-
- :param parent: Parent QWidget
- """
- super(ArrayCurvePlot, self).__init__(parent)
-
- self.__signals = None
- self.__signals_names = None
- self.__signal_errors = None
- self.__axis = None
- self.__axis_name = None
- self.__x_axis_errors = None
- self.__values = None
-
- self._plot = Plot1D(self)
-
- self._selector = NumpyAxesSelector(self)
- self._selector.setNamedAxesSelectorVisibility(False)
- self.__selector_is_connected = False
-
- self._plot.sigActiveCurveChanged.connect(self._setYLabelFromActiveLegend)
-
- layout = qt.QVBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self._plot)
- layout.addWidget(self._selector)
-
- self.setLayout(layout)
-
- def getPlot(self):
- """Returns the plot used for the display
-
- :rtype: Plot1D
- """
- return self._plot
-
- def setCurvesData(self, ys, x=None,
- yerror=None, xerror=None,
- ylabels=None, xlabel=None, title=None,
- xscale=None, yscale=None):
- """
-
- :param List[ndarray] ys: List of arrays to be represented by the y (vertical) axis.
- It can be multiple n-D array whose last dimension must
- have the same length as x (and values must be None)
- :param ndarray x: 1-D dataset used as the curve's x values. If provided,
- its lengths must be equal to the length of the last dimension of
- ``y`` (and equal to the length of ``value``, for a scatter plot).
- :param ndarray yerror: Single array of errors for y (same shape), or None.
- There can only be one array, and it applies to the first/main y
- (no y errors for auxiliary_signals curves).
- :param ndarray xerror: 1-D dataset of errors for x, or None
- :param str ylabels: Labels for each curve's Y axis
- :param str xlabel: Label for X axis
- :param str title: Graph title
- :param str xscale: Scale of X axis in (None, 'linear', 'log')
- :param str yscale: Scale of Y axis in (None, 'linear', 'log')
- """
- self.__signals = ys
- self.__signals_names = ylabels or (["Y"] * len(ys))
- self.__signal_errors = yerror
- self.__axis = x
- self.__axis_name = xlabel
- self.__x_axis_errors = xerror
-
- if self.__selector_is_connected:
- self._selector.selectionChanged.disconnect(self._updateCurve)
- self.__selector_is_connected = False
- self._selector.setData(ys[0])
- self._selector.setAxisNames(["Y"])
-
- if len(ys[0].shape) < 2:
- self._selector.hide()
- else:
- self._selector.show()
-
- self._plot.setGraphTitle(title or "")
- if xscale is not None:
- self._plot.getXAxis().setScale(
- 'log' if xscale == 'log' else 'linear')
- if yscale is not None:
- self._plot.getYAxis().setScale(
- 'log' if yscale == 'log' else 'linear')
- self._updateCurve()
-
- if not self.__selector_is_connected:
- self._selector.selectionChanged.connect(self._updateCurve)
- self.__selector_is_connected = True
-
- def _updateCurve(self):
- selection = self._selector.selection()
- ys = [sig[selection] for sig in self.__signals]
- y0 = ys[0]
- len_y = len(y0)
- x = self.__axis
- if x is None:
- x = numpy.arange(len_y)
- elif numpy.isscalar(x) or len(x) == 1:
- # constant axis
- x = x * numpy.ones_like(y0)
- elif len(x) == 2 and len_y != 2:
- # linear calibration a + b * x
- x = x[0] + x[1] * numpy.arange(len_y)
-
- self._plot.remove(kind=("curve",))
-
- for i in range(len(self.__signals)):
- legend = self.__signals_names[i]
-
- # errors only supported for primary signal in NXdata
- y_errors = None
- if i == 0 and self.__signal_errors is not None:
- y_errors = self.__signal_errors[self._selector.selection()]
- self._plot.addCurve(x, ys[i], legend=legend,
- xerror=self.__x_axis_errors,
- yerror=y_errors)
- if i == 0:
- self._plot.setActiveCurve(legend)
-
- self._plot.resetZoom()
- self._plot.getXAxis().setLabel(self.__axis_name)
- self._plot.getYAxis().setLabel(self.__signals_names[0])
-
- def _setYLabelFromActiveLegend(self, previous_legend, new_legend):
- for ylabel in self.__signals_names:
- if new_legend is not None and new_legend == ylabel:
- self._plot.getYAxis().setLabel(ylabel)
- break
-
- def clear(self):
- old = self._selector.blockSignals(True)
- self._selector.clear()
- self._selector.blockSignals(old)
- self._plot.clear()
-
-
-class XYVScatterPlot(qt.QWidget):
- """
- Widget for plotting one or more scatters
- (with identical x, y coordinates).
- """
- def __init__(self, parent=None):
- """
-
- :param parent: Parent QWidget
- """
- super(XYVScatterPlot, self).__init__(parent)
-
- self.__y_axis = None
- """1D array"""
- self.__y_axis_name = None
- self.__values = None
- """List of 1D arrays (for multiple scatters with identical
- x, y coordinates)"""
-
- self.__x_axis = None
- self.__x_axis_name = None
- self.__x_axis_errors = None
- self.__y_axis = None
- self.__y_axis_name = None
- self.__y_axis_errors = None
-
- self._plot = ScatterView(self)
- self._plot.setColormap(Colormap(name="viridis",
- vmin=None, vmax=None,
- normalization=Colormap.LINEAR))
-
- self._slider = HorizontalSliderWithBrowser(parent=self)
- self._slider.setMinimum(0)
- self._slider.setValue(0)
- self._slider.valueChanged[int].connect(self._sliderIdxChanged)
- self._slider.setToolTip("Select auxiliary signals")
-
- layout = qt.QGridLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self._plot, 0, 0)
- layout.addWidget(self._slider, 1, 0)
-
- self.setLayout(layout)
-
- def _sliderIdxChanged(self, value):
- self._updateScatter()
-
- def getScatterView(self):
- """Returns the :class:`ScatterView` used for the display
-
- :rtype: ScatterView
- """
- return self._plot
-
- def getPlot(self):
- """Returns the plot used for the display
-
- :rtype: PlotWidget
- """
- return self._plot.getPlotWidget()
-
- def setScattersData(self, y, x, values,
- yerror=None, xerror=None,
- ylabel=None, xlabel=None,
- title="", scatter_titles=None,
- xscale=None, yscale=None):
- """
-
- :param ndarray y: 1D array for y (vertical) coordinates.
- :param ndarray x: 1D array for x coordinates.
- :param List[ndarray] values: List of 1D arrays of values.
- This will be used to compute the color map and assign colors
- to the points. There should be as many arrays in the list as
- scatters to be represented.
- :param ndarray yerror: 1D array of errors for y (same shape), or None.
- :param ndarray xerror: 1D array of errors for x, or None
- :param str ylabel: Label for Y axis
- :param str xlabel: Label for X axis
- :param str title: Main graph title
- :param List[str] scatter_titles: Subtitles (one per scatter)
- :param str xscale: Scale of X axis in (None, 'linear', 'log')
- :param str yscale: Scale of Y axis in (None, 'linear', 'log')
- """
- self.__y_axis = y
- self.__x_axis = x
- self.__x_axis_name = xlabel or "X"
- self.__y_axis_name = ylabel or "Y"
- self.__x_axis_errors = xerror
- self.__y_axis_errors = yerror
- self.__values = values
-
- self.__graph_title = title or ""
- self.__scatter_titles = scatter_titles
-
- self._slider.valueChanged[int].disconnect(self._sliderIdxChanged)
- self._slider.setMaximum(len(values) - 1)
- if len(values) > 1:
- self._slider.show()
- else:
- self._slider.hide()
- self._slider.setValue(0)
- self._slider.valueChanged[int].connect(self._sliderIdxChanged)
-
- if xscale is not None:
- self._plot.getXAxis().setScale(
- 'log' if xscale == 'log' else 'linear')
- if yscale is not None:
- self._plot.getYAxis().setScale(
- 'log' if yscale == 'log' else 'linear')
-
- self._updateScatter()
-
- def _updateScatter(self):
- x = self.__x_axis
- y = self.__y_axis
-
- idx = self._slider.value()
-
- if self.__graph_title:
- title = self.__graph_title # main NXdata @title
- if len(self.__scatter_titles) > 1:
- # Append dataset name only when there is many datasets
- title += '\n' + self.__scatter_titles[idx]
- else:
- title = self.__scatter_titles[idx] # scatter dataset name
-
- self._plot.setGraphTitle(title)
- self._plot.setData(x, y, self.__values[idx],
- xerror=self.__x_axis_errors,
- yerror=self.__y_axis_errors)
- self._plot.resetZoom()
- self._plot.getXAxis().setLabel(self.__x_axis_name)
- self._plot.getYAxis().setLabel(self.__y_axis_name)
-
- def clear(self):
- self._plot.getPlotWidget().clear()
-
-
-class ArrayImagePlot(qt.QWidget):
- """
- Widget for plotting an image from a multi-dimensional signal array
- and two 1D axes array.
-
- The signal array can have an arbitrary number of dimensions, the only
- limitation being that the last two dimensions must have the same length as
- the axes arrays.
-
- Sliders are provided to select indices on the first (n - 2) dimensions of
- the signal array, and the plot is updated to show the image corresponding
- to the selection.
-
- If one or both of the axes does not have regularly spaced values, the
- the image is plotted as a coloured scatter plot.
- """
- def __init__(self, parent=None):
- """
-
- :param parent: Parent QWidget
- """
- super(ArrayImagePlot, self).__init__(parent)
-
- self.__signals = None
- self.__signals_names = None
- self.__x_axis = None
- self.__x_axis_name = None
- self.__y_axis = None
- self.__y_axis_name = None
-
- self._plot = Plot2D(self)
- self._plot.setDefaultColormap(Colormap(name="viridis",
- vmin=None, vmax=None,
- normalization=Colormap.LINEAR))
- self._plot.getIntensityHistogramAction().setVisible(True)
- self._plot.setKeepDataAspectRatio(True)
- maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
- maskToolWidget.setItemMaskUpdated(True)
-
- # not closable
- self._selector = NumpyAxesSelector(self)
- self._selector.setNamedAxesSelectorVisibility(False)
- self._selector.selectionChanged.connect(self._updateImage)
-
- self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
- self._auxSigSlider.setMinimum(0)
- self._auxSigSlider.setValue(0)
- self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
- self._auxSigSlider.setToolTip("Select auxiliary signals")
-
- layout = qt.QVBoxLayout()
- layout.addWidget(self._plot)
- layout.addWidget(self._selector)
- layout.addWidget(self._auxSigSlider)
-
- self.setLayout(layout)
-
- def _sliderIdxChanged(self, value):
- self._updateImage()
-
- def getPlot(self):
- """Returns the plot used for the display
-
- :rtype: Plot2D
- """
- return self._plot
-
- def setImageData(self, signals,
- x_axis=None, y_axis=None,
- signals_names=None,
- xlabel=None, ylabel=None,
- title=None, isRgba=False,
- xscale=None, yscale=None):
- """
-
- :param signals: list of n-D datasets, whose last 2 dimensions are used as the
- image's values, or list of 3D datasets interpreted as RGBA image.
- :param x_axis: 1-D dataset used as the image's x coordinates. If
- provided, its lengths must be equal to the length of the last
- dimension of ``signal``.
- :param y_axis: 1-D dataset used as the image's y. If provided,
- its lengths must be equal to the length of the 2nd to last
- dimension of ``signal``.
- :param signals_names: Names for each image, used as subtitle and legend.
- :param xlabel: Label for X axis
- :param ylabel: Label for Y axis
- :param title: Graph title
- :param isRgba: True if data is a 3D RGBA image
- :param str xscale: Scale of X axis in (None, 'linear', 'log')
- :param str yscale: Scale of Y axis in (None, 'linear', 'log')
- """
- self._selector.selectionChanged.disconnect(self._updateImage)
- self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
-
- self.__signals = signals
- self.__signals_names = signals_names
- self.__x_axis = x_axis
- self.__x_axis_name = xlabel
- self.__y_axis = y_axis
- self.__y_axis_name = ylabel
- self.__title = title
-
- self._selector.clear()
- if not isRgba:
- self._selector.setAxisNames(["Y", "X"])
- img_ndim = 2
- else:
- self._selector.setAxisNames(["Y", "X", "RGB(A) channel"])
- img_ndim = 3
- self._selector.setData(signals[0])
-
- if len(signals[0].shape) <= img_ndim:
- self._selector.hide()
- else:
- self._selector.show()
-
- self._auxSigSlider.setMaximum(len(signals) - 1)
- if len(signals) > 1:
- self._auxSigSlider.show()
- else:
- self._auxSigSlider.hide()
- self._auxSigSlider.setValue(0)
-
- self._axis_scales = xscale, yscale
- self._updateImage()
- self._plot.resetZoom()
-
- self._selector.selectionChanged.connect(self._updateImage)
- self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
-
- def _updateImage(self):
- selection = self._selector.selection()
- auxSigIdx = self._auxSigSlider.value()
-
- legend = self.__signals_names[auxSigIdx]
-
- images = [img[selection] for img in self.__signals]
- image = images[auxSigIdx]
-
- x_axis = self.__x_axis
- y_axis = self.__y_axis
-
- if x_axis is None and y_axis is None:
- xcalib = NoCalibration()
- ycalib = NoCalibration()
- else:
- if x_axis is None:
- # no calibration
- x_axis = numpy.arange(image.shape[1])
- elif numpy.isscalar(x_axis) or len(x_axis) == 1:
- # constant axis
- x_axis = x_axis * numpy.ones((image.shape[1], ))
- elif len(x_axis) == 2:
- # linear calibration
- x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
-
- if y_axis is None:
- y_axis = numpy.arange(image.shape[0])
- elif numpy.isscalar(y_axis) or len(y_axis) == 1:
- y_axis = y_axis * numpy.ones((image.shape[0], ))
- elif len(y_axis) == 2:
- y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
-
- xcalib = ArrayCalibration(x_axis)
- ycalib = ArrayCalibration(y_axis)
-
- self._plot.remove(kind=("scatter", "image",))
- if xcalib.is_affine() and ycalib.is_affine():
- # regular image
- xorigin, xscale = xcalib(0), xcalib.get_slope()
- yorigin, yscale = ycalib(0), ycalib.get_slope()
- origin = (xorigin, yorigin)
- scale = (xscale, yscale)
-
- self._plot.getXAxis().setScale('linear')
- self._plot.getYAxis().setScale('linear')
- self._plot.addImage(image, legend=legend,
- origin=origin, scale=scale,
- replace=True, resetzoom=False)
- else:
- xaxisscale, yaxisscale = self._axis_scales
-
- if xaxisscale is not None:
- self._plot.getXAxis().setScale(
- 'log' if xaxisscale == 'log' else 'linear')
- if yaxisscale is not None:
- self._plot.getYAxis().setScale(
- 'log' if yaxisscale == 'log' else 'linear')
-
- scatterx, scattery = numpy.meshgrid(x_axis, y_axis)
- # fixme: i don't think this can handle "irregular" RGBA images
- self._plot.addScatter(numpy.ravel(scatterx),
- numpy.ravel(scattery),
- numpy.ravel(image),
- legend=legend)
-
- if self.__title:
- title = self.__title
- if len(self.__signals_names) > 1:
- # Append dataset name only when there is many datasets
- title += '\n' + self.__signals_names[auxSigIdx]
- else:
- title = self.__signals_names[auxSigIdx]
- self._plot.setGraphTitle(title)
- self._plot.getXAxis().setLabel(self.__x_axis_name)
- self._plot.getYAxis().setLabel(self.__y_axis_name)
-
- def clear(self):
- old = self._selector.blockSignals(True)
- self._selector.clear()
- self._selector.blockSignals(old)
- self._plot.clear()
-
-
-class ArrayComplexImagePlot(qt.QWidget):
- """
- Widget for plotting an image of complex from a multi-dimensional signal array
- and two 1D axes array.
-
- The signal array can have an arbitrary number of dimensions, the only
- limitation being that the last two dimensions must have the same length as
- the axes arrays.
-
- Sliders are provided to select indices on the first (n - 2) dimensions of
- the signal array, and the plot is updated to show the image corresponding
- to the selection.
-
- If one or both of the axes does not have regularly spaced values, the
- the image is plotted as a coloured scatter plot.
- """
- def __init__(self, parent=None, colormap=None):
- """
-
- :param parent: Parent QWidget
- """
- super(ArrayComplexImagePlot, self).__init__(parent)
-
- self.__signals = None
- self.__signals_names = None
- self.__x_axis = None
- self.__x_axis_name = None
- self.__y_axis = None
- self.__y_axis_name = None
-
- self._plot = ComplexImageView(self)
- if colormap is not None:
- for mode in (ComplexImageView.ComplexMode.ABSOLUTE,
- ComplexImageView.ComplexMode.SQUARE_AMPLITUDE,
- ComplexImageView.ComplexMode.REAL,
- ComplexImageView.ComplexMode.IMAGINARY):
- self._plot.setColormap(colormap, mode)
-
- self._plot.getPlot().getIntensityHistogramAction().setVisible(True)
- self._plot.setKeepDataAspectRatio(True)
- maskToolWidget = self._plot.getPlot().getMaskToolsDockWidget().widget()
- maskToolWidget.setItemMaskUpdated(True)
-
- # not closable
- self._selector = NumpyAxesSelector(self)
- self._selector.setNamedAxesSelectorVisibility(False)
- self._selector.selectionChanged.connect(self._updateImage)
-
- self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
- self._auxSigSlider.setMinimum(0)
- self._auxSigSlider.setValue(0)
- self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
- self._auxSigSlider.setToolTip("Select auxiliary signals")
-
- layout = qt.QVBoxLayout()
- layout.addWidget(self._plot)
- layout.addWidget(self._selector)
- layout.addWidget(self._auxSigSlider)
-
- self.setLayout(layout)
-
- def _sliderIdxChanged(self, value):
- self._updateImage()
-
- def getPlot(self):
- """Returns the plot used for the display
-
- :rtype: PlotWidget
- """
- return self._plot.getPlot()
-
- def setImageData(self, signals,
- x_axis=None, y_axis=None,
- signals_names=None,
- xlabel=None, ylabel=None,
- title=None):
- """
-
- :param signals: list of n-D datasets, whose last 2 dimensions are used as the
- image's values, or list of 3D datasets interpreted as RGBA image.
- :param x_axis: 1-D dataset used as the image's x coordinates. If
- provided, its lengths must be equal to the length of the last
- dimension of ``signal``.
- :param y_axis: 1-D dataset used as the image's y. If provided,
- its lengths must be equal to the length of the 2nd to last
- dimension of ``signal``.
- :param signals_names: Names for each image, used as subtitle and legend.
- :param xlabel: Label for X axis
- :param ylabel: Label for Y axis
- :param title: Graph title
- """
- self._selector.selectionChanged.disconnect(self._updateImage)
- self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
-
- self.__signals = signals
- self.__signals_names = signals_names
- self.__x_axis = x_axis
- self.__x_axis_name = xlabel
- self.__y_axis = y_axis
- self.__y_axis_name = ylabel
- self.__title = title
-
- self._selector.clear()
- self._selector.setAxisNames(["Y", "X"])
- self._selector.setData(signals[0])
-
- if len(signals[0].shape) <= 2:
- self._selector.hide()
- else:
- self._selector.show()
-
- self._auxSigSlider.setMaximum(len(signals) - 1)
- if len(signals) > 1:
- self._auxSigSlider.show()
- else:
- self._auxSigSlider.hide()
- self._auxSigSlider.setValue(0)
-
- self._updateImage()
- self._plot.getPlot().resetZoom()
-
- self._selector.selectionChanged.connect(self._updateImage)
- self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
-
- def _updateImage(self):
- selection = self._selector.selection()
- auxSigIdx = self._auxSigSlider.value()
-
- images = [img[selection] for img in self.__signals]
- image = images[auxSigIdx]
-
- x_axis = self.__x_axis
- y_axis = self.__y_axis
-
- if x_axis is None and y_axis is None:
- xcalib = NoCalibration()
- ycalib = NoCalibration()
- else:
- if x_axis is None:
- # no calibration
- x_axis = numpy.arange(image.shape[1])
- elif numpy.isscalar(x_axis) or len(x_axis) == 1:
- # constant axis
- x_axis = x_axis * numpy.ones((image.shape[1], ))
- elif len(x_axis) == 2:
- # linear calibration
- x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
-
- if y_axis is None:
- y_axis = numpy.arange(image.shape[0])
- elif numpy.isscalar(y_axis) or len(y_axis) == 1:
- y_axis = y_axis * numpy.ones((image.shape[0], ))
- elif len(y_axis) == 2:
- y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
-
- xcalib = ArrayCalibration(x_axis)
- ycalib = ArrayCalibration(y_axis)
-
- self._plot.setData(image)
- if xcalib.is_affine():
- xorigin, xscale = xcalib(0), xcalib.get_slope()
- else:
- _logger.warning("Unsupported complex image X axis calibration")
- xorigin, xscale = 0., 1.
-
- if ycalib.is_affine():
- yorigin, yscale = ycalib(0), ycalib.get_slope()
- else:
- _logger.warning("Unsupported complex image Y axis calibration")
- yorigin, yscale = 0., 1.
-
- self._plot.setOrigin((xorigin, yorigin))
- self._plot.setScale((xscale, yscale))
-
- if self.__title:
- title = self.__title
- if len(self.__signals_names) > 1:
- # Append dataset name only when there is many datasets
- title += '\n' + self.__signals_names[auxSigIdx]
- else:
- title = self.__signals_names[auxSigIdx]
- self._plot.setGraphTitle(title)
- self._plot.getXAxis().setLabel(self.__x_axis_name)
- self._plot.getYAxis().setLabel(self.__y_axis_name)
-
- def clear(self):
- old = self._selector.blockSignals(True)
- self._selector.clear()
- self._selector.blockSignals(old)
- self._plot.setData(None)
-
-
-class ArrayStackPlot(qt.QWidget):
- """
- Widget for plotting a n-D array (n >= 3) as a stack of images.
- Three axis arrays can be provided to calibrate the axes.
-
- The signal array can have an arbitrary number of dimensions, the only
- limitation being that the last 3 dimensions must have the same length as
- the axes arrays.
-
- Sliders are provided to select indices on the first (n - 3) dimensions of
- the signal array, and the plot is updated to load the stack corresponding
- to the selection.
- """
- def __init__(self, parent=None):
- """
-
- :param parent: Parent QWidget
- """
- super(ArrayStackPlot, self).__init__(parent)
-
- self.__signal = None
- self.__signal_name = None
- # the Z, Y, X axes apply to the last three dimensions of the signal
- # (in that order)
- self.__z_axis = None
- self.__z_axis_name = None
- self.__y_axis = None
- self.__y_axis_name = None
- self.__x_axis = None
- self.__x_axis_name = None
-
- self._stack_view = StackView(self)
- maskToolWidget = self._stack_view.getPlotWidget().getMaskToolsDockWidget().widget()
- maskToolWidget.setItemMaskUpdated(True)
-
- self._hline = qt.QFrame(self)
- self._hline.setFrameStyle(qt.QFrame.HLine)
- self._hline.setFrameShadow(qt.QFrame.Sunken)
- self._legend = qt.QLabel(self)
- self._selector = NumpyAxesSelector(self)
- self._selector.setNamedAxesSelectorVisibility(False)
- self.__selector_is_connected = False
-
- layout = qt.QVBoxLayout()
- layout.addWidget(self._stack_view)
- layout.addWidget(self._hline)
- layout.addWidget(self._legend)
- layout.addWidget(self._selector)
-
- self.setLayout(layout)
-
- def getStackView(self):
- """Returns the plot used for the display
-
- :rtype: StackView
- """
- return self._stack_view
-
- def setStackData(self, signal,
- x_axis=None, y_axis=None, z_axis=None,
- signal_name=None,
- xlabel=None, ylabel=None, zlabel=None,
- title=None):
- """
-
- :param signal: n-D dataset, whose last 3 dimensions are used as the
- 3D stack values.
- :param x_axis: 1-D dataset used as the image's x coordinates. If
- provided, its lengths must be equal to the length of the last
- dimension of ``signal``.
- :param y_axis: 1-D dataset used as the image's y. If provided,
- its lengths must be equal to the length of the 2nd to last
- dimension of ``signal``.
- :param z_axis: 1-D dataset used as the image's z. If provided,
- its lengths must be equal to the length of the 3rd to last
- dimension of ``signal``.
- :param signal_name: Label used in the legend
- :param xlabel: Label for X axis
- :param ylabel: Label for Y axis
- :param zlabel: Label for Z axis
- :param title: Graph title
- """
- if self.__selector_is_connected:
- self._selector.selectionChanged.disconnect(self._updateStack)
- self.__selector_is_connected = False
-
- self.__signal = signal
- self.__signal_name = signal_name or ""
- self.__x_axis = x_axis
- self.__x_axis_name = xlabel
- self.__y_axis = y_axis
- self.__y_axis_name = ylabel
- self.__z_axis = z_axis
- self.__z_axis_name = zlabel
-
- self._selector.setData(signal)
- self._selector.setAxisNames(["Y", "X", "Z"])
-
- self._stack_view.setGraphTitle(title or "")
- # by default, the z axis is the image position (dimension not plotted)
- self._stack_view.getPlotWidget().getXAxis().setLabel(self.__x_axis_name or "X")
- self._stack_view.getPlotWidget().getYAxis().setLabel(self.__y_axis_name or "Y")
-
- self._updateStack()
-
- ndims = len(signal.shape)
- self._stack_view.setFirstStackDimension(ndims - 3)
-
- # the legend label shows the selection slice producing the volume
- # (only interesting for ndim > 3)
- if ndims > 3:
- self._selector.setVisible(True)
- self._legend.setVisible(True)
- self._hline.setVisible(True)
- else:
- self._selector.setVisible(False)
- self._legend.setVisible(False)
- self._hline.setVisible(False)
-
- if not self.__selector_is_connected:
- self._selector.selectionChanged.connect(self._updateStack)
- self.__selector_is_connected = True
-
- @staticmethod
- def _get_origin_scale(axis):
- """Assuming axis is a regularly spaced 1D array,
- return a tuple (origin, scale) where:
- - origin = axis[0]
- - scale = (axis[n-1] - axis[0]) / (n -1)
- :param axis: 1D numpy array
- :return: Tuple (axis[0], (axis[-1] - axis[0]) / (len(axis) - 1))
- """
- return axis[0], (axis[-1] - axis[0]) / (len(axis) - 1)
-
- def _updateStack(self):
- """Update displayed stack according to the current axes selector
- data."""
- stk = self._selector.selectedData()
- x_axis = self.__x_axis
- y_axis = self.__y_axis
- z_axis = self.__z_axis
-
- calibrations = []
- for axis in [z_axis, y_axis, x_axis]:
-
- if axis is None:
- calibrations.append(NoCalibration())
- elif len(axis) == 2:
- calibrations.append(
- LinearCalibration(y_intercept=axis[0],
- slope=axis[1]))
- else:
- calibrations.append(ArrayCalibration(axis))
-
- legend = self.__signal_name + "["
- for sl in self._selector.selection():
- if sl == slice(None):
- legend += ":, "
- else:
- legend += str(sl) + ", "
- legend = legend[:-2] + "]"
- self._legend.setText("Displayed data: " + legend)
-
- self._stack_view.setStack(stk, calibrations=calibrations)
- self._stack_view.setLabels(
- labels=[self.__z_axis_name,
- self.__y_axis_name,
- self.__x_axis_name])
-
- def clear(self):
- old = self._selector.blockSignals(True)
- self._selector.clear()
- self._selector.blockSignals(old)
- self._stack_view.clear()
-
-
-class ArrayVolumePlot(qt.QWidget):
- """
- Widget for plotting a n-D array (n >= 3) as a 3D scalar field.
- Three axis arrays can be provided to calibrate the axes.
-
- The signal array can have an arbitrary number of dimensions, the only
- limitation being that the last 3 dimensions must have the same length as
- the axes arrays.
-
- Sliders are provided to select indices on the first (n - 3) dimensions of
- the signal array, and the plot is updated to load the stack corresponding
- to the selection.
- """
- def __init__(self, parent=None):
- """
-
- :param parent: Parent QWidget
- """
- super(ArrayVolumePlot, self).__init__(parent)
-
- self.__signal = None
- self.__signal_name = None
- # the Z, Y, X axes apply to the last three dimensions of the signal
- # (in that order)
- self.__z_axis = None
- self.__z_axis_name = None
- self.__y_axis = None
- self.__y_axis_name = None
- self.__x_axis = None
- self.__x_axis_name = None
-
- from ._VolumeWindow import VolumeWindow
-
- self._view = VolumeWindow(self)
-
- self._hline = qt.QFrame(self)
- self._hline.setFrameStyle(qt.QFrame.HLine)
- self._hline.setFrameShadow(qt.QFrame.Sunken)
- self._legend = qt.QLabel(self)
- self._selector = NumpyAxesSelector(self)
- self._selector.setNamedAxesSelectorVisibility(False)
- self.__selector_is_connected = False
-
- layout = qt.QVBoxLayout()
- layout.addWidget(self._view)
- layout.addWidget(self._hline)
- layout.addWidget(self._legend)
- layout.addWidget(self._selector)
-
- self.setLayout(layout)
-
- def getVolumeView(self):
- """Returns the plot used for the display
-
- :rtype: SceneWindow
- """
- return self._view
-
- def setData(self, signal,
- x_axis=None, y_axis=None, z_axis=None,
- signal_name=None,
- xlabel=None, ylabel=None, zlabel=None,
- title=None):
- """
-
- :param signal: n-D dataset, whose last 3 dimensions are used as the
- 3D stack values.
- :param x_axis: 1-D dataset used as the image's x coordinates. If
- provided, its lengths must be equal to the length of the last
- dimension of ``signal``.
- :param y_axis: 1-D dataset used as the image's y. If provided,
- its lengths must be equal to the length of the 2nd to last
- dimension of ``signal``.
- :param z_axis: 1-D dataset used as the image's z. If provided,
- its lengths must be equal to the length of the 3rd to last
- dimension of ``signal``.
- :param signal_name: Label used in the legend
- :param xlabel: Label for X axis
- :param ylabel: Label for Y axis
- :param zlabel: Label for Z axis
- :param title: Graph title
- """
- if self.__selector_is_connected:
- self._selector.selectionChanged.disconnect(self._updateVolume)
- self.__selector_is_connected = False
-
- self.__signal = signal
- self.__signal_name = signal_name or ""
- self.__x_axis = x_axis
- self.__x_axis_name = xlabel
- self.__y_axis = y_axis
- self.__y_axis_name = ylabel
- self.__z_axis = z_axis
- self.__z_axis_name = zlabel
-
- self._selector.setData(signal)
- self._selector.setAxisNames(["Y", "X", "Z"])
-
- self._updateVolume()
-
- # the legend label shows the selection slice producing the volume
- # (only interesting for ndim > 3)
- if signal.ndim > 3:
- self._selector.setVisible(True)
- self._legend.setVisible(True)
- self._hline.setVisible(True)
- else:
- self._selector.setVisible(False)
- self._legend.setVisible(False)
- self._hline.setVisible(False)
-
- if not self.__selector_is_connected:
- self._selector.selectionChanged.connect(self._updateVolume)
- self.__selector_is_connected = True
-
- def _updateVolume(self):
- """Update displayed stack according to the current axes selector
- data."""
- x_axis = self.__x_axis
- y_axis = self.__y_axis
- z_axis = self.__z_axis
-
- offset = []
- scale = []
- for axis in [x_axis, y_axis, z_axis]:
- if axis is None:
- calibration = NoCalibration()
- elif len(axis) == 2:
- calibration = LinearCalibration(
- y_intercept=axis[0], slope=axis[1])
- else:
- calibration = ArrayCalibration(axis)
- if not calibration.is_affine():
- _logger.warning("Axis has not linear values, ignored")
- offset.append(0.)
- scale.append(1.)
- else:
- offset.append(calibration(0))
- scale.append(calibration.get_slope())
-
- legend = self.__signal_name + "["
- for sl in self._selector.selection():
- if sl == slice(None):
- legend += ":, "
- else:
- legend += str(sl) + ", "
- legend = legend[:-2] + "]"
- self._legend.setText("Displayed data: " + legend)
-
- # Update SceneWidget
- data = self._selector.selectedData()
-
- volumeView = self.getVolumeView()
- volumeView.setData(data, offset=offset, scale=scale)
- volumeView.setAxesLabels(
- self.__x_axis_name, self.__y_axis_name, self.__z_axis_name)
-
- def clear(self):
- old = self._selector.blockSignals(True)
- self._selector.clear()
- self._selector.blockSignals(old)
- self.getVolumeView().clear()
diff --git a/silx/gui/data/RecordTableView.py b/silx/gui/data/RecordTableView.py
deleted file mode 100644
index 2c0011a..0000000
--- a/silx/gui/data/RecordTableView.py
+++ /dev/null
@@ -1,447 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module define model and widget to display 1D slices from numpy
-array using compound data types or hdf5 databases.
-"""
-from __future__ import division
-
-import itertools
-import numpy
-from silx.gui import qt
-import silx.io
-from .TextFormatter import TextFormatter
-from silx.gui.widgets.TableWidget import CopySelectedCellsAction
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "29/08/2018"
-
-
-class _MultiLineItem(qt.QItemDelegate):
- """Draw a multiline text without hiding anything.
-
- The paint method display a cell without any wrap. And an editor is
- available to scroll into the selected cell.
- """
-
- def __init__(self, parent=None):
- """
- Constructor
-
- :param qt.QWidget parent: Parent of the widget
- """
- qt.QItemDelegate.__init__(self, parent)
- self.__textOptions = qt.QTextOption()
- self.__textOptions.setFlags(qt.QTextOption.IncludeTrailingSpaces |
- qt.QTextOption.ShowTabsAndSpaces)
- self.__textOptions.setWrapMode(qt.QTextOption.NoWrap)
- self.__textOptions.setAlignment(qt.Qt.AlignTop | qt.Qt.AlignLeft)
-
- def paint(self, painter, option, index):
- """
- Write multiline text without using any wrap or any alignment according
- to the cell size.
-
- :param qt.QPainter painter: Painter context used to displayed the cell
- :param qt.QStyleOptionViewItem option: Control how the editor is shown
- :param qt.QIndex index: Index of the data to display
- """
- painter.save()
-
- # set colors
- painter.setPen(qt.QPen(qt.Qt.NoPen))
- if option.state & qt.QStyle.State_Selected:
- brush = option.palette.highlight()
- painter.setBrush(brush)
- else:
- brush = index.data(qt.Qt.BackgroundRole)
- if brush is None:
- # default background color for a cell
- brush = qt.Qt.white
- painter.setBrush(brush)
- painter.drawRect(option.rect)
-
- if index.isValid():
- if option.state & qt.QStyle.State_Selected:
- brush = option.palette.highlightedText()
- else:
- brush = index.data(qt.Qt.ForegroundRole)
- if brush is None:
- brush = option.palette.text()
- painter.setPen(qt.QPen(brush.color()))
- text = index.data(qt.Qt.DisplayRole)
- painter.drawText(qt.QRectF(option.rect), text, self.__textOptions)
-
- painter.restore()
-
- def createEditor(self, parent, option, index):
- """
- Returns the widget used to edit the item specified by index for editing.
-
- We use it not to edit the content but to show the content with a
- convenient scroll bar.
-
- :param qt.QWidget parent: Parent of the widget
- :param qt.QStyleOptionViewItem option: Control how the editor is shown
- :param qt.QIndex index: Index of the data to display
- """
- if not index.isValid():
- return super(_MultiLineItem, self).createEditor(parent, option, index)
-
- editor = qt.QTextEdit(parent)
- editor.setReadOnly(True)
- return editor
-
- def setEditorData(self, editor, index):
- """
- Read data from the model and feed the editor.
-
- :param qt.QWidget editor: Editor widget
- :param qt.QIndex index: Index of the data to display
- """
- text = index.model().data(index, qt.Qt.EditRole)
- editor.setText(text)
-
- def updateEditorGeometry(self, editor, option, index):
- """
- Update the geometry of the editor according to the changes of the view.
-
- :param qt.QWidget editor: Editor widget
- :param qt.QStyleOptionViewItem option: Control how the editor is shown
- :param qt.QIndex index: Index of the data to display
- """
- editor.setGeometry(option.rect)
-
-
-class RecordTableModel(qt.QAbstractTableModel):
- """This data model provides access to 1D slices from numpy array using
- compound data types or hdf5 databases.
-
- Each entries are displayed in a single row, and each columns contain a
- specific field of the compound type.
-
- It also allows to display 1D arrays of simple data types.
- array.
-
- :param qt.QObject parent: Parent object
- :param numpy.ndarray data: A numpy array or a h5py dataset
- """
-
- MAX_NUMBER_OF_ROWS = 10e6
- """Maximum number of display values of the dataset"""
-
- def __init__(self, parent=None, data=None):
- qt.QAbstractTableModel.__init__(self, parent)
-
- self.__data = None
- self.__is_array = False
- self.__fields = None
- self.__formatter = None
- self.__editFormatter = None
- self.setFormatter(TextFormatter(self))
-
- # set _data
- self.setArrayData(data)
-
- # Methods to be implemented to subclass QAbstractTableModel
- def rowCount(self, parent_idx=None):
- """Returns number of rows to be displayed in table"""
- if self.__data is None:
- return 0
- elif not self.__is_array:
- return 1
- else:
- return min(len(self.__data), self.MAX_NUMBER_OF_ROWS)
-
- def columnCount(self, parent_idx=None):
- """Returns number of columns to be displayed in table"""
- if self.__fields is None:
- return 1
- else:
- return len(self.__fields)
-
- def __clippedData(self, role=qt.Qt.DisplayRole):
- """Return data for cells representing clipped data"""
- if role == qt.Qt.DisplayRole:
- return "..."
- elif role == qt.Qt.ToolTipRole:
- return "Dataset is too large: display is clipped"
- else:
- return None
-
- def data(self, index, role=qt.Qt.DisplayRole):
- """QAbstractTableModel method to access data values
- in the format ready to be displayed"""
- if not index.isValid():
- return None
-
- if self.__data is None:
- return None
-
- # Special display of one before last data for clipped table
- if self.__isClipped() and index.row() == self.rowCount() - 2:
- return self.__clippedData(role)
-
- if self.__is_array:
- row = index.row()
- if row >= self.rowCount():
- return None
- elif self.__isClipped() and row == self.rowCount() - 1:
- # Clipped array, display last value at the end
- data = self.__data[-1]
- else:
- data = self.__data[row]
- else:
- if index.row() > 0:
- return None
- data = self.__data
-
- if self.__fields is not None:
- if index.column() >= len(self.__fields):
- return None
- key = self.__fields[index.column()][1]
- data = data[key[0]]
- if len(key) > 1:
- data = data[key[1]]
-
- # no dtype in case of 1D array of unicode objects (#2093)
- dtype = getattr(data, "dtype", None)
-
- if role == qt.Qt.DisplayRole:
- return self.__formatter.toString(data, dtype=dtype)
- elif role == qt.Qt.EditRole:
- return self.__editFormatter.toString(data, dtype=dtype)
- return None
-
- def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
- """Returns the 0-based row or column index, for display in the
- horizontal and vertical headers"""
- if section == -1:
- # PyQt4 send -1 when there is columns but no rows
- return None
-
- # Handle clipping of huge tables
- if (self.__isClipped() and
- orientation == qt.Qt.Vertical and
- section == self.rowCount() - 2):
- return self.__clippedData(role)
-
- if role == qt.Qt.DisplayRole:
- if orientation == qt.Qt.Vertical:
- if not self.__is_array:
- return "Scalar"
- elif section == self.MAX_NUMBER_OF_ROWS - 1:
- return str(len(self.__data) - 1)
- else:
- return str(section)
- if orientation == qt.Qt.Horizontal:
- if self.__fields is None:
- if section == 0:
- return "Data"
- else:
- return None
- else:
- if section < len(self.__fields):
- return self.__fields[section][0]
- else:
- return None
- return None
-
- def flags(self, index):
- """QAbstractTableModel method to inform the view whether data
- is editable or not.
- """
- return qt.QAbstractTableModel.flags(self, index)
-
- def __isClipped(self) -> bool:
- """Returns whether the displayed array is clipped or not"""
- return self.__data is not None and self.__is_array and len(self.__data) > self.MAX_NUMBER_OF_ROWS
-
- def setArrayData(self, data):
- """Set the data array and the viewing perspective.
-
- You can set ``copy=False`` if you need more performances, when dealing
- with a large numpy array. In this case, a simple reference to the data
- is used to access the data, rather than a copy of the array.
-
- .. warning::
-
- Any change to the data model will affect your original data
- array, when using a reference rather than a copy..
-
- :param data: 1D numpy array, or any object that can be
- converted to a numpy array using ``numpy.array(data)`` (e.g.
- a nested sequence).
- """
- if qt.qVersion() > "4.6":
- self.beginResetModel()
-
- self.__data = data
- if isinstance(data, numpy.ndarray):
- self.__is_array = True
- elif silx.io.is_dataset(data) and data.shape != tuple():
- self.__is_array = True
- else:
- self.__is_array = False
-
- self.__fields = []
- if data is not None:
- if data.dtype.fields is not None:
- fields = sorted(data.dtype.fields.items(), key=lambda e: e[1][1])
- for name, (dtype, _index) in fields:
- if dtype.shape != tuple():
- keys = itertools.product(*[range(x) for x in dtype.shape])
- for key in keys:
- label = "%s%s" % (name, list(key))
- array_key = (name, key)
- self.__fields.append((label, array_key))
- else:
- self.__fields.append((name, (name,)))
- else:
- self.__fields = None
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
- else:
- self.reset()
-
- def arrayData(self):
- """Returns the internal data.
-
- :rtype: numpy.ndarray of h5py.Dataset
- """
- return self.__data
-
- def setFormatter(self, formatter):
- """Set the formatter object to be used to display data from the model
-
- :param TextFormatter formatter: Formatter to use
- """
- if formatter is self.__formatter:
- return
-
- if qt.qVersion() > "4.6":
- self.beginResetModel()
-
- if self.__formatter is not None:
- self.__formatter.formatChanged.disconnect(self.__formatChanged)
-
- self.__formatter = formatter
- self.__editFormatter = TextFormatter(formatter)
- self.__editFormatter.setUseQuoteForText(False)
-
- if self.__formatter is not None:
- self.__formatter.formatChanged.connect(self.__formatChanged)
-
- if qt.qVersion() > "4.6":
- self.endResetModel()
- else:
- self.reset()
-
- def getFormatter(self):
- """Returns the text formatter used.
-
- :rtype: TextFormatter
- """
- return self.__formatter
-
- def __formatChanged(self):
- """Called when the format changed.
- """
- self.__editFormatter = TextFormatter(self, self.getFormatter())
- self.__editFormatter.setUseQuoteForText(False)
- self.reset()
-
-
-class _ShowEditorProxyModel(qt.QIdentityProxyModel):
- """
- Allow to custom the flag edit of the model
- """
-
- def __init__(self, parent=None):
- """
- Constructor
-
- :param qt.QObject arent: parent object
- """
- super(_ShowEditorProxyModel, self).__init__(parent)
- self.__forceEditable = False
-
- def flags(self, index):
- flag = qt.QIdentityProxyModel.flags(self, index)
- if self.__forceEditable:
- flag = flag | qt.Qt.ItemIsEditable
- return flag
-
- def forceCellEditor(self, show):
- """
- Enable the editable flag to allow to display cell editor.
- """
- if self.__forceEditable == show:
- return
- self.beginResetModel()
- self.__forceEditable = show
- self.endResetModel()
-
-
-class RecordTableView(qt.QTableView):
- """TableView using DatabaseTableModel as default model.
- """
- def __init__(self, parent=None):
- """
- Constructor
-
- :param qt.QWidget parent: parent QWidget
- """
- qt.QTableView.__init__(self, parent)
-
- model = _ShowEditorProxyModel(self)
- self._model = RecordTableModel()
- model.setSourceModel(self._model)
- self.setModel(model)
-
- self.__multilineView = _MultiLineItem(self)
- self.setEditTriggers(qt.QAbstractItemView.AllEditTriggers)
- self._copyAction = CopySelectedCellsAction(self)
- self.addAction(self._copyAction)
-
- def copy(self):
- self._copyAction.trigger()
-
- def setArrayData(self, data):
- model = self.model()
- sourceModel = model.sourceModel()
- sourceModel.setArrayData(data)
-
- if data is not None:
- if issubclass(data.dtype.type, (numpy.string_, numpy.unicode_)):
- # TODO it would be nice to also fix fields
- # but using it only for string array is already very useful
- self.setItemDelegateForColumn(0, self.__multilineView)
- model.forceCellEditor(True)
- else:
- self.setItemDelegateForColumn(0, None)
- model.forceCellEditor(False)
diff --git a/silx/gui/data/TextFormatter.py b/silx/gui/data/TextFormatter.py
deleted file mode 100644
index 8fd7c7c..0000000
--- a/silx/gui/data/TextFormatter.py
+++ /dev/null
@@ -1,395 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides a class sharred by widget from the
-data module to format data as text in the same way."""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "24/07/2018"
-
-import logging
-import numbers
-
-import numpy
-import six
-
-from silx.gui import qt
-
-import h5py
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TextFormatter(qt.QObject):
- """Formatter to convert data to string.
-
- The method :meth:`toString` returns a formatted string from an input data
- using parameters set to this object.
-
- It support most python and numpy data, expecting dictionary. Unsupported
- data are displayed using the string representation of the object (`str`).
-
- It provides a set of parameters to custom the formatting of integer and
- float values (:meth:`setIntegerFormat`, :meth:`setFloatFormat`).
-
- It also allows to custom the use of quotes to display text data
- (:meth:`setUseQuoteForText`), and custom unit used to display imaginary
- numbers (:meth:`setImaginaryUnit`).
-
- The object emit an event `formatChanged` every time a parametter is
- changed.
- """
-
- formatChanged = qt.Signal()
- """Emitted when properties of the formatter change."""
-
- def __init__(self, parent=None, formatter=None):
- """
- Constructor
-
- :param qt.QObject parent: Owner of the object
- :param TextFormatter formatter: Instantiate this object from the
- formatter
- """
- qt.QObject.__init__(self, parent)
- if formatter is not None:
- self.__integerFormat = formatter.integerFormat()
- self.__floatFormat = formatter.floatFormat()
- self.__useQuoteForText = formatter.useQuoteForText()
- self.__imaginaryUnit = formatter.imaginaryUnit()
- self.__enumFormat = formatter.enumFormat()
- else:
- self.__integerFormat = "%d"
- self.__floatFormat = "%g"
- self.__useQuoteForText = True
- self.__imaginaryUnit = u"j"
- self.__enumFormat = u"%(name)s(%(value)d)"
-
- def integerFormat(self):
- """Returns the format string controlling how the integer data
- are formated by this object.
-
- This is the C-style format string used by python when formatting
- strings with the modulus operator.
-
- :rtype: str
- """
- return self.__integerFormat
-
- def setIntegerFormat(self, value):
- """Set format string controlling how the integer data are
- formated by this object.
-
- :param str value: Format string (e.g. "%d", "%i", "%08i").
- This is the C-style format string used by python when formatting
- strings with the modulus operator.
- """
- if self.__integerFormat == value:
- return
- self.__integerFormat = value
- self.formatChanged.emit()
-
- def floatFormat(self):
- """Returns the format string controlling how the floating-point data
- are formated by this object.
-
- This is the C-style format string used by python when formatting
- strings with the modulus operator.
-
- :rtype: str
- """
- return self.__floatFormat
-
- def setFloatFormat(self, value):
- """Set format string controlling how the floating-point data are
- formated by this object.
-
- :param str value: Format string (e.g. "%.3f", "%d", "%-10.2f",
- "%10.3e").
- This is the C-style format string used by python when formatting
- strings with the modulus operator.
- """
- if self.__floatFormat == value:
- return
- self.__floatFormat = value
- self.formatChanged.emit()
-
- def useQuoteForText(self):
- """Returns true if the string data are formatted using double quotes.
-
- Else, no quotes are used.
- """
- return self.__integerFormat
-
- def setUseQuoteForText(self, useQuote):
- """Set the use of quotes to delimit string data.
-
- :param bool useQuote: True to use quotes.
- """
- if self.__useQuoteForText == useQuote:
- return
- self.__useQuoteForText = useQuote
- self.formatChanged.emit()
-
- def imaginaryUnit(self):
- """Returns the unit display for imaginary numbers.
-
- :rtype: str
- """
- return self.__imaginaryUnit
-
- def setImaginaryUnit(self, imaginaryUnit):
- """Set the unit display for imaginary numbers.
-
- :param str imaginaryUnit: Unit displayed after imaginary numbers
- """
- if self.__imaginaryUnit == imaginaryUnit:
- return
- self.__imaginaryUnit = imaginaryUnit
- self.formatChanged.emit()
-
- def setEnumFormat(self, value):
- """Set format string controlling how the enum data are
- formated by this object.
-
- :param str value: Format string (e.g. "%(name)s(%(value)d)").
- This is the C-style format string used by python when formatting
- strings with the modulus operator.
- """
- if self.__enumFormat == value:
- return
- self.__enumFormat = value
- self.formatChanged.emit()
-
- def enumFormat(self):
- """Returns the format string controlling how the enum data
- are formated by this object.
-
- This is the C-style format string used by python when formatting
- strings with the modulus operator.
-
- :rtype: str
- """
- return self.__enumFormat
-
- def __formatText(self, text):
- if self.__useQuoteForText:
- text = "\"%s\"" % text.replace("\\", "\\\\").replace("\"", "\\\"")
- return text
-
- def __formatBinary(self, data):
- if isinstance(data, numpy.void):
- if six.PY2:
- data = [ord(d) for d in data.data]
- else:
- data = data.item()
- if isinstance(data, numpy.ndarray):
- # Before numpy 1.15.0 the item API was returning a numpy array
- data = data.astype(numpy.uint8)
- else:
- # Now it is supposed to be a bytes type
- pass
- elif six.PY2:
- data = [ord(d) for d in data]
- # In python3 data is already a bytes array
- data = ["\\x%02X" % d for d in data]
- if self.__useQuoteForText:
- return "b\"%s\"" % "".join(data)
- else:
- return "".join(data)
-
- def __formatSafeAscii(self, data):
- if six.PY2:
- data = [ord(d) for d in data]
- data = [chr(d) if (d > 0x20 and d < 0x7F) else "\\x%02X" % d for d in data]
- if self.__useQuoteForText:
- data = [c if c != '"' else "\\" + c for c in data]
- return "b\"%s\"" % "".join(data)
- else:
- return "".join(data)
-
- def __formatCharString(self, data):
- """Format text of char.
-
- From the specifications we expect to have ASCII, but we also allow
- CP1252 in some ceases as fallback.
-
- If no encoding fits, it will display a readable ASCII chars, with
- escaped chars (using the python syntax) for non decoded characters.
-
- :param data: A binary string of char expected in ASCII
- :rtype: str
- """
- try:
- text = "%s" % data.decode("ascii")
- return self.__formatText(text)
- except UnicodeDecodeError:
- # Here we can spam errors, this is definitly a badly
- # generated file
- _logger.error("Invalid ASCII string %s.", data)
- if data == b"\xB0":
- _logger.error("Fallback using cp1252 encoding")
- return self.__formatText(u"\u00B0")
- return self.__formatSafeAscii(data)
-
- def __formatH5pyObject(self, data, dtype):
- # That's an HDF5 object
- ref = h5py.check_dtype(ref=dtype)
- if ref is not None:
- if bool(data):
- return "REF"
- else:
- return "NULL_REF"
- vlen = h5py.check_dtype(vlen=dtype)
- if vlen is not None:
- if vlen == six.text_type:
- # HDF5 UTF8
- # With h5py>=3 reading dataset returns bytes
- if isinstance(data, (bytes, numpy.bytes_)):
- try:
- data = data.decode("utf-8")
- except UnicodeDecodeError:
- self.__formatSafeAscii(data)
- return self.__formatText(data)
- elif vlen == six.binary_type:
- # HDF5 ASCII
- return self.__formatCharString(data)
- elif isinstance(vlen, numpy.dtype):
- return self.toString(data, vlen)
- return None
-
- def toString(self, data, dtype=None):
- """Format a data into a string using formatter options
-
- :param object data: Data to render
- :param dtype: enforce a dtype (mostly used to remember the h5py dtype,
- special h5py dtypes are not propagated from array to items)
- :rtype: str
- """
- if isinstance(data, tuple):
- text = [self.toString(d) for d in data]
- return "(" + " ".join(text) + ")"
- elif isinstance(data, list):
- text = [self.toString(d) for d in data]
- return "[" + " ".join(text) + "]"
- elif isinstance(data, numpy.ndarray):
- if dtype is None:
- dtype = data.dtype
- if data.shape == ():
- # it is a scaler
- return self.toString(data[()], dtype)
- else:
- text = [self.toString(d, dtype) for d in data]
- return "[" + " ".join(text) + "]"
- if dtype is not None and dtype.kind == 'O':
- text = self.__formatH5pyObject(data, dtype)
- if text is not None:
- return text
- elif isinstance(data, numpy.void):
- if dtype is None:
- dtype = data.dtype
- if dtype.fields is not None:
- text = []
- for index, field in enumerate(dtype.fields.items()):
- text.append(field[0] + ":" + self.toString(data[index], field[1][0]))
- return "(" + " ".join(text) + ")"
- return self.__formatBinary(data)
- elif isinstance(data, (numpy.unicode_, six.text_type)):
- return self.__formatText(data)
- elif isinstance(data, (numpy.string_, six.binary_type)):
- if dtype is None and hasattr(data, "dtype"):
- dtype = data.dtype
- if dtype is not None:
- # Maybe a sub item from HDF5
- if dtype.kind == 'S':
- return self.__formatCharString(data)
- elif dtype.kind == 'O':
- text = self.__formatH5pyObject(data, dtype)
- if text is not None:
- return text
- try:
- # Try ascii/utf-8
- text = "%s" % data.decode("utf-8")
- return self.__formatText(text)
- except UnicodeDecodeError:
- pass
- return self.__formatBinary(data)
- elif isinstance(data, six.string_types):
- text = "%s" % data
- return self.__formatText(text)
- elif isinstance(data, (numpy.integer)):
- if dtype is None:
- dtype = data.dtype
- enumType = h5py.check_dtype(enum=dtype)
- if enumType is not None:
- for key, value in enumType.items():
- if value == data:
- result = {}
- result["name"] = key
- result["value"] = data
- return self.__enumFormat % result
- return self.__integerFormat % data
- elif isinstance(data, (numbers.Integral)):
- return self.__integerFormat % data
- elif isinstance(data, (numbers.Real, numpy.floating)):
- # It have to be done before complex checking
- return self.__floatFormat % data
- elif isinstance(data, (numpy.complexfloating, numbers.Complex)):
- text = ""
- if data.real != 0:
- text += self.__floatFormat % data.real
- if data.real != 0 and data.imag != 0:
- if data.imag < 0:
- template = self.__floatFormat + " - " + self.__floatFormat + self.__imaginaryUnit
- params = (data.real, -data.imag)
- else:
- template = self.__floatFormat + " + " + self.__floatFormat + self.__imaginaryUnit
- params = (data.real, data.imag)
- else:
- if data.imag != 0:
- template = self.__floatFormat + self.__imaginaryUnit
- params = (data.imag)
- else:
- template = self.__floatFormat
- params = (data.real)
- return template % params
- elif isinstance(data, h5py.h5r.Reference):
- dtype = h5py.special_dtype(ref=h5py.Reference)
- text = self.__formatH5pyObject(data, dtype)
- return text
- elif isinstance(data, h5py.h5r.RegionReference):
- dtype = h5py.special_dtype(ref=h5py.RegionReference)
- text = self.__formatH5pyObject(data, dtype)
- return text
- elif isinstance(data, numpy.object_) or dtype is not None:
- if dtype is None:
- dtype = data.dtype
- text = self.__formatH5pyObject(data, dtype)
- if text is not None:
- return text
- # That's a numpy object
- return str(data)
- return str(data)
diff --git a/silx/gui/data/test/__init__.py b/silx/gui/data/test/__init__.py
deleted file mode 100644
index 08c044b..0000000
--- a/silx/gui/data/test/__init__.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-import unittest
-
-from . import test_arraywidget
-from . import test_numpyaxesselector
-from . import test_dataviewer
-from . import test_textformatter
-
-__authors__ = ["V. Valls", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/01/2017"
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- [test_arraywidget.suite(),
- test_numpyaxesselector.suite(),
- test_dataviewer.suite(),
- test_textformatter.suite(),
- ])
- return test_suite
diff --git a/silx/gui/data/test/test_arraywidget.py b/silx/gui/data/test/test_arraywidget.py
deleted file mode 100644
index 87081ed..0000000
--- a/silx/gui/data/test/test_arraywidget.py
+++ /dev/null
@@ -1,329 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-import os
-import tempfile
-import unittest
-
-import numpy
-
-from silx.gui import qt
-from silx.gui.data import ArrayTableWidget
-from silx.gui.data.ArrayTableModel import ArrayTableModel
-from silx.gui.utils.testutils import TestCaseQt
-
-import h5py
-
-
-class TestArrayWidget(TestCaseQt):
- """Basic test for ArrayTableWidget with a numpy array"""
- def setUp(self):
- super(TestArrayWidget, self).setUp()
- self.aw = ArrayTableWidget.ArrayTableWidget()
-
- def tearDown(self):
- del self.aw
- super(TestArrayWidget, self).tearDown()
-
- def testShow(self):
- """test for errors"""
- self.aw.show()
- self.qWaitForWindowExposed(self.aw)
-
- def testSetData0D(self):
- a = 1
- self.aw.setArrayData(a)
- b = self.aw.getData(copy=True)
-
- self.assertTrue(numpy.array_equal(a, b))
-
- # scalar/0D data has no frame index
- self.assertEqual(len(self.aw.model._index), 0)
- # and no perspective
- self.assertEqual(len(self.aw.model._perspective), 0)
-
- def testSetData1D(self):
- a = [1, 2]
- self.aw.setArrayData(a)
- b = self.aw.getData(copy=True)
-
- self.assertTrue(numpy.array_equal(a, b))
-
- # 1D data has no frame index
- self.assertEqual(len(self.aw.model._index), 0)
- # and no perspective
- self.assertEqual(len(self.aw.model._perspective), 0)
-
- def testSetData4D(self):
- a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250),
- (5, 5, 5, 10))
- self.aw.setArrayData(a)
-
- # default perspective (0, 1)
- self.assertEqual(list(self.aw.model._perspective),
- [0, 1])
- self.aw.setPerspective((1, 3))
- self.assertEqual(list(self.aw.model._perspective),
- [1, 3])
-
- b = self.aw.getData(copy=True)
- self.assertTrue(numpy.array_equal(a, b))
-
- # 4D data has a 2-tuple as frame index
- self.assertEqual(len(self.aw.model._index), 2)
- # default index is (0, 0)
- self.assertEqual(list(self.aw.model._index),
- [0, 0])
- self.aw.setFrameIndex((3, 1))
-
- self.assertEqual(list(self.aw.model._index),
- [3, 1])
-
- def testColors(self):
- a = numpy.arange(256, dtype=numpy.uint8)
- self.aw.setArrayData(a)
-
- bgcolor = numpy.empty(a.shape + (3,), dtype=numpy.uint8)
- # Black & white palette
- bgcolor[..., 0] = a
- bgcolor[..., 1] = a
- bgcolor[..., 2] = a
-
- fgcolor = numpy.bitwise_xor(bgcolor, 255)
-
- self.aw.setArrayColors(bgcolor, fgcolor)
-
- # test colors are as expected in model
- for i in range(256):
- # all RGB channels for BG equal to data value
- self.assertEqual(
- self.aw.model.data(self.aw.model.index(0, i),
- role=qt.Qt.BackgroundRole),
- qt.QColor(i, i, i),
- "Unexpected background color"
- )
-
- # all RGB channels for FG equal to XOR(data value, 255)
- self.assertEqual(
- self.aw.model.data(self.aw.model.index(0, i),
- role=qt.Qt.ForegroundRole),
- qt.QColor(i ^ 255, i ^ 255, i ^ 255),
- "Unexpected text color"
- )
-
- # test colors are reset to None when a new data array is loaded
- # with different shape
- self.aw.setArrayData(numpy.arange(300))
-
- for i in range(300):
- # all RGB channels for BG equal to data value
- self.assertIsNone(
- self.aw.model.data(self.aw.model.index(0, i),
- role=qt.Qt.BackgroundRole))
-
- def testDefaultFlagNotEditable(self):
- """editable should be False by default, in setArrayData"""
- self.aw.setArrayData([[0]])
- idx = self.aw.model.createIndex(0, 0)
- # model is editable
- self.assertFalse(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
-
- def testFlagEditable(self):
- self.aw.setArrayData([[0]], editable=True)
- idx = self.aw.model.createIndex(0, 0)
- # model is editable
- self.assertTrue(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
-
- def testFlagNotEditable(self):
- self.aw.setArrayData([[0]], editable=False)
- idx = self.aw.model.createIndex(0, 0)
- # model is editable
- self.assertFalse(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
-
- def testReferenceReturned(self):
- """when setting the data with copy=False and
- retrieving it with getData(copy=False), we should recover
- the same original object.
- """
- # n-D (n >=2)
- a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
- (10, 10, 10))
- self.aw.setArrayData(a0, copy=False)
- a1 = self.aw.getData(copy=False)
-
- self.assertIs(a0, a1)
-
- # 1D
- b0 = numpy.linspace(0.213, 1.234, 1000)
- self.aw.setArrayData(b0, copy=False)
- b1 = self.aw.getData(copy=False)
- self.assertIs(b0, b1)
-
- def testClipping(self):
- """Test clipping of large arrays"""
- self.aw.show()
- self.qWaitForWindowExposed(self.aw)
-
- data = numpy.arange(ArrayTableModel.MAX_NUMBER_OF_SECTIONS + 10)
-
- for shape in [(1, -1), (-1, 1)]:
- with self.subTest(shape=shape):
- self.aw.setArrayData(data.reshape(shape), editable=True)
- self.qapp.processEvents()
-
-
-class TestH5pyArrayWidget(TestCaseQt):
- """Basic test for ArrayTableWidget with a dataset.
-
- Test flags, for dataset open in read-only or read-write modes"""
- def setUp(self):
- super(TestH5pyArrayWidget, self).setUp()
- self.aw = ArrayTableWidget.ArrayTableWidget()
- self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
- (10, 10, 10))
- # create an h5py file with a dataset
- self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "array.h5")
- h5f = h5py.File(self.h5_fname, mode='w')
- h5f["my_array"] = self.data
- h5f["my_scalar"] = 3.14
- h5f["my_1D_array"] = numpy.array(numpy.arange(1000))
- h5f.close()
-
- def tearDown(self):
- del self.aw
- os.unlink(self.h5_fname)
- os.rmdir(self.tempdir)
- super(TestH5pyArrayWidget, self).tearDown()
-
- def testShow(self):
- self.aw.show()
- self.qWaitForWindowExposed(self.aw)
-
- def testReadOnly(self):
- """Open H5 dataset in read-only mode, ensure the model is not editable."""
- h5f = h5py.File(self.h5_fname, "r")
- a = h5f["my_array"]
- # ArrayTableModel relies on following condition
- self.assertTrue(a.file.mode == "r")
-
- self.aw.setArrayData(a, copy=False, editable=True)
-
- self.assertIsInstance(a, h5py.Dataset) # simple sanity check
- # internal representation must be a reference to original data (copy=False)
- self.assertIsInstance(self.aw.model._array, h5py.Dataset)
- self.assertTrue(self.aw.model._array.file.mode == "r")
-
- b = self.aw.getData()
- self.assertTrue(numpy.array_equal(self.data, b))
-
- # model must have detected read-only dataset and disabled editing
- self.assertFalse(self.aw.model._editable)
- idx = self.aw.model.createIndex(0, 0)
- self.assertFalse(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
-
- # force editing read-only datasets raises IOError
- self.assertRaises(IOError, self.aw.model.setData,
- idx, 123.4, role=qt.Qt.EditRole)
- h5f.close()
-
- def testReadWrite(self):
- h5f = h5py.File(self.h5_fname, "r+")
- a = h5f["my_array"]
- self.assertTrue(a.file.mode == "r+")
-
- self.aw.setArrayData(a, copy=False, editable=True)
- b = self.aw.getData(copy=False)
- self.assertTrue(numpy.array_equal(self.data, b))
-
- idx = self.aw.model.createIndex(0, 0)
- # model is editable
- self.assertTrue(
- self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
- h5f.close()
-
- def testSetData0D(self):
- h5f = h5py.File(self.h5_fname, "r+")
- a = h5f["my_scalar"]
- self.aw.setArrayData(a)
- b = self.aw.getData(copy=True)
-
- self.assertTrue(numpy.array_equal(a, b))
-
- h5f.close()
-
- def testSetData1D(self):
- h5f = h5py.File(self.h5_fname, "r+")
- a = h5f["my_1D_array"]
- self.aw.setArrayData(a)
- b = self.aw.getData(copy=True)
-
- self.assertTrue(numpy.array_equal(a, b))
-
- h5f.close()
-
- def testReferenceReturned(self):
- """when setting the data with copy=False and
- retrieving it with getData(copy=False), we should recover
- the same original object.
-
- This only works for array with at least 2D. For 1D and 0D
- arrays, a view is created at some point, which in the case
- of an hdf5 dataset creates a copy."""
- h5f = h5py.File(self.h5_fname, "r+")
-
- # n-D
- a0 = h5f["my_array"]
- self.aw.setArrayData(a0, copy=False)
- a1 = self.aw.getData(copy=False)
- self.assertIs(a0, a1)
-
- # 1D
- b0 = h5f["my_1D_array"]
- self.aw.setArrayData(b0, copy=False)
- b1 = self.aw.getData(copy=False)
- self.assertIs(b0, b1)
-
- h5f.close()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestArrayWidget))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestH5pyArrayWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/data/test/test_dataviewer.py b/silx/gui/data/test/test_dataviewer.py
deleted file mode 100644
index dd01dd6..0000000
--- a/silx/gui/data/test/test_dataviewer.py
+++ /dev/null
@@ -1,314 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "19/02/2019"
-
-import os
-import tempfile
-import unittest
-from contextlib import contextmanager
-
-import numpy
-from ..DataViewer import DataViewer
-from ..DataViews import DataView
-from .. import DataViews
-
-from silx.gui import qt
-
-from silx.gui.data.DataViewerFrame import DataViewerFrame
-from silx.gui.utils.testutils import SignalListener
-from silx.gui.utils.testutils import TestCaseQt
-
-import h5py
-
-
-class _DataViewMock(DataView):
- """Dummy view to display nothing"""
-
- def __init__(self, parent):
- DataView.__init__(self, parent)
-
- def axesNames(self, data, info):
- return []
-
- def createWidget(self, parent):
- return qt.QLabel(parent)
-
- def getDataPriority(self, data, info):
- return 0
-
-
-class AbstractDataViewerTests(TestCaseQt):
-
- def create_widget(self):
- # Avoid to raise an error when testing the full module
- self.skipTest("Not implemented")
-
- @contextmanager
- def h5_temporary_file(self):
- # create tmp file
- fd, tmp_name = tempfile.mkstemp(suffix=".h5")
- os.close(fd)
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- # create h5 data
- h5file = h5py.File(tmp_name, "w")
- h5file["data"] = data
- yield h5file
- # clean up
- h5file.close()
- os.unlink(tmp_name)
-
- def test_text_data(self):
- data_list = ["aaa", int, 8, self]
- widget = self.create_widget()
- for data in data_list:
- widget.setData(data)
- self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
-
- def test_plot_1d_data(self):
- data = numpy.arange(3 ** 1)
- data.shape = [3] * 1
- widget = self.create_widget()
- widget.setData(data)
- availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
- self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
- self.assertIn(DataViews.PLOT1D_MODE, availableModes)
-
- def test_image_data(self):
- data = numpy.arange(3 ** 2)
- data.shape = [3] * 2
- widget = self.create_widget()
- widget.setData(data)
- availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
- self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
- self.assertIn(DataViews.IMAGE_MODE, availableModes)
-
- def test_image_bool(self):
- data = numpy.zeros((10, 10), dtype=bool)
- data[::2, ::2] = True
- widget = self.create_widget()
- widget.setData(data)
- availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
- self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
- self.assertIn(DataViews.IMAGE_MODE, availableModes)
-
- def test_image_complex_data(self):
- data = numpy.arange(3 ** 2, dtype=numpy.complex64)
- data.shape = [3] * 2
- widget = self.create_widget()
- widget.setData(data)
- availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
- self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
- self.assertIn(DataViews.IMAGE_MODE, availableModes)
-
- def test_plot_3d_data(self):
- data = numpy.arange(3 ** 3)
- data.shape = [3] * 3
- widget = self.create_widget()
- widget.setData(data)
- availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
- try:
- import silx.gui.plot3d # noqa
- self.assertIn(DataViews.PLOT3D_MODE, availableModes)
- except ImportError:
- self.assertIn(DataViews.STACK_MODE, availableModes)
- self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
-
- def test_array_1d_data(self):
- data = numpy.array(["aaa"] * (3 ** 1))
- data.shape = [3] * 1
- widget = self.create_widget()
- widget.setData(data)
- self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
-
- def test_array_2d_data(self):
- data = numpy.array(["aaa"] * (3 ** 2))
- data.shape = [3] * 2
- widget = self.create_widget()
- widget.setData(data)
- self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
-
- def test_array_4d_data(self):
- data = numpy.array(["aaa"] * (3 ** 4))
- data.shape = [3] * 4
- widget = self.create_widget()
- widget.setData(data)
- self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
-
- def test_record_4d_data(self):
- data = numpy.zeros(3 ** 4, dtype='3int8, float32, (2,3)float64')
- data.shape = [3] * 4
- widget = self.create_widget()
- widget.setData(data)
- self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
-
- def test_3d_h5_dataset(self):
- with self.h5_temporary_file() as h5file:
- dataset = h5file["data"]
- widget = self.create_widget()
- widget.setData(dataset)
-
- def test_data_event(self):
- listener = SignalListener()
- widget = self.create_widget()
- widget.dataChanged.connect(listener)
- widget.setData(10)
- widget.setData(None)
- self.assertEqual(listener.callCount(), 2)
-
- def test_display_mode_event(self):
- listener = SignalListener()
- widget = self.create_widget()
- widget.displayedViewChanged.connect(listener)
- widget.setData(10)
- widget.setData(None)
- modes = [v.modeId() for v in listener.arguments(argumentIndex=0)]
- self.assertEqual(modes, [DataViews.RAW_MODE, DataViews.EMPTY_MODE])
- listener.clear()
-
- def test_change_display_mode(self):
- data = numpy.arange(10 ** 4)
- data.shape = [10] * 4
- widget = self.create_widget()
- widget.setData(data)
- widget.setDisplayMode(DataViews.PLOT1D_MODE)
- self.assertEqual(widget.displayedView().modeId(), DataViews.PLOT1D_MODE)
- widget.setDisplayMode(DataViews.IMAGE_MODE)
- self.assertEqual(widget.displayedView().modeId(), DataViews.IMAGE_MODE)
- widget.setDisplayMode(DataViews.RAW_MODE)
- self.assertEqual(widget.displayedView().modeId(), DataViews.RAW_MODE)
- widget.setDisplayMode(DataViews.EMPTY_MODE)
- self.assertEqual(widget.displayedView().modeId(), DataViews.EMPTY_MODE)
-
- def test_create_default_views(self):
- widget = self.create_widget()
- views = widget.createDefaultViews()
- self.assertTrue(len(views) > 0)
-
- def test_add_view(self):
- widget = self.create_widget()
- view = _DataViewMock(widget)
- widget.addView(view)
- self.assertTrue(view in widget.availableViews())
- self.assertTrue(view in widget.currentAvailableViews())
-
- def test_remove_view(self):
- widget = self.create_widget()
- widget.setData("foobar")
- view = widget.currentAvailableViews()[0]
- widget.removeView(view)
- self.assertTrue(view not in widget.availableViews())
- self.assertTrue(view not in widget.currentAvailableViews())
-
- def test_replace_view(self):
- widget = self.create_widget()
- view = _DataViewMock(widget)
- widget.replaceView(DataViews.RAW_MODE,
- view)
- self.assertIsNone(widget.getViewFromModeId(DataViews.RAW_MODE))
- self.assertTrue(view in widget.availableViews())
- self.assertTrue(view in widget.currentAvailableViews())
-
- def test_replace_view_in_composite(self):
- # replace a view that is a child of a composite view
- widget = self.create_widget()
- view = _DataViewMock(widget)
- replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE,
- view)
- self.assertTrue(replaced)
- nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE)
- self.assertNotIn(DataViews.NXDATA_INVALID_MODE,
- [v.modeId() for v in nxdata_view.getViews()])
- self.assertTrue(view in nxdata_view.getViews())
-
-
-class TestDataViewer(AbstractDataViewerTests):
- def create_widget(self):
- return DataViewer()
-
-
-class TestDataViewerFrame(AbstractDataViewerTests):
- def create_widget(self):
- return DataViewerFrame()
-
-
-class TestDataView(TestCaseQt):
-
- def createComplexData(self):
- line = [1, 2j, 3 + 3j, 4]
- image = [line, line, line, line]
- cube = [image, image, image, image]
- data = numpy.array(cube, dtype=numpy.complex64)
- return data
-
- def createDataViewWithData(self, dataViewClass, data):
- viewer = dataViewClass(None)
- widget = viewer.getWidget()
- viewer.setData(data)
- return widget
-
- def testCurveWithComplex(self):
- data = self.createComplexData()
- dataViewClass = DataViews._Plot1dView
- widget = self.createDataViewWithData(dataViewClass, data[0, 0])
- self.qWaitForWindowExposed(widget)
-
- def testImageWithComplex(self):
- data = self.createComplexData()
- dataViewClass = DataViews._Plot2dView
- widget = self.createDataViewWithData(dataViewClass, data[0])
- self.qWaitForWindowExposed(widget)
-
- def testCubeWithComplex(self):
- self.skipTest("OpenGL widget not yet tested")
- try:
- import silx.gui.plot3d # noqa
- except ImportError:
- self.skipTest("OpenGL not available")
- data = self.createComplexData()
- dataViewClass = DataViews._Plot3dView
- widget = self.createDataViewWithData(dataViewClass, data)
- self.qWaitForWindowExposed(widget)
-
- def testImageStackWithComplex(self):
- data = self.createComplexData()
- dataViewClass = DataViews._StackView
- widget = self.createDataViewWithData(dataViewClass, data)
- self.qWaitForWindowExposed(widget)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTestsFromTestCase(TestDataViewer))
- test_suite.addTest(loadTestsFromTestCase(TestDataViewerFrame))
- test_suite.addTest(loadTestsFromTestCase(TestDataView))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/data/test/test_numpyaxesselector.py b/silx/gui/data/test/test_numpyaxesselector.py
deleted file mode 100644
index d37cff7..0000000
--- a/silx/gui/data/test/test_numpyaxesselector.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "29/01/2018"
-
-import os
-import tempfile
-import unittest
-from contextlib import contextmanager
-
-import numpy
-
-from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
-from silx.gui.utils.testutils import SignalListener
-from silx.gui.utils.testutils import TestCaseQt
-
-import h5py
-
-
-class TestNumpyAxesSelector(TestCaseQt):
-
- def test_creation(self):
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- widget = NumpyAxesSelector()
- widget.setVisible(True)
-
- def test_none(self):
- data = numpy.arange(3 * 3 * 3)
- widget = NumpyAxesSelector()
- widget.setData(data)
- widget.setData(None)
- result = widget.selectedData()
- self.assertIsNone(result)
-
- def test_output_samedim(self):
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- expectedResult = data
-
- widget = NumpyAxesSelector()
- widget.setAxisNames(["x", "y", "z"])
- widget.setData(data)
- result = widget.selectedData()
- self.assertTrue(numpy.array_equal(result, expectedResult))
-
- def test_output_moredim(self):
- data = numpy.arange(3 * 3 * 3 * 3)
- data.shape = 3, 3, 3, 3
- expectedResult = data
-
- widget = NumpyAxesSelector()
- widget.setAxisNames(["x", "y", "z", "boum"])
- widget.setData(data[0])
- result = widget.selectedData()
- self.assertIsNone(result)
- widget.setData(data)
- result = widget.selectedData()
- self.assertTrue(numpy.array_equal(result, expectedResult))
-
- def test_output_lessdim(self):
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- expectedResult = data[0]
-
- widget = NumpyAxesSelector()
- widget.setAxisNames(["y", "x"])
- widget.setData(data)
- result = widget.selectedData()
- self.assertTrue(numpy.array_equal(result, expectedResult))
-
- def test_output_1dim(self):
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- expectedResult = data[0, 0, 0]
-
- widget = NumpyAxesSelector()
- widget.setData(data)
- result = widget.selectedData()
- self.assertTrue(numpy.array_equal(result, expectedResult))
-
- @contextmanager
- def h5_temporary_file(self):
- # create tmp file
- fd, tmp_name = tempfile.mkstemp(suffix=".h5")
- os.close(fd)
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- # create h5 data
- h5file = h5py.File(tmp_name, "w")
- h5file["data"] = data
- yield h5file
- # clean up
- h5file.close()
- os.unlink(tmp_name)
-
- def test_h5py_dataset(self):
- with self.h5_temporary_file() as h5file:
- dataset = h5file["data"]
- expectedResult = dataset[0]
-
- widget = NumpyAxesSelector()
- widget.setData(dataset)
- widget.setAxisNames(["y", "x"])
- result = widget.selectedData()
- self.assertTrue(numpy.array_equal(result, expectedResult))
-
- def test_data_event(self):
- data = numpy.arange(3 * 3 * 3)
- widget = NumpyAxesSelector()
- listener = SignalListener()
- widget.dataChanged.connect(listener)
- widget.setData(data)
- widget.setData(None)
- self.assertEqual(listener.callCount(), 2)
-
- def test_selected_data_event(self):
- data = numpy.arange(3 * 3 * 3)
- data.shape = 3, 3, 3
- widget = NumpyAxesSelector()
- listener = SignalListener()
- widget.selectionChanged.connect(listener)
- widget.setData(data)
- widget.setAxisNames(["x"])
- widget.setData(None)
- self.assertEqual(listener.callCount(), 3)
- listener.clear()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestNumpyAxesSelector))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/data/test/test_textformatter.py b/silx/gui/data/test/test_textformatter.py
deleted file mode 100644
index d3050bf..0000000
--- a/silx/gui/data/test/test_textformatter.py
+++ /dev/null
@@ -1,212 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "12/12/2017"
-
-import unittest
-import shutil
-import tempfile
-
-import numpy
-import six
-
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.utils.testutils import SignalListener
-from ..TextFormatter import TextFormatter
-from silx.io.utils import h5py_read_dataset
-
-import h5py
-
-
-class TestTextFormatter(TestCaseQt):
-
- def test_copy(self):
- formatter = TextFormatter()
- copy = TextFormatter(formatter=formatter)
- self.assertIsNot(formatter, copy)
- copy.setFloatFormat("%.3f")
- self.assertEqual(formatter.integerFormat(), copy.integerFormat())
- self.assertNotEqual(formatter.floatFormat(), copy.floatFormat())
- self.assertEqual(formatter.useQuoteForText(), copy.useQuoteForText())
- self.assertEqual(formatter.imaginaryUnit(), copy.imaginaryUnit())
-
- def test_event(self):
- listener = SignalListener()
- formatter = TextFormatter()
- formatter.formatChanged.connect(listener)
- formatter.setFloatFormat("%.3f")
- formatter.setIntegerFormat("%03i")
- formatter.setUseQuoteForText(False)
- formatter.setImaginaryUnit("z")
- self.assertEqual(listener.callCount(), 4)
-
- def test_int(self):
- formatter = TextFormatter()
- formatter.setIntegerFormat("%05i")
- result = formatter.toString(512)
- self.assertEqual(result, "00512")
-
- def test_float(self):
- formatter = TextFormatter()
- formatter.setFloatFormat("%.3f")
- result = formatter.toString(1.3)
- self.assertEqual(result, "1.300")
-
- def test_complex(self):
- formatter = TextFormatter()
- formatter.setFloatFormat("%.1f")
- formatter.setImaginaryUnit("i")
- result = formatter.toString(1.0 + 5j)
- result = result.replace(" ", "")
- self.assertEqual(result, "1.0+5.0i")
-
- def test_string(self):
- formatter = TextFormatter()
- formatter.setIntegerFormat("%.1f")
- formatter.setImaginaryUnit("z")
- result = formatter.toString("toto")
- self.assertEqual(result, '"toto"')
-
- def test_numpy_void(self):
- formatter = TextFormatter()
- result = formatter.toString(numpy.void(b"\xFF"))
- self.assertEqual(result, 'b"\\xFF"')
-
- def test_char_cp1252(self):
- # degree character in cp1252
- formatter = TextFormatter()
- result = formatter.toString(numpy.bytes_(b"\xB0"))
- self.assertEqual(result, u'"\u00B0"')
-
-
-class TestTextFormatterWithH5py(TestCaseQt):
-
- @classmethod
- def setUpClass(cls):
- super(TestTextFormatterWithH5py, cls).setUpClass()
-
- cls.tmpDirectory = tempfile.mkdtemp()
- cls.h5File = h5py.File("%s/formatter.h5" % cls.tmpDirectory, mode="w")
- cls.formatter = TextFormatter()
-
- @classmethod
- def tearDownClass(cls):
- super(TestTextFormatterWithH5py, cls).tearDownClass()
- cls.h5File.close()
- cls.h5File = None
- shutil.rmtree(cls.tmpDirectory)
-
- def create_dataset(self, data, dtype=None):
- testName = "%s" % self.id()
- dataset = self.h5File.create_dataset(testName, data=data, dtype=dtype)
- return dataset
-
- def read_dataset(self, d):
- return self.formatter.toString(d[()], dtype=d.dtype)
-
- def testAscii(self):
- d = self.create_dataset(data=b"abc")
- result = self.read_dataset(d)
- self.assertEqual(result, '"abc"')
-
- def testUnicode(self):
- d = self.create_dataset(data=u"i\u2661cookies")
- result = self.read_dataset(d)
- self.assertEqual(len(result), 11)
- self.assertEqual(result, u'"i\u2661cookies"')
-
- def testBadAscii(self):
- d = self.create_dataset(data=b"\xF0\x9F\x92\x94")
- result = self.read_dataset(d)
- self.assertEqual(result, 'b"\\xF0\\x9F\\x92\\x94"')
-
- def testVoid(self):
- d = self.create_dataset(data=numpy.void(b"abc\xF0"))
- result = self.read_dataset(d)
- self.assertEqual(result, 'b"\\x61\\x62\\x63\\xF0"')
-
- def testEnum(self):
- dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
- d = numpy.array(42, dtype=dtype)
- d = self.create_dataset(data=d)
- result = self.read_dataset(d)
- self.assertEqual(result, 'BLUE(42)')
-
- def testRef(self):
- dtype = h5py.special_dtype(ref=h5py.Reference)
- d = numpy.array(self.h5File.ref, dtype=dtype)
- d = self.create_dataset(data=d)
- result = self.read_dataset(d)
- self.assertEqual(result, 'REF')
-
- def testArrayAscii(self):
- d = self.create_dataset(data=[b"abc"])
- result = self.read_dataset(d)
- self.assertEqual(result, '["abc"]')
-
- def testArrayUnicode(self):
- dtype = h5py.special_dtype(vlen=six.text_type)
- d = numpy.array([u"i\u2661cookies"], dtype=dtype)
- d = self.create_dataset(data=d)
- result = self.read_dataset(d)
- self.assertEqual(len(result), 13)
- self.assertEqual(result, u'["i\u2661cookies"]')
-
- def testArrayBadAscii(self):
- d = self.create_dataset(data=[b"\xF0\x9F\x92\x94"])
- result = self.read_dataset(d)
- self.assertEqual(result, '[b"\\xF0\\x9F\\x92\\x94"]')
-
- def testArrayVoid(self):
- d = self.create_dataset(data=numpy.void([b"abc\xF0"]))
- result = self.read_dataset(d)
- self.assertEqual(result, '[b"\\x61\\x62\\x63\\xF0"]')
-
- def testArrayEnum(self):
- dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
- d = numpy.array([42, 1, 100], dtype=dtype)
- d = self.create_dataset(data=d)
- result = self.read_dataset(d)
- self.assertEqual(result, '[BLUE(42) GREEN(1) 100]')
-
- def testArrayRef(self):
- dtype = h5py.special_dtype(ref=h5py.Reference)
- d = numpy.array([self.h5File.ref, None], dtype=dtype)
- d = self.create_dataset(data=d)
- result = self.read_dataset(d)
- self.assertEqual(result, '[REF NULL_REF]')
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestTextFormatter))
- test_suite.addTest(loadTests(TestTextFormatterWithH5py))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/dialog/AbstractDataFileDialog.py b/silx/gui/dialog/AbstractDataFileDialog.py
deleted file mode 100644
index 29e7bb5..0000000
--- a/silx/gui/dialog/AbstractDataFileDialog.py
+++ /dev/null
@@ -1,1742 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module contains an :class:`AbstractDataFileDialog`.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "05/03/2019"
-
-
-import sys
-import os
-import logging
-import functools
-from distutils.version import LooseVersion
-
-import numpy
-import six
-
-import silx.io.url
-from silx.gui import qt
-from silx.gui.hdf5.Hdf5TreeModel import Hdf5TreeModel
-from . import utils
-from .FileTypeComboBox import FileTypeComboBox
-
-import fabio
-
-
-_logger = logging.getLogger(__name__)
-
-
-DEFAULT_SIDEBAR_URL = True
-"""Set it to false to disable initilializing of the sidebar urls with the
-default Qt list. This could allow to disable a behaviour known to segfault on
-some version of PyQt."""
-
-
-class _IconProvider(object):
-
- FileDialogToParentDir = qt.QStyle.SP_CustomBase + 1
-
- FileDialogToParentFile = qt.QStyle.SP_CustomBase + 2
-
- def __init__(self):
- self.__iconFileDialogToParentDir = None
- self.__iconFileDialogToParentFile = None
-
- def _createIconToParent(self, standardPixmap):
- """
-
- FIXME: It have to be tested for some OS (arrow icon do not have always
- the same direction)
- """
- style = qt.QApplication.style()
- baseIcon = style.standardIcon(qt.QStyle.SP_FileDialogToParent)
- backgroundIcon = style.standardIcon(standardPixmap)
- icon = qt.QIcon()
-
- sizes = baseIcon.availableSizes()
- sizes = sorted(sizes, key=lambda s: s.height())
- sizes = filter(lambda s: s.height() < 100, sizes)
- sizes = list(sizes)
- if len(sizes) > 0:
- baseSize = sizes[-1]
- else:
- baseSize = baseIcon.availableSizes()[0]
- size = qt.QSize(baseSize.width(), baseSize.height() * 3 // 2)
-
- modes = [qt.QIcon.Normal, qt.QIcon.Disabled]
- for mode in modes:
- pixmap = qt.QPixmap(size)
- pixmap.fill(qt.Qt.transparent)
- painter = qt.QPainter(pixmap)
- painter.drawPixmap(0, 0, backgroundIcon.pixmap(baseSize, mode=mode))
- painter.drawPixmap(0, size.height() // 3, baseIcon.pixmap(baseSize, mode=mode))
- painter.end()
- icon.addPixmap(pixmap, mode=mode)
-
- return icon
-
- def getFileDialogToParentDir(self):
- if self.__iconFileDialogToParentDir is None:
- self.__iconFileDialogToParentDir = self._createIconToParent(qt.QStyle.SP_DirIcon)
- return self.__iconFileDialogToParentDir
-
- def getFileDialogToParentFile(self):
- if self.__iconFileDialogToParentFile is None:
- self.__iconFileDialogToParentFile = self._createIconToParent(qt.QStyle.SP_FileIcon)
- return self.__iconFileDialogToParentFile
-
- def icon(self, kind):
- if kind == self.FileDialogToParentDir:
- return self.getFileDialogToParentDir()
- elif kind == self.FileDialogToParentFile:
- return self.getFileDialogToParentFile()
- else:
- style = qt.QApplication.style()
- icon = style.standardIcon(kind)
- return icon
-
-
-class _SideBar(qt.QListView):
- """Sidebar containing shortcuts for common directories"""
-
- def __init__(self, parent=None):
- super(_SideBar, self).__init__(parent)
- self.__iconProvider = qt.QFileIconProvider()
- self.setUniformItemSizes(True)
- model = qt.QStandardItemModel(self)
- self.setModel(model)
- self._initModel()
- self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
-
- def iconProvider(self):
- return self.__iconProvider
-
- def _initModel(self):
- urls = self._getDefaultUrls()
- self.setUrls(urls)
-
- def _getDefaultUrls(self):
- """Returns the default shortcuts.
-
- It uses the default QFileDialog shortcuts if it is possible, else
- provides a link to the computer's root and the user's home.
-
- :rtype: List[str]
- """
- urls = []
- version = LooseVersion(qt.qVersion())
- feed_sidebar = True
-
- if not DEFAULT_SIDEBAR_URL:
- _logger.debug("Skip default sidebar URLs (from setted variable)")
- feed_sidebar = False
- elif version.version[0] == 4 and sys.platform in ["win32"]:
- # Avoid locking the GUI 5min in case of use of network driver
- _logger.debug("Skip default sidebar URLs (avoid lock when using network drivers)")
- feed_sidebar = False
- elif version < LooseVersion("5.11.2") and qt.BINDING == "PyQt5" and sys.platform in ["linux", "linux2"]:
- # Avoid segfault on PyQt5 + gtk
- _logger.debug("Skip default sidebar URLs (avoid PyQt5 segfault)")
- feed_sidebar = False
-
- if feed_sidebar:
- # Get default shortcut
- # There is no other way
- d = qt.QFileDialog(self)
- # Needed to be able to reach the sidebar urls
- d.setOption(qt.QFileDialog.DontUseNativeDialog, True)
- urls = d.sidebarUrls()
- d.deleteLater()
- d = None
-
- if len(urls) == 0:
- urls.append(qt.QUrl("file://"))
- urls.append(qt.QUrl.fromLocalFile(qt.QDir.homePath()))
-
- return urls
-
- def setSelectedPath(self, path):
- selected = None
- model = self.model()
- for i in range(model.rowCount()):
- index = model.index(i, 0)
- url = model.data(index, qt.Qt.UserRole)
- urlPath = url.toLocalFile()
- if path == urlPath:
- selected = index
-
- selectionModel = self.selectionModel()
- if selected is not None:
- selectionModel.setCurrentIndex(selected, qt.QItemSelectionModel.ClearAndSelect)
- else:
- selectionModel.clear()
-
- def setUrls(self, urls):
- model = self.model()
- model.clear()
-
- names = {}
- names[qt.QDir.rootPath()] = "Computer"
- names[qt.QDir.homePath()] = "Home"
-
- style = qt.QApplication.style()
- iconProvider = self.iconProvider()
- for url in urls:
- path = url.toLocalFile()
- if path == "":
- if sys.platform != "win32":
- url = qt.QUrl(qt.QDir.rootPath())
- name = "Computer"
- icon = style.standardIcon(qt.QStyle.SP_ComputerIcon)
- else:
- fileInfo = qt.QFileInfo(path)
- name = names.get(path, fileInfo.fileName())
- icon = iconProvider.icon(fileInfo)
-
- if icon.isNull():
- icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
-
- item = qt.QStandardItem()
- item.setText(name)
- item.setIcon(icon)
- item.setData(url, role=qt.Qt.UserRole)
- model.appendRow(item)
-
- def urls(self):
- result = []
- model = self.model()
- for i in range(model.rowCount()):
- index = model.index(i, 0)
- url = model.data(index, qt.Qt.UserRole)
- result.append(url)
- return result
-
- def sizeHint(self):
- index = self.model().index(0, 0)
- return self.sizeHintForIndex(index) + qt.QSize(2 * self.frameWidth(), 2 * self.frameWidth())
-
-
-class _Browser(qt.QStackedWidget):
-
- activated = qt.Signal(qt.QModelIndex)
- selected = qt.Signal(qt.QModelIndex)
- rootIndexChanged = qt.Signal(qt.QModelIndex)
-
- def __init__(self, parent, listView, detailView):
- qt.QStackedWidget.__init__(self, parent)
- self.__listView = listView
- self.__detailView = detailView
- self.insertWidget(0, self.__listView)
- self.insertWidget(1, self.__detailView)
-
- self.__listView.activated.connect(self.__emitActivated)
- self.__detailView.activated.connect(self.__emitActivated)
-
- def __emitActivated(self, index):
- self.activated.emit(index)
-
- def __emitSelected(self, selected, deselected):
- index = self.selectedIndex()
- if index is not None:
- self.selected.emit(index)
-
- def selectedIndex(self):
- if self.currentIndex() == 0:
- selectionModel = self.__listView.selectionModel()
- else:
- selectionModel = self.__detailView.selectionModel()
-
- if selectionModel is None:
- return None
-
- indexes = selectionModel.selectedIndexes()
- # Filter non-main columns
- indexes = [i for i in indexes if i.column() == 0]
- if len(indexes) == 1:
- index = indexes[0]
- return index
- return None
-
- def model(self):
- """Returns the current model."""
- if self.currentIndex() == 0:
- return self.__listView.model()
- else:
- return self.__detailView.model()
-
- def selectIndex(self, index):
- if self.currentIndex() == 0:
- selectionModel = self.__listView.selectionModel()
- else:
- selectionModel = self.__detailView.selectionModel()
- if selectionModel is None:
- return
- selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
-
- def viewMode(self):
- """Returns the current view mode.
-
- :rtype: qt.QFileDialog.ViewMode
- """
- if self.currentIndex() == 0:
- return qt.QFileDialog.List
- elif self.currentIndex() == 1:
- return qt.QFileDialog.Detail
- else:
- assert(False)
-
- def setViewMode(self, mode):
- """Set the current view mode.
-
- :param qt.QFileDialog.ViewMode mode: The new view mode
- """
- if mode == qt.QFileDialog.Detail:
- self.showDetails()
- elif mode == qt.QFileDialog.List:
- self.showList()
- else:
- assert(False)
-
- def showList(self):
- self.__listView.show()
- self.__detailView.hide()
- self.setCurrentIndex(0)
-
- def showDetails(self):
- self.__listView.hide()
- self.__detailView.show()
- self.setCurrentIndex(1)
- self.__detailView.updateGeometry()
-
- def clear(self):
- self.__listView.setRootIndex(qt.QModelIndex())
- self.__detailView.setRootIndex(qt.QModelIndex())
- selectionModel = self.__listView.selectionModel()
- if selectionModel is not None:
- selectionModel.selectionChanged.disconnect()
- selectionModel.clear()
- selectionModel = self.__detailView.selectionModel()
- if selectionModel is not None:
- selectionModel.selectionChanged.disconnect()
- selectionModel.clear()
- self.__listView.setModel(None)
- self.__detailView.setModel(None)
-
- def setRootIndex(self, index, model=None):
- """Sets the root item to the item at the given index.
- """
- rootIndex = self.__listView.rootIndex()
- newModel = model or index.model()
- assert(newModel is not None)
-
- if rootIndex is None or rootIndex.model() is not newModel:
- # update the model
- selectionModel = self.__listView.selectionModel()
- if selectionModel is not None:
- selectionModel.selectionChanged.disconnect()
- selectionModel.clear()
- selectionModel = self.__detailView.selectionModel()
- if selectionModel is not None:
- selectionModel.selectionChanged.disconnect()
- selectionModel.clear()
- pIndex = qt.QPersistentModelIndex(index)
- self.__listView.setModel(newModel)
- # changing the model of the tree view change the index mapping
- # that is why we are using a persistance model index
- self.__detailView.setModel(newModel)
- index = newModel.index(pIndex.row(), pIndex.column(), pIndex.parent())
- selectionModel = self.__listView.selectionModel()
- selectionModel.selectionChanged.connect(self.__emitSelected)
- selectionModel = self.__detailView.selectionModel()
- selectionModel.selectionChanged.connect(self.__emitSelected)
-
- self.__listView.setRootIndex(index)
- self.__detailView.setRootIndex(index)
- self.rootIndexChanged.emit(index)
-
- def rootIndex(self):
- """Returns the model index of the model's root item. The root item is
- the parent item to the view's toplevel items. The root can be invalid.
- """
- return self.__listView.rootIndex()
-
- __serialVersion = 1
- """Store the current version of the serialized data"""
-
- def visualRect(self, index):
- """Returns the rectangle on the viewport occupied by the item at index.
-
- :param qt.QModelIndex index: An index
- :rtype: QRect
- """
- if self.currentIndex() == 0:
- return self.__listView.visualRect(index)
- else:
- return self.__detailView.visualRect(index)
-
- def viewport(self):
- """Returns the viewport widget.
-
- :param qt.QModelIndex index: An index
- :rtype: QRect
- """
- if self.currentIndex() == 0:
- return self.__listView.viewport()
- else:
- return self.__detailView.viewport()
-
- def restoreState(self, state):
- """Restores the dialogs's layout, history and current directory to the
- state specified.
-
- :param qt.QByeArray state: Stream containing the new state
- :rtype: bool
- """
- stream = qt.QDataStream(state, qt.QIODevice.ReadOnly)
-
- nameId = stream.readQString()
- if nameId != "Browser":
- _logger.warning("Stored state contains an invalid name id. Browser restoration cancelled.")
- return False
-
- version = stream.readInt32()
- if version != self.__serialVersion:
- _logger.warning("Stored state contains an invalid version. Browser restoration cancelled.")
- return False
-
- headerData = stream.readQVariant()
- self.__detailView.header().restoreState(headerData)
-
- viewMode = stream.readInt32()
- self.setViewMode(viewMode)
- return True
-
- def saveState(self):
- """Saves the state of the dialog's layout.
-
- :rtype: qt.QByteArray
- """
- data = qt.QByteArray()
- stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
-
- nameId = u"Browser"
- stream.writeQString(nameId)
- stream.writeInt32(self.__serialVersion)
- stream.writeQVariant(self.__detailView.header().saveState())
- stream.writeInt32(self.viewMode())
-
- return data
-
-
-class _FabioData(object):
-
- def __init__(self, fabioFile):
- self.__fabioFile = fabioFile
-
- @property
- def dtype(self):
- # Let say it is a valid type
- return numpy.dtype("float")
-
- @property
- def shape(self):
- if self.__fabioFile.nframes == 0:
- return None
- if self.__fabioFile.nframes == 1:
- return [slice(None), slice(None)]
- return [self.__fabioFile.nframes, slice(None), slice(None)]
-
- def __getitem__(self, selector):
- if self.__fabioFile.nframes == 1 and selector == tuple():
- return self.__fabioFile.data
- if isinstance(selector, tuple) and len(selector) == 1:
- selector = selector[0]
-
- if isinstance(selector, six.integer_types):
- if 0 <= selector < self.__fabioFile.nframes:
- if self.__fabioFile.nframes == 1:
- return self.__fabioFile.data
- else:
- frame = self.__fabioFile.getframe(selector)
- return frame.data
- else:
- raise ValueError("Invalid selector %s" % selector)
- else:
- raise TypeError("Unsupported selector type %s" % type(selector))
-
-
-class _PathEdit(qt.QLineEdit):
- pass
-
-
-class _CatchResizeEvent(qt.QObject):
-
- resized = qt.Signal(qt.QResizeEvent)
-
- def __init__(self, parent, target):
- super(_CatchResizeEvent, self).__init__(parent)
- self.__target = target
- self.__target_oldResizeEvent = self.__target.resizeEvent
- self.__target.resizeEvent = self.__resizeEvent
-
- def __resizeEvent(self, event):
- result = self.__target_oldResizeEvent(event)
- self.resized.emit(event)
- return result
-
-
-class AbstractDataFileDialog(qt.QDialog):
- """The `AbstractFileDialog` provides a generic GUI to create a custom dialog
- allowing to access to file resources like HDF5 files or HDF5 datasets.
-
- .. image:: img/abstractdatafiledialog.png
-
- The dialog contains:
-
- - Shortcuts: It provides few links to have a fast access of browsing
- locations.
- - Browser: It provides a display to browse throw the file system and inside
- HDF5 files or fabio files. A file format selector is provided.
- - URL: Display the URL available to reach the data using
- :meth:`silx.io.get_data`, :meth:`silx.io.open`.
- - Data selector: A widget to apply a sub selection of the browsed dataset.
- This widget can be provided, else nothing will be used.
- - Data preview: A widget to preview the selected data, which is the result
- of the filter from the data selector.
- This widget can be provided, else nothing will be used.
- - Preview's toolbar: Provides tools used to custom data preview or data
- selector.
- This widget can be provided, else nothing will be used.
- - Buttons to validate the dialog
- """
-
- _defaultIconProvider = None
- """Lazy loaded default icon provider"""
-
- def __init__(self, parent=None):
- super(AbstractDataFileDialog, self).__init__(parent)
- self._init()
-
- def _init(self):
- self.setWindowTitle("Open")
-
- self.__openedFiles = []
- """Store the list of files opened by the model itself."""
- # FIXME: It should be managed one by one by Hdf5Item itself
-
- self.__directory = None
- self.__directoryLoadedFilter = None
- self.__errorWhileLoadingFile = None
- self.__selectedFile = None
- self.__selectedData = None
- self.__currentHistory = []
- """Store history of URLs, last index one is the latest one"""
- self.__currentHistoryLocation = -1
- """Store the location in the history. Bigger is older"""
-
- self.__processing = 0
- """Number of asynchronous processing tasks"""
- self.__h5 = None
- self.__fabio = None
-
- if qt.qVersion() < "5.0":
- # On Qt4 it is needed to provide a safe file system model
- _logger.debug("Uses SafeFileSystemModel")
- from .SafeFileSystemModel import SafeFileSystemModel
- self.__fileModel = SafeFileSystemModel(self)
- else:
- # On Qt5 a safe icon provider is still needed to avoid freeze
- _logger.debug("Uses default QFileSystemModel with a SafeFileIconProvider")
- self.__fileModel = qt.QFileSystemModel(self)
- from .SafeFileIconProvider import SafeFileIconProvider
- iconProvider = SafeFileIconProvider()
- self.__fileModel.setIconProvider(iconProvider)
-
- # The common file dialog filter only on Mac OS X
- self.__fileModel.setNameFilterDisables(sys.platform == "darwin")
- self.__fileModel.setReadOnly(True)
- self.__fileModel.directoryLoaded.connect(self.__directoryLoaded)
-
- self.__dataModel = Hdf5TreeModel(self)
-
- self.__createWidgets()
- self.__initLayout()
- self.__showAsListView()
-
- path = os.getcwd()
- self.__fileModel_setRootPath(path)
-
- self.__clearData()
- self.__updatePath()
-
- # Update the file model filter
- self.__fileTypeCombo.setCurrentIndex(0)
- self.__filterSelected(0)
-
- # It is not possible to override the QObject destructor nor
- # to access to the content of the Python object with the `destroyed`
- # signal cause the Python method was already removed with the QWidget,
- # while the QObject still exists.
- # We use a static method plus explicit references to objects to
- # release. The callback do not use any ref to self.
- onDestroy = functools.partial(self._closeFileList, self.__openedFiles)
- self.destroyed.connect(onDestroy)
-
- @staticmethod
- def _closeFileList(fileList):
- """Static method to close explicit references to internal objects."""
- _logger.debug("Clear AbstractDataFileDialog")
- for obj in fileList:
- _logger.debug("Close file %s", obj.filename)
- obj.close()
- fileList[:] = []
-
- def done(self, result):
- self._clear()
- super(AbstractDataFileDialog, self).done(result)
-
- def _clear(self):
- """Explicit method to clear data stored in the dialog.
- After this call it is not anymore possible to use the widget.
-
- This method is triggered by the destruction of the object and the
- QDialog :meth:`done`. Then it can be triggered more than once.
- """
- _logger.debug("Clear dialog")
- self.__errorWhileLoadingFile = None
- self.__clearData()
- if self.__fileModel is not None:
- # Cache the directory before cleaning the model
- self.__directory = self.directory()
- self.__browser.clear()
- self.__closeFile()
- self.__fileModel = None
- self.__dataModel = None
-
- def hasPendingEvents(self):
- """Returns true if the dialog have asynchronous tasks working on the
- background."""
- return self.__processing > 0
-
- # User interface
-
- def __createWidgets(self):
- self.__sidebar = self._createSideBar()
- if self.__sidebar is not None:
- sideBarModel = self.__sidebar.selectionModel()
- sideBarModel.selectionChanged.connect(self.__shortcutSelected)
- self.__sidebar.setSelectionMode(qt.QAbstractItemView.SingleSelection)
-
- listView = qt.QListView(self)
- listView.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- listView.setSelectionMode(qt.QAbstractItemView.SingleSelection)
- listView.setResizeMode(qt.QListView.Adjust)
- listView.setWrapping(True)
- listView.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
- listView.setContextMenuPolicy(qt.Qt.CustomContextMenu)
- utils.patchToConsumeReturnKey(listView)
-
- treeView = qt.QTreeView(self)
- treeView.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- treeView.setSelectionMode(qt.QAbstractItemView.SingleSelection)
- treeView.setRootIsDecorated(False)
- treeView.setItemsExpandable(False)
- treeView.setSortingEnabled(True)
- treeView.header().setSortIndicator(0, qt.Qt.AscendingOrder)
- treeView.header().setStretchLastSection(False)
- treeView.setTextElideMode(qt.Qt.ElideMiddle)
- treeView.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
- treeView.setContextMenuPolicy(qt.Qt.CustomContextMenu)
- treeView.setDragDropMode(qt.QAbstractItemView.InternalMove)
- utils.patchToConsumeReturnKey(treeView)
-
- self.__browser = _Browser(self, listView, treeView)
- self.__browser.activated.connect(self.__browsedItemActivated)
- self.__browser.selected.connect(self.__browsedItemSelected)
- self.__browser.rootIndexChanged.connect(self.__rootIndexChanged)
- self.__browser.setObjectName("browser")
-
- self.__previewWidget = self._createPreviewWidget(self)
-
- self.__fileTypeCombo = FileTypeComboBox(self)
- self.__fileTypeCombo.setObjectName("fileTypeCombo")
- self.__fileTypeCombo.setDuplicatesEnabled(False)
- self.__fileTypeCombo.setSizeAdjustPolicy(qt.QComboBox.AdjustToMinimumContentsLength)
- self.__fileTypeCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- self.__fileTypeCombo.activated[int].connect(self.__filterSelected)
- self.__fileTypeCombo.setFabioUrlSupproted(self._isFabioFilesSupported())
-
- self.__pathEdit = _PathEdit(self)
- self.__pathEdit.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- self.__pathEdit.textChanged.connect(self.__textChanged)
- self.__pathEdit.setObjectName("url")
- utils.patchToConsumeReturnKey(self.__pathEdit)
-
- self.__buttons = qt.QDialogButtonBox(self)
- self.__buttons.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
- types = qt.QDialogButtonBox.Open | qt.QDialogButtonBox.Cancel
- self.__buttons.setStandardButtons(types)
- self.__buttons.button(qt.QDialogButtonBox.Cancel).setObjectName("cancel")
- self.__buttons.button(qt.QDialogButtonBox.Open).setObjectName("open")
-
- self.__buttons.accepted.connect(self.accept)
- self.__buttons.rejected.connect(self.reject)
-
- self.__browseToolBar = self._createBrowseToolBar()
- self.__backwardAction.setEnabled(False)
- self.__forwardAction.setEnabled(False)
- self.__fileDirectoryAction.setEnabled(False)
- self.__parentFileDirectoryAction.setEnabled(False)
-
- self.__selectorWidget = self._createSelectorWidget(self)
- if self.__selectorWidget is not None:
- self.__selectorWidget.selectionChanged.connect(self.__selectorWidgetChanged)
-
- self.__previewToolBar = self._createPreviewToolbar(self, self.__previewWidget, self.__selectorWidget)
-
- self.__dataIcon = qt.QLabel(self)
- self.__dataIcon.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
- self.__dataIcon.setScaledContents(True)
- self.__dataIcon.setMargin(2)
- self.__dataIcon.setAlignment(qt.Qt.AlignCenter)
-
- self.__dataInfo = qt.QLabel(self)
- self.__dataInfo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
-
- def _createSideBar(self):
- sidebar = _SideBar(self)
- sidebar.setObjectName("sidebar")
- return sidebar
-
- def iconProvider(self):
- iconProvider = self.__class__._defaultIconProvider
- if iconProvider is None:
- iconProvider = _IconProvider()
- self.__class__._defaultIconProvider = iconProvider
- return iconProvider
-
- def _createBrowseToolBar(self):
- toolbar = qt.QToolBar(self)
- toolbar.setIconSize(qt.QSize(16, 16))
- iconProvider = self.iconProvider()
-
- backward = qt.QAction(toolbar)
- backward.setText("Back")
- backward.setObjectName("backwardAction")
- backward.setIcon(iconProvider.icon(qt.QStyle.SP_ArrowBack))
- backward.triggered.connect(self.__navigateBackward)
- self.__backwardAction = backward
-
- forward = qt.QAction(toolbar)
- forward.setText("Forward")
- forward.setObjectName("forwardAction")
- forward.setIcon(iconProvider.icon(qt.QStyle.SP_ArrowForward))
- forward.triggered.connect(self.__navigateForward)
- self.__forwardAction = forward
-
- parentDirectory = qt.QAction(toolbar)
- parentDirectory.setText("Go to parent")
- parentDirectory.setObjectName("toParentAction")
- parentDirectory.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogToParent))
- parentDirectory.triggered.connect(self.__navigateToParent)
- self.__toParentAction = parentDirectory
-
- fileDirectory = qt.QAction(toolbar)
- fileDirectory.setText("Root of the file")
- fileDirectory.setObjectName("toRootFileAction")
- fileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentFile))
- fileDirectory.triggered.connect(self.__navigateToParentFile)
- self.__fileDirectoryAction = fileDirectory
-
- parentFileDirectory = qt.QAction(toolbar)
- parentFileDirectory.setText("Parent directory of the file")
- parentFileDirectory.setObjectName("toDirectoryAction")
- parentFileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentDir))
- parentFileDirectory.triggered.connect(self.__navigateToParentDir)
- self.__parentFileDirectoryAction = parentFileDirectory
-
- listView = qt.QAction(toolbar)
- listView.setText("List view")
- listView.setObjectName("listModeAction")
- listView.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogListView))
- listView.triggered.connect(self.__showAsListView)
- listView.setCheckable(True)
-
- detailView = qt.QAction(toolbar)
- detailView.setText("Detail view")
- detailView.setObjectName("detailModeAction")
- detailView.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogDetailedView))
- detailView.triggered.connect(self.__showAsDetailedView)
- detailView.setCheckable(True)
-
- self.__listViewAction = listView
- self.__detailViewAction = detailView
-
- toolbar.addAction(backward)
- toolbar.addAction(forward)
- toolbar.addSeparator()
- toolbar.addAction(parentDirectory)
- toolbar.addAction(fileDirectory)
- toolbar.addAction(parentFileDirectory)
- toolbar.addSeparator()
- toolbar.addAction(listView)
- toolbar.addAction(detailView)
-
- toolbar.setStyleSheet("QToolBar { border: 0px }")
-
- return toolbar
-
- def __initLayout(self):
- sideBarLayout = qt.QVBoxLayout()
- sideBarLayout.setContentsMargins(0, 0, 0, 0)
- dummyToolBar = qt.QWidget(self)
- dummyToolBar.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- dummyCombo = qt.QWidget(self)
- dummyCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- sideBarLayout.addWidget(dummyToolBar)
- if self.__sidebar is not None:
- sideBarLayout.addWidget(self.__sidebar)
- sideBarLayout.addWidget(dummyCombo)
- sideBarWidget = qt.QWidget(self)
- sideBarWidget.setLayout(sideBarLayout)
-
- dummyCombo.setFixedHeight(self.__fileTypeCombo.height())
- self.__resizeCombo = _CatchResizeEvent(self, self.__fileTypeCombo)
- self.__resizeCombo.resized.connect(lambda e: dummyCombo.setFixedHeight(e.size().height()))
-
- dummyToolBar.setFixedHeight(self.__browseToolBar.height())
- self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
- self.__resizeToolbar.resized.connect(lambda e: dummyToolBar.setFixedHeight(e.size().height()))
-
- datasetSelection = qt.QWidget(self)
- layoutLeft = qt.QVBoxLayout()
- layoutLeft.setContentsMargins(0, 0, 0, 0)
- layoutLeft.addWidget(self.__browseToolBar)
- layoutLeft.addWidget(self.__browser)
- layoutLeft.addWidget(self.__fileTypeCombo)
- datasetSelection.setLayout(layoutLeft)
- datasetSelection.setSizePolicy(qt.QSizePolicy.MinimumExpanding, qt.QSizePolicy.Expanding)
-
- infoLayout = qt.QHBoxLayout()
- infoLayout.setContentsMargins(0, 0, 0, 0)
- infoLayout.addWidget(self.__dataIcon)
- infoLayout.addWidget(self.__dataInfo)
-
- dataFrame = qt.QFrame(self)
- dataFrame.setFrameShape(qt.QFrame.StyledPanel)
- layout = qt.QVBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
- layout.addWidget(self.__previewWidget)
- layout.addLayout(infoLayout)
- dataFrame.setLayout(layout)
-
- dataSelection = qt.QWidget(self)
- dataLayout = qt.QVBoxLayout()
- dataLayout.setContentsMargins(0, 0, 0, 0)
- if self.__previewToolBar is not None:
- dataLayout.addWidget(self.__previewToolBar)
- else:
- # Add dummy space
- dummyToolbar2 = qt.QWidget(self)
- dummyToolbar2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- dummyToolbar2.setFixedHeight(self.__browseToolBar.height())
- self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
- self.__resizeToolbar.resized.connect(lambda e: dummyToolbar2.setFixedHeight(e.size().height()))
- dataLayout.addWidget(dummyToolbar2)
-
- dataLayout.addWidget(dataFrame)
- if self.__selectorWidget is not None:
- dataLayout.addWidget(self.__selectorWidget)
- else:
- # Add dummy space
- dummyCombo2 = qt.QWidget(self)
- dummyCombo2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- dummyCombo2.setFixedHeight(self.__fileTypeCombo.height())
- self.__resizeToolbar = _CatchResizeEvent(self, self.__fileTypeCombo)
- self.__resizeToolbar.resized.connect(lambda e: dummyCombo2.setFixedHeight(e.size().height()))
- dataLayout.addWidget(dummyCombo2)
- dataSelection.setLayout(dataLayout)
-
- self.__splitter = qt.QSplitter(self)
- self.__splitter.setContentsMargins(0, 0, 0, 0)
- self.__splitter.addWidget(sideBarWidget)
- self.__splitter.addWidget(datasetSelection)
- self.__splitter.addWidget(dataSelection)
- self.__splitter.setStretchFactor(1, 10)
-
- bottomLayout = qt.QHBoxLayout()
- bottomLayout.setContentsMargins(0, 0, 0, 0)
- bottomLayout.addWidget(self.__pathEdit)
- bottomLayout.addWidget(self.__buttons)
-
- layout = qt.QVBoxLayout(self)
- layout.addWidget(self.__splitter)
- layout.addLayout(bottomLayout)
-
- self.setLayout(layout)
- self.updateGeometry()
-
- # Logic
-
- def __navigateBackward(self):
- """Navigate through the history one step backward."""
- if len(self.__currentHistory) > 0 and self.__currentHistoryLocation > 0:
- self.__currentHistoryLocation -= 1
- url = self.__currentHistory[self.__currentHistoryLocation]
- self.selectUrl(url)
-
- def __navigateForward(self):
- """Navigate through the history one step forward."""
- if len(self.__currentHistory) > 0 and self.__currentHistoryLocation < len(self.__currentHistory) - 1:
- self.__currentHistoryLocation += 1
- url = self.__currentHistory[self.__currentHistoryLocation]
- self.selectUrl(url)
-
- def __navigateToParent(self):
- index = self.__browser.rootIndex()
- if index.model() is self.__fileModel:
- # browse throw the file system
- index = index.parent()
- path = self.__fileModel.filePath(index)
- self.__fileModel_setRootPath(path)
- self.__browser.selectIndex(qt.QModelIndex())
- self.__updatePath()
- elif index.model() is self.__dataModel:
- index = index.parent()
- if index.isValid():
- # browse throw the hdf5
- self.__browser.setRootIndex(index)
- self.__browser.selectIndex(qt.QModelIndex())
- self.__updatePath()
- else:
- # go back to the file system
- self.__navigateToParentDir()
- else:
- # Root of the file system (my computer)
- pass
-
- def __navigateToParentFile(self):
- index = self.__browser.rootIndex()
- if index.model() is self.__dataModel:
- index = self.__dataModel.indexFromH5Object(self.__h5)
- self.__browser.setRootIndex(index)
- self.__browser.selectIndex(qt.QModelIndex())
- self.__updatePath()
-
- def __navigateToParentDir(self):
- index = self.__browser.rootIndex()
- if index.model() is self.__dataModel:
- path = os.path.dirname(self.__h5.file.filename)
- index = self.__fileModel.index(path)
- self.__browser.setRootIndex(index)
- self.__browser.selectIndex(qt.QModelIndex())
- self.__closeFile()
- self.__updatePath()
-
- def viewMode(self):
- """Returns the current view mode.
-
- :rtype: qt.QFileDialog.ViewMode
- """
- return self.__browser.viewMode()
-
- def setViewMode(self, mode):
- """Set the current view mode.
-
- :param qt.QFileDialog.ViewMode mode: The new view mode
- """
- if mode == qt.QFileDialog.Detail:
- self.__browser.showDetails()
- self.__listViewAction.setChecked(False)
- self.__detailViewAction.setChecked(True)
- elif mode == qt.QFileDialog.List:
- self.__browser.showList()
- self.__listViewAction.setChecked(True)
- self.__detailViewAction.setChecked(False)
- else:
- assert(False)
-
- def __showAsListView(self):
- self.setViewMode(qt.QFileDialog.List)
-
- def __showAsDetailedView(self):
- self.setViewMode(qt.QFileDialog.Detail)
-
- def __shortcutSelected(self):
- self.__browser.selectIndex(qt.QModelIndex())
- self.__clearData()
- self.__updatePath()
- selectionModel = self.__sidebar.selectionModel()
- indexes = selectionModel.selectedIndexes()
- if len(indexes) == 1:
- index = indexes[0]
- url = self.__sidebar.model().data(index, role=qt.Qt.UserRole)
- path = url.toLocalFile()
- self.__fileModel_setRootPath(path)
-
- def __browsedItemActivated(self, index):
- if not index.isValid():
- return
- if index.model() is self.__fileModel:
- path = self.__fileModel.filePath(index)
- if self.__fileModel.isDir(index):
- self.__fileModel_setRootPath(path)
- if os.path.isfile(path):
- self.__fileActivated(index)
- elif index.model() is self.__dataModel:
- obj = self.__dataModel.data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
- if silx.io.is_group(obj):
- self.__browser.setRootIndex(index)
- else:
- assert(False)
-
- def __browsedItemSelected(self, index):
- self.__dataSelected(index)
- self.__updatePath()
-
- def __fileModel_setRootPath(self, path):
- """Set the root path of the fileModel with a filter on the
- directoryLoaded event.
-
- Without this filter an extra event is received (at least with PyQt4)
- when we use for the first time the sidebar.
-
- :param str path: Path to load
- """
- assert(path is not None)
- if path != "" and not os.path.exists(path):
- return
- if self.hasPendingEvents():
- # Make sure the asynchronous fileModel setRootPath is finished
- qt.QApplication.instance().processEvents()
-
- if self.__directoryLoadedFilter is not None:
- if utils.samefile(self.__directoryLoadedFilter, path):
- return
- self.__directoryLoadedFilter = path
- self.__processing += 1
- if self.__fileModel is None:
- return
- index = self.__fileModel.setRootPath(path)
- if not index.isValid():
- # There is a problem with this path
- # No asynchronous process will be waked up
- self.__processing -= 1
- self.__browser.setRootIndex(index, model=self.__fileModel)
- self.__clearData()
- self.__updatePath()
-
- def __directoryLoaded(self, path):
- if self.__directoryLoadedFilter is not None:
- if not utils.samefile(self.__directoryLoadedFilter, path):
- # Filter event which should not arrive in PyQt4
- # The first click on the sidebar sent 2 events
- self.__processing -= 1
- return
- if self.__fileModel is None:
- return
- index = self.__fileModel.index(path)
- self.__browser.setRootIndex(index, model=self.__fileModel)
- self.__updatePath()
- self.__processing -= 1
-
- def __closeFile(self):
- self.__openedFiles[:] = []
- self.__fileDirectoryAction.setEnabled(False)
- self.__parentFileDirectoryAction.setEnabled(False)
- if self.__h5 is not None:
- self.__dataModel.removeH5pyObject(self.__h5)
- self.__h5.close()
- self.__h5 = None
- if self.__fabio is not None:
- if hasattr(self.__fabio, "close"):
- self.__fabio.close()
- self.__fabio = None
-
- def __openFabioFile(self, filename):
- self.__closeFile()
- try:
- self.__fabio = fabio.open(filename)
- self.__openedFiles.append(self.__fabio)
- self.__selectedFile = filename
- except Exception as e:
- _logger.error("Error while loading file %s: %s", filename, e.args[0])
- _logger.debug("Backtrace", exc_info=True)
- self.__errorWhileLoadingFile = filename, e.args[0]
- return False
- else:
- return True
-
- def __openSilxFile(self, filename):
- self.__closeFile()
- try:
- self.__h5 = silx.io.open(filename)
- self.__openedFiles.append(self.__h5)
- self.__selectedFile = filename
- except IOError as e:
- _logger.error("Error while loading file %s: %s", filename, e.args[0])
- _logger.debug("Backtrace", exc_info=True)
- self.__errorWhileLoadingFile = filename, e.args[0]
- return False
- else:
- self.__fileDirectoryAction.setEnabled(True)
- self.__parentFileDirectoryAction.setEnabled(True)
- self.__dataModel.insertH5pyObject(self.__h5)
- return True
-
- def __isSilxHavePriority(self, filename):
- """Silx have priority when there is a specific decoder
- """
- _, ext = os.path.splitext(filename)
- ext = "*%s" % ext
- formats = silx.io.supported_extensions(flat_formats=False)
- for extensions in formats.values():
- if ext in extensions:
- return True
- return False
-
- def __openFile(self, filename):
- codec = self.__fileTypeCombo.currentCodec()
- openners = []
- if codec.is_autodetect():
- if self.__isSilxHavePriority(filename):
- openners.append(self.__openSilxFile)
- if self._isFabioFilesSupported():
- openners.append(self.__openFabioFile)
- else:
- if self._isFabioFilesSupported():
- openners.append(self.__openFabioFile)
- openners.append(self.__openSilxFile)
- elif codec.is_silx_codec():
- openners.append(self.__openSilxFile)
- elif self._isFabioFilesSupported() and codec.is_fabio_codec():
- # It is requested to use fabio, anyway fabio is here or not
- openners.append(self.__openFabioFile)
-
- for openner in openners:
- ref = openner(filename)
- if ref is not None:
- return True
- return False
-
- def __fileActivated(self, index):
- self.__selectedFile = None
- path = self.__fileModel.filePath(index)
- if os.path.isfile(path):
- loaded = self.__openFile(path)
- if loaded:
- if self.__h5 is not None:
- index = self.__dataModel.indexFromH5Object(self.__h5)
- self.__browser.setRootIndex(index)
- elif self.__fabio is not None:
- data = _FabioData(self.__fabio)
- self.__setData(data)
- self.__updatePath()
- else:
- self.__clearData()
-
- def __dataSelected(self, index):
- selectedData = None
- if index is not None:
- if index.model() is self.__dataModel:
- obj = self.__dataModel.data(index, self.__dataModel.H5PY_OBJECT_ROLE)
- if self._isDataSupportable(obj):
- selectedData = obj
- elif index.model() is self.__fileModel:
- self.__closeFile()
- if self._isFabioFilesSupported():
- path = self.__fileModel.filePath(index)
- if os.path.isfile(path):
- codec = self.__fileTypeCombo.currentCodec()
- is_fabio_decoder = codec.is_fabio_codec()
- is_fabio_have_priority = not codec.is_silx_codec() and not self.__isSilxHavePriority(path)
- if is_fabio_decoder or is_fabio_have_priority:
- # Then it's flat frame container
- self.__openFabioFile(path)
- if self.__fabio is not None:
- selectedData = _FabioData(self.__fabio)
- else:
- assert(False)
-
- self.__setData(selectedData)
-
- def __filterSelected(self, index):
- filters = self.__fileTypeCombo.itemExtensions(index)
- self.__fileModel.setNameFilters(list(filters))
-
- def __setData(self, data):
- self.__data = data
-
- if data is not None and self._isDataSupportable(data):
- if self.__selectorWidget is not None:
- self.__selectorWidget.setData(data)
- if not self.__selectorWidget.isUsed():
- # Needed to fake the fact we have to reset the zoom in preview
- self.__selectedData = None
- self.__setSelectedData(data)
- self.__selectorWidget.hide()
- else:
- self.__selectorWidget.setVisible(self.__selectorWidget.hasVisibleSelectors())
- # Needed to fake the fact we have to reset the zoom in preview
- self.__selectedData = None
- self.__selectorWidget.selectionChanged.emit()
- else:
- # Needed to fake the fact we have to reset the zoom in preview
- self.__selectedData = None
- self.__setSelectedData(data)
- else:
- self.__clearData()
- self.__updatePath()
-
- def _isDataSupported(self, data):
- """Check if the data can be returned by the dialog.
-
- If true, this data can be returned by the dialog and the open button
- while be enabled. If false the button will be disabled.
-
- :rtype: bool
- """
- raise NotImplementedError()
-
- def _isDataSupportable(self, data):
- """Check if the selected data can be supported at one point.
-
- If true, the data selector will be checked and it will update the data
- preview. Else the selecting is disabled.
-
- :rtype: bool
- """
- raise NotImplementedError()
-
- def __clearData(self):
- """Clear the data part of the GUI"""
- if self.__previewWidget is not None:
- self.__previewWidget.setData(None)
- if self.__selectorWidget is not None:
- self.__selectorWidget.setData(None)
- self.__selectorWidget.hide()
- self.__selectedData = None
- self.__data = None
- self.__updateDataInfo()
- button = self.__buttons.button(qt.QDialogButtonBox.Open)
- button.setEnabled(False)
-
- def __selectorWidgetChanged(self):
- data = self.__selectorWidget.getSelectedData(self.__data)
- self.__setSelectedData(data)
-
- def __setSelectedData(self, data):
- """Set the data selected by the dialog.
-
- If :meth:`_isDataSupported` returns false, this function will be
- inhibited and no data will be selected.
- """
- if isinstance(data, _FabioData):
- data = data[()]
- if self.__previewWidget is not None:
- fromDataSelector = self.__selectedData is not None
- self.__previewWidget.setData(data, fromDataSelector=fromDataSelector)
- if self._isDataSupported(data):
- self.__selectedData = data
- else:
- self.__clearData()
- return
- self.__updateDataInfo()
- self.__updatePath()
- button = self.__buttons.button(qt.QDialogButtonBox.Open)
- button.setEnabled(True)
-
- def __updateDataInfo(self):
- if self.__errorWhileLoadingFile is not None:
- filename, message = self.__errorWhileLoadingFile
- message = "<b>Error while loading file '%s'</b><hr/>%s" % (filename, message)
- size = self.__dataInfo.height()
- icon = self.style().standardIcon(qt.QStyle.SP_MessageBoxCritical)
- pixmap = icon.pixmap(size, size)
-
- self.__dataInfo.setText("Error while loading file")
- self.__dataInfo.setToolTip(message)
- self.__dataIcon.setToolTip(message)
- self.__dataIcon.setVisible(True)
- self.__dataIcon.setPixmap(pixmap)
-
- self.__errorWhileLoadingFile = None
- return
-
- self.__dataIcon.setVisible(False)
- self.__dataInfo.setToolTip("")
- if self.__selectedData is None:
- self.__dataInfo.setText("No data selected")
- else:
- text = self._displayedDataInfo(self.__data, self.__selectedData)
- self.__dataInfo.setVisible(text is not None)
- if text is not None:
- self.__dataInfo.setText(text)
-
- def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
- """Returns the text displayed under the data preview.
-
- This zone is used to display error in case or problem of data selection
- or problems with IO.
-
- :param numpy.ndarray dataAfterSelection: Data as it is after the
- selection widget (basically the data from the preview widget)
- :param numpy.ndarray dataAfterSelection: Data as it is before the
- selection widget (basically the data from the browsing widget)
- :rtype: bool
- """
- return None
-
- def __createUrlFromIndex(self, index, useSelectorWidget=True):
- if index.model() is self.__fileModel:
- filename = self.__fileModel.filePath(index)
- dataPath = None
- elif index.model() is self.__dataModel:
- obj = self.__dataModel.data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
- filename = obj.file.filename
- dataPath = obj.name
- else:
- # root of the computer
- filename = ""
- dataPath = None
-
- if useSelectorWidget and self.__selectorWidget is not None and self.__selectorWidget.isUsed():
- slicing = self.__selectorWidget.slicing()
- if slicing == tuple():
- slicing = None
- else:
- slicing = None
-
- if self.__fabio is not None:
- scheme = "fabio"
- elif self.__h5 is not None:
- scheme = "silx"
- else:
- if os.path.isfile(filename):
- codec = self.__fileTypeCombo.currentCodec()
- if codec.is_fabio_codec():
- scheme = "fabio"
- elif codec.is_silx_codec():
- scheme = "silx"
- else:
- scheme = None
- else:
- scheme = None
-
- url = silx.io.url.DataUrl(file_path=filename, data_path=dataPath, data_slice=slicing, scheme=scheme)
- return url
-
- def __updatePath(self):
- index = self.__browser.selectedIndex()
- if index is None:
- index = self.__browser.rootIndex()
- url = self.__createUrlFromIndex(index)
- if url.path() != self.__pathEdit.text():
- old = self.__pathEdit.blockSignals(True)
- self.__pathEdit.setText(url.path())
- self.__pathEdit.blockSignals(old)
-
- def __rootIndexChanged(self, index):
- url = self.__createUrlFromIndex(index, useSelectorWidget=False)
-
- currentUrl = None
- if 0 <= self.__currentHistoryLocation < len(self.__currentHistory):
- currentUrl = self.__currentHistory[self.__currentHistoryLocation]
-
- if currentUrl is None or currentUrl != url.path():
- # clean up the forward history
- self.__currentHistory = self.__currentHistory[0:self.__currentHistoryLocation + 1]
- self.__currentHistory.append(url.path())
- self.__currentHistoryLocation += 1
-
- if index.model() != self.__dataModel:
- if sys.platform == "win32":
- # path == ""
- isRoot = not index.isValid()
- else:
- # path in ["", "/"]
- isRoot = not index.isValid() or not index.parent().isValid()
- else:
- isRoot = False
-
- if index.isValid():
- self.__dataSelected(index)
- self.__toParentAction.setEnabled(not isRoot)
- self.__updateActionHistory()
- self.__updateSidebar()
-
- def __updateSidebar(self):
- """Called when the current directory location change"""
- if self.__sidebar is None:
- return
- selectionModel = self.__sidebar.selectionModel()
- selectionModel.selectionChanged.disconnect(self.__shortcutSelected)
- index = self.__browser.rootIndex()
- if index.model() == self.__fileModel:
- path = self.__fileModel.filePath(index)
- self.__sidebar.setSelectedPath(path)
- elif index.model() is None:
- path = ""
- self.__sidebar.setSelectedPath(path)
- else:
- selectionModel.clear()
- selectionModel.selectionChanged.connect(self.__shortcutSelected)
-
- def __updateActionHistory(self):
- self.__forwardAction.setEnabled(len(self.__currentHistory) - 1 > self.__currentHistoryLocation)
- self.__backwardAction.setEnabled(self.__currentHistoryLocation > 0)
-
- def __textChanged(self, text):
- self.__pathChanged()
-
- def _isFabioFilesSupported(self):
- """Returns true fabio files can be loaded.
- """
- return True
-
- def _isLoadableUrl(self, url):
- """Returns true if the URL is loadable by this dialog.
-
- :param DataUrl url: The requested URL
- """
- return True
-
- def __pathChanged(self):
- url = silx.io.url.DataUrl(path=self.__pathEdit.text())
- if url.is_valid() or url.path() == "":
- if url.path() in ["", "/"] or url.file_path() in ["", "/"]:
- self.__fileModel_setRootPath(qt.QDir.rootPath())
- elif os.path.exists(url.file_path()):
- rootIndex = None
- if os.path.isdir(url.file_path()):
- self.__fileModel_setRootPath(url.file_path())
- index = self.__fileModel.index(url.file_path())
- elif os.path.isfile(url.file_path()):
- if self._isLoadableUrl(url):
- if url.scheme() == "silx":
- loaded = self.__openSilxFile(url.file_path())
- elif url.scheme() == "fabio" and self._isFabioFilesSupported():
- loaded = self.__openFabioFile(url.file_path())
- else:
- loaded = self.__openFile(url.file_path())
- else:
- loaded = False
- if loaded:
- if self.__h5 is not None:
- rootIndex = self.__dataModel.indexFromH5Object(self.__h5)
- elif self.__fabio is not None:
- index = self.__fileModel.index(url.file_path())
- rootIndex = index
- if rootIndex is None:
- index = self.__fileModel.index(url.file_path())
- index = index.parent()
-
- if rootIndex is not None:
- if rootIndex.model() == self.__dataModel:
- if url.data_path() is not None:
- dataPath = url.data_path()
- if dataPath in self.__h5:
- obj = self.__h5[dataPath]
- else:
- path = utils.findClosestSubPath(self.__h5, dataPath)
- if path is None:
- path = "/"
- obj = self.__h5[path]
-
- if silx.io.is_file(obj):
- self.__browser.setRootIndex(rootIndex)
- elif silx.io.is_group(obj):
- index = self.__dataModel.indexFromH5Object(obj)
- self.__browser.setRootIndex(index)
- else:
- index = self.__dataModel.indexFromH5Object(obj)
- self.__browser.setRootIndex(index.parent())
- self.__browser.selectIndex(index)
- else:
- self.__browser.setRootIndex(rootIndex)
- self.__clearData()
- elif rootIndex.model() == self.__fileModel:
- # that's a fabio file
- self.__browser.setRootIndex(rootIndex.parent())
- self.__browser.selectIndex(rootIndex)
- # data = _FabioData(self.__fabio)
- # self.__setData(data)
- else:
- assert(False)
- else:
- self.__browser.setRootIndex(index, model=self.__fileModel)
- self.__clearData()
-
- if self.__selectorWidget is not None:
- self.__selectorWidget.selectSlicing(url.data_slice())
- else:
- self.__errorWhileLoadingFile = (url.file_path(), "File not found")
- self.__clearData()
- else:
- self.__errorWhileLoadingFile = (url.file_path(), "Path invalid")
- self.__clearData()
-
- def previewToolbar(self):
- return self.__previewToolbar
-
- def previewWidget(self):
- return self.__previewWidget
-
- def selectorWidget(self):
- return self.__selectorWidget
-
- def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
- return None
-
- def _createPreviewWidget(self, parent):
- return None
-
- def _createSelectorWidget(self, parent):
- return None
-
- # Selected file
-
- def setDirectory(self, path):
- """Sets the data dialog's current directory."""
- self.__fileModel_setRootPath(path)
-
- def selectedFile(self):
- """Returns the file path containing the selected data.
-
- :rtype: str
- """
- return self.__selectedFile
-
- def selectFile(self, filename):
- """Sets the data dialog's current file."""
- self.__directoryLoadedFilter = ""
- old = self.__pathEdit.blockSignals(True)
- try:
- self.__pathEdit.setText(filename)
- finally:
- self.__pathEdit.blockSignals(old)
- self.__pathChanged()
-
- # Selected data
-
- def selectUrl(self, url):
- """Sets the data dialog's current data url.
-
- :param Union[str,DataUrl] url: URL identifying a data (it can be a
- `DataUrl` object)
- """
- if isinstance(url, silx.io.url.DataUrl):
- url = url.path()
- self.__directoryLoadedFilter = ""
- old = self.__pathEdit.blockSignals(True)
- try:
- self.__pathEdit.setText(url)
- finally:
- self.__pathEdit.blockSignals(old)
- self.__pathChanged()
-
- def selectedUrl(self):
- """Returns the URL from the file system to the data.
-
- If the dialog is not validated, the path can be an intermediat
- selected path, or an invalid path.
-
- :rtype: str
- """
- return self.__pathEdit.text()
-
- def selectedDataUrl(self):
- """Returns the URL as a :class:`DataUrl` from the file system to the
- data.
-
- If the dialog is not validated, the path can be an intermediat
- selected path, or an invalid path.
-
- :rtype: DataUrl
- """
- url = self.selectedUrl()
- return silx.io.url.DataUrl(url)
-
- def directory(self):
- """Returns the path from the current browsed directory.
-
- :rtype: str
- """
- if self.__directory is not None:
- # At post execution, returns the cache
- return self.__directory
-
- index = self.__browser.rootIndex()
- if index.model() is self.__fileModel:
- path = self.__fileModel.filePath(index)
- return path
- elif index.model() is self.__dataModel:
- path = os.path.dirname(self.__h5.file.filename)
- return path
- else:
- return ""
-
- def _selectedData(self):
- """Returns the internal selected data
-
- :rtype: numpy.ndarray
- """
- return self.__selectedData
-
- # Filters
-
- def selectedNameFilter(self):
- """Returns the filter that the user selected in the file dialog."""
- return self.__fileTypeCombo.currentText()
-
- # History
-
- def history(self):
- """Returns the browsing history of the filedialog as a list of paths.
-
- :rtype: List<str>
- """
- if len(self.__currentHistory) <= 1:
- return []
- history = self.__currentHistory[0:self.__currentHistoryLocation]
- return list(history)
-
- def setHistory(self, history):
- self.__currentHistory = []
- self.__currentHistory.extend(history)
- self.__currentHistoryLocation = len(self.__currentHistory) - 1
- self.__updateActionHistory()
-
- # Colormap
-
- def colormap(self):
- if self.__previewWidget is None:
- return None
- return self.__previewWidget.colormap()
-
- def setColormap(self, colormap):
- if self.__previewWidget is None:
- raise RuntimeError("No preview widget defined")
- self.__previewWidget.setColormap(colormap)
-
- # Sidebar
-
- def setSidebarUrls(self, urls):
- """Sets the urls that are located in the sidebar."""
- if self.__sidebar is None:
- return
- self.__sidebar.setUrls(urls)
-
- def sidebarUrls(self):
- """Returns a list of urls that are currently in the sidebar."""
- if self.__sidebar is None:
- return []
- return self.__sidebar.urls()
-
- # State
-
- __serialVersion = 1
- """Store the current version of the serialized data"""
-
- @classmethod
- def qualifiedName(cls):
- return "%s.%s" % (cls.__module__, cls.__name__)
-
- def restoreState(self, state):
- """Restores the dialogs's layout, history and current directory to the
- state specified.
-
- :param qt.QByteArray state: Stream containing the new state
- :rtype: bool
- """
- stream = qt.QDataStream(state, qt.QIODevice.ReadOnly)
-
- qualifiedName = stream.readQString()
- if qualifiedName != self.qualifiedName():
- _logger.warning("Stored state contains an invalid qualified name. %s restoration cancelled.", self.__class__.__name__)
- return False
-
- version = stream.readInt32()
- if version != self.__serialVersion:
- _logger.warning("Stored state contains an invalid version. %s restoration cancelled.", self.__class__.__name__)
- return False
-
- result = True
-
- splitterData = stream.readQVariant()
- sidebarUrls = stream.readQStringList()
- history = stream.readQStringList()
- workingDirectory = stream.readQString()
- browserData = stream.readQVariant()
- viewMode = stream.readInt32()
- colormapData = stream.readQVariant()
-
- result &= self.__splitter.restoreState(splitterData)
- sidebarUrls = [qt.QUrl(s) for s in sidebarUrls]
- self.setSidebarUrls(list(sidebarUrls))
- history = [s for s in history]
- self.setHistory(list(history))
- if workingDirectory is not None:
- self.setDirectory(workingDirectory)
- result &= self.__browser.restoreState(browserData)
- self.setViewMode(viewMode)
- colormap = self.colormap()
- if colormap is not None:
- result &= self.colormap().restoreState(colormapData)
-
- return result
-
- def saveState(self):
- """Saves the state of the dialog's layout, history and current
- directory.
-
- :rtype: qt.QByteArray
- """
- data = qt.QByteArray()
- stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
-
- s = self.qualifiedName()
- stream.writeQString(u"%s" % s)
- stream.writeInt32(self.__serialVersion)
- stream.writeQVariant(self.__splitter.saveState())
- strings = [u"%s" % s.toString() for s in self.sidebarUrls()]
- stream.writeQStringList(strings)
- strings = [u"%s" % s for s in self.history()]
- stream.writeQStringList(strings)
- stream.writeQString(u"%s" % self.directory())
- stream.writeQVariant(self.__browser.saveState())
- stream.writeInt32(self.viewMode())
- colormap = self.colormap()
- if colormap is not None:
- stream.writeQVariant(self.colormap().saveState())
- else:
- stream.writeQVariant(None)
-
- return data
diff --git a/silx/gui/dialog/ColormapDialog.py b/silx/gui/dialog/ColormapDialog.py
deleted file mode 100644
index ca7ee97..0000000
--- a/silx/gui/dialog/ColormapDialog.py
+++ /dev/null
@@ -1,1771 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""A QDialog widget to set-up the colormap.
-
-It uses a description of colormaps as dict compatible with :class:`Plot`.
-
-To run the following sample code, a QApplication must be initialized.
-
-Create the colormap dialog and set the colormap description and data range:
-
->>> from silx.gui.dialog.ColormapDialog import ColormapDialog
->>> from silx.gui.colors import Colormap
-
->>> dialog = ColormapDialog()
->>> colormap = Colormap(name='red', normalization='log',
-... vmin=1., vmax=2.)
-
->>> dialog.setColormap(colormap)
->>> colormap.setVRange(1., 100.) # This scale the width of the plot area
->>> dialog.show()
-
-Get the colormap description (compatible with :class:`Plot`) from the dialog:
-
->>> cmap = dialog.getColormap()
->>> cmap.getName()
-'red'
-
-It is also possible to display an histogram of the image in the dialog.
-This updates the data range with the range of the bins.
-
->>> import numpy
->>> image = numpy.random.normal(size=512 * 512).reshape(512, -1)
->>> hist, bin_edges = numpy.histogram(image, bins=10)
->>> dialog.setHistogram(hist, bin_edges)
-
-The updates of the colormap description are also available through the signal:
-:attr:`ColormapDialog.sigColormapChanged`.
-""" # noqa
-
-__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
-__license__ = "MIT"
-__date__ = "08/12/2020"
-
-import enum
-import logging
-
-import numpy
-
-from .. import qt
-from .. import utils
-from ..colors import Colormap, cursorColorForColormap
-from ..plot import PlotWidget
-from ..plot.items.axis import Axis
-from ..plot.items import BoundingRect
-from silx.gui.widgets.FloatEdit import FloatEdit
-import weakref
-from silx.math.combo import min_max
-from silx.gui.plot import items
-from silx.gui import icons
-from silx.gui.qt import inspect as qtinspect
-from silx.gui.widgets.ColormapNameComboBox import ColormapNameComboBox
-from silx.gui.widgets.WaitingPushButton import WaitingPushButton
-from silx.math.histogram import Histogramnd
-from silx.utils import deprecation
-from silx.gui.plot.items.roi import RectangleROI
-from silx.gui.plot.tools.roi import RegionOfInterestManager
-
-_logger = logging.getLogger(__name__)
-
-_colormapIconPreview = {}
-
-
-class _DataRefHolder(items.Item, items.ColormapMixIn):
- """Holder for a weakref of a numpy array.
-
- It provides features from `ColormapMixIn`.
- """
-
- def __init__(self, dataRef):
- items.Item.__init__(self)
- items.ColormapMixIn.__init__(self)
- self.__dataRef = dataRef
- self._updated(items.ItemChangedType.DATA)
-
- def getColormappedData(self, copy=True):
- return self.__dataRef()
-
-
-class _BoundaryWidget(qt.QWidget):
- """Widget to edit a boundary of the colormap (vmin or vmax)"""
-
- sigAutoScaleChanged = qt.Signal(object)
- """Signal emitted when the autoscale was changed
-
- True is sent as an argument if autoscale is set to true.
- """
-
- sigValueChanged = qt.Signal(object)
- """Signal emitted when value is changed
-
- The new value is sent as an argument.
- """
-
- def __init__(self, parent=None, value=0.0):
- qt.QWidget.__init__(self, parent=parent)
- self.setLayout(qt.QHBoxLayout())
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._numVal = FloatEdit(parent=self, value=value)
- self.layout().addWidget(self._numVal)
- self._autoCB = qt.QCheckBox('auto', parent=self)
- self.layout().addWidget(self._autoCB)
- self._autoCB.setChecked(False)
- self._autoCB.setVisible(False)
-
- self._autoCB.toggled.connect(self._autoToggled)
- self._numVal.textEdited.connect(self.__textEdited)
- self._numVal.editingFinished.connect(self.__editingFinished)
- self.setFocusProxy(self._numVal)
-
- self.__textWasEdited = False
- """True if the text was edited, in order to send an event
- at the end of the user interaction"""
-
- self.__realValue = None
- """Store the real value set by setValue, to avoid
- rounding of the widget"""
-
- def __textEdited(self):
- self.__textWasEdited = True
-
- def __editingFinished(self):
- if self.__textWasEdited:
- value = self._numVal.value()
- self.__realValue = value
- with utils.blockSignals(self._numVal):
- # Fix the formatting
- self._numVal.setValue(self.__realValue)
- self.sigValueChanged.emit(value)
- self.__textWasEdited = False
-
- def isAutoChecked(self):
- return self._autoCB.isChecked()
-
- def getValue(self):
- """Returns the stored range. If autoscale is
- enabled, this returns None.
- """
- if self._autoCB.isChecked():
- return None
- if self.__realValue is not None:
- return self.__realValue
- return self._numVal.value()
-
- def _autoToggled(self, enabled):
- self._numVal.setEnabled(not enabled)
- self._updateDisplayedText()
- self.sigAutoScaleChanged.emit(enabled)
-
- def _updateDisplayedText(self):
- self.__textWasEdited = False
- if self._autoCB.isChecked() and self.__realValue is not None:
- with utils.blockSignals(self._numVal):
- self._numVal.setValue(self.__realValue)
-
- def setValue(self, value, isAuto=False):
- """Set the value of the boundary.
-
- :param float value: A finite value for the boundary
- :param bool isAuto: If true, the finite value was automatically computed
- from the data, else it is a fixed custom value.
- """
- assert value is not None
- self._autoCB.setChecked(isAuto)
- with utils.blockSignals(self._numVal):
- if isAuto or self.__realValue != value:
- if not self.__textWasEdited:
- self._numVal.setValue(value)
- self.__realValue = value
- self._numVal.setEnabled(not isAuto)
-
-
-class _AutoscaleModeComboBox(qt.QComboBox):
-
- DATA = {
- Colormap.MINMAX: ("Min/max", "Use the data min/max"),
- Colormap.STDDEV3: ("Mean ± 3 × stddev", "Use the data mean ± 3 × standard deviation"),
- }
-
- def __init__(self, parent: qt.QWidget):
- super(_AutoscaleModeComboBox, self).__init__(parent=parent)
- self.currentIndexChanged.connect(self.__updateTooltip)
- self._init()
-
- def _init(self):
- for mode in Colormap.AUTOSCALE_MODES:
- label, tooltip = self.DATA.get(mode, (mode, None))
- self.addItem(label, mode)
- if tooltip is not None:
- self.setItemData(self.count() - 1, tooltip, qt.Qt.ToolTipRole)
-
- def setCurrentIndex(self, index):
- self.__updateTooltip(index)
- super(_AutoscaleModeComboBox, self).setCurrentIndex(index)
-
- def __updateTooltip(self, index):
- if index > -1:
- tooltip = self.itemData(index, qt.Qt.ToolTipRole)
- else:
- tooltip = ""
- self.setToolTip(tooltip)
-
- def currentMode(self):
- index = self.currentIndex()
- return self.itemData(index)
-
- def setCurrentMode(self, mode):
- for index in range(self.count()):
- if mode == self.itemData(index):
- self.setCurrentIndex(index)
- return
- if mode is None:
- # If None was not a value
- self.setCurrentIndex(-1)
- return
- self.addItem(mode, mode)
- self.setCurrentIndex(self.count() - 1)
-
-
-class _AutoScaleButtons(qt.QWidget):
-
- autoRangeChanged = qt.Signal(object)
-
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent=parent)
- layout = qt.QHBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
-
- self.setFocusPolicy(qt.Qt.NoFocus)
-
- self._bothAuto = qt.QPushButton(self)
- self._bothAuto.setText("Autoscale")
- self._bothAuto.setToolTip("Enable/disable the autoscale for both min and max")
- self._bothAuto.setCheckable(True)
- self._bothAuto.toggled[bool].connect(self.__bothToggled)
- self._bothAuto.setFocusPolicy(qt.Qt.TabFocus)
-
- self._minAuto = qt.QCheckBox(self)
- self._minAuto.setText("")
- self._minAuto.setToolTip("Enable/disable the autoscale for min")
- self._minAuto.toggled[bool].connect(self.__minToggled)
- self._minAuto.setFocusPolicy(qt.Qt.TabFocus)
-
- self._maxAuto = qt.QCheckBox(self)
- self._maxAuto.setText("")
- self._maxAuto.setToolTip("Enable/disable the autoscale for max")
- self._maxAuto.toggled[bool].connect(self.__maxToggled)
- self._maxAuto.setFocusPolicy(qt.Qt.TabFocus)
-
- layout.addStretch(1)
- layout.addWidget(self._minAuto)
- layout.addSpacing(20)
- layout.addWidget(self._bothAuto)
- layout.addSpacing(20)
- layout.addWidget(self._maxAuto)
- layout.addStretch(1)
-
- def __bothToggled(self, checked):
- autoRange = checked, checked
- self.setAutoRange(autoRange)
- self.autoRangeChanged.emit(autoRange)
-
- def __minToggled(self, checked):
- autoRange = self.getAutoRange()
- self.setAutoRange(autoRange)
- self.autoRangeChanged.emit(autoRange)
-
- def __maxToggled(self, checked):
- autoRange = self.getAutoRange()
- self.setAutoRange(autoRange)
- self.autoRangeChanged.emit(autoRange)
-
- def setAutoRangeFromColormap(self, colormap):
- vRange = colormap.getVRange()
- autoRange = vRange[0] is None, vRange[1] is None
- self.setAutoRange(autoRange)
-
- def setAutoRange(self, autoRange):
- if autoRange[0] == autoRange[1]:
- with utils.blockSignals(self._bothAuto):
- self._bothAuto.setChecked(autoRange[0])
- else:
- with utils.blockSignals(self._bothAuto):
- self._bothAuto.setChecked(False)
- with utils.blockSignals(self._minAuto):
- self._minAuto.setChecked(autoRange[0])
- with utils.blockSignals(self._maxAuto):
- self._maxAuto.setChecked(autoRange[1])
-
- def getAutoRange(self):
- return self._minAuto.isChecked(), self._maxAuto.isChecked()
-
-
-@enum.unique
-class _DataInPlotMode(enum.Enum):
- """Enum for each mode of display of the data in the plot."""
- RANGE = 'range'
- HISTOGRAM = 'histogram'
-
-
-class _ColormapHistogram(qt.QWidget):
- """Display the colormap and the data as a plot."""
-
- sigRangeMoving = qt.Signal(object, object)
- """Emitted when a mouse interaction moves the location
- of the colormap range in the plot.
-
- This signal contains 2 elements:
-
- - vmin: A float value if this range was moved, else None
- - vmax: A float value if this range was moved, else None
- """
-
- sigRangeMoved = qt.Signal(object, object)
- """Emitted when a mouse interaction stop.
-
- This signal contains 2 elements:
-
- - vmin: A float value if this range was moved, else None
- - vmax: A float value if this range was moved, else None
- """
-
- def __init__(self, parent):
- qt.QWidget.__init__(self, parent=parent)
- self._dataInPlotMode = _DataInPlotMode.RANGE
- self._finiteRange = None, None
- self._initPlot()
-
- self._histogramData = {}
- """Histogram displayed in the plot"""
-
- self._dragging = False, False
- """True, if the min or the max handle is dragging"""
-
- self._dataRange = {}
- """Histogram displayed in the plot"""
-
- self._invalidated = False
-
- def paintEvent(self, event):
- if self._invalidated:
- self._updateDataInPlot()
- self._invalidated = False
- self._updateMarkerPosition()
- return super(_ColormapHistogram, self).paintEvent(event)
-
- def getFiniteRange(self):
- """Returns the colormap range as displayed in the plot."""
- return self._finiteRange
-
- def setFiniteRange(self, vRange):
- """Set the colormap range to use in the plot.
-
- Here there is no concept of auto. The values should
- not be None, except if there is no range or marker
- to display.
- """
- # Do not reset the limit for handle about to be dragged
- if self._dragging[0]:
- vRange = self._finiteRange[0], vRange[1]
- if self._dragging[1]:
- vRange = vRange[0], self._finiteRange[1]
-
- if vRange == self._finiteRange:
- return
-
- self._finiteRange = vRange
- self.update()
-
- def getColormap(self):
- return self.parent().getColormap()
-
- def _getNormalizedHistogram(self):
- """Return an histogram already normalized according to the colormap
- normalization.
-
- Returns a tuple edges, counts
- """
- norm = self._getNorm()
- histogram = self._histogramData.get(norm, None)
- if histogram is None:
- histogram = self._computeNormalizedHistogram()
- self._histogramData[norm] = histogram
- return histogram
-
- def _computeNormalizedHistogram(self):
- colormap = self.getColormap()
- if colormap is None:
- norm = Colormap.LINEAR
- else:
- norm = colormap.getNormalization()
-
- # Try to use the histogram defined in the dialog
- histo = self.parent()._getHistogram()
- if histo is not None:
- counts, edges = histo
- normalizer = Colormap(normalization=norm)._getNormalizer()
- mask = normalizer.isValid(edges[:-1]) # Check lower bin edges only
- firstValid = numpy.argmax(mask) # edges increases monotonically
- if firstValid == 0: # Mask is all False or all True
- return (counts, edges) if mask[0] else (None, None)
- else: # Clip to valid values
- return counts[firstValid:], edges[firstValid:]
-
- data = self.parent()._getArray()
- if data is None:
- return None, None
- dataRange = self._getNormalizedDataRange()
- if dataRange[0] is None or dataRange[1] is None:
- return None, None
- counts, edges = self.parent().computeHistogram(data, scale=norm, dataRange=dataRange)
- return counts, edges
-
- def _getNormalizedDataRange(self):
- """Return a data range already normalized according to the colormap
- normalization.
-
- Returns a tuple with min and max
- """
- norm = self._getNorm()
- dataRange = self._dataRange.get(norm, None)
- if dataRange is None:
- dataRange = self._computeNormalizedDataRange()
- self._dataRange[norm] = dataRange
- return dataRange
-
- def _computeNormalizedDataRange(self):
- colormap = self.getColormap()
- if colormap is None:
- norm = Colormap.LINEAR
- else:
- norm = colormap.getNormalization()
-
- # Try to use the one defined in the dialog
- dataRange = self.parent()._getDataRange()
- if dataRange is not None:
- if norm in (Colormap.LINEAR, Colormap.GAMMA, Colormap.ARCSINH):
- return dataRange[0], dataRange[2]
- elif norm == Colormap.LOGARITHM:
- return dataRange[1], dataRange[2]
- elif norm == Colormap.SQRT:
- return dataRange[1], dataRange[2]
- else:
- _logger.error("Undefined %s normalization", norm)
-
- # Try to use the histogram defined in the dialog
- histo = self.parent()._getHistogram()
- if histo is not None:
- _histo, edges = histo
- normalizer = Colormap(normalization=norm)._getNormalizer()
- edges = edges[normalizer.isValid(edges)]
- if edges.size == 0:
- return None, None
- else:
- dataRange = min_max(edges, finite=True)
- return dataRange.minimum, dataRange.maximum
-
- item = self.parent()._getItem()
- if item is not None:
- # Trick to reach data range using colormap cache
- cm = Colormap()
- cm.setVRange(None, None)
- cm.setNormalization(norm)
- dataRange = item._getColormapAutoscaleRange(cm)
- return dataRange
-
- # If there is no item, there is no data
- return None, None
-
- def _getDisplayableRange(self):
- """Returns the selected min/max range to apply to the data,
- according to the used scale.
-
- One or both limits can be None in case it is not displayable in the
- current axes scale.
-
- :returns: Tuple{float, float}
- """
- scale = self._plot.getXAxis().getScale()
-
- def isDisplayable(pos):
- if pos is None:
- return False
- if scale == Axis.LOGARITHMIC:
- return pos > 0.0
- return True
-
- posMin, posMax = self.getFiniteRange()
- if not isDisplayable(posMin):
- posMin = None
- if not isDisplayable(posMax):
- posMax = None
-
- return posMin, posMax
-
- def _initPlot(self):
- """Init the plot to display the range and the values"""
- self._plot = PlotWidget(self)
- self._plot.setDataMargins(0.125, 0.125, 0.125, 0.125)
- self._plot.getXAxis().setLabel("Data Values")
- self._plot.getYAxis().setLabel("")
- self._plot.setInteractiveMode('select', zoomOnWheel=False)
- self._plot.setActiveCurveHandling(False)
- self._plot.setMinimumSize(qt.QSize(250, 200))
- self._plot.sigPlotSignal.connect(self._plotEventReceived)
- palette = self.palette()
- color = palette.color(qt.QPalette.Normal, qt.QPalette.Window)
- self._plot.setBackgroundColor(color)
- self._plot.setDataBackgroundColor("white")
-
- lut = numpy.arange(256)
- lut.shape = 1, -1
- self._plot.addImage(lut, legend='lut')
- self._lutItem = self._plot._getItem("image", "lut")
- self._lutItem.setVisible(False)
-
- self._plot.addScatter(x=[], y=[], value=[], legend='lut2')
- self._lutItem2 = self._plot._getItem("scatter", "lut2")
- self._lutItem2.setVisible(False)
- self.__lutY = numpy.array([-0.05] * 256)
- self.__lutV = numpy.arange(256)
-
- self._bound = BoundingRect()
- self._plot.addItem(self._bound)
- self._bound.setVisible(True)
-
- # Add plot for histogram
- self._plotToolbar = qt.QToolBar(self)
- self._plotToolbar.setFloatable(False)
- self._plotToolbar.setMovable(False)
- self._plotToolbar.setIconSize(qt.QSize(8, 8))
- self._plotToolbar.setStyleSheet("QToolBar { border: 0px }")
- self._plotToolbar.setOrientation(qt.Qt.Vertical)
-
- group = qt.QActionGroup(self._plotToolbar)
- group.setExclusive(True)
-
- action = qt.QAction("Data range", self)
- action.setToolTip("Display the data range within the colormap range. A fast data processing have to be done.")
- action.setIcon(icons.getQIcon('colormap-range'))
- action.setCheckable(True)
- action.setData(_DataInPlotMode.RANGE)
- action.setChecked(action.data() == self._dataInPlotMode)
- self._plotToolbar.addAction(action)
- group.addAction(action)
- action = qt.QAction("Histogram", self)
- action.setToolTip("Display the data histogram within the colormap range. A slow data processing have to be done. ")
- action.setIcon(icons.getQIcon('colormap-histogram'))
- action.setCheckable(True)
- action.setData(_DataInPlotMode.HISTOGRAM)
- action.setChecked(action.data() == self._dataInPlotMode)
- self._plotToolbar.addAction(action)
- group.addAction(action)
- group.triggered.connect(self._displayDataInPlotModeChanged)
-
- plotBoxLayout = qt.QHBoxLayout()
- plotBoxLayout.setContentsMargins(0, 0, 0, 0)
- plotBoxLayout.setSpacing(2)
- plotBoxLayout.addWidget(self._plotToolbar)
- plotBoxLayout.addWidget(self._plot)
- plotBoxLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
- self.setLayout(plotBoxLayout)
-
- def _plotEventReceived(self, event):
- """Handle events from the plot"""
- kind = event['event']
-
- if kind == 'markerMoving':
- value = event['xdata']
- if event['label'] == 'Min':
- self._dragging = True, False
- self._finiteRange = value, self._finiteRange[1]
- self._last = value, None
- self.sigRangeMoving.emit(*self._last)
- elif event['label'] == 'Max':
- self._dragging = False, True
- self._finiteRange = self._finiteRange[0], value
- self._last = None, value
- self.sigRangeMoving.emit(*self._last)
- self._updateLutItem(self._finiteRange)
- elif kind == 'markerMoved':
- self.sigRangeMoved.emit(*self._last)
- self._plot.resetZoom()
- self._dragging = False, False
- else:
- pass
-
- def _updateMarkerPosition(self):
- colormap = self.getColormap()
- posMin, posMax = self._getDisplayableRange()
-
- if colormap is None:
- isDraggable = False
- else:
- isDraggable = colormap.isEditable()
-
- with utils.blockSignals(self):
- if posMin is not None and not self._dragging[0]:
- self._plot.addXMarker(
- posMin,
- legend='Min',
- text='Min',
- draggable=isDraggable,
- color="blue",
- constraint=self._plotMinMarkerConstraint)
- if posMax is not None and not self._dragging[1]:
- self._plot.addXMarker(
- posMax,
- legend='Max',
- text='Max',
- draggable=isDraggable,
- color="blue",
- constraint=self._plotMaxMarkerConstraint)
-
- self._updateLutItem((posMin, posMax))
- self._plot.resetZoom()
-
- def _updateLutItem(self, vRange):
- colormap = self.getColormap()
- if colormap is None:
- return
-
- if vRange is None:
- posMin, posMax = self._getDisplayableRange()
- else:
- posMin, posMax = vRange
- if posMin is None or posMax is None:
- self._lutItem.setVisible(False)
- pos = posMax if posMin is None else posMin
- if pos is not None:
- self._bound.setBounds((pos, pos, -0.1, 0))
- else:
- self._bound.setBounds((0, 0, -0.1, 0))
- else:
- norm = colormap.getNormalization()
- normColormap = colormap.copy()
- normColormap.setVRange(0, 255)
- normColormap.setNormalization(Colormap.LINEAR)
- if norm == Colormap.LINEAR:
- scale = (posMax - posMin) / 256
- self._lutItem.setColormap(normColormap)
- self._lutItem.setOrigin((posMin, -0.09))
- self._lutItem.setScale((scale, 0.08))
- self._lutItem.setVisible(True)
- self._lutItem2.setVisible(False)
- elif norm == Colormap.LOGARITHM:
- self._lutItem2.setVisible(False)
- self._lutItem2.setColormap(normColormap)
- xx = numpy.geomspace(posMin, posMax, 256)
- self._lutItem2.setData(x=xx,
- y=self.__lutY,
- value=self.__lutV,
- copy=False)
- self._lutItem2.setSymbol("|")
- self._lutItem2.setVisible(True)
- self._lutItem.setVisible(False)
- else:
- # Fallback: Display with linear axis and applied normalization
- self._lutItem2.setVisible(False)
- normColormap.setNormalization(norm)
- self._lutItem2.setColormap(normColormap)
- xx = numpy.linspace(posMin, posMax, 256, endpoint=True)
- self._lutItem2.setData(
- x=xx,
- y=self.__lutY,
- value=self.__lutV,
- copy=False)
- self._lutItem2.setSymbol("|")
- self._lutItem2.setVisible(True)
- self._lutItem.setVisible(False)
-
- self._bound.setBounds((posMin, posMax, -0.1, 1))
-
- def _plotMinMarkerConstraint(self, x, y):
- """Constraint of the min marker"""
- _vmin, vmax = self.getFiniteRange()
- if vmax is None:
- return x, y
- return min(x, vmax), y
-
- def _plotMaxMarkerConstraint(self, x, y):
- """Constraint of the max marker"""
- vmin, _vmax = self.getFiniteRange()
- if vmin is None:
- return x, y
- return max(x, vmin), y
-
- def _setDataInPlotMode(self, mode):
- if self._dataInPlotMode == mode:
- return
- self._dataInPlotMode = mode
- self._updateDataInPlot()
-
- def _displayDataInPlotModeChanged(self, action):
- mode = action.data()
- self._setDataInPlotMode(mode)
-
- def invalidateData(self):
- self._histogramData = {}
- self._dataRange = {}
- self._invalidated = True
- self.update()
-
- def _updateDataInPlot(self):
- mode = self._dataInPlotMode
-
- norm = self._getNorm()
- if norm == Colormap.LINEAR:
- scale = Axis.LINEAR
- elif norm == Colormap.LOGARITHM:
- scale = Axis.LOGARITHMIC
- else:
- scale = Axis.LINEAR
-
- axis = self._plot.getXAxis()
- axis.setScale(scale)
-
- if mode == _DataInPlotMode.RANGE:
- dataRange = self._getNormalizedDataRange()
- xmin, xmax = dataRange
- if xmax is None or xmin is None:
- self._plot.remove(legend='Data', kind='histogram')
- else:
- histogram = numpy.array([1])
- bin_edges = numpy.array([xmin, xmax])
- self._plot.addHistogram(histogram,
- bin_edges,
- legend="Data",
- color='gray',
- align='center',
- fill=True,
- z=1)
-
- elif mode == _DataInPlotMode.HISTOGRAM:
- histogram, bin_edges = self._getNormalizedHistogram()
- if histogram is None or bin_edges is None:
- self._plot.remove(legend='Data', kind='histogram')
- else:
- histogram = numpy.array(histogram, copy=True)
- bin_edges = numpy.array(bin_edges, copy=True)
- with numpy.errstate(invalid='ignore'):
- norm_histogram = histogram / numpy.nanmax(histogram)
- self._plot.addHistogram(norm_histogram,
- bin_edges,
- legend="Data",
- color='gray',
- align='center',
- fill=True,
- z=1)
- else:
- _logger.error("Mode unsupported")
-
- def sizeHint(self):
- return self.layout().minimumSize()
-
- def updateLut(self):
- self._updateLutItem(None)
-
- def _getNorm(self):
- colormap = self.getColormap()
- if colormap is None:
- return Axis.LINEAR
- else:
- norm = colormap.getNormalization()
- return norm
-
- def updateNormalization(self):
- self._updateDataInPlot()
- self.update()
-
-
-class ColormapDialog(qt.QDialog):
- """A QDialog widget to set the colormap.
-
- :param parent: See :class:`QDialog`
- :param str title: The QDialog title
- """
-
- visibleChanged = qt.Signal(bool)
- """This event is sent when the dialog visibility change"""
-
- def __init__(self, parent=None, title="Colormap Dialog"):
- qt.QDialog.__init__(self, parent)
- self.setWindowTitle(title)
-
- self.__aboutToDelete = False
- self._colormap = None
-
- self._data = None
- """Weak ref to an external numpy array
- """
- self._itemHolder = None
- """Hard ref to a private item (used as holder to the data)
- This allow to reuse the item cache
- """
- self._item = None
- """Weak ref to an external item"""
-
- self._colormapChange = utils.LockReentrant()
- """Used as a semaphore to avoid editing the colormap object when we are
- only attempt to display it.
- Used instead of n connect and disconnect of the sigChanged. The
- disconnection to sigChanged was also limiting when this colormapdialog
- is used in the colormapaction and associated to the activeImageChanged.
- (because the activeImageChanged is send when the colormap changed and
- the self.setcolormap is a callback)
- """
-
- self.__colormapInvalidated = False
- self.__dataInvalidated = False
-
- self._histogramData = None
-
- self._dataRange = None
- """If defined 3-tuple containing information from a data:
- minimum, positive minimum, maximum"""
-
- self._colormapStoredState = None
-
- # Colormap row
- self._comboBoxColormap = ColormapNameComboBox(parent=self)
- self._comboBoxColormap.currentIndexChanged[int].connect(self._comboBoxColormapUpdated)
-
- # Normalization row
- self._comboBoxNormalization = qt.QComboBox(parent=self)
- normalizations = [
- ('Linear', Colormap.LINEAR),
- ('Gamma correction', Colormap.GAMMA),
- ('Arcsinh', Colormap.ARCSINH),
- ('Logarithmic', Colormap.LOGARITHM),
- ('Square root', Colormap.SQRT)]
- for name, userData in normalizations:
- try:
- icon = icons.getQIcon("colormap-norm-%s" % userData)
- except:
- icon = qt.QIcon()
- self._comboBoxNormalization.addItem(icon, name, userData)
- self._comboBoxNormalization.currentIndexChanged[int].connect(
- self._normalizationUpdated)
-
- self._gammaSpinBox = qt.QDoubleSpinBox(parent=self)
- self._gammaSpinBox.setEnabled(False)
- self._gammaSpinBox.setRange(0., 1000.)
- self._gammaSpinBox.setDecimals(4)
- if hasattr(qt.QDoubleSpinBox, "setStepType"):
- # Introduced in Qt 5.12
- self._gammaSpinBox.setStepType(qt.QDoubleSpinBox.AdaptiveDecimalStepType)
- else:
- self._gammaSpinBox.setSingleStep(0.1)
- self._gammaSpinBox.valueChanged.connect(self._gammaUpdated)
- self._gammaSpinBox.setValue(2.)
-
- autoScaleCombo = _AutoscaleModeComboBox(self)
- autoScaleCombo.currentIndexChanged.connect(self._autoscaleModeUpdated)
- self._autoScaleCombo = autoScaleCombo
-
- # Min row
- self._minValue = _BoundaryWidget(parent=self, value=1.0)
- self._minValue.sigAutoScaleChanged.connect(self._minAutoscaleUpdated)
- self._minValue.sigValueChanged.connect(self._minValueUpdated)
-
- # Max row
- self._maxValue = _BoundaryWidget(parent=self, value=10.0)
- self._maxValue.sigAutoScaleChanged.connect(self._maxAutoscaleUpdated)
- self._maxValue.sigValueChanged.connect(self._maxValueUpdated)
-
- self._autoButtons = _AutoScaleButtons(self)
- self._autoButtons.autoRangeChanged.connect(self._autoRangeButtonsUpdated)
-
- rangeLayout = qt.QGridLayout()
- miniFont = qt.QFont(self.font())
- miniFont.setPixelSize(8)
- labelMin = qt.QLabel("Min", self)
- labelMin.setFont(miniFont)
- labelMin.setAlignment(qt.Qt.AlignHCenter)
- labelMax = qt.QLabel("Max", self)
- labelMax.setAlignment(qt.Qt.AlignHCenter)
- labelMax.setFont(miniFont)
- rangeLayout.addWidget(labelMin, 0, 0)
- rangeLayout.addWidget(labelMax, 0, 1)
- rangeLayout.addWidget(self._minValue, 1, 0)
- rangeLayout.addWidget(self._maxValue, 1, 1)
- rangeLayout.addWidget(self._autoButtons, 2, 0, 1, -1, qt.Qt.AlignCenter)
-
- self._histoWidget = _ColormapHistogram(self)
- self._histoWidget.sigRangeMoving.connect(self._histogramRangeMoving)
- self._histoWidget.sigRangeMoved.connect(self._histogramRangeMoved)
-
- # Scale to buttons
- self._visibleAreaButton = qt.QPushButton(self)
- self._visibleAreaButton.setEnabled(False)
- self._visibleAreaButton.setText("Visible Area")
- self._visibleAreaButton.clicked.connect(
- self._handleScaleToVisibleAreaClicked,
- type=qt.Qt.QueuedConnection)
-
- # Place-holder for selected area ROI manager
- self._roiForColormapManager = None
-
- self._selectedAreaButton = WaitingPushButton(self)
- self._selectedAreaButton.setEnabled(False)
- self._selectedAreaButton.setText("Selection")
- self._selectedAreaButton.setIcon(icons.getQIcon("add-shape-rectangle"))
- self._selectedAreaButton.setCheckable(True)
- self._selectedAreaButton.setDisabledWhenWaiting(False)
- self._selectedAreaButton.toggled.connect(
- self._handleScaleToSelectionToggled,
- type=qt.Qt.QueuedConnection)
-
- # define modal buttons
- types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
- self._buttonsModal = qt.QDialogButtonBox(parent=self)
- self._buttonsModal.setStandardButtons(types)
- self._buttonsModal.accepted.connect(self.accept)
- self._buttonsModal.rejected.connect(self.reject)
-
- # define non modal buttons
- types = qt.QDialogButtonBox.Close | qt.QDialogButtonBox.Reset
- self._buttonsNonModal = qt.QDialogButtonBox(parent=self)
- self._buttonsNonModal.setStandardButtons(types)
- button = self._buttonsNonModal.button(qt.QDialogButtonBox.Close)
- button.clicked.connect(self.accept)
- button.setDefault(True)
- button = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
- button.clicked.connect(self.resetColormap)
-
- self._buttonsModal.setFocus(qt.Qt.OtherFocusReason)
- self._buttonsNonModal.setFocus(qt.Qt.OtherFocusReason)
-
- # Set the colormap to default values
- self.setColormap(Colormap(name='gray', normalization='linear',
- vmin=None, vmax=None))
-
- self.setModal(self.isModal())
-
- formLayout = qt.QFormLayout(self)
- formLayout.setContentsMargins(10, 10, 10, 10)
- formLayout.addRow('Colormap:', self._comboBoxColormap)
- formLayout.addRow('Normalization:', self._comboBoxNormalization)
- formLayout.addRow('Gamma:', self._gammaSpinBox)
- formLayout.addRow(self._histoWidget)
- formLayout.addRow(rangeLayout)
- label = qt.QLabel('Mode:', self)
- self._autoscaleModeLabel = label
- label.setToolTip("Mode for autoscale. Algorithm used to find range in auto scale.")
- formLayout.addItem(qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed))
- formLayout.addRow(label, autoScaleCombo)
-
- layout = qt.QHBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self._visibleAreaButton)
- layout.addWidget(self._selectedAreaButton)
- self._scaleToAreaGroup = qt.QGroupBox('Scale to:', self)
- self._scaleToAreaGroup.setLayout(layout)
- self._scaleToAreaGroup.setVisible(False)
- formLayout.addRow(self._scaleToAreaGroup)
-
- formLayout.addRow(self._buttonsModal)
- formLayout.addRow(self._buttonsNonModal)
- formLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
-
- self.setTabOrder(self._comboBoxColormap, self._comboBoxNormalization)
- self.setTabOrder(self._comboBoxNormalization, self._gammaSpinBox)
- self.setTabOrder(self._gammaSpinBox, self._minValue)
- self.setTabOrder(self._minValue, self._maxValue)
- self.setTabOrder(self._maxValue, self._autoButtons)
- self.setTabOrder(self._autoButtons, self._autoScaleCombo)
- self.setTabOrder(self._autoScaleCombo, self._visibleAreaButton)
- self.setTabOrder(self._visibleAreaButton, self._selectedAreaButton)
- self.setTabOrder(self._selectedAreaButton, self._buttonsModal)
- self.setTabOrder(self._buttonsModal, self._buttonsNonModal)
-
- self.setFixedSize(self.sizeHint())
- self._applyColormap()
-
- def _invalidateColormap(self):
- if self.isVisible():
- self._applyColormap()
- else:
- self.__colormapInvalidated = True
-
- def _invalidateData(self):
- if self.isVisible():
- self._updateWidgetRange()
- self._histoWidget.invalidateData()
- else:
- self.__dataInvalidated = True
-
- def _validate(self):
- if self.__colormapInvalidated:
- self._applyColormap()
- if self.__dataInvalidated:
- self._histoWidget.invalidateData()
- if self.__dataInvalidated or self.__colormapInvalidated:
- self._updateWidgetRange()
- self.__dataInvalidated = False
- self.__colormapInvalidated = False
-
- def showEvent(self, event):
- self.visibleChanged.emit(True)
- super(ColormapDialog, self).showEvent(event)
- if self.isVisible():
- self._validate()
-
- def closeEvent(self, event):
- if not self.isModal():
- self.accept()
- super(ColormapDialog, self).closeEvent(event)
-
- def hideEvent(self, event):
- self.visibleChanged.emit(False)
- super(ColormapDialog, self).hideEvent(event)
-
- def close(self):
- self.accept()
- qt.QDialog.close(self)
-
- def setModal(self, modal):
- assert type(modal) is bool
- self._buttonsNonModal.setVisible(not modal)
- self._buttonsModal.setVisible(modal)
- qt.QDialog.setModal(self, modal)
-
- def event(self, event):
- if event.type() == qt.QEvent.DeferredDelete:
- self.__aboutToDelete = True
- return super(ColormapDialog, self).event(event)
-
- def exec_(self):
- wasModal = self.isModal()
- self.setModal(True)
- result = super(ColormapDialog, self).exec_()
- if not self.__aboutToDelete:
- self.setModal(wasModal)
- return result
-
- def _getFiniteColormapRange(self):
- """Return a colormap range where auto ranges are fixed
- according to the available data.
- """
- colormap = self.getColormap()
- if colormap is None:
- return 1, 10
-
- item = self._getItem()
- if item is not None:
- return colormap.getColormapRange(item)
- # If there is not item, there is no data
- return colormap.getColormapRange(None)
-
- @staticmethod
- def computeDataRange(data):
- """Compute the data range as used by :meth:`setDataRange`.
-
- :param data: The data to process
- :rtype: List[Union[None,float]]
- """
- if data is None or len(data) == 0:
- return None, None, None
-
- dataRange = min_max(data, min_positive=True, finite=True)
- if dataRange.minimum is None:
- # Only non-finite data
- dataRange = None
-
- if dataRange is not None:
- dataRange = dataRange.minimum, dataRange.min_positive, dataRange.maximum
-
- if dataRange is None or len(dataRange) != 3:
- qt.QMessageBox.warning(
- None, "No Data",
- "Image data does not contain any real value")
- dataRange = 1., 1., 10.
-
- return dataRange
-
- @staticmethod
- def computeHistogram(data, scale=Axis.LINEAR, dataRange=None):
- """Compute the data histogram as used by :meth:`setHistogram`.
-
- :param data: The data to process
- :param dataRange: Optional range to compute the histogram, which is a
- tuple of min, max
- :rtype: Tuple(List(float),List(float)
- """
- # For compatibility
- if scale == Axis.LOGARITHMIC:
- scale = Colormap.LOGARITHM
-
- if data is None:
- return None, None
-
- if len(data) == 0:
- return None, None
-
- if data.ndim == 3: # RGB(A) images
- _logger.info('Converting current image from RGB(A) to grayscale\
- in order to compute the intensity distribution')
- data = (data[:,:, 0] * 0.299 +
- data[:,:, 1] * 0.587 +
- data[:,:, 2] * 0.114)
-
- # bad hack: get 256 continuous bins in the case we have a B&W
- normalizeData = True
- if numpy.issubdtype(data.dtype, numpy.ubyte):
- normalizeData = False
- elif numpy.issubdtype(data.dtype, numpy.integer):
- if dataRange is not None:
- xmin, xmax = dataRange
- if xmin is not None and xmax is not None:
- normalizeData = (xmax - xmin) > 255
-
- if normalizeData:
- if scale == Colormap.LOGARITHM:
- with numpy.errstate(divide='ignore', invalid='ignore'):
- data = numpy.log10(data)
-
- if dataRange is not None:
- xmin, xmax = dataRange
- if xmin is None:
- return None, None
- if normalizeData:
- if scale == Colormap.LOGARITHM:
- xmin, xmax = numpy.log10(xmin), numpy.log10(xmax)
- else:
- xmin, xmax = min_max(data, min_positive=False, finite=True)
-
- if xmin is None:
- return None, None
-
- nbins = min(256, int(numpy.sqrt(data.size)))
- data_range = xmin, xmax
-
- # bad hack: get 256 bins in the case we have a B&W
- if numpy.issubdtype(data.dtype, numpy.integer):
- if nbins > xmax - xmin:
- nbins = int(xmax - xmin)
-
- nbins = max(2, nbins)
- data = data.ravel().astype(numpy.float32)
-
- histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
- bins = histogram.edges[0]
- if normalizeData:
- if scale == Colormap.LOGARITHM:
- bins = 10 ** bins
- return histogram.histo, bins
-
- def _getItem(self):
- if self._itemHolder is not None:
- return self._itemHolder
- if self._item is None:
- return None
- return self._item()
-
- def setItem(self, item):
- """Store the plot item.
-
- According to the state of the dialog, the item will be used to display
- the data range or the histogram of the data using :meth:`setDataRange`
- and :meth:`setHistogram`
- """
- # While event from items are not supported, we can't ignore dup items
- # old = self._getItem()
- # if old is item:
- # return
- self._data = None
- self._itemHolder = None
- try:
- if item is None:
- self._item = None
- else:
- if not isinstance(item, items.ColormapMixIn):
- self._item = None
- raise ValueError("Item %s is not supported" % item)
- self._item = weakref.ref(item, self._itemAboutToFinalize)
- finally:
- self._syncScaleToButtonsEnabled()
- self._dataRange = None
- self._histogramData = None
- self._invalidateData()
-
- def _getData(self):
- if self._data is None:
- return None
- return self._data()
-
- def setData(self, data):
- """Store the data
-
- According to the state of the dialog, the data will be used to display
- the data range or the histogram of the data using :meth:`setDataRange`
- and :meth:`setHistogram`
- """
- oldData = self._getData()
- if oldData is data:
- return
-
- self._item = None
- self._syncScaleToButtonsEnabled()
- if data is None:
- self._data = None
- self._itemHolder = None
- else:
- self._data = weakref.ref(data, self._dataAboutToFinalize)
- self._itemHolder = _DataRefHolder(self._data)
-
- self._dataRange = None
- self._histogramData = None
-
- self._invalidateData()
-
- def _getArray(self):
- data = self._getData()
- if data is not None:
- return data
- item = self._getItem()
- if item is not None:
- return item.getColormappedData(copy=False)
- return None
-
- def _colormapAboutToFinalize(self, weakrefColormap):
- """Callback when the data weakref is about to be finalized."""
- if self._colormap is weakrefColormap and qtinspect.isValid(self):
- self.setColormap(None)
-
- def _dataAboutToFinalize(self, weakrefData):
- """Callback when the data weakref is about to be finalized."""
- if self._data is weakrefData and qtinspect.isValid(self):
- self.setData(None)
-
- def _itemAboutToFinalize(self, weakref):
- """Callback when the data weakref is about to be finalized."""
- if self._item is weakref and qtinspect.isValid(self):
- self.setItem(None)
-
- @deprecation.deprecated(reason="It is private data", since_version="0.13")
- def getHistogram(self):
- histo = self._getHistogram()
- if histo is None:
- return None
- counts, bin_edges = histo
- return numpy.array(counts, copy=True), numpy.array(bin_edges, copy=True)
-
- def _getHistogram(self):
- """Returns the histogram defined by the dialog as metadata
- to describe the data in order to speed up the dialog.
-
- :return: (hist, bin_edges)
- :rtype: 2-tuple of numpy arrays"""
- return self._histogramData
-
- def setHistogram(self, hist=None, bin_edges=None):
- """Set the histogram to display.
-
- This update the data range with the bounds of the bins.
-
- :param hist: array-like of counts or None to hide histogram
- :param bin_edges: array-like of bins edges or None to hide histogram
- """
- if hist is None or bin_edges is None:
- self._histogramData = None
- else:
- self._histogramData = numpy.array(hist), numpy.array(bin_edges)
-
- self._invalidateData()
-
- def getColormap(self):
- """Return the colormap description.
-
- :rtype: ~silx.gui.colors.Colormap
- """
- if self._colormap is None:
- return None
- return self._colormap()
-
- def resetColormap(self):
- """
- Reset the colormap state before modification.
-
- ..note :: the colormap reference state is the state when set or the
- state when validated
- """
- colormap = self.getColormap()
- if colormap is not None and self._colormapStoredState is not None:
- if colormap != self._colormapStoredState:
- with self._colormapChange:
- colormap.setFromColormap(self._colormapStoredState)
- self._applyColormap()
-
- def _getDataRange(self):
- """Returns the data range defined by the dialog as metadata
- to describe the data in order to speed up the dialog.
-
- :return: (minimum, positiveMin, maximum)
- :rtype: 3-tuple of floats or None"""
- return self._dataRange
-
- def setDataRange(self, minimum=None, positiveMin=None, maximum=None):
- """Set the range of data to use for the range of the histogram area.
-
- :param float minimum: The minimum of the data
- :param float positiveMin: The positive minimum of the data
- :param float maximum: The maximum of the data
- """
- self._dataRange = minimum, positiveMin, maximum
- self._invalidateData()
-
- def _setColormapRange(self, xmin, xmax):
- """Set a new range to the held colormap and update the
- widget."""
- colormap = self.getColormap()
- if colormap is not None:
- with self._colormapChange:
- colormap.setVRange(xmin, xmax)
- self._updateWidgetRange()
-
- def setColormapRangeFromDataBounds(self, bounds):
- """Set the range of the colormap from current item and rect.
-
- If there is no ColormapMixIn item attached to the ColormapDialog,
- nothing is done.
-
- :param Union[List[float],None] bounds:
- (xmin, xmax, ymin, ymax) Rectangular region in data space
- """
- if bounds is None:
- return None # no-op
-
- colormap = self.getColormap()
- if colormap is None:
- return # no-op
-
- item = self._getItem()
- if not isinstance(item, items.ColormapMixIn):
- return None # no-op
-
- data = item.getColormappedData(copy=False)
-
- xmin, xmax, ymin, ymax = bounds
-
- if isinstance(item, items.ImageBase):
- ox, oy = item.getOrigin()
- sx, sy = item.getScale()
-
- ystart = max(0, int((ymin - oy) / sy))
- ystop = max(0, int(numpy.ceil((ymax - oy) / sy)))
- xstart = max(0, int((xmin - ox) / sx))
- xstop = max(0, int(numpy.ceil((xmax - ox) / sx)))
-
- subset = data[ystart:ystop, xstart:xstop]
-
- elif isinstance(item, items.Scatter):
- x = item.getXData(copy=False)
- y = item.getYData(copy=False)
- subset = data[
- numpy.logical_and(
- numpy.logical_and(xmin <= x, x <= xmax),
- numpy.logical_and(ymin <= y, y <= ymax))]
-
- if subset.size == 0:
- return # no-op
-
- vmin, vmax = colormap._computeAutoscaleRange(subset)
- self._setColormapRange(vmin, vmax)
-
- def _updateWidgetRange(self):
- """Update the colormap range displayed into the widget."""
- xmin, xmax = self._getFiniteColormapRange()
- colormap = self.getColormap()
- if colormap is not None:
- vRange = colormap.getVRange()
- autoMin, autoMax = (r is None for r in vRange)
- else:
- autoMin, autoMax = False, False
-
- with utils.blockSignals(self._minValue):
- self._minValue.setValue(xmin, autoMin)
- with utils.blockSignals(self._maxValue):
- self._maxValue.setValue(xmax, autoMax)
- with utils.blockSignals(self._histoWidget):
- self._histoWidget.setFiniteRange((xmin, xmax))
- with utils.blockSignals(self._autoButtons):
- self._autoButtons.setAutoRange((autoMin, autoMax))
- self._autoscaleModeLabel.setEnabled(autoMin or autoMax)
-
- def accept(self):
- self.storeCurrentState()
- qt.QDialog.accept(self)
-
- def storeCurrentState(self):
- """
- save the current value sof the colormap if the user want to undo is
- modifications
- """
- colormap = self.getColormap()
- if colormap is not None:
- self._colormapStoredState = colormap.copy()
- else:
- self._colormapStoredState = None
-
- def reject(self):
- self.resetColormap()
- qt.QDialog.reject(self)
-
- def setColormap(self, colormap):
- """Set the colormap description
-
- :param ~silx.gui.colors.Colormap colormap: the colormap to edit
- """
- assert colormap is None or isinstance(colormap, Colormap)
- if self._colormapChange.locked():
- return
-
- oldColormap = self.getColormap()
- if oldColormap is colormap:
- return
- if oldColormap is not None:
- oldColormap.sigChanged.disconnect(self._applyColormap)
-
- if colormap is not None:
- colormap.sigChanged.connect(self._applyColormap)
- colormap = weakref.ref(colormap, self._colormapAboutToFinalize)
-
- self._colormap = colormap
- self.storeCurrentState()
- self._invalidateColormap()
-
- def _updateResetButton(self):
- resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
- rStateEnabled = False
- colormap = self.getColormap()
- if colormap is not None and colormap.isEditable():
- # can reset only in the case the colormap changed
- rStateEnabled = colormap != self._colormapStoredState
- resetButton.setEnabled(rStateEnabled)
-
- def _applyColormap(self):
- self._updateResetButton()
- if self._colormapChange.locked():
- return
-
- self._syncScaleToButtonsEnabled()
-
- colormap = self.getColormap()
- if colormap is None:
- self._comboBoxColormap.setEnabled(False)
- self._comboBoxNormalization.setEnabled(False)
- self._gammaSpinBox.setEnabled(False)
- self._autoScaleCombo.setEnabled(False)
- self._minValue.setEnabled(False)
- self._maxValue.setEnabled(False)
- self._autoButtons.setEnabled(False)
- self._autoscaleModeLabel.setEnabled(False)
- self._histoWidget.setVisible(False)
- self._histoWidget.setFiniteRange((None, None))
- else:
- assert colormap.getNormalization() in Colormap.NORMALIZATIONS
- with utils.blockSignals(self._comboBoxColormap):
- self._comboBoxColormap.setCurrentLut(colormap)
- self._comboBoxColormap.setEnabled(colormap.isEditable())
- with utils.blockSignals(self._comboBoxNormalization):
- index = self._comboBoxNormalization.findData(
- colormap.getNormalization())
- if index < 0:
- _logger.error('Unsupported normalization: %s' %
- colormap.getNormalization())
- else:
- self._comboBoxNormalization.setCurrentIndex(index)
- self._comboBoxNormalization.setEnabled(colormap.isEditable())
- with utils.blockSignals(self._gammaSpinBox):
- self._gammaSpinBox.setValue(
- colormap.getGammaNormalizationParameter())
- self._gammaSpinBox.setEnabled(
- colormap.getNormalization() == 'gamma' and
- colormap.isEditable())
- with utils.blockSignals(self._autoScaleCombo):
- self._autoScaleCombo.setCurrentMode(colormap.getAutoscaleMode())
- self._autoScaleCombo.setEnabled(colormap.isEditable())
- with utils.blockSignals(self._autoButtons):
- self._autoButtons.setEnabled(colormap.isEditable())
- self._autoButtons.setAutoRangeFromColormap(colormap)
-
- vmin, vmax = colormap.getVRange()
- if vmin is None or vmax is None:
- # Compute it only if needed
- dataRange = self._getFiniteColormapRange()
- else:
- dataRange = vmin, vmax
-
- with utils.blockSignals(self._minValue):
- self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
- self._minValue.setEnabled(colormap.isEditable())
- with utils.blockSignals(self._maxValue):
- self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
- self._maxValue.setEnabled(colormap.isEditable())
- self._autoscaleModeLabel.setEnabled(vmin is None or vmax is None)
-
- with utils.blockSignals(self._histoWidget):
- self._histoWidget.setVisible(True)
- self._histoWidget.setFiniteRange(dataRange)
- self._histoWidget.updateNormalization()
-
- def _comboBoxColormapUpdated(self):
- """Callback executed when the combo box with the colormap LUT
- is updated by user input.
- """
- colormap = self.getColormap()
- if colormap is not None:
- with self._colormapChange:
- name = self._comboBoxColormap.getCurrentName()
- if name is not None:
- colormap.setName(name)
- else:
- lut = self._comboBoxColormap.getCurrentColors()
- colormap.setColormapLUT(lut)
- self._histoWidget.updateLut()
-
- def _autoRangeButtonsUpdated(self, autoRange):
- """Callback executed when the autoscale buttons widget
- is updated by user input.
- """
- dataRange = self._getFiniteColormapRange()
-
- # Final colormap range
- vmin = (dataRange[0] if not autoRange[0] else None)
- vmax = (dataRange[1] if not autoRange[1] else None)
-
- with self._colormapChange:
- colormap = self.getColormap()
- colormap.setVRange(vmin, vmax)
-
- with utils.blockSignals(self._minValue):
- self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
- with utils.blockSignals(self._maxValue):
- self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
-
- self._updateWidgetRange()
-
- def _normalizationUpdated(self, index):
- """Callback executed when the normalization widget
- is updated by user input.
- """
- colormap = self.getColormap()
- if colormap is not None:
- normalization = self._comboBoxNormalization.itemData(index)
- self._gammaSpinBox.setEnabled(normalization == 'gamma')
-
- with self._colormapChange:
- colormap.setNormalization(normalization)
- self._histoWidget.updateNormalization()
-
- self._updateWidgetRange()
-
- def _gammaUpdated(self, value):
- """Callback used to update the gamma normalization parameter"""
- colormap = self.getColormap()
- if colormap is not None:
- colormap.setGammaNormalizationParameter(value)
-
- def _autoscaleModeUpdated(self):
- """Callback executed when the autoscale mode widget
- is updated by user input.
- """
- mode = self._autoScaleCombo.currentMode()
-
- colormap = self.getColormap()
- if colormap is not None:
- with self._colormapChange:
- colormap.setAutoscaleMode(mode)
-
- self._updateWidgetRange()
-
- def _minAutoscaleUpdated(self, autoEnabled):
- """Callback executed when the min autoscale from
- the lineedit is updated by user input"""
- colormap = self.getColormap()
- xmin, xmax = colormap.getVRange()
- if autoEnabled:
- xmin = None
- else:
- xmin, _xmax = self._getFiniteColormapRange()
- self._setColormapRange(xmin, xmax)
-
- def _maxAutoscaleUpdated(self, autoEnabled):
- """Callback executed when the max autoscale from
- the lineedit is updated by user input"""
- colormap = self.getColormap()
- xmin, xmax = colormap.getVRange()
- if autoEnabled:
- xmax = None
- else:
- _xmin, xmax = self._getFiniteColormapRange()
- self._setColormapRange(xmin, xmax)
-
- def _minValueUpdated(self, value):
- """Callback executed when the lineedit min value is
- updated by user input"""
- xmin = value
- xmax = self._maxValue.getValue()
- if xmax is not None and xmin > xmax:
- # FIXME: This should be done in the widget itself
- xmin = xmax
- with utils.blockSignals(self._minValue):
- self._minValue.setValue(xmin)
- self._setColormapRange(xmin, xmax)
-
- def _maxValueUpdated(self, value):
- """Callback executed when the lineedit max value is
- updated by user input"""
- xmin = self._minValue.getValue()
- xmax = value
- if xmin is not None and xmin > xmax:
- # FIXME: This should be done in the widget itself
- xmax = xmin
- with utils.blockSignals(self._maxValue):
- self._maxValue.setValue(xmax)
- self._setColormapRange(xmin, xmax)
-
- def _histogramRangeMoving(self, vmin, vmax):
- """Callback executed when for colormap range displayed in
- the histogram widget is moving.
-
- :param vmin: Update of the minimum range, else None
- :param vmax: Update of the maximum range, else None
- """
- colormap = self.getColormap()
- if vmin is not None:
- with self._colormapChange:
- colormap.setVMin(vmin)
- self._minValue.setValue(vmin)
- if vmax is not None:
- with self._colormapChange:
- colormap.setVMax(vmax)
- self._maxValue.setValue(vmax)
-
- def _histogramRangeMoved(self, vmin, vmax):
- """Callback executed when for colormap range displayed in
- the histogram widget has finished to move
- """
- xmin = self._minValue.getValue()
- xmax = self._maxValue.getValue()
- if vmin is None:
- vmin = xmin
- if vmax is None:
- vmax = xmax
- self._setColormapRange(vmin, vmax)
-
- def _syncScaleToButtonsEnabled(self):
- """Set the state of scale to buttons according to current item and colormap"""
- colormap = self.getColormap()
- enabled = self._item is not None and colormap is not None and colormap.isEditable()
- self._scaleToAreaGroup.setVisible(enabled)
- self._visibleAreaButton.setEnabled(enabled)
- if not enabled:
- self._selectedAreaButton.setChecked(False)
- self._selectedAreaButton.setEnabled(enabled)
-
- def _handleScaleToVisibleAreaClicked(self):
- """Set colormap range from current item's visible area"""
- item = self._getItem()
- if item is None:
- return # no-op
-
- bounds = item.getVisibleBounds()
- if bounds is None:
- return # no-op
-
- self.setColormapRangeFromDataBounds(bounds)
-
- def _handleScaleToSelectionToggled(self, checked=False):
- """Handle toggle of scale to selected are button"""
- # Reset any previous ROI manager
- if self._roiForColormapManager is not None:
- self._roiForColormapManager.clear()
- self._roiForColormapManager.stop()
- self._roiForColormapManager = None
-
- if not checked: # Reset button status
- self._selectedAreaButton.setWaiting(False)
- self._selectedAreaButton.setText("Selection")
- return
-
- item = self._getItem()
- if item is None:
- self._selectedAreaButton.setChecked(False)
- return # no-op
-
- plotWidget = item.getPlot()
- if plotWidget is None:
- self._selectedAreaButton.setChecked(False)
- return # no-op
-
- self._selectedAreaButton.setWaiting(True)
- self._selectedAreaButton.setText("Draw Area...")
-
- self._roiForColormapManager = RegionOfInterestManager(parent=plotWidget)
- cmap = self.getColormap()
- self._roiForColormapManager.setColor(
- 'black' if cmap is None else cursorColorForColormap(cmap.getName()))
- self._roiForColormapManager.sigInteractiveModeFinished.connect(
- self.__roiInteractiveModeFinished)
- self._roiForColormapManager.sigInteractiveRoiFinalized.connect(self.__roiFinalized)
- self._roiForColormapManager.start(RectangleROI)
-
- def __roiInteractiveModeFinished(self):
- self._selectedAreaButton.setChecked(False)
-
- def __roiFinalized(self, roi):
- self._selectedAreaButton.setChecked(False)
- if roi is not None:
- ox, oy = roi.getOrigin()
- width, height = roi.getSize()
- self.setColormapRangeFromDataBounds((ox, ox+width, oy, oy+height))
-
- def keyPressEvent(self, event):
- """Override key handling.
-
- It disables leaving the dialog when editing a text field.
-
- But several press of Return key can be use to validate and close the
- dialog.
- """
- if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return):
- # Bypass QDialog keyPressEvent
- # To avoid leaving the dialog when pressing enter on a text field
- if self._minValue.hasFocus():
- nextFocus = self._maxValue
- elif self._maxValue.hasFocus():
- if self.isModal():
- nextFocus = self._buttonsModal.button(qt.QDialogButtonBox.Apply)
- else:
- nextFocus = self._buttonsNonModal.button(qt.QDialogButtonBox.Close)
- else:
- nextFocus = None
- if nextFocus is not None:
- nextFocus.setFocus(qt.Qt.OtherFocusReason)
- else:
- super(ColormapDialog, self).keyPressEvent(event)
diff --git a/silx/gui/dialog/DataFileDialog.py b/silx/gui/dialog/DataFileDialog.py
deleted file mode 100644
index 84605d9..0000000
--- a/silx/gui/dialog/DataFileDialog.py
+++ /dev/null
@@ -1,340 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module contains an :class:`DataFileDialog`.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "14/02/2018"
-
-import enum
-import logging
-from silx.gui import qt
-from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
-import silx.io
-from .AbstractDataFileDialog import AbstractDataFileDialog
-
-import fabio
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _DataPreview(qt.QWidget):
- """Provide a preview of the selected image"""
-
- def __init__(self, parent=None):
- super(_DataPreview, self).__init__(parent)
-
- self.__formatter = Hdf5Formatter(self)
- self.__data = None
- self.__info = qt.QTableView(self)
- self.__model = qt.QStandardItemModel(self)
- self.__info.setModel(self.__model)
- self.__info.horizontalHeader().hide()
- self.__info.horizontalHeader().setStretchLastSection(True)
- layout = qt.QVBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self.__info)
- self.setLayout(layout)
-
- def colormap(self):
- return None
-
- def setColormap(self, colormap):
- # Ignored
- pass
-
- def sizeHint(self):
- return qt.QSize(200, 200)
-
- def setData(self, data, fromDataSelector=False):
- self.__info.setEnabled(data is not None)
- if data is None:
- self.__model.clear()
- else:
- self.__model.clear()
-
- if silx.io.is_dataset(data):
- kind = "Dataset"
- elif silx.io.is_group(data):
- kind = "Group"
- elif silx.io.is_file(data):
- kind = "File"
- else:
- kind = "Unknown"
-
- headers = []
-
- basename = data.name.split("/")[-1]
- if basename == "":
- basename = "/"
- headers.append("Basename")
- self.__model.appendRow([qt.QStandardItem(basename)])
- headers.append("Kind")
- self.__model.appendRow([qt.QStandardItem(kind)])
- if hasattr(data, "dtype"):
- headers.append("Type")
- text = self.__formatter.humanReadableType(data)
- self.__model.appendRow([qt.QStandardItem(text)])
- if hasattr(data, "shape"):
- headers.append("Shape")
- text = self.__formatter.humanReadableShape(data)
- self.__model.appendRow([qt.QStandardItem(text)])
- if hasattr(data, "attrs") and "NX_class" in data.attrs:
- headers.append("NX_class")
- value = data.attrs["NX_class"]
- formatter = self.__formatter.textFormatter()
- old = formatter.useQuoteForText()
- formatter.setUseQuoteForText(False)
- text = self.__formatter.textFormatter().toString(value)
- formatter.setUseQuoteForText(old)
- self.__model.appendRow([qt.QStandardItem(text)])
- self.__model.setVerticalHeaderLabels(headers)
- self.__data = data
-
- def __imageItem(self):
- image = self.__plot.getImage("data")
- return image
-
- def data(self):
- if self.__data is not None:
- if hasattr(self.__data, "name"):
- # in case of HDF5
- if self.__data.name is None:
- # The dataset was closed
- self.__data = None
- return self.__data
-
- def clear(self):
- self.__data = None
- self.__info.setText("")
-
-
-class DataFileDialog(AbstractDataFileDialog):
- """The `DataFileDialog` class provides a dialog that allow users to select
- any datasets or groups from an HDF5-like file.
-
- The `DataFileDialog` class enables a user to traverse the file system in
- order to select an HDF5-like file. Then to traverse the file to select an
- HDF5 node.
-
- .. image:: img/datafiledialog.png
-
- The selected data is any kind of group or dataset. It can be restricted
- to only existing datasets or only existing groups using
- :meth:`setFilterMode`. A callback can be defining using
- :meth:`setFilterCallback` to filter even more data which can be returned.
-
- Filtering data which can be returned by a `DataFileDialog` can be done like
- that:
-
- .. code-block:: python
-
- # Force to return only a dataset
- dialog = DataFileDialog()
- dialog.setFilterMode(DataFileDialog.FilterMode.ExistingDataset)
-
- .. code-block:: python
-
- def customFilter(obj):
- if "NX_class" in obj.attrs:
- return obj.attrs["NX_class"] in [b"NXentry", u"NXentry"]
- return False
-
- # Force to return an NX entry
- dialog = DataFileDialog()
- # 1st, filter out everything which is not a group
- dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
- # 2nd, check what NX_class is an NXentry
- dialog.setFilterCallback(customFilter)
-
- Executing a `DataFileDialog` can be done like that:
-
- .. code-block:: python
-
- dialog = DataFileDialog()
- result = dialog.exec_()
- if result:
- print("Selection:")
- print(dialog.selectedFile())
- print(dialog.selectedUrl())
- else:
- print("Nothing selected")
-
- If the selection is a dataset you can access to the data using
- :meth:`selectedData`.
-
- If the selection is a group or if you want to read the selected object on
- your own you can use the `silx.io` API.
-
- .. code-block:: python
-
- url = dialog.selectedUrl()
- with silx.io.open(url) as data:
- pass
-
- Or by loading the file first
-
- .. code-block:: python
-
- url = dialog.selectedDataUrl()
- with silx.io.open(url.file_path()) as h5:
- data = h5[url.data_path()]
-
- Or by using `h5py` library
-
- .. code-block:: python
-
- url = dialog.selectedDataUrl()
- with h5py.File(url.file_path(), mode="r") as h5:
- data = h5[url.data_path()]
- """
-
- class FilterMode(enum.Enum):
- """This enum is used to indicate what the user may select in the
- dialog; i.e. what the dialog will return if the user clicks OK."""
-
- AnyNode = 0
- """Any existing node from an HDF5-like file."""
- ExistingDataset = 1
- """An existing HDF5-like dataset."""
- ExistingGroup = 2
- """An existing HDF5-like group. A file root is a group."""
-
- def __init__(self, parent=None):
- AbstractDataFileDialog.__init__(self, parent=parent)
- self.__filter = DataFileDialog.FilterMode.AnyNode
- self.__filterCallback = None
-
- def selectedData(self):
- """Returns the selected data by using the :meth:`silx.io.get_data`
- API with the selected URL provided by the dialog.
-
- If the URL identify a group of a file it will raise an exception. For
- group or file you have to use on your own the API :meth:`silx.io.open`.
-
- :rtype: numpy.ndarray
- :raise ValueError: If the URL do not link to a dataset
- """
- url = self.selectedUrl()
- return silx.io.get_data(url)
-
- def _createPreviewWidget(self, parent):
- previewWidget = _DataPreview(parent)
- previewWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
- return previewWidget
-
- def _createSelectorWidget(self, parent):
- # There is no selector
- return None
-
- def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
- # There is no toolbar
- return None
-
- def _isDataSupportable(self, data):
- """Check if the selected data can be supported at one point.
-
- If true, the data selector will be checked and it will update the data
- preview. Else the selecting is disabled.
-
- :rtype: bool
- """
- # Everything is supported
- return True
-
- def _isFabioFilesSupported(self):
- # Everything is supported
- return False
-
- def _isDataSupported(self, data):
- """Check if the data can be returned by the dialog.
-
- If true, this data can be returned by the dialog and the open button
- will be enabled. If false the button will be disabled.
-
- :rtype: bool
- """
- if self.__filter == DataFileDialog.FilterMode.AnyNode:
- accepted = True
- elif self.__filter == DataFileDialog.FilterMode.ExistingDataset:
- accepted = silx.io.is_dataset(data)
- elif self.__filter == DataFileDialog.FilterMode.ExistingGroup:
- accepted = silx.io.is_group(data)
- else:
- raise ValueError("Filter %s is not supported" % self.__filter)
- if not accepted:
- return False
- if self.__filterCallback is not None:
- try:
- return self.__filterCallback(data)
- except Exception:
- _logger.error("Error while executing custom callback", exc_info=True)
- return False
- return True
-
- def setFilterCallback(self, callback):
- """Set the filter callback. This filter is applied only if the filter
- mode (:meth:`filterMode`) first accepts the selected data.
-
- It is not supposed to be set while the dialog is being used.
-
- :param callable callback: Define a custom function returning a boolean
- and taking as argument an h5-like node. If the function returns true
- the dialog can return the associated URL.
- """
- self.__filterCallback = callback
-
- def setFilterMode(self, mode):
- """Set the filter mode.
-
- It is not supposed to be set while the dialog is being used.
-
- :param DataFileDialog.FilterMode mode: The new filter.
- """
- self.__filter = mode
-
- def fileMode(self):
- """Returns the filter mode.
-
- :rtype: DataFileDialog.FilterMode
- """
- return self.__filter
-
- def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
- """Returns the text displayed under the data preview.
-
- This zone is used to display error in case or problem of data selection
- or problems with IO.
-
- :param numpy.ndarray dataAfterSelection: Data as it is after the
- selection widget (basically the data from the preview widget)
- :param numpy.ndarray dataAfterSelection: Data as it is before the
- selection widget (basically the data from the browsing widget)
- :rtype: bool
- """
- return u""
diff --git a/silx/gui/dialog/DatasetDialog.py b/silx/gui/dialog/DatasetDialog.py
deleted file mode 100644
index 87fc89d..0000000
--- a/silx/gui/dialog/DatasetDialog.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a dialog widget to select a HDF5 dataset in a
-tree.
-
-.. autoclass:: DatasetDialog
- :members: addFile, addGroup, getSelectedDataUrl, setMode
-
-"""
-from .GroupDialog import _Hdf5ItemSelectionDialog
-import silx.io
-from silx.io.url import DataUrl
-
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/09/2018"
-
-
-class DatasetDialog(_Hdf5ItemSelectionDialog):
- """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to
- provide a HDF5 dataset selection dialog.
-
- The information identifying the selected node is provided as a
- :class:`silx.io.url.DataUrl`.
-
- Example:
-
- .. code-block:: python
-
- dialog = DatasetDialog()
- dialog.addFile(filepath1)
- dialog.addFile(filepath2)
-
- if dialog.exec_():
- print("File path: %s" % dialog.getSelectedDataUrl().file_path())
- print("HDF5 dataset path : %s " % dialog.getSelectedDataUrl().data_path())
- else:
- print("Operation cancelled :(")
-
- """
- def __init__(self, parent=None):
- _Hdf5ItemSelectionDialog.__init__(self, parent)
-
- # customization for groups
- self.setWindowTitle("HDF5 dataset selection")
-
- self._header.setSections([self._model.NAME_COLUMN,
- self._model.NODE_COLUMN,
- self._model.LINK_COLUMN,
- self._model.TYPE_COLUMN,
- self._model.SHAPE_COLUMN])
- self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
-
- def setMode(self, mode):
- """Set dialog mode DatasetDialog.SaveMode or DatasetDialog.LoadMode
-
- :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
- """
- _Hdf5ItemSelectionDialog.setMode(self, mode)
- if mode == DatasetDialog.SaveMode:
- self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
- elif mode == DatasetDialog.LoadMode:
- self._selectDatasetStatusText = "Select a dataset"
-
- def _onActivation(self, idx):
- # double-click or enter press: filter for datasets
- nodes = list(self._tree.selectedH5Nodes())
- node = nodes[0]
- if silx.io.is_dataset(node.h5py_object):
- self.accept()
-
- def _updateUrl(self):
- # overloaded to filter for datasets
- nodes = list(self._tree.selectedH5Nodes())
- newDatasetName = self._lineEditNewItem.text()
- isDatasetSelected = False
- if nodes:
- node = nodes[0]
- if silx.io.is_dataset(node.h5py_object):
- data_path = node.local_name
- isDatasetSelected = True
- elif silx.io.is_group(node.h5py_object):
- data_path = node.local_name
- if newDatasetName.lstrip("/"):
- if not data_path.endswith("/"):
- data_path += "/"
- data_path += newDatasetName.lstrip("/")
- isDatasetSelected = True
-
- if isDatasetSelected:
- self._selectedUrl = DataUrl(file_path=node.local_filename,
- data_path=data_path)
- self._okButton.setEnabled(True)
- self._labelSelection.setText(
- self._selectedUrl.path())
- else:
- self._selectedUrl = None
- self._okButton.setEnabled(False)
- self._labelSelection.setText(self._selectDatasetStatusText)
diff --git a/silx/gui/dialog/GroupDialog.py b/silx/gui/dialog/GroupDialog.py
deleted file mode 100644
index 217a03c..0000000
--- a/silx/gui/dialog/GroupDialog.py
+++ /dev/null
@@ -1,230 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a dialog widget to select a HDF5 group in a
-tree.
-
-.. autoclass:: GroupDialog
- :members: addFile, addGroup, getSelectedDataUrl, setMode
-
-"""
-from silx.gui import qt
-from silx.gui.hdf5.Hdf5TreeView import Hdf5TreeView
-import silx.io
-from silx.io.url import DataUrl
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "22/03/2018"
-
-
-class _Hdf5ItemSelectionDialog(qt.QDialog):
- SaveMode = 1
- """Mode used to set the HDF5 item selection dialog to *save* mode.
- This adds a text field to type in a new item name."""
-
- LoadMode = 2
- """Mode used to set the HDF5 item selection dialog to *load* mode.
- Only existing items of the HDF5 file can be selected in this mode."""
-
- def __init__(self, parent=None):
- qt.QDialog.__init__(self, parent)
- self.setWindowTitle("HDF5 item selection")
-
- self._tree = Hdf5TreeView(self)
- self._tree.setSelectionMode(qt.QAbstractItemView.SingleSelection)
- self._tree.activated.connect(self._onActivation)
- self._tree.selectionModel().selectionChanged.connect(
- self._onSelectionChange)
-
- self._model = self._tree.findHdf5TreeModel()
-
- self._header = self._tree.header()
-
- self._newItemWidget = qt.QWidget(self)
- newItemLayout = qt.QVBoxLayout(self._newItemWidget)
- self._labelNewItem = qt.QLabel(self._newItemWidget)
- self._labelNewItem.setText("Create new item in selected group (optional):")
- self._lineEditNewItem = qt.QLineEdit(self._newItemWidget)
- self._lineEditNewItem.setToolTip(
- "Specify the name of a new item "
- "to be created in the selected group.")
- self._lineEditNewItem.textChanged.connect(
- self._onNewItemNameChange)
- newItemLayout.addWidget(self._labelNewItem)
- newItemLayout.addWidget(self._lineEditNewItem)
-
- _labelSelectionTitle = qt.QLabel(self)
- _labelSelectionTitle.setText("Current selection")
- self._labelSelection = qt.QLabel(self)
- self._labelSelection.setStyleSheet("color: gray")
- self._labelSelection.setWordWrap(True)
- self._labelSelection.setText("Select an item")
-
- buttonBox = qt.QDialogButtonBox()
- self._okButton = buttonBox.addButton(qt.QDialogButtonBox.Ok)
- self._okButton.setEnabled(False)
- buttonBox.addButton(qt.QDialogButtonBox.Cancel)
-
- buttonBox.accepted.connect(self.accept)
- buttonBox.rejected.connect(self.reject)
-
- vlayout = qt.QVBoxLayout(self)
- vlayout.addWidget(self._tree)
- vlayout.addWidget(self._newItemWidget)
- vlayout.addWidget(_labelSelectionTitle)
- vlayout.addWidget(self._labelSelection)
- vlayout.addWidget(buttonBox)
- self.setLayout(vlayout)
-
- self.setMinimumWidth(400)
-
- self._selectedUrl = None
-
- def _onSelectionChange(self, old, new):
- self._updateUrl()
-
- def _onNewItemNameChange(self, text):
- self._updateUrl()
-
- def _onActivation(self, idx):
- # double-click or enter press
- self.accept()
-
- def setMode(self, mode):
- """Set dialog mode DatasetDialog.SaveMode or DatasetDialog.LoadMode
-
- :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
- """
- if mode == self.LoadMode:
- # hide "Create new item" field
- self._lineEditNewItem.clear()
- self._newItemWidget.hide()
- elif mode == self.SaveMode:
- self._newItemWidget.show()
- else:
- raise ValueError("Invalid DatasetDialog mode %s" % mode)
-
- def addFile(self, path):
- """Add a HDF5 file to the tree.
- All groups it contains will be selectable in the dialog.
-
- :param str path: File path
- """
- self._model.insertFile(path)
-
- def addGroup(self, group):
- """Add a HDF5 group to the tree. This group and all its subgroups
- will be selectable in the dialog.
-
- :param h5py.Group group: HDF5 group
- """
- self._model.insertH5pyObject(group)
-
- def _updateUrl(self):
- nodes = list(self._tree.selectedH5Nodes())
- subgroupName = self._lineEditNewItem.text()
- if nodes:
- node = nodes[0]
- data_path = node.local_name
- if subgroupName.lstrip("/"):
- if not data_path.endswith("/"):
- data_path += "/"
- data_path += subgroupName.lstrip("/")
- self._selectedUrl = DataUrl(file_path=node.local_filename,
- data_path=data_path)
- self._okButton.setEnabled(True)
- self._labelSelection.setText(
- self._selectedUrl.path())
-
- def getSelectedDataUrl(self):
- """Return a :class:`DataUrl` with a file path and a data path.
- Return None if the dialog was cancelled.
-
- :return: :class:`silx.io.url.DataUrl` object pointing to the
- selected HDF5 item.
- """
- return self._selectedUrl
-
-
-class GroupDialog(_Hdf5ItemSelectionDialog):
- """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to
- provide a HDF5 group selection dialog.
-
- The information identifying the selected node is provided as a
- :class:`silx.io.url.DataUrl`.
-
- Example:
-
- .. code-block:: python
-
- dialog = GroupDialog()
- dialog.addFile(filepath1)
- dialog.addFile(filepath2)
-
- if dialog.exec_():
- print("File path: %s" % dialog.getSelectedDataUrl().file_path())
- print("HDF5 group path : %s " % dialog.getSelectedDataUrl().data_path())
- else:
- print("Operation cancelled :(")
-
- """
- def __init__(self, parent=None):
- _Hdf5ItemSelectionDialog.__init__(self, parent)
-
- # customization for groups
- self.setWindowTitle("HDF5 group selection")
-
- self._header.setSections([self._model.NAME_COLUMN,
- self._model.NODE_COLUMN,
- self._model.LINK_COLUMN])
-
- def _onActivation(self, idx):
- # double-click or enter press: filter for groups
- nodes = list(self._tree.selectedH5Nodes())
- node = nodes[0]
- if silx.io.is_group(node.h5py_object):
- self.accept()
-
- def _updateUrl(self):
- # overloaded to filter for groups
- nodes = list(self._tree.selectedH5Nodes())
- subgroupName = self._lineEditNewItem.text()
- if nodes:
- node = nodes[0]
- if silx.io.is_group(node.h5py_object):
- data_path = node.local_name
- if subgroupName.lstrip("/"):
- if not data_path.endswith("/"):
- data_path += "/"
- data_path += subgroupName.lstrip("/")
- self._selectedUrl = DataUrl(file_path=node.local_filename,
- data_path=data_path)
- self._okButton.setEnabled(True)
- self._labelSelection.setText(
- self._selectedUrl.path())
- else:
- self._selectedUrl = None
- self._okButton.setEnabled(False)
- self._labelSelection.setText("Select a group")
diff --git a/silx/gui/dialog/ImageFileDialog.py b/silx/gui/dialog/ImageFileDialog.py
deleted file mode 100644
index d015bd2..0000000
--- a/silx/gui/dialog/ImageFileDialog.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module contains an :class:`ImageFileDialog`.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "05/03/2019"
-
-import logging
-from silx.gui.plot import actions
-from silx.gui import qt
-from silx.gui.plot.PlotWidget import PlotWidget
-from .AbstractDataFileDialog import AbstractDataFileDialog
-import silx.io
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _ImageSelection(qt.QWidget):
- """Provide a widget allowing to select an image from an hypercube by
- selecting a slice."""
-
- selectionChanged = qt.Signal()
- """Emitted when the selection change."""
-
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
- self.__shape = None
- self.__axis = []
- layout = qt.QVBoxLayout()
- self.setLayout(layout)
-
- def hasVisibleSelectors(self):
- return self.__visibleSliders > 0
-
- def isUsed(self):
- if self.__shape is None:
- return False
- return len(self.__shape) > 2
-
- def getSelectedData(self, data):
- slicing = self.slicing()
- image = data[slicing]
- return image
-
- def setData(self, data):
- if data is None:
- self.__visibleSliders = 0
- return
-
- shape = data.shape
- if self.__shape is not None:
- # clean up
- for widget in self.__axis:
- self.layout().removeWidget(widget)
- widget.deleteLater()
- self.__axis = []
-
- self.__shape = shape
- self.__visibleSliders = 0
-
- if shape is not None:
- # create expected axes
- for index in range(len(shape) - 2):
- axis = qt.QSlider(self)
- axis.setMinimum(0)
- axis.setMaximum(shape[index] - 1)
- axis.setOrientation(qt.Qt.Horizontal)
- if shape[index] == 1:
- axis.setVisible(False)
- else:
- self.__visibleSliders += 1
-
- axis.valueChanged.connect(self.__axisValueChanged)
- self.layout().addWidget(axis)
- self.__axis.append(axis)
-
- self.selectionChanged.emit()
-
- def __axisValueChanged(self):
- self.selectionChanged.emit()
-
- def slicing(self):
- slicing = []
- for axes in self.__axis:
- slicing.append(axes.value())
- return tuple(slicing)
-
- def setSlicing(self, slicing):
- for i, value in enumerate(slicing):
- if i > len(self.__axis):
- break
- self.__axis[i].setValue(value)
-
- def selectSlicing(self, slicing):
- """Select a slicing.
-
- The provided value could be unconsistent and therefore is not supposed
- to be retrivable with a getter.
-
- :param Union[None,Tuple[int]] slicing:
- """
- if slicing is None:
- # Create a default slicing
- needed = self.__visibleSliders
- slicing = (0,) * needed
- if len(slicing) < self.__visibleSliders:
- slicing = slicing + (0,) * (self.__visibleSliders - len(slicing))
- self.setSlicing(slicing)
-
-
-class _ImagePreview(qt.QWidget):
- """Provide a preview of the selected image"""
-
- def __init__(self, parent=None):
- super(_ImagePreview, self).__init__(parent)
-
- self.__data = None
- self.__plot = PlotWidget(self)
- self.__plot.setAxesDisplayed(False)
- self.__plot.setKeepDataAspectRatio(True)
- layout = qt.QVBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self.__plot)
- self.setLayout(layout)
-
- def resizeEvent(self, event):
- self.__updateConstraints()
- return qt.QWidget.resizeEvent(self, event)
-
- def sizeHint(self):
- return qt.QSize(200, 200)
-
- def plot(self):
- return self.__plot
-
- def setData(self, data, fromDataSelector=False):
- if data is None:
- self.clear()
- return
-
- resetzoom = not fromDataSelector
- previousImage = self.data()
- if previousImage is not None and data.shape != previousImage.shape:
- resetzoom = True
-
- self.__plot.addImage(legend="data", data=data, resetzoom=resetzoom)
- self.__data = data
- self.__updateConstraints()
-
- def __updateConstraints(self):
- """
- Update the constraints depending on the size of the widget
- """
- image = self.data()
- if image is None:
- return
- size = self.size()
- if size.width() == 0 or size.height() == 0:
- return
-
- heightData, widthData = image.shape
-
- widthContraint = heightData * size.width() / size.height()
- if widthContraint > widthData:
- heightContraint = heightData
- else:
- heightContraint = heightData * size.height() / size.width()
- widthContraint = widthData
-
- midWidth, midHeight = widthData * 0.5, heightData * 0.5
- heightContraint, widthContraint = heightContraint * 0.5, widthContraint * 0.5
-
- axis = self.__plot.getXAxis()
- axis.setLimitsConstraints(midWidth - widthContraint, midWidth + widthContraint)
- axis = self.__plot.getYAxis()
- axis.setLimitsConstraints(midHeight - heightContraint, midHeight + heightContraint)
-
- def __imageItem(self):
- image = self.__plot.getImage("data")
- return image
-
- def data(self):
- if self.__data is not None:
- if hasattr(self.__data, "name"):
- # in case of HDF5
- if self.__data.name is None:
- # The dataset was closed
- self.__data = None
- return self.__data
-
- def colormap(self):
- image = self.__imageItem()
- if image is not None:
- return image.getColormap()
- return self.__plot.getDefaultColormap()
-
- def setColormap(self, colormap):
- self.__plot.setDefaultColormap(colormap)
-
- def clear(self):
- self.__data = None
- image = self.__imageItem()
- if image is not None:
- self.__plot.removeImage(legend="data")
-
-
-class ImageFileDialog(AbstractDataFileDialog):
- """The `ImageFileDialog` class provides a dialog that allow users to select
- an image from a file.
-
- The `ImageFileDialog` class enables a user to traverse the file system in
- order to select one file. Then to traverse the file to select a frame or
- a slice of a dataset.
-
- .. image:: img/imagefiledialog_h5.png
-
- It supports fast access to image files using `FabIO`. Which is not the case
- of the default silx API. Image files still also can be available using the
- NeXus layout, by editing the file type combo box.
-
- .. image:: img/imagefiledialog_edf.png
-
- The selected data is an numpy array with 2 dimension.
-
- Using an `ImageFileDialog` can be done like that.
-
- .. code-block:: python
-
- dialog = ImageFileDialog()
- result = dialog.exec_()
- if result:
- print("Selection:")
- print(dialog.selectedFile())
- print(dialog.selectedUrl())
- print(dialog.selectedImage())
- else:
- print("Nothing selected")
- """
-
- def selectedImage(self):
- """Returns the selected image data as numpy
-
- :rtype: numpy.ndarray
- """
- url = self.selectedUrl()
- return silx.io.get_data(url)
-
- def _createPreviewWidget(self, parent):
- previewWidget = _ImagePreview(parent)
- previewWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
- return previewWidget
-
- def _createSelectorWidget(self, parent):
- return _ImageSelection(parent)
-
- def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
- plot = dataPreviewWidget.plot()
- toolbar = qt.QToolBar(parent)
- toolbar.setIconSize(qt.QSize(16, 16))
- toolbar.setStyleSheet("QToolBar { border: 0px }")
- toolbar.addAction(actions.mode.ZoomModeAction(plot, parent))
- toolbar.addAction(actions.mode.PanModeAction(plot, parent))
- toolbar.addSeparator()
- toolbar.addAction(actions.control.ResetZoomAction(plot, parent))
- toolbar.addSeparator()
- toolbar.addAction(actions.control.ColormapAction(plot, parent))
- return toolbar
-
- def _isDataSupportable(self, data):
- """Check if the selected data can be supported at one point.
-
- If true, the data selector will be checked and it will update the data
- preview. Else the selecting is disabled.
-
- :rtype: bool
- """
- if not hasattr(data, "dtype"):
- # It is not an HDF5 dataset nor a fabio image wrapper
- return False
-
- if data is None or data.shape is None:
- return False
-
- if data.dtype.kind not in set(["f", "u", "i", "b"]):
- return False
-
- dim = len(data.shape)
- return dim >= 2
-
- def _isFabioFilesSupported(self):
- return True
-
- def _isDataSupported(self, data):
- """Check if the data can be returned by the dialog.
-
- If true, this data can be returned by the dialog and the open button
- while be enabled. If false the button will be disabled.
-
- :rtype: bool
- """
- dim = len(data.shape)
- return dim == 2
-
- def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
- """Returns the text displayed under the data preview.
-
- This zone is used to display error in case or problem of data selection
- or problems with IO.
-
- :param numpy.ndarray dataAfterSelection: Data as it is after the
- selection widget (basically the data from the preview widget)
- :param numpy.ndarray dataAfterSelection: Data as it is before the
- selection widget (basically the data from the browsing widget)
- :rtype: bool
- """
- destination = self.__formatShape(dataAfterSelection.shape)
- source = self.__formatShape(dataBeforeSelection.shape)
- return u"%s \u2192 %s" % (source, destination)
-
- def __formatShape(self, shape):
- result = []
- for s in shape:
- if isinstance(s, slice):
- v = u"\u2026"
- else:
- v = str(s)
- result.append(v)
- return u" \u00D7 ".join(result)
diff --git a/silx/gui/dialog/SafeFileSystemModel.py b/silx/gui/dialog/SafeFileSystemModel.py
deleted file mode 100644
index 26954e3..0000000
--- a/silx/gui/dialog/SafeFileSystemModel.py
+++ /dev/null
@@ -1,804 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module contains an :class:`SafeFileSystemModel`.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "22/11/2017"
-
-import sys
-import os.path
-import logging
-import weakref
-
-import six
-
-from silx.gui import qt
-from .SafeFileIconProvider import SafeFileIconProvider
-
-_logger = logging.getLogger(__name__)
-
-
-class _Item(object):
-
- def __init__(self, fileInfo):
- self.__fileInfo = fileInfo
- self.__parent = None
- self.__children = None
- self.__absolutePath = None
-
- def isDrive(self):
- if sys.platform == "win32":
- return self.parent().parent() is None
- else:
- return False
-
- def isRoot(self):
- return self.parent() is None
-
- def isFile(self):
- """
- Returns true if the path is a file.
-
- It avoid to access to the `Qt.QFileInfo` in case the file is a drive.
- """
- if self.isDrive():
- return False
- return self.__fileInfo.isFile()
-
- def isDir(self):
- """
- Returns true if the path is a directory.
-
- The default `qt.QFileInfo.isDir` can freeze the file system with
- network drives. This function avoid the freeze in case of browsing
- the root.
- """
- if self.isDrive():
- # A drive is a directory, we don't have to synchronize the
- # drive to know that
- return True
- return self.__fileInfo.isDir()
-
- def absoluteFilePath(self):
- """
- Returns an absolute path including the file name.
-
- This function uses in most cases the default
- `qt.QFileInfo.absoluteFilePath`. But it is known to freeze the file
- system with network drives.
-
- This function uses `qt.QFileInfo.filePath` in case of root drives, to
- avoid this kind of issues. In case of drive, the result is the same,
- while the file path is already absolute.
-
- :rtype: str
- """
- if self.__absolutePath is None:
- if self.isRoot():
- path = ""
- elif self.isDrive():
- path = self.__fileInfo.filePath()
- else:
- path = os.path.join(self.parent().absoluteFilePath(), self.__fileInfo.fileName())
- if path == "":
- return "/"
- self.__absolutePath = path
- return self.__absolutePath
-
- def child(self):
- self.populate()
- return self.__children
-
- def childAt(self, position):
- self.populate()
- return self.__children[position]
-
- def childCount(self):
- self.populate()
- return len(self.__children)
-
- def indexOf(self, item):
- self.populate()
- return self.__children.index(item)
-
- def parent(self):
- parent = self.__parent
- if parent is None:
- return None
- return parent()
-
- def filePath(self):
- return self.__fileInfo.filePath()
-
- def fileName(self):
- if self.isDrive():
- name = self.absoluteFilePath()
- if name[-1] == "/":
- name = name[:-1]
- return name
- return os.path.basename(self.absoluteFilePath())
-
- def fileInfo(self):
- """
- Returns the Qt file info.
-
- :rtype: Qt.QFileInfo
- """
- return self.__fileInfo
-
- def _setParent(self, parent):
- self.__parent = weakref.ref(parent)
-
- def findChildrenByPath(self, path):
- if path == "":
- return self
- path = path.replace("\\", "/")
- if path[-1] == "/":
- path = path[:-1]
- names = path.split("/")
- caseSensitive = qt.QFSFileEngine(path).caseSensitive()
- count = len(names)
- cursor = self
- for name in names:
- for item in cursor.child():
- if caseSensitive:
- same = item.fileName() == name
- else:
- same = item.fileName().lower() == name.lower()
- if same:
- cursor = item
- count -= 1
- break
- else:
- return None
- if count == 0:
- break
- else:
- return None
- return cursor
-
- def populate(self):
- if self.__children is not None:
- return
- self.__children = []
- if self.isRoot():
- items = qt.QDir.drives()
- else:
- directory = qt.QDir(self.absoluteFilePath())
- filters = qt.QDir.AllEntries | qt.QDir.Hidden | qt.QDir.System
- items = directory.entryInfoList(filters)
- for fileInfo in items:
- i = _Item(fileInfo)
- self.__children.append(i)
- i._setParent(self)
-
-
-class _RawFileSystemModel(qt.QAbstractItemModel):
- """
- This class implement a file system model and try to avoid freeze. On Qt4,
- :class:`qt.QFileSystemModel` is known to freeze the file system when
- network drives are available.
-
- To avoid this behaviour, this class does not use
- `qt.QFileInfo.absoluteFilePath` nor `qt.QFileInfo.canonicalPath` to reach
- information on drives.
-
- This model do not take care of sorting and filtering. This features are
- managed by another model, by composition.
-
- And because it is the end of life of Qt4, we do not implement asynchronous
- loading of files as it is done by :class:`qt.QFileSystemModel`, nor some
- useful features.
- """
-
- __directoryLoadedSync = qt.Signal(str)
- """This signal is connected asynchronously to a slot. It allows to
- emit directoryLoaded as an asynchronous signal."""
-
- directoryLoaded = qt.Signal(str)
- """This signal is emitted when the gatherer thread has finished to load the
- path."""
-
- rootPathChanged = qt.Signal(str)
- """This signal is emitted whenever the root path has been changed to a
- newPath."""
-
- NAME_COLUMN = 0
- SIZE_COLUMN = 1
- TYPE_COLUMN = 2
- LAST_MODIFIED_COLUMN = 3
-
- def __init__(self, parent=None):
- qt.QAbstractItemModel.__init__(self, parent)
- self.__computer = _Item(qt.QFileInfo())
- self.__header = "Name", "Size", "Type", "Last modification"
- self.__currentPath = ""
- self.__iconProvider = SafeFileIconProvider()
- self.__directoryLoadedSync.connect(self.__emitDirectoryLoaded, qt.Qt.QueuedConnection)
-
- def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
- if orientation == qt.Qt.Horizontal:
- if role == qt.Qt.DisplayRole:
- return self.__header[section]
- if role == qt.Qt.TextAlignmentRole:
- return qt.Qt.AlignRight if section == 1 else qt.Qt.AlignLeft
- return None
-
- def flags(self, index):
- if not index.isValid():
- return 0
- return qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable
-
- def columnCount(self, parent=qt.QModelIndex()):
- return len(self.__header)
-
- def rowCount(self, parent=qt.QModelIndex()):
- item = self.__getItem(parent)
- return item.childCount()
-
- def data(self, index, role=qt.Qt.DisplayRole):
- if not index.isValid():
- return None
-
- column = index.column()
- if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
- if column == self.NAME_COLUMN:
- return self.__displayName(index)
- elif column == self.SIZE_COLUMN:
- return self.size(index)
- elif column == self.TYPE_COLUMN:
- return self.type(index)
- elif column == self.LAST_MODIFIED_COLUMN:
- return self.lastModified(index)
- else:
- _logger.warning("data: invalid display value column %d", index.column())
- elif role == qt.QFileSystemModel.FilePathRole:
- return self.filePath(index)
- elif role == qt.QFileSystemModel.FileNameRole:
- return self.fileName(index)
- elif role == qt.Qt.DecorationRole:
- if column == self.NAME_COLUMN:
- icon = self.fileIcon(index)
- if icon is None or icon.isNull():
- if self.isDir(index):
- self.__iconProvider.icon(qt.QFileIconProvider.Folder)
- else:
- self.__iconProvider.icon(qt.QFileIconProvider.File)
- return icon
- elif role == qt.Qt.TextAlignmentRole:
- if column == self.SIZE_COLUMN:
- return qt.Qt.AlignRight
- elif role == qt.QFileSystemModel.FilePermissions:
- return self.permissions(index)
-
- return None
-
- def index(self, *args, **kwargs):
- path_api = False
- path_api |= len(args) >= 1 and isinstance(args[0], six.string_types)
- path_api |= "path" in kwargs
-
- if path_api:
- return self.__indexFromPath(*args, **kwargs)
- else:
- return self.__index(*args, **kwargs)
-
- def __index(self, row, column, parent=qt.QModelIndex()):
- if parent.isValid() and parent.column() != 0:
- return None
-
- parentItem = self.__getItem(parent)
- item = parentItem.childAt(row)
- return self.createIndex(row, column, item)
-
- def __indexFromPath(self, path, column=0):
- """
- Uses the index(str) C++ API
-
- :rtype: qt.QModelIndex
- """
- if path == "":
- return qt.QModelIndex()
-
- item = self.__computer.findChildrenByPath(path)
- if item is None:
- return qt.QModelIndex()
-
- return self.createIndex(item.parent().indexOf(item), column, item)
-
- def parent(self, index):
- if not index.isValid():
- return qt.QModelIndex()
-
- item = self.__getItem(index)
- if index is None:
- return qt.QModelIndex()
-
- parent = item.parent()
- if parent is None or parent is self.__computer:
- return qt.QModelIndex()
-
- return self.createIndex(parent.parent().indexOf(parent), 0, parent)
-
- def __emitDirectoryLoaded(self, path):
- self.directoryLoaded.emit(path)
-
- def __emitRootPathChanged(self, path):
- self.rootPathChanged.emit(path)
-
- def __getItem(self, index):
- if not index.isValid():
- return self.__computer
- item = index.internalPointer()
- return item
-
- def fileIcon(self, index):
- item = self.__getItem(index)
- if self.__iconProvider is not None:
- fileInfo = item.fileInfo()
- result = self.__iconProvider.icon(fileInfo)
- else:
- style = qt.QApplication.instance().style()
- if item.isRoot():
- result = style.standardIcon(qt.QStyle.SP_ComputerIcon)
- elif item.isDrive():
- result = style.standardIcon(qt.QStyle.SP_DriveHDIcon)
- elif item.isDir():
- result = style.standardIcon(qt.QStyle.SP_DirIcon)
- else:
- result = style.standardIcon(qt.QStyle.SP_FileIcon)
- return result
-
- def _item(self, index):
- item = self.__getItem(index)
- return item
-
- def fileInfo(self, index):
- item = self.__getItem(index)
- result = item.fileInfo()
- return result
-
- def __fileIcon(self, index):
- item = self.__getItem(index)
- result = item.fileName()
- return result
-
- def __displayName(self, index):
- item = self.__getItem(index)
- result = item.fileName()
- return result
-
- def fileName(self, index):
- item = self.__getItem(index)
- result = item.fileName()
- return result
-
- def filePath(self, index):
- item = self.__getItem(index)
- result = item.fileInfo().filePath()
- return result
-
- def isDir(self, index):
- item = self.__getItem(index)
- result = item.isDir()
- return result
-
- def lastModified(self, index):
- item = self.__getItem(index)
- result = item.fileInfo().lastModified()
- return result
-
- def permissions(self, index):
- item = self.__getItem(index)
- result = item.fileInfo().permissions()
- return result
-
- def size(self, index):
- item = self.__getItem(index)
- result = item.fileInfo().size()
- return result
-
- def type(self, index):
- item = self.__getItem(index)
- if self.__iconProvider is not None:
- fileInfo = item.fileInfo()
- result = self.__iconProvider.type(fileInfo)
- else:
- if item.isRoot():
- result = "Computer"
- elif item.isDrive():
- result = "Drive"
- elif item.isDir():
- result = "Directory"
- else:
- fileInfo = item.fileInfo()
- result = fileInfo.suffix()
- return result
-
- # File manipulation
-
- # bool remove(const QModelIndex & index) const
- # bool rmdir(const QModelIndex & index) const
- # QModelIndex mkdir(const QModelIndex & parent, const QString & name)
-
- # Configuration
-
- def rootDirectory(self):
- return qt.QDir(self.rootPath())
-
- def rootPath(self):
- return self.__currentPath
-
- def setRootPath(self, path):
- if self.__currentPath == path:
- return
- self.__currentPath = path
- item = self.__computer.findChildrenByPath(path)
- self.__emitRootPathChanged(path)
- if item is None or item.parent() is None:
- return qt.QModelIndex()
- index = self.createIndex(item.parent().indexOf(item), 0, item)
- self.__directoryLoadedSync.emit(path)
- return index
-
- def iconProvider(self):
- # FIXME: invalidate the model
- return self.__iconProvider
-
- def setIconProvider(self, provider):
- # FIXME: invalidate the model
- self.__iconProvider = provider
-
- # bool resolveSymlinks() const
- # void setResolveSymlinks(bool enable)
-
- def setNameFilterDisables(self, enable):
- return None
-
- def nameFilterDisables(self):
- return None
-
- def myComputer(self, role=qt.Qt.DisplayRole):
- return None
-
- def setNameFilters(self, filters):
- return
-
- def nameFilters(self):
- return None
-
- def filter(self):
- return self.__filters
-
- def setFilter(self, filters):
- return
-
- def setReadOnly(self, enable):
- assert(enable is True)
-
- def isReadOnly(self):
- return False
-
-
-class SafeFileSystemModel(qt.QSortFilterProxyModel):
- """
- This class implement a file system model and try to avoid freeze. On Qt4,
- :class:`qt.QFileSystemModel` is known to freeze the file system when
- network drives are available.
-
- To avoid this behaviour, this class does not use
- `qt.QFileInfo.absoluteFilePath` nor `qt.QFileInfo.canonicalPath` to reach
- information on drives.
-
- And because it is the end of life of Qt4, we do not implement asynchronous
- loading of files as it is done by :class:`qt.QFileSystemModel`, nor some
- useful features.
- """
-
- def __init__(self, parent=None):
- qt.QSortFilterProxyModel.__init__(self, parent=parent)
- self.__nameFilterDisables = sys.platform == "darwin"
- self.__nameFilters = []
- self.__filters = qt.QDir.AllEntries | qt.QDir.NoDotAndDotDot | qt.QDir.AllDirs
- sourceModel = _RawFileSystemModel(self)
- self.setSourceModel(sourceModel)
-
- @property
- def directoryLoaded(self):
- return self.sourceModel().directoryLoaded
-
- @property
- def rootPathChanged(self):
- return self.sourceModel().rootPathChanged
-
- def index(self, *args, **kwargs):
- path_api = False
- path_api |= len(args) >= 1 and isinstance(args[0], six.string_types)
- path_api |= "path" in kwargs
-
- if path_api:
- return self.__indexFromPath(*args, **kwargs)
- else:
- return self.__index(*args, **kwargs)
-
- def __index(self, row, column, parent=qt.QModelIndex()):
- return qt.QSortFilterProxyModel.index(self, row, column, parent)
-
- def __indexFromPath(self, path, column=0):
- """
- Uses the index(str) C++ API
-
- :rtype: qt.QModelIndex
- """
- if path == "":
- return qt.QModelIndex()
-
- index = self.sourceModel().index(path, column)
- index = self.mapFromSource(index)
- return index
-
- def lessThan(self, leftSourceIndex, rightSourceIndex):
- sourceModel = self.sourceModel()
- sortColumn = self.sortColumn()
- if sortColumn == _RawFileSystemModel.NAME_COLUMN:
- leftItem = sourceModel._item(leftSourceIndex)
- rightItem = sourceModel._item(rightSourceIndex)
- if sys.platform != "darwin":
- # Sort directories before files
- leftIsDir = leftItem.isDir()
- rightIsDir = rightItem.isDir()
- if leftIsDir ^ rightIsDir:
- return leftIsDir
- return leftItem.fileName().lower() < rightItem.fileName().lower()
- elif sortColumn == _RawFileSystemModel.SIZE_COLUMN:
- left = sourceModel.fileInfo(leftSourceIndex)
- right = sourceModel.fileInfo(rightSourceIndex)
- return left.size() < right.size()
- elif sortColumn == _RawFileSystemModel.TYPE_COLUMN:
- left = sourceModel.type(leftSourceIndex)
- right = sourceModel.type(rightSourceIndex)
- return left < right
- elif sortColumn == _RawFileSystemModel.LAST_MODIFIED_COLUMN:
- left = sourceModel.fileInfo(leftSourceIndex)
- right = sourceModel.fileInfo(rightSourceIndex)
- return left.lastModified() < right.lastModified()
- else:
- _logger.warning("Unsupported sorted column %d", sortColumn)
-
- return False
-
- def __filtersAccepted(self, item, filters):
- """
- Check individual flag filters.
- """
- if not (filters & (qt.QDir.Dirs | qt.QDir.AllDirs)):
- # Hide dirs
- if item.isDir():
- return False
- if not (filters & qt.QDir.Files):
- # Hide files
- if item.isFile():
- return False
- if not (filters & qt.QDir.Drives):
- # Hide drives
- if item.isDrive():
- return False
-
- fileInfo = item.fileInfo()
- if fileInfo is None:
- return False
-
- filterPermissions = (filters & qt.QDir.PermissionMask) != 0
- if filterPermissions and (filters & (qt.QDir.Dirs | qt.QDir.Files)):
- if (filters & qt.QDir.Readable):
- # Hide unreadable
- if not fileInfo.isReadable():
- return False
- if (filters & qt.QDir.Writable):
- # Hide unwritable
- if not fileInfo.isWritable():
- return False
- if (filters & qt.QDir.Executable):
- # Hide unexecutable
- if not fileInfo.isExecutable():
- return False
-
- if (filters & qt.QDir.NoSymLinks):
- # Hide sym links
- if fileInfo.isSymLink():
- return False
-
- if not (filters & qt.QDir.System):
- # Hide system
- if not item.isDir() and not item.isFile():
- return False
-
- fileName = item.fileName()
- isDot = fileName == "."
- isDotDot = fileName == ".."
-
- if not (filters & qt.QDir.Hidden):
- # Hide hidden
- if not (isDot or isDotDot) and fileInfo.isHidden():
- return False
-
- if filters & (qt.QDir.NoDot | qt.QDir.NoDotDot | qt.QDir.NoDotAndDotDot):
- # Hide parent/self references
- if filters & qt.QDir.NoDot:
- if isDot:
- return False
- if filters & qt.QDir.NoDotDot:
- if isDotDot:
- return False
- if filters & qt.QDir.NoDotAndDotDot:
- if isDot or isDotDot:
- return False
-
- return True
-
- def filterAcceptsRow(self, sourceRow, sourceParent):
- if not sourceParent.isValid():
- return True
-
- sourceModel = self.sourceModel()
- index = sourceModel.index(sourceRow, 0, sourceParent)
- if not index.isValid():
- return True
- item = sourceModel._item(index)
-
- filters = self.__filters
-
- if item.isDrive():
- # Let say a user always have access to a drive
- # It avoid to access to fileInfo then avoid to freeze the file
- # system
- return True
-
- if not self.__filtersAccepted(item, filters):
- return False
-
- if self.__nameFilterDisables:
- return True
-
- if item.isDir() and (filters & qt.QDir.AllDirs):
- # dont apply the filters to directory names
- return True
-
- return self.__nameFiltersAccepted(item)
-
- def __nameFiltersAccepted(self, item):
- if len(self.__nameFilters) == 0:
- return True
-
- fileName = item.fileName()
- for reg in self.__nameFilters:
- if reg.exactMatch(fileName):
- return True
- return False
-
- def setNameFilterDisables(self, enable):
- self.__nameFilterDisables = enable
- self.invalidate()
-
- def nameFilterDisables(self):
- return self.__nameFilterDisables
-
- def myComputer(self, role=qt.Qt.DisplayRole):
- return self.sourceModel().myComputer(role)
-
- def setNameFilters(self, filters):
- self.__nameFilters = []
- isCaseSensitive = self.__filters & qt.QDir.CaseSensitive
- caseSensitive = qt.Qt.CaseSensitive if isCaseSensitive else qt.Qt.CaseInsensitive
- for f in filters:
- reg = qt.QRegExp(f, caseSensitive, qt.QRegExp.Wildcard)
- self.__nameFilters.append(reg)
- self.invalidate()
-
- def nameFilters(self):
- return [f.pattern() for f in self.__nameFilters]
-
- def filter(self):
- return self.__filters
-
- def setFilter(self, filters):
- self.__filters = filters
- # In case of change of case sensitivity
- self.setNameFilters(self.nameFilters())
- self.invalidate()
-
- def setReadOnly(self, enable):
- assert(enable is True)
-
- def isReadOnly(self):
- return False
-
- def rootPath(self):
- return self.sourceModel().rootPath()
-
- def setRootPath(self, path):
- index = self.sourceModel().setRootPath(path)
- index = self.mapFromSource(index)
- return index
-
- def flags(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- filters = sourceModel.flags(index)
-
- if self.__nameFilterDisables and not sourceModel.isDir(index):
- item = sourceModel._item(index)
- if not self.__nameFiltersAccepted(item):
- filters &= ~qt.Qt.ItemIsEnabled
-
- return filters
-
- def fileIcon(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.fileIcon(index)
-
- def fileInfo(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.fileInfo(index)
-
- def fileName(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.fileName(index)
-
- def filePath(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.filePath(index)
-
- def isDir(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.isDir(index)
-
- def lastModified(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.lastModified(index)
-
- def permissions(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.permissions(index)
-
- def size(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.size(index)
-
- def type(self, index):
- sourceModel = self.sourceModel()
- index = self.mapToSource(index)
- return sourceModel.type(index)
diff --git a/silx/gui/dialog/test/__init__.py b/silx/gui/dialog/test/__init__.py
deleted file mode 100644
index f43a37a..0000000
--- a/silx/gui/dialog/test/__init__.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for Qt dialogs"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import logging
-import os
-import sys
-import unittest
-
-
-_logger = logging.getLogger(__name__)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- from . import test_imagefiledialog
- from . import test_datafiledialog
- from . import test_colormapdialog
- test_suite.addTest(test_imagefiledialog.suite())
- test_suite.addTest(test_datafiledialog.suite())
- test_suite.addTest(test_colormapdialog.suite())
- return test_suite
diff --git a/silx/gui/dialog/test/test_colormapdialog.py b/silx/gui/dialog/test/test_colormapdialog.py
deleted file mode 100644
index 61e6365..0000000
--- a/silx/gui/dialog/test/test_colormapdialog.py
+++ /dev/null
@@ -1,453 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for ColormapDialog"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "09/11/2018"
-
-
-import unittest
-
-from silx.gui import qt
-from silx.gui.dialog import ColormapDialog
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.colors import Colormap, preferredColormaps
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.plot.PlotWindow import PlotWindow
-from silx.gui.plot.items.image import ImageData
-
-import numpy.random
-
-
-class TestColormapDialog(TestCaseQt, ParametricTestCase):
- """Test the ColormapDialog."""
- def setUp(self):
- TestCaseQt.setUp(self)
- ParametricTestCase.setUp(self)
- self.colormap = Colormap(name='gray', vmin=10.0, vmax=20.0,
- normalization='linear')
-
- self.colormapDiag = ColormapDialog.ColormapDialog()
-
- def tearDown(self):
- self.qapp.processEvents()
- colormapDiag = self.colormapDiag
- self.colormapDiag = None
- if colormapDiag is not None:
- colormapDiag.close()
- colormapDiag.deleteLater()
- colormapDiag = None
- self.qapp.processEvents()
- ParametricTestCase.tearDown(self)
- TestCaseQt.tearDown(self)
-
- def testGUIEdition(self):
- """Make sure the colormap is correctly edited and also that the
- modification are correctly updated if an other colormapdialog is
- editing the same colormap"""
- colormapDiag2 = ColormapDialog.ColormapDialog()
- colormapDiag2.setColormap(self.colormap)
- colormapDiag2.show()
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
-
- self.colormapDiag._comboBoxColormap._setCurrentName('red')
- self.colormapDiag._comboBoxNormalization.setCurrentIndex(
- self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
- self.assertTrue(self.colormap.getName() == 'red')
- self.assertTrue(self.colormapDiag.getColormap().getName() == 'red')
- self.assertTrue(self.colormap.getNormalization() == 'log')
- self.assertTrue(self.colormap.getVMin() == 10)
- self.assertTrue(self.colormap.getVMax() == 20)
- # checked second colormap dialog
- self.assertTrue(colormapDiag2._comboBoxColormap.getCurrentName() == 'red')
- self.assertEqual(colormapDiag2._comboBoxNormalization.currentData(),
- Colormap.LOGARITHM)
- self.assertTrue(int(colormapDiag2._minValue.getValue()) == 10)
- self.assertTrue(int(colormapDiag2._maxValue.getValue()) == 20)
- colormapDiag2.close()
-
- def testGUIModalOk(self):
- """Make sure the colormap is modified if gone through accept"""
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(True)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.colormapDiag._maxValue.sigAutoScaleChanged.emit(True)
- self.mouseClick(
- widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Ok),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is None)
- self.assertTrue(self.colormap.getVMax() is None)
- self.assertTrue(self.colormap.isAutoscale() is True)
-
- def testGUIModalCancel(self):
- """Make sure the colormap is not modified if gone through reject"""
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(True)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.mouseClick(
- widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Cancel),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is not None)
-
- def testGUIModalClose(self):
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(False)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.mouseClick(
- widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Close),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is None)
-
- def testGUIModalReset(self):
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.setModal(False)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.mouseClick(
- widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset),
- button=qt.Qt.LeftButton
- )
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag.close()
-
- def testGUIClose(self):
- """Make sure the colormap is modify if go through reject"""
- assert self.colormap.isAutoscale() is False
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(self.colormap)
- self.assertTrue(self.colormap.getVMin() is not None)
- self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
- self.assertTrue(self.colormap.getVMin() is None)
- self.colormapDiag.close()
- self.qapp.processEvents()
- self.assertTrue(self.colormap.getVMin() is None)
-
- def testSetColormapIsCorrect(self):
- """Make sure the interface fir the colormap when set a new colormap"""
- self.colormap.setName('red')
- self.colormapDiag.show()
- self.qapp.processEvents()
- for norm in (Colormap.NORMALIZATIONS):
- for autoscale in (True, False):
- if autoscale is True:
- self.colormap.setVRange(None, None)
- else:
- self.colormap.setVRange(11, 101)
- self.colormap.setNormalization(norm)
- with self.subTest(colormap=self.colormap):
- self.colormapDiag.setColormap(self.colormap)
- self.assertEqual(
- self.colormapDiag._comboBoxNormalization.currentData(), norm)
- self.assertTrue(
- self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
- self.assertTrue(
- self.colormapDiag._minValue.isAutoChecked() == autoscale)
- self.assertTrue(
- self.colormapDiag._maxValue.isAutoChecked() == autoscale)
- if autoscale is False:
- self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
- self.assertTrue(self.colormapDiag._maxValue.getValue() == 101)
- self.assertTrue(self.colormapDiag._minValue.isEnabled())
- self.assertTrue(self.colormapDiag._maxValue.isEnabled())
- else:
- self.assertFalse(self.colormapDiag._minValue._numVal.isEnabled())
- self.assertFalse(self.colormapDiag._maxValue._numVal.isEnabled())
-
- def testColormapDel(self):
- """Check behavior if the colormap has been deleted outside. For now
- we make sure the colormap is still running and nothing more"""
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
- del self.colormap
- self.assertTrue(self.colormapDiag.getColormap() is None)
- self.colormapDiag._comboBoxColormap._setCurrentName('blue')
-
- def testColormapEditedOutside(self):
- """Make sure the GUI is still up to date if the colormap is modified
- outside"""
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
-
- self.colormap.setName('red')
- self.assertTrue(
- self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
- self.colormap.setNormalization(Colormap.LOGARITHM)
- self.assertEqual(self.colormapDiag._comboBoxNormalization.currentData(),
- Colormap.LOGARITHM)
- self.colormap.setVRange(11, 201)
- self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
- self.assertTrue(self.colormapDiag._maxValue.getValue() == 201)
- self.assertTrue(self.colormapDiag._minValue._numVal.isEnabled())
- self.assertTrue(self.colormapDiag._maxValue._numVal.isEnabled())
- self.assertFalse(self.colormapDiag._minValue.isAutoChecked())
- self.assertFalse(self.colormapDiag._maxValue.isAutoChecked())
- self.colormap.setVRange(None, None)
- self.assertFalse(self.colormapDiag._minValue._numVal.isEnabled())
- self.assertFalse(self.colormapDiag._maxValue._numVal.isEnabled())
- self.assertTrue(self.colormapDiag._minValue.isAutoChecked())
- self.assertTrue(self.colormapDiag._maxValue.isAutoChecked())
-
- def testSetColormapScenario(self):
- """Test of a simple scenario of a colormap dialog editing several
- colormap"""
- colormap1 = Colormap(name='gray', vmin=10.0, vmax=20.0,
- normalization='linear')
- colormap2 = Colormap(name='red', vmin=10.0, vmax=20.0,
- normalization='log')
- colormap3 = Colormap(name='blue', vmin=None, vmax=None,
- normalization='linear')
- self.colormapDiag.setColormap(self.colormap)
- self.colormapDiag.setColormap(colormap1)
- del colormap1
- self.colormapDiag.setColormap(colormap2)
- del colormap2
- self.colormapDiag.setColormap(colormap3)
- del colormap3
-
- def testNotPreferredColormap(self):
- """Test that the colormapEditor is able to edit a colormap which is not
- part of the 'prefered colormap'
- """
- def getFirstNotPreferredColormap():
- cms = Colormap.getSupportedColormaps()
- preferred = preferredColormaps()
- for cm in cms:
- if cm not in preferred:
- return cm
- return None
-
- colormapName = getFirstNotPreferredColormap()
- assert colormapName is not None
- colormap = Colormap(name=colormapName)
- self.colormapDiag.setColormap(colormap)
- self.colormapDiag.show()
- self.qapp.processEvents()
- cb = self.colormapDiag._comboBoxColormap
- self.assertTrue(cb.getCurrentName() == colormapName)
- cb.setCurrentIndex(0)
- index = cb.findLutName(colormapName)
- assert index != 0 # if 0 then the rest of the test has no sense
- cb.setCurrentIndex(index)
- self.assertTrue(cb.getCurrentName() == colormapName)
-
- def testColormapEditableMode(self):
- """Test that the colormapDialog is correctly updated when changing the
- colormap editable status"""
- colormap = Colormap(normalization='linear', vmin=1.0, vmax=10.0)
- self.colormapDiag.show()
- self.qapp.processEvents()
- self.colormapDiag.setColormap(colormap)
- for editable in (True, False):
- with self.subTest(editable=editable):
- colormap.setEditable(editable)
- self.assertTrue(
- self.colormapDiag._comboBoxColormap.isEnabled() is editable)
- self.assertTrue(
- self.colormapDiag._minValue.isEnabled() is editable)
- self.assertTrue(
- self.colormapDiag._maxValue.isEnabled() is editable)
- self.assertTrue(
- self.colormapDiag._comboBoxNormalization.isEnabled() is editable)
-
- # Make sure the reset button is also set to enable when edition mode is
- # False
- self.colormapDiag.setModal(False)
- colormap.setEditable(True)
- self.colormapDiag._comboBoxNormalization.setCurrentIndex(
- self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
- resetButton = self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
- self.assertTrue(resetButton.isEnabled())
- colormap.setEditable(False)
- self.assertFalse(resetButton.isEnabled())
-
- def testImageData(self):
- data = numpy.random.rand(5, 5)
- self.colormapDiag.setData(data)
-
- def testEmptyData(self):
- data = numpy.empty((10, 0))
- self.colormapDiag.setData(data)
-
- def testNoneData(self):
- data = numpy.random.rand(5, 5)
- self.colormapDiag.setData(data)
- self.colormapDiag.setData(None)
-
- def testImageItem(self):
- """Check that an ImageData plot item can be used"""
- dialog = self.colormapDiag
- colormap = Colormap(name='gray', vmin=None, vmax=None)
- data = numpy.arange(3**2).reshape(3, 3)
- item = ImageData()
- item.setData(data, copy=False)
-
- dialog.setColormap(colormap)
- dialog.show()
- self.qapp.processEvents()
- dialog.setItem(item)
- vrange = dialog._getFiniteColormapRange()
- self.assertEqual(vrange, (0, 8))
-
- def testItemDel(self):
- """Check that the plot items are not hard linked to the dialog"""
- dialog = self.colormapDiag
- colormap = Colormap(name='gray', vmin=None, vmax=None)
- data = numpy.arange(3**2).reshape(3, 3)
- item = ImageData()
- item.setData(data, copy=False)
-
- dialog.setColormap(colormap)
- dialog.show()
- self.qapp.processEvents()
- dialog.setItem(item)
- previousRange = dialog._getFiniteColormapRange()
- del item
- vrange = dialog._getFiniteColormapRange()
- self.assertNotEqual(vrange, previousRange)
-
- def testDataDel(self):
- """Check that the data are not hard linked to the dialog"""
- dialog = self.colormapDiag
- colormap = Colormap(name='gray', vmin=None, vmax=None)
- data = numpy.arange(5)
-
- dialog.setColormap(colormap)
- dialog.show()
- self.qapp.processEvents()
- dialog.setData(data)
- previousRange = dialog._getFiniteColormapRange()
- del data
- vrange = dialog._getFiniteColormapRange()
- self.assertNotEqual(vrange, previousRange)
-
- def testDeleteWhileExec(self):
- colormapDiag = self.colormapDiag
- self.colormapDiag = None
- qt.QTimer.singleShot(1000, colormapDiag.deleteLater)
- result = colormapDiag.exec_()
- self.assertEqual(result, 0)
-
-
-class TestColormapAction(TestCaseQt):
- def setUp(self):
- TestCaseQt.setUp(self)
- self.plot = PlotWindow()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
-
- self.colormap1 = Colormap(name='blue', vmin=0.0, vmax=1.0,
- normalization='linear')
- self.colormap2 = Colormap(name='red', vmin=10.0, vmax=100.0,
- normalization='log')
- self.defaultColormap = self.plot.getDefaultColormap()
-
- self.plot.getColormapAction()._actionTriggered(checked=True)
- self.colormapDialog = self.plot.getColormapAction()._dialog
- self.colormapDialog.setAttribute(qt.Qt.WA_DeleteOnClose)
-
- def tearDown(self):
- self.colormapDialog.close()
- self.plot.close()
- del self.colormapDialog
- del self.plot
- TestCaseQt.tearDown(self)
-
- def testActiveColormap(self):
- self.assertTrue(self.colormapDialog.getColormap() is self.defaultColormap)
-
- self.plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
- origin=(0, 0),
- colormap=self.colormap1)
- self.plot.setActiveImage('img1')
- self.assertTrue(self.colormapDialog.getColormap() is self.colormap1)
-
- self.plot.addImage(data=numpy.random.rand(10, 10), legend='img2',
- origin=(0, 0),
- colormap=self.colormap2)
- self.plot.addImage(data=numpy.random.rand(10, 10), legend='img3',
- origin=(0, 0))
-
- self.plot.setActiveImage('img3')
- self.assertTrue(self.colormapDialog.getColormap() is self.defaultColormap)
- self.plot.getActiveImage().setColormap(self.colormap2)
- self.assertTrue(self.colormapDialog.getColormap() is self.colormap2)
-
- self.plot.remove('img2')
- self.plot.remove('img3')
- self.plot.remove('img1')
- self.assertTrue(self.colormapDialog.getColormap() is self.defaultColormap)
-
- def testShowHideColormapDialog(self):
- self.plot.getColormapAction()._actionTriggered(checked=False)
- self.assertFalse(self.plot.getColormapAction().isChecked())
- self.plot.getColormapAction()._actionTriggered(checked=True)
- self.assertTrue(self.plot.getColormapAction().isChecked())
- self.plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
- origin=(0, 0),
- colormap=self.colormap1)
- self.colormap1.setName('red')
- self.plot.getColormapAction()._actionTriggered()
- self.colormap1.setName('blue')
- self.colormapDialog.close()
- self.assertFalse(self.plot.getColormapAction().isChecked())
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for testClass in (TestColormapDialog, TestColormapAction):
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- testClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/dialog/test/test_datafiledialog.py b/silx/gui/dialog/test/test_datafiledialog.py
deleted file mode 100644
index b60ea12..0000000
--- a/silx/gui/dialog/test/test_datafiledialog.py
+++ /dev/null
@@ -1,939 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for silx.gui.hdf5 module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "08/03/2019"
-
-
-import unittest
-import tempfile
-import numpy
-import shutil
-import os
-import io
-import weakref
-import fabio
-import h5py
-import silx.io.url
-from silx.gui import qt
-from silx.gui.utils import testutils
-from ..DataFileDialog import DataFileDialog
-from silx.gui.hdf5 import Hdf5TreeModel
-
-_tmpDirectory = None
-
-
-def setUpModule():
- global _tmpDirectory
- _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
-
- data = numpy.arange(100 * 100)
- data.shape = 100, 100
-
- filename = _tmpDirectory + "/singleimage.edf"
- image = fabio.edfimage.EdfImage(data=data)
- image.write(filename)
-
- filename = _tmpDirectory + "/data.h5"
- f = h5py.File(filename, "w")
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["complex_image"] = data * 1j
- f["group/image"] = data
- f["nxdata/foo"] = 10
- f["nxdata"].attrs["NX_class"] = u"NXdata"
- f.close()
-
- directory = os.path.join(_tmpDirectory, "data")
- os.mkdir(directory)
- filename = os.path.join(directory, "data.h5")
- f = h5py.File(filename, "w")
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["complex_image"] = data * 1j
- f["group/image"] = data
- f["nxdata/foo"] = 10
- f["nxdata"].attrs["NX_class"] = u"NXdata"
- f.close()
-
- filename = _tmpDirectory + "/badformat.h5"
- with io.open(filename, "wb") as f:
- f.write(b"{\nHello Nurse!")
-
-
-def tearDownModule():
- global _tmpDirectory
- shutil.rmtree(_tmpDirectory)
- _tmpDirectory = None
-
-
-class _UtilsMixin(object):
-
- def createDialog(self):
- self._deleteDialog()
- self._dialog = self._createDialog()
- return self._dialog
-
- def _createDialog(self):
- return DataFileDialog()
-
- def _deleteDialog(self):
- if not hasattr(self, "_dialog"):
- return
- if self._dialog is not None:
- ref = weakref.ref(self._dialog)
- self._dialog = None
- self.qWaitForDestroy(ref)
-
- def qWaitForPendingActions(self, dialog):
- for _ in range(20):
- if not dialog.hasPendingEvents():
- return
- self.qWait(10)
- raise RuntimeError("Still have pending actions")
-
- def assertSamePath(self, path1, path2):
- path1_ = os.path.normcase(path1)
- path2_ = os.path.normcase(path2)
- if path1_ != path2_:
- # Use the unittest API to log and display error
- self.assertEqual(path1, path2)
-
- def assertNotSamePath(self, path1, path2):
- path1_ = os.path.normcase(path1)
- path2_ = os.path.normcase(path2)
- if path1_ == path2_:
- # Use the unittest API to log and display error
- self.assertNotEqual(path1, path2)
-
-
-class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def testDisplayAndKeyEscape(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
-
- self.keyClick(dialog, qt.Qt.Key_Escape)
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Rejected)
-
- def testDisplayAndClickCancel(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="cancel")[0]
- self.mouseClick(button, qt.Qt.LeftButton)
- self.assertFalse(dialog.isVisible())
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Rejected)
-
- def testDisplayAndClickLockedOpen(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.mouseClick(button, qt.Qt.LeftButton)
- # open button locked, dialog is not closed
- self.assertTrue(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Rejected)
-
- def testSelectRoot_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertTrue(url.data_path() is not None)
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
- def testSelectGroup_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/group")
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
- def testSelectDataset_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/scalar")
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
- def testClickOnBackToParentTool(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- action = testutils.findChildren(dialog, qt.QAction, name="toParentAction")[0]
- toParentButton = testutils.getQToolButtonFromAction(action)
- filename = _tmpDirectory + "/data/data.h5"
-
- # init state
- path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
- dialog.selectUrl(path)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
- self.assertSamePath(url.text(), path)
- # test
- self.mouseClick(toParentButton, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- self.assertSamePath(url.text(), path)
-
- self.mouseClick(toParentButton, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), _tmpDirectory + "/data")
-
- self.mouseClick(toParentButton, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), _tmpDirectory)
-
- def testClickOnBackToRootTool(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- action = testutils.findChildren(dialog, qt.QAction, name="toRootFileAction")[0]
- button = testutils.getQToolButtonFromAction(action)
- filename = _tmpDirectory + "/data.h5"
-
- # init state
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
- dialog.selectUrl(path)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), path)
- self.assertTrue(button.isEnabled())
- # test
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- self.assertSamePath(url.text(), path)
- # self.assertFalse(button.isEnabled())
-
- def testClickOnBackToDirectoryTool(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- action = testutils.findChildren(dialog, qt.QAction, name="toDirectoryAction")[0]
- button = testutils.getQToolButtonFromAction(action)
- filename = _tmpDirectory + "/data.h5"
-
- # init state
- path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
- dialog.selectUrl(path)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
- self.assertSamePath(url.text(), path)
- self.assertTrue(button.isEnabled())
- # test
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), _tmpDirectory)
- self.assertFalse(button.isEnabled())
-
- # FIXME: There is an unreleased qt.QWidget without nameObject
- # No idea where it come from.
- self.allowedLeakingWidgets = 1
-
- def testClickOnHistoryTools(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
- backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
- filename = _tmpDirectory + "/data.h5"
-
- dialog.setDirectory(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- # No way to use QTest.mouseDClick with QListView, QListWidget
- # Then we feed the history using selectPath
- dialog.selectUrl(filename)
- self.qWaitForPendingActions(dialog)
- path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- dialog.selectUrl(path2)
- self.qWaitForPendingActions(dialog)
- path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
- dialog.selectUrl(path3)
- self.qWaitForPendingActions(dialog)
- self.assertFalse(forwardAction.isEnabled())
- self.assertTrue(backwardAction.isEnabled())
-
- button = testutils.getQToolButtonFromAction(backwardAction)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertTrue(forwardAction.isEnabled())
- self.assertTrue(backwardAction.isEnabled())
- self.assertSamePath(url.text(), path2)
-
- button = testutils.getQToolButtonFromAction(forwardAction)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertFalse(forwardAction.isEnabled())
- self.assertTrue(backwardAction.isEnabled())
- self.assertSamePath(url.text(), path3)
-
- def testSelectImageFromEdf(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/singleimage.edf"
- url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scan_0/instrument/detector_0/data")
- dialog.selectUrl(url.path())
- self.assertEqual(dialog._selectedData().shape, (100, 100))
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), url.path())
-
- def testSelectImage(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
- dialog.selectUrl(path)
- # test
- self.assertEqual(dialog._selectedData().shape, (100, 100))
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectScalar(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scalar").path()
- dialog.selectUrl(path)
- # test
- self.assertEqual(dialog._selectedData()[()], 10)
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectGroup(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- uri = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group")
- dialog.selectUrl(uri.path())
- self.qWaitForPendingActions(dialog)
- # test
- self.assertTrue(silx.io.is_group(dialog._selectedData()))
- self.assertSamePath(dialog.selectedFile(), filename)
- uri = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertSamePath(uri.data_path(), "/group")
-
- def testSelectRoot(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- uri = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/")
- dialog.selectUrl(uri.path())
- self.qWaitForPendingActions(dialog)
- # test
- self.assertTrue(silx.io.is_file(dialog._selectedData()))
- self.assertSamePath(dialog.selectedFile(), filename)
- uri = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertSamePath(uri.data_path(), "/")
-
- def testSelectH5_Activate(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- index = browser.rootIndex().model().index(filename)
- # click
- browser.selectIndex(index)
- # double click
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
- # test
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectBadFileFormat_Activate(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- filename = _tmpDirectory + "/badformat.h5"
- index = browser.model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
- # test
- self.assertSamePath(dialog.selectedUrl(), filename)
-
- def _countSelectableItems(self, model, rootIndex):
- selectable = 0
- for i in range(model.rowCount(rootIndex)):
- index = model.index(i, 0, rootIndex)
- flags = model.flags(index)
- isEnabled = (int(flags) & qt.Qt.ItemIsEnabled) != 0
- if isEnabled:
- selectable += 1
- return selectable
-
- def testFilterExtensions(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
-
-
-class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def _createDialog(self):
- dialog = DataFileDialog()
- dialog.setFilterMode(DataFileDialog.FilterMode.ExistingDataset)
- return dialog
-
- def testSelectGroup_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertFalse(button.isEnabled())
-
- def testSelectDataset_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/scalar")
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
- data = dialog.selectedData()
- self.assertEqual(data, 10)
-
-
-class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def _createDialog(self):
- dialog = DataFileDialog()
- dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
- return dialog
-
- def testSelectGroup_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/group")
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
- self.assertRaises(Exception, dialog.selectedData)
-
- def testSelectDataset_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertFalse(button.isEnabled())
-
-
-class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def _createDialog(self):
- def customFilter(obj):
- if "NX_class" in obj.attrs:
- return obj.attrs["NX_class"] == u"NXdata"
- return False
-
- dialog = DataFileDialog()
- dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
- dialog.setFilterCallback(customFilter)
- return dialog
-
- def testSelectGroupRefused_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertFalse(button.isEnabled())
-
- self.assertRaises(Exception, dialog.selectedData)
-
- def testSelectNXdataAccepted_Activate(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/data.h5"
- dialog.selectFile(os.path.dirname(filename))
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- # select, then double click on the file
- index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/nxdata"])
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/nxdata")
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
-
-class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def _createDialog(self):
- dialog = DataFileDialog()
- return dialog
-
- def testSaveRestoreState(self):
- dialog = self.createDialog()
- dialog.setDirectory(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- state = dialog.saveState()
- dialog = None
-
- dialog2 = self.createDialog()
- result = dialog2.restoreState(state)
- self.assertTrue(result)
- dialog2 = None
-
- def printState(self):
- """
- Print state of the ImageFileDialog.
-
- Can be used to add or regenerate `STATE_VERSION1_QT4` or
- `STATE_VERSION1_QT5`.
-
- >>> ./run_tests.py -v silx.gui.dialog.test.test_datafiledialog.TestDataFileDialogApi.printState
- """
- dialog = self.createDialog()
- dialog.setDirectory("")
- dialog.setHistory([])
- dialog.setSidebarUrls([])
- state = dialog.saveState()
- string = ""
- strings = []
- for i in range(state.size()):
- d = state.data()[i]
- if not isinstance(d, int):
- d = ord(d)
- if d > 0x20 and d < 0x7F:
- string += chr(d)
- else:
- string += "\\x%02X" % d
- if len(string) > 60:
- strings.append(string)
- string = ""
- strings.append(string)
- strings = ["b'%s'" % s for s in strings]
- print()
- print("\\\n".join(strings))
-
- STATE_VERSION1_QT4 = b''\
- b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
- b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
- b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
- b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00\xFF\x00\x00'\
- b'\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00\x00'\
- b'}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00\x00\x00'\
- b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00\xFF\x00\x00'\
- b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00\x00\x81'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00\x00\x00\x04'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x01\xFF\xFF\xFF\xFF'
- """Serialized state on Qt4. Generated using :meth:`printState`"""
-
- STATE_VERSION1_QT5 = b''\
- b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
- b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
- b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
- b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00\xFF\x00\x00'\
- b'\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00'\
- b'\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00'\
- b'\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87\x00\x00\x00\xFF'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00'\
- b'\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00d\x00\x00'\
- b'\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00'\
- b'\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00'\
- b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03\xE8\x00\xFF'\
- b'\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x01'
- """Serialized state on Qt5. Generated using :meth:`printState`"""
-
- def testAvoidRestoreRegression_Version1(self):
- version = qt.qVersion().split(".")[0]
- if version == "4":
- state = self.STATE_VERSION1_QT4
- elif version == "5":
- state = self.STATE_VERSION1_QT5
- else:
- self.skipTest("Resource not available")
-
- state = qt.QByteArray(state)
- dialog = self.createDialog()
- result = dialog.restoreState(state)
- self.assertTrue(result)
-
- def testRestoreRobusness(self):
- """What's happen if you try to open a config file with a different
- binding."""
- state = qt.QByteArray(self.STATE_VERSION1_QT4)
- dialog = self.createDialog()
- dialog.restoreState(state)
- state = qt.QByteArray(self.STATE_VERSION1_QT5)
- dialog = None
- dialog = self.createDialog()
- dialog.restoreState(state)
-
- def testRestoreNonExistingDirectory(self):
- directory = os.path.join(_tmpDirectory, "dir")
- os.mkdir(directory)
- dialog = self.createDialog()
- dialog.setDirectory(directory)
- self.qWaitForPendingActions(dialog)
- state = dialog.saveState()
- os.rmdir(directory)
- dialog = None
-
- dialog2 = self.createDialog()
- result = dialog2.restoreState(state)
- self.assertTrue(result)
- self.assertNotEqual(dialog2.directory(), directory)
-
- def testHistory(self):
- dialog = self.createDialog()
- history = dialog.history()
- dialog.setHistory([])
- self.assertEqual(dialog.history(), [])
- dialog.setHistory(history)
- self.assertEqual(dialog.history(), history)
-
- def testSidebarUrls(self):
- dialog = self.createDialog()
- urls = dialog.sidebarUrls()
- dialog.setSidebarUrls([])
- self.assertEqual(dialog.sidebarUrls(), [])
- dialog.setSidebarUrls(urls)
- self.assertEqual(dialog.sidebarUrls(), urls)
-
- def testDirectory(self):
- dialog = self.createDialog()
- self.qWaitForPendingActions(dialog)
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(dialog.directory(), _tmpDirectory)
-
- def testBadFileFormat(self):
- dialog = self.createDialog()
- dialog.selectUrl(_tmpDirectory + "/badformat.h5")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- def testBadPath(self):
- dialog = self.createDialog()
- dialog.selectUrl("#$%/#$%")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- def testBadSubpath(self):
- dialog = self.createDialog()
- self.qWaitForPendingActions(dialog)
-
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
-
- filename = _tmpDirectory + "/data.h5"
- url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
- dialog.selectUrl(url.path())
- self.qWaitForPendingActions(dialog)
- self.assertIsNotNone(dialog._selectedData())
-
- # an existing node is browsed, but the wrong path is selected
- index = browser.rootIndex()
- obj = index.model().data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
- self.assertEqual(obj.name, "/group")
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/group")
-
- def testUnsupportedSlicingPath(self):
- dialog = self.createDialog()
- self.qWaitForPendingActions(dialog)
- dialog.selectUrl(_tmpDirectory + "/data.h5?path=/cube&slice=0")
- self.qWaitForPendingActions(dialog)
- data = dialog._selectedData()
- if data is None:
- # Maybe nothing is selected
- self.assertTrue(True)
- else:
- # Maybe the cube is selected but not sliced
- self.assertEqual(len(data.shape), 3)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestDataFileDialogInteraction))
- test_suite.addTest(loadTests(TestDataFileDialogApi))
- test_suite.addTest(loadTests(TestDataFileDialog_FilterDataset))
- test_suite.addTest(loadTests(TestDataFileDialog_FilterGroup))
- test_suite.addTest(loadTests(TestDataFileDialog_FilterNXdata))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/dialog/test/test_imagefiledialog.py b/silx/gui/dialog/test/test_imagefiledialog.py
deleted file mode 100644
index 3cbb492..0000000
--- a/silx/gui/dialog/test/test_imagefiledialog.py
+++ /dev/null
@@ -1,784 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for silx.gui.hdf5 module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "08/03/2019"
-
-
-import unittest
-import tempfile
-import numpy
-import shutil
-import os
-import io
-import weakref
-import fabio
-import h5py
-import silx.io.url
-from silx.gui import qt
-from silx.gui.utils import testutils
-from ..ImageFileDialog import ImageFileDialog
-from silx.gui.colors import Colormap
-from silx.gui.hdf5 import Hdf5TreeModel
-
-_tmpDirectory = None
-
-
-def setUpModule():
- global _tmpDirectory
- _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
-
- data = numpy.arange(100 * 100)
- data.shape = 100, 100
-
- filename = _tmpDirectory + "/singleimage.edf"
- image = fabio.edfimage.EdfImage(data=data)
- image.write(filename)
-
- filename = _tmpDirectory + "/multiframe.edf"
- image = fabio.edfimage.EdfImage(data=data)
- image.append_frame(data=data + 1)
- image.append_frame(data=data + 2)
- image.write(filename)
-
- filename = _tmpDirectory + "/singleimage.msk"
- image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1)
- image.write(filename)
-
- filename = _tmpDirectory + "/data.h5"
- with h5py.File(filename, "w") as f:
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["single_frame"] = [data + 5]
- f["complex_image"] = data * 1j
- f["group/image"] = data
-
- directory = os.path.join(_tmpDirectory, "data")
- os.mkdir(directory)
- filename = os.path.join(directory, "data.h5")
- with h5py.File(filename, "w") as f:
- f["scalar"] = 10
- f["image"] = data
- f["cube"] = [data, data + 1, data + 2]
- f["single_frame"] = [data + 5]
- f["complex_image"] = data * 1j
- f["group/image"] = data
-
- filename = _tmpDirectory + "/badformat.edf"
- with io.open(filename, "wb") as f:
- f.write(b"{\nHello Nurse!")
-
-
-def tearDownModule():
- global _tmpDirectory
- shutil.rmtree(_tmpDirectory)
- _tmpDirectory = None
-
-
-class _UtilsMixin(object):
-
- def createDialog(self):
- self._deleteDialog()
- self._dialog = self._createDialog()
- return self._dialog
-
- def _createDialog(self):
- return ImageFileDialog()
-
- def _deleteDialog(self):
- if not hasattr(self, "_dialog"):
- return
- if self._dialog is not None:
- ref = weakref.ref(self._dialog)
- self._dialog = None
- self.qWaitForDestroy(ref)
-
- def qWaitForPendingActions(self, dialog):
- for _ in range(20):
- if not dialog.hasPendingEvents():
- return
- self.qWait(10)
- raise RuntimeError("Still have pending actions")
-
- def assertSamePath(self, path1, path2):
- path1_ = os.path.normcase(path1)
- path2_ = os.path.normcase(path2)
- if path1_ != path2_:
- # Use the unittest API to log and display error
- self.assertEqual(path1, path2)
-
- def assertNotSamePath(self, path1, path2):
- path1_ = os.path.normcase(path1)
- path2_ = os.path.normcase(path2)
- if path1_ == path2_:
- # Use the unittest API to log and display error
- self.assertNotEqual(path1, path2)
-
-
-class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def testDisplayAndKeyEscape(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
-
- self.keyClick(dialog, qt.Qt.Key_Escape)
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Rejected)
-
- def testDisplayAndClickCancel(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="cancel")[0]
- self.mouseClick(button, qt.Qt.LeftButton)
- self.assertFalse(dialog.isVisible())
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Rejected)
-
- def testDisplayAndClickLockedOpen(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.mouseClick(button, qt.Qt.LeftButton)
- # open button locked, dialog is not closed
- self.assertTrue(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Rejected)
-
- def testDisplayAndClickOpen(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- self.assertTrue(dialog.isVisible())
- filename = _tmpDirectory + "/singleimage.edf"
- dialog.selectFile(filename)
- self.qWaitForPendingActions(dialog)
-
- button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
- self.assertTrue(button.isEnabled())
- self.mouseClick(button, qt.Qt.LeftButton)
- self.assertFalse(dialog.isVisible())
- self.assertEqual(dialog.result(), qt.QDialog.Accepted)
-
- def testClickOnShortcut(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- sidebar = testutils.findChildren(dialog, qt.QListView, name="sidebar")[0]
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- dialog.setDirectory(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
-
- self.assertSamePath(url.text(), _tmpDirectory)
-
- urls = sidebar.urls()
- if len(urls) == 0:
- self.skipTest("No sidebar path")
- path = urls[0].path()
- if path != "" and not os.path.exists(path):
- self.skipTest("Sidebar path do not exists")
-
- index = sidebar.model().index(0, 0)
- # rect = sidebar.visualRect(index)
- # self.mouseClick(sidebar, qt.Qt.LeftButton, pos=rect.center())
- # Using mouse click is not working, let's use the selection API
- sidebar.selectionModel().select(index, qt.QItemSelectionModel.ClearAndSelect)
- self.qWaitForPendingActions(dialog)
-
- index = browser.rootIndex()
- if not index.isValid():
- path = ""
- else:
- path = index.model().filePath(index)
- self.assertNotSamePath(_tmpDirectory, path)
- self.assertNotSamePath(url.text(), _tmpDirectory)
-
- def testClickOnDetailView(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- action = testutils.findChildren(dialog, qt.QAction, name="detailModeAction")[0]
- detailModeButton = testutils.getQToolButtonFromAction(action)
- self.mouseClick(detailModeButton, qt.Qt.LeftButton)
- self.assertEqual(dialog.viewMode(), qt.QFileDialog.Detail)
-
- action = testutils.findChildren(dialog, qt.QAction, name="listModeAction")[0]
- listModeButton = testutils.getQToolButtonFromAction(action)
- self.mouseClick(listModeButton, qt.Qt.LeftButton)
- self.assertEqual(dialog.viewMode(), qt.QFileDialog.List)
-
- def testClickOnBackToParentTool(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- action = testutils.findChildren(dialog, qt.QAction, name="toParentAction")[0]
- toParentButton = testutils.getQToolButtonFromAction(action)
- filename = _tmpDirectory + "/data/data.h5"
-
- # init state
- path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
- dialog.selectUrl(path)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
- self.assertSamePath(url.text(), path)
- # test
- self.mouseClick(toParentButton, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- self.assertSamePath(url.text(), path)
-
- self.mouseClick(toParentButton, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), _tmpDirectory + "/data")
-
- self.mouseClick(toParentButton, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), _tmpDirectory)
-
- def testClickOnBackToRootTool(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- action = testutils.findChildren(dialog, qt.QAction, name="toRootFileAction")[0]
- button = testutils.getQToolButtonFromAction(action)
- filename = _tmpDirectory + "/data.h5"
-
- # init state
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
- dialog.selectUrl(path)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), path)
- self.assertTrue(button.isEnabled())
- # test
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- self.assertSamePath(url.text(), path)
- # self.assertFalse(button.isEnabled())
-
- def testClickOnBackToDirectoryTool(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- action = testutils.findChildren(dialog, qt.QAction, name="toDirectoryAction")[0]
- button = testutils.getQToolButtonFromAction(action)
- filename = _tmpDirectory + "/data.h5"
-
- # init state
- path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
- dialog.selectUrl(path)
- self.qWaitForPendingActions(dialog)
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
- self.assertSamePath(url.text(), path)
- self.assertTrue(button.isEnabled())
- # test
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(url.text(), _tmpDirectory)
- self.assertFalse(button.isEnabled())
-
- # FIXME: There is an unreleased qt.QWidget without nameObject
- # No idea where it come from.
- self.allowedLeakingWidgets = 1
-
- def testClickOnHistoryTools(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
- forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
- backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
- filename = _tmpDirectory + "/data.h5"
-
- dialog.setDirectory(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- # No way to use QTest.mouseDClick with QListView, QListWidget
- # Then we feed the history using selectPath
- dialog.selectUrl(filename)
- self.qWaitForPendingActions(dialog)
- path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- dialog.selectUrl(path2)
- self.qWaitForPendingActions(dialog)
- path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
- dialog.selectUrl(path3)
- self.qWaitForPendingActions(dialog)
- self.assertFalse(forwardAction.isEnabled())
- self.assertTrue(backwardAction.isEnabled())
-
- button = testutils.getQToolButtonFromAction(backwardAction)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertTrue(forwardAction.isEnabled())
- self.assertTrue(backwardAction.isEnabled())
- self.assertSamePath(url.text(), path2)
-
- button = testutils.getQToolButtonFromAction(forwardAction)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qWaitForPendingActions(dialog)
- self.assertFalse(forwardAction.isEnabled())
- self.assertTrue(backwardAction.isEnabled())
- self.assertSamePath(url.text(), path3)
-
- def testSelectImageFromEdf(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/singleimage.edf"
- path = filename
- dialog.selectUrl(path)
- self.assertEqual(dialog.selectedImage().shape, (100, 100))
- self.assertSamePath(dialog.selectedFile(), filename)
- path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectImageFromEdf_Activate(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- filename = _tmpDirectory + "/singleimage.edf"
- path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
- index = browser.rootIndex().model().index(filename)
- # click
- browser.selectIndex(index)
- # double click
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
- # test
- self.assertEqual(dialog.selectedImage().shape, (100, 100))
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectFrameFromEdf(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/multiframe.edf"
- path = silx.io.url.DataUrl(scheme="fabio", file_path=filename, data_slice=(1,)).path()
- dialog.selectUrl(path)
- # test
- image = dialog.selectedImage()
- self.assertEqual(image.shape, (100, 100))
- self.assertEqual(image[0, 0], 1)
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectImageFromMsk(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/singleimage.msk"
- path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
- dialog.selectUrl(path)
- # test
- self.assertEqual(dialog.selectedImage().shape, (100, 100))
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectImageFromH5(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
- dialog.selectUrl(path)
- # test
- self.assertEqual(dialog.selectedImage().shape, (100, 100))
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectH5_Activate(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
- index = browser.rootIndex().model().index(filename)
- # click
- browser.selectIndex(index)
- # double click
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
- # test
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectFrameFromH5(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/cube", data_slice=(1, )).path()
- dialog.selectUrl(path)
- # test
- self.assertEqual(dialog.selectedImage().shape, (100, 100))
- self.assertEqual(dialog.selectedImage()[0, 0], 1)
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectSingleFrameFromH5(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- filename = _tmpDirectory + "/data.h5"
- path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/single_frame", data_slice=(0, )).path()
- dialog.selectUrl(path)
- # test
- self.assertEqual(dialog.selectedImage().shape, (100, 100))
- self.assertEqual(dialog.selectedImage()[0, 0], 5)
- self.assertSamePath(dialog.selectedFile(), filename)
- self.assertSamePath(dialog.selectedUrl(), path)
-
- def testSelectBadFileFormat_Activate(self):
- dialog = self.createDialog()
- dialog.show()
- self.qWaitForWindowExposed(dialog)
-
- # init state
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- filename = _tmpDirectory + "/badformat.edf"
- index = browser.model().index(filename)
- browser.selectIndex(index)
- browser.activated.emit(index)
- self.qWaitForPendingActions(dialog)
- # test
- self.assertSamePath(dialog.selectedUrl(), filename)
-
- def _countSelectableItems(self, model, rootIndex):
- selectable = 0
- for i in range(model.rowCount(rootIndex)):
- index = model.index(i, 0, rootIndex)
- flags = model.flags(index)
- isEnabled = (int(flags) & qt.Qt.ItemIsEnabled) != 0
- if isEnabled:
- selectable += 1
- return selectable
-
- def testFilterExtensions(self):
- dialog = self.createDialog()
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
- filters = testutils.findChildren(dialog, qt.QWidget, name="fileTypeCombo")[0]
- dialog.show()
- self.qWaitForWindowExposed(dialog)
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 6)
-
- codecName = fabio.edfimage.EdfImage.codec_name()
- index = filters.indexFromCodec(codecName)
- filters.setCurrentIndex(index)
- filters.activated[int].emit(index)
- self.qWait(50)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
-
- codecName = fabio.fit2dmaskimage.Fit2dMaskImage.codec_name()
- index = filters.indexFromCodec(codecName)
- filters.setCurrentIndex(index)
- filters.activated[int].emit(index)
- self.qWait(50)
- self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 2)
-
-
-class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
-
- def tearDown(self):
- self._deleteDialog()
- testutils.TestCaseQt.tearDown(self)
-
- def testSaveRestoreState(self):
- dialog = self.createDialog()
- dialog.setDirectory(_tmpDirectory)
- colormap = Colormap(normalization=Colormap.LOGARITHM)
- dialog.setColormap(colormap)
- self.qWaitForPendingActions(dialog)
- state = dialog.saveState()
- dialog = None
-
- dialog2 = self.createDialog()
- result = dialog2.restoreState(state)
- self.qWaitForPendingActions(dialog2)
- self.assertTrue(result)
- self.assertEqual(dialog2.colormap().getNormalization(), "log")
-
- def printState(self):
- """
- Print state of the ImageFileDialog.
-
- Can be used to add or regenerate `STATE_VERSION1_QT4` or
- `STATE_VERSION1_QT5`.
-
- >>> ./run_tests.py -v silx.gui.dialog.test.test_imagefiledialog.TestImageFileDialogApi.printState
- """
- dialog = self.createDialog()
- colormap = Colormap(normalization=Colormap.LOGARITHM)
- dialog.setDirectory("")
- dialog.setHistory([])
- dialog.setColormap(colormap)
- dialog.setSidebarUrls([])
- state = dialog.saveState()
- string = ""
- strings = []
- for i in range(state.size()):
- d = state.data()[i]
- if not isinstance(d, int):
- d = ord(d)
- if d > 0x20 and d < 0x7F:
- string += chr(d)
- else:
- string += "\\x%02X" % d
- if len(string) > 60:
- strings.append(string)
- string = ""
- strings.append(string)
- strings = ["b'%s'" % s for s in strings]
- print()
- print("\\\n".join(strings))
-
- STATE_VERSION1_QT4 = b''\
- b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
- b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
- b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
- b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00'\
- b'\xFF\x00\x00\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00'\
- b'\x00\x00\x00}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00'\
- b'r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00'\
- b'\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00'\
- b'\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00'\
- b'\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00'\
- b'\x00\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00'\
- b'o\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
- b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
- """Serialized state on Qt4. Generated using :meth:`printState`"""
-
- STATE_VERSION1_QT5 = b''\
- b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
- b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
- b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
- b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
- b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00'\
- b'\xFF\x00\x00\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
- b'\xFF\xFF\xFF\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C'\
- b'\x00\x00\x00\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s'\
- b'\x00e\x00r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87'\
- b'\x00\x00\x00\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00'\
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF'\
- b'\xFF\xFF\x00\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00'\
- b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00'\
- b'\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00'\
- b'\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03'\
- b'\xE8\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00'\
- b'\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00o'\
- b'\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
- b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
- """Serialized state on Qt5. Generated using :meth:`printState`"""
-
- def testAvoidRestoreRegression_Version1(self):
- version = qt.qVersion().split(".")[0]
- if version == "4":
- state = self.STATE_VERSION1_QT4
- elif version == "5":
- state = self.STATE_VERSION1_QT5
- else:
- self.skipTest("Resource not available")
-
- state = qt.QByteArray(state)
- dialog = self.createDialog()
- result = dialog.restoreState(state)
- self.assertTrue(result)
- colormap = dialog.colormap()
- self.assertEqual(colormap.getNormalization(), "log")
-
- def testRestoreRobusness(self):
- """What's happen if you try to open a config file with a different
- binding."""
- state = qt.QByteArray(self.STATE_VERSION1_QT4)
- dialog = self.createDialog()
- dialog.restoreState(state)
- state = qt.QByteArray(self.STATE_VERSION1_QT5)
- dialog = None
- dialog = self.createDialog()
- dialog.restoreState(state)
-
- def testRestoreNonExistingDirectory(self):
- directory = os.path.join(_tmpDirectory, "dir")
- os.mkdir(directory)
- dialog = self.createDialog()
- dialog.setDirectory(directory)
- self.qWaitForPendingActions(dialog)
- state = dialog.saveState()
- os.rmdir(directory)
- dialog = None
-
- dialog2 = self.createDialog()
- result = dialog2.restoreState(state)
- self.assertTrue(result)
- self.assertNotEqual(dialog2.directory(), directory)
-
- def testHistory(self):
- dialog = self.createDialog()
- history = dialog.history()
- dialog.setHistory([])
- self.assertEqual(dialog.history(), [])
- dialog.setHistory(history)
- self.assertEqual(dialog.history(), history)
-
- def testSidebarUrls(self):
- dialog = self.createDialog()
- urls = dialog.sidebarUrls()
- dialog.setSidebarUrls([])
- self.assertEqual(dialog.sidebarUrls(), [])
- dialog.setSidebarUrls(urls)
- self.assertEqual(dialog.sidebarUrls(), urls)
-
- def testColomap(self):
- dialog = self.createDialog()
- colormap = dialog.colormap()
- self.assertEqual(colormap.getNormalization(), "linear")
- colormap = Colormap(normalization=Colormap.LOGARITHM)
- dialog.setColormap(colormap)
- self.assertEqual(colormap.getNormalization(), "log")
-
- def testDirectory(self):
- dialog = self.createDialog()
- self.qWaitForPendingActions(dialog)
- dialog.selectUrl(_tmpDirectory)
- self.qWaitForPendingActions(dialog)
- self.assertSamePath(dialog.directory(), _tmpDirectory)
-
- def testBadDataType(self):
- dialog = self.createDialog()
- dialog.selectUrl(_tmpDirectory + "/data.h5::/complex_image")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- def testBadDataShape(self):
- dialog = self.createDialog()
- dialog.selectUrl(_tmpDirectory + "/data.h5::/unknown")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- def testBadDataFormat(self):
- dialog = self.createDialog()
- dialog.selectUrl(_tmpDirectory + "/badformat.edf")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- def testBadPath(self):
- dialog = self.createDialog()
- dialog.selectUrl("#$%/#$%")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- def testBadSubpath(self):
- dialog = self.createDialog()
- self.qWaitForPendingActions(dialog)
-
- browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
-
- filename = _tmpDirectory + "/data.h5"
- url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
- dialog.selectUrl(url.path())
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
- # an existing node is browsed, but the wrong path is selected
- index = browser.rootIndex()
- obj = index.model().data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
- self.assertEqual(obj.name, "/group")
- url = silx.io.url.DataUrl(dialog.selectedUrl())
- self.assertEqual(url.data_path(), "/group")
-
- def testBadSlicingPath(self):
- dialog = self.createDialog()
- self.qWaitForPendingActions(dialog)
- dialog.selectUrl(_tmpDirectory + "/data.h5::/cube[a;45,-90]")
- self.qWaitForPendingActions(dialog)
- self.assertIsNone(dialog._selectedData())
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestImageFileDialogInteraction))
- test_suite.addTest(loadTests(TestImageFileDialogApi))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/dialog/utils.py b/silx/gui/dialog/utils.py
deleted file mode 100644
index e2334f9..0000000
--- a/silx/gui/dialog/utils.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module contains utilitaries used by other dialog modules.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "25/10/2017"
-
-import os
-import sys
-import types
-
-import six
-
-from silx.gui import qt
-
-
-def samefile(path1, path2):
- """Portable :func:`os.path.samepath` function.
-
- :param str path1: A path to a file
- :param str path2: Another path to a file
- :rtype: bool
- """
- if six.PY2 and sys.platform == "win32":
- path1 = os.path.normcase(path1)
- path2 = os.path.normcase(path2)
- return path1 == path2
- if path1 == path2:
- return True
- if path1 == "":
- return False
- if path2 == "":
- return False
- return os.path.samefile(path1, path2)
-
-
-def findClosestSubPath(hdf5Object, path):
- """Find the closest existing path from the hdf5Object using a subset of the
- provided path.
-
- Returns None if no path found. It is possible if the path is a relative
- path.
-
- :param h5py.Node hdf5Object: An HDF5 node
- :param str path: A path
- :rtype: str
- """
- if path in ["", "/"]:
- return "/"
- names = path.split("/")
- if path[0] == "/":
- names.pop(0)
- for i in range(len(names)):
- n = len(names) - i
- path2 = "/".join(names[0:n])
- if path2 == "":
- return ""
- if path2 in hdf5Object:
- return path2
-
- if path[0] == "/":
- return "/"
- return None
-
-
-def patchToConsumeReturnKey(widget):
- """
- Monkey-patch a widget to consume the return key instead of propagating it
- to the dialog.
- """
- assert(not hasattr(widget, "_oldKeyPressEvent"))
-
- def keyPressEvent(self, event):
- k = event.key()
- result = self._oldKeyPressEvent(event)
- if k in [qt.Qt.Key_Return, qt.Qt.Key_Enter]:
- event.accept()
- return result
-
- widget._oldKeyPressEvent = widget.keyPressEvent
- widget.keyPressEvent = types.MethodType(keyPressEvent, widget)
diff --git a/silx/gui/fit/BackgroundWidget.py b/silx/gui/fit/BackgroundWidget.py
deleted file mode 100644
index 76bc043..0000000
--- a/silx/gui/fit/BackgroundWidget.py
+++ /dev/null
@@ -1,534 +0,0 @@
-# coding: utf-8
-#/*##########################################################################
-# Copyright (C) 2004-2020 V.A. Sole, European Synchrotron Radiation Facility
-#
-# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
-# the ESRF by the Software group.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# #########################################################################*/
-"""This module provides a background configuration widget
-:class:`BackgroundWidget` and a corresponding dialog window
-:class:`BackgroundDialog`.
-
-.. image:: img/BackgroundDialog.png
- :height: 300px
-"""
-import sys
-import numpy
-from silx.gui import qt
-from silx.gui.plot import PlotWidget
-from silx.math.fit import filters
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "28/06/2017"
-
-
-class HorizontalSpacer(qt.QWidget):
- def __init__(self, *args):
- qt.QWidget.__init__(self, *args)
- self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Fixed))
-
-
-class BackgroundParamWidget(qt.QWidget):
- """Background configuration composite widget.
-
- Strip and snip filters parameters can be adjusted using input widgets.
-
- Updating the widgets causes :attr:`sigBackgroundParamWidgetSignal` to
- be emitted.
- """
- sigBackgroundParamWidgetSignal = qt.pyqtSignal(object)
-
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
-
- self.mainLayout = qt.QGridLayout(self)
- self.mainLayout.setColumnStretch(1, 1)
-
- # Algorithm choice ---------------------------------------------------
- self.algorithmComboLabel = qt.QLabel(self)
- self.algorithmComboLabel.setText("Background algorithm")
- self.algorithmCombo = qt.QComboBox(self)
- self.algorithmCombo.addItem("Strip")
- self.algorithmCombo.addItem("Snip")
- self.algorithmCombo.activated[int].connect(
- self._algorithmComboActivated)
-
- # Strip parameters ---------------------------------------------------
- self.stripWidthLabel = qt.QLabel(self)
- self.stripWidthLabel.setText("Strip Width")
-
- self.stripWidthSpin = qt.QSpinBox(self)
- self.stripWidthSpin.setMaximum(100)
- self.stripWidthSpin.setMinimum(1)
- self.stripWidthSpin.valueChanged[int].connect(self._emitSignal)
-
- self.stripIterLabel = qt.QLabel(self)
- self.stripIterLabel.setText("Strip Iterations")
- self.stripIterValue = qt.QLineEdit(self)
- validator = qt.QIntValidator(self.stripIterValue)
- self.stripIterValue._v = validator
- self.stripIterValue.setText("0")
- self.stripIterValue.editingFinished[()].connect(self._emitSignal)
- self.stripIterValue.setToolTip(
- "Number of iterations for strip algorithm.\n" +
- "If greater than 999, an 2nd pass of strip filter is " +
- "applied to remove artifacts created by first pass.")
-
- # Snip parameters ----------------------------------------------------
- self.snipWidthLabel = qt.QLabel(self)
- self.snipWidthLabel.setText("Snip Width")
-
- self.snipWidthSpin = qt.QSpinBox(self)
- self.snipWidthSpin.setMaximum(300)
- self.snipWidthSpin.setMinimum(0)
- self.snipWidthSpin.valueChanged[int].connect(self._emitSignal)
-
-
- # Smoothing parameters -----------------------------------------------
- self.smoothingFlagCheck = qt.QCheckBox(self)
- self.smoothingFlagCheck.setText("Smoothing Width (Savitsky-Golay)")
- self.smoothingFlagCheck.toggled.connect(self._smoothingToggled)
-
- self.smoothingSpin = qt.QSpinBox(self)
- self.smoothingSpin.setMinimum(3)
- #self.smoothingSpin.setMaximum(40)
- self.smoothingSpin.setSingleStep(2)
- self.smoothingSpin.valueChanged[int].connect(self._emitSignal)
-
- # Anchors ------------------------------------------------------------
-
- self.anchorsGroup = qt.QWidget(self)
- anchorsLayout = qt.QHBoxLayout(self.anchorsGroup)
- anchorsLayout.setSpacing(2)
- anchorsLayout.setContentsMargins(0, 0, 0, 0)
-
- self.anchorsFlagCheck = qt.QCheckBox(self.anchorsGroup)
- self.anchorsFlagCheck.setText("Use anchors")
- self.anchorsFlagCheck.setToolTip(
- "Define X coordinates of points that must remain fixed")
- self.anchorsFlagCheck.stateChanged[int].connect(
- self._anchorsToggled)
- anchorsLayout.addWidget(self.anchorsFlagCheck)
-
- maxnchannel = 16384 * 4 # Fixme ?
- self.anchorsList = []
- num_anchors = 4
- for i in range(num_anchors):
- anchorSpin = qt.QSpinBox(self.anchorsGroup)
- anchorSpin.setMinimum(0)
- anchorSpin.setMaximum(maxnchannel)
- anchorSpin.valueChanged[int].connect(self._emitSignal)
- anchorsLayout.addWidget(anchorSpin)
- self.anchorsList.append(anchorSpin)
-
- # Layout ------------------------------------------------------------
- self.mainLayout.addWidget(self.algorithmComboLabel, 0, 0)
- self.mainLayout.addWidget(self.algorithmCombo, 0, 2)
- self.mainLayout.addWidget(self.stripWidthLabel, 1, 0)
- self.mainLayout.addWidget(self.stripWidthSpin, 1, 2)
- self.mainLayout.addWidget(self.stripIterLabel, 2, 0)
- self.mainLayout.addWidget(self.stripIterValue, 2, 2)
- self.mainLayout.addWidget(self.snipWidthLabel, 3, 0)
- self.mainLayout.addWidget(self.snipWidthSpin, 3, 2)
- self.mainLayout.addWidget(self.smoothingFlagCheck, 4, 0)
- self.mainLayout.addWidget(self.smoothingSpin, 4, 2)
- self.mainLayout.addWidget(self.anchorsGroup, 5, 0, 1, 4)
-
- # Initialize interface -----------------------------------------------
- self._setAlgorithm("strip")
- self.smoothingFlagCheck.setChecked(False)
- self._smoothingToggled(is_checked=False)
- self.anchorsFlagCheck.setChecked(False)
- self._anchorsToggled(is_checked=False)
-
- def _algorithmComboActivated(self, algorithm_index):
- self._setAlgorithm("strip" if algorithm_index == 0 else "snip")
-
- def _setAlgorithm(self, algorithm):
- """Enable/disable snip and snip input widgets, depending on the
- chosen algorithm.
- :param algorithm: "snip" or "strip"
- """
- if algorithm not in ["strip", "snip"]:
- raise ValueError(
- "Unknown background filter algorithm %s" % algorithm)
-
- self.algorithm = algorithm
- self.stripWidthSpin.setEnabled(algorithm == "strip")
- self.stripIterValue.setEnabled(algorithm == "strip")
- self.snipWidthSpin.setEnabled(algorithm == "snip")
-
- def _smoothingToggled(self, is_checked):
- """Enable/disable smoothing input widgets, emit dictionary"""
- self.smoothingSpin.setEnabled(is_checked)
- self._emitSignal()
-
- def _anchorsToggled(self, is_checked):
- """Enable/disable all spin widgets defining anchor X coordinates,
- emit signal.
- """
- for anchor_spin in self.anchorsList:
- anchor_spin.setEnabled(is_checked)
- self._emitSignal()
-
- def setParameters(self, ddict):
- """Set values for all input widgets.
-
- :param dict ddict: Input dictionary, must have the same
- keys as the dictionary output by :meth:`getParameters`
- """
- if "algorithm" in ddict:
- self._setAlgorithm(ddict["algorithm"])
-
- if "SnipWidth" in ddict:
- self.snipWidthSpin.setValue(int(ddict["SnipWidth"]))
-
- if "StripWidth" in ddict:
- self.stripWidthSpin.setValue(int(ddict["StripWidth"]))
-
- if "StripIterations" in ddict:
- self.stripIterValue.setText("%d" % int(ddict["StripIterations"]))
-
- if "SmoothingFlag" in ddict:
- self.smoothingFlagCheck.setChecked(bool(ddict["SmoothingFlag"]))
-
- if "SmoothingWidth" in ddict:
- self.smoothingSpin.setValue(int(ddict["SmoothingWidth"]))
-
- if "AnchorsFlag" in ddict:
- self.anchorsFlagCheck.setChecked(bool(ddict["AnchorsFlag"]))
-
- if "AnchorsList" in ddict:
- anchorslist = ddict["AnchorsList"]
- if anchorslist in [None, 'None']:
- anchorslist = []
- for spin in self.anchorsList:
- spin.setValue(0)
-
- i = 0
- for value in anchorslist:
- self.anchorsList[i].setValue(int(value))
- i += 1
-
- def getParameters(self):
- """Return dictionary of parameters defined in the GUI
-
- The returned dictionary contains following values:
-
- - *algorithm*: *"strip"* or *"snip"*
- - *StripWidth*: width of strip iterator
- - *StripIterations*: number of iterations
- - *StripThreshold*: curvature parameter (currently fixed to 1.0)
- - *SnipWidth*: width of snip algorithm
- - *SmoothingFlag*: flag to enable/disable smoothing
- - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
- - *AnchorsFlag*: flag to enable/disable anchors
- - *AnchorsList*: list of anchors (X coordinates of fixed values)
- """
- stripitertext = self.stripIterValue.text()
- stripiter = int(stripitertext) if len(stripitertext) else 0
-
- return {"algorithm": self.algorithm,
- "StripThreshold": 1.0,
- "SnipWidth": self.snipWidthSpin.value(),
- "StripIterations": stripiter,
- "StripWidth": self.stripWidthSpin.value(),
- "SmoothingFlag": self.smoothingFlagCheck.isChecked(),
- "SmoothingWidth": self.smoothingSpin.value(),
- "AnchorsFlag": self.anchorsFlagCheck.isChecked(),
- "AnchorsList": [spin.value() for spin in self.anchorsList]}
-
- def _emitSignal(self, dummy=None):
- self.sigBackgroundParamWidgetSignal.emit(
- {'event': 'ParametersChanged',
- 'parameters': self.getParameters()})
-
-
-class BackgroundWidget(qt.QWidget):
- """Background configuration widget, with a plot to preview the results.
-
- Strip and snip filters parameters can be adjusted using input widgets,
- and the computed backgrounds are plotted next to the original data to
- show the result."""
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
- self.setWindowTitle("Strip and SNIP Configuration Window")
- self.mainLayout = qt.QVBoxLayout(self)
- self.mainLayout.setContentsMargins(0, 0, 0, 0)
- self.mainLayout.setSpacing(2)
- self.parametersWidget = BackgroundParamWidget(self)
- self.graphWidget = PlotWidget(parent=self)
- self.mainLayout.addWidget(self.parametersWidget)
- self.mainLayout.addWidget(self.graphWidget)
- self._x = None
- self._y = None
- self.parametersWidget.sigBackgroundParamWidgetSignal.connect(self._slot)
-
- def getParameters(self):
- """Return dictionary of parameters defined in the GUI
-
- The returned dictionary contains following values:
-
- - *algorithm*: *"strip"* or *"snip"*
- - *StripWidth*: width of strip iterator
- - *StripIterations*: number of iterations
- - *StripThreshold*: strip curvature (currently fixed to 1.0)
- - *SnipWidth*: width of snip algorithm
- - *SmoothingFlag*: flag to enable/disable smoothing
- - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
- - *AnchorsFlag*: flag to enable/disable anchors
- - *AnchorsList*: list of anchors (X coordinates of fixed values)
- """
- return self.parametersWidget.getParameters()
-
- def setParameters(self, ddict):
- """Set values for all input widgets.
-
- :param dict ddict: Input dictionary, must have the same
- keys as the dictionary output by :meth:`getParameters`
- """
- return self.parametersWidget.setParameters(ddict)
-
- def setData(self, x, y, xmin=None, xmax=None):
- """Set data for the original curve, and _update strip and snip
- curves accordingly.
-
- :param x: Array or sequence of curve abscissa values
- :param y: Array or sequence of curve ordinate values
- :param xmin: Min value to be displayed on the X axis
- :param xmax: Max value to be displayed on the X axis
- """
- self._x = x
- self._y = y
- self._xmin = xmin
- self._xmax = xmax
- self._update(resetzoom=True)
-
- def _slot(self, ddict):
- self._update()
-
- def _update(self, resetzoom=False):
- """Compute strip and snip backgrounds, update the curves
- """
- if self._y is None:
- return
-
- pars = self.getParameters()
-
- # smoothed data
- y = numpy.ravel(numpy.array(self._y)).astype(numpy.float64)
- if pars["SmoothingFlag"]:
- ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth'])
- f = [0.25, 0.5, 0.25]
- ysmooth[1:-1] = numpy.convolve(ysmooth, f, mode=0)
- ysmooth[0] = 0.5 * (ysmooth[0] + ysmooth[1])
- ysmooth[-1] = 0.5 * (ysmooth[-1] + ysmooth[-2])
- else:
- ysmooth = y
-
-
- # loop for anchors
- x = self._x
- niter = pars['StripIterations']
- anchors_indices = []
- if pars['AnchorsFlag'] and pars['AnchorsList'] is not None:
- ravelled = x
- for channel in pars['AnchorsList']:
- if channel <= ravelled[0]:
- continue
- index = numpy.nonzero(ravelled >= channel)[0]
- if len(index):
- index = min(index)
- if index > 0:
- anchors_indices.append(index)
-
- stripBackground = filters.strip(ysmooth,
- w=pars['StripWidth'],
- niterations=niter,
- factor=pars['StripThreshold'],
- anchors=anchors_indices)
-
- if niter >= 1000:
- # final smoothing
- stripBackground = filters.strip(stripBackground,
- w=1,
- niterations=50*pars['StripWidth'],
- factor=pars['StripThreshold'],
- anchors=anchors_indices)
-
- if len(anchors_indices) == 0:
- anchors_indices = [0, len(ysmooth)-1]
- anchors_indices.sort()
- snipBackground = 0.0 * ysmooth
- lastAnchor = 0
- for anchor in anchors_indices:
- if (anchor > lastAnchor) and (anchor < len(ysmooth)):
- snipBackground[lastAnchor:anchor] =\
- filters.snip1d(ysmooth[lastAnchor:anchor],
- pars['SnipWidth'])
- lastAnchor = anchor
- if lastAnchor < len(ysmooth):
- snipBackground[lastAnchor:] =\
- filters.snip1d(ysmooth[lastAnchor:],
- pars['SnipWidth'])
-
- self.graphWidget.addCurve(x, y,
- legend='Input Data',
- replace=True,
- resetzoom=resetzoom)
- self.graphWidget.addCurve(x, stripBackground,
- legend='Strip Background',
- resetzoom=False)
- self.graphWidget.addCurve(x, snipBackground,
- legend='SNIP Background',
- resetzoom=False)
- if self._xmin is not None and self._xmax is not None:
- self.graphWidget.getXAxis().setLimits(self._xmin, self._xmax)
-
-
-class BackgroundDialog(qt.QDialog):
- """QDialog window featuring a :class:`BackgroundWidget`"""
- def __init__(self, parent=None):
- qt.QDialog.__init__(self, parent)
- self.setWindowTitle("Strip and Snip Configuration Window")
- self.mainLayout = qt.QVBoxLayout(self)
- self.mainLayout.setContentsMargins(0, 0, 0, 0)
- self.mainLayout.setSpacing(2)
- self.parametersWidget = BackgroundWidget(self)
- self.mainLayout.addWidget(self.parametersWidget)
- hbox = qt.QWidget(self)
- hboxLayout = qt.QHBoxLayout(hbox)
- hboxLayout.setContentsMargins(0, 0, 0, 0)
- hboxLayout.setSpacing(2)
- self.okButton = qt.QPushButton(hbox)
- self.okButton.setText("OK")
- self.okButton.setAutoDefault(False)
- self.dismissButton = qt.QPushButton(hbox)
- self.dismissButton.setText("Cancel")
- self.dismissButton.setAutoDefault(False)
- hboxLayout.addWidget(HorizontalSpacer(hbox))
- hboxLayout.addWidget(self.okButton)
- hboxLayout.addWidget(self.dismissButton)
- self.mainLayout.addWidget(hbox)
- self.dismissButton.clicked.connect(self.reject)
- self.okButton.clicked.connect(self.accept)
-
- self.output = {}
- """Configuration dictionary containing following fields:
-
- - *SmoothingFlag*
- - *SmoothingWidth*
- - *StripWidth*
- - *StripIterations*
- - *StripThreshold*
- - *SnipWidth*
- - *AnchorsFlag*
- - *AnchorsList*
- """
-
- # self.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(self.updateOutput)
-
- # def updateOutput(self, ddict):
- # self.output = ddict
-
- def accept(self):
- """Update :attr:`output`, then call :meth:`QDialog.accept`
- """
- self.output = self.getParameters()
- super(BackgroundDialog, self).accept()
-
- def sizeHint(self):
- return qt.QSize(int(1.5*qt.QDialog.sizeHint(self).width()),
- qt.QDialog.sizeHint(self).height())
-
- def setData(self, x, y, xmin=None, xmax=None):
- """See :meth:`BackgroundWidget.setData`"""
- return self.parametersWidget.setData(x, y, xmin, xmax)
-
- def getParameters(self):
- """See :meth:`BackgroundWidget.getParameters`"""
- return self.parametersWidget.getParameters()
-
- def setParameters(self, ddict):
- """See :meth:`BackgroundWidget.setPrintGeometry`"""
- return self.parametersWidget.setParameters(ddict)
-
- def setDefault(self, ddict):
- """Alias for :meth:`setPrintGeometry`"""
- return self.setParameters(ddict)
-
-
-def getBgDialog(parent=None, default=None, modal=True):
- """Instantiate and return a bg configuration dialog, adapted
- for configuring standard background theories from
- :mod:`silx.math.fit.bgtheories`.
-
- :return: Instance of :class:`BackgroundDialog`
- """
- bgd = BackgroundDialog(parent=parent)
- # apply default to newly added pages
- bgd.setParameters(default)
-
- return bgd
-
-
-def main():
- # synthetic data
- from silx.math.fit.functions import sum_gauss
-
- x = numpy.arange(5000)
- # (height1, center1, fwhm1, ...) 5 peaks
- params1 = (50, 500, 100,
- 20, 2000, 200,
- 50, 2250, 100,
- 40, 3000, 75,
- 23, 4000, 150)
- y0 = sum_gauss(x, *params1)
-
- # random values between [-1;1]
- noise = 2 * numpy.random.random(5000) - 1
- # make it +- 5%
- noise *= 0.05
-
- # 2 gaussians with very large fwhm, as background signal
- actual_bg = sum_gauss(x, 15, 3500, 3000, 5, 1000, 1500)
-
- # Add 5% random noise to gaussians and add background
- y = y0 + numpy.average(y0) * noise + actual_bg
-
- # Open widget
- a = qt.QApplication(sys.argv)
- a.lastWindowClosed.connect(a.quit)
-
- def mySlot(ddict):
- print(ddict)
-
- w = BackgroundDialog()
- w.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(mySlot)
- w.setData(x, y)
- w.exec_()
- #a.exec_()
-
-if __name__ == "__main__":
- main()
diff --git a/silx/gui/fit/FitConfig.py b/silx/gui/fit/FitConfig.py
deleted file mode 100644
index 479e469..0000000
--- a/silx/gui/fit/FitConfig.py
+++ /dev/null
@@ -1,543 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2004-2018 V.A. Sole, European Synchrotron Radiation Facility
-#
-# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
-# the ESRF by the Software group.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ######################################################################### */
-"""This module defines widgets used to build a fit configuration dialog.
-The resulting dialog widget outputs a dictionary of configuration parameters.
-"""
-from silx.gui import qt
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "30/11/2016"
-
-
-class TabsDialog(qt.QDialog):
- """Dialog widget containing a QTabWidget :attr:`tabWidget`
- and a buttons:
-
- # - buttonHelp
- - buttonDefaults
- - buttonOk
- - buttonCancel
-
- This dialog defines a __len__ returning the number of tabs,
- and an __iter__ method yielding the tab widgets.
- """
- def __init__(self, parent=None):
- qt.QDialog.__init__(self, parent)
- self.tabWidget = qt.QTabWidget(self)
-
- layout = qt.QVBoxLayout(self)
- layout.addWidget(self.tabWidget)
-
- layout2 = qt.QHBoxLayout(None)
-
- # self.buttonHelp = qt.QPushButton(self)
- # self.buttonHelp.setText("Help")
- # layout2.addWidget(self.buttonHelp)
-
- self.buttonDefault = qt.QPushButton(self)
- self.buttonDefault.setText("Undo changes")
- layout2.addWidget(self.buttonDefault)
-
- spacer = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
- layout2.addItem(spacer)
-
- self.buttonOk = qt.QPushButton(self)
- self.buttonOk.setText("OK")
- layout2.addWidget(self.buttonOk)
-
- self.buttonCancel = qt.QPushButton(self)
- self.buttonCancel.setText("Cancel")
- layout2.addWidget(self.buttonCancel)
-
- layout.addLayout(layout2)
-
- self.buttonOk.clicked.connect(self.accept)
- self.buttonCancel.clicked.connect(self.reject)
-
- def __len__(self):
- """Return number of tabs"""
- return self.tabWidget.count()
-
- def __iter__(self):
- """Return the next tab widget in :attr:`tabWidget` every
- time this method is called.
-
- :return: Tab widget
- :rtype: QWidget
- """
- for widget_index in range(len(self)):
- yield self.tabWidget.widget(widget_index)
-
- def addTab(self, page, label):
- """Add a new tab
-
- :param page: Content of new page. Must be a widget with
- a get() method returning a dictionary.
- :param str label: Tab label
- """
- self.tabWidget.addTab(page, label)
-
- def getTabLabels(self):
- """
- Return a list of all tab labels in :attr:`tabWidget`
- """
- return [self.tabWidget.tabText(i) for i in range(len(self))]
-
-
-class TabsDialogData(TabsDialog):
- """This dialog adds a data attribute to :class:`TabsDialog`.
-
- Data input in widgets, such as text entries or checkboxes, is stored in an
- attribute :attr:`output` when the user clicks the OK button.
-
- A default dictionary can be supplied when this dialog is initialized, to
- be used as default data for :attr:`output`.
- """
- def __init__(self, parent=None, modal=True, default=None):
- """
-
- :param parent: Parent :class:`QWidget`
- :param modal: If `True`, dialog is modal, meaning this dialog remains
- in front of it's parent window and disables it until the user is
- done interacting with the dialog
- :param default: Default dictionary, used to initialize and reset
- :attr:`output`.
- """
- TabsDialog.__init__(self, parent)
- self.setModal(modal)
- self.setWindowTitle("Fit configuration")
-
- self.output = {}
-
- self.default = {} if default is None else default
-
- self.buttonDefault.clicked.connect(self._resetDefault)
- # self.keyPressEvent(qt.Qt.Key_Enter).
-
- def keyPressEvent(self, event):
- """Redefining this method to ignore Enter key
- (for some reason it activates buttonDefault callback which
- resets all widgets)
- """
- if event.key() in [qt.Qt.Key_Enter, qt.Qt.Key_Return]:
- return
- TabsDialog.keyPressEvent(self, event)
-
- def accept(self):
- """When *OK* is clicked, update :attr:`output` with data from
- various widgets
- """
- self.output.update(self.default)
-
- # loop over all tab widgets (uses TabsDialog.__iter__)
- for tabWidget in self:
- self.output.update(tabWidget.get())
-
- # avoid pathological None cases
- for key in self.output.keys():
- if self.output[key] is None:
- if key in self.default:
- self.output[key] = self.default[key]
- super(TabsDialogData, self).accept()
-
- def reject(self):
- """When the *Cancel* button is clicked, reinitialize :attr:`output`
- and quit
- """
- self.setDefault()
- super(TabsDialogData, self).reject()
-
- def _resetDefault(self, checked):
- self.setDefault()
-
- def setDefault(self, newdefault=None):
- """Reinitialize :attr:`output` with :attr:`default` or with
- new dictionary ``newdefault`` if provided.
- Call :meth:`setDefault` for each tab widget, if available.
- """
- self.output = {}
- if newdefault is None:
- newdefault = self.default
- else:
- self.default = newdefault
- self.output.update(newdefault)
-
- for tabWidget in self:
- if hasattr(tabWidget, "setDefault"):
- tabWidget.setDefault(self.output)
-
-
-class ConstraintsPage(qt.QGroupBox):
- """Checkable QGroupBox widget filled with QCheckBox widgets,
- to configure the fit estimation for standard fit theories.
- """
- def __init__(self, parent=None, title="Set constraints"):
- super(ConstraintsPage, self).__init__(parent)
- self.setTitle(title)
- self.setToolTip("Disable 'Set constraints' to remove all " +
- "constraints on all fit parameters")
- self.setCheckable(True)
-
- layout = qt.QVBoxLayout(self)
- self.setLayout(layout)
-
- self.positiveHeightCB = qt.QCheckBox("Force positive height/area", self)
- self.positiveHeightCB.setToolTip("Fit must find positive peaks")
- layout.addWidget(self.positiveHeightCB)
-
- self.positionInIntervalCB = qt.QCheckBox("Force position in interval", self)
- self.positionInIntervalCB.setToolTip(
- "Fit must position peak within X limits")
- layout.addWidget(self.positionInIntervalCB)
-
- self.positiveFwhmCB = qt.QCheckBox("Force positive FWHM", self)
- self.positiveFwhmCB.setToolTip("Fit must find a positive FWHM")
- layout.addWidget(self.positiveFwhmCB)
-
- self.sameFwhmCB = qt.QCheckBox("Force same FWHM for all peaks", self)
- self.sameFwhmCB.setToolTip("Fit must find same FWHM for all peaks")
- layout.addWidget(self.sameFwhmCB)
-
- self.quotedEtaCB = qt.QCheckBox("Force Eta between 0 and 1", self)
- self.quotedEtaCB.setToolTip(
- "Fit must find Eta between 0 and 1 for pseudo-Voigt function")
- layout.addWidget(self.quotedEtaCB)
-
- layout.addStretch()
-
- self.setDefault()
-
- def setDefault(self, default_dict=None):
- """Set default state for all widgets.
-
- :param default_dict: If a default config dictionary is provided as
- a parameter, its values are used as default state."""
- if default_dict is None:
- default_dict = {}
- # this one uses reverse logic: if checked, NoConstraintsFlag must be False
- self.setChecked(
- not default_dict.get('NoConstraintsFlag', False))
- self.positiveHeightCB.setChecked(
- default_dict.get('PositiveHeightAreaFlag', True))
- self.positionInIntervalCB.setChecked(
- default_dict.get('QuotedPositionFlag', False))
- self.positiveFwhmCB.setChecked(
- default_dict.get('PositiveFwhmFlag', True))
- self.sameFwhmCB.setChecked(
- default_dict.get('SameFwhmFlag', False))
- self.quotedEtaCB.setChecked(
- default_dict.get('QuotedEtaFlag', False))
-
- def get(self):
- """Return a dictionary of constraint flags, to be processed by the
- :meth:`configure` method of the selected fit theory."""
- ddict = {
- 'NoConstraintsFlag': not self.isChecked(),
- 'PositiveHeightAreaFlag': self.positiveHeightCB.isChecked(),
- 'QuotedPositionFlag': self.positionInIntervalCB.isChecked(),
- 'PositiveFwhmFlag': self.positiveFwhmCB.isChecked(),
- 'SameFwhmFlag': self.sameFwhmCB.isChecked(),
- 'QuotedEtaFlag': self.quotedEtaCB.isChecked(),
- }
- return ddict
-
-
-class SearchPage(qt.QWidget):
- def __init__(self, parent=None):
- super(SearchPage, self).__init__(parent)
- layout = qt.QVBoxLayout(self)
-
- self.manualFwhmGB = qt.QGroupBox("Define FWHM manually", self)
- self.manualFwhmGB.setCheckable(True)
- self.manualFwhmGB.setToolTip(
- "If disabled, the FWHM parameter used for peak search is " +
- "estimated based on the highest peak in the data")
- layout.addWidget(self.manualFwhmGB)
- # ------------ GroupBox fwhm--------------------------
- layout2 = qt.QHBoxLayout(self.manualFwhmGB)
- self.manualFwhmGB.setLayout(layout2)
-
- label = qt.QLabel("Fwhm Points", self.manualFwhmGB)
- layout2.addWidget(label)
-
- self.fwhmPointsSpin = qt.QSpinBox(self.manualFwhmGB)
- self.fwhmPointsSpin.setRange(0, 999999)
- self.fwhmPointsSpin.setToolTip("Typical peak fwhm (number of data points)")
- layout2.addWidget(self.fwhmPointsSpin)
- # ----------------------------------------------------
-
- self.manualScalingGB = qt.QGroupBox("Define scaling manually", self)
- self.manualScalingGB.setCheckable(True)
- self.manualScalingGB.setToolTip(
- "If disabled, the Y scaling used for peak search is " +
- "estimated automatically")
- layout.addWidget(self.manualScalingGB)
- # ------------ GroupBox scaling-----------------------
- layout3 = qt.QHBoxLayout(self.manualScalingGB)
- self.manualScalingGB.setLayout(layout3)
-
- label = qt.QLabel("Y Scaling", self.manualScalingGB)
- layout3.addWidget(label)
-
- self.yScalingEntry = qt.QLineEdit(self.manualScalingGB)
- self.yScalingEntry.setToolTip(
- "Data values will be multiplied by this value prior to peak" +
- " search")
- self.yScalingEntry.setValidator(qt.QDoubleValidator(self))
- layout3.addWidget(self.yScalingEntry)
- # ----------------------------------------------------
-
- # ------------------- grid layout --------------------
- containerWidget = qt.QWidget(self)
- layout4 = qt.QHBoxLayout(containerWidget)
- containerWidget.setLayout(layout4)
-
- label = qt.QLabel("Sensitivity", containerWidget)
- layout4.addWidget(label)
-
- self.sensitivityEntry = qt.QLineEdit(containerWidget)
- self.sensitivityEntry.setToolTip(
- "Peak search sensitivity threshold, expressed as a multiple " +
- "of the standard deviation of the noise.\nMinimum value is 1 " +
- "(to be detected, peak must be higher than the estimated noise)")
- sensivalidator = qt.QDoubleValidator(self)
- sensivalidator.setBottom(1.0)
- self.sensitivityEntry.setValidator(sensivalidator)
- layout4.addWidget(self.sensitivityEntry)
- # ----------------------------------------------------
- layout.addWidget(containerWidget)
-
- self.forcePeakPresenceCB = qt.QCheckBox("Force peak presence", self)
- self.forcePeakPresenceCB.setToolTip(
- "If peak search algorithm is unsuccessful, place one peak " +
- "at the maximum of the curve")
- layout.addWidget(self.forcePeakPresenceCB)
-
- layout.addStretch()
-
- self.setDefault()
-
- def setDefault(self, default_dict=None):
- """Set default values for all widgets.
-
- :param default_dict: If a default config dictionary is provided as
- a parameter, its values are used as default values."""
- if default_dict is None:
- default_dict = {}
- self.manualFwhmGB.setChecked(
- not default_dict.get('AutoFwhm', True))
- self.fwhmPointsSpin.setValue(
- default_dict.get('FwhmPoints', 8))
- self.sensitivityEntry.setText(
- str(default_dict.get('Sensitivity', 1.0)))
- self.manualScalingGB.setChecked(
- not default_dict.get('AutoScaling', False))
- self.yScalingEntry.setText(
- str(default_dict.get('Yscaling', 1.0)))
- self.forcePeakPresenceCB.setChecked(
- default_dict.get('ForcePeakPresence', False))
-
- def get(self):
- """Return a dictionary of peak search parameters, to be processed by
- the :meth:`configure` method of the selected fit theory."""
- ddict = {
- 'AutoFwhm': not self.manualFwhmGB.isChecked(),
- 'FwhmPoints': self.fwhmPointsSpin.value(),
- 'Sensitivity': safe_float(self.sensitivityEntry.text()),
- 'AutoScaling': not self.manualScalingGB.isChecked(),
- 'Yscaling': safe_float(self.yScalingEntry.text()),
- 'ForcePeakPresence': self.forcePeakPresenceCB.isChecked()
- }
- return ddict
-
-
-class BackgroundPage(qt.QGroupBox):
- """Background subtraction configuration, specific to fittheories
- estimation functions."""
- def __init__(self, parent=None,
- title="Subtract strip background prior to estimation"):
- super(BackgroundPage, self).__init__(parent)
- self.setTitle(title)
- self.setCheckable(True)
- self.setToolTip(
- "The strip algorithm strips away peaks to compute the " +
- "background signal.\nAt each iteration, a sample is compared " +
- "to the average of the two samples at a given distance in both" +
- " directions,\n and if its value is higher than the average,"
- "it is replaced by the average.")
-
- layout = qt.QGridLayout(self)
- self.setLayout(layout)
-
- for i, label_text in enumerate(
- ["Strip width (in samples)",
- "Number of iterations",
- "Strip threshold factor"]):
- label = qt.QLabel(label_text)
- layout.addWidget(label, i, 0)
-
- self.stripWidthSpin = qt.QSpinBox(self)
- self.stripWidthSpin.setToolTip(
- "Width, in number of samples, of the strip operator")
- self.stripWidthSpin.setRange(1, 999999)
-
- layout.addWidget(self.stripWidthSpin, 0, 1)
-
- self.numIterationsSpin = qt.QSpinBox(self)
- self.numIterationsSpin.setToolTip(
- "Number of iterations of the strip algorithm")
- self.numIterationsSpin.setRange(1, 999999)
- layout.addWidget(self.numIterationsSpin, 1, 1)
-
- self.thresholdFactorEntry = qt.QLineEdit(self)
- self.thresholdFactorEntry.setToolTip(
- "Factor used by the strip algorithm to decide whether a sample" +
- "value should be stripped.\nThe value must be higher than the " +
- "average of the 2 samples at +- w times this factor.\n")
- self.thresholdFactorEntry.setValidator(qt.QDoubleValidator(self))
- layout.addWidget(self.thresholdFactorEntry, 2, 1)
-
- self.smoothStripGB = qt.QGroupBox("Apply smoothing prior to strip", self)
- self.smoothStripGB.setCheckable(True)
- self.smoothStripGB.setToolTip(
- "Apply a smoothing before subtracting strip background" +
- " in fit and estimate processes")
- smoothlayout = qt.QHBoxLayout(self.smoothStripGB)
- label = qt.QLabel("Smoothing width (Savitsky-Golay)")
- smoothlayout.addWidget(label)
- self.smoothingWidthSpin = qt.QSpinBox(self)
- self.smoothingWidthSpin.setToolTip(
- "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)")
- self.smoothingWidthSpin.setRange(3, 101)
- self.smoothingWidthSpin.setSingleStep(2)
- smoothlayout.addWidget(self.smoothingWidthSpin)
-
- layout.addWidget(self.smoothStripGB, 3, 0, 1, 2)
-
- layout.setRowStretch(4, 1)
-
- self.setDefault()
-
- def setDefault(self, default_dict=None):
- """Set default values for all widgets.
-
- :param default_dict: If a default config dictionary is provided as
- a parameter, its values are used as default values."""
- if default_dict is None:
- default_dict = {}
-
- self.setChecked(
- default_dict.get('StripBackgroundFlag', True))
-
- self.stripWidthSpin.setValue(
- default_dict.get('StripWidth', 2))
- self.numIterationsSpin.setValue(
- default_dict.get('StripIterations', 5000))
- self.thresholdFactorEntry.setText(
- str(default_dict.get('StripThreshold', 1.0)))
- self.smoothStripGB.setChecked(
- default_dict.get('SmoothingFlag', False))
- self.smoothingWidthSpin.setValue(
- default_dict.get('SmoothingWidth', 3))
-
- def get(self):
- """Return a dictionary of background subtraction parameters, to be
- processed by the :meth:`configure` method of the selected fit theory.
- """
- ddict = {
- 'StripBackgroundFlag': self.isChecked(),
- 'StripWidth': self.stripWidthSpin.value(),
- 'StripIterations': self.numIterationsSpin.value(),
- 'StripThreshold': safe_float(self.thresholdFactorEntry.text()),
- 'SmoothingFlag': self.smoothStripGB.isChecked(),
- 'SmoothingWidth': self.smoothingWidthSpin.value()
- }
- return ddict
-
-
-def safe_float(string_, default=1.0):
- """Convert a string into a float.
- If the conversion fails, return the default value.
- """
- try:
- ret = float(string_)
- except ValueError:
- return default
- else:
- return ret
-
-
-def safe_int(string_, default=1):
- """Convert a string into a integer.
- If the conversion fails, return the default value.
- """
- try:
- ret = int(float(string_))
- except ValueError:
- return default
- else:
- return ret
-
-
-def getFitConfigDialog(parent=None, default=None, modal=True):
- """Instantiate and return a fit configuration dialog, adapted
- for configuring standard fit theories from
- :mod:`silx.math.fit.fittheories`.
-
- :return: Instance of :class:`TabsDialogData` with 3 tabs:
- :class:`ConstraintsPage`, :class:`SearchPage` and
- :class:`BackgroundPage`
- """
- tdd = TabsDialogData(parent=parent, default=default)
- tdd.addTab(ConstraintsPage(), label="Constraints")
- tdd.addTab(SearchPage(), label="Peak search")
- tdd.addTab(BackgroundPage(), label="Background")
- # apply default to newly added pages
- tdd.setDefault()
-
- return tdd
-
-
-def main():
- a = qt.QApplication([])
-
- mw = qt.QMainWindow()
- mw.show()
-
- tdd = getFitConfigDialog(mw, default={"a": 1})
- tdd.show()
- tdd.exec_()
- print("TabsDialogData result: ", tdd.result())
- print("TabsDialogData output: ", tdd.output)
-
- a.exec_()
-
-if __name__ == "__main__":
- main()
diff --git a/silx/gui/fit/FitWidget.py b/silx/gui/fit/FitWidget.py
deleted file mode 100644
index 08731f1..0000000
--- a/silx/gui/fit/FitWidget.py
+++ /dev/null
@@ -1,739 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
-# the ESRF by the Software group.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ######################################################################### */
-"""This module provides a widget designed to configure and run a fitting
-process with constraints on parameters.
-
-The main class is :class:`FitWidget`. It relies on
-:mod:`silx.math.fit.fitmanager`, which relies on :func:`silx.math.fit.leastsq`.
-
-The user can choose between functions before running the fit. These function can
-be user defined, or by default are loaded from
-:mod:`silx.math.fit.fittheories`.
-"""
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/07/2018"
-
-import logging
-import sys
-import traceback
-
-from silx.math.fit import fittheories
-from silx.math.fit import fitmanager, functions
-from silx.gui import qt
-from .FitWidgets import (FitActionsButtons, FitStatusLines,
- FitConfigWidget, ParametersTab)
-from .FitConfig import getFitConfigDialog
-from .BackgroundWidget import getBgDialog, BackgroundDialog
-from ...utils.deprecation import deprecated
-
-QTVERSION = qt.qVersion()
-DEBUG = 0
-_logger = logging.getLogger(__name__)
-
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "30/11/2016"
-
-
-class FitWidget(qt.QWidget):
- """This widget can be used to configure, run and display results of a
- fitting process.
-
- The standard steps for using this widget is to initialize it, then load
- the data to be fitted.
-
- Optionally, you can also load user defined fit theories. If you skip this
- step, a series of default fit functions will be presented (gaussian-like
- functions), and you can later load your custom fit theories from an
- external file using the GUI.
-
- A fit theory is a fit function and its associated features:
-
- - estimation function,
- - list of parameter names
- - numerical derivative algorithm
- - configuration widget
-
- Once the widget is up and running, the user may select a fit theory and a
- background theory, change configuration parameters specific to the theory
- run the estimation, set constraints on parameters and run the actual fit.
-
- The results are displayed in a table.
-
- .. image:: img/FitWidget.png
- """
- sigFitWidgetSignal = qt.Signal(object)
- """This signal is emitted by the estimation and fit methods.
- It carries a dictionary with two items:
-
- - *event*: one of the following strings
-
- - *EstimateStarted*,
- - *FitStarted*
- - *EstimateFinished*,
- - *FitFinished*
- - *EstimateFailed*
- - *FitFailed*
-
- - *data*: None, or fit/estimate results (see documentation for
- :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
- """
-
- def __init__(self, parent=None, title=None, fitmngr=None,
- enableconfig=True, enablestatus=True, enablebuttons=True):
- """
-
- :param parent: Parent widget
- :param title: Window title
- :param fitmngr: User defined instance of
- :class:`silx.math.fit.fitmanager.FitManager`, or ``None``
- :param enableconfig: If ``True``, activate widgets to modify the fit
- configuration (select between several fit functions or background
- functions, apply global constraints, peak search parameters…)
- :param enablestatus: If ``True``, add a fit status widget, to display
- a message when fit estimation is available and when fit results
- are available, as well as a measure of the fit error.
- :param enablebuttons: If ``True``, add buttons to run estimation and
- fitting.
- """
- if title is None:
- title = "FitWidget"
- qt.QWidget.__init__(self, parent)
-
- self.setWindowTitle(title)
- layout = qt.QVBoxLayout(self)
-
- self.fitmanager = self._setFitManager(fitmngr)
- """Instance of :class:`FitManager`.
- This is the underlying data model of this FitWidget.
-
- If no custom theories are defined, the default ones from
- :mod:`silx.math.fit.fittheories` are imported.
- """
-
- # reference fitmanager.configure method for direct access
- self.configure = self.fitmanager.configure
- self.fitconfig = self.fitmanager.fitconfig
-
- self.configdialogs = {}
- """This dictionary defines the fit configuration widgets
- associated with the fit theories in :attr:`fitmanager.theories`
-
- Keys must correspond to existing theory names, i.e. existing keys
- in :attr:`fitmanager.theories`.
-
- Values must be instances of QDialog widgets with an additional
- *output* attribute, a dictionary storing configuration parameters
- interpreted by the corresponding fit theory.
-
- The dialog can also define a *setDefault* method to initialize the
- widget values with values in a dictionary passed as a parameter.
- This will be executed first.
-
- In case the widget does not actually inherit :class:`QDialog`, it
- must at least implement the following methods (executed in this
- particular order):
-
- - :meth:`show`: should cause the widget to become visible to the
- user)
- - :meth:`exec_`: should run while the user is interacting with the
- widget, interrupting the rest of the program. It should
- typically end (*return*) when the user clicks an *OK*
- or a *Cancel* button.
- - :meth:`result`: must return ``True`` if the new configuration in
- attribute :attr:`output` is to be accepted (user clicked *OK*),
- or return ``False`` if :attr:`output` is to be rejected (user
- clicked *Cancel*)
-
- To associate a custom configuration widget with a fit theory, use
- :meth:`associateConfigDialog`. E.g.::
-
- fw = FitWidget()
- my_config_widget = MyGaussianConfigWidget(parent=fw)
- fw.associateConfigDialog(theory_name="Gaussians",
- config_widget=my_config_widget)
- """
-
- self.bgconfigdialogs = {}
- """Same as :attr:`configdialogs`, except that the widget is associated
- with a background theory in :attr:`fitmanager.bgtheories`"""
-
- self._associateDefaultConfigDialogs()
-
- self.guiConfig = None
- """Configuration widget at the top of FitWidget, to select
- fit function, background function, and open an advanced
- configuration dialog."""
-
- self.guiParameters = ParametersTab(self)
- """Table widget for display of fit parameters and constraints"""
-
- if enableconfig:
- self.guiConfig = FitConfigWidget(self)
- """Function selector and configuration widget"""
-
- self.guiConfig.FunConfigureButton.clicked.connect(
- self.__funConfigureGuiSlot)
- self.guiConfig.BgConfigureButton.clicked.connect(
- self.__bgConfigureGuiSlot)
-
- self.guiConfig.WeightCheckBox.setChecked(
- self.fitconfig.get("WeightFlag", False))
- self.guiConfig.WeightCheckBox.stateChanged[int].connect(self.weightEvent)
-
- self.guiConfig.BkgComBox.activated[str].connect(self.bkgEvent)
- self.guiConfig.FunComBox.activated[str].connect(self.funEvent)
- self._populateFunctions()
-
- layout.addWidget(self.guiConfig)
-
- layout.addWidget(self.guiParameters)
-
- if enablestatus:
- self.guistatus = FitStatusLines(self)
- """Status bar"""
- layout.addWidget(self.guistatus)
-
- if enablebuttons:
- self.guibuttons = FitActionsButtons(self)
- """Widget with estimate, start fit and dismiss buttons"""
- self.guibuttons.EstimateButton.clicked.connect(self.estimate)
- self.guibuttons.EstimateButton.setEnabled(False)
- self.guibuttons.StartFitButton.clicked.connect(self.startFit)
- self.guibuttons.StartFitButton.setEnabled(False)
- self.guibuttons.DismissButton.clicked.connect(self.dismiss)
- layout.addWidget(self.guibuttons)
-
- def _setFitManager(self, fitinstance):
- """Initialize a :class:`FitManager` instance, to be assigned to
- :attr:`fitmanager`, or use a custom FitManager instance.
-
- :param fitinstance: Existing instance of FitManager, possibly
- customized by the user, or None to load a default instance."""
- if isinstance(fitinstance, fitmanager.FitManager):
- # customized
- fitmngr = fitinstance
- else:
- # initialize default instance
- fitmngr = fitmanager.FitManager()
-
- # initialize the default fitting functions in case
- # none is present
- if not len(fitmngr.theories):
- fitmngr.loadtheories(fittheories)
-
- return fitmngr
-
- def _associateDefaultConfigDialogs(self):
- """Fill :attr:`bgconfigdialogs` and :attr:`configdialogs` by calling
- :meth:`associateConfigDialog` with default config dialog widgets.
- """
- # associate silx.gui.fit.FitConfig with all theories
- # Users can later associate their own custom dialogs to
- # replace the default.
- configdialog = getFitConfigDialog(parent=self,
- default=self.fitconfig)
- for theory in self.fitmanager.theories:
- self.associateConfigDialog(theory, configdialog)
- for bgtheory in self.fitmanager.bgtheories:
- self.associateConfigDialog(bgtheory, configdialog,
- theory_is_background=True)
-
- # associate silx.gui.fit.BackgroundWidget with Strip and Snip
- bgdialog = getBgDialog(parent=self,
- default=self.fitconfig)
- for bgtheory in ["Strip", "Snip"]:
- if bgtheory in self.fitmanager.bgtheories:
- self.associateConfigDialog(bgtheory, bgdialog,
- theory_is_background=True)
-
- def _populateFunctions(self):
- """Fill combo-boxes with fit theories and background theories
- loaded by :attr:`fitmanager`.
- Run :meth:`fitmanager.configure` to ensure the custom configuration
- of the selected theory has been loaded into :attr:`fitconfig`"""
- for theory_name in self.fitmanager.bgtheories:
- self.guiConfig.BkgComBox.addItem(theory_name)
- self.guiConfig.BkgComBox.setItemData(
- self.guiConfig.BkgComBox.findText(theory_name),
- self.fitmanager.bgtheories[theory_name].description,
- qt.Qt.ToolTipRole)
-
- for theory_name in self.fitmanager.theories:
- self.guiConfig.FunComBox.addItem(theory_name)
- self.guiConfig.FunComBox.setItemData(
- self.guiConfig.FunComBox.findText(theory_name),
- self.fitmanager.theories[theory_name].description,
- qt.Qt.ToolTipRole)
-
- # - activate selected fit theory (if any)
- # - activate selected bg theory (if any)
- configuration = self.fitmanager.configure()
- if self.fitmanager.selectedtheory is None:
- # take the first one by default
- self.guiConfig.FunComBox.setCurrentIndex(1)
- self.funEvent(list(self.fitmanager.theories.keys())[0])
- else:
- idx = list(self.fitmanager.theories).index(self.fitmanager.selectedtheory)
- self.guiConfig.FunComBox.setCurrentIndex(idx + 1)
- self.funEvent(self.fitmanager.selectedtheory)
-
- if self.fitmanager.selectedbg is None:
- self.guiConfig.BkgComBox.setCurrentIndex(1)
- self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
- else:
- idx = list(self.fitmanager.bgtheories).index(self.fitmanager.selectedbg)
- self.guiConfig.BkgComBox.setCurrentIndex(idx + 1)
- self.bkgEvent(self.fitmanager.selectedbg)
-
- configuration.update(self.configure())
-
- @deprecated(replacement='setData', since_version='0.3.0')
- def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
- self.setData(x, y, sigmay, xmin, xmax)
-
- def setData(self, x=None, y=None, sigmay=None, xmin=None, xmax=None):
- """Set data to be fitted.
-
- :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to
- ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
- :type x: Sequence or numpy array or None
- :param y: The dependant data ``y = f(x)``. ``y`` must have the same
- shape as ``x`` if ``x`` is not ``None``.
- :type y: Sequence or numpy array or None
- :param sigmay: The uncertainties in the ``ydata`` array. These are
- used as weights in the least-squares problem.
- If ``None``, the uncertainties are assumed to be 1.
- :type sigmay: Sequence or numpy array or None
- :param xmin: Lower value of x values to use for fitting
- :param xmax: Upper value of x values to use for fitting
- """
- if y is None:
- self.guibuttons.EstimateButton.setEnabled(False)
- self.guibuttons.StartFitButton.setEnabled(False)
- else:
- self.guibuttons.EstimateButton.setEnabled(True)
- self.guibuttons.StartFitButton.setEnabled(True)
- self.fitmanager.setdata(x=x, y=y, sigmay=sigmay,
- xmin=xmin, xmax=xmax)
- for config_dialog in self.bgconfigdialogs.values():
- if isinstance(config_dialog, BackgroundDialog):
- config_dialog.setData(x, y, xmin=xmin, xmax=xmax)
-
- def associateConfigDialog(self, theory_name, config_widget,
- theory_is_background=False):
- """Associate an instance of custom configuration dialog widget to
- a fit theory or to a background theory.
-
- This adds or modifies an item in the correspondence table
- :attr:`configdialogs` or :attr:`bgconfigdialogs`.
-
- :param str theory_name: Name of fit theory. This must be a key of dict
- :attr:`fitmanager.theories`
- :param config_widget: Custom configuration widget. See documentation
- for :attr:`configdialogs`
- :param bool theory_is_background: If flag is *True*, add dialog to
- :attr:`bgconfigdialogs` rather than :attr:`configdialogs`
- (default).
- :raise: KeyError if parameter ``theory_name`` does not match an
- existing fit theory or background theory in :attr:`fitmanager`.
- :raise: AttributeError if the widget does not implement the mandatory
- methods (*show*, *exec_*, *result*, *setDefault*) or the mandatory
- attribute (*output*).
- """
- theories = self.fitmanager.bgtheories if theory_is_background else\
- self.fitmanager.theories
-
- if theory_name not in theories:
- raise KeyError("%s does not match an existing fitmanager theory")
-
- if config_widget is not None:
- for mandatory_attr in ["show", "exec_", "result", "output"]:
- if not hasattr(config_widget, mandatory_attr):
- raise AttributeError(
- "Custom configuration widget must define " +
- "attribute or method " + mandatory_attr)
-
- if theory_is_background:
- self.bgconfigdialogs[theory_name] = config_widget
- else:
- self.configdialogs[theory_name] = config_widget
-
- def _emitSignal(self, ddict):
- """Emit pyqtSignal after estimation completed
- (``ddict = {'event': 'EstimateFinished', 'data': fit_results}``)
- and after fit completed
- (``ddict = {'event': 'FitFinished', 'data': fit_results}``)"""
- self.sigFitWidgetSignal.emit(ddict)
-
- def __funConfigureGuiSlot(self):
- """Open an advanced configuration dialog widget"""
- self.__configureGui(dialog_type="function")
-
- def __bgConfigureGuiSlot(self):
- """Open an advanced configuration dialog widget"""
- self.__configureGui(dialog_type="background")
-
- def __configureGui(self, newconfiguration=None, dialog_type="function"):
- """Open an advanced configuration dialog widget to get a configuration
- dictionary, or use a supplied configuration dictionary. Call
- :meth:`configure` with this dictionary as a parameter. Update the gui
- accordingly. Reinitialize the fit results in the table and in
- :attr:`fitmanager`.
-
- :param newconfiguration: User supplied configuration dictionary. If ``None``,
- open a dialog widget that returns a dictionary."""
- configuration = self.configure()
- # get new dictionary
- if newconfiguration is None:
- newconfiguration = self.configureDialog(configuration, dialog_type)
- # update configuration
- configuration.update(self.configure(**newconfiguration))
- # set fit function theory
- try:
- i = 1 + \
- list(self.fitmanager.theories.keys()).index(
- self.fitmanager.selectedtheory)
- self.guiConfig.FunComBox.setCurrentIndex(i)
- self.funEvent(self.fitmanager.selectedtheory)
- except ValueError:
- _logger.error("Function not in list %s",
- self.fitmanager.selectedtheory)
- self.funEvent(list(self.fitmanager.theories.keys())[0])
- # current background
- try:
- i = 1 + \
- list(self.fitmanager.bgtheories.keys()).index(
- self.fitmanager.selectedbg)
- self.guiConfig.BkgComBox.setCurrentIndex(i)
- self.bkgEvent(self.fitmanager.selectedbg)
- except ValueError:
- _logger.error("Background not in list %s",
- self.fitmanager.selectedbg)
- self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
-
- # update the Gui
- self.__initialParameters()
-
- def configureDialog(self, oldconfiguration, dialog_type="function"):
- """Display a dialog, allowing the user to define fit configuration
- parameters.
-
- By default, a common dialog is used for all fit theories. But if the
- defined a custom dialog using :meth:`associateConfigDialog`, it is
- used instead.
-
- :param dict oldconfiguration: Dictionary containing previous configuration
- :param str dialog_type: "function" or "background"
- :return: User defined parameters in a dictionary
- """
- newconfiguration = {}
- newconfiguration.update(oldconfiguration)
-
- if dialog_type == "function":
- theory = self.fitmanager.selectedtheory
- configdialog = self.configdialogs[theory]
- elif dialog_type == "background":
- theory = self.fitmanager.selectedbg
- configdialog = self.bgconfigdialogs[theory]
-
- # this should only happen if a user specifically associates None
- # with a theory, to have no configuration option
- if configdialog is None:
- return {}
-
- # update state of configdialog before showing it
- if hasattr(configdialog, "setDefault"):
- configdialog.setDefault(newconfiguration)
- configdialog.show()
- configdialog.exec_()
- if configdialog.result():
- newconfiguration.update(configdialog.output)
-
- return newconfiguration
-
- def estimate(self):
- """Run parameter estimation function then emit
- :attr:`sigFitWidgetSignal` with a dictionary containing a status
- message and a list of fit parameters estimations
- in the format defined in
- :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
-
- The emitted dictionary has an *"event"* key that can have
- following values:
-
- - *'EstimateStarted'*
- - *'EstimateFailed'*
- - *'EstimateFinished'*
- """
- try:
- theory_name = self.fitmanager.selectedtheory
- estimation_function = self.fitmanager.theories[theory_name].estimate
- if estimation_function is not None:
- ddict = {'event': 'EstimateStarted',
- 'data': None}
- self._emitSignal(ddict)
- self.fitmanager.estimate(callback=self.fitStatus)
- else:
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Information)
- text = "Function does not define a way to estimate\n"
- text += "the initial parameters. Please, fill them\n"
- text += "yourself in the table and press Start Fit\n"
- msg.setText(text)
- msg.setWindowTitle('FitWidget Message')
- msg.exec_()
- return
- except Exception as e: # noqa (we want to catch and report all errors)
- _logger.warning('Estimate error: %s', traceback.format_exc())
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setWindowTitle("Estimate Error")
- msg.setText("Error on estimate: %s" % e)
- msg.exec_()
- ddict = {
- 'event': 'EstimateFailed',
- 'data': None}
- self._emitSignal(ddict)
- return
-
- self.guiParameters.fillFromFit(
- self.fitmanager.fit_results, view='Fit')
- self.guiParameters.removeAllViews(keep='Fit')
- ddict = {
- 'event': 'EstimateFinished',
- 'data': self.fitmanager.fit_results}
- self._emitSignal(ddict)
-
- @deprecated(replacement='startFit', since_version='0.3.0')
- def startfit(self):
- self.startFit()
-
- def startFit(self):
- """Run fit, then emit :attr:`sigFitWidgetSignal` with a dictionary
- containing a status message and a list of fit
- parameters results in the format defined in
- :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
-
- The emitted dictionary has an *"event"* key that can have
- following values:
-
- - *'FitStarted'*
- - *'FitFailed'*
- - *'FitFinished'*
- """
- self.fitmanager.fit_results = self.guiParameters.getFitResults()
- try:
- ddict = {'event': 'FitStarted',
- 'data': None}
- self._emitSignal(ddict)
- self.fitmanager.runfit(callback=self.fitStatus)
- except Exception as e: # noqa (we want to catch and report all errors)
- _logger.warning('Estimate error: %s', traceback.format_exc())
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setWindowTitle("Fit Error")
- msg.setText("Error on Fit: %s" % e)
- msg.exec_()
- ddict = {
- 'event': 'FitFailed',
- 'data': None
- }
- self._emitSignal(ddict)
- return
-
- self.guiParameters.fillFromFit(
- self.fitmanager.fit_results, view='Fit')
- self.guiParameters.removeAllViews(keep='Fit')
- ddict = {
- 'event': 'FitFinished',
- 'data': self.fitmanager.fit_results
- }
- self._emitSignal(ddict)
- return
-
- def bkgEvent(self, bgtheory):
- """Select background theory, then reinitialize parameters"""
- bgtheory = str(bgtheory)
- if bgtheory in self.fitmanager.bgtheories:
- self.fitmanager.setbackground(bgtheory)
- else:
- functionsfile = qt.QFileDialog.getOpenFileName(
- self, "Select python module with your function(s)", "",
- "Python Files (*.py);;All Files (*)")
-
- if len(functionsfile):
- try:
- self.fitmanager.loadbgtheories(functionsfile)
- except ImportError:
- qt.QMessageBox.critical(self, "ERROR",
- "Function not imported")
- return
- else:
- # empty the ComboBox
- while self.guiConfig.BkgComBox.count() > 1:
- self.guiConfig.BkgComBox.removeItem(1)
- # and fill it again
- for key in self.fitmanager.bgtheories:
- self.guiConfig.BkgComBox.addItem(str(key))
-
- i = 1 + \
- list(self.fitmanager.bgtheories.keys()).index(
- self.fitmanager.selectedbg)
- self.guiConfig.BkgComBox.setCurrentIndex(i)
- self.__initialParameters()
-
- def funEvent(self, theoryname):
- """Select a fit theory to be used for fitting. If this theory exists
- in :attr:`fitmanager`, use it. Then, reinitialize table.
-
- :param theoryname: Name of the fit theory to use for fitting. If this theory
- exists in :attr:`fitmanager`, use it. Else, open a file dialog to open
- a custom fit function definition file with
- :meth:`fitmanager.loadtheories`.
- """
- theoryname = str(theoryname)
- if theoryname in self.fitmanager.theories:
- self.fitmanager.settheory(theoryname)
- else:
- # open a load file dialog
- functionsfile = qt.QFileDialog.getOpenFileName(
- self, "Select python module with your function(s)", "",
- "Python Files (*.py);;All Files (*)")
-
- if len(functionsfile):
- try:
- self.fitmanager.loadtheories(functionsfile)
- except ImportError:
- qt.QMessageBox.critical(self, "ERROR",
- "Function not imported")
- return
- else:
- # empty the ComboBox
- while self.guiConfig.FunComBox.count() > 1:
- self.guiConfig.FunComBox.removeItem(1)
- # and fill it again
- for key in self.fitmanager.theories:
- self.guiConfig.FunComBox.addItem(str(key))
-
- i = 1 + \
- list(self.fitmanager.theories.keys()).index(
- self.fitmanager.selectedtheory)
- self.guiConfig.FunComBox.setCurrentIndex(i)
- self.__initialParameters()
-
- def weightEvent(self, flag):
- """This is called when WeightCheckBox is clicked, to configure the
- *WeightFlag* field in :attr:`fitmanager.fitconfig` and set weights
- in the least-square problem."""
- self.configure(WeightFlag=flag)
- if flag:
- self.fitmanager.enableweight()
- else:
- # set weights back to 1
- self.fitmanager.disableweight()
-
- def __initialParameters(self):
- """Fill the fit parameters names with names of the parameters of
- the selected background theory and the selected fit theory.
- Initialize :attr:`fitmanager.fit_results` with these names, and
- initialize the table with them. This creates a view called "Fit"
- in :attr:`guiParameters`"""
- self.fitmanager.parameter_names = []
- self.fitmanager.fit_results = []
- for pname in self.fitmanager.bgtheories[self.fitmanager.selectedbg].parameters:
- self.fitmanager.parameter_names.append(pname)
- self.fitmanager.fit_results.append({'name': pname,
- 'estimation': 0,
- 'group': 0,
- 'code': 'FREE',
- 'cons1': 0,
- 'cons2': 0,
- 'fitresult': 0.0,
- 'sigma': 0.0,
- 'xmin': None,
- 'xmax': None})
- if self.fitmanager.selectedtheory is not None:
- theory = self.fitmanager.selectedtheory
- for pname in self.fitmanager.theories[theory].parameters:
- self.fitmanager.parameter_names.append(pname + "1")
- self.fitmanager.fit_results.append({'name': pname + "1",
- 'estimation': 0,
- 'group': 1,
- 'code': 'FREE',
- 'cons1': 0,
- 'cons2': 0,
- 'fitresult': 0.0,
- 'sigma': 0.0,
- 'xmin': None,
- 'xmax': None})
-
- self.guiParameters.fillFromFit(
- self.fitmanager.fit_results, view='Fit')
-
- def fitStatus(self, data):
- """Set *status* and *chisq* in status bar"""
- if 'chisq' in data:
- if data['chisq'] is None:
- self.guistatus.ChisqLine.setText(" ")
- else:
- chisq = data['chisq']
- self.guistatus.ChisqLine.setText("%6.2f" % chisq)
-
- if 'status' in data:
- status = data['status']
- self.guistatus.StatusLine.setText(str(status))
-
- def dismiss(self):
- """Close FitWidget"""
- self.close()
-
-
-if __name__ == "__main__":
- import numpy
-
- x = numpy.arange(1500).astype(numpy.float64)
- constant_bg = 3.14
-
- p = [1000, 100., 30.0,
- 500, 300., 25.,
- 1700, 500., 35.,
- 750, 700., 30.0,
- 1234, 900., 29.5,
- 302, 1100., 30.5,
- 75, 1300., 21.]
- y = functions.sum_gauss(x, *p) + constant_bg
-
- a = qt.QApplication(sys.argv)
- w = FitWidget()
- w.setData(x=x, y=y)
- w.show()
- a.exec_()
diff --git a/silx/gui/fit/FitWidgets.py b/silx/gui/fit/FitWidgets.py
deleted file mode 100644
index 408666b..0000000
--- a/silx/gui/fit/FitWidgets.py
+++ /dev/null
@@ -1,559 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2004-2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ######################################################################### */
-"""Collection of widgets used to build
-:class:`silx.gui.fit.FitWidget.FitWidget`"""
-
-from collections import OrderedDict
-
-from silx.gui import qt
-from silx.gui.fit.Parameters import Parameters
-
-QTVERSION = qt.qVersion()
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "13/10/2016"
-
-
-class FitActionsButtons(qt.QWidget):
- """Widget with 3 ``QPushButton``:
-
- The buttons can be accessed as public attributes::
-
- - ``EstimateButton``
- - ``StartFitButton``
- - ``DismissButton``
-
- You will typically need to access these attributes to connect the buttons
- to actions. For instance, if you have 3 functions ``estimate``,
- ``runfit`` and ``dismiss``, you can connect them like this::
-
- >>> fit_actions_buttons = FitActionsButtons()
- >>> fit_actions_buttons.EstimateButton.clicked.connect(estimate)
- >>> fit_actions_buttons.StartFitButton.clicked.connect(runfit)
- >>> fit_actions_buttons.DismissButton.clicked.connect(dismiss)
-
- """
-
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
-
- self.resize(234, 53)
-
- grid_layout = qt.QGridLayout(self)
- grid_layout.setContentsMargins(11, 11, 11, 11)
- grid_layout.setSpacing(6)
- layout = qt.QHBoxLayout(None)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(6)
-
- self.EstimateButton = qt.QPushButton(self)
- self.EstimateButton.setText("Estimate")
- layout.addWidget(self.EstimateButton)
- spacer = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
- layout.addItem(spacer)
-
- self.StartFitButton = qt.QPushButton(self)
- self.StartFitButton.setText("Start Fit")
- layout.addWidget(self.StartFitButton)
- spacer_2 = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
- layout.addItem(spacer_2)
-
- self.DismissButton = qt.QPushButton(self)
- self.DismissButton.setText("Dismiss")
- layout.addWidget(self.DismissButton)
-
- grid_layout.addLayout(layout, 0, 0)
-
-
-class FitStatusLines(qt.QWidget):
- """Widget with 2 greyed out write-only ``QLineEdit``.
-
- These text widgets can be accessed as public attributes::
-
- - ``StatusLine``
- - ``ChisqLine``
-
- You will typically need to access these widgets to update the displayed
- text::
-
- >>> fit_status_lines = FitStatusLines()
- >>> fit_status_lines.StatusLine.setText("Ready")
- >>> fit_status_lines.ChisqLine.setText("%6.2f" % 0.01)
-
- """
-
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
-
- self.resize(535, 47)
-
- layout = qt.QHBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(6)
-
- self.StatusLabel = qt.QLabel(self)
- self.StatusLabel.setText("Status:")
- layout.addWidget(self.StatusLabel)
-
- self.StatusLine = qt.QLineEdit(self)
- self.StatusLine.setText("Ready")
- self.StatusLine.setReadOnly(1)
- layout.addWidget(self.StatusLine)
-
- self.ChisqLabel = qt.QLabel(self)
- self.ChisqLabel.setText("Reduced chisq:")
- layout.addWidget(self.ChisqLabel)
-
- self.ChisqLine = qt.QLineEdit(self)
- self.ChisqLine.setMaximumSize(qt.QSize(16000, 32767))
- self.ChisqLine.setText("")
- self.ChisqLine.setReadOnly(1)
- layout.addWidget(self.ChisqLine)
-
-
-class FitConfigWidget(qt.QWidget):
- """Widget whose purpose is to select a fit theory and a background
- theory, load a new fit theory definition file and provide
- a "Configure" button to open an advanced configuration dialog.
-
- This is used in :class:`silx.gui.fit.FitWidget.FitWidget`, to offer
- an interface to quickly modify the main parameters prior to running a fit:
-
- - select a fitting function through :attr:`FunComBox`
- - select a background function through :attr:`BkgComBox`
- - open a dialog for modifying advanced parameters through
- :attr:`FunConfigureButton`
- """
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
-
- self.setWindowTitle("FitConfigGUI")
-
- layout = qt.QGridLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(6)
-
- self.FunLabel = qt.QLabel(self)
- self.FunLabel.setText("Function")
- layout.addWidget(self.FunLabel, 0, 0)
-
- self.FunComBox = qt.QComboBox(self)
- self.FunComBox.addItem("Add Function(s)")
- self.FunComBox.setItemData(self.FunComBox.findText("Add Function(s)"),
- "Load fit theories from a file",
- qt.Qt.ToolTipRole)
- layout.addWidget(self.FunComBox, 0, 1)
-
- self.BkgLabel = qt.QLabel(self)
- self.BkgLabel.setText("Background")
- layout.addWidget(self.BkgLabel, 1, 0)
-
- self.BkgComBox = qt.QComboBox(self)
- self.BkgComBox.addItem("Add Background(s)")
- self.BkgComBox.setItemData(self.BkgComBox.findText("Add Background(s)"),
- "Load background theories from a file",
- qt.Qt.ToolTipRole)
- layout.addWidget(self.BkgComBox, 1, 1)
-
- self.FunConfigureButton = qt.QPushButton(self)
- self.FunConfigureButton.setText("Configure")
- self.FunConfigureButton.setToolTip(
- "Open a configuration dialog for the selected function")
- layout.addWidget(self.FunConfigureButton, 0, 2)
-
- self.BgConfigureButton = qt.QPushButton(self)
- self.BgConfigureButton.setText("Configure")
- self.BgConfigureButton.setToolTip(
- "Open a configuration dialog for the selected background")
- layout.addWidget(self.BgConfigureButton, 1, 2)
-
- self.WeightCheckBox = qt.QCheckBox(self)
- self.WeightCheckBox.setText("Weighted fit")
- self.WeightCheckBox.setToolTip(
- "Enable usage of weights in the least-square problem.\n Use" +
- " the uncertainties (sigma) if provided, else use sqrt(y).")
-
- layout.addWidget(self.WeightCheckBox, 0, 3, 2, 1)
-
- layout.setColumnStretch(4, 1)
-
-
-class ParametersTab(qt.QTabWidget):
- """This widget provides tabs to display and modify fit parameters. Each
- tab contains a table with fit data such as parameter names, estimated
- values, fit constraints, and final fit results.
-
- The usual way to initialize the table is to fill it with the fit
- parameters from a :class:`silx.math.fit.fitmanager.FitManager` object, after
- the estimation process or after the final fit.
-
- In the following example we use a :class:`ParametersTab` to display the
- results of two separate fits::
-
- from silx.math.fit import fittheories
- from silx.math.fit import fitmanager
- from silx.math.fit import functions
- from silx.gui import qt
- import numpy
-
- a = qt.QApplication([])
-
- # Create synthetic data
- x = numpy.arange(1000)
- y1 = functions.sum_gauss(x, 100, 400, 100)
-
- fit = fitmanager.FitManager(x=x, y=y1)
-
- fitfuns = fittheories.FitTheories()
- fit.addtheory(theory="Gaussian",
- function=functions.sum_gauss,
- parameters=("height", "peak center", "fwhm"),
- estimate=fitfuns.estimate_height_position_fwhm)
- fit.settheory('Gaussian')
- fit.configure(PositiveFwhmFlag=True,
- PositiveHeightAreaFlag=True,
- AutoFwhm=True,)
-
- # Fit
- fit.estimate()
- fit.runfit()
-
- # Show first fit result in a tab in our widget
- w = ParametersTab()
- w.show()
- w.fillFromFit(fit.fit_results, view='Gaussians')
-
- # new synthetic data
- y2 = functions.sum_splitgauss(x,
- 100, 400, 100, 40,
- 10, 600, 50, 500,
- 80, 850, 10, 50)
- fit.setData(x=x, y=y2)
-
- # Define new theory
- fit.addtheory(theory="Asymetric gaussian",
- function=functions.sum_splitgauss,
- parameters=("height", "peak center", "left fwhm", "right fwhm"),
- estimate=fitfuns.estimate_splitgauss)
- fit.settheory('Asymetric gaussian')
-
- # Fit
- fit.estimate()
- fit.runfit()
-
- # Show first fit result in another tab in our widget
- w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
- a.exec_()
-
- """
-
- def __init__(self, parent=None, name="FitParameters"):
- """
-
- :param parent: Parent widget
- :param name: Widget title
- """
- qt.QTabWidget.__init__(self, parent)
- self.setWindowTitle(name)
- self.setContentsMargins(0, 0, 0, 0)
-
- self.views = OrderedDict()
- """Dictionary of views. Keys are view names,
- items are :class:`Parameters` widgets"""
-
- self.latest_view = None
- """Name of latest view"""
-
- # the widgets/tables themselves
- self.tables = {}
- """Dictionary of :class:`silx.gui.fit.parameters.Parameters` objects.
- These objects store fit results
- """
-
- self.setContentsMargins(10, 10, 10, 10)
-
- def setView(self, view=None, fitresults=None):
- """Add or update a table. Fill it with data from a fit
-
- :param view: Tab name to be added or updated. If ``None``, use the
- latest view.
- :param fitresults: Fit data to be added to the table
- :raise: KeyError if no view name specified and no latest view
- available.
- """
- if view is None:
- if self.latest_view is not None:
- view = self.latest_view
- else:
- raise KeyError(
- "No view available. You must specify a view" +
- " name the first time you call this method."
- )
-
- if view in self.tables.keys():
- table = self.tables[view]
- else:
- # create the parameters instance
- self.tables[view] = Parameters(self)
- table = self.tables[view]
- self.views[view] = table
- self.addTab(table, str(view))
-
- if fitresults is not None:
- table.fillFromFit(fitresults)
-
- self.setCurrentWidget(self.views[view])
- self.latest_view = view
-
- def renameView(self, oldname=None, newname=None):
- """Rename a view (tab)
-
- :param oldname: Name of the view to be renamed
- :param newname: New name of the view"""
- error = 1
- if newname is not None:
- if newname not in self.views.keys():
- if oldname in self.views.keys():
- parameterlist = self.tables[oldname].getFitResults()
- self.setView(view=newname, fitresults=parameterlist)
- self.removeView(oldname)
- error = 0
- return error
-
- def fillFromFit(self, fitparameterslist, view=None):
- """Update a view with data from a fit (alias for :meth:`setView`)
-
- :param view: Tab name to be added or updated (default: latest view)
- :param fitparameterslist: Fit data to be added to the table
- """
- self.setView(view=view, fitresults=fitparameterslist)
-
- def getFitResults(self, name=None):
- """Call :meth:`getFitResults` for the
- :class:`silx.gui.fit.parameters.Parameters` corresponding to the
- latest table or to the named table (if ``name`` is not
- ``None``). This return a list of dictionaries in the format used by
- :class:`silx.math.fit.fitmanager.FitManager` to store fit parameter
- results.
-
- :param name: View name.
- """
- if name is None:
- name = self.latest_view
- return self.tables[name].getFitResults()
-
- def removeView(self, name):
- """Remove a view by name.
-
- :param name: View name.
- """
- if name in self.views:
- index = self.indexOf(self.tables[name])
- self.removeTab(index)
- index = self.indexOf(self.views[name])
- self.removeTab(index)
- del self.tables[name]
- del self.views[name]
-
- def removeAllViews(self, keep=None):
- """Remove all views, except the one specified (argument
- ``keep``)
-
- :param keep: Name of the view to be kept."""
- for view in self.tables:
- if view != keep:
- self.removeView(view)
-
- def getHtmlText(self, name=None):
- """Return the table data as HTML
-
- :param name: View name."""
- if name is None:
- name = self.latest_view
- table = self.tables[name]
- lemon = ("#%x%x%x" % (255, 250, 205)).upper()
- hcolor = ("#%x%x%x" % (230, 240, 249)).upper()
- text = ""
- text += "<nobr>"
- text += "<table>"
- text += "<tr>"
- ncols = table.columnCount()
- for l in range(ncols):
- text += ('<td align="left" bgcolor="%s"><b>' % hcolor)
- if QTVERSION < '4.0.0':
- text += (str(table.horizontalHeader().label(l)))
- else:
- text += (str(table.horizontalHeaderItem(l).text()))
- text += "</b></td>"
- text += "</tr>"
- nrows = table.rowCount()
- for r in range(nrows):
- text += "<tr>"
- item = table.item(r, 0)
- newtext = ""
- if item is not None:
- newtext = str(item.text())
- if len(newtext):
- color = "white"
- b = "<b>"
- else:
- b = ""
- color = lemon
- try:
- # MyQTable item has color defined
- cc = table.item(r, 0).color
- cc = ("#%x%x%x" % (cc.red(), cc.green(), cc.blue())).upper()
- color = cc
- except:
- pass
- for c in range(ncols):
- item = table.item(r, c)
- newtext = ""
- if item is not None:
- newtext = str(item.text())
- if len(newtext):
- finalcolor = color
- else:
- finalcolor = "white"
- if c < 2:
- text += ('<td align="left" bgcolor="%s">%s' %
- (finalcolor, b))
- else:
- text += ('<td align="right" bgcolor="%s">%s' %
- (finalcolor, b))
- text += newtext
- if len(b):
- text += "</td>"
- else:
- text += "</b></td>"
- item = table.item(r, 0)
- newtext = ""
- if item is not None:
- newtext = str(item.text())
- if len(newtext):
- text += "</b>"
- text += "</tr>"
- text += "\n"
- text += "</table>"
- text += "</nobr>"
- return text
-
- def getText(self, name=None):
- """Return the table data as CSV formatted text, using tabulation
- characters as separators.
-
- :param name: View name."""
- if name is None:
- name = self.latest_view
- table = self.tables[name]
- text = ""
- ncols = table.columnCount()
- for l in range(ncols):
- text += (str(table.horizontalHeaderItem(l).text())) + "\t"
- text += "\n"
- nrows = table.rowCount()
- for r in range(nrows):
- for c in range(ncols):
- newtext = ""
- if c != 4:
- item = table.item(r, c)
- if item is not None:
- newtext = str(item.text())
- else:
- item = table.cellWidget(r, c)
- if item is not None:
- newtext = str(item.currentText())
- text += newtext + "\t"
- text += "\n"
- text += "\n"
- return text
-
-
-def test():
- from silx.math.fit import fittheories
- from silx.math.fit import fitmanager
- from silx.math.fit import functions
- from silx.gui.plot.PlotWindow import PlotWindow
- import numpy
-
- a = qt.QApplication([])
-
- x = numpy.arange(1000)
- y1 = functions.sum_gauss(x, 100, 400, 100)
-
- fit = fitmanager.FitManager(x=x, y=y1)
-
- fitfuns = fittheories.FitTheories()
- fit.addtheory(name="Gaussian",
- function=functions.sum_gauss,
- parameters=("height", "peak center", "fwhm"),
- estimate=fitfuns.estimate_height_position_fwhm)
- fit.settheory('Gaussian')
- fit.configure(PositiveFwhmFlag=True,
- PositiveHeightAreaFlag=True,
- AutoFwhm=True,)
-
- # Fit
- fit.estimate()
- fit.runfit()
-
- w = ParametersTab()
- w.show()
- w.fillFromFit(fit.fit_results, view='Gaussians')
-
- y2 = functions.sum_splitgauss(x,
- 100, 400, 100, 40,
- 10, 600, 50, 500,
- 80, 850, 10, 50)
- fit.setdata(x=x, y=y2)
-
- # Define new theory
- fit.addtheory(name="Asymetric gaussian",
- function=functions.sum_splitgauss,
- parameters=("height", "peak center", "left fwhm", "right fwhm"),
- estimate=fitfuns.estimate_splitgauss)
- fit.settheory('Asymetric gaussian')
-
- # Fit
- fit.estimate()
- fit.runfit()
-
- w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
-
- # Plot
- pw = PlotWindow(control=True)
- pw.addCurve(x, y1, "Gaussians")
- pw.addCurve(x, y2, "Asymetric gaussians")
- pw.show()
-
- a.exec_()
-
-
-if __name__ == "__main__":
- test()
diff --git a/silx/gui/fit/Parameters.py b/silx/gui/fit/Parameters.py
deleted file mode 100644
index 62e3278..0000000
--- a/silx/gui/fit/Parameters.py
+++ /dev/null
@@ -1,882 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ######################################################################### */
-"""This module defines a table widget that is specialized in displaying fit
-parameter results and associated constraints."""
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "25/11/2016"
-
-import sys
-from collections import OrderedDict
-
-from silx.gui import qt
-from silx.gui.widgets.TableWidget import TableWidget
-
-
-def float_else_zero(sstring):
- """Return converted string to float. If conversion fail, return zero.
-
- :param sstring: String to be converted
- :return: ``float(sstrinq)`` if ``sstring`` can be converted to float
- (e.g. ``"3.14"``), else ``0``
- """
- try:
- return float(sstring)
- except ValueError:
- return 0
-
-
-class QComboTableItem(qt.QComboBox):
- """:class:`qt.QComboBox` augmented with a ``sigCellChanged`` signal
- to emit a tuple of ``(row, column)`` coordinates when the value is
- changed.
-
- This signal can be used to locate the modified combo box in a table.
-
- :param row: Row number of the table cell containing this widget
- :param col: Column number of the table cell containing this widget"""
- sigCellChanged = qt.Signal(int, int)
- """Signal emitted when this ``QComboBox`` is activated.
- A ``(row, column)`` tuple is passed."""
-
- def __init__(self, parent=None, row=None, col=None):
- self._row = row
- self._col = col
- qt.QComboBox.__init__(self, parent)
- self.activated[int].connect(self._cellChanged)
-
- def _cellChanged(self, idx): # noqa
- self.sigCellChanged.emit(self._row, self._col)
-
-
-class QCheckBoxItem(qt.QCheckBox):
- """:class:`qt.QCheckBox` augmented with a ``sigCellChanged`` signal
- to emit a tuple of ``(row, column)`` coordinates when the check box has
- been clicked on.
-
- This signal can be used to locate the modified check box in a table.
-
- :param row: Row number of the table cell containing this widget
- :param col: Column number of the table cell containing this widget"""
- sigCellChanged = qt.Signal(int, int)
- """Signal emitted when this ``QCheckBox`` is clicked.
- A ``(row, column)`` tuple is passed."""
-
- def __init__(self, parent=None, row=None, col=None):
- self._row = row
- self._col = col
- qt.QCheckBox.__init__(self, parent)
- self.clicked.connect(self._cellChanged)
-
- def _cellChanged(self):
- self.sigCellChanged.emit(self._row, self._col)
-
-
-class Parameters(TableWidget):
- """:class:`TableWidget` customized to display fit results
- and to interact with :class:`FitManager` objects.
-
- Data and references to cell widgets are kept in a dictionary
- attribute :attr:`parameters`.
-
- :param parent: Parent widget
- :param labels: Column headers. If ``None``, default headers will be used.
- :type labels: List of strings or None
- :param paramlist: List of fit parameters to be displayed for each fitted
- peak.
- :type paramlist: list[str] or None
- """
- def __init__(self, parent=None, paramlist=None):
- TableWidget.__init__(self, parent)
- self.setContentsMargins(0, 0, 0, 0)
-
- labels = ['Parameter', 'Estimation', 'Fit Value', 'Sigma',
- 'Constraints', 'Min/Parame', 'Max/Factor/Delta']
- tooltips = ["Fit parameter name",
- "Estimated value for fit parameter. You can edit this column.",
- "Actual value for parameter, after fit",
- "Uncertainty (same unit as the parameter)",
- "Constraint to be applied to the parameter for fit",
- "First parameter for constraint (name of another param or min value)",
- "Second parameter for constraint (max value, or factor/delta)"]
-
- self.columnKeys = ['name', 'estimation', 'fitresult',
- 'sigma', 'code', 'val1', 'val2']
- """This list assigns shorter keys to refer to columns than the
- displayed labels."""
-
- self.__configuring = False
-
- # column headers and associated tooltips
- self.setColumnCount(len(labels))
-
- for i, label in enumerate(labels):
- item = self.horizontalHeaderItem(i)
- if item is None:
- item = qt.QTableWidgetItem(label,
- qt.QTableWidgetItem.Type)
- self.setHorizontalHeaderItem(i, item)
-
- item.setText(label)
- if tooltips is not None:
- item.setToolTip(tooltips[i])
-
- # resize columns
- for col_key in ["name", "estimation", "sigma", "val1", "val2"]:
- col_idx = self.columnIndexByField(col_key)
- self.resizeColumnToContents(col_idx)
-
- # Initialize the table with one line per supplied parameter
- paramlist = paramlist if paramlist is not None else []
- self.parameters = OrderedDict()
- """This attribute stores all the data in an ordered dictionary.
- New data can be added using :meth:`newParameterLine`.
- Existing data can be modified using :meth:`configureLine`
-
- Keys of the dictionary are:
-
- - 'name': parameter name
- - 'line': line index for the parameter in the table
- - 'estimation'
- - 'fitresult'
- - 'sigma'
- - 'code': constraint code (one of the elements of
- :attr:`code_options`)
- - 'val1': first parameter related to constraint, formatted
- as a string, as typed in the table
- - 'val2': second parameter related to constraint, formatted
- as a string, as typed in the table
- - 'cons1': scalar representation of 'val1'
- (e.g. when val1 is the name of a fit parameter, cons1
- will be the line index of this parameter)
- - 'cons2': scalar representation of 'val2'
- - 'vmin': equal to 'val1' when 'code' is "QUOTED"
- - 'vmax': equal to 'val2' when 'code' is "QUOTED"
- - 'relatedto': name of related parameter when this parameter
- is constrained to another parameter (same as 'val1')
- - 'factor': same as 'val2' when 'code' is 'FACTOR'
- - 'delta': same as 'val2' when 'code' is 'DELTA'
- - 'sum': same as 'val2' when 'code' is 'SUM'
- - 'group': group index for the parameter
- - 'xmin': data range minimum
- - 'xmax': data range maximum
- """
- for line, param in enumerate(paramlist):
- self.newParameterLine(param, line)
-
- self.code_options = ["FREE", "POSITIVE", "QUOTED", "FIXED",
- "FACTOR", "DELTA", "SUM", "IGNORE", "ADD"]
- """Possible values in the combo boxes in the 'Constraints' column.
- """
-
- # connect signal
- self.cellChanged[int, int].connect(self.onCellChanged)
-
- def newParameterLine(self, param, line):
- """Add a line to the :class:`QTableWidget`.
-
- Each line represents one of the fit parameters for one of
- the fitted peaks.
-
- :param param: Name of the fit parameter
- :type param: str
- :param line: 0-based line index
- :type line: int
- """
- # get current number of lines
- nlines = self.rowCount()
- self.__configuring = True
- if line >= nlines:
- self.setRowCount(line + 1)
-
- # default configuration for fit parameters
- self.parameters[param] = OrderedDict((('line', line),
- ('estimation', '0'),
- ('fitresult', ''),
- ('sigma', ''),
- ('code', 'FREE'),
- ('val1', ''),
- ('val2', ''),
- ('cons1', 0),
- ('cons2', 0),
- ('vmin', '0'),
- ('vmax', '1'),
- ('relatedto', ''),
- ('factor', '1.0'),
- ('delta', '0.0'),
- ('sum', '0.0'),
- ('group', ''),
- ('name', param),
- ('xmin', None),
- ('xmax', None)))
- self.setReadWrite(param, 'estimation')
- self.setReadOnly(param, ['name', 'fitresult', 'sigma', 'val1', 'val2'])
-
- # Constraint codes
- a = []
- for option in self.code_options:
- a.append(option)
-
- code_column_index = self.columnIndexByField('code')
- cellWidget = self.cellWidget(line, code_column_index)
- if cellWidget is None:
- cellWidget = QComboTableItem(self, row=line,
- col=code_column_index)
- cellWidget.addItems(a)
- self.setCellWidget(line, code_column_index, cellWidget)
- cellWidget.sigCellChanged[int, int].connect(self.onCellChanged)
- self.parameters[param]['code_item'] = cellWidget
- self.parameters[param]['relatedto_item'] = None
- self.__configuring = False
-
- def columnIndexByField(self, field):
- """
-
- :param field: Field name (column key)
- :return: Index of the column with this field name
- """
- return self.columnKeys.index(field)
-
- def fillFromFit(self, fitresults):
- """Fill table with values from a list of dictionaries
- (see :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
-
- :param fitresults: List of parameters as recorded
- in the ``paramlist`` attribute of a :class:`FitManager` object
- :type fitresults: list[dict]
- """
- self.setRowCount(len(fitresults))
-
- # Reinitialize and fill self.parameters
- self.parameters = OrderedDict()
- for (line, param) in enumerate(fitresults):
- self.newParameterLine(param['name'], line)
-
- for param in fitresults:
- name = param['name']
- code = str(param['code'])
- if code not in self.code_options:
- # convert code from int to descriptive string
- code = self.code_options[int(code)]
- val1 = param['cons1']
- val2 = param['cons2']
- estimation = param['estimation']
- group = param['group']
- sigma = param['sigma']
- fitresult = param['fitresult']
-
- xmin = param.get('xmin')
- xmax = param.get('xmax')
-
- self.configureLine(name=name,
- code=code,
- val1=val1, val2=val2,
- estimation=estimation,
- fitresult=fitresult,
- sigma=sigma,
- group=group,
- xmin=xmin, xmax=xmax)
-
- def getConfiguration(self):
- """Return ``FitManager.paramlist`` dictionary
- encapsulated in another dictionary"""
- return {'parameters': self.getFitResults()}
-
- def setConfiguration(self, ddict):
- """Fill table with values from a ``FitManager.paramlist`` dictionary
- encapsulated in another dictionary"""
- self.fillFromFit(ddict['parameters'])
-
- def getFitResults(self):
- """Return fit parameters as a list of dictionaries in the format used
- by :class:`FitManager` (attribute ``paramlist``).
- """
- fitparameterslist = []
- for param in self.parameters:
- fitparam = {}
- name = param
- estimation, [code, cons1, cons2] = self.getEstimationConstraints(name)
- buf = str(self.parameters[param]['fitresult'])
- xmin = self.parameters[param]['xmin']
- xmax = self.parameters[param]['xmax']
- if len(buf):
- fitresult = float(buf)
- else:
- fitresult = 0.0
- buf = str(self.parameters[param]['sigma'])
- if len(buf):
- sigma = float(buf)
- else:
- sigma = 0.0
- buf = str(self.parameters[param]['group'])
- if len(buf):
- group = float(buf)
- else:
- group = 0
- fitparam['name'] = name
- fitparam['estimation'] = estimation
- fitparam['fitresult'] = fitresult
- fitparam['sigma'] = sigma
- fitparam['group'] = group
- fitparam['code'] = code
- fitparam['cons1'] = cons1
- fitparam['cons2'] = cons2
- fitparam['xmin'] = xmin
- fitparam['xmax'] = xmax
- fitparameterslist.append(fitparam)
- return fitparameterslist
-
- def onCellChanged(self, row, col):
- """Slot called when ``cellChanged`` signal is emitted.
- Checks the validity of the new text in the cell, then calls
- :meth:`configureLine` to update the internal ``self.parameters``
- dictionary.
-
- :param row: Row number of the changed cell (0-based index)
- :param col: Column number of the changed cell (0-based index)
- """
- if (col != self.columnIndexByField("code")) and (col != -1):
- if row != self.currentRow():
- return
- if col != self.currentColumn():
- return
- if self.__configuring:
- return
- param = list(self.parameters)[row]
- field = self.columnKeys[col]
- oldvalue = self.parameters[param][field]
- if col != 4:
- item = self.item(row, col)
- if item is not None:
- newvalue = item.text()
- else:
- newvalue = ''
- else:
- # this is the combobox
- widget = self.cellWidget(row, col)
- newvalue = widget.currentText()
- if self.validate(param, field, oldvalue, newvalue):
- paramdict = {"name": param, field: newvalue}
- self.configureLine(**paramdict)
- else:
- if field == 'code':
- # New code not valid, try restoring the old one
- index = self.code_options.index(oldvalue)
- self.__configuring = True
- try:
- self.parameters[param]['code_item'].setCurrentIndex(index)
- finally:
- self.__configuring = False
- else:
- paramdict = {"name": param, field: oldvalue}
- self.configureLine(**paramdict)
-
- def validate(self, param, field, oldvalue, newvalue):
- """Check validity of ``newvalue`` when a cell's value is modified.
-
- :param param: Fit parameter name
- :param field: Column name
- :param oldvalue: Cell value before change attempt
- :param newvalue: New value to be validated
- :return: True if new cell value is valid, else False
- """
- if field == 'code':
- return self.setCodeValue(param, oldvalue, newvalue)
- # FIXME: validate() shouldn't have side effects. Move this bit to configureLine()?
- if field == 'val1' and str(self.parameters[param]['code']) in ['DELTA', 'FACTOR', 'SUM']:
- _, candidates = self.getRelatedCandidates(param)
- # We expect val1 to be a fit parameter name
- if str(newvalue) in candidates:
- return True
- else:
- return False
- # except for code, val1 and name (which is read-only and does not need
- # validation), all fields must always be convertible to float
- else:
- try:
- float(str(newvalue))
- except ValueError:
- return False
- return True
-
- def setCodeValue(self, param, oldvalue, newvalue):
- """Update 'code' and 'relatedto' fields when code cell is
- changed.
-
- :param param: Fit parameter name
- :param oldvalue: Cell value before change attempt
- :param newvalue: New value to be validated
- :return: ``True`` if code was successfully updated
- """
-
- if str(newvalue) in ['FREE', 'POSITIVE', 'QUOTED', 'FIXED']:
- self.configureLine(name=param,
- code=newvalue)
- if str(oldvalue) == 'IGNORE':
- self.freeRestOfGroup(param)
- return True
- elif str(newvalue) in ['FACTOR', 'DELTA', 'SUM']:
- # I should check here that some parameter is set
- best, candidates = self.getRelatedCandidates(param)
- if len(candidates) == 0:
- return False
- self.configureLine(name=param,
- code=newvalue,
- relatedto=best)
- if str(oldvalue) == 'IGNORE':
- self.freeRestOfGroup(param)
- return True
-
- elif str(newvalue) == 'IGNORE':
- # I should check if the group can be ignored
- # for the time being I just fix all of them to ignore
- group = int(float(str(self.parameters[param]['group'])))
- candidates = []
- for param in self.parameters.keys():
- if group == int(float(str(self.parameters[param]['group']))):
- candidates.append(param)
- # print candidates
- # I should check here if there is any relation to them
- for param in candidates:
- self.configureLine(name=param,
- code=newvalue)
- return True
- elif str(newvalue) == 'ADD':
- group = int(float(str(self.parameters[param]['group'])))
- if group == 0:
- # One cannot add a background group
- return False
- i = 0
- for param in self.parameters:
- if i <= int(float(str(self.parameters[param]['group']))):
- i += 1
- if (group == 0) and (i == 1): # FIXME: why +1?
- i += 1
- self.addGroup(i, group)
- return False
- elif str(newvalue) == 'SHOW':
- print(self.getEstimationConstraints(param))
- return False
-
- def addGroup(self, newg, gtype):
- """Add a fit parameter group with the same fit parameters as an
- existing group.
-
- This function is called when the user selects "ADD" in the
- "constraints" combobox.
-
- :param int newg: New group number
- :param int gtype: Group number whose parameters we want to copy
-
- """
- newparam = []
- # loop through parameters until we encounter group number `gtype`
- for param in list(self.parameters):
- paramgroup = int(float(str(self.parameters[param]['group'])))
- # copy parameter names in group number `gtype`
- if paramgroup == gtype:
- # but replace `gtype` with `newg`
- newparam.append(param.rstrip("0123456789") + "%d" % newg)
-
- xmin = self.parameters[param]['xmin']
- xmax = self.parameters[param]['xmax']
-
- # Add new parameters (one table line per parameter) and configureLine each
- # one by updating xmin and xmax to the same values as group `gtype`
- line = len(list(self.parameters))
- for param in newparam:
- self.newParameterLine(param, line)
- line += 1
- for param in newparam:
- self.configureLine(name=param, group=newg, xmin=xmin, xmax=xmax)
-
- def freeRestOfGroup(self, workparam):
- """Set ``code`` to ``"FREE"`` for all fit parameters belonging to
- the same group as ``workparam``. This is done when the entire group
- of parameters was previously ignored and one of them has his code
- set to something different than ``"IGNORE"``.
-
- :param workparam: Fit parameter name
- """
- if workparam in self.parameters.keys():
- group = int(float(str(self.parameters[workparam]['group'])))
- for param in self.parameters:
- if param != workparam and\
- group == int(float(str(self.parameters[param]['group']))):
- self.configureLine(name=param,
- code='FREE',
- cons1=0,
- cons2=0,
- val1='',
- val2='')
-
- def getRelatedCandidates(self, workparam):
- """If fit parameter ``workparam`` has a constraint that involves other
- fit parameters, find possible candidates and try to guess which one
- is the most likely.
-
- :param workparam: Fit parameter name
- :return: (best_candidate, possible_candidates) tuple
- :rtype: (str, list[str])
- """
- candidates = []
- for param_name in self.parameters:
- if param_name != workparam:
- # ignore parameters that are fixed by a constraint
- if str(self.parameters[param_name]['code']) not in\
- ['IGNORE', 'FACTOR', 'DELTA', 'SUM']:
- candidates.append(param_name)
- # take the previous one (before code cell changed) if possible
- if str(self.parameters[workparam]['relatedto']) in candidates:
- best = str(self.parameters[workparam]['relatedto'])
- return best, candidates
- # take the first with same base name (after removing numbers)
- for param_name in candidates:
- basename = param_name.rstrip("0123456789")
- try:
- pos = workparam.index(basename)
- if pos == 0:
- best = param_name
- return best, candidates
- except ValueError:
- pass
- # take the first
- return candidates[0], candidates
-
- def setReadOnly(self, parameter, fields):
- """Make table cells read-only by setting it's flags and omitting
- flag ``qt.Qt.ItemIsEditable``
-
- :param parameter: Fit parameter names identifying the rows
- :type parameter: str or list[str]
- :param fields: Field names identifying the columns
- :type fields: str or list[str]
- """
- editflags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
- self.setField(parameter, fields, editflags)
-
- def setReadWrite(self, parameter, fields):
- """Make table cells read-write by setting it's flags including
- flag ``qt.Qt.ItemIsEditable``
-
- :param parameter: Fit parameter names identifying the rows
- :type parameter: str or list[str]
- :param fields: Field names identifying the columns
- :type fields: str or list[str]
- """
- editflags = qt.Qt.ItemIsSelectable |\
- qt.Qt.ItemIsEnabled |\
- qt.Qt.ItemIsEditable
- self.setField(parameter, fields, editflags)
-
- def setField(self, parameter, fields, edit_flags):
- """Set text and flags in a table cell.
-
- :param parameter: Fit parameter names identifying the rows
- :type parameter: str or list[str]
- :param fields: Field names identifying the columns
- :type fields: str or list[str]
- :param edit_flags: Flag combination, e.g::
-
- qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable
- """
- if isinstance(parameter, list) or \
- isinstance(parameter, tuple):
- paramlist = parameter
- else:
- paramlist = [parameter]
- if isinstance(fields, list) or \
- isinstance(fields, tuple):
- fieldlist = fields
- else:
- fieldlist = [fields]
-
- # Set _configuring flag to ignore cellChanged signals in
- # self.onCellChanged
- _oldvalue = self.__configuring
- self.__configuring = True
-
- # 2D loop through parameter list and field list
- # to update their cells
- for param in paramlist:
- row = list(self.parameters.keys()).index(param)
- for field in fieldlist:
- col = self.columnIndexByField(field)
- if field != 'code':
- key = field + "_item"
- item = self.item(row, col)
- if item is None:
- item = qt.QTableWidgetItem()
- item.setText(self.parameters[param][field])
- self.setItem(row, col, item)
- else:
- item.setText(self.parameters[param][field])
- self.parameters[param][key] = item
- item.setFlags(edit_flags)
-
- # Restore previous _configuring flag
- self.__configuring = _oldvalue
-
- def configureLine(self, name, code=None, val1=None, val2=None,
- sigma=None, estimation=None, fitresult=None,
- group=None, xmin=None, xmax=None, relatedto=None,
- cons1=None, cons2=None):
- """This function updates values in a line of the table
-
- :param name: Name of the parameter (serves as unique identifier for
- a line).
- :param code: Constraint code *FREE, FIXED, POSITIVE, DELTA, FACTOR,
- SUM, QUOTED, IGNORE*
- :param val1: Constraint 1 (can be the index or name of another
- parameter for code *DELTA, FACTOR, SUM*, or a min value
- for code *QUOTED*)
- :param val2: Constraint 2
- :param sigma: Standard deviation for a fit parameter
- :param estimation: Estimated initial value for a fit parameter (used
- as input to iterative fit)
- :param fitresult: Final result of fit
- :param group: Group number of a fit parameter (peak number when doing
- multi-peak fitting, as each peak corresponds to a group
- of several consecutive parameters)
- :param xmin:
- :param xmax:
- :param relatedto: Index or name of another fit parameter
- to which this parameter is related to (constraints)
- :param cons1: similar meaning to ``val1``, but is always a number
- :param cons2: similar meaning to ``val2``, but is always a number
- :return:
- """
- paramlist = list(self.parameters.keys())
-
- if name not in self.parameters:
- raise KeyError("'%s' is not in the parameter list" % name)
-
- # update code first, if specified
- if code is not None:
- code = str(code)
- self.parameters[name]['code'] = code
- # update combobox
- index = self.parameters[name]['code_item'].findText(code)
- self.parameters[name]['code_item'].setCurrentIndex(index)
- else:
- # set code to previous value, used later for setting val1 val2
- code = self.parameters[name]['code']
-
- # val1 and sigma have special formats
- if val1 is not None:
- fmt = None if self.parameters[name]['code'] in\
- ['DELTA', 'FACTOR', 'SUM'] else "%8g"
- self._updateField(name, "val1", val1, fmat=fmt)
-
- if sigma is not None:
- self._updateField(name, "sigma", sigma, fmat="%6.3g")
-
- # other fields are formatted as "%8g"
- keys_params = (("val2", val2), ("estimation", estimation),
- ("fitresult", fitresult))
- for key, value in keys_params:
- if value is not None:
- self._updateField(name, key, value, fmat="%8g")
-
- # the rest of the parameters are treated as strings and don't need
- # validation
- keys_params = (("group", group), ("xmin", xmin),
- ("xmax", xmax), ("relatedto", relatedto),
- ("cons1", cons1), ("cons2", cons2))
- for key, value in keys_params:
- if value is not None:
- self.parameters[name][key] = str(value)
-
- # val1 and val2 have different meanings depending on the code
- if code == 'QUOTED':
- if val1 is not None:
- self.parameters[name]['vmin'] = self.parameters[name]['val1']
- else:
- self.parameters[name]['val1'] = self.parameters[name]['vmin']
- if val2 is not None:
- self.parameters[name]['vmax'] = self.parameters[name]['val2']
- else:
- self.parameters[name]['val2'] = self.parameters[name]['vmax']
-
- # cons1 and cons2 are scalar representations of val1 and val2
- self.parameters[name]['cons1'] =\
- float_else_zero(self.parameters[name]['val1'])
- self.parameters[name]['cons2'] =\
- float_else_zero(self.parameters[name]['val2'])
-
- # cons1, cons2 = min(val1, val2), max(val1, val2)
- if self.parameters[name]['cons1'] > self.parameters[name]['cons2']:
- self.parameters[name]['cons1'], self.parameters[name]['cons2'] =\
- self.parameters[name]['cons2'], self.parameters[name]['cons1']
-
- elif code in ['DELTA', 'SUM', 'FACTOR']:
- # For these codes, val1 is the fit parameter name on which the
- # constraint depends
- if val1 is not None and val1 in paramlist:
- self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
-
- elif val1 is not None:
- # val1 could be the index of the fit parameter
- try:
- self.parameters[name]['relatedto'] = paramlist[int(val1)]
- except ValueError:
- self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
-
- elif relatedto is not None:
- # code changed, val1 not specified but relatedto specified:
- # set val1 to relatedto (pre-fill best guess)
- self.parameters[name]["val1"] = relatedto
-
- # update fields "delta", "sum" or "factor"
- key = code.lower()
- self.parameters[name][key] = self.parameters[name]["val2"]
-
- # FIXME: val1 is sometimes specified as an index rather than a param name
- self.parameters[name]['val1'] = self.parameters[name]['relatedto']
-
- # cons1 is the index of the fit parameter in the ordered dictionary
- if self.parameters[name]['val1'] in paramlist:
- self.parameters[name]['cons1'] =\
- paramlist.index(self.parameters[name]['val1'])
-
- # cons2 is the constraint value (factor, delta or sum)
- try:
- self.parameters[name]['cons2'] =\
- float(str(self.parameters[name]['val2']))
- except ValueError:
- self.parameters[name]['cons2'] = 1.0 if code == "FACTOR" else 0.0
-
- elif code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
- self.parameters[name]['val1'] = ""
- self.parameters[name]['val2'] = ""
- self.parameters[name]['cons1'] = 0
- self.parameters[name]['cons2'] = 0
-
- self._updateCellRWFlags(name, code)
-
- def _updateField(self, name, field, value, fmat=None):
- """Update field in ``self.parameters`` dictionary, if the new value
- is valid.
-
- :param name: Fit parameter name
- :param field: Field name
- :param value: New value to assign
- :type value: String
- :param fmat: Format string (e.g. "%8g") to be applied if value represents
- a scalar. If ``None``, format is not modified. If ``value`` is an
- empty string, ``fmat`` is ignored.
- """
- if value is not None:
- oldvalue = self.parameters[name][field]
- if fmat is not None:
- newvalue = fmat % float(value) if value != "" else ""
- else:
- newvalue = value
- self.parameters[name][field] = newvalue if\
- self.validate(name, field, oldvalue, newvalue) else\
- oldvalue
-
- def _updateCellRWFlags(self, name, code=None):
- """Set read-only or read-write flags in a row,
- depending on the constraint code
-
- :param name: Fit parameter name identifying the row
- :param code: Constraint code, in `'FREE', 'POSITIVE', 'IGNORE',`
- `'FIXED', 'FACTOR', 'DELTA', 'SUM', 'ADD'`
- :return:
- """
- if code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
- self.setReadWrite(name, 'estimation')
- self.setReadOnly(name, ['fitresult', 'sigma', 'val1', 'val2'])
- else:
- self.setReadWrite(name, ['estimation', 'val1', 'val2'])
- self.setReadOnly(name, ['fitresult', 'sigma'])
-
- def getEstimationConstraints(self, param):
- """
- Return tuple ``(estimation, constraints)`` where ``estimation`` is the
- value in the ``estimate`` field and ``constraints`` are the relevant
- constraints according to the active code
- """
- estimation = None
- constraints = None
- if param in self.parameters.keys():
- buf = str(self.parameters[param]['estimation'])
- if len(buf):
- estimation = float(buf)
- else:
- estimation = 0
- if str(self.parameters[param]['code']) in self.code_options:
- code = self.code_options.index(
- str(self.parameters[param]['code']))
- else:
- code = str(self.parameters[param]['code'])
- cons1 = self.parameters[param]['cons1']
- cons2 = self.parameters[param]['cons2']
- constraints = [code, cons1, cons2]
- return estimation, constraints
-
-
-def main(args):
- from silx.math.fit import fittheories
- from silx.math.fit import fitmanager
- try:
- from PyMca5 import PyMcaDataDir
- except ImportError:
- raise ImportError("This demo requires PyMca data. Install PyMca5.")
- import numpy
- import os
- app = qt.QApplication(args)
- tab = Parameters(paramlist=['Height', 'Position', 'FWHM'])
- tab.showGrid()
- tab.configureLine(name='Height', estimation='1234', group=0)
- tab.configureLine(name='Position', code='FIXED', group=1)
- tab.configureLine(name='FWHM', group=1)
-
- y = numpy.loadtxt(os.path.join(PyMcaDataDir.PYMCA_DATA_DIR,
- "XRFSpectrum.mca")) # FIXME
-
- x = numpy.arange(len(y)) * 0.0502883 - 0.492773
- fit = fitmanager.FitManager()
- fit.setdata(x=x, y=y, xmin=20, xmax=150)
-
- fit.loadtheories(fittheories)
-
- fit.settheory('ahypermet')
- fit.configure(Yscaling=1.,
- PositiveFwhmFlag=True,
- PositiveHeightAreaFlag=True,
- FwhmPoints=16,
- QuotedPositionFlag=1,
- HypermetTails=1)
- fit.setbackground('Linear')
- fit.estimate()
- fit.runfit()
- tab.fillFromFit(fit.fit_results)
- tab.show()
- app.exec_()
-
-if __name__ == "__main__":
- main(sys.argv)
diff --git a/silx/gui/fit/test/__init__.py b/silx/gui/fit/test/__init__.py
deleted file mode 100644
index 2236d64..0000000
--- a/silx/gui/fit/test/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-import unittest
-
-from .testFitWidget import suite as testFitWidgetSuite
-from .testFitConfig import suite as testFitConfigSuite
-from .testBackgroundWidget import suite as testBackgroundWidgetSuite
-
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "21/07/2016"
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- [testFitWidgetSuite(),
- testFitConfigSuite(),
- testBackgroundWidgetSuite()])
- return test_suite
diff --git a/silx/gui/fit/test/testBackgroundWidget.py b/silx/gui/fit/test/testBackgroundWidget.py
deleted file mode 100644
index 03b17b9..0000000
--- a/silx/gui/fit/test/testBackgroundWidget.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-import unittest
-
-from silx.gui.utils.testutils import TestCaseQt
-
-from .. import BackgroundWidget
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-
-class TestBackgroundWidget(TestCaseQt):
- def setUp(self):
- super(TestBackgroundWidget, self).setUp()
- self.bgdialog = BackgroundWidget.BackgroundDialog()
- self.bgdialog.setData(list([0, 1, 2, 3]),
- list([0, 1, 4, 8]))
- self.qWaitForWindowExposed(self.bgdialog)
-
- def tearDown(self):
- del self.bgdialog
- super(TestBackgroundWidget, self).tearDown()
-
- def testShow(self):
- self.bgdialog.show()
- self.bgdialog.hide()
-
- def testAccept(self):
- self.bgdialog.accept()
- self.assertTrue(self.bgdialog.result())
-
- def testReject(self):
- self.bgdialog.reject()
- self.assertFalse(self.bgdialog.result())
-
- def testDefaultOutput(self):
- self.bgdialog.accept()
- output = self.bgdialog.output
-
- for key in ["algorithm", "StripThreshold", "SnipWidth",
- "StripIterations", "StripWidth", "SmoothingFlag",
- "SmoothingWidth", "AnchorsFlag", "AnchorsList"]:
- self.assertIn(key, output)
-
- self.assertFalse(output["AnchorsFlag"])
- self.assertEqual(output["StripWidth"], 1)
- self.assertEqual(output["SmoothingFlag"], False)
- self.assertEqual(output["SmoothingWidth"], 3)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestBackgroundWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/fit/test/testFitConfig.py b/silx/gui/fit/test/testFitConfig.py
deleted file mode 100644
index f89c099..0000000
--- a/silx/gui/fit/test/testFitConfig.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for :class:`FitConfig`"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-import unittest
-
-from silx.gui.utils.testutils import TestCaseQt
-from .. import FitConfig
-
-
-class TestFitConfig(TestCaseQt):
- """Basic test for FitWidget"""
-
- def setUp(self):
- super(TestFitConfig, self).setUp()
- self.fit_config = FitConfig.getFitConfigDialog(modal=False)
- self.qWaitForWindowExposed(self.fit_config)
-
- def tearDown(self):
- del self.fit_config
- super(TestFitConfig, self).tearDown()
-
- def testShow(self):
- self.fit_config.show()
- self.fit_config.hide()
-
- def testAccept(self):
- self.fit_config.accept()
- self.assertTrue(self.fit_config.result())
-
- def testReject(self):
- self.fit_config.reject()
- self.assertFalse(self.fit_config.result())
-
- def testDefaultOutput(self):
- self.fit_config.accept()
- output = self.fit_config.output
-
- for key in ["AutoFwhm",
- "PositiveHeightAreaFlag",
- "QuotedPositionFlag",
- "PositiveFwhmFlag",
- "SameFwhmFlag",
- "QuotedEtaFlag",
- "NoConstraintsFlag",
- "FwhmPoints",
- "Sensitivity",
- "Yscaling",
- "ForcePeakPresence",
- "StripBackgroundFlag",
- "StripWidth",
- "StripIterations",
- "StripThreshold",
- "SmoothingFlag"]:
- self.assertIn(key, output)
-
- self.assertTrue(output["AutoFwhm"])
- self.assertEqual(output["StripWidth"], 2)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestFitConfig))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/fit/test/testFitWidget.py b/silx/gui/fit/test/testFitWidget.py
deleted file mode 100644
index cfd2bc9..0000000
--- a/silx/gui/fit/test/testFitWidget.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for :class:`FitWidget`"""
-
-import unittest
-
-from silx.gui.utils.testutils import TestCaseQt
-
-from ... import qt
-from .. import FitWidget
-
-from ....math.fit.fittheory import FitTheory
-from ....math.fit.fitmanager import FitManager
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-
-class TestFitWidget(TestCaseQt):
- """Basic test for FitWidget"""
-
- def setUp(self):
- super(TestFitWidget, self).setUp()
- self.fit_widget = FitWidget()
- self.fit_widget.show()
- self.qWaitForWindowExposed(self.fit_widget)
-
- def tearDown(self):
- self.fit_widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.fit_widget.close()
- del self.fit_widget
- super(TestFitWidget, self).tearDown()
-
- def testShow(self):
- pass
-
- def testInteract(self):
- self.mouseClick(self.fit_widget, qt.Qt.LeftButton)
- self.keyClick(self.fit_widget, qt.Qt.Key_Enter)
- self.qapp.processEvents()
-
- def testCustomConfigWidget(self):
- class CustomConfigWidget(qt.QDialog):
- def __init__(self):
- qt.QDialog.__init__(self)
- self.setModal(True)
- self.ok = qt.QPushButton("ok", self)
- self.ok.clicked.connect(self.accept)
- cancel = qt.QPushButton("cancel", self)
- cancel.clicked.connect(self.reject)
- layout = qt.QVBoxLayout(self)
- layout.addWidget(self.ok)
- layout.addWidget(cancel)
- self.output = {"hello": "world"}
-
- def fitfun(x, a, b):
- return a * x + b
-
- x = list(range(0, 100))
- y = [fitfun(x_, 2, 3) for x_ in x]
-
- def conf(**kw):
- return {"spam": "eggs",
- "hello": "world!"}
-
- theory = FitTheory(
- function=fitfun,
- parameters=["a", "b"],
- configure=conf)
-
- fitmngr = FitManager()
- fitmngr.setdata(x, y)
- fitmngr.addtheory("foo", theory)
- fitmngr.addtheory("bar", theory)
- fitmngr.addbgtheory("spam", theory)
-
- fw = FitWidget(fitmngr=fitmngr)
- fw.associateConfigDialog("spam", CustomConfigWidget(),
- theory_is_background=True)
- fw.associateConfigDialog("foo", CustomConfigWidget())
- fw.show()
- self.qWaitForWindowExposed(fw)
-
- fw.bgconfigdialogs["spam"].accept()
- self.assertTrue(fw.bgconfigdialogs["spam"].result())
-
- self.assertEqual(fw.bgconfigdialogs["spam"].output,
- {"hello": "world"})
-
- fw.bgconfigdialogs["spam"].reject()
- self.assertFalse(fw.bgconfigdialogs["spam"].result())
-
- fw.configdialogs["foo"].accept()
- self.assertTrue(fw.configdialogs["foo"].result())
-
- # todo: figure out how to click fw.configdialog.ok to close dialog
- # open dialog
- # self.mouseClick(fw.guiConfig.FunConfigureButton, qt.Qt.LeftButton)
- # clove dialog
- # self.mouseClick(fw.configdialogs["foo"].ok, qt.Qt.LeftButton)
- # self.qapp.processEvents()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestFitWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/hdf5/Hdf5Formatter.py b/silx/gui/hdf5/Hdf5Formatter.py
deleted file mode 100644
index 5754fe8..0000000
--- a/silx/gui/hdf5/Hdf5Formatter.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides a class sharred by widgets to format HDF5 data as
-text."""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "06/06/2018"
-
-import numpy
-import six
-
-from silx.gui import qt
-from silx.gui.data.TextFormatter import TextFormatter
-
-import h5py
-
-
-class Hdf5Formatter(qt.QObject):
- """Formatter to convert HDF5 data to string.
- """
-
- formatChanged = qt.Signal()
- """Emitted when properties of the formatter change."""
-
- def __init__(self, parent=None, textFormatter=None):
- """
- Constructor
-
- :param qt.QObject parent: Owner of the object
- :param TextFormatter formatter: Text formatter
- """
- qt.QObject.__init__(self, parent)
- if textFormatter is not None:
- self.__formatter = textFormatter
- else:
- self.__formatter = TextFormatter(self)
- self.__formatter.formatChanged.connect(self.__formatChanged)
-
- def textFormatter(self):
- """Returns the used text formatter
-
- :rtype: TextFormatter
- """
- return self.__formatter
-
- def setTextFormatter(self, textFormatter):
- """Set the text formatter to be used
-
- :param TextFormatter textFormatter: The text formatter to use
- """
- if textFormatter is None:
- raise ValueError("Formatter expected but None found")
- if self.__formatter is textFormatter:
- return
- self.__formatter.formatChanged.disconnect(self.__formatChanged)
- self.__formatter = textFormatter
- self.__formatter.formatChanged.connect(self.__formatChanged)
- self.__formatChanged()
-
- def __formatChanged(self):
- self.formatChanged.emit()
-
- def humanReadableShape(self, dataset):
- if dataset.shape is None:
- return "none"
- if dataset.shape == tuple():
- return "scalar"
- shape = [str(i) for i in dataset.shape]
- text = u" \u00D7 ".join(shape)
- return text
-
- def humanReadableValue(self, dataset):
- if dataset.shape is None:
- return "No data"
-
- dtype = dataset.dtype
- if dataset.dtype.type == numpy.void:
- if dtype.fields is None:
- return "Raw data"
-
- if dataset.shape == tuple():
- numpy_object = dataset[()]
- text = self.__formatter.toString(numpy_object, dtype=dataset.dtype)
- else:
- if dataset.size < 5 and dataset.compression is None:
- numpy_object = dataset[0:5]
- text = self.__formatter.toString(numpy_object, dtype=dataset.dtype)
- else:
- dimension = len(dataset.shape)
- if dataset.compression is not None:
- text = "Compressed %dD data" % dimension
- else:
- text = "%dD data" % dimension
- return text
-
- def humanReadableType(self, dataset, full=False):
- if hasattr(dataset, "dtype"):
- dtype = dataset.dtype
- else:
- # Fallback...
- dtype = type(dataset)
- return self.humanReadableDType(dtype, full)
-
- def humanReadableDType(self, dtype, full=False):
- if dtype == six.binary_type or numpy.issubdtype(dtype, numpy.string_):
- text = "string"
- if full:
- text = "ASCII " + text
- return text
- elif dtype == six.text_type or numpy.issubdtype(dtype, numpy.unicode_):
- text = "string"
- if full:
- text = "UTF-8 " + text
- return text
- elif dtype.type == numpy.object_:
- ref = h5py.check_dtype(ref=dtype)
- if ref is not None:
- return "reference"
- vlen = h5py.check_dtype(vlen=dtype)
- if vlen is not None:
- text = self.humanReadableDType(vlen, full=full)
- if full:
- text = "variable-length " + text
- return text
- return "object"
- elif dtype.type == numpy.bool_:
- return "bool"
- elif dtype.type == numpy.void:
- if dtype.fields is None:
- return "opaque"
- else:
- if not full:
- return "compound"
- else:
- fields = sorted(dtype.fields.items(), key=lambda e: e[1][1])
- compound = [d[1][0] for d in fields]
- compound = [self.humanReadableDType(d) for d in compound]
- return "compound(%s)" % ", ".join(compound)
- elif numpy.issubdtype(dtype, numpy.integer):
- enumType = h5py.check_dtype(enum=dtype)
- if enumType is not None:
- return "enum"
-
- text = str(dtype.newbyteorder('N'))
- if numpy.issubdtype(dtype, numpy.floating):
- if hasattr(numpy, "float128") and dtype == numpy.float128:
- text = "float80"
- if full:
- text += " (padding 128bits)"
- elif hasattr(numpy, "float96") and dtype == numpy.float96:
- text = "float80"
- if full:
- text += " (padding 96bits)"
-
- if full:
- if dtype.byteorder == "<":
- text = "Little-endian " + text
- elif dtype.byteorder == ">":
- text = "Big-endian " + text
- elif dtype.byteorder == "=":
- text = "Native " + text
-
- dtype = dtype.newbyteorder('N')
- return text
-
- def humanReadableHdf5Type(self, dataset):
- """Format the internal HDF5 type as a string"""
- t = dataset.id.get_type()
- class_ = t.get_class()
- if class_ == h5py.h5t.NO_CLASS:
- return "NO_CLASS"
- elif class_ == h5py.h5t.INTEGER:
- return "INTEGER"
- elif class_ == h5py.h5t.FLOAT:
- return "FLOAT"
- elif class_ == h5py.h5t.TIME:
- return "TIME"
- elif class_ == h5py.h5t.STRING:
- charset = t.get_cset()
- strpad = t.get_strpad()
- text = ""
-
- if strpad == h5py.h5t.STR_NULLTERM:
- text += "NULLTERM"
- elif strpad == h5py.h5t.STR_NULLPAD:
- text += "NULLPAD"
- elif strpad == h5py.h5t.STR_SPACEPAD:
- text += "SPACEPAD"
- else:
- text += "UNKNOWN_STRPAD"
-
- if t.is_variable_str():
- text += " VARIABLE"
-
- if charset == h5py.h5t.CSET_ASCII:
- text += " ASCII"
- elif charset == h5py.h5t.CSET_UTF8:
- text += " UTF8"
- else:
- text += " UNKNOWN_CSET"
-
- return text + " STRING"
- elif class_ == h5py.h5t.BITFIELD:
- return "BITFIELD"
- elif class_ == h5py.h5t.OPAQUE:
- return "OPAQUE"
- elif class_ == h5py.h5t.COMPOUND:
- return "COMPOUND"
- elif class_ == h5py.h5t.REFERENCE:
- return "REFERENCE"
- elif class_ == h5py.h5t.ENUM:
- return "ENUM"
- elif class_ == h5py.h5t.VLEN:
- return "VLEN"
- elif class_ == h5py.h5t.ARRAY:
- return "ARRAY"
- else:
- return "UNKNOWN_CLASS"
diff --git a/silx/gui/hdf5/Hdf5HeaderView.py b/silx/gui/hdf5/Hdf5HeaderView.py
deleted file mode 100644
index 7baa6e0..0000000
--- a/silx/gui/hdf5/Hdf5HeaderView.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "16/06/2017"
-
-
-from .. import qt
-from .Hdf5TreeModel import Hdf5TreeModel
-
-QTVERSION = qt.qVersion()
-
-
-class Hdf5HeaderView(qt.QHeaderView):
- """
- Default HDF5 header
-
- Manage auto-resize and context menu to display/hide columns
- """
-
- def __init__(self, orientation, parent=None):
- """
- Constructor
-
- :param orientation qt.Qt.Orientation: Orientation of the header
- :param parent qt.QWidget: Parent of the widget
- """
- super(Hdf5HeaderView, self).__init__(orientation, parent)
- self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
- self.customContextMenuRequested.connect(self.__createContextMenu)
-
- # default initialization done by QTreeView for it's own header
- if QTVERSION < "5.0":
- self.setClickable(True)
- self.setMovable(True)
- else:
- self.setSectionsClickable(True)
- self.setSectionsMovable(True)
- self.setDefaultAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
- self.setStretchLastSection(True)
-
- self.__auto_resize = True
- self.__hide_columns_popup = True
-
- def setModel(self, model):
- """Override model to configure view when a model is expected
-
- `qt.QHeaderView.setResizeMode` expect already existing columns
- to work.
-
- :param model qt.QAbstractItemModel: A model
- """
- super(Hdf5HeaderView, self).setModel(model)
- self.__updateAutoResize()
-
- def __updateAutoResize(self):
- """Update the view according to the state of the auto-resize"""
- if QTVERSION < "5.0":
- setResizeMode = self.setResizeMode
- else:
- setResizeMode = self.setSectionResizeMode
-
- if self.__auto_resize:
- setResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.ResizeToContents)
- setResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.ResizeToContents)
- setResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.ResizeToContents)
- setResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.ResizeToContents)
- setResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.ResizeToContents)
- else:
- setResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.Interactive)
- setResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.Interactive)
-
- def setAutoResizeColumns(self, autoResize):
- """Enable/disable auto-resize. When auto-resized, the header take care
- of the content of the column to set fixed size of some of them, or to
- auto fix the size according to the content.
-
- :param autoResize bool: Enable/disable auto-resize
- """
- if self.__auto_resize == autoResize:
- return
- self.__auto_resize = autoResize
- self.__updateAutoResize()
-
- def hasAutoResizeColumns(self):
- """Is auto-resize enabled.
-
- :rtype: bool
- """
- return self.__auto_resize
-
- autoResizeColumns = qt.Property(bool, hasAutoResizeColumns, setAutoResizeColumns)
- """Property to enable/disable auto-resize."""
-
- def setEnableHideColumnsPopup(self, enablePopup):
- """Enable/disable a popup to allow to hide/show each column of the
- model.
-
- :param bool enablePopup: Enable/disable popup to hide/show columns
- """
- self.__hide_columns_popup = enablePopup
-
- def hasHideColumnsPopup(self):
- """Is popup to hide/show columns is enabled.
-
- :rtype: bool
- """
- return self.__hide_columns_popup
-
- enableHideColumnsPopup = qt.Property(bool, hasHideColumnsPopup, setAutoResizeColumns)
- """Property to enable/disable popup allowing to hide/show columns."""
-
- def __genHideSectionEvent(self, column):
- """Generate a callback which change the column visibility according to
- the event parameter
-
- :param int column: logical id of the column
- :rtype: callable
- """
- return lambda checked: self.setSectionHidden(column, not checked)
-
- def __createContextMenu(self, pos):
- """Callback to create and display a context menu
-
- :param pos qt.QPoint: Requested position for the context menu
- """
- if not self.__hide_columns_popup:
- return
-
- model = self.model()
- if model.columnCount() > 1:
- menu = qt.QMenu(self)
- menu.setTitle("Display/hide columns")
-
- action = qt.QAction("Display/hide column", self)
- action.setEnabled(False)
- menu.addAction(action)
-
- for column in range(model.columnCount()):
- if column == 0:
- # skip the main column
- continue
- text = model.headerData(column, qt.Qt.Horizontal, qt.Qt.DisplayRole)
- action = qt.QAction("%s displayed" % text, self)
- action.setCheckable(True)
- action.setChecked(not self.isSectionHidden(column))
- action.toggled.connect(self.__genHideSectionEvent(column))
- menu.addAction(action)
-
- menu.popup(self.viewport().mapToGlobal(pos))
-
- def setSections(self, logicalIndexes):
- """
- Defines order of visible sections by logical indexes.
-
- Use `Hdf5TreeModel.NAME_COLUMN` to set the list.
-
- :param list logicalIndexes: List of logical indexes to display
- """
- for pos, column_id in enumerate(logicalIndexes):
- current_pos = self.visualIndex(column_id)
- self.moveSection(current_pos, pos)
- self.setSectionHidden(column_id, False)
- for column_id in set(range(self.model().columnCount())) - set(logicalIndexes):
- self.setSectionHidden(column_id, True)
diff --git a/silx/gui/hdf5/Hdf5TreeModel.py b/silx/gui/hdf5/Hdf5TreeModel.py
deleted file mode 100644
index 152f3e5..0000000
--- a/silx/gui/hdf5/Hdf5TreeModel.py
+++ /dev/null
@@ -1,778 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "12/03/2019"
-
-
-import os
-import logging
-import functools
-from .. import qt
-from .. import icons
-from .Hdf5Node import Hdf5Node
-from .Hdf5Item import Hdf5Item
-from .Hdf5LoadingItem import Hdf5LoadingItem
-from . import _utils
-from ... import io as silx_io
-
-_logger = logging.getLogger(__name__)
-
-"""Helpers to take care of None objects as signal parameters.
-PySide crash if a signal with a None parameter is emitted between threads.
-"""
-if qt.BINDING == 'PySide':
- class _NoneWraper(object):
- pass
- _NoneWraperInstance = _NoneWraper()
-
- def _wrapNone(x):
- """Wrap x if it is a None value, else returns x"""
- if x is None:
- return _NoneWraperInstance
- else:
- return x
-
- def _unwrapNone(x):
- """Unwrap x as a None if a None was stored by `wrapNone`, else returns
- x"""
- if x is _NoneWraperInstance:
- return None
- else:
- return x
-else:
- # Allow to fix None event params to avoid PySide crashes
- def _wrapNone(x):
- return x
-
- def _unwrapNone(x):
- return x
-
-
-def _createRootLabel(h5obj):
- """
- Create label for the very first npde of the tree.
-
- :param h5obj: The h5py object to display in the GUI
- :type h5obj: h5py-like object
- :rtpye: str
- """
- if silx_io.is_file(h5obj):
- label = os.path.basename(h5obj.filename)
- else:
- filename = os.path.basename(h5obj.file.filename)
- path = h5obj.name
- if path.startswith("/"):
- path = path[1:]
- label = "%s::%s" % (filename, path)
- return label
-
-
-class LoadingItemRunnable(qt.QRunnable):
- """Runner to process item loading from a file"""
-
- class __Signals(qt.QObject):
- """Signal holder"""
- itemReady = qt.Signal(object, object, object)
- runnerFinished = qt.Signal(object)
-
- def __init__(self, filename, item):
- """Constructor
-
- :param LoadingItemWorker worker: Object holding data and signals
- """
- super(LoadingItemRunnable, self).__init__()
- self.filename = filename
- self.oldItem = item
- self.signals = self.__Signals()
-
- def setFile(self, filename, item):
- self.filenames.append((filename, item))
-
- @property
- def itemReady(self):
- return self.signals.itemReady
-
- @property
- def runnerFinished(self):
- return self.signals.runnerFinished
-
- def __loadItemTree(self, oldItem, h5obj):
- """Create an item tree used by the GUI from an h5py object.
-
- :param Hdf5Node oldItem: The current item displayed the GUI
- :param h5py.File h5obj: The h5py object to display in the GUI
- :rtpye: Hdf5Node
- """
- text = _createRootLabel(h5obj)
- item = Hdf5Item(text=text, obj=h5obj, parent=oldItem.parent, populateAll=True)
- return item
-
- def run(self):
- """Process the file loading. The worker is used as holder
- of the data and the signal. The result is sent as a signal.
- """
- h5file = None
- try:
- h5file = silx_io.open(self.filename)
- newItem = self.__loadItemTree(self.oldItem, h5file)
- error = None
- except IOError as e:
- # Should be logged
- error = e
- newItem = None
- if h5file is not None:
- h5file.close()
-
- # Take care of None value in case of PySide
- newItem = _wrapNone(newItem)
- error = _wrapNone(error)
- self.itemReady.emit(self.oldItem, newItem, error)
- self.runnerFinished.emit(self)
-
- def autoDelete(self):
- return True
-
-
-class Hdf5TreeModel(qt.QAbstractItemModel):
- """Tree model storing a list of :class:`h5py.File` like objects.
-
- The main column display the :class:`h5py.File` list and there hierarchy.
- Other columns display information on node hierarchy.
- """
-
- H5PY_ITEM_ROLE = qt.Qt.UserRole
- """Role to reach h5py item from an item index"""
-
- H5PY_OBJECT_ROLE = qt.Qt.UserRole + 1
- """Role to reach h5py object from an item index"""
-
- USER_ROLE = qt.Qt.UserRole + 2
- """Start of range of available user role for derivative models"""
-
- NAME_COLUMN = 0
- """Column id containing HDF5 node names"""
-
- TYPE_COLUMN = 1
- """Column id containing HDF5 dataset types"""
-
- SHAPE_COLUMN = 2
- """Column id containing HDF5 dataset shapes"""
-
- VALUE_COLUMN = 3
- """Column id containing HDF5 dataset values"""
-
- DESCRIPTION_COLUMN = 4
- """Column id containing HDF5 node description/title/message"""
-
- NODE_COLUMN = 5
- """Column id containing HDF5 node type"""
-
- LINK_COLUMN = 6
- """Column id containing HDF5 link type"""
-
- COLUMN_IDS = [
- NAME_COLUMN,
- TYPE_COLUMN,
- SHAPE_COLUMN,
- VALUE_COLUMN,
- DESCRIPTION_COLUMN,
- NODE_COLUMN,
- LINK_COLUMN,
- ]
- """List of logical columns available"""
-
- sigH5pyObjectLoaded = qt.Signal(object)
- """Emitted when a new root item was loaded and inserted to the model."""
-
- sigH5pyObjectRemoved = qt.Signal(object)
- """Emitted when a root item is removed from the model."""
-
- sigH5pyObjectSynchronized = qt.Signal(object, object)
- """Emitted when an item was synchronized."""
-
- def __init__(self, parent=None, ownFiles=True):
- """
- Constructor
-
- :param qt.QWidget parent: Parent widget
- :param bool ownFiles: If true (default) the model will manage the files
- life cycle when they was added using path (like DnD).
- """
- super(Hdf5TreeModel, self).__init__(parent)
-
- self.header_labels = [None] * len(self.COLUMN_IDS)
- self.header_labels[self.NAME_COLUMN] = 'Name'
- self.header_labels[self.TYPE_COLUMN] = 'Type'
- self.header_labels[self.SHAPE_COLUMN] = 'Shape'
- self.header_labels[self.VALUE_COLUMN] = 'Value'
- self.header_labels[self.DESCRIPTION_COLUMN] = 'Description'
- self.header_labels[self.NODE_COLUMN] = 'Node'
- self.header_labels[self.LINK_COLUMN] = 'Link'
-
- # Create items
- self.__root = Hdf5Node()
- self.__fileDropEnabled = True
- self.__fileMoveEnabled = True
- self.__datasetDragEnabled = False
-
- self.__animatedIcon = icons.getWaitIcon()
- self.__animatedIcon.iconChanged.connect(self.__updateLoadingItems)
- self.__runnerSet = set([])
-
- # store used icons to avoid the cache to release it
- self.__icons = []
- self.__icons.append(icons.getQIcon("item-none"))
- self.__icons.append(icons.getQIcon("item-0dim"))
- self.__icons.append(icons.getQIcon("item-1dim"))
- self.__icons.append(icons.getQIcon("item-2dim"))
- self.__icons.append(icons.getQIcon("item-3dim"))
- self.__icons.append(icons.getQIcon("item-ndim"))
-
- self.__ownFiles = ownFiles
- self.__openedFiles = []
- """Store the list of files opened by the model itself."""
- # FIXME: It should be managed one by one by Hdf5Item itself
-
- # It is not possible to override the QObject destructor nor
- # to access to the content of the Python object with the `destroyed`
- # signal cause the Python method was already removed with the QWidget,
- # while the QObject still exists.
- # We use a static method plus explicit references to objects to
- # release. The callback do not use any ref to self.
- onDestroy = functools.partial(self._closeFileList, self.__openedFiles)
- self.destroyed.connect(onDestroy)
-
- @staticmethod
- def _closeFileList(fileList):
- """Static method to close explicit references to internal objects."""
- _logger.debug("Clear Hdf5TreeModel")
- for obj in fileList:
- _logger.debug("Close file %s", obj.filename)
- obj.close()
- fileList[:] = []
-
- def _closeOpened(self):
- """Close files which was opened by this model.
-
- File are opened by the model when it was inserted using
- `insertFileAsync`, `insertFile`, `appendFile`."""
- self._closeFileList(self.__openedFiles)
-
- def __updateLoadingItems(self, icon):
- for i in range(self.__root.childCount()):
- item = self.__root.child(i)
- if isinstance(item, Hdf5LoadingItem):
- index1 = self.index(i, 0, qt.QModelIndex())
- index2 = self.index(i, self.columnCount() - 1, qt.QModelIndex())
- self.dataChanged.emit(index1, index2)
-
- def __itemReady(self, oldItem, newItem, error):
- """Called at the end of a concurent file loading, when the loading
- item is ready. AN error is defined if an exception occured when
- loading the newItem .
-
- :param Hdf5Node oldItem: current displayed item
- :param Hdf5Node newItem: item loaded, or None if error is defined
- :param Exception error: An exception, or None if newItem is defined
- """
- # Take care of None value in case of PySide
- newItem = _unwrapNone(newItem)
- error = _unwrapNone(error)
- row = self.__root.indexOfChild(oldItem)
-
- rootIndex = qt.QModelIndex()
- self.beginRemoveRows(rootIndex, row, row)
- self.__root.removeChildAtIndex(row)
- self.endRemoveRows()
-
- if newItem is not None:
- rootIndex = qt.QModelIndex()
- if self.__ownFiles:
- self.__openedFiles.append(newItem.obj)
- self.beginInsertRows(rootIndex, row, row)
- self.__root.insertChild(row, newItem)
- self.endInsertRows()
-
- if isinstance(oldItem, Hdf5LoadingItem):
- self.sigH5pyObjectLoaded.emit(newItem.obj)
- else:
- self.sigH5pyObjectSynchronized.emit(oldItem.obj, newItem.obj)
-
- # FIXME the error must be displayed
-
- def isFileDropEnabled(self):
- return self.__fileDropEnabled
-
- def setFileDropEnabled(self, enabled):
- self.__fileDropEnabled = enabled
-
- fileDropEnabled = qt.Property(bool, isFileDropEnabled, setFileDropEnabled)
- """Property to enable/disable file dropping in the model."""
-
- def isDatasetDragEnabled(self):
- return self.__datasetDragEnabled
-
- def setDatasetDragEnabled(self, enabled):
- self.__datasetDragEnabled = enabled
-
- datasetDragEnabled = qt.Property(bool, isDatasetDragEnabled, setDatasetDragEnabled)
- """Property to enable/disable drag of datasets."""
-
- def isFileMoveEnabled(self):
- return self.__fileMoveEnabled
-
- def setFileMoveEnabled(self, enabled):
- self.__fileMoveEnabled = enabled
-
- fileMoveEnabled = qt.Property(bool, isFileMoveEnabled, setFileMoveEnabled)
- """Property to enable/disable drag-and-drop of files to
- change the ordering in the model."""
-
- def supportedDropActions(self):
- if self.__fileMoveEnabled or self.__fileDropEnabled:
- return qt.Qt.CopyAction | qt.Qt.MoveAction
- else:
- return 0
-
- def mimeTypes(self):
- types = []
- if self.__fileMoveEnabled or self.__datasetDragEnabled:
- types.append(_utils.Hdf5DatasetMimeData.MIME_TYPE)
- return types
-
- def mimeData(self, indexes):
- """
- Returns an object that contains serialized items of data corresponding
- to the list of indexes specified.
-
- :param List[qt.QModelIndex] indexes: List of indexes
- :rtype: qt.QMimeData
- """
- if len(indexes) == 0:
- return None
-
- indexes = [i for i in indexes if i.column() == 0]
- if len(indexes) > 1:
- raise NotImplementedError("Drag of multi rows is not implemented")
- if len(indexes) == 0:
- raise NotImplementedError("Drag of cell is not implemented")
-
- node = self.nodeFromIndex(indexes[0])
-
- if self.__fileMoveEnabled and node.parent is self.__root:
- mimeData = _utils.Hdf5DatasetMimeData(node=node, isRoot=True)
- elif self.__datasetDragEnabled:
- mimeData = _utils.Hdf5DatasetMimeData(node=node)
- else:
- mimeData = None
- return mimeData
-
- def flags(self, index):
- defaultFlags = qt.QAbstractItemModel.flags(self, index)
-
- if index.isValid():
- node = self.nodeFromIndex(index)
- if self.__fileMoveEnabled and node.parent is self.__root:
- # that's a root
- return qt.Qt.ItemIsDragEnabled | defaultFlags
- elif self.__datasetDragEnabled:
- return qt.Qt.ItemIsDragEnabled | defaultFlags
- return defaultFlags
- elif self.__fileDropEnabled or self.__fileMoveEnabled:
- return qt.Qt.ItemIsDropEnabled | defaultFlags
- else:
- return defaultFlags
-
- def dropMimeData(self, mimedata, action, row, column, parentIndex):
- if action == qt.Qt.IgnoreAction:
- return True
-
- if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5DatasetMimeData.MIME_TYPE):
- if mimedata.isRoot():
- dragNode = mimedata.node()
- parentNode = self.nodeFromIndex(parentIndex)
- if parentNode is not dragNode.parent:
- return False
-
- if row == -1:
- # append to the parent
- row = parentNode.childCount()
- else:
- # insert at row
- pass
-
- dragNodeParent = dragNode.parent
- sourceRow = dragNodeParent.indexOfChild(dragNode)
- self.moveRow(parentIndex, sourceRow, parentIndex, row)
- return True
-
- if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"):
-
- parentNode = self.nodeFromIndex(parentIndex)
- if parentNode is not self.__root:
- while(parentNode is not self.__root):
- node = parentNode
- parentNode = node.parent
- row = parentNode.indexOfChild(node)
- else:
- if row == -1:
- row = self.__root.childCount()
-
- messages = []
- for url in mimedata.urls():
- try:
- self.insertFileAsync(url.toLocalFile(), row)
- row += 1
- except IOError as e:
- messages.append(e.args[0])
- if len(messages) > 0:
- title = "Error occurred when loading files"
- message = "<html>%s:<ul><li>%s</li><ul></html>" % (title, "</li><li>".join(messages))
- qt.QMessageBox.critical(None, title, message)
- return True
-
- return False
-
- def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
- if orientation == qt.Qt.Horizontal:
- if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
- return self.header_labels[section]
- return None
-
- def insertNode(self, row, node):
- if row == -1:
- row = self.__root.childCount()
- self.beginInsertRows(qt.QModelIndex(), row, row)
- self.__root.insertChild(row, node)
- self.endInsertRows()
-
- def moveRow(self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow):
- if sourceRow == destinationRow or sourceRow == destinationRow - 1:
- # abort move, same place
- return
- return self.moveRows(sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow)
-
- def moveRows(self, sourceParentIndex, sourceRow, count, destinationParentIndex, destinationRow):
- self.beginMoveRows(sourceParentIndex, sourceRow, sourceRow, destinationParentIndex, destinationRow)
- sourceNode = self.nodeFromIndex(sourceParentIndex)
- destinationNode = self.nodeFromIndex(destinationParentIndex)
-
- if sourceNode is destinationNode and sourceRow < destinationRow:
- item = sourceNode.child(sourceRow)
- destinationNode.insertChild(destinationRow, item)
- sourceNode.removeChildAtIndex(sourceRow)
- else:
- item = sourceNode.removeChildAtIndex(sourceRow)
- destinationNode.insertChild(destinationRow, item)
-
- self.endMoveRows()
- return True
-
- def index(self, row, column, parent=qt.QModelIndex()):
- try:
- node = self.nodeFromIndex(parent)
- return self.createIndex(row, column, node.child(row))
- except IndexError:
- return qt.QModelIndex()
-
- def data(self, index, role=qt.Qt.DisplayRole):
- node = self.nodeFromIndex(index)
-
- if role == self.H5PY_ITEM_ROLE:
- return node
-
- if role == self.H5PY_OBJECT_ROLE:
- return node.obj
-
- if index.column() == self.NAME_COLUMN:
- return node.dataName(role)
- elif index.column() == self.TYPE_COLUMN:
- return node.dataType(role)
- elif index.column() == self.SHAPE_COLUMN:
- return node.dataShape(role)
- elif index.column() == self.VALUE_COLUMN:
- return node.dataValue(role)
- elif index.column() == self.DESCRIPTION_COLUMN:
- return node.dataDescription(role)
- elif index.column() == self.NODE_COLUMN:
- return node.dataNode(role)
- elif index.column() == self.LINK_COLUMN:
- return node.dataLink(role)
- else:
- return None
-
- def columnCount(self, parent=qt.QModelIndex()):
- return len(self.COLUMN_IDS)
-
- def hasChildren(self, parent=qt.QModelIndex()):
- node = self.nodeFromIndex(parent)
- if node is None:
- return 0
- return node.hasChildren()
-
- def rowCount(self, parent=qt.QModelIndex()):
- node = self.nodeFromIndex(parent)
- if node is None:
- return 0
- return node.childCount()
-
- def parent(self, child):
- if not child.isValid():
- return qt.QModelIndex()
-
- node = self.nodeFromIndex(child)
-
- if node is None:
- return qt.QModelIndex()
-
- parent = node.parent
-
- if parent is None:
- return qt.QModelIndex()
-
- grandparent = parent.parent
- if grandparent is None:
- return qt.QModelIndex()
- row = grandparent.indexOfChild(parent)
-
- assert row != - 1
- return self.createIndex(row, 0, parent)
-
- def nodeFromIndex(self, index):
- return index.internalPointer() if index.isValid() else self.__root
-
- def _closeFileIfOwned(self, node):
- """"Close the file if it was loaded from a filename or a
- drag-and-drop"""
- obj = node.obj
- for f in self.__openedFiles:
- if f is obj:
- _logger.debug("Close file %s", obj.filename)
- obj.close()
- self.__openedFiles.remove(obj)
-
- def synchronizeIndex(self, index):
- """
- Synchronize a file a given its index.
-
- Basically close it and load it again.
-
- :param qt.QModelIndex index: Index of the item to update
- """
- node = self.nodeFromIndex(index)
- if node.parent is not self.__root:
- return
-
- filename = node.obj.filename
- self.insertFileAsync(filename, index.row(), synchronizingNode=node)
-
- def h5pyObjectRow(self, h5pyObject):
- for row in range(self.__root.childCount()):
- item = self.__root.child(row)
- if item.obj == h5pyObject:
- return row
- return -1
-
- def synchronizeH5pyObject(self, h5pyObject):
- """
- Synchronize a h5py object in all the tree.
-
- Basically close it and load it again.
-
- :param h5py.File h5pyObject: A :class:`h5py.File` object.
- """
- index = 0
- while index < self.__root.childCount():
- item = self.__root.child(index)
- if item.obj == h5pyObject:
- qindex = self.index(index, 0, qt.QModelIndex())
- self.synchronizeIndex(qindex)
- index += 1
-
- def removeIndex(self, index):
- """
- Remove an item from the model using its index.
-
- :param qt.QModelIndex index: Index of the item to remove
- """
- node = self.nodeFromIndex(index)
- if node.parent != self.__root:
- return
- self._closeFileIfOwned(node)
- self.beginRemoveRows(qt.QModelIndex(), index.row(), index.row())
- self.__root.removeChildAtIndex(index.row())
- self.endRemoveRows()
- self.sigH5pyObjectRemoved.emit(node.obj)
-
- def removeH5pyObject(self, h5pyObject):
- """
- Remove an item from the model using the holding h5py object.
- It can remove more than one item.
-
- :param h5py.File h5pyObject: A :class:`h5py.File` object.
- """
- index = 0
- while index < self.__root.childCount():
- item = self.__root.child(index)
- if item.obj == h5pyObject:
- qindex = self.index(index, 0, qt.QModelIndex())
- self.removeIndex(qindex)
- else:
- index += 1
-
- def insertH5pyObject(self, h5pyObject, text=None, row=-1):
- """Append an HDF5 object from h5py to the tree.
-
- :param h5pyObject: File handle/descriptor for a :class:`h5py.File`
- or any other class of h5py file structure.
- """
- if text is None:
- text = _createRootLabel(h5pyObject)
- if row == -1:
- row = self.__root.childCount()
- self.insertNode(row, Hdf5Item(text=text, obj=h5pyObject, parent=self.__root))
-
- def hasPendingOperations(self):
- return len(self.__runnerSet) > 0
-
- def insertFileAsync(self, filename, row=-1, synchronizingNode=None):
- if not os.path.isfile(filename):
- raise IOError("Filename '%s' must be a file path" % filename)
-
- # create temporary item
- if synchronizingNode is None:
- text = os.path.basename(filename)
- item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon)
- self.insertNode(row, item)
- else:
- item = synchronizingNode
-
- # start loading the real one
- runnable = LoadingItemRunnable(filename, item)
- runnable.itemReady.connect(self.__itemReady)
- runnable.runnerFinished.connect(self.__releaseRunner)
- self.__runnerSet.add(runnable)
- qt.silxGlobalThreadPool().start(runnable)
-
- def __releaseRunner(self, runner):
- self.__runnerSet.remove(runner)
-
- def insertFile(self, filename, row=-1):
- """Load a HDF5 file into the data model.
-
- :param filename: file path.
- """
- try:
- h5file = silx_io.open(filename)
- if self.__ownFiles:
- self.__openedFiles.append(h5file)
- self.sigH5pyObjectLoaded.emit(h5file)
- self.insertH5pyObject(h5file, row=row)
- except IOError:
- _logger.debug("File '%s' can't be read.", filename, exc_info=True)
- raise
-
- def clear(self):
- """Remove all the content of the model"""
- for _ in range(self.rowCount()):
- qindex = self.index(0, 0, qt.QModelIndex())
- self.removeIndex(qindex)
-
- def appendFile(self, filename):
- self.insertFile(filename, -1)
-
- def indexFromH5Object(self, h5Object):
- """Returns a model index from an h5py-like object.
-
- :param object h5Object: An h5py-like object
- :rtype: qt.QModelIndex
- """
- if h5Object is None:
- return qt.QModelIndex()
-
- filename = h5Object.file.filename
-
- # Seach for the right roots
- rootIndices = []
- for index in range(self.rowCount(qt.QModelIndex())):
- index = self.index(index, 0, qt.QModelIndex())
- obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
- if obj.file.filename == filename:
- # We can have many roots with different subtree of the same
- # root
- rootIndices.append(index)
-
- if len(rootIndices) == 0:
- # No root found
- return qt.QModelIndex()
-
- path = h5Object.name + "/"
- path = path.replace("//", "/")
-
- # Search for the right node
- found = False
- foundIndices = []
- for _ in range(1000 * len(rootIndices)):
- # Avoid too much iterations, in case of recurssive links
- if len(foundIndices) == 0:
- if len(rootIndices) == 0:
- # Nothing found
- break
- # Start fron a new root
- foundIndices.append(rootIndices.pop(0))
-
- obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
- p = obj.name + "/"
- p = p.replace("//", "/")
- if path == p:
- found = True
- break
-
- parentIndex = foundIndices[-1]
- for index in range(self.rowCount(parentIndex)):
- index = self.index(index, 0, parentIndex)
- obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
-
- p = obj.name + "/"
- p = p.replace("//", "/")
- if path == p:
- foundIndices.append(index)
- found = True
- break
- elif path.startswith(p):
- foundIndices.append(index)
- break
- else:
- # Nothing found, start again with another root
- foundIndices = []
-
- if found:
- break
-
- if found:
- return foundIndices[-1]
- return qt.QModelIndex()
diff --git a/silx/gui/hdf5/Hdf5TreeView.py b/silx/gui/hdf5/Hdf5TreeView.py
deleted file mode 100644
index a86140a..0000000
--- a/silx/gui/hdf5/Hdf5TreeView.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "30/04/2018"
-
-
-import logging
-from .. import qt
-from ...utils import weakref as silxweakref
-from .Hdf5TreeModel import Hdf5TreeModel
-from .Hdf5HeaderView import Hdf5HeaderView
-from .NexusSortFilterProxyModel import NexusSortFilterProxyModel
-from .Hdf5Item import Hdf5Item
-from . import _utils
-
-_logger = logging.getLogger(__name__)
-
-
-class Hdf5TreeView(qt.QTreeView):
- """TreeView which allow to browse HDF5 file structure.
-
- .. image:: img/Hdf5TreeView.png
-
- It provides columns width auto-resizing and additional
- signals.
-
- The default model is a :class:`NexusSortFilterProxyModel` sourcing
- a :class:`Hdf5TreeModel`. The :class:`Hdf5TreeModel` is reachable using
- :meth:`findHdf5TreeModel`. The default header is :class:`Hdf5HeaderView`.
-
- Context menu is managed by the :meth:`setContextMenuPolicy` with the value
- Qt.CustomContextMenu. This policy must not be changed, otherwise context
- menus will not work anymore. You can use :meth:`addContextMenuCallback` and
- :meth:`removeContextMenuCallback` to add your custum actions according
- to the selected objects.
- """
- def __init__(self, parent=None):
- """
- Constructor
-
- :param parent qt.QWidget: The parent widget
- """
- qt.QTreeView.__init__(self, parent)
-
- model = self.createDefaultModel()
- self.setModel(model)
-
- self.setHeader(Hdf5HeaderView(qt.Qt.Horizontal, self))
- self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- self.sortByColumn(0, qt.Qt.AscendingOrder)
- # optimise the rendering
- self.setUniformRowHeights(True)
-
- self.setIconSize(qt.QSize(16, 16))
- self.setAcceptDrops(True)
- self.setDragEnabled(True)
- self.setDragDropMode(qt.QAbstractItemView.DragDrop)
- self.showDropIndicator()
-
- self.__context_menu_callbacks = silxweakref.WeakList()
- self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
- self.customContextMenuRequested.connect(self._createContextMenu)
-
- def createDefaultModel(self):
- """Creates and returns the default model.
-
- Inherite to custom the default model"""
- model = Hdf5TreeModel(self)
- proxy_model = NexusSortFilterProxyModel(self)
- proxy_model.setSourceModel(model)
- return proxy_model
-
- def __removeContextMenuProxies(self, ref):
- """Callback to remove dead proxy from the list"""
- self.__context_menu_callbacks.remove(ref)
-
- def _createContextMenu(self, pos):
- """
- Create context menu.
-
- :param pos qt.QPoint: Position of the context menu
- """
- actions = []
-
- menu = qt.QMenu(self)
-
- hovered_index = self.indexAt(pos)
- hovered_node = self.model().data(hovered_index, Hdf5TreeModel.H5PY_ITEM_ROLE)
- if hovered_node is None or not isinstance(hovered_node, Hdf5Item):
- return
-
- hovered_object = _utils.H5Node(hovered_node)
- event = _utils.Hdf5ContextMenuEvent(self, menu, hovered_object)
-
- for callback in self.__context_menu_callbacks:
- try:
- callback(event)
- except KeyboardInterrupt:
- raise
- except Exception:
- # make sure no user callback crash the application
- _logger.error("Error while calling callback", exc_info=True)
- pass
-
- if not menu.isEmpty():
- for action in actions:
- menu.addAction(action)
- menu.popup(self.viewport().mapToGlobal(pos))
-
- def addContextMenuCallback(self, callback):
- """Register a context menu callback.
-
- The callback will be called when a context menu is requested with the
- treeview and the list of selected h5py objects in parameters. The
- callback must return a list of :class:`qt.QAction` object.
-
- Callbacks are stored as saferef. The object must store a reference by
- itself.
- """
- self.__context_menu_callbacks.append(callback)
-
- def removeContextMenuCallback(self, callback):
- """Unregister a context menu callback"""
- self.__context_menu_callbacks.remove(callback)
-
- def findHdf5TreeModel(self):
- """Find the Hdf5TreeModel from the stack of model filters.
-
- :returns: A Hdf5TreeModel, else None
- :rtype: Hdf5TreeModel
- """
- model = self.model()
- while model is not None:
- if isinstance(model, qt.QAbstractProxyModel):
- model = model.sourceModel()
- else:
- break
- if model is None:
- return None
- if isinstance(model, Hdf5TreeModel):
- return model
- else:
- return None
-
- def dragEnterEvent(self, event):
- model = self.findHdf5TreeModel()
- if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
- self.setState(qt.QAbstractItemView.DraggingState)
- event.accept()
- else:
- qt.QTreeView.dragEnterEvent(self, event)
-
- def dragMoveEvent(self, event):
- model = self.findHdf5TreeModel()
- if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
- event.setDropAction(qt.Qt.CopyAction)
- event.accept()
- else:
- qt.QTreeView.dragMoveEvent(self, event)
-
- def selectedH5Nodes(self, ignoreBrokenLinks=True):
- """Returns selected h5py objects like :class:`h5py.File`,
- :class:`h5py.Group`, :class:`h5py.Dataset` or mimicked objects.
-
- :param ignoreBrokenLinks bool: Returns objects which are not not
- broken links.
- :rtype: iterator(:class:`_utils.H5Node`)
- """
- for index in self.selectedIndexes():
- if index.column() != 0:
- continue
- item = self.model().data(index, Hdf5TreeModel.H5PY_ITEM_ROLE)
- if item is None:
- continue
- if isinstance(item, Hdf5Item):
- if ignoreBrokenLinks and item.isBrokenObj():
- continue
- yield _utils.H5Node(item)
-
- def __intermediateModels(self, index):
- """Returns intermediate models from the view model to the
- model of the index."""
- models = []
- targetModel = index.model()
- model = self.model()
- while model is not None:
- if model is targetModel:
- # found
- return models
- models.append(model)
- if isinstance(model, qt.QAbstractProxyModel):
- model = model.sourceModel()
- else:
- break
- raise RuntimeError("Model from the requested index is not reachable from this view")
-
- def mapToModel(self, index):
- """Map an index from any model reachable by the view to an index from
- the very first model connected to the view.
-
- :param qt.QModelIndex index: Index from the Hdf5Tree model
- :rtype: qt.QModelIndex
- :return: Index from the model connected to the view
- """
- if not index.isValid():
- return index
- models = self.__intermediateModels(index)
- for model in reversed(models):
- index = model.mapFromSource(index)
- return index
-
- def setSelectedH5Node(self, h5Object):
- """
- Select the specified node of the tree using an h5py node.
-
- - If the item is found, parent items are expended, and then the item
- is selected.
- - If the item is not found, the selection do not change.
- - A none argument allow to deselect everything
-
- :param h5py.Node h5Object: The node to select
- """
- if h5Object is None:
- self.setCurrentIndex(qt.QModelIndex())
- return
-
- model = self.findHdf5TreeModel()
- index = model.indexFromH5Object(h5Object)
- index = self.mapToModel(index)
- if index.isValid():
- # Update the GUI
- i = index
- while i.isValid():
- self.expand(i)
- i = i.parent()
- self.setCurrentIndex(index)
-
- def mousePressEvent(self, event):
- """Override mousePressEvent to provide a consistante compatible API
- between Qt4 and Qt5
- """
- super(Hdf5TreeView, self).mousePressEvent(event)
- if event.button() != qt.Qt.LeftButton:
- # Qt5 only sends itemClicked on left button mouse click
- if qt.qVersion() > "5":
- qindex = self.indexAt(event.pos())
- self.clicked.emit(qindex)
diff --git a/silx/gui/hdf5/_utils.py b/silx/gui/hdf5/_utils.py
deleted file mode 100644
index aaab228..0000000
--- a/silx/gui/hdf5/_utils.py
+++ /dev/null
@@ -1,461 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides a set of helper class and function used by the
-package `silx.gui.hdf5` package.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/01/2019"
-
-
-import logging
-import os.path
-
-import silx.io.utils
-import silx.io.url
-from .. import qt
-from silx.utils.html import escape
-
-_logger = logging.getLogger(__name__)
-
-
-class Hdf5ContextMenuEvent(object):
- """Hold information provided to context menu callbacks."""
-
- def __init__(self, source, menu, hoveredObject):
- """
- Constructor
-
- :param QWidget source: Widget source
- :param QMenu menu: Context menu which will be displayed
- :param H5Node hoveredObject: Hovered H5 node
- """
- self.__source = source
- self.__menu = menu
- self.__hoveredObject = hoveredObject
-
- def source(self):
- """Source of the event
-
- :rtype: Hdf5TreeView
- """
- return self.__source
-
- def menu(self):
- """Menu which will be displayed
-
- :rtype: qt.QMenu
- """
- return self.__menu
-
- def hoveredObject(self):
- """Item content hovered by the mouse when the context menu was
- requested
-
- :rtype: H5Node
- """
- return self.__hoveredObject
-
-
-def htmlFromDict(dictionary, title=None):
- """Generate a readable HTML from a dictionary
-
- :param dict dictionary: A Dictionary
- :rtype: str
- """
- result = """<html>
- <head>
- <style type="text/css">
- ul { -qt-list-indent: 0; list-style: none; }
- li > b {display: inline-block; min-width: 4em; font-weight: bold; }
- </style>
- </head>
- <body>
- """
- if title is not None:
- result += "<b>%s</b>" % escape(title)
- result += "<ul>"
- for key, value in dictionary.items():
- result += "<li><b>%s</b>: %s</li>" % (escape(key), escape(value))
- result += "</ul>"
- result += "</body></html>"
- return result
-
-
-class Hdf5DatasetMimeData(qt.QMimeData):
- """Mimedata class to identify an internal drag and drop of a Hdf5Node."""
-
- MIME_TYPE = "application/x-internal-h5py-dataset"
-
- SILX_URI_TYPE = "application/x-silx-uri"
-
- def __init__(self, node=None, dataset=None, isRoot=False):
- qt.QMimeData.__init__(self)
- self.__dataset = dataset
- self.__node = node
- self.__isRoot = isRoot
- self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
- if node is not None:
- h5Node = H5Node(node)
- silxUrl = h5Node.url
- self.setText(silxUrl)
- self.setData(self.SILX_URI_TYPE, silxUrl.encode(encoding='utf-8'))
-
- def isRoot(self):
- return self.__isRoot
-
- def node(self):
- return self.__node
-
- def dataset(self):
- if self.__node is not None:
- return self.__node.obj
- return self.__dataset
-
-
-class H5Node(object):
- """Adapter over an h5py object to provide missing informations from h5py
- nodes, like internal node path and filename (which are not provided by
- :mod:`h5py` for soft and external links).
-
- It also provides an abstraction to reach node type for mimicked h5py
- objects.
- """
-
- def __init__(self, h5py_item=None):
- """Constructor
-
- :param Hdf5Item h5py_item: An Hdf5Item
- """
- self.__h5py_object = h5py_item.obj
- self.__h5py_target = None
- self.__h5py_item = h5py_item
-
- def __getattr__(self, name):
- if hasattr(self.__h5py_object, name):
- attr = getattr(self.__h5py_object, name)
- return attr
- raise AttributeError("H5Node has no attribute %s" % name)
-
- def __get_target(self, obj):
- """
- Return the actual physical target of the provided object.
-
- Objects can contains links in the middle of the path, this function
- check each groups and remove this prefix in case of the link by the
- link of the path.
-
- :param obj: A valid h5py object (File, group or dataset)
- :type obj: h5py.Dataset or h5py.Group or h5py.File
- :rtype: h5py.Dataset or h5py.Group or h5py.File
- """
- elements = obj.name.split("/")
- if obj.name == "/":
- return obj
- elif obj.name.startswith("/"):
- elements.pop(0)
- path = ""
- subpath = ""
- while len(elements) > 0:
- e = elements.pop(0)
- subpath = path + "/" + e
- link = obj.parent.get(subpath, getlink=True)
- classlink = silx.io.utils.get_h5_class(link)
-
- if classlink == silx.io.utils.H5Type.EXTERNAL_LINK:
- subpath = "/".join(elements)
- external_obj = obj.parent.get(self.basename + "/" + subpath)
- return self.__get_target(external_obj)
- elif classlink == silx.io.utils.H5Type.SOFT_LINK:
- # Restart from this stat
- root_elements = link.path.split("/")
- if link.path == "/":
- path = ""
- root_elements = []
- elif link.path.startswith("/"):
- path = ""
- root_elements.pop(0)
-
- for name in reversed(root_elements):
- elements.insert(0, name)
- else:
- path = subpath
-
- return obj.file[path]
-
- @property
- def h5py_target(self):
- if self.__h5py_target is not None:
- return self.__h5py_target
- self.__h5py_target = self.__get_target(self.__h5py_object)
- return self.__h5py_target
-
- @property
- def h5py_object(self):
- """Returns the internal h5py node.
-
- :rtype: h5py.File or h5py.Group or h5py.Dataset
- """
- return self.__h5py_object
-
- @property
- def h5type(self):
- """Returns the node type, as an H5Type.
-
- :rtype: H5Node
- """
- return silx.io.utils.get_h5_class(self.__h5py_object)
-
- @property
- def ntype(self):
- """Returns the node type, as an h5py class.
-
- :rtype:
- :class:`h5py.File`, :class:`h5py.Group` or :class:`h5py.Dataset`
- """
- type_ = self.h5type
- return silx.io.utils.h5type_to_h5py_class(type_)
-
- @property
- def basename(self):
- """Returns the basename of this h5py node. It is the last identifier of
- the path.
-
- :rtype: str
- """
- return self.__h5py_object.name.split("/")[-1]
-
- @property
- def is_broken(self):
- """Returns true if the node is a broken link.
-
- :rtype: bool
- """
- if self.__h5py_item is None:
- raise RuntimeError("h5py_item is not defined")
- return self.__h5py_item.isBrokenObj()
-
- @property
- def local_name(self):
- """Returns the path from the master file root to this node.
-
- For links, this path is not equal to the h5py one.
-
- :rtype: str
- """
- if self.__h5py_item is None:
- raise RuntimeError("h5py_item is not defined")
-
- result = []
- item = self.__h5py_item
- while item is not None:
- # stop before the root item (item without parent)
- if item.parent.parent is None:
- name = item.obj.name
- if name != "/":
- result.append(item.obj.name)
- break
- else:
- result.append(item.basename)
- item = item.parent
- if item is None:
- raise RuntimeError("The item does not have parent holding h5py.File")
- if result == []:
- return "/"
- if not result[-1].startswith("/"):
- result.append("")
- result.reverse()
- name = "/".join(result)
- return name
-
- def __get_local_file(self):
- """Returns the file of the root of this tree
-
- :rtype: h5py.File
- """
- item = self.__h5py_item
- while item.parent.parent is not None:
- class_ = silx.io.utils.get_h5_class(class_=item.h5pyClass)
- if class_ == silx.io.utils.H5Type.FILE:
- break
- item = item.parent
-
- class_ = silx.io.utils.get_h5_class(class_=item.h5pyClass)
- if class_ == silx.io.utils.H5Type.FILE:
- return item.obj
- else:
- return item.obj.file
-
- @property
- def local_file(self):
- """Returns the master file in which is this node.
-
- For path containing external links, this file is not equal to the h5py
- one.
-
- :rtype: h5py.File
- :raises RuntimeException: If no file are found
- """
- return self.__get_local_file()
-
- @property
- def local_filename(self):
- """Returns the filename from the master file of this node.
-
- For path containing external links, this path is not equal to the
- filename provided by h5py.
-
- :rtype: str
- :raises RuntimeException: If no file are found
- """
- return self.local_file.filename
-
- @property
- def local_basename(self):
- """Returns the basename from the master file root to this node.
-
- For path containing links, this basename can be different than the
- basename provided by h5py.
-
- :rtype: str
- """
- class_ = self.__h5py_item.h5Class
- if class_ is not None and class_ == silx.io.utils.H5Type.FILE:
- return ""
- return self.__h5py_item.basename
-
- @property
- def physical_file(self):
- """Returns the physical file in which is this node.
-
- .. versionadded:: 0.6
-
- :rtype: h5py.File
- :raises RuntimeError: If no file are found
- """
- class_ = silx.io.utils.get_h5_class(self.__h5py_object)
- if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
- # It means the link is broken
- raise RuntimeError("No file node found")
- if class_ == silx.io.utils.H5Type.SOFT_LINK:
- # It means the link is broken
- return self.local_file
-
- physical_obj = self.h5py_target
- return physical_obj.file
-
- @property
- def physical_name(self):
- """Returns the path from the location this h5py node is physically
- stored.
-
- For broken links, this filename can be different from the
- filename provided by h5py.
-
- :rtype: str
- """
- class_ = silx.io.utils.get_h5_class(self.__h5py_object)
- if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
- # It means the link is broken
- return self.__h5py_object.path
- if class_ == silx.io.utils.H5Type.SOFT_LINK:
- # It means the link is broken
- return self.__h5py_object.path
-
- physical_obj = self.h5py_target
- return physical_obj.name
-
- @property
- def physical_filename(self):
- """Returns the filename from the location this h5py node is physically
- stored.
-
- For broken links, this filename can be different from the
- filename provided by h5py.
-
- :rtype: str
- """
- class_ = silx.io.utils.get_h5_class(self.__h5py_object)
- if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
- # It means the link is broken
- return self.__h5py_object.filename
- if class_ == silx.io.utils.H5Type.SOFT_LINK:
- # It means the link is broken
- return self.local_file.filename
-
- return self.physical_file.filename
-
- @property
- def physical_basename(self):
- """Returns the basename from the location this h5py node is physically
- stored.
-
- For broken links, this basename can be different from the
- basename provided by h5py.
-
- :rtype: str
- """
- return self.physical_name.split("/")[-1]
-
- @property
- def data_url(self):
- """Returns a :class:`silx.io.url.DataUrl` object identify this node in the file
- system.
-
- :rtype: ~silx.io.url.DataUrl
- """
- absolute_filename = os.path.abspath(self.local_filename)
- return silx.io.url.DataUrl(scheme="silx",
- file_path=absolute_filename,
- data_path=self.local_name)
-
- @property
- def url(self):
- """Returns an URL object identifying this node in the file
- system.
-
- This URL can be used in different ways.
-
- .. code-block:: python
-
- # Parsing the URL
- import silx.io.url
- dataurl = silx.io.url.DataUrl(item.url)
- # dataurl provides access to URL fields
-
- # Open a numpy array
- import silx.io
- dataset = silx.io.get_data(item.url)
-
- # Open an hdf5 object (URL targetting a file or a group)
- import silx.io
- with silx.io.open(item.url) as h5:
- ...your stuff...
-
- :rtype: str
- """
- data_url = self.data_url
- return data_url.path()
diff --git a/silx/gui/hdf5/test/__init__.py b/silx/gui/hdf5/test/__init__.py
deleted file mode 100644
index 3000d96..0000000
--- a/silx/gui/hdf5/test/__init__.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-import unittest
-
-from . import test_hdf5
-
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "28/09/2016"
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- [test_hdf5.suite()])
- return test_suite
diff --git a/silx/gui/hdf5/test/test_hdf5.py b/silx/gui/hdf5/test/test_hdf5.py
deleted file mode 100755
index fcfc02c..0000000
--- a/silx/gui/hdf5/test/test_hdf5.py
+++ /dev/null
@@ -1,1140 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for silx.gui.hdf5 module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "12/03/2019"
-
-
-import time
-import os
-import unittest
-import tempfile
-import numpy
-import shutil
-from contextlib import contextmanager
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import hdf5
-from silx.gui.utils.testutils import SignalListener
-from silx.io import commonh5
-import weakref
-
-import h5py
-
-
-_tmpDirectory = None
-
-
-def setUpModule():
- global _tmpDirectory
- _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
-
- filename = _tmpDirectory + "/data.h5"
-
- # create h5 data
- f = h5py.File(filename, "w")
- g = f.create_group("arrays")
- g.create_dataset("scalar", data=10)
- f.close()
-
-
-def tearDownModule():
- global _tmpDirectory
- shutil.rmtree(_tmpDirectory)
- _tmpDirectory = None
-
-
-_called = 0
-
-
-class _Holder(object):
- def callback(self, *args, **kvargs):
- _called += 1
-
-
-def create_NXentry(group, name):
- attrs = {"NX_class": "NXentry"}
- node = commonh5.Group(name, parent=group, attrs=attrs)
- group.add_node(node)
- return node
-
-
-class TestHdf5TreeModel(TestCaseQt):
-
- def setUp(self):
- super(TestHdf5TreeModel, self).setUp()
-
- def waitForPendingOperations(self, model):
- for _ in range(10):
- if not model.hasPendingOperations():
- break
- self.qWait(10)
- else:
- raise RuntimeError("Still waiting for a pending operation")
-
- @contextmanager
- def h5TempFile(self):
- # create tmp file
- fd, tmp_name = tempfile.mkstemp(suffix=".h5")
- os.close(fd)
- # create h5 data
- h5file = h5py.File(tmp_name, "w")
- g = h5file.create_group("arrays")
- g.create_dataset("scalar", data=10)
- h5file.close()
- yield tmp_name
- # clean up
- os.unlink(tmp_name)
-
- def testCreate(self):
- model = hdf5.Hdf5TreeModel()
- self.assertIsNotNone(model)
-
- def testAppendFilename(self):
- filename = _tmpDirectory + "/data.h5"
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.appendFile(filename)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- # clean up
- ref = weakref.ref(model)
- model = None
- self.qWaitForDestroy(ref)
-
- def testAppendBadFilename(self):
- model = hdf5.Hdf5TreeModel()
- self.assertRaises(IOError, model.appendFile, "#%$")
-
- def testInsertFilename(self):
- filename = _tmpDirectory + "/data.h5"
- try:
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.insertFile(filename)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- # clean up
- index = model.index(0, 0, qt.QModelIndex())
- h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- self.assertIsNotNone(h5File)
- finally:
- ref = weakref.ref(model)
- model = None
- self.qWaitForDestroy(ref)
-
- def testInsertFilenameAsync(self):
- filename = _tmpDirectory + "/data.h5"
- try:
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.insertFileAsync(filename)
- index = model.index(0, 0, qt.QModelIndex())
- self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem)
- self.waitForPendingOperations(model)
- index = model.index(0, 0, qt.QModelIndex())
- self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
- finally:
- ref = weakref.ref(model)
- model = None
- self.qWaitForDestroy(ref)
-
- def testInsertObject(self):
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.insertH5pyObject(h5)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
-
- def testRemoveObject(self):
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.insertH5pyObject(h5)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- model.removeH5pyObject(h5)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
-
- def testSynchronizeObject(self):
- filename = _tmpDirectory + "/data.h5"
- h5 = h5py.File(filename, mode="r")
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(h5)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- index = model.index(0, 0, qt.QModelIndex())
- node1 = model.nodeFromIndex(index)
- model.synchronizeH5pyObject(h5)
- self.waitForPendingOperations(model)
- # Now h5 was loaded from it's filename
- # Another ref is owned by the model
- h5.close()
-
- index = model.index(0, 0, qt.QModelIndex())
- node2 = model.nodeFromIndex(index)
- self.assertIsNot(node1, node2)
- # after sync
- time.sleep(0.1)
- self.qapp.processEvents()
- time.sleep(0.1)
- index = model.index(0, 0, qt.QModelIndex())
- self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
- # clean up
- index = model.index(0, 0, qt.QModelIndex())
- h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- self.assertIsNotNone(h5File)
- h5File = None
- # delete the model
- ref = weakref.ref(model)
- model = None
- self.qWaitForDestroy(ref)
-
- def testFileMoveState(self):
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.isFileMoveEnabled(), True)
- model.setFileMoveEnabled(False)
- self.assertEqual(model.isFileMoveEnabled(), False)
-
- def testFileDropState(self):
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.isFileDropEnabled(), True)
- model.setFileDropEnabled(False)
- self.assertEqual(model.isFileDropEnabled(), False)
-
- def testSupportedDrop(self):
- model = hdf5.Hdf5TreeModel()
- self.assertNotEqual(model.supportedDropActions(), 0)
-
- model.setFileMoveEnabled(False)
- model.setFileDropEnabled(False)
- self.assertEqual(model.supportedDropActions(), 0)
-
- model.setFileMoveEnabled(False)
- model.setFileDropEnabled(True)
- self.assertNotEqual(model.supportedDropActions(), 0)
-
- model.setFileMoveEnabled(True)
- model.setFileDropEnabled(False)
- self.assertNotEqual(model.supportedDropActions(), 0)
-
- def testCloseFile(self):
- """A file inserted as a filename is open and closed internally."""
- filename = _tmpDirectory + "/data.h5"
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.insertFile(filename)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- index = model.index(0, 0)
- h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- model.removeIndex(index)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- self.assertFalse(bool(h5File.id.valid), "The HDF5 file was not closed")
-
- def testNotCloseFile(self):
- """A file inserted as an h5py object is not open (then not closed)
- internally."""
- filename = _tmpDirectory + "/data.h5"
- try:
- h5File = h5py.File(filename, mode="r")
- model = hdf5.Hdf5TreeModel()
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- model.insertH5pyObject(h5File)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- index = model.index(0, 0)
- h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- model.removeIndex(index)
- self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
- self.assertTrue(bool(h5File.id.valid), "The HDF5 file was unexpetedly closed")
- finally:
- h5File.close()
-
- def testDropExternalFile(self):
- filename = _tmpDirectory + "/data.h5"
- model = hdf5.Hdf5TreeModel()
- mimeData = qt.QMimeData()
- mimeData.setUrls([qt.QUrl.fromLocalFile(filename)])
- model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex())
- self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
- # after sync
- self.waitForPendingOperations(model)
- index = model.index(0, 0, qt.QModelIndex())
- self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
- # clean up
- index = model.index(0, 0, qt.QModelIndex())
- h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
- self.assertIsNotNone(h5File)
- h5File = None
- ref = weakref.ref(model)
- model = None
- self.qWaitForDestroy(ref)
-
- def getRowDataAsDict(self, model, row):
- displayed = {}
- roles = [qt.Qt.DisplayRole, qt.Qt.DecorationRole, qt.Qt.ToolTipRole, qt.Qt.TextAlignmentRole]
- for column in range(0, model.columnCount(qt.QModelIndex())):
- index = model.index(0, column, qt.QModelIndex())
- for role in roles:
- datum = model.data(index, role)
- displayed[column, role] = datum
- return displayed
-
- def getItemName(self, model, row):
- index = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, qt.QModelIndex())
- return model.data(index, qt.Qt.DisplayRole)
-
- def testFileData(self):
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(h5)
- displayed = self.getRowDataAsDict(model, row=0)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock")
- self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], None)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File")
-
- def testGroupData(self):
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- d = h5.create_group("foo")
- d.attrs["desc"] = "fooo"
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(d)
- displayed = self.getRowDataAsDict(model, row=0)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
- self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group")
-
- def testDatasetData(self):
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- value = numpy.array([1, 2, 3])
- d = h5.create_dataset("foo", data=value)
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(d)
- displayed = self.getRowDataAsDict(model, row=0)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
- self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], value.dtype.name)
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
- self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset")
-
- def testDropLastAsFirst(self):
- model = hdf5.Hdf5TreeModel()
- h5_1 = commonh5.File("/foo/bar/1.mock", "w")
- h5_2 = commonh5.File("/foo/bar/2.mock", "w")
- model.insertH5pyObject(h5_1)
- model.insertH5pyObject(h5_2)
- self.assertEqual(self.getItemName(model, 0), "1.mock")
- self.assertEqual(self.getItemName(model, 1), "2.mock")
- index = model.index(1, 0, qt.QModelIndex())
- mimeData = model.mimeData([index])
- model.dropMimeData(mimeData, qt.Qt.MoveAction, 0, 0, qt.QModelIndex())
- self.assertEqual(self.getItemName(model, 0), "2.mock")
- self.assertEqual(self.getItemName(model, 1), "1.mock")
-
- def testDropFirstAsLast(self):
- model = hdf5.Hdf5TreeModel()
- h5_1 = commonh5.File("/foo/bar/1.mock", "w")
- h5_2 = commonh5.File("/foo/bar/2.mock", "w")
- model.insertH5pyObject(h5_1)
- model.insertH5pyObject(h5_2)
- self.assertEqual(self.getItemName(model, 0), "1.mock")
- self.assertEqual(self.getItemName(model, 1), "2.mock")
- index = model.index(0, 0, qt.QModelIndex())
- mimeData = model.mimeData([index])
- model.dropMimeData(mimeData, qt.Qt.MoveAction, 2, 0, qt.QModelIndex())
- self.assertEqual(self.getItemName(model, 0), "2.mock")
- self.assertEqual(self.getItemName(model, 1), "1.mock")
-
- def testRootParent(self):
- model = hdf5.Hdf5TreeModel()
- h5_1 = commonh5.File("/foo/bar/1.mock", "w")
- model.insertH5pyObject(h5_1)
- index = model.index(0, 0, qt.QModelIndex())
- index = model.parent(index)
- self.assertEqual(index, qt.QModelIndex())
-
-
-class TestHdf5TreeModelSignals(TestCaseQt):
-
- def setUp(self):
- TestCaseQt.setUp(self)
- self.model = hdf5.Hdf5TreeModel()
- filename = _tmpDirectory + "/data.h5"
- self.h5 = h5py.File(filename, mode='r')
- self.model.insertH5pyObject(self.h5)
-
- self.listener = SignalListener()
- self.model.sigH5pyObjectLoaded.connect(self.listener.partial(signal="loaded"))
- self.model.sigH5pyObjectRemoved.connect(self.listener.partial(signal="removed"))
- self.model.sigH5pyObjectSynchronized.connect(self.listener.partial(signal="synchronized"))
-
- def tearDown(self):
- self.signals = None
- ref = weakref.ref(self.model)
- self.model = None
- self.qWaitForDestroy(ref)
- self.h5.close()
- self.h5 = None
- TestCaseQt.tearDown(self)
-
- def waitForPendingOperations(self, model):
- for _ in range(10):
- if not model.hasPendingOperations():
- break
- self.qWait(10)
- else:
- raise RuntimeError("Still waiting for a pending operation")
-
- def testInsert(self):
- filename = _tmpDirectory + "/data.h5"
- h5 = h5py.File(filename, mode='r')
- self.model.insertH5pyObject(h5)
- self.assertEqual(self.listener.callCount(), 0)
-
- def testLoaded(self):
- filename = _tmpDirectory + "/data.h5"
- self.model.insertFile(filename)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertEqual(self.listener.karguments(argumentName="signal")[0], "loaded")
- self.assertIsNot(self.listener.arguments(callIndex=0)[0], self.h5)
- self.assertEqual(self.listener.arguments(callIndex=0)[0].filename, filename)
-
- def testRemoved(self):
- self.model.removeH5pyObject(self.h5)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertEqual(self.listener.karguments(argumentName="signal")[0], "removed")
- self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
-
- def testSynchonized(self):
- self.model.synchronizeH5pyObject(self.h5)
- self.waitForPendingOperations(self.model)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertEqual(self.listener.karguments(argumentName="signal")[0], "synchronized")
- self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
- self.assertIsNot(self.listener.arguments(callIndex=0)[1], self.h5)
-
-
-class TestNexusSortFilterProxyModel(TestCaseQt):
-
- def getChildNames(self, model, index):
- count = model.rowCount(index)
- result = []
- for row in range(0, count):
- itemIndex = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, index)
- name = model.data(itemIndex, qt.Qt.DisplayRole)
- result.append(name)
- return result
-
- def testNXentryStartTime(self):
- """Test NXentry with start_time"""
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- create_NXentry(h5, "a").create_dataset("start_time", data=numpy.string_("2015"))
- create_NXentry(h5, "b").create_dataset("start_time", data=numpy.string_("2013"))
- create_NXentry(h5, "c").create_dataset("start_time", data=numpy.string_("2014"))
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.DescendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a", "c", "b"])
-
- def testNXentryStartTimeInArray(self):
- """Test NXentry with start_time"""
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- create_NXentry(h5, "a").create_dataset("start_time", data=numpy.array([numpy.string_("2015")]))
- create_NXentry(h5, "b").create_dataset("start_time", data=numpy.array([numpy.string_("2013")]))
- create_NXentry(h5, "c").create_dataset("start_time", data=numpy.array([numpy.string_("2014")]))
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.DescendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a", "c", "b"])
-
- def testNXentryEndTimeInArray(self):
- """Test NXentry with end_time"""
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- create_NXentry(h5, "a").create_dataset("end_time", data=numpy.array([numpy.string_("2015")]))
- create_NXentry(h5, "b").create_dataset("end_time", data=numpy.array([numpy.string_("2013")]))
- create_NXentry(h5, "c").create_dataset("end_time", data=numpy.array([numpy.string_("2014")]))
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.DescendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a", "c", "b"])
-
- def testNXentryName(self):
- """Test NXentry without start_time or end_time"""
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- create_NXentry(h5, "a")
- create_NXentry(h5, "c")
- create_NXentry(h5, "b")
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.AscendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a", "b", "c"])
-
- def testStartTime(self):
- """If it is not NXentry, start_time is not used"""
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- h5.create_group("a").create_dataset("start_time", data=numpy.string_("2015"))
- h5.create_group("b").create_dataset("start_time", data=numpy.string_("2013"))
- h5.create_group("c").create_dataset("start_time", data=numpy.string_("2014"))
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.AscendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a", "b", "c"])
-
- def testName(self):
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- h5.create_group("a")
- h5.create_group("c")
- h5.create_group("b")
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.AscendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a", "b", "c"])
-
- def testNumber(self):
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- h5.create_group("a1")
- h5.create_group("a20")
- h5.create_group("a3")
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.AscendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a1", "a3", "a20"])
-
- def testMultiNumber(self):
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- h5.create_group("a1-1")
- h5.create_group("a20-1")
- h5.create_group("a3-1")
- h5.create_group("a3-20")
- h5.create_group("a3-3")
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.AscendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["a1-1", "a3-1", "a3-3", "a3-20", "a20-1"])
-
- def testUnconsistantTypes(self):
- model = hdf5.Hdf5TreeModel()
- h5 = commonh5.File("/foo/bar/1.mock", "w")
- h5.create_group("aaa100")
- h5.create_group("100aaa")
- model.insertH5pyObject(h5)
-
- proxy = hdf5.NexusSortFilterProxyModel()
- proxy.setSourceModel(model)
- proxy.sort(0, qt.Qt.AscendingOrder)
- names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
- self.assertListEqual(names, ["100aaa", "aaa100"])
-
-
-class _TestModelBase(TestCaseQt):
-
- @classmethod
- def setUpClass(cls):
- super(_TestModelBase, cls).setUpClass()
-
- cls.tmpDirectory = tempfile.mkdtemp()
- cls.h5Filename = cls.createResource(cls.tmpDirectory)
- cls.h5File = h5py.File(cls.h5Filename, mode="r")
- cls.model = cls.createModel(cls.h5File)
-
- @classmethod
- def createResource(cls, directory):
- filename = os.path.join(directory, "base.h5")
- extH5FileName = os.path.join(directory, "base__external.h5")
- extDatFileName = os.path.join(directory, "base__external.dat")
-
- externalh5 = h5py.File(extH5FileName, mode="w")
- externalh5["target/dataset"] = 50
- externalh5["target/link"] = h5py.SoftLink("/target/dataset")
- externalh5["/ext/vds0"] = [0, 1]
- externalh5["/ext/vds1"] = [2, 3]
- externalh5.close()
-
- numpy.array([0,1,10,10,2,3]).tofile(extDatFileName)
-
- h5 = h5py.File(filename, mode="w")
- h5["group/dataset"] = 50
- h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
- h5["link/soft_link_to_group"] = h5py.SoftLink("/group")
- h5["link/soft_link_to_link"] = h5py.SoftLink("/link/soft_link")
- h5["link/soft_link_to_file"] = h5py.SoftLink("/")
- h5["group/soft_link_relative"] = h5py.SoftLink("dataset")
- h5["link/external_link"] = h5py.ExternalLink(extH5FileName, "/target/dataset")
- h5["link/external_link_to_link"] = h5py.ExternalLink(extH5FileName, "/target/link")
- h5["broken_link/external_broken_file"] = h5py.ExternalLink(extH5FileName + "_not_exists", "/target/link")
- h5["broken_link/external_broken_link"] = h5py.ExternalLink(extH5FileName, "/target/not_exists")
- h5["broken_link/soft_broken_link"] = h5py.SoftLink("/group/not_exists")
- h5["broken_link/soft_link_to_broken_link"] = h5py.SoftLink("/group/not_exists")
- layout = h5py.VirtualLayout((2,2), dtype=int)
- layout[0] = h5py.VirtualSource("base__external.h5", name="/ext/vds0", shape=(2,), dtype=int)
- layout[1] = h5py.VirtualSource("base__external.h5", name="/ext/vds1", shape=(2,), dtype=int)
- h5.create_group("/ext")
- h5["/ext"].create_virtual_dataset("virtual", layout)
- external = [("base__external.dat", 0, 2*8), ("base__external.dat", 4*8, 2*8)]
- h5["/ext"].create_dataset("raw", shape=(2,2), dtype=int, external=external)
- h5.close()
-
- return filename
-
- @classmethod
- def createModel(cls, h5pyFile):
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(h5pyFile)
- return model
-
- @classmethod
- def tearDownClass(cls):
- ref = weakref.ref(cls.model)
- cls.model = None
- cls.qWaitForDestroy(ref)
- cls.h5File.close()
- shutil.rmtree(cls.tmpDirectory)
- super(_TestModelBase, cls).tearDownClass()
-
- def getIndexFromPath(self, model, path):
- """
- :param qt.QAbstractItemModel: model
- """
- index = qt.QModelIndex()
- for name in path:
- for row in range(model.rowCount(index)):
- i = model.index(row, 0, index)
- label = model.data(i)
- if label == name:
- index = i
- break
- else:
- raise RuntimeError("Path not found")
- return index
-
- def getH5ItemFromPath(self, model, path):
- index = self.getIndexFromPath(model, path)
- return model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
-
-
-class TestH5Item(_TestModelBase):
-
- def testFile(self):
- path = ["base.h5"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
-
- def testGroup(self):
- path = ["base.h5", "group"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
-
- def testDataset(self):
- path = ["base.h5", "group", "dataset"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
-
- def testSoftLink(self):
- path = ["base.h5", "link", "soft_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
-
- def testSoftLinkToLink(self):
- path = ["base.h5", "link", "soft_link_to_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
-
- def testSoftLinkRelative(self):
- path = ["base.h5", "group", "soft_link_relative"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
-
- def testExternalLink(self):
- path = ["base.h5", "link", "external_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
-
- def testExternalLinkToLink(self):
- path = ["base.h5", "link", "external_link_to_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
-
- def testExternalBrokenFile(self):
- path = ["base.h5", "broken_link", "external_broken_file"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
-
- def testExternalBrokenLink(self):
- path = ["base.h5", "broken_link", "external_broken_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
-
- def testSoftBrokenLink(self):
- path = ["base.h5", "broken_link", "soft_broken_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
-
- def testSoftLinkToBrokenLink(self):
- path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
-
- def testDatasetFromSoftLinkToGroup(self):
- path = ["base.h5", "link", "soft_link_to_group", "dataset"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
-
- def testDatasetFromSoftLinkToFile(self):
- path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
-
- def testExternalVirtual(self):
- path = ["base.h5", "ext", "virtual"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Virtual")
-
- def testExternalRaw(self):
- path = ["base.h5", "ext", "raw"]
- h5item = self.getH5ItemFromPath(self.model, path)
-
- self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "ExtRaw")
-
-
-class TestH5Node(_TestModelBase):
-
- def getH5NodeFromPath(self, model, path):
- item = self.getH5ItemFromPath(model, path)
- h5node = hdf5.H5Node(item)
- return h5node
-
- def testFile(self):
- path = ["base.h5"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "")
- self.assertEqual(h5node.physical_name, "/")
- self.assertEqual(h5node.local_basename, "")
- self.assertEqual(h5node.local_name, "/")
-
- def testGroup(self):
- path = ["base.h5", "group"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "group")
- self.assertEqual(h5node.physical_name, "/group")
- self.assertEqual(h5node.local_basename, "group")
- self.assertEqual(h5node.local_name, "/group")
-
- def testDataset(self):
- path = ["base.h5", "group", "dataset"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "dataset")
- self.assertEqual(h5node.local_name, "/group/dataset")
-
- def testSoftLink(self):
- path = ["base.h5", "link", "soft_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "soft_link")
- self.assertEqual(h5node.local_name, "/link/soft_link")
-
- def testSoftLinkToLink(self):
- path = ["base.h5", "link", "soft_link_to_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "soft_link_to_link")
- self.assertEqual(h5node.local_name, "/link/soft_link_to_link")
-
- def testSoftLinkRelative(self):
- path = ["base.h5", "group", "soft_link_relative"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "soft_link_relative")
- self.assertEqual(h5node.local_name, "/group/soft_link_relative")
-
- def testExternalLink(self):
- path = ["base.h5", "link", "external_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.local_filename)
- self.assertIn("base__external.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/target/dataset")
- self.assertEqual(h5node.local_basename, "external_link")
- self.assertEqual(h5node.local_name, "/link/external_link")
-
- def testExternalLinkToLink(self):
- path = ["base.h5", "link", "external_link_to_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.local_filename)
- self.assertIn("base__external.h5", h5node.physical_filename)
-
- self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/target/dataset")
- self.assertEqual(h5node.local_basename, "external_link_to_link")
- self.assertEqual(h5node.local_name, "/link/external_link_to_link")
-
- def testExternalBrokenFile(self):
- path = ["base.h5", "broken_link", "external_broken_file"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.local_filename)
- self.assertIn("not_exists", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "link")
- self.assertEqual(h5node.physical_name, "/target/link")
- self.assertEqual(h5node.local_basename, "external_broken_file")
- self.assertEqual(h5node.local_name, "/broken_link/external_broken_file")
-
- def testExternalBrokenLink(self):
- path = ["base.h5", "broken_link", "external_broken_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.local_filename)
- self.assertIn("__external", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "not_exists")
- self.assertEqual(h5node.physical_name, "/target/not_exists")
- self.assertEqual(h5node.local_basename, "external_broken_link")
- self.assertEqual(h5node.local_name, "/broken_link/external_broken_link")
-
- def testSoftBrokenLink(self):
- path = ["base.h5", "broken_link", "soft_broken_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "not_exists")
- self.assertEqual(h5node.physical_name, "/group/not_exists")
- self.assertEqual(h5node.local_basename, "soft_broken_link")
- self.assertEqual(h5node.local_name, "/broken_link/soft_broken_link")
-
- def testSoftLinkToBrokenLink(self):
- path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "not_exists")
- self.assertEqual(h5node.physical_name, "/group/not_exists")
- self.assertEqual(h5node.local_basename, "soft_link_to_broken_link")
- self.assertEqual(h5node.local_name, "/broken_link/soft_link_to_broken_link")
-
- def testDatasetFromSoftLinkToGroup(self):
- path = ["base.h5", "link", "soft_link_to_group", "dataset"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "dataset")
- self.assertEqual(h5node.local_name, "/link/soft_link_to_group/dataset")
-
- def testDatasetFromSoftLinkToFile(self):
- path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "dataset")
- self.assertEqual(h5node.physical_name, "/group/dataset")
- self.assertEqual(h5node.local_basename, "dataset")
- self.assertEqual(h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset")
-
- def testExternalVirtual(self):
- path = ["base.h5", "ext", "virtual"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "virtual")
- self.assertEqual(h5node.physical_name, "/ext/virtual")
- self.assertEqual(h5node.local_basename, "virtual")
- self.assertEqual(h5node.local_name, "/ext/virtual")
-
- def testExternalRaw(self):
- path = ["base.h5", "ext", "raw"]
- h5node = self.getH5NodeFromPath(self.model, path)
-
- self.assertEqual(h5node.physical_filename, h5node.local_filename)
- self.assertIn("base.h5", h5node.physical_filename)
- self.assertEqual(h5node.physical_basename, "raw")
- self.assertEqual(h5node.physical_name, "/ext/raw")
- self.assertEqual(h5node.local_basename, "raw")
- self.assertEqual(h5node.local_name, "/ext/raw")
-
-
-class TestHdf5TreeView(TestCaseQt):
- """Test to check that icons module."""
-
- def setUp(self):
- super(TestHdf5TreeView, self).setUp()
-
- def testCreate(self):
- view = hdf5.Hdf5TreeView()
- self.assertIsNotNone(view)
-
- def testContextMenu(self):
- view = hdf5.Hdf5TreeView()
- view._createContextMenu(qt.QPoint(0, 0))
-
- def testSelection_OriginalModel(self):
- tree = commonh5.File("/foo/bar/1.mock", "w")
- item = tree.create_group("a/b/c/d")
- item.create_group("e").create_group("f")
-
- view = hdf5.Hdf5TreeView()
- view.findHdf5TreeModel().insertH5pyObject(tree)
- view.setSelectedH5Node(item)
-
- selected = list(view.selectedH5Nodes())[0]
- self.assertIs(item, selected.h5py_object)
-
- def testSelection_Simple(self):
- tree = commonh5.File("/foo/bar/1.mock", "w")
- item = tree.create_group("a/b/c/d")
- item.create_group("e").create_group("f")
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(tree)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(item)
-
- selected = list(view.selectedH5Nodes())[0]
- self.assertIs(item, selected.h5py_object)
-
- def testSelection_NotFound(self):
- tree2 = commonh5.File("/foo/bar/2.mock", "w")
- tree = commonh5.File("/foo/bar/1.mock", "w")
- item = tree.create_group("a/b/c/d")
- item.create_group("e").create_group("f")
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(tree)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(tree2)
-
- selection = list(view.selectedH5Nodes())
- self.assertEqual(len(selection), 0)
-
- def testSelection_ManyGroupFromSameFile(self):
- tree = commonh5.File("/foo/bar/1.mock", "w")
- group1 = tree.create_group("a1")
- group2 = tree.create_group("a2")
- group3 = tree.create_group("a3")
- group1.create_group("b/c/d")
- item = group2.create_group("b/c/d")
- group3.create_group("b/c/d")
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(group1)
- model.insertH5pyObject(group2)
- model.insertH5pyObject(group3)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(item)
-
- selected = list(view.selectedH5Nodes())[0]
- self.assertIs(item, selected.h5py_object)
-
- def testSelection_RootFromSubTree(self):
- tree = commonh5.File("/foo/bar/1.mock", "w")
- group = tree.create_group("a1")
- group.create_group("b/c/d")
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(group)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(group)
-
- selected = list(view.selectedH5Nodes())[0]
- self.assertIs(group, selected.h5py_object)
-
- def testSelection_FileFromSubTree(self):
- tree = commonh5.File("/foo/bar/1.mock", "w")
- group = tree.create_group("a1")
- group.create_group("b").create_group("b").create_group("d")
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(group)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(tree)
-
- selection = list(view.selectedH5Nodes())
- self.assertEqual(len(selection), 0)
-
- def testSelection_Tree(self):
- tree1 = commonh5.File("/foo/bar/1.mock", "w")
- tree2 = commonh5.File("/foo/bar/2.mock", "w")
- tree3 = commonh5.File("/foo/bar/3.mock", "w")
- tree1.create_group("a/b/c")
- tree2.create_group("a/b/c")
- tree3.create_group("a/b/c")
- item = tree2
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(tree1)
- model.insertH5pyObject(tree2)
- model.insertH5pyObject(tree3)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(item)
-
- selected = list(view.selectedH5Nodes())[0]
- self.assertIs(item, selected.h5py_object)
-
- def testSelection_RecurssiveLink(self):
- """
- Recurssive link selection
-
- This example is not really working as expected cause commonh5 do not
- support recurssive links.
- But item.name == "/a/b" and the result is found.
- """
- tree = commonh5.File("/foo/bar/1.mock", "w")
- group = tree.create_group("a")
- group.add_node(commonh5.SoftLink("b", "/"))
-
- item = tree["/a/b/a/b/a/b/a/b/a/b/a/b/a/b/a/b"]
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(tree)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(item)
-
- selected = list(view.selectedH5Nodes())[0]
- self.assertEqual(item.name, selected.h5py_object.name)
-
- def testSelection_SelectNone(self):
- tree = commonh5.File("/foo/bar/1.mock", "w")
-
- model = hdf5.Hdf5TreeModel()
- model.insertH5pyObject(tree)
- view = hdf5.Hdf5TreeView()
- view.setModel(model)
- view.setSelectedH5Node(tree)
- view.setSelectedH5Node(None)
-
- selection = list(view.selectedH5Nodes())
- self.assertEqual(len(selection), 0)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestHdf5TreeModel))
- test_suite.addTest(loadTests(TestHdf5TreeModelSignals))
- test_suite.addTest(loadTests(TestNexusSortFilterProxyModel))
- test_suite.addTest(loadTests(TestHdf5TreeView))
- test_suite.addTest(loadTests(TestH5Node))
- test_suite.addTest(loadTests(TestH5Item))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/AlphaSlider.py b/silx/gui/plot/AlphaSlider.py
deleted file mode 100644
index ab2e5aa..0000000
--- a/silx/gui/plot/AlphaSlider.py
+++ /dev/null
@@ -1,300 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module defines slider widgets interacting with the transparency
-of an image on a :class:`PlotWidget`
-
-Classes:
---------
-
-- :class:`BaseAlphaSlider` (abstract class)
-- :class:`NamedImageAlphaSlider`
-- :class:`ActiveImageAlphaSlider`
-
-Example:
---------
-
-This widget can, for instance, be added to a plot toolbar.
-
-.. code-block:: python
-
- import numpy
- from silx.gui import qt
- from silx.gui.plot import PlotWidget
- from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider
-
- app = qt.QApplication([])
- pw = PlotWidget()
-
- img0 = numpy.arange(200*150).reshape((200, 150))
- pw.addImage(img0, legend="my background", z=0, origin=(50, 50))
-
- x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200),
- numpy.linspace(-10, 5, 150),
- indexing="ij")
- img1 = numpy.asarray(numpy.sin(x * y) / (x * y),
- dtype='float32')
-
- pw.addImage(img1, legend="my data", z=1,
- replace=False)
-
- alpha_slider = NamedImageAlphaSlider(parent=pw,
- plot=pw,
- legend="my data")
- alpha_slider.setOrientation(qt.Qt.Horizontal)
-
- toolbar = qt.QToolBar("plot", pw)
- toolbar.addWidget(alpha_slider)
- pw.addToolBar(toolbar)
-
- pw.show()
- app.exec_()
-
-"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/03/2017"
-
-import logging
-
-from silx.gui import qt
-
-_logger = logging.getLogger(__name__)
-
-
-class BaseAlphaSlider(qt.QSlider):
- """Slider widget to be used in a plot toolbar to control the
- transparency of a plot primitive (image, scatter or curve).
-
- Internally, the slider stores its state as an integer between
- 0 and 255. This is the value emitted by the :attr:`valueChanged`
- signal.
-
- The method :meth:`getAlpha` returns the corresponding opacity/alpha
- as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`).
-
- You must subclass this class and implement :meth:`getItem`.
- """
- sigAlphaChanged = qt.Signal(float)
- """Emits the alpha value when the slider's value changes,
- as a float between 0. and 1."""
-
- def __init__(self, parent=None, plot=None):
- """
-
- :param parent: Parent QWidget
- :param plot: Parent plot widget
- """
- assert plot is not None
- super(BaseAlphaSlider, self).__init__(parent)
-
- self.plot = plot
-
- self.setRange(0, 255)
-
- # if already connected to an item, use its alpha as initial value
- if self.getItem() is None:
- self.setValue(255)
- self.setEnabled(False)
- else:
- alpha = self.getItem().getAlpha()
- self.setValue(round(255*alpha))
-
- self.valueChanged.connect(self._valueChanged)
-
- def getItem(self):
- """You must implement this class to define which item
- to work on. It must return an item that inherits
- :class:`silx.gui.plot.items.core.AlphaMixIn`.
-
- :return: Item on which to operate, or None
- :rtype: :class:`silx.plot.items.Item`
- """
- raise NotImplementedError(
- "BaseAlphaSlider must be subclassed to " +
- "implement getItem()")
-
- def getAlpha(self):
- """Get the opacity, as a float between 0. and 1.
-
- :return: Alpha value in [0., 1.]
- :rtype: float
- """
- return self.value() / 255.
-
- def _valueChanged(self, value):
- self._updateItem()
- self.sigAlphaChanged.emit(value / 255.)
-
- def _updateItem(self):
- """Update the item's alpha channel.
- """
- item = self.getItem()
- if item is not None:
- item.setAlpha(self.getAlpha())
-
-
-class ActiveImageAlphaSlider(BaseAlphaSlider):
- """Slider widget to be used in a plot toolbar to control the
- transparency of the **active image**.
-
- :param parent: Parent QWidget
- :param plot: Plot on which to operate
-
- See documentation of :class:`BaseAlphaSlider`
- """
- def __init__(self, parent=None, plot=None):
- """
-
- :param parent: Parent QWidget
- :param plot: Plot widget on which to operate
- """
- super(ActiveImageAlphaSlider, self).__init__(parent, plot)
- plot.sigActiveImageChanged.connect(self._activeImageChanged)
-
- def getItem(self):
- return self.plot.getActiveImage()
-
- def _activeImageChanged(self, previous, new):
- """Activate or deactivate slider depending on presence of a new
- active image.
- Apply transparency value to new active image.
-
- :param previous: Legend of previous active image, or None
- :param new: Legend of new active image, or None
- """
- if new is not None and not self.isEnabled():
- self.setEnabled(True)
- elif new is None and self.isEnabled():
- self.setEnabled(False)
-
- self._updateItem()
-
-
-class NamedItemAlphaSlider(BaseAlphaSlider):
- """Slider widget to be used in a plot toolbar to control the
- transparency of an item (defined by its kind and legend).
-
- :param parent: Parent QWidget
- :param plot: Plot on which to operate
- :param str kind: Kind of item whose transparency is to be
- controlled: "scatter", "image" or "curve".
- :param str legend: Legend of item whose transparency is to be
- controlled.
- """
- def __init__(self, parent=None, plot=None,
- kind=None, legend=None):
- self._item_legend = legend
- self._item_kind = kind
-
- super(NamedItemAlphaSlider, self).__init__(parent, plot)
-
- self._updateState()
- plot.sigContentChanged.connect(self._onContentChanged)
-
- def _onContentChanged(self, action, kind, legend):
- if legend == self._item_legend and kind == self._item_kind:
- if action == "add":
- self.setEnabled(True)
- elif action == "remove":
- self.setEnabled(False)
-
- def _updateState(self):
- """Enable or disable widget based on item's availability."""
- if self.getItem() is not None:
- self.setEnabled(True)
- else:
- self.setEnabled(False)
-
- def getItem(self):
- """Return plot item currently associated to this widget (can be
- a curve, an image, a scatter...)
-
- :rtype: subclass of :class:`silx.gui.plot.items.Item`"""
- if self._item_legend is None or self._item_kind is None:
- return None
- return self.plot._getItem(kind=self._item_kind,
- legend=self._item_legend)
-
- def setLegend(self, legend):
- """Associate a different item (of the same kind) to the slider.
-
- :param legend: New legend of item whose transparency is to be
- controlled.
- """
- self._item_legend = legend
- self._updateState()
-
- def getLegend(self):
- """Return legend of the item currently controlled by this slider.
-
- :return: Image legend associated to the slider
- """
- return self._item_kind
-
- def setItemKind(self, legend):
- """Associate a different item (of the same kind) to the slider.
-
- :param legend: New legend of item whose transparency is to be
- controlled.
- """
- self._item_legend = legend
- self._updateState()
-
- def getItemKind(self):
- """Return kind of the item currently controlled by this slider.
-
- :return: Item kind ("image", "scatter"...)
- :rtype: str on None
- """
- return self._item_kind
-
-
-class NamedImageAlphaSlider(NamedItemAlphaSlider):
- """Slider widget to be used in a plot toolbar to control the
- transparency of an image (defined by its legend).
-
- :param parent: Parent QWidget
- :param plot: Plot on which to operate
- :param str legend: Legend of image whose transparency is to be
- controlled.
- """
- def __init__(self, parent=None, plot=None, legend=None):
- NamedItemAlphaSlider.__init__(self, parent, plot,
- kind="image", legend=legend)
-
-
-class NamedScatterAlphaSlider(NamedItemAlphaSlider):
- """Slider widget to be used in a plot toolbar to control the
- transparency of a scatter (defined by its legend).
-
- :param parent: Parent QWidget
- :param plot: Plot on which to operate
- :param str legend: Legend of scatter whose transparency is to be
- controlled.
- """
- def __init__(self, parent=None, plot=None, legend=None):
- NamedItemAlphaSlider.__init__(self, parent, plot,
- kind="scatter", legend=legend)
diff --git a/silx/gui/plot/ColorBar.py b/silx/gui/plot/ColorBar.py
deleted file mode 100644
index d869825..0000000
--- a/silx/gui/plot/ColorBar.py
+++ /dev/null
@@ -1,881 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module containing several widgets associated to a colormap.
-"""
-
-__authors__ = ["H. Payno", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import logging
-import weakref
-import numpy
-
-from ._utils import ticklayout
-from .. import qt
-from ..qt import inspect as qt_inspect
-from silx.gui import colors
-
-_logger = logging.getLogger(__name__)
-
-
-class ColorBarWidget(qt.QWidget):
- """Colorbar widget displaying a colormap
-
- It uses a description of colormap as dict compatible with :class:`Plot`.
-
- .. image:: img/linearColorbar.png
- :width: 80px
- :align: center
-
- To run the following sample code, a QApplication must be initialized.
-
- >>> from silx.gui.plot import Plot2D
- >>> from silx.gui.plot.ColorBar import ColorBarWidget
-
- >>> plot = Plot2D() # Create a plot widget
- >>> plot.show()
-
- >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it
- >>> colorbar.show()
-
- Initializer parameters:
-
- :param parent: See :class:`QWidget`
- :param plot: PlotWidget the colorbar is attached to (optional)
- :param str legend: the label to set to the colorbar
- """
- sigVisibleChanged = qt.Signal(bool)
- """Emitted when the property `visible` have changed."""
-
- def __init__(self, parent=None, plot=None, legend=None):
- self._isConnected = False
- self._plotRef = None
- self._colormap = None
- self._data = None
-
- super(ColorBarWidget, self).__init__(parent)
-
- self.__buildGUI()
- self.setLegend(legend)
- self.setPlot(plot)
-
- def __buildGUI(self):
- self.setLayout(qt.QHBoxLayout())
-
- # create color scale widget
- self._colorScale = ColorScaleBar(parent=self,
- colormap=None)
- self.layout().addWidget(self._colorScale)
-
- # legend (is the right group)
- self.legend = _VerticalLegend('', self)
- self.layout().addWidget(self.legend)
-
- self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
-
- def getPlot(self):
- """Returns the :class:`Plot` associated to this widget or None"""
- return None if self._plotRef is None else self._plotRef()
-
- def setPlot(self, plot):
- """Associate a plot to the ColorBar
-
- :param plot: the plot to associate with the colorbar.
- If None will remove any connection with a previous plot.
- """
- self._disconnectPlot()
- self._plotRef = None if plot is None else weakref.ref(plot)
- self._connectPlot()
-
- def _disconnectPlot(self):
- """Disconnect from Plot signals"""
- if self._isConnected:
- self._isConnected = False
- plot = self.getPlot()
- if plot is not None and qt_inspect.isValid(plot):
- plot.sigActiveImageChanged.disconnect(
- self._activeImageChanged)
- plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChanged)
- plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
-
- def _connectPlot(self):
- """Connect to Plot signals"""
- plot = self.getPlot()
- if plot is not None and not self._isConnected:
- activeImageLegend = plot.getActiveImage(just_legend=True)
- activeScatterLegend = plot._getActiveItem(
- kind='scatter', just_legend=True)
- if activeImageLegend is None and activeScatterLegend is None:
- # Show plot default colormap
- self._syncWithDefaultColormap()
- elif activeImageLegend is not None: # Show active image colormap
- self._activeImageChanged(None, activeImageLegend)
- elif activeScatterLegend is not None: # Show active scatter colormap
- self._activeScatterChanged(None, activeScatterLegend)
-
- plot.sigActiveImageChanged.connect(self._activeImageChanged)
- plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
- plot.sigPlotSignal.connect(self._defaultColormapChanged)
- self._isConnected = True
-
- def setVisible(self, isVisible):
- qt.QWidget.setVisible(self, isVisible)
- self.sigVisibleChanged.emit(isVisible)
-
- def showEvent(self, event):
- self._connectPlot()
-
- def hideEvent(self, event):
- self._disconnectPlot()
-
- def getColormap(self):
- """Returns the colormap displayed in the colorbar.
-
- :rtype: ~silx.gui.colors.Colormap
- """
- return self.getColorScaleBar().getColormap()
-
- def setColormap(self, colormap, data=None):
- """Set the colormap to be displayed.
-
- :param ~silx.gui.colors.Colormap colormap:
- The colormap to apply on the ColorBarWidget
- :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
- The data to display or item, needed if the colormap require an autoscale
- """
- self._data = data
- self.getColorScaleBar().setColormap(colormap=colormap,
- data=data)
- if self._colormap is not None:
- self._colormap.sigChanged.disconnect(self._colormapHasChanged)
- self._colormap = colormap
- if self._colormap is not None:
- self._colormap.sigChanged.connect(self._colormapHasChanged)
-
- def _colormapHasChanged(self):
- """handler of the Colormap.sigChanged signal
- """
- assert self._colormap is not None
- self.setColormap(colormap=self._colormap,
- data=self._data)
-
- def setLegend(self, legend):
- """Set the legend displayed along the colorbar
-
- :param str legend: The label
- """
- if legend is None or legend == "":
- self.legend.hide()
- self.legend.setText("")
- else:
- assert type(legend) is str
- self.legend.show()
- self.legend.setText(legend)
-
- def getLegend(self):
- """
- Returns the legend displayed along the colorbar
-
- :return: return the legend displayed along the colorbar
- :rtype: str
- """
- return self.legend.text()
-
- def _activeScatterChanged(self, previous, legend):
- """Handle plot active scatter changed"""
- plot = self.getPlot()
-
- # Do not handle active scatter while there is an image
- if plot.getActiveImage() is not None:
- return
-
- if legend is None: # No active scatter, display no colormap
- self.setColormap(colormap=None)
- return
-
- # Sync with active scatter
- scatter = plot._getActiveItem(kind='scatter')
-
- self.setColormap(colormap=scatter.getColormap(),
- data=scatter)
-
- def _activeImageChanged(self, previous, legend):
- """Handle plot active image changed"""
- plot = self.getPlot()
-
- if legend is None: # No active image, try with active scatter
- activeScatterLegend = plot._getActiveItem(
- kind='scatter', just_legend=True)
- # No more active image, use active scatter if any
- self._activeScatterChanged(None, activeScatterLegend)
- else:
- # Sync with active image
- image = plot.getActiveImage()
-
- # RGB(A) image, display default colormap
- array = image.getData(copy=False)
- if array.ndim != 2:
- self.setColormap(colormap=None)
- return
-
- # data image, sync with image colormap
- # do we need the copy here : used in the case we are changing
- # vmin and vmax but should have already be done by the plot
- self.setColormap(colormap=image.getColormap(), data=image)
-
- def _defaultColormapChanged(self, event):
- """Handle plot default colormap changed"""
- if event['event'] == 'defaultColormapChanged':
- plot = self.getPlot()
- if (plot is not None and
- plot.getActiveImage() is None and
- plot._getActiveItem(kind='scatter') is None):
- # No active item, take default colormap update into account
- self._syncWithDefaultColormap()
-
- def _syncWithDefaultColormap(self):
- """Update colorbar according to plot default colormap"""
- self.setColormap(self.getPlot().getDefaultColormap())
-
- def getColorScaleBar(self):
- """
-
- :return: return the :class:`ColorScaleBar` used to display ColorScale
- and ticks"""
- return self._colorScale
-
-
-class _VerticalLegend(qt.QLabel):
- """Display vertically the given text
- """
- def __init__(self, text, parent=None):
- """
-
- :param text: the legend
- :param parent: the Qt parent if any
- """
- qt.QLabel.__init__(self, text, parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setContentsMargins(0, 0, 0, 0)
-
- def paintEvent(self, event):
- painter = qt.QPainter(self)
- painter.setFont(self.font())
-
- painter.translate(0, self.rect().height())
- painter.rotate(270)
- newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width())
-
- painter.drawText(newRect, qt.Qt.AlignHCenter, self.text())
-
- fm = qt.QFontMetrics(self.font())
- preferedHeight = fm.width(self.text())
- preferedWidth = fm.height()
- self.setFixedWidth(preferedWidth)
- self.setMinimumHeight(preferedHeight)
-
-
-class ColorScaleBar(qt.QWidget):
- """This class is making the composition of a :class:`_ColorScale` and a
- :class:`_TickBar`.
-
- It is the simplest widget displaying ticks and colormap gradient.
-
- .. image:: img/colorScaleBar.png
- :width: 150px
- :align: center
-
- To run the following sample code, a QApplication must be initialized.
-
- >>> colormap = Colormap(name='gray',
- ... norm='log',
- ... vmin=1,
- ... vmax=100000,
- ... )
- >>> colorscale = ColorScaleBar(parent=None,
- ... colormap=colormap )
- >>> colorscale.show()
-
- Initializer parameters :
-
- :param colormap: the colormap to be displayed
- :param parent: the Qt parent if any
- :param displayTicksValues: display the ticks value or only the '-'
- """
-
- _TEXT_MARGIN = 5
- """The tick bar need a margin to display all labels at the correct place.
- So the ColorScale should have the same margin in order for both to fit"""
-
- def __init__(self, parent=None, colormap=None, data=None,
- displayTicksValues=True):
- super(ColorScaleBar, self).__init__(parent)
-
- self.minVal = None
- """Value set to the _minLabel"""
- self.maxVal = None
- """Value set to the _maxLabel"""
-
- self.setLayout(qt.QGridLayout())
-
- # create the left side group (ColorScale)
- self.colorScale = _ColorScale(colormap=colormap,
- data=data,
- parent=self,
- margin=ColorScaleBar._TEXT_MARGIN)
- if colormap:
- vmin, vmax = colormap.getColormapRange(data)
- normalizer = colormap._getNormalizer()
- else:
- vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN
- normalizer = None
-
- self.tickbar = _TickBar(vmin=vmin,
- vmax=vmax,
- normalizer=normalizer,
- parent=self,
- displayValues=displayTicksValues,
- margin=ColorScaleBar._TEXT_MARGIN)
-
- self.layout().addWidget(self.tickbar, 1, 0, 1, 1, qt.Qt.AlignRight)
- self.layout().addWidget(self.colorScale, 1, 1, qt.Qt.AlignLeft)
-
- self.layout().setContentsMargins(0, 0, 0, 0)
- self.layout().setSpacing(0)
-
- # max label
- self._maxLabel = qt.QLabel(str(1.0), parent=self)
- self._maxLabel.setToolTip(str(0.0))
- self.layout().addWidget(self._maxLabel, 0, 0, 1, 2, qt.Qt.AlignRight)
-
- # min label
- self._minLabel = qt.QLabel(str(0.0), parent=self)
- self._minLabel.setToolTip(str(0.0))
- self.layout().addWidget(self._minLabel, 2, 0, 1, 2, qt.Qt.AlignRight)
-
- self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
- self.layout().setColumnStretch(0, 1)
- self.layout().setRowStretch(1, 1)
-
- def getTickBar(self):
- """
-
- :return: the instanciation of the :class:`_TickBar`
- """
- return self.tickbar
-
- def getColorScale(self):
- """
-
- :return: the instanciation of the :class:`_ColorScale`
- """
- return self.colorScale
-
- def getColormap(self):
- """
-
- :returns: the colormap.
- :rtype: :class:`.Colormap`
- """
- return self.colorScale.getColormap()
-
- def setColormap(self, colormap, data=None):
- """Set the new colormap to be displayed
-
- :param Colormap colormap: the colormap to set
- :param Union[numpy.ndarray,~silx.gui.plot.items.Item] data:
- The data or item to display, needed if the colormap requires an autoscale
- """
- self.colorScale.setColormap(colormap, data)
-
- if colormap is not None:
- vmin, vmax = colormap.getColormapRange(data)
- normalizer = colormap._getNormalizer()
- else:
- vmin, vmax = None, None
- normalizer = None
-
- self.tickbar.update(vmin=vmin,
- vmax=vmax,
- normalizer=normalizer)
- self._setMinMaxLabels(vmin, vmax)
-
- def setMinMaxVisible(self, val=True):
- """Change visibility of the min label and the max label
-
- :param val: if True, set the labels visible, otherwise set it not visible
- """
- self._minLabel.setVisible(val)
- self._maxLabel.setVisible(val)
-
- def _updateMinMax(self):
- """Update the min and max label if we are in the case of the
- configuration 'minMaxValueOnly'"""
- if self.minVal is None:
- text, tooltip = '', ''
- else:
- if self.minVal == 0 or 0 <= numpy.log10(abs(self.minVal)) < 7:
- text = '%.7g' % self.minVal
- else:
- text = '%.2e' % self.minVal
- tooltip = repr(self.minVal)
-
- self._minLabel.setText(text)
- self._minLabel.setToolTip(tooltip)
-
- if self.maxVal is None:
- text, tooltip = '', ''
- else:
- if self.maxVal == 0 or 0 <= numpy.log10(abs(self.maxVal)) < 7:
- text = '%.7g' % self.maxVal
- else:
- text = '%.2e' % self.maxVal
- tooltip = repr(self.maxVal)
-
- self._maxLabel.setText(text)
- self._maxLabel.setToolTip(tooltip)
-
- def _setMinMaxLabels(self, minVal, maxVal):
- """Change the value of the min and max labels to be displayed.
-
- :param minVal: the minimal value of the TickBar (not str)
- :param maxVal: the maximal value of the TickBar (not str)
- """
- # bad hack to try to display has much information as possible
- self.minVal = minVal
- self.maxVal = maxVal
- self._updateMinMax()
-
- def resizeEvent(self, event):
- qt.QWidget.resizeEvent(self, event)
- self._updateMinMax()
-
-
-class _ColorScale(qt.QWidget):
- """Widget displaying the colormap colorScale.
-
- Show matching value between the gradient color (from the colormap) at mouse
- position and value.
-
- .. image:: img/colorScale.png
- :width: 20px
- :align: center
-
-
- To run the following sample code, a QApplication must be initialized.
-
- >>> colormap = Colormap(name='viridis',
- ... norm='log',
- ... vmin=1,
- ... vmax=100000,
- ... )
- >>> colorscale = ColorScale(parent=None,
- ... colormap=colormap)
- >>> colorscale.show()
-
- Initializer parameters :
-
- :param colormap: the colormap to be displayed
- :param parent: the Qt parent if any
- :param int margin: the top and left margin to apply.
- :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
- The data or item to use for getting the range for autoscale colormap.
-
- .. warning:: Value drawing will be
- done at the center of ticks. So if no margin is done your values
- drawing might not be fully done for extrems values.
- """
-
- _NB_CONTROL_POINTS = 256
-
- def __init__(self, colormap, parent=None, margin=5, data=None):
- qt.QWidget.__init__(self, parent)
- self._colormap = None
- self.margin = margin
- self.setColormap(colormap, data)
-
- self.setLayout(qt.QVBoxLayout())
- self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Expanding)
- # needed to get the mouse event without waiting for button click
- self.setMouseTracking(True)
- self.setMargin(margin)
- self.setContentsMargins(0, 0, 0, 0)
-
- self.setMinimumHeight(self._NB_CONTROL_POINTS // 2 + 2 * self.margin)
- self.setFixedWidth(25)
-
- def setColormap(self, colormap, data=None):
- """Set the new colormap to be displayed
-
- :param dict colormap: the colormap to set
- :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
- Optional data for which to compute colormap range.
- """
- self._colormap = colormap
- self.setEnabled(colormap is not None)
-
- if colormap is None:
- self.vmin, self.vmax = None, None
- else:
- assert colormap.getNormalization() in colors.Colormap.NORMALIZATIONS
- self.vmin, self.vmax = self._colormap.getColormapRange(data=data)
- self._updateColorGradient()
- self.update()
-
- def getColormap(self):
- """Returns the colormap
-
- :rtype: :class:`.Colormap`
- """
- return None if self._colormap is None else self._colormap
-
- def _updateColorGradient(self):
- """Compute the color gradient"""
- colormap = self.getColormap()
- if colormap is None:
- return
-
- indices = numpy.linspace(0., 1., self._NB_CONTROL_POINTS)
- colors = colormap.getNColors(nbColors=self._NB_CONTROL_POINTS)
- self._gradient = qt.QLinearGradient(0, 1, 0, 0)
- self._gradient.setCoordinateMode(qt.QGradient.StretchToDeviceMode)
- self._gradient.setStops(
- [(i, qt.QColor(*color)) for i, color in zip(indices, colors)]
- )
-
- def paintEvent(self, event):
- """"""
- painter = qt.QPainter(self)
- if self.getColormap() is not None:
- painter.setBrush(self._gradient)
- penColor = self.palette().color(qt.QPalette.Active,
- qt.QPalette.Foreground)
- else:
- penColor = self.palette().color(qt.QPalette.Disabled,
- qt.QPalette.Foreground)
- painter.setPen(penColor)
-
- painter.drawRect(qt.QRect(
- 0,
- self.margin,
- self.width() - 1,
- self.height() - 2 * self.margin - 1))
-
- def mouseMoveEvent(self, event):
- tooltip = str(self.getValueFromRelativePosition(
- self._getRelativePosition(event.y())))
- qt.QToolTip.showText(event.globalPos(), tooltip, self)
- super(_ColorScale, self).mouseMoveEvent(event)
-
- def _getRelativePosition(self, yPixel):
- """yPixel : pixel position into _ColorScale widget reference
- """
- # widgets are bottom-top referencial but we display in top-bottom referential
- return 1. - (yPixel - self.margin) / float(self.height() - 2 * self.margin)
-
- def getValueFromRelativePosition(self, value):
- """Return the value in the colorMap from a relative position in the
- ColorScaleBar (y)
-
- :param value: float value in [0, 1]
- :return: the value in [colormap['vmin'], colormap['vmax']]
- """
- colormap = self.getColormap()
- if colormap is None:
- return
-
- value = numpy.clip(value, 0., 1.)
- normalizer = colormap._getNormalizer()
- normMin, normMax = normalizer.apply([self.vmin, self.vmax], self.vmin, self.vmax)
-
- return normalizer.revert(
- normMin + (normMax - normMin) * value, self.vmin, self.vmax)
-
- def setMargin(self, margin):
- """Define the margin to fit with a TickBar object.
- This is needed since we can only paint on the viewport of the widget.
- Didn't work with a simple setContentsMargins
-
- :param int margin: the margin to apply on the top and bottom.
- """
- self.margin = int(margin)
- self.update()
-
-
-class _TickBar(qt.QWidget):
- """Bar grouping the ticks displayed
-
- To run the following sample code, a QApplication must be initialized.
-
- >>> bar = _TickBar(1, 1000, norm='log', parent=None, displayValues=True)
- >>> bar.show()
-
- .. image:: img/tickbar.png
- :width: 40px
- :align: center
-
- :param int vmin: smaller value of the range of values
- :param int vmax: higher value of the range of values
- :param normalizer: Normalization object.
- :param parent: the Qt parent if any
- :param bool displayValues: if True display the values close to the tick,
- Otherwise only signal it by '-'
- :param int nticks: the number of tick we want to display. Should be an
- unsigned int ot None. If None, let the Tick bar find the optimal
- number of ticks from the tick density.
- :param int margin: margin to set on the top and bottom
- """
- _WIDTH_DISP_VAL = 45
- """widget width when displayed with ticks labels"""
- _WIDTH_NO_DISP_VAL = 10
- """widget width when displayed without ticks labels"""
- _FONT_SIZE = 10
- """font size for ticks labels"""
- _LINE_WIDTH = 10
- """width of the line to mark a tick"""
-
- DEFAULT_TICK_DENSITY = 0.015
-
- def __init__(self, vmin, vmax, normalizer, parent=None, displayValues=True,
- nticks=None, margin=5):
- super(_TickBar, self).__init__(parent)
- self.margin = margin
- self._nticks = None
- self.ticks = ()
- self.subTicks = ()
- self._forcedDisplayType = None
- self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY
-
- self._vmin = vmin
- self._vmax = vmax
- self._normalizer = normalizer
- self.displayValues = displayValues
- self.setTicksNumber(nticks)
-
- self.setMargin(margin)
- self.setContentsMargins(0, 0, 0, 0)
-
- self._resetWidth()
-
- def setTicksValuesVisible(self, val):
- self.displayValues = val
- self._resetWidth()
-
- def _resetWidth(self):
- width = self._WIDTH_DISP_VAL if self.displayValues else self._WIDTH_NO_DISP_VAL
- self.setFixedWidth(width)
-
- def update(self, vmin, vmax, normalizer):
- self._vmin = vmin
- self._vmax = vmax
- self._normalizer = normalizer
- self.computeTicks()
- qt.QWidget.update(self)
-
- def setMargin(self, margin):
- """Define the margin to fit with a _ColorScale object.
- This is needed since we can only paint on the viewport of the widget
-
- :param int margin: the margin to apply on the top and bottom.
- """
- self.margin = margin
-
- def setTicksNumber(self, nticks):
- """Set the number of ticks to display.
-
- :param nticks: the number of tick to be display. Should be an
- unsigned int ot None. If None, let the :class:`_TickBar` find the
- optimal number of ticks from the tick density.
- """
- self._nticks = nticks
- self.computeTicks()
- qt.QWidget.update(self)
-
- def setTicksDensity(self, density):
- """If you let :class:`_TickBar` deal with the number of ticks
- (nticks=None) then you can specify a ticks density to be displayed.
- """
- if density < 0.0:
- raise ValueError('Density should be a positive value')
- self.ticksDensity = density
-
- def computeTicks(self):
- """This function compute ticks values labels. It is called at each
- update and each resize event.
- Deal only with linear and log scale.
- """
- nticks = self._nticks
- if nticks is None:
- nticks = self._getOptimalNbTicks()
-
- if self._vmin == self._vmax:
- # No range: no ticks
- self.ticks = ()
- self.subTicks = ()
- elif isinstance(self._normalizer, colors._LogarithmicNormalization):
- self._computeTicksLog(nticks)
- else: # Fallback: use linear
- self._computeTicksLin(nticks)
-
- # update the form
- font = qt.QFont()
- font.setPixelSize(_TickBar._FONT_SIZE)
-
- self.form = self._getFormat(font)
-
- def _computeTicksLog(self, nticks):
- logMin = numpy.log10(self._vmin)
- logMax = numpy.log10(self._vmax)
- lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin,
- logMax,
- nticks)
- self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing))
- if spacing == 1:
- self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks,
- lowBound=numpy.power(10., lowBound),
- highBound=numpy.power(10., highBound))
- else:
- self.subTicks = []
-
- def resizeEvent(self, event):
- qt.QWidget.resizeEvent(self, event)
- self.computeTicks()
-
- def _computeTicksLin(self, nticks):
- _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin,
- self._vmax,
- nticks)
-
- self.ticks = numpy.arange(_min, _max, _spacing)
- self.subTicks = []
-
- def _getOptimalNbTicks(self):
- return max(2, int(round(self.ticksDensity * self.rect().height())))
-
- def paintEvent(self, event):
- painter = qt.QPainter(self)
- font = painter.font()
- font.setPixelSize(_TickBar._FONT_SIZE)
- painter.setFont(font)
-
- # paint ticks
- for val in self.ticks:
- self._paintTick(val, painter, majorTick=True)
-
- # paint subticks
- for val in self.subTicks:
- self._paintTick(val, painter, majorTick=False)
-
- def _getRelativePosition(self, val):
- """Return the relative position of val according to min and max value
- """
- if self._normalizer is None:
- return 0.
- normMin, normMax, normVal = self._normalizer.apply(
- [self._vmin, self._vmax, val],
- self._vmin,
- self._vmax)
-
- if normMin == normMax:
- return 0.
- else:
- return 1. - (normVal - normMin) / (normMax - normMin)
-
- def _paintTick(self, val, painter, majorTick=True):
- """
-
- :param bool majorTick: if False will never draw text and will set a line
- with a smaller width
- """
- fm = qt.QFontMetrics(painter.font())
- viewportHeight = self.rect().height() - self.margin * 2 - 1
- relativePos = self._getRelativePosition(val)
- height = int(viewportHeight * relativePos + self.margin)
- lineWidth = _TickBar._LINE_WIDTH
- if majorTick is False:
- lineWidth /= 2
-
- painter.drawLine(qt.QLine(int(self.width() - lineWidth),
- height,
- self.width(),
- height))
-
- if self.displayValues and majorTick is True:
- painter.drawText(qt.QPoint(0, int(height + fm.height() / 2)),
- self.form.format(val))
-
- def setDisplayType(self, disType):
- """Set the type of display we want to set for ticks labels
-
- :param str disType: The type of display we want to set. disType values
- can be :
-
- - 'std' for standard, meaning only a formatting on the number of
- digits is done
- - 'e' for scientific display
- - None to let the _TickBar guess the best display for this kind of data.
- """
- if disType not in (None, 'std', 'e'):
- raise ValueError("display type not recognized, value should be in (None, 'std', 'e'")
- self._forcedDisplayType = disType
-
- def _getStandardFormat(self):
- return "{0:.%sf}" % self._nfrac
-
- def _getFormat(self, font):
- if self._forcedDisplayType is None:
- return self._guessType(font)
- elif self._forcedDisplayType == 'std':
- return self._getStandardFormat()
- elif self._forcedDisplayType == 'e':
- return self._getScientificForm()
- else:
- err = 'Forced type for display %s is not recognized' % self._forcedDisplayType
- raise ValueError(err)
-
- def _getScientificForm(self):
- return "{0:.0e}"
-
- def _guessType(self, font):
- """Try fo find the better format to display the tick's labels
-
- :param QFont font: the font we want to use during the painting
- """
- form = self._getStandardFormat()
-
- fm = qt.QFontMetrics(font)
- width = 0
- for tick in self.ticks:
- width = max(fm.boundingRect(form.format(tick)).width(), width)
-
- # if the length of the string are too long we are moving to scientific
- # display
- if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
- return self._getScientificForm()
- else:
- return form
diff --git a/silx/gui/plot/CompareImages.py b/silx/gui/plot/CompareImages.py
deleted file mode 100644
index 3875be4..0000000
--- a/silx/gui/plot/CompareImages.py
+++ /dev/null
@@ -1,1249 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""A widget dedicated to compare 2 images.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "23/07/2018"
-
-
-import enum
-import logging
-import numpy
-import weakref
-import collections
-import math
-
-import silx.image.bilinear
-from silx.gui import qt
-from silx.gui import plot
-from silx.gui import icons
-from silx.gui.colors import Colormap
-from silx.gui.plot import tools
-
-_logger = logging.getLogger(__name__)
-
-from silx.opencl import ocl
-if ocl is not None:
- try:
- from silx.opencl import sift
- except ImportError:
- # sift module is not available (e.g., in official Debian packages)
- sift = None
-else: # No OpenCL device or no pyopencl
- sift = None
-
-
-@enum.unique
-class VisualizationMode(enum.Enum):
- """Enum for each visualization mode available."""
- ONLY_A = 'a'
- ONLY_B = 'b'
- VERTICAL_LINE = 'vline'
- HORIZONTAL_LINE = 'hline'
- COMPOSITE_RED_BLUE_GRAY = "rbgchannel"
- COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel"
- COMPOSITE_A_MINUS_B = "aminusb"
-
-
-@enum.unique
-class AlignmentMode(enum.Enum):
- """Enum for each alignment mode available."""
- ORIGIN = 'origin'
- CENTER = 'center'
- STRETCH = 'stretch'
- AUTO = 'auto'
-
-
-AffineTransformation = collections.namedtuple("AffineTransformation",
- ["tx", "ty", "sx", "sy", "rot"])
-"""Contains a 2D affine transformation: translation, scale and rotation"""
-
-
-class CompareImagesToolBar(qt.QToolBar):
- """ToolBar containing specific tools to custom the configuration of a
- :class:`CompareImages` widget
-
- Use :meth:`setCompareWidget` to connect this toolbar to a specific
- :class:`CompareImages` widget.
-
- :param Union[qt.QWidget,None] parent: Parent of this widget.
- """
- def __init__(self, parent=None):
- qt.QToolBar.__init__(self, parent)
-
- self.__compareWidget = None
-
- menu = qt.QMenu(self)
- self.__visualizationAction = qt.QAction(self)
- self.__visualizationAction.setMenu(menu)
- self.__visualizationAction.setCheckable(False)
- self.addAction(self.__visualizationAction)
- self.__visualizationGroup = qt.QActionGroup(self)
- self.__visualizationGroup.setExclusive(True)
- self.__visualizationGroup.triggered.connect(self.__visualizationModeChanged)
-
- icon = icons.getQIcon("compare-mode-a")
- action = qt.QAction(icon, "Display the first image only", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_A))
- action.setProperty("mode", VisualizationMode.ONLY_A)
- menu.addAction(action)
- self.__aModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-b")
- action = qt.QAction(icon, "Display the second image only", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
- action.setProperty("mode", VisualizationMode.ONLY_B)
- menu.addAction(action)
- self.__bModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-vline")
- action = qt.QAction(icon, "Vertical compare mode", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_V))
- action.setProperty("mode", VisualizationMode.VERTICAL_LINE)
- menu.addAction(action)
- self.__vlineModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-hline")
- action = qt.QAction(icon, "Horizontal compare mode", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_H))
- action.setProperty("mode", VisualizationMode.HORIZONTAL_LINE)
- menu.addAction(action)
- self.__hlineModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-rb-channel")
- action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_C))
- action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY)
- menu.addAction(action)
- self.__brChannelModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-rbneg-channel")
- action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
- action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG)
- menu.addAction(action)
- self.__ycChannelModeAction = action
- self.__visualizationGroup.addAction(action)
-
- icon = icons.getQIcon("compare-mode-a-minus-b")
- action = qt.QAction(icon, "Raw A minus B compare mode", self)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
- action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B)
- menu.addAction(action)
- self.__ycChannelModeAction = action
- self.__visualizationGroup.addAction(action)
-
- menu = qt.QMenu(self)
- self.__alignmentAction = qt.QAction(self)
- self.__alignmentAction.setMenu(menu)
- self.__alignmentAction.setIconVisibleInMenu(True)
- self.addAction(self.__alignmentAction)
- self.__alignmentGroup = qt.QActionGroup(self)
- self.__alignmentGroup.setExclusive(True)
- self.__alignmentGroup.triggered.connect(self.__alignmentModeChanged)
-
- icon = icons.getQIcon("compare-align-origin")
- action = qt.QAction(icon, "Align images on their upper-left pixel", self)
- action.setProperty("mode", AlignmentMode.ORIGIN)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__originAlignAction = action
- menu.addAction(action)
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-align-center")
- action = qt.QAction(icon, "Center images", self)
- action.setProperty("mode", AlignmentMode.CENTER)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__centerAlignAction = action
- menu.addAction(action)
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-align-stretch")
- action = qt.QAction(icon, "Stretch the second image on the first one", self)
- action.setProperty("mode", AlignmentMode.STRETCH)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__stretchAlignAction = action
- menu.addAction(action)
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-align-auto")
- action = qt.QAction(icon, "Auto-alignment of the second image", self)
- action.setProperty("mode", AlignmentMode.AUTO)
- action.setIconVisibleInMenu(True)
- action.setCheckable(True)
- self.__autoAlignAction = action
- menu.addAction(action)
- if sift is None:
- action.setEnabled(False)
- action.setToolTip("Sift module is not available")
- self.__alignmentGroup.addAction(action)
-
- icon = icons.getQIcon("compare-keypoints")
- action = qt.QAction(icon, "Display/hide alignment keypoints", self)
- action.setCheckable(True)
- action.triggered.connect(self.__keypointVisibilityChanged)
- self.addAction(action)
- self.__displayKeypoints = action
-
- def setCompareWidget(self, widget):
- """
- Connect this tool bar to a specific :class:`CompareImages` widget.
-
- :param Union[None,CompareImages] widget: The widget to connect with.
- """
- compareWidget = self.getCompareWidget()
- if compareWidget is not None:
- compareWidget.sigConfigurationChanged.disconnect(self.__updateSelectedActions)
- compareWidget = widget
- if compareWidget is None:
- self.__compareWidget = None
- else:
- self.__compareWidget = weakref.ref(compareWidget)
- if compareWidget is not None:
- widget.sigConfigurationChanged.connect(self.__updateSelectedActions)
- self.__updateSelectedActions()
-
- def getCompareWidget(self):
- """Returns the connected widget.
-
- :rtype: CompareImages
- """
- if self.__compareWidget is None:
- return None
- else:
- return self.__compareWidget()
-
- def __updateSelectedActions(self):
- """
- Update the state of this tool bar according to the state of the
- connected :class:`CompareImages` widget.
- """
- widget = self.getCompareWidget()
- if widget is None:
- return
-
- mode = widget.getVisualizationMode()
- action = None
- for a in self.__visualizationGroup.actions():
- actionMode = a.property("mode")
- if mode == actionMode:
- action = a
- break
- old = self.__visualizationGroup.blockSignals(True)
- if action is not None:
- # Check this action
- action.setChecked(True)
- else:
- action = self.__visualizationGroup.checkedAction()
- if action is not None:
- # Uncheck this action
- action.setChecked(False)
- self.__updateVisualizationMenu()
- self.__visualizationGroup.blockSignals(old)
-
- mode = widget.getAlignmentMode()
- action = None
- for a in self.__alignmentGroup.actions():
- actionMode = a.property("mode")
- if mode == actionMode:
- action = a
- break
- old = self.__alignmentGroup.blockSignals(True)
- if action is not None:
- # Check this action
- action.setChecked(True)
- else:
- action = self.__alignmentGroup.checkedAction()
- if action is not None:
- # Uncheck this action
- action.setChecked(False)
- self.__updateAlignmentMenu()
- self.__alignmentGroup.blockSignals(old)
-
- def __visualizationModeChanged(self, selectedAction):
- """Called when user requesting changes of the visualization mode.
- """
- self.__updateVisualizationMenu()
- widget = self.getCompareWidget()
- if widget is not None:
- mode = selectedAction.property("mode")
- widget.setVisualizationMode(mode)
-
- def __updateVisualizationMenu(self):
- """Update the state of the action containing visualization menu.
- """
- selectedAction = self.__visualizationGroup.checkedAction()
- if selectedAction is not None:
- self.__visualizationAction.setText(selectedAction.text())
- self.__visualizationAction.setIcon(selectedAction.icon())
- self.__visualizationAction.setToolTip(selectedAction.toolTip())
- else:
- self.__visualizationAction.setText("")
- self.__visualizationAction.setIcon(qt.QIcon())
- self.__visualizationAction.setToolTip("")
-
- def __alignmentModeChanged(self, selectedAction):
- """Called when user requesting changes of the alignment mode.
- """
- self.__updateAlignmentMenu()
- widget = self.getCompareWidget()
- if widget is not None:
- mode = selectedAction.property("mode")
- widget.setAlignmentMode(mode)
-
- def __updateAlignmentMenu(self):
- """Update the state of the action containing alignment menu.
- """
- selectedAction = self.__alignmentGroup.checkedAction()
- if selectedAction is not None:
- self.__alignmentAction.setText(selectedAction.text())
- self.__alignmentAction.setIcon(selectedAction.icon())
- self.__alignmentAction.setToolTip(selectedAction.toolTip())
- else:
- self.__alignmentAction.setText("")
- self.__alignmentAction.setIcon(qt.QIcon())
- self.__alignmentAction.setToolTip("")
-
- def __keypointVisibilityChanged(self):
- """Called when action managing keypoints visibility changes"""
- widget = self.getCompareWidget()
- if widget is not None:
- keypointsVisible = self.__displayKeypoints.isChecked()
- widget.setKeypointsVisible(keypointsVisible)
-
-
-class CompareImagesStatusBar(qt.QStatusBar):
- """StatusBar containing specific information contained in a
- :class:`CompareImages` widget
-
- Use :meth:`setCompareWidget` to connect this toolbar to a specific
- :class:`CompareImages` widget.
-
- :param Union[qt.QWidget,None] parent: Parent of this widget.
- """
- def __init__(self, parent=None):
- qt.QStatusBar.__init__(self, parent)
- self.setSizeGripEnabled(False)
- self.layout().setSpacing(0)
- self.__compareWidget = None
- self._label1 = qt.QLabel(self)
- self._label1.setFrameShape(qt.QFrame.WinPanel)
- self._label1.setFrameShadow(qt.QFrame.Sunken)
- self._label2 = qt.QLabel(self)
- self._label2.setFrameShape(qt.QFrame.WinPanel)
- self._label2.setFrameShadow(qt.QFrame.Sunken)
- self._transform = qt.QLabel(self)
- self._transform.setFrameShape(qt.QFrame.WinPanel)
- self._transform.setFrameShadow(qt.QFrame.Sunken)
- self.addWidget(self._label1)
- self.addWidget(self._label2)
- self.addWidget(self._transform)
- self._pos = None
- self._updateStatusBar()
-
- def setCompareWidget(self, widget):
- """
- Connect this tool bar to a specific :class:`CompareImages` widget.
-
- :param Union[None,CompareImages] widget: The widget to connect with.
- """
- compareWidget = self.getCompareWidget()
- if compareWidget is not None:
- compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived)
- compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged)
- compareWidget = widget
- if compareWidget is None:
- self.__compareWidget = None
- else:
- self.__compareWidget = weakref.ref(compareWidget)
- if compareWidget is not None:
- compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived)
- compareWidget.sigConfigurationChanged.connect(self.__dataChanged)
-
- def getCompareWidget(self):
- """Returns the connected widget.
-
- :rtype: CompareImages
- """
- if self.__compareWidget is None:
- return None
- else:
- return self.__compareWidget()
-
- def __plotSignalReceived(self, event):
- """Called when old style signals at emmited from the plot."""
- if event["event"] == "mouseMoved":
- x, y = event["x"], event["y"]
- self.__mouseMoved(x, y)
-
- def __mouseMoved(self, x, y):
- """Called when mouse move over the plot."""
- self._pos = x, y
- self._updateStatusBar()
-
- def __dataChanged(self):
- """Called when internal data from the connected widget changes."""
- self._updateStatusBar()
-
- def _formatData(self, data):
- """Format pixel of an image.
-
- It supports intensity, RGB, and RGBA.
-
- :param Union[int,float,numpy.ndarray,str]: Value of a pixel
- :rtype: str
- """
- if data is None:
- return "No data"
- if isinstance(data, (int, numpy.integer)):
- return "%d" % data
- if isinstance(data, (float, numpy.floating)):
- return "%f" % data
- if isinstance(data, numpy.ndarray):
- # RGBA value
- if data.shape == (3,):
- return "R:%d G:%d B:%d" % (data[0], data[1], data[2])
- elif data.shape == (4,):
- return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3])
- _logger.debug("Unsupported data format %s. Cast it to string.", type(data))
- return str(data)
-
- def _updateStatusBar(self):
- """Update the content of the status bar"""
- widget = self.getCompareWidget()
- if widget is None:
- self._label1.setText("Image1: NA")
- self._label2.setText("Image2: NA")
- self._transform.setVisible(False)
- else:
- transform = widget.getTransformation()
- self._transform.setVisible(transform is not None)
- if transform is not None:
- has_notable_translation = not numpy.isclose(transform.tx, 0.0, atol=0.01) \
- or not numpy.isclose(transform.ty, 0.0, atol=0.01)
- has_notable_scale = not numpy.isclose(transform.sx, 1.0, atol=0.01) \
- or not numpy.isclose(transform.sy, 1.0, atol=0.01)
- has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01)
-
- strings = []
- if has_notable_translation:
- strings.append("Translation")
- if has_notable_scale:
- strings.append("Scale")
- if has_notable_rotation:
- strings.append("Rotation")
- if strings == []:
- has_translation = not numpy.isclose(transform.tx, 0.0) \
- or not numpy.isclose(transform.ty, 0.0)
- has_scale = not numpy.isclose(transform.sx, 1.0) \
- or not numpy.isclose(transform.sy, 1.0)
- has_rotation = not numpy.isclose(transform.rot, 0.0)
- if has_translation or has_scale or has_rotation:
- text = "No big changes"
- else:
- text = "No changes"
- else:
- text = "+".join(strings)
- self._transform.setText("Align: " + text)
-
- strings = []
- if not numpy.isclose(transform.ty, 0.0):
- strings.append("Translation x: %0.3fpx" % transform.tx)
- if not numpy.isclose(transform.ty, 0.0):
- strings.append("Translation y: %0.3fpx" % transform.ty)
- if not numpy.isclose(transform.sx, 1.0):
- strings.append("Scale x: %0.3f" % transform.sx)
- if not numpy.isclose(transform.sy, 1.0):
- strings.append("Scale y: %0.3f" % transform.sy)
- if not numpy.isclose(transform.rot, 0.0):
- strings.append("Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi))
- if strings == []:
- text = "No transformation"
- else:
- text = "\n".join(strings)
- self._transform.setToolTip(text)
-
- if self._pos is None:
- self._label1.setText("Image1: NA")
- self._label2.setText("Image2: NA")
- else:
- data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1])
- if isinstance(data1, str):
- self._label1.setToolTip(data1)
- text1 = "NA"
- else:
- self._label1.setToolTip("")
- text1 = self._formatData(data1)
- if isinstance(data2, str):
- self._label2.setToolTip(data2)
- text2 = "NA"
- else:
- self._label2.setToolTip("")
- text2 = self._formatData(data2)
- self._label1.setText("Image1: %s" % text1)
- self._label2.setText("Image2: %s" % text2)
-
-
-class CompareImages(qt.QMainWindow):
- """Widget providing tools to compare 2 images.
-
- .. image:: img/CompareImages.png
-
- :param Union[qt.QWidget,None] parent: Parent of this widget.
- :param backend: The backend to use, in:
- 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
- or a :class:`BackendBase.BackendBase` class
- :type backend: str or :class:`BackendBase.BackendBase`
- """
-
- VisualizationMode = VisualizationMode
- """Available visualization modes"""
-
- AlignmentMode = AlignmentMode
- """Available alignment modes"""
-
- sigConfigurationChanged = qt.Signal()
- """Emitted when the configuration of the widget (visualization mode,
- alignement mode...) have changed."""
-
- def __init__(self, parent=None, backend=None):
- qt.QMainWindow.__init__(self, parent)
- self._resetZoomActive = True
- self._colormap = Colormap()
- """Colormap shared by all modes, except the compose images (rgb image)"""
- self._colormapKeyPoints = Colormap('spring')
- """Colormap used for sift keypoints"""
-
- if parent is None:
- self.setWindowTitle('Compare images')
- else:
- self.setWindowFlags(qt.Qt.Widget)
-
- self.__transformation = None
- self.__raw1 = None
- self.__raw2 = None
- self.__data1 = None
- self.__data2 = None
- self.__previousSeparatorPosition = None
-
- self.__plot = plot.PlotWidget(parent=self, backend=backend)
- self.__plot.setDefaultColormap(self._colormap)
- self.__plot.getXAxis().setLabel('Columns')
- self.__plot.getYAxis().setLabel('Rows')
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
- self.__plot.getYAxis().setInverted(True)
-
- self.__plot.setKeepDataAspectRatio(True)
- self.__plot.sigPlotSignal.connect(self.__plotSlot)
- self.__plot.setAxesDisplayed(False)
-
- self.setCentralWidget(self.__plot)
-
- legend = VisualizationMode.VERTICAL_LINE.name
- self.__plot.addXMarker(
- 0,
- legend=legend,
- text='',
- draggable=True,
- color='blue',
- constraint=self.__separatorConstraint)
- self.__vline = self.__plot._getMarker(legend)
-
- legend = VisualizationMode.HORIZONTAL_LINE.name
- self.__plot.addYMarker(
- 0,
- legend=legend,
- text='',
- draggable=True,
- color='blue',
- constraint=self.__separatorConstraint)
- self.__hline = self.__plot._getMarker(legend)
-
- # default values
- self.__visualizationMode = ""
- self.__alignmentMode = ""
- self.__keypointsVisible = True
-
- self.setAlignmentMode(AlignmentMode.ORIGIN)
- self.setVisualizationMode(VisualizationMode.VERTICAL_LINE)
- self.setKeypointsVisible(False)
-
- # Toolbars
-
- self._createToolBars(self.__plot)
- if self._interactiveModeToolBar is not None:
- self.addToolBar(self._interactiveModeToolBar)
- if self._imageToolBar is not None:
- self.addToolBar(self._imageToolBar)
- if self._compareToolBar is not None:
- self.addToolBar(self._compareToolBar)
-
- # Statusbar
-
- self._createStatusBar(self.__plot)
- if self._statusBar is not None:
- self.setStatusBar(self._statusBar)
-
- def _createStatusBar(self, plot):
- self._statusBar = CompareImagesStatusBar(self)
- self._statusBar.setCompareWidget(self)
-
- def _createToolBars(self, plot):
- """Create tool bars displayed by the widget"""
- toolBar = tools.InteractiveModeToolBar(parent=self, plot=plot)
- self._interactiveModeToolBar = toolBar
- toolBar = tools.ImageToolBar(parent=self, plot=plot)
- self._imageToolBar = toolBar
- toolBar = CompareImagesToolBar(self)
- toolBar.setCompareWidget(self)
- self._compareToolBar = toolBar
-
- def getPlot(self):
- """Returns the plot which is used to display the images.
-
- :rtype: silx.gui.plot.PlotWidget
- """
- return self.__plot
-
- def getColormap(self):
- """
-
- :return: colormap used for compare image
- :rtype: silx.gui.colors.Colormap
- """
- return self._colormap
-
- def getRawPixelData(self, x, y):
- """Return the raw pixel of each image data from axes positions.
-
- If the coordinate is outside of the image it returns None element in
- the tuple.
-
- The pixel is reach from the raw data image without filter or
- transformation. But the coordinate x and y are in the reference of the
- current displayed mode.
-
- :param float x: X-coordinate of the pixel in the current displayed plot
- :param float y: Y-coordinate of the pixel in the current displayed plot
- :return: A tuple of for each images containing pixel information. It
- could be a scalar value or an array in case of RGB/RGBA informations.
- It also could be a string containing information is some cases.
- :rtype: Tuple(Union[int,float,numpy.ndarray,str],Union[int,float,numpy.ndarray,str])
- """
- data2 = None
- alignmentMode = self.__alignmentMode
- raw1, raw2 = self.__raw1, self.__raw2
- if alignmentMode == AlignmentMode.ORIGIN:
- x1 = x
- y1 = y
- x2 = x
- y2 = y
- elif alignmentMode == AlignmentMode.CENTER:
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- x1 = x - (xx - raw1.shape[1]) * 0.5
- x2 = x - (xx - raw2.shape[1]) * 0.5
- y1 = y - (yy - raw1.shape[0]) * 0.5
- y2 = y - (yy - raw2.shape[0]) * 0.5
- elif alignmentMode == AlignmentMode.STRETCH:
- x1 = x
- y1 = y
- x2 = x * raw2.shape[1] / raw1.shape[1]
- y2 = x * raw2.shape[1] / raw1.shape[1]
- elif alignmentMode == AlignmentMode.AUTO:
- x1 = x
- y1 = y
- # Not implemented
- data2 = "Not implemented with sift"
- else:
- assert(False)
-
- x1, y1 = int(x1), int(y1)
- if raw1 is None or y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]:
- data1 = None
- else:
- data1 = raw1[y1, x1]
-
- if data2 is None:
- x2, y2 = int(x2), int(y2)
- if raw2 is None or y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]:
- data2 = None
- else:
- data2 = raw2[y2, x2]
-
- return data1, data2
-
- def setVisualizationMode(self, mode):
- """Set the visualization mode.
-
- :param str mode: New visualization to display the image comparison
- """
- if self.__visualizationMode == mode:
- return
- self.__visualizationMode = mode
- mode = self.getVisualizationMode()
- self.__vline.setVisible(mode == VisualizationMode.VERTICAL_LINE)
- self.__hline.setVisible(mode == VisualizationMode.HORIZONTAL_LINE)
- self.__updateData()
- self.sigConfigurationChanged.emit()
-
- def getVisualizationMode(self):
- """Returns the current interaction mode."""
- return self.__visualizationMode
-
- def setAlignmentMode(self, mode):
- """Set the alignment mode.
-
- :param str mode: New alignement to apply to images
- """
- if self.__alignmentMode == mode:
- return
- self.__alignmentMode = mode
- self.__updateData()
- self.sigConfigurationChanged.emit()
-
- def getAlignmentMode(self):
- """Returns the current selected alignemnt mode."""
- return self.__alignmentMode
-
- def setKeypointsVisible(self, isVisible):
- """Set keypoints visibility.
-
- :param bool isVisible: If True, keypoints are displayed (if some)
- """
- if self.__keypointsVisible == isVisible:
- return
- self.__keypointsVisible = isVisible
- self.__updateKeyPoints()
- self.sigConfigurationChanged.emit()
-
- def __setDefaultAlignmentMode(self):
- """Reset the alignemnt mode to the default value"""
- self.setAlignmentMode(AlignmentMode.ORIGIN)
-
- def __plotSlot(self, event):
- """Handle events from the plot"""
- if event['event'] in ('markerMoving', 'markerMoved'):
- mode = self.getVisualizationMode()
- legend = mode.name
- if event['label'] == legend:
- if mode == VisualizationMode.VERTICAL_LINE:
- value = int(float(str(event['xdata'])))
- elif mode == VisualizationMode.HORIZONTAL_LINE:
- value = int(float(str(event['ydata'])))
- else:
- assert(False)
- if self.__previousSeparatorPosition != value:
- self.__separatorMoved(value)
- self.__previousSeparatorPosition = value
-
- def __separatorConstraint(self, x, y):
- """Manage contains on the separators to clamp them inside the images."""
- if self.__data1 is None:
- return 0, 0
- x = int(x)
- if x < 0:
- x = 0
- elif x > self.__data1.shape[1]:
- x = self.__data1.shape[1]
- y = int(y)
- if y < 0:
- y = 0
- elif y > self.__data1.shape[0]:
- y = self.__data1.shape[0]
- return x, y
-
- def __updateSeparators(self):
- """Redraw images according to the current state of the separators.
- """
- mode = self.getVisualizationMode()
- if mode == VisualizationMode.VERTICAL_LINE:
- pos = self.__vline.getXPosition()
- self.__separatorMoved(pos)
- self.__previousSeparatorPosition = pos
- elif mode == VisualizationMode.HORIZONTAL_LINE:
- pos = self.__hline.getYPosition()
- self.__separatorMoved(pos)
- self.__previousSeparatorPosition = pos
- else:
- self.__image1.setOrigin((0, 0))
- self.__image2.setOrigin((0, 0))
-
- def __separatorMoved(self, pos):
- """Called when vertical or horizontal separators have moved.
-
- Update the displayed images.
- """
- if self.__data1 is None:
- return
-
- mode = self.getVisualizationMode()
- if mode == VisualizationMode.VERTICAL_LINE:
- pos = int(pos)
- if pos <= 0:
- pos = 0
- elif pos >= self.__data1.shape[1]:
- pos = self.__data1.shape[1]
- data1 = self.__data1[:, 0:pos]
- data2 = self.__data2[:, pos:]
- self.__image1.setData(data1, copy=False)
- self.__image2.setData(data2, copy=False)
- self.__image2.setOrigin((pos, 0))
- elif mode == VisualizationMode.HORIZONTAL_LINE:
- pos = int(pos)
- if pos <= 0:
- pos = 0
- elif pos >= self.__data1.shape[0]:
- pos = self.__data1.shape[0]
- data1 = self.__data1[0:pos, :]
- data2 = self.__data2[pos:, :]
- self.__image1.setData(data1, copy=False)
- self.__image2.setData(data2, copy=False)
- self.__image2.setOrigin((0, pos))
- else:
- assert(False)
-
- def setData(self, image1, image2):
- """Set images to compare.
-
- Images can contains floating-point or integer values, or RGB and RGBA
- values, but should have comparable intensities.
-
- RGB and RGBA images are provided as an array as `[width,height,channels]`
- of usigned integer 8-bits or floating-points between 0.0 to 1.0.
-
- :param numpy.ndarray image1: The first image
- :param numpy.ndarray image2: The second image
- """
- self.__raw1 = image1
- self.__raw2 = image2
- self.__updateData()
- if self.isAutoResetZoom():
- self.__plot.resetZoom()
-
- def setImage1(self, image1):
- """Set image1 to be compared.
-
- Images can contains floating-point or integer values, or RGB and RGBA
- values, but should have comparable intensities.
-
- RGB and RGBA images are provided as an array as `[width,height,channels]`
- of usigned integer 8-bits or floating-points between 0.0 to 1.0.
-
- :param numpy.ndarray image1: The first image
- """
- self.__raw1 = image1
- self.__updateData()
- if self.isAutoResetZoom():
- self.__plot.resetZoom()
-
- def setImage2(self, image2):
- """Set image2 to be compared.
-
- Images can contains floating-point or integer values, or RGB and RGBA
- values, but should have comparable intensities.
-
- RGB and RGBA images are provided as an array as `[width,height,channels]`
- of usigned integer 8-bits or floating-points between 0.0 to 1.0.
-
- :param numpy.ndarray image2: The second image
- """
- self.__raw2 = image2
- self.__updateData()
- if self.isAutoResetZoom():
- self.__plot.resetZoom()
-
- def __updateKeyPoints(self):
- """Update the displayed keypoints using cached keypoints.
- """
- if self.__keypointsVisible:
- data = self.__matching_keypoints
- else:
- data = [], [], []
- self.__plot.addScatter(x=data[0],
- y=data[1],
- z=1,
- value=data[2],
- colormap=self._colormapKeyPoints,
- legend="keypoints")
-
- def __updateData(self):
- """Compute aligned image when the alignment mode changes.
-
- This function cache input images which are used when
- vertical/horizontal separators moves.
- """
- raw1, raw2 = self.__raw1, self.__raw2
- if raw1 is None or raw2 is None:
- return
-
- alignmentMode = self.getAlignmentMode()
- self.__transformation = None
-
- if alignmentMode == AlignmentMode.ORIGIN:
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- size = yy, xx
- data1 = self.__createMarginImage(raw1, size, transparent=True)
- data2 = self.__createMarginImage(raw2, size, transparent=True)
- self.__matching_keypoints = [0.0], [0.0], [1.0]
- elif alignmentMode == AlignmentMode.CENTER:
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- size = yy, xx
- data1 = self.__createMarginImage(raw1, size, transparent=True, center=True)
- data2 = self.__createMarginImage(raw2, size, transparent=True, center=True)
- self.__matching_keypoints = ([data1.shape[1] // 2],
- [data1.shape[0] // 2],
- [1.0])
- elif alignmentMode == AlignmentMode.STRETCH:
- data1 = raw1
- data2 = self.__rescaleImage(raw2, data1.shape)
- self.__matching_keypoints = ([0, data1.shape[1], data1.shape[1], 0],
- [0, 0, data1.shape[0], data1.shape[0]],
- [1.0, 1.0, 1.0, 1.0])
- elif alignmentMode == AlignmentMode.AUTO:
- # TODO: sift implementation do not support RGBA images
- yy = max(raw1.shape[0], raw2.shape[0])
- xx = max(raw1.shape[1], raw2.shape[1])
- size = yy, xx
- data1 = self.__createMarginImage(raw1, size)
- data2 = self.__createMarginImage(raw2, size)
- self.__matching_keypoints = [0.0], [0.0], [1.0]
- try:
- data1, data2 = self.__createSiftData(data1, data2)
- if data2 is None:
- raise ValueError("Unexpected None value")
- except Exception as e:
- # TODO: Display it on the GUI
- _logger.error(e)
- self.__setDefaultAlignmentMode()
- return
- else:
- assert(False)
-
- mode = self.getVisualizationMode()
- if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
- data1 = self.__composeImage(data1, data2, mode)
- data2 = numpy.empty((0, 0))
- elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
- data1 = self.__composeImage(data1, data2, mode)
- data2 = numpy.empty((0, 0))
- elif mode == VisualizationMode.COMPOSITE_A_MINUS_B:
- data1 = self.__composeImage(data1, data2, mode)
- data2 = numpy.empty((0, 0))
- elif mode == VisualizationMode.ONLY_A:
- data2 = numpy.empty((0, 0))
- elif mode == VisualizationMode.ONLY_B:
- data1 = numpy.empty((0, 0))
-
- self.__data1, self.__data2 = data1, data2
- self.__plot.addImage(data1, z=0, legend="image1", resetzoom=False)
- self.__plot.addImage(data2, z=0, legend="image2", resetzoom=False)
- self.__image1 = self.__plot.getImage("image1")
- self.__image2 = self.__plot.getImage("image2")
- self.__updateKeyPoints()
-
- # Set the separator into the middle
- if self.__previousSeparatorPosition is None:
- value = self.__data1.shape[1] // 2
- self.__vline.setPosition(value, 0)
- value = self.__data1.shape[0] // 2
- self.__hline.setPosition(0, value)
- self.__updateSeparators()
-
- # Avoid to change the colormap range when the separator is moving
- # TODO: The colormap histogram will still be wrong
- mode1 = self.__getImageMode(data1)
- mode2 = self.__getImageMode(data2)
- if mode1 == "intensity" and mode1 == mode2:
- if self.__data1.size == 0:
- vmin = self.__data2.min()
- vmax = self.__data2.max()
- elif self.__data2.size == 0:
- vmin = self.__data1.min()
- vmax = self.__data1.max()
- else:
- vmin = min(self.__data1.min(), self.__data2.min())
- vmax = max(self.__data1.max(), self.__data2.max())
- colormap = self.getColormap()
- colormap.setVRange(vmin=vmin, vmax=vmax)
- self.__image1.setColormap(colormap)
- self.__image2.setColormap(colormap)
-
- def __getImageMode(self, image):
- """Returns a value identifying the way the image is stored in the
- array.
-
- :param numpy.ndarray image: Image to check
- :rtype: str
- """
- if len(image.shape) == 2:
- return "intensity"
- elif len(image.shape) == 3:
- if image.shape[2] == 3:
- return "rgb"
- elif image.shape[2] == 4:
- return "rgba"
- raise TypeError("'image' argument is not an image.")
-
- def __rescaleImage(self, image, shape):
- """Rescale an image to the requested shape.
-
- :rtype: numpy.ndarray
- """
- mode = self.__getImageMode(image)
- if mode == "intensity":
- data = self.__rescaleArray(image, shape)
- elif mode == "rgb":
- data = numpy.empty((shape[0], shape[1], 3), dtype=image.dtype)
- for c in range(3):
- data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
- elif mode == "rgba":
- data = numpy.empty((shape[0], shape[1], 4), dtype=image.dtype)
- for c in range(4):
- data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
- return data
-
- def __composeImage(self, data1, data2, mode):
- """Returns an RBG image containing composition of data1 and data2 in 2
- different channels
-
- :param numpy.ndarray data1: First image
- :param numpy.ndarray data1: Second image
- :param VisualizationMode mode: Composition mode.
- :rtype: numpy.ndarray
- """
- assert(data1.shape[0:2] == data2.shape[0:2])
- if mode == VisualizationMode.COMPOSITE_A_MINUS_B:
- # TODO: this calculation has no interest of generating a 'composed'
- # rgb image, this could be moved in an other function or doc
- # should be modified
- _type = data1.dtype
- result = data1.astype(numpy.float64) - data2.astype(numpy.float64)
- return result
- mode1 = self.__getImageMode(data1)
- if mode1 in ["rgb", "rgba"]:
- intensity1 = self.__luminosityImage(data1)
- vmin1, vmax1 = 0.0, 1.0
- else:
- intensity1 = data1
- vmin1, vmax1 = data1.min(), data1.max()
-
- mode2 = self.__getImageMode(data2)
- if mode2 in ["rgb", "rgba"]:
- intensity2 = self.__luminosityImage(data2)
- vmin2, vmax2 = 0.0, 1.0
- else:
- intensity2 = data2
- vmin2, vmax2 = data2.min(), data2.max()
-
- vmin, vmax = min(vmin1, vmin2) * 1.0, max(vmax1, vmax2) * 1.0
- shape = data1.shape
- result = numpy.empty((shape[0], shape[1], 3), dtype=numpy.uint8)
- a = (intensity1 - vmin) * (1.0 / (vmax - vmin)) * 255.0
- b = (intensity2 - vmin) * (1.0 / (vmax - vmin)) * 255.0
- if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
- result[:, :, 0] = a
- result[:, :, 1] = (a + b) / 2
- result[:, :, 2] = b
- elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
- result[:, :, 0] = 255 - b
- result[:, :, 1] = 255 - (a + b) / 2
- result[:, :, 2] = 255 - a
- return result
-
- def __luminosityImage(self, image):
- """Returns the luminosity channel from an RBG(A) image.
- The alpha channel is ignored.
-
- :rtype: numpy.ndarray
- """
- mode = self.__getImageMode(image)
- assert(mode in ["rgb", "rgba"])
- is_uint8 = image.dtype.type == numpy.uint8
- # luminosity
- image = 0.21 * image[..., 0] + 0.72 * image[..., 1] + 0.07 * image[..., 2]
- if is_uint8:
- image = image / 255.0
- return image
-
- def __rescaleArray(self, image, shape):
- """Rescale a 2D array to the requested shape.
-
- :rtype: numpy.ndarray
- """
- y, x = numpy.ogrid[:shape[0], :shape[1]]
- y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (image.shape[1] - 1) / (shape[1] - 1)
- b = silx.image.bilinear.BilinearImage(image)
- # TODO: could be optimized using strides
- x2d = numpy.zeros_like(y) + x
- y2d = numpy.zeros_like(x) + y
- result = b.map_coordinates((y2d, x2d))
- return result
-
- def __createMarginImage(self, image, size, transparent=False, center=False):
- """Returns a new image with margin to respect the requested size.
-
- :rtype: numpy.ndarray
- """
- assert(image.shape[0] <= size[0])
- assert(image.shape[1] <= size[1])
- if image.shape == size:
- return image
- mode = self.__getImageMode(image)
-
- if center:
- pos0 = size[0] // 2 - image.shape[0] // 2
- pos1 = size[1] // 2 - image.shape[1] // 2
- else:
- pos0, pos1 = 0, 0
-
- if mode == "intensity":
- data = numpy.zeros(size, dtype=image.dtype)
- data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1]] = image
- # TODO: It is maybe possible to put NaN on the margin
- else:
- if transparent:
- data = numpy.zeros((size[0], size[1], 4), dtype=numpy.uint8)
- else:
- data = numpy.zeros((size[0], size[1], 3), dtype=numpy.uint8)
- depth = min(data.shape[2], image.shape[2])
- data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 0:depth] = image[:, :, 0:depth]
- if transparent and depth == 3:
- data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 3] = 255
- return data
-
- def __toAffineTransformation(self, sift_result):
- """Returns an affine transformation from the sift result.
-
- :param dict sift_result: Result of sift when using `all_result=True`
- :rtype: AffineTransformation
- """
- offset = sift_result["offset"]
- matrix = sift_result["matrix"]
-
- tx = offset[0]
- ty = offset[1]
- a = matrix[0, 0]
- b = matrix[0, 1]
- c = matrix[1, 0]
- d = matrix[1, 1]
- rot = math.atan2(-b, a)
- sx = (-1.0 if a < 0 else 1.0) * math.sqrt(a**2 + b**2)
- sy = (-1.0 if d < 0 else 1.0) * math.sqrt(c**2 + d**2)
- return AffineTransformation(tx, ty, sx, sy, rot)
-
- def getTransformation(self):
- """Retuns the affine transformation applied to the second image to align
- it to the first image.
-
- This result is only valid for sift alignment.
-
- :rtype: Union[None,AffineTransformation]
- """
- return self.__transformation
-
- def __createSiftData(self, image, second_image):
- """Generate key points and aligned images from 2 images.
-
- If no keypoints matches, unaligned data are anyway returns.
-
- :rtype: Tuple(numpy.ndarray,numpy.ndarray)
- """
- devicetype = "GPU"
-
- # Compute base image
- sift_ocl = sift.SiftPlan(template=image, devicetype=devicetype)
- keypoints = sift_ocl(image)
-
- # Check image compatibility
- second_keypoints = sift_ocl(second_image)
- mp = sift.MatchPlan()
- match = mp(keypoints, second_keypoints)
- _logger.info("Number of Keypoints within image 1: %i" % keypoints.size)
- _logger.info(" within image 2: %i" % second_keypoints.size)
-
- self.__matching_keypoints = (match[:].x[:, 0],
- match[:].y[:, 0],
- match[:].scale[:, 0])
- matching_keypoints = match.shape[0]
- _logger.info("Matching keypoints: %i" % matching_keypoints)
- if matching_keypoints == 0:
- return image, second_image
-
- # TODO: Problem here is we have to compute 2 time sift
- # The first time to extract matching keypoints, second time
- # to extract the aligned image.
-
- # Normalize the second image
- sa = sift.LinearAlign(image, devicetype=devicetype)
- data1 = image
- # TODO: Create a sift issue: if data1 is RGB and data2 intensity
- # it returns None, while extracting manually keypoints (above) works
- result = sa.align(second_image, return_all=True)
- data2 = result["result"]
- self.__transformation = self.__toAffineTransformation(result)
- return data1, data2
-
- def setAutoResetZoom(self, activate=True):
- """
-
- :param bool activate: True if we want to activate the automatic
- plot reset zoom when setting images.
- """
- self._resetZoomActive = activate
-
- def isAutoResetZoom(self):
- """
-
- :return: True if the automatic call to resetzoom is activated
- :rtype: bool
- """
- return self._resetZoomActive
diff --git a/silx/gui/plot/ComplexImageView.py b/silx/gui/plot/ComplexImageView.py
deleted file mode 100644
index dc6bf63..0000000
--- a/silx/gui/plot/ComplexImageView.py
+++ /dev/null
@@ -1,518 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a widget to view 2D complex data.
-
-The :class:`ComplexImageView` widget is dedicated to visualize a single 2D dataset
-of complex data.
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import logging
-import collections
-import numpy
-
-from ...utils.deprecation import deprecated
-from .. import qt, icons
-from .PlotWindow import Plot2D
-from . import items
-from .items import ImageComplexData
-from silx.gui.widgets.FloatEdit import FloatEdit
-
-_logger = logging.getLogger(__name__)
-
-
-# Widgets
-
-class _AmplitudeRangeDialog(qt.QDialog):
- """QDialog asking for the amplitude range to display."""
-
- sigRangeChanged = qt.Signal(tuple)
- """Signal emitted when the range has changed.
-
- It provides the new range as a 2-tuple: (max, delta)
- """
-
- def __init__(self,
- parent=None,
- amplitudeRange=None,
- displayedRange=(None, 2)):
- super(_AmplitudeRangeDialog, self).__init__(parent)
- self.setWindowTitle('Set Displayed Amplitude Range')
-
- if amplitudeRange is not None:
- amplitudeRange = min(amplitudeRange), max(amplitudeRange)
- self._amplitudeRange = amplitudeRange
- self._defaultDisplayedRange = displayedRange
-
- layout = qt.QFormLayout()
- self.setLayout(layout)
-
- if self._amplitudeRange is not None:
- min_, max_ = self._amplitudeRange
- layout.addRow(
- qt.QLabel('Data Amplitude Range: [%g, %g]' % (min_, max_)))
-
- self._maxLineEdit = FloatEdit(parent=self)
- self._maxLineEdit.validator().setBottom(0.)
- self._maxLineEdit.setAlignment(qt.Qt.AlignRight)
-
- self._maxLineEdit.editingFinished.connect(self._rangeUpdated)
- layout.addRow('Displayed Max.:', self._maxLineEdit)
-
- self._autoscale = qt.QCheckBox('autoscale')
- self._autoscale.toggled.connect(self._autoscaleCheckBoxToggled)
- layout.addRow('', self._autoscale)
-
- self._deltaLineEdit = FloatEdit(parent=self)
- self._deltaLineEdit.validator().setBottom(1.)
- self._deltaLineEdit.setAlignment(qt.Qt.AlignRight)
- self._deltaLineEdit.editingFinished.connect(self._rangeUpdated)
- layout.addRow('Displayed delta (log10 unit):', self._deltaLineEdit)
-
- buttons = qt.QDialogButtonBox(self)
- buttons.addButton(qt.QDialogButtonBox.Ok)
- buttons.addButton(qt.QDialogButtonBox.Cancel)
- buttons.accepted.connect(self.accept)
- buttons.rejected.connect(self.reject)
- layout.addRow(buttons)
-
- # Set dialog from default values
- self._resetDialogToDefault()
-
- self.rejected.connect(self._handleRejected)
-
- def _resetDialogToDefault(self):
- """Set Widgets of the dialog from range information
- """
- max_, delta = self._defaultDisplayedRange
-
- if max_ is not None: # Not in autoscale
- displayedMax = max_
- elif self._amplitudeRange is not None: # Autoscale with data
- displayedMax = self._amplitudeRange[1]
- else: # Autoscale without data
- displayedMax = ''
- if displayedMax == "":
- self._maxLineEdit.setText("")
- else:
- self._maxLineEdit.setValue(displayedMax)
- self._maxLineEdit.setEnabled(max_ is not None)
-
- self._deltaLineEdit.setValue(delta)
-
- self._autoscale.setChecked(self._defaultDisplayedRange[0] is None)
-
- def getRangeInfo(self):
- """Returns the current range as a 2-tuple (max, delta (in log10))"""
- if self._autoscale.isChecked():
- max_ = None
- else:
- maxStr = self._maxLineEdit.text()
- max_ = self._maxLineEdit.value() if maxStr else None
- return max_, self._deltaLineEdit.value() if self._deltaLineEdit.text() else 2
-
- def _handleRejected(self):
- """Reset range info to default when rejected"""
- self._resetDialogToDefault()
- self._rangeUpdated()
-
- def _rangeUpdated(self):
- """Handle QLineEdit editing finised"""
- self.sigRangeChanged.emit(self.getRangeInfo())
-
- def _autoscaleCheckBoxToggled(self, checked):
- """Handle autoscale checkbox state changes"""
- if checked: # Use default values
- if self._amplitudeRange is None:
- max_ = ''
- else:
- max_ = self._amplitudeRange[1]
- if max_ == "":
- self._maxLineEdit.setText("")
- else:
- self._maxLineEdit.setValue(max_)
- self._maxLineEdit.setEnabled(not checked)
- self._rangeUpdated()
-
-
-class _ComplexDataToolButton(qt.QToolButton):
- """QToolButton providing choices of complex data visualization modes
-
- :param parent: See :class:`QToolButton`
- :param plot: The :class:`ComplexImageView` to control
- """
-
- _MODES = collections.OrderedDict([
- (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')),
- (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE,
- ('math-square-amplitude', 'Square amplitude')),
- (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')),
- (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')),
- (ImageComplexData.ComplexMode.IMAGINARY,
- ('math-imaginary', 'Imaginary part')),
- (ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
- ('math-phase-color', 'Amplitude and Phase')),
- (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE,
- ('math-phase-color-log', 'Log10(Amp.) and Phase'))
- ])
-
- _RANGE_DIALOG_TEXT = 'Set Amplitude Range...'
-
- def __init__(self, parent=None, plot=None):
- super(_ComplexDataToolButton, self).__init__(parent=parent)
-
- assert plot is not None
- self._plot2DComplex = plot
-
- menu = qt.QMenu(self)
- menu.triggered.connect(self._triggered)
- self.setMenu(menu)
-
- for mode, info in self._MODES.items():
- icon, text = info
- action = qt.QAction(icons.getQIcon(icon), text, self)
- action.setData(mode)
- action.setIconVisibleInMenu(True)
- menu.addAction(action)
-
- self._rangeDialogAction = qt.QAction(self)
- self._rangeDialogAction.setText(self._RANGE_DIALOG_TEXT)
- menu.addAction(self._rangeDialogAction)
-
- self.setPopupMode(qt.QToolButton.InstantPopup)
-
- self._modeChanged(self._plot2DComplex.getComplexMode())
- self._plot2DComplex.sigVisualizationModeChanged.connect(
- self._modeChanged)
-
- def _modeChanged(self, mode):
- """Handle change of visualization modes"""
- icon, text = self._MODES[mode]
- self.setIcon(icons.getQIcon(icon))
- self.setToolTip('Display the ' + text.lower())
- self._rangeDialogAction.setEnabled(
- mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE)
-
- def _triggered(self, action):
- """Handle triggering of menu actions"""
- actionText = action.text()
-
- if actionText == self._RANGE_DIALOG_TEXT: # Show dialog
- # Get amplitude range
- data = self._plot2DComplex.getData(copy=False)
-
- if data.size > 0:
- absolute = numpy.absolute(data)
- dataRange = (numpy.nanmin(absolute), numpy.nanmax(absolute))
- else:
- dataRange = None
-
- # Show dialog
- dialog = _AmplitudeRangeDialog(
- parent=self,
- amplitudeRange=dataRange,
- displayedRange=self._plot2DComplex._getAmplitudeRangeInfo())
- dialog.sigRangeChanged.connect(self._rangeChanged)
- dialog.exec_()
- dialog.sigRangeChanged.disconnect(self._rangeChanged)
-
- else: # update mode
- mode = action.data()
- if isinstance(mode, ImageComplexData.ComplexMode):
- self._plot2DComplex.setComplexMode(mode)
-
- def _rangeChanged(self, range_):
- """Handle updates of range in the dialog"""
- self._plot2DComplex._setAmplitudeRangeInfo(*range_)
-
-
-class ComplexImageView(qt.QWidget):
- """Display an image of complex data and allow to choose the visualization.
-
- :param parent: See :class:`QMainWindow`
- """
-
- ComplexMode = ImageComplexData.ComplexMode
- """Complex Modes enumeration"""
-
- sigDataChanged = qt.Signal()
- """Signal emitted when data has changed."""
-
- sigVisualizationModeChanged = qt.Signal(object)
- """Signal emitted when the visualization mode has changed.
-
- It provides the new visualization mode.
- """
-
- def __init__(self, parent=None):
- super(ComplexImageView, self).__init__(parent)
- if parent is None:
- self.setWindowTitle('ComplexImageView')
-
- self._plot2D = Plot2D(self)
-
- layout = qt.QHBoxLayout(self)
- layout.setSpacing(0)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self._plot2D)
- self.setLayout(layout)
-
- # Create and add image to the plot
- self._plotImage = ImageComplexData()
- self._plotImage.setName('__ComplexImageView__complex_image__')
- self._plotImage.sigItemChanged.connect(self._itemChanged)
- self._plot2D.addItem(self._plotImage)
- self._plot2D.setActiveImage(self._plotImage.getName())
-
- toolBar = qt.QToolBar('Complex', self)
- toolBar.addWidget(
- _ComplexDataToolButton(parent=self, plot=self))
-
- self._plot2D.insertToolBar(self._plot2D.getProfileToolbar(), toolBar)
-
- def _itemChanged(self, event):
- """Handle item changed signal"""
- if event is items.ItemChangedType.DATA:
- self.sigDataChanged.emit()
- elif event is items.ItemChangedType.VISUALIZATION_MODE:
- mode = self.getComplexMode()
- self.sigVisualizationModeChanged.emit(mode)
-
- def getPlot(self):
- """Return the PlotWidget displaying the data"""
- return self._plot2D
-
- def setData(self, data=None, copy=True):
- """Set the complex data to display.
-
- :param numpy.ndarray data: 2D complex data
- :param bool copy: True (default) to copy the data,
- False to use provided data (do not modify!).
- """
- if data is None:
- data = numpy.zeros((0, 0), dtype=numpy.complex64)
-
- previousData = self._plotImage.getComplexData(copy=False)
-
- self._plotImage.setData(data, copy=copy)
-
- if previousData.shape != data.shape:
- self.getPlot().resetZoom()
-
- def getData(self, copy=True):
- """Get the currently displayed complex data.
-
- :param bool copy: True (default) to return a copy of the data,
- False to return internal data (do not modify!).
- :return: The complex data array.
- :rtype: numpy.ndarray of complex with 2 dimensions
- """
- return self._plotImage.getComplexData(copy=copy)
-
- def getDisplayedData(self, copy=True):
- """Returns the displayed data depending on the visualization mode
-
- WARNING: The returned data can be a uint8 RGBA image
-
- :param bool copy: True (default) to return a copy of the data,
- False to return internal data (do not modify!)
- :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8).
- """
- mode = self.getComplexMode()
- if mode in (self.ComplexMode.AMPLITUDE_PHASE,
- self.ComplexMode.LOG10_AMPLITUDE_PHASE):
- return self._plotImage.getRgbaImageData(copy=copy)
- else:
- return self._plotImage.getData(copy=copy)
-
- # Backward compatibility
-
- Mode = ComplexMode
-
- @classmethod
- @deprecated(replacement='supportedComplexModes', since_version='0.11.0')
- def getSupportedVisualizationModes(cls):
- return cls.supportedComplexModes()
-
- @deprecated(replacement='setComplexMode', since_version='0.11.0')
- def setVisualizationMode(self, mode):
- return self.setComplexMode(mode)
-
- @deprecated(replacement='getComplexMode', since_version='0.11.0')
- def getVisualizationMode(self):
- return self.getComplexMode()
-
- # Image item proxy
-
- @staticmethod
- def supportedComplexModes():
- """Returns the supported visualization modes.
-
- Supported visualization modes are:
-
- - amplitude: The absolute value provided by numpy.absolute
- - phase: The phase (or argument) provided by numpy.angle
- - real: Real part
- - imaginary: Imaginary part
- - amplitude_phase: Color-coded phase with amplitude as alpha.
- - log10_amplitude_phase:
- Color-coded phase with log10(amplitude) as alpha.
-
- :rtype: List[ComplexMode]
- """
- return ImageComplexData.supportedComplexModes()
-
- def setComplexMode(self, mode):
- """Set the mode of visualization of the complex data.
-
- See :meth:`supportedComplexModes` for the list of
- supported modes.
-
- How-to change visualization mode::
-
- widget = ComplexImageView()
- widget.setComplexMode(ComplexImageView.ComplexMode.PHASE)
- # or
- widget.setComplexMode('phase')
-
- :param Unions[ComplexMode,str] mode: The mode to use.
- """
- self._plotImage.setComplexMode(mode)
-
- def getComplexMode(self):
- """Get the current visualization mode of the complex data.
-
- :rtype: ComplexMode
- """
- return self._plotImage.getComplexMode()
-
- def _setAmplitudeRangeInfo(self, max_=None, delta=2):
- """Set the amplitude range to display for 'log10_amplitude_phase' mode.
-
- :param max_: Max of the amplitude range.
- If None it autoscales to data max.
- :param float delta: Delta range in log10 to display
- """
- self._plotImage._setAmplitudeRangeInfo(max_, delta)
-
- def _getAmplitudeRangeInfo(self):
- """Returns the amplitude range to use for 'log10_amplitude_phase' mode.
-
- :return: (max, delta), if max is None, then it autoscales to data max
- :rtype: 2-tuple"""
- return self._plotImage._getAmplitudeRangeInfo()
-
- def setColormap(self, colormap, mode=None):
- """Set the colormap to use for amplitude, phase, real or imaginary.
-
- WARNING: This colormap is not used when displaying both
- amplitude and phase.
-
- :param ~silx.gui.colors.Colormap colormap: The colormap
- :param ComplexMode mode: If specified, set the colormap of this specific mode
- """
- self._plotImage.setColormap(colormap, mode)
-
- def getColormap(self, mode=None):
- """Returns the colormap used to display the data.
-
- :param ComplexMode mode: If specified, set the colormap of this specific mode
- :rtype: ~silx.gui.colors.Colormap
- """
- return self._plotImage.getColormap(mode=mode)
-
- def getOrigin(self):
- """Returns the offset from origin at which to display the image.
-
- :rtype: 2-tuple of float
- """
- return self._plotImage.getOrigin()
-
- def setOrigin(self, origin):
- """Set the offset from origin at which to display the image.
-
- :param origin: (ox, oy) Offset from origin
- :type origin: float or 2-tuple of float
- """
- self._plotImage.setOrigin(origin)
-
- def getScale(self):
- """Returns the scale of the image in data coordinates.
-
- :rtype: 2-tuple of float
- """
- return self._plotImage.getScale()
-
- def setScale(self, scale):
- """Set the scale of the image
-
- :param scale: (sx, sy) Scale of the image
- :type scale: float or 2-tuple of float
- """
- self._plotImage.setScale(scale)
-
- # PlotWidget API proxy
-
- def getXAxis(self):
- """Returns the X axis
-
- :rtype: :class:`.items.Axis`
- """
- return self.getPlot().getXAxis()
-
- def getYAxis(self):
- """Returns an Y axis
-
- :rtype: :class:`.items.Axis`
- """
- return self.getPlot().getYAxis(axis='left')
-
- def getGraphTitle(self):
- """Return the plot main title as a str."""
- return self.getPlot().getGraphTitle()
-
- def setGraphTitle(self, title=""):
- """Set the plot main title.
-
- :param str title: Main title of the plot (default: '')
- """
- self.getPlot().setGraphTitle(title)
-
- def setKeepDataAspectRatio(self, flag):
- """Set whether the plot keeps data aspect ratio or not.
-
- :param bool flag: True to respect data aspect ratio
- """
- self.getPlot().setKeepDataAspectRatio(flag)
-
- def isKeepDataAspectRatio(self):
- """Returns whether the plot is keeping data aspect ratio or not."""
- return self.getPlot().isKeepDataAspectRatio()
diff --git a/silx/gui/plot/CurvesROIWidget.py b/silx/gui/plot/CurvesROIWidget.py
deleted file mode 100644
index 5c9033e..0000000
--- a/silx/gui/plot/CurvesROIWidget.py
+++ /dev/null
@@ -1,1584 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-Widget to handle regions of interest (:class:`ROI`) on curves displayed in a
-:class:`PlotWindow`.
-
-This widget is meant to work with :class:`PlotWindow`.
-"""
-
-__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
-__license__ = "MIT"
-__date__ = "13/03/2018"
-
-from collections import OrderedDict
-import logging
-import os
-import sys
-import functools
-import numpy
-from silx.io import dictdump
-from silx.utils import deprecation
-from silx.utils.weakref import WeakMethodProxy
-from silx.utils.proxy import docstring
-from .. import icons, qt
-from silx.math.combo import min_max
-import weakref
-from silx.gui.widgets.TableWidget import TableWidget
-from . import items
-from .items.roi import _RegionOfInterestBase
-
-
-_logger = logging.getLogger(__name__)
-
-
-class CurvesROIWidget(qt.QWidget):
- """
- Widget displaying a table of ROI information.
-
- Implements also the following behavior:
-
- * if the roiTable has no ROI when showing create the default ICR one
-
- :param parent: See :class:`QWidget`
- :param str name: The title of this widget
- """
-
- sigROIWidgetSignal = qt.Signal(object)
- """Signal of ROIs modifications.
-
- Modification information if given as a dict with an 'event' key
- providing the type of events.
-
- Type of events:
-
- - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
- - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
- 'rowheader'
- """
-
- sigROISignal = qt.Signal(object)
-
- def __init__(self, parent=None, name=None, plot=None):
- super(CurvesROIWidget, self).__init__(parent)
- if name is not None:
- self.setWindowTitle(name)
- self.__lastSigROISignal = None
- """Store the last value emitted for the sigRoiSignal. In the case the
- active curve change we need to add this extra step in order to make
- sure we won't send twice the sigROISignal.
- This come from the fact sigROISignal is connected to the
- activeROIChanged signal which is emitted when raw and net counts
- values are changing but are not embed in the sigROISignal.
- """
- assert plot is not None
- self._plotRef = weakref.ref(plot)
- self._showAllMarkers = False
- self.currentROI = None
-
- layout = qt.QVBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
-
- self.headerLabel = qt.QLabel(self)
- self.headerLabel.setAlignment(qt.Qt.AlignHCenter)
- self.setHeader()
- layout.addWidget(self.headerLabel)
-
- widgetAllCheckbox = qt.QWidget(parent=self)
- self._showAllCheckBox = qt.QCheckBox("show all ROI",
- parent=widgetAllCheckbox)
- widgetAllCheckbox.setLayout(qt.QHBoxLayout())
- spacer = qt.QWidget(parent=widgetAllCheckbox)
- spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
- widgetAllCheckbox.layout().addWidget(spacer)
- widgetAllCheckbox.layout().addWidget(self._showAllCheckBox)
- layout.addWidget(widgetAllCheckbox)
-
- self.roiTable = ROITable(self, plot=plot)
- rheight = self.roiTable.horizontalHeader().sizeHint().height()
- self.roiTable.setMinimumHeight(4 * rheight)
- layout.addWidget(self.roiTable)
- self._roiFileDir = qt.QDir.home().absolutePath()
- self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers)
-
- hbox = qt.QWidget(self)
- hboxlayout = qt.QHBoxLayout(hbox)
- hboxlayout.setContentsMargins(0, 0, 0, 0)
- hboxlayout.setSpacing(0)
-
- hboxlayout.addStretch(0)
-
- self.addButton = qt.QPushButton(hbox)
- self.addButton.setText("Add ROI")
- self.addButton.setToolTip('Create a new ROI')
- self.delButton = qt.QPushButton(hbox)
- self.delButton.setText("Delete ROI")
- self.addButton.setToolTip('Remove the selected ROI')
- self.resetButton = qt.QPushButton(hbox)
- self.resetButton.setText("Reset")
- self.addButton.setToolTip('Clear all created ROIs. We only let the '
- 'default ROI')
-
- hboxlayout.addWidget(self.addButton)
- hboxlayout.addWidget(self.delButton)
- hboxlayout.addWidget(self.resetButton)
-
- hboxlayout.addStretch(0)
-
- self.loadButton = qt.QPushButton(hbox)
- self.loadButton.setText("Load")
- self.loadButton.setToolTip('Load ROIs from a .ini file')
- self.saveButton = qt.QPushButton(hbox)
- self.saveButton.setText("Save")
- self.loadButton.setToolTip('Save ROIs to a .ini file')
- hboxlayout.addWidget(self.loadButton)
- hboxlayout.addWidget(self.saveButton)
- layout.setStretchFactor(self.headerLabel, 0)
- layout.setStretchFactor(self.roiTable, 1)
- layout.setStretchFactor(hbox, 0)
-
- layout.addWidget(hbox)
-
- # Signal / Slot connections
- self.addButton.clicked.connect(self._add)
- self.delButton.clicked.connect(self._del)
- self.resetButton.clicked.connect(self._reset)
-
- self.loadButton.clicked.connect(self._load)
- self.saveButton.clicked.connect(self._save)
-
- self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal)
-
- self._isConnected = False # True if connected to plot signals
- self._isInit = False
-
- # expose API
- self.getROIListAndDict = self.roiTable.getROIListAndDict
-
- def getPlotWidget(self):
- """Returns the associated PlotWidget or None
-
- :rtype: Union[~silx.gui.plot.PlotWidget,None]
- """
- return None if self._plotRef is None else self._plotRef()
-
- def showEvent(self, event):
- self._visibilityChangedHandler(visible=True)
- qt.QWidget.showEvent(self, event)
-
- @property
- def roiFileDir(self):
- """The directory from which to load/save ROI from/to files."""
- if not os.path.isdir(self._roiFileDir):
- self._roiFileDir = qt.QDir.home().absolutePath()
- return self._roiFileDir
-
- @roiFileDir.setter
- def roiFileDir(self, roiFileDir):
- self._roiFileDir = str(roiFileDir)
-
- def setRois(self, rois, order=None):
- return self.roiTable.setRois(rois, order)
-
- def getRois(self, order=None):
- return self.roiTable.getRois(order)
-
- def setMiddleROIMarkerFlag(self, flag=True):
- return self.roiTable.setMiddleROIMarkerFlag(flag)
-
- def _add(self):
- """Add button clicked handler"""
- def getNextRoiName():
- rois = self.roiTable.getRois(order=None)
- roisNames = []
- [roisNames.append(roiName) for roiName in rois]
- nrois = len(rois)
- if nrois == 0:
- return "ICR"
- else:
- i = 1
- newroi = "newroi %d" % i
- while newroi in roisNames:
- i += 1
- newroi = "newroi %d" % i
- return newroi
- roi = ROI(name=getNextRoiName())
-
- if roi.getName() == "ICR":
- roi.setType("Default")
- else:
- roi.setType(self.getPlotWidget().getXAxis().getLabel())
-
- xmin, xmax = self.getPlotWidget().getXAxis().getLimits()
- fromdata = xmin + 0.25 * (xmax - xmin)
- todata = xmin + 0.75 * (xmax - xmin)
- if roi.isICR():
- fromdata, dummy0, todata, dummy1 = self._getAllLimits()
- roi.setFrom(fromdata)
- roi.setTo(todata)
- self.roiTable.addRoi(roi)
-
- # back compatibility pymca roi signals
- ddict = {}
- ddict['event'] = "AddROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
- self.sigROIWidgetSignal.emit(ddict)
- # end back compatibility pymca roi signals
-
- def _del(self):
- """Delete button clicked handler"""
- self.roiTable.deleteActiveRoi()
-
- # back compatibility pymca roi signals
- ddict = {}
- ddict['event'] = "DelROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
- self.sigROIWidgetSignal.emit(ddict)
- # end back compatibility pymca roi signals
-
- def _reset(self):
- """Reset button clicked handler"""
- self.roiTable.clear()
- old = self.blockSignals(True) # avoid several sigROISignal emission
- self._add()
- self.blockSignals(old)
-
- # back compatibility pymca roi signals
- ddict = {}
- ddict['event'] = "ResetROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
- self.sigROIWidgetSignal.emit(ddict)
- # end back compatibility pymca roi signals
-
- def _load(self):
- """Load button clicked handler"""
- dialog = qt.QFileDialog(self)
- dialog.setNameFilters(
- ['INI File *.ini', 'JSON File *.json', 'All *.*'])
- dialog.setFileMode(qt.QFileDialog.ExistingFile)
- dialog.setDirectory(self.roiFileDir)
- if not dialog.exec_():
- dialog.close()
- return
-
- # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494
- outputFile = dialog.selectedFiles()[0]
- dialog.close()
-
- self.roiFileDir = os.path.dirname(outputFile)
- self.roiTable.load(outputFile)
-
- # back compatibility pymca roi signals
- ddict = {}
- ddict['event'] = "LoadROI"
- ddict['roilist'] = self.roiTable.roidict.values()
- ddict['roidict'] = self.roiTable.roidict
- self.sigROIWidgetSignal.emit(ddict)
- # end back compatibility pymca roi signals
-
- def load(self, filename):
- """Load ROI widget information from a file storing a dict of ROI.
-
- :param str filename: The file from which to load ROI
- """
- self.roiTable.load(filename)
-
- def _save(self):
- """Save button clicked handler"""
- dialog = qt.QFileDialog(self)
- dialog.setNameFilters(['INI File *.ini', 'JSON File *.json'])
- dialog.setFileMode(qt.QFileDialog.AnyFile)
- dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
- dialog.setDirectory(self.roiFileDir)
- if not dialog.exec_():
- dialog.close()
- return
-
- outputFile = dialog.selectedFiles()[0]
- extension = '.' + dialog.selectedNameFilter().split('.')[-1]
- dialog.close()
-
- if not outputFile.endswith(extension):
- outputFile += extension
-
- if os.path.exists(outputFile):
- try:
- os.remove(outputFile)
- except IOError:
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setText("Input Output Error: %s" % (sys.exc_info()[1]))
- msg.exec_()
- return
- self.roiFileDir = os.path.dirname(outputFile)
- self.save(outputFile)
-
- def save(self, filename):
- """Save current ROIs of the widget as a dict of ROI to a file.
-
- :param str filename: The file to which to save the ROIs
- """
- self.roiTable.save(filename)
-
- def setHeader(self, text='ROIs'):
- """Set the header text of this widget"""
- self.headerLabel.setText("<b>%s<\b>" % text)
-
- @deprecation.deprecated(replacement="calculateRois",
- reason="CamelCase convention",
- since_version="0.7")
- def calculateROIs(self, *args, **kw):
- self.calculateRois(*args, **kw)
-
- def calculateRois(self, roiList=None, roiDict=None):
- """Compute ROI information"""
- return self.roiTable.calculateRois()
-
- def showAllMarkers(self, _show=True):
- self.roiTable.showAllMarkers(_show)
-
- def _getAllLimits(self):
- """Retrieve the limits based on the curves."""
- plot = self.getPlotWidget()
- curves = () if plot is None else plot.getAllCurves()
- if not curves:
- return 1.0, 1.0, 100., 100.
-
- xmin, ymin = None, None
- xmax, ymax = None, None
-
- for curve in curves:
- x = curve.getXData(copy=False)
- y = curve.getYData(copy=False)
- if xmin is None:
- xmin = x.min()
- else:
- xmin = min(xmin, x.min())
- if xmax is None:
- xmax = x.max()
- else:
- xmax = max(xmax, x.max())
- if ymin is None:
- ymin = y.min()
- else:
- ymin = min(ymin, y.min())
- if ymax is None:
- ymax = y.max()
- else:
- ymax = max(ymax, y.max())
-
- return xmin, ymin, xmax, ymax
-
- def showEvent(self, event):
- self._visibilityChangedHandler(visible=True)
- qt.QWidget.showEvent(self, event)
-
- def hideEvent(self, event):
- self._visibilityChangedHandler(visible=False)
- qt.QWidget.hideEvent(self, event)
-
- def _visibilityChangedHandler(self, visible):
- """Handle widget's visibility updates.
-
- It is connected to plot signals only when visible.
- """
- if visible:
- # if no ROI existing yet, add the default one
- if self.roiTable.rowCount() == 0:
- old = self.blockSignals(True) # avoid several sigROISignal emission
- self._add()
- self.blockSignals(old)
- self.calculateRois()
-
- def fillFromROIDict(self, *args, **kwargs):
- self.roiTable.fillFromROIDict(*args, **kwargs)
-
- def _emitCurrentROISignal(self):
- ddict = {}
- ddict['event'] = "currentROISignal"
- if self.roiTable.activeRoi is not None:
- ddict['ROI'] = self.roiTable.activeRoi.toDict()
- ddict['current'] = self.roiTable.activeRoi.getName()
- else:
- ddict['current'] = None
-
- if self.__lastSigROISignal != ddict:
- self.__lastSigROISignal = ddict
- self.sigROISignal.emit(ddict)
-
- @property
- def currentRoi(self):
- return self.roiTable.activeRoi
-
-
-class _FloatItem(qt.QTableWidgetItem):
- """
- Simple QTableWidgetItem overloading the < operator to deal with ordering
- """
- def __init__(self):
- qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
-
- def __lt__(self, other):
- if self.text() in ('', ROITable.INFO_NOT_FOUND):
- return False
- if other.text() in ('', ROITable.INFO_NOT_FOUND):
- return True
- return float(self.text()) < float(other.text())
-
-
-class ROITable(TableWidget):
- """Table widget displaying ROI information.
-
- See :class:`QTableWidget` for constructor arguments.
-
- Behavior: listen at the active curve changed only when the widget is
- visible. Otherwise won't compute the row and net counts...
- """
-
- activeROIChanged = qt.Signal()
- """Signal emitted when the active roi changed or when the value of the
- active roi are changing"""
-
- COLUMNS_INDEX = OrderedDict([
- ('ID', 0),
- ('ROI', 1),
- ('Type', 2),
- ('From', 3),
- ('To', 4),
- ('Raw Counts', 5),
- ('Net Counts', 6),
- ('Raw Area', 7),
- ('Net Area', 8),
- ])
-
- COLUMNS = list(COLUMNS_INDEX.keys())
-
- INFO_NOT_FOUND = '????????'
-
- def __init__(self, parent=None, plot=None, rois=None):
- super(ROITable, self).__init__(parent)
- self._showAllMarkers = False
- self._userIsEditingRoi = False
- """bool used to avoid conflict when editing the ROI object"""
- self._isConnected = False
- self._roiToItems = {}
- self._roiDict = {}
- """dict of ROI object. Key is ROi id, value is the ROI object"""
- self._markersHandler = _RoiMarkerManager()
-
- """
- Associate for each marker legend used when the `_showAllMarkers` option
- is active a roi.
- """
- self.setColumnCount(len(self.COLUMNS))
- self.setPlot(plot)
- self.__setTooltip()
- self.setSortingEnabled(True)
- self.itemChanged.connect(self._itemChanged)
-
- @property
- def roidict(self):
- return self._getRoiDict()
-
- @property
- def activeRoi(self):
- return self._markersHandler._activeRoi
-
- def _getRoiDict(self):
- ddict = {}
- for id in self._roiDict:
- ddict[self._roiDict[id].getName()] = self._roiDict[id]
- return ddict
-
- def clear(self):
- """
- .. note:: clear the interface only. keep the roidict...
- """
- self._markersHandler.clear()
- self._roiToItems = {}
- self._roiDict = {}
-
- qt.QTableWidget.clear(self)
- self.setRowCount(0)
- self.setHorizontalHeaderLabels(self.COLUMNS)
- header = self.horizontalHeader()
- if hasattr(header, 'setSectionResizeMode'): # Qt5
- header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
- else: # Qt4
- header.setResizeMode(qt.QHeaderView.ResizeToContents)
- self.sortByColumn(0, qt.Qt.AscendingOrder)
- self.hideColumn(self.COLUMNS_INDEX['ID'])
-
- def setPlot(self, plot):
- self.clear()
- self.plot = plot
-
- def __setTooltip(self):
- self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip(
- 'Region of interest identifier')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip(
- 'Type of the ROI')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip(
- 'X-value of the min point')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip(
- 'X-value of the max point')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip(
- 'Estimation of the integral between y=0 and the selected curve')
- self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip(
- 'Estimation of the integral between the segment [maxPt, minPt] '
- 'and the selected curve')
-
- def setRois(self, rois, order=None):
- """Set the ROIs by providing a dictionary of ROI information.
-
- The dictionary keys are the ROI names.
- Each value is a sub-dictionary of ROI info with the following fields:
-
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
-
-
- :param roidict: Dictionary of ROIs
- :param str order: Field used for ordering the ROIs.
- One of "from", "to", "type".
- None (default) for no ordering, or same order as specified
- in parameter ``roidict`` if provided as an OrderedDict.
- """
- assert order in [None, "from", "to", "type"]
- self.clear()
-
- # backward compatibility since 0.10.0
- if isinstance(rois, dict):
- for roiName, roi in rois.items():
- if isinstance(roi, ROI):
- _roi = roi
- else:
- roi['name'] = roiName
- _roi = ROI._fromDict(roi)
- self.addRoi(_roi)
- else:
- for roi in rois:
- assert isinstance(roi, ROI)
- self.addRoi(roi)
- self._updateMarkers()
-
- def addRoi(self, roi):
- """
-
- :param :class:`ROI` roi: roi to add to the table
- """
- assert isinstance(roi, ROI)
- self._getItem(name='ID', row=None, roi=roi)
- self._roiDict[roi.getID()] = roi
- self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
- self._updateRoiInfo(roi.getID())
- callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
- roi.getID())
- roi.sigChanged.connect(callback)
- # set it as the active one
- self.setActiveRoi(roi)
-
- def _getItem(self, name, row, roi):
- if row:
- item = self.item(row, self.COLUMNS_INDEX[name])
- else:
- item = None
- if item:
- return item
- else:
- if name == 'ID':
- assert roi
- if roi.getID() in self._roiToItems:
- return self._roiToItems[roi.getID()]
- else:
- # create a new row
- row = self.rowCount()
- self.setRowCount(self.rowCount() + 1)
- item = qt.QTableWidgetItem(str(roi.getID()),
- type=qt.QTableWidgetItem.Type)
- self._roiToItems[roi.getID()] = item
- elif name == 'ROI':
- item = qt.QTableWidgetItem(roi.getName() if roi else '',
- type=qt.QTableWidgetItem.Type)
- if roi.getName().upper() in ('ICR', 'DEFAULT'):
- item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
- else:
- item.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable)
- elif name == 'Type':
- item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
- item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
- elif name in ('To', 'From'):
- item = _FloatItem()
- if roi.getName().upper() in ('ICR', 'DEFAULT'):
- item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
- else:
- item.setFlags(qt.Qt.ItemIsSelectable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsEditable)
- elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'):
- item = _FloatItem()
- item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
- else:
- raise ValueError('item type not recognized')
-
- self.setItem(row, self.COLUMNS_INDEX[name], item)
- return item
-
- def _itemChanged(self, item):
- def getRoi():
- IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID'])
- assert IDItem
- id = int(IDItem.text())
- assert id in self._roiDict
- roi = self._roiDict[id]
- return roi
-
- def signalChanged(roi):
- if self.activeRoi and roi.getID() == self.activeRoi.getID():
- self.activeROIChanged.emit()
-
- self._userIsEditingRoi = True
- if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']):
- roi = getRoi()
-
- if item.text() not in ('', self.INFO_NOT_FOUND):
- try:
- value = float(item.text())
- except ValueError:
- value = 0
- changed = False
- if item.column() == self.COLUMNS_INDEX['To']:
- if value != roi.getTo():
- roi.setTo(value)
- changed = True
- else:
- assert(item.column() == self.COLUMNS_INDEX['From'])
- if value != roi.getFrom():
- roi.setFrom(value)
- changed = True
- if changed:
- self._updateMarker(roi.getName())
- signalChanged(roi)
-
- if item.column() is self.COLUMNS_INDEX['ROI']:
- roi = getRoi()
- if roi.getName() != item.text():
- roi.setName(item.text())
- self._markersHandler.getMarkerHandler(roi.getID()).updateTexts()
- signalChanged(roi)
-
- self._userIsEditingRoi = False
-
- def deleteActiveRoi(self):
- """
- remove the current active roi
- """
- activeItems = self.selectedItems()
- if len(activeItems) == 0:
- return
- old = self.blockSignals(True) # avoid several emission of sigROISignal
- roiToRm = set()
- for item in activeItems:
- row = item.row()
- itemID = self.item(row, self.COLUMNS_INDEX['ID'])
- roiToRm.add(self._roiDict[int(itemID.text())])
- [self.removeROI(roi) for roi in roiToRm]
- self.blockSignals(old)
- self.setActiveRoi(None)
-
- def removeROI(self, roi):
- """
- remove the requested roi
-
- :param str name: the name of the roi to remove from the table
- """
- if roi and roi.getID() in self._roiToItems:
- item = self._roiToItems[roi.getID()]
- self.removeRow(item.row())
- del self._roiToItems[roi.getID()]
-
- assert roi.getID() in self._roiDict
- del self._roiDict[roi.getID()]
- self._markersHandler.remove(roi)
-
- callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
- roi.getID())
- roi.sigChanged.connect(callback)
-
- def setActiveRoi(self, roi):
- """
- Define the given roi as the active one.
-
- .. warning:: this roi should already be registred / added to the table
-
- :param :class:`ROI` roi: the roi to defined as active
- """
- if roi is None:
- self.clearSelection()
- self._markersHandler.setActiveRoi(None)
- self.activeROIChanged.emit()
- else:
- assert isinstance(roi, ROI)
- if roi and roi.getID() in self._roiToItems.keys():
- # avoid several call back to setActiveROI
- old = self.blockSignals(True)
- self.selectRow(self._roiToItems[roi.getID()].row())
- self.blockSignals(old)
- self._markersHandler.setActiveRoi(roi)
- self.activeROIChanged.emit()
-
- def _updateRoiInfo(self, roiID):
- if self._userIsEditingRoi is True:
- return
- if roiID not in self._roiDict:
- return
- roi = self._roiDict[roiID]
- if roi.isICR():
- activeCurve = self.plot.getActiveCurve()
- if activeCurve:
- xData = activeCurve.getXData()
- if len(xData) > 0:
- min, max = min_max(xData)
- roi.blockSignals(True)
- roi.setFrom(min)
- roi.setTo(max)
- roi.blockSignals(False)
-
- itemID = self._getItem(name='ID', roi=roi, row=None)
- itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi)
- itemName.setText(roi.getName())
-
- itemType = self._getItem(name='Type', row=itemID.row(), roi=roi)
- itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
-
- itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi)
- fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
- itemFrom.setText(fromdata)
-
- itemTo = self._getItem(name='To', row=itemID.row(), roi=roi)
- todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
- itemTo.setText(todata)
-
- rawCounts, netCounts = roi.computeRawAndNetCounts(
- curve=self.plot.getActiveCurve(just_legend=False))
- itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(),
- roi=roi)
- rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
- itemRawCounts.setText(rawCounts)
-
- itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(),
- roi=roi)
- netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
- itemNetCounts.setText(netCounts)
-
- rawArea, netArea = roi.computeRawAndNetArea(
- curve=self.plot.getActiveCurve(just_legend=False))
- itemRawArea = self._getItem(name='Raw Area', row=itemID.row(),
- roi=roi)
- rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
- itemRawArea.setText(rawArea)
-
- itemNetArea = self._getItem(name='Net Area', row=itemID.row(),
- roi=roi)
- netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
- itemNetArea.setText(netArea)
-
- if self.activeRoi and roi.getID() == self.activeRoi.getID():
- self.activeROIChanged.emit()
-
- def currentChanged(self, current, previous):
- if previous and current.row() != previous.row() and current.row() >= 0:
- roiItem = self.item(current.row(),
- self.COLUMNS_INDEX['ID'])
-
- assert roiItem
- self.setActiveRoi(self._roiDict[int(roiItem.text())])
- self._markersHandler.updateAllMarkers()
- qt.QTableWidget.currentChanged(self, current, previous)
-
- @deprecation.deprecated(reason="Removed",
- replacement="roidict and roidict.values()",
- since_version="0.10.0")
- def getROIListAndDict(self):
- """
-
- :return: the list of roi objects and the dictionary of roi name to roi
- object.
- """
- roidict = self._roiDict
- return list(roidict.values()), roidict
-
- def calculateRois(self, roiList=None, roiDict=None):
- """
- Update values of all registred rois (raw and net counts in particular)
-
- :param roiList: deprecated parameter
- :param roiDict: deprecated parameter
- """
- if roiDict:
- deprecation.deprecated_warning(name='roiDict', type_='Parameter',
- reason='Unused parameter',
- since_version="0.10.0")
- if roiList:
- deprecation.deprecated_warning(name='roiList', type_='Parameter',
- reason='Unused parameter',
- since_version="0.10.0")
-
- for roiID in self._roiDict:
- self._updateRoiInfo(roiID)
-
- def _updateMarker(self, roiID):
- """Make sure the marker of the given roi name is updated"""
- if self._showAllMarkers or (self.activeRoi
- and self.activeRoi.getName() == roiID):
- self._updateMarkers()
-
- def _updateMarkers(self):
- if self._showAllMarkers is True:
- self._markersHandler.updateMarkers()
- else:
- if not self.activeRoi or not self.plot:
- return
- assert isinstance(self.activeRoi, ROI)
- markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID())
- if markerHandler is not None:
- markerHandler.updateMarkers()
-
- def getRois(self, order):
- """
- Return the currently defined ROIs, as an ordered dict.
-
- The dictionary keys are the ROI names.
- Each value is a :class:`ROI` object..
-
- :param order: Field used for ordering the ROIs.
- One of "from", "to", "type", "netcounts", "rawcounts".
- None (default) to get the same order as displayed in the widget.
- :return: Ordered dictionary of ROI information
- """
-
- if order is None or order.lower() == "none":
- ordered_roilist = list(self._roiDict.values())
- res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist])
- else:
- assert order in ["from", "to", "type", "netcounts", "rawcounts"]
- ordered_roilist = sorted(self._roiDict.keys(),
- key=lambda roi_id: self._roiDict[roi_id].get(order))
- res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
-
- return res
-
- def save(self, filename):
- """
- Save current ROIs of the widget as a dict of ROI to a file.
-
- :param str filename: The file to which to save the ROIs
- """
- roilist = []
- roidict = {}
- for roiID, roi in self._roiDict.items():
- roilist.append(roi.toDict())
- roidict[roi.getName()] = roi.toDict()
- datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
- dictdump.dump(datadict, filename)
-
- def load(self, filename):
- """
- Load ROI widget information from a file storing a dict of ROI.
-
- :param str filename: The file from which to load ROI
- """
- roisDict = dictdump.load(filename)
- rois = []
-
- # Remove rawcounts and netcounts from ROIs
- for roiDict in roisDict['ROI']['roidict'].values():
- roiDict.pop('rawcounts', None)
- roiDict.pop('netcounts', None)
- rois.append(ROI._fromDict(roiDict))
-
- self.setRois(rois)
-
- def showAllMarkers(self, _show=True):
- """
-
- :param bool _show: if true show all the markers of all the ROIs
- boundaries otherwise will only show the one of
- the active ROI.
- """
- self._markersHandler.setShowAllMarkers(_show)
-
- def setMiddleROIMarkerFlag(self, flag=True):
- """
- Activate or deactivate middle marker.
-
- This allows shifting both min and max limits at once, by dragging
- a marker located in the middle.
-
- :param bool flag: True to activate middle ROI marker
- """
- self._markersHandler._middleROIMarkerFlag = flag
-
- def _handleROIMarkerEvent(self, ddict):
- """Handle plot signals related to marker events."""
- if ddict['event'] == 'markerMoved':
- label = ddict['label']
- roiID = self._markersHandler.getRoiID(markerID=label)
- if roiID is not None:
- # avoid several emission of sigROISignal
- old = self.blockSignals(True)
- self._markersHandler.changePosition(markerID=label,
- x=ddict['x'])
- self.blockSignals(old)
- self._updateRoiInfo(roiID)
-
- def showEvent(self, event):
- self._visibilityChangedHandler(visible=True)
- qt.QWidget.showEvent(self, event)
-
- def hideEvent(self, event):
- self._visibilityChangedHandler(visible=False)
- qt.QWidget.hideEvent(self, event)
-
- def _visibilityChangedHandler(self, visible):
- """Handle widget's visibility updates.
-
- It is connected to plot signals only when visible.
- """
- if visible:
- assert self.plot
- if self._isConnected is False:
- self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
- self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
- self._isConnected = True
- self.calculateRois()
- else:
- if self._isConnected:
- self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
- self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged)
- self._isConnected = False
-
- def _activeCurveChanged(self, curve):
- self.calculateRois()
-
- def setCountsVisible(self, visible):
- """
- Display the columns relative to areas or not
-
- :param bool visible: True if the columns 'Raw Area' and 'Net Area'
- should be visible.
- """
- if visible is True:
- self.showColumn(self.COLUMNS_INDEX['Raw Counts'])
- self.showColumn(self.COLUMNS_INDEX['Net Counts'])
- else:
- self.hideColumn(self.COLUMNS_INDEX['Raw Counts'])
- self.hideColumn(self.COLUMNS_INDEX['Net Counts'])
-
- def setAreaVisible(self, visible):
- """
- Display the columns relative to areas or not
-
- :param bool visible: True if the columns 'Raw Area' and 'Net Area'
- should be visible.
- """
- if visible is True:
- self.showColumn(self.COLUMNS_INDEX['Raw Area'])
- self.showColumn(self.COLUMNS_INDEX['Net Area'])
- else:
- self.hideColumn(self.COLUMNS_INDEX['Raw Area'])
- self.hideColumn(self.COLUMNS_INDEX['Net Area'])
-
- def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
- """
- This function API is kept for compatibility.
- But `setRois` should be preferred.
-
- Set the ROIs by providing a list of ROI names and a dictionary
- of ROI information for each ROI.
- The ROI names must match an existing dictionary key.
- The name list is used to provide an order for the ROIs.
- The dictionary's values are sub-dictionaries containing 3
- mandatory fields:
-
- - ``"from"``: x coordinate of the left limit, as a float
- - ``"to"``: x coordinate of the right limit, as a float
- - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
-
- :param roilist: List of ROI names (keys of roidict)
- :type roilist: List
- :param dict roidict: Dict of ROI information
- :param currentroi: Name of the selected ROI or None (no selection)
- """
- if roidict is not None:
- self.setRois(roidict)
- else:
- self.setRois(roilist)
- if currentroi:
- self.setActiveRoi(currentroi)
-
-
-_indexNextROI = 0
-
-
-class ROI(_RegionOfInterestBase):
- """The Region Of Interest is defined by:
-
- - A name
- - A type. The type is the label of the x axis. This can be used to apply or
- not some ROI to a curve and do some post processing.
- - The x coordinate of the left limit (fromdata)
- - The x coordinate of the right limit (todata)
-
- :param str: name of the ROI
- :param fromdata: left limit of the roi
- :param todata: right limit of the roi
- :param type: type of the ROI
- """
-
- sigChanged = qt.Signal()
- """Signal emitted when the ROI is edited"""
-
- def __init__(self, name, fromdata=None, todata=None, type_=None):
- _RegionOfInterestBase.__init__(self)
- self.setName(name)
- global _indexNextROI
- self._id = _indexNextROI
- _indexNextROI += 1
-
- self._fromdata = fromdata
- self._todata = todata
- self._type = type_ or 'Default'
-
- self.sigItemChanged.connect(self.__itemChanged)
-
- def __itemChanged(self, event):
- """Handle name change"""
- if event == items.ItemChangedType.NAME:
- self.sigChanged.emit()
-
- def getID(self):
- """
-
- :return int: the unique ID of the ROI
- """
- return self._id
-
- def setType(self, type_):
- """
-
- :param str type_:
- """
- if self._type != type_:
- self._type = type_
- self.sigChanged.emit()
-
- def getType(self):
- """
-
- :return str: the type of the ROI.
- """
- return self._type
-
- def setFrom(self, frm):
- """
-
- :param frm: set x coordinate of the left limit
- """
- if self._fromdata != frm:
- self._fromdata = frm
- self.sigChanged.emit()
-
- def getFrom(self):
- """
-
- :return: x coordinate of the left limit
- """
- return self._fromdata
-
- def setTo(self, to):
- """
-
- :param to: x coordinate of the right limit
- """
- if self._todata != to:
- self._todata = to
- self.sigChanged.emit()
-
- def getTo(self):
- """
-
- :return: x coordinate of the right limit
- """
- return self._todata
-
- def getMiddle(self):
- """
-
- :return: middle position between 'from' and 'to' values
- """
- return 0.5 * (self.getFrom() + self.getTo())
-
- def toDict(self):
- """
-
- :return: dict containing the roi parameters
- """
- ddict = {
- 'type': self._type,
- 'name': self.getName(),
- 'from': self._fromdata,
- 'to': self._todata,
- }
- if hasattr(self, '_extraInfo'):
- ddict.update(self._extraInfo)
- return ddict
-
- @staticmethod
- def _fromDict(dic):
- assert 'name' in dic
- roi = ROI(name=dic['name'])
- roi._extraInfo = {}
- for key in dic:
- if key == 'from':
- roi.setFrom(dic['from'])
- elif key == 'to':
- roi.setTo(dic['to'])
- elif key == 'type':
- roi.setType(dic['type'])
- else:
- roi._extraInfo[key] = dic[key]
-
- return roi
-
- def isICR(self):
- """
-
- :return: True if the ROI is the `ICR`
- """
- return self.getName() == 'ICR'
-
- def computeRawAndNetCounts(self, curve):
- """Compute the Raw and net counts in the ROI for the given curve.
-
- - Raw count: Points values sum of the curve in the defined Region Of
- Interest.
-
- .. image:: img/rawCounts.png
-
- - Net count: Raw counts minus background
-
- .. image:: img/netCounts.png
-
- :param CurveItem curve:
- :return tuple: rawCount, netCount
- """
- assert isinstance(curve, items.Curve) or curve is None
-
- if curve is None:
- return None, None
-
- x = curve.getXData(copy=False)
- y = curve.getYData(copy=False)
-
- idx = numpy.nonzero((self._fromdata <= x) &
- (x <= self._todata))[0]
- if len(idx):
- xw = x[idx]
- yw = y[idx]
- rawCounts = yw.sum(dtype=numpy.float64)
- deltaX = xw[-1] - xw[0]
- deltaY = yw[-1] - yw[0]
- if deltaX > 0.0:
- slope = (deltaY / deltaX)
- background = yw[0] + slope * (xw - xw[0])
- netCounts = (rawCounts -
- background.sum(dtype=numpy.float64))
- else:
- netCounts = 0.0
- else:
- rawCounts = 0.0
- netCounts = 0.0
- return rawCounts, netCounts
-
- def computeRawAndNetArea(self, curve):
- """Compute the Raw and net counts in the ROI for the given curve.
-
- - Raw area: integral of the curve between the min ROI point and the
- max ROI point to the y = 0 line.
-
- .. image:: img/rawArea.png
-
- - Net area: Raw counts minus background
-
- .. image:: img/netArea.png
-
- :param CurveItem curve:
- :return tuple: rawArea, netArea
- """
- assert isinstance(curve, items.Curve) or curve is None
-
- if curve is None:
- return None, None
-
- x = curve.getXData(copy=False)
- y = curve.getYData(copy=False)
-
- y = y[(x >= self._fromdata) & (x <= self._todata)]
- x = x[(x >= self._fromdata) & (x <= self._todata)]
-
- if x.size == 0:
- return 0.0, 0.0
-
- rawArea = numpy.trapz(y, x=x)
- # to speed up and avoid an intersection calculation we are taking the
- # closest index to the ROI
- closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin()
- closestXRightIndex = (numpy.abs(x - self.getTo())).argmin()
- yBackground = y[closestXLeftIndex], y[closestXRightIndex]
- background = numpy.trapz(yBackground, x=x)
- netArea = rawArea - background
- return rawArea, netArea
-
- @docstring(_RegionOfInterestBase)
- def contains(self, position):
- return self._fromdata <= position[0] <= self._todata
-
-
-class _RoiMarkerManager(object):
- """
- Deal with all the ROI markers
- """
- def __init__(self):
- self._roiMarkerHandlers = {}
- self._middleROIMarkerFlag = False
- self._showAllMarkers = False
- self._activeRoi = None
-
- def setActiveRoi(self, roi):
- self._activeRoi = roi
- self.updateAllMarkers()
-
- def setShowAllMarkers(self, show):
- if show != self._showAllMarkers:
- self._showAllMarkers = show
- self.updateAllMarkers()
-
- def add(self, roi, markersHandler):
- assert isinstance(roi, ROI)
- assert isinstance(markersHandler, _RoiMarkerHandler)
- if roi.getID() in self._roiMarkerHandlers:
- raise ValueError('roi with the same ID already existing')
- else:
- self._roiMarkerHandlers[roi.getID()] = markersHandler
-
- def getMarkerHandler(self, roiID):
- if roiID in self._roiMarkerHandlers:
- return self._roiMarkerHandlers[roiID]
- else:
- return None
-
- def clear(self):
- roisHandler = list(self._roiMarkerHandlers.values())
- for roiHandler in roisHandler:
- self.remove(roiHandler.roi)
-
- def remove(self, roi):
- if roi is None:
- return
- assert isinstance(roi, ROI)
- if roi.getID() in self._roiMarkerHandlers:
- self._roiMarkerHandlers[roi.getID()].clear()
- del self._roiMarkerHandlers[roi.getID()]
-
- def hasMarker(self, markerID):
- assert type(markerID) is str
- return self.getMarker(markerID) is not None
-
- def changePosition(self, markerID, x):
- markerHandler = self.getMarker(markerID)
- if markerHandler is None:
- raise ValueError('Marker %s not register' % markerID)
- markerHandler.changePosition(markerID=markerID, x=x)
-
- def updateMarker(self, markerID):
- markerHandler = self.getMarker(markerID)
- if markerHandler is None:
- raise ValueError('Marker %s not register' % markerID)
- roiID = self.getRoiID(markerID)
- visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True
- markerHandler.setVisible(visible)
- markerHandler.updateAllMarkers()
-
- def updateRoiMarkers(self, roiID):
- if roiID in self._roiMarkerHandlers:
- visible = ((self._activeRoi and self._activeRoi.getID() == roiID)
- or self._showAllMarkers is True)
- _roi = self._roiMarkerHandlers[roiID]._roi()
- if _roi and not _roi.isICR():
- self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag)
- self._roiMarkerHandlers[roiID].setVisible(visible)
- self._roiMarkerHandlers[roiID].updateMarkers()
-
- def getMarker(self, markerID):
- assert type(markerID) is str
- for marker in list(self._roiMarkerHandlers.values()):
- if marker.hasMarker(markerID):
- return marker
-
- def updateMarkers(self):
- for markerHandler in list(self._roiMarkerHandlers.values()):
- markerHandler.updateMarkers()
-
- def getRoiID(self, markerID):
- for roiID, markerHandler in self._roiMarkerHandlers.items():
- if markerHandler.hasMarker(markerID):
- return roiID
- return None
-
- def setShowMiddleMarkers(self, show):
- self._middleROIMarkerFlag = show
- self._roiMarkerHandlers.updateAllMarkers()
-
- def updateAllMarkers(self):
- for roiID in self._roiMarkerHandlers:
- self.updateRoiMarkers(roiID)
-
- def getVisibleRois(self):
- res = {}
- for roiID, roiHandler in self._roiMarkerHandlers.items():
- markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'),
- roiHandler.getMarker('middle'))
- for marker in markers:
- if marker.isVisible():
- if roiID not in res:
- res[roiID] = []
- res[roiID].append(marker)
- return res
-
-
-class _RoiMarkerHandler(object):
- """Used to deal with ROI markers used in ROITable"""
- def __init__(self, roi, plot):
- assert roi and isinstance(roi, ROI)
- assert plot
-
- self._roi = weakref.ref(roi)
- self._plot = weakref.ref(plot)
- self._draggable = False if roi.isICR() else True
- self._color = 'black' if roi.isICR() else 'blue'
- self._displayMidMarker = False
- self._visible = True
-
- @property
- def draggable(self):
- return self._draggable
-
- @property
- def plot(self):
- return self._plot()
-
- def clear(self):
- if self.plot and self.roi:
- self.plot.removeMarker(self._markerID('min'))
- self.plot.removeMarker(self._markerID('max'))
- self.plot.removeMarker(self._markerID('middle'))
-
- @property
- def roi(self):
- return self._roi()
-
- def setVisible(self, visible):
- if visible != self._visible:
- self._visible = visible
- self.updateMarkers()
-
- def showMiddleMarker(self, visible):
- if self.draggable is False and visible is True:
- _logger.warning("ROI is not draggable. Won't display middle marker")
- return
- self._displayMidMarker = visible
- self.getMarker('middle').setVisible(self._displayMidMarker)
-
- def updateMarkers(self):
- if self.roi is None:
- return
- self._updateMinMarkerPos()
- self._updateMaxMarkerPos()
- self._updateMiddleMarkerPos()
-
- def _updateMinMarkerPos(self):
- self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None)
- self.getMarker('min').setVisible(self._visible)
-
- def _updateMaxMarkerPos(self):
- self.getMarker('max').setPosition(x=self.roi.getTo(), y=None)
- self.getMarker('max').setVisible(self._visible)
-
- def _updateMiddleMarkerPos(self):
- self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None)
- self.getMarker('middle').setVisible(self._displayMidMarker and self._visible)
-
- def getMarker(self, markerType):
- if self.plot is None:
- return None
- assert markerType in ('min', 'max', 'middle')
- if self.plot._getMarker(self._markerID(markerType)) is None:
- assert self.roi
- if markerType == 'min':
- val = self.roi.getFrom()
- elif markerType == 'max':
- val = self.roi.getTo()
- else:
- val = self.roi.getMiddle()
-
- _color = self._color
- if markerType == 'middle':
- _color = 'yellow'
- self.plot.addXMarker(val,
- legend=self._markerID(markerType),
- text=self.getMarkerName(markerType),
- color=_color,
- draggable=self.draggable)
- return self.plot._getMarker(self._markerID(markerType))
-
- def _markerID(self, markerType):
- assert markerType in ('min', 'max', 'middle')
- assert self.roi
- return '_'.join((str(self.roi.getID()), markerType))
-
- def getMarkerName(self, markerType):
- assert markerType in ('min', 'max', 'middle')
- assert self.roi
- return ' '.join((self.roi.getName(), markerType))
-
- def updateTexts(self):
- self.getMarker('min').setText(self.getMarkerName('min'))
- self.getMarker('max').setText(self.getMarkerName('max'))
- self.getMarker('middle').setText(self.getMarkerName('middle'))
-
- def changePosition(self, markerID, x):
- assert self.hasMarker(markerID)
- markerType = self._getMarkerType(markerID)
- assert markerType is not None
- if self.roi is None:
- return
- if markerType == 'min':
- self.roi.setFrom(x)
- self._updateMiddleMarkerPos()
- elif markerType == 'max':
- self.roi.setTo(x)
- self._updateMiddleMarkerPos()
- else:
- delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo())
- self.roi.setFrom(self.roi.getFrom() + delta)
- self.roi.setTo(self.roi.getTo() + delta)
- self._updateMinMarkerPos()
- self._updateMaxMarkerPos()
-
- def hasMarker(self, marker):
- return marker in (self._markerID('min'),
- self._markerID('max'),
- self._markerID('middle'))
-
- def _getMarkerType(self, markerID):
- if markerID.endswith('_min'):
- return 'min'
- elif markerID.endswith('_max'):
- return 'max'
- elif markerID.endswith('_middle'):
- return 'middle'
- else:
- return None
-
-
-class CurvesROIDockWidget(qt.QDockWidget):
- """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow.
-
- It makes the link between the :class:`CurvesROIWidget` and the PlotWindow.
-
- :param parent: See :class:`QDockWidget`
- :param plot: :class:`.PlotWindow` instance on which to operate
- :param name: See :class:`QDockWidget`
- """
- sigROISignal = qt.Signal(object)
- """Deprecated signal for backward compatibility with silx < 0.7.
- Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
- """
-
- def __init__(self, parent=None, plot=None, name=None):
- super(CurvesROIDockWidget, self).__init__(name, parent)
-
- assert plot is not None
- self.plot = plot
- self.roiWidget = CurvesROIWidget(self, name, plot=plot)
- """Main widget of type :class:`CurvesROIWidget`"""
-
- # convenience methods to offer a simpler API allowing to ignore
- # the details of the underlying implementation
- # (ALL DEPRECATED)
- self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois
- self.setRois = self.roiWidget.setRois
- self.getRois = self.roiWidget.getRois
-
- self.roiWidget.sigROISignal.connect(self._forwardSigROISignal)
-
- self.layout().setContentsMargins(0, 0, 0, 0)
- self.setWidget(self.roiWidget)
-
- self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible
- self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible
-
- def _forwardSigROISignal(self, ddict):
- # emit deprecated signal for backward compatibility (silx < 0.7)
- self.sigROISignal.emit(ddict)
-
- def toggleViewAction(self):
- """Returns a checkable action that shows or closes this widget.
-
- See :class:`QMainWindow`.
- """
- action = super(CurvesROIDockWidget, self).toggleViewAction()
- action.setIcon(icons.getQIcon('plot-roi'))
- return action
-
- def showEvent(self, event):
- """Make sure this widget is raised when it is shown
- (when it is first created as a tab in PlotWindow or when it is shown
- again after hiding).
- """
- self.raise_()
- qt.QDockWidget.showEvent(self, event)
-
- @property
- def currentROI(self):
- return self.roiWidget.currentRoi
diff --git a/silx/gui/plot/ImageStack.py b/silx/gui/plot/ImageStack.py
deleted file mode 100644
index fe4b451..0000000
--- a/silx/gui/plot/ImageStack.py
+++ /dev/null
@@ -1,636 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Image stack view with data prefetch capabilty."""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "04/03/2019"
-
-
-from silx.gui import icons, qt
-from silx.gui.plot import Plot2D
-from silx.gui.utils import concurrent
-from silx.io.url import DataUrl
-from silx.io.utils import get_data
-from collections import OrderedDict
-from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
-import time
-import threading
-import typing
-import logging
-
-_logger = logging.getLogger(__name__)
-
-
-class _PlotWithWaitingLabel(qt.QWidget):
- """Image plot widget with an overlay 'waiting' status.
- """
-
- class AnimationThread(threading.Thread):
- def __init__(self, label):
- self.running = True
- self._label = label
- self.animated_icon = icons.getWaitIcon()
- self.animated_icon.register(self._label)
- super(_PlotWithWaitingLabel.AnimationThread, self).__init__()
-
- def run(self):
- while self.running:
- time.sleep(0.05)
- icon = self.animated_icon.currentIcon()
- self.future_result = concurrent.submitToQtMainThread(
- self._label.setPixmap, icon.pixmap(30, state=qt.QIcon.On))
-
- def stop(self):
- """Stop the update thread"""
- self.animated_icon.unregister(self._label)
- self.running = False
- self.join(2)
-
- def __init__(self, parent):
- super(_PlotWithWaitingLabel, self).__init__(parent=parent)
- self._autoResetZoom = True
- layout = qt.QStackedLayout(self)
- layout.setStackingMode(qt.QStackedLayout.StackAll)
-
- self._waiting_label = qt.QLabel(parent=self)
- self._waiting_label.setAlignment(qt.Qt.AlignHCenter | qt.Qt.AlignVCenter)
- layout.addWidget(self._waiting_label)
-
- self._plot = Plot2D(parent=self)
- layout.addWidget(self._plot)
-
- self.updateThread = _PlotWithWaitingLabel.AnimationThread(self._waiting_label)
- self.updateThread.start()
-
- def close(self) -> bool:
- super(_PlotWithWaitingLabel, self).close()
- self.updateThread.stop()
-
- def setAutoResetZoom(self, reset):
- """
- Should we reset the zoom when adding an image (eq. when browsing)
-
- :param bool reset:
- """
- self._autoResetZoom = reset
- if self._autoResetZoom:
- self._plot.resetZoom()
-
- def isAutoResetZoom(self):
- """
-
- :return: True if a reset is done when the image change
- :rtype: bool
- """
- return self._autoResetZoom
-
- def setWaiting(self, activate=True):
- if activate is True:
- self._plot.clear()
- self._waiting_label.show()
- else:
- self._waiting_label.hide()
-
- def setData(self, data):
- self.setWaiting(activate=False)
- self._plot.addImage(data=data, resetzoom=self._autoResetZoom)
-
- def clear(self):
- self._plot.clear()
- self.setWaiting(False)
-
- def getPlotWidget(self):
- return self._plot
-
-
-class _HorizontalSlider(HorizontalSliderWithBrowser):
-
- sigCurrentUrlIndexChanged = qt.Signal(int)
-
- def __init__(self, parent):
- super(_HorizontalSlider, self).__init__(parent=parent)
- # connect signal / slot
- self.valueChanged.connect(self._urlChanged)
-
- def setUrlIndex(self, index):
- self.setValue(index)
- self.sigCurrentUrlIndexChanged.emit(index)
-
- def _urlChanged(self, value):
- self.sigCurrentUrlIndexChanged.emit(value)
-
-
-class UrlList(qt.QWidget):
- """List of URLs the user to select an URL"""
-
- sigCurrentUrlChanged = qt.Signal(str)
- """Signal emitted when the active/current url change"""
-
- def __init__(self, parent=None):
- super(UrlList, self).__init__(parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setSpacing(0)
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._listWidget = qt.QListWidget(parent=self)
- self.layout().addWidget(self._listWidget)
-
- # connect signal / Slot
- self._listWidget.currentItemChanged.connect(self._notifyCurrentUrlChanged)
-
- # expose API
- self.currentItem = self._listWidget.currentItem
-
- def setUrls(self, urls: list) -> None:
- url_names = []
- [url_names.append(url.path()) for url in urls]
- self._listWidget.addItems(url_names)
-
- def _notifyCurrentUrlChanged(self, current, previous):
- if current is None:
- pass
- else:
- self.sigCurrentUrlChanged.emit(current.text())
-
- def setUrl(self, url: DataUrl) -> None:
- assert isinstance(url, DataUrl)
- sel_items = self._listWidget.findItems(url.path(), qt.Qt.MatchExactly)
- if sel_items is None:
- _logger.warning(url.path(), ' is not registered in the list.')
- elif len(sel_items) > 0:
- item = sel_items[0]
- self._listWidget.setCurrentItem(item)
- self.sigCurrentUrlChanged.emit(item.text())
-
- def clear(self):
- self._listWidget.clear()
-
-
-class _ToggleableUrlSelectionTable(qt.QWidget):
-
- _BUTTON_ICON = qt.QStyle.SP_ToolBarHorizontalExtensionButton # noqa
-
- sigCurrentUrlChanged = qt.Signal(str)
- """Signal emitted when the active/current url change"""
-
- def __init__(self, parent=None) -> None:
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QGridLayout())
- self._toggleButton = qt.QPushButton(parent=self)
- self.layout().addWidget(self._toggleButton, 0, 2, 1, 1)
- self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed,
- qt.QSizePolicy.Fixed)
-
- self._urlsTable = UrlList(parent=self)
- self.layout().addWidget(self._urlsTable, 1, 1, 1, 2)
-
- # set up
- self._setButtonIcon(show=True)
-
- # Signal / slot connection
- self._toggleButton.clicked.connect(self.toggleUrlSelectionTable)
- self._urlsTable.sigCurrentUrlChanged.connect(self._propagateSignal)
-
- # expose API
- self.setUrls = self._urlsTable.setUrls
- self.setUrl = self._urlsTable.setUrl
- self.currentItem = self._urlsTable.currentItem
-
- def toggleUrlSelectionTable(self):
- visible = not self.urlSelectionTableIsVisible()
- self._setButtonIcon(show=visible)
- self._urlsTable.setVisible(visible)
-
- def _setButtonIcon(self, show):
- style = qt.QApplication.instance().style()
- # return a QIcon
- icon = style.standardIcon(self._BUTTON_ICON)
- if show is False:
- pixmap = icon.pixmap(32, 32).transformed(qt.QTransform().scale(-1, 1))
- icon = qt.QIcon(pixmap)
- self._toggleButton.setIcon(icon)
-
- def urlSelectionTableIsVisible(self):
- return self._urlsTable.isVisible()
-
- def _propagateSignal(self, url):
- self.sigCurrentUrlChanged.emit(url)
-
- def clear(self):
- self._urlsTable.clear()
-
-
-class UrlLoader(qt.QThread):
- """
- Thread use to load DataUrl
- """
- def __init__(self, parent, url):
- super(UrlLoader, self).__init__(parent=parent)
- assert isinstance(url, DataUrl)
- self.url = url
- self.data = None
-
- def run(self):
- try:
- self.data = get_data(self.url)
- except IOError:
- self.data = None
-
-
-class ImageStack(qt.QMainWindow):
- """Widget loading on the fly images contained the given urls.
-
- It prefetches images close to the displayed one.
- """
-
- N_PRELOAD = 10
-
- sigLoaded = qt.Signal(str)
- """Signal emitted when new data is available"""
-
- sigCurrentUrlChanged = qt.Signal(str)
- """Signal emitted when the current url change"""
-
- def __init__(self, parent=None) -> None:
- super(ImageStack, self).__init__(parent)
- self.__n_prefetch = ImageStack.N_PRELOAD
- self._loadingThreads = []
- self.setWindowFlags(qt.Qt.Widget)
- self._current_url = None
- self._url_loader = UrlLoader
- "class to instantiate for loading urls"
-
- # main widget
- self._plot = _PlotWithWaitingLabel(parent=self)
- self._plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
- self.setWindowTitle("Image stack")
- self.setCentralWidget(self._plot)
-
- # dock widget: url table
- self._tableDockWidget = qt.QDockWidget(parent=self)
- self._urlsTable = _ToggleableUrlSelectionTable(parent=self)
- self._tableDockWidget.setWidget(self._urlsTable)
- self._tableDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
- self.addDockWidget(qt.Qt.RightDockWidgetArea, self._tableDockWidget)
- # dock widget: qslider
- self._sliderDockWidget = qt.QDockWidget(parent=self)
- self._slider = _HorizontalSlider(parent=self)
- self._sliderDockWidget.setWidget(self._slider)
- self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._sliderDockWidget)
- self._sliderDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
-
- self.reset()
-
- # connect signal / slot
- self._urlsTable.sigCurrentUrlChanged.connect(self.setCurrentUrl)
- self._slider.sigCurrentUrlIndexChanged.connect(self.setCurrentUrlIndex)
-
- def close(self) -> bool:
- self._freeLoadingThreads()
- self._plot.close()
- super(ImageStack, self).close()
-
- def setUrlLoaderClass(self, urlLoader: typing.Type[UrlLoader]) -> None:
- """
-
- :param urlLoader: define the class to call for loading urls.
- warning: this should be a class object and not a
- class instance.
- """
- assert isinstance(urlLoader, type(UrlLoader))
- self._url_loader = urlLoader
-
- def getUrlLoaderClass(self):
- """
-
- :return: class to instantiate for loading urls
- :rtype: typing.Type[UrlLoader]
- """
- return self._url_loader
-
- def _freeLoadingThreads(self):
- for thread in self._loadingThreads:
- thread.blockSignals(True)
- thread.wait(5)
- self._loadingThreads.clear()
-
- def getPlotWidget(self) -> Plot2D:
- """
- Returns the PlotWidget contained in this window
-
- :return: PlotWidget contained in this window
- :rtype: Plot2D
- """
- return self._plot.getPlotWidget()
-
- def reset(self) -> None:
- """Clear the plot and remove any link to url"""
- self._freeLoadingThreads()
- self._urls = None
- self._urlIndexes = None
- self._urlData = OrderedDict({})
- self._current_url = None
- self._plot.clear()
- self._urlsTable.clear()
- self._slider.setMaximum(-1)
-
- def _preFetch(self, urls: list) -> None:
- """Pre-fetch the given urls if necessary
-
- :param urls: list of DataUrl to prefetch
- :type: list
- """
- for url in urls:
- if url.path() not in self._urlData:
- self._load(url)
-
- def _load(self, url):
- """
- Launch background load of a DataUrl
-
- :param url:
- :type: DataUrl
- """
- assert isinstance(url, DataUrl)
- url_path = url.path()
- assert url_path in self._urlIndexes
- loader = self._url_loader(parent=self, url=url)
- loader.finished.connect(self._urlLoaded, qt.Qt.QueuedConnection)
- self._loadingThreads.append(loader)
- loader.start()
-
- def _urlLoaded(self) -> None:
- """
-
- :param url: restul of DataUrl.path() function
- :return:
- """
- sender = self.sender()
- assert isinstance(sender, UrlLoader)
- url = sender.url.path()
- if url in self._urlIndexes:
- self._urlData[url] = sender.data
- if self.getCurrentUrl().path() == url:
- self._plot.setData(self._urlData[url])
- if sender in self._loadingThreads:
- self._loadingThreads.remove(sender)
- self.sigLoaded.emit(url)
-
- def setNPrefetch(self, n: int) -> None:
- """
- Define the number of url to prefetch around
-
- :param int n: number of url to prefetch on left and right sides.
- In total n*2 DataUrl will be prefetch
- """
- self.__n_prefetch = n
- current_url = self.getCurrentUrl()
- if current_url is not None:
- self.setCurrentUrl(current_url)
-
- def getNPrefetch(self) -> int:
- """
-
- :return: number of url to prefetch on left and right sides. In total
- will load 2* NPrefetch DataUrls
- """
- return self.__n_prefetch
-
- def setUrls(self, urls: list) -> None:
- """list of urls within an index. Warning: urls should contain an image
- compatible with the silx.gui.plot.Plot class
-
- :param urls: urls we want to set in the stack. Key is the index
- (position in the stack), value is the DataUrl
- :type: list
- """
- def createUrlIndexes():
- indexes = OrderedDict()
- for index, url in enumerate(urls):
- indexes[index] = url
- return indexes
-
- urls_with_indexes = createUrlIndexes()
- urlsToIndex = self._urlsToIndex(urls_with_indexes)
- self.reset()
- self._urls = urls_with_indexes
- self._urlIndexes = urlsToIndex
-
- old_url_table = self._urlsTable.blockSignals(True)
- self._urlsTable.setUrls(urls=list(self._urls.values()))
- self._urlsTable.blockSignals(old_url_table)
-
- old_slider = self._slider.blockSignals(True)
- self._slider.setMinimum(0)
- self._slider.setMaximum(len(self._urls) - 1)
- self._slider.blockSignals(old_slider)
-
- if self.getCurrentUrl() in self._urls:
- self.setCurrentUrl(self.getCurrentUrl())
- else:
- if len(self._urls.keys()) > 0:
- first_url = self._urls[list(self._urls.keys())[0]]
- self.setCurrentUrl(first_url)
-
- def getUrls(self) -> tuple:
- """
-
- :return: tuple of urls
- :rtype: tuple
- """
- return tuple(self._urlIndexes.keys())
-
- def _getNextUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
- """
- return the next url in the stack
-
- :param url: url for which we want the next url
- :type: DataUrl
- :return: next url in the stack or None if `url` is the last one
- :rtype: Union[None, DataUrl]
- """
- assert isinstance(url, DataUrl)
- if self._urls is None:
- return None
- else:
- index = self._urlIndexes[url.path()]
- indexes = list(self._urls.keys())
- res = list(filter(lambda x: x > index, indexes))
- if len(res) == 0:
- return None
- else:
- return self._urls[res[0]]
-
- def _getPreviousUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
- """
- return the previous url in the stack
-
- :param url: url for which we want the previous url
- :type: DataUrl
- :return: next url in the stack or None if `url` is the last one
- :rtype: Union[None, DataUrl]
- """
- if self._urls is None:
- return None
- else:
- index = self._urlIndexes[url.path()]
- indexes = list(self._urls.keys())
- res = list(filter(lambda x: x < index, indexes))
- if len(res) == 0:
- return None
- else:
- return self._urls[res[-1]]
-
- def _getNNextUrls(self, n: int, url: DataUrl) -> list:
- """
- Deduce the next urls in the stack after `url`
-
- :param n: the number of url store after `url`
- :type: int
- :param url: url for which we want n next url
- :type: DataUrl
- :return: list of next urls.
- :rtype: list
- """
- res = []
- next_free = self._getNextUrl(url=url)
- while len(res) < n and next_free is not None:
- assert isinstance(next_free, DataUrl)
- res.append(next_free)
- next_free = self._getNextUrl(res[-1])
- return res
-
- def _getNPreviousUrls(self, n: int, url: DataUrl):
- """
- Deduce the previous urls in the stack after `url`
-
- :param n: the number of url store after `url`
- :type: int
- :param url: url for which we want n previous url
- :type: DataUrl
- :return: list of previous urls.
- :rtype: list
- """
- res = []
- next_free = self._getPreviousUrl(url=url)
- while len(res) < n and next_free is not None:
- res.insert(0, next_free)
- next_free = self._getPreviousUrl(res[0])
- return res
-
- def setCurrentUrlIndex(self, index: int):
- """
- Define the url to be displayed
-
- :param index: url to be displayed
- :type: int
- """
- if index < 0:
- return
- if self._urls is None:
- return
- elif index >= len(self._urls):
- raise ValueError('requested index out of bounds')
- else:
- return self.setCurrentUrl(self._urls[index])
-
- def setCurrentUrl(self, url: typing.Union[DataUrl, str]) -> None:
- """
- Define the url to be displayed
-
- :param url: url to be displayed
- :type: DataUrl
- """
- assert isinstance(url, (DataUrl, str))
- if isinstance(url, str):
- url = DataUrl(path=url)
- if url != self._current_url:
- self._current_url = url
- self.sigCurrentUrlChanged.emit(url.path())
-
- old_url_table = self._urlsTable.blockSignals(True)
- old_slider = self._slider.blockSignals(True)
-
- self._urlsTable.setUrl(url)
- self._slider.setUrlIndex(self._urlIndexes[url.path()])
- if self._current_url is None:
- self._plot.clear()
- else:
- if self._current_url.path() in self._urlData:
- self._plot.setData(self._urlData[url.path()])
- else:
- self._load(url)
- self._notifyLoading()
- self._preFetch(self._getNNextUrls(self.__n_prefetch, url))
- self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url))
- self._urlsTable.blockSignals(old_url_table)
- self._slider.blockSignals(old_slider)
-
- def getCurrentUrl(self) -> typing.Union[None, DataUrl]:
- """
-
- :return: url currently displayed
- :rtype: Union[None, DataUrl]
- """
- return self._current_url
-
- def getCurrentUrlIndex(self) -> typing.Union[None, int]:
- """
-
- :return: index of the url currently displayed
- :rtype: Union[None, int]
- """
- if self._current_url is None:
- return None
- else:
- return self._urlIndexes[self._current_url.path()]
-
- @staticmethod
- def _urlsToIndex(urls):
- """util, return a dictionary with url as key and index as value"""
- res = {}
- for index, url in urls.items():
- res[url.path()] = index
- return res
-
- def _notifyLoading(self):
- """display a simple image of loading..."""
- self._plot.setWaiting(activate=True)
-
- def setAutoResetZoom(self, reset):
- """
- Should we reset the zoom when adding an image (eq. when browsing)
-
- :param bool reset:
- """
- self._plot.setAutoResetZoom(reset)
-
- def isAutoResetZoom(self) -> bool:
- """
-
- :return: True if a reset is done when the image change
- :rtype: bool
- """
- return self._plot.isAutoResetZoom()
diff --git a/silx/gui/plot/ImageView.py b/silx/gui/plot/ImageView.py
deleted file mode 100644
index 1befe58..0000000
--- a/silx/gui/plot/ImageView.py
+++ /dev/null
@@ -1,854 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""QWidget displaying a 2D image with histograms on its sides.
-
-The :class:`ImageView` implements this widget, and
-:class:`ImageViewMainWindow` provides a main window with additional toolbar
-and status bar.
-
-Basic usage of :class:`ImageView` is through the following methods:
-
-- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the
- default colormap to use and update the currently displayed image.
-- :meth:`ImageView.setImage` to update the displayed image.
-
-For an example of use, see `imageview.py` in :ref:`sample-code`.
-"""
-
-from __future__ import division
-
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/04/2018"
-
-
-import logging
-import numpy
-import collections
-from typing import Union
-import weakref
-
-import silx
-from .. import qt
-from .. import colors
-
-from . import items, PlotWindow, PlotWidget, actions
-from ..colors import Colormap
-from ..colors import cursorColorForColormap
-from .tools import LimitsToolBar
-from .Profile import ProfileToolBar
-from ...utils.proxy import docstring
-from ...utils.enum import Enum
-from .tools.RadarView import RadarView
-from .utils.axis import SyncAxes
-from ..utils import blockSignals
-from . import _utils
-from .tools.profile import manager
-from .tools.profile import rois
-
-_logger = logging.getLogger(__name__)
-
-
-ProfileSumResult = collections.namedtuple("ProfileResult",
- ["dataXRange", "dataYRange",
- 'histoH', 'histoHRange',
- 'histoV', 'histoVRange',
- "xCoords", "xData",
- "yCoords", "yData"])
-
-
-def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
- """
- Compute a full vertical and horizontal profile on an image item using a
- a range in the plot referential.
-
- Optionally takes a previous computed result to be able to skip the
- computation.
-
- :rtype: ProfileSumResult
- """
- data = imageItem.getValueData(copy=False)
- origin = imageItem.getOrigin()
- scale = imageItem.getScale()
- height, width = data.shape
-
- xMin, xMax = xRange
- yMin, yMax = yRange
-
- # Convert plot area limits to image coordinates
- # and work in image coordinates (i.e., in pixels)
- xMin = int((xMin - origin[0]) / scale[0])
- xMax = int((xMax - origin[0]) / scale[0])
- yMin = int((yMin - origin[1]) / scale[1])
- yMax = int((yMax - origin[1]) / scale[1])
-
- if (xMin >= width or xMax < 0 or
- yMin >= height or yMax < 0):
- return None
-
- # The image is at least partly in the plot area
- # Get the visible bounds in image coords (i.e., in pixels)
- subsetXMin = 0 if xMin < 0 else xMin
- subsetXMax = (width if xMax >= width else xMax) + 1
- subsetYMin = 0 if yMin < 0 else yMin
- subsetYMax = (height if yMax >= height else yMax) + 1
-
- if cache is not None:
- if ((subsetXMin, subsetXMax) == cache.dataXRange and
- (subsetYMin, subsetYMax) == cache.dataYRange):
- # The visible area of data is the same
- return cache
-
- # Rebuild histograms for visible area
- visibleData = data[subsetYMin:subsetYMax,
- subsetXMin:subsetXMax]
- histoHVisibleData = numpy.nansum(visibleData, axis=0)
- histoVVisibleData = numpy.nansum(visibleData, axis=1)
- histoHMin = numpy.nanmin(histoHVisibleData)
- histoHMax = numpy.nanmax(histoHVisibleData)
- histoVMin = numpy.nanmin(histoVVisibleData)
- histoVMax = numpy.nanmax(histoVVisibleData)
-
- # Convert to histogram curve and update plots
- # Taking into account origin and scale
- coords = numpy.arange(2 * histoHVisibleData.size)
- xCoords = (coords + 1) // 2 + subsetXMin
- xCoords = origin[0] + scale[0] * xCoords
- xData = numpy.take(histoHVisibleData, coords // 2)
- coords = numpy.arange(2 * histoVVisibleData.size)
- yCoords = (coords + 1) // 2 + subsetYMin
- yCoords = origin[1] + scale[1] * yCoords
- yData = numpy.take(histoVVisibleData, coords // 2)
-
- result = ProfileSumResult(
- dataXRange=(subsetXMin, subsetXMax),
- dataYRange=(subsetYMin, subsetYMax),
- histoH=histoHVisibleData,
- histoHRange=(histoHMin, histoHMax),
- histoV=histoVVisibleData,
- histoVRange=(histoVMin, histoVMax),
- xCoords=xCoords,
- xData=xData,
- yCoords=yCoords,
- yData=yData)
-
- return result
-
-
-class _SideHistogram(PlotWidget):
- """
- Widget displaying one of the side profile of the ImageView.
-
- Implement ProfileWindow
- """
-
- sigClose = qt.Signal()
-
- sigMouseMoved = qt.Signal(float, float)
-
- def __init__(self, parent=None, backend=None, direction=qt.Qt.Horizontal):
- super(_SideHistogram, self).__init__(parent=parent, backend=backend)
- self._direction = direction
- self.sigPlotSignal.connect(self._plotEvents)
- self._color = "blue"
- self.__profile = None
- self.__profileSum = None
-
- def _plotEvents(self, eventDict):
- """Callback for horizontal histogram plot events."""
- if eventDict['event'] == 'mouseMoved':
- self.sigMouseMoved.emit(eventDict['x'], eventDict['y'])
-
- def setProfileColor(self, color):
- self._color = color
-
- def setProfileSum(self, result):
- self.__profileSum = result
- if self.__profile is None:
- self.__drawProfileSum()
-
- def prepareWidget(self, roi):
- """Implements `ProfileWindow`"""
- pass
-
- def setRoiProfile(self, roi):
- """Implements `ProfileWindow`"""
- if roi is None:
- return
- self._roiColor = colors.rgba(roi.getColor())
-
- def getProfile(self):
- """Implements `ProfileWindow`"""
- return self.__profile
-
- def setProfile(self, data):
- """Implements `ProfileWindow`"""
- self.__profile = data
- if data is None:
- self.__drawProfileSum()
- else:
- self.__drawProfile()
-
- def __drawProfileSum(self):
- """Only draw the profile sum on the plot.
-
- Other elements are removed
- """
- profileSum = self.__profileSum
-
- try:
- self.removeCurve('profile')
- except Exception:
- pass
-
- if profileSum is None:
- try:
- self.removeCurve('profilesum')
- except Exception:
- pass
- return
-
- if self._direction == qt.Qt.Horizontal:
- xx, yy = profileSum.xCoords, profileSum.xData
- elif self._direction == qt.Qt.Vertical:
- xx, yy = profileSum.yData, profileSum.yCoords
- else:
- assert False
-
- self.addCurve(xx, yy,
- xlabel='', ylabel='',
- legend="profilesum",
- color=self._color,
- linestyle='-',
- selectable=False,
- resetzoom=False)
-
- self.__updateLimits()
-
- def __drawProfile(self):
- """Only draw the profile on the plot.
-
- Other elements are removed
- """
- profile = self.__profile
-
- try:
- self.removeCurve('profilesum')
- except Exception:
- pass
-
- if profile is None:
- try:
- self.removeCurve('profile')
- except Exception:
- pass
- self.setProfileSum(self.__profileSum)
- return
-
- if self._direction == qt.Qt.Horizontal:
- xx, yy = profile.coords, profile.profile
- elif self._direction == qt.Qt.Vertical:
- xx, yy = profile.profile, profile.coords
- else:
- assert False
-
- self.addCurve(xx,
- yy,
- legend="profile",
- color=self._roiColor,
- resetzoom=False)
-
- self.__updateLimits()
-
- def __updateLimits(self):
- if self.__profile:
- data = self.__profile.profile
- vMin = numpy.nanmin(data)
- vMax = numpy.nanmax(data)
- elif self.__profileSum is not None:
- if self._direction == qt.Qt.Horizontal:
- vMin, vMax = self.__profileSum.histoHRange
- elif self._direction == qt.Qt.Vertical:
- vMin, vMax = self.__profileSum.histoVRange
- else:
- assert False
- else:
- vMin, vMax = 0, 0
-
- # Tune the result using the data margins
- margins = self.getDataMargins()
- if self._direction == qt.Qt.Horizontal:
- _, _, vMin, vMax = _utils.addMarginsToLimits(margins, False, False, 0, 0, vMin, vMax)
- elif self._direction == qt.Qt.Vertical:
- vMin, vMax, _, _ = _utils.addMarginsToLimits(margins, False, False, vMin, vMax, 0, 0)
- else:
- assert False
-
- if self._direction == qt.Qt.Horizontal:
- dataAxis = self.getYAxis()
- elif self._direction == qt.Qt.Vertical:
- dataAxis = self.getXAxis()
- else:
- assert False
-
- with blockSignals(dataAxis):
- dataAxis.setLimits(vMin, vMax)
-
-
-class ImageView(PlotWindow):
- """Display a single image with horizontal and vertical histograms.
-
- Use :meth:`setImage` to control the displayed image.
- This class also provides the :class:`silx.gui.plot.Plot` API.
-
- The :class:`ImageView` inherits from :class:`.PlotWindow` (which provides
- the toolbars) and also exposes :class:`.PlotWidget` API for further
- plot control (plot title, axes labels, aspect ratio, ...).
-
- :param parent: The parent of this widget or None.
- :param backend: The backend to use for the plot (default: matplotlib).
- See :class:`.PlotWidget` for the list of supported backend.
- :type backend: str or :class:`BackendBase.BackendBase`
- """
-
- HISTOGRAMS_COLOR = 'blue'
- """Color to use for the side histograms."""
-
- HISTOGRAMS_HEIGHT = 200
- """Height in pixels of the side histograms."""
-
- IMAGE_MIN_SIZE = 200
- """Minimum size in pixels of the image area."""
-
- # Qt signals
- valueChanged = qt.Signal(float, float, float)
- """Signals that the data value under the cursor has changed.
-
- It provides: row, column, data value.
-
- When the cursor is over an histogram, either row or column is Nan
- and the provided data value is the histogram value
- (i.e., the sum along the corresponding row/column).
- Row and columns are either Nan or integer values.
- """
-
- class ProfileWindowBehavior(Enum):
- """ImageView's profile window behavior options"""
-
- POPUP = 'popup'
- """All profiles are displayed in pop-up windows"""
-
- EMBEDDED = 'embedded'
- """Horizontal, vertical and cross profiles are displayed in
- sides widgets, others are displayed in pop-up windows.
- """
-
- def __init__(self, parent=None, backend=None):
- self._imageLegend = '__ImageView__image' + str(id(self))
- self._cache = None # Store currently visible data information
-
- super(ImageView, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=False,
- logScale=False, grid=False,
- curveStyle=False, colormap=True,
- aspectRatio=True, yInverted=True,
- copy=True, save=True, print_=True,
- control=False, position=False,
- roi=False, mask=True)
-
- # Enable mask synchronisation to use it in profiles
- maskToolsWidget = self.getMaskToolsDockWidget().widget()
- maskToolsWidget.setItemMaskUpdated(True)
-
- if parent is None:
- self.setWindowTitle('ImageView')
-
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
- self.getYAxis().setInverted(True)
-
- self._initWidgets(backend)
-
- self.__profileWindowBehavior = self.ProfileWindowBehavior.POPUP
- self.__profile = ProfileToolBar(plot=self)
- self.addToolBar(self.__profile)
-
- def _initWidgets(self, backend):
- """Set-up layout and plots."""
- self._histoHPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Horizontal)
- widgetHandle = self._histoHPlot.getWidgetHandle()
- widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
- widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT)
- self._histoHPlot.setInteractiveMode('zoom')
- self._histoHPlot.setDataMargins(0., 0., 0.1, 0.1)
- self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH)
- self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR)
-
- self._histoVPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Vertical)
- widgetHandle = self._histoVPlot.getWidgetHandle()
- widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT)
- widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT)
- self._histoVPlot.setInteractiveMode('zoom')
- self._histoVPlot.setDataMargins(0.1, 0.1, 0., 0.)
- self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV)
- self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR)
-
- self.setPanWithArrowKeys(True)
- self.setInteractiveMode('zoom') # Color set in setColormap
- self.sigPlotSignal.connect(self._imagePlotCB)
- self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
-
- self._radarView = RadarView(parent=self)
- self._radarView.setPlotWidget(self)
-
- self.__syncXAxis = SyncAxes([self.getXAxis(), self._histoHPlot.getXAxis()])
- self.__syncYAxis = SyncAxes([self.getYAxis(), self._histoVPlot.getYAxis()])
-
- self.__setCentralWidget()
-
- def __setCentralWidget(self):
- """Set central widget with all its content"""
- layout = qt.QGridLayout()
- layout.addWidget(self.getWidgetHandle(), 0, 0)
- layout.addWidget(self._histoVPlot.getWidgetHandle(), 0, 1)
- layout.addWidget(self._histoHPlot.getWidgetHandle(), 1, 0)
- layout.addWidget(self._radarView, 1, 1, 1, 2)
- layout.addWidget(self.getColorBarWidget(), 0, 2)
-
- layout.setColumnMinimumWidth(0, self.IMAGE_MIN_SIZE)
- layout.setColumnStretch(0, 1)
- layout.setColumnMinimumWidth(1, self.HISTOGRAMS_HEIGHT)
- layout.setColumnStretch(1, 0)
-
- layout.setRowMinimumHeight(0, self.IMAGE_MIN_SIZE)
- layout.setRowStretch(0, 1)
- layout.setRowMinimumHeight(1, self.HISTOGRAMS_HEIGHT)
- layout.setRowStretch(1, 0)
-
- layout.setSpacing(0)
- layout.setContentsMargins(0, 0, 0, 0)
-
- centralWidget = qt.QWidget(self)
- centralWidget.setLayout(layout)
- self.setCentralWidget(centralWidget)
-
- @docstring(PlotWidget)
- def setBackend(self, backend):
- # Use PlotWidget here since we override PlotWindow behavior
- PlotWidget.setBackend(self, backend)
- self.__setCentralWidget()
-
- def _dirtyCache(self):
- self._cache = None
-
- def _updateHistograms(self):
- """Update histograms content using current active image."""
- activeImage = self.getActiveImage()
- if activeImage is not None:
- xRange = self.getXAxis().getLimits()
- yRange = self.getYAxis().getLimits()
- result = computeProfileSumOnRange(activeImage, xRange, yRange, self._cache)
- self._cache = result
- self._histoHPlot.setProfileSum(result)
- self._histoVPlot.setProfileSum(result)
-
- # Plots event listeners
-
- def _imagePlotCB(self, eventDict):
- """Callback for imageView plot events."""
- if eventDict['event'] == 'mouseMoved':
- activeImage = self.getActiveImage()
- if activeImage is not None:
- data = activeImage.getData(copy=False)
- height, width = data.shape
-
- # Get corresponding coordinate in image
- origin = activeImage.getOrigin()
- scale = activeImage.getScale()
- if (eventDict['x'] >= origin[0] and
- eventDict['y'] >= origin[1]):
- x = int((eventDict['x'] - origin[0]) / scale[0])
- y = int((eventDict['y'] - origin[1]) / scale[1])
-
- if x >= 0 and x < width and y >= 0 and y < height:
- self.valueChanged.emit(float(x), float(y),
- data[y][x])
-
- elif eventDict['event'] == 'limitsChanged':
- self._updateHistograms()
-
- def _mouseMovedOnHistoH(self, x, y):
- if self._cache is None:
- return
- activeImage = self.getActiveImage()
- if activeImage is None:
- return
-
- xOrigin = activeImage.getOrigin()[0]
- xScale = activeImage.getScale()[0]
-
- minValue = xOrigin + xScale * self._cache.dataXRange[0]
-
- if x >= minValue:
- data = self._cache.histoH
- column = int((x - minValue) / xScale)
- if column >= 0 and column < data.shape[0]:
- self.valueChanged.emit(
- float('nan'),
- float(column + self._cache.dataXRange[0]),
- data[column])
-
- def _mouseMovedOnHistoV(self, x, y):
- if self._cache is None:
- return
- activeImage = self.getActiveImage()
- if activeImage is None:
- return
-
- yOrigin = activeImage.getOrigin()[1]
- yScale = activeImage.getScale()[1]
-
- minValue = yOrigin + yScale * self._cache.dataYRange[0]
-
- if y >= minValue:
- data = self._cache.histoV
- row = int((y - minValue) / yScale)
- if row >= 0 and row < data.shape[0]:
- self.valueChanged.emit(
- float(row + self._cache.dataYRange[0]),
- float('nan'),
- data[row])
-
- def _activeImageChangedSlot(self, previous, legend):
- """Handle Plot active image change.
-
- Resets side histograms cache
- """
- self._dirtyCache()
- self._updateHistograms()
-
- def setProfileWindowBehavior(self, behavior: Union[str, ProfileWindowBehavior]):
- """Set where profile widgets are displayed.
-
- :param ProfileWindowBehavior behavior:
- - 'popup': All profiles are displayed in pop-up windows
- - 'embedded': Horizontal, vertical and cross profiles are displayed in
- sides widgets, others are displayed in pop-up windows.
- """
- behavior = self.ProfileWindowBehavior.from_value(behavior)
- if behavior is not self.getProfileWindowBehavior():
- manager = self.__profile.getProfileManager()
- manager.clearProfile()
- manager.requestUpdateAllProfile()
-
- if behavior is self.ProfileWindowBehavior.EMBEDDED:
- horizontalProfileWindow = self._histoHPlot
- verticalProfileWindow = self._histoVPlot
- else:
- horizontalProfileWindow = None
- verticalProfileWindow = None
-
- manager.setSpecializedProfileWindow(
- rois.ProfileImageHorizontalLineROI, horizontalProfileWindow
- )
- manager.setSpecializedProfileWindow(
- rois.ProfileImageVerticalLineROI, verticalProfileWindow
- )
- self.__profileWindowBehavior = behavior
-
- def getProfileWindowBehavior(self) -> ProfileWindowBehavior:
- """Returns current profile display behavior.
-
- See :meth:`setProfileWindowBehavior` and :class:`ProfileWindowBehavior`
- """
- return self.__profileWindowBehavior
-
- def getProfileToolBar(self):
- """"Returns profile tools attached to this plot.
-
- :rtype: silx.gui.plot.PlotTools.ProfileToolBar
- """
- return self.__profile
-
- @property
- def profile(self):
- return self.getProfileToolBar()
-
- def getHistogram(self, axis):
- """Return the histogram and corresponding row or column extent.
-
- The returned value when an histogram is available is a dict with keys:
-
- - 'data': numpy array of the histogram values.
- - 'extent': (start, end) row or column index.
- end index is not included in the histogram.
-
- :param str axis: 'x' for horizontal, 'y' for vertical
- :return: The histogram and its extent as a dict or None.
- :rtype: dict
- """
- assert axis in ('x', 'y')
- if self._cache is None:
- return None
- else:
- if axis == 'x':
- return dict(
- data=numpy.array(self._cache.histoH, copy=True),
- extent=self._cache.dataXRange)
- else:
- return dict(
- data=numpy.array(self._cache.histoV, copy=True),
- extent=(self._cache.dataYRange))
-
- def radarView(self):
- """Get the lower right radarView widget."""
- return self._radarView
-
- def setRadarView(self, radarView):
- """Change the lower right radarView widget.
-
- :param RadarView radarView: Widget subclassing RadarView to replace
- the lower right corner widget.
- """
- self._radarView = radarView
- self._radarView.setPlotWidget(self)
- self.centralWidget().layout().addWidget(self._radarView, 1, 1)
-
- # High-level API
-
- def getColormap(self):
- """Get the default colormap description.
-
- :return: A description of the current colormap.
- See :meth:`setColormap` for details.
- :rtype: dict
- """
- return self.getDefaultColormap()
-
- def setColormap(self, colormap=None, normalization=None,
- autoscale=None, vmin=None, vmax=None, colors=None):
- """Set the default colormap and update active image.
-
- Parameters that are not provided are taken from the current colormap.
-
- The colormap parameter can also be a dict with the following keys:
-
- - *name*: string. The colormap to use:
- 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
- - *normalization*: string. The mapping to use for the colormap:
- either 'linear' or 'log'.
- - *autoscale*: bool. Whether to use autoscale (True)
- or range provided by keys 'vmin' and 'vmax' (False).
- - *vmin*: float. The minimum value of the range to use if 'autoscale'
- is False.
- - *vmax*: float. The maximum value of the range to use if 'autoscale'
- is False.
- - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
- List of RGB or RGBA colors to use (only if name is None)
-
- :param colormap: Name of the colormap in
- 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
- Or the description of the colormap as a dict.
- :type colormap: dict or str.
- :param str normalization: Colormap mapping: 'linear' or 'log'.
- :param bool autoscale: Whether to use autoscale (True)
- or [vmin, vmax] range (False).
- :param float vmin: The minimum value of the range to use if
- 'autoscale' is False.
- :param float vmax: The maximum value of the range to use if
- 'autoscale' is False.
- :param numpy.ndarray colors: Only used if name is None.
- Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
- """
- cmap = self.getDefaultColormap()
-
- if isinstance(colormap, Colormap):
- # Replace colormap
- cmap = colormap
-
- self.setDefaultColormap(cmap)
-
- # Update active image colormap
- activeImage = self.getActiveImage()
- if isinstance(activeImage, items.ColormapMixIn):
- activeImage.setColormap(cmap)
-
- elif isinstance(colormap, dict):
- # Support colormap parameter as a dict
- assert normalization is None
- assert autoscale is None
- assert vmin is None
- assert vmax is None
- assert colors is None
- cmap._setFromDict(colormap)
-
- else:
- if colormap is not None:
- cmap.setName(colormap)
- if normalization is not None:
- cmap.setNormalization(normalization)
- if autoscale:
- cmap.setVRange(None, None)
- else:
- if vmin is not None:
- cmap.setVMin(vmin)
- if vmax is not None:
- cmap.setVMax(vmax)
- if colors is not None:
- cmap.setColormapLUT(colors)
-
- cursorColor = cursorColorForColormap(cmap.getName())
- self.setInteractiveMode('zoom', color=cursorColor)
-
- def setImage(self, image, origin=(0, 0), scale=(1., 1.),
- copy=True, reset=True):
- """Set the image to display.
-
- :param image: A 2D array representing the image or None to empty plot.
- :type image: numpy.ndarray-like with 2 dimensions or None.
- :param origin: The (x, y) position of the origin of the image.
- Default: (0, 0).
- The origin is the lower left corner of the image when
- the Y axis is not inverted.
- :type origin: Tuple of 2 floats: (origin x, origin y).
- :param scale: The scale factor to apply to the image on X and Y axes.
- Default: (1, 1).
- It is the size of a pixel in the coordinates of the axes.
- Scales must be positive numbers.
- :type scale: Tuple of 2 floats: (scale x, scale y).
- :param bool copy: Whether to copy image data (default) or not.
- :param bool reset: Whether to reset zoom and ROI (default) or not.
- """
- self._dirtyCache()
-
- assert len(origin) == 2
- assert len(scale) == 2
- assert scale[0] > 0
- assert scale[1] > 0
-
- if image is None:
- self.remove(self._imageLegend, kind='image')
- return
-
- data = numpy.array(image, order='C', copy=copy)
- assert data.size != 0
- assert len(data.shape) == 2
-
- self.addImage(data,
- legend=self._imageLegend,
- origin=origin, scale=scale,
- colormap=self.getColormap(),
- resetzoom=False)
- self.setActiveImage(self._imageLegend)
- self._updateHistograms()
- if reset:
- self.resetZoom()
-
-
-# ImageViewMainWindow #########################################################
-
-class ImageViewMainWindow(ImageView):
- """:class:`ImageView` with additional toolbars
-
- Adds extra toolbar and a status bar to :class:`ImageView`.
- """
- def __init__(self, parent=None, backend=None):
- self._dataInfo = None
- super(ImageViewMainWindow, self).__init__(parent, backend)
- self.setWindowFlags(qt.Qt.Window)
-
- self.getXAxis().setLabel('X')
- self.getYAxis().setLabel('Y')
- self.setGraphTitle('Image')
-
- # Add toolbars and status bar
- self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self))
-
- self.statusBar()
-
- menu = self.menuBar().addMenu('File')
- menu.addAction(self.getOutputToolBar().getSaveAction())
- menu.addAction(self.getOutputToolBar().getPrintAction())
- menu.addSeparator()
- action = menu.addAction('Quit')
- action.triggered[bool].connect(qt.QApplication.instance().quit)
-
- menu = self.menuBar().addMenu('Edit')
- menu.addAction(self.getOutputToolBar().getCopyAction())
- menu.addSeparator()
- menu.addAction(self.getResetZoomAction())
- menu.addAction(self.getColormapAction())
- menu.addAction(actions.control.KeepAspectRatioAction(self, self))
- menu.addAction(actions.control.YAxisInvertedAction(self, self))
-
- self.__profileMenu = self.menuBar().addMenu('Profile')
- self.__updateProfileMenu()
-
- # Connect to ImageView's signal
- self.valueChanged.connect(self._statusBarSlot)
-
- def __updateProfileMenu(self):
- """Update actions available in 'Profile' menu"""
- profile = self.getProfileToolBar()
- self.__profileMenu.clear()
- self.__profileMenu.addAction(profile.hLineAction)
- self.__profileMenu.addAction(profile.vLineAction)
- self.__profileMenu.addAction(profile.crossAction)
- self.__profileMenu.addAction(profile.lineAction)
- self.__profileMenu.addAction(profile.clearAction)
-
- def _statusBarSlot(self, row, column, value):
- """Update status bar with coordinates/value from plots."""
- if numpy.isnan(row):
- msg = 'Column: %d, Sum: %g' % (int(column), value)
- elif numpy.isnan(column):
- msg = 'Row: %d, Sum: %g' % (int(row), value)
- else:
- msg = 'Position: (%d, %d), Value: %g' % (int(row), int(column),
- value)
- if self._dataInfo is not None:
- msg = self._dataInfo + ', ' + msg
-
- self.statusBar().showMessage(msg)
-
- @docstring(ImageView)
- def setProfileWindowBehavior(self, behavior: str):
- super().setProfileWindowBehavior(behavior)
- self.__updateProfileMenu()
-
- @docstring(ImageView)
- def setImage(self, image, *args, **kwargs):
- if hasattr(image, 'dtype') and hasattr(image, 'shape'):
- assert len(image.shape) == 2
- height, width = image.shape
- self._dataInfo = 'Data: %dx%d (%s)' % (width, height,
- str(image.dtype))
- self.statusBar().showMessage(self._dataInfo)
- else:
- self._dataInfo = None
-
- # Set the new image in ImageView widget
- super(ImageViewMainWindow, self).setImage(image, *args, **kwargs)
- self.setStatusBar(None)
diff --git a/silx/gui/plot/ItemsSelectionDialog.py b/silx/gui/plot/ItemsSelectionDialog.py
deleted file mode 100644
index ebd1c64..0000000
--- a/silx/gui/plot/ItemsSelectionDialog.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a dialog widget to select plot items.
-
-.. autoclass:: ItemsSelectionDialog
-
-"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "28/06/2017"
-
-import logging
-
-from silx.gui import qt
-from silx.gui.plot.PlotWidget import PlotWidget
-
-_logger = logging.getLogger(__name__)
-
-
-class KindsSelector(qt.QListWidget):
- """List widget allowing to select plot item kinds
- ("curve", "scatter", "image"...)
- """
- sigSelectedKindsChanged = qt.Signal(list)
-
- def __init__(self, parent=None, kinds=None):
- """
-
- :param parent: Parent QWidget or None
- :param tuple(str) kinds: Sequence of kinds. If None, the default
- behavior is to provide a checkbox for all possible item kinds.
- """
- qt.QListWidget.__init__(self, parent)
-
- self.plot_item_kinds = []
-
- self.setAvailableKinds(kinds if kinds is not None else PlotWidget.ITEM_KINDS)
-
- self.setSelectionMode(qt.QAbstractItemView.ExtendedSelection)
- self.selectAll()
-
- self.itemSelectionChanged.connect(self.emitSigKindsSelectionChanged)
-
- def emitSigKindsSelectionChanged(self):
- self.sigSelectedKindsChanged.emit(self.selectedKinds)
-
- @property
- def selectedKinds(self):
- """Tuple of all selected kinds (as strings)."""
- # check for updates when self.itemSelectionChanged
- return [item.text() for item in self.selectedItems()]
-
- def setAvailableKinds(self, kinds):
- """Set a list of kinds to be displayed.
-
- :param list[str] kinds: Sequence of kinds
- """
- self.plot_item_kinds = kinds
-
- self.clear()
- for kind in self.plot_item_kinds:
- item = qt.QListWidgetItem(self)
- item.setText(kind)
- self.addItem(item)
-
- def selectAll(self):
- """Select all available kinds."""
- if self.selectionMode() in [qt.QAbstractItemView.SingleSelection,
- qt.QAbstractItemView.NoSelection]:
- raise RuntimeError("selectAll requires a multiple selection mode")
- for i in range(self.count()):
- self.item(i).setSelected(True)
-
-
-class PlotItemsSelector(qt.QTableWidget):
- """Table widget displaying the legend and kind of all
- plot items corresponding to a list of specified kinds.
-
- Selected plot items are provided as property :attr:`selectedPlotItems`.
- You can be warned of selection changes by listening to signal
- :attr:`itemSelectionChanged`.
- """
- def __init__(self, parent=None, plot=None):
- if plot is None or not isinstance(plot, PlotWidget):
- raise AttributeError("parameter plot is required")
- self.plot = plot
- """:class:`PlotWidget` instance"""
-
- self.plot_item_kinds = None
- """List of plot item kinds (strings)"""
-
- qt.QTableWidget.__init__(self, parent)
-
- self.setColumnCount(2)
-
- self.setSelectionBehavior(qt.QTableWidget.SelectRows)
-
- def _clear(self):
- self.clear()
- self.setHorizontalHeaderLabels(["legend", "type"])
-
- def setAllKindsFilter(self):
- """Display all kinds of plot items."""
- self.setKindsFilter(PlotWidget.ITEM_KINDS)
-
- def setKindsFilter(self, kinds):
- """Set list of all kinds of plot items to be displayed.
-
- :param list[str] kinds: Sequence of kinds
- """
- if not set(kinds) <= set(PlotWidget.ITEM_KINDS):
- raise KeyError("Illegal plot item kinds: %s" %
- set(kinds) - set(PlotWidget.ITEM_KINDS))
- self.plot_item_kinds = kinds
-
- self.updatePlotItems()
-
- def updatePlotItems(self):
- self._clear()
-
- # respect order of kinds as set in method setKindsFilter
- itemsAndKind = []
- for kind in self.plot_item_kinds:
- itemClasses = self.plot._KIND_TO_CLASSES[kind]
- for item in self.plot.getItems():
- if isinstance(item, itemClasses) and item.isVisible():
- itemsAndKind.append((item, kind))
-
- self.setRowCount(len(itemsAndKind))
-
- for index, (item, kind) in enumerate(itemsAndKind):
- legend_twitem = qt.QTableWidgetItem(item.getName())
- self.setItem(index, 0, legend_twitem)
-
- kind_twitem = qt.QTableWidgetItem(kind)
- self.setItem(index, 1, kind_twitem)
-
- @property
- def selectedPlotItems(self):
- """List of all selected items"""
- selection_model = self.selectionModel()
- selected_rows_idx = selection_model.selectedRows()
- selected_rows = [idx.row() for idx in selected_rows_idx]
-
- items = []
- for row in selected_rows:
- legend = self.item(row, 0).text()
- kind = self.item(row, 1).text()
- item = self.plot._getItem(kind, legend)
- if item is not None:
- items.append(item)
-
- return items
-
-
-class ItemsSelectionDialog(qt.QDialog):
- """This widget is a modal dialog allowing to select one or more plot
- items, in a table displaying their legend and kind.
-
- Public methods:
-
- - :meth:`getSelectedItems`
- - :meth:`setAvailableKinds`
- - :meth:`setItemsSelectionMode`
-
- This widget inherits QDialog and therefore implements the usual
- dialog methods, e.g. :meth:`exec_`.
-
- A trivial usage example would be::
-
- isd = ItemsSelectionDialog(plot=my_plot_widget)
- isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
- result = isd.exec_()
- if result:
- for item in isd.getSelectedItems():
- print(item.getName(), type(item))
- else:
- print("Selection cancelled")
- """
- def __init__(self, parent=None, plot=None):
- if plot is None or not isinstance(plot, PlotWidget):
- raise AttributeError("parameter plot is required")
- qt.QDialog.__init__(self, parent)
-
- self.setWindowTitle("Plot items selector")
-
- kind_selector_label = qt.QLabel("Filter item kinds:", self)
- item_selector_label = qt.QLabel("Select items:", self)
-
- self.kind_selector = KindsSelector(self)
- self.kind_selector.setToolTip(
- "select one or more item kinds to show them in the item list")
-
- self.item_selector = PlotItemsSelector(self, plot)
- self.item_selector.setToolTip("select items")
-
- self.item_selector.setKindsFilter(self.kind_selector.selectedKinds)
- self.kind_selector.sigSelectedKindsChanged.connect(
- self.item_selector.setKindsFilter
- )
-
- okb = qt.QPushButton("OK", self)
- okb.clicked.connect(self.accept)
-
- cancelb = qt.QPushButton("Cancel", self)
- cancelb.clicked.connect(self.reject)
-
- layout = qt.QGridLayout(self)
- layout.addWidget(kind_selector_label, 0, 0)
- layout.addWidget(item_selector_label, 0, 1)
- layout.addWidget(self.kind_selector, 1, 0)
- layout.addWidget(self.item_selector, 1, 1)
- layout.addWidget(okb, 2, 0)
- layout.addWidget(cancelb, 2, 1)
-
- self.setLayout(layout)
-
- def getSelectedItems(self):
- """Return a list of selected plot items
-
- :return: List of selected plot items
- :rtype: list[silx.gui.plot.items.Item]"""
- return self.item_selector.selectedPlotItems
-
- def setAvailableKinds(self, kinds):
- """Set a list of kinds to be displayed.
-
- :param list[str] kinds: Sequence of kinds
- """
- self.kind_selector.setAvailableKinds(kinds)
-
- def selectAllKinds(self):
- self.kind_selector.selectAll()
-
- def setItemsSelectionMode(self, mode):
- """Set selection mode for plot item (single item selection,
- multiple...).
-
- :param mode: One of :class:`QTableWidget` selection modes
- """
- if mode == self.item_selector.SingleSelection:
- self.item_selector.setToolTip(
- "Select one item by clicking on it.")
- elif mode == self.item_selector.MultiSelection:
- self.item_selector.setToolTip(
- "Select one or more items by clicking with the left mouse"
- " button.\nYou can unselect items by clicking them again.\n"
- "Multiple items can be toggled by dragging the mouse over them.")
- elif mode == self.item_selector.ExtendedSelection:
- self.item_selector.setToolTip(
- "Select one or more items. You can select multiple items "
- "by keeping the Ctrl key pushed when clicking.\nYou can "
- "select a range of items by clicking on the first and "
- "last while keeping the Shift key pushed.")
- elif mode == self.item_selector.ContiguousSelection:
- self.item_selector.setToolTip(
- "Select one item by clicking on it. If you press the Shift"
- " key while clicking on a second item,\nall items between "
- "the two will be selected.")
- elif mode == self.item_selector.NoSelection:
- raise ValueError("The NoSelection mode is not allowed "
- "in this context.")
- self.item_selector.setSelectionMode(mode)
diff --git a/silx/gui/plot/LegendSelector.py b/silx/gui/plot/LegendSelector.py
deleted file mode 100755
index 94112aa..0000000
--- a/silx/gui/plot/LegendSelector.py
+++ /dev/null
@@ -1,1036 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Widget displaying curves legends and allowing to operate on curves.
-
-This widget is meant to work with :class:`PlotWindow`.
-"""
-
-__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"]
-__license__ = "MIT"
-__data__ = "16/10/2017"
-
-
-import logging
-import weakref
-
-import numpy
-
-from .. import qt, colors
-from ..widgets.LegendIconWidget import LegendIconWidget
-from . import items
-
-
-_logger = logging.getLogger(__name__)
-
-
-class LegendIcon(LegendIconWidget):
- """Object displaying a curve linestyle and symbol.
-
- :param QWidget parent: See :class:`QWidget`
- :param Union[~silx.gui.plot.items.Curve,None] curve:
- Curve with which to synchronize
- """
-
- def __init__(self, parent=None, curve=None):
- super(LegendIcon, self).__init__(parent)
- self._curveRef = None
- self.setCurve(curve)
-
- def getCurve(self):
- """Returns curve associated to this widget
-
- :rtype: Union[~silx.gui.plot.items.Curve,None]
- """
- return None if self._curveRef is None else self._curveRef()
-
- def setCurve(self, curve):
- """Set the curve with which to synchronize this widget.
-
- :param curve: Union[~silx.gui.plot.items.Curve,None]
- """
- assert curve is None or isinstance(curve, items.Curve)
-
- previousCurve = self.getCurve()
- if curve == previousCurve:
- return
-
- if previousCurve is not None:
- previousCurve.sigItemChanged.disconnect(self._curveChanged)
-
- self._curveRef = None if curve is None else weakref.ref(curve)
-
- if curve is not None:
- curve.sigItemChanged.connect(self._curveChanged)
-
- self._update()
-
- def _update(self):
- """Update widget according to current curve state.
- """
- curve = self.getCurve()
- if curve is None:
- _logger.error('Curve no more exists')
- self.setEnabled(False)
- return
-
- style = curve.getCurrentStyle()
-
- self.setEnabled(curve.isVisible())
- self.setSymbol(style.getSymbol())
- self.setLineWidth(style.getLineWidth())
- self.setLineStyle(style.getLineStyle())
-
- color = style.getColor()
- if numpy.array(color, copy=False).ndim != 1:
- # array of colors, use transparent black
- color = 0., 0., 0., 0.
- color = colors.rgba(color) # Make sure it is float in [0, 1]
- alpha = curve.getAlpha()
- color = qt.QColor.fromRgbF(
- color[0], color[1], color[2], color[3] * alpha)
- self.setLineColor(color)
- self.setSymbolColor(color)
- self.update() # TODO this should not be needed
-
- def _curveChanged(self, event):
- """Handle update of curve item
-
- :param event: Kind of change
- """
- if event in (items.ItemChangedType.VISIBLE,
- items.ItemChangedType.SYMBOL,
- items.ItemChangedType.SYMBOL_SIZE,
- items.ItemChangedType.LINE_WIDTH,
- items.ItemChangedType.LINE_STYLE,
- items.ItemChangedType.COLOR,
- items.ItemChangedType.ALPHA,
- items.ItemChangedType.HIGHLIGHTED,
- items.ItemChangedType.HIGHLIGHTED_STYLE):
- self._update()
-
-
-class LegendModel(qt.QAbstractListModel):
- """Data model of curve legends.
-
- It holds the information of the curve:
-
- - color
- - line width
- - line style
- - visibility of the lines
- - symbol
- - visibility of the symbols
- """
- iconColorRole = qt.Qt.UserRole + 0
- iconLineWidthRole = qt.Qt.UserRole + 1
- iconLineStyleRole = qt.Qt.UserRole + 2
- showLineRole = qt.Qt.UserRole + 3
- iconSymbolRole = qt.Qt.UserRole + 4
- showSymbolRole = qt.Qt.UserRole + 5
-
- def __init__(self, legendList=None, parent=None):
- super(LegendModel, self).__init__(parent)
- if legendList is None:
- legendList = []
- self.legendList = []
- self.insertLegendList(0, legendList)
- self._palette = qt.QPalette()
-
- def __getitem__(self, idx):
- if idx >= len(self.legendList):
- raise IndexError('list index out of range')
- return self.legendList[idx]
-
- def rowCount(self, modelIndex=None):
- return len(self.legendList)
-
- def flags(self, index):
- return (qt.Qt.ItemIsEditable |
- qt.Qt.ItemIsEnabled |
- qt.Qt.ItemIsSelectable)
-
- def data(self, modelIndex, role):
- if modelIndex.isValid:
- idx = modelIndex.row()
- else:
- return None
- if idx >= len(self.legendList):
- raise IndexError('list index out of range')
-
- item = self.legendList[idx]
- isActive = item[1].get("active", False)
- if role == qt.Qt.DisplayRole:
- # Data to be rendered in the form of text
- legend = str(item[0])
- return legend
- elif role == qt.Qt.SizeHintRole:
- # size = qt.QSize(200,50)
- _logger.warning('LegendModel -- size hint role not implemented')
- return qt.QSize()
- elif role == qt.Qt.TextAlignmentRole:
- alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft
- return alignment
- elif role == qt.Qt.BackgroundRole:
- # Background color, must be QBrush
- if isActive:
- brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.Highlight)
- elif idx % 2:
- brush = qt.QBrush(qt.QColor(240, 240, 240))
- else:
- brush = qt.QBrush(qt.Qt.white)
- return brush
- elif role == qt.Qt.ForegroundRole:
- # ForegroundRole color, must be QBrush
- if isActive:
- brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.HighlightedText)
- else:
- brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.WindowText)
- return brush
- elif role == qt.Qt.CheckStateRole:
- return bool(item[2]) # item[2] == True
- elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole:
- return ''
- elif role == self.iconColorRole:
- return item[1]['color']
- elif role == self.iconLineWidthRole:
- return item[1]['linewidth']
- elif role == self.iconLineStyleRole:
- return item[1]['linestyle']
- elif role == self.iconSymbolRole:
- return item[1]['symbol']
- elif role == self.showLineRole:
- return item[3]
- elif role == self.showSymbolRole:
- return item[4]
- else:
- _logger.info('Unkown role requested: %s', str(role))
- return None
-
- def setData(self, modelIndex, value, role):
- if modelIndex.isValid:
- idx = modelIndex.row()
- else:
- return None
- if idx >= len(self.legendList):
- # raise IndexError('list index out of range')
- _logger.warning(
- 'setData -- List index out of range, idx: %d', idx)
- return None
-
- item = self.legendList[idx]
- try:
- if role == qt.Qt.DisplayRole:
- # Set legend
- item[0] = str(value)
- elif role == self.iconColorRole:
- item[1]['color'] = qt.QColor(value)
- elif role == self.iconLineWidthRole:
- item[1]['linewidth'] = int(value)
- elif role == self.iconLineStyleRole:
- item[1]['linestyle'] = str(value)
- elif role == self.iconSymbolRole:
- item[1]['symbol'] = str(value)
- elif role == qt.Qt.CheckStateRole:
- item[2] = value
- elif role == self.showLineRole:
- item[3] = value
- elif role == self.showSymbolRole:
- item[4] = value
- except ValueError:
- _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s',
- str(value), str(role))
- # Can that be right? Read docs again..
- self.dataChanged.emit(modelIndex, modelIndex)
- return True
-
- def insertLegendList(self, row, llist):
- """
- :param int row: Determines after which row the items are inserted
- :param llist: Carries the new legend information
- :type llist: List
- """
- modelIndex = self.createIndex(row, 0)
- count = len(llist)
- super(LegendModel, self).beginInsertRows(modelIndex,
- row,
- row + count)
- head = self.legendList[0:row]
- tail = self.legendList[row:]
- new = []
- for (legend, icon) in llist:
- linestyle = icon.get('linestyle', None)
- if LegendIconWidget.isEmptyLineStyle(linestyle):
- # Curve had no line, give it one and hide it
- # So when toggle line, it will display a solid line
- showLine = False
- icon['linestyle'] = '-'
- else:
- showLine = True
-
- symbol = icon.get('symbol', None)
- if LegendIconWidget.isEmptySymbol(symbol):
- # Curve had no symbol, give it one and hide it
- # So when toggle symbol, it will display 'o'
- showSymbol = False
- icon['symbol'] = 'o'
- else:
- showSymbol = True
-
- selected = icon.get('selected', True)
- item = [legend,
- icon,
- selected,
- showLine,
- showSymbol]
- new.append(item)
- self.legendList = head + new + tail
- super(LegendModel, self).endInsertRows()
- return True
-
- def insertRows(self, row, count, modelIndex=qt.QModelIndex()):
- raise NotImplementedError('Use LegendModel.insertLegendList instead')
-
- def removeRow(self, row):
- return self.removeRows(row, 1)
-
- def removeRows(self, row, count, modelIndex=qt.QModelIndex()):
- length = len(self.legendList)
- if length == 0:
- # Nothing to do..
- return True
- if row < 0 or row >= length:
- raise IndexError('Index out of range -- ' +
- 'idx: %d, len: %d' % (row, length))
- if count == 0:
- return False
- super(LegendModel, self).beginRemoveRows(modelIndex,
- row,
- row + count)
- del(self.legendList[row:row + count])
- super(LegendModel, self).endRemoveRows()
- return True
-
- def setEditor(self, event, editor):
- """
- :param str event: String that identifies the editor
- :param editor: Widget used to change data in the underlying model
- :type editor: QWidget
- """
- if event not in self.eventList:
- raise ValueError('setEditor -- Event must be in %s' %
- str(self.eventList))
- self.editorDict[event] = editor
-
-
-class LegendListItemWidget(qt.QItemDelegate):
- """Object displaying a single item (i.e., a row) in the list."""
-
- # Notice: LegendListItem does NOT inherit
- # from QObject, it cannot emit signals!
-
- def __init__(self, parent=None, itemType=0):
- super(LegendListItemWidget, self).__init__(parent)
-
- # Dictionary to render checkboxes
- self.cbDict = {}
- self.labelDict = {}
- self.iconDict = {}
-
- # Keep checkbox and legend to get sizeHint
- self.checkbox = qt.QCheckBox()
- self.legend = qt.QLabel()
- self.icon = LegendIcon()
-
- # Context Menu and Editors
- self.contextMenu = None
-
- def paint(self, painter, option, modelIndex):
- """
- Here be docs..
-
- :param QPainter painter:
- :param QStyleOptionViewItem option:
- :param QModelIndex modelIndex:
- """
- painter.save()
- rect = option.rect
-
- # Calculate the icon rectangle
- iconSize = self.icon.sizeHint()
- # Calculate icon position
- x = rect.left() + 2
- y = rect.top() + int(.5 * (rect.height() - iconSize.height()))
- iconRect = qt.QRect(qt.QPoint(x, y), iconSize)
-
- # Calculate label rectangle
- legendSize = qt.QSize(rect.width() - iconSize.width() - 30,
- rect.height())
- # Calculate label position
- x = rect.left() + iconRect.width()
- y = rect.top()
- labelRect = qt.QRect(qt.QPoint(x, y), legendSize)
- labelRect.translate(qt.QPoint(10, 0))
-
- # Calculate the checkbox rectangle
- x = rect.right() - 30
- y = rect.top()
- chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight())
-
- # Remember the rectangles
- idx = modelIndex.row()
- self.cbDict[idx] = chBoxRect
- self.iconDict[idx] = iconRect
- self.labelDict[idx] = labelRect
-
- # Draw background first!
- if option.state & qt.QStyle.State_MouseOver:
- backgroundBrush = option.palette.highlight()
- else:
- backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole)
- painter.fillRect(rect, backgroundBrush)
-
- # Draw label
- legendText = modelIndex.data(qt.Qt.DisplayRole)
- textBrush = modelIndex.data(qt.Qt.ForegroundRole)
- textAlign = modelIndex.data(qt.Qt.TextAlignmentRole)
- painter.setBrush(textBrush)
- painter.setFont(self.legend.font())
- painter.setPen(textBrush.color())
- painter.drawText(labelRect, textAlign, legendText)
-
- # Draw icon
- iconColor = modelIndex.data(LegendModel.iconColorRole)
- iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole)
- iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole)
- iconSymbol = modelIndex.data(LegendModel.iconSymbolRole)
- icon = LegendIcon()
- icon.resize(iconRect.size())
- icon.move(iconRect.topRight())
- icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole)
- icon.showLine = modelIndex.data(LegendModel.showLineRole)
- icon.setSymbolColor(iconColor)
- icon.setLineColor(iconColor)
- icon.setLineWidth(iconLineWidth)
- icon.setLineStyle(iconLineStyle)
- icon.setSymbol(iconSymbol)
- icon.symbolOutlineBrush = backgroundBrush
- icon.paint(painter, iconRect, option.palette)
-
- # Draw the checkbox
- if modelIndex.data(qt.Qt.CheckStateRole):
- checkState = qt.Qt.Checked
- else:
- checkState = qt.Qt.Unchecked
-
- self.drawCheck(
- painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
-
- painter.restore()
-
- def editorEvent(self, event, model, option, modelIndex):
- # From the docs:
- # Mouse events are sent to editorEvent()
- # even if they don't start editing of the item.
- if event.button() == qt.Qt.RightButton and self.contextMenu:
- self.contextMenu.exec_(event.globalPos(), modelIndex)
- return True
- elif event.button() == qt.Qt.LeftButton:
- # Check if checkbox was clicked
- idx = modelIndex.row()
- cbRect = self.cbDict[idx]
- if cbRect.contains(event.pos()):
- # Toggle checkbox
- model.setData(modelIndex,
- not modelIndex.data(qt.Qt.CheckStateRole),
- qt.Qt.CheckStateRole)
- event.ignore()
- return True
- else:
- return super(LegendListItemWidget, self).editorEvent(
- event, model, option, modelIndex)
-
- def createEditor(self, parent, option, idx):
- _logger.info('### Editor request ###')
-
- def sizeHint(self, option, idx):
- # return qt.QSize(68,24)
- iconSize = self.icon.sizeHint()
- legendSize = self.legend.sizeHint()
- checkboxSize = self.checkbox.sizeHint()
- height = max([iconSize.height(),
- legendSize.height(),
- checkboxSize.height()]) + 4
- width = iconSize.width() + legendSize.width() + checkboxSize.width()
- return qt.QSize(width, height)
-
-
-class LegendListView(qt.QListView):
- """Widget displaying a list of curve legends, line style and symbol."""
-
- sigLegendSignal = qt.Signal(object)
- """Signal emitting a dict when an action is triggered by the user."""
-
- __mouseClickedEvent = 'mouseClicked'
- __checkBoxClickedEvent = 'checkBoxClicked'
- __legendClickedEvent = 'legendClicked'
-
- def __init__(self, parent=None, model=None, contextMenu=None):
- super(LegendListView, self).__init__(parent)
- self.__lastButton = None
- self.__lastClickPos = None
- self.__lastModelIdx = None
- # Set default delegate
- self.setItemDelegate(LegendListItemWidget())
- # Set default editors
- # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding,
- # qt.QSizePolicy.MinimumExpanding)
- # Set edit triggers by hand using self.edit(QModelIndex)
- # in mousePressEvent (better to control than signals)
- self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
-
- # Control layout
- # self.setBatchSize(2)
- # self.setLayoutMode(qt.QListView.Batched)
- # self.setFlow(qt.QListView.LeftToRight)
-
- # Control selection
- self.setSelectionMode(qt.QAbstractItemView.NoSelection)
-
- if model is None:
- model = LegendModel(parent=self)
- self.setModel(model)
- self.setContextMenu(contextMenu)
-
- def setLegendList(self, legendList, row=None):
- if row is not None:
- model = self.model()
- model.insertLegendList(row, legendList)
- elif len(legendList) != self.model().rowCount():
- self.clear()
- model = self.model()
- model.insertLegendList(0, legendList)
- else:
- model = self.model()
- for i, (new_legend, icon) in enumerate(legendList):
- modelIndex = model.index(i)
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- if new_legend != legend:
- model.setData(modelIndex, new_legend, qt.Qt.DisplayRole)
-
- color = modelIndex.data(LegendModel.iconColorRole)
- new_color = icon.get('color', None)
- if new_color != color:
- model.setData(modelIndex, new_color, LegendModel.iconColorRole)
-
- linewidth = modelIndex.data(LegendModel.iconLineWidthRole)
- new_linewidth = icon.get('linewidth', 1.0)
- if new_linewidth != linewidth:
- model.setData(modelIndex, new_linewidth, LegendModel.iconLineWidthRole)
-
- linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
- new_linestyle = icon.get('linestyle', None)
- visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle)
- model.setData(modelIndex, visible, LegendModel.showLineRole)
- if new_linestyle != linestyle:
- model.setData(modelIndex, new_linestyle, LegendModel.iconLineStyleRole)
-
- symbol = modelIndex.data(LegendModel.iconSymbolRole)
- new_symbol = icon.get('symbol', None)
- visible = not LegendIconWidget.isEmptySymbol(new_symbol)
- model.setData(modelIndex, visible, LegendModel.showSymbolRole)
- if new_symbol != symbol:
- model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole)
-
- selected = modelIndex.data(qt.Qt.CheckStateRole)
- new_selected = icon.get('selected', True)
- if new_selected != selected:
- model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole)
- _logger.debug('LegendListView.setLegendList(legendList) finished')
-
- def clear(self):
- model = self.model()
- model.removeRows(0, model.rowCount())
- _logger.debug('LegendListView.clear() finished')
-
- def setContextMenu(self, contextMenu=None):
- delegate = self.itemDelegate()
- if isinstance(delegate, LegendListItemWidget) and self.model():
- if contextMenu is None:
- delegate.contextMenu = LegendListContextMenu(self.model())
- delegate.contextMenu.sigContextMenu.connect(
- self._contextMenuSlot)
- else:
- delegate.contextMenu = contextMenu
-
- def __getitem__(self, idx):
- model = self.model()
- try:
- item = model[idx]
- except ValueError:
- item = None
- return item
-
- def _contextMenuSlot(self, ddict):
- self.sigLegendSignal.emit(ddict)
-
- def mousePressEvent(self, event):
- self.__lastButton = event.button()
- self.__lastPosition = event.pos()
- super(LegendListView, self).mousePressEvent(event)
- # call _handleMouseClick after editing was handled
- # If right click (context menu) is aborted, no
- # signal is emitted..
- self._handleMouseClick(self.indexAt(self.__lastPosition))
-
- def mouseDoubleClickEvent(self, event):
- self.__lastButton = event.button()
- self.__lastPosition = event.pos()
- super(LegendListView, self).mouseDoubleClickEvent(event)
- # call _handleMouseClick after editing was handled
- # If right click (context menu) is aborted, no
- # signal is emitted..
- self._handleMouseClick(self.indexAt(self.__lastPosition))
-
- def mouseMoveEvent(self, event):
- # LegendListView.mouseMoveEvent is overwritten
- # to suppress unwanted behavior in the delegate.
- pass
-
- def mouseReleaseEvent(self, event):
- # LegendListView.mouseReleaseEvent is overwritten
- # to subpress unwanted behavior in the delegate.
- pass
-
- def _handleMouseClick(self, modelIndex):
- """
- Distinguish between mouse click on Legend
- and mouse click on CheckBox by setting the
- currentCheckState attribute in LegendListItem.
-
- Emits signal sigLegendSignal(ddict)
-
- :param QModelIndex modelIndex: index of the clicked item
- """
- _logger.debug('self._handleMouseClick called')
- if self.__lastButton not in [qt.Qt.LeftButton,
- qt.Qt.RightButton]:
- return
- if not modelIndex.isValid():
- _logger.debug('_handleMouseClick -- Invalid QModelIndex')
- return
- # model = self.model()
- idx = modelIndex.row()
-
- delegate = self.itemDelegate()
- cbClicked = False
- if isinstance(delegate, LegendListItemWidget):
- for cbRect in delegate.cbDict.values():
- if cbRect.contains(self.__lastPosition):
- cbClicked = True
- break
-
- # TODO: Check for doubleclicks on legend/icon and spawn editors
-
- ddict = {
- 'legend': str(modelIndex.data(qt.Qt.DisplayRole)),
- 'icon': {
- 'linewidth': str(modelIndex.data(
- LegendModel.iconLineWidthRole)),
- 'linestyle': str(modelIndex.data(
- LegendModel.iconLineStyleRole)),
- 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole))
- },
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data())
- }
- if self.__lastButton == qt.Qt.RightButton:
- _logger.debug('Right clicked')
- ddict['button'] = "right"
- ddict['event'] = self.__mouseClickedEvent
- elif cbClicked:
- _logger.debug('CheckBox clicked')
- ddict['button'] = "left"
- ddict['event'] = self.__checkBoxClickedEvent
- else:
- _logger.debug('Legend clicked')
- ddict['button'] = "left"
- ddict['event'] = self.__legendClickedEvent
- _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict))
- self.sigLegendSignal.emit(ddict)
-
-
-class LegendListContextMenu(qt.QMenu):
- """Contextual menu associated to items in a :class:`LegendListView`."""
-
- sigContextMenu = qt.Signal(object)
- """Signal emitting a dict upon contextual menu actions."""
-
- def __init__(self, model):
- super(LegendListContextMenu, self).__init__(parent=None)
- self.model = model
-
- self.addAction('Set Active', self.setActiveAction)
- self.addAction('Map to left', self.mapToLeftAction)
- self.addAction('Map to right', self.mapToRightAction)
-
- self._pointsAction = self.addAction(
- 'Points', self.togglePointsAction)
- self._pointsAction.setCheckable(True)
-
- self._linesAction = self.addAction('Lines', self.toggleLinesAction)
- self._linesAction.setCheckable(True)
-
- self.addAction('Remove curve', self.removeItemAction)
- self.addAction('Rename curve', self.renameItemAction)
-
- def exec_(self, pos, idx):
- self.__currentIdx = idx
-
- # Set checkable action state
- modelIndex = self.currentIdx()
- self._pointsAction.setChecked(
- modelIndex.data(LegendModel.showSymbolRole))
- self._linesAction.setChecked(
- modelIndex.data(LegendModel.showLineRole))
-
- super(LegendListContextMenu, self).popup(pos)
-
- def currentIdx(self):
- return self.__currentIdx
-
- def mapToLeftAction(self):
- _logger.debug('LegendListContextMenu.mapToLeftAction called')
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "mapToLeft"
- }
- self.sigContextMenu.emit(ddict)
-
- def mapToRightAction(self):
- _logger.debug('LegendListContextMenu.mapToRightAction called')
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "mapToRight"
- }
- self.sigContextMenu.emit(ddict)
-
- def removeItemAction(self):
- _logger.debug('LegendListContextMenu.removeCurveAction called')
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "removeCurve"
- }
- self.model.removeRow(modelIndex.row())
- self.sigContextMenu.emit(ddict)
-
- def renameItemAction(self):
- _logger.debug('LegendListContextMenu.renameCurveAction called')
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "renameCurve"
- }
- self.sigContextMenu.emit(ddict)
-
- def toggleLinesAction(self):
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- }
- linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
- visible = not modelIndex.data(LegendModel.showLineRole)
- _logger.debug('toggleLinesAction -- lines visible: %s', str(visible))
- ddict['event'] = "toggleLine"
- ddict['line'] = visible
- ddict['linestyle'] = linestyle if visible else ''
- self.model.setData(modelIndex, visible, LegendModel.showLineRole)
- self.sigContextMenu.emit(ddict)
-
- def togglePointsAction(self):
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- }
- flag = modelIndex.data(LegendModel.showSymbolRole)
- symbol = modelIndex.data(LegendModel.iconSymbolRole)
- visible = not flag or LegendIconWidget.isEmptySymbol(symbol)
- _logger.debug(
- 'togglePointsAction -- Symbols visible: %s', str(visible))
-
- ddict['event'] = "togglePoints"
- ddict['points'] = visible
- ddict['symbol'] = symbol if visible else ''
- self.model.setData(modelIndex, visible, LegendModel.showSymbolRole)
- self.sigContextMenu.emit(ddict)
-
- def setActiveAction(self):
- modelIndex = self.currentIdx()
- legend = str(modelIndex.data(qt.Qt.DisplayRole))
- _logger.debug('setActiveAction -- active curve: %s', legend)
- ddict = {
- 'legend': legend,
- 'label': legend,
- 'selected': modelIndex.data(qt.Qt.CheckStateRole),
- 'type': str(modelIndex.data()),
- 'event': "setActiveCurve",
- }
- self.sigContextMenu.emit(ddict)
-
-
-class RenameCurveDialog(qt.QDialog):
- """Dialog box to input the name of a curve."""
-
- def __init__(self, parent=None, current="", curves=()):
- super(RenameCurveDialog, self).__init__(parent)
- self.setWindowTitle("Rename Curve %s" % current)
- self.curves = curves
- layout = qt.QVBoxLayout(self)
- self.lineEdit = qt.QLineEdit(self)
- self.lineEdit.setText(current)
- self.hbox = qt.QWidget(self)
- self.hboxLayout = qt.QHBoxLayout(self.hbox)
- self.hboxLayout.addStretch(1)
- self.okButton = qt.QPushButton(self.hbox)
- self.okButton.setText('OK')
- self.hboxLayout.addWidget(self.okButton)
- self.cancelButton = qt.QPushButton(self.hbox)
- self.cancelButton.setText('Cancel')
- self.hboxLayout.addWidget(self.cancelButton)
- self.hboxLayout.addStretch(1)
- layout.addWidget(self.lineEdit)
- layout.addWidget(self.hbox)
- self.okButton.clicked.connect(self.preAccept)
- self.cancelButton.clicked.connect(self.reject)
-
- def preAccept(self):
- text = str(self.lineEdit.text())
- addedText = ""
- if len(text):
- if text not in self.curves:
- self.accept()
- return
- else:
- addedText = "Curve already exists."
- text = "Invalid Curve Name"
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setWindowTitle(text)
- text += "\n%s" % addedText
- msg.setText(text)
- msg.exec_()
-
- def getText(self):
- return str(self.lineEdit.text())
-
-
-class LegendsDockWidget(qt.QDockWidget):
- """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow.
-
- It makes the link between the LegendListView widget and the PlotWindow.
-
- :param parent: See :class:`QDockWidget`
- :param plot: :class:`.PlotWindow` instance on which to operate
- """
-
- def __init__(self, parent=None, plot=None):
- assert plot is not None
- self._plotRef = weakref.ref(plot)
- self._isConnected = False # True if widget connected to plot signals
-
- super(LegendsDockWidget, self).__init__("Legends", parent)
-
- self._legendWidget = LegendListView()
-
- self.layout().setContentsMargins(0, 0, 0, 0)
- self.setWidget(self._legendWidget)
-
- self.visibilityChanged.connect(
- self._visibilityChangedHandler)
-
- self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler)
-
- @property
- def plot(self):
- """The :class:`.PlotWindow` this widget is attached to."""
- return self._plotRef()
-
- def renameCurve(self, oldLegend, newLegend):
- """Change the name of a curve using remove and addCurve
-
- :param str oldLegend: The legend of the curve to be changed
- :param str newLegend: The new legend of the curve
- """
- is_active = self.plot.getActiveCurve(just_legend=True) == oldLegend
- curve = self.plot.getCurve(oldLegend)
- self.plot.remove(oldLegend, kind='curve')
- self.plot.addCurve(curve.getXData(copy=False),
- curve.getYData(copy=False),
- legend=newLegend,
- info=curve.getInfo(),
- color=curve.getColor(),
- symbol=curve.getSymbol(),
- linewidth=curve.getLineWidth(),
- linestyle=curve.getLineStyle(),
- xlabel=curve.getXLabel(),
- ylabel=curve.getYLabel(),
- xerror=curve.getXErrorData(copy=False),
- yerror=curve.getYErrorData(copy=False),
- z=curve.getZValue(),
- selectable=curve.isSelectable(),
- fill=curve.isFill(),
- resetzoom=False)
- if is_active:
- self.plot.setActiveCurve(newLegend)
-
- def _legendSignalHandler(self, ddict):
- """Handles events from the LegendListView signal"""
- _logger.debug("Legend signal ddict = %s", str(ddict))
-
- if ddict['event'] == "legendClicked":
- if ddict['button'] == "left":
- self.plot.setActiveCurve(ddict['legend'])
-
- elif ddict['event'] == "removeCurve":
- self.plot.removeCurve(ddict['legend'])
-
- elif ddict['event'] == "renameCurve":
- curveList = self.plot.getAllCurves(just_legend=True)
- oldLegend = ddict['legend']
- dialog = RenameCurveDialog(self.plot, oldLegend, curveList)
- ret = dialog.exec_()
- if ret:
- newLegend = dialog.getText()
- self.renameCurve(oldLegend, newLegend)
-
- elif ddict['event'] == "setActiveCurve":
- self.plot.setActiveCurve(ddict['legend'])
-
- elif ddict['event'] == "checkBoxClicked":
- self.plot.hideCurve(ddict['legend'], not ddict['selected'])
-
- elif ddict['event'] in ["mapToRight", "mapToLeft"]:
- legend = ddict['legend']
- curve = self.plot.getCurve(legend)
- yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left'
- self.plot.addCurve(x=curve.getXData(copy=False),
- y=curve.getYData(copy=False),
- legend=curve.getName(),
- info=curve.getInfo(),
- yaxis=yaxis)
-
- elif ddict['event'] == "togglePoints":
- legend = ddict['legend']
- curve = self.plot.getCurve(legend)
- symbol = ddict['symbol'] if ddict['points'] else ''
- self.plot.addCurve(x=curve.getXData(copy=False),
- y=curve.getYData(copy=False),
- legend=curve.getName(),
- info=curve.getInfo(),
- symbol=symbol)
-
- elif ddict['event'] == "toggleLine":
- legend = ddict['legend']
- curve = self.plot.getCurve(legend)
- linestyle = ddict['linestyle'] if ddict['line'] else ''
- self.plot.addCurve(x=curve.getXData(copy=False),
- y=curve.getYData(copy=False),
- legend=curve.getName(),
- info=curve.getInfo(),
- linestyle=linestyle)
-
- else:
- _logger.debug("unhandled event %s", str(ddict['event']))
-
- def updateLegends(self, *args):
- """Sync the LegendSelector widget displayed info with the plot.
- """
- legendList = []
- for curve in self.plot.getAllCurves(withhidden=True):
- legend = curve.getName()
- # Use active color if curve is active
- isActive = legend == self.plot.getActiveCurve(just_legend=True)
- style = curve.getCurrentStyle()
- color = style.getColor()
- if numpy.array(color, copy=False).ndim != 1:
- # array of colors, use transparent black
- color = 0., 0., 0., 0.
-
- curveInfo = {
- 'color': qt.QColor.fromRgbF(*color),
- 'linewidth': style.getLineWidth(),
- 'linestyle': style.getLineStyle(),
- 'symbol': style.getSymbol(),
- 'selected': not self.plot.isCurveHidden(legend),
- 'active': isActive}
- legendList.append((legend, curveInfo))
-
- self._legendWidget.setLegendList(legendList)
-
- def _visibilityChangedHandler(self, visible):
- if visible:
- self.updateLegends()
- if not self._isConnected:
- self.plot.sigContentChanged.connect(self.updateLegends)
- self.plot.sigActiveCurveChanged.connect(self.updateLegends)
- self._isConnected = True
- else:
- if self._isConnected:
- self.plot.sigContentChanged.disconnect(self.updateLegends)
- self.plot.sigActiveCurveChanged.disconnect(self.updateLegends)
- self._isConnected = False
-
- def showEvent(self, event):
- """Make sure this widget is raised when it is shown
- (when it is first created as a tab in PlotWindow or when it is shown
- again after hiding).
- """
- self.raise_()
diff --git a/silx/gui/plot/MaskToolsWidget.py b/silx/gui/plot/MaskToolsWidget.py
deleted file mode 100644
index 1ec1e7f..0000000
--- a/silx/gui/plot/MaskToolsWidget.py
+++ /dev/null
@@ -1,919 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Widget providing a set of tools to draw masks on a PlotWidget.
-
-This widget is meant to work with :class:`silx.gui.plot.PlotWidget`.
-
-- :class:`ImageMask`: Handle mask bitmap update and history
-- :class:`MaskToolsWidget`: GUI for :class:`Mask`
-- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
-"""
-from __future__ import division
-
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "08/12/2020"
-
-import os
-import sys
-import numpy
-import logging
-import collections
-import h5py
-
-from silx.image import shapes
-from silx.io.utils import NEXUS_HDF5_EXT, is_dataset
-from silx.gui.dialog.DatasetDialog import DatasetDialog
-
-from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
-from . import items
-from ..colors import cursorColorForColormap, rgba
-from .. import qt
-from ..utils import LockReentrant
-
-from silx.third_party.EdfFile import EdfFile
-from silx.third_party.TiffIO import TiffIO
-
-import fabio
-
-_logger = logging.getLogger(__name__)
-
-_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
-
-
-def _selectDataset(filename, mode=DatasetDialog.SaveMode):
- """Open a dialog to prompt the user to select a dataset in
- a hdf5 file.
-
- :param str filename: name of an existing HDF5 file
- :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
- :rtype: str
- :return: Name of selected dataset
- """
- dialog = DatasetDialog()
- dialog.addFile(filename)
- dialog.setWindowTitle("Select a 2D dataset")
- dialog.setMode(mode)
- if not dialog.exec_():
- return None
- return dialog.getSelectedDataUrl().data_path()
-
-
-class ImageMask(BaseMask):
- """A 2D mask field with update operations.
-
- Coords follows (row, column) convention and are in mask array coords.
-
- This is meant for internal use by :class:`MaskToolsWidget`.
- """
-
- def __init__(self, image=None):
- """
-
- :param image: :class:`silx.gui.plot.items.ImageBase` instance
- """
- BaseMask.__init__(self, image)
- self.reset(shape=(0, 0)) # Init the mask with a 2D shape
-
- def getDataValues(self):
- """Return image data as a 2D or 3D array (if it is a RGBA image).
-
- :rtype: 2D or 3D numpy.ndarray
- """
- return self._dataItem.getData(copy=False)
-
- def save(self, filename, kind):
- """Save current mask in a file
-
- :param str filename: The file where to save to mask
- :param str kind: The kind of file to save in 'edf', 'tif', 'npy', 'h5'
- or 'msk' (if FabIO is installed)
- :raise Exception: Raised if the file writing fail
- """
- if kind == 'edf':
- edfFile = EdfFile(filename, access="w+")
- header = {"program_name": "silx-mask", "masked_value": "nonzero"}
- edfFile.WriteImage(header, self.getMask(copy=False), Append=0)
-
- elif kind == 'tif':
- tiffFile = TiffIO(filename, mode='w')
- tiffFile.writeImage(self.getMask(copy=False), software='silx')
-
- elif kind == 'npy':
- try:
- numpy.save(filename, self.getMask(copy=False))
- except IOError:
- raise RuntimeError("Mask file can't be written")
-
- elif ("." + kind) in NEXUS_HDF5_EXT:
- self._saveToHdf5(filename, self.getMask(copy=False))
-
- elif kind == 'msk':
- try:
- data = self.getMask(copy=False)
- image = fabio.fabioimage.FabioImage(data=data)
- image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage)
- image.save(filename)
- except Exception:
- _logger.debug("Backtrace", exc_info=True)
- raise RuntimeError("Mask file can't be written")
- else:
- raise ValueError("Format '%s' is not supported" % kind)
-
- @staticmethod
- def _saveToHdf5(filename, mask):
- """Save a mask array to a HDF5 file.
-
- :param str filename: name of an existing HDF5 file
- :param numpy.ndarray mask: Mask array.
- :returns: True if operation succeeded, False otherwise.
- """
- if not os.path.exists(filename):
- # create new file
- with h5py.File(filename, "w") as _h5f:
- pass
- dataPath = _selectDataset(filename)
- if dataPath is None:
- return False
- with h5py.File(filename, "a") as h5f:
- existing_ds = h5f.get(dataPath)
- if existing_ds is not None:
- reply = qt.QMessageBox.question(
- None,
- "Confirm overwrite",
- "Do you want to overwrite an existing dataset?",
- qt.QMessageBox.Yes | qt.QMessageBox.No)
- if reply != qt.QMessageBox.Yes:
- return False
- del h5f[dataPath]
- try:
- h5f.create_dataset(dataPath, data=mask)
- except Exception:
- return False
- return True
-
- # Drawing operations
- def updateRectangle(self, level, row, col, height, width, mask=True):
- """Mask/Unmask a rectangle of the given mask level.
-
- :param int level: Mask level to update.
- :param int row: Starting row of the rectangle
- :param int col: Starting column of the rectangle
- :param int height:
- :param int width:
- :param bool mask: True to mask (default), False to unmask.
- """
- assert 0 < level < 256
- if row + height <= 0 or col + width <= 0:
- return # Rectangle outside image, avoid negative indices
- selection = self._mask[max(0, row):row + height + 1,
- max(0, col):col + width + 1]
- if mask:
- selection[:,:] = level
- else:
- selection[selection == level] = 0
- self._notify()
-
- def updatePolygon(self, level, vertices, mask=True):
- """Mask/Unmask a polygon of the given mask level.
-
- :param int level: Mask level to update.
- :param vertices: Nx2 array of polygon corners as (row, col)
- :param bool mask: True to mask (default), False to unmask.
- """
- fill = shapes.polygon_fill_mask(vertices, self._mask.shape)
- if mask:
- self._mask[fill != 0] = level
- else:
- self._mask[numpy.logical_and(fill != 0,
- self._mask == level)] = 0
- self._notify()
-
- def updatePoints(self, level, rows, cols, mask=True):
- """Mask/Unmask points with given coordinates.
-
- :param int level: Mask level to update.
- :param rows: Rows of selected points
- :type rows: 1D numpy.ndarray
- :param cols: Columns of selected points
- :type cols: 1D numpy.ndarray
- :param bool mask: True to mask (default), False to unmask.
- """
- valid = numpy.logical_and(
- numpy.logical_and(rows >= 0, cols >= 0),
- numpy.logical_and(rows < self._mask.shape[0],
- cols < self._mask.shape[1]))
- rows, cols = rows[valid], cols[valid]
-
- if mask:
- self._mask[rows, cols] = level
- else:
- inMask = self._mask[rows, cols] == level
- self._mask[rows[inMask], cols[inMask]] = 0
- self._notify()
-
- def updateDisk(self, level, crow, ccol, radius, mask=True):
- """Mask/Unmask a disk of the given mask level.
-
- :param int level: Mask level to update.
- :param int crow: Disk center row.
- :param int ccol: Disk center column.
- :param float radius: Radius of the disk in mask array unit
- :param bool mask: True to mask (default), False to unmask.
- """
- rows, cols = shapes.circle_fill(crow, ccol, radius)
- self.updatePoints(level, rows, cols, mask)
-
- def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
- """Mask/Unmask an ellipse of the given mask level.
-
- :param int level: Mask level to update.
- :param int crow: Row of the center of the ellipse
- :param int ccol: Column of the center of the ellipse
- :param float radius_r: Radius of the ellipse in the row
- :param float radius_c: Radius of the ellipse in the column
- :param bool mask: True to mask (default), False to unmask.
- """
- rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c)
- self.updatePoints(level, rows, cols, mask)
-
- def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
- """Mask/Unmask a line of the given mask level.
-
- :param int level: Mask level to update.
- :param int row0: Row of the starting point.
- :param int col0: Column of the starting point.
- :param int row1: Row of the end point.
- :param int col1: Column of the end point.
- :param int width: Width of the line in mask array unit.
- :param bool mask: True to mask (default), False to unmask.
- """
- rows, cols = shapes.draw_line(row0, col0, row1, col1, width)
- self.updatePoints(level, rows, cols, mask)
-
-
-class MaskToolsWidget(BaseMaskToolsWidget):
- """Widget with tools for drawing mask on an image in a PlotWidget."""
-
- _maxLevelNumber = 255
-
- def __init__(self, parent=None, plot=None):
- super(MaskToolsWidget, self).__init__(parent, plot,
- mask=ImageMask())
- self._origin = (0., 0.) # Mask origin in plot
- self._scale = (1., 1.) # Mask scale in plot
- self._z = 1 # Mask layer in plot
- self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
-
- self.__itemMaskUpdatedLock = LockReentrant()
- self.__itemMaskUpdated = False
-
- def __maskStateChanged(self) -> None:
- """Handle mask commit to update item mask"""
- item = self._mask.getDataItem()
- if item is not None:
- with self.__itemMaskUpdatedLock:
- item.setMaskData(self._mask.getMask(copy=True), copy=False)
-
- def setItemMaskUpdated(self, enabled: bool) -> None:
- """Toggle item mask and mask tool synchronisation.
-
- :param bool enabled: True to synchronise. Default: False
- """
- enabled = bool(enabled)
- if enabled != self.__itemMaskUpdated:
- if self.__itemMaskUpdated:
- self._mask.sigStateChanged.disconnect(self.__maskStateChanged)
- self.__itemMaskUpdated = enabled
- if self.__itemMaskUpdated:
- # Synchronize item and tool mask
- self._setMaskedImage(self._mask.getDataItem())
- self._mask.sigStateChanged.connect(self.__maskStateChanged)
-
- def isItemMaskUpdated(self) -> bool:
- """Returns whether or not item and mask tool masks are synchronised.
-
- :rtype: bool
- """
- return self.__itemMaskUpdated
-
- def setSelectionMask(self, mask, copy=True):
- """Set the mask to a new array.
-
- :param numpy.ndarray mask:
- The array to use for the mask or None to reset the mask.
- :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
- Array of other types are converted.
- :param bool copy: True (the default) to copy the array,
- False to use it as is if possible.
- :return: None if failed, shape of mask as 2-tuple if successful.
- The mask can be cropped or padded to fit active image,
- the returned shape is that of the active image.
- """
- if mask is None:
- self.resetSelectionMask()
- return self._data.shape[:2]
-
- mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
- if len(mask.shape) != 2:
- _logger.error('Not an image, shape: %d', len(mask.shape))
- return None
-
- # Handle mask with single level
- if self.multipleMasks() == 'single':
- mask = numpy.array(mask != 0, dtype=numpy.uint8)
-
- # if mask has not changed, do nothing
- if numpy.array_equal(mask, self.getSelectionMask()):
- return mask.shape
-
- if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]:
- self._mask.setMask(mask, copy=copy)
- self._mask.commit()
- return mask.shape
- else:
- _logger.warning('Mask has not the same size as current image.'
- ' Mask will be cropped or padded to fit image'
- ' dimensions. %s != %s',
- str(mask.shape), str(self._data.shape))
- resizedMask = numpy.zeros(self._data.shape[0:2],
- dtype=numpy.uint8)
- height = min(self._data.shape[0], mask.shape[0])
- width = min(self._data.shape[1], mask.shape[1])
- resizedMask[:height,:width] = mask[:height,:width]
- self._mask.setMask(resizedMask, copy=False)
- self._mask.commit()
- return resizedMask.shape
-
- # Handle mask refresh on the plot
- def _updatePlotMask(self):
- """Update mask image in plot"""
- mask = self.getSelectionMask(copy=False)
- if mask is not None:
- # get the mask from the plot
- maskItem = self.plot.getImage(self._maskName)
- mustBeAdded = maskItem is None
- if mustBeAdded:
- maskItem = items.MaskImageData()
- maskItem.setName(self._maskName)
- # update the items
- maskItem.setData(mask, copy=False)
- maskItem.setColormap(self._colormap)
- maskItem.setOrigin(self._origin)
- maskItem.setScale(self._scale)
- maskItem.setZValue(self._z)
-
- if mustBeAdded:
- self.plot.addItem(maskItem)
-
- elif self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
-
- def showEvent(self, event):
- try:
- self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChangedAfterCare)
- except (RuntimeError, TypeError):
- pass
-
- # Sync with current active image
- self._setMaskedImage(self.plot.getActiveImage())
- self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
-
- def hideEvent(self, event):
- try:
- self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChanged)
- except (RuntimeError, TypeError):
- pass
-
- image = self.getMaskedItem()
- if image is not None:
- try:
- image.sigItemChanged.disconnect(self.__imageChanged)
- except (RuntimeError, TypeError):
- pass # TODO should not happen
-
- if self.isMaskInteractionActivated():
- # Disable drawing tool
- self.browseAction.trigger()
-
- if self.isItemMaskUpdated(): # No "after-care"
- self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
- self._mask.setDataItem(None)
- self._mask.reset()
-
- if self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
-
- elif self.getSelectionMask(copy=False) is not None:
- self.plot.sigActiveImageChanged.connect(
- self._activeImageChangedAfterCare)
-
- def _activeImageChanged(self, previous, current):
- """Reacts upon active image change.
-
- Only handle change of active image items here.
- """
- if previous != current:
- image = self.plot.getActiveImage()
- if image is not None and image.getName() == self._maskName:
- image = None # Active image is the mask
- self._setMaskedImage(image)
-
- def _setOverlayColorForImage(self, image):
- """Set the color of overlay adapted to image
-
- :param image: :class:`.items.ImageBase` object to set color for.
- """
- if isinstance(image, items.ColormapMixIn):
- colormap = image.getColormap()
- self._defaultOverlayColor = rgba(
- cursorColorForColormap(colormap['name']))
- else:
- self._defaultOverlayColor = rgba('black')
-
- def _activeImageChangedAfterCare(self, *args):
- """Check synchro of active image and mask when mask widget is hidden.
-
- If active image has no more the same size as the mask, the mask is
- removed, otherwise it is adjusted to origin, scale and z.
- """
- activeImage = self.plot.getActiveImage()
- if activeImage is None or activeImage.getName() == self._maskName:
- # No active image or active image is the mask...
- self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
- self._mask.setDataItem(None)
- self._mask.reset()
-
- if self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
-
- self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChangedAfterCare)
- else:
- self._setOverlayColorForImage(activeImage)
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
-
- self._origin = activeImage.getOrigin()
- self._scale = activeImage.getScale()
- self._z = activeImage.getZValue() + 1
- self._data = activeImage.getData(copy=False)
- if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
- # Image has not the same size, remove mask and stop listening
- if self.plot.getImage(self._maskName):
- self.plot.remove(self._maskName, kind='image')
-
- self.plot.sigActiveImageChanged.disconnect(
- self._activeImageChangedAfterCare)
- else:
- # Refresh in case origin, scale, z changed
- self._mask.setDataItem(activeImage)
- self._updatePlotMask()
-
- def _setMaskedImage(self, image):
- """Change the image that is used a reference to author the mask"""
- previous = self.getMaskedItem()
- if previous is not None and self.isVisible():
- # Disconnect from previous image
- try:
- previous.sigItemChanged.disconnect(self.__imageChanged)
- except TypeError:
- pass # TODO fixme should not happen
-
- # Set the image
- self._mask.setDataItem(image)
-
- if image is None: # No image, disable mask
- self.setEnabled(False)
-
- self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
- self._mask.reset()
- self._mask.commit()
-
- self._updateInteractiveMode()
-
- else: # Update and connect to image's sigItemChanged
- if self.isItemMaskUpdated():
- if image.getMaskData(copy=False) is None:
- # Image item has no mask: use current mask from the tool
- image.setMaskData(
- self.getSelectionMask(copy=False), copy=True)
- else: # Image item has a mask: set it in tool
- self.setSelectionMask(
- image.getMaskData(copy=False), copy=True)
- self._mask.resetHistory()
- self.__imageUpdated()
- if self.isVisible():
- image.sigItemChanged.connect(self.__imageChanged)
-
- def __imageChanged(self, event):
- """Reacts upon image item changes"""
- image = self._mask.getDataItem()
- if image is None:
- _logger.error("Mask is not attached to an image")
- return
-
- if event in (items.ItemChangedType.COLORMAP,
- items.ItemChangedType.DATA,
- items.ItemChangedType.POSITION,
- items.ItemChangedType.SCALE,
- items.ItemChangedType.VISIBLE,
- items.ItemChangedType.ZVALUE):
- self.__imageUpdated()
-
- elif (event == items.ItemChangedType.MASK and
- self.isItemMaskUpdated() and
- not self.__itemMaskUpdatedLock.locked()):
- # Update mask from the image item unless mask tool is updating it
- self.setSelectionMask(image.getMaskData(copy=False), copy=True)
-
- def __imageUpdated(self):
- """Synchronize mask with current state of the image"""
- image = self._mask.getDataItem()
- if image is None:
- _logger.error("No active image while expecting one")
- return
-
- self._setOverlayColorForImage(image)
-
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
-
- self._origin = image.getOrigin()
- self._scale = image.getScale()
- self._z = image.getZValue() + 1
- self._data = image.getData(copy=False)
- self._mask.setDataItem(image)
- if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
- self._mask.reset(self._data.shape[:2])
- self._mask.commit()
- else:
- # Refresh in case origin, scale, z changed
- self._updatePlotMask()
-
- # Visible and with data
- self.setEnabled(image.isVisible() and self._data.size != 0)
-
- # Threshold tools only available for data with colormap
- self.thresholdGroup.setEnabled(self._data.ndim == 2)
-
- self._updateInteractiveMode()
-
- # Handle whole mask operations
- def load(self, filename):
- """Load a mask from an image file.
-
- :param str filename: File name from which to load the mask
- :raise Exception: An exception in case of failure
- :raise RuntimeWarning: In case the mask was applied but with some
- import changes to notice
- """
- _, extension = os.path.splitext(filename)
- extension = extension.lower()[1:]
-
- if extension == "npy":
- try:
- mask = numpy.load(filename)
- except IOError:
- _logger.error("Can't load filename '%s'", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise RuntimeError('File "%s" is not a numpy file.', filename)
- elif extension in ["tif", "tiff"]:
- try:
- image = TiffIO(filename, mode="r")
- mask = image.getImage(0)
- except Exception as e:
- _logger.error("Can't load filename %s", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise e
- elif extension == "edf":
- try:
- mask = EdfFile(filename, access='r').GetData(0)
- except Exception as e:
- _logger.error("Can't load filename %s", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise e
- elif extension == "msk":
- try:
- mask = fabio.open(filename).data
- except Exception as e:
- _logger.error("Can't load fit2d mask file")
- _logger.debug("Backtrace", exc_info=True)
- raise e
- elif ("." + extension) in NEXUS_HDF5_EXT:
- mask = self._loadFromHdf5(filename)
- if mask is None:
- raise IOError("Could not load mask from HDF5 dataset")
- else:
- msg = "Extension '%s' is not supported."
- raise RuntimeError(msg % extension)
-
- effectiveMaskShape = self.setSelectionMask(mask, copy=False)
- if effectiveMaskShape is None:
- return
- if mask.shape != effectiveMaskShape:
- msg = 'Mask was resized from %s to %s'
- msg = msg % (str(mask.shape), str(effectiveMaskShape))
- raise RuntimeWarning(msg)
-
- def _loadMask(self):
- """Open load mask dialog"""
- dialog = qt.QFileDialog(self)
- dialog.setWindowTitle("Load Mask")
- dialog.setModal(1)
-
- extensions = collections.OrderedDict()
- extensions["EDF files"] = "*.edf"
- extensions["TIFF files"] = "*.tif *.tiff"
- extensions["NumPy binary files"] = "*.npy"
- extensions["HDF5 files"] = _HDF5_EXT_STR
- # Fit2D mask is displayed anyway fabio is here or not
- # to show to the user that the option exists
- extensions["Fit2D mask files"] = "*.msk"
-
- filters = []
- filters.append("All supported files (%s)" % " ".join(extensions.values()))
- for name, extension in extensions.items():
- filters.append("%s (%s)" % (name, extension))
- filters.append("All files (*)")
-
- dialog.setNameFilters(filters)
- dialog.setFileMode(qt.QFileDialog.ExistingFile)
- dialog.setDirectory(self.maskFileDir)
- if not dialog.exec_():
- dialog.close()
- return
-
- filename = dialog.selectedFiles()[0]
- dialog.close()
-
- # Update the directory according to the user selection
- self.maskFileDir = os.path.dirname(filename)
-
- try:
- self.load(filename)
- except RuntimeWarning as e:
- message = e.args[0]
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Warning)
- msg.setText("Mask loaded but an operation was applied.\n" + message)
- msg.exec_()
- except Exception as e:
- message = e.args[0]
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setText("Cannot load mask from file. " + message)
- msg.exec_()
-
- @staticmethod
- def _loadFromHdf5(filename):
- """Load a mask array from a HDF5 file.
-
- :param str filename: name of an existing HDF5 file
- :returns: A mask as a numpy array, or None if the interactive dialog
- was cancelled
- """
- dataPath = _selectDataset(filename, mode=DatasetDialog.LoadMode)
- if dataPath is None:
- return None
-
- with h5py.File(filename, "r") as h5f:
- dataset = h5f.get(dataPath)
- if not is_dataset(dataset):
- raise IOError("%s is not a dataset" % dataPath)
- mask = dataset[()]
- return mask
-
- def _saveMask(self):
- """Open Save mask dialog"""
- dialog = qt.QFileDialog(self)
- dialog.setWindowTitle("Save Mask")
- dialog.setOption(dialog.DontUseNativeDialog)
- dialog.setModal(1)
- hdf5Filter = 'HDF5 (%s)' % _HDF5_EXT_STR
- filters = [
- 'EDF (*.edf)',
- 'TIFF (*.tif)',
- 'NumPy binary file (*.npy)',
- hdf5Filter,
- # Fit2D mask is displayed anyway fabio is here or not
- # to show to the user that the option exists
- 'Fit2D mask (*.msk)',
- ]
- dialog.setNameFilters(filters)
- dialog.setFileMode(qt.QFileDialog.AnyFile)
- dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
- dialog.setDirectory(self.maskFileDir)
-
- def onFilterSelection(filt_):
- # disable overwrite confirmation for HDF5,
- # because we append the data to existing files
- if filt_ == hdf5Filter:
- dialog.setOption(dialog.DontConfirmOverwrite)
- else:
- dialog.setOption(dialog.DontConfirmOverwrite, False)
-
- dialog.filterSelected.connect(onFilterSelection)
- if not dialog.exec_():
- dialog.close()
- return
-
- nameFilter = dialog.selectedNameFilter()
- filename = dialog.selectedFiles()[0]
- dialog.close()
-
- if "HDF5" in nameFilter:
- has_allowed_ext = False
- for ext in NEXUS_HDF5_EXT:
- if (len(filename) > len(ext) and
- filename[-len(ext):].lower() == ext.lower()):
- has_allowed_ext = True
- extension = ext
- if not has_allowed_ext:
- extension = ".h5"
- filename += ".h5"
- else:
- # convert filter name to extension name with the .
- extension = nameFilter.split()[-1][2:-1]
- if not filename.lower().endswith(extension):
- filename += extension
-
- if os.path.exists(filename) and "HDF5" not in nameFilter:
- try:
- os.remove(filename)
- except IOError as e:
- msg = qt.QMessageBox(self)
- msg.setWindowTitle("Removing existing file")
- msg.setIcon(qt.QMessageBox.Critical)
-
- if hasattr(e, "strerror"):
- strerror = e.strerror
- else:
- strerror = sys.exc_info()[1]
- msg.setText("Cannot save.\n"
- "Input Output Error: %s" % strerror)
- msg.exec_()
- return
-
- # Update the directory according to the user selection
- self.maskFileDir = os.path.dirname(filename)
-
- try:
- self.save(filename, extension[1:])
- except Exception as e:
- msg = qt.QMessageBox(self)
- msg.setWindowTitle("Saving mask file")
- msg.setIcon(qt.QMessageBox.Critical)
-
- if hasattr(e, "strerror"):
- strerror = e.strerror
- else:
- strerror = sys.exc_info()[1]
- msg.setText("Cannot save file %s\n%s" % (filename, strerror))
- msg.exec_()
-
- def resetSelectionMask(self):
- """Reset the mask"""
- self._mask.reset(shape=self._data.shape[:2])
- self._mask.commit()
-
- def _plotDrawEvent(self, event):
- """Handle draw events from the plot"""
- if (self._drawingMode is None or
- event['event'] not in ('drawingProgress', 'drawingFinished')):
- return
-
- if not len(self._data):
- return
-
- level = self.levelSpinBox.value()
-
- if self._drawingMode == 'rectangle':
- if event['event'] == 'drawingFinished':
- # Convert from plot to array coords
- doMask = self._isMasking()
- ox, oy = self._origin
- sx, sy = self._scale
-
- height = int(abs(event['height'] / sy))
- width = int(abs(event['width'] / sx))
-
- row = int((event['y'] - oy) / sy)
- if sy < 0:
- row -= height
-
- col = int((event['x'] - ox) / sx)
- if sx < 0:
- col -= width
-
- self._mask.updateRectangle(
- level,
- row=row,
- col=col,
- height=height,
- width=width,
- mask=doMask)
- self._mask.commit()
-
- elif self._drawingMode == 'ellipse':
- if event['event'] == 'drawingFinished':
- doMask = self._isMasking()
- # Convert from plot to array coords
- center = (event['points'][0] - self._origin) / self._scale
- size = event['points'][1] / self._scale
- center = center.astype(numpy.int64) # (row, col)
- self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
- self._mask.commit()
-
- elif self._drawingMode == 'polygon':
- if event['event'] == 'drawingFinished':
- doMask = self._isMasking()
- # Convert from plot to array coords
- vertices = (event['points'] - self._origin) / self._scale
- vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col)
- self._mask.updatePolygon(level, vertices, doMask)
- self._mask.commit()
-
- elif self._drawingMode == 'pencil':
- doMask = self._isMasking()
- # convert from plot to array coords
- col, row = (event['points'][-1] - self._origin) / self._scale
- col, row = int(col), int(row)
- brushSize = self._getPencilWidth()
-
- if self._lastPencilPos != (row, col):
- if self._lastPencilPos is not None:
- # Draw the line
- self._mask.updateLine(
- level,
- self._lastPencilPos[0], self._lastPencilPos[1],
- row, col,
- brushSize,
- doMask)
-
- # Draw the very first, or last point
- self._mask.updateDisk(level, row, col, brushSize / 2., doMask)
-
- if event['event'] == 'drawingFinished':
- self._mask.commit()
- self._lastPencilPos = None
- else:
- self._lastPencilPos = row, col
- else:
- _logger.error("Drawing mode %s unsupported", self._drawingMode)
-
- def _loadRangeFromColormapTriggered(self):
- """Set range from active image colormap range"""
- activeImage = self.plot.getActiveImage()
- if (isinstance(activeImage, items.ColormapMixIn) and
- activeImage.getName() != self._maskName):
- # Update thresholds according to colormap
- colormap = activeImage.getColormap()
- if colormap['autoscale']:
- min_ = numpy.nanmin(activeImage.getData(copy=False))
- max_ = numpy.nanmax(activeImage.getData(copy=False))
- else:
- min_, max_ = colormap['vmin'], colormap['vmax']
- self.minLineEdit.setText(str(min_))
- self.maxLineEdit.setText(str(max_))
-
-
-class MaskToolsDockWidget(BaseMaskToolsDockWidget):
- """:class:`MaskToolsWidget` embedded in a QDockWidget.
-
- For integration in a :class:`PlotWindow`.
-
- :param parent: See :class:`QDockWidget`
- :param plot: The PlotWidget this widget is operating on
- :paran str name: The title of this widget
- """
-
- def __init__(self, parent=None, plot=None, name='Mask'):
- widget = MaskToolsWidget(plot=plot)
- super(MaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/silx/gui/plot/PlotInteraction.py b/silx/gui/plot/PlotInteraction.py
deleted file mode 100644
index cfe140b..0000000
--- a/silx/gui/plot/PlotInteraction.py
+++ /dev/null
@@ -1,1748 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Implementation of the interaction for the :class:`Plot`."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "15/02/2019"
-
-
-import math
-import numpy
-import time
-import weakref
-
-from .. import colors
-from .. import qt
-from . import items
-from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, MIDDLE_BTN,
- State, StateMachine)
-from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal,
- prepareHoverSignal, prepareImageSignal,
- prepareMarkerSignal, prepareMouseSignal)
-
-from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR,
- CURSOR_SIZE_VER, CURSOR_SIZE_ALL)
-
-from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX,
- applyZoomToPlot)
-
-
-# Base class ##################################################################
-
-class _PlotInteraction(object):
- """Base class for interaction handler.
-
- It provides a weakref to the plot and methods to set/reset overlay.
- """
- def __init__(self, plot):
- """Init.
-
- :param plot: The plot to apply modifications to.
- """
- self._needReplot = False
- self._selectionAreas = set()
- self._plot = weakref.ref(plot) # Avoid cyclic-ref
-
- @property
- def plot(self):
- plot = self._plot()
- assert plot is not None
- return plot
-
- def setSelectionArea(self, points, fill, color, name='', shape='polygon'):
- """Set a polygon selection area overlaid on the plot.
- Multiple simultaneous areas are supported through the name parameter.
-
- :param points: The 2D coordinates of the points of the polygon
- :type points: An iterable of (x, y) coordinates
- :param str fill: The fill mode: 'hatch', 'solid' or 'none'
- :param color: RGBA color to use or None to disable display
- :type color: list or tuple of 4 float in the range [0, 1]
- :param name: The key associated with this selection area
- :param str shape: Shape of the area in 'polygon', 'polylines'
- """
- assert shape in ('polygon', 'polylines')
-
- if color is None:
- return
-
- points = numpy.asarray(points)
-
- # TODO Not very nice, but as is for now
- legend = '__SELECTION_AREA__' + name
-
- fill = fill != 'none' # TODO not very nice either
-
- greyed = colors.greyed(color)[0]
- if greyed < 0.5:
- color2 = "white"
- else:
- color2 = "black"
-
- self.plot.addShape(points[:, 0], points[:, 1], legend=legend,
- replace=False,
- shape=shape, fill=fill,
- color=color, linebgcolor=color2, linestyle="--",
- overlay=True)
-
- self._selectionAreas.add(legend)
-
- def resetSelectionArea(self):
- """Remove all selection areas set by setSelectionArea."""
- for legend in self._selectionAreas:
- self.plot.remove(legend, kind='item')
- self._selectionAreas = set()
-
-
-# Zoom/Pan ####################################################################
-
-class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
- """:class:`ClickOrDrag` state machine with zooming on mouse wheel.
-
- Base class for :class:`Pan` and :class:`Zoom`
- """
-
- _DOUBLE_CLICK_TIMEOUT = 0.4
-
- class Idle(ClickOrDrag.Idle):
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.machine.plot, scaleF, (x, y))
-
- def click(self, x, y, btn):
- """Handle clicks by sending events
-
- :param int x: Mouse X position in pixels
- :param int y: Mouse Y position in pixels
- :param btn: Clicked mouse button
- """
- if btn == LEFT_BTN:
- lastClickTime, lastClickPos = self._lastClick
-
- # Signal mouse double clicked event first
- if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT:
- # Use position of first click
- eventDict = prepareMouseSignal('mouseDoubleClicked', 'left',
- *lastClickPos)
- self.plot.notify(**eventDict)
-
- self._lastClick = 0., None
- else:
- # Signal mouse clicked event
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- eventDict = prepareMouseSignal('mouseClicked', 'left',
- dataPos[0], dataPos[1],
- x, y)
- self.plot.notify(**eventDict)
-
- self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y)
-
- elif btn == RIGHT_BTN:
- # Signal mouse clicked event
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- eventDict = prepareMouseSignal('mouseClicked', 'right',
- dataPos[0], dataPos[1],
- x, y)
- self.plot.notify(**eventDict)
-
- def __init__(self, plot, **kwargs):
- """Init.
-
- :param plot: The plot to apply modifications to.
- """
- self._lastClick = 0., None
-
- _PlotInteraction.__init__(self, plot)
- ClickOrDrag.__init__(self, **kwargs)
-
-
-# Pan #########################################################################
-
-class Pan(_ZoomOnWheel):
- """Pan plot content and zoom on wheel state machine."""
-
- def _pixelToData(self, x, y):
- xData, yData = self.plot.pixelToData(x, y)
- _, y2Data = self.plot.pixelToData(x, y, axis='right')
- return xData, yData, y2Data
-
- def beginDrag(self, x, y, btn):
- self._previousDataPos = self._pixelToData(x, y)
-
- def drag(self, x, y, btn):
- xData, yData, y2Data = self._pixelToData(x, y)
- lastX, lastY, lastY2 = self._previousDataPos
-
- xMin, xMax = self.plot.getXAxis().getLimits()
- yMin, yMax = self.plot.getYAxis().getLimits()
- y2Min, y2Max = self.plot.getYAxis(axis='right').getLimits()
-
- if self.plot.getXAxis()._isLogarithmic():
- try:
- dx = math.log10(xData) - math.log10(lastX)
- newXMin = pow(10., (math.log10(xMin) - dx))
- newXMax = pow(10., (math.log10(xMax) - dx))
- except (ValueError, OverflowError):
- newXMin, newXMax = xMin, xMax
-
- # Makes sure both values stays in positive float32 range
- if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX:
- newXMin, newXMax = xMin, xMax
- else:
- dx = xData - lastX
- newXMin, newXMax = xMin - dx, xMax - dx
-
- # Makes sure both values stays in float32 range
- if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX:
- newXMin, newXMax = xMin, xMax
-
- if self.plot.getYAxis()._isLogarithmic():
- try:
- dy = math.log10(yData) - math.log10(lastY)
- newYMin = pow(10., math.log10(yMin) - dy)
- newYMax = pow(10., math.log10(yMax) - dy)
-
- dy2 = math.log10(y2Data) - math.log10(lastY2)
- newY2Min = pow(10., math.log10(y2Min) - dy2)
- newY2Max = pow(10., math.log10(y2Max) - dy2)
- except (ValueError, OverflowError):
- newYMin, newYMax = yMin, yMax
- newY2Min, newY2Max = y2Min, y2Max
-
- # Makes sure y and y2 stays in positive float32 range
- if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or
- newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX):
- newYMin, newYMax = yMin, yMax
- newY2Min, newY2Max = y2Min, y2Max
- else:
- dy = yData - lastY
- dy2 = y2Data - lastY2
- newYMin, newYMax = yMin - dy, yMax - dy
- newY2Min, newY2Max = y2Min - dy2, y2Max - dy2
-
- # Makes sure y and y2 stays in float32 range
- if (newYMin < FLOAT32_SAFE_MIN or
- newYMax > FLOAT32_SAFE_MAX or
- newY2Min < FLOAT32_SAFE_MIN or
- newY2Max > FLOAT32_SAFE_MAX):
- newYMin, newYMax = yMin, yMax
- newY2Min, newY2Max = y2Min, y2Max
-
- self.plot.setLimits(newXMin, newXMax,
- newYMin, newYMax,
- newY2Min, newY2Max)
-
- self._previousDataPos = self._pixelToData(x, y)
-
- def endDrag(self, startPos, endPos, btn):
- del self._previousDataPos
-
- def cancel(self):
- pass
-
-
-# Zoom ########################################################################
-
-class Zoom(_ZoomOnWheel):
- """Zoom-in/out state machine.
-
- Zoom-in on selected area, zoom-out on right click,
- and zoom on mouse wheel.
- """
-
- SURFACE_THRESHOLD = 5
-
- def __init__(self, plot, color):
- self.color = color
-
- super(Zoom, self).__init__(plot)
- self.plot.getLimitsHistory().clear()
-
- def _areaWithAspectRatio(self, x0, y0, x1, y1):
- _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels()
-
- areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1
-
- if plotH != 0.:
- plotRatio = plotW / float(plotH)
- width, height = math.fabs(x1 - x0), math.fabs(y1 - y0)
-
- if height != 0. and width != 0.:
- if width / height > plotRatio:
- areaHeight = width / plotRatio
- areaX0, areaX1 = x0, x1
- center = 0.5 * (y0 + y1)
- areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
- areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
- else:
- areaWidth = height * plotRatio
- areaY0, areaY1 = y0, y1
- center = 0.5 * (x0 + x1)
- areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
- areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
-
- return areaX0, areaY0, areaX1, areaY1
-
- def beginDrag(self, x, y, btn):
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- self.x0, self.y0 = x, y
-
- def drag(self, x1, y1, btn):
- if self.color is None:
- return # Do not draw zoom area
-
- dataPos = self.plot.pixelToData(x1, y1)
- assert dataPos is not None
-
- if self.plot.isKeepDataAspectRatio():
- area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1)
- areaX0, areaY0, areaX1, areaY1 = area
- areaPoints = ((areaX0, areaY0),
- (areaX1, areaY0),
- (areaX1, areaY1),
- (areaX0, areaY1))
- areaPoints = numpy.array([self.plot.pixelToData(
- x, y, check=False) for (x, y) in areaPoints])
-
- if self.color != 'video inverted':
- areaColor = list(self.color)
- areaColor[3] *= 0.25
- else:
- areaColor = [1., 1., 1., 1.]
-
- self.setSelectionArea(areaPoints,
- fill='none',
- color=areaColor,
- name="zoomedArea")
-
- corners = ((self.x0, self.y0),
- (self.x0, y1),
- (x1, y1),
- (x1, self.y0))
- corners = numpy.array([self.plot.pixelToData(x, y, check=False)
- for (x, y) in corners])
-
- self.setSelectionArea(corners, fill='none', color=self.color)
-
- def _zoom(self, x0, y0, x1, y1):
- """Zoom to the rectangle view x0,y0 x1,y1.
- """
- startPos = x0, y0
- endPos = x1, y1
-
- # Store current zoom state in stack
- self.plot.getLimitsHistory().push()
-
- if self.plot.isKeepDataAspectRatio():
- x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
-
- # Convert to data space and set limits
- x0, y0 = self.plot.pixelToData(x0, y0, check=False)
-
- dataPos = self.plot.pixelToData(
- startPos[0], startPos[1], axis="right", check=False)
- y2_0 = dataPos[1]
-
- x1, y1 = self.plot.pixelToData(x1, y1, check=False)
-
- dataPos = self.plot.pixelToData(
- endPos[0], endPos[1], axis="right", check=False)
- y2_1 = dataPos[1]
-
- xMin, xMax = min(x0, x1), max(x0, x1)
- yMin, yMax = min(y0, y1), max(y0, y1)
- y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
-
- self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
-
- def endDrag(self, startPos, endPos, btn):
- x0, y0 = startPos
- x1, y1 = endPos
-
- if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD:
- # Avoid empty zoom area
- self._zoom(x0, y0, x1, y1)
-
- self.resetSelectionArea()
-
- def cancel(self):
- if isinstance(self.state, self.states['drag']):
- self.resetSelectionArea()
-
-
-# Select ######################################################################
-
-class Select(StateMachine, _PlotInteraction):
- """Base class for drawing selection areas."""
-
- def __init__(self, plot, parameters, states, state):
- """Init a state machine.
-
- :param plot: The plot to apply changes to.
- :param dict parameters: A dict of parameters such as color.
- :param dict states: The states of the state machine.
- :param str state: The name of the initial state.
- """
- _PlotInteraction.__init__(self, plot)
- self.parameters = parameters
- StateMachine.__init__(self, states, state)
-
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.plot, scaleF, (x, y))
-
- @property
- def color(self):
- return self.parameters.get('color', None)
-
-
-class SelectPolygon(Select):
- """Drawing selection polygon area state machine."""
-
- DRAG_THRESHOLD_DIST = 4
-
- class Idle(State):
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- self.goto('select', x, y)
- return True
-
- class Select(State):
- def enterState(self, x, y):
- dataPos = self.machine.plot.pixelToData(x, y)
- assert dataPos is not None
- self._firstPos = dataPos
- self.points = [dataPos, dataPos]
-
- self.updateFirstPoint()
-
- def updateFirstPoint(self):
- """Update drawing first point, using self._firstPos"""
- x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False)
-
- offset = self.machine.getDragThreshold()
- points = [(x - offset, y - offset),
- (x - offset, y + offset),
- (x + offset, y + offset),
- (x + offset, y - offset)]
- points = [self.machine.plot.pixelToData(xpix, ypix, check=False)
- for xpix, ypix in points]
- self.machine.setSelectionArea(points, fill=None,
- color=self.machine.color,
- name='first_point')
-
- def updateSelectionArea(self):
- """Update drawing selection area using self.points"""
- self.machine.setSelectionArea(self.points,
- fill='hatch',
- color=self.machine.color)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'polygon',
- self.points,
- self.machine.parameters)
- self.machine.plot.notify(**eventDict)
-
- def validate(self):
- if len(self.points) > 2:
- self.closePolygon()
- else:
- # It would be nice to have a cancel event.
- # The plot is not aware that the interaction was cancelled
- self.machine.cancel()
-
- def closePolygon(self):
- self.machine.resetSelectionArea()
- self.points[-1] = self.points[0]
- eventDict = prepareDrawingSignal('drawingFinished',
- 'polygon',
- self.points,
- self.machine.parameters)
- self.machine.plot.notify(**eventDict)
- self.goto('idle')
-
- def onWheel(self, x, y, angle):
- self.machine.onWheel(x, y, angle)
- self.updateFirstPoint()
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- # checking if the position is close to the first point
- # if yes : closing the "loop"
- firstPos = self.machine.plot.dataToPixel(*self._firstPos,
- check=False)
- dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
-
- threshold = self.machine.getDragThreshold()
-
- # Only allow to close polygon after first point
- if len(self.points) > 2 and dx <= threshold and dy <= threshold:
- self.closePolygon()
- return False
-
- # Update polygon last point not too close to previous one
- dataPos = self.machine.plot.pixelToData(x, y)
- assert dataPos is not None
- self.updateSelectionArea()
-
- # checking that the new points isnt the same (within range)
- # of the previous one
- # This has to be done because sometimes the mouse release event
- # is caught right after entering the Select state (i.e : press
- # in Idle state, but with a slightly different position that
- # the mouse press. So we had the two first vertices that were
- # almost identical.
- previousPos = self.machine.plot.dataToPixel(*self.points[-2],
- check=False)
- dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y)
- if dx >= threshold or dy >= threshold:
- self.points.append(dataPos)
- else:
- self.points[-1] = dataPos
-
- return True
- return False
-
- def onMove(self, x, y):
- firstPos = self.machine.plot.dataToPixel(*self._firstPos,
- check=False)
- dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
- threshold = self.machine.getDragThreshold()
-
- if dx <= threshold and dy <= threshold:
- x, y = firstPos # Snap to first point
-
- dataPos = self.machine.plot.pixelToData(x, y)
- assert dataPos is not None
- self.points[-1] = dataPos
- self.updateSelectionArea()
-
- def __init__(self, plot, parameters):
- states = {
- 'idle': SelectPolygon.Idle,
- 'select': SelectPolygon.Select
- }
- super(SelectPolygon, self).__init__(plot, parameters,
- states, 'idle')
-
- def cancel(self):
- if isinstance(self.state, self.states['select']):
- self.resetSelectionArea()
-
- def getDragThreshold(self):
- """Return dragging ratio with device to pixel ratio applied.
-
- :rtype: float
- """
- ratio = 1.
- if qt.BINDING in ('PyQt5', 'PySide2'):
- ratio = self.plot.window().windowHandle().devicePixelRatio()
- return self.DRAG_THRESHOLD_DIST * ratio
-
-
-class Select2Points(Select):
- """Base class for drawing selection based on 2 input points."""
- class Idle(State):
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- self.goto('start', x, y)
- return True
-
- class Start(State):
- def enterState(self, x, y):
- self.machine.beginSelect(x, y)
-
- def onMove(self, x, y):
- self.goto('select', x, y)
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- self.goto('select', x, y)
- return True
-
- class Select(State):
- def enterState(self, x, y):
- self.onMove(x, y)
-
- def onMove(self, x, y):
- self.machine.select(x, y)
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- self.machine.endSelect(x, y)
- self.goto('idle')
-
- def __init__(self, plot, parameters):
- states = {
- 'idle': Select2Points.Idle,
- 'start': Select2Points.Start,
- 'select': Select2Points.Select
- }
- super(Select2Points, self).__init__(plot, parameters,
- states, 'idle')
-
- def beginSelect(self, x, y):
- pass
-
- def select(self, x, y):
- pass
-
- def endSelect(self, x, y):
- pass
-
- def cancelSelect(self):
- pass
-
- def cancel(self):
- if isinstance(self.state, self.states['select']):
- self.cancelSelect()
-
-
-class SelectEllipse(Select2Points):
- """Drawing ellipse selection area state machine."""
- def beginSelect(self, x, y):
- self.center = self.plot.pixelToData(x, y)
- assert self.center is not None
-
- def _getEllipseSize(self, pointInEllipse):
- """
- Returns the size from the center to the bounding box of the ellipse.
-
- :param Tuple[float,float] pointInEllipse: A point of the ellipse
- :rtype: Tuple[float,float]
- """
- x = abs(self.center[0] - pointInEllipse[0])
- y = abs(self.center[1] - pointInEllipse[1])
- if x == 0 or y == 0:
- return x, y
- # Ellipse definitions
- # e: eccentricity
- # a: length fron center to bounding box width
- # b: length fron center to bounding box height
- # Equations
- # (1) b < a
- # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1
- # (3) b = a * sqrt(1-e^2)
- # (4) e = sqrt(a^2 - b^2) / a
-
- # The eccentricity of the ellipse defined by a,b=x,y is the same
- # as the one we are searching for.
- swap = x < y
- if swap:
- x, y = y, x
- e = math.sqrt(x**2 - y**2) / x
- # From (2) using (3) to replace b
- # a^2 = x^2 + y^2 / (1-e^2)
- a = math.sqrt(x**2 + y**2 / (1.0 - e**2))
- b = a * math.sqrt(1 - e**2)
- if swap:
- a, b = b, a
- return a, b
-
- def select(self, x, y):
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- width, height = self._getEllipseSize(dataPos)
-
- # Circle used for circle preview
- nbpoints = 27.
- angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
- circleShape = numpy.array((numpy.cos(angles) * width,
- numpy.sin(angles) * height)).T
- circleShape += numpy.array(self.center)
-
- self.setSelectionArea(circleShape,
- shape="polygon",
- fill='hatch',
- color=self.color)
-
- eventDict = prepareDrawingSignal('drawingProgress',
- 'ellipse',
- (self.center, (width, height)),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def endSelect(self, x, y):
- self.resetSelectionArea()
-
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- width, height = self._getEllipseSize(dataPos)
-
- eventDict = prepareDrawingSignal('drawingFinished',
- 'ellipse',
- (self.center, (width, height)),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def cancelSelect(self):
- self.resetSelectionArea()
-
-
-class SelectRectangle(Select2Points):
- """Drawing rectangle selection area state machine."""
- def beginSelect(self, x, y):
- self.startPt = self.plot.pixelToData(x, y)
- assert self.startPt is not None
-
- def select(self, x, y):
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- self.setSelectionArea((self.startPt,
- (self.startPt[0], dataPos[1]),
- dataPos,
- (dataPos[0], self.startPt[1])),
- fill='hatch',
- color=self.color)
-
- eventDict = prepareDrawingSignal('drawingProgress',
- 'rectangle',
- (self.startPt, dataPos),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def endSelect(self, x, y):
- self.resetSelectionArea()
-
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- eventDict = prepareDrawingSignal('drawingFinished',
- 'rectangle',
- (self.startPt, dataPos),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def cancelSelect(self):
- self.resetSelectionArea()
-
-
-class SelectLine(Select2Points):
- """Drawing line selection area state machine."""
- def beginSelect(self, x, y):
- self.startPt = self.plot.pixelToData(x, y)
- assert self.startPt is not None
-
- def select(self, x, y):
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- self.setSelectionArea((self.startPt, dataPos),
- fill='hatch',
- color=self.color)
-
- eventDict = prepareDrawingSignal('drawingProgress',
- 'line',
- (self.startPt, dataPos),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def endSelect(self, x, y):
- self.resetSelectionArea()
-
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- eventDict = prepareDrawingSignal('drawingFinished',
- 'line',
- (self.startPt, dataPos),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def cancelSelect(self):
- self.resetSelectionArea()
-
-
-class Select1Point(Select):
- """Base class for drawing selection area based on one input point."""
- class Idle(State):
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- self.goto('select', x, y)
- return True
-
- class Select(State):
- def enterState(self, x, y):
- self.onMove(x, y)
-
- def onMove(self, x, y):
- self.machine.select(x, y)
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- self.machine.endSelect(x, y)
- self.goto('idle')
-
- def onWheel(self, x, y, angle):
- self.machine.onWheel(x, y, angle) # Call select default wheel
- self.machine.select(x, y)
-
- def __init__(self, plot, parameters):
- states = {
- 'idle': Select1Point.Idle,
- 'select': Select1Point.Select
- }
- super(Select1Point, self).__init__(plot, parameters, states, 'idle')
-
- def select(self, x, y):
- pass
-
- def endSelect(self, x, y):
- pass
-
- def cancelSelect(self):
- pass
-
- def cancel(self):
- if isinstance(self.state, self.states['select']):
- self.cancelSelect()
-
-
-class SelectHLine(Select1Point):
- """Drawing a horizontal line selection area state machine."""
- def _hLine(self, y):
- """Return points in data coords of the segment visible in the plot.
-
- Supports non-orthogonal axes.
- """
- left, _top, width, _height = self.plot.getPlotBoundsInPixels()
-
- dataPos1 = self.plot.pixelToData(left, y, check=False)
- dataPos2 = self.plot.pixelToData(left + width, y, check=False)
- return dataPos1, dataPos2
-
- def select(self, x, y):
- points = self._hLine(y)
- self.setSelectionArea(points, fill='hatch', color=self.color)
-
- eventDict = prepareDrawingSignal('drawingProgress',
- 'hline',
- points,
- self.parameters)
- self.plot.notify(**eventDict)
-
- def endSelect(self, x, y):
- self.resetSelectionArea()
-
- eventDict = prepareDrawingSignal('drawingFinished',
- 'hline',
- self._hLine(y),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def cancelSelect(self):
- self.resetSelectionArea()
-
-
-class SelectVLine(Select1Point):
- """Drawing a vertical line selection area state machine."""
- def _vLine(self, x):
- """Return points in data coords of the segment visible in the plot.
-
- Supports non-orthogonal axes.
- """
- _left, top, _width, height = self.plot.getPlotBoundsInPixels()
-
- dataPos1 = self.plot.pixelToData(x, top, check=False)
- dataPos2 = self.plot.pixelToData(x, top + height, check=False)
- return dataPos1, dataPos2
-
- def select(self, x, y):
- points = self._vLine(x)
- self.setSelectionArea(points, fill='hatch', color=self.color)
-
- eventDict = prepareDrawingSignal('drawingProgress',
- 'vline',
- points,
- self.parameters)
- self.plot.notify(**eventDict)
-
- def endSelect(self, x, y):
- self.resetSelectionArea()
-
- eventDict = prepareDrawingSignal('drawingFinished',
- 'vline',
- self._vLine(x),
- self.parameters)
- self.plot.notify(**eventDict)
-
- def cancelSelect(self):
- self.resetSelectionArea()
-
-
-class DrawFreeHand(Select):
- """Interaction for drawing pencil. It display the preview of the pencil
- before pressing the mouse.
- """
-
- class Idle(State):
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- self.goto('select', x, y)
- return True
-
- def onMove(self, x, y):
- self.machine.updatePencilShape(x, y)
-
- def onLeave(self):
- self.machine.cancel()
-
- class Select(State):
- def enterState(self, x, y):
- self.__isOut = False
- self.machine.setFirstPoint(x, y)
-
- def onMove(self, x, y):
- self.machine.updatePencilShape(x, y)
- self.machine.select(x, y)
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- if self.__isOut:
- self.machine.resetSelectionArea()
- self.machine.endSelect(x, y)
- self.goto('idle')
-
- def onEnter(self):
- self.__isOut = False
-
- def onLeave(self):
- self.__isOut = True
-
- def __init__(self, plot, parameters):
- # Circle used for pencil preview
- angle = numpy.arange(13.) * numpy.pi * 2.0 / 13.
- size = parameters.get('width', 1.) * 0.5
- self._circle = size * numpy.array((numpy.cos(angle),
- numpy.sin(angle))).T
-
- states = {
- 'idle': DrawFreeHand.Idle,
- 'select': DrawFreeHand.Select
- }
- super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle')
-
- @property
- def width(self):
- return self.parameters.get('width', None)
-
- def setFirstPoint(self, x, y):
- self._points = []
- self.select(x, y)
-
- def updatePencilShape(self, x, y):
- center = self.plot.pixelToData(x, y, check=False)
- assert center is not None
-
- polygon = center + self._circle
-
- self.setSelectionArea(polygon, fill='none', color=self.color)
-
- def select(self, x, y):
- pos = self.plot.pixelToData(x, y, check=False)
- if len(self._points) > 0:
- if self._points[-1] == pos:
- # Skip same points
- return
- self._points.append(pos)
- eventDict = prepareDrawingSignal('drawingProgress',
- 'polylines',
- self._points,
- self.parameters)
- self.plot.notify(**eventDict)
-
- def endSelect(self, x, y):
- pos = self.plot.pixelToData(x, y, check=False)
- if len(self._points) > 0:
- if self._points[-1] != pos:
- # Append if different
- self._points.append(pos)
-
- eventDict = prepareDrawingSignal('drawingFinished',
- 'polylines',
- self._points,
- self.parameters)
- self.plot.notify(**eventDict)
- self._points = None
-
- def cancelSelect(self):
- self.resetSelectionArea()
-
- def cancel(self):
- self.resetSelectionArea()
-
-
-class SelectFreeLine(ClickOrDrag, _PlotInteraction):
- """Base class for drawing free lines with tools such as pencil."""
-
- def __init__(self, plot, parameters):
- """Init a state machine.
-
- :param plot: The plot to apply changes to.
- :param dict parameters: A dict of parameters such as color.
- """
- # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold
- self._points = []
- ClickOrDrag.__init__(self)
- _PlotInteraction.__init__(self, plot)
- self.parameters = parameters
-
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.plot, scaleF, (x, y))
-
- @property
- def color(self):
- return self.parameters.get('color', None)
-
- def click(self, x, y, btn):
- if btn == LEFT_BTN:
- self._processEvent(x, y, isLast=True)
-
- def beginDrag(self, x, y, btn):
- self._processEvent(x, y, isLast=False)
-
- def drag(self, x, y, btn):
- self._processEvent(x, y, isLast=False)
-
- def endDrag(self, startPos, endPos, btn):
- x, y = endPos
- self._processEvent(x, y, isLast=True)
-
- def cancel(self):
- self.resetSelectionArea()
- self._points = []
-
- def _processEvent(self, x, y, isLast):
- dataPos = self.plot.pixelToData(x, y, check=False)
- isNewPoint = not self._points or dataPos != self._points[-1]
-
- if isNewPoint:
- self._points.append(dataPos)
-
- if isNewPoint or isLast:
- eventDict = prepareDrawingSignal(
- 'drawingFinished' if isLast else 'drawingProgress',
- 'polylines',
- self._points,
- self.parameters)
- self.plot.notify(**eventDict)
-
- if not isLast:
- self.setSelectionArea(self._points, fill='none', color=self.color,
- shape='polylines')
- else:
- self.cancel()
-
-
-# ItemInteraction #############################################################
-
-class ItemsInteraction(ClickOrDrag, _PlotInteraction):
- """Interaction with items (markers, curves and images).
-
- This class provides selection and dragging of plot primitives
- that support those interaction.
- It is also meant to be combined with the zoom interaction.
- """
-
- class Idle(ClickOrDrag.Idle):
- def __init__(self, *args, **kw):
- super(ItemsInteraction.Idle, self).__init__(*args, **kw)
- self._hoverMarker = None
-
- def onWheel(self, x, y, angle):
- scaleF = 1.1 if angle > 0 else 1. / 1.1
- applyZoomToPlot(self.machine.plot, scaleF, (x, y))
-
- def onMove(self, x, y):
- marker = self.machine.plot._getMarkerAt(x, y)
-
- if marker is not None:
- dataPos = self.machine.plot.pixelToData(x, y)
- assert dataPos is not None
- eventDict = prepareHoverSignal(
- marker.getName(), 'marker',
- dataPos, (x, y),
- marker.isDraggable(),
- marker.isSelectable())
- self.machine.plot.notify(**eventDict)
-
- if marker != self._hoverMarker:
- self._hoverMarker = marker
-
- if marker is None:
- self.machine.plot.setGraphCursorShape()
-
- elif marker.isDraggable():
- if isinstance(marker, items.YMarker):
- self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER)
- elif isinstance(marker, items.XMarker):
- self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR)
- else:
- self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL)
-
- elif marker.isSelectable():
- self.machine.plot.setGraphCursorShape(CURSOR_POINTING)
- else:
- self.machine.plot.setGraphCursorShape()
-
- return True
-
- def __init__(self, plot):
- self._pan = Pan(plot)
-
- _PlotInteraction.__init__(self, plot)
- ClickOrDrag.__init__(self,
- clickButtons=(LEFT_BTN, RIGHT_BTN),
- dragButtons=(LEFT_BTN, MIDDLE_BTN))
-
- def click(self, x, y, btn):
- """Handle mouse click
-
- :param x: X position of the mouse in pixels
- :param y: Y position of the mouse in pixels
- :param btn: Pressed button id
- :return: True if click is catched by an item, False otherwise
- """
- # Signal mouse clicked event
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- eventDict = prepareMouseSignal('mouseClicked', btn,
- dataPos[0], dataPos[1],
- x, y)
- self.plot.notify(**eventDict)
-
- eventDict = self._handleClick(x, y, btn)
- if eventDict is not None:
- self.plot.notify(**eventDict)
-
- def _handleClick(self, x, y, btn):
- """Perform picking and prepare event if click is handled here
-
- :param x: X position of the mouse in pixels
- :param y: Y position of the mouse in pixels
- :param btn: Pressed button id
- :return: event description to send of None if not handling event.
- :rtype: dict or None
- """
-
- if btn == LEFT_BTN:
- result = self.plot._pickTopMost(x, y, lambda i: i.isSelectable())
- if result is None:
- return None
-
- item = result.getItem()
-
- if isinstance(item, items.MarkerBase):
- xData, yData = item.getPosition()
- if xData is None:
- xData = [0, 1]
- if yData is None:
- yData = [0, 1]
-
- eventDict = prepareMarkerSignal('markerClicked',
- 'left',
- item.getName(),
- 'marker',
- item.isDraggable(),
- item.isSelectable(),
- (xData, yData),
- (x, y), None)
- return eventDict
-
- elif isinstance(item, items.Curve):
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- xData = item.getXData(copy=False)
- yData = item.getYData(copy=False)
-
- indices = result.getIndices(copy=False)
- eventDict = prepareCurveSignal('left',
- item.getName(),
- 'curve',
- xData[indices],
- yData[indices],
- dataPos[0], dataPos[1],
- x, y)
- return eventDict
-
- elif isinstance(item, items.ImageBase):
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- indices = result.getIndices(copy=False)
- row, column = indices[0][0], indices[1][0]
- eventDict = prepareImageSignal('left',
- item.getName(),
- 'image',
- column, row,
- dataPos[0], dataPos[1],
- x, y)
- return eventDict
-
- return None
-
- def _signalMarkerMovingEvent(self, eventType, marker, x, y):
- assert marker is not None
-
- xData, yData = marker.getPosition()
- if xData is None:
- xData = [0, 1]
- if yData is None:
- yData = [0, 1]
-
- posDataCursor = self.plot.pixelToData(x, y)
- assert posDataCursor is not None
-
- eventDict = prepareMarkerSignal(eventType,
- 'left',
- marker.getName(),
- 'marker',
- marker.isDraggable(),
- marker.isSelectable(),
- (xData, yData),
- (x, y),
- posDataCursor)
- self.plot.notify(**eventDict)
-
- @staticmethod
- def __isDraggableItem(item):
- return isinstance(item, items.DraggableMixIn) and item.isDraggable()
-
- def __terminateDrag(self):
- """Finalize a drag operation by reseting to initial state"""
- self.plot.setGraphCursorShape()
- self.draggedItemRef = None
-
- def beginDrag(self, x, y, btn):
- """Handle begining of drag interaction
-
- :param x: X position of the mouse in pixels
- :param y: Y position of the mouse in pixels
- :param str btn: The mouse button for which a drag is starting.
- :return: True if drag is catched by an item, False otherwise
- """
- if btn == LEFT_BTN:
- self._lastPos = self.plot.pixelToData(x, y)
- assert self._lastPos is not None
-
- result = self.plot._pickTopMost(x, y, self.__isDraggableItem)
- item = result.getItem() if result is not None else None
-
- self.draggedItemRef = None if item is None else weakref.ref(item)
-
- if item is None:
- self.__terminateDrag()
- return False
-
- if isinstance(item, items.MarkerBase):
- self._signalMarkerMovingEvent('markerMoving', item, x, y)
- item._startDrag()
-
- return True
- elif btn == MIDDLE_BTN:
- self._pan.beginDrag(x, y, btn)
- return True
-
- def drag(self, x, y, btn):
- if btn == LEFT_BTN:
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
-
- item = None if self.draggedItemRef is None else self.draggedItemRef()
- if item is not None:
- item.drag(self._lastPos, dataPos)
-
- if isinstance(item, items.MarkerBase):
- self._signalMarkerMovingEvent('markerMoving', item, x, y)
-
- self._lastPos = dataPos
- elif btn == MIDDLE_BTN:
- self._pan.drag(x, y, btn)
-
- def endDrag(self, startPos, endPos, btn):
- if btn == LEFT_BTN:
- item = None if self.draggedItemRef is None else self.draggedItemRef()
- if isinstance(item, items.MarkerBase):
- posData = list(item.getPosition())
- if posData[0] is None:
- posData[0] = 1.
- if posData[1] is None:
- posData[1] = 1.
-
- eventDict = prepareMarkerSignal(
- 'markerMoved',
- 'left',
- item.getLegend(),
- 'marker',
- item.isDraggable(),
- item.isSelectable(),
- posData)
- self.plot.notify(**eventDict)
- item._endDrag()
-
- self.__terminateDrag()
- elif btn == MIDDLE_BTN:
- self._pan.endDrag(startPos, endPos, btn)
-
- def cancel(self):
- self._pan.cancel()
- self.__terminateDrag()
-
-
-class ItemsInteractionForCombo(ItemsInteraction):
- """Interaction with items to combine through :class:`FocusManager`.
- """
-
- class Idle(ItemsInteraction.Idle):
- @staticmethod
- def __isItemSelectableOrDraggable(item):
- return (item.isSelectable() or (
- isinstance(item, items.DraggableMixIn) and item.isDraggable()))
-
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- result = self.machine.plot._pickTopMost(
- x, y, self.__isItemSelectableOrDraggable)
- if result is not None: # Request focus and handle interaction
- self.goto('clickOrDrag', x, y, btn)
- return True
- else: # Do not request focus
- return False
- else:
- return super().onPress(x, y, btn)
-
-
-# FocusManager ################################################################
-
-class FocusManager(StateMachine):
- """Manages focus across multiple event handlers
-
- On press an event handler can acquire focus.
- By default it looses focus when all buttons are released.
- """
- class Idle(State):
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- for eventHandler in self.machine.eventHandlers:
- requestFocus = eventHandler.handleEvent('press', x, y, btn)
- if requestFocus:
- self.goto('focus', eventHandler, btn)
- break
-
- def _processEvent(self, *args):
- for eventHandler in self.machine.eventHandlers:
- consumeEvent = eventHandler.handleEvent(*args)
- if consumeEvent:
- break
-
- def onMove(self, x, y):
- self._processEvent('move', x, y)
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- self._processEvent('release', x, y, btn)
-
- def onWheel(self, x, y, angle):
- self._processEvent('wheel', x, y, angle)
-
- class Focus(State):
- def enterState(self, eventHandler, btn):
- self.eventHandler = eventHandler
- self.focusBtns = {btn}
-
- def validate(self):
- self.eventHandler.validate()
- self.goto('idle')
-
- def onPress(self, x, y, btn):
- if btn == LEFT_BTN:
- self.focusBtns.add(btn)
- self.eventHandler.handleEvent('press', x, y, btn)
-
- def onMove(self, x, y):
- self.eventHandler.handleEvent('move', x, y)
-
- def onRelease(self, x, y, btn):
- if btn == LEFT_BTN:
- self.focusBtns.discard(btn)
- requestFocus = self.eventHandler.handleEvent('release', x, y, btn)
- if len(self.focusBtns) == 0 and not requestFocus:
- self.goto('idle')
-
- def onWheel(self, x, y, angleInDegrees):
- self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
-
- def __init__(self, eventHandlers=()):
- self.eventHandlers = list(eventHandlers)
-
- states = {
- 'idle': FocusManager.Idle,
- 'focus': FocusManager.Focus
- }
- super(FocusManager, self).__init__(states, 'idle')
-
- def cancel(self):
- for handler in self.eventHandlers:
- handler.cancel()
-
-
-class ZoomAndSelect(ItemsInteraction):
- """Combine Zoom and ItemInteraction state machine.
-
- :param plot: The Plot to which this interaction is attached
- :param color: The color to use for the zoom area bounding box
- """
-
- def __init__(self, plot, color):
- super(ZoomAndSelect, self).__init__(plot)
- self._zoom = Zoom(plot, color)
- self._doZoom = False
-
- @property
- def color(self):
- """Color of the zoom area"""
- return self._zoom.color
-
- def click(self, x, y, btn):
- """Handle mouse click
-
- :param x: X position of the mouse in pixels
- :param y: Y position of the mouse in pixels
- :param btn: Pressed button id
- :return: True if click is catched by an item, False otherwise
- """
- eventDict = self._handleClick(x, y, btn)
-
- if eventDict is not None:
- # Signal mouse clicked event
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- clickedEventDict = prepareMouseSignal('mouseClicked', btn,
- dataPos[0], dataPos[1],
- x, y)
- self.plot.notify(**clickedEventDict)
-
- self.plot.notify(**eventDict)
-
- else:
- self._zoom.click(x, y, btn)
-
- def beginDrag(self, x, y, btn):
- """Handle start drag and switching between zoom and item drag.
-
- :param x: X position in pixels
- :param y: Y position in pixels
- :param str btn: The mouse button for which a drag is starting.
- """
- self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y, btn)
- if self._doZoom:
- self._zoom.beginDrag(x, y, btn)
-
- def drag(self, x, y, btn):
- """Handle drag, eventually forwarding to zoom.
-
- :param x: X position in pixels
- :param y: Y position in pixels
- :param str btn: The mouse button for which a drag is in progress.
- """
- if self._doZoom:
- return self._zoom.drag(x, y, btn)
- else:
- return super(ZoomAndSelect, self).drag(x, y, btn)
-
- def endDrag(self, startPos, endPos, btn):
- """Handle end of drag, eventually forwarding to zoom.
-
- :param startPos: (x, y) position at the beginning of the drag
- :param endPos: (x, y) position at the end of the drag
- :param str btn: The mouse button for which a drag is done.
- """
- if self._doZoom:
- return self._zoom.endDrag(startPos, endPos, btn)
- else:
- return super(ZoomAndSelect, self).endDrag(startPos, endPos, btn)
-
-
-class PanAndSelect(ItemsInteraction):
- """Combine Pan and ItemInteraction state machine.
-
- :param plot: The Plot to which this interaction is attached
- """
-
- def __init__(self, plot):
- super(PanAndSelect, self).__init__(plot)
- self._pan = Pan(plot)
- self._doPan = False
-
- def click(self, x, y, btn):
- """Handle mouse click
-
- :param x: X position of the mouse in pixels
- :param y: Y position of the mouse in pixels
- :param btn: Pressed button id
- :return: True if click is catched by an item, False otherwise
- """
- eventDict = self._handleClick(x, y, btn)
-
- if eventDict is not None:
- # Signal mouse clicked event
- dataPos = self.plot.pixelToData(x, y)
- assert dataPos is not None
- clickedEventDict = prepareMouseSignal('mouseClicked', btn,
- dataPos[0], dataPos[1],
- x, y)
- self.plot.notify(**clickedEventDict)
-
- self.plot.notify(**eventDict)
-
- else:
- self._pan.click(x, y, btn)
-
- def beginDrag(self, x, y, btn):
- """Handle start drag and switching between zoom and item drag.
-
- :param x: X position in pixels
- :param y: Y position in pixels
- :param str btn: The mouse button for which a drag is starting.
- """
- self._doPan = not super(PanAndSelect, self).beginDrag(x, y, btn)
- if self._doPan:
- self._pan.beginDrag(x, y, btn)
-
- def drag(self, x, y, btn):
- """Handle drag, eventually forwarding to zoom.
-
- :param x: X position in pixels
- :param y: Y position in pixels
- :param str btn: The mouse button for which a drag is in progress.
- """
- if self._doPan:
- return self._pan.drag(x, y, btn)
- else:
- return super(PanAndSelect, self).drag(x, y, btn)
-
- def endDrag(self, startPos, endPos, btn):
- """Handle end of drag, eventually forwarding to zoom.
-
- :param startPos: (x, y) position at the beginning of the drag
- :param endPos: (x, y) position at the end of the drag
- :param str btn: The mouse button for which a drag is done.
- """
- if self._doPan:
- return self._pan.endDrag(startPos, endPos, btn)
- else:
- return super(PanAndSelect, self).endDrag(startPos, endPos, btn)
-
-
-# Interaction mode control ####################################################
-
-# Mapping of draw modes: event handler
-_DRAW_MODES = {
- 'polygon': SelectPolygon,
- 'rectangle': SelectRectangle,
- 'ellipse': SelectEllipse,
- 'line': SelectLine,
- 'vline': SelectVLine,
- 'hline': SelectHLine,
- 'polylines': SelectFreeLine,
- 'pencil': DrawFreeHand,
- }
-
-
-class DrawMode(FocusManager):
- """Interactive mode for draw and select"""
-
- def __init__(self, plot, shape, label, color, width):
- eventHandlerClass = _DRAW_MODES[shape]
- parameters = {
- 'shape': shape,
- 'label': label,
- 'color': color,
- 'width': width,
- }
- super().__init__((
- Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)),
- eventHandlerClass(plot, parameters)))
-
- def getDescription(self):
- """Returns the dict describing this interactive mode"""
- params = self.eventHandlers[1].parameters.copy()
- params['mode'] = 'draw'
- return params
-
-
-class DrawSelectMode(FocusManager):
- """Interactive mode for draw and select"""
-
- def __init__(self, plot, shape, label, color, width):
- eventHandlerClass = _DRAW_MODES[shape]
- self._pan = Pan(plot)
- self._panStart = None
- parameters = {
- 'shape': shape,
- 'label': label,
- 'color': color,
- 'width': width,
- }
- super().__init__((
- ItemsInteractionForCombo(plot),
- eventHandlerClass(plot, parameters)))
-
- def handleEvent(self, eventName, *args, **kwargs):
- # Hack to add pan interaction to select-draw
- # See issue Refactor PlotWidget interaction #3292
- if eventName == 'press' and args[2] == MIDDLE_BTN:
- self._panStart = args[:2]
- self._pan.beginDrag(*args)
- return # Consume middle click events
- elif eventName == 'release' and args[2] == MIDDLE_BTN:
- self._panStart = None
- self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN)
- return # Consume middle click events
- elif self._panStart is not None and eventName == 'move':
- x, y = args[:2]
- self._pan.drag(x, y, MIDDLE_BTN)
-
- super().handleEvent(eventName, *args, **kwargs)
-
- def getDescription(self):
- """Returns the dict describing this interactive mode"""
- params = self.eventHandlers[1].parameters.copy()
- params['mode'] = 'select-draw'
- return params
-
-
-class PlotInteraction(object):
- """Proxy to currently use state machine for interaction.
-
- This allows to switch interactive mode.
-
- :param plot: The :class:`Plot` to apply interaction to
- """
-
- _DRAW_MODES = {
- 'polygon': SelectPolygon,
- 'rectangle': SelectRectangle,
- 'ellipse': SelectEllipse,
- 'line': SelectLine,
- 'vline': SelectVLine,
- 'hline': SelectHLine,
- 'polylines': SelectFreeLine,
- 'pencil': DrawFreeHand,
- }
-
- def __init__(self, plot):
- self._plot = weakref.ref(plot) # Avoid cyclic-ref
-
- self.zoomOnWheel = True
- """True to enable zoom on wheel, False otherwise."""
-
- # Default event handler
- self._eventHandler = ItemsInteraction(plot)
-
- def getInteractiveMode(self):
- """Returns the current interactive mode as a dict.
-
- The returned dict contains at least the key 'mode'.
- Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
- It can also contains extra keys (e.g., 'color') specific to a mode
- as provided to :meth:`setInteractiveMode`.
- """
- if isinstance(self._eventHandler, ZoomAndSelect):
- return {'mode': 'zoom', 'color': self._eventHandler.color}
-
- elif isinstance(self._eventHandler, (DrawMode, DrawSelectMode)):
- return self._eventHandler.getDescription()
-
- elif isinstance(self._eventHandler, PanAndSelect):
- return {'mode': 'pan'}
-
- else:
- return {'mode': 'select'}
-
- def validate(self):
- """Validate the current interaction if possible
-
- If was designed to close the polygon interaction.
- """
- self._eventHandler.validate()
-
- def setInteractiveMode(self, mode, color='black',
- shape='polygon', label=None, width=None):
- """Switch the interactive mode.
-
- :param str mode: The name of the interactive mode.
- In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
- :param color: Only for 'draw' and 'zoom' modes.
- Color to use for drawing selection area. Default black.
- If None, selection area is not drawn.
- :type color: Color description: The name as a str or
- a tuple of 4 floats or None.
- :param str shape: Only for 'draw' mode. The kind of shape to draw.
- In 'polygon', 'rectangle', 'line', 'vline', 'hline',
- 'polylines'.
- Default is 'polygon'.
- :param str label: Only for 'draw' mode.
- :param float width: Width of the pencil. Only for draw pencil mode.
- """
- assert mode in ('draw', 'pan', 'select', 'select-draw', 'zoom')
-
- plot = self._plot()
- assert plot is not None
-
- if isinstance(color, numpy.ndarray) or color not in (None, 'video inverted'):
- color = colors.rgba(color)
-
- if mode in ('draw', 'select-draw'):
- self._eventHandler.cancel()
- handlerClass = DrawMode if mode == 'draw' else DrawSelectMode
- self._eventHandler = handlerClass(plot, shape, label, color, width)
-
- elif mode == 'pan':
- # Ignores color, shape and label
- self._eventHandler.cancel()
- self._eventHandler = PanAndSelect(plot)
-
- elif mode == 'zoom':
- # Ignores shape and label
- self._eventHandler.cancel()
- self._eventHandler = ZoomAndSelect(plot, color)
-
- else: # Default mode: interaction with plot objects
- # Ignores color, shape and label
- self._eventHandler.cancel()
- self._eventHandler = ItemsInteraction(plot)
-
- def handleEvent(self, event, *args, **kwargs):
- """Forward event to current interactive mode state machine."""
- if not self.zoomOnWheel and event == 'wheel':
- return # Discard wheel events
- self._eventHandler.handleEvent(event, *args, **kwargs)
diff --git a/silx/gui/plot/PlotWidget.py b/silx/gui/plot/PlotWidget.py
deleted file mode 100755
index 2a211de..0000000
--- a/silx/gui/plot/PlotWidget.py
+++ /dev/null
@@ -1,3621 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Qt widget providing plot API for 1D and 2D data.
-
-The :class:`PlotWidget` implements the plot API initially provided in PyMca.
-"""
-
-from __future__ import division
-
-
-__authors__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "21/12/2018"
-
-import logging
-
-_logger = logging.getLogger(__name__)
-
-
-from collections import OrderedDict, namedtuple
-from contextlib import contextmanager
-import datetime as dt
-import itertools
-import typing
-import warnings
-
-import numpy
-
-import silx
-from silx.utils.weakref import WeakMethodProxy
-from silx.utils.property import classproperty
-from silx.utils.deprecation import deprecated, deprecated_warning
-try:
- # Import matplotlib now to init matplotlib our way
- import silx.gui.utils.matplotlib # noqa
-except ImportError:
- _logger.debug("matplotlib not available")
-
-import six
-from ..colors import Colormap
-from .. import colors
-from . import PlotInteraction
-from . import PlotEvents
-from .LimitsHistory import LimitsHistory
-from . import _utils
-
-from . import items
-from .items.curve import CurveStyle
-from .items.axis import TickMode # noqa
-
-from .. import qt
-from ._utils.panzoom import ViewConstraints
-from ...gui.plot._utils.dtime_ticklayout import timestamp
-
-
-
-_COLORDICT = colors.COLORDICT
-_COLORLIST = silx.config.DEFAULT_PLOT_CURVE_COLORS
-
-"""
-Object returned when requesting the data range.
-"""
-_PlotDataRange = namedtuple('PlotDataRange',
- ['x', 'y', 'yright'])
-
-
-class _PlotWidgetSelection(qt.QObject):
- """Object managing a :class:`PlotWidget` selection.
-
- It is a wrapper over :class:`PlotWidget`'s active items API.
-
- :param PlotWidget parent:
- """
-
- sigCurrentItemChanged = qt.Signal(object, object)
- """This signal is emitted whenever the current item changes.
-
- It provides the current and previous items.
- """
-
- sigSelectedItemsChanged = qt.Signal()
- """Signal emitted whenever the list of selected items changes."""
-
- def __init__(self, parent):
- assert isinstance(parent, PlotWidget)
- super(_PlotWidgetSelection, self).__init__(parent=parent)
-
- # Init history
- self.__history = [ # Store active items from most recent to oldest
- item for item in (parent.getActiveCurve(),
- parent.getActiveImage(),
- parent.getActiveScatter())
- if item is not None]
-
- self.__current = self.__mostRecentActiveItem()
-
- parent.sigActiveImageChanged.connect(self._activeImageChanged)
- parent.sigActiveCurveChanged.connect(self._activeCurveChanged)
- parent.sigActiveScatterChanged.connect(self._activeScatterChanged)
-
- def __mostRecentActiveItem(self) -> typing.Optional[items.Item]:
- """Returns most recent active item."""
- return self.__history[0] if len(self.__history) >= 1 else None
-
- def getSelectedItems(self) -> typing.Tuple[items.Item]:
- """Returns the list of currently selected items in the :class:`PlotWidget`.
-
- The list is given from most recently current item to oldest one."""
- plot = self.parent()
- if plot is None:
- return ()
-
- active = tuple(self.__history)
-
- current = self.getCurrentItem()
- if current is not None and current not in active:
- # Current might not be an active item, if so add it
- active = (current,) + active
-
- return active
-
- def getCurrentItem(self) -> typing.Optional[items.Item]:
- """Returns the current item in the :class:`PlotWidget` or None. """
- return self.__current
-
- def setCurrentItem(self, item: typing.Optional[items.Item]):
- """Set the current item in the :class:`PlotWidget`.
-
- :param item:
- The new item to select or None to clear the selection.
- :raise ValueError: If the item is not the :class:`PlotWidget`
- """
- previous = self.getCurrentItem()
- if previous is item:
- return
-
- previousSelected = self.getSelectedItems()
-
- if item is None:
- self.__current = None
-
- # Reset all PlotWidget active items
- plot = self.parent()
- if plot is not None:
- for kind in PlotWidget._ACTIVE_ITEM_KINDS:
- if plot._getActiveItem(kind) is not None:
- plot._setActiveItem(kind, None)
-
- elif isinstance(item, items.Item):
- plot = self.parent()
- if plot is None or item.getPlot() is not plot:
- raise ValueError(
- "Item is not in the PlotWidget: %s" % str(item))
- self.__current = item
-
- kind = plot._itemKind(item)
-
- # Clean-up history to be safe
- self.__history = [item for item in self.__history
- if PlotWidget._itemKind(item) != kind]
-
- # Sync active item if needed
- if (kind in plot._ACTIVE_ITEM_KINDS and
- item is not plot._getActiveItem(kind)):
- plot._setActiveItem(kind, item.getName())
- else:
- raise ValueError("Not an Item: %s" % str(item))
-
- self.sigCurrentItemChanged.emit(previous, item)
-
- if previousSelected != self.getSelectedItems():
- self.sigSelectedItemsChanged.emit()
-
- def __activeItemChanged(self,
- kind: str,
- previous: typing.Optional[str],
- legend: typing.Optional[str]):
- """Set current item from kind and legend"""
- if previous == legend:
- return # No-op for update of item
-
- plot = self.parent()
- if plot is None:
- return
-
- previousSelected = self.getSelectedItems()
-
- # Remove items of this kind from the history
- self.__history = [item for item in self.__history
- if PlotWidget._itemKind(item) != kind]
-
- # Retrieve current item
- if legend is None: # Use most recent active item
- currentItem = self.__mostRecentActiveItem()
- else:
- currentItem = plot._getItem(kind=kind, legend=legend)
- if currentItem is None: # Fallback in case something went wrong
- currentItem = self.__mostRecentActiveItem()
-
- # Update history
- if currentItem is not None:
- while currentItem in self.__history:
- self.__history.remove(currentItem)
- self.__history.insert(0, currentItem)
-
- if currentItem != self.__current:
- previousItem = self.__current
- self.__current = currentItem
- self.sigCurrentItemChanged.emit(previousItem, currentItem)
-
- if previousSelected != self.getSelectedItems():
- self.sigSelectedItemsChanged.emit()
-
- def _activeImageChanged(self, previous, current):
- """Handle active image change"""
- self.__activeItemChanged('image', previous, current)
-
- def _activeCurveChanged(self, previous, current):
- """Handle active curve change"""
- self.__activeItemChanged('curve', previous, current)
-
- def _activeScatterChanged(self, previous, current):
- """Handle active scatter change"""
- self.__activeItemChanged('scatter', previous, current)
-
-
-class PlotWidget(qt.QMainWindow):
- """Qt Widget providing a 1D/2D plot.
-
- This widget is a QMainWindow.
- This class implements the plot API initially provided in PyMca.
-
- Supported backends:
-
- - 'matplotlib' and 'mpl': Matplotlib with Qt.
- - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
- - 'none': No backend, to run headless for testing purpose.
-
- :param parent: The parent of this widget or None (default).
- :param backend: The backend to use, in:
- 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
- or a :class:`BackendBase.BackendBase` class
- :type backend: str or :class:`BackendBase.BackendBase`
- """
-
- # TODO: Can be removed for silx 0.10
- @classproperty
- @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
- def DEFAULT_BACKEND(self):
- """Class attribute setting the default backend for all instances."""
- return silx.config.DEFAULT_PLOT_BACKEND
-
- colorList = _COLORLIST
- colorDict = _COLORDICT
-
- sigPlotSignal = qt.Signal(object)
- """Signal for all events of the plot.
-
- The signal information is provided as a dict.
- See the :ref:`plot signal documentation page <plot_signal>` for
- information about the content of the dict
- """
-
- sigSetKeepDataAspectRatio = qt.Signal(bool)
- """Signal emitted when plot keep aspect ratio has changed"""
-
- sigSetGraphGrid = qt.Signal(str)
- """Signal emitted when plot grid has changed"""
-
- sigSetGraphCursor = qt.Signal(bool)
- """Signal emitted when plot crosshair cursor has changed"""
-
- sigSetPanWithArrowKeys = qt.Signal(bool)
- """Signal emitted when pan with arrow keys has changed"""
-
- _sigAxesVisibilityChanged = qt.Signal(bool)
- """Signal emitted when the axes visibility changed"""
-
- sigContentChanged = qt.Signal(str, str, str)
- """Signal emitted when the content of the plot is changed.
-
- It provides the following information:
-
- - action: The change of the plot: 'add' or 'remove'
- - kind: The kind of primitive changed:
- 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker'
- - legend: The legend of the primitive changed.
- """
-
- sigActiveCurveChanged = qt.Signal(object, object)
- """Signal emitted when the active curve has changed.
-
- It provides the following information:
-
- - previous: The legend of the previous active curve or None
- - legend: The legend of the new active curve or None if no curve is active
- """
-
- sigActiveImageChanged = qt.Signal(object, object)
- """Signal emitted when the active image has changed.
-
- It provides the following information:
-
- - previous: The legend of the previous active image or None
- - legend: The legend of the new active image or None if no image is active
- """
-
- sigActiveScatterChanged = qt.Signal(object, object)
- """Signal emitted when the active Scatter has changed.
-
- It provides the following information:
-
- - previous: The legend of the previous active scatter or None
- - legend: The legend of the new active image or None if no image is active
- """
-
- sigInteractiveModeChanged = qt.Signal(object)
- """Signal emitted when the interactive mode has changed
-
- It provides the source as passed to :meth:`setInteractiveMode`.
- """
-
- sigItemAdded = qt.Signal(items.Item)
- """Signal emitted when an item was just added to the plot
-
- It provides the added item.
- """
-
- sigItemAboutToBeRemoved = qt.Signal(items.Item)
- """Signal emitted right before an item is removed from the plot.
-
- It provides the item that will be removed.
- """
-
- sigItemRemoved = qt.Signal(items.Item)
- """Signal emitted right after an item was removed from the plot.
-
- It provides the item that was removed.
- """
-
- sigVisibilityChanged = qt.Signal(bool)
- """Signal emitted when the widget becomes visible (or invisible).
- This happens when the widget is hidden or shown.
-
- It provides the visible state.
- """
-
- _sigDefaultContextMenu = qt.Signal(qt.QMenu)
- """Signal emitted when the default context menu of the plot is feed.
-
- It provides the menu which will be displayed.
- """
-
- def __init__(self, parent=None, backend=None):
- self._autoreplot = False
- self._dirty = False
- self._cursorInPlot = False
- self.__muteActiveItemChanged = False
-
- self._panWithArrowKeys = True
- self._viewConstrains = None
-
- super(PlotWidget, self).__init__(parent)
- if parent is not None:
- # behave as a widget
- self.setWindowFlags(qt.Qt.Widget)
- else:
- self.setWindowTitle('PlotWidget')
-
- # Init the backend
- self._backend = self.__getBackendClass(backend)(self, self)
-
- self.setCallback() # set _callback
-
- # Items handling
- self._content = OrderedDict()
- self._contentToUpdate = [] # Used as an OrderedSet
-
- self._dataRange = None
-
- # line types
- self._styleList = ['-', '--', '-.', ':']
- self._colorIndex = 0
- self._styleIndex = 0
-
- self._activeCurveSelectionMode = "atmostone"
- self._activeCurveStyle = CurveStyle(color='#000000')
- self._activeLegend = {'curve': None, 'image': None,
- 'scatter': None}
-
- # plot colors (updated later to sync backend)
- self._foregroundColor = 0., 0., 0., 1.
- self._gridColor = .7, .7, .7, 1.
- self._backgroundColor = 1., 1., 1., 1.
- self._dataBackgroundColor = None
-
- # default properties
- self._cursorConfiguration = None
-
- self._xAxis = items.XAxis(self)
- self._yAxis = items.YAxis(self)
- self._yRightAxis = items.YRightAxis(self, self._yAxis)
-
- self._grid = None
- self._graphTitle = ''
- self.__graphCursorShape = 'default'
-
- # Set axes margins
- self.__axesDisplayed = True
- self.__axesMargins = 0., 0., 0., 0.
- self.setAxesMargins(.15, .1, .1, .15)
-
- self.setGraphTitle()
- self.setGraphXLabel()
- self.setGraphYLabel()
- self.setGraphYLabel('', axis='right')
-
- self.setDefaultColormap() # Init default colormap
-
- self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE)
- self.setDefaultPlotLines(True)
-
- self._limitsHistory = LimitsHistory(self)
-
- self._eventHandler = PlotInteraction.PlotInteraction(self)
- self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.))
- self._previousDefaultMode = "zoom", True
-
- self._pressedButtons = [] # Currently pressed mouse buttons
-
- self._defaultDataMargins = (0., 0., 0., 0.)
-
- # Only activate autoreplot at the end
- # This avoids errors when loaded in Qt designer
- self._dirty = False
- self._autoreplot = True
-
- widget = self.getWidgetHandle()
- if widget is not None:
- self.setCentralWidget(widget)
- else:
- _logger.info("PlotWidget backend does not support widget")
-
- self.setFocusPolicy(qt.Qt.StrongFocus)
- self.setFocus(qt.Qt.OtherFocusReason)
-
- # Set default limits
- self.setGraphXLimits(0., 100.)
- self.setGraphYLimits(0., 100., axis='right')
- self.setGraphYLimits(0., 100., axis='left')
-
- # Sync backend colors with default ones
- self._foregroundColorsUpdated()
- self._backgroundColorsUpdated()
-
- # selection handling
- self.__selection = None
-
- def __getBackendClass(self, backend):
- """Returns backend class corresponding to backend.
-
- If multiple backends are provided, the first available one is used.
-
- :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
- The name of the backend or its class or an iterable of those.
- :rtype: BackendBase
- :raise ValueError: In case the backend is not supported
- :raise RuntimeError: If a backend is not available
- """
- if backend is None:
- backend = silx.config.DEFAULT_PLOT_BACKEND
-
- if callable(backend):
- return backend
-
- elif isinstance(backend, six.string_types):
- backend = backend.lower()
- if backend in ('matplotlib', 'mpl'):
- try:
- from .backends.BackendMatplotlib import \
- BackendMatplotlibQt as backendClass
- except ImportError:
- _logger.debug("Backtrace", exc_info=True)
- raise RuntimeError("matplotlib backend is not available")
-
- elif backend in ('gl', 'opengl'):
- from ..utils.glutils import isOpenGLAvailable
- checkOpenGL = isOpenGLAvailable(version=(2, 1), runtimeCheck=False)
- if not checkOpenGL:
- _logger.debug("OpenGL check failed")
- raise RuntimeError(
- "OpenGL backend is not available: %s" % checkOpenGL.error)
-
- try:
- from .backends.BackendOpenGL import \
- BackendOpenGL as backendClass
- except ImportError:
- _logger.debug("Backtrace", exc_info=True)
- raise RuntimeError("OpenGL backend is not available")
-
- elif backend == 'none':
- from .backends.BackendBase import BackendBase as backendClass
-
- else:
- raise ValueError("Backend not supported %s" % backend)
-
- return backendClass
-
- elif isinstance(backend, (tuple, list)):
- for b in backend:
- try:
- return self.__getBackendClass(b)
- except RuntimeError:
- pass
- else: # No backend was found
- raise RuntimeError("None of the request backends are available")
-
- raise ValueError("Backend not supported %s" % str(backend))
-
- def selection(self):
- """Returns the selection hander"""
- if self.__selection is None: # Lazy initialization
- self.__selection = _PlotWidgetSelection(parent=self)
- return self.__selection
-
- # TODO: Can be removed for silx 0.10
- @staticmethod
- @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
- def setDefaultBackend(backend):
- """Set system wide default plot backend.
-
- .. versionadded:: 0.6
-
- :param backend: The backend to use, in:
- 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
- or a :class:`BackendBase.BackendBase` class
- """
- silx.config.DEFAULT_PLOT_BACKEND = backend
-
- def setBackend(self, backend):
- """Set the backend to use for rendering.
-
- Supported backends:
-
- - 'matplotlib' and 'mpl': Matplotlib with Qt.
- - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
- - 'none': No backend, to run headless for testing purpose.
-
- :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
- The backend to use, in:
- 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none',
- a :class:`BackendBase.BackendBase` class.
- If multiple backends are provided, the first available one is used.
- :raises ValueError: Unsupported backend descriptor
- :raises RuntimeError: Error while loading a backend
- """
- backend = self.__getBackendClass(backend)(self, self)
-
- # First save state that is stored in the backend
- xaxis = self.getXAxis()
- xmin, xmax = xaxis.getLimits()
- ymin, ymax = self.getYAxis(axis='left').getLimits()
- y2min, y2max = self.getYAxis(axis='right').getLimits()
- isKeepDataAspectRatio = self.isKeepDataAspectRatio()
- xTimeZone = xaxis.getTimeZone()
- isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES
-
- isYAxisInverted = self.getYAxis().isInverted()
-
- # Remove all items from previous backend
- for item in self.getItems():
- item._removeBackendRenderer(self._backend)
-
- # Switch backend
- self._backend = backend
- widget = self._backend.getWidgetHandle()
- self.setCentralWidget(widget)
- if widget is None:
- _logger.info("PlotWidget backend does not support widget")
-
- # Mark as newly dirty
- self._dirty = False
- self._setDirtyPlot()
-
- # Synchronize/restore state
- self._foregroundColorsUpdated()
- self._backgroundColorsUpdated()
-
- self._backend.setGraphCursorShape(self.getGraphCursorShape())
- crosshairConfig = self.getGraphCursor()
- if crosshairConfig is None:
- self._backend.setGraphCursor(False, 'black', 1, '-')
- else:
- self._backend.setGraphCursor(True, *crosshairConfig)
-
- self._backend.setGraphTitle(self.getGraphTitle())
- self._backend.setGraphGrid(self.getGraphGrid())
- if self.isAxesDisplayed():
- self._backend.setAxesMargins(*self.getAxesMargins())
- else:
- self._backend.setAxesMargins(0., 0., 0., 0.)
-
- # Set axes
- xaxis = self.getXAxis()
- self._backend.setGraphXLabel(xaxis.getLabel())
- self._backend.setXAxisTimeZone(xTimeZone)
- self._backend.setXAxisTimeSeries(isXAxisTimeSeries)
- self._backend.setXAxisLogarithmic(
- xaxis.getScale() == items.Axis.LOGARITHMIC)
-
- for axis in ('left', 'right'):
- self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis)
- self._backend.setYAxisInverted(isYAxisInverted)
- self._backend.setYAxisLogarithmic(
- self.getYAxis().getScale() == items.Axis.LOGARITHMIC)
-
- # Finally restore aspect ratio and limits
- self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio)
- self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
-
- # Mark all items for update with new backend
- for item in self.getItems():
- item._updated()
-
- def getBackend(self):
- """Returns the backend currently used by :class:`PlotWidget`.
-
- :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase
- """
- return self._backend
-
- def _getDirtyPlot(self):
- """Return the plot dirty flag.
-
- If False, the plot has not changed since last replot.
- If True, the full plot need to be redrawn.
- If 'overlay', only the overlay has changed since last replot.
-
- It can be accessed by backend to check the dirty state.
-
- :return: False, True, 'overlay'
- """
- return self._dirty
-
- # Default Qt context menu
-
- def contextMenuEvent(self, event):
- """Override QWidget.contextMenuEvent to implement the context menu"""
- menu = qt.QMenu(self)
- from .actions.control import ZoomBackAction # Avoid cyclic import
- zoomBackAction = ZoomBackAction(plot=self, parent=menu)
- menu.addAction(zoomBackAction)
-
- mode = self.getInteractiveMode()
- if "shape" in mode and mode["shape"] == "polygon":
- from .actions.control import ClosePolygonInteractionAction # Avoid cyclic import
- action = ClosePolygonInteractionAction(plot=self, parent=menu)
- menu.addAction(action)
-
- self._sigDefaultContextMenu.emit(menu)
-
- # Make sure the plot is updated, especially when the plot is in
- # draw interaction mode
- menu.aboutToHide.connect(self.__simulateMouseMove)
-
- menu.exec_(event.globalPos())
-
- def _setDirtyPlot(self, overlayOnly=False):
- """Mark the plot as needing redraw
-
- :param bool overlayOnly: True to redraw only the overlay,
- False to redraw everything
- """
- wasDirty = self._dirty
-
- if not self._dirty and overlayOnly:
- self._dirty = 'overlay'
- else:
- self._dirty = True
-
- if self._autoreplot and not wasDirty and self.isVisible():
- self._backend.postRedisplay()
-
- def _foregroundColorsUpdated(self):
- """Handle change of foreground/grid color"""
- if self._gridColor is None:
- gridColor = self._foregroundColor
- else:
- gridColor = self._gridColor
- self._backend.setForegroundColors(
- self._foregroundColor, gridColor)
- self._setDirtyPlot()
-
- def getForegroundColor(self):
- """Returns the RGBA colors used to display the foreground of this widget
-
- :rtype: qt.QColor
- """
- return qt.QColor.fromRgbF(*self._foregroundColor)
-
- def setForegroundColor(self, color):
- """Set the foreground color of this widget.
-
- :param Union[List[int],List[float],QColor] color:
- The new RGB(A) color.
- """
- color = colors.rgba(color)
- if self._foregroundColor != color:
- self._foregroundColor = color
- self._foregroundColorsUpdated()
-
- def getGridColor(self):
- """Returns the RGBA colors used to display the grid lines
-
- An invalid QColor is returned if there is no grid color,
- in which case the foreground color is used.
-
- :rtype: qt.QColor
- """
- if self._gridColor is None:
- return qt.QColor() # An invalid color
- else:
- return qt.QColor.fromRgbF(*self._gridColor)
-
- def setGridColor(self, color):
- """Set the grid lines color
-
- :param Union[List[int],List[float],QColor,None] color:
- The new RGB(A) color.
- """
- if isinstance(color, qt.QColor) and not color.isValid():
- color = None
- if color is not None:
- color = colors.rgba(color)
- if self._gridColor != color:
- self._gridColor = color
- self._foregroundColorsUpdated()
-
- def _backgroundColorsUpdated(self):
- """Handle change of background/data background color"""
- if self._dataBackgroundColor is None:
- dataBGColor = self._backgroundColor
- else:
- dataBGColor = self._dataBackgroundColor
- self._backend.setBackgroundColors(
- self._backgroundColor, dataBGColor)
- self._setDirtyPlot()
-
- def getBackgroundColor(self):
- """Returns the RGBA colors used to display the background of this widget.
-
- :rtype: qt.QColor
- """
- return qt.QColor.fromRgbF(*self._backgroundColor)
-
- def setBackgroundColor(self, color):
- """Set the background color of this widget.
-
- :param Union[List[int],List[float],QColor] color:
- The new RGB(A) color.
- """
- color = colors.rgba(color)
- if self._backgroundColor != color:
- self._backgroundColor = color
- self._backgroundColorsUpdated()
-
- def getDataBackgroundColor(self):
- """Returns the RGBA colors used to display the background of the plot
- view displaying the data.
-
- An invalid QColor is returned if there is no data background color.
-
- :rtype: qt.QColor
- """
- if self._dataBackgroundColor is None:
- # An invalid color
- return qt.QColor()
- else:
- return qt.QColor.fromRgbF(*self._dataBackgroundColor)
-
- def setDataBackgroundColor(self, color):
- """Set the background color of the plot area.
-
- Set to None or an invalid QColor to use the background color.
-
- :param Union[List[int],List[float],QColor,None] color:
- The new RGB(A) color.
- """
- if isinstance(color, qt.QColor) and not color.isValid():
- color = None
- if color is not None:
- color = colors.rgba(color)
- if self._dataBackgroundColor != color:
- self._dataBackgroundColor = color
- self._backgroundColorsUpdated()
-
- dataBackgroundColor = qt.Property(
- qt.QColor, getDataBackgroundColor, setDataBackgroundColor
- )
-
- backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor)
-
- foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor)
-
- gridColor = qt.Property(qt.QColor, getGridColor, setGridColor)
-
- def showEvent(self, event):
- if self._autoreplot and self._dirty:
- self._backend.postRedisplay()
- super(PlotWidget, self).showEvent(event)
- self.sigVisibilityChanged.emit(True)
-
- def hideEvent(self, event):
- super(PlotWidget, self).hideEvent(event)
- self.sigVisibilityChanged.emit(False)
-
- def _invalidateDataRange(self):
- """
- Notifies this PlotWidget instance that the range has changed
- and will have to be recomputed.
- """
- self._dataRange = None
-
- def _updateDataRange(self):
- """
- Recomputes the range of the data displayed on this PlotWidget.
- """
- xMin = yMinLeft = yMinRight = float('nan')
- xMax = yMaxLeft = yMaxRight = float('nan')
-
- for item in self.getItems():
- if item.isVisible():
- bounds = item.getBounds()
- if bounds is not None:
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
- # Ignore All-NaN slice encountered
- xMin = numpy.nanmin([xMin, bounds[0]])
- xMax = numpy.nanmax([xMax, bounds[1]])
- # Take care of right axis
- if (isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right'):
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
- # Ignore All-NaN slice encountered
- yMinRight = numpy.nanmin([yMinRight, bounds[2]])
- yMaxRight = numpy.nanmax([yMaxRight, bounds[3]])
- else:
- with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=RuntimeWarning)
- # Ignore All-NaN slice encountered
- yMinLeft = numpy.nanmin([yMinLeft, bounds[2]])
- yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]])
-
- def lGetRange(x, y):
- return None if numpy.isnan(x) and numpy.isnan(y) else (x, y)
- xRange = lGetRange(xMin, xMax)
- yLeftRange = lGetRange(yMinLeft, yMaxLeft)
- yRightRange = lGetRange(yMinRight, yMaxRight)
-
- self._dataRange = _PlotDataRange(x=xRange,
- y=yLeftRange,
- yright=yRightRange)
-
- def getDataRange(self):
- """
- Returns this PlotWidget's data range.
-
- :return: a namedtuple with the following members :
- x, y (left y axis), yright. Each member is a tuple (min, max)
- or None if no data is associated with the axis.
- :rtype: namedtuple
- """
- if self._dataRange is None:
- self._updateDataRange()
- return self._dataRange
-
- # Content management
-
- _KIND_TO_CLASSES = {
- 'curve': (items.Curve,),
- 'image': (items.ImageBase,),
- 'scatter': (items.Scatter,),
- 'marker': (items.MarkerBase,),
- 'item': (items.Shape,
- items.BoundingRect,
- items.XAxisExtent,
- items.YAxisExtent),
- 'histogram': (items.Histogram,),
- }
- """Mapping kind to item classes of this kind"""
-
- @classmethod
- def _itemKind(cls, item):
- """Returns the "kind" of a given item
-
- :param Item item: The item get the kind
- :rtype: str
- """
- for kind, itemClasses in cls._KIND_TO_CLASSES.items():
- if isinstance(item, itemClasses):
- return kind
- raise ValueError('Unsupported item type %s' % type(item))
-
- def _notifyContentChanged(self, item):
- self.notify('contentChanged', action='add',
- kind=self._itemKind(item), legend=item.getName())
-
- def _itemRequiresUpdate(self, item):
- """Called by items in the plot for asynchronous update
-
- :param Item item: The item that required update
- """
- assert item.getPlot() == self
- # Put item at the end of the list
- if item in self._contentToUpdate:
- self._contentToUpdate.remove(item)
- self._contentToUpdate.append(item)
- self._setDirtyPlot(overlayOnly=item.isOverlay())
-
- def addItem(self, item=None, *args, **kwargs):
- """Add an item to the plot content.
-
- :param ~silx.gui.plot.items.Item item: The item to add.
- :raises ValueError: If item is already in the plot.
- """
- if not isinstance(item, items.Item):
- deprecated_warning(
- 'Function',
- 'addItem',
- replacement='addShape',
- since_version='0.13')
- if item is None and not args: # Only kwargs
- return self.addShape(**kwargs)
- else:
- return self.addShape(item, *args, **kwargs)
-
- assert not args and not kwargs
- if item in self.getItems():
- raise ValueError('Item already in the plot')
-
- # Add item to plot
- self._content[(item.getName(), self._itemKind(item))] = item
- item._setPlot(self)
- self._itemRequiresUpdate(item)
- if isinstance(item, items.DATA_ITEMS):
- self._invalidateDataRange() # TODO handle this automatically
-
- self._notifyContentChanged(item)
- self.sigItemAdded.emit(item)
-
- def removeItem(self, item):
- """Remove the item from the plot.
-
- :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
- :raises ValueError: If item is not in the plot.
- """
- if not isinstance(item, items.Item): # Previous method usage
- deprecated_warning(
- 'Function',
- 'removeItem',
- replacement='remove(legend, kind="item")',
- since_version='0.13')
- if item is None:
- return
- self.remove(item, kind='item')
- return
-
- if item not in self.getItems():
- raise ValueError('Item not in the plot')
-
- self.sigItemAboutToBeRemoved.emit(item)
-
- kind = self._itemKind(item)
-
- if kind in self._ACTIVE_ITEM_KINDS:
- if self._getActiveItem(kind) == item:
- # Reset active item
- self._setActiveItem(kind, None)
-
- # Remove item from plot
- self._content.pop((item.getName(), kind))
- if item in self._contentToUpdate:
- self._contentToUpdate.remove(item)
- if item.isVisible():
- self._setDirtyPlot(overlayOnly=item.isOverlay())
- if item.getBounds() is not None:
- self._invalidateDataRange()
- item._removeBackendRenderer(self._backend)
- item._setPlot(None)
-
- if (kind == 'curve' and not self.getAllCurves(just_legend=True,
- withhidden=True)):
- self._resetColorAndStyle()
-
- self.sigItemRemoved.emit(item)
-
- self.notify('contentChanged', action='remove',
- kind=kind, legend=item.getName())
-
- def discardItem(self, item) -> bool:
- """Remove the item from the plot.
-
- Same as :meth:`removeItem` but do not raise an exception.
-
- :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
- :returns: True if the item was present, False otherwise.
- """
- try:
- self.removeItem(item)
- except ValueError:
- return False
- else:
- return True
-
- @deprecated(replacement='addItem', since_version='0.13')
- def _add(self, item):
- return self.addItem(item)
-
- @deprecated(replacement='removeItem', since_version='0.13')
- def _remove(self, item):
- return self.removeItem(item)
-
- def getItems(self):
- """Returns the list of items in the plot
-
- :rtype: List[silx.gui.plot.items.Item]
- """
- return tuple(self._content.values())
-
- @contextmanager
- def _muteActiveItemChangedSignal(self):
- self.__muteActiveItemChanged = True
- yield
- self.__muteActiveItemChanged = False
-
- # Add
-
- # add * input arguments management:
- # If an arg is set, then use it.
- # Else:
- # If a curve with the same legend exists, then use its arg value
- # Else, use a default value.
- # Store used value.
- # This value is used when curve is updated either internally or by user.
-
- def addCurve(self, x, y, legend=None, info=None,
- replace=False,
- color=None, symbol=None,
- linewidth=None, linestyle=None,
- xlabel=None, ylabel=None, yaxis=None,
- xerror=None, yerror=None, z=None, selectable=None,
- fill=None, resetzoom=True,
- histogram=None, copy=True,
- baseline=None):
- """Add a 1D curve given by x an y to the graph.
-
- Curves are uniquely identified by their legend.
- To add multiple curves, call :meth:`addCurve` multiple times with
- different legend argument.
- To replace an existing curve, call :meth:`addCurve` with the
- existing curve legend.
- If you want to display the curve values as an histogram see the
- histogram parameter or :meth:`addHistogram`.
-
- When curve parameters are not provided, if a curve with the
- same legend is displayed in the plot, its parameters are used.
-
- :param numpy.ndarray x: The data corresponding to the x coordinates.
- If you attempt to plot an histogram you can set edges values in x.
- In this case len(x) = len(y) + 1.
- If x contains datetime objects the XAxis tickMode is set to
- TickMode.TIME_SERIES.
- :param numpy.ndarray y: The data corresponding to the y coordinates
- :param str legend: The legend to be associated to the curve (or None)
- :param info: User-defined information associated to the curve
- :param bool replace: True to delete already existing curves
- (the default is False)
- :param color: color(s) to be used
- :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- :param str symbol: Symbol to be drawn at each (x, y) position::
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
- - None (the default) to use default symbol
-
- :param float linewidth: The width of the curve in pixels (Default: 1).
- :param str linestyle: Type of line::
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
- - None (the default) to use default line style
-
- :param str xlabel: Label to show on the X axis when the curve is active
- or None to keep default axis label.
- :param str ylabel: Label to show on the Y axis when the curve is active
- or None to keep default axis label.
- :param str yaxis: The Y axis this curve is attached to.
- Either 'left' (the default) or 'right'
- :param xerror: Values with the uncertainties on the x values
- :type xerror: A float, or a numpy.ndarray of float32.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for positive errors,
- row 1 for negative errors.
- :param yerror: Values with the uncertainties on the y values
- :type yerror: A float, or a numpy.ndarray of float32. See xerror.
- :param int z: Layer on which to draw the curve (default: 1)
- This allows to control the overlay.
- :param bool selectable: Indicate if the curve can be selected.
- (Default: True)
- :param bool fill: True to fill the curve, False otherwise (default).
- :param bool resetzoom: True (the default) to reset the zoom.
- :param str histogram: if not None then the curve will be draw as an
- histogram. The step for each values of the curve can be set to the
- left, center or right of the original x curve values.
- If histogram is not None and len(x) == len(y)+1 then x is directly
- take as edges of the histogram.
- Type of histogram::
-
- - None (default)
- - 'left'
- - 'right'
- - 'center'
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- :param baseline: curve baseline
- :type: Union[None,float,numpy.ndarray]
- :returns: The key string identify this curve
- """
- # This is an histogram, use addHistogram
- if histogram is not None:
- histoLegend = self.addHistogram(histogram=y,
- edges=x,
- legend=legend,
- color=color,
- fill=fill,
- align=histogram,
- copy=copy)
- histo = self.getHistogram(histoLegend)
-
- histo.setInfo(info)
- if linewidth is not None:
- histo.setLineWidth(linewidth)
- if linestyle is not None:
- histo.setLineStyle(linestyle)
- if xlabel is not None:
- _logger.warning(
- 'addCurve: Histogram does not support xlabel argument')
- if ylabel is not None:
- _logger.warning(
- 'addCurve: Histogram does not support ylabel argument')
- if yaxis is not None:
- histo.setYAxis(yaxis)
- if z is not None:
- histo.setZValue(z)
- if selectable is not None:
- _logger.warning(
- 'addCurve: Histogram does not support selectable argument')
-
- return
-
- legend = 'Unnamed curve 1.1' if legend is None else str(legend)
-
- # Check if curve was previously active
- wasActive = self.getActiveCurve(just_legend=True) == legend
-
- if replace:
- self._resetColorAndStyle()
-
- # Create/Update curve object
- curve = self.getCurve(legend)
- mustBeAdded = curve is None
- if curve is None:
- # No previous curve, create a default one and add it to the plot
- curve = items.Curve() if histogram is None else items.Histogram()
- curve.setName(legend)
- # Set default color, linestyle and symbol
- default_color, default_linestyle = self._getColorAndStyle()
- curve.setColor(default_color)
- curve.setLineStyle(default_linestyle)
- curve.setSymbol(self._defaultPlotPoints)
- curve._setBaseline(baseline=baseline)
-
- # Do not emit sigActiveCurveChanged,
- # it will be sent once with _setActiveItem
- with self._muteActiveItemChangedSignal():
- # Override previous/default values with provided ones
- curve.setInfo(info)
- if color is not None:
- curve.setColor(color)
- if symbol is not None:
- curve.setSymbol(symbol)
- if linewidth is not None:
- curve.setLineWidth(linewidth)
- if linestyle is not None:
- curve.setLineStyle(linestyle)
- if xlabel is not None:
- curve._setXLabel(xlabel)
- if ylabel is not None:
- curve._setYLabel(ylabel)
- if yaxis is not None:
- curve.setYAxis(yaxis)
- if z is not None:
- curve.setZValue(z)
- if selectable is not None:
- curve._setSelectable(selectable)
- if fill is not None:
- curve.setFill(fill)
-
- # Set curve data
- # If errors not provided, reuse previous ones
- # TODO: Issue if size of data change but not that of errors
- if xerror is None:
- xerror = curve.getXErrorData(copy=False)
- if yerror is None:
- yerror = curve.getYErrorData(copy=False)
-
- # Convert x to timestamps so that the internal representation
- # remains floating points. The user is expected to set the axis'
- # tickMode to TickMode.TIME_SERIES and, if necessary, set the axis
- # to the correct time zone.
- if len(x) > 0 and isinstance(x[0], dt.datetime):
- x = [timestamp(d) for d in x]
-
- curve.setData(x, y, xerror, yerror, baseline=baseline, copy=copy)
-
- if replace: # Then remove all other curves
- for c in self.getAllCurves(withhidden=True):
- if c is not curve:
- self.removeItem(c)
-
- if mustBeAdded:
- self.addItem(curve)
- else:
- self._notifyContentChanged(curve)
-
- if wasActive:
- self.setActiveCurve(curve.getName())
- elif self.getActiveCurveSelectionMode() == "legacy":
- if self.getActiveCurve(just_legend=True) is None:
- if len(self.getAllCurves(just_legend=True,
- withhidden=False)) == 1:
- if curve.isVisible():
- self.setActiveCurve(curve.getName())
-
- if resetzoom:
- # We ask for a zoom reset in order to handle the plot scaling
- # if the user does not want that, autoscale of the different
- # axes has to be set to off.
- self.resetZoom()
-
- return legend
-
- def addHistogram(self,
- histogram,
- edges,
- legend=None,
- color=None,
- fill=None,
- align='center',
- resetzoom=True,
- copy=True,
- z=None,
- baseline=None):
- """Add an histogram to the graph.
-
- This is NOT computing the histogram, this method takes as parameter
- already computed histogram values.
-
- Histogram are uniquely identified by their legend.
- To add multiple histograms, call :meth:`addHistogram` multiple times
- with different legend argument.
-
- When histogram parameters are not provided, if an histogram with the
- same legend is displayed in the plot, its parameters are used.
-
- :param numpy.ndarray histogram: The values of the histogram.
- :param numpy.ndarray edges:
- The bin edges of the histogram.
- If histogram and edges have the same length, the bin edges
- are computed according to the align parameter.
- :param str legend:
- The legend to be associated to the histogram (or None)
- :param color: color to be used
- :type color: str ("#RRGGBB") or RGB unsigned byte array or
- one of the predefined color names defined in colors.py
- :param bool fill: True to fill the curve, False otherwise (default).
- :param str align:
- In case histogram values and edges have the same length N,
- the N+1 bin edges are computed according to the alignment in:
- 'center' (default), 'left', 'right'.
- :param bool resetzoom: True (the default) to reset the zoom.
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- :param int z: Layer on which to draw the histogram
- :param baseline: histogram baseline
- :type: Union[None,float,numpy.ndarray]
- :returns: The key string identify this histogram
- """
- legend = 'Unnamed histogram' if legend is None else str(legend)
-
- # Create/Update histogram object
- histo = self.getHistogram(legend)
- mustBeAdded = histo is None
- if histo is None:
- # No previous histogram, create a default one and
- # add it to the plot
- histo = items.Histogram()
- histo.setName(legend)
- histo.setColor(self._getColorAndStyle()[0])
-
- # Override previous/default values with provided ones
- if color is not None:
- histo.setColor(color)
- if fill is not None:
- histo.setFill(fill)
- if z is not None:
- histo.setZValue(z=z)
-
- # Set histogram data
- histo.setData(histogram=histogram, edges=edges, baseline=baseline,
- align=align, copy=copy)
-
- if mustBeAdded:
- self.addItem(histo)
- else:
- self._notifyContentChanged(histo)
-
- if resetzoom:
- # We ask for a zoom reset in order to handle the plot scaling
- # if the user does not want that, autoscale of the different
- # axes has to be set to off.
- self.resetZoom()
-
- return legend
-
- def addImage(self, data, legend=None, info=None,
- replace=False,
- z=None,
- selectable=None, draggable=None,
- colormap=None, pixmap=None,
- xlabel=None, ylabel=None,
- origin=None, scale=None,
- resetzoom=True, copy=True):
- """Add a 2D dataset or an image to the plot.
-
- It displays either an array of data using a colormap or a RGB(A) image.
-
- Images are uniquely identified by their legend.
- To add multiple images, call :meth:`addImage` multiple times with
- different legend argument.
- To replace/update an existing image, call :meth:`addImage` with the
- existing image legend.
-
- When image parameters are not provided, if an image with the
- same legend is displayed in the plot, its parameters are used.
-
- :param numpy.ndarray data:
- (nrows, ncolumns) data or
- (nrows, ncolumns, RGBA) ubyte array
- Note: boolean values are converted to int8.
- :param str legend: The legend to be associated to the image (or None)
- :param info: User-defined information associated to the image
- :param bool replace:
- True to delete already existing images (Default: False).
- :param int z: Layer on which to draw the image (default: 0)
- This allows to control the overlay.
- :param bool selectable: Indicate if the image can be selected.
- (default: False)
- :param bool draggable: Indicate if the image can be moved.
- (default: False)
- :param colormap: Colormap object to use (or None).
- This is ignored if data is a RGB(A) image.
- :type colormap: Union[~silx.gui.colors.Colormap, dict]
- :param pixmap: Pixmap representation of the data (if any)
- :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default)
- :param str xlabel: X axis label to show when this curve is active,
- or None to keep default axis label.
- :param str ylabel: Y axis label to show when this curve is active,
- or None to keep default axis label.
- :param origin: (origin X, origin Y) of the data.
- It is possible to pass a single float if both
- coordinates are equal.
- Default: (0., 0.)
- :type origin: float or 2-tuple of float
- :param scale: (scale X, scale Y) of the data.
- It is possible to pass a single float if both
- coordinates are equal.
- Default: (1., 1.)
- :type scale: float or 2-tuple of float
- :param bool resetzoom: True (the default) to reset the zoom.
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- :returns: The key string identify this image
- """
- legend = "Unnamed Image 1.1" if legend is None else str(legend)
-
- # Check if image was previously active
- wasActive = self.getActiveImage(just_legend=True) == legend
-
- data = numpy.array(data, copy=False)
- assert data.ndim in (2, 3)
-
- image = self.getImage(legend)
- if image is not None and image.getData(copy=False).ndim != data.ndim:
- # Update a data image with RGBA image or the other way around:
- # Remove previous image
- # In this case, we don't retrieve defaults from the previous image
- self.removeItem(image)
- image = None
-
- mustBeAdded = image is None
- if image is None:
- # No previous image, create a default one and add it to the plot
- if data.ndim == 2:
- image = items.ImageData()
- image.setColormap(self.getDefaultColormap())
- else:
- image = items.ImageRgba()
- image.setName(legend)
-
- # Do not emit sigActiveImageChanged,
- # it will be sent once with _setActiveItem
- with self._muteActiveItemChangedSignal():
- # Override previous/default values with provided ones
- image.setInfo(info)
- if origin is not None:
- image.setOrigin(origin)
- if scale is not None:
- image.setScale(scale)
- if z is not None:
- image.setZValue(z)
- if selectable is not None:
- image._setSelectable(selectable)
- if draggable is not None:
- image._setDraggable(draggable)
- if colormap is not None and isinstance(image, items.ColormapMixIn):
- if isinstance(colormap, dict):
- image.setColormap(Colormap._fromDict(colormap))
- else:
- assert isinstance(colormap, Colormap)
- image.setColormap(colormap)
- if xlabel is not None:
- image._setXLabel(xlabel)
- if ylabel is not None:
- image._setYLabel(ylabel)
-
- if data.ndim == 2:
- image.setData(data, alternative=pixmap, copy=copy)
- else: # RGB(A) image
- if pixmap is not None:
- _logger.warning(
- 'addImage: pixmap argument ignored when data is RGB(A)')
- image.setData(data, copy=copy)
-
- if replace:
- for img in self.getAllImages():
- if img is not image:
- self.removeItem(img)
-
- if mustBeAdded:
- self.addItem(image)
- else:
- self._notifyContentChanged(image)
-
- if len(self.getAllImages()) == 1 or wasActive:
- self.setActiveImage(legend)
-
- if resetzoom:
- # We ask for a zoom reset in order to handle the plot scaling
- # if the user does not want that, autoscale of the different
- # axes has to be set to off.
- self.resetZoom()
-
- return legend
-
- def addScatter(self, x, y, value, legend=None, colormap=None,
- info=None, symbol=None, xerror=None, yerror=None,
- z=None, copy=True):
- """Add a (x, y, value) scatter to the graph.
-
- Scatters are uniquely identified by their legend.
- To add multiple scatters, call :meth:`addScatter` multiple times with
- different legend argument.
- To replace/update an existing scatter, call :meth:`addScatter` with the
- existing scatter legend.
-
- When scatter parameters are not provided, if a scatter with the
- same legend is displayed in the plot, its parameters are used.
-
- :param numpy.ndarray x: The data corresponding to the x coordinates.
- :param numpy.ndarray y: The data corresponding to the y coordinates
- :param numpy.ndarray value: The data value associated with each point
- :param str legend: The legend to be associated to the scatter (or None)
- :param ~silx.gui.colors.Colormap colormap:
- Colormap object to be used for the scatter (or None)
- :param info: User-defined information associated to the curve
- :param str symbol: Symbol to be drawn at each (x, y) position::
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
- - None (the default) to use default symbol
-
- :param xerror: Values with the uncertainties on the x values
- :type xerror: A float, or a numpy.ndarray of float32.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for positive errors,
- row 1 for negative errors.
- :param yerror: Values with the uncertainties on the y values
- :type yerror: A float, or a numpy.ndarray of float32. See xerror.
- :param int z: Layer on which to draw the scatter (default: 1)
- This allows to control the overlay.
-
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- :returns: The key string identify this scatter
- """
- legend = 'Unnamed scatter 1.1' if legend is None else str(legend)
-
- # Check if scatter was previously active
- wasActive = self._getActiveItem(kind='scatter',
- just_legend=True) == legend
-
- # Create/Update curve object
- scatter = self._getItem(kind='scatter', legend=legend)
- mustBeAdded = scatter is None
- if scatter is None:
- # No previous scatter, create a default one and add it to the plot
- scatter = items.Scatter()
- scatter.setName(legend)
- scatter.setColormap(self.getDefaultColormap())
-
- # Do not emit sigActiveScatterChanged,
- # it will be sent once with _setActiveItem
- with self._muteActiveItemChangedSignal():
- # Override previous/default values with provided ones
- scatter.setInfo(info)
- if symbol is not None:
- scatter.setSymbol(symbol)
- if z is not None:
- scatter.setZValue(z)
- if colormap is not None:
- if isinstance(colormap, dict):
- scatter.setColormap(Colormap._fromDict(colormap))
- else:
- assert isinstance(colormap, Colormap)
- scatter.setColormap(colormap)
-
- # Set scatter data
- # If errors not provided, reuse previous ones
- if xerror is None:
- xerror = scatter.getXErrorData(copy=False)
- if xerror is not None and len(xerror) != len(x):
- xerror = None
- if yerror is None:
- yerror = scatter.getYErrorData(copy=False)
- if yerror is not None and len(yerror) != len(y):
- yerror = None
-
- scatter.setData(x, y, value, xerror, yerror, copy=copy)
-
- if mustBeAdded:
- self.addItem(scatter)
- else:
- self._notifyContentChanged(scatter)
-
- scatters = [item for item in self.getItems()
- if isinstance(item, items.Scatter) and item.isVisible()]
- if len(scatters) == 1 or wasActive:
- self._setActiveItem('scatter', scatter.getName())
-
- return legend
-
- def addShape(self, xdata, ydata, legend=None, info=None,
- replace=False,
- shape="polygon", color='black', fill=True,
- overlay=False, z=None, linestyle="-", linewidth=1.0,
- linebgcolor=None):
- """Add an item (i.e. a shape) to the plot.
-
- Items are uniquely identified by their legend.
- To add multiple items, call :meth:`addItem` multiple times with
- different legend argument.
- To replace/update an existing item, call :meth:`addItem` with the
- existing item legend.
-
- :param numpy.ndarray xdata: The X coords of the points of the shape
- :param numpy.ndarray ydata: The Y coords of the points of the shape
- :param str legend: The legend to be associated to the item
- :param info: User-defined information associated to the item
- :param bool replace: True (default) to delete already existing images
- :param str shape: Type of item to be drawn in
- hline, polygon (the default), rectangle, vline,
- polylines
- :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000'
- (Default: 'black')
- :param bool fill: True (the default) to fill the shape
- :param bool overlay: True if item is an overlay (Default: False).
- This allows for rendering optimization if this
- item is changed often.
- :param int z: Layer on which to draw the item (default: 2)
- :param str linestyle: Style of the line.
- Only relevant for line markers where X or Y is None.
- Value in:
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
- :param float linewidth: Width of the line.
- Only relevant for line markers where X or Y is None.
- :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
- '#FF0000'. It is used to draw dotted line using a second color.
- :returns: The key string identify this item
- """
- # expected to receive the same parameters as the signal
-
- legend = "Unnamed Item 1.1" if legend is None else str(legend)
-
- z = int(z) if z is not None else 2
-
- if replace:
- self.remove(kind='item')
- else:
- self.remove(legend, kind='item')
-
- item = items.Shape(shape)
- item.setName(legend)
- item.setInfo(info)
- item.setColor(color)
- item.setFill(fill)
- item.setOverlay(overlay)
- item.setZValue(z)
- item.setPoints(numpy.array((xdata, ydata)).T)
- item.setLineStyle(linestyle)
- item.setLineWidth(linewidth)
- item.setLineBgColor(linebgcolor)
-
- self.addItem(item)
-
- return legend
-
- def addXMarker(self, x, legend=None,
- text=None,
- color=None,
- selectable=False,
- draggable=False,
- constraint=None,
- yaxis='left'):
- """Add a vertical line marker to the plot.
-
- Markers are uniquely identified by their legend.
- As opposed to curves, images and items, two calls to
- :meth:`addXMarker` without legend argument adds two markers with
- different identifying legends.
-
- :param x: Position of the marker on the X axis in data coordinates
- :type x: Union[None, float]
- :param str legend: Legend associated to the marker to identify it
- :param str text: Text to display on the marker.
- :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
- (Default: 'black')
- :param bool selectable: Indicate if the marker can be selected.
- (default: False)
- :param bool draggable: Indicate if the marker can be moved.
- (default: False)
- :param constraint: A function filtering marker displacement by
- dragging operations or None for no filter.
- This function is called each time a marker is
- moved.
- This parameter is only used if draggable is True.
- :type constraint: None or a callable that takes the coordinates of
- the current cursor position in the plot as input
- and that returns the filtered coordinates.
- :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: The key string identify this marker
- """
- return self._addMarker(x=x, y=None, legend=legend,
- text=text, color=color,
- selectable=selectable, draggable=draggable,
- symbol=None, constraint=constraint,
- yaxis=yaxis)
-
- def addYMarker(self, y,
- legend=None,
- text=None,
- color=None,
- selectable=False,
- draggable=False,
- constraint=None,
- yaxis='left'):
- """Add a horizontal line marker to the plot.
-
- Markers are uniquely identified by their legend.
- As opposed to curves, images and items, two calls to
- :meth:`addYMarker` without legend argument adds two markers with
- different identifying legends.
-
- :param float y: Position of the marker on the Y axis in data
- coordinates
- :param str legend: Legend associated to the marker to identify it
- :param str text: Text to display next to the marker.
- :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
- (Default: 'black')
- :param bool selectable: Indicate if the marker can be selected.
- (default: False)
- :param bool draggable: Indicate if the marker can be moved.
- (default: False)
- :param constraint: A function filtering marker displacement by
- dragging operations or None for no filter.
- This function is called each time a marker is
- moved.
- This parameter is only used if draggable is True.
- :type constraint: None or a callable that takes the coordinates of
- the current cursor position in the plot as input
- and that returns the filtered coordinates.
- :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: The key string identify this marker
- """
- return self._addMarker(x=None, y=y, legend=legend,
- text=text, color=color,
- selectable=selectable, draggable=draggable,
- symbol=None, constraint=constraint,
- yaxis=yaxis)
-
- def addMarker(self, x, y, legend=None,
- text=None,
- color=None,
- selectable=False,
- draggable=False,
- symbol='+',
- constraint=None,
- yaxis='left'):
- """Add a point marker to the plot.
-
- Markers are uniquely identified by their legend.
- As opposed to curves, images and items, two calls to
- :meth:`addMarker` without legend argument adds two markers with
- different identifying legends.
-
- :param float x: Position of the marker on the X axis in data
- coordinates
- :param float y: Position of the marker on the Y axis in data
- coordinates
- :param str legend: Legend associated to the marker to identify it
- :param str text: Text to display next to the marker
- :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
- (Default: 'black')
- :param bool selectable: Indicate if the marker can be selected.
- (default: False)
- :param bool draggable: Indicate if the marker can be moved.
- (default: False)
- :param str symbol: Symbol representing the marker in::
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross (the default)
- - 'x' x-cross
- - 'd' diamond
- - 's' square
-
- :param constraint: A function filtering marker displacement by
- dragging operations or None for no filter.
- This function is called each time a marker is
- moved.
- This parameter is only used if draggable is True.
- :type constraint: None or a callable that takes the coordinates of
- the current cursor position in the plot as input
- and that returns the filtered coordinates.
- :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: The key string identify this marker
- """
- if x is None:
- xmin, xmax = self._xAxis.getLimits()
- x = 0.5 * (xmax + xmin)
-
- if y is None:
- ymin, ymax = self._yAxis.getLimits()
- y = 0.5 * (ymax + ymin)
-
- return self._addMarker(x=x, y=y, legend=legend,
- text=text, color=color,
- selectable=selectable, draggable=draggable,
- symbol=symbol, constraint=constraint,
- yaxis=yaxis)
-
- def _addMarker(self, x, y, legend,
- text, color,
- selectable, draggable,
- symbol, constraint,
- yaxis=None):
- """Common method for adding point, vline and hline marker.
-
- See :meth:`addMarker` for argument documentation.
- """
- assert (x, y) != (None, None)
-
- if legend is None: # Find an unused legend
- markerLegends = [item.getName() for item in self.getItems()
- if isinstance(item, items.MarkerBase)]
- for index in itertools.count():
- legend = "Unnamed Marker %d" % index
- if legend not in markerLegends:
- break # Keep this legend
- legend = str(legend)
-
- if x is None:
- markerClass = items.YMarker
- elif y is None:
- markerClass = items.XMarker
- else:
- markerClass = items.Marker
-
- # Create/Update marker object
- marker = self._getMarker(legend)
- if marker is not None and not isinstance(marker, markerClass):
- _logger.warning('Adding marker with same legend'
- ' but different type replaces it')
- self.removeItem(marker)
- marker = None
-
- mustBeAdded = marker is None
- if marker is None:
- # No previous marker, create one
- marker = markerClass()
- marker.setName(legend)
-
- if text is not None:
- marker.setText(text)
- if color is not None:
- marker.setColor(color)
- if selectable is not None:
- marker._setSelectable(selectable)
- if draggable is not None:
- marker._setDraggable(draggable)
- if symbol is not None:
- marker.setSymbol(symbol)
- marker.setYAxis(yaxis)
-
- # TODO to improve, but this ensure constraint is applied
- marker.setPosition(x, y)
- if constraint is not None:
- marker._setConstraint(constraint)
- marker.setPosition(x, y)
-
- if mustBeAdded:
- self.addItem(marker)
- else:
- self._notifyContentChanged(marker)
-
- return legend
-
- # Hide
-
- def isCurveHidden(self, legend):
- """Returns True if the curve associated to legend is hidden, else False
-
- :param str legend: The legend key identifying the curve
- :return: True if the associated curve is hidden, False otherwise
- """
- curve = self._getItem('curve', legend)
- return curve is not None and not curve.isVisible()
-
- def hideCurve(self, legend, flag=True):
- """Show/Hide the curve associated to legend.
-
- Even when hidden, the curve is kept in the list of curves.
-
- :param str legend: The legend associated to the curve to be hidden
- :param bool flag: True (default) to hide the curve, False to show it
- """
- curve = self._getItem('curve', legend)
- if curve is None:
- _logger.warning('Curve not in plot: %s', legend)
- return
-
- isVisible = not flag
- if isVisible != curve.isVisible():
- curve.setVisible(isVisible)
-
- # Remove
-
- ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram'
- """List of supported kind of items in the plot."""
-
- _ACTIVE_ITEM_KINDS = 'curve', 'scatter', 'image'
- """List of item's kind which have a active item."""
-
- def remove(self, legend=None, kind=ITEM_KINDS):
- """Remove one or all element(s) of the given legend and kind.
-
- Examples:
-
- - ``remove()`` clears the plot
- - ``remove(kind='curve')`` removes all curves from the plot
- - ``remove('myCurve', kind='curve')`` removes the curve with
- legend 'myCurve' from the plot.
- - ``remove('myImage, kind='image')`` removes the image with
- legend 'myImage' from the plot.
- - ``remove('myImage')`` removes elements (for instance curve, image,
- item and marker) with legend 'myImage'.
-
- :param str legend: The legend associated to the element to remove,
- or None to remove
- :param kind: The kind of elements to remove from the plot.
- See :attr:`ITEM_KINDS`.
- By default, it removes all kind of elements.
- :type kind: str or tuple of str to specify multiple kinds.
- """
- if kind == 'all': # Replace all by tuple of all kinds
- kind = self.ITEM_KINDS
-
- if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
- kind = (kind,)
-
- for aKind in kind:
- assert aKind in self.ITEM_KINDS
-
- if legend is None: # This is a clear
- # Clear each given kind
- for aKind in kind:
- for item in self.getItems():
- if (isinstance(item, self._KIND_TO_CLASSES[aKind]) and
- item.getPlot() is self): # Make sure item is still in the plot
- self.removeItem(item)
-
- else: # This is removing a single element
- # Remove each given kind
- for aKind in kind:
- item = self._getItem(aKind, legend)
- if item is not None:
- self.removeItem(item)
-
- def removeCurve(self, legend):
- """Remove the curve associated to legend from the graph.
-
- :param str legend: The legend associated to the curve to be deleted
- """
- if legend is None:
- return
- self.remove(legend, kind='curve')
-
- def removeImage(self, legend):
- """Remove the image associated to legend from the graph.
-
- :param str legend: The legend associated to the image to be deleted
- """
- if legend is None:
- return
- self.remove(legend, kind='image')
-
- def removeMarker(self, legend):
- """Remove the marker associated to legend from the graph.
-
- :param str legend: The legend associated to the marker to be deleted
- """
- if legend is None:
- return
- self.remove(legend, kind='marker')
-
- # Clear
-
- def clear(self):
- """Remove everything from the plot."""
- for item in self.getItems():
- if item.getPlot() is self: # Make sure item is still in the plot
- self.removeItem(item)
-
- def clearCurves(self):
- """Remove all the curves from the plot."""
- self.remove(kind='curve')
-
- def clearImages(self):
- """Remove all the images from the plot."""
- self.remove(kind='image')
-
- def clearItems(self):
- """Remove all the items from the plot. """
- self.remove(kind='item')
-
- def clearMarkers(self):
- """Remove all the markers from the plot."""
- self.remove(kind='marker')
-
- # Interaction
-
- def getGraphCursor(self):
- """Returns the state of the crosshair cursor.
-
- See :meth:`setGraphCursor`.
-
- :return: None if the crosshair cursor is not active,
- else a tuple (color, linewidth, linestyle).
- """
- return self._cursorConfiguration
-
- def setGraphCursor(self, flag=False, color='black',
- linewidth=1, linestyle='-'):
- """Toggle the display of a crosshair cursor and set its attributes.
-
- :param bool flag: Toggle the display of a crosshair cursor.
- The crosshair cursor is hidden by default.
- :param color: The color to use for the crosshair.
- :type color: A string (either a predefined color name in colors.py
- or "#RRGGBB")) or a 4 columns unsigned byte array
- (Default: black).
- :param int linewidth: The width of the lines of the crosshair
- (Default: 1).
- :param str linestyle: Type of line::
-
- - ' ' no line
- - '-' solid line (the default)
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
- """
- if flag:
- self._cursorConfiguration = color, linewidth, linestyle
- else:
- self._cursorConfiguration = None
-
- self._backend.setGraphCursor(flag=flag, color=color,
- linewidth=linewidth, linestyle=linestyle)
- self._setDirtyPlot()
- self.notify('setGraphCursor',
- state=self._cursorConfiguration is not None)
-
- def pan(self, direction, factor=0.1):
- """Pan the graph in the given direction by the given factor.
-
- Warning: Pan of right Y axis not implemented!
-
- :param str direction: One of 'up', 'down', 'left', 'right'.
- :param float factor: Proportion of the range used to pan the graph.
- Must be strictly positive.
- """
- assert direction in ('up', 'down', 'left', 'right')
- assert factor > 0.
-
- if direction in ('left', 'right'):
- xFactor = factor if direction == 'right' else - factor
- xMin, xMax = self._xAxis.getLimits()
-
- xMin, xMax = _utils.applyPan(xMin, xMax, xFactor,
- self._xAxis.getScale() == self._xAxis.LOGARITHMIC)
- self._xAxis.setLimits(xMin, xMax)
-
- else: # direction in ('up', 'down')
- sign = -1. if self._yAxis.isInverted() else 1.
- yFactor = sign * (factor if direction == 'up' else -factor)
- yMin, yMax = self._yAxis.getLimits()
- yIsLog = self._yAxis.getScale() == self._yAxis.LOGARITHMIC
-
- yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog)
- self._yAxis.setLimits(yMin, yMax)
-
- y2Min, y2Max = self._yRightAxis.getLimits()
-
- y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog)
- self._yRightAxis.setLimits(y2Min, y2Max)
-
- # Active Curve/Image
-
- def isActiveCurveHandling(self):
- """Returns True if active curve selection is enabled.
-
- :rtype: bool
- """
- return self.getActiveCurveSelectionMode() != 'none'
-
- def setActiveCurveHandling(self, flag=True):
- """Enable/Disable active curve selection.
-
- :param bool flag: True to enable 'atmostone' active curve selection,
- False to disable active curve selection.
- """
- self.setActiveCurveSelectionMode('atmostone' if flag else 'none')
-
- def getActiveCurveStyle(self):
- """Returns the current style applied to active curve
-
- :rtype: CurveStyle
- """
- return self._activeCurveStyle
-
- def setActiveCurveStyle(self,
- color=None,
- linewidth=None,
- linestyle=None,
- symbol=None,
- symbolsize=None):
- """Set the style of active curve
-
- :param color: Color
- :param Union[str,None] linestyle: Style of the line
- :param Union[float,None] linewidth: Width of the line
- :param Union[str,None] symbol: Symbol of the markers
- :param Union[float,None] symbolsize: Size of the symbols
- """
- self._activeCurveStyle = CurveStyle(color=color,
- linewidth=linewidth,
- linestyle=linestyle,
- symbol=symbol,
- symbolsize=symbolsize)
- curve = self.getActiveCurve()
- if curve is not None:
- curve.setHighlightedStyle(self.getActiveCurveStyle())
-
- @deprecated(replacement="getActiveCurveStyle", since_version="0.9")
- def getActiveCurveColor(self):
- """Get the color used to display the currently active curve.
-
- See :meth:`setActiveCurveColor`.
- """
- return self._activeCurveStyle.getColor()
-
- @deprecated(replacement="setActiveCurveStyle", since_version="0.9")
- def setActiveCurveColor(self, color="#000000"):
- """Set the color to use to display the currently active curve.
-
- :param str color: Color of the active curve,
- e.g., 'blue', 'b', '#FF0000' (Default: 'black')
- """
- if color is None:
- color = "black"
- if color in self.colorDict:
- color = self.colorDict[color]
- self.setActiveCurveStyle(color=color)
-
- def getActiveCurve(self, just_legend=False):
- """Return the currently active curve.
-
- It returns None in case of not having an active curve.
-
- :param bool just_legend: True to get the legend of the curve,
- False (the default) to get the curve data
- and info.
- :return: Active curve's legend or corresponding
- :class:`.items.Curve`
- :rtype: str or :class:`.items.Curve` or None
- """
- if not self.isActiveCurveHandling():
- return None
-
- return self._getActiveItem(kind='curve', just_legend=just_legend)
-
- def setActiveCurve(self, legend):
- """Make the curve associated to legend the active curve.
-
- :param legend: The legend associated to the curve
- or None to have no active curve.
- :type legend: str or None
- """
- if not self.isActiveCurveHandling():
- return
- if legend is None and self.getActiveCurveSelectionMode() == "legacy":
- _logger.info(
- 'setActiveCurve(None) ignored due to active curve selection mode')
- return
-
- return self._setActiveItem(kind='curve', legend=legend)
-
- def setActiveCurveSelectionMode(self, mode):
- """Sets the current selection mode.
-
- :param str mode: The active curve selection mode to use.
- It can be: 'legacy', 'atmostone' or 'none'.
- """
- assert mode in ('legacy', 'atmostone', 'none')
-
- if mode != self._activeCurveSelectionMode:
- self._activeCurveSelectionMode = mode
- if mode == 'none': # reset active curve
- self._setActiveItem(kind='curve', legend=None)
-
- elif mode == 'legacy' and self.getActiveCurve() is None:
- # Select an active curve
- curves = self.getAllCurves(just_legend=False,
- withhidden=False)
- if len(curves) == 1:
- if curves[0].isVisible():
- self.setActiveCurve(curves[0].getName())
-
- def getActiveCurveSelectionMode(self):
- """Returns the current selection mode.
-
- It can be "atmostone", "legacy" or "none".
-
- :rtype: str
- """
- return self._activeCurveSelectionMode
-
- def getActiveImage(self, just_legend=False):
- """Returns the currently active image.
-
- It returns None in case of not having an active image.
-
- :param bool just_legend: True to get the legend of the image,
- False (the default) to get the image data
- and info.
- :return: Active image's legend or corresponding image object
- :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba`
- or None
- """
- return self._getActiveItem(kind='image', just_legend=just_legend)
-
- def setActiveImage(self, legend):
- """Make the image associated to legend the active image.
-
- :param str legend: The legend associated to the image
- or None to have no active image.
- """
- return self._setActiveItem(kind='image', legend=legend)
-
- def getActiveScatter(self, just_legend=False):
- """Returns the currently active scatter.
-
- It returns None in case of not having an active scatter.
-
- :param bool just_legend: True to get the legend of the scatter,
- False (the default) to get the scatter data
- and info.
- :return: Active scatter's legend or corresponding scatter object
- :rtype: str, :class:`.items.Scatter` or None
- """
- return self._getActiveItem(kind='scatter', just_legend=just_legend)
-
- def setActiveScatter(self, legend):
- """Make the scatter associated to legend the active scatter.
-
- :param str legend: The legend associated to the scatter
- or None to have no active scatter.
- """
- return self._setActiveItem(kind='scatter', legend=legend)
-
- def _getActiveItem(self, kind, just_legend=False):
- """Return the currently active item of that kind if any
-
- :param str kind: Type of item: 'curve', 'scatter' or 'image'
- :param bool just_legend: True to get the legend,
- False (default) to get the item
- :return: legend or item or None if no active item
- """
- assert kind in self._ACTIVE_ITEM_KINDS
-
- if self._activeLegend[kind] is None:
- return None
-
- item = self._getItem(kind, self._activeLegend[kind])
- if item is None:
- return None
-
- return item.getName() if just_legend else item
-
- def _setActiveItem(self, kind, legend):
- """Make the curve associated to legend the active curve.
-
- :param str kind: Type of item: 'curve' or 'image'
- :param legend: The legend associated to the curve
- or None to have no active curve.
- :type legend: str or None
- """
- assert kind in self._ACTIVE_ITEM_KINDS
-
- xLabel = None
- yLabel = None
- yRightLabel = None
-
- oldActiveItem = self._getActiveItem(kind=kind)
-
- if oldActiveItem is not None: # Stop listening previous active image
- oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged)
-
- # Curve specific: Reset highlight of previous active curve
- if kind == 'curve' and oldActiveItem is not None:
- oldActiveItem.setHighlighted(False)
-
- if legend is None:
- self._activeLegend[kind] = None
- else:
- legend = str(legend)
- item = self._getItem(kind, legend)
- if item is None:
- _logger.warning("This %s does not exist: %s", kind, legend)
- self._activeLegend[kind] = None
- else:
- self._activeLegend[kind] = legend
-
- # Curve specific: handle highlight
- if kind == 'curve':
- item.setHighlightedStyle(self.getActiveCurveStyle())
- item.setHighlighted(True)
-
- if isinstance(item, items.LabelsMixIn):
- if item.getXLabel() is not None:
- xLabel = item.getXLabel()
- if item.getYLabel() is not None:
- if (isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right'):
- yRightLabel = item.getYLabel()
- else:
- yLabel = item.getYLabel()
-
- # Start listening new active item
- item.sigItemChanged.connect(self._activeItemChanged)
-
- # Store current labels and update plot
- self._xAxis._setCurrentLabel(xLabel)
- self._yAxis._setCurrentLabel(yLabel)
- self._yRightAxis._setCurrentLabel(yRightLabel)
-
- self._setDirtyPlot()
-
- activeLegend = self._activeLegend[kind]
- if oldActiveItem is not None or activeLegend is not None:
- if oldActiveItem is None:
- oldActiveLegend = None
- else:
- oldActiveLegend = oldActiveItem.getName()
- self.notify(
- 'active' + kind[0].upper() + kind[1:] + 'Changed',
- updated=oldActiveLegend != activeLegend,
- previous=oldActiveLegend,
- legend=activeLegend)
-
- return activeLegend
-
- def _activeItemChanged(self, type_):
- """Listen for active item changed signal and broadcast signal
-
- :param item.ItemChangedType type_: The type of item change
- """
- if not self.__muteActiveItemChanged:
- item = self.sender()
- if item is not None:
- kind = self._itemKind(item)
- self.notify(
- 'active' + kind[0].upper() + kind[1:] + 'Changed',
- updated=False,
- previous=item.getName(),
- legend=item.getName())
-
- # Getters
-
- def getAllCurves(self, just_legend=False, withhidden=False):
- """Returns all curves legend or info and data.
-
- It returns an empty list in case of not having any curve.
-
- If just_legend is False, it returns a list of :class:`items.Curve`
- objects describing the curves.
- If just_legend is True, it returns a list of curves' legend.
-
- :param bool just_legend: True to get the legend of the curves,
- False (the default) to get the curves' data
- and info.
- :param bool withhidden: False (default) to skip hidden curves.
- :return: list of curves' legend or :class:`.items.Curve`
- :rtype: list of str or list of :class:`.items.Curve`
- """
- curves = [item for item in self.getItems() if
- isinstance(item, items.Curve) and
- (withhidden or item.isVisible())]
- return [curve.getName() for curve in curves] if just_legend else curves
-
- def getCurve(self, legend=None):
- """Get the object describing a specific curve.
-
- It returns None in case no matching curve is found.
-
- :param str legend:
- The legend identifying the curve.
- If not provided or None (the default), the active curve is returned
- or if there is no active curve, the latest updated curve that is
- not hidden is returned if there are curves in the plot.
- :return: None or :class:`.items.Curve` object
- """
- return self._getItem(kind='curve', legend=legend)
-
- def getAllImages(self, just_legend=False):
- """Returns all images legend or objects.
-
- It returns an empty list in case of not having any image.
-
- If just_legend is False, it returns a list of :class:`items.ImageBase`
- objects describing the images.
- If just_legend is True, it returns a list of legends.
-
- :param bool just_legend: True to get the legend of the images,
- False (the default) to get the images'
- object.
- :return: list of images' legend or :class:`.items.ImageBase`
- :rtype: list of str or list of :class:`.items.ImageBase`
- """
- images = [item for item in self.getItems()
- if isinstance(item, items.ImageBase)]
- return [image.getName() for image in images] if just_legend else images
-
- def getImage(self, legend=None):
- """Get the object describing a specific image.
-
- It returns None in case no matching image is found.
-
- :param str legend:
- The legend identifying the image.
- If not provided or None (the default), the active image is returned
- or if there is no active image, the latest updated image
- is returned if there are images in the plot.
- :return: None or :class:`.items.ImageBase` object
- """
- return self._getItem(kind='image', legend=legend)
-
- def getScatter(self, legend=None):
- """Get the object describing a specific scatter.
-
- It returns None in case no matching scatter is found.
-
- :param str legend:
- The legend identifying the scatter.
- If not provided or None (the default), the active scatter is
- returned or if there is no active scatter, the latest updated
- scatter is returned if there are scatters in the plot.
- :return: None or :class:`.items.Scatter` object
- """
- return self._getItem(kind='scatter', legend=legend)
-
- def getHistogram(self, legend=None):
- """Get the object describing a specific histogram.
-
- It returns None in case no matching histogram is found.
-
- :param str legend:
- The legend identifying the histogram.
- If not provided or None (the default), the latest updated scatter
- is returned if there are histograms in the plot.
- :return: None or :class:`.items.Histogram` object
- """
- return self._getItem(kind='histogram', legend=legend)
-
- @deprecated(replacement='getItems', since_version='0.13')
- def _getItems(self, kind=ITEM_KINDS, just_legend=False, withhidden=False):
- """Retrieve all items of a kind in the plot
-
- :param kind: The kind of elements to retrieve from the plot.
- See :attr:`ITEM_KINDS`.
- By default, it removes all kind of elements.
- :type kind: str or tuple of str to specify multiple kinds.
- :param str kind: Type of item: 'curve' or 'image'
- :param bool just_legend: True to get the legend of the curves,
- False (the default) to get the curves' data
- and info.
- :param bool withhidden: False (default) to skip hidden curves.
- :return: list of legends or item objects
- """
- if kind == 'all': # Replace all by tuple of all kinds
- kind = self.ITEM_KINDS
-
- if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
- kind = (kind,)
-
- for aKind in kind:
- assert aKind in self.ITEM_KINDS
-
- output = []
- for item in self.getItems():
- type_ = self._itemKind(item)
- if type_ in kind and (withhidden or item.isVisible()):
- output.append(item.getName() if just_legend else item)
- return output
-
- def _getItem(self, kind, legend=None):
- """Get an item from the plot: either an image or a curve.
-
- Returns None if no match found.
-
- :param str kind: Type of item to retrieve,
- see :attr:`ITEM_KINDS`.
- :param str legend: Legend of the item or
- None to get active or last item
- :return: Object describing the item or None
- """
- assert kind in self.ITEM_KINDS
-
- if legend is not None:
- return self._content.get((legend, kind), None)
- else:
- if kind in self._ACTIVE_ITEM_KINDS:
- item = self._getActiveItem(kind=kind)
- if item is not None: # Return active item if available
- return item
- # Return last visible item if any
- itemClasses = self._KIND_TO_CLASSES[kind]
- allItems = [item for item in self.getItems()
- if isinstance(item, itemClasses) and item.isVisible()]
- return allItems[-1] if allItems else None
-
- # Limits
-
- def _notifyLimitsChanged(self, emitSignal=True):
- """Send an event when plot area limits are changed."""
- xRange = self._xAxis.getLimits()
- yRange = self._yAxis.getLimits()
- y2Range = self._yRightAxis.getLimits()
- if emitSignal:
- axes = self.getXAxis(), self.getYAxis(), self.getYAxis(axis="right")
- ranges = xRange, yRange, y2Range
- for axis, limits in zip(axes, ranges):
- axis.sigLimitsChanged.emit(*limits)
- event = PlotEvents.prepareLimitsChangedSignal(
- id(self.getWidgetHandle()), xRange, yRange, y2Range)
- self.notify(**event)
-
- def getLimitsHistory(self):
- """Returns the object handling the history of limits of the plot"""
- return self._limitsHistory
-
- def getGraphXLimits(self):
- """Get the graph X (bottom) limits.
-
- :return: Minimum and maximum values of the X axis
- """
- return self._backend.getGraphXLimits()
-
- def setGraphXLimits(self, xmin, xmax):
- """Set the graph X (bottom) limits.
-
- :param float xmin: minimum bottom axis value
- :param float xmax: maximum bottom axis value
- """
- self._xAxis.setLimits(xmin, xmax)
-
- def getGraphYLimits(self, axis='left'):
- """Get the graph Y limits.
-
- :param str axis: The axis for which to get the limits:
- Either 'left' or 'right'
- :return: Minimum and maximum values of the X axis
- """
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
- return yAxis.getLimits()
-
- def setGraphYLimits(self, ymin, ymax, axis='left'):
- """Set the graph Y limits.
-
- :param float ymin: minimum bottom axis value
- :param float ymax: maximum bottom axis value
- :param str axis: The axis for which to get the limits:
- Either 'left' or 'right'
- """
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
- return yAxis.setLimits(ymin, ymax)
-
- def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
- """Set the limits of the X and Y axes at once.
-
- If y2min or y2max is None, the right Y axis limits are not updated.
-
- :param float xmin: minimum bottom axis value
- :param float xmax: maximum bottom axis value
- :param float ymin: minimum left axis value
- :param float ymax: maximum left axis value
- :param float y2min: minimum right axis value or None (the default)
- :param float y2max: maximum right axis value or None (the default)
- """
- # Deal with incorrect values
- axis = self.getXAxis()
- xmin, xmax = axis._checkLimits(xmin, xmax)
- axis = self.getYAxis()
- ymin, ymax = axis._checkLimits(ymin, ymax)
-
- if y2min is None or y2max is None:
- # if one limit is None, both are ignored
- y2min, y2max = None, None
- else:
- axis = self.getYAxis(axis="right")
- y2min, y2max = axis._checkLimits(y2min, y2max)
-
- if self._viewConstrains:
- view = self._viewConstrains.normalize(xmin, xmax, ymin, ymax)
- xmin, xmax, ymin, ymax = view
-
- self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
- self._setDirtyPlot()
- self._notifyLimitsChanged()
-
- def _getViewConstraints(self):
- """Return the plot object managing constaints on the plot view.
-
- :rtype: ViewConstraints
- """
- if self._viewConstrains is None:
- self._viewConstrains = ViewConstraints()
- return self._viewConstrains
-
- # Title and labels
-
- def getGraphTitle(self):
- """Return the plot main title as a str."""
- return self._graphTitle
-
- def setGraphTitle(self, title=""):
- """Set the plot main title.
-
- :param str title: Main title of the plot (default: '')
- """
- self._graphTitle = str(title)
- self._backend.setGraphTitle(title)
- self._setDirtyPlot()
-
- def getGraphXLabel(self):
- """Return the current X axis label as a str."""
- return self._xAxis.getLabel()
-
- def setGraphXLabel(self, label="X"):
- """Set the plot X axis label.
-
- The provided label can be temporarily replaced by the X label of the
- active curve if any.
-
- :param str label: The X axis label (default: 'X')
- """
- self._xAxis.setLabel(label)
-
- def getGraphYLabel(self, axis='left'):
- """Return the current Y axis label as a str.
-
- :param str axis: The Y axis for which to get the label (left or right)
- """
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
- return yAxis.getLabel()
-
- def setGraphYLabel(self, label="Y", axis='left'):
- """Set the plot Y axis label.
-
- The provided label can be temporarily replaced by the Y label of the
- active curve if any.
-
- :param str label: The Y axis label (default: 'Y')
- :param str axis: The Y axis for which to set the label (left or right)
- """
- assert axis in ('left', 'right')
- yAxis = self._yAxis if axis == 'left' else self._yRightAxis
- return yAxis.setLabel(label)
-
- # Axes
-
- def getXAxis(self):
- """Returns the X axis
-
- .. versionadded:: 0.6
-
- :rtype: :class:`.items.Axis`
- """
- return self._xAxis
-
- def getYAxis(self, axis="left"):
- """Returns an Y axis
-
- .. versionadded:: 0.6
-
- :param str axis: The Y axis to return
- ('left' or 'right').
- :rtype: :class:`.items.Axis`
- """
- assert(axis in ["left", "right"])
- return self._yAxis if axis == "left" else self._yRightAxis
-
- def setAxesDisplayed(self, displayed: bool):
- """Display or not the axes.
-
- :param bool displayed: If `True` axes are displayed. If `False` axes
- are not anymore visible and the margin used for them is removed.
- """
- if displayed != self.__axesDisplayed:
- self.__axesDisplayed = displayed
- if displayed:
- self._backend.setAxesMargins(*self.__axesMargins)
- else:
- self._backend.setAxesMargins(0., 0., 0., 0.)
- self._setDirtyPlot()
- self._sigAxesVisibilityChanged.emit(displayed)
-
- def isAxesDisplayed(self) -> bool:
- """Returns whether or not axes are currently displayed
-
- :rtype: bool
- """
- return self.__axesDisplayed
-
- def setAxesMargins(
- self, left: float, top: float, right: float, bottom: float):
- """Set ratios of margins surrounding data plot area.
-
- All ratios must be within [0., 1.].
- Sums of ratios of opposed side must be < 1.
-
- :param float left: Left-side margin ratio.
- :param float top: Top margin ratio
- :param float right: Right-side margin ratio
- :param float bottom: Bottom margin ratio
- :raises ValueError:
- """
- for value in (left, top, right, bottom):
- if value < 0. or value > 1.:
- raise ValueError("Margin ratios must be within [0., 1.]")
- if left + right >= 1. or top + bottom >= 1.:
- raise ValueError("Sum of ratios of opposed sides >= 1")
- margins = left, top, right, bottom
-
- if margins != self.__axesMargins:
- self.__axesMargins = margins
- if self.isAxesDisplayed(): # Only apply if axes are displayed
- self._backend.setAxesMargins(*margins)
- self._setDirtyPlot()
-
- def getAxesMargins(self):
- """Returns ratio of margins surrounding data plot area.
-
- :return: (left, top, right, bottom)
- :rtype: List[float]
- """
- return self.__axesMargins
-
- def setYAxisInverted(self, flag=True):
- """Set the Y axis orientation.
-
- :param bool flag: True for Y axis going from top to bottom,
- False for Y axis going from bottom to top
- """
- self._yAxis.setInverted(flag)
-
- def isYAxisInverted(self):
- """Return True if Y axis goes from top to bottom, False otherwise."""
- return self._yAxis.isInverted()
-
- def isXAxisLogarithmic(self):
- """Return True if X axis scale is logarithmic, False if linear."""
- return self._xAxis._isLogarithmic()
-
- def setXAxisLogarithmic(self, flag):
- """Set the bottom X axis scale (either linear or logarithmic).
-
- :param bool flag: True to use a logarithmic scale, False for linear.
- """
- self._xAxis._setLogarithmic(flag)
-
- def isYAxisLogarithmic(self):
- """Return True if Y axis scale is logarithmic, False if linear."""
- return self._yAxis._isLogarithmic()
-
- def setYAxisLogarithmic(self, flag):
- """Set the Y axes scale (either linear or logarithmic).
-
- :param bool flag: True to use a logarithmic scale, False for linear.
- """
- self._yAxis._setLogarithmic(flag)
-
- def isXAxisAutoScale(self):
- """Return True if X axis is automatically adjusting its limits."""
- return self._xAxis.isAutoScale()
-
- def setXAxisAutoScale(self, flag=True):
- """Set the X axis limits adjusting behavior of :meth:`resetZoom`.
-
- :param bool flag: True to resize limits automatically,
- False to disable it.
- """
- self._xAxis.setAutoScale(flag)
-
- def isYAxisAutoScale(self):
- """Return True if Y axes are automatically adjusting its limits."""
- return self._yAxis.isAutoScale()
-
- def setYAxisAutoScale(self, flag=True):
- """Set the Y axis limits adjusting behavior of :meth:`resetZoom`.
-
- :param bool flag: True to resize limits automatically,
- False to disable it.
- """
- self._yAxis.setAutoScale(flag)
-
- def isKeepDataAspectRatio(self):
- """Returns whether the plot is keeping data aspect ratio or not."""
- return self._backend.isKeepDataAspectRatio()
-
- def setKeepDataAspectRatio(self, flag=True):
- """Set whether the plot keeps data aspect ratio or not.
-
- :param bool flag: True to respect data aspect ratio
- """
- flag = bool(flag)
- if flag == self.isKeepDataAspectRatio():
- return
- self._backend.setKeepDataAspectRatio(flag=flag)
- self._setDirtyPlot()
- self._forceResetZoom()
- self.notify('setKeepDataAspectRatio', state=flag)
-
- def getGraphGrid(self):
- """Return the current grid mode, either None, 'major' or 'both'.
-
- See :meth:`setGraphGrid`.
- """
- return self._grid
-
- def setGraphGrid(self, which=True):
- """Set the type of grid to display.
-
- :param which: None or False to disable the grid,
- 'major' or True for grid on major ticks (the default),
- 'both' for grid on both major and minor ticks.
- :type which: str of bool
- """
- assert which in (None, True, False, 'both', 'major')
- if not which:
- which = None
- elif which is True:
- which = 'major'
- self._grid = which
- self._backend.setGraphGrid(which)
- self._setDirtyPlot()
- self.notify('setGraphGrid', which=str(which))
-
- # Defaults
-
- def isDefaultPlotPoints(self):
- """Return True if the default Curve symbol is set and False if not."""
- return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL
-
- def setDefaultPlotPoints(self, flag):
- """Set the default symbol of all curves.
-
- When called, this reset the symbol of all existing curves.
-
- :param bool flag: True to use 'o' as the default curve symbol,
- False to use no symbol.
- """
- self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ''
-
- # Reset symbol of all curves
- curves = self.getAllCurves(just_legend=False, withhidden=True)
-
- if curves:
- for curve in curves:
- curve.setSymbol(self._defaultPlotPoints)
-
- def isDefaultPlotLines(self):
- """Return True for line as default line style, False for no line."""
- return self._plotLines
-
- def setDefaultPlotLines(self, flag):
- """Toggle the use of lines as the default curve line style.
-
- :param bool flag: True to use a line as the default line style,
- False to use no line as the default line style.
- """
- self._plotLines = bool(flag)
-
- linestyle = '-' if self._plotLines else ' '
-
- # Reset linestyle of all curves
- curves = self.getAllCurves(withhidden=True)
-
- if curves:
- for curve in curves:
- curve.setLineStyle(linestyle)
-
- def getDefaultColormap(self):
- """Return the default colormap used by :meth:`addImage`.
-
- :rtype: ~silx.gui.colors.Colormap
- """
- return self._defaultColormap
-
- def setDefaultColormap(self, colormap=None):
- """Set the default colormap used by :meth:`addImage`.
-
- Setting the default colormap do not change any currently displayed
- image.
- It only affects future calls to :meth:`addImage` without the colormap
- parameter.
-
- :param ~silx.gui.colors.Colormap colormap:
- The description of the default colormap, or
- None to set the colormap to a linear
- autoscale gray colormap.
- """
- if colormap is None:
- colormap = Colormap(name=silx.config.DEFAULT_COLORMAP_NAME,
- normalization='linear',
- vmin=None,
- vmax=None)
- if isinstance(colormap, dict):
- self._defaultColormap = Colormap._fromDict(colormap)
- else:
- assert isinstance(colormap, Colormap)
- self._defaultColormap = colormap
- self.notify('defaultColormapChanged')
-
- @staticmethod
- def getSupportedColormaps():
- """Get the supported colormap names as a tuple of str.
-
- The list contains at least:
- ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
- 'magma', 'inferno', 'plasma', 'viridis')
- """
- return Colormap.getSupportedColormaps()
-
- def _resetColorAndStyle(self):
- self._colorIndex = 0
- self._styleIndex = 0
-
- def _getColorAndStyle(self):
- color = self.colorList[self._colorIndex]
- style = self._styleList[self._styleIndex]
-
- # Loop over color and then styles
- self._colorIndex += 1
- if self._colorIndex >= len(self.colorList):
- self._colorIndex = 0
- self._styleIndex = (self._styleIndex + 1) % len(self._styleList)
-
- # If color is the one of active curve, take the next one
- if colors.rgba(color) == self.getActiveCurveStyle().getColor():
- color, style = self._getColorAndStyle()
-
- if not self._plotLines:
- style = ' '
-
- return color, style
-
- # Misc.
-
- def getWidgetHandle(self):
- """Return the widget the plot is displayed in.
-
- This widget is owned by the backend.
- """
- return self._backend.getWidgetHandle()
-
- def notify(self, event, **kwargs):
- """Send an event to the listeners and send signals.
-
- Event are passed to the registered callback as a dict with an 'event'
- key for backward compatibility with PyMca.
-
- :param str event: The type of event
- :param kwargs: The information of the event.
- """
- eventDict = kwargs.copy()
- eventDict['event'] = event
- self.sigPlotSignal.emit(eventDict)
-
- if event == 'setKeepDataAspectRatio':
- self.sigSetKeepDataAspectRatio.emit(kwargs['state'])
- elif event == 'setGraphGrid':
- self.sigSetGraphGrid.emit(kwargs['which'])
- elif event == 'setGraphCursor':
- self.sigSetGraphCursor.emit(kwargs['state'])
- elif event == 'contentChanged':
- self.sigContentChanged.emit(
- kwargs['action'], kwargs['kind'], kwargs['legend'])
- elif event == 'activeCurveChanged':
- self.sigActiveCurveChanged.emit(
- kwargs['previous'], kwargs['legend'])
- elif event == 'activeImageChanged':
- self.sigActiveImageChanged.emit(
- kwargs['previous'], kwargs['legend'])
- elif event == 'activeScatterChanged':
- self.sigActiveScatterChanged.emit(
- kwargs['previous'], kwargs['legend'])
- elif event == 'interactiveModeChanged':
- self.sigInteractiveModeChanged.emit(kwargs['source'])
-
- eventDict = kwargs.copy()
- eventDict['event'] = event
- self._callback(eventDict)
-
- def setCallback(self, callbackFunction=None):
- """Attach a listener to the backend.
-
- Limitation: Only one listener at a time.
-
- :param callbackFunction: function accepting a dictionary as input
- to handle the graph events
- If None (default), use a default listener.
- """
- # TODO allow multiple listeners
- # allow register listener by event type
- if callbackFunction is None:
- callbackFunction = WeakMethodProxy(self.graphCallback)
- self._callback = callbackFunction
-
- def graphCallback(self, ddict=None):
- """This callback is going to receive all the events from the plot.
-
- Those events will consist on a dictionary and among the dictionary
- keys the key 'event' is mandatory to describe the type of event.
- This default implementation only handles setting the active curve.
- """
-
- if ddict is None:
- ddict = {}
- _logger.debug("Received dict keys = %s", str(ddict.keys()))
- _logger.debug(str(ddict))
- if ddict['event'] in ["legendClicked", "curveClicked"]:
- if ddict['button'] == "left":
- self.setActiveCurve(ddict['label'])
- qt.QToolTip.showText(self.cursor().pos(), ddict['label'])
- elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left':
- self.setActiveCurve(None)
-
- def saveGraph(self, filename, fileFormat=None, dpi=None):
- """Save a snapshot of the plot.
-
- Supported file formats depends on the backend in use.
- The following file formats are always supported: "png", "svg".
- The matplotlib backend supports more formats:
- "pdf", "ps", "eps", "tiff", "jpeg", "jpg".
-
- :param filename: Destination
- :type filename: str, StringIO or BytesIO
- :param str fileFormat: String specifying the format
- :return: False if cannot save the plot, True otherwise
- """
- if fileFormat is None:
- if not hasattr(filename, 'lower'):
- _logger.warning(
- 'saveGraph cancelled, cannot define file format.')
- return False
- else:
- fileFormat = (filename.split(".")[-1]).lower()
-
- supportedFormats = ("png", "svg", "pdf", "ps", "eps",
- "tif", "tiff", "jpeg", "jpg")
-
- if fileFormat not in supportedFormats:
- _logger.warning('Unsupported format %s', fileFormat)
- return False
- else:
- self._backend.saveGraph(filename,
- fileFormat=fileFormat,
- dpi=dpi)
- return True
-
- def getDataMargins(self):
- """Get the default data margin ratios, see :meth:`setDataMargins`.
-
- :return: The margin ratios for each side (xMin, xMax, yMin, yMax).
- :rtype: A 4-tuple of floats.
- """
- return self._defaultDataMargins
-
- def setDataMargins(self, xMinMargin=0., xMaxMargin=0.,
- yMinMargin=0., yMaxMargin=0.):
- """Set the default data margins to use in :meth:`resetZoom`.
-
- Set the default ratios of margins (as floats) to add around the data
- inside the plot area for each side.
- """
- self._defaultDataMargins = (xMinMargin, xMaxMargin,
- yMinMargin, yMaxMargin)
-
- def getAutoReplot(self):
- """Return True if replot is automatically handled, False otherwise.
-
- See :meth`setAutoReplot`.
- """
- return self._autoreplot
-
- def setAutoReplot(self, autoreplot=True):
- """Set automatic replot mode.
-
- When enabled, the plot is redrawn automatically when changed.
- When disabled, the plot is not redrawn when its content change.
- Instead, it :meth:`replot` must be called.
-
- :param bool autoreplot: True to enable it (default),
- False to disable it.
- """
- self._autoreplot = bool(autoreplot)
-
- # If the plot is dirty before enabling autoreplot,
- # then _backend.postRedisplay will never be called from _setDirtyPlot
- if self._autoreplot and self._getDirtyPlot():
- self._backend.postRedisplay()
-
- def replot(self):
- """Redraw the plot immediately."""
- for item in self._contentToUpdate:
- item._update(self._backend)
-
- self._contentToUpdate = []
- self._backend.replot()
- self._dirty = False # reset dirty flag
-
- def _forceResetZoom(self, dataMargins=None):
- """Reset the plot limits to the bounds of the data and redraw the plot.
-
- This method forces a reset zoom and does not check axis autoscale.
-
- Extra margins can be added around the data inside the plot area
- (see :meth:`setDataMargins`).
- Margins are given as one ratio of the data range per limit of the
- data (xMin, xMax, yMin and yMax limits).
- For log scale, extra margins are applied in log10 of the data.
-
- :param dataMargins: Ratios of margins to add around the data inside
- the plot area for each side (default: no margins).
- :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
- """
- if dataMargins is None:
- dataMargins = self._defaultDataMargins
-
- # Get data range
- ranges = self.getDataRange()
- xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
- ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
- if ranges.yright is None:
- ymin2, ymax2 = ymin, ymax
- else:
- ymin2, ymax2 = ranges.yright
- if ranges.y is None:
- ymin, ymax = ranges.yright
-
- # Add margins around data inside the plot area
- newLimits = list(_utils.addMarginsToLimits(
- dataMargins,
- self._xAxis._isLogarithmic(),
- self._yAxis._isLogarithmic(),
- xmin, xmax, ymin, ymax, ymin2, ymax2))
-
- if self.isKeepDataAspectRatio():
- # Use limits with margins to keep ratio
- xmin, xmax, ymin, ymax = newLimits[:4]
-
- # Compute bbox wth figure aspect ratio
- plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
- if plotWidth > 0 and plotHeight > 0:
- plotRatio = plotHeight / plotWidth
- dataRatio = (ymax - ymin) / (xmax - xmin)
- if dataRatio < plotRatio:
- # Increase y range
- ycenter = 0.5 * (ymax + ymin)
- yrange = (xmax - xmin) * plotRatio
- newLimits[2] = ycenter - 0.5 * yrange
- newLimits[3] = ycenter + 0.5 * yrange
-
- elif dataRatio > plotRatio:
- # Increase x range
- xcenter = 0.5 * (xmax + xmin)
- xrange_ = (ymax - ymin) / plotRatio
- newLimits[0] = xcenter - 0.5 * xrange_
- newLimits[1] = xcenter + 0.5 * xrange_
-
- self.setLimits(*newLimits)
-
- def resetZoom(self, dataMargins=None):
- """Reset the plot limits to the bounds of the data and redraw the plot.
-
- It automatically scale limits of axes that are in autoscale mode
- (see :meth:`getXAxis`, :meth:`getYAxis` and :meth:`Axis.setAutoScale`).
- It keeps current limits on axes that are not in autoscale mode.
-
- Extra margins can be added around the data inside the plot area
- (see :meth:`setDataMargins`).
- Margins are given as one ratio of the data range per limit of the
- data (xMin, xMax, yMin and yMax limits).
- For log scale, extra margins are applied in log10 of the data.
-
- :param dataMargins: Ratios of margins to add around the data inside
- the plot area for each side (default: no margins).
- :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
- """
- xLimits = self._xAxis.getLimits()
- yLimits = self._yAxis.getLimits()
- y2Limits = self._yRightAxis.getLimits()
-
- xAuto = self._xAxis.isAutoScale()
- yAuto = self._yAxis.isAutoScale()
-
- # With log axes, autoscale if limits are <= 0
- # This avoids issues with toggling log scale with matplotlib 2.1.0
- if self._xAxis.getScale() == self._xAxis.LOGARITHMIC and xLimits[0] <= 0:
- xAuto = True
- if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (yLimits[0] <= 0 or y2Limits[0] <= 0):
- yAuto = True
-
- if not xAuto and not yAuto:
- _logger.debug("Nothing to autoscale")
- else: # Some axes to autoscale
- self._forceResetZoom(dataMargins=dataMargins)
-
- # Restore limits for axis not in autoscale
- if not xAuto and yAuto:
- self.setGraphXLimits(*xLimits)
- elif xAuto and not yAuto:
- if y2Limits is not None:
- self.setGraphYLimits(
- y2Limits[0], y2Limits[1], axis='right')
- if yLimits is not None:
- self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')
-
- if (xLimits != self._xAxis.getLimits() or
- yLimits != self._yAxis.getLimits() or
- y2Limits != self._yRightAxis.getLimits()):
- self._notifyLimitsChanged()
-
- # Coord conversion
-
- def dataToPixel(self, x=None, y=None, axis="left", check=True):
- """Convert a position in data coordinates to a position in pixels.
-
- :param float x: The X coordinate in data space. If None (default)
- the middle position of the displayed data is used.
- :param float y: The Y coordinate in data space. If None (default)
- the middle position of the displayed data is used.
- :param str axis: The Y axis to use for the conversion
- ('left' or 'right').
- :param bool check: True to return None if outside displayed area,
- False to convert to pixels anyway
- :returns: The corresponding position in pixels or
- None if the data position is not in the displayed area and
- check is True.
- :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
- """
- assert axis in ("left", "right")
-
- xmin, xmax = self._xAxis.getLimits()
- yAxis = self.getYAxis(axis=axis)
- ymin, ymax = yAxis.getLimits()
-
- if x is None:
- x = 0.5 * (xmax + xmin)
- if y is None:
- y = 0.5 * (ymax + ymin)
-
- if check:
- if x > xmax or x < xmin:
- return None
-
- if y > ymax or y < ymin:
- return None
-
- return self._backend.dataToPixel(x, y, axis=axis)
-
- def pixelToData(self, x, y, axis="left", check=False):
- """Convert a position in pixels to a position in data coordinates.
-
- :param float x: The X coordinate in pixels. If None (default)
- the center of the widget is used.
- :param float y: The Y coordinate in pixels. If None (default)
- the center of the widget is used.
- :param str axis: The Y axis to use for the conversion
- ('left' or 'right').
- :param bool check: Toggle checking if pixel is in plot area.
- If False, this method never returns None.
- :returns: The corresponding position in data space or
- None if the pixel position is not in the plot area.
- :rtype: A tuple of 2 floats: (xData, yData) or None.
- """
- assert axis in ("left", "right")
-
- if x is None:
- x = self.width() // 2
- if y is None:
- y = self.height() // 2
-
- if check:
- left, top, width, height = self.getPlotBoundsInPixels()
- if not (left <= x <= left + width and top <= y <= top + height):
- return None
-
- return self._backend.pixelToData(x, y, axis)
-
- def getPlotBoundsInPixels(self):
- """Plot area bounds in widget coordinates in pixels.
-
- :return: bounds as a 4-tuple of int: (left, top, width, height)
- """
- return self._backend.getPlotBoundsInPixels()
-
- # Interaction support
-
- def getGraphCursorShape(self):
- """Returns the current cursor shape.
-
- :rtype: str
- """
- return self.__graphCursorShape
-
- def setGraphCursorShape(self, cursor=None):
- """Set the cursor shape.
-
- :param str cursor: Name of the cursor shape
- """
- self.__graphCursorShape = cursor
- self._backend.setGraphCursorShape(cursor)
-
- @deprecated(replacement='getItems', since_version='0.13')
- def _getAllMarkers(self, just_legend=False):
- markers = [item for item in self.getItems() if isinstance(item, items.MarkerBase)]
- if just_legend:
- return [marker.getName() for marker in markers]
- else:
- return markers
-
- def _getMarkerAt(self, x, y):
- """Return the most interactive marker at a location, else None
-
- :param float x: X position in pixels
- :param float y: Y position in pixels
- :rtype: None of marker object
- """
- def checkDraggable(item):
- return isinstance(item, items.MarkerBase) and item.isDraggable()
- def checkSelectable(item):
- return isinstance(item, items.MarkerBase) and item.isSelectable()
- def check(item):
- return isinstance(item, items.MarkerBase)
-
- result = self._pickTopMost(x, y, checkDraggable)
- if not result:
- result = self._pickTopMost(x, y, checkSelectable)
- if not result:
- result = self._pickTopMost(x, y, check)
- marker = result.getItem() if result is not None else None
- return marker
-
- def _getMarker(self, legend=None):
- """Get the object describing a specific marker.
-
- It returns None in case no matching marker is found
-
- :param str legend: The legend of the marker to retrieve
- :rtype: None of marker object
- """
- return self._getItem(kind='marker', legend=legend)
-
- def pickItems(self, x, y, condition=None):
- """Generator of picked items in the plot at given position.
-
- Items are returned from front to back.
-
- :param float x: X position in pixels
- :param float y: Y position in pixels
- :param callable condition:
- Callable taking an item as input and returning False for items
- to skip during picking. If None (default) no item is skipped.
- :return: Iterable of :class:`PickingResult` objects at picked position.
- Items are ordered from front to back.
- """
- for item in reversed(self._backend.getItemsFromBackToFront(condition=condition)):
- result = item.pick(x, y)
- if result is not None:
- yield result
-
- def _pickTopMost(self, x, y, condition=None):
- """Returns top-most picked item in the plot at given position.
-
- Items are checked from front to back.
-
- :param float x: X position in pixels
- :param float y: Y position in pixels
- :param callable condition:
- Callable taking an item as input and returning False for items
- to skip during picking. If None (default) no item is skipped.
- :return: :class:`PickingResult` object at picked position.
- If no item is picked, it returns None
- :rtype: Union[None,PickingResult]
- """
- for result in self.pickItems(x, y, condition):
- return result
- return None
-
- # User event handling #
-
- def _isPositionInPlotArea(self, x, y):
- """Project position in pixel to the closest point in the plot area
-
- :param float x: X coordinate in widget coordinate (in pixel)
- :param float y: Y coordinate in widget coordinate (in pixel)
- :return: (x, y) in widget coord (in pixel) in the plot area
- """
- left, top, width, height = self.getPlotBoundsInPixels()
- xPlot = numpy.clip(x, left, left + width)
- yPlot = numpy.clip(y, top, top + height)
- return xPlot, yPlot
-
- def onMousePress(self, xPixel, yPixel, btn):
- """Handle mouse press event.
-
- :param float xPixel: X mouse position in pixels
- :param float yPixel: Y mouse position in pixels
- :param str btn: Mouse button in 'left', 'middle', 'right'
- """
- if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
- self._pressedButtons.append(btn)
- self._eventHandler.handleEvent('press', xPixel, yPixel, btn)
-
- def onMouseMove(self, xPixel, yPixel):
- """Handle mouse move event.
-
- :param float xPixel: X mouse position in pixels
- :param float yPixel: Y mouse position in pixels
- """
- inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel)
- isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
-
- if self._cursorInPlot != isCursorInPlot:
- self._cursorInPlot = isCursorInPlot
- self._eventHandler.handleEvent(
- 'enter' if self._cursorInPlot else 'leave')
-
- if isCursorInPlot:
- # Signal mouse move event
- dataPos = self.pixelToData(inXPixel, inYPixel)
- assert dataPos is not None
-
- btn = self._pressedButtons[-1] if self._pressedButtons else None
- event = PlotEvents.prepareMouseSignal(
- 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel)
- self.notify(**event)
-
- # Either button was pressed in the plot or cursor is in the plot
- if isCursorInPlot or self._pressedButtons:
- self._eventHandler.handleEvent('move', inXPixel, inYPixel)
-
- def onMouseRelease(self, xPixel, yPixel, btn):
- """Handle mouse release event.
-
- :param float xPixel: X mouse position in pixels
- :param float yPixel: Y mouse position in pixels
- :param str btn: Mouse button in 'left', 'middle', 'right'
- """
- try:
- self._pressedButtons.remove(btn)
- except ValueError:
- pass
- else:
- xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel)
- self._eventHandler.handleEvent('release', xPixel, yPixel, btn)
-
- def onMouseWheel(self, xPixel, yPixel, angleInDegrees):
- """Handle mouse wheel event.
-
- :param float xPixel: X mouse position in pixels
- :param float yPixel: Y mouse position in pixels
- :param float angleInDegrees: Angle corresponding to wheel motion.
- Positive for movement away from the user,
- negative for movement toward the user.
- """
- if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
- self._eventHandler.handleEvent(
- 'wheel', xPixel, yPixel, angleInDegrees)
-
- def onMouseLeaveWidget(self):
- """Handle mouse leave widget event."""
- if self._cursorInPlot:
- self._cursorInPlot = False
- self._eventHandler.handleEvent('leave')
-
- # Interaction modes #
-
- def getInteractiveMode(self):
- """Returns the current interactive mode as a dict.
-
- The returned dict contains at least the key 'mode'.
- Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
- It can also contains extra keys (e.g., 'color') specific to a mode
- as provided to :meth:`setInteractiveMode`.
- """
- return self._eventHandler.getInteractiveMode()
-
- def resetInteractiveMode(self):
- """Reset the interactive mode to use the previous basic interactive
- mode used.
-
- It can be one of "zoom" or "pan".
- """
- mode, zoomOnWheel = self._previousDefaultMode
- self.setInteractiveMode(mode=mode, zoomOnWheel=zoomOnWheel)
-
- def setInteractiveMode(self, mode, color='black',
- shape='polygon', label=None,
- zoomOnWheel=True, source=None, width=None):
- """Switch the interactive mode.
-
- :param str mode: The name of the interactive mode.
- In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
- :param color: Only for 'draw' and 'zoom' modes.
- Color to use for drawing selection area. Default black.
- :type color: Color description: The name as a str or
- a tuple of 4 floats.
- :param str shape: Only for 'draw' mode. The kind of shape to draw.
- In 'polygon', 'rectangle', 'line', 'vline', 'hline',
- 'freeline'.
- Default is 'polygon'.
- :param str label: Only for 'draw' mode, sent in drawing events.
- :param bool zoomOnWheel: Toggle zoom on wheel support
- :param source: A user-defined object (typically the caller object)
- that will be send in the interactiveModeChanged event,
- to identify which object required a mode change.
- Default: None
- :param float width: Width of the pencil. Only for draw pencil mode.
- """
- self._eventHandler.setInteractiveMode(mode, color, shape, label, width)
- self._eventHandler.zoomOnWheel = zoomOnWheel
- if mode in ["pan", "zoom"]:
- self._previousDefaultMode = mode, zoomOnWheel
-
- self.notify(
- 'interactiveModeChanged', source=source)
-
- # Panning with arrow keys
-
- def isPanWithArrowKeys(self):
- """Returns whether or not panning the graph with arrow keys is enabled.
-
- See :meth:`setPanWithArrowKeys`.
- """
- return self._panWithArrowKeys
-
- def setPanWithArrowKeys(self, pan=False):
- """Enable/Disable panning the graph with arrow keys.
-
- This grabs the keyboard.
-
- :param bool pan: True to enable panning, False to disable.
- """
- pan = bool(pan)
- panHasChanged = self._panWithArrowKeys != pan
-
- self._panWithArrowKeys = pan
- if not self._panWithArrowKeys:
- self.setFocusPolicy(qt.Qt.NoFocus)
- else:
- self.setFocusPolicy(qt.Qt.StrongFocus)
- self.setFocus(qt.Qt.OtherFocusReason)
-
- if panHasChanged:
- self.sigSetPanWithArrowKeys.emit(pan)
-
- # Dict to convert Qt arrow key code to direction str.
- _ARROWS_TO_PAN_DIRECTION = {
- qt.Qt.Key_Left: 'left',
- qt.Qt.Key_Right: 'right',
- qt.Qt.Key_Up: 'up',
- qt.Qt.Key_Down: 'down'
- }
-
- def __simulateMouseMove(self):
- qapp = qt.QApplication.instance()
- event = qt.QMouseEvent(
- qt.QEvent.MouseMove,
- self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()),
- qt.Qt.NoButton,
- qapp.mouseButtons(),
- qapp.keyboardModifiers())
- qapp.sendEvent(self.getWidgetHandle(), event)
-
- def keyPressEvent(self, event):
- """Key event handler handling panning on arrow keys.
-
- Overrides base class implementation.
- """
- key = event.key()
- if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION:
- self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1)
-
- # Send a mouse move event to the plot widget to take into account
- # that even if mouse didn't move on the screen, it moved relative
- # to the plotted data.
- self.__simulateMouseMove()
- else:
- # Only call base class implementation when key is not handled.
- # See QWidget.keyPressEvent for details.
- super(PlotWidget, self).keyPressEvent(event)
diff --git a/silx/gui/plot/PlotWindow.py b/silx/gui/plot/PlotWindow.py
deleted file mode 100644
index 3cd605f..0000000
--- a/silx/gui/plot/PlotWindow.py
+++ /dev/null
@@ -1,994 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""A :class:`.PlotWidget` with additional toolbars.
-
-The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
-"""
-
-__authors__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "12/04/2019"
-
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
-import logging
-import weakref
-
-import silx
-from silx.utils.weakref import WeakMethodProxy
-from silx.utils.deprecation import deprecated
-from silx.utils.proxy import docstring
-
-from . import PlotWidget
-from . import actions
-from . import items
-from .actions import medfilt as actions_medfilt
-from .actions import fit as actions_fit
-from .actions import control as actions_control
-from .actions import histogram as actions_histogram
-from . import PlotToolButtons
-from . import tools
-from .Profile import ProfileToolBar
-from .LegendSelector import LegendsDockWidget
-from .CurvesROIWidget import CurvesROIDockWidget
-from .MaskToolsWidget import MaskToolsDockWidget
-from .StatsWidget import BasicStatsWidget
-from .ColorBar import ColorBarWidget
-try:
- from ..console import IPythonDockWidget
-except ImportError:
- IPythonDockWidget = None
-
-from .. import qt
-
-
-_logger = logging.getLogger(__name__)
-
-
-class PlotWindow(PlotWidget):
- """Qt Widget providing a 1D/2D plot area and additional tools.
-
- This widgets inherits from :class:`.PlotWidget` and provides its plot API.
-
- Initialiser parameters:
-
- :param parent: The parent of this widget or None.
- :param backend: The backend to use for the plot (default: matplotlib).
- See :class:`.PlotWidget` for the list of supported backend.
- :type backend: str or :class:`BackendBase.BackendBase`
- :param bool resetzoom: Toggle visibility of reset zoom action.
- :param bool autoScale: Toggle visibility of axes autoscale actions.
- :param bool logScale: Toggle visibility of axes log scale actions.
- :param bool grid: Toggle visibility of grid mode action.
- :param bool curveStyle: Toggle visibility of curve style action.
- :param bool colormap: Toggle visibility of colormap action.
- :param bool aspectRatio: Toggle visibility of aspect ratio button.
- :param bool yInverted: Toggle visibility of Y axis direction button.
- :param bool copy: Toggle visibility of copy action.
- :param bool save: Toggle visibility of save action.
- :param bool print_: Toggle visibility of print action.
- :param bool control: True to display an Options button with a sub-menu
- to show legends, toggle crosshair and pan with arrows.
- (Default: False)
- :param position: True to display widget with (x, y) mouse position
- (Default: False).
- It also supports a list of (name, funct(x, y)->value)
- to customize the displayed values.
- See :class:`~silx.gui.plot.tools.PositionInfo`.
- :param bool roi: Toggle visibilty of ROI action.
- :param bool mask: Toggle visibilty of mask action.
- :param bool fit: Toggle visibilty of fit action.
- """
-
- def __init__(self, parent=None, backend=None,
- resetzoom=True, autoScale=True, logScale=True, grid=True,
- curveStyle=True, colormap=True,
- aspectRatio=True, yInverted=True,
- copy=True, save=True, print_=True,
- control=False, position=False,
- roi=True, mask=True, fit=False):
- super(PlotWindow, self).__init__(parent=parent, backend=backend)
- if parent is None:
- self.setWindowTitle('PlotWindow')
-
- self._dockWidgets = []
-
- # lazy loaded dock widgets
- self._legendsDockWidget = None
- self._curvesROIDockWidget = None
- self._maskToolsDockWidget = None
- self._consoleDockWidget = None
- self._statsDockWidget = None
-
- # Create color bar, hidden by default for backward compatibility
- self._colorbar = ColorBarWidget(parent=self, plot=self)
-
- # Init actions
- self.group = qt.QActionGroup(self)
- self.group.setExclusive(False)
-
- self.resetZoomAction = self.group.addAction(
- actions.control.ResetZoomAction(self, parent=self))
- self.resetZoomAction.setVisible(resetzoom)
- self.addAction(self.resetZoomAction)
-
- self.zoomInAction = actions.control.ZoomInAction(self, parent=self)
- self.addAction(self.zoomInAction)
-
- self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self)
- self.addAction(self.zoomOutAction)
-
- self.xAxisAutoScaleAction = self.group.addAction(
- actions.control.XAxisAutoScaleAction(self, parent=self))
- self.xAxisAutoScaleAction.setVisible(autoScale)
- self.addAction(self.xAxisAutoScaleAction)
-
- self.yAxisAutoScaleAction = self.group.addAction(
- actions.control.YAxisAutoScaleAction(self, parent=self))
- self.yAxisAutoScaleAction.setVisible(autoScale)
- self.addAction(self.yAxisAutoScaleAction)
-
- self.xAxisLogarithmicAction = self.group.addAction(
- actions.control.XAxisLogarithmicAction(self, parent=self))
- self.xAxisLogarithmicAction.setVisible(logScale)
- self.addAction(self.xAxisLogarithmicAction)
-
- self.yAxisLogarithmicAction = self.group.addAction(
- actions.control.YAxisLogarithmicAction(self, parent=self))
- self.yAxisLogarithmicAction.setVisible(logScale)
- self.addAction(self.yAxisLogarithmicAction)
-
- self.gridAction = self.group.addAction(
- actions.control.GridAction(self, gridMode='both', parent=self))
- self.gridAction.setVisible(grid)
- self.addAction(self.gridAction)
-
- self.curveStyleAction = self.group.addAction(
- actions.control.CurveStyleAction(self, parent=self))
- self.curveStyleAction.setVisible(curveStyle)
- self.addAction(self.curveStyleAction)
-
- self.colormapAction = self.group.addAction(
- actions.control.ColormapAction(self, parent=self))
- self.colormapAction.setVisible(colormap)
- self.addAction(self.colormapAction)
-
- self.colorbarAction = self.group.addAction(
- actions_control.ColorBarAction(self, parent=self))
- self.colorbarAction.setVisible(False)
- self.addAction(self.colorbarAction)
- self._colorbar.setVisible(False)
-
- self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
- parent=self, plot=self)
- self.keepDataAspectRatioButton.setVisible(aspectRatio)
-
- self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
- parent=self, plot=self)
- self.yAxisInvertedButton.setVisible(yInverted)
-
- self.group.addAction(self.getRoiAction())
- self.getRoiAction().setVisible(roi)
-
- self.group.addAction(self.getMaskAction())
- self.getMaskAction().setVisible(mask)
-
- self._intensityHistoAction = self.group.addAction(
- actions_histogram.PixelIntensitiesHistoAction(self, parent=self))
- self._intensityHistoAction.setVisible(False)
-
- self._medianFilter2DAction = self.group.addAction(
- actions_medfilt.MedianFilter2DAction(self, parent=self))
- self._medianFilter2DAction.setVisible(False)
-
- self._medianFilter1DAction = self.group.addAction(
- actions_medfilt.MedianFilter1DAction(self, parent=self))
- self._medianFilter1DAction.setVisible(False)
-
- self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self))
- self.fitAction.setVisible(fit)
- self.addAction(self.fitAction)
-
- # lazy loaded actions needed by the controlButton menu
- self._consoleAction = None
- self._statsAction = None
- self._panWithArrowKeysAction = None
- self._crosshairAction = None
-
- # Make colorbar background white
- self._colorbar.setAutoFillBackground(True)
- self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
- self._updateColorBarBackground()
-
- if control: # Create control button only if requested
- self.controlButton = qt.QToolButton()
- self.controlButton.setText("Options")
- self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
- self.controlButton.setAutoRaise(True)
- self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
- menu = qt.QMenu(self)
- menu.aboutToShow.connect(self._customControlButtonMenu)
- self.controlButton.setMenu(menu)
-
- self._positionWidget = None
- if position: # Add PositionInfo widget to the bottom of the plot
- if isinstance(position, abc.Iterable):
- # Use position as a set of converters
- converters = position
- else:
- converters = None
- self._positionWidget = tools.PositionInfo(
- plot=self, converters=converters)
- # Set a snapping mode that is consistent with legacy one
- self._positionWidget.setSnappingMode(
- tools.PositionInfo.SNAPPING_CROSSHAIR |
- tools.PositionInfo.SNAPPING_ACTIVE_ONLY |
- tools.PositionInfo.SNAPPING_SYMBOLS_ONLY |
- tools.PositionInfo.SNAPPING_CURVE |
- tools.PositionInfo.SNAPPING_SCATTER)
-
- self.__setCentralWidget()
-
- # Creating the toolbar also create actions for toolbuttons
- self._interactiveModeToolBar = tools.InteractiveModeToolBar(
- parent=self, plot=self)
- self.addToolBar(self._interactiveModeToolBar)
-
- self._toolbar = self._createToolBar(title='Plot', parent=self)
- self.addToolBar(self._toolbar)
-
- self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
- self._outputToolBar.getCopyAction().setVisible(copy)
- self._outputToolBar.getSaveAction().setVisible(save)
- self._outputToolBar.getPrintAction().setVisible(print_)
- self.addToolBar(self._outputToolBar)
-
- # Activate shortcuts in PlotWindow widget:
- for toolbar in (self._interactiveModeToolBar, self._outputToolBar):
- for action in toolbar.actions():
- self.addAction(action)
-
- def __setCentralWidget(self):
- """Set central widget to host plot backend, colorbar, and bottom bar"""
- gridLayout = qt.QGridLayout()
- gridLayout.setSpacing(0)
- gridLayout.setContentsMargins(0, 0, 0, 0)
- gridLayout.addWidget(self.getWidgetHandle(), 0, 0)
- gridLayout.addWidget(self._colorbar, 0, 1)
- gridLayout.setRowStretch(0, 1)
- gridLayout.setColumnStretch(0, 1)
- centralWidget = qt.QWidget(self)
- centralWidget.setLayout(gridLayout)
-
- if hasattr(self, "controlButton") or self._positionWidget is not None:
- hbox = qt.QHBoxLayout()
- hbox.setContentsMargins(0, 0, 0, 0)
-
- if hasattr(self, "controlButton"):
- hbox.addWidget(self.controlButton)
-
- if self._positionWidget is not None:
- hbox.addWidget(self._positionWidget)
-
- hbox.addStretch(1)
- bottomBar = qt.QWidget(centralWidget)
- bottomBar.setLayout(hbox)
-
- gridLayout.addWidget(bottomBar, 1, 0, 1, -1)
-
- self.setCentralWidget(centralWidget)
-
- @docstring(PlotWidget)
- def setBackend(self, backend):
- super(PlotWindow, self).setBackend(backend)
- self.__setCentralWidget() # Recreate PlotWindow's central widget
-
- @docstring(PlotWidget)
- def setBackgroundColor(self, color):
- super(PlotWindow, self).setBackgroundColor(color)
- self._updateColorBarBackground()
-
- @docstring(PlotWidget)
- def setDataBackgroundColor(self, color):
- super(PlotWindow, self).setDataBackgroundColor(color)
- self._updateColorBarBackground()
-
- @docstring(PlotWidget)
- def setForegroundColor(self, color):
- super(PlotWindow, self).setForegroundColor(color)
- self._updateColorBarBackground()
-
- def _updateColorBarBackground(self):
- """Update the colorbar background according to the state of the plot"""
- if self.isAxesDisplayed():
- color = self.getBackgroundColor()
- else:
- color = self.getDataBackgroundColor()
- if not color.isValid():
- # If no color defined, use the background one
- color = self.getBackgroundColor()
-
- foreground = self.getForegroundColor()
-
- palette = self._colorbar.palette()
- palette.setColor(qt.QPalette.Background, color)
- palette.setColor(qt.QPalette.Window, color)
- palette.setColor(qt.QPalette.WindowText, foreground)
- palette.setColor(qt.QPalette.Text, foreground)
- self._colorbar.setPalette(palette)
-
- def getInteractiveModeToolBar(self):
- """Returns QToolBar controlling interactive mode.
-
- :rtype: QToolBar
- """
- return self._interactiveModeToolBar
-
- def getOutputToolBar(self):
- """Returns QToolBar containing save, copy and print actions
-
- :rtype: QToolBar
- """
- return self._outputToolBar
-
- @property
- @deprecated(replacement="getPositionInfoWidget()", since_version="0.8.0")
- def positionWidget(self):
- return self.getPositionInfoWidget()
-
- def getPositionInfoWidget(self):
- """Returns the widget displaying current cursor position information
-
- :rtype: ~silx.gui.plot.tools.PositionInfo
- """
- return self._positionWidget
-
- def getSelectionMask(self):
- """Return the current mask handled by :attr:`maskToolsDockWidget`.
-
- :return: The array of the mask with dimension of the 'active' image.
- If there is no active image, an empty array is returned.
- :rtype: 2D numpy.ndarray of uint8
- """
- return self.getMaskToolsDockWidget().getSelectionMask()
-
- def setSelectionMask(self, mask):
- """Set the mask handled by :attr:`maskToolsDockWidget`.
-
- If the provided mask has not the same dimension as the 'active'
- image, it will by cropped or padded.
-
- :param mask: The array to use for the mask.
- :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
- Array of other types are converted.
- :return: True if success, False if failed
- """
- return bool(self.getMaskToolsDockWidget().setSelectionMask(mask))
-
- def _toggleConsoleVisibility(self, isChecked=False):
- """Create IPythonDockWidget if needed,
- show it or hide it."""
- # create widget if needed (first call)
- if self._consoleDockWidget is None:
- available_vars = {"plt": weakref.proxy(self)}
- banner = "The variable 'plt' is available. Use the 'whos' "
- banner += "and 'help(plt)' commands for more information.\n\n"
- self._consoleDockWidget = IPythonDockWidget(
- available_vars=available_vars,
- custom_banner=banner,
- parent=self)
- self.addTabbedDockWidget(self._consoleDockWidget)
- # self._consoleDockWidget.setVisible(True)
- self._consoleDockWidget.toggleViewAction().toggled.connect(
- self.getConsoleAction().setChecked)
-
- self._consoleDockWidget.setVisible(isChecked)
-
- def _toggleStatsVisibility(self, isChecked=False):
- self.getStatsWidget().parent().setVisible(isChecked)
-
- def _createToolBar(self, title, parent):
- """Create a QToolBar from the QAction of the PlotWindow.
-
- :param str title: The title of the QMenu
- :param qt.QWidget parent: See :class:`QToolBar`
- """
- toolbar = qt.QToolBar(title, parent)
-
- # Order widgets with actions
- objects = self.group.actions()
-
- # Add push buttons to list
- index = objects.index(self.colormapAction)
- objects.insert(index + 1, self.keepDataAspectRatioButton)
- objects.insert(index + 2, self.yAxisInvertedButton)
-
- for obj in objects:
- if isinstance(obj, qt.QAction):
- toolbar.addAction(obj)
- else:
- # Add action for toolbutton in order to allow changing
- # visibility (see doc QToolBar.addWidget doc)
- if obj is self.keepDataAspectRatioButton:
- self.keepDataAspectRatioAction = toolbar.addWidget(obj)
- elif obj is self.yAxisInvertedButton:
- self.yAxisInvertedAction = toolbar.addWidget(obj)
- else:
- raise RuntimeError()
- return toolbar
-
- def toolBar(self):
- """Return a QToolBar from the QAction of the PlotWindow.
- """
- return self._toolbar
-
- def menu(self, title='Plot', parent=None):
- """Return a QMenu from the QAction of the PlotWindow.
-
- :param str title: The title of the QMenu
- :param parent: See :class:`QMenu`
- """
- menu = qt.QMenu(title, parent)
- for action in self.group.actions():
- menu.addAction(action)
- return menu
-
- def _customControlButtonMenu(self):
- """Display Options button sub-menu."""
- controlMenu = self.controlButton.menu()
- controlMenu.clear()
- controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction())
- controlMenu.addAction(self.getRoiAction())
- controlMenu.addAction(self.getStatsAction())
- controlMenu.addAction(self.getMaskAction())
- controlMenu.addAction(self.getConsoleAction())
-
- controlMenu.addSeparator()
- controlMenu.addAction(self.getCrosshairAction())
- controlMenu.addAction(self.getPanWithArrowKeysAction())
-
- def addTabbedDockWidget(self, dock_widget):
- """Add a dock widget as a new tab if there are already dock widgets
- in the plot. When the first tab is added, the area is chosen
- depending on the plot geometry:
- if the window is much wider than it is high, the right dock area
- is used, else the bottom dock area is used.
-
- :param dock_widget: Instance of :class:`QDockWidget` to be added.
- """
- if dock_widget not in self._dockWidgets:
- self._dockWidgets.append(dock_widget)
- if len(self._dockWidgets) == 1:
- # The first created dock widget must be added to a Widget area
- width = self.centralWidget().width()
- height = self.centralWidget().height()
- if width > (1.25 * height):
- area = qt.Qt.RightDockWidgetArea
- else:
- area = qt.Qt.BottomDockWidgetArea
- self.addDockWidget(area, dock_widget)
- else:
- # Other dock widgets are added as tabs to the same widget area
- self.tabifyDockWidget(self._dockWidgets[0],
- dock_widget)
-
- def removeDockWidget(self, dockwidget):
- """Removes the *dockwidget* from the main window layout and hides it.
-
- Note that the *dockwidget* is *not* deleted.
-
- :param QDockWidget dockwidget:
- """
- if dockwidget in self._dockWidgets:
- self._dockWidgets.remove(dockwidget)
- super(PlotWindow, self).removeDockWidget(dockwidget)
-
- def __handleFirstDockWidgetShow(self, visible):
- """Handle QDockWidget.visibilityChanged
-
- It calls :meth:`addTabbedDockWidget` for the `sender` widget.
- This allows to call `addTabbedDockWidget` lazily.
-
- It disconnect itself from the signal once done.
-
- :param bool visible:
- """
- if visible:
- dockWidget = self.sender()
- dockWidget.visibilityChanged.disconnect(
- self.__handleFirstDockWidgetShow)
- self.addTabbedDockWidget(dockWidget)
-
- def getColorBarWidget(self):
- """Returns the embedded :class:`ColorBarWidget` widget.
-
- :rtype: ColorBarWidget
- """
- return self._colorbar
-
- # getters for dock widgets
-
- def getLegendsDockWidget(self):
- """DockWidget with Legend panel"""
- if self._legendsDockWidget is None:
- self._legendsDockWidget = LegendsDockWidget(plot=self)
- self._legendsDockWidget.hide()
- self._legendsDockWidget.visibilityChanged.connect(
- self.__handleFirstDockWidgetShow)
- return self._legendsDockWidget
-
- def getCurvesRoiDockWidget(self):
- # Undocumented for a "soft deprecation" in version 0.7.0
- # (still used internally for lazy loading)
- if self._curvesROIDockWidget is None:
- self._curvesROIDockWidget = CurvesROIDockWidget(
- plot=self, name='Regions Of Interest')
- self._curvesROIDockWidget.hide()
- self._curvesROIDockWidget.visibilityChanged.connect(
- self.__handleFirstDockWidgetShow)
- return self._curvesROIDockWidget
-
- def getCurvesRoiWidget(self):
- """Return the :class:`CurvesROIWidget`.
-
- :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter
- and a setter for the ROI data:
-
- - :meth:`CurvesROIWidget.getRois`
- - :meth:`CurvesROIWidget.setRois`
- """
- return self.getCurvesRoiDockWidget().roiWidget
-
- def getMaskToolsDockWidget(self):
- """DockWidget with image mask panel (lazy-loaded)."""
- if self._maskToolsDockWidget is None:
- self._maskToolsDockWidget = MaskToolsDockWidget(
- plot=self, name='Mask')
- self._maskToolsDockWidget.hide()
- self._maskToolsDockWidget.visibilityChanged.connect(
- self.__handleFirstDockWidgetShow)
- return self._maskToolsDockWidget
-
- def getStatsWidget(self):
- """Returns a BasicStatsWidget connected to this plot
-
- :rtype: BasicStatsWidget
- """
- if self._statsDockWidget is None:
- self._statsDockWidget = qt.QDockWidget()
- self._statsDockWidget.setWindowTitle("Curves stats")
- self._statsDockWidget.layout().setContentsMargins(0, 0, 0, 0)
- statsWidget = BasicStatsWidget(parent=self, plot=self)
- self._statsDockWidget.setWidget(statsWidget)
- statsWidget.sigVisibilityChanged.connect(
- self.getStatsAction().setChecked)
- self._statsDockWidget.hide()
- self._statsDockWidget.visibilityChanged.connect(
- self.__handleFirstDockWidgetShow)
- return self._statsDockWidget.widget()
-
- # getters for actions
- @property
- @deprecated(replacement="getInteractiveModeToolBar().getZoomModeAction()",
- since_version="0.8.0")
- def zoomModeAction(self):
- return self.getInteractiveModeToolBar().getZoomModeAction()
-
- @property
- @deprecated(replacement="getInteractiveModeToolBar().getPanModeAction()",
- since_version="0.8.0")
- def panModeAction(self):
- return self.getInteractiveModeToolBar().getPanModeAction()
-
- def getConsoleAction(self):
- """QAction handling the IPython console activation.
-
- By default, it is connected to a method that initializes the
- console widget the first time the user clicks the "Console" menu
- button. The following clicks, after initialization is done,
- will toggle the visibility of the console widget.
-
- :rtype: QAction
- """
- if self._consoleAction is None:
- self._consoleAction = qt.QAction('Console', self)
- self._consoleAction.setCheckable(True)
- if IPythonDockWidget is not None:
- self._consoleAction.toggled.connect(self._toggleConsoleVisibility)
- else:
- self._consoleAction.setEnabled(False)
- return self._consoleAction
-
- def getCrosshairAction(self):
- """Action toggling crosshair cursor mode.
-
- :rtype: actions.PlotAction
- """
- if self._crosshairAction is None:
- self._crosshairAction = actions.control.CrosshairAction(self, color='red')
- return self._crosshairAction
-
- def getMaskAction(self):
- """QAction toggling image mask dock widget
-
- :rtype: QAction
- """
- return self.getMaskToolsDockWidget().toggleViewAction()
-
- def getPanWithArrowKeysAction(self):
- """Action toggling pan with arrow keys.
-
- :rtype: actions.PlotAction
- """
- if self._panWithArrowKeysAction is None:
- self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self)
- return self._panWithArrowKeysAction
-
- def getStatsAction(self):
- if self._statsAction is None:
- self._statsAction = qt.QAction('Curves stats', self)
- self._statsAction.setCheckable(True)
- self._statsAction.setChecked(self.getStatsWidget().parent().isVisible())
- self._statsAction.toggled.connect(self._toggleStatsVisibility)
- return self._statsAction
-
- def getRoiAction(self):
- """QAction toggling curve ROI dock widget
-
- :rtype: QAction
- """
- return self.getCurvesRoiDockWidget().toggleViewAction()
-
- def getResetZoomAction(self):
- """Action resetting the zoom
-
- :rtype: actions.PlotAction
- """
- return self.resetZoomAction
-
- def getZoomInAction(self):
- """Action to zoom in
-
- :rtype: actions.PlotAction
- """
- return self.zoomInAction
-
- def getZoomOutAction(self):
- """Action to zoom out
-
- :rtype: actions.PlotAction
- """
- return self.zoomOutAction
-
- def getXAxisAutoScaleAction(self):
- """Action to toggle the X axis autoscale on zoom reset
-
- :rtype: actions.PlotAction
- """
- return self.xAxisAutoScaleAction
-
- def getYAxisAutoScaleAction(self):
- """Action to toggle the Y axis autoscale on zoom reset
-
- :rtype: actions.PlotAction
- """
- return self.yAxisAutoScaleAction
-
- def getXAxisLogarithmicAction(self):
- """Action to toggle logarithmic X axis
-
- :rtype: actions.PlotAction
- """
- return self.xAxisLogarithmicAction
-
- def getYAxisLogarithmicAction(self):
- """Action to toggle logarithmic Y axis
-
- :rtype: actions.PlotAction
- """
- return self.yAxisLogarithmicAction
-
- def getGridAction(self):
- """Action to toggle the grid visibility in the plot
-
- :rtype: actions.PlotAction
- """
- return self.gridAction
-
- def getCurveStyleAction(self):
- """Action to change curve line and markers styles
-
- :rtype: actions.PlotAction
- """
- return self.curveStyleAction
-
- def getColormapAction(self):
- """Action open a colormap dialog to change active image
- and default colormap.
-
- :rtype: actions.PlotAction
- """
- return self.colormapAction
-
- def getKeepDataAspectRatioButton(self):
- """Button to toggle aspect ratio preservation
-
- :rtype: PlotToolButtons.AspectToolButton
- """
- return self.keepDataAspectRatioButton
-
- def getKeepDataAspectRatioAction(self):
- """Action associated to keepDataAspectRatioButton.
- Use this to change the visibility of keepDataAspectRatioButton in the
- toolbar (See :meth:`QToolBar.addWidget` documentation).
-
- :rtype: actions.PlotAction
- """
- return self.keepDataAspectRatioButton
-
- def getYAxisInvertedButton(self):
- """Button to switch the Y axis orientation
-
- :rtype: PlotToolButtons.YAxisOriginToolButton
- """
- return self.yAxisInvertedButton
-
- def getYAxisInvertedAction(self):
- """Action associated to yAxisInvertedButton.
- Use this to change the visibility yAxisInvertedButton in the toolbar.
- (See :meth:`QToolBar.addWidget` documentation).
-
- :rtype: actions.PlotAction
- """
- return self.yAxisInvertedAction
-
- def getIntensityHistogramAction(self):
- """Action toggling the histogram intensity Plot widget
-
- :rtype: actions.PlotAction
- """
- return self._intensityHistoAction
-
- def getCopyAction(self):
- """Action to copy plot snapshot to clipboard
-
- :rtype: actions.PlotAction
- """
- return self.getOutputToolBar().getCopyAction()
-
- def getSaveAction(self):
- """Action to save plot
-
- :rtype: actions.PlotAction
- """
- return self.getOutputToolBar().getSaveAction()
-
- def getPrintAction(self):
- """Action to print plot
-
- :rtype: actions.PlotAction
- """
- return self.getOutputToolBar().getPrintAction()
-
- def getFitAction(self):
- """Action to fit selected curve
-
- :rtype: actions.PlotAction
- """
- return self.fitAction
-
- def getMedianFilter1DAction(self):
- """Action toggling the 1D median filter
-
- :rtype: actions.PlotAction
- """
- return self._medianFilter1DAction
-
- def getMedianFilter2DAction(self):
- """Action toggling the 2D median filter
-
- :rtype: actions.PlotAction
- """
- return self._medianFilter2DAction
-
- def getColorBarAction(self):
- """Action toggling the colorbar show/hide action
-
- .. warning:: to show/hide the plot colorbar call directly the ColorBar
- widget using getColorBarWidget()
-
- :rtype: actions.PlotAction
- """
- return self.colorbarAction
-
-
-class Plot1D(PlotWindow):
- """PlotWindow with tools specific for curves.
-
- This widgets provides the plot API of :class:`.PlotWidget`.
-
- :param parent: The parent of this widget
- :param backend: The backend to use for the plot (default: matplotlib).
- See :class:`.PlotWidget` for the list of supported backend.
- :type backend: str or :class:`BackendBase.BackendBase`
- """
-
- def __init__(self, parent=None, backend=None):
- super(Plot1D, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=True,
- logScale=True, grid=True,
- curveStyle=True, colormap=False,
- aspectRatio=False, yInverted=False,
- copy=True, save=True, print_=True,
- control=True, position=True,
- roi=True, mask=False, fit=True)
- if parent is None:
- self.setWindowTitle('Plot1D')
- self.getXAxis().setLabel('X')
- self.getYAxis().setLabel('Y')
- action = self.getFitAction()
- action.setXRangeUpdatedOnZoom(True)
- action.setFittedItemUpdatedFromActiveCurve(True)
-
-
-class Plot2D(PlotWindow):
- """PlotWindow with a toolbar specific for images.
-
- This widgets provides the plot API of :~:`.PlotWidget`.
-
- :param parent: The parent of this widget
- :param backend: The backend to use for the plot (default: matplotlib).
- See :class:`.PlotWidget` for the list of supported backend.
- :type backend: str or :class:`BackendBase.BackendBase`
- """
-
- def __init__(self, parent=None, backend=None):
- # List of information to display at the bottom of the plot
- posInfo = [
- ('X', lambda x, y: x),
- ('Y', lambda x, y: y),
- ('Data', WeakMethodProxy(self._getImageValue)),
- ('Dims', WeakMethodProxy(self._getImageDims)),
- ]
-
- super(Plot2D, self).__init__(parent=parent, backend=backend,
- resetzoom=True, autoScale=False,
- logScale=False, grid=False,
- curveStyle=False, colormap=True,
- aspectRatio=True, yInverted=True,
- copy=True, save=True, print_=True,
- control=False, position=posInfo,
- roi=False, mask=True)
- if parent is None:
- self.setWindowTitle('Plot2D')
- self.getXAxis().setLabel('Columns')
- self.getYAxis().setLabel('Rows')
-
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
- self.getYAxis().setInverted(True)
-
- self.profile = ProfileToolBar(plot=self)
- self.addToolBar(self.profile)
-
- self.colorbarAction.setVisible(True)
- self.getColorBarWidget().setVisible(True)
-
- # Put colorbar action after colormap action
- actions = self.toolBar().actions()
- for action in actions:
- if action is self.getColormapAction():
- break
-
- self.sigActiveImageChanged.connect(self.__activeImageChanged)
-
- def __activeImageChanged(self, previous, legend):
- """Handle change of active image
-
- :param Union[str,None] previous: Legend of previous active image
- :param Union[str,None] legend: Legend of current active image
- """
- if previous is not None:
- item = self.getImage(previous)
- if item is not None:
- item.sigItemChanged.disconnect(self.__imageChanged)
-
- if legend is not None:
- item = self.getImage(legend)
- item.sigItemChanged.connect(self.__imageChanged)
-
- positionInfo = self.getPositionInfoWidget()
- if positionInfo is not None:
- positionInfo.updateInfo()
-
- def __imageChanged(self, event):
- """Handle update of active image item
-
- :param event: Type of changed event
- """
- if event == items.ItemChangedType.DATA:
- positionInfo = self.getPositionInfoWidget()
- if positionInfo is not None:
- positionInfo.updateInfo()
-
- def _getImageValue(self, x, y):
- """Get status bar value of top most image at position (x, y)
-
- :param float x: X position in plot coordinates
- :param float y: Y position in plot coordinates
- :return: The value at that point or '-'
- """
- pickedMask = None
- for picked in self.pickItems(
- *self.dataToPixel(x, y, check=False),
- lambda item: isinstance(item, items.ImageBase)):
- if isinstance(picked.getItem(), items.MaskImageData):
- if pickedMask is None: # Use top-most if many masks
- pickedMask = picked
- else:
- image = picked.getItem()
-
- indices = picked.getIndices(copy=False)
- if indices is not None:
- row, col = indices[0][0], indices[1][0]
- value = image.getData(copy=False)[row, col]
-
- if pickedMask is not None: # Check if masked
- maskItem = pickedMask.getItem()
- indices = pickedMask.getIndices()
- row, col = indices[0][0], indices[1][0]
- if maskItem.getData(copy=False)[row, col] != 0:
- return value, "Masked"
- return value
-
- return '-' # No image picked
-
- def _getImageDims(self, *args):
- activeImage = self.getActiveImage()
- if (activeImage is not None and
- activeImage.getData(copy=False) is not None):
- dims = activeImage.getData(copy=False).shape[1::-1]
- return 'x'.join(str(dim) for dim in dims)
- else:
- return '-'
-
- def getProfileToolbar(self):
- """Profile tools attached to this plot
-
- See :class:`silx.gui.plot.Profile.ProfileToolBar`
- """
- return self.profile
-
- @deprecated(replacement="getProfilePlot", since_version="0.5.0")
- def getProfileWindow(self):
- return self.getProfilePlot()
-
- def getProfilePlot(self):
- """Return plot window used to display profile curve.
-
- :return: :class:`Plot1D`
- """
- return self.profile.getProfilePlot()
diff --git a/silx/gui/plot/PrintPreviewToolButton.py b/silx/gui/plot/PrintPreviewToolButton.py
deleted file mode 100644
index d857c18..0000000
--- a/silx/gui/plot/PrintPreviewToolButton.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This modules provides tool buttons to send the content of a plot to a
-print preview page.
-The plot content can then be moved on the page and resized prior to printing.
-
-Classes
--------
-
-- :class:`PrintPreviewToolButton`
-- :class:`SingletonPrintPreviewToolButton`
-
-Examples
---------
-
-Simple example
-++++++++++++++
-
-.. code-block:: python
-
- from silx.gui import qt
- from silx.gui.plot import PlotWidget
- from silx.gui.plot.PrintPreviewToolButton import PrintPreviewToolButton
- import numpy
-
- app = qt.QApplication([])
-
- pw = PlotWidget()
- toolbar = qt.QToolBar(pw)
- toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw)
- pw.addToolBar(toolbar)
- toolbar.addWidget(toolbutton)
- pw.show()
-
- x = numpy.arange(1000)
- y = x / numpy.sin(x)
- pw.addCurve(x, y)
-
- app.exec_()
-
-Singleton example
-+++++++++++++++++
-
-This example illustrates how to print the content of several different
-plots on the same page. The plots all instantiate a
-:class:`SingletonPrintPreviewToolButton`, which relies on a singleton widget
-(:class:`silx.gui.widgets.PrintPreview.SingletonPrintPreviewDialog`).
-
-.. image:: img/printPreviewMultiPlot.png
-
-.. code-block:: python
-
- from silx.gui import qt
- from silx.gui.plot import PlotWidget
- from silx.gui.plot.PrintPreviewToolButton import SingletonPrintPreviewToolButton
- import numpy
-
- app = qt.QApplication([])
-
- plot_widgets = []
-
- for i in range(3):
- pw = PlotWidget()
- toolbar = qt.QToolBar(pw)
- toolbutton = SingletonPrintPreviewToolButton(parent=toolbar,
- plot=pw)
- pw.addToolBar(toolbar)
- toolbar.addWidget(toolbutton)
- pw.show()
- plot_widgets.append(pw)
-
- x = numpy.arange(1000)
-
- plot_widgets[0].addCurve(x, numpy.sin(x * 2 * numpy.pi / 1000))
- plot_widgets[1].addCurve(x, numpy.cos(x * 2 * numpy.pi / 1000))
- plot_widgets[2].addCurve(x, numpy.tan(x * 2 * numpy.pi / 1000))
-
- app.exec_()
-
-"""
-from __future__ import absolute_import
-
-import logging
-from io import StringIO
-
-from .. import qt
-from .. import icons
-from . import PlotWidget
-from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
-from ..widgets.PrintGeometryDialog import PrintGeometryDialog
-from silx.utils.deprecation import deprecated
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "20/12/2018"
-
-_logger = logging.getLogger(__name__)
-# _logger.setLevel(logging.DEBUG)
-
-
-class PrintPreviewToolButton(qt.QToolButton):
- """QToolButton to open a :class:`PrintPreviewDialog` (if not already open)
- and add the current plot to its page to be printed.
-
- :param parent: See :class:`QAction`
- :param plot: :class:`.PlotWidget` instance on which to operate
- """
- def __init__(self, parent=None, plot=None):
- super(PrintPreviewToolButton, self).__init__(parent)
-
- if not isinstance(plot, PlotWidget):
- raise TypeError("plot parameter must be a PlotWidget")
- self._plot = plot
-
- self.setIcon(icons.getQIcon('document-print'))
-
- printGeomAction = qt.QAction("Print geometry", self)
- printGeomAction.setToolTip("Define a print geometry prior to sending "
- "the plot to the print preview dialog")
- printGeomAction.setIcon(icons.getQIcon('shape-rectangle'))
- printGeomAction.triggered.connect(self._setPrintConfiguration)
-
- printPreviewAction = qt.QAction("Print preview", self)
- printPreviewAction.setToolTip("Send plot to the print preview dialog")
- printPreviewAction.setIcon(icons.getQIcon('document-print'))
- printPreviewAction.triggered.connect(self._plotToPrintPreview)
-
- menu = qt.QMenu(self)
- menu.addAction(printGeomAction)
- menu.addAction(printPreviewAction)
- self.setMenu(menu)
- self.setPopupMode(qt.QToolButton.InstantPopup)
-
- self._printPreviewDialog = None
- self._printConfigurationDialog = None
-
- self._printGeometry = {"xOffset": 0.1,
- "yOffset": 0.1,
- "width": 0.9,
- "height": 0.9,
- "units": "page",
- "keepAspectRatio": True}
-
- @property
- def printPreviewDialog(self):
- """Lazy loaded :class:`PrintPreviewDialog`"""
- # if changes are made here, don't forget making them in
- # SingletonPrintPreviewToolButton.printPreviewDialog as well
- if self._printPreviewDialog is None:
- self._printPreviewDialog = PrintPreviewDialog(self.parent())
- return self._printPreviewDialog
-
- def getTitle(self):
- """Implement this method to fetch the title in the plot.
-
- :return: Title to be printed above the plot, or None (no title added)
- :rtype: str or None
- """
- return None
-
- def getCommentAndPosition(self):
- """Implement this method to fetch the legend to be printed below the
- figure and its position.
-
- :return: Legend to be printed below the figure and its position:
- "CENTER", "LEFT" or "RIGHT"
- :rtype: (str, str) or (None, None)
- """
- return None, None
-
- @property
- @deprecated(since_version="0.10",
- replacement="getPlot()")
- def plot(self):
- return self._plot
-
- def getPlot(self):
- """Return the :class:`.PlotWidget` associated with this tool button.
-
- :rtype: :class:`.PlotWidget`
- """
- return self._plot
-
- def _plotToPrintPreview(self):
- """Grab the plot widget and send it to the print preview dialog.
- Make sure the print preview dialog is shown and raised."""
- if not self.printPreviewDialog.ensurePrinterIsSet():
- return
-
- comment, commentPosition = self.getCommentAndPosition()
-
- if qt.HAS_SVG:
- svgRenderer, viewBox = self._getSvgRendererAndViewbox()
- self.printPreviewDialog.addSvgItem(svgRenderer,
- title=self.getTitle(),
- comment=comment,
- commentPosition=commentPosition,
- viewBox=viewBox,
- keepRatio=self._printGeometry["keepAspectRatio"])
- else:
- _logger.warning("Missing QtSvg library, using a raster image")
- if qt.BINDING in ["PyQt4", "PySide"]:
- pixmap = qt.QPixmap.grabWidget(self._plot.centralWidget())
- else:
- # PyQt5 and hopefully PyQt6+
- pixmap = self._plot.centralWidget().grab()
- self.printPreviewDialog.addPixmap(pixmap,
- title=self.getTitle(),
- comment=comment,
- commentPosition=commentPosition)
- self.printPreviewDialog.show()
- self.printPreviewDialog.raise_()
-
- def _getSvgRendererAndViewbox(self):
- """Return a SVG renderer displaying the plot and its viewbox
- (interactively specified by the user the first time this is called).
-
- The size of the renderer is adjusted to the printer configuration
- and to the geometry configuration (width, height, ratio) specified
- by the user."""
- imgData = StringIO()
- assert self._plot.saveGraph(imgData, fileFormat="svg"), \
- "Unable to save graph"
- imgData.flush()
- imgData.seek(0)
- svgData = imgData.read()
-
- svgRenderer = qt.QSvgRenderer()
-
- viewbox = self._getViewBox()
-
- svgRenderer.setViewBox(viewbox)
-
- xml_stream = qt.QXmlStreamReader(svgData.encode(errors="replace"))
-
- # This is for PyMca compatibility, to share a print preview with PyMca plots
- svgRenderer._viewBox = viewbox
- svgRenderer._svgRawData = svgData.encode(errors="replace")
- svgRenderer._svgRendererData = xml_stream
-
- if not svgRenderer.load(xml_stream):
- raise RuntimeError("Cannot interpret svg data")
-
- return svgRenderer, viewbox
-
- def _getViewBox(self):
- """
- """
- printer = self.printPreviewDialog.printer
- dpix = printer.logicalDpiX()
- dpiy = printer.logicalDpiY()
- availableWidth = printer.width()
- availableHeight = printer.height()
-
- config = self._printGeometry
- width = config['width']
- height = config['height']
- xOffset = config['xOffset']
- yOffset = config['yOffset']
- units = config['units']
- keepAspectRatio = config['keepAspectRatio']
- aspectRatio = self._getPlotAspectRatio()
-
- # convert the offsets to dots
- if units.lower() in ['inch', 'inches']:
- xOffset = xOffset * dpix
- yOffset = yOffset * dpiy
- if width is not None:
- width = width * dpix
- if height is not None:
- height = height * dpiy
- elif units.lower() in ['cm', 'centimeters']:
- xOffset = (xOffset / 2.54) * dpix
- yOffset = (yOffset / 2.54) * dpiy
- if width is not None:
- width = (width / 2.54) * dpix
- if height is not None:
- height = (height / 2.54) * dpiy
- else:
- # page units
- xOffset = availableWidth * xOffset
- yOffset = availableHeight * yOffset
- if width is not None:
- width = availableWidth * width
- if height is not None:
- height = availableHeight * height
-
- availableWidth -= xOffset
- availableHeight -= yOffset
-
- if width is not None:
- if (availableWidth + 0.1) < width:
- txt = "Available width %f is less than requested width %f" % \
- (availableWidth, width)
- raise ValueError(txt)
- if height is not None:
- if (availableHeight + 0.1) < height:
- txt = "Available height %f is less than requested height %f" % \
- (availableHeight, height)
- raise ValueError(txt)
-
- if keepAspectRatio:
- bodyWidth = width or availableWidth
- bodyHeight = bodyWidth * aspectRatio
-
- if bodyHeight > availableHeight:
- bodyHeight = availableHeight
- bodyWidth = bodyHeight / aspectRatio
-
- else:
- bodyWidth = width or availableWidth
- bodyHeight = height or availableHeight
-
- return qt.QRectF(xOffset,
- yOffset,
- bodyWidth,
- bodyHeight)
-
- def _setPrintConfiguration(self):
- """Open a dialog to prompt the user to adjust print
- geometry parameters."""
- self.printPreviewDialog.ensurePrinterIsSet()
- if self._printConfigurationDialog is None:
- self._printConfigurationDialog = PrintGeometryDialog(self.parent())
-
- self._printConfigurationDialog.setPrintGeometry(self._printGeometry)
- if self._printConfigurationDialog.exec_():
- self._printGeometry = self._printConfigurationDialog.getPrintGeometry()
-
- def _getPlotAspectRatio(self):
- widget = self._plot.centralWidget()
- graphWidth = float(widget.width())
- graphHeight = float(widget.height())
- return graphHeight / graphWidth
-
-
-class SingletonPrintPreviewToolButton(PrintPreviewToolButton):
- """This class is similar to its parent class :class:`PrintPreviewToolButton`
- but it uses a singleton print preview widget.
-
- This allows for several plots to send their content to the
- same print page, and for users to arrange them."""
- def __init__(self, parent=None, plot=None):
- PrintPreviewToolButton.__init__(self, parent, plot)
-
- @property
- def printPreviewDialog(self):
- if self._printPreviewDialog is None:
- self._printPreviewDialog = SingletonPrintPreviewDialog(self.parent())
- return self._printPreviewDialog
-
-
-if __name__ == '__main__':
- import numpy
- app = qt.QApplication([])
-
- pw = PlotWidget()
- toolbar = qt.QToolBar(pw)
- toolbutton = PrintPreviewToolButton(parent=toolbar,
- plot=pw)
- pw.addToolBar(toolbar)
- toolbar.addWidget(toolbutton)
- pw.show()
-
- x = numpy.arange(1000)
- y = x / numpy.sin(x)
- pw.addCurve(x, y)
-
- app.exec_()
diff --git a/silx/gui/plot/ROIStatsWidget.py b/silx/gui/plot/ROIStatsWidget.py
deleted file mode 100644
index 094d66a..0000000
--- a/silx/gui/plot/ROIStatsWidget.py
+++ /dev/null
@@ -1,780 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides widget for displaying statistics relative to a
-Region of interest and an item
-"""
-
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "22/07/2019"
-
-
-from contextlib import contextmanager
-from silx.gui import qt
-from silx.gui import icons
-from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container
-from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode
-from silx.gui.widgets.TableWidget import TableWidget
-from silx.gui.plot.items.roi import RegionOfInterest
-from silx.gui.plot import items as plotitems
-from silx.gui.plot.items.core import ItemChangedType
-from silx.gui.plot3d import items as plot3ditems
-from silx.gui.plot.CurvesROIWidget import ROI
-from silx.gui.plot import stats as statsmdl
-from collections import OrderedDict
-from silx.utils.proxy import docstring
-import silx.gui.plot.items.marker
-import silx.gui.plot.items.shape
-import functools
-import logging
-
-_logger = logging.getLogger(__name__)
-
-
-class _GetROIItemCoupleDialog(qt.QDialog):
- """
- Dialog used to know which plot item and which roi he wants
- """
- _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram')
-
- def __init__(self, parent=None, plot=None, rois=None):
- qt.QDialog.__init__(self, parent=parent)
- assert plot is not None
- assert rois is not None
- self._plot = plot
- self._rois = rois
-
- self.setLayout(qt.QVBoxLayout())
-
- # define the selection widget
- self._selection_widget = qt.QWidget()
- self._selection_widget.setLayout(qt.QHBoxLayout())
- self._kindCB = qt.QComboBox(parent=self)
- self._selection_widget.layout().addWidget(self._kindCB)
- self._itemCB = qt.QComboBox(parent=self)
- self._selection_widget.layout().addWidget(self._itemCB)
- self._roiCB = qt.QComboBox(parent=self)
- self._selection_widget.layout().addWidget(self._roiCB)
- self.layout().addWidget(self._selection_widget)
-
- # define modal buttons
- types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
- self._buttonsModal = qt.QDialogButtonBox(parent=self)
- self._buttonsModal.setStandardButtons(types)
- self.layout().addWidget(self._buttonsModal)
- self._buttonsModal.accepted.connect(self.accept)
- self._buttonsModal.rejected.connect(self.reject)
-
- # connect signal / slot
- self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi)
-
- def _getCompatibleRois(self, kind):
- """Return compatible rois for the given item kind"""
- def is_compatible(roi, kind):
- if isinstance(roi, RegionOfInterest):
- return kind in ('image', 'scatter')
- elif isinstance(roi, ROI):
- return kind in ('curve', 'histogram')
- else:
- raise ValueError('kind not managed')
- return list(filter(lambda x: is_compatible(x, kind), self._rois))
-
- def exec_(self):
- self._kindCB.clear()
- self._itemCB.clear()
- # filter kind without any items
- self._valid_kinds = {}
- # key is item type, value kinds
- self._valid_rois = {}
- # key is item type, value rois
- self._kind_name_to_roi = {}
- # key is (kind, roi name) value is roi
- self._kind_name_to_item = {}
- # key is (kind, legend name) value is item
- for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS:
- def getItems(kind):
- output = []
- for item in self._plot.getItems():
- type_ = self._plot._itemKind(item)
- if type_ in kind and item.isVisible():
- output.append(item)
- return output
-
- items = getItems(kind=kind)
- rois = self._getCompatibleRois(kind=kind)
- if len(items) > 0 and len(rois) > 0:
- self._valid_kinds[kind] = items
- self._valid_rois[kind] = rois
- for roi in rois:
- name = roi.getName()
- self._kind_name_to_roi[(kind, name)] = roi
- for item in items:
- self._kind_name_to_item[(kind, item.getLegend())] = item
-
- # filter roi according to kinds
- if len(self._valid_kinds) == 0:
- _logger.warning('no couple item/roi detected for displaying stats')
- return self.reject()
-
- for kind in self._valid_kinds:
- self._kindCB.addItem(kind)
- self._updateValidItemAndRoi()
-
- return qt.QDialog.exec_(self)
-
- def _updateValidItemAndRoi(self, *args, **kwargs):
- self._itemCB.clear()
- self._roiCB.clear()
- kind = self._kindCB.currentText()
- for roi in self._valid_rois[kind]:
- self._roiCB.addItem(roi.getName())
- for item in self._valid_kinds[kind]:
- self._itemCB.addItem(item.getLegend())
-
- def getROI(self):
- kind = self._kindCB.currentText()
- roi_name = self._roiCB.currentText()
- return self._kind_name_to_roi[(kind, roi_name)]
-
- def getItem(self):
- kind = self._kindCB.currentText()
- item_name = self._itemCB.currentText()
- return self._kind_name_to_item[(kind, item_name)]
-
-
-class ROIStatsItemHelper(object):
- """Item utils to associate a plot item and a roi
-
- Display on one row statistics regarding the couple
- (Item (plot item) / roi).
-
- :param Item plot_item: item for which we want statistics
- :param Union[ROI,RegionOfInterest]: region of interest to use for
- statistics.
- """
- def __init__(self, plot_item, roi):
- self._plot_item = plot_item
- self._roi = roi
-
- @property
- def roi(self):
- """roi"""
- return self._roi
-
- def roi_name(self):
- if isinstance(self._roi, ROI):
- return self._roi.getName()
- elif isinstance(self._roi, RegionOfInterest):
- return self._roi.getName()
- else:
- raise TypeError('Unmanaged roi type')
-
- @property
- def roi_kind(self):
- """roi class"""
- return self._roi.__class__
-
- # TODO: should call a util function from the wrapper ?
- def item_kind(self):
- """item kind"""
- if isinstance(self._plot_item, plotitems.Curve):
- return 'curve'
- elif isinstance(self._plot_item, plotitems.ImageData):
- return 'image'
- elif isinstance(self._plot_item, plotitems.Scatter):
- return 'scatter'
- elif isinstance(self._plot_item, plotitems.Histogram):
- return 'histogram'
- elif isinstance(self._plot_item, (plot3ditems.ImageData,
- plot3ditems.ScalarField3D)):
- return 'image'
- elif isinstance(self._plot_item, (plot3ditems.Scatter2D,
- plot3ditems.Scatter3D)):
- return 'scatter'
-
- @property
- def item_legend(self):
- """legend of the plot Item"""
- return self._plot_item.getLegend()
-
- def id_key(self):
- """unique key to represent the couple (item, roi)"""
- return (self.item_kind(), self.item_legend, self.roi_kind,
- self.roi_name())
-
-
-class _StatsROITable(_StatsWidgetBase, TableWidget):
- """
- Table sued to display some statistics regarding a couple (item/roi)
- """
- _LEGEND_HEADER_DATA = 'legend'
-
- _KIND_HEADER_DATA = 'kind'
-
- _ROI_HEADER_DATA = 'roi'
-
- sigUpdateModeChanged = qt.Signal(object)
- """Signal emitted when the update mode changed"""
-
- def __init__(self, parent, plot):
- TableWidget.__init__(self, parent)
- _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
- displayOnlyActItem=False)
- self.__region_edition_callback = {}
- """We need to keep trace of the roi signals connection because
- the roi emits the sigChanged during roi edition"""
- self._items = {}
- self.setRowCount(0)
- self.setColumnCount(3)
-
- # Init headers
- headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
- headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
- self.setHorizontalHeaderItem(0, headerItem)
- headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
- headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
- self.setHorizontalHeaderItem(1, headerItem)
- headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title())
- headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA)
- self.setHorizontalHeaderItem(2, headerItem)
-
- self.setSortingEnabled(True)
- self.setPlot(plot)
-
- self.__plotItemToItems = {}
- """Key is plotItem, values is list of __RoiStatsItemWidget"""
- self.__roiToItems = {}
- """Key is roi, values is list of __RoiStatsItemWidget"""
- self.__roisKeyToRoi = {}
-
- def add(self, item):
- assert isinstance(item, ROIStatsItemHelper)
- if item.id_key() in self._items:
- _logger.warning(item.id_key(), 'is already present')
- return None
- self._items[item.id_key()] = item
- self._addItem(item)
- return item
-
- def _addItem(self, item):
- """
- Add a _RoiStatsItemWidget item to the table.
-
- :param item:
- :return: True if successfully added.
- """
- if not isinstance(item, ROIStatsItemHelper):
- # skipped because also receive all new plot item (Marker...) that
- # we don't want to manage in this case.
- return
- # plotItem = item.getItem()
- # roi = item.getROI()
- kind = item.item_kind()
- if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
- _logger.info("Item has not a supported type: %s", item)
- return False
-
- # register the roi and the kind
- self._registerPlotItem(item)
- self._registerROI(item)
-
- # Prepare table items
- tableItems = [
- qt.QTableWidgetItem(), # Legend
- qt.QTableWidgetItem(), # Kind
- qt.QTableWidgetItem()] # roi
-
- for column in range(3, self.columnCount()):
- header = self.horizontalHeaderItem(column)
- name = header.data(qt.Qt.UserRole)
-
- formatter = self._statsHandler.formatters[name]
- if formatter:
- tableItem = formatter.tabWidgetItemClass()
- else:
- tableItem = qt.QTableWidgetItem()
-
- tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
- if tooltip is not None:
- tableItem.setToolTip(tooltip)
-
- tableItems.append(tableItem)
-
- # Disable sorting while adding table items
- with self._disableSorting():
- # Add a row to the table
- self.setRowCount(self.rowCount() + 1)
-
- # Add table items to the last row
- row = self.rowCount() - 1
- for column, tableItem in enumerate(tableItems):
- tableItem.setData(qt.Qt.UserRole, _Container(item))
- tableItem.setFlags(
- qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
- self.setItem(row, column, tableItem)
-
- # Update table items content
- self._updateStats(item, data_changed=True)
-
- # Listen for item changes
- # Using queued connection to avoid issue with sender
- # being that of the signal calling the signal
- item._plot_item.sigItemChanged.connect(self._plotItemChanged,
- qt.Qt.QueuedConnection)
- return True
-
- def _removeAllItems(self):
- for row in range(self.rowCount()):
- tableItem = self.item(row, 0)
- # item = self._tableItemToItem(tableItem)
- # item.sigItemChanged.disconnect(self._plotItemChanged)
- self.clearContents()
- self.setRowCount(0)
-
- def clear(self):
- self._removeAllItems()
-
- def setStats(self, statsHandler):
- """Set which stats to display and the associated formatting.
-
- :param StatsHandler statsHandler:
- Set the statistics to be displayed and how to format them using
- """
- self._removeAllItems()
- _StatsWidgetBase.setStats(self, statsHandler)
-
- self.setRowCount(0)
- self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa
-
- for index, stat in enumerate(self._statsHandler.stats.values()):
- headerItem = qt.QTableWidgetItem(stat.name.capitalize())
- headerItem.setData(qt.Qt.UserRole, stat.name)
- if stat.description is not None:
- headerItem.setToolTip(stat.description)
- self.setHorizontalHeaderItem(3 + index, headerItem)
-
- horizontalHeader = self.horizontalHeader()
- if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5
- horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
- else: # Qt4
- horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents)
-
- self._updateItemObserve()
-
- def _updateItemObserve(self, *args):
- pass
-
- def _dataChanged(self, item):
- pass
-
- def _updateStats(self, item, data_changed=False, roi_changed=False):
- assert isinstance(item, ROIStatsItemHelper)
- plotItem = item._plot_item
- roi = item._roi
- if item is None:
- return
- plot = self.getPlot()
- if plot is None:
- _logger.info("Plot not available")
- return
-
- row = self._itemToRow(item)
- if row is None:
- _logger.error("This item is not in the table: %s", str(item))
- return
-
- statsHandler = self.getStatsHandler()
- if statsHandler is not None:
- stats = statsHandler.calculate(plotItem, plot,
- onlimits=self._statsOnVisibleData,
- roi=roi, data_changed=data_changed,
- roi_changed=roi_changed)
- else:
- stats = {}
-
- with self._disableSorting():
- for name, tableItem in self._itemToTableItems(item).items():
- if name == self._LEGEND_HEADER_DATA:
- text = self._plotWrapper.getLabel(plotItem)
- tableItem.setText(text)
- elif name == self._KIND_HEADER_DATA:
- tableItem.setText(self._plotWrapper.getKind(plotItem))
- elif name == self._ROI_HEADER_DATA:
- name = roi.getName()
- tableItem.setText(name)
- else:
- value = stats.get(name)
- if value is None:
- _logger.error("Value not found for: %s", name)
- tableItem.setText('-')
- else:
- tableItem.setText(str(value))
-
- @contextmanager
- def _disableSorting(self):
- """Context manager that disables table sorting
-
- Previous state is restored when leaving
- """
- sorting = self.isSortingEnabled()
- if sorting:
- self.setSortingEnabled(False)
- yield
- if sorting:
- self.setSortingEnabled(sorting)
-
- def _itemToRow(self, item):
- """Find the row corresponding to a plot item
-
- :param item: The plot item
- :return: The corresponding row index
- :rtype: Union[int,None]
- """
- for row in range(self.rowCount()):
- tableItem = self.item(row, 0)
- if self._tableItemToItem(tableItem) == item:
- return row
- return None
-
- def _tableItemToItem(self, tableItem):
- """Find the plot item corresponding to a table item
-
- :param QTableWidgetItem tableItem:
- :rtype: QObject
- """
- container = tableItem.data(qt.Qt.UserRole)
- return container()
-
- def _itemToTableItems(self, item):
- """Find all table items corresponding to a plot item
-
- :param item: The plot item
- :return: An ordered dict of column name to QTableWidgetItem mapping
- for the given plot item.
- :rtype: OrderedDict
- """
- result = OrderedDict()
- row = self._itemToRow(item)
- if row is not None:
- for column in range(self.columnCount()):
- tableItem = self.item(row, column)
- if self._tableItemToItem(tableItem) != item:
- _logger.error("Table item/plot item mismatch")
- else:
- header = self.horizontalHeaderItem(column)
- name = header.data(qt.Qt.UserRole)
- result[name] = tableItem
- return result
-
- def _plotItemToItems(self, plotItem):
- """Return all _RoiStatsItemWidget associated to the plotItem
- Needed for updating on itemChanged signal
- """
- if plotItem in self.__plotItemToItems:
- return []
- else:
- return self.__plotItemToItems[plotItem]
-
- def _registerPlotItem(self, item):
- if item._plot_item not in self.__plotItemToItems:
- self.__plotItemToItems[item._plot_item] = set()
- self.__plotItemToItems[item._plot_item].add(item)
-
- def _roiToItems(self, roi):
- """Return all _RoiStatsItemWidget associated to the roi
- Needed for updating on roiChanged signal
- """
- if roi in self.__roiToItems:
- return []
- else:
- return self.__roiToItems[roi]
-
- def _registerROI(self, item):
- if item._roi not in self.__roiToItems:
- self.__roiToItems[item._roi] = set()
- # TODO: normalize also sig name
- if isinstance(item._roi, RegionOfInterest):
- # item connection within sigRegionChanged should only be
- # stopped during the region edition
- self.__region_edition_callback[item._roi] = functools.partial(
- self._updateAllStats, False, True)
- item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi])
- item._roi.sigEditingStarted.connect(functools.partial(
- self._startFiltering, item._roi))
- item._roi.sigEditingFinished.connect(functools.partial(
- self._endFiltering, item._roi))
- else:
- item._roi.sigChanged.connect(functools.partial(
- self._updateAllStats, False, True))
- self.__roiToItems[item._roi].add(item)
-
- def _startFiltering(self, roi):
- roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi])
-
- def _endFiltering(self, roi):
- roi.sigRegionChanged.connect(self.__region_edition_callback[roi])
- self._updateAllStats(roi_changed=True)
-
- def unregisterROI(self, roi):
- if roi in self.__roiToItems:
- del self.__roiToItems[roi]
- if isinstance(roi, RegionOfInterest):
- roi.sigRegionEditionStarted.disconnect(functools.partial(
- self._startFiltering, roi))
- roi.sigRegionEditionFinished.disconnect(functools.partial(
- self._startFiltering, roi))
- try:
- roi.sigRegionChanged.disconnect(self._updateAllStats)
- except:
- pass
- else:
- roi.sigChanged.disconnect(self._updateAllStats)
-
- def _plotItemChanged(self, event):
- """Handle modifications of the items.
-
- :param event:
- """
- if event is ItemChangedType.DATA:
- if self.getUpdateMode() is UpdateMode.MANUAL:
- return
- if self._skipPlotItemChangedEvent(event) is True:
- return
- else:
- sender = self.sender()
- for item in self.__plotItemToItems[sender]:
- # TODO: get all concerned items
- self._updateStats(item, data_changed=True)
- # deal with stat items visibility
- if event is ItemChangedType.VISIBLE:
- if len(self._itemToTableItems(item).items()) > 0:
- item_0 = list(self._itemToTableItems(item).values())[0]
- row_index = item_0.row()
- self.setRowHidden(row_index, not item.isVisible())
-
- def _removeItem(self, itemKey):
- if isinstance(itemKey, (silx.gui.plot.items.marker.Marker,
- silx.gui.plot.items.shape.Shape)):
- return
- if itemKey not in self._items:
- _logger.warning('key not recognized. Won\'t remove any item')
- return
- item = self._items[itemKey]
- row = self._itemToRow(item)
- if row is None:
- kind = self._plotWrapper.getKind(item)
- if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
- _logger.error("Removing item that is not in table: %s", str(item))
- return
- item._plot_item.sigItemChanged.disconnect(self._plotItemChanged)
- self.removeRow(row)
- del self._items[itemKey]
-
- def _updateAllStats(self, is_request=False, roi_changed=False):
- """Update stats for all rows in the table
-
- :param bool is_request: True if come from a manual request
- """
- if (self.getUpdateMode() is UpdateMode.MANUAL and
- not is_request and not roi_changed):
- return
-
- with self._disableSorting():
- for row in range(self.rowCount()):
- tableItem = self.item(row, 0)
- item = self._tableItemToItem(tableItem)
- self._updateStats(item, roi_changed=roi_changed,
- data_changed=is_request)
-
- def _plotCurrentChanged(self, *args):
- pass
-
- def _getRoi(self, kind, name):
- """return the roi fitting the requirement kind, name. This information
- is enough to be sure it is unique (in the widget)"""
- for roi in self.__roiToItems:
- roiName = roi.getName()
- if isinstance(roi, kind) and name == roiName:
- return roi
- return None
-
- def _getPlotItem(self, kind, legend):
- """return the plotItem fitting the requirement kind, legend.
- This information is enough to be sure it is unique (in the widget)"""
- for plotItem in self.__plotItemToItems:
- if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind:
- return plotItem
- return None
-
-
-class ROIStatsWidget(qt.QMainWindow):
- """
- Widget used to define stats item for a couple(roi, plotItem).
- Stats will be computing on a given item (curve, image...) in the given
- region of interest.
-
- It also provide an interface for adding and removing items.
-
- .. snapshotqt:: img/ROIStatsWidget.png
- :width: 300px
- :align: center
-
- from silx.gui import qt
- from silx.gui.plot import Plot2D
- from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
- from silx.gui.plot.items.roi import RectangleROI
- import numpy
- plot = Plot2D()
- plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img')
- plot.show()
- rectangleROI = RectangleROI()
- rectangleROI.setGeometry(origin=(0, 100), size=(20, 20))
- rectangleROI.setName('Initial ROI')
- widget = ROIStatsWidget(plot=plot)
- widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)])
- widget.registerROI(rectangleROI)
- widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img'))
- widget.show()
-
- :param Union[qt.QWidget,None] parent: parent qWidget
- :param PlotWindow plot: plot widget containing the items
- :param stats: stats to display
- :param tuple rois: tuple of rois to manage
- """
-
- def __init__(self, parent=None, plot=None, stats=None, rois=None):
- qt.QMainWindow.__init__(self, parent)
-
- toolbar = qt.QToolBar(self)
- icon = icons.getQIcon('add')
- self._rois = list(rois) if rois is not None else []
- self._addAction = qt.QAction(icon, 'add item/roi', toolbar)
- self._addAction.triggered.connect(self._addRoiStatsItem)
- icon = icons.getQIcon('rm')
- self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar)
- self._removeAction.triggered.connect(self._removeCurrentRow)
-
- toolbar.addAction(self._addAction)
- toolbar.addAction(self._removeAction)
- self.addToolBar(toolbar)
-
- self._plot = plot
- self._statsROITable = _StatsROITable(parent=self, plot=self._plot)
- self.setStats(stats=stats)
- self.setCentralWidget(self._statsROITable)
- self.setWindowFlags(qt.Qt.Widget)
-
- # expose API
- self._setUpdateMode = self._statsROITable.setUpdateMode
- self._updateAllStats = self._statsROITable._updateAllStats
-
- # setup
- self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows)
-
- def registerROI(self, roi):
- """For now there is no direct link between roi and plot. That is why
- we need to add/register them to be able to associate them"""
- self._rois.append(roi)
-
- def setPlot(self, plot):
- """Define the plot to interact with
-
- :param Union[PlotWidget,SceneWidget,None] plot:
- The plot containing the items on which statistics are applied
- """
- self._plot = plot
-
- def getPlot(self):
- return self._plot
-
- @docstring(_StatsROITable)
- def setStats(self, stats):
- if stats is not None:
- self._statsROITable.setStats(statsHandler=stats)
-
- @docstring(_StatsROITable)
- def getStatsHandler(self):
- """
-
- :return:
- """
- return self._statsROITable.getStatsHandler()
-
- def _addRoiStatsItem(self):
- """Ask the user what couple ROI / item he want to display"""
- dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot,
- rois=self._rois)
- if dialog.exec_():
- self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem())
-
- def addItem(self, plotItem, roi):
- """
- Add a row of statitstic regarding the couple (plotItem, roi)
-
- :param Item plotItem: item to use for statistics
- :param roi: region of interest to limit the statistic.
- :type: Union[ROI, RegionOfInterest]
- :return: None of failed to add the item
- :rtype: Union[None,ROIStatsItemHelper]
- """
- statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
- return self._statsROITable.add(item=statsItem)
-
- def removeItem(self, plotItem, roi):
- """
- Remove the row associated to the couple (plotItem, roi)
-
- :param Item plotItem: item to use for statistics
- :param roi: region of interest to limit the statistic.
- :type: Union[ROI,RegionOfInterest]
- """
- statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
- self._statsROITable._removeItem(itemKey=statsItem.id_key())
-
- def _removeCurrentRow(self):
- def is1DKind(kind):
- if kind in ('curve', 'histogram', 'scatter'):
- return True
- else:
- return False
-
- currentRow = self._statsROITable.currentRow()
- item_kind = self._statsROITable.item(currentRow, 1).text()
- item_legend = self._statsROITable.item(currentRow, 0).text()
-
- roi_name = self._statsROITable.item(currentRow, 2).text()
- roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest
- roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name)
- if roi is None:
- _logger.warning('failed to retrieve the roi you want to remove')
- return False
- plot_item = self._statsROITable._getPlotItem(kind=item_kind,
- legend=item_legend)
- if plot_item is None:
- _logger.warning('failed to retrieve the plot item you want to'
- 'remove')
- return False
- return self.removeItem(plotItem=plot_item, roi=roi)
diff --git a/silx/gui/plot/ScatterMaskToolsWidget.py b/silx/gui/plot/ScatterMaskToolsWidget.py
deleted file mode 100644
index 5ae8653..0000000
--- a/silx/gui/plot/ScatterMaskToolsWidget.py
+++ /dev/null
@@ -1,621 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Widget providing a set of tools to draw masks on a PlotWidget.
-
-This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget`
-
-- :class:`ScatterMask`: Handle scatter mask update and history
-- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask`
-- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
-"""
-
-from __future__ import division
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "15/02/2019"
-
-
-import math
-import logging
-import os
-import numpy
-import sys
-
-from .. import qt
-from ...math.combo import min_max
-from ...image import shapes
-
-from .items import ItemChangedType, Scatter
-from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
-from ..colors import cursorColorForColormap, rgba
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ScatterMask(BaseMask):
- """A 1D mask for scatter data.
- """
- def __init__(self, scatter=None):
- """
-
- :param scatter: :class:`silx.gui.plot.items.Scatter` instance
- """
- BaseMask.__init__(self, scatter)
-
- def _getXY(self):
- x = self._dataItem.getXData(copy=False)
- y = self._dataItem.getYData(copy=False)
- return x, y
-
- def getDataValues(self):
- """Return scatter data values as a 1D array.
-
- :rtype: 1D numpy.ndarray
- """
- return self._dataItem.getValueData(copy=False)
-
- def save(self, filename, kind):
- if kind == 'npy':
- try:
- numpy.save(filename, self.getMask(copy=False))
- except IOError:
- raise RuntimeError("Mask file can't be written")
- elif kind in ["csv", "txt"]:
- try:
- numpy.savetxt(filename, self.getMask(copy=False))
- except IOError:
- raise RuntimeError("Mask file can't be written")
-
- def updatePoints(self, level, indices, mask=True):
- """Mask/Unmask points with given indices.
-
- :param int level: Mask level to update.
- :param indices: Sequence or 1D array of indices of points to be
- updated
- :param bool mask: True to mask (default), False to unmask.
- """
- if mask:
- self._mask[indices] = level
- else:
- # unmask only where mask level is the specified value
- indices_stencil = numpy.zeros_like(self._mask, dtype=bool)
- indices_stencil[indices] = True
- self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0
- self._notify()
-
- # update shapes
- def updatePolygon(self, level, vertices, mask=True):
- """Mask/Unmask a polygon of the given mask level.
-
- :param int level: Mask level to update.
- :param vertices: Nx2 array of polygon corners as (y, x) or (row, col)
- :param bool mask: True to mask (default), False to unmask.
- """
- polygon = shapes.Polygon(vertices)
- x, y = self._getXY()
-
- # TODO: this could be optimized if necessary
- indices_in_polygon = [idx for idx in range(len(x)) if
- polygon.is_inside(y[idx], x[idx])]
-
- self.updatePoints(level, indices_in_polygon, mask)
-
- def updateRectangle(self, level, y, x, height, width, mask=True):
- """Mask/Unmask data inside a rectangle
-
- :param int level: Mask level to update.
- :param float y: Y coordinate of bottom left corner of the rectangle
- :param float x: X coordinate of bottom left corner of the rectangle
- :param float height:
- :param float width:
- :param bool mask: True to mask (default), False to unmask.
- """
- vertices = [(y, x),
- (y + height, x),
- (y + height, x + width),
- (y, x + width)]
- self.updatePolygon(level, vertices, mask)
-
- def updateDisk(self, level, cy, cx, radius, mask=True):
- """Mask/Unmask a disk of the given mask level.
-
- :param int level: Mask level to update.
- :param float cy: Disk center (y).
- :param float cx: Disk center (x).
- :param float radius: Radius of the disk in mask array unit
- :param bool mask: True to mask (default), False to unmask.
- """
- x, y = self._getXY()
- stencil = (y - cy)**2 + (x - cx)**2 < radius**2
- self.updateStencil(level, stencil, mask)
-
- def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
- """Mask/Unmask an ellipse of the given mask level.
-
- :param int level: Mask level to update.
- :param int crow: Row of the center of the ellipse
- :param int ccol: Column of the center of the ellipse
- :param float radius_r: Radius of the ellipse in the row
- :param float radius_c: Radius of the ellipse in the column
- :param bool mask: True to mask (default), False to unmask.
- """
- def is_inside(px, py):
- return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0
- x, y = self._getXY()
- indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
- self.updatePoints(level, indices_inside, mask)
-
- def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
- """Mask/Unmask points inside a rectangle defined by a line (two
- end points) and a width.
-
- :param int level: Mask level to update.
- :param float y0: Row of the starting point.
- :param float x0: Column of the starting point.
- :param float row1: Row of the end point.
- :param float col1: Column of the end point.
- :param float width: Width of the line.
- :param bool mask: True to mask (default), False to unmask.
- """
- # theta is the angle between the horizontal and the line
- theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0
- w_over_2_sin_theta = width / 2. * math.sin(theta)
- w_over_2_cos_theta = width / 2. * math.cos(theta)
-
- vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
- (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
- (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
- (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)]
-
- self.updatePolygon(level, vertices, mask)
-
-
-class ScatterMaskToolsWidget(BaseMaskToolsWidget):
- """Widget with tools for masking data points on a scatter in a
- :class:`PlotWidget`."""
-
- def __init__(self, parent=None, plot=None):
- super(ScatterMaskToolsWidget, self).__init__(parent, plot,
- mask=ScatterMask())
- self._z = 2 # Mask layer in plot
- self._data_scatter = None
- """plot Scatter item for data"""
-
- self._data_extent = None
- """Maximum extent of the data i.e., max(xMax-xMin, yMax-yMin)"""
-
- self._mask_scatter = None
- """plot Scatter item for representing the mask"""
-
- def setSelectionMask(self, mask, copy=True):
- """Set the mask to a new array.
-
- :param numpy.ndarray mask:
- The array to use for the mask or None to reset the mask.
- :type mask: numpy.ndarray of uint8, C-contiguous.
- Array of other types are converted.
- :param bool copy: True (the default) to copy the array,
- False to use it as is if possible.
- :return: None if failed, shape of mask as 1-tuple if successful.
- The mask can be cropped or padded to fit active scatter,
- the returned shape is that of the scatter data.
- """
- if self._data_scatter is None:
- # this can happen if the mask tools widget has never been shown
- self._data_scatter = self.plot._getActiveItem(kind="scatter")
- if self._data_scatter is None:
- return None
- self._adjustColorAndBrushSize(self._data_scatter)
-
- if mask is None:
- self.resetSelectionMask()
- return self._data_scatter.getXData(copy=False).shape
-
- mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
-
- if self._data_scatter.getXData(copy=False).shape == (0,) \
- or mask.shape == self._data_scatter.getXData(copy=False).shape:
- self._mask.setMask(mask, copy=copy)
- self._mask.commit()
- return mask.shape
- else:
- raise ValueError("Mask does not have the same shape as the data")
-
- # Handle mask refresh on the plot
-
- def _updatePlotMask(self):
- """Update mask image in plot"""
- mask = self.getSelectionMask(copy=False)
- if mask is not None:
- self.plot.addScatter(self._data_scatter.getXData(),
- self._data_scatter.getYData(),
- mask,
- legend=self._maskName,
- colormap=self._colormap,
- z=self._z)
- self._mask_scatter = self.plot._getItem(kind="scatter",
- legend=self._maskName)
- self._mask_scatter.setSymbolSize(
- self._data_scatter.getSymbolSize() + 2.0)
- self._mask_scatter.sigItemChanged.connect(self.__maskScatterChanged)
- elif self.plot._getItem(kind="scatter",
- legend=self._maskName) is not None:
- self.plot.remove(self._maskName, kind='scatter')
-
- def __maskScatterChanged(self, event):
- """Handles update of mask scatter"""
- if (event is ItemChangedType.VISUALIZATION_MODE and
- self._mask_scatter is not None):
- self._mask_scatter.setVisualization(Scatter.Visualization.POINTS)
-
- # track widget visibility and plot active image changes
-
- def showEvent(self, event):
- try:
- self.plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChangedAfterCare)
- except (RuntimeError, TypeError):
- pass
- self._activeScatterChanged(None, None) # Init mask + enable/disable widget
- self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
-
- def hideEvent(self, event):
- try:
- # if the method is not connected this raises a TypeError and there is no way
- # to know the connected slots
- self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
- except (RuntimeError, TypeError):
- _logger.info(sys.exc_info()[1])
- if not self.browseAction.isChecked():
- self.browseAction.trigger() # Disable drawing tool
-
- if self.getSelectionMask(copy=False) is not None:
- self.plot.sigActiveScatterChanged.connect(
- self._activeScatterChangedAfterCare)
-
- def _adjustColorAndBrushSize(self, activeScatter):
- colormap = activeScatter.getColormap()
- self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name']))
- self._setMaskColors(self.levelSpinBox.value(),
- self.transparencySlider.value() /
- self.transparencySlider.maximum())
- self._z = activeScatter.getZValue() + 1
- self._data_scatter = activeScatter
-
- # Adjust brush size to data range
- xData = self._data_scatter.getXData(copy=False)
- yData = self._data_scatter.getYData(copy=False)
- # Adjust brush size to data range
- if xData.size > 0 and yData.size > 0:
- xMin, xMax = min_max(xData)
- yMin, yMax = min_max(yData)
- self._data_extent = max(xMax - xMin, yMax - yMin)
- else:
- self._data_extent = None
-
- def _activeScatterChangedAfterCare(self, previous, next):
- """Check synchro of active scatter and mask when mask widget is hidden.
-
- If active image has no more the same size as the mask, the mask is
- removed, otherwise it is adjusted to z.
- """
- # check that content changed was the active scatter
- activeScatter = self.plot._getActiveItem(kind="scatter")
-
- if activeScatter is None or activeScatter.getName() == self._maskName:
- # No active scatter or active scatter is the mask...
- self.plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChangedAfterCare)
- self._data_extent = None
- self._data_scatter = None
-
- else:
- self._adjustColorAndBrushSize(activeScatter)
-
- if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
- # scatter has not the same size, remove mask and stop listening
- if self.plot._getItem(kind="scatter", legend=self._maskName):
- self.plot.remove(self._maskName, kind='scatter')
-
- self.plot.sigActiveScatterChanged.disconnect(
- self._activeScatterChangedAfterCare)
- self._data_extent = None
- self._data_scatter = None
-
- else:
- # Refresh in case z changed
- self._mask.setDataItem(self._data_scatter)
- self._updatePlotMask()
-
- def _activeScatterChanged(self, previous, next):
- """Update widget and mask according to active scatter changes"""
- activeScatter = self.plot._getActiveItem(kind="scatter")
-
- if activeScatter is None or activeScatter.getName() == self._maskName:
- # No active scatter or active scatter is the mask...
- self.setEnabled(False)
-
- self._data_scatter = None
- self._data_extent = None
- self._mask.reset()
- self._mask.commit()
-
- else: # There is an active scatter
- self.setEnabled(True)
- self._adjustColorAndBrushSize(activeScatter)
-
- self._mask.setDataItem(self._data_scatter)
- if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
- self._mask.reset(self._data_scatter.getXData(copy=False).shape)
- self._mask.commit()
- else:
- # Refresh in case z changed
- self._updatePlotMask()
-
- self._updateInteractiveMode()
-
- # Handle whole mask operations
-
- def load(self, filename):
- """Load a mask from an image file.
-
- :param str filename: File name from which to load the mask
- :raise Exception: An exception in case of failure
- :raise RuntimeWarning: In case the mask was applied but with some
- import changes to notice
- """
- _, extension = os.path.splitext(filename)
- extension = extension.lower()[1:]
- if extension == "npy":
- try:
- mask = numpy.load(filename)
- except IOError:
- _logger.error("Can't load filename '%s'", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise RuntimeError('File "%s" is not a numpy file.',
- filename)
- elif extension in ["txt", "csv"]:
- try:
- mask = numpy.loadtxt(filename)
- except IOError:
- _logger.error("Can't load filename '%s'", filename)
- _logger.debug("Backtrace", exc_info=True)
- raise RuntimeError('File "%s" is not a numpy txt file.',
- filename)
- else:
- msg = "Extension '%s' is not supported."
- raise RuntimeError(msg % extension)
-
- self.setSelectionMask(mask, copy=False)
-
- def _loadMask(self):
- """Open load mask dialog"""
- dialog = qt.QFileDialog(self)
- dialog.setWindowTitle("Load Mask")
- dialog.setModal(1)
- filters = [
- 'NumPy binary file (*.npy)',
- 'CSV text file (*.csv)',
- ]
- dialog.setNameFilters(filters)
- dialog.setFileMode(qt.QFileDialog.ExistingFile)
- dialog.setDirectory(self.maskFileDir)
- if not dialog.exec_():
- dialog.close()
- return
-
- filename = dialog.selectedFiles()[0]
- dialog.close()
-
- # Update the directory according to the user selection
- self.maskFileDir = os.path.dirname(filename)
-
- try:
- self.load(filename)
- # except RuntimeWarning as e:
- # message = e.args[0]
- # msg = qt.QMessageBox(self)
- # msg.setIcon(qt.QMessageBox.Warning)
- # msg.setText("Mask loaded but an operation was applied.\n" + message)
- # msg.exec_()
- except Exception as e:
- message = e.args[0]
- msg = qt.QMessageBox(self)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setText("Cannot load mask from file. " + message)
- msg.exec_()
-
- def _saveMask(self):
- """Open Save mask dialog"""
- dialog = qt.QFileDialog(self)
- dialog.setWindowTitle("Save Mask")
- dialog.setModal(1)
- filters = [
- 'NumPy binary file (*.npy)',
- 'CSV text file (*.csv)',
- ]
- dialog.setNameFilters(filters)
- dialog.setFileMode(qt.QFileDialog.AnyFile)
- dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
- dialog.setDirectory(self.maskFileDir)
- if not dialog.exec_():
- dialog.close()
- return
-
- # convert filter name to extension name with the .
- extension = dialog.selectedNameFilter().split()[-1][2:-1]
- filename = dialog.selectedFiles()[0]
- dialog.close()
-
- if not filename.lower().endswith(extension):
- filename += extension
-
- if os.path.exists(filename):
- try:
- os.remove(filename)
- except IOError as e:
- msg = qt.QMessageBox(self)
- msg.setWindowTitle("Removing existing file")
- msg.setIcon(qt.QMessageBox.Critical)
-
- if hasattr(e, "strerror"):
- strerror = e.strerror
- else:
- strerror = sys.exc_info()[1]
- msg.setText("Cannot save.\n"
- "Input Output Error: %s" % strerror)
- msg.exec_()
- return
-
- # Update the directory according to the user selection
- self.maskFileDir = os.path.dirname(filename)
-
- try:
- self.save(filename, extension[1:])
- except Exception as e:
- msg = qt.QMessageBox(self)
- msg.setWindowTitle("Saving mask file")
- msg.setIcon(qt.QMessageBox.Critical)
-
- if hasattr(e, "strerror"):
- strerror = e.strerror
- else:
- strerror = sys.exc_info()[1]
- msg.setText("Cannot save file %s\n%s" % (filename, strerror))
- msg.exec_()
-
- def resetSelectionMask(self):
- """Reset the mask"""
- self._mask.reset(
- shape=self._data_scatter.getXData(copy=False).shape)
- self._mask.commit()
-
- def _getPencilWidth(self):
- """Returns the width of the pencil to use in data coordinates`
-
- :rtype: float
- """
- width = super(ScatterMaskToolsWidget, self)._getPencilWidth()
- if self._data_extent is not None:
- width *= 0.01 * self._data_extent
- return width
-
- def _plotDrawEvent(self, event):
- """Handle draw events from the plot"""
- if (self._drawingMode is None or
- event['event'] not in ('drawingProgress', 'drawingFinished')):
- return
-
- if not len(self._data_scatter.getXData(copy=False)):
- return
-
- level = self.levelSpinBox.value()
-
- if self._drawingMode == 'rectangle':
- if event['event'] == 'drawingFinished':
- doMask = self._isMasking()
-
- self._mask.updateRectangle(
- level,
- y=event['y'],
- x=event['x'],
- height=abs(event['height']),
- width=abs(event['width']),
- mask=doMask)
- self._mask.commit()
-
- elif self._drawingMode == 'ellipse':
- if event['event'] == 'drawingFinished':
- doMask = self._isMasking()
- center = event['points'][0]
- size = event['points'][1]
- self._mask.updateEllipse(level, center[1], center[0],
- size[1], size[0], doMask)
- self._mask.commit()
-
- elif self._drawingMode == 'polygon':
- if event['event'] == 'drawingFinished':
- doMask = self._isMasking()
- vertices = event['points']
- vertices = vertices[:, (1, 0)] # (y, x)
- self._mask.updatePolygon(level, vertices, doMask)
- self._mask.commit()
-
- elif self._drawingMode == 'pencil':
- doMask = self._isMasking()
- # convert from plot to array coords
- x, y = event['points'][-1]
-
- brushSize = self._getPencilWidth()
-
- if self._lastPencilPos != (y, x):
- if self._lastPencilPos is not None:
- # Draw the line
- self._mask.updateLine(
- level,
- self._lastPencilPos[0], self._lastPencilPos[1],
- y, x,
- brushSize,
- doMask)
-
- # Draw the very first, or last point
- self._mask.updateDisk(level, y, x, brushSize / 2., doMask)
-
- if event['event'] == 'drawingFinished':
- self._mask.commit()
- self._lastPencilPos = None
- else:
- self._lastPencilPos = y, x
- else:
- _logger.error("Drawing mode %s unsupported", self._drawingMode)
-
- def _loadRangeFromColormapTriggered(self):
- """Set range from active scatter colormap range"""
- if self._data_scatter is not None:
- # Update thresholds according to colormap
- colormap = self._data_scatter.getColormap()
- if colormap['autoscale']:
- min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False))
- max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False))
- else:
- min_, max_ = colormap['vmin'], colormap['vmax']
- self.minLineEdit.setText(str(min_))
- self.maxLineEdit.setText(str(max_))
-
-
-class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget):
- """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget.
-
- For integration in a :class:`PlotWindow`.
-
- :param parent: See :class:`QDockWidget`
- :param plot: The PlotWidget this widget is operating on
- :paran str name: The title of this widget
- """
- def __init__(self, parent=None, plot=None, name='Mask'):
- widget = ScatterMaskToolsWidget(plot=plot)
- super(ScatterMaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/silx/gui/plot/ScatterView.py b/silx/gui/plot/ScatterView.py
deleted file mode 100644
index 0423648..0000000
--- a/silx/gui/plot/ScatterView.py
+++ /dev/null
@@ -1,405 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""A widget dedicated to display scatter plots
-
-It is based on a :class:`~silx.gui.plot.PlotWidget` with additional tools
-for scatter plots.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "14/06/2018"
-
-
-import logging
-import weakref
-
-import numpy
-
-from . import items
-from . import PlotWidget
-from . import tools
-from .actions import histogram as actions_histogram
-from .tools.profile import ScatterProfileToolBar
-from .ColorBar import ColorBarWidget
-from .ScatterMaskToolsWidget import ScatterMaskToolsWidget
-
-from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
-from .. import qt, icons
-from ...utils.proxy import docstring
-from ...utils.weakref import WeakMethodProxy
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ScatterView(qt.QMainWindow):
- """Main window with a PlotWidget and tools specific for scatter plots.
-
- :param parent: The parent of this widget
- :param backend: The backend to use for the plot (default: matplotlib).
- See :class:`~silx.gui.plot.PlotWidget` for the list of supported backend.
- :type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase]
- """
-
- _SCATTER_LEGEND = ' '
- """Legend used for the scatter item"""
-
- def __init__(self, parent=None, backend=None):
- super(ScatterView, self).__init__(parent=parent)
- if parent is not None:
- # behave as a widget
- self.setWindowFlags(qt.Qt.Widget)
- else:
- self.setWindowTitle('ScatterView')
-
- # Create plot widget
- plot = PlotWidget(parent=self, backend=backend)
- self._plot = weakref.ref(plot)
-
- # Add an empty scatter
- self.__createEmptyScatter()
-
- # Create colorbar widget with white background
- self._colorbar = ColorBarWidget(parent=self, plot=plot)
- self._colorbar.setAutoFillBackground(True)
- palette = self._colorbar.palette()
- palette.setColor(qt.QPalette.Background, qt.Qt.white)
- palette.setColor(qt.QPalette.Window, qt.Qt.white)
- self._colorbar.setPalette(palette)
-
- # Create PositionInfo widget
- self.__lastPickingPos = None
- self.__pickingCache = None
- self._positionInfo = tools.PositionInfo(
- plot=plot,
- converters=(('X', WeakMethodProxy(self._getPickedX)),
- ('Y', WeakMethodProxy(self._getPickedY)),
- ('Data', WeakMethodProxy(self._getPickedValue)),
- ('Index', WeakMethodProxy(self._getPickedIndex))))
-
- # Combine plot, position info and colorbar into central widget
- gridLayout = qt.QGridLayout()
- gridLayout.setSpacing(0)
- gridLayout.setContentsMargins(0, 0, 0, 0)
- gridLayout.addWidget(plot, 0, 0)
- gridLayout.addWidget(self._colorbar, 0, 1)
- gridLayout.addWidget(self._positionInfo, 1, 0, 1, -1)
- gridLayout.setRowStretch(0, 1)
- gridLayout.setColumnStretch(0, 1)
- centralWidget = qt.QWidget(self)
- centralWidget.setLayout(gridLayout)
- self.setCentralWidget(centralWidget)
-
- # Create mask tool dock widget
- self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot)
- self._maskDock = BoxLayoutDockWidget()
- self._maskDock.setWindowTitle('Scatter Mask')
- self._maskDock.setWidget(self._maskToolsWidget)
- self._maskDock.setVisible(False)
- self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock)
-
- self._maskAction = self._maskDock.toggleViewAction()
- self._maskAction.setIcon(icons.getQIcon('image-mask'))
- self._maskAction.setToolTip("Display/hide mask tools")
-
- self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(plot=plot, parent=self)
-
- # Create toolbars
- self._interactiveModeToolBar = tools.InteractiveModeToolBar(
- parent=self, plot=plot)
-
- self._scatterToolBar = tools.ScatterToolBar(
- parent=self, plot=plot)
- self._scatterToolBar.addAction(self._maskAction)
- self._scatterToolBar.addAction(self._intensityHistoAction)
-
- self._profileToolBar = ScatterProfileToolBar(parent=self, plot=plot)
-
- self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot)
-
- # Activate shortcuts in PlotWindow widget:
- for toolbar in (self._interactiveModeToolBar,
- self._scatterToolBar,
- self._profileToolBar,
- self._outputToolBar):
- self.addToolBar(toolbar)
- for action in toolbar.actions():
- self.addAction(action)
-
-
- def __createEmptyScatter(self):
- """Create an empty scatter item that is used to display the data
-
- :rtype: ~silx.gui.plot.items.Scatter
- """
- plot = self.getPlotWidget()
- plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND)
- scatter = plot._getItem(
- kind='scatter', legend=self._SCATTER_LEGEND)
- # Profile is not selectable,
- # so it does not interfere with profile interaction
- scatter._setSelectable(False)
- return scatter
-
- def _pickScatterData(self, x, y):
- """Get data and index and value of top most scatter plot at position (x, y)
-
- :param float x: X position in plot coordinates
- :param float y: Y position in plot coordinates
- :return: The data index and value at that point or None
- """
- pickingPos = x, y
- if self.__lastPickingPos != pickingPos:
- self.__pickingCache = None
- self.__lastPickingPos = pickingPos
-
- plot = self.getPlotWidget()
- if plot is not None:
- pixelPos = plot.dataToPixel(x, y)
- if pixelPos is not None:
- # Start from top-most item
- result = plot._pickTopMost(
- pixelPos[0], pixelPos[1],
- lambda item: isinstance(item, items.Scatter))
- if result is not None:
- item = result.getItem()
- if item.getVisualization() is items.Scatter.Visualization.BINNED_STATISTIC:
- # Get highest index of closest points
- selected = result.getIndices(copy=False)[::-1]
- dataIndex = selected[numpy.argmin(
- (item.getXData(copy=False)[selected] - x)**2 +
- (item.getYData(copy=False)[selected] - y)**2)]
- else:
- # Get last index
- # with matplotlib it should be the top-most point
- dataIndex = result.getIndices(copy=False)[-1]
- self.__pickingCache = (
- dataIndex,
- item.getXData(copy=False)[dataIndex],
- item.getYData(copy=False)[dataIndex],
- item.getValueData(copy=False)[dataIndex])
-
- return self.__pickingCache
-
- def _getPickedIndex(self, x, y):
- """Get data index of top most scatter plot at position (x, y)
-
- :param float x: X position in plot coordinates
- :param float y: Y position in plot coordinates
- :return: The data index at that point or '-'
- """
- picking = self._pickScatterData(x, y)
- return '-' if picking is None else picking[0]
-
- def _getPickedX(self, x, y):
- """Returns X position snapped to scatter plot when close enough
-
- :param float x:
- :param float y:
- :rtype: float
- """
- picking = self._pickScatterData(x, y)
- return x if picking is None else picking[1]
-
- def _getPickedY(self, x, y):
- """Returns Y position snapped to scatter plot when close enough
-
- :param float x:
- :param float y:
- :rtype: float
- """
- picking = self._pickScatterData(x, y)
- return y if picking is None else picking[2]
-
- def _getPickedValue(self, x, y):
- """Get data value of top most scatter plot at position (x, y)
-
- :param float x: X position in plot coordinates
- :param float y: Y position in plot coordinates
- :return: The data value at that point or '-'
- """
- picking = self._pickScatterData(x, y)
- return '-' if picking is None else picking[3]
-
- def _mouseInPlotArea(self, x, y):
- """Clip mouse coordinates to plot area coordinates
-
- :param float x: X position in pixels
- :param float y: Y position in pixels
- :return: (x, y) in data coordinates
- """
- plot = self.getPlotWidget()
- left, top, width, height = plot.getPlotBoundsInPixels()
- xPlot = numpy.clip(x, left, left + width - 1)
- yPlot = numpy.clip(y, top, top + height - 1)
- return xPlot, yPlot
-
- def getPlotWidget(self):
- """Returns the :class:`~silx.gui.plot.PlotWidget` this window is based on.
-
- :rtype: ~silx.gui.plot.PlotWidget
- """
- return self._plot()
-
- def getPositionInfoWidget(self):
- """Returns the widget display mouse coordinates information.
-
- :rtype: ~silx.gui.plot.tools.PositionInfo
- """
- return self._positionInfo
-
- def getMaskToolsWidget(self):
- """Returns the widget controlling mask drawing
-
- :rtype: ~silx.gui.plot.ScatterMaskToolsWidget
- """
- return self._maskToolsWidget
-
- def getInteractiveModeToolBar(self):
- """Returns QToolBar controlling interactive mode.
-
- :rtype: ~silx.gui.plot.tools.InteractiveModeToolBar
- """
- return self._interactiveModeToolBar
-
- def getScatterToolBar(self):
- """Returns QToolBar providing scatter plot tools.
-
- :rtype: ~silx.gui.plot.tools.ScatterToolBar
- """
- return self._scatterToolBar
-
- def getScatterProfileToolBar(self):
- """Returns QToolBar providing scatter profile tools.
-
- :rtype: ~silx.gui.plot.tools.profile.ScatterProfileToolBar
- """
- return self._profileToolBar
-
- def getOutputToolBar(self):
- """Returns QToolBar containing save, copy and print actions
-
- :rtype: ~silx.gui.plot.tools.OutputToolBar
- """
- return self._outputToolBar
-
- def setColormap(self, colormap=None):
- """Set the colormap for the displayed scatter and the
- default plot colormap.
-
- :param ~silx.gui.colors.Colormap colormap:
- The description of the colormap.
- """
- self.getScatterItem().setColormap(colormap)
- # Resilient to call to PlotWidget API (e.g., clear)
- self.getPlotWidget().setDefaultColormap(colormap)
-
- def getColormap(self):
- """Return the colormap object in use.
-
- :return: Colormap currently in use
- :rtype: ~silx.gui.colors.Colormap
- """
- return self.getScatterItem().getColormap()
-
- # Control displayed scatter plot
-
- def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
- """Set the data of the scatter plot.
-
- To reset the scatter plot, set x, y and value to None.
-
- :param Union[numpy.ndarray,None] x: X coordinates.
- :param Union[numpy.ndarray,None] y: Y coordinates.
- :param Union[numpy.ndarray,None] value:
- The data corresponding to the value of the data points.
- :param xerror: Values with the uncertainties on the x values.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for positive errors,
- row 1 for negative errors.
- :type xerror: A float, or a numpy.ndarray of float32.
-
- :param yerror: Values with the uncertainties on the y values
- :type yerror: A float, or a numpy.ndarray of float32. See xerror.
- :param alpha: Values with the transparency (between 0 and 1)
- :type alpha: A float, or a numpy.ndarray of float32
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- """
- x = () if x is None else x
- y = () if y is None else y
- value = () if value is None else value
-
- self.getScatterItem().setData(
- x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy)
-
- @docstring(items.Scatter)
- def getData(self, *args, **kwargs):
- return self.getScatterItem().getData(*args, **kwargs)
-
- def getScatterItem(self):
- """Returns the plot item displaying the scatter data.
-
- This allows to set the style of the displayed scatter.
-
- :rtype: ~silx.gui.plot.items.Scatter
- """
- plot = self.getPlotWidget()
- scatter = plot._getItem(kind='scatter', legend=self._SCATTER_LEGEND)
- if scatter is None: # Resilient to call to PlotWidget API (e.g., clear)
- scatter = self.__createEmptyScatter()
- return scatter
-
- # Convenient proxies
-
- @docstring(PlotWidget)
- def getXAxis(self, *args, **kwargs):
- return self.getPlotWidget().getXAxis(*args, **kwargs)
-
- @docstring(PlotWidget)
- def getYAxis(self, *args, **kwargs):
- return self.getPlotWidget().getYAxis(*args, **kwargs)
-
- @docstring(PlotWidget)
- def setGraphTitle(self, *args, **kwargs):
- return self.getPlotWidget().setGraphTitle(*args, **kwargs)
-
- @docstring(PlotWidget)
- def getGraphTitle(self, *args, **kwargs):
- return self.getPlotWidget().getGraphTitle(*args, **kwargs)
-
- @docstring(PlotWidget)
- def resetZoom(self, *args, **kwargs):
- return self.getPlotWidget().resetZoom(*args, **kwargs)
-
- @docstring(ScatterMaskToolsWidget)
- def getSelectionMask(self, *args, **kwargs):
- return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs)
-
- @docstring(ScatterMaskToolsWidget)
- def setSelectionMask(self, *args, **kwargs):
- return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs)
diff --git a/silx/gui/plot/StackView.py b/silx/gui/plot/StackView.py
deleted file mode 100644
index 40e0661..0000000
--- a/silx/gui/plot/StackView.py
+++ /dev/null
@@ -1,1254 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""QWidget displaying a 3D volume as a stack of 2D images.
-
-The :class:`StackView` class implements this widget.
-
-Basic usage of :class:`StackView` is through the following methods:
-
-- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the
- default colormap to use and update the currently displayed image.
-- :meth:`StackView.setStack` to update the displayed image.
-
-The :class:`StackView` uses :class:`PlotWindow` and also
-exposes a subset of the :class:`silx.gui.plot.Plot` API for further control
-(plot title, axes labels, ...).
-
-The :class:`StackViewMainWindow` class implements a widget that adds a status
-bar displaying the 3D index and the value under the mouse cursor.
-
-Example::
-
- import numpy
- import sys
- from silx.gui import qt
- from silx.gui.plot.StackView import StackViewMainWindow
-
-
- app = qt.QApplication(sys.argv[1:])
-
- # synthetic data, stack of 100 images of size 200x300
- mystack = numpy.fromfunction(
- lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
- (100, 200, 300)
- )
-
-
- sv = StackViewMainWindow()
- sv.setColormap("jet", autoscale=True)
- sv.setStack(mystack)
- sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)",
- "3rd dim (0-299)"])
- sv.show()
-
- app.exec_()
-
-"""
-
-__authors__ = ["P. Knobel", "H. Payno"]
-__license__ = "MIT"
-__date__ = "10/10/2018"
-
-import numpy
-import logging
-
-import silx
-from silx.gui import qt
-from .. import icons
-from . import items, PlotWindow, actions
-from .items.image import ImageStack
-from ..colors import Colormap
-from ..colors import cursorColorForColormap
-from .tools import LimitsToolBar
-from .Profile import Profile3DToolBar
-from ..widgets.FrameBrowser import HorizontalSliderWithBrowser
-
-from silx.gui.plot.actions import control as actions_control
-from silx.gui.plot.actions import io as silx_io
-from silx.io.nxdata import save_NXdata
-from silx.utils.array_like import DatasetView, ListOfImages
-from silx.math import calibration
-from silx.utils.deprecation import deprecated_warning
-from silx.utils.deprecation import deprecated
-
-import h5py
-from silx.io.utils import is_dataset
-
-_logger = logging.getLogger(__name__)
-
-
-class StackView(qt.QMainWindow):
- """Stack view widget, to display and browse through stack of
- images.
-
- The profile tool can be switched to "3D" mode, to compute the profile
- on each image of the stack (not only the active image currently displayed)
- and display the result as a slice.
-
- :param QWidget parent: the Qt parent, or None
- :param backend: The backend to use for the plot (default: matplotlib).
- See :class:`.PlotWidget` for the list of supported backend.
- :type backend: str or :class:`BackendBase.BackendBase`
- :param bool resetzoom: Toggle visibility of reset zoom action.
- :param bool autoScale: Toggle visibility of axes autoscale actions.
- :param bool logScale: Toggle visibility of axes log scale actions.
- :param bool grid: Toggle visibility of grid mode action.
- :param bool colormap: Toggle visibility of colormap action.
- :param bool aspectRatio: Toggle visibility of aspect ratio button.
- :param bool yInverted: Toggle visibility of Y axis direction button.
- :param bool copy: Toggle visibility of copy action.
- :param bool save: Toggle visibility of save action.
- :param bool print_: Toggle visibility of print action.
- :param bool control: True to display an Options button with a sub-menu
- to show legends, toggle crosshair and pan with arrows.
- (Default: False)
- :param position: True to display widget with (x, y) mouse position
- (Default: False).
- It also supports a list of (name, funct(x, y)->value)
- to customize the displayed values.
- See :class:`silx.gui.plot.PlotTools.PositionInfo`.
- :param bool mask: Toggle visibilty of mask action.
- """
- # Qt signals
- valueChanged = qt.Signal(object, object, object)
- """Signals that the data value under the cursor has changed.
-
- It provides: row, column, data value.
- """
-
- sigPlaneSelectionChanged = qt.Signal(int)
- """Signal emitted when there is a change is perspective/displayed axes.
-
- It provides the perspective as an integer, with the following meaning:
-
- - 0: axis Y is the 2nd dimension, axis X is the 3rd dimension
- - 1: axis Y is the 1st dimension, axis X is the 3rd dimension
- - 2: axis Y is the 1st dimension, axis X is the 2nd dimension
- """
-
- sigStackChanged = qt.Signal(int)
- """Signal emitted when the stack is changed.
- This happens when a new volume is loaded, or when the current volume
- is transposed (change in perspective).
-
- The signal provides the size (number of pixels) of the stack.
- This will be 0 if the stack is cleared, else it will be a positive
- integer.
- """
-
- sigFrameChanged = qt.Signal(int)
- """Signal emitter when the frame number has changed.
-
- This signal provides the current frame number.
- """
-
- IMAGE_STACK_FILTER_NXDATA = 'Stack of images as NXdata (%s)' % silx_io._NEXUS_HDF5_EXT_STR
-
-
- def __init__(self, parent=None, resetzoom=True, backend=None,
- autoScale=False, logScale=False, grid=False,
- colormap=True, aspectRatio=True, yinverted=True,
- copy=True, save=True, print_=True, control=False,
- position=None, mask=True):
- qt.QMainWindow.__init__(self, parent)
- if parent is not None:
- # behave as a widget
- self.setWindowFlags(qt.Qt.Widget)
- else:
- self.setWindowTitle('StackView')
-
- self._stack = None
- """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays."""
- self.__transposed_view = None
- """View on :attr:`_stack` with the axes sorted, to have
- the orthogonal dimension first"""
- self._perspective = 0
- """Orthogonal dimension (depth) in :attr:`_stack`"""
-
- self._stackItem = ImageStack()
- """Hold the item displaying the stack"""
- imageLegend = '__StackView__image' + str(id(self))
- self._stackItem.setName(imageLegend)
-
- self.__autoscaleCmap = False
- """Flag to disable/enable colormap auto-scaling
- based on the min/max values of the entire 3D volume"""
- self.__dimensionsLabels = ["Dimension 0", "Dimension 1",
- "Dimension 2"]
- """These labels are displayed on the X and Y axes.
- :meth:`setLabels` updates this attribute."""
-
- self._first_stack_dimension = 0
- """Used for dimension labels and combobox"""
-
- self._titleCallback = self._defaultTitleCallback
- """Function returning the plot title based on the frame index.
- It can be set to a custom function using :meth:`setTitleCallback`"""
-
- self.calibrations3D = (calibration.NoCalibration(),
- calibration.NoCalibration(),
- calibration.NoCalibration())
-
- central_widget = qt.QWidget(self)
-
- self._plot = PlotWindow(parent=central_widget, backend=backend,
- resetzoom=resetzoom, autoScale=autoScale,
- logScale=logScale, grid=grid,
- curveStyle=False, colormap=colormap,
- aspectRatio=aspectRatio, yInverted=yinverted,
- copy=copy, save=save, print_=print_,
- control=control, position=position,
- roi=False, mask=mask)
- self._plot.addItem(self._stackItem)
- self._plot.getIntensityHistogramAction().setVisible(True)
- self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged
- self.sigActiveImageChanged = self._plot.sigActiveImageChanged
- self.sigPlotSignal = self._plot.sigPlotSignal
-
- if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
- self._plot.getYAxis().setInverted(True)
-
- self._addColorBarAction()
-
- self._profileToolBar = Profile3DToolBar(parent=self._plot,
- stackview=self)
- self._plot.addToolBar(self._profileToolBar)
- self._plot.getXAxis().setLabel('Columns')
- self._plot.getYAxis().setLabel('Rows')
- self._plot.sigPlotSignal.connect(self._plotCallback)
- self._plot.getSaveAction().setFileFilter('image', self.IMAGE_STACK_FILTER_NXDATA, func=self._saveImageStack, appendToFile=True)
-
- self.__planeSelection = PlanesWidget(self._plot)
- self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
-
- self._browser_label = qt.QLabel("Image index (Dim0):")
-
- self._browser = HorizontalSliderWithBrowser(central_widget)
- self._browser.setRange(0, 0)
- self._browser.valueChanged[int].connect(self.__updateFrameNumber)
- self._browser.setEnabled(False)
-
- layout = qt.QGridLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.addWidget(self._plot, 0, 0, 1, 3)
- layout.addWidget(self.__planeSelection, 1, 0)
- layout.addWidget(self._browser_label, 1, 1)
- layout.addWidget(self._browser, 1, 2)
-
- central_widget.setLayout(layout)
- self.setCentralWidget(central_widget)
-
- # clear profile lines when the perspective changes (plane browsed changed)
- self.__planeSelection.sigPlaneSelectionChanged.connect(
- self._profileToolBar.clearProfile)
-
- def _saveImageStack(self, plot, filename, nameFilter):
- """Save all images from the stack into a volume.
-
- :param str filename: The name of the file to write
- :param str nameFilter: The selected name filter
- :return: False if format is not supported or save failed,
- True otherwise.
- :raises: ValueError if nameFilter is invalid
- """
- if not nameFilter == self.IMAGE_STACK_FILTER_NXDATA:
- raise ValueError('Wrong callback')
- entryPath = silx_io.SaveAction._selectWriteableOutputGroup(filename, parent=self)
- if entryPath is None:
- return False
- return save_NXdata(filename,
- nxentry_name=entryPath,
- signal=self.getStack(copy=False, returnNumpyArray=True)[0],
- signal_name="image_stack")
-
- def _addColorBarAction(self):
- self._plot.getColorBarWidget().setVisible(True)
- actions = self._plot.toolBar().actions()
- for index, action in enumerate(actions):
- if action is self._plot.getColormapAction():
- break
- self._colorbarAction = actions_control.ColorBarAction(self._plot, self._plot)
- self._plot.toolBar().insertAction(actions[index + 1], self._colorbarAction)
-
- def _plotCallback(self, eventDict):
- """Callback for plot events.
-
- Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the
- cursor location in the plot."""
- if eventDict['event'] == 'mouseMoved':
- activeImage = self.getActiveImage()
- if activeImage is not None:
- data = activeImage.getData()
- height, width = data.shape
-
- # Get corresponding coordinate in image
- origin = activeImage.getOrigin()
- scale = activeImage.getScale()
- x = int((eventDict['x'] - origin[0]) / scale[0])
- y = int((eventDict['y'] - origin[1]) / scale[1])
-
- if 0 <= x < width and 0 <= y < height:
- self.valueChanged.emit(float(x), float(y),
- data[y][x])
- else:
- self.valueChanged.emit(float(x), float(y),
- None)
-
- def getPerspective(self):
- """Returns the index of the dimension the stack is browsed with
-
- Possible values are: 0, 1, or 2.
-
- :rtype: int
- """
- return self._perspective
-
- def setPerspective(self, perspective):
- """Set the index of the dimension the stack is browsed with:
-
- - slice plane Dim1-Dim2: perspective 0
- - slice plane Dim0-Dim2: perspective 1
- - slice plane Dim0-Dim1: perspective 2
-
- :param int perspective: Orthogonal dimension number (0, 1, or 2)
- """
- if perspective == self._perspective:
- return
- else:
- if perspective > 2 or perspective < 0:
- raise ValueError(
- "Perspective must be 0, 1 or 2, not %s" % perspective)
-
- self._perspective = int(perspective)
- self.__createTransposedView()
- self.__updateFrameNumber(self._browser.value())
- self._plot.resetZoom()
- self.__updatePlotLabels()
- self._updateTitle()
- self._browser_label.setText("Image index (Dim%d):" %
- (self._first_stack_dimension + perspective))
-
- self.sigPlaneSelectionChanged.emit(perspective)
- self.sigStackChanged.emit(self._stack.size if
- self._stack is not None else 0)
- self.__planeSelection.sigPlaneSelectionChanged.disconnect(self.setPerspective)
- self.__planeSelection.setPerspective(self._perspective)
- self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
-
- def __updatePlotLabels(self):
- """Update plot axes labels depending on perspective"""
- y, x = (1, 2) if self._perspective == 0 else \
- (0, 2) if self._perspective == 1 else (0, 1)
- self.setGraphXLabel(self.__dimensionsLabels[x])
- self.setGraphYLabel(self.__dimensionsLabels[y])
-
- def __createTransposedView(self):
- """Create the new view on the stack depending on the perspective
- (set orthogonal axis browsed on the viewer as first dimension)
- """
- assert self._stack is not None
- assert 0 <= self._perspective < 3
-
- # ensure we have the stack encapsulated in an array-like object
- # having a transpose() method
- if isinstance(self._stack, numpy.ndarray):
- self.__transposed_view = self._stack
-
- elif is_dataset(self._stack) or isinstance(self._stack, DatasetView):
- self.__transposed_view = DatasetView(self._stack)
-
- elif isinstance(self._stack, ListOfImages):
- self.__transposed_view = ListOfImages(self._stack)
-
- # transpose the array-like object if necessary
- if self._perspective == 1:
- self.__transposed_view = self.__transposed_view.transpose((1, 0, 2))
- elif self._perspective == 2:
- self.__transposed_view = self.__transposed_view.transpose((2, 0, 1))
-
- self._browser.setRange(0, self.__transposed_view.shape[0] - 1)
- self._browser.setValue(0)
-
- # Update the item structure
- self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
- self._stackItem.setColormap(self.getColormap())
- self._stackItem.setOrigin(self._getImageOrigin())
- self._stackItem.setScale(self._getImageScale())
-
- def __updateFrameNumber(self, index):
- """Update the current image.
-
- :param index: index of the frame to be displayed
- """
- if self.__transposed_view is None:
- # no data set
- return
-
- self._stackItem.setStackPosition(index)
-
- self._updateTitle()
- self.sigFrameChanged.emit(index)
-
- def _set3DScaleAndOrigin(self, calibrations):
- """Set scale and origin for all 3 axes, to be used when plotting
- an image.
-
- See setStack for parameter documentation
- """
- if calibrations is None:
- self.calibrations3D = (calibration.NoCalibration(),
- calibration.NoCalibration(),
- calibration.NoCalibration())
- else:
- self.calibrations3D = []
- for i, calib in enumerate(calibrations):
- if hasattr(calib, "__len__") and len(calib) == 2:
- calib = calibration.LinearCalibration(calib[0], calib[1])
- elif calib is None:
- calib = calibration.NoCalibration()
- elif not isinstance(calib, calibration.AbstractCalibration):
- raise TypeError("calibration must be a 2-tuple, None or" +
- " an instance of an AbstractCalibration " +
- "subclass")
- elif not calib.is_affine():
- _logger.warning(
- "Calibration for dimension %d is not linear, "
- "it will be ignored for scaling the graph axes.",
- i)
- self.calibrations3D.append(calib)
-
- def getCalibrations(self, order='array'):
- """Returns currently used calibrations for each axis
-
- Returned calibrations might differ from the ones that were set as
- non-linear calibrations used for image axes are temporarily ignored.
-
- :param str order:
- 'array' to sort calibrations as data array (dim0, dim1, dim2),
- 'axes' to sort calibrations as currently selected x, y and z axes.
- :return: Calibrations ordered depending on order
- :rtype: List[~silx.math.calibration.AbstractCalibration]
- """
- assert order in ('array', 'axes')
- calibs = []
-
- # filter out non-linear calibration for graph axes
- for index, calib in enumerate(self.calibrations3D):
- if index != self._perspective and not calib.is_affine():
- calib = calibration.NoCalibration()
- calibs.append(calib)
-
- if order == 'axes': # Move 'z' axis to the end
- xy_dims = [d for d in (0, 1, 2) if d != self._perspective]
- calibs = [calibs[max(xy_dims)],
- calibs[min(xy_dims)],
- calibs[self._perspective]]
-
- return tuple(calibs)
-
- def _getImageScale(self):
- """
- :return: 2-tuple (XScale, YScale) for current image view
- """
- xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
- return xcalib.get_slope(), ycalib.get_slope()
-
- def _getImageOrigin(self):
- """
- :return: 2-tuple (XOrigin, YOrigin) for current image view
- """
- xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
- return xcalib(0), ycalib(0)
-
- def _getImageZ(self, index):
- """
- :param idx: 0-based image index in the stack
- :return: calibrated Z value corresponding to the image idx
- """
- _xcalib, _ycalib, zcalib = self.getCalibrations(order='axes')
- return zcalib(index)
-
- def _updateTitle(self):
- frame_idx = self._browser.value()
- self._plot.setGraphTitle(self._titleCallback(frame_idx))
-
- def _defaultTitleCallback(self, index):
- return "Image z=%g" % self._getImageZ(index)
-
- # public API, stack specific methods
- def setStack(self, stack, perspective=None, reset=True, calibrations=None):
- """Set the 3D stack.
-
- The perspective parameter is used to define which dimension of the 3D
- array is to be used as frame index. The lowest remaining dimension
- number is the row index of the displayed image (Y axis), and the highest
- remaining dimension is the column index (X axis).
-
- :param stack: 3D stack, or `None` to clear plot.
- :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D
- numpy arrays, or None.
- :param int perspective: Dimension for the frame index: 0, 1 or 2.
- Use ``None`` to keep the current perspective (default).
- :param bool reset: Whether to reset zoom or not.
- :param calibrations: Sequence of 3 calibration objects for each axis.
- These objects can be a subclass of :class:`AbstractCalibration`,
- or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the
- slope of a linear calibration (:math:`x \\mapsto a + b x`)
- """
- if stack is None:
- self.clear()
- self.sigStackChanged.emit(0)
- return
-
- self._set3DScaleAndOrigin(calibrations)
-
- # stack as list of 2D arrays: must be converted into an array_like
- if not isinstance(stack, numpy.ndarray):
- if not is_dataset(stack):
- try:
- assert hasattr(stack, "__len__")
- for img in stack:
- assert hasattr(img, "shape")
- assert len(img.shape) == 2
- except AssertionError:
- raise ValueError(
- "Stack must be a 3D array/dataset or a list of " +
- "2D arrays.")
- stack = ListOfImages(stack)
-
- assert len(stack.shape) == 3, "data must be 3D"
-
- self._stack = stack
- self.__createTransposedView()
-
- perspective_changed = False
- if perspective not in [None, self._perspective]:
- perspective_changed = True
- self.setPerspective(perspective)
-
- if self.__autoscaleCmap:
- self.scaleColormapRangeToStack()
-
- # init plot
- self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
- self._stackItem.setColormap(self.getColormap())
- self._stackItem.setOrigin(self._getImageOrigin())
- self._stackItem.setScale(self._getImageScale())
- self._stackItem.setVisible(True)
-
- # Put back the item in the plot in case it was cleared
- exists = self._plot.getImage(self._stackItem.getName())
- if exists is None:
- self._plot.addItem(self._stackItem)
-
- self._plot.setActiveImage(self._stackItem.getName())
- self.__updatePlotLabels()
- self._updateTitle()
-
- if reset:
- self._plot.resetZoom()
-
- # enable and init browser
- self._browser.setEnabled(True)
-
- if not perspective_changed: # avoid double signal (see self.setPerspective)
- self.sigStackChanged.emit(stack.size)
-
- def getStack(self, copy=True, returnNumpyArray=False):
- """Get the original stack, as a 3D array or dataset.
-
- The output has the form: [data, params]
- where params is a dictionary containing display parameters.
-
- :param bool copy: If True (default), then the object is copied
- and returned as a numpy array.
- Else, a reference to original data is returned, if possible.
- If the original data is not a numpy array and parameter
- returnNumpyArray is True, a copy will be made anyway.
- :param bool returnNumpyArray: If True, the returned object is
- guaranteed to be a numpy array.
- :return: 3D stack and parameters.
- :rtype: (numpy.ndarray, dict)
- """
- if self._stack is None:
- return None
-
- image = self._stackItem
- colormap = image.getColormap()
-
- params = {
- 'info': image.getInfo(),
- 'origin': image.getOrigin(),
- 'scale': image.getScale(),
- 'z': image.getZValue(),
- 'selectable': image.isSelectable(),
- 'draggable': image.isDraggable(),
- 'colormap': colormap,
- 'xlabel': image.getXLabel(),
- 'ylabel': image.getYLabel(),
- }
- if returnNumpyArray or copy:
- return numpy.array(self._stack, copy=copy), params
-
- # if a list of 2D arrays was cast into a ListOfImages,
- # return the original list
- if isinstance(self._stack, ListOfImages):
- return self._stack.images, params
-
- return self._stack, params
-
- def getCurrentView(self, copy=True, returnNumpyArray=False):
- """Get the stack, as it is currently displayed.
-
- The first index of the returned stack is always the frame
- index. If the perspective has been changed in the widget since the
- data was first loaded, this will be reflected in the order of the
- dimensions of the returned object.
-
- The output has the form: [data, params]
- where params is a dictionary containing display parameters.
-
- :param bool copy: If True (default), then the object is copied
- and returned as a numpy array.
- Else, a reference to original data is returned, if possible.
- If the original data is not a numpy array and parameter
- `returnNumpyArray` is `True`, a copy will be made anyway.
- :param bool returnNumpyArray: If `True`, the returned object is
- guaranteed to be a numpy array.
- :return: 3D stack and parameters.
- :rtype: (numpy.ndarray, dict)
- """
- image = self.getActiveImage()
- if image is None:
- return None
-
- if isinstance(image, items.ColormapMixIn):
- colormap = image.getColormap()
- else:
- colormap = None
-
- params = {
- 'info': image.getInfo(),
- 'origin': image.getOrigin(),
- 'scale': image.getScale(),
- 'z': image.getZValue(),
- 'selectable': image.isSelectable(),
- 'draggable': image.isDraggable(),
- 'colormap': colormap,
- 'xlabel': image.getXLabel(),
- 'ylabel': image.getYLabel(),
- }
- if returnNumpyArray or copy:
- return numpy.array(self.__transposed_view, copy=copy), params
- return self.__transposed_view, params
-
- def setFrameNumber(self, number):
- """Set the frame selection to a specific value
-
- :param int number: Number of the frame
- """
- self._browser.setValue(number)
-
- def getFrameNumber(self):
- """Set the frame selection to a specific value
-
- :return: Index of currently displayed frame
- :rtype: int
- """
- return self._browser.value()
-
- def setFirstStackDimension(self, first_stack_dimension):
- """When viewing the last 3 dimensions of an n-D array (n>3), you can
- use this method to change the text in the combobox.
-
- For instance, for a 7-D array, first stack dim is 4, so the default
- "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
- numbers are 0-based).
-
- :param int first_stack_dim: First stack dimension (n-3) when viewing the
- last 3 dimensions of an n-D array.
- """
- old_state = self.__planeSelection.blockSignals(True)
- self.__planeSelection.setFirstStackDimension(first_stack_dimension)
- self.__planeSelection.blockSignals(old_state)
- self._first_stack_dimension = first_stack_dimension
- self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension)
-
- def setTitleCallback(self, callback):
- """Set a user defined function to generate the plot title based on the
- image/frame index.
-
- The callback function must accept an integer as a its first positional
- parameter and must not require any other mandatory parameter.
- It must return a string.
-
- To switch back the default behavior, you can pass ``None``::
-
- mystackview.setTitleCallback(None)
-
- To have no title, pass a function that returns an empty string::
-
- mystackview.setTitleCallback(lambda idx: "")
-
- :param callback: Callback function generating the stack title based
- on the frame number.
- """
-
- if callback is None:
- self._titleCallback = self._defaultTitleCallback
- elif callable(callback):
- self._titleCallback = callback
- else:
- raise TypeError("Provided callback is not callable")
- self._updateTitle()
-
- def clear(self):
- """Clear the widget:
-
- - clear the plot
- - clear the loaded data volume
- """
- self._stack = None
- self.__transposed_view = None
- self._perspective = 0
- self._browser.setEnabled(False)
- # reset browser range
- self._browser.setRange(0, 0)
- self._plot.clear()
-
- def setLabels(self, labels=None):
- """Set the labels to be displayed on the plot axes.
-
- You must provide a sequence of 3 strings, corresponding to the 3
- dimensions of the original data volume.
- The proper label will automatically be selected for each plot axis
- when the volume is rotated (when different axes are selected as the
- X and Y axes).
-
- :param List[str] labels: 3 labels corresponding to the 3 dimensions
- of the data volumes.
- """
-
- default_labels = ["Dimension %d" % self._first_stack_dimension,
- "Dimension %d" % (self._first_stack_dimension + 1),
- "Dimension %d" % (self._first_stack_dimension + 2)]
- if labels is None:
- new_labels = default_labels
- else:
- # filter-out None
- new_labels = []
- for i, label in enumerate(labels):
- new_labels.append(label or default_labels[i])
-
- self.__dimensionsLabels = new_labels
- self.__updatePlotLabels()
-
- def getLabels(self):
- """Return dimension labels displayed on the plot axes
-
- :return: List of three strings corresponding to the 3 dimensions
- of the stack: (name_dim0, name_dim1, name_dim2)
- """
- return self.__dimensionsLabels
-
- def getColormap(self):
- """Get the current colormap description.
-
- :return: A description of the current colormap.
- See :meth:`setColormap` for details.
- :rtype: dict
- """
- # "default" colormap used by addImage when image is added without
- # specifying a special colormap
- return self._plot.getDefaultColormap()
-
- def scaleColormapRangeToStack(self):
- """Scale colormap range according to current stack data.
-
- If no stack has been set through :meth:`setStack`, this has no effect.
-
- The range scaling mode is given by current :class:`Colormap`'s
- :meth:`Colormap.getAutoscaleMode`.
- """
- stack = self.getStack(copy=False, returnNumpyArray=True)
- if stack is None:
- return # No-op
-
- colormap = self.getColormap()
- vmin, vmax = colormap.getColormapRange(data=stack[0])
- colormap.setVRange(vmin=vmin, vmax=vmax)
-
- def setColormap(self, colormap=None, normalization=None,
- autoscale=None, vmin=None, vmax=None, colors=None):
- """Set the colormap and update active image.
-
- Parameters that are not provided are taken from the current colormap.
-
- The colormap parameter can also be a dict with the following keys:
-
- - *name*: string. The colormap to use:
- 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
- - *normalization*: string. The mapping to use for the colormap:
- either 'linear' or 'log'.
- - *autoscale*: bool. Whether to use autoscale (True) or range
- provided by keys
- 'vmin' and 'vmax' (False).
- - *vmin*: float. The minimum value of the range to use if 'autoscale'
- is False.
- - *vmax*: float. The maximum value of the range to use if 'autoscale'
- is False.
- - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
- List of RGB or RGBA colors to use (only if name is None)
-
- :param colormap: Name of the colormap in
- 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
- Or a :class`.Colormap` object.
- :type colormap: dict or str.
- :param str normalization: Colormap mapping: 'linear' or 'log'.
- :param bool autoscale: Whether to use autoscale or [vmin, vmax] range.
- Default value of autoscale is False. This option is not compatible
- with h5py datasets.
- :param float vmin: The minimum value of the range to use if
- 'autoscale' is False.
- :param float vmax: The maximum value of the range to use if
- 'autoscale' is False.
- :param numpy.ndarray colors: Only used if name is None.
- Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
- """
- # if is a colormap object or a dictionary
- if isinstance(colormap, Colormap) or isinstance(colormap, dict):
- # Support colormap parameter as a dict
- errmsg = "If colormap is provided as a Colormap object, all other parameters"
- errmsg += " must not be specified when calling setColormap"
- assert normalization is None, errmsg
- assert autoscale is None, errmsg
- assert vmin is None, errmsg
- assert vmax is None, errmsg
- assert colors is None, errmsg
-
- if isinstance(colormap, dict):
- reason = 'colormap parameter should now be an object'
- replacement = 'Colormap()'
- since_version = '0.6'
- deprecated_warning(type_='function',
- name='setColormap',
- reason=reason,
- replacement=replacement,
- since_version=since_version)
- _colormap = Colormap._fromDict(colormap)
- else:
- _colormap = colormap
- else:
- norm = normalization if normalization is not None else 'linear'
- name = colormap if colormap is not None else 'gray'
- _colormap = Colormap(name=name,
- normalization=norm,
- vmin=vmin,
- vmax=vmax,
- colors=colors)
-
- if autoscale is not None:
- deprecated_warning(
- type_='function',
- name='setColormap',
- reason='autoscale argument is replaced by a method',
- replacement='scaleColormapRangeToStack',
- since_version='0.14')
- self.__autoscaleCmap = bool(autoscale)
-
- cursorColor = cursorColorForColormap(_colormap.getName())
- self._plot.setInteractiveMode('zoom', color=cursorColor)
-
- self._plot.setDefaultColormap(_colormap)
-
- # Update active image colormap
- activeImage = self.getActiveImage()
- if isinstance(activeImage, items.ColormapMixIn):
- activeImage.setColormap(self.getColormap())
-
- if self.__autoscaleCmap:
- # scaleColormapRangeToStack needs to be called **after**
- # setDefaultColormap so getColormap returns the right colormap
- self.scaleColormapRangeToStack()
-
-
- @deprecated(replacement="getPlotWidget", since_version="0.13")
- def getPlot(self):
- return self.getPlotWidget()
-
- def getPlotWidget(self):
- """Return the :class:`PlotWidget`.
-
- This gives access to advanced plot configuration options.
- Be warned that modifying the plot can cause issues, and some changes
- you make to the plot could be overwritten by the :class:`StackView`
- widget's internal methods and callbacks.
-
- :return: instance of :class:`PlotWidget` used in widget
- """
- return self._plot
-
- def setOptionVisible(self, isVisible):
- """
- Set the visibility of the browsing options.
-
- :param bool isVisible: True to have the options visible, else False
- """
- self._browser.setVisible(isVisible)
- self.__planeSelection.setVisible(isVisible)
-
- # proxies to PlotWidget or PlotWindow methods
- def getProfileToolbar(self):
- """Profile tools attached to this plot
- """
- return self._profileToolBar
-
- def getGraphTitle(self):
- """Return the plot main title as a str.
- """
- return self._plot.getGraphTitle()
-
- def setGraphTitle(self, title=""):
- """Set the plot main title.
-
- :param str title: Main title of the plot (default: '')
- """
- return self._plot.setGraphTitle(title)
-
- def getGraphXLabel(self):
- """Return the current horizontal axis label as a str.
- """
- return self._plot.getXAxis().getLabel()
-
- def setGraphXLabel(self, label=None):
- """Set the plot horizontal axis label.
-
- :param str label: The horizontal axis label
- """
- if label is None:
- label = self.__dimensionsLabels[1 if self._perspective == 2 else 2]
- self._plot.getXAxis().setLabel(label)
-
- def getGraphYLabel(self, axis='left'):
- """Return the current vertical axis label as a str.
-
- :param str axis: The Y axis for which to get the label (left or right)
- """
- return self._plot.getYAxis().getLabel(axis)
-
- def setGraphYLabel(self, label=None, axis='left'):
- """Set the vertical axis label on the plot.
-
- :param str label: The Y axis label
- :param str axis: The Y axis for which to set the label (left or right)
- """
- if label is None:
- label = self.__dimensionsLabels[1 if self._perspective == 0 else 0]
- self._plot.getYAxis(axis=axis).setLabel(label)
-
- def resetZoom(self):
- """Reset the plot limits to the bounds of the data and redraw the plot.
-
- This method is a simple proxy to the legacy :class:`PlotWidget` method
- of the same name. Using the object oriented approach is now
- preferred::
-
- stackview.getPlot().resetZoom()
- """
- self._plot.resetZoom()
-
- def setYAxisInverted(self, flag=True):
- """Set the Y axis orientation.
-
- This method is a simple proxy to the legacy :class:`PlotWidget` method
- of the same name. Using the object oriented approach is now
- preferred::
-
- stackview.getPlot().setYAxisInverted(flag)
-
- :param bool flag: True for Y axis going from top to bottom,
- False for Y axis going from bottom to top
- """
- self._plot.setYAxisInverted(flag)
-
- def isYAxisInverted(self):
- """Return True if Y axis goes from top to bottom, False otherwise.
-
- This method is a simple proxy to the legacy :class:`PlotWidget` method
- of the same name. Using the object oriented approach is now
- preferred::
-
- stackview.getPlot().isYAxisInverted()"""
- return self._plot.isYAxisInverted()
-
- def getSupportedColormaps(self):
- """Get the supported colormap names as a tuple of str.
-
- The list should at least contain and start by:
- ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
-
- This method is a simple proxy to the legacy :class:`PlotWidget` method
- of the same name. Using the object oriented approach is now
- preferred::
-
- stackview.getPlot().getSupportedColormaps()
- """
- return self._plot.getSupportedColormaps()
-
- def isKeepDataAspectRatio(self):
- """Returns whether the plot is keeping data aspect ratio or not.
-
- This method is a simple proxy to the legacy :class:`PlotWidget` method
- of the same name. Using the object oriented approach is now
- preferred::
-
- stackview.getPlot().isKeepDataAspectRatio()"""
- return self._plot.isKeepDataAspectRatio()
-
- def setKeepDataAspectRatio(self, flag=True):
- """Set whether the plot keeps data aspect ratio or not.
-
- This method is a simple proxy to the legacy :class:`PlotWidget` method
- of the same name. Using the object oriented approach is now
- preferred::
-
- stackview.getPlot().setKeepDataAspectRatio(flag)
-
- :param bool flag: True to respect data aspect ratio
- """
- self._plot.setKeepDataAspectRatio(flag)
-
- # kind of private methods, but needed by Profile
- def getActiveImage(self, just_legend=False):
- """Returns the stack image object.
- """
- if just_legend:
- return self._stackItem.getName()
- return self._stackItem
-
- def getColorBarAction(self):
- """Returns the action managing the visibility of the colorbar.
-
- .. warning:: to show/hide the plot colorbar call directly the ColorBar
- widget using getColorBarWidget()
-
- :rtype: QAction
- """
- return self._colorbarAction
-
- def remove(self, legend=None,
- kind=('curve', 'image', 'item', 'marker')):
- """See :meth:`Plot.Plot.remove`"""
- self._plot.remove(legend, kind)
-
- def setInteractiveMode(self, *args, **kwargs):
- """
- See :meth:`Plot.Plot.setInteractiveMode`
- """
- self._plot.setInteractiveMode(*args, **kwargs)
-
- @deprecated(replacement="addShape", since_version="0.13")
- def addItem(self, *args, **kwargs):
- self.addShape(*args, **kwargs)
-
- def addShape(self, *args, **kwargs):
- """
- See :meth:`Plot.Plot.addShape`
- """
- self._plot.addShape(*args, **kwargs)
-
-
-class PlanesWidget(qt.QWidget):
- """Widget for the plane/perspective selection
-
- :param parent: the parent QWidget
- """
- sigPlaneSelectionChanged = qt.Signal(int)
-
- def __init__(self, parent):
- super(PlanesWidget, self).__init__(parent)
-
- self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
- layout0 = qt.QHBoxLayout()
- self.setLayout(layout0)
- layout0.setContentsMargins(0, 0, 0, 0)
-
- layout0.addWidget(qt.QLabel("Axes selection:"))
-
- # By default, the first dimension (dim0) is the frame index/depth/z,
- # the second dimension is the image row number/y axis
- # and the third dimension is the image column index/x axis
-
- # 1
- # | 0
- # |/__2
- self.qcbAxisSelection = qt.QComboBox(self)
- self._setCBChoices(first_stack_dimension=0)
- self.qcbAxisSelection.currentIndexChanged[int].connect(
- self.__planeSelectionChanged)
-
- layout0.addWidget(self.qcbAxisSelection)
-
- def __planeSelectionChanged(self, idx):
- """Callback function when the combobox selection changes
-
- idx is the dimension number orthogonal to the slice plane,
- following the convention:
-
- - slice plane Dim1-Dim2: perspective 0
- - slice plane Dim0-Dim2: perspective 1
- - slice plane Dim0-Dim1: perspective 2
- """
- self.sigPlaneSelectionChanged.emit(idx)
-
- def _setCBChoices(self, first_stack_dimension):
- self.qcbAxisSelection.clear()
-
- dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1,
- first_stack_dimension + 2)
- dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension,
- first_stack_dimension + 2)
- dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension,
- first_stack_dimension + 1)
-
- self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2)
- self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2)
- self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1)
-
- def setFirstStackDimension(self, first_stack_dim):
- """When viewing the last 3 dimensions of an n-D array (n>3), you can
- use this method to change the text in the combobox.
-
- For instance, for a 7-D array, first stack dim is 4, so the default
- "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
- numbers are 0-based).
-
- :param int first_stack_dim: First stack dimension (n-3) when viewing the
- last 3 dimensions of an n-D array.
- """
- self._setCBChoices(first_stack_dim)
-
- def setPerspective(self, perspective):
- """Update the combobox selection.
-
- - slice plane Dim1-Dim2: perspective 0
- - slice plane Dim0-Dim2: perspective 1
- - slice plane Dim0-Dim1: perspective 2
-
- :param perspective: Orthogonal dimension number (0, 1, or 2)
- """
- self.qcbAxisSelection.setCurrentIndex(perspective)
-
-
-class StackViewMainWindow(StackView):
- """This class is a :class:`StackView` with a menu, an additional toolbar
- to set the plot limits, and a status bar to display the value and 3D
- index of the data samples hovered by the mouse cursor.
-
- :param QWidget parent: Parent widget, or None
- """
- def __init__(self, parent=None):
- self._dataInfo = None
- super(StackViewMainWindow, self).__init__(parent)
- self.setWindowFlags(qt.Qt.Window)
-
- # Add toolbars and status bar
- self.addToolBar(qt.Qt.BottomToolBarArea,
- LimitsToolBar(plot=self._plot))
-
- self.statusBar()
-
- menu = self.menuBar().addMenu('File')
- menu.addAction(self._plot.getOutputToolBar().getSaveAction())
- menu.addAction(self._plot.getOutputToolBar().getPrintAction())
- menu.addSeparator()
- action = menu.addAction('Quit')
- action.triggered[bool].connect(qt.QApplication.instance().quit)
-
- menu = self.menuBar().addMenu('Edit')
- menu.addAction(self._plot.getOutputToolBar().getCopyAction())
- menu.addSeparator()
- menu.addAction(self._plot.getResetZoomAction())
- menu.addAction(self._plot.getColormapAction())
- menu.addAction(self.getColorBarAction())
-
- menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self))
- menu.addAction(actions.control.YAxisInvertedAction(self._plot, self))
-
- menu = self.menuBar().addMenu('Profile')
- profileToolBar = self._profileToolBar
- menu.addAction(profileToolBar.hLineAction)
- menu.addAction(profileToolBar.vLineAction)
- menu.addAction(profileToolBar.lineAction)
- menu.addAction(profileToolBar.crossAction)
- menu.addSeparator()
- menu.addAction(profileToolBar._editor)
- menu.addSeparator()
- menu.addAction(profileToolBar.clearAction)
-
- # Connect to StackView's signal
- self.valueChanged.connect(self._statusBarSlot)
-
- def _statusBarSlot(self, x, y, value):
- """Update status bar with coordinates/value from plots."""
- # todo (after implementing calibration):
- # - use floats for (x, y, z)
- # - display both indices (dim0, dim1, dim2) and (x, y, z)
- msg = "Cursor out of range"
- if x is not None and y is not None:
- img_idx = self._browser.value()
-
- if self._perspective == 0:
- dim0, dim1, dim2 = img_idx, int(y), int(x)
- elif self._perspective == 1:
- dim0, dim1, dim2 = int(y), img_idx, int(x)
- elif self._perspective == 2:
- dim0, dim1, dim2 = int(y), int(x), img_idx
-
- msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2)
- if value is not None:
- msg += ', Value: %g' % value
- if self._dataInfo is not None:
- msg = self._dataInfo + ', ' + msg
-
- self.statusBar().showMessage(msg)
-
- def setStack(self, stack, *args, **kwargs):
- """Set the displayed stack.
-
- See :meth:`StackView.setStack` for details.
- """
- if hasattr(stack, 'dtype') and hasattr(stack, 'shape'):
- assert len(stack.shape) == 3
- nframes, height, width = stack.shape
- self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width,
- str(stack.dtype))
- self.statusBar().showMessage(self._dataInfo)
- else:
- self._dataInfo = None
-
- # Set the new stack in StackView widget
- super(StackViewMainWindow, self).setStack(stack, *args, **kwargs)
- self.setStatusBar(None)
diff --git a/silx/gui/plot/StatsWidget.py b/silx/gui/plot/StatsWidget.py
deleted file mode 100644
index 6d8739e..0000000
--- a/silx/gui/plot/StatsWidget.py
+++ /dev/null
@@ -1,1661 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-Module containing widgets displaying stats from items of a plot.
-"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "24/07/2018"
-
-
-from collections import OrderedDict
-from contextlib import contextmanager
-import logging
-import weakref
-import functools
-import numpy
-import enum
-from silx.utils.proxy import docstring
-from silx.utils.enum import Enum as _Enum
-from silx.gui import qt
-from silx.gui import icons
-from silx.gui.plot import stats as statsmdl
-from silx.gui.widgets.TableWidget import TableWidget
-from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter
-from silx.gui.plot.items.core import ItemChangedType
-from silx.gui.widgets.FlowLayout import FlowLayout
-from . import PlotWidget
-from . import items as plotitems
-
-
-_logger = logging.getLogger(__name__)
-
-
-@enum.unique
-class UpdateMode(_Enum):
- AUTO = 'auto'
- MANUAL = 'manual'
-
-
-# Helper class to handle specific calls to PlotWidget and SceneWidget
-
-
-class _Wrapper(qt.QObject):
- """Base class for connection with PlotWidget and SceneWidget.
-
- This class is used when no PlotWidget or SceneWidget is connected.
-
- :param plot: The plot to be used
- """
-
- sigItemAdded = qt.Signal(object)
- """Signal emitted when a new item is added.
-
- It provides the added item.
- """
-
- sigItemRemoved = qt.Signal(object)
- """Signal emitted when an item is (about to be) removed.
-
- It provides the removed item.
- """
-
- sigCurrentChanged = qt.Signal(object)
- """Signal emitted when the current item has changed.
-
- It provides the current item.
- """
-
- sigVisibleDataChanged = qt.Signal()
- """Signal emitted when the visible data area has changed"""
-
- def __init__(self, plot=None):
- super(_Wrapper, self).__init__(parent=None)
- self._plotRef = None if plot is None else weakref.ref(plot)
-
- def getPlot(self):
- """Returns the plot attached to this widget"""
- return None if self._plotRef is None else self._plotRef()
-
- def getItems(self):
- """Returns the list of items in the plot
-
- :rtype: List[object]
- """
- return ()
-
- def getSelectedItems(self):
- """Returns the list of selected items in the plot
-
- :rtype: List[object]
- """
- return ()
-
- def setCurrentItem(self, item):
- """Set the current/active item in the plot
-
- :param item: The plot item to set as active/current
- """
- pass
-
- def getLabel(self, item):
- """Returns the label of the given item.
-
- :param item:
- :rtype: str
- """
- return ''
-
- def getKind(self, item):
- """Returns the kind of an item or None if not supported
-
- :param item:
- :rtype: Union[str,None]
- """
- return None
-
-
-class _PlotWidgetWrapper(_Wrapper):
- """Class handling PlotWidget specific calls and signal connections
-
- See :class:`._Wrapper` for documentation
-
- :param PlotWidget plot:
- """
-
- def __init__(self, plot):
- assert isinstance(plot, PlotWidget)
- super(_PlotWidgetWrapper, self).__init__(plot)
- plot.sigItemAdded.connect(self.sigItemAdded.emit)
- plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit)
- plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
- plot.sigActiveImageChanged.connect(self._activeImageChanged)
- plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
- plot.sigPlotSignal.connect(self._limitsChanged)
-
- def _activeChanged(self, kind):
- """Handle change of active curve/image/scatter"""
- plot = self.getPlot()
- if plot is not None:
- item = plot._getActiveItem(kind=kind)
- if item is None or self.getKind(item) is not None:
- self.sigCurrentChanged.emit(item)
-
- def _activeCurveChanged(self, previous, current):
- self._activeChanged(kind='curve')
-
- def _activeImageChanged(self, previous, current):
- self._activeChanged(kind='image')
-
- def _activeScatterChanged(self, previous, current):
- self._activeChanged(kind='scatter')
-
- def _limitsChanged(self, event):
- """Handle change of plot area limits."""
- if event['event'] == 'limitsChanged':
- self.sigVisibleDataChanged.emit()
-
- def getItems(self):
- plot = self.getPlot()
- if plot is None:
- return ()
- else:
- return [item for item in plot.getItems() if item.isVisible()]
-
- def getSelectedItems(self):
- plot = self.getPlot()
- items = []
- if plot is not None:
- for kind in plot._ACTIVE_ITEM_KINDS:
- item = plot._getActiveItem(kind=kind)
- if item is not None:
- items.append(item)
- return tuple(items)
-
- def setCurrentItem(self, item):
- plot = self.getPlot()
- if plot is not None:
- kind = self.getKind(item)
- if kind in plot._ACTIVE_ITEM_KINDS:
- if plot._getActiveItem(kind) != item:
- plot._setActiveItem(kind, item.getName())
-
- def getLabel(self, item):
- return item.getName()
-
- def getKind(self, item):
- if isinstance(item, plotitems.Curve):
- return 'curve'
- elif isinstance(item, plotitems.ImageData):
- return 'image'
- elif isinstance(item, plotitems.Scatter):
- return 'scatter'
- elif isinstance(item, plotitems.Histogram):
- return 'histogram'
- else:
- return None
-
-
-class _SceneWidgetWrapper(_Wrapper):
- """Class handling SceneWidget specific calls and signal connections
-
- See :class:`._Wrapper` for documentation
-
- :param SceneWidget plot:
- """
-
- def __init__(self, plot):
- # Lazy-import to avoid circular imports
- from ..plot3d.SceneWidget import SceneWidget
-
- assert isinstance(plot, SceneWidget)
- super(_SceneWidgetWrapper, self).__init__(plot)
- plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded)
- plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved)
- plot.selection().sigCurrentChanged.connect(self._currentChanged)
- # sigVisibleDataChanged is never emitted
-
- def _currentChanged(self, current, previous):
- self.sigCurrentChanged.emit(current)
-
- def getItems(self):
- plot = self.getPlot()
- return () if plot is None else tuple(plot.getSceneGroup().visit())
-
- def getSelectedItems(self):
- plot = self.getPlot()
- return () if plot is None else (plot.selection().getCurrentItem(),)
-
- def setCurrentItem(self, item):
- plot = self.getPlot()
- if plot is not None:
- plot.selection().setCurrentItem(item)
-
- def getLabel(self, item):
- return item.getLabel()
-
- def getKind(self, item):
- from ..plot3d import items as plot3ditems
-
- if isinstance(item, (plot3ditems.ImageData,
- plot3ditems.ScalarField3D)):
- return 'image'
- elif isinstance(item, (plot3ditems.Scatter2D,
- plot3ditems.Scatter3D)):
- return 'scatter'
- else:
- return None
-
-
-class _ScalarFieldViewWrapper(_Wrapper):
- """Class handling ScalarFieldView specific calls and signal connections
-
- See :class:`._Wrapper` for documentation
-
- :param SceneWidget plot:
- """
-
- def __init__(self, plot):
- # Lazy-import to avoid circular imports
- from ..plot3d.ScalarFieldView import ScalarFieldView
- from ..plot3d.items import ScalarField3D
-
- assert isinstance(plot, ScalarFieldView)
- super(_ScalarFieldViewWrapper, self).__init__(plot)
- self._item = ScalarField3D()
- self._dataChanged()
- plot.sigDataChanged.connect(self._dataChanged)
- # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted
-
- def _dataChanged(self):
- plot = self.getPlot()
- if plot is not None:
- self._item.setData(plot.getData(copy=False), copy=False)
- self.sigCurrentChanged.emit(self._item)
-
- def getItems(self):
- plot = self.getPlot()
- return () if plot is None else (self._item,)
-
- def getSelectedItems(self):
- return self.getItems()
-
- def setCurrentItem(self, item):
- pass
-
- def getLabel(self, item):
- return 'Data'
-
- def getKind(self, item):
- return 'image'
-
-
-class _Container(object):
- """Class to contain a plot item.
-
- This is apparently needed for compatibility with PySide2,
-
- :param QObject obj:
- """
- def __init__(self, obj):
- self._obj = obj
-
- def __call__(self):
- return self._obj
-
-
-class _StatsWidgetBase(object):
- """
- Base class for all widgets which want to display statistics
- """
-
- def __init__(self, statsOnVisibleData, displayOnlyActItem):
- self._displayOnlyActItem = displayOnlyActItem
- self._statsOnVisibleData = statsOnVisibleData
- self._statsHandler = None
- self._updateMode = UpdateMode.AUTO
-
- self.__default_skipped_events = (
- ItemChangedType.ALPHA,
- ItemChangedType.COLOR,
- ItemChangedType.COLORMAP,
- ItemChangedType.SYMBOL,
- ItemChangedType.SYMBOL_SIZE,
- ItemChangedType.LINE_WIDTH,
- ItemChangedType.LINE_STYLE,
- ItemChangedType.LINE_BG_COLOR,
- ItemChangedType.FILL,
- ItemChangedType.HIGHLIGHTED_COLOR,
- ItemChangedType.HIGHLIGHTED_STYLE,
- ItemChangedType.TEXT,
- ItemChangedType.OVERLAY,
- ItemChangedType.VISUALIZATION_MODE,
- )
-
- self._plotWrapper = _Wrapper()
- self._dealWithPlotConnection(create=True)
-
- def setPlot(self, plot):
- """Define the plot to interact with
-
- :param Union[PlotWidget,SceneWidget,None] plot:
- The plot containing the items on which statistics are applied
- """
- try:
- import OpenGL
- except ImportError:
- has_opengl = False
- else:
- has_opengl = True
- from ..plot3d.SceneWidget import SceneWidget # Lazy import
- self._dealWithPlotConnection(create=False)
- self.clear()
- if plot is None:
- self._plotWrapper = _Wrapper()
- elif isinstance(plot, PlotWidget):
- self._plotWrapper = _PlotWidgetWrapper(plot)
- else:
- if has_opengl is True:
- if isinstance(plot, SceneWidget):
- self._plotWrapper = _SceneWidgetWrapper(plot)
- else: # Expect a ScalarFieldView
- self._plotWrapper = _ScalarFieldViewWrapper(plot)
- else:
- _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView'))
- self._dealWithPlotConnection(create=True)
-
- def setStats(self, statsHandler):
- """Set which stats to display and the associated formatting.
-
- :param StatsHandler statsHandler:
- Set the statistics to be displayed and how to format them using
- """
- if statsHandler is None:
- statsHandler = StatsHandler(statFormatters=())
- elif isinstance(statsHandler, (list, tuple)):
- statsHandler = StatsHandler(statsHandler)
- assert isinstance(statsHandler, StatsHandler)
-
- self._statsHandler = statsHandler
-
- def getStatsHandler(self):
- """Returns the :class:`StatsHandler` in use.
-
- :rtype: StatsHandler
- """
- return self._statsHandler
-
- def getPlot(self):
- """Returns the plot attached to this widget
-
- :rtype: Union[PlotWidget,SceneWidget,None]
- """
- return self._plotWrapper.getPlot()
-
- def _dealWithPlotConnection(self, create=True):
- """Manage connection to plot signals
-
- Note: connection on Item are managed by _addItem and _removeItem methods
- """
- connections = [] # List of (signal, slot) to connect/disconnect
- if self._statsOnVisibleData:
- connections.append(
- (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats))
-
- if self._displayOnlyActItem:
- connections.append(
- (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem))
- else:
- connections += [
- (self._plotWrapper.sigItemAdded, self._addItem),
- (self._plotWrapper.sigItemRemoved, self._removeItem),
- (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)]
-
- for signal, slot in connections:
- if create:
- signal.connect(slot)
- else:
- signal.disconnect(slot)
-
- def _updateItemObserve(self, *args):
- """Reload table depending on mode"""
- raise NotImplementedError('Base class')
-
- def _updateCurrentItem(self, *args):
- """specific callback for the sigCurrentChanged and with the
- _displayOnlyActItem option."""
- raise NotImplementedError('Base class')
-
- def _updateStats(self, item, data_changed=False, roi_changed=False):
- """Update displayed information for given plot item
-
- :param item: The plot item
- :param bool data_changed: is the item data changed.
- :param bool roi_changed: is the associated roi changed.
- """
- raise NotImplementedError('Base class')
-
- def _updateAllStats(self):
- """Update stats for all rows in the table"""
- raise NotImplementedError('Base class')
-
- def setDisplayOnlyActiveItem(self, displayOnlyActItem):
- """Toggle display off all items or only the active/selected one
-
- :param bool displayOnlyActItem:
- True if we want to only show active item
- """
- self._displayOnlyActItem = displayOnlyActItem
-
- def setStatsOnVisibleData(self, b):
- """Toggle computation of statistics on whole data or only visible ones.
-
- .. warning:: When visible data is activated we will process to a simple
- filtering of visible data by the user. The filtering is a
- simple data sub-sampling. No interpolation is made to fit
- data to boundaries.
-
- :param bool b: True if we want to apply statistics only on visible data
- """
- if self._statsOnVisibleData != b:
- self._dealWithPlotConnection(create=False)
- self._statsOnVisibleData = b
- self._dealWithPlotConnection(create=True)
- self._updateAllStats()
-
- def _addItem(self, item):
- """Add a plot item to the table
-
- If item is not supported, it is ignored.
-
- :param item: The plot item
- :returns: True if the item is added to the widget.
- :rtype: bool
- """
- raise NotImplementedError('Base class')
-
- def _removeItem(self, item):
- """Remove table items corresponding to given plot item from the table.
-
- :param item: The plot item
- """
- raise NotImplementedError('Base class')
-
- def _plotCurrentChanged(self, current):
- """Handle change of current item and update selection in table
-
- :param current:
- """
- raise NotImplementedError('Base class')
-
- def clear(self):
- """clear GUI"""
- pass
-
- def _skipPlotItemChangedEvent(self, event):
- """
-
- :param ItemChangedtype event: event to filter or not
- :return: True if we want to ignore this ItemChangedtype
- :rtype: bool
- """
- return event in self.__default_skipped_events
-
- def setUpdateMode(self, mode):
- """Set the way to update the displayed statistics.
-
- :param mode: mode requested for update
- :type mode: Union[str,UpdateMode]
- """
- mode = UpdateMode.from_value(mode)
- if mode != self._updateMode:
- self._updateMode = mode
- self._updateModeHasChanged()
-
- def getUpdateMode(self):
- """Returns update mode (See :meth:`setUpdateMode`).
-
- :return: update mode
- :rtype: UpdateMode
- """
- return self._updateMode
-
- def _updateModeHasChanged(self):
- """callback when the update mode has changed"""
- pass
-
-
-class StatsTable(_StatsWidgetBase, TableWidget):
- """
- TableWidget displaying for each items contained by the Plot some
- information:
-
- * legend
- * minimal value
- * maximal value
- * standard deviation (std)
-
- :param QWidget parent: The widget's parent.
- :param Union[PlotWidget,SceneWidget] plot:
- :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
- """
-
- _LEGEND_HEADER_DATA = 'legend'
- _KIND_HEADER_DATA = 'kind'
-
- sigUpdateModeChanged = qt.Signal(object)
- """Signal emitted when the update mode changed"""
-
- def __init__(self, parent=None, plot=None):
- TableWidget.__init__(self, parent)
- _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
- displayOnlyActItem=False)
-
- # Init for _displayOnlyActItem == False
- assert self._displayOnlyActItem is False
- self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
- self.currentItemChanged.connect(self._currentItemChanged)
-
- self.setRowCount(0)
- self.setColumnCount(2)
-
- # Init headers
- headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
- headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
- self.setHorizontalHeaderItem(0, headerItem)
- headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
- headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
- self.setHorizontalHeaderItem(1, headerItem)
-
- self.setSortingEnabled(True)
- self.setPlot(plot)
-
- @contextmanager
- def _disableSorting(self):
- """Context manager that disables table sorting
-
- Previous state is restored when leaving
- """
- sorting = self.isSortingEnabled()
- if sorting:
- self.setSortingEnabled(False)
- yield
- if sorting:
- self.setSortingEnabled(sorting)
-
- def setStats(self, statsHandler):
- """Set which stats to display and the associated formatting.
-
- :param StatsHandler statsHandler:
- Set the statistics to be displayed and how to format them using
- """
- self._removeAllItems()
- _StatsWidgetBase.setStats(self, statsHandler)
-
- self.setRowCount(0)
- self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind
-
- for index, stat in enumerate(self._statsHandler.stats.values()):
- headerItem = qt.QTableWidgetItem(stat.name.capitalize())
- headerItem.setData(qt.Qt.UserRole, stat.name)
- if stat.description is not None:
- headerItem.setToolTip(stat.description)
- self.setHorizontalHeaderItem(2 + index, headerItem)
-
- horizontalHeader = self.horizontalHeader()
- if hasattr(horizontalHeader, 'setSectionResizeMode'): # Qt5
- horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
- else: # Qt4
- horizontalHeader.setResizeMode(qt.QHeaderView.ResizeToContents)
-
- self._updateItemObserve()
-
- def setPlot(self, plot):
- """Define the plot to interact with
-
- :param Union[PlotWidget,SceneWidget,None] plot:
- The plot containing the items on which statistics are applied
- """
- _StatsWidgetBase.setPlot(self, plot)
- self._updateItemObserve()
-
- def clear(self):
- """Define the plot to interact with
-
- :param Union[PlotWidget,SceneWidget,None] plot:
- The plot containing the items on which statistics are applied
- """
- self._removeAllItems()
-
- def _updateItemObserve(self, *args):
- """Reload table depending on mode"""
- self._removeAllItems()
-
- # Get selected or all items from the plot
- if self._displayOnlyActItem: # Only selected
- items = self._plotWrapper.getSelectedItems()
- else: # All items
- items = self._plotWrapper.getItems()
-
- # Add items to the plot
- for item in items:
- self._addItem(item)
-
- def _updateCurrentItem(self, *args):
- """specific callback for the sigCurrentChanged and with the
- _displayOnlyActItem option.
-
- Behavior: create the tableItems if does not exists.
- If exists, update it only when we are in 'auto' mode"""
- if self.getUpdateMode() is UpdateMode.MANUAL:
- # when sigCurrentChanged is giving the current item
- if len(args) > 0 and isinstance(args[0], (plotitems.Curve, plotitems.Histogram, plotitems.ImageData, plotitems.Scatter)):
- item = args[0]
- tableItems = self._itemToTableItems(item)
- # if the table does not exists yet
- if len(tableItems) == 0:
- self._updateItemObserve()
- else:
- # in this case no current item
- self._updateItemObserve(args)
- else:
- # auto mode
- self._updateItemObserve(args)
-
- def _plotCurrentChanged(self, current):
- """Handle change of current item and update selection in table
-
- :param current:
- """
- row = self._itemToRow(current)
- if row is None:
- if self.currentRow() >= 0:
- self.setCurrentCell(-1, -1)
- elif row != self.currentRow():
- self.setCurrentCell(row, 0)
-
- def _tableItemToItem(self, tableItem):
- """Find the plot item corresponding to a table item
-
- :param QTableWidgetItem tableItem:
- :rtype: QObject
- """
- container = tableItem.data(qt.Qt.UserRole)
- return container()
-
- def _itemToRow(self, item):
- """Find the row corresponding to a plot item
-
- :param item: The plot item
- :return: The corresponding row index
- :rtype: Union[int,None]
- """
- for row in range(self.rowCount()):
- tableItem = self.item(row, 0)
- if self._tableItemToItem(tableItem) == item:
- return row
- return None
-
- def _itemToTableItems(self, item):
- """Find all table items corresponding to a plot item
-
- :param item: The plot item
- :return: An ordered dict of column name to QTableWidgetItem mapping
- for the given plot item.
- :rtype: OrderedDict
- """
- result = OrderedDict()
- row = self._itemToRow(item)
- if row is not None:
- for column in range(self.columnCount()):
- tableItem = self.item(row, column)
- if self._tableItemToItem(tableItem) != item:
- _logger.error("Table item/plot item mismatch")
- else:
- header = self.horizontalHeaderItem(column)
- name = header.data(qt.Qt.UserRole)
- result[name] = tableItem
- return result
-
- def _plotItemChanged(self, event):
- """Handle modifications of the items.
-
- :param event:
- """
- if self.getUpdateMode() is UpdateMode.MANUAL:
- return
- if self._skipPlotItemChangedEvent(event) is True:
- return
- else:
- item = self.sender()
- self._updateStats(item, data_changed=True)
- # deal with stat items visibility
- if event is ItemChangedType.VISIBLE:
- if len(self._itemToTableItems(item).items()) > 0:
- item_0 = list(self._itemToTableItems(item).values())[0]
- row_index = item_0.row()
- self.setRowHidden(row_index, not item.isVisible())
-
- def _addItem(self, item):
- """Add a plot item to the table
-
- If item is not supported, it is ignored.
-
- :param item: The plot item
- :returns: True if the item is added to the widget.
- :rtype: bool
- """
- if self._itemToRow(item) is not None:
- _logger.info("Item already present in the table")
- self._updateStats(item)
- return True
-
- kind = self._plotWrapper.getKind(item)
- if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
- _logger.info("Item has not a supported type: %s", item)
- return False
-
- # Prepare table items
- tableItems = [
- qt.QTableWidgetItem(), # Legend
- qt.QTableWidgetItem()] # Kind
-
- for column in range(2, self.columnCount()):
- header = self.horizontalHeaderItem(column)
- name = header.data(qt.Qt.UserRole)
-
- formatter = self._statsHandler.formatters[name]
- if formatter:
- tableItem = formatter.tabWidgetItemClass()
- else:
- tableItem = qt.QTableWidgetItem()
-
- tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
- if tooltip is not None:
- tableItem.setToolTip(tooltip)
-
- tableItems.append(tableItem)
-
- # Disable sorting while adding table items
- with self._disableSorting():
- # Add a row to the table
- self.setRowCount(self.rowCount() + 1)
-
- # Add table items to the last row
- row = self.rowCount() - 1
- for column, tableItem in enumerate(tableItems):
- tableItem.setData(qt.Qt.UserRole, _Container(item))
- tableItem.setFlags(
- qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
- self.setItem(row, column, tableItem)
-
- # Update table items content
- self._updateStats(item, data_changed=True)
-
- # Listen for item changes
- # Using queued connection to avoid issue with sender
- # being that of the signal calling the signal
- item.sigItemChanged.connect(self._plotItemChanged,
- qt.Qt.QueuedConnection)
-
- return True
-
- def _removeItem(self, item):
- """Remove table items corresponding to given plot item from the table.
-
- :param item: The plot item
- """
- row = self._itemToRow(item)
- if row is None:
- kind = self._plotWrapper.getKind(item)
- if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
- _logger.error("Removing item that is not in table: %s", str(item))
- return
- item.sigItemChanged.disconnect(self._plotItemChanged)
- self.removeRow(row)
-
- def _removeAllItems(self):
- """Remove content of the table"""
- for row in range(self.rowCount()):
- tableItem = self.item(row, 0)
- item = self._tableItemToItem(tableItem)
- item.sigItemChanged.disconnect(self._plotItemChanged)
- self.clearContents()
- self.setRowCount(0)
-
- def _updateStats(self, item, data_changed=False, roi_changed=False):
- """Update displayed information for given plot item
-
- :param item: The plot item
- :param bool data_changed: is the item data changed.
- :param bool roi_changed: is the associated roi changed.
- """
- if item is None:
- return
- plot = self.getPlot()
- if plot is None:
- _logger.info("Plot not available")
- return
-
- row = self._itemToRow(item)
- if row is None:
- _logger.error("This item is not in the table: %s", str(item))
- return
-
- statsHandler = self.getStatsHandler()
- if statsHandler is not None:
- # _updateStats is call when the plot visible area change.
- # to force stats update we consider roi changed
- if self._statsOnVisibleData:
- roi_changed = True
- else:
- roi_changed = False
- stats = statsHandler.calculate(
- item, plot, self._statsOnVisibleData,
- data_changed=data_changed, roi_changed=roi_changed)
- else:
- stats = {}
-
- with self._disableSorting():
- for name, tableItem in self._itemToTableItems(item).items():
- if name == self._LEGEND_HEADER_DATA:
- text = self._plotWrapper.getLabel(item)
- tableItem.setText(text)
- elif name == self._KIND_HEADER_DATA:
- tableItem.setText(self._plotWrapper.getKind(item))
- else:
- value = stats.get(name)
- if value is None:
- _logger.error("Value not found for: %s", name)
- tableItem.setText('-')
- else:
- tableItem.setText(str(value))
-
- def _updateAllStats(self, is_request=False):
- """Update stats for all rows in the table
-
- :param bool is_request: True if come from a manual request
- """
- if self.getUpdateMode() is UpdateMode.MANUAL and not is_request:
- return
- with self._disableSorting():
- for row in range(self.rowCount()):
- tableItem = self.item(row, 0)
- item = self._tableItemToItem(tableItem)
- self._updateStats(item, data_changed=is_request)
-
- def _currentItemChanged(self, current, previous):
- """Handle change of selection in table and sync plot selection
-
- :param QTableWidgetItem current:
- :param QTableWidgetItem previous:
- """
- if current and current.row() >= 0:
- item = self._tableItemToItem(current)
- self._plotWrapper.setCurrentItem(item)
-
- def setDisplayOnlyActiveItem(self, displayOnlyActItem):
- """Toggle display off all items or only the active/selected one
-
- :param bool displayOnlyActItem:
- True if we want to only show active item
- """
- if self._displayOnlyActItem == displayOnlyActItem:
- return
- self._dealWithPlotConnection(create=False)
- if not self._displayOnlyActItem:
- self.currentItemChanged.disconnect(self._currentItemChanged)
-
- _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem)
-
- self._updateItemObserve()
- self._dealWithPlotConnection(create=True)
-
- if not self._displayOnlyActItem:
- self.currentItemChanged.connect(self._currentItemChanged)
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
- else:
- self.setSelectionMode(qt.QAbstractItemView.NoSelection)
-
- def _updateModeHasChanged(self):
- self.sigUpdateModeChanged.emit(self._updateMode)
-
-
-class UpdateModeWidget(qt.QWidget):
- """Widget used to select the mode of update"""
- sigUpdateModeChanged = qt.Signal(object)
- """signal emitted when the mode for update changed"""
- sigUpdateRequested = qt.Signal()
- """signal emitted when an manual request for example is activate"""
-
- def __init__(self, parent=None):
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QHBoxLayout())
- self._buttonGrp = qt.QButtonGroup(parent=self)
- self._buttonGrp.setExclusive(True)
-
- spacer = qt.QSpacerItem(20, 20,
- qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Minimum)
- self.layout().addItem(spacer)
-
- self._autoRB = qt.QRadioButton('auto', parent=self)
- self.layout().addWidget(self._autoRB)
- self._buttonGrp.addButton(self._autoRB)
-
- self._manualRB = qt.QRadioButton('manual', parent=self)
- self.layout().addWidget(self._manualRB)
- self._buttonGrp.addButton(self._manualRB)
- self._manualRB.setChecked(True)
-
- refresh_icon = icons.getQIcon('view-refresh')
- self._updatePB = qt.QPushButton(refresh_icon, '', parent=self)
- self.layout().addWidget(self._updatePB)
-
- # connect signal / SLOT
- self._updatePB.clicked.connect(self._updateRequested)
- self._manualRB.toggled.connect(self._manualButtonToggled)
- self._autoRB.toggled.connect(self._autoButtonToggled)
-
- def _manualButtonToggled(self, checked):
- if checked:
- self.setUpdateMode(UpdateMode.MANUAL)
- self.sigUpdateModeChanged.emit(self.getUpdateMode())
-
- def _autoButtonToggled(self, checked):
- if checked:
- self.setUpdateMode(UpdateMode.AUTO)
- self.sigUpdateModeChanged.emit(self.getUpdateMode())
-
- def _updateRequested(self):
- if self.getUpdateMode() is UpdateMode.MANUAL:
- self.sigUpdateRequested.emit()
-
- def setUpdateMode(self, mode):
- """Set the way to update the displayed statistics.
-
- :param mode: mode requested for update
- :type mode: Union[str,UpdateMode]
- """
- mode = UpdateMode.from_value(mode)
-
- if mode is UpdateMode.AUTO:
- if not self._autoRB.isChecked():
- self._autoRB.setChecked(True)
- elif mode is UpdateMode.MANUAL:
- if not self._manualRB.isChecked():
- self._manualRB.setChecked(True)
- else:
- raise ValueError('mode', mode, 'is not recognized')
-
- def getUpdateMode(self):
- """Returns update mode (See :meth:`setUpdateMode`).
-
- :return: the active update mode
- :rtype: UpdateMode
- """
- if self._manualRB.isChecked():
- return UpdateMode.MANUAL
- elif self._autoRB.isChecked():
- return UpdateMode.AUTO
- else:
- raise RuntimeError("No mode selected")
-
- def showRadioButtons(self, show):
- """show / hide the QRadioButtons
-
- :param bool show: if True make RadioButton visible
- """
- self._autoRB.setVisible(show)
- self._manualRB.setVisible(show)
-
-
-class _OptionsWidget(qt.QToolBar):
-
- def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False):
- assert updateMode is not None
- qt.QToolBar.__init__(self, parent)
- self.setIconSize(qt.QSize(16, 16))
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-active-items"))
- action.setText("Active items only")
- action.setToolTip("Display stats for active items only.")
- action.setCheckable(True)
- action.setChecked(displayOnlyActItem)
- self.__displayActiveItems = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-whole-items"))
- action.setText("All items")
- action.setToolTip("Display stats for all available items.")
- action.setCheckable(True)
- self.__displayWholeItems = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-visible-data"))
- action.setText("Use the visible data range")
- action.setToolTip("Use the visible data range.<br/>"
- "If activated the data is filtered to only use"
- "visible data of the plot."
- "The filtering is a data sub-sampling."
- "No interpolation is made to fit data to"
- "boundaries.")
- action.setCheckable(True)
- self.__useVisibleData = action
-
- action = qt.QAction(self)
- action.setIcon(icons.getQIcon("stats-whole-data"))
- action.setText("Use the full data range")
- action.setToolTip("Use the full data range.")
- action.setCheckable(True)
- action.setChecked(True)
- self.__useWholeData = action
-
- self.addAction(self.__displayWholeItems)
- self.addAction(self.__displayActiveItems)
- self.addSeparator()
- self.addAction(self.__useVisibleData)
- self.addAction(self.__useWholeData)
-
- self.itemSelection = qt.QActionGroup(self)
- self.itemSelection.setExclusive(True)
- self.itemSelection.addAction(self.__displayActiveItems)
- self.itemSelection.addAction(self.__displayWholeItems)
-
- self.dataRangeSelection = qt.QActionGroup(self)
- self.dataRangeSelection.setExclusive(True)
- self.dataRangeSelection.addAction(self.__useWholeData)
- self.dataRangeSelection.addAction(self.__useVisibleData)
-
- self.__updateStatsAction = qt.QAction(self)
- self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh"))
- self.__updateStatsAction.setText("update statistics")
- self.__updateStatsAction.setToolTip("update statistics")
- self.__updateStatsAction.setCheckable(False)
- self._updateStatsSep = self.addSeparator()
- self.addAction(self.__updateStatsAction)
-
- self._setUpdateMode(mode=updateMode)
-
- # expose API
- self.sigUpdateStats = self.__updateStatsAction.triggered
-
- def isActiveItemMode(self):
- return self.itemSelection.checkedAction() is self.__displayActiveItems
-
- def setDisplayActiveItems(self, only_active):
- self.__displayActiveItems.setChecked(only_active)
- self.__displayWholeItems.setChecked(not only_active)
-
- def isVisibleDataRangeMode(self):
- return self.dataRangeSelection.checkedAction() is self.__useVisibleData
-
- def setVisibleDataRangeModeEnabled(self, enabled):
- """Enable/Disable the visible data range mode
-
- :param bool enabled: True to allow user to choose
- stats on visible data
- """
- self.__useVisibleData.setEnabled(enabled)
- if not enabled:
- self.__useWholeData.setChecked(True)
-
- def _setUpdateMode(self, mode):
- self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL)
- self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL)
-
- def getUpdateStatsAction(self):
- """
-
- :return: the action for the automatic mode
- :rtype: QAction
- """
- return self.__updateStatsAction
-
-
-class StatsWidget(qt.QWidget):
- """
- Widget displaying a set of :class:`Stat` to be displayed on a
- :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
- Also contains options to:
-
- * compute statistics on all the data or on visible data only
- * show statistics of all items or only the active one
-
- :param QWidget parent: Qt parent
- :param Union[PlotWidget,SceneWidget] plot:
- The plot containing items on which we want statistics.
- :param StatsHandler stats:
- Set the statistics to be displayed and how to format them using
- """
-
- sigVisibilityChanged = qt.Signal(bool)
- """Signal emitted when the visibility of this widget changes.
-
- It Provides the visibility of the widget.
- """
-
- NUMBER_FORMAT = '{0:.3f}'
-
- def __init__(self, parent=None, plot=None, stats=None):
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL)
- self.layout().addWidget(self._options)
- self._statsTable = StatsTable(parent=self, plot=plot)
- self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode())
- self._options._setUpdateMode(mode=self._statsTable.getUpdateMode())
- self.setStats(stats)
-
- self.layout().addWidget(self._statsTable)
-
- old = self._statsTable.blockSignals(True)
- self._options.itemSelection.triggered.connect(
- self._optSelectionChanged)
- self._options.dataRangeSelection.triggered.connect(
- self._optDataRangeChanged)
- self._optDataRangeChanged()
- self._statsTable.blockSignals(old)
-
- self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode)
- callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True)
- self._options.sigUpdateStats.connect(callback)
-
- def _getStatsTable(self):
- """Returns the :class:`StatsTable` used by this widget.
-
- :rtype: StatsTable
- """
- return self._statsTable
-
- def showEvent(self, event):
- self.sigVisibilityChanged.emit(True)
- qt.QWidget.showEvent(self, event)
-
- def hideEvent(self, event):
- self.sigVisibilityChanged.emit(False)
- qt.QWidget.hideEvent(self, event)
-
- def _optSelectionChanged(self, action=None):
- self._getStatsTable().setDisplayOnlyActiveItem(
- self._options.isActiveItemMode())
-
- def _optDataRangeChanged(self, action=None):
- self._getStatsTable().setStatsOnVisibleData(
- self._options.isVisibleDataRangeMode())
-
- # Proxy methods
-
- @docstring(StatsTable)
- def setStats(self, statsHandler):
- return self._getStatsTable().setStats(statsHandler=statsHandler)
-
- @docstring(StatsTable)
- def setPlot(self, plot):
- self._options.setVisibleDataRangeModeEnabled(
- plot is None or isinstance(plot, PlotWidget))
- return self._getStatsTable().setPlot(plot=plot)
-
- @docstring(StatsTable)
- def getPlot(self):
- return self._getStatsTable().getPlot()
-
- @docstring(StatsTable)
- def setDisplayOnlyActiveItem(self, displayOnlyActItem):
- old = self._options.blockSignals(True)
- # update the options
- self._options.setDisplayActiveItems(displayOnlyActItem)
- self._options.blockSignals(old)
- return self._getStatsTable().setDisplayOnlyActiveItem(
- displayOnlyActItem=displayOnlyActItem)
-
- @docstring(StatsTable)
- def setStatsOnVisibleData(self, b):
- return self._getStatsTable().setStatsOnVisibleData(b=b)
-
- @docstring(StatsTable)
- def getUpdateMode(self):
- return self._statsTable.getUpdateMode()
-
- @docstring(StatsTable)
- def setUpdateMode(self, mode):
- self._statsTable.setUpdateMode(mode)
-
-
-DEFAULT_STATS = StatsHandler((
- (statsmdl.StatMin(), StatFormatter()),
- statsmdl.StatCoordMin(),
- (statsmdl.StatMax(), StatFormatter()),
- statsmdl.StatCoordMax(),
- statsmdl.StatCOM(),
- (('mean', numpy.mean), StatFormatter()),
- (('std', numpy.std), StatFormatter()),
-))
-
-
-class BasicStatsWidget(StatsWidget):
- """
- Widget defining a simple set of :class:`Stat` to be displayed on a
- :class:`StatsWidget`.
-
- :param QWidget parent: Qt parent
- :param PlotWidget plot:
- The plot containing items on which we want statistics.
- :param StatsHandler stats:
- Set the statistics to be displayed and how to format them using
-
- .. snapshotqt:: img/BasicStatsWidget.png
- :width: 300px
- :align: center
-
- from silx.gui.plot import Plot1D
- from silx.gui.plot.StatsWidget import BasicStatsWidget
-
- plot = Plot1D()
- x = range(100)
- y = x
- plot.addCurve(x, y, legend='curve_0')
- plot.setActiveCurve('curve_0')
-
- widget = BasicStatsWidget(plot=plot)
- widget.show()
- """
- def __init__(self, parent=None, plot=None):
- StatsWidget.__init__(self, parent=parent, plot=plot,
- stats=DEFAULT_STATS)
-
-
-class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
- """
- Widget made to display stats into a QLayout with couple (QLabel, QLineEdit)
- created for each stats.
- The layout can be defined prior of adding any statistic.
-
- :param QWidget parent: Qt parent
- :param Union[PlotWidget,SceneWidget] plot:
- The plot containing items on which we want statistics.
- :param str kind: the kind of plotitems we want to display
- :param StatsHandler stats:
- Set the statistics to be displayed and how to format them using
- :param bool statsOnVisibleData: compute statistics for the whole data or
- only visible ones.
- """
-
- sigUpdateModeChanged = qt.Signal(object)
- """Signal emitted when the update mode changed"""
-
- def __init__(self, parent=None, plot=None, kind='curve', stats=None,
- statsOnVisibleData=False):
- self._item_kind = kind
- """The item displayed"""
- self._statQlineEdit = {}
- """list of legends actually displayed"""
- self._n_statistics_per_line = 4
- """number of statistics displayed per line in the grid layout"""
- qt.QWidget.__init__(self, parent)
- _StatsWidgetBase.__init__(self,
- statsOnVisibleData=statsOnVisibleData,
- displayOnlyActItem=True)
- self.setLayout(self._createLayout())
- self.setPlot(plot)
- if stats is not None:
- self.setStats(stats)
-
- def _addItemForStatistic(self, statistic):
- assert isinstance(statistic, statsmdl.StatBase)
- assert statistic.name in self._statsHandler.stats
-
- self.layout().setSpacing(2)
- self.layout().setContentsMargins(2, 2, 2, 2)
-
- if isinstance(self.layout(), qt.QGridLayout):
- parent = self
- else:
- widget = qt.QWidget(parent=self)
- parent = widget
-
- qLabel = qt.QLabel(statistic.name + ':', parent=parent)
- qLineEdit = qt.QLineEdit('', parent=parent)
- qLineEdit.setReadOnly(True)
-
- self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
- self._statQlineEdit[statistic.name] = qLineEdit
-
- def setPlot(self, plot):
- """Define the plot to interact with
-
- :param Union[PlotWidget,SceneWidget,None] plot:
- The plot containing the items on which statistics are applied
- """
- _StatsWidgetBase.setPlot(self, plot)
- self._updateAllStats()
-
- def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
- raise NotImplementedError('Base class')
-
- def setStats(self, statsHandler):
- """Set which stats to display and the associated formatting.
-
- :param StatsHandler statsHandler:
- Set the statistics to be displayed and how to format them using
- """
- _StatsWidgetBase.setStats(self, statsHandler)
- for statName, stat in list(self._statsHandler.stats.items()):
- self._addItemForStatistic(stat)
- self._updateAllStats()
-
- def _activeItemChanged(self, kind, previous, current):
- if self.getUpdateMode() is UpdateMode.MANUAL:
- return
- if kind == self._item_kind:
- self._updateAllStats()
-
- def _updateAllStats(self):
- plot = self.getPlot()
- if plot is not None:
- _items = self._plotWrapper.getSelectedItems()
-
- def kind_filter(_item):
- return self._plotWrapper.getKind(_item) == self.getKind()
- items = list(filter(kind_filter, _items))
- assert len(items) in (0, 1)
- if len(items) == 1:
- self._setItem(items[0])
-
- def setKind(self, kind):
- """Change the kind of active item to display
- :param str kind: kind of item to display information for ('curve' ...)
- """
- if self._item_kind != kind:
- self._item_kind = kind
- self._updateItemObserve()
-
- def getKind(self):
- """
- :return: kind of item we want to compute statistic for
- :rtype: str
- """
- return self._item_kind
-
- def _setItem(self, item, data_changed=True):
- if item is None:
- for stat_name, stat_widget in self._statQlineEdit.items():
- stat_widget.setText('')
- elif (self._statsHandler is not None and len(
- self._statsHandler.stats) > 0):
- plot = self.getPlot()
- if plot is not None:
- statsValDict = self._statsHandler.calculate(item,
- plot,
- self._statsOnVisibleData,
- data_changed=data_changed)
- for statName, statVal in list(statsValDict.items()):
- self._statQlineEdit[statName].setText(statVal)
-
- def _updateItemObserve(self, *argv):
- if self.getUpdateMode() is UpdateMode.MANUAL:
- return
- assert self._displayOnlyActItem
- _items = self._plotWrapper.getSelectedItems()
-
- def kind_filter(_item):
- return self._plotWrapper.getKind(_item) == self.getKind()
- items = list(filter(kind_filter, _items))
- assert len(items) in (0, 1)
- _item = items[0] if len(items) == 1 else None
- self._setItem(_item, data_changed=True)
-
- def _updateCurrentItem(self):
- self._updateItemObserve()
-
- def _createLayout(self):
- """create an instance of the main QLayout"""
- raise NotImplementedError('Base class')
-
- def _addItem(self, item):
- raise NotImplementedError('Display only the active item')
-
- def _removeItem(self, item):
- raise NotImplementedError('Display only the active item')
-
- def _plotCurrentChanged(self, current):
- raise NotImplementedError('Display only the active item')
-
- def _updateModeHasChanged(self):
- self.sigUpdateModeChanged.emit(self._updateMode)
-
-
-class _BasicLineStatsWidget(_BaseLineStatsWidget):
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False):
- _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
- plot=plot, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
-
- def _createLayout(self):
- return FlowLayout()
-
- def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
- # create a mother widget to make sure both qLabel & qLineEdit will
- # always be displayed side by side
- widget = qt.QWidget(parent=self)
- widget.setLayout(qt.QHBoxLayout())
- widget.layout().setSpacing(0)
- widget.layout().setContentsMargins(0, 0, 0, 0)
-
- widget.layout().addWidget(qLabel)
- widget.layout().addWidget(qLineEdit)
-
- self.layout().addWidget(widget)
-
- def _addOptionsWidget(self, widget):
- self.layout().addWidget(widget)
-
-
-class BasicLineStatsWidget(qt.QWidget):
- """
- Widget defining a simple set of :class:`Stat` to be displayed on a
- :class:`LineStatsWidget`.
-
- :param QWidget parent: Qt parent
- :param Union[PlotWidget,SceneWidget] plot:
- The plot containing items on which we want statistics.
- :param str kind: the kind of plotitems we want to display
- :param StatsHandler stats:
- Set the statistics to be displayed and how to format them using
- :param bool statsOnVisibleData: compute statistics for the whole data or
- only visible ones.
- """
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False):
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QHBoxLayout())
- self.layout().setSpacing(0)
- self.layout().setContentsMargins(0, 0, 0, 0)
- self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot,
- kind=kind, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
- self.layout().addWidget(self._lineStatsWidget)
-
- self._options = UpdateModeWidget()
- self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
- self._options.showRadioButtons(False)
- self.layout().addWidget(self._options)
-
- # connect Signal ? SLOT
- self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
- self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
- self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
-
- def showControl(self, visible):
- self._options.setVisible(visible)
-
- # Proxy methods
-
- @docstring(_BasicLineStatsWidget)
- def setUpdateMode(self, mode):
- self._lineStatsWidget.setUpdateMode(mode=mode)
-
- @docstring(_BasicLineStatsWidget)
- def getUpdateMode(self):
- return self._lineStatsWidget.getUpdateMode()
-
- @docstring(_BasicLineStatsWidget)
- def setPlot(self, plot):
- self._lineStatsWidget.setPlot(plot=plot)
-
- @docstring(_BasicLineStatsWidget)
- def setStats(self, statsHandler):
- self._lineStatsWidget.setStats(statsHandler=statsHandler)
-
- @docstring(_BasicLineStatsWidget)
- def setKind(self, kind):
- self._lineStatsWidget.setKind(kind=kind)
-
- @docstring(_BasicLineStatsWidget)
- def getKind(self):
- return self._lineStatsWidget.getKind()
-
- @docstring(_BasicLineStatsWidget)
- def setStatsOnVisibleData(self, b):
- self._lineStatsWidget.setStatsOnVisibleData(b)
-
- @docstring(UpdateModeWidget)
- def showRadioButtons(self, show):
- self._options.showRadioButtons(show=show)
-
-
-class _BasicGridStatsWidget(_BaseLineStatsWidget):
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False,
- statsPerLine=4):
- _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
- plot=plot, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
- self._n_statistics_per_line = statsPerLine
-
- def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
- column = len(self._statQlineEdit) % self._n_statistics_per_line
- row = len(self._statQlineEdit) // self._n_statistics_per_line
- self.layout().addWidget(qLabel, row, column * 2)
- self.layout().addWidget(qLineEdit, row, column * 2 + 1)
-
- def _createLayout(self):
- return qt.QGridLayout()
-
-
-class BasicGridStatsWidget(qt.QWidget):
- """
- pymca design like widget
-
- :param QWidget parent: Qt parent
- :param Union[PlotWidget,SceneWidget] plot:
- The plot containing items on which we want statistics.
- :param StatsHandler stats:
- Set the statistics to be displayed and how to format them using
- :param str kind: the kind of plotitems we want to display
- :param bool statsOnVisibleData: compute statistics for the whole data or
- only visible ones.
- :param int statsPerLine: number of statistic to be displayed per line
-
- .. snapshotqt:: img/BasicGridStatsWidget.png
- :width: 600px
- :align: center
-
- from silx.gui.plot import Plot1D
- from silx.gui.plot.StatsWidget import BasicGridStatsWidget
-
- plot = Plot1D()
- x = range(100)
- y = x
- plot.addCurve(x, y, legend='curve_0')
- plot.setActiveCurve('curve_0')
-
- widget = BasicGridStatsWidget(plot=plot, kind='curve')
- widget.show()
- """
-
- def __init__(self, parent=None, plot=None, kind='curve',
- stats=DEFAULT_STATS, statsOnVisibleData=False):
- qt.QWidget.__init__(self, parent)
- self.setLayout(qt.QVBoxLayout())
- self.layout().setSpacing(0)
- self.layout().setContentsMargins(0, 0, 0, 0)
-
- self._options = UpdateModeWidget()
- self._options.showRadioButtons(False)
- self.layout().addWidget(self._options)
-
- self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot,
- kind=kind, stats=stats,
- statsOnVisibleData=statsOnVisibleData)
- self.layout().addWidget(self._lineStatsWidget)
-
- # tune options
- self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
-
- # connect Signal ? SLOT
- self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
- self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
- self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
-
- def showControl(self, visible):
- self._options.setVisible(visible)
-
- @docstring(_BasicGridStatsWidget)
- def setUpdateMode(self, mode):
- self._lineStatsWidget.setUpdateMode(mode=mode)
-
- @docstring(_BasicGridStatsWidget)
- def getUpdateMode(self):
- return self._lineStatsWidget.getUpdateMode()
-
- @docstring(_BasicGridStatsWidget)
- def setPlot(self, plot):
- self._lineStatsWidget.setPlot(plot=plot)
-
- @docstring(_BasicGridStatsWidget)
- def setStats(self, statsHandler):
- self._lineStatsWidget.setStats(statsHandler=statsHandler)
-
- @docstring(_BasicGridStatsWidget)
- def setKind(self, kind):
- self._lineStatsWidget.setKind(kind=kind)
-
- @docstring(_BasicGridStatsWidget)
- def getKind(self):
- return self._lineStatsWidget.getKind()
-
- @docstring(_BasicGridStatsWidget)
- def setStatsOnVisibleData(self, b):
- self._lineStatsWidget.setStatsOnVisibleData(b)
-
- @docstring(UpdateModeWidget)
- def showRadioButtons(self, show):
- self._options.showRadioButtons(show=show)
diff --git a/silx/gui/plot/_utils/__init__.py b/silx/gui/plot/_utils/__init__.py
deleted file mode 100644
index 3c2dfa4..0000000
--- a/silx/gui/plot/_utils/__init__.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Miscellaneous utility functions for the Plot"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "21/03/2017"
-
-
-import numpy
-
-from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
-from .panzoom import applyZoomToPlot, applyPan
-
-
-def addMarginsToLimits(margins, isXLog, isYLog,
- xMin, xMax, yMin, yMax, y2Min=None, y2Max=None):
- """Returns updated limits by extending them with margins.
-
- :param margins: The ratio of the margins to add or None for no margins.
- :type margins: A 4-tuple of floats as
- (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
-
- :return: The updated limits
- :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or
- (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max
- are provided.
- """
- if margins is not None:
- xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins
-
- if not isXLog:
- xRange = xMax - xMin
- xMin -= xMinMargin * xRange
- xMax += xMaxMargin * xRange
-
- elif xMin > 0. and xMax > 0.: # Log scale
- # Do not apply margins if limits < 0
- xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax)
- xRangeLog = xMaxLog - xMinLog
- xMin = pow(10., xMinLog - xMinMargin * xRangeLog)
- xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog)
-
- if not isYLog:
- yRange = yMax - yMin
- yMin -= yMinMargin * yRange
- yMax += yMaxMargin * yRange
- elif yMin > 0. and yMax > 0.: # Log scale
- # Do not apply margins if limits < 0
- yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax)
- yRangeLog = yMaxLog - yMinLog
- yMin = pow(10., yMinLog - yMinMargin * yRangeLog)
- yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog)
-
- if y2Min is not None and y2Max is not None:
- if not isYLog:
- yRange = y2Max - y2Min
- y2Min -= yMinMargin * yRange
- y2Max += yMaxMargin * yRange
- elif y2Min > 0. and y2Max > 0.: # Log scale
- # Do not apply margins if limits < 0
- yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max)
- yRangeLog = yMaxLog - yMinLog
- y2Min = pow(10., yMinLog - yMinMargin * yRangeLog)
- y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog)
-
- if y2Min is None or y2Max is None:
- return xMin, xMax, yMin, yMax
- else:
- return xMin, xMax, yMin, yMax, y2Min, y2Max
-
diff --git a/silx/gui/plot/_utils/panzoom.py b/silx/gui/plot/_utils/panzoom.py
deleted file mode 100644
index 3946a04..0000000
--- a/silx/gui/plot/_utils/panzoom.py
+++ /dev/null
@@ -1,292 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Functions to apply pan and zoom on a Plot"""
-
-__authors__ = ["T. Vincent", "V. Valls"]
-__license__ = "MIT"
-__date__ = "08/08/2017"
-
-
-import math
-import numpy
-
-
-# Float 32 info ###############################################################
-# Using min/max value below limits of float32
-# so operation with such value (e.g., max - min) do not overflow
-
-FLOAT32_SAFE_MIN = -1e37
-FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny
-FLOAT32_SAFE_MAX = 1e37
-# TODO double support
-
-
-def scale1DRange(min_, max_, center, scale, isLog):
- """Scale a 1D range given a scale factor and an center point.
-
- Keeps the values in a smaller range than float32.
-
- :param float min_: The current min value of the range.
- :param float max_: The current max value of the range.
- :param float center: The center of the zoom (i.e., invariant point).
- :param float scale: The scale to use for zoom
- :param bool isLog: Whether using log scale or not.
- :return: The zoomed range.
- :rtype: tuple of 2 floats: (min, max)
- """
- if isLog:
- # Min and center can be < 0 when
- # autoscale is off and switch to log scale
- # max_ < 0 should not happen
- min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS
- center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS
- max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS
-
- if min_ == max_:
- return min_, max_
-
- offset = (center - min_) / (max_ - min_)
- range_ = (max_ - min_) / scale
- newMin = center - offset * range_
- newMax = center + (1. - offset) * range_
-
- if isLog:
- # No overflow as exponent is log10 of a float32
- newMin = pow(10., newMin)
- newMax = pow(10., newMax)
- newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
- newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
- else:
- newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
- newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
- return newMin, newMax
-
-
-def applyZoomToPlot(plot, scaleF, center=None):
- """Zoom in/out plot given a scale and a center point.
-
- :param plot: The plot on which to apply zoom.
- :param float scaleF: Scale factor of zoom.
- :param center: (x, y) coords in pixel coordinates of the zoom center.
- :type center: 2-tuple of float
- """
- xMin, xMax = plot.getXAxis().getLimits()
- yMin, yMax = plot.getYAxis().getLimits()
-
- if center is None:
- left, top, width, height = plot.getPlotBoundsInPixels()
- cx, cy = left + width // 2, top + height // 2
- else:
- cx, cy = center
-
- dataCenterPos = plot.pixelToData(cx, cy)
- assert dataCenterPos is not None
-
- xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF,
- plot.getXAxis()._isLogarithmic())
-
- yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF,
- plot.getYAxis()._isLogarithmic())
-
- dataPos = plot.pixelToData(cx, cy, axis="right")
- assert dataPos is not None
- y2Center = dataPos[1]
- y2Min, y2Max = plot.getYAxis(axis="right").getLimits()
- y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF,
- plot.getYAxis()._isLogarithmic())
-
- plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
-
-
-def applyPan(min_, max_, panFactor, isLog10):
- """Returns a new range with applied panning.
-
- Moves the range according to panFactor.
- If isLog10 is True, converts to log10 before moving.
-
- :param float min_: Min value of the data range to pan.
- :param float max_: Max value of the data range to pan.
- Must be >= min.
- :param float panFactor: Signed proportion of the range to use for pan.
- :param bool isLog10: True if log10 scale, False if linear scale.
- :return: New min and max value with pan applied.
- :rtype: 2-tuple of float.
- """
- if isLog10 and min_ > 0.:
- # Negative range and log scale can happen with matplotlib
- logMin, logMax = math.log10(min_), math.log10(max_)
- logOffset = panFactor * (logMax - logMin)
- newMin = pow(10., logMin + logOffset)
- newMax = pow(10., logMax + logOffset)
-
- # Takes care of out-of-range values
- if newMin > 0. and newMax < float('inf'):
- min_, max_ = newMin, newMax
-
- else:
- offset = panFactor * (max_ - min_)
- newMin, newMax = min_ + offset, max_ + offset
-
- # Takes care of out-of-range values
- if newMin > - float('inf') and newMax < float('inf'):
- min_, max_ = newMin, newMax
- return min_, max_
-
-
-class _Unset(object):
- """To be able to have distinction between None and unset"""
- pass
-
-
-class ViewConstraints(object):
- """
- Store constraints applied on the view box and compute the resulting view box.
- """
-
- def __init__(self):
- self._min = [None, None]
- self._max = [None, None]
- self._minRange = [None, None]
- self._maxRange = [None, None]
-
- def update(self, xMin=_Unset, xMax=_Unset,
- yMin=_Unset, yMax=_Unset,
- minXRange=_Unset, maxXRange=_Unset,
- minYRange=_Unset, maxYRange=_Unset):
- """
- Update the constraints managed by the object
-
- The constraints are the same as the ones provided by PyQtGraph.
-
- :param float xMin: Minimum allowed x-axis value.
- (default do not change the stat, None remove the constraint)
- :param float xMax: Maximum allowed x-axis value.
- (default do not change the stat, None remove the constraint)
- :param float yMin: Minimum allowed y-axis value.
- (default do not change the stat, None remove the constraint)
- :param float yMax: Maximum allowed y-axis value.
- (default do not change the stat, None remove the constraint)
- :param float minXRange: Minimum allowed left-to-right span across the
- view (default do not change the stat, None remove the constraint)
- :param float maxXRange: Maximum allowed left-to-right span across the
- view (default do not change the stat, None remove the constraint)
- :param float minYRange: Minimum allowed top-to-bottom span across the
- view (default do not change the stat, None remove the constraint)
- :param float maxYRange: Maximum allowed top-to-bottom span across the
- view (default do not change the stat, None remove the constraint)
- :return: True if the constraints was changed
- """
- updated = False
-
- minRange = [minXRange, minYRange]
- maxRange = [maxXRange, maxYRange]
- minPos = [xMin, yMin]
- maxPos = [xMax, yMax]
-
- for axis in range(2):
-
- value = minPos[axis]
- if value is not _Unset and value != self._min[axis]:
- self._min[axis] = value
- updated = True
-
- value = maxPos[axis]
- if value is not _Unset and value != self._max[axis]:
- self._max[axis] = value
- updated = True
-
- value = minRange[axis]
- if value is not _Unset and value != self._minRange[axis]:
- self._minRange[axis] = value
- updated = True
-
- value = maxRange[axis]
- if value is not _Unset and value != self._maxRange[axis]:
- self._maxRange[axis] = value
- updated = True
-
- # Sanity checks
-
- for axis in range(2):
- if self._maxRange[axis] is not None and self._min[axis] is not None and self._max[axis] is not None:
- # max range cannot be larger than bounds
- diff = self._max[axis] - self._min[axis]
- self._maxRange[axis] = min(self._maxRange[axis], diff)
- updated = True
-
- return updated
-
- def normalize(self, xMin, xMax, yMin, yMax, allow_scaling=True):
- """Normalize a view range defined by x and y corners using predefined
- containts.
-
- :param float xMin: Min position of the x-axis
- :param float xMax: Max position of the x-axis
- :param float yMin: Min position of the y-axis
- :param float yMax: Max position of the y-axis
- :param bool allow_scaling: Allow or not to apply scaling for the
- normalization. Used according to the interaction mode.
- :return: A normalized tuple of (xMin, xMax, yMin, yMax)
- """
- viewRange = [[xMin, xMax], [yMin, yMax]]
-
- for axis in range(2):
- # clamp xRange and yRange
- if allow_scaling:
- diff = viewRange[axis][1] - viewRange[axis][0]
- delta = None
- if self._maxRange[axis] is not None and diff > self._maxRange[axis]:
- delta = self._maxRange[axis] - diff
- elif self._minRange[axis] is not None and diff < self._minRange[axis]:
- delta = self._minRange[axis] - diff
- if delta is not None:
- viewRange[axis][0] -= delta * 0.5
- viewRange[axis][1] += delta * 0.5
-
- # clamp min and max positions
- outMin = self._min[axis] is not None and viewRange[axis][0] < self._min[axis]
- outMax = self._max[axis] is not None and viewRange[axis][1] > self._max[axis]
-
- if outMin and outMax:
- if allow_scaling:
- # we can clamp both sides
- viewRange[axis][0] = self._min[axis]
- viewRange[axis][1] = self._max[axis]
- else:
- # center the result
- delta = viewRange[axis][1] - viewRange[axis][0]
- mid = self._min[axis] + self._max[axis] - self._min[axis]
- viewRange[axis][0] = mid - delta
- viewRange[axis][1] = mid + delta
- elif outMin:
- delta = self._min[axis] - viewRange[axis][0]
- viewRange[axis][0] += delta
- viewRange[axis][1] += delta
- elif outMax:
- delta = self._max[axis] - viewRange[axis][1]
- viewRange[axis][0] += delta
- viewRange[axis][1] += delta
-
- return viewRange[0][0], viewRange[0][1], viewRange[1][0], viewRange[1][1]
diff --git a/silx/gui/plot/_utils/test/__init__.py b/silx/gui/plot/_utils/test/__init__.py
deleted file mode 100644
index 624dbcb..0000000
--- a/silx/gui/plot/_utils/test/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "18/10/2016"
-
-
-import unittest
-
-from .test_dtime_ticklayout import suite as test_dtime_ticklayout_suite
-from .test_ticklayout import suite as test_ticklayout_suite
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(test_dtime_ticklayout_suite())
- testsuite.addTest(test_ticklayout_suite())
- return testsuite
diff --git a/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
deleted file mode 100644
index 2b87148..0000000
--- a/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["P. Kenter"]
-__license__ = "MIT"
-__date__ = "06/04/2018"
-
-
-import datetime as dt
-import unittest
-
-
-from silx.gui.plot._utils.dtime_ticklayout import (
- calcTicks, DtUnit, SECONDS_PER_YEAR)
-
-
-class DtTestTickLayout(unittest.TestCase):
- """Test ticks layout algorithms"""
-
- def testSmallMonthlySpacing(self):
- """ Tests a range that did result in a spacing of less than 1 month.
- It is impossible to add fractional month so the unit must be in days
- """
- from dateutil import parser
- d1 = parser.parse("2017-01-03 13:15:06.000044")
- d2 = parser.parse("2017-03-08 09:16:16.307584")
- _ticks, _units, spacing = calcTicks(d1, d2, nTicks=4)
-
- self.assertEqual(spacing, DtUnit.DAYS)
-
-
- def testNoCrash(self):
- """ Creates many combinations of and number-of-ticks and end-dates;
- tests that it doesn't give an exception and returns a reasonable number
- of ticks.
- """
- d1 = dt.datetime(2017, 1, 3, 13, 15, 6, 44)
-
- value = 100e-6 # Start at 100 micro sec range.
-
- while value <= 200 * SECONDS_PER_YEAR:
-
- d2 = d1 + dt.timedelta(microseconds=value*1e6) # end date range
-
- for numTicks in range(2, 12):
- ticks, _, _ = calcTicks(d1, d2, numTicks)
-
- margin = 2.5
- self.assertTrue(
- numTicks/margin <= len(ticks) <= numTicks*margin,
- "Condition {} <= {} <= {} failed for # ticks={} and d2={}:"
- .format(numTicks/margin, len(ticks), numTicks * margin,
- numTicks, d2))
-
- value = value * 1.5 # let date period grow exponentially
-
-
-
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(DtTestTickLayout))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/silx/gui/plot/_utils/test/test_ticklayout.py b/silx/gui/plot/_utils/test/test_ticklayout.py
deleted file mode 100644
index 927ffb6..0000000
--- a/silx/gui/plot/_utils/test/test_ticklayout.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-
-from silx.gui.plot._utils import ticklayout
-
-
-class TestTickLayout(ParametricTestCase):
- """Test ticks layout algorithms"""
-
- def testTicks(self):
- """Test of :func:`ticks`"""
- tests = { # (vmin, vmax): ref_ticks
- (1., 1.): (1.,),
- (0.5, 10.5): (2.0, 4.0, 6.0, 8.0, 10.0),
- (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005)
- }
-
- for (vmin, vmax), ref_ticks in tests.items():
- with self.subTest(vmin=vmin, vmax=vmax):
- ticks, labels = ticklayout.ticks(vmin, vmax)
- self.assertTrue(numpy.allclose(ticks, ref_ticks))
-
- def testNiceNumbers(self):
- """Minimalistic tests of :func:`niceNumbers`"""
- tests = { # (vmin, vmax): ref_ticks
- (0.5, 10.5): (0.0, 12.0, 2.0, 0),
- (10000., 10000.5): (10000.0, 10000.5, 0.1, 1),
- (0.001, 0.005): (0.001, 0.005, 0.001, 3)
- }
-
- for (vmin, vmax), ref_ticks in tests.items():
- with self.subTest(vmin=vmin, vmax=vmax):
- ticks = ticklayout.niceNumbers(vmin, vmax)
- self.assertEqual(ticks, ref_ticks)
-
- def testNiceNumbersLog(self):
- """Minimalistic tests of :func:`niceNumbersForLog10`"""
- tests = { # (log10(min), log10(max): ref_ticks
- (0., 3.): (0, 3, 1, 0),
- (-3., 3): (-3, 3, 1, 0),
- (-32., 0.): (-36, 0, 6, 0)
- }
-
- for (vmin, vmax), ref_ticks in tests.items():
- with self.subTest(vmin=vmin, vmax=vmax):
- ticks = ticklayout.niceNumbersForLog10(vmin, vmax)
- self.assertEqual(ticks, ref_ticks)
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestTickLayout))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/silx/gui/plot/actions/fit.py b/silx/gui/plot/actions/fit.py
deleted file mode 100644
index f3c9e1c..0000000
--- a/silx/gui/plot/actions/fit.py
+++ /dev/null
@@ -1,403 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-:mod:`silx.gui.plot.actions.fit` module provides actions relative to fit.
-
-The following QAction are available:
-
-- :class:`.FitAction`
-
-.. autoclass:`.FitAction`
-"""
-
-from __future__ import division
-
-__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "10/10/2018"
-
-import logging
-
-import numpy
-
-from .PlotToolAction import PlotToolAction
-from .. import items
-from ....utils.deprecation import deprecated
-from silx.gui import qt
-from silx.gui.plot.ItemsSelectionDialog import ItemsSelectionDialog
-
-_logger = logging.getLogger(__name__)
-
-
-def _getUniqueCurveOrHistogram(plot):
- """Returns unique :class:`Curve` or :class:`Histogram` in a `PlotWidget`.
-
- If there is an active curve, returns it, else return curve or histogram
- only if alone in the plot.
-
- :param PlotWidget plot:
- :rtype: Union[None,~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram]
- """
- curve = plot.getActiveCurve()
- if curve is not None:
- return curve
-
- histograms = [item for item in plot.getItems()
- if isinstance(item, items.Histogram) and item.isVisible()]
- curves = [item for item in plot.getItems()
- if isinstance(item, items.Curve) and item.isVisible()]
-
- if len(histograms) == 1 and len(curves) == 0:
- return histograms[0]
- elif len(curves) == 1 and len(histograms) == 0:
- return curves[0]
- else:
- return None
-
-
-class FitAction(PlotToolAction):
- """QAction to open a :class:`FitWidget` and set its data to the
- active curve if any, or to the first curve.
-
- :param plot: :class:`.PlotWidget` instance on which to operate
- :param parent: See :class:`QAction`
- """
-
- def __init__(self, plot, parent=None):
- self.__item = None
- self.__activeCurveSynchroEnabled = False
- self.__range = 0, 1
- self.__rangeAutoUpdate = False
- self.__x, self.__y = None, None # Data to fit
- self.__curveParams = {} # Store curve parameters to use for fit result
- self.__legend = None
-
- super(FitAction, self).__init__(
- plot, icon='math-fit', text='Fit curve',
- tooltip='Open a fit dialog',
- parent=parent)
-
- @property
- @deprecated(replacement='getXRange()[0]', since_version='0.13.0')
- def xmin(self):
- return self.getXRange()[0]
-
- @property
- @deprecated(replacement='getXRange()[1]', since_version='0.13.0')
- def xmax(self):
- return self.getXRange()[1]
-
- @property
- @deprecated(replacement='getXData()', since_version='0.13.0')
- def x(self):
- return self.getXData()
-
- @property
- @deprecated(replacement='getYData()', since_version='0.13.0')
- def y(self):
- return self.getYData()
-
- @property
- @deprecated(since_version='0.13.0')
- def xlabel(self):
- return self.__curveParams.get('xlabel', None)
-
- @property
- @deprecated(since_version='0.13.0')
- def ylabel(self):
- return self.__curveParams.get('ylabel', None)
-
- @property
- @deprecated(since_version='0.13.0')
- def legend(self):
- return self.__legend
-
- def _createToolWindow(self):
- # import done here rather than at module level to avoid circular import
- # FitWidget -> BackgroundWidget -> PlotWindow -> actions -> fit -> FitWidget
- from ...fit.FitWidget import FitWidget
-
- window = FitWidget(parent=self.plot)
- window.setWindowFlags(qt.Qt.Dialog)
- window.sigFitWidgetSignal.connect(self.handle_signal)
- return window
-
- def _connectPlot(self, window):
- if self.isXRangeUpdatedOnZoom():
- self.__setAutoXRangeEnabled(True)
- else:
- plot = self.plot
- if plot is None:
- _logger.error("No associated PlotWidget")
- return
- self._setXRange(*plot.getXAxis().getLimits())
-
- if self.isFittedItemUpdatedFromActiveCurve():
- self.__setFittedItemAutoUpdateEnabled(True)
- else:
- # Wait for the next iteration, else the plot is not yet initialized
- # No curve available
- qt.QTimer.singleShot(10, self._initFit)
-
- def _disconnectPlot(self, window):
- if self.isXRangeUpdatedOnZoom():
- self.__setAutoXRangeEnabled(False)
-
- if self.isFittedItemUpdatedFromActiveCurve():
- self.__setFittedItemAutoUpdateEnabled(False)
-
- def _initFit(self):
- plot = self.plot
- if plot is None:
- _logger.error("No associated PlotWidget")
- return
-
- item = _getUniqueCurveOrHistogram(plot)
- if item is None:
- # ambiguous case, we need to ask which plot item to fit
- isd = ItemsSelectionDialog(parent=plot, plot=plot)
- isd.setWindowTitle("Select item to be fitted")
- isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
- isd.setAvailableKinds(["curve", "histogram"])
- isd.selectAllKinds()
-
- if not isd.exec_(): # Cancel
- self._getToolWindow().setVisible(False)
- else:
- selectedItems = isd.getSelectedItems()
- item = selectedItems[0] if len(selectedItems) == 1 else None
-
- self._setXRange(*plot.getXAxis().getLimits())
- self._setFittedItem(item)
-
- def __updateFitWidget(self):
- """Update the data/range used by the FitWidget"""
- fitWidget = self._getToolWindow()
-
- item = self._getFittedItem()
- xdata = self.getXData(copy=False)
- ydata = self.getYData(copy=False)
- if item is None or xdata is None or ydata is None:
- fitWidget.setData(y=None)
- fitWidget.setWindowTitle("No curve selected")
-
- else:
- xmin, xmax = self.getXRange()
- fitWidget.setData(
- xdata, ydata, xmin=xmin, xmax=xmax)
- fitWidget.setWindowTitle(
- "Fitting " + item.getName() +
- " on x range %f-%f" % (xmin, xmax))
-
- # X Range management
-
- def getXRange(self):
- """Returns the range on the X axis on which to perform the fit."""
- return self.__range
-
- def _setXRange(self, xmin, xmax):
- """Set the range on which the fit is done.
-
- :param float xmin:
- :param float xmax:
- """
- range_ = float(xmin), float(xmax)
- if self.__range != range_:
- self.__range = range_
- self.__updateFitWidget()
-
- def __setAutoXRangeEnabled(self, enabled):
- """Implement the change of update mode of the X range.
-
- :param bool enabled:
- """
- plot = self.plot
- if plot is None:
- _logger.error("No associated PlotWidget")
- return
-
- if enabled:
- self._setXRange(*plot.getXAxis().getLimits())
- plot.getXAxis().sigLimitsChanged.connect(self._setXRange)
- else:
- plot.getXAxis().sigLimitsChanged.disconnect(self._setXRange)
-
- def setXRangeUpdatedOnZoom(self, enabled):
- """Set whether or not to update the X range on zoom change.
-
- :param bool enabled:
- """
- if enabled != self.__rangeAutoUpdate:
- self.__rangeAutoUpdate = enabled
- if self._getToolWindow().isVisible():
- self.__setAutoXRangeEnabled(enabled)
-
- def isXRangeUpdatedOnZoom(self):
- """Returns the current mode of fitted data X range update.
-
- :rtype: bool
- """
- return self.__rangeAutoUpdate
-
- # Fitted item update
-
- def getXData(self, copy=True):
- """Returns the X data used for the fit or None if undefined.
-
- :param bool copy:
- True to get a copy of the data, False to get the internal data.
- :rtype: Union[numpy.ndarray,None]
- """
- return None if self.__x is None else numpy.array(self.__x, copy=copy)
-
- def getYData(self, copy=True):
- """Returns the Y data used for the fit or None if undefined.
-
- :param bool copy:
- True to get a copy of the data, False to get the internal data.
- :rtype: Union[numpy.ndarray,None]
- """
- return None if self.__y is None else numpy.array(self.__y, copy=copy)
-
- def _getFittedItem(self):
- """Returns the current item used for the fit
-
- :rtype: Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None]
- """
- return self.__item
-
- def _setFittedItem(self, item):
- """Set the curve to use for fitting.
-
- :param Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None] item:
- """
- plot = self.plot
- if plot is None:
- _logger.error("No associated PlotWidget")
-
- if plot is None or item is None:
- self.__item = None
- self.__curveParams = {}
- self.__updateFitWidget()
- return
-
- axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else 'left'
- self.__curveParams = {
- 'yaxis': axis,
- 'xlabel': plot.getXAxis().getLabel(),
- 'ylabel': plot.getYAxis(axis).getLabel(),
- }
- self.__legend = item.getName()
-
- if isinstance(item, items.Histogram):
- bin_edges = item.getBinEdgesData(copy=False)
- # take the middle coordinate between adjacent bin edges
- self.__x = (bin_edges[1:] + bin_edges[:-1]) / 2
- self.__y = item.getValueData(copy=False)
- # else take the active curve, or else the unique curve
- elif isinstance(item, items.Curve):
- self.__x = item.getXData(copy=False)
- self.__y = item.getYData(copy=False)
-
- self.__item = item
- self.__updateFitWidget()
-
- def __activeCurveChanged(self, previous, current):
- """Handle change of active curve in the PlotWidget
- """
- if current is None:
- self._setFittedItem(None)
- else:
- item = self.plot.getCurve(current)
- self._setFittedItem(item)
-
- def __setFittedItemAutoUpdateEnabled(self, enabled):
- """Implement the change of fitted item update mode
-
- :param bool enabled:
- """
- plot = self.plot
- if plot is None:
- _logger.error("No associated PlotWidget")
- return
-
- if enabled:
- self._setFittedItem(plot.getActiveCurve())
- plot.sigActiveCurveChanged.connect(self.__activeCurveChanged)
-
- else:
- plot.sigActiveCurveChanged.disconnect(
- self.__activeCurveChanged)
-
- def setFittedItemUpdatedFromActiveCurve(self, enabled):
- """Toggle fitted data synchronization with plot active curve.
-
- :param bool enabled:
- """
- enabled = bool(enabled)
- if enabled != self.__activeCurveSynchroEnabled:
- self.__activeCurveSynchroEnabled = enabled
- if self._getToolWindow().isVisible():
- self.__setFittedItemAutoUpdateEnabled(enabled)
-
- def isFittedItemUpdatedFromActiveCurve(self):
- """Returns True if fitted data is synchronized with plot.
-
- :rtype: bool
- """
- return self.__activeCurveSynchroEnabled
-
- # Handle fit completed
-
- def handle_signal(self, ddict):
- xdata = self.getXData(copy=False)
- if xdata is None:
- _logger.error("No reference data to display fit result for")
- return
-
- xmin, xmax = self.getXRange()
- x_fit = xdata[xmin <= xdata]
- x_fit = x_fit[x_fit <= xmax]
- fit_legend = "Fit <%s>" % self.__legend
- fit_curve = self.plot.getCurve(fit_legend)
-
- if ddict["event"] == "FitFinished":
- fit_widget = self._getToolWindow()
- if fit_widget is None:
- return
- y_fit = fit_widget.fitmanager.gendata()
- if fit_curve is None:
- self.plot.addCurve(x_fit, y_fit,
- fit_legend,
- resetzoom=False,
- **self.__curveParams)
- else:
- fit_curve.setData(x_fit, y_fit)
- fit_curve.setVisible(True)
- fit_curve.setYAxis(self.__curveParams.get('yaxis', 'left'))
-
- if ddict["event"] in ["FitStarted", "FitFailed"]:
- if fit_curve is not None:
- fit_curve.setVisible(False)
diff --git a/silx/gui/plot/actions/histogram.py b/silx/gui/plot/actions/histogram.py
deleted file mode 100644
index 0bba558..0000000
--- a/silx/gui/plot/actions/histogram.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-:mod:`silx.gui.plot.actions.histogram` provides actions relative to histograms
-for :class:`.PlotWidget`.
-
-The following QAction are available:
-
-- :class:`PixelIntensitiesHistoAction`
-"""
-
-from __future__ import division
-
-__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
-__date__ = "01/12/2020"
-__license__ = "MIT"
-
-import numpy
-import logging
-import typing
-import weakref
-
-from .PlotToolAction import PlotToolAction
-
-from silx.math.histogram import Histogramnd
-from silx.math.combo import min_max
-from silx.gui import qt
-from silx.gui.plot import items
-from silx.gui.widgets.ElidedLabel import ElidedLabel
-from silx.utils.deprecation import deprecated
-
-_logger = logging.getLogger(__name__)
-
-
-class _ElidedLabel(ElidedLabel):
- """QLabel with a default size larger than what is displayed."""
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
-
- def sizeHint(self):
- hint = super().sizeHint()
- nbchar = max(len(self.getText()), 12)
- width = self.fontMetrics().boundingRect('#' * nbchar).width()
- return qt.QSize(max(hint.width(), width), hint.height())
-
-
-class _StatWidget(qt.QWidget):
- """Widget displaying a name and a value
-
- :param parent:
- :param name:
- """
-
- def __init__(self, parent=None, name: str=''):
- super().__init__(parent)
- layout = qt.QHBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
-
- keyWidget = qt.QLabel(parent=self)
- keyWidget.setText("<b>" + name.capitalize() + ":<b>")
- layout.addWidget(keyWidget)
- self.__valueWidget = _ElidedLabel(parent=self)
- self.__valueWidget.setText("-")
- self.__valueWidget.setTextInteractionFlags(
- qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard)
- layout.addWidget(self.__valueWidget)
-
- def setValue(self, value: typing.Optional[float]):
- """Set the displayed value
-
- :param value:
- """
- self.__valueWidget.setText(
- "-" if value is None else "{:.5g}".format(value))
-
-
-class HistogramWidget(qt.QWidget):
- """Widget displaying a histogram and some statistic indicators"""
-
- _SUPPORTED_ITEM_CLASS = items.ImageBase, items.Scatter
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.setWindowTitle('Histogram')
-
- self.__itemRef = None # weakref on the item to track
-
- layout = qt.QVBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
-
- # Plot
- # Lazy import to avoid circular dependencies
- from silx.gui.plot.PlotWindow import Plot1D
- self.__plot = Plot1D(self)
- layout.addWidget(self.__plot)
-
- self.__plot.setDataMargins(0.1, 0.1, 0.1, 0.1)
- self.__plot.getXAxis().setLabel("Value")
- self.__plot.getYAxis().setLabel("Count")
- posInfo = self.__plot.getPositionInfoWidget()
- posInfo.setSnappingMode(posInfo.SNAPPING_CURVE)
-
- # Stats display
- statsWidget = qt.QWidget(self)
- layout.addWidget(statsWidget)
- statsLayout = qt.QHBoxLayout(statsWidget)
- statsLayout.setContentsMargins(4, 4, 4, 4)
-
- self.__statsWidgets = dict(
- (name, _StatWidget(parent=statsWidget, name=name))
- for name in ("min", "max", "mean", "std", "sum"))
-
- for widget in self.__statsWidgets.values():
- statsLayout.addWidget(widget)
- statsLayout.addStretch(1)
-
- def getPlotWidget(self):
- """Returns :class:`PlotWidget` use to display the histogram"""
- return self.__plot
-
- def resetZoom(self):
- """Reset PlotWidget zoom"""
- self.getPlotWidget().resetZoom()
-
- def reset(self):
- """Clear displayed information"""
- self.getPlotWidget().clear()
- self.setStatistics()
-
- def getItem(self) -> typing.Optional[items.Item]:
- """Returns item used to display histogram and statistics."""
- return None if self.__itemRef is None else self.__itemRef()
-
- def setItem(self, item: typing.Optional[items.Item]):
- """Set item from which to display histogram and statistics.
-
- :param item:
- """
- previous = self.getItem()
- if previous is not None:
- previous.sigItemChanged.disconnect(self.__itemChanged)
-
- self.__itemRef = None if item is None else weakref.ref(item)
- if item is not None:
- if isinstance(item, self._SUPPORTED_ITEM_CLASS):
- # Only listen signal for supported items
- item.sigItemChanged.connect(self.__itemChanged)
- self._updateFromItem()
-
- def __itemChanged(self, event):
- """Handle update of the item"""
- if event in (items.ItemChangedType.DATA, items.ItemChangedType.MASK):
- self._updateFromItem()
-
- def _updateFromItem(self):
- """Update histogram and stats from the item"""
- item = self.getItem()
-
- if item is None:
- self.reset()
- return
-
- if not isinstance(item, self._SUPPORTED_ITEM_CLASS):
- _logger.error("Unsupported item", item)
- self.reset()
- return
-
- # Compute histogram and stats
- array = item.getValueData(copy=False)
-
- if array.size == 0:
- self.reset()
- return
-
- xmin, xmax = min_max(array, min_positive=False, finite=True)
- nbins = min(1024, int(numpy.sqrt(array.size)))
- data_range = xmin, xmax
-
- # bad hack: get 256 bins in the case we have a B&W
- if numpy.issubdtype(array.dtype, numpy.integer):
- if nbins > xmax - xmin:
- nbins = xmax - xmin
-
- nbins = max(2, nbins)
-
- data = array.ravel().astype(numpy.float32)
- histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
- if len(histogram.edges) != 1:
- _logger.error("Error while computing the histogram")
- self.reset()
- return
-
- self.setHistogram(histogram.histo, histogram.edges[0])
- self.resetZoom()
- self.setStatistics(
- min_=xmin,
- max_=xmax,
- mean=numpy.nanmean(array),
- std=numpy.nanstd(array),
- sum_=numpy.nansum(array))
-
- def setHistogram(self, histogram, edges):
- """Set displayed histogram
-
- :param histogram: Bin values (N)
- :param edges: Bin edges (N+1)
- """
- self.getPlotWidget().addHistogram(
- histogram=histogram,
- edges=edges,
- legend='histogram',
- fill=True,
- color='#66aad7',
- resetzoom=False)
-
- def getHistogram(self, copy: bool=True):
- """Returns currently displayed histogram.
-
- :param copy: True to get a copy,
- False to get internal representation (Do not modify!)
- :return: (histogram, edges) or None
- """
- for item in self.getPlotWidget().getItems():
- if item.getName() == 'histogram':
- return (item.getValueData(copy=copy),
- item.getBinEdgesData(copy=copy))
- else:
- return None
-
- def setStatistics(self,
- min_: typing.Optional[float] = None,
- max_: typing.Optional[float] = None,
- mean: typing.Optional[float] = None,
- std: typing.Optional[float] = None,
- sum_: typing.Optional[float] = None):
- """Set displayed statistic indicators."""
- self.__statsWidgets['min'].setValue(min_)
- self.__statsWidgets['max'].setValue(max_)
- self.__statsWidgets['mean'].setValue(mean)
- self.__statsWidgets['std'].setValue(std)
- self.__statsWidgets['sum'].setValue(sum_)
-
-
-class _LastActiveItem(qt.QObject):
-
- sigActiveItemChanged = qt.Signal(object, object)
- """Emitted when the active plot item have changed"""
-
- def __init__(self, parent, plot):
- assert plot is not None
- super(_LastActiveItem, self).__init__(parent=parent)
- self.__plot = weakref.ref(plot)
- self.__item = None
- item = self.__findActiveItem()
- self.setActiveItem(item)
- plot.sigActiveImageChanged.connect(self._activeImageChanged)
- plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
-
- def getPlotWidget(self):
- return self.__plot()
-
- def __findActiveItem(self):
- plot = self.getPlotWidget()
- image = plot.getActiveImage()
- if image is not None:
- return image
- scatter = plot.getActiveScatter()
- if scatter is not None:
- return scatter
-
- def getActiveItem(self):
- if self.__item is None:
- return None
- item = self.__item()
- if item is None:
- self.__item = None
- return item
-
- def setActiveItem(self, item):
- previous = self.getActiveItem()
- if previous is item:
- return
- if item is None:
- self.__item = None
- else:
- self.__item = weakref.ref(item)
- self.sigActiveItemChanged.emit(previous, item)
-
- def _activeImageChanged(self, previous, current):
- """Handle active image change"""
- plot = self.getPlotWidget()
- if current is None: # Fall-back to active scatter if any
- self.setActiveItem(plot.getActiveScatter())
- else:
- item = plot.getImage(current)
- if item is None:
- self.setActiveItem(None)
- elif isinstance(item, items.ImageBase):
- self.setActiveItem(item)
- else:
- # Do not touch anything, which is consistent with silx v0.12 behavior
- pass
-
- def _activeScatterChanged(self, previous, current):
- """Handle active scatter change"""
- plot = self.getPlotWidget()
- if current is None: # Fall-back to active image if any
- self.setActiveItem(plot.getActiveImage())
- else:
- item = plot.getScatter(current)
- self.setActiveItem(item)
-
-
-class PixelIntensitiesHistoAction(PlotToolAction):
- """QAction to plot the pixels intensities diagram
-
- :param plot: :class:`.PlotWidget` instance on which to operate
- :param parent: See :class:`QAction`
- """
-
- def __init__(self, plot, parent=None):
- PlotToolAction.__init__(self,
- plot,
- icon='pixel-intensities',
- text='pixels intensity',
- tooltip='Compute image intensity distribution',
- parent=parent)
- self._lastItemFilter = _LastActiveItem(self, plot)
-
- def _connectPlot(self, window):
- self._lastItemFilter.sigActiveItemChanged.connect(self._activeItemChanged)
- item = self._lastItemFilter.getActiveItem()
- self.getHistogramWidget().setItem(item)
- PlotToolAction._connectPlot(self, window)
-
- def _disconnectPlot(self, window):
- self._lastItemFilter.sigActiveItemChanged.disconnect(self._activeItemChanged)
- PlotToolAction._disconnectPlot(self, window)
- self.getHistogramWidget().setItem(None)
-
- def _activeItemChanged(self, previous, current):
- if self._isWindowInUse():
- self.getHistogramWidget().setItem(current)
-
- @deprecated(since_version='0.15.0')
- def computeIntensityDistribution(self):
- self.getHistogramWidget()._updateFromItem()
-
- def getHistogramWidget(self):
- """Returns the widget displaying the histogram"""
- return self._getToolWindow()
-
- @deprecated(since_version='0.15.0',
- replacement='getHistogramWidget().getPlotWidget()')
- def getHistogramPlotWidget(self):
- return self._getToolWindow().getPlotWidget()
-
- def _createToolWindow(self):
- return HistogramWidget(self.plot, qt.Qt.Window)
-
- def getHistogram(self) -> typing.Optional[numpy.ndarray]:
- """Return the last computed histogram
-
- :return: the histogram displayed in the HistogramWidget
- """
- histogram = self.getHistogramWidget().getHistogram()
- return None if histogram is None else histogram[0]
diff --git a/silx/gui/plot/actions/io.py b/silx/gui/plot/actions/io.py
deleted file mode 100644
index f728b7a..0000000
--- a/silx/gui/plot/actions/io.py
+++ /dev/null
@@ -1,818 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-:mod:`silx.gui.plot.actions.io` provides a set of QAction relative of inputs
-and outputs for a :class:`.PlotWidget`.
-
-The following QAction are available:
-
-- :class:`CopyAction`
-- :class:`PrintAction`
-- :class:`SaveAction`
-"""
-
-from __future__ import division
-
-__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "25/09/2020"
-
-from . import PlotAction
-from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
-from silx.io.nxdata import save_NXdata
-import logging
-import sys
-import os.path
-from collections import OrderedDict
-import traceback
-import numpy
-from silx.utils.deprecation import deprecated
-from silx.gui import qt, printer
-from silx.gui.dialog.GroupDialog import GroupDialog
-from silx.third_party.EdfFile import EdfFile
-from silx.third_party.TiffIO import TiffIO
-from ...utils.image import convertArrayToQImage
-if sys.version_info[0] == 3:
- from io import BytesIO
-else:
- import cStringIO as _StringIO
- BytesIO = _StringIO.StringIO
-
-_logger = logging.getLogger(__name__)
-
-_NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
-
-
-def selectOutputGroup(h5filename):
- """Open a dialog to prompt the user to select a group in
- which to output data.
-
- :param str h5filename: name of an existing HDF5 file
- :rtype: str
- :return: Name of output group, or None if the dialog was cancelled
- """
- dialog = GroupDialog()
- dialog.addFile(h5filename)
- dialog.setWindowTitle("Select an output group")
- if not dialog.exec_():
- return None
- return dialog.getSelectedDataUrl().data_path()
-
-
-class SaveAction(PlotAction):
- """QAction for saving Plot content.
-
- It opens a Save as... dialog.
-
- :param plot: :class:`.PlotWidget` instance on which to operate.
- :param parent: See :class:`QAction`.
- """
-
- SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)'
- SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)'
-
- DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG)
-
- # Dict of curve filters with CSV-like format
- # Using ordered dict to guarantee filters order
- # Note: '%.18e' is numpy.savetxt default format
- CURVE_FILTERS_TXT = OrderedDict((
- ('Curve as Raw ASCII (*.txt)',
- {'fmt': '%.18e', 'delimiter': ' ', 'header': False}),
- ('Curve as ";"-separated CSV (*.csv)',
- {'fmt': '%.18e', 'delimiter': ';', 'header': True}),
- ('Curve as ","-separated CSV (*.csv)',
- {'fmt': '%.18e', 'delimiter': ',', 'header': True}),
- ('Curve as tab-separated CSV (*.csv)',
- {'fmt': '%.18e', 'delimiter': '\t', 'header': True}),
- ('Curve as OMNIC CSV (*.csv)',
- {'fmt': '%.7E', 'delimiter': ',', 'header': False}),
- ('Curve as SpecFile (*.dat)',
- {'fmt': '%.10g', 'delimiter': '', 'header': False})
- ))
-
- CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)'
-
- CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
-
- DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [
- CURVE_FILTER_NPY, CURVE_FILTER_NXDATA]
-
- DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",)
-
- IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)'
- IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)'
- IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)'
- IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)'
- IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)'
- IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)'
- IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)'
- IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)'
- IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
-
- DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF,
- IMAGE_FILTER_TIFF,
- IMAGE_FILTER_NUMPY,
- IMAGE_FILTER_ASCII,
- IMAGE_FILTER_CSV_COMMA,
- IMAGE_FILTER_CSV_SEMICOLON,
- IMAGE_FILTER_CSV_TAB,
- IMAGE_FILTER_RGB_PNG,
- IMAGE_FILTER_NXDATA)
-
- SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
- DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,)
-
- # filters for which we don't want an "overwrite existing file" warning
- DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA,
- SCATTER_FILTER_NXDATA)
-
- def __init__(self, plot, parent=None):
- self._filters = {
- 'all': OrderedDict(),
- 'curve': OrderedDict(),
- 'curves': OrderedDict(),
- 'image': OrderedDict(),
- 'scatter': OrderedDict()}
-
- self._appendFilters = list(self.DEFAULT_APPEND_FILTERS)
-
- # Initialize filters
- for nameFilter in self.DEFAULT_ALL_FILTERS:
- self.setFileFilter(
- dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot)
-
- for nameFilter in self.DEFAULT_CURVE_FILTERS:
- self.setFileFilter(
- dataKind='curve', nameFilter=nameFilter, func=self._saveCurve)
-
- for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS:
- self.setFileFilter(
- dataKind='curves', nameFilter=nameFilter, func=self._saveCurves)
-
- for nameFilter in self.DEFAULT_IMAGE_FILTERS:
- self.setFileFilter(
- dataKind='image', nameFilter=nameFilter, func=self._saveImage)
-
- for nameFilter in self.DEFAULT_SCATTER_FILTERS:
- self.setFileFilter(
- dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter)
-
- super(SaveAction, self).__init__(
- plot, icon='document-save', text='Save as...',
- tooltip='Save curve/image/plot snapshot dialog',
- triggered=self._actionTriggered,
- checkable=False, parent=parent)
- self.setShortcut(qt.QKeySequence.Save)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
-
- @staticmethod
- def _errorMessage(informativeText='', parent=None):
- """Display an error message."""
- # TODO issue with QMessageBox size fixed and too small
- msg = qt.QMessageBox(parent)
- msg.setIcon(qt.QMessageBox.Critical)
- msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1]))
- msg.setDetailedText(traceback.format_exc())
- msg.exec_()
-
- def _saveSnapshot(self, plot, filename, nameFilter):
- """Save a snapshot of the :class:`PlotWindow` widget.
-
- :param str filename: The name of the file to write
- :param str nameFilter: The selected name filter
- :return: False if format is not supported or save failed,
- True otherwise.
- """
- if nameFilter == self.SNAPSHOT_FILTER_PNG:
- fileFormat = 'png'
- elif nameFilter == self.SNAPSHOT_FILTER_SVG:
- fileFormat = 'svg'
- else: # Format not supported
- _logger.error(
- 'Saving plot snapshot failed: format not supported')
- return False
-
- plot.saveGraph(filename, fileFormat=fileFormat)
- return True
-
- def _getAxesLabels(self, item):
- # If curve has no associated label, get the default from the plot
- xlabel = item.getXLabel() or self.plot.getXAxis().getLabel()
- ylabel = item.getYLabel() or self.plot.getYAxis().getLabel()
- return xlabel, ylabel
-
- def _get1dData(self, item):
- "provide xdata, [ydata], xlabel, [ylabel] and manages error bars"
- xlabel, ylabel = self._getAxesLabels(item)
- x_data = item.getXData(copy=False)
- y_data = item.getYData(copy=False)
- x_err = item.getXErrorData(copy=False)
- y_err = item.getYErrorData(copy=False)
- labels = [ylabel]
- data = [y_data]
-
- if x_err is not None:
- if numpy.isscalar(x_err):
- data.append(numpy.zeros_like(y_data) + x_err)
- labels.append(xlabel + "_errors")
- elif x_err.ndim == 1:
- data.append(x_err)
- labels.append(xlabel + "_errors")
- elif x_err.ndim == 2:
- data.append(x_err[0])
- labels.append(xlabel + "_errors_below")
- data.append(x_err[1])
- labels.append(xlabel + "_errors_above")
-
- if y_err is not None:
- if numpy.isscalar(y_err):
- data.append(numpy.zeros_like(y_data) + y_err)
- labels.append(ylabel + "_errors")
- elif y_err.ndim == 1:
- data.append(y_err)
- labels.append(ylabel + "_errors")
- elif y_err.ndim == 2:
- data.append(y_err[0])
- labels.append(ylabel + "_errors_below")
- data.append(y_err[1])
- labels.append(ylabel + "_errors_above")
- return x_data, data, xlabel, labels
-
- @staticmethod
- def _selectWriteableOutputGroup(filename, parent):
- if os.path.exists(filename) and os.path.isfile(filename) \
- and os.access(filename, os.W_OK):
- entryPath = selectOutputGroup(filename)
- if entryPath is None:
- _logger.info("Save operation cancelled")
- return None
- return entryPath
- elif not os.path.exists(filename):
- # create new entry in new file
- return "/entry"
- else:
- SaveAction._errorMessage('Save failed (file access issue)\n', parent=parent)
- return None
-
- def _saveCurveAsNXdata(self, curve, filename):
- entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
- if entryPath is None:
- return False
-
- xlabel, ylabel = self._getAxesLabels(curve)
-
- return save_NXdata(
- filename,
- nxentry_name=entryPath,
- signal=curve.getYData(copy=False),
- axes=[curve.getXData(copy=False)],
- signal_name="y",
- axes_names=["x"],
- signal_long_name=ylabel,
- axes_long_names=[xlabel],
- signal_errors=curve.getYErrorData(copy=False),
- axes_errors=[curve.getXErrorData(copy=True)],
- title=self.plot.getGraphTitle())
-
- def _saveCurve(self, plot, filename, nameFilter):
- """Save a curve from the plot.
-
- :param str filename: The name of the file to write
- :param str nameFilter: The selected name filter
- :return: False if format is not supported or save failed,
- True otherwise.
- """
- if nameFilter not in self.DEFAULT_CURVE_FILTERS:
- return False
-
- # Check if a curve is to be saved
- curve = plot.getActiveCurve()
- # before calling _saveCurve, if there is no selected curve, we
- # make sure there is only one curve on the graph
- if curve is None:
- curves = plot.getAllCurves()
- if not curves:
- self._errorMessage("No curve to be saved", parent=self.plot)
- return False
- curve = curves[0]
-
- if nameFilter in self.CURVE_FILTERS_TXT:
- filter_ = self.CURVE_FILTERS_TXT[nameFilter]
- fmt = filter_['fmt']
- csvdelim = filter_['delimiter']
- autoheader = filter_['header']
- else:
- # .npy or nxdata
- fmt, csvdelim, autoheader = ("", "", False)
-
- if nameFilter == self.CURVE_FILTER_NXDATA:
- return self._saveCurveAsNXdata(curve, filename)
-
- xdata, data, xlabel, labels = self._get1dData(curve)
-
- try:
- save1D(filename,
- xdata, data,
- xlabel, labels,
- fmt=fmt, csvdelim=csvdelim,
- autoheader=autoheader)
- except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
- return False
-
- return True
-
- def _saveCurves(self, plot, filename, nameFilter):
- """Save all curves from the plot.
-
- :param str filename: The name of the file to write
- :param str nameFilter: The selected name filter
- :return: False if format is not supported or save failed,
- True otherwise.
- """
- if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS:
- return False
-
- curves = plot.getAllCurves()
- if not curves:
- self._errorMessage("No curves to be saved", parent=self.plot)
- return False
-
- curve = curves[0]
- scanno = 1
- try:
- xdata, data, xlabel, labels = self._get1dData(curve)
-
- specfile = savespec(filename,
- xdata, data,
- xlabel, labels,
- fmt="%.7g", scan_number=1, mode="w",
- write_file_header=True,
- close_file=False)
- except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
- return False
-
- for curve in curves[1:]:
- try:
- scanno += 1
- xdata, data, xlabel, labels = self._get1dData(curve)
- specfile = savespec(specfile,
- xdata, data,
- xlabel, labels,
- fmt="%.7g", scan_number=scanno,
- write_file_header=False,
- close_file=False)
- except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
- return False
- specfile.close()
-
- return True
-
- def _saveImage(self, plot, filename, nameFilter):
- """Save an image from the plot.
-
- :param str filename: The name of the file to write
- :param str nameFilter: The selected name filter
- :return: False if format is not supported or save failed,
- True otherwise.
- """
- if nameFilter not in self.DEFAULT_IMAGE_FILTERS:
- return False
-
- image = plot.getActiveImage()
- if image is None:
- qt.QMessageBox.warning(
- plot, "No Data", "No image to be saved")
- return False
-
- data = image.getData(copy=False)
-
- # TODO Use silx.io for writing files
- if nameFilter == self.IMAGE_FILTER_EDF:
- edfFile = EdfFile(filename, access="w+")
- edfFile.WriteImage({}, data, Append=0)
- return True
-
- elif nameFilter == self.IMAGE_FILTER_TIFF:
- tiffFile = TiffIO(filename, mode='w')
- tiffFile.writeImage(data, software='silx')
- return True
-
- elif nameFilter == self.IMAGE_FILTER_NUMPY:
- try:
- numpy.save(filename, data)
- except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
- return False
- return True
-
- elif nameFilter == self.IMAGE_FILTER_NXDATA:
- entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
- if entryPath is None:
- return False
- xorigin, yorigin = image.getOrigin()
- xscale, yscale = image.getScale()
- xaxis = xorigin + xscale * numpy.arange(data.shape[1])
- yaxis = yorigin + yscale * numpy.arange(data.shape[0])
- xlabel, ylabel = self._getAxesLabels(image)
- interpretation = "image" if len(data.shape) == 2 else "rgba-image"
-
- return save_NXdata(filename,
- nxentry_name=entryPath,
- signal=data,
- axes=[yaxis, xaxis],
- signal_name="image",
- axes_names=["y", "x"],
- axes_long_names=[ylabel, xlabel],
- title=plot.getGraphTitle(),
- interpretation=interpretation)
-
- elif nameFilter in (self.IMAGE_FILTER_ASCII,
- self.IMAGE_FILTER_CSV_COMMA,
- self.IMAGE_FILTER_CSV_SEMICOLON,
- self.IMAGE_FILTER_CSV_TAB):
- csvdelim, filetype = {
- self.IMAGE_FILTER_ASCII: (' ', 'txt'),
- self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'),
- self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'),
- self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'),
- }[nameFilter]
-
- height, width = data.shape
- rows, cols = numpy.mgrid[0:height, 0:width]
- try:
- save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()),
- filetype=filetype,
- xlabel='row',
- ylabels=['column', 'value'],
- csvdelim=csvdelim,
- autoheader=True)
-
- except IOError:
- self._errorMessage('Save failed\n', parent=self.plot)
- return False
- return True
-
- elif nameFilter == self.IMAGE_FILTER_RGB_PNG:
- # Get displayed image
- rgbaImage = image.getRgbaImageData(copy=False)
- # Convert RGB QImage
- qimage = convertArrayToQImage(rgbaImage[:, :, :3])
-
- if qimage.save(filename, 'PNG'):
- return True
- else:
- _logger.error('Failed to save image as %s', filename)
- qt.QMessageBox.critical(
- self.parent(),
- 'Save image as',
- 'Failed to save image')
-
- return False
-
- def _saveScatter(self, plot, filename, nameFilter):
- """Save an image from the plot.
-
- :param str filename: The name of the file to write
- :param str nameFilter: The selected name filter
- :return: False if format is not supported or save failed,
- True otherwise.
- """
- if nameFilter not in self.DEFAULT_SCATTER_FILTERS:
- return False
-
- if nameFilter == self.SCATTER_FILTER_NXDATA:
- entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
- if entryPath is None:
- return False
- scatter = plot.getScatter()
-
- x = scatter.getXData(copy=False)
- y = scatter.getYData(copy=False)
- z = scatter.getValueData(copy=False)
-
- xerror = scatter.getXErrorData(copy=False)
- if isinstance(xerror, float):
- xerror = xerror * numpy.ones(x.shape, dtype=numpy.float32)
-
- yerror = scatter.getYErrorData(copy=False)
- if isinstance(yerror, float):
- yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32)
-
- xlabel = plot.getGraphXLabel()
- ylabel = plot.getGraphYLabel()
-
- return save_NXdata(
- filename,
- nxentry_name=entryPath,
- signal=z,
- axes=[x, y],
- signal_name="values",
- axes_names=["x", "y"],
- axes_long_names=[xlabel, ylabel],
- axes_errors=[xerror, yerror],
- title=plot.getGraphTitle())
-
- def setFileFilter(self, dataKind, nameFilter, func, index=None, appendToFile=False):
- """Set a name filter to add/replace a file format support
-
- :param str dataKind:
- The kind of data for which the provided filter is valid.
- One of: 'all', 'curve', 'curves', 'image', 'scatter'
- :param str nameFilter: The name filter in the QFileDialog.
- See :meth:`QFileDialog.setNameFilters`.
- :param callable func: The function to call to perform saving.
- Expected signature is:
- bool func(PlotWidget plot, str filename, str nameFilter)
- :param bool appendToFile: True to append the data into the selected
- file.
- :param integer index: Index of the filter in the final list (or None)
- """
- assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
-
- if appendToFile:
- self._appendFilters.append(nameFilter)
-
- # first append or replace the new filter to prevent colissions
- self._filters[dataKind][nameFilter] = func
- if index is None:
- # we are already done
- return
-
- # get the current ordered list of keys
- keyList = list(self._filters[dataKind].keys())
-
- # deal with negative indices
- if index < 0:
- index = len(keyList) + index
- if index < 0:
- index = 0
-
- if index >= len(keyList):
- # nothing to be done, already at the end
- txt = 'Requested index %d impossible, already at the end' % index
- _logger.info(txt)
- return
-
- # get the new ordered list
- oldIndex = keyList.index(nameFilter)
- del keyList[oldIndex]
- keyList.insert(index, nameFilter)
-
- # build the new filters
- newFilters = OrderedDict()
- for key in keyList:
- newFilters[key] = self._filters[dataKind][key]
-
- # and update the filters
- self._filters[dataKind] = newFilters
- return
-
- def getFileFilters(self, dataKind):
- """Returns the nameFilter and associated function for a kind of data.
-
- :param str dataKind:
- The kind of data for which the provided filter is valid.
- On of: 'all', 'curve', 'curves', 'image', 'scatter'
- :return: {nameFilter: function} associations.
- :rtype: collections.OrderedDict
- """
- assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
-
- return self._filters[dataKind].copy()
-
- def _actionTriggered(self, checked=False):
- """Handle save action."""
- # Set-up filters
- filters = OrderedDict()
-
- # Add image filters if there is an active image
- if self.plot.getActiveImage() is not None:
- filters.update(self._filters['image'].items())
-
- # Add curve filters if there is a curve to save
- if (self.plot.getActiveCurve() is not None or
- len(self.plot.getAllCurves()) == 1):
- filters.update(self._filters['curve'].items())
- if len(self.plot.getAllCurves()) >= 1:
- filters.update(self._filters['curves'].items())
-
- # Add scatter filters if there is a scatter
- # todo: CSV
- if self.plot.getScatter() is not None:
- filters.update(self._filters['scatter'].items())
-
- filters.update(self._filters['all'].items())
-
- # Create and run File dialog
- dialog = qt.QFileDialog(self.plot)
- dialog.setOption(dialog.DontUseNativeDialog)
- dialog.setWindowTitle("Output File Selection")
- dialog.setModal(1)
- dialog.setNameFilters(list(filters.keys()))
-
- dialog.setFileMode(dialog.AnyFile)
- dialog.setAcceptMode(dialog.AcceptSave)
-
- def onFilterSelection(filt_):
- # disable overwrite confirmation for NXdata types,
- # because we append the data to existing files
- if filt_ in self._appendFilters:
- dialog.setOption(dialog.DontConfirmOverwrite)
- else:
- dialog.setOption(dialog.DontConfirmOverwrite, False)
-
- dialog.filterSelected.connect(onFilterSelection)
-
- if not dialog.exec_():
- return False
-
- nameFilter = dialog.selectedNameFilter()
- filename = dialog.selectedFiles()[0]
- dialog.close()
-
- if '(' in nameFilter and ')' == nameFilter.strip()[-1]:
- # Check for correct file extension
- # Extract file extensions as .something
- extensions = [ext[ext.find('.'):] for ext in
- nameFilter[nameFilter.find('(') + 1:-1].split()]
- for ext in extensions:
- if (len(filename) > len(ext) and
- filename[-len(ext):].lower() == ext.lower()):
- break
- else: # filename has no extension supported in nameFilter, add one
- if len(extensions) >= 1:
- filename += extensions[0]
-
- # Handle save
- func = filters.get(nameFilter, None)
- if func is not None:
- return func(self.plot, filename, nameFilter)
- else:
- _logger.error('Unsupported file filter: %s', nameFilter)
- return False
-
-
-def _plotAsPNG(plot):
- """Save a :class:`Plot` as PNG and return the payload.
-
- :param plot: The :class:`Plot` to save
- """
- pngFile = BytesIO()
- plot.saveGraph(pngFile, fileFormat='png')
- pngFile.flush()
- pngFile.seek(0)
- data = pngFile.read()
- pngFile.close()
- return data
-
-
-class PrintAction(PlotAction):
- """QAction for printing the plot.
-
- It opens a Print dialog.
-
- Current implementation print a bitmap of the plot area and not vector
- graphics, so printing quality is not great.
-
- :param plot: :class:`.PlotWidget` instance on which to operate.
- :param parent: See :class:`QAction`.
- """
-
- def __init__(self, plot, parent=None):
- super(PrintAction, self).__init__(
- plot, icon='document-print', text='Print...',
- tooltip='Open print dialog',
- triggered=self.printPlot,
- checkable=False, parent=parent)
- self.setShortcut(qt.QKeySequence.Print)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
-
- def getPrinter(self):
- """The QPrinter instance used by the PrintAction.
-
- :rtype: QPrinter
- """
- return printer.getDefaultPrinter()
-
- @property
- @deprecated(replacement="getPrinter()", since_version="0.8.0")
- def printer(self):
- return self.getPrinter()
-
- def printPlotAsWidget(self):
- """Open the print dialog and print the plot.
-
- Use :meth:`QWidget.render` to print the plot
-
- :return: True if successful
- """
- dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
- dialog.setWindowTitle('Print Plot')
- if not dialog.exec_():
- return False
-
- # Print a snapshot of the plot widget at the top of the page
- widget = self.plot.centralWidget()
-
- painter = qt.QPainter()
- if not painter.begin(self.getPrinter()):
- return False
-
- pageRect = self.getPrinter().pageRect()
- xScale = pageRect.width() / widget.width()
- yScale = pageRect.height() / widget.height()
- scale = min(xScale, yScale)
-
- painter.translate(pageRect.width() / 2., 0.)
- painter.scale(scale, scale)
- painter.translate(-widget.width() / 2., 0.)
- widget.render(painter)
- painter.end()
-
- return True
-
- def printPlot(self):
- """Open the print dialog and print the plot.
-
- Use :meth:`Plot.saveGraph` to print the plot.
-
- :return: True if successful
- """
- # Init printer and start printer dialog
- dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
- dialog.setWindowTitle('Print Plot')
- if not dialog.exec_():
- return False
-
- # Save Plot as PNG and make a pixmap from it with default dpi
- pngData = _plotAsPNG(self.plot)
-
- pixmap = qt.QPixmap()
- pixmap.loadFromData(pngData, 'png')
-
- xScale = self.getPrinter().pageRect().width() / pixmap.width()
- yScale = self.getPrinter().pageRect().height() / pixmap.height()
- scale = min(xScale, yScale)
-
- # Draw pixmap with painter
- painter = qt.QPainter()
- if not painter.begin(self.getPrinter()):
- return False
-
- painter.drawPixmap(0, 0,
- pixmap.width() * scale,
- pixmap.height() * scale,
- pixmap)
- painter.end()
-
- return True
-
-
-class CopyAction(PlotAction):
- """QAction to copy :class:`.PlotWidget` content to clipboard.
-
- :param plot: :class:`.PlotWidget` instance on which to operate
- :param parent: See :class:`QAction`
- """
-
- def __init__(self, plot, parent=None):
- super(CopyAction, self).__init__(
- plot, icon='edit-copy', text='Copy plot',
- tooltip='Copy a snapshot of the plot into the clipboard',
- triggered=self.copyPlot,
- checkable=False, parent=parent)
- self.setShortcut(qt.QKeySequence.Copy)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
-
- def copyPlot(self):
- """Copy plot content to the clipboard as a bitmap."""
- # Save Plot as PNG and make a QImage from it with default dpi
- pngData = _plotAsPNG(self.plot)
- image = qt.QImage.fromData(pngData, 'png')
- qt.QApplication.clipboard().setImage(image)
diff --git a/silx/gui/plot/backends/BackendBase.py b/silx/gui/plot/backends/BackendBase.py
deleted file mode 100755
index 6fc1aa7..0000000
--- a/silx/gui/plot/backends/BackendBase.py
+++ /dev/null
@@ -1,578 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Base class for Plot backends.
-
-It documents the Plot backend API.
-
-This API is a simplified version of PyMca PlotBackend API.
-"""
-
-__authors__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "21/12/2018"
-
-import weakref
-from ... import qt
-
-
-# Names for setCursor
-CURSOR_DEFAULT = 'default'
-CURSOR_POINTING = 'pointing'
-CURSOR_SIZE_HOR = 'size horizontal'
-CURSOR_SIZE_VER = 'size vertical'
-CURSOR_SIZE_ALL = 'size all'
-
-
-class BackendBase(object):
- """Class defining the API a backend of the Plot should provide."""
-
- def __init__(self, plot, parent=None):
- """Init.
-
- :param Plot plot: The Plot this backend is attached to
- :param parent: The parent widget of the plot widget.
- """
- self.__xLimits = 1., 100.
- self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)}
- self.__yAxisInverted = False
- self.__keepDataAspectRatio = False
- self.__xAxisTimeSeries = False
- self._xAxisTimeZone = None
- # Store a weakref to get access to the plot state.
- self._setPlot(plot)
-
- @property
- def _plot(self):
- """The plot this backend is attached to."""
- if self._plotRef is None:
- raise RuntimeError('This backend is not attached to a Plot')
-
- plot = self._plotRef()
- if plot is None:
- raise RuntimeError('This backend is no more attached to a Plot')
- return plot
-
- def _setPlot(self, plot):
- """Allow to set plot after init.
-
- Use with caution, basically **immediately** after init.
- """
- self._plotRef = weakref.ref(plot)
-
- # Add methods
-
- def addCurve(self, x, y,
- color, symbol, linewidth, linestyle,
- yaxis,
- xerror, yerror,
- fill, alpha, symbolsize, baseline):
- """Add a 1D curve given by x an y to the graph.
-
- :param numpy.ndarray x: The data corresponding to the x axis
- :param numpy.ndarray y: The data corresponding to the y axis
- :param color: color(s) to be used
- :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- :param str symbol: Symbol to be drawn at each (x, y) position::
-
- - ' ' or '' no symbol
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
-
- :param float linewidth: The width of the curve in pixels
- :param str linestyle: Type of line::
-
- - ' ' or '' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
-
- :param str yaxis: The Y axis this curve belongs to in: 'left', 'right'
- :param xerror: Values with the uncertainties on the x values
- :type xerror: numpy.ndarray or None
- :param yerror: Values with the uncertainties on the y values
- :type yerror: numpy.ndarray or None
- :param bool fill: True to fill the curve, False otherwise
- :param float alpha: Curve opacity, as a float in [0., 1.]
- :param float symbolsize: Size of the symbol (if any) drawn
- at each (x, y) position.
- :returns: The handle used by the backend to univocally access the curve
- """
- return object()
-
- def addImage(self, data,
- origin, scale,
- colormap, alpha):
- """Add an image to the plot.
-
- :param numpy.ndarray data: (nrows, ncolumns) data or
- (nrows, ncolumns, RGBA) ubyte array
- :param origin: (origin X, origin Y) of the data.
- Default: (0., 0.)
- :type origin: 2-tuple of float
- :param scale: (scale X, scale Y) of the data.
- Default: (1., 1.)
- :type scale: 2-tuple of float
- :param ~silx.gui.colors.Colormap colormap: Colormap object to use.
- Ignored if data is RGB(A).
- :param float alpha: Opacity of the image, as a float in range [0, 1].
- :returns: The handle used by the backend to univocally access the image
- """
- return object()
-
- def addTriangles(self, x, y, triangles,
- color, alpha):
- """Add a set of triangles.
-
- :param numpy.ndarray x: The data corresponding to the x axis
- :param numpy.ndarray y: The data corresponding to the y axis
- :param numpy.ndarray triangles: The indices to make triangles
- as a (Ntriangle, 3) array
- :param numpy.ndarray color: color(s) as (npoints, 4) array
- :param float alpha: Opacity as a float in [0., 1.]
- :returns: The triangles' unique identifier used by the backend
- """
- return object()
-
- def addShape(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
- """Add an item (i.e. a shape) to the plot.
-
- :param numpy.ndarray x: The X coords of the points of the shape
- :param numpy.ndarray y: The Y coords of the points of the shape
- :param str shape: Type of item to be drawn in
- hline, polygon, rectangle, vline, polylines
- :param str color: Color of the item
- :param bool fill: True to fill the shape
- :param bool overlay: True if item is an overlay, False otherwise
- :param str linestyle: Style of the line.
- Only relevant for line markers where X or Y is None.
- Value in:
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
- :param float linewidth: Width of the line.
- Only relevant for line markers where X or Y is None.
- :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
- '#FF0000'. It is used to draw dotted line using a second color.
- :returns: The handle used by the backend to univocally access the item
- """
- return object()
-
- def addMarker(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
- """Add a point, vertical line or horizontal line marker to the plot.
-
- :param float x: Horizontal position of the marker in graph coordinates.
- If None, the marker is a horizontal line.
- :param float y: Vertical position of the marker in graph coordinates.
- If None, the marker is a vertical line.
- :param str text: Text associated to the marker (or None for no text)
- :param str color: Color to be used for instance 'blue', 'b', '#FF0000'
- :param str symbol: Symbol representing the marker.
- Only relevant for point markers where X and Y are not None.
- Value in:
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
- :param str linestyle: Style of the line.
- Only relevant for line markers where X or Y is None.
- Value in:
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
- :param float linewidth: Width of the line.
- Only relevant for line markers where X or Y is None.
- :param constraint: A function filtering marker displacement by
- dragging operations or None for no filter.
- This function is called each time a marker is
- moved.
- :type constraint: None or a callable that takes the coordinates of
- the current cursor position in the plot as input
- and that returns the filtered coordinates.
- :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
- :return: Handle used by the backend to univocally access the marker
- """
- return object()
-
- # Remove methods
-
- def remove(self, item):
- """Remove an existing item from the plot.
-
- :param item: A backend specific item handle returned by a add* method
- """
- pass
-
- # Interaction methods
-
- def setGraphCursorShape(self, cursor):
- """Set the cursor shape.
-
- To override in interactive backends.
-
- :param str cursor: Name of the cursor shape or None
- """
- pass
-
- def setGraphCursor(self, flag, color, linewidth, linestyle):
- """Toggle the display of a crosshair cursor and set its attributes.
-
- To override in interactive backends.
-
- :param bool flag: Toggle the display of a crosshair cursor.
- :param color: The color to use for the crosshair.
- :type color: A string (either a predefined color name in colors.py
- or "#RRGGBB")) or a 4 columns unsigned byte array.
- :param int linewidth: The width of the lines of the crosshair.
- :param linestyle: Type of line::
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
-
- :type linestyle: None or one of the predefined styles.
- """
- pass
-
- def getItemsFromBackToFront(self, condition=None):
- """Returns the list of plot items order as rendered by the backend.
-
- This is the order used for rendering.
- By default, it takes into account overlays, z value and order of addition of items,
- but backends can override it.
-
- :param callable condition:
- Callable taking an item as input and returning False for items to skip.
- If None (default), no item is skipped.
- :rtype: List[~silx.gui.plot.items.Item]
- """
- # Sort items: Overlays first, then others
- # and in each category ordered by z and then by order of addition
- # as content keeps this order.
- content = self._plot.getItems()
- if condition is not None:
- content = [item for item in content if condition(item)]
-
- return sorted(
- content,
- key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue()))
-
- def pickItem(self, x, y, item):
- """Return picked indices if any, or None.
-
- :param float x: The x pixel coord where to pick.
- :param float y: The y pixel coord where to pick.
- :param item: A backend item created with add* methods.
- :return: None if item was not picked, else returns
- picked indices information.
- :rtype: Union[None,List]
- """
- return None
-
- # Update curve
-
- def setCurveColor(self, curve, color):
- """Set the color of a curve.
-
- :param curve: The curve handle
- :param str color: The color to use.
- """
- pass
-
- # Misc.
-
- def getWidgetHandle(self):
- """Return the widget this backend is drawing to."""
- return None
-
- def postRedisplay(self):
- """Trigger a :meth:`Plot.replot`.
-
- Default implementation triggers a synchronous replot if plot is dirty.
- This method should be overridden by the embedding widget in order to
- provide an asynchronous call to replot in order to optimize the number
- replot operations.
- """
- # This method can be deferred and it might happen that plot has been
- # destroyed in between, especially with unittests
-
- plot = self._plotRef()
- if plot is not None and plot._getDirtyPlot():
- plot.replot()
-
- def replot(self):
- """Redraw the plot."""
- pass
-
- def saveGraph(self, fileName, fileFormat, dpi):
- """Save the graph to a file (or a StringIO)
-
- At least "png", "svg" are supported.
-
- :param fileName: Destination
- :type fileName: String or StringIO or BytesIO
- :param str fileFormat: String specifying the format
- :param int dpi: The resolution to use or None.
- """
- pass
-
- # Graph labels
-
- def setGraphTitle(self, title):
- """Set the main title of the plot.
-
- :param str title: Title associated to the plot
- """
- pass
-
- def setGraphXLabel(self, label):
- """Set the X axis label.
-
- :param str label: label associated to the plot bottom X axis
- """
- pass
-
- def setGraphYLabel(self, label, axis):
- """Set the left Y axis label.
-
- :param str label: label associated to the plot left Y axis
- :param str axis: The axis for which to get the limits: left or right
- """
- pass
-
- # Graph limits
-
- def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
- """Set the limits of the X and Y axes at once.
-
- :param float xmin: minimum bottom axis value
- :param float xmax: maximum bottom axis value
- :param float ymin: minimum left axis value
- :param float ymax: maximum left axis value
- :param float y2min: minimum right axis value
- :param float y2max: maximum right axis value
- """
- self.__xLimits = xmin, xmax
- self.__yLimits['left'] = ymin, ymax
- if y2min is not None and y2max is not None:
- self.__yLimits['right'] = y2min, y2max
-
- def getGraphXLimits(self):
- """Get the graph X (bottom) limits.
-
- :return: Minimum and maximum values of the X axis
- """
- return self.__xLimits
-
- def setGraphXLimits(self, xmin, xmax):
- """Set the limits of X axis.
-
- :param float xmin: minimum bottom axis value
- :param float xmax: maximum bottom axis value
- """
- self.__xLimits = xmin, xmax
-
- def getGraphYLimits(self, axis):
- """Get the graph Y (left) limits.
-
- :param str axis: The axis for which to get the limits: left or right
- :return: Minimum and maximum values of the Y axis
- """
- return self.__yLimits[axis]
-
- def setGraphYLimits(self, ymin, ymax, axis):
- """Set the limits of the Y axis.
-
- :param float ymin: minimum left axis value
- :param float ymax: maximum left axis value
- :param str axis: The axis for which to get the limits: left or right
- """
- self.__yLimits[axis] = ymin, ymax
-
- # Graph axes
-
-
- def getXAxisTimeZone(self):
- """Returns tzinfo that is used if the X-Axis plots date-times.
-
- None means the datetimes are interpreted as local time.
-
- :rtype: datetime.tzinfo of None.
- """
- return self._xAxisTimeZone
-
- def setXAxisTimeZone(self, tz):
- """Sets tzinfo that is used if the X-Axis plots date-times.
-
- Use None to let the datetimes be interpreted as local time.
-
- :rtype: datetime.tzinfo of None.
- """
- self._xAxisTimeZone = tz
-
- def isXAxisTimeSeries(self):
- """Return True if the X-axis scale shows datetime objects.
-
- :rtype: bool
- """
- return self.__xAxisTimeSeries
-
- def setXAxisTimeSeries(self, isTimeSeries):
- """Set whether the X-axis is a time series
-
- :param bool flag: True to switch to time series, False for regular axis.
- """
- self.__xAxisTimeSeries = bool(isTimeSeries)
-
- def setXAxisLogarithmic(self, flag):
- """Set the X axis scale between linear and log.
-
- :param bool flag: If True, the bottom axis will use a log scale
- """
- pass
-
- def setYAxisLogarithmic(self, flag):
- """Set the Y axis scale between linear and log.
-
- :param bool flag: If True, the left axis will use a log scale
- """
- pass
-
- def setYAxisInverted(self, flag):
- """Invert the Y axis.
-
- :param bool flag: If True, put the vertical axis origin on the top
- """
- self.__yAxisInverted = bool(flag)
-
- def isYAxisInverted(self):
- """Return True if left Y axis is inverted, False otherwise."""
- return self.__yAxisInverted
-
- def isKeepDataAspectRatio(self):
- """Returns whether the plot is keeping data aspect ratio or not."""
- return self.__keepDataAspectRatio
-
- def setKeepDataAspectRatio(self, flag):
- """Set whether to keep data aspect ratio or not.
-
- :param flag: True to respect data aspect ratio
- :type flag: Boolean, default True
- """
- self.__keepDataAspectRatio = bool(flag)
-
- def setGraphGrid(self, which):
- """Set grid.
-
- :param which: None to disable grid, 'major' for major grid,
- 'both' for major and minor grid
- """
- pass
-
- # Data <-> Pixel coordinates conversion
-
- def dataToPixel(self, x, y, axis):
- """Convert a position in data space to a position in pixels
- in the widget.
-
- :param float x: The X coordinate in data space.
- :param float y: The Y coordinate in data space.
- :param str axis: The Y axis to use for the conversion
- ('left' or 'right').
- :returns: The corresponding position in pixels or
- None if the data position is not in the displayed area.
- :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
- """
- raise NotImplementedError()
-
- def pixelToData(self, x, y, axis):
- """Convert a position in pixels in the widget to a position in
- the data space.
-
- :param float x: The X coordinate in pixels.
- :param float y: The Y coordinate in pixels.
- :param str axis: The Y axis to use for the conversion
- ('left' or 'right').
- :returns: The corresponding position in data space or
- None if the pixel position is not in the plot area.
- :rtype: A tuple of 2 floats: (xData, yData) or None.
- """
- raise NotImplementedError()
-
- def getPlotBoundsInPixels(self):
- """Plot area bounds in widget coordinates in pixels.
-
- :return: bounds as a 4-tuple of int: (left, top, width, height)
- """
- raise NotImplementedError()
-
- def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
- """Set the size of plot margins as ratios.
-
- Values are expected in [0., 1.]
-
- :param float left:
- :param float top:
- :param float right:
- :param float bottom:
- """
- pass
-
- def setForegroundColors(self, foregroundColor, gridColor):
- """Set foreground and grid colors used to display this widget.
-
- :param List[float] foregroundColor: RGBA foreground color of the widget
- :param List[float] gridColor: RGBA grid color of the data view
- """
- pass
-
- def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
- """Set background colors used to display this widget.
-
- :param List[float] backgroundColor: RGBA background color of the widget
- :param Union[Tuple[float],None] dataBackgroundColor:
- RGBA background color of the data view
- """
- pass
diff --git a/silx/gui/plot/backends/BackendMatplotlib.py b/silx/gui/plot/backends/BackendMatplotlib.py
deleted file mode 100755
index 432b0b0..0000000
--- a/silx/gui/plot/backends/BackendMatplotlib.py
+++ /dev/null
@@ -1,1544 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Matplotlib Plot backend."""
-
-from __future__ import division
-
-__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
-__license__ = "MIT"
-__date__ = "21/12/2018"
-
-
-import logging
-import datetime as dt
-from typing import Tuple
-import numpy
-
-from pkg_resources import parse_version as _parse_version
-
-
-_logger = logging.getLogger(__name__)
-
-
-from ... import qt
-
-# First of all init matplotlib and set its backend
-from ...utils.matplotlib import FigureCanvasQTAgg
-import matplotlib
-from matplotlib.container import Container
-from matplotlib.figure import Figure
-from matplotlib.patches import Rectangle, Polygon
-from matplotlib.image import AxesImage
-from matplotlib.backend_bases import MouseEvent
-from matplotlib.lines import Line2D
-from matplotlib.text import Text
-from matplotlib.collections import PathCollection, LineCollection
-from matplotlib.ticker import Formatter, ScalarFormatter, Locator
-from matplotlib.tri import Triangulation
-from matplotlib.collections import TriMesh
-from matplotlib import path as mpath
-
-from . import BackendBase
-from .. import items
-from .._utils import FLOAT32_MINPOS
-from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp
-
-_PATCH_LINESTYLE = {
- "-": 'solid',
- "--": 'dashed',
- '-.': 'dashdot',
- ':': 'dotted',
- '': "solid",
- None: "solid",
-}
-"""Patches do not uses the same matplotlib syntax"""
-
-_MARKER_PATHS = {}
-"""Store cached extra marker paths"""
-
-_SPECIAL_MARKERS = {
- 'tickleft': 0,
- 'tickright': 1,
- 'tickup': 2,
- 'tickdown': 3,
- 'caretleft': 4,
- 'caretright': 5,
- 'caretup': 6,
- 'caretdown': 7,
-}
-
-
-def normalize_linestyle(linestyle):
- """Normalize known old-style linestyle, else return the provided value."""
- return _PATCH_LINESTYLE.get(linestyle, linestyle)
-
-def get_path_from_symbol(symbol):
- """Get the path representation of a symbol, else None if
- it is not provided.
-
- :param str symbol: Symbol description used by silx
- :rtype: Union[None,matplotlib.path.Path]
- """
- if symbol == u'\u2665':
- path = _MARKER_PATHS.get(symbol, None)
- if path is not None:
- return path
- vertices = numpy.array([
- [0,-99],
- [31,-73], [47,-55], [55,-46],
- [63,-37], [94,-2], [94,33],
- [94,69], [71,89], [47,89],
- [24,89], [8,74], [0,58],
- [-8,74], [-24,89], [-47,89],
- [-71,89], [-94,69], [-94,33],
- [-94,-2], [-63,-37], [-55,-46],
- [-47,-55], [-31,-73], [0,-99],
- [0,-99]])
- codes = [mpath.Path.CURVE4] * len(vertices)
- codes[0] = mpath.Path.MOVETO
- codes[-1] = mpath.Path.CLOSEPOLY
- path = mpath.Path(vertices, codes)
- _MARKER_PATHS[symbol] = path
- return path
- return None
-
-class NiceDateLocator(Locator):
- """
- Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates)
- to find the tick locations. This results in the same number behaviour
- as when using the silx Open GL backend.
-
- Expects the data to be posix timestampes (i.e. seconds since 1970)
- """
- def __init__(self, numTicks=5, tz=None):
- """
- :param numTicks: target number of ticks
- :param datetime.tzinfo tz: optional time zone. None is local time.
- """
- super(NiceDateLocator, self).__init__()
- self.numTicks = numTicks
-
- self._spacing = None
- self._unit = None
- self.tz = tz
-
- @property
- def spacing(self):
- """ The current spacing. Will be updated when new tick value are made"""
- return self._spacing
-
- @property
- def unit(self):
- """ The current DtUnit. Will be updated when new tick value are made"""
- return self._unit
-
- def __call__(self):
- """Return the locations of the ticks"""
- vmin, vmax = self.axis.get_view_interval()
- return self.tick_values(vmin, vmax)
-
- def tick_values(self, vmin, vmax):
- """ Calculates tick values
- """
- if vmax < vmin:
- vmin, vmax = vmax, vmin
-
- # vmin and vmax should be timestamps (i.e. seconds since 1 Jan 1970)
- dtMin = dt.datetime.fromtimestamp(vmin, tz=self.tz)
- dtMax = dt.datetime.fromtimestamp(vmax, tz=self.tz)
- dtTicks, self._spacing, self._unit = \
- calcTicks(dtMin, dtMax, self.numTicks)
-
- # Convert datetime back to time stamps.
- ticks = [timestamp(dtTick) for dtTick in dtTicks]
- return ticks
-
-
-class NiceAutoDateFormatter(Formatter):
- """
- Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the
- best possible formats given the locators current spacing an date unit.
- """
-
- def __init__(self, locator, tz=None):
- """
- :param niceDateLocator: a NiceDateLocator object
- :param datetime.tzinfo tz: optional time zone. None is local time.
- """
- super(NiceAutoDateFormatter, self).__init__()
- self.locator = locator
- self.tz = tz
-
- @property
- def formatString(self):
- if self.locator.spacing is None or self.locator.unit is None:
- # Locator has no spacing or units yet. Return elaborate fmtString
- return "Y-%m-%d %H:%M:%S"
- else:
- return bestFormatString(self.locator.spacing, self.locator.unit)
-
- def __call__(self, x, pos=None):
- """Return the format for tick val *x* at position *pos*
- Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
- """
- dateTime = dt.datetime.fromtimestamp(x, tz=self.tz)
- tickStr = dateTime.strftime(self.formatString)
- return tickStr
-
-
-class _PickableContainer(Container):
- """Artists container with a :meth:`contains` method"""
-
- def __init__(self, *args, **kwargs):
- Container.__init__(self, *args, **kwargs)
- self.__zorder = None
-
- @property
- def axes(self):
- """Mimin Artist.axes"""
- for child in self.get_children():
- if hasattr(child, 'axes'):
- return child.axes
- return None
-
- def draw(self, *args, **kwargs):
- """artist-like draw to broadcast draw to children"""
- for child in self.get_children():
- child.draw(*args, **kwargs)
-
- def get_zorder(self):
- """Mimic Artist.get_zorder"""
- return self.__zorder
-
- def set_zorder(self, z):
- """Mimic Artist.set_zorder to broadcast to children"""
- if z != self.__zorder:
- self.__zorder = z
- for child in self.get_children():
- child.set_zorder(z)
-
- def contains(self, mouseevent):
- """Mimic Artist.contains, and call it on all children.
-
- :param mouseevent:
- :return: Picking status and associated information as a dict
- :rtype: (bool,dict)
- """
- # Goes through children from front to back and return first picked one.
- for child in reversed(self.get_children()):
- picked, info = child.contains(mouseevent)
- if picked:
- return picked, info
- return False, {}
-
-
-class _TextWithOffset(Text):
- """Text object which can be displayed at a specific position
- of the plot, but with a pixel offset"""
-
- def __init__(self, *args, **kwargs):
- Text.__init__(self, *args, **kwargs)
- self.pixel_offset = (0, 0)
- self.__cache = None
-
- def draw(self, renderer):
- self.__cache = None
- return Text.draw(self, renderer)
-
- def __get_xy(self):
- if self.__cache is not None:
- return self.__cache
-
- align = self.get_horizontalalignment()
- if align == "left":
- xoffset = self.pixel_offset[0]
- elif align == "right":
- xoffset = -self.pixel_offset[0]
- else:
- xoffset = 0
-
- align = self.get_verticalalignment()
- if align == "top":
- yoffset = -self.pixel_offset[1]
- elif align == "bottom":
- yoffset = self.pixel_offset[1]
- else:
- yoffset = 0
-
- trans = self.get_transform()
- x = super(_TextWithOffset, self).convert_xunits(self._x)
- y = super(_TextWithOffset, self).convert_xunits(self._y)
- pos = x, y
-
- try:
- invtrans = trans.inverted()
- except numpy.linalg.LinAlgError:
- # Cannot inverse transform, fallback: pos without offset
- self.__cache = None
- return pos
-
- proj = trans.transform_point(pos)
- proj = proj + numpy.array((xoffset, yoffset))
- pos = invtrans.transform_point(proj)
- self.__cache = pos
- return pos
-
- def convert_xunits(self, x):
- """Return the pixel position of the annotated point."""
- return self.__get_xy()[0]
-
- def convert_yunits(self, y):
- """Return the pixel position of the annotated point."""
- return self.__get_xy()[1]
-
-
-class _MarkerContainer(_PickableContainer):
- """Marker artists container supporting draw/remove and text position update
-
- :param artists:
- Iterable with either one Line2D or a Line2D and a Text.
- The use of an iterable if enforced by Container being
- a subclass of tuple that defines a specific __new__.
- :param x: X coordinate of the marker (None for horizontal lines)
- :param y: Y coordinate of the marker (None for vertical lines)
- """
-
- def __init__(self, artists, symbol, x, y, yAxis):
- self.line = artists[0]
- self.text = artists[1] if len(artists) > 1 else None
- self.symbol = symbol
- self.x = x
- self.y = y
- self.yAxis = yAxis
-
- _PickableContainer.__init__(self, artists)
-
- def draw(self, *args, **kwargs):
- """artist-like draw to broadcast draw to line and text"""
- self.line.draw(*args, **kwargs)
- if self.text is not None:
- self.text.draw(*args, **kwargs)
-
- def updateMarkerText(self, xmin, xmax, ymin, ymax, yinverted):
- """Update marker text position and visibility according to plot limits
-
- :param xmin: X axis lower limit
- :param xmax: X axis upper limit
- :param ymin: Y axis lower limit
- :param ymax: Y axis upper limit
- :param yinverted: True if the y axis is inverted
- """
- if self.text is not None:
- visible = ((self.x is None or xmin <= self.x <= xmax) and
- (self.y is None or ymin <= self.y <= ymax))
- self.text.set_visible(visible)
-
- if self.x is not None and self.y is not None:
- if self.symbol is None:
- valign = 'baseline'
- else:
- if yinverted:
- valign = 'bottom'
- else:
- valign = 'top'
- self.text.set_verticalalignment(valign)
-
- elif self.y is None: # vertical line
- # Always display it on top
- center = (ymax + ymin) * 0.5
- pos = (ymax - ymin) * 0.5 * 0.99
- if yinverted:
- pos = -pos
- self.text.set_y(center + pos)
-
- elif self.x is None: # Horizontal line
- delta = abs(xmax - xmin)
- if xmin > xmax:
- xmax = xmin
- xmax -= 0.005 * delta
- self.text.set_x(xmax)
-
- def contains(self, mouseevent):
- """Mimic Artist.contains, and call it on the line Artist.
-
- :param mouseevent:
- :return: Picking status and associated information as a dict
- :rtype: (bool,dict)
- """
- return self.line.contains(mouseevent)
-
-
-class _DoubleColoredLinePatch(matplotlib.patches.Patch):
- """Matplotlib patch to display any patch using double color."""
-
- def __init__(self, patch):
- super(_DoubleColoredLinePatch, self).__init__()
- self.__patch = patch
- self.linebgcolor = None
-
- def __getattr__(self, name):
- return getattr(self.__patch, name)
-
- def draw(self, renderer):
- oldLineStype = self.__patch.get_linestyle()
- if self.linebgcolor is not None and oldLineStype != "solid":
- oldLineColor = self.__patch.get_edgecolor()
- oldHatch = self.__patch.get_hatch()
- self.__patch.set_linestyle("solid")
- self.__patch.set_edgecolor(self.linebgcolor)
- self.__patch.set_hatch(None)
- self.__patch.draw(renderer)
- self.__patch.set_linestyle(oldLineStype)
- self.__patch.set_edgecolor(oldLineColor)
- self.__patch.set_hatch(oldHatch)
- self.__patch.draw(renderer)
-
- def set_transform(self, transform):
- self.__patch.set_transform(transform)
-
- def get_path(self):
- return self.__patch.get_path()
-
- def contains(self, mouseevent, radius=None):
- return self.__patch.contains(mouseevent, radius)
-
- def contains_point(self, point, radius=None):
- return self.__patch.contains_point(point, radius)
-
-
-class Image(AxesImage):
- """An AxesImage with a fast path for uint8 RGBA images.
-
- :param List[float] silx_origin: (ox, oy) Offset of the image.
- :param List[float] silx_scale: (sx, sy) Scale of the image.
- """
-
- def __init__(self, *args,
- silx_origin=(0., 0.),
- silx_scale=(1., 1.),
- **kwargs):
- super().__init__(*args, **kwargs)
- self.__silx_origin = silx_origin
- self.__silx_scale = silx_scale
-
- def contains(self, mouseevent):
- """Overridden to fill 'ind' with row and column"""
- inside, info = super().contains(mouseevent)
- if inside:
- x, y = mouseevent.xdata, mouseevent.ydata
- ox, oy = self.__silx_origin
- sx, sy = self.__silx_scale
- height, width = self.get_size()
- column = numpy.clip(int((x - ox) / sx), 0, width - 1)
- row = numpy.clip(int((y - oy) / sy), 0, height - 1)
- info['ind'] = (row,), (column,)
- return inside, info
-
- def set_data(self, A):
- """Overridden to add a fast path for RGBA unit8 images"""
- A = numpy.array(A, copy=False)
- if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8:
- super(Image, self).set_data(A)
- else:
- # Call AxesImage.set_data with small data to set attributes
- super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype))
- self._A = A # Override stored data
-
-
-class BackendMatplotlib(BackendBase.BackendBase):
- """Base class for Matplotlib backend without a FigureCanvas.
-
- For interactive on screen plot, see :class:`BackendMatplotlibQt`.
-
- See :class:`BackendBase.BackendBase` for public API documentation.
- """
-
- def __init__(self, plot, parent=None):
- super(BackendMatplotlib, self).__init__(plot, parent)
-
- # matplotlib is handling keep aspect ratio at draw time
- # When keep aspect ratio is on, and one changes the limits and
- # ask them *before* next draw has been performed he will get the
- # limits without applying keep aspect ratio.
- # This attribute is used to ensure consistent values returned
- # when getting the limits at the expense of a replot
- self._dirtyLimits = True
- self._axesDisplayed = True
- self._matplotlibVersion = _parse_version(matplotlib.__version__)
-
- self.fig = Figure()
- self.fig.set_facecolor("w")
-
- self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
- self.ax2 = self.ax.twinx()
- self.ax2.set_label("right")
- # Make sure background of Axes is displayed
- self.ax2.patch.set_visible(False)
- self.ax.patch.set_visible(True)
-
- # Set axis zorder=0.5 so grid is displayed at 0.5
- self.ax.set_axisbelow(True)
-
- # disable the use of offsets
- try:
- axes = [
- self.ax.get_yaxis().get_major_formatter(),
- self.ax.get_xaxis().get_major_formatter(),
- self.ax2.get_yaxis().get_major_formatter(),
- self.ax2.get_xaxis().get_major_formatter(),
- ]
- for axis in axes:
- axis.set_useOffset(False)
- axis.set_scientific(False)
- except:
- _logger.warning('Cannot disabled axes offsets in %s '
- % matplotlib.__version__)
-
- self.ax2.set_autoscaley_on(True)
-
- # this works but the figure color is left
- if self._matplotlibVersion < _parse_version('2'):
- self.ax.set_axis_bgcolor('none')
- else:
- self.ax.set_facecolor('none')
- self.fig.sca(self.ax)
-
- self._background = None
-
- self._colormaps = {}
-
- self._graphCursor = tuple()
-
- self._enableAxis('right', False)
- self._isXAxisTimeSeries = False
-
- def getItemsFromBackToFront(self, condition=None):
- """Order as BackendBase + take into account matplotlib Axes structure"""
- def axesOrder(item):
- if item.isOverlay():
- return 2
- elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right':
- return 1
- else:
- return 0
-
- return sorted(
- BackendBase.BackendBase.getItemsFromBackToFront(
- self, condition=condition),
- key=axesOrder)
-
- def _overlayItems(self):
- """Generator of backend renderer for overlay items"""
- for item in self._plot.getItems():
- if (item.isOverlay() and
- item.isVisible() and
- item._backendRenderer is not None):
- yield item._backendRenderer
-
- def _hasOverlays(self):
- """Returns whether there is an overlay layer or not.
-
- The overlay layers contains overlay items and the crosshair.
-
- :rtype: bool
- """
- if self._graphCursor:
- return True # There is the crosshair
-
- for item in self._overlayItems():
- return True # There is at least one overlay item
- return False
-
- # Add methods
-
- def _getMarkerFromSymbol(self, symbol):
- """Returns a marker that can be displayed by matplotlib.
-
- :param str symbol: A symbol description used by silx
- :rtype: Union[str,int,matplotlib.path.Path]
- """
- path = get_path_from_symbol(symbol)
- if path is not None:
- return path
- num = _SPECIAL_MARKERS.get(symbol, None)
- if num is not None:
- return num
- # This symbol must be supported by matplotlib
- return symbol
-
- def addCurve(self, x, y,
- color, symbol, linewidth, linestyle,
- yaxis,
- xerror, yerror,
- fill, alpha, symbolsize, baseline):
- for parameter in (x, y, color, symbol, linewidth, linestyle,
- yaxis, fill, alpha, symbolsize):
- assert parameter is not None
- assert yaxis in ('left', 'right')
-
- if (len(color) == 4 and
- type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
- color = numpy.array(color, dtype=numpy.float64) / 255.
-
- if yaxis == "right":
- axes = self.ax2
- self._enableAxis("right", True)
- else:
- axes = self.ax
-
- pickradius = 3
-
- artists = [] # All the artists composing the curve
-
- # First add errorbars if any so they are behind the curve
- if xerror is not None or yerror is not None:
- if hasattr(color, 'dtype') and len(color) == len(x):
- errorbarColor = 'k'
- else:
- errorbarColor = color
-
- # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3)
- if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and
- xerror.shape[1] == 1):
- xerror = numpy.ravel(xerror)
- if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
- yerror.shape[1] == 1):
- yerror = numpy.ravel(yerror)
-
- errorbars = axes.errorbar(x, y,
- xerr=xerror, yerr=yerror,
- linestyle=' ', color=errorbarColor)
- artists += list(errorbars.get_children())
-
- if hasattr(color, 'dtype') and len(color) == len(x):
- # scatter plot
- if color.dtype not in [numpy.float32, numpy.float64]:
- actualColor = color / 255.
- else:
- actualColor = color
-
- if linestyle not in ["", " ", None]:
- # scatter plot with an actual line ...
- # we need to assign a color ...
- curveList = axes.plot(x, y,
- linestyle=linestyle,
- color=actualColor[0],
- linewidth=linewidth,
- picker=True,
- pickradius=pickradius,
- marker=None)
- artists += list(curveList)
-
- marker = self._getMarkerFromSymbol(symbol)
- scatter = axes.scatter(x, y,
- color=actualColor,
- marker=marker,
- picker=True,
- pickradius=pickradius,
- s=symbolsize**2)
- artists.append(scatter)
-
- if fill:
- if baseline is None:
- _baseline = FLOAT32_MINPOS
- else:
- _baseline = baseline
- artists.append(axes.fill_between(
- x, _baseline, y, facecolor=actualColor[0], linestyle=''))
-
- else: # Curve
- curveList = axes.plot(x, y,
- linestyle=linestyle,
- color=color,
- linewidth=linewidth,
- marker=symbol,
- picker=True,
- pickradius=pickradius,
- markersize=symbolsize)
- artists += list(curveList)
-
- if fill:
- if baseline is None:
- _baseline = FLOAT32_MINPOS
- else:
- _baseline = baseline
- artists.append(
- axes.fill_between(x, _baseline, y, facecolor=color))
-
- for artist in artists:
- if alpha < 1:
- artist.set_alpha(alpha)
-
- return _PickableContainer(artists)
-
- def addImage(self, data, origin, scale, colormap, alpha):
- # Non-uniform image
- # http://wiki.scipy.org/Cookbook/Histograms
- # Non-linear axes
- # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
- for parameter in (data, origin, scale):
- assert parameter is not None
-
- origin = float(origin[0]), float(origin[1])
- scale = float(scale[0]), float(scale[1])
- height, width = data.shape[0:2]
-
- # All image are shown as RGBA image
- image = Image(self.ax,
- interpolation='nearest',
- picker=True,
- origin='lower',
- silx_origin=origin,
- silx_scale=scale)
-
- if alpha < 1:
- image.set_alpha(alpha)
-
- # Set image extent
- xmin = origin[0]
- xmax = xmin + scale[0] * width
- if scale[0] < 0.:
- xmin, xmax = xmax, xmin
-
- ymin = origin[1]
- ymax = ymin + scale[1] * height
- if scale[1] < 0.:
- ymin, ymax = ymax, ymin
-
- image.set_extent((xmin, xmax, ymin, ymax))
-
- # Set image data
- if scale[0] < 0. or scale[1] < 0.:
- # For negative scale, step by -1
- xstep = 1 if scale[0] >= 0. else -1
- ystep = 1 if scale[1] >= 0. else -1
- data = data[::ystep, ::xstep]
-
- if data.ndim == 2: # Data image, convert to RGBA image
- data = colormap.applyToData(data)
- elif data.dtype == numpy.uint16:
- # Normalize uint16 data to have a similar behavior as opengl backend
- data = data.astype(numpy.float32)
- data /= 65535
-
- image.set_data(data)
- self.ax.add_artist(image)
- return image
-
- def addTriangles(self, x, y, triangles, color, alpha):
- for parameter in (x, y, triangles, color, alpha):
- assert parameter is not None
-
- color = numpy.array(color, copy=False)
- assert color.ndim == 2 and len(color) == len(x)
-
- if color.dtype not in [numpy.float32, numpy.float64]:
- color = color.astype(numpy.float32) / 255.
-
- collection = TriMesh(
- Triangulation(x, y, triangles),
- alpha=alpha,
- pickradius=0) # 0 enables picking on filled triangle
- collection.set_color(color)
- self.ax.add_collection(collection)
-
- return collection
-
- def addShape(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
- if (linebgcolor is not None and
- shape not in ('rectangle', 'polygon', 'polylines')):
- _logger.warning(
- 'linebgcolor not implemented for %s with matplotlib backend',
- shape)
- xView = numpy.array(x, copy=False)
- yView = numpy.array(y, copy=False)
-
- linestyle = normalize_linestyle(linestyle)
-
- if shape == "line":
- item = self.ax.plot(x, y, color=color,
- linestyle=linestyle, linewidth=linewidth,
- marker=None)[0]
-
- elif shape == "hline":
- if hasattr(y, "__len__"):
- y = y[-1]
- item = self.ax.axhline(y, color=color,
- linestyle=linestyle, linewidth=linewidth)
-
- elif shape == "vline":
- if hasattr(x, "__len__"):
- x = x[-1]
- item = self.ax.axvline(x, color=color,
- linestyle=linestyle, linewidth=linewidth)
-
- elif shape == 'rectangle':
- xMin = numpy.nanmin(xView)
- xMax = numpy.nanmax(xView)
- yMin = numpy.nanmin(yView)
- yMax = numpy.nanmax(yView)
- w = xMax - xMin
- h = yMax - yMin
- item = Rectangle(xy=(xMin, yMin),
- width=w,
- height=h,
- fill=False,
- color=color,
- linestyle=linestyle,
- linewidth=linewidth)
- if fill:
- item.set_hatch('.')
-
- if linestyle != "solid" and linebgcolor is not None:
- item = _DoubleColoredLinePatch(item)
- item.linebgcolor = linebgcolor
-
- self.ax.add_patch(item)
-
- elif shape in ('polygon', 'polylines'):
- points = numpy.array((xView, yView)).T
- if shape == 'polygon':
- closed = True
- else: # shape == 'polylines'
- closed = numpy.all(numpy.equal(points[0], points[-1]))
- item = Polygon(points,
- closed=closed,
- fill=False,
- color=color,
- linestyle=linestyle,
- linewidth=linewidth)
- if fill and shape == 'polygon':
- item.set_hatch('/')
-
- if linestyle != "solid" and linebgcolor is not None:
- item = _DoubleColoredLinePatch(item)
- item.linebgcolor = linebgcolor
-
- self.ax.add_patch(item)
-
- else:
- raise NotImplementedError("Unsupported item shape %s" % shape)
-
- if overlay:
- item.set_animated(True)
-
- return item
-
- def addMarker(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
- textArtist = None
-
- xmin, xmax = self.getGraphXLimits()
- ymin, ymax = self.getGraphYLimits(axis=yaxis)
-
- if yaxis == 'left':
- ax = self.ax
- elif yaxis == 'right':
- ax = self.ax2
- else:
- assert(False)
-
- marker = self._getMarkerFromSymbol(symbol)
- if x is not None and y is not None:
- line = ax.plot(x, y,
- linestyle=" ",
- color=color,
- marker=marker,
- markersize=10.)[-1]
-
- if text is not None:
- textArtist = _TextWithOffset(x, y, text,
- color=color,
- horizontalalignment='left')
- if symbol is not None:
- textArtist.pixel_offset = 10, 3
- elif x is not None:
- line = ax.axvline(x,
- color=color,
- linewidth=linewidth,
- linestyle=linestyle)
- if text is not None:
- # Y position will be updated in updateMarkerText call
- textArtist = _TextWithOffset(x, 1., text,
- color=color,
- horizontalalignment='left',
- verticalalignment='top')
- textArtist.pixel_offset = 5, 3
- elif y is not None:
- line = ax.axhline(y,
- color=color,
- linewidth=linewidth,
- linestyle=linestyle)
-
- if text is not None:
- # X position will be updated in updateMarkerText call
- textArtist = _TextWithOffset(1., y, text,
- color=color,
- horizontalalignment='right',
- verticalalignment='top')
- textArtist.pixel_offset = 5, 3
- else:
- raise RuntimeError('A marker must at least have one coordinate')
-
- line.set_picker(True)
- line.set_pickradius(5)
-
- # All markers are overlays
- line.set_animated(True)
- if textArtist is not None:
- ax.add_artist(textArtist)
- textArtist.set_animated(True)
-
- artists = [line] if textArtist is None else [line, textArtist]
- container = _MarkerContainer(artists, symbol, x, y, yaxis)
- container.updateMarkerText(xmin, xmax, ymin, ymax, self.isYAxisInverted())
-
- return container
-
- def _updateMarkers(self):
- xmin, xmax = self.ax.get_xbound()
- ymin1, ymax1 = self.ax.get_ybound()
- ymin2, ymax2 = self.ax2.get_ybound()
- yinverted = self.isYAxisInverted()
- for item in self._overlayItems():
- if isinstance(item, _MarkerContainer):
- if item.yAxis == 'left':
- item.updateMarkerText(xmin, xmax, ymin1, ymax1, yinverted)
- else:
- item.updateMarkerText(xmin, xmax, ymin2, ymax2, yinverted)
-
- # Remove methods
-
- def remove(self, item):
- try:
- item.remove()
- except ValueError:
- pass # Already removed e.g., in set[X|Y]AxisLogarithmic
-
- # Interaction methods
-
- def setGraphCursor(self, flag, color, linewidth, linestyle):
- if flag:
- lineh = self.ax.axhline(
- self.ax.get_ybound()[0], visible=False, color=color,
- linewidth=linewidth, linestyle=linestyle)
- lineh.set_animated(True)
-
- linev = self.ax.axvline(
- self.ax.get_xbound()[0], visible=False, color=color,
- linewidth=linewidth, linestyle=linestyle)
- linev.set_animated(True)
-
- self._graphCursor = lineh, linev
- else:
- if self._graphCursor:
- lineh, linev = self._graphCursor
- lineh.remove()
- linev.remove()
- self._graphCursor = tuple()
-
- # Active curve
-
- def setCurveColor(self, curve, color):
- # Store Line2D and PathCollection
- for artist in curve.get_children():
- if isinstance(artist, (Line2D, LineCollection)):
- artist.set_color(color)
- elif isinstance(artist, PathCollection):
- artist.set_facecolors(color)
- artist.set_edgecolors(color)
- else:
- _logger.warning(
- 'setActiveCurve ignoring artist %s', str(artist))
-
- # Misc.
-
- def getWidgetHandle(self):
- return self.fig.canvas
-
- def _enableAxis(self, axis, flag=True):
- """Show/hide Y axis
-
- :param str axis: Axis name: 'left' or 'right'
- :param bool flag: Default, True
- """
- assert axis in ('right', 'left')
- axes = self.ax2 if axis == 'right' else self.ax
- axes.get_yaxis().set_visible(flag)
-
- def replot(self):
- """Do not perform rendering.
-
- Override in subclass to actually draw something.
- """
- # TODO images, markers? scatter plot? move in remove?
- # Right Y axis only support curve for now
- # Hide right Y axis if no line is present
- self._dirtyLimits = False
- if not self.ax2.lines:
- self._enableAxis('right', False)
-
- def _drawOverlays(self):
- """Draw overlays if any."""
- def condition(item):
- return (item.isVisible() and
- item._backendRenderer is not None and
- item.isOverlay())
-
- for item in self.getItemsFromBackToFront(condition=condition):
- if (isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right'):
- axes = self.ax2
- else:
- axes = self.ax
- axes.draw_artist(item._backendRenderer)
-
- for item in self._graphCursor:
- self.ax.draw_artist(item)
-
- def updateZOrder(self):
- """Reorder all items with z order from 0 to 1"""
- items = self.getItemsFromBackToFront(
- lambda item: item.isVisible() and item._backendRenderer is not None)
- count = len(items)
- for index, item in enumerate(items):
- if item.getZValue() < 0.5:
- # Make sure matplotlib z order is below the grid (with z=0.5)
- zorder = 0.5 * index / count
- else: # Make sure matplotlib z order is above the grid (> 0.5)
- zorder = 1. + index / count
- if zorder != item._backendRenderer.get_zorder():
- item._backendRenderer.set_zorder(zorder)
-
- def saveGraph(self, fileName, fileFormat, dpi):
- self.updateZOrder()
-
- # fileName can be also a StringIO or file instance
- if dpi is not None:
- self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
- else:
- self.fig.savefig(fileName, format=fileFormat)
- self._plot._setDirtyPlot()
-
- # Graph labels
-
- def setGraphTitle(self, title):
- self.ax.set_title(title)
-
- def setGraphXLabel(self, label):
- self.ax.set_xlabel(label)
-
- def setGraphYLabel(self, label, axis):
- axes = self.ax if axis == 'left' else self.ax2
- axes.set_ylabel(label)
-
- # Graph limits
-
- def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
- # Let matplotlib taking care of keep aspect ratio if any
- self._dirtyLimits = True
- self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
-
- if y2min is not None and y2max is not None:
- if not self.isYAxisInverted():
- self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
- else:
- self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))
-
- if not self.isYAxisInverted():
- self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
- else:
- self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))
-
- self._updateMarkers()
-
- def getGraphXLimits(self):
- if self._dirtyLimits and self.isKeepDataAspectRatio():
- self.ax.apply_aspect()
- self.ax2.apply_aspect()
- self._dirtyLimits = False
- return self.ax.get_xbound()
-
- def setGraphXLimits(self, xmin, xmax):
- self._dirtyLimits = True
- self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
- self._updateMarkers()
-
- def getGraphYLimits(self, axis):
- assert axis in ('left', 'right')
- ax = self.ax2 if axis == 'right' else self.ax
-
- if not ax.get_visible():
- return None
-
- if self._dirtyLimits and self.isKeepDataAspectRatio():
- self.ax.apply_aspect()
- self.ax2.apply_aspect()
- self._dirtyLimits = False
-
- return ax.get_ybound()
-
- def setGraphYLimits(self, ymin, ymax, axis):
- ax = self.ax2 if axis == 'right' else self.ax
- if ymax < ymin:
- ymin, ymax = ymax, ymin
- self._dirtyLimits = True
-
- if self.isKeepDataAspectRatio():
- # matplotlib keeps limits of shared axis when keeping aspect ratio
- # So x limits are kept when changing y limits....
- # Change x limits first by taking into account aspect ratio
- # and then change y limits.. so matplotlib does not need
- # to make change (to y) to keep aspect ratio
- xmin, xmax = ax.get_xbound()
- curYMin, curYMax = ax.get_ybound()
-
- newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
- xcenter = 0.5 * (xmin + xmax)
- ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)
-
- if not self.isYAxisInverted():
- ax.set_ylim(ymin, ymax)
- else:
- ax.set_ylim(ymax, ymin)
-
- self._updateMarkers()
-
- # Graph axes
-
- def setXAxisTimeZone(self, tz):
- super(BackendMatplotlib, self).setXAxisTimeZone(tz)
-
- # Make new formatter and locator with the time zone.
- self.setXAxisTimeSeries(self.isXAxisTimeSeries())
-
- def isXAxisTimeSeries(self):
- return self._isXAxisTimeSeries
-
- def setXAxisTimeSeries(self, isTimeSeries):
- self._isXAxisTimeSeries = isTimeSeries
- if self._isXAxisTimeSeries:
- # We can't use a matplotlib.dates.DateFormatter because it expects
- # the data to be in datetimes. Silx works internally with
- # timestamps (floats).
- locator = NiceDateLocator(tz=self.getXAxisTimeZone())
- self.ax.xaxis.set_major_locator(locator)
- self.ax.xaxis.set_major_formatter(
- NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
- else:
- try:
- scalarFormatter = ScalarFormatter(useOffset=False)
- except:
- _logger.warning('Cannot disabled axes offsets in %s ' %
- matplotlib.__version__)
- scalarFormatter = ScalarFormatter()
- self.ax.xaxis.set_major_formatter(scalarFormatter)
-
- def setXAxisLogarithmic(self, flag):
- # Workaround for matplotlib 2.1.0 when one tries to set an axis
- # to log scale with both limits <= 0
- # In this case a draw with positive limits is needed first
- if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
- xlim = self.ax.get_xlim()
- if xlim[0] <= 0 and xlim[1] <= 0:
- self.ax.set_xlim(1, 10)
- self.draw()
-
- self.ax2.set_xscale('log' if flag else 'linear')
- self.ax.set_xscale('log' if flag else 'linear')
-
- def setYAxisLogarithmic(self, flag):
- # Workaround for matplotlib 2.0 issue with negative bounds
- # before switching to log scale
- if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
- redraw = False
- for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
- ylim = axis.get_ylim()
- if ylim[0] <= 0 or ylim[1] <= 0:
- dataRange = self._plot.getDataRange()[dataRangeIndex]
- if dataRange is None:
- dataRange = 1, 100 # Fallback
- axis.set_ylim(*dataRange)
- redraw = True
- if redraw:
- self.draw()
-
- self.ax2.set_yscale('log' if flag else 'linear')
- self.ax.set_yscale('log' if flag else 'linear')
-
- def setYAxisInverted(self, flag):
- if self.ax.yaxis_inverted() != bool(flag):
- self.ax.invert_yaxis()
- self._updateMarkers()
-
- def isYAxisInverted(self):
- return self.ax.yaxis_inverted()
-
- def isKeepDataAspectRatio(self):
- return self.ax.get_aspect() in (1.0, 'equal')
-
- def setKeepDataAspectRatio(self, flag):
- self.ax.set_aspect(1.0 if flag else 'auto')
- self.ax2.set_aspect(1.0 if flag else 'auto')
-
- def setGraphGrid(self, which):
- self.ax.grid(False, which='both') # Disable all grid first
- if which is not None:
- self.ax.grid(True, which=which)
-
- # Data <-> Pixel coordinates conversion
-
- def _getDevicePixelRatio(self) -> float:
- """Compatibility wrapper for devicePixelRatioF"""
- return 1.
-
- def _mplToQtPosition(self, x: float, y: float) -> Tuple[float, float]:
- """Convert matplotlib "display" space coord to Qt widget logical pixel
- """
- ratio = self._getDevicePixelRatio()
- # Convert from matplotlib origin (bottom) to Qt origin (top)
- # and apply device pixel ratio
- return x / ratio, (self.fig.get_window_extent().height - y) / ratio
-
- def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]:
- """Convert Qt widget logical pixel to matplotlib "display" space coord
- """
- ratio = self._getDevicePixelRatio()
- # Apply device pixel ration and
- # convert from Qt origin (top) to matplotlib origin (bottom)
- return x * ratio, self.fig.get_window_extent().height - (y * ratio)
-
- def dataToPixel(self, x, y, axis):
- ax = self.ax2 if axis == "right" else self.ax
- displayPos = ax.transData.transform_point((x, y)).transpose()
- return self._mplToQtPosition(*displayPos)
-
- def pixelToData(self, x, y, axis):
- ax = self.ax2 if axis == "right" else self.ax
- displayPos = self._qtToMplPosition(x, y)
- return tuple(ax.transData.inverted().transform_point(displayPos))
-
- def getPlotBoundsInPixels(self):
- bbox = self.ax.get_window_extent()
- # Warning this is not returning int...
- ratio = self._getDevicePixelRatio()
- return tuple(int(value / ratio) for value in (
- bbox.xmin,
- self.fig.get_window_extent().height - bbox.ymax,
- bbox.width,
- bbox.height))
-
- def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
- width, height = 1. - left - right, 1. - top - bottom
- position = left, bottom, width, height
-
- # Toggle display of axes and viewbox rect
- isFrameOn = position != (0., 0., 1., 1.)
- self.ax.set_frame_on(isFrameOn)
- self.ax2.set_frame_on(isFrameOn)
-
- self.ax.set_position(position)
- self.ax2.set_position(position)
-
- self._synchronizeBackgroundColors()
- self._synchronizeForegroundColors()
- self._plot._setDirtyPlot()
-
- def _synchronizeBackgroundColors(self):
- backgroundColor = self._plot.getBackgroundColor().getRgbF()
-
- dataBackgroundColor = self._plot.getDataBackgroundColor()
- if dataBackgroundColor.isValid():
- dataBackgroundColor = dataBackgroundColor.getRgbF()
- else:
- dataBackgroundColor = backgroundColor
-
- if self.ax.get_frame_on():
- self.fig.patch.set_facecolor(backgroundColor)
- if self._matplotlibVersion < _parse_version('2'):
- self.ax.set_axis_bgcolor(dataBackgroundColor)
- else:
- self.ax.set_facecolor(dataBackgroundColor)
- else:
- self.fig.patch.set_facecolor(dataBackgroundColor)
-
- def _synchronizeForegroundColors(self):
- foregroundColor = self._plot.getForegroundColor().getRgbF()
-
- gridColor = self._plot.getGridColor()
- if gridColor.isValid():
- gridColor = gridColor.getRgbF()
- else:
- gridColor = foregroundColor
-
- for axes in (self.ax, self.ax2):
- if axes.get_frame_on():
- axes.spines['bottom'].set_color(foregroundColor)
- axes.spines['top'].set_color(foregroundColor)
- axes.spines['right'].set_color(foregroundColor)
- axes.spines['left'].set_color(foregroundColor)
- axes.tick_params(axis='x', colors=foregroundColor)
- axes.tick_params(axis='y', colors=foregroundColor)
- axes.yaxis.label.set_color(foregroundColor)
- axes.xaxis.label.set_color(foregroundColor)
- axes.title.set_color(foregroundColor)
-
- for line in axes.get_xgridlines():
- line.set_color(gridColor)
-
- for line in axes.get_ygridlines():
- line.set_color(gridColor)
- # axes.grid().set_markeredgecolor(gridColor)
-
- def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
- self._synchronizeBackgroundColors()
-
- def setForegroundColors(self, foregroundColor, gridColor):
- self._synchronizeForegroundColors()
-
-
-class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
- """QWidget matplotlib backend using a QtAgg canvas.
-
- It adds fast overlay drawing and mouse event management.
- """
-
- _sigPostRedisplay = qt.Signal()
- """Signal handling automatic asynchronous replot"""
-
- def __init__(self, plot, parent=None):
- BackendMatplotlib.__init__(self, plot, parent)
- FigureCanvasQTAgg.__init__(self, self.fig)
- self.setParent(parent)
-
- self._limitsBeforeResize = None
-
- FigureCanvasQTAgg.setSizePolicy(
- self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
- FigureCanvasQTAgg.updateGeometry(self)
-
- # Make postRedisplay asynchronous using Qt signal
- self._sigPostRedisplay.connect(
- super(BackendMatplotlibQt, self).postRedisplay,
- qt.Qt.QueuedConnection)
-
- self._picked = None
-
- self.mpl_connect('button_press_event', self._onMousePress)
- self.mpl_connect('button_release_event', self._onMouseRelease)
- self.mpl_connect('motion_notify_event', self._onMouseMove)
- self.mpl_connect('scroll_event', self._onMouseWheel)
-
- def postRedisplay(self):
- self._sigPostRedisplay.emit()
-
- def _getDevicePixelRatio(self) -> float:
- """Compatibility wrapper for devicePixelRatioF"""
- if hasattr(self, 'devicePixelRatioF'):
- ratio = self.devicePixelRatioF()
- else: # Qt < 5.6 compatibility
- ratio = float(self.devicePixelRatio())
- # Safety net: avoid returning 0
- return ratio if ratio != 0. else 1.
-
- # Mouse event forwarding
-
- _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'}
-
- def _onMousePress(self, event):
- button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
- if button is not None:
- x, y = self._mplToQtPosition(event.x, event.y)
- self._plot.onMousePress(int(x), int(y), button)
-
- def _onMouseMove(self, event):
- x, y = self._mplToQtPosition(event.x, event.y)
- if self._graphCursor:
- position = self._plot.pixelToData(
- x, y, axis='left', check=True)
- lineh, linev = self._graphCursor
- if position is not None:
- linev.set_visible(True)
- linev.set_xdata((position[0], position[0]))
- lineh.set_visible(True)
- lineh.set_ydata((position[1], position[1]))
- self._plot._setDirtyPlot(overlayOnly=True)
- elif lineh.get_visible():
- lineh.set_visible(False)
- linev.set_visible(False)
- self._plot._setDirtyPlot(overlayOnly=True)
- # onMouseMove must trigger replot if dirty flag is raised
-
- self._plot.onMouseMove(int(x), int(y))
-
- def _onMouseRelease(self, event):
- button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
- if button is not None:
- x, y = self._mplToQtPosition(event.x, event.y)
- self._plot.onMouseRelease(int(x), int(y), button)
-
- def _onMouseWheel(self, event):
- x, y = self._mplToQtPosition(event.x, event.y)
- self._plot.onMouseWheel(int(x), int(y), event.step)
-
- def leaveEvent(self, event):
- """QWidget event handler"""
- try:
- plot = self._plot
- except RuntimeError:
- pass
- else:
- plot.onMouseLeaveWidget()
-
- # picking
-
- def pickItem(self, x, y, item):
- xDisplay, yDisplay = self._qtToMplPosition(x, y)
- mouseEvent = MouseEvent(
- 'button_press_event', self, int(xDisplay), int(yDisplay))
- # Override axes and data position with the axes
- mouseEvent.inaxes = item.axes
- mouseEvent.xdata, mouseEvent.ydata = self.pixelToData(
- x, y, axis='left' if item.axes is self.ax else 'right')
- picked, info = item.contains(mouseEvent)
-
- if not picked:
- return None
-
- elif isinstance(item, TriMesh):
- # Convert selected triangle to data point indices
- triangulation = item._triangulation
- indices = triangulation.get_masked_triangles()[info['ind'][0]]
-
- # Sort picked triangle points by distance to mouse
- # from furthest to closest to put closest point last
- # This is to be somewhat consistent with last scatter point
- # being the top one.
- xdata, ydata = self.pixelToData(x, y, axis='left')
- dists = ((triangulation.x[indices] - xdata) ** 2 +
- (triangulation.y[indices] - ydata) ** 2)
- return indices[numpy.flip(numpy.argsort(dists), axis=0)]
-
- else: # Returns indices if any
- return info.get('ind', ())
-
- # replot control
-
- def resizeEvent(self, event):
- # Store current limits
- self._limitsBeforeResize = (
- self.ax.get_xbound(), self.ax.get_ybound(), self.ax2.get_ybound())
-
- FigureCanvasQTAgg.resizeEvent(self, event)
- if self.isKeepDataAspectRatio() or self._hasOverlays():
- # This is needed with matplotlib 1.5.x and 2.0.x
- self._plot._setDirtyPlot()
-
- def draw(self):
- """Overload draw
-
- It performs a full redraw (including overlays) of the plot.
- It also resets background and emit limits changed signal.
-
- This is directly called by matplotlib for widget resize.
- """
- self.updateZOrder()
-
- # Starting with mpl 2.1.0, toggling autoscale raises a ValueError
- # in some situations. See #1081, #1136, #1163,
- if self._matplotlibVersion >= _parse_version("2.0.0"):
- try:
- FigureCanvasQTAgg.draw(self)
- except ValueError as err:
- _logger.debug(
- "ValueError caught while calling FigureCanvasQTAgg.draw: "
- "'%s'", err)
- else:
- FigureCanvasQTAgg.draw(self)
-
- if self._hasOverlays():
- # Save background
- self._background = self.copy_from_bbox(self.fig.bbox)
- else:
- self._background = None # Reset background
-
- # Check if limits changed due to a resize of the widget
- if self._limitsBeforeResize is not None:
- xLimits, yLimits, yRightLimits = self._limitsBeforeResize
- self._limitsBeforeResize = None
-
- if (xLimits != self.ax.get_xbound() or
- yLimits != self.ax.get_ybound()):
- self._updateMarkers()
-
- if xLimits != self.ax.get_xbound():
- self._plot.getXAxis()._emitLimitsChanged()
- if yLimits != self.ax.get_ybound():
- self._plot.getYAxis(axis='left')._emitLimitsChanged()
- if yRightLimits != self.ax2.get_ybound():
- self._plot.getYAxis(axis='right')._emitLimitsChanged()
-
- self._drawOverlays()
-
- def replot(self):
- BackendMatplotlib.replot(self)
-
- dirtyFlag = self._plot._getDirtyPlot()
-
- if dirtyFlag == 'overlay':
- # Only redraw overlays using fast rendering path
- if self._background is None:
- self._background = self.copy_from_bbox(self.fig.bbox)
- self.restore_region(self._background)
- self._drawOverlays()
- self.blit(self.fig.bbox)
-
- elif dirtyFlag: # Need full redraw
- self.draw()
-
- # Workaround issue of rendering overlays with some matplotlib versions
- if (_parse_version('1.5') <= self._matplotlibVersion < _parse_version('2.1') and
- not hasattr(self, '_firstReplot')):
- self._firstReplot = False
- if self._hasOverlays():
- qt.QTimer.singleShot(0, self.draw) # Request async draw
-
- # cursor
-
- _QT_CURSORS = {
- BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
- BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
- BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
- BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
- BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
- }
-
- def setGraphCursorShape(self, cursor):
- if cursor is None:
- FigureCanvasQTAgg.unsetCursor(self)
- else:
- cursor = self._QT_CURSORS[cursor]
- FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
diff --git a/silx/gui/plot/backends/BackendOpenGL.py b/silx/gui/plot/backends/BackendOpenGL.py
deleted file mode 100755
index 6fde9df..0000000
--- a/silx/gui/plot/backends/BackendOpenGL.py
+++ /dev/null
@@ -1,1420 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""OpenGL Plot backend."""
-
-from __future__ import division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "21/12/2018"
-
-import logging
-import weakref
-
-import numpy
-
-from .. import items
-from .._utils import FLOAT32_MINPOS
-from . import BackendBase
-from ... import colors
-from ... import qt
-
-from ..._glutils import gl
-from ... import _glutils as glu
-from . import glutils
-from .glutils.PlotImageFile import saveImageToFile
-
-_logger = logging.getLogger(__name__)
-
-
-# TODO idea: BackendQtMixIn class to share code between mpl and gl
-# TODO check if OpenGL is available
-# TODO make an off-screen mesa backend
-
-# Content #####################################################################
-
-class _ShapeItem(dict):
- def __init__(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
- super(_ShapeItem, self).__init__()
-
- if shape not in ('polygon', 'rectangle', 'line',
- 'vline', 'hline', 'polylines'):
- raise NotImplementedError("Unsupported shape {0}".format(shape))
-
- x = numpy.array(x, copy=False)
- y = numpy.array(y, copy=False)
-
- if shape == 'rectangle':
- xMin, xMax = x
- x = numpy.array((xMin, xMin, xMax, xMax))
- yMin, yMax = y
- y = numpy.array((yMin, yMax, yMax, yMin))
-
- # Ignore fill for polylines to mimic matplotlib
- fill = fill if shape != 'polylines' else False
-
- self.update({
- 'shape': shape,
- 'color': colors.rgba(color),
- 'fill': 'hatch' if fill else None,
- 'x': x,
- 'y': y,
- 'linestyle': linestyle,
- 'linewidth': linewidth,
- 'linebgcolor': linebgcolor,
- })
-
-
-class _MarkerItem(dict):
- def __init__(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
- super(_MarkerItem, self).__init__()
-
- if symbol is None:
- symbol = '+'
-
- # Apply constraint to provided position
- isConstraint = (constraint is not None and
- x is not None and y is not None)
- if isConstraint:
- x, y = constraint(x, y)
-
- self.update({
- 'x': x,
- 'y': y,
- 'text': text,
- 'color': colors.rgba(color),
- 'constraint': constraint if isConstraint else None,
- 'symbol': symbol,
- 'linestyle': linestyle,
- 'linewidth': linewidth,
- 'yaxis': yaxis,
- })
-
-
-# shaders #####################################################################
-
-_baseVertShd = """
- attribute vec2 position;
- uniform mat4 matrix;
- uniform bvec2 isLog;
-
- const float oneOverLog10 = 0.43429448190325176;
-
- void main(void) {
- vec2 posTransformed = position;
- if (isLog.x) {
- posTransformed.x = oneOverLog10 * log(position.x);
- }
- if (isLog.y) {
- posTransformed.y = oneOverLog10 * log(position.y);
- }
- gl_Position = matrix * vec4(posTransformed, 0.0, 1.0);
- }
- """
-
-_baseFragShd = """
- uniform vec4 color;
- uniform int hatchStep;
- uniform float tickLen;
-
- void main(void) {
- if (tickLen != 0.) {
- if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) {
- gl_FragColor = color;
- } else {
- discard;
- }
- } else if (hatchStep == 0 ||
- mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) {
- gl_FragColor = color;
- } else {
- discard;
- }
- }
- """
-
-_texVertShd = """
- attribute vec2 position;
- attribute vec2 texCoords;
- uniform mat4 matrix;
-
- varying vec2 coords;
-
- void main(void) {
- gl_Position = matrix * vec4(position, 0.0, 1.0);
- coords = texCoords;
- }
- """
-
-_texFragShd = """
- uniform sampler2D tex;
-
- varying vec2 coords;
-
- void main(void) {
- gl_FragColor = texture2D(tex, coords);
- gl_FragColor.a = 1.0;
- }
- """
-
-# BackendOpenGL ###############################################################
-
-
-class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
- """OpenGL-based Plot backend.
-
- WARNINGS:
- Unless stated otherwise, this API is NOT thread-safe and MUST be
- called from the main thread.
- When numpy arrays are passed as arguments to the API (through
- :func:`addCurve` and :func:`addImage`), they are copied only if
- required.
- So, the caller should not modify these arrays afterwards.
- """
-
- _sigPostRedisplay = qt.Signal()
- """Signal handling automatic asynchronous replot"""
-
- def __init__(self, plot, parent=None, f=qt.Qt.WindowFlags()):
- glu.OpenGLWidget.__init__(self, parent,
- alphaBufferSize=8,
- depthBufferSize=0,
- stencilBufferSize=0,
- version=(2, 1),
- f=f)
- BackendBase.BackendBase.__init__(self, plot, parent)
-
- self._backgroundColor = 1., 1., 1., 1.
- self._dataBackgroundColor = 1., 1., 1., 1.
-
- self.matScreenProj = glutils.mat4Identity()
-
- self._progBase = glu.Program(
- _baseVertShd, _baseFragShd, attrib0='position')
- self._progTex = glu.Program(
- _texVertShd, _texFragShd, attrib0='position')
- self._plotFBOs = weakref.WeakKeyDictionary()
-
- self._keepDataAspectRatio = False
-
- self._crosshairCursor = None
- self._mousePosInPixels = None
-
- self._glGarbageCollector = []
-
- self._plotFrame = glutils.GLPlotFrame2D(
- foregroundColor=(0., 0., 0., 1.),
- gridColor=(.7, .7, .7, 1.),
- marginRatios=(.15, .1, .1, .15))
- self._plotFrame.size = ( # Init size with size int
- int(self.getDevicePixelRatio() * 640),
- int(self.getDevicePixelRatio() * 480))
-
- # Make postRedisplay asynchronous using Qt signal
- self._sigPostRedisplay.connect(
- super(BackendOpenGL, self).postRedisplay,
- qt.Qt.QueuedConnection)
-
- self.setAutoFillBackground(False)
- self.setMouseTracking(True)
-
- # QWidget
-
- _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
-
- def sizeHint(self):
- return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend
-
- def mousePressEvent(self, event):
- if event.button() not in self._MOUSE_BTNS:
- return super(BackendOpenGL, self).mousePressEvent(event)
- self._plot.onMousePress(
- event.x(), event.y(), self._MOUSE_BTNS[event.button()])
- event.accept()
-
- def mouseMoveEvent(self, event):
- qtPos = event.x(), event.y()
-
- previousMousePosInPixels = self._mousePosInPixels
- if qtPos == self._mouseInPlotArea(*qtPos):
- devicePixelRatio = self.getDevicePixelRatio()
- devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio
- self._mousePosInPixels = devicePos # Mouse in plot area
- else:
- self._mousePosInPixels = None # Mouse outside plot area
-
- if (self._crosshairCursor is not None and
- previousMousePosInPixels != self._mousePosInPixels):
- # Avoid replot when cursor remains outside plot area
- self._plot._setDirtyPlot(overlayOnly=True)
-
- self._plot.onMouseMove(*qtPos)
- event.accept()
-
- def mouseReleaseEvent(self, event):
- if event.button() not in self._MOUSE_BTNS:
- return super(BackendOpenGL, self).mouseReleaseEvent(event)
- self._plot.onMouseRelease(
- event.x(), event.y(), self._MOUSE_BTNS[event.button()])
- event.accept()
-
- def wheelEvent(self, event):
- if hasattr(event, 'angleDelta'): # Qt 5
- delta = event.angleDelta().y()
- else: # Qt 4 support
- delta = event.delta()
- angleInDegrees = delta / 8.
- self._plot.onMouseWheel(event.x(), event.y(), angleInDegrees)
- event.accept()
-
- def leaveEvent(self, _):
- self._plot.onMouseLeaveWidget()
-
- # OpenGLWidget API
-
- def initializeGL(self):
- gl.testGL()
-
- gl.glClearStencil(0)
-
- gl.glEnable(gl.GL_BLEND)
- # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
- gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA,
- gl.GL_ONE_MINUS_SRC_ALPHA,
- gl.GL_ONE,
- gl.GL_ONE)
-
- # For lines
- gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
-
- # For points
- gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
- gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
- # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
-
- def _paintDirectGL(self):
- self._renderPlotAreaGL()
- self._plotFrame.render()
- self._renderOverlayGL()
-
- def _paintFBOGL(self):
- context = glu.Context.getCurrent()
- plotFBOTex = self._plotFBOs.get(context)
- if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or
- plotFBOTex is None):
- self._plotVertices = (
- # Vertex coordinates
- numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)),
- dtype=numpy.float32),
- # Texture coordinates
- numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
- dtype=numpy.float32))
- if plotFBOTex is None or \
- plotFBOTex.shape[1] != self._plotFrame.size[0] or \
- plotFBOTex.shape[0] != self._plotFrame.size[1]:
- if plotFBOTex is not None:
- plotFBOTex.discard()
- plotFBOTex = glu.FramebufferTexture(
- gl.GL_RGBA,
- shape=(self._plotFrame.size[1],
- self._plotFrame.size[0]),
- minFilter=gl.GL_NEAREST,
- magFilter=gl.GL_NEAREST,
- wrap=(gl.GL_CLAMP_TO_EDGE,
- gl.GL_CLAMP_TO_EDGE))
- self._plotFBOs[context] = plotFBOTex
-
- with plotFBOTex:
- gl.glClearColor(*self._backgroundColor)
- gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
- self._renderPlotAreaGL()
- self._plotFrame.render()
-
- # Render plot in screen coords
- gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
-
- self._progTex.use()
- texUnit = 0
-
- gl.glUniform1i(self._progTex.uniforms['tex'], texUnit)
- gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE,
- glutils.mat4Identity().astype(numpy.float32))
-
- gl.glEnableVertexAttribArray(self._progTex.attributes['position'])
- gl.glVertexAttribPointer(self._progTex.attributes['position'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._plotVertices[0])
-
- gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords'])
- gl.glVertexAttribPointer(self._progTex.attributes['texCoords'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._plotVertices[1])
-
- with plotFBOTex.texture:
- gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0]))
-
- self._renderOverlayGL()
-
- def paintGL(self):
- with glu.Context.current(self.context()):
- # Release OpenGL resources
- for item in self._glGarbageCollector:
- item.discard()
- self._glGarbageCollector = []
-
- gl.glClearColor(*self._backgroundColor)
- gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
-
- # Check if window is large enough
- if self._plotFrame.plotSize <= (2, 2):
- return
-
- # Sync plot frame with window
- self._plotFrame.devicePixelRatio = self.getDevicePixelRatio()
- # self._paintDirectGL()
- self._paintFBOGL()
-
- def _renderItems(self, overlay=False):
- """Render items according to :class:`PlotWidget` order
-
- Note: Scissor test should already be set.
-
- :param bool overlay:
- False (the default) to render item that are not overlays.
- True to render items that are overlays.
- """
- # Values that are often used
- plotWidth, plotHeight = self._plotFrame.plotSize
- isXLog = self._plotFrame.xAxis.isLog
- isYLog = self._plotFrame.yAxis.isLog
- isYInverted = self._plotFrame.isYAxisInverted
-
- # Used by marker rendering
- labels = []
- pixelOffset = 3
-
- context = glutils.RenderContext(
- isXLog=isXLog, isYLog=isYLog, dpi=self.getDotsPerInch())
-
- for plotItem in self.getItemsFromBackToFront(
- condition=lambda i: i.isVisible() and i.isOverlay() == overlay):
- if plotItem._backendRenderer is None:
- continue
-
- item = plotItem._backendRenderer
-
- if isinstance(item, glutils.GLPlotItem): # Render data items
- gl.glViewport(self._plotFrame.margins.left,
- self._plotFrame.margins.bottom,
- plotWidth, plotHeight)
- # Set matrix
- if item.yaxis == 'right':
- context.matrix = self._plotFrame.transformedDataY2ProjMat
- else:
- context.matrix = self._plotFrame.transformedDataProjMat
- item.render(context)
-
- elif isinstance(item, _ShapeItem): # Render shape items
- gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
-
- if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or
- (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)):
- # Ignore items <= 0. on log axes
- continue
-
- if item['shape'] == 'hline':
- width = self._plotFrame.size[0]
- _, yPixel = self._plotFrame.dataToPixel(
- 0.5 * sum(self._plotFrame.dataRanges[0]),
- item['y'],
- axis='left')
- subShapes = [numpy.array(((0., yPixel), (width, yPixel)),
- dtype=numpy.float32)]
-
- elif item['shape'] == 'vline':
- xPixel, _ = self._plotFrame.dataToPixel(
- item['x'],
- 0.5 * sum(self._plotFrame.dataRanges[1]),
- axis='left')
- height = self._plotFrame.size[1]
- subShapes = [numpy.array(((xPixel, 0), (xPixel, height)),
- dtype=numpy.float32)]
-
- else:
- # Split sub-shapes at not finite values
- splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
- numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0]
- splits = numpy.concatenate(([-1], splits, [len(item['x'])]))
- subShapes = []
- for begin, end in zip(splits[:-1] + 1, splits[1:]):
- if end > begin:
- subShapes.append(numpy.array([
- self._plotFrame.dataToPixel(x, y, axis='left')
- for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])]))
-
- for points in subShapes: # Draw each sub-shape
- # Draw the fill
- if (item['fill'] is not None and
- item['shape'] not in ('hline', 'vline')):
- self._progBase.use()
- gl.glUniformMatrix4fv(
- self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
-
- shape2D = glutils.FilledShape2D(
- points, style=item['fill'], color=item['color'])
- shape2D.render(
- posAttrib=self._progBase.attributes['position'],
- colorUnif=self._progBase.uniforms['color'],
- hatchStepUnif=self._progBase.uniforms['hatchStep'])
-
- # Draw the stroke
- if item['linestyle'] not in ('', ' ', None):
- if item['shape'] != 'polylines':
- # close the polyline
- points = numpy.append(points,
- numpy.atleast_2d(points[0]), axis=0)
-
- lines = glutils.GLLines2D(
- points[:, 0], points[:, 1],
- style=item['linestyle'],
- color=item['color'],
- dash2ndColor=item['linebgcolor'],
- width=item['linewidth'])
- context.matrix = self.matScreenProj
- lines.render(context)
-
- elif isinstance(item, _MarkerItem):
- gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
-
- xCoord, yCoord, yAxis = item['x'], item['y'], item['yaxis']
-
- if ((isXLog and xCoord is not None and xCoord <= 0) or
- (isYLog and yCoord is not None and yCoord <= 0)):
- # Do not render markers with negative coords on log axis
- continue
-
- if xCoord is None or yCoord is None:
- if xCoord is None: # Horizontal line in data space
- pixelPos = self._plotFrame.dataToPixel(
- 0.5 * sum(self._plotFrame.dataRanges[0]),
- yCoord,
- axis=yAxis)
-
- if item['text'] is not None:
- x = self._plotFrame.size[0] - \
- self._plotFrame.margins.right - pixelOffset
- y = pixelPos[1] - pixelOffset
- label = glutils.Text2D(
- item['text'], x, y,
- color=item['color'],
- bgColor=(1., 1., 1., 0.5),
- align=glutils.RIGHT,
- valign=glutils.BOTTOM,
- devicePixelRatio=self.getDevicePixelRatio())
- labels.append(label)
-
- width = self._plotFrame.size[0]
- lines = glutils.GLLines2D(
- (0, width), (pixelPos[1], pixelPos[1]),
- style=item['linestyle'],
- color=item['color'],
- width=item['linewidth'])
- context.matrix = self.matScreenProj
- lines.render(context)
-
- else: # yCoord is None: vertical line in data space
- yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2]
- pixelPos = self._plotFrame.dataToPixel(
- xCoord, 0.5 * sum(yRange), axis=yAxis)
-
- if item['text'] is not None:
- x = pixelPos[0] + pixelOffset
- y = self._plotFrame.margins.top + pixelOffset
- label = glutils.Text2D(
- item['text'], x, y,
- color=item['color'],
- bgColor=(1., 1., 1., 0.5),
- align=glutils.LEFT,
- valign=glutils.TOP,
- devicePixelRatio=self.getDevicePixelRatio())
- labels.append(label)
-
- height = self._plotFrame.size[1]
- lines = glutils.GLLines2D(
- (pixelPos[0], pixelPos[0]), (0, height),
- style=item['linestyle'],
- color=item['color'],
- width=item['linewidth'])
- context.matrix = self.matScreenProj
- lines.render(context)
-
- else:
- xmin, xmax = self._plot.getXAxis().getLimits()
- ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits()
- if not xmin < xCoord < xmax or not ymin < yCoord < ymax:
- # Do not render markers outside visible plot area
- continue
- pixelPos = self._plotFrame.dataToPixel(
- xCoord, yCoord, axis=yAxis)
-
- if isYInverted:
- valign = glutils.BOTTOM
- vPixelOffset = -pixelOffset
- else:
- valign = glutils.TOP
- vPixelOffset = pixelOffset
-
- if item['text'] is not None:
- x = pixelPos[0] + pixelOffset
- y = pixelPos[1] + vPixelOffset
- label = glutils.Text2D(
- item['text'], x, y,
- color=item['color'],
- bgColor=(1., 1., 1., 0.5),
- align=glutils.LEFT,
- valign=valign,
- devicePixelRatio=self.getDevicePixelRatio())
- labels.append(label)
-
- # For now simple implementation: using a curve for each marker
- # Should pack all markers to a single set of points
- markerCurve = glutils.GLPlotCurve2D(
- numpy.array((pixelPos[0],), dtype=numpy.float64),
- numpy.array((pixelPos[1],), dtype=numpy.float64),
- marker=item['symbol'],
- markerColor=item['color'],
- markerSize=11)
-
- context = glutils.RenderContext(
- matrix=self.matScreenProj,
- isXLog=False,
- isYLog=False,
- dpi=self.getDotsPerInch())
- markerCurve.render(context)
- markerCurve.discard()
-
- else:
- _logger.error('Unsupported item: %s', str(item))
- continue
-
- # Render marker labels
- gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
- for label in labels:
- label.render(self.matScreenProj)
-
- def _renderOverlayGL(self):
- """Render overlay layer: overlay items and crosshair."""
- plotWidth, plotHeight = self._plotFrame.plotSize
-
- # Scissor to plot area
- gl.glScissor(self._plotFrame.margins.left,
- self._plotFrame.margins.bottom,
- plotWidth, plotHeight)
- gl.glEnable(gl.GL_SCISSOR_TEST)
-
- self._renderItems(overlay=True)
-
- # Render crosshair cursor
- if self._crosshairCursor is not None and self._mousePosInPixels is not None:
- self._progBase.use()
- gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
- gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
- posAttrib = self._progBase.attributes['position']
- matrixUnif = self._progBase.uniforms['matrix']
- colorUnif = self._progBase.uniforms['color']
- hatchStepUnif = self._progBase.uniforms['hatchStep']
-
- gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
-
- gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE,
- self.matScreenProj.astype(numpy.float32))
-
- color, lineWidth = self._crosshairCursor
- gl.glUniform4f(colorUnif, *color)
- gl.glUniform1i(hatchStepUnif, 0)
-
- xPixel, yPixel = self._mousePosInPixels
- xPixel, yPixel = xPixel + 0.5, yPixel + 0.5
- vertices = numpy.array(((0., yPixel),
- (self._plotFrame.size[0], yPixel),
- (xPixel, 0.),
- (xPixel, self._plotFrame.size[1])),
- dtype=numpy.float32)
-
- gl.glEnableVertexAttribArray(posAttrib)
- gl.glVertexAttribPointer(posAttrib,
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, vertices)
- gl.glLineWidth(lineWidth)
- gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
-
- gl.glDisable(gl.GL_SCISSOR_TEST)
-
- def _renderPlotAreaGL(self):
- """Render base layer of plot area.
-
- It renders the background, grid and items except overlays
- """
- plotWidth, plotHeight = self._plotFrame.plotSize
-
- gl.glScissor(self._plotFrame.margins.left,
- self._plotFrame.margins.bottom,
- plotWidth, plotHeight)
- gl.glEnable(gl.GL_SCISSOR_TEST)
-
- if self._dataBackgroundColor != self._backgroundColor:
- gl.glClearColor(*self._dataBackgroundColor)
- gl.glClear(gl.GL_COLOR_BUFFER_BIT)
-
- self._plotFrame.renderGrid()
-
- # Matrix
- trBounds = self._plotFrame.transformedDataRanges
- if trBounds.x[0] != trBounds.x[1] and trBounds.y[0] != trBounds.y[1]:
- # Do rendering of items
- self._renderItems(overlay=False)
-
- gl.glDisable(gl.GL_SCISSOR_TEST)
-
- def resizeGL(self, width, height):
- if width == 0 or height == 0: # Do not resize
- return
-
- self._plotFrame.size = (
- int(self.getDevicePixelRatio() * width),
- int(self.getDevicePixelRatio() * height))
-
- self.matScreenProj = glutils.mat4Ortho(
- 0, self._plotFrame.size[0],
- self._plotFrame.size[1], 0,
- 1, -1)
-
- # Store current ranges
- previousXRange = self.getGraphXLimits()
- previousYRange = self.getGraphYLimits(axis='left')
- previousYRightRange = self.getGraphYLimits(axis='right')
-
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
- self._plotFrame.dataRanges
- self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
-
- # If plot range has changed, then emit signal
- if previousXRange != self.getGraphXLimits():
- self._plot.getXAxis()._emitLimitsChanged()
- if previousYRange != self.getGraphYLimits(axis='left'):
- self._plot.getYAxis(axis='left')._emitLimitsChanged()
- if previousYRightRange != self.getGraphYLimits(axis='right'):
- self._plot.getYAxis(axis='right')._emitLimitsChanged()
-
- # Add methods
-
- @staticmethod
- def _castArrayTo(v):
- """Returns best floating type to cast the array to.
-
- :param numpy.ndarray v: Array to cast
- :rtype: numpy.dtype
- :raise ValueError: If dtype is not supported
- """
- if numpy.issubdtype(v.dtype, numpy.floating):
- return numpy.float32 if v.itemsize <= 4 else numpy.float64
- elif numpy.issubdtype(v.dtype, numpy.integer):
- return numpy.float32 if v.itemsize <= 2 else numpy.float64
- else:
- raise ValueError('Unsupported data type')
-
- def addCurve(self, x, y,
- color, symbol, linewidth, linestyle,
- yaxis,
- xerror, yerror,
- fill, alpha, symbolsize, baseline):
- for parameter in (x, y, color, symbol, linewidth, linestyle,
- yaxis, fill, symbolsize):
- assert parameter is not None
- assert yaxis in ('left', 'right')
-
- # Convert input data
- x = numpy.array(x, copy=False)
- y = numpy.array(y, copy=False)
-
- # Check if float32 is enough
- if (self._castArrayTo(x) is numpy.float32 and
- self._castArrayTo(y) is numpy.float32):
- dtype = numpy.float32
- else:
- dtype = numpy.float64
-
- x = numpy.array(x, dtype=dtype, copy=False, order='C')
- y = numpy.array(y, dtype=dtype, copy=False, order='C')
-
- # Convert errors to float32
- if xerror is not None:
- xerror = numpy.array(
- xerror, dtype=numpy.float32, copy=False, order='C')
- if yerror is not None:
- yerror = numpy.array(
- yerror, dtype=numpy.float32, copy=False, order='C')
-
- # Handle axes log scale: convert data
-
- if self._plotFrame.xAxis.isLog:
- logX = numpy.log10(x)
-
- if xerror is not None:
- # Transform xerror so that
- # log10(x) +/- xerror' = log10(x +/- xerror)
- if hasattr(xerror, 'shape') and len(xerror.shape) == 2:
- xErrorMinus, xErrorPlus = xerror[0], xerror[1]
- else:
- xErrorMinus, xErrorPlus = xerror, xerror
- with numpy.errstate(divide='ignore', invalid='ignore'):
- # Ignore divide by zero, invalid value encountered in log10
- xErrorMinus = logX - numpy.log10(x - xErrorMinus)
- xErrorPlus = numpy.log10(x + xErrorPlus) - logX
- xerror = numpy.array((xErrorMinus, xErrorPlus),
- dtype=numpy.float32)
-
- x = logX
-
- isYLog = (yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
- yaxis == 'right' and self._plotFrame.y2Axis.isLog)
-
- if isYLog:
- logY = numpy.log10(y)
-
- if yerror is not None:
- # Transform yerror so that
- # log10(y) +/- yerror' = log10(y +/- yerror)
- if hasattr(yerror, 'shape') and len(yerror.shape) == 2:
- yErrorMinus, yErrorPlus = yerror[0], yerror[1]
- else:
- yErrorMinus, yErrorPlus = yerror, yerror
- with numpy.errstate(divide='ignore', invalid='ignore'):
- # Ignore divide by zero, invalid value encountered in log10
- yErrorMinus = logY - numpy.log10(y - yErrorMinus)
- yErrorPlus = numpy.log10(y + yErrorPlus) - logY
- yerror = numpy.array((yErrorMinus, yErrorPlus),
- dtype=numpy.float32)
-
- y = logY
-
- # TODO check if need more filtering of error (e.g., clip to positive)
-
- # TODO check and improve this
- if (len(color) == 4 and
- type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
- color = numpy.array(color, dtype=numpy.float32) / 255.
-
- if isinstance(color, numpy.ndarray) and color.ndim == 2:
- colorArray = color
- color = None
- else:
- colorArray = None
- color = colors.rgba(color)
-
- if alpha < 1.: # Apply image transparency
- if colorArray is not None and colorArray.shape[1] == 4:
- # multiply alpha channel
- colorArray[:, 3] = colorArray[:, 3] * alpha
- if color is not None:
- color = color[0], color[1], color[2], color[3] * alpha
-
- fillColor = None
- if fill is True:
- fillColor = color
- curve = glutils.GLPlotCurve2D(
- x, y, colorArray,
- xError=xerror,
- yError=yerror,
- lineStyle=linestyle,
- lineColor=color,
- lineWidth=linewidth,
- marker=symbol,
- markerColor=color,
- markerSize=symbolsize,
- fillColor=fillColor,
- baseline=baseline,
- isYLog=isYLog)
- curve.yaxis = 'left' if yaxis is None else yaxis
-
- if yaxis == "right":
- self._plotFrame.isY2Axis = True
-
- return curve
-
- def addImage(self, data,
- origin, scale,
- colormap, alpha):
- for parameter in (data, origin, scale):
- assert parameter is not None
-
- if data.ndim == 2:
- # Ensure array is contiguous and eventually convert its type
- dtypes = [dtype for dtype in (
- numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
- if glu.isSupportedGLType(dtype)]
- if data.dtype in dtypes:
- data = numpy.array(data, copy=False, order='C')
- else:
- _logger.info(
- 'addImage: Convert %s data to float32', str(data.dtype))
- data = numpy.array(data, dtype=numpy.float32, order='C')
-
- normalization = colormap.getNormalization()
- if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS:
- # Fast path applying colormap on the GPU
- cmapRange = colormap.getColormapRange(data=data)
- colormapLut = colormap.getNColors(nbColors=256)
- gamma = colormap.getGammaNormalizationParameter()
- nanColor = colors.rgba(colormap.getNaNColor())
-
- image = glutils.GLPlotColormap(
- data,
- origin,
- scale,
- colormapLut,
- normalization,
- gamma,
- cmapRange,
- alpha,
- nanColor)
-
- else: # Fallback applying colormap on CPU
- rgba = colormap.applyToData(data)
- image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha)
-
- elif len(data.shape) == 3:
- # For RGB, RGBA data
- assert data.shape[2] in (3, 4)
-
- if numpy.issubdtype(data.dtype, numpy.floating):
- data = numpy.array(data, dtype=numpy.float32, copy=False)
- elif data.dtype in [numpy.uint8, numpy.uint16]:
- pass
- elif numpy.issubdtype(data.dtype, numpy.integer):
- data = numpy.array(data, dtype=numpy.uint8, copy=False)
- else:
- raise ValueError('Unsupported data type')
-
- image = glutils.GLPlotRGBAImage(data, origin, scale, alpha)
-
- else:
- raise RuntimeError("Unsupported data shape {0}".format(data.shape))
-
- # TODO is this needed?
- if self._plotFrame.xAxis.isLog and image.xMin <= 0.:
- raise RuntimeError(
- 'Cannot add image with X <= 0 with X axis log scale')
- if self._plotFrame.yAxis.isLog and image.yMin <= 0.:
- raise RuntimeError(
- 'Cannot add image with Y <= 0 with Y axis log scale')
-
- return image
-
- def addTriangles(self, x, y, triangles,
- color, alpha):
- # Handle axes log scale: convert data
- if self._plotFrame.xAxis.isLog:
- x = numpy.log10(x)
- if self._plotFrame.yAxis.isLog:
- y = numpy.log10(y)
-
- triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha)
-
- return triangles
-
- def addShape(self, x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor):
- x = numpy.array(x, copy=False)
- y = numpy.array(y, copy=False)
-
- # TODO is this needed?
- if self._plotFrame.xAxis.isLog and x.min() <= 0.:
- raise RuntimeError(
- 'Cannot add item with X <= 0 with X axis log scale')
- if self._plotFrame.yAxis.isLog and y.min() <= 0.:
- raise RuntimeError(
- 'Cannot add item with Y <= 0 with Y axis log scale')
-
- return _ShapeItem(x, y, shape, color, fill, overlay,
- linestyle, linewidth, linebgcolor)
-
- def addMarker(self, x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis):
- return _MarkerItem(x, y, text, color,
- symbol, linestyle, linewidth, constraint, yaxis)
-
- # Remove methods
-
- def remove(self, item):
- if isinstance(item, glutils.GLPlotItem):
- if item.yaxis == 'right':
- # Check if some curves remains on the right Y axis
- y2AxisItems = (item for item in self._plot.getItems()
- if isinstance(item, items.YAxisMixIn) and
- item.getYAxis() == 'right')
- self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
-
- if item.isInitialized():
- self._glGarbageCollector.append(item)
-
- elif isinstance(item, (_MarkerItem, _ShapeItem)):
- pass # No-op
-
- else:
- _logger.error('Unsupported item: %s', str(item))
-
- # Interaction methods
-
- _QT_CURSORS = {
- BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
- BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
- BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
- BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
- BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
- }
-
- def setGraphCursorShape(self, cursor):
- if cursor is None:
- super(BackendOpenGL, self).unsetCursor()
- else:
- cursor = self._QT_CURSORS[cursor]
- super(BackendOpenGL, self).setCursor(qt.QCursor(cursor))
-
- def setGraphCursor(self, flag, color, linewidth, linestyle):
- if linestyle != '-':
- _logger.warning(
- "BackendOpenGL.setGraphCursor linestyle parameter ignored")
-
- if flag:
- color = colors.rgba(color)
- crosshairCursor = color, linewidth
- else:
- crosshairCursor = None
-
- if crosshairCursor != self._crosshairCursor:
- self._crosshairCursor = crosshairCursor
-
- _PICK_OFFSET = 3 # Offset in pixel used for picking
-
- def _mouseInPlotArea(self, x, y):
- """Returns closest visible position in the plot.
-
- This is performed in Qt widget pixel, not device pixel.
-
- :param float x: X coordinate in Qt widget pixel
- :param float y: Y coordinate in Qt widget pixel
- :return: (x, y) closest point in the plot.
- :rtype: List[float]
- """
- left, top, width, height = self.getPlotBoundsInPixels()
- return (numpy.clip(x, left, left + width - 1), # TODO -1?
- numpy.clip(y, top, top + height - 1))
-
- def __pickCurves(self, item, x, y):
- """Perform picking on a curve item.
-
- :param GLPlotCurve2D item:
- :param float x: X position of the mouse in widget coordinates
- :param float y: Y position of the mouse in widget coordinates
- :return: List of indices of picked points or None if not picked
- :rtype: Union[List[int],None]
- """
- offset = self._PICK_OFFSET
- if item.marker is not None:
- # Convert markerSize from points to qt pixels
- qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
- size = item.markerSize / 72. * qtDpi
- offset = max(size / 2., offset)
- if item.lineStyle is not None:
- # Convert line width from points to qt pixels
- qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
- lineWidth = item.lineWidth / 72. * qtDpi
- offset = max(lineWidth / 2., offset)
-
- inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
- dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
- axis=item.yaxis, check=True)
- if dataPos is None:
- return None
- xPick0, yPick0 = dataPos
-
- inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
- dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
- axis=item.yaxis, check=True)
- if dataPos is None:
- return None
- xPick1, yPick1 = dataPos
-
- if xPick0 < xPick1:
- xPickMin, xPickMax = xPick0, xPick1
- else:
- xPickMin, xPickMax = xPick1, xPick0
-
- if yPick0 < yPick1:
- yPickMin, yPickMax = yPick0, yPick1
- else:
- yPickMin, yPickMax = yPick1, yPick0
-
- # Apply log scale if axis is log
- if self._plotFrame.xAxis.isLog:
- xPickMin = numpy.log10(xPickMin)
- xPickMax = numpy.log10(xPickMax)
-
- if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
- item.yaxis == 'right' and self._plotFrame.y2Axis.isLog):
- yPickMin = numpy.log10(yPickMin)
- yPickMax = numpy.log10(yPickMax)
-
- return item.pick(xPickMin, yPickMin,
- xPickMax, yPickMax)
-
- def pickItem(self, x, y, item):
- # Picking is performed in Qt widget pixels not device pixels
- dataPos = self._plot.pixelToData(x, y, axis='left', check=True)
- if dataPos is None:
- return None # Outside plot area
-
- if item is None:
- _logger.error("No item provided for picking")
- return None
-
- # Pick markers
- if isinstance(item, _MarkerItem):
- yaxis = item['yaxis']
- pixelPos = self._plot.dataToPixel(
- item['x'], item['y'], axis=yaxis, check=False)
- if pixelPos is None:
- return None # negative coord on a log axis
-
- if item['x'] is None: # Horizontal line
- pt1 = self._plot.pixelToData(
- x, y - self._PICK_OFFSET, axis=yaxis, check=False)
- pt2 = self._plot.pixelToData(
- x, y + self._PICK_OFFSET, axis=yaxis, check=False)
- isPicked = (min(pt1[1], pt2[1]) <= item['y'] <=
- max(pt1[1], pt2[1]))
-
- elif item['y'] is None: # Vertical line
- pt1 = self._plot.pixelToData(
- x - self._PICK_OFFSET, y, axis=yaxis, check=False)
- pt2 = self._plot.pixelToData(
- x + self._PICK_OFFSET, y, axis=yaxis, check=False)
- isPicked = (min(pt1[0], pt2[0]) <= item['x'] <=
- max(pt1[0], pt2[0]))
-
- else:
- isPicked = (
- numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and
- numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET)
-
- return (0,) if isPicked else None
-
- # Pick image, curve, triangles
- elif isinstance(item, glutils.GLPlotItem):
- if isinstance(item, glutils.GLPlotCurve2D):
- return self.__pickCurves(item, x, y)
- else:
- return item.pick(*dataPos) # Might be None
-
- # Update curve
-
- def setCurveColor(self, curve, color):
- pass # TODO
-
- # Misc.
-
- def getWidgetHandle(self):
- return self
-
- def postRedisplay(self):
- self._sigPostRedisplay.emit()
-
- def replot(self):
- self.update() # async redraw
- # self.repaint() # immediate redraw
-
- def saveGraph(self, fileName, fileFormat, dpi):
- if dpi is not None:
- _logger.warning("saveGraph ignores dpi parameter")
-
- if fileFormat not in ['png', 'ppm', 'svg', 'tiff']:
- raise NotImplementedError('Unsupported format: %s' % fileFormat)
-
- if not self.isValid():
- _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
- width, height = self._plotFrame.size
- data = numpy.zeros((height, width, 3), dtype=numpy.uint8)
- else:
- self.makeCurrent()
-
- data = numpy.empty(
- (self._plotFrame.size[1], self._plotFrame.size[0], 3),
- dtype=numpy.uint8, order='C')
-
- context = self.context()
- framebufferTexture = self._plotFBOs.get(context)
- if framebufferTexture is None:
- # Fallback, supports direct rendering mode: _paintDirectGL
- # might have issues as it can read on-screen framebuffer
- fboName = self.defaultFramebufferObject()
- width, height = self._plotFrame.size
- else:
- fboName = framebufferTexture.name
- height, width = framebufferTexture.shape
-
- previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
- gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fboName)
- gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
- gl.glReadPixels(0, 0, width, height,
- gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
- gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
-
- # glReadPixels gives bottom to top,
- # while images are stored as top to bottom
- data = numpy.flipud(data)
-
- # fileName is either a file-like object or a str
- saveImageToFile(data, fileName, fileFormat)
-
- # Graph labels
-
- def setGraphTitle(self, title):
- self._plotFrame.title = title
-
- def setGraphXLabel(self, label):
- self._plotFrame.xAxis.title = label
-
- def setGraphYLabel(self, label, axis):
- if axis == 'left':
- self._plotFrame.yAxis.title = label
- else: # right axis
- self._plotFrame.y2Axis.title = label
-
- # Graph limits
-
- def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
- """Set the visible range of data in the plot frame.
-
- This clips the ranges to possible values (takes care of float32
- range + positive range for log).
- This also takes care of non-orthogonal axes.
-
- This should be moved to PlotFrame.
- """
- # Update axes range with a clipped range if too wide
- self._plotFrame.setDataRanges(xlim, ylim, y2lim)
-
- def _ensureAspectRatio(self, keepDim=None):
- """Update plot bounds in order to keep aspect ratio.
-
- Warning: keepDim on right Y axis is not implemented !
-
- :param str keepDim: The dimension to maintain: 'x', 'y' or None.
- If None (the default), the dimension with the largest range.
- """
- plotWidth, plotHeight = self._plotFrame.plotSize
- if plotWidth <= 2 or plotHeight <= 2:
- return
-
- if keepDim is None:
- ranges = self._plot.getDataRange()
- if (ranges.y is not None and
- ranges.x is not None and
- (ranges.y[1] - ranges.y[0]) != 0.):
- dataRatio = (ranges.x[1] - ranges.x[0]) / float(ranges.y[1] - ranges.y[0])
- plotRatio = plotWidth / float(plotHeight) # Test != 0 before
-
- keepDim = 'x' if dataRatio > plotRatio else 'y'
- else: # Limit case
- keepDim = 'x'
-
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
- self._plotFrame.dataRanges
- if keepDim == 'y':
- dataW = (yMax - yMin) * plotWidth / float(plotHeight)
- xCenter = 0.5 * (xMin + xMax)
- xMin = xCenter - 0.5 * dataW
- xMax = xCenter + 0.5 * dataW
- elif keepDim == 'x':
- dataH = (xMax - xMin) * plotHeight / float(plotWidth)
- yCenter = 0.5 * (yMin + yMax)
- yMin = yCenter - 0.5 * dataH
- yMax = yCenter + 0.5 * dataH
- y2Center = 0.5 * (y2Min + y2Max)
- y2Min = y2Center - 0.5 * dataH
- y2Max = y2Center + 0.5 * dataH
- else:
- raise RuntimeError('Unsupported dimension to keep: %s' % keepDim)
-
- # Update plot frame bounds
- self._setDataRanges(xlim=(xMin, xMax),
- ylim=(yMin, yMax),
- y2lim=(y2Min, y2Max))
-
- def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None,
- keepDim=None):
- # Update axes range with a clipped range if too wide
- self._setDataRanges(xlim=xRange,
- ylim=yRange,
- y2lim=y2Range)
-
- # Keep data aspect ratio
- if self.isKeepDataAspectRatio():
- self._ensureAspectRatio(keepDim)
-
- def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
- assert xmin < xmax
- assert ymin < ymax
-
- if y2min is None or y2max is None:
- y2Range = None
- else:
- assert y2min < y2max
- y2Range = y2min, y2max
- self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range)
-
- def getGraphXLimits(self):
- return self._plotFrame.dataRanges.x
-
- def setGraphXLimits(self, xmin, xmax):
- assert xmin < xmax
- self._setPlotBounds(xRange=(xmin, xmax), keepDim='x')
-
- def getGraphYLimits(self, axis):
- assert axis in ("left", "right")
- if axis == "left":
- return self._plotFrame.dataRanges.y
- else:
- return self._plotFrame.dataRanges.y2
-
- def setGraphYLimits(self, ymin, ymax, axis):
- assert ymin < ymax
- assert axis in ("left", "right")
-
- if axis == "left":
- self._setPlotBounds(yRange=(ymin, ymax), keepDim='y')
- else:
- self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y')
-
- # Graph axes
-
- def getXAxisTimeZone(self):
- return self._plotFrame.xAxis.timeZone
-
- def setXAxisTimeZone(self, tz):
- self._plotFrame.xAxis.timeZone = tz
-
- def isXAxisTimeSeries(self):
- return self._plotFrame.xAxis.isTimeSeries
-
- def setXAxisTimeSeries(self, isTimeSeries):
- self._plotFrame.xAxis.isTimeSeries = isTimeSeries
-
- def setXAxisLogarithmic(self, flag):
- if flag != self._plotFrame.xAxis.isLog:
- if flag and self._keepDataAspectRatio:
- _logger.warning(
- "KeepDataAspectRatio is ignored with log axes")
-
- self._plotFrame.xAxis.isLog = flag
-
- def setYAxisLogarithmic(self, flag):
- if (flag != self._plotFrame.yAxis.isLog or
- flag != self._plotFrame.y2Axis.isLog):
- if flag and self._keepDataAspectRatio:
- _logger.warning(
- "KeepDataAspectRatio is ignored with log axes")
-
- self._plotFrame.yAxis.isLog = flag
- self._plotFrame.y2Axis.isLog = flag
-
- def setYAxisInverted(self, flag):
- if flag != self._plotFrame.isYAxisInverted:
- self._plotFrame.isYAxisInverted = flag
-
- def isYAxisInverted(self):
- return self._plotFrame.isYAxisInverted
-
- def isKeepDataAspectRatio(self):
- if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog:
- return False
- else:
- return self._keepDataAspectRatio
-
- def setKeepDataAspectRatio(self, flag):
- if flag and (self._plotFrame.xAxis.isLog or
- self._plotFrame.yAxis.isLog):
- _logger.warning("KeepDataAspectRatio is ignored with log axes")
-
- self._keepDataAspectRatio = flag
-
- def setGraphGrid(self, which):
- assert which in (None, 'major', 'both')
- self._plotFrame.grid = which is not None # TODO True grid support
-
- # Data <-> Pixel coordinates conversion
-
- def dataToPixel(self, x, y, axis):
- result = self._plotFrame.dataToPixel(x, y, axis)
- if result is None:
- return None
- else:
- devicePixelRatio = self.getDevicePixelRatio()
- return tuple(value/devicePixelRatio for value in result)
-
- def pixelToData(self, x, y, axis):
- devicePixelRatio = self.getDevicePixelRatio()
- return self._plotFrame.pixelToData(
- x * devicePixelRatio, y * devicePixelRatio, axis)
-
- def getPlotBoundsInPixels(self):
- devicePixelRatio = self.getDevicePixelRatio()
- return tuple(int(value / devicePixelRatio)
- for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize)
-
- def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
- self._plotFrame.marginRatios = left, top, right, bottom
-
- def setForegroundColors(self, foregroundColor, gridColor):
- self._plotFrame.foregroundColor = foregroundColor
- self._plotFrame.gridColor = gridColor
-
- def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
- self._backgroundColor = backgroundColor
- self._dataBackgroundColor = dataBackgroundColor
diff --git a/silx/gui/plot/backends/glutils/GLPlotCurve.py b/silx/gui/plot/backends/glutils/GLPlotCurve.py
deleted file mode 100644
index 34844c6..0000000
--- a/silx/gui/plot/backends/glutils/GLPlotCurve.py
+++ /dev/null
@@ -1,1375 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-This module provides classes to render 2D lines and scatter plots
-"""
-
-from __future__ import division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/04/2017"
-
-
-import math
-import logging
-
-import numpy
-
-from silx.math.combo import min_max
-
-from ...._glutils import gl
-from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
-from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
-from .GLPlotImage import GLPlotItem
-
-
-_logger = logging.getLogger(__name__)
-
-
-_MPL_NONES = None, 'None', '', ' '
-"""Possible values for None"""
-
-
-def _notNaNSlices(array, length=1):
- """Returns slices of none NaN values in the array.
-
- :param numpy.ndarray array: 1D array from which to get slices
- :param int length: Slices shorter than length gets discarded
- :return: Array of (start, end) slice indices
- :rtype: numpy.ndarray
- """
- isnan = numpy.isnan(numpy.array(array, copy=False).reshape(-1))
- notnan = numpy.logical_not(isnan)
- start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
- if notnan[0]:
- start = numpy.append(0, start)
- end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
- if notnan[-1]:
- end = numpy.append(end, len(array))
- slices = numpy.transpose((start, end))
- if length > 1:
- # discard slices with less than length values
- slices = slices[numpy.diff(slices, axis=1).ravel() >= length]
- return slices
-
-
-# fill ########################################################################
-
-class _Fill2D(object):
- """Object rendering curve filling as polygons
-
- :param numpy.ndarray xData: X coordinates of points
- :param numpy.ndarray yData: Y coordinates of points
- :param float baseline: Y value of the 'bottom' of the fill.
- 0 for linear Y scale, -38 for log Y scale
- :param List[float] color: RGBA color as 4 float in [0, 1]
- :param List[float] offset: Translation of coordinates (ox, oy)
- """
-
- _PROGRAM = Program(
- vertexShader="""
- #version 120
-
- uniform mat4 matrix;
- attribute float xPos;
- attribute float yPos;
-
- void main(void) {
- gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0);
- }
- """,
- fragmentShader="""
- #version 120
-
- uniform vec4 color;
-
- void main(void) {
- gl_FragColor = color;
- }
- """,
- attrib0='xPos')
-
- def __init__(self, xData=None, yData=None,
- baseline=0,
- color=(0., 0., 0., 1.),
- offset=(0., 0.)):
- self.xData = xData
- self.yData = yData
- self._xFillVboData = None
- self._yFillVboData = None
- self.color = color
- self.offset = offset
-
- # Offset baseline
- self.baseline = baseline - self.offset[1]
-
- def prepare(self):
- """Rendering preparation: build indices and bounding box vertices"""
- if (self._xFillVboData is None and
- self.xData is not None and self.yData is not None):
-
- # Get slices of not NaN values longer than 1 element
- isnan = numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
- notnan = numpy.logical_not(isnan)
- start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
- if notnan[0]:
- start = numpy.append(0, start)
- end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
- if notnan[-1]:
- end = numpy.append(end, len(isnan))
- slices = numpy.transpose((start, end))
- # discard slices with less than length values
- slices = slices[numpy.diff(slices, axis=1).reshape(-1) >= 2]
-
- # Number of points: slice + 2 * leading and trailing points
- # Twice leading and trailing points to produce degenerated triangles
- nbPoints = numpy.sum(numpy.diff(slices, axis=1)) * 2 + 4 * len(slices)
- points = numpy.empty((nbPoints, 2), dtype=numpy.float32)
-
- offset = 0
- # invert baseline for filling
- new_y_data = numpy.append(self.yData, self.baseline)
- for start, end in slices:
- # Duplicate first point for connecting degenerated triangle
- points[offset:offset+2] = self.xData[start], new_y_data[start]
-
- # 2nd point of the polygon is last point
- points[offset+2] = self.xData[start], self.baseline[start]
-
- indices = numpy.append(numpy.arange(start, end),
- numpy.arange(len(self.xData) + end-1, len(self.xData) + start-1, -1))
- indices = indices[buildFillMaskIndices(len(indices))]
-
- points[offset+3:offset+3+len(indices), 0] = self.xData[indices % len(self.xData)]
- points[offset+3:offset+3+len(indices), 1] = new_y_data[indices]
-
- # Duplicate last point for connecting degenerated triangle
- points[offset+3+len(indices)] = points[offset+3+len(indices)-1]
-
- offset += len(indices) + 4
-
- self._xFillVboData, self._yFillVboData = vertexBuffer(points.T)
-
- def render(self, context):
- """Perform rendering
-
- :param RenderContext context:
- """
- self.prepare()
-
- if self._xFillVboData is None:
- return # Nothing to display
-
- self._PROGRAM.use()
-
- gl.glUniformMatrix4fv(
- self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
- numpy.dot(context.matrix,
- mat4Translate(*self.offset)).astype(numpy.float32))
-
- gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color)
-
- xPosAttrib = self._PROGRAM.attributes['xPos']
- yPosAttrib = self._PROGRAM.attributes['yPos']
-
- gl.glEnableVertexAttribArray(xPosAttrib)
- self._xFillVboData.setVertexAttrib(xPosAttrib)
-
- gl.glEnableVertexAttribArray(yPosAttrib)
- self._yFillVboData.setVertexAttrib(yPosAttrib)
-
- # Prepare fill mask
- gl.glEnable(gl.GL_STENCIL_TEST)
- gl.glStencilMask(1)
- gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
- gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
- gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
- gl.glDepthMask(gl.GL_FALSE)
-
- gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._xFillVboData.size)
-
- gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
- # Reset stencil while drawing
- gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
- gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
- gl.glDepthMask(gl.GL_TRUE)
-
- # Draw directly in NDC
- gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
- mat4Identity().astype(numpy.float32))
-
- # NDC vertices
- gl.glVertexAttribPointer(
- xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
- numpy.array((-1., -1., 1., 1.), dtype=numpy.float32))
- gl.glVertexAttribPointer(
- yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
- numpy.array((-1., 1., -1., 1.), dtype=numpy.float32))
-
- gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
-
- gl.glDisable(gl.GL_STENCIL_TEST)
-
- def discard(self):
- """Release VBOs"""
- if self.isInitialized():
- self._xFillVboData.vbo.discard()
-
- self._xFillVboData = None
- self._yFillVboData = None
-
- def isInitialized(self):
- return self._xFillVboData is not None
-
-
-# line ########################################################################
-
-SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
-
-
-class GLLines2D(object):
- """Object rendering curve as a polyline
-
- :param xVboData: X coordinates VBO
- :param yVboData: Y coordinates VBO
- :param colorVboData: VBO of colors
- :param distVboData: VBO of distance along the polyline
- :param str style: Line style in: '-', '--', '-.', ':'
- :param List[float] color: RGBA color as 4 float in [0, 1]
- :param float width: Line width
- :param float dashPeriod: Period of dashes
- :param drawMode: OpenGL drawing mode
- :param List[float] offset: Translation of coordinates (ox, oy)
- """
-
- STYLES = SOLID, DASHED, DASHDOT, DOTTED
- """Supported line styles"""
-
- _SOLID_PROGRAM = Program(
- vertexShader="""
- #version 120
-
- uniform mat4 matrix;
- attribute float xPos;
- attribute float yPos;
- attribute vec4 color;
-
- varying vec4 vColor;
-
- void main(void) {
- gl_Position = matrix * vec4(xPos, yPos, 0., 1.) ;
- vColor = color;
- }
- """,
- fragmentShader="""
- #version 120
-
- varying vec4 vColor;
-
- void main(void) {
- gl_FragColor = vColor;
- }
- """,
- attrib0='xPos')
-
- # Limitation: Dash using an estimate of distance in screen coord
- # to avoid computing distance when viewport is resized
- # results in inequal dashes when viewport aspect ratio is far from 1
- _DASH_PROGRAM = Program(
- vertexShader="""
- #version 120
-
- uniform mat4 matrix;
- uniform vec2 halfViewportSize;
- attribute float xPos;
- attribute float yPos;
- attribute vec4 color;
- attribute float distance;
-
- varying float vDist;
- varying vec4 vColor;
-
- void main(void) {
- gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
- //Estimate distance in pixels
- vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) *
- halfViewportSize;
- float pixelPerDataEstimate = length(probe)/sqrt(2.);
- vDist = distance * pixelPerDataEstimate;
- vColor = color;
- }
- """,
- fragmentShader="""
- #version 120
-
- /* Dashes: [0, x], [y, z]
- Dash period: w */
- uniform vec4 dash;
- uniform vec4 dash2ndColor;
-
- varying float vDist;
- varying vec4 vColor;
-
- void main(void) {
- float dist = mod(vDist, dash.w);
- if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
- if (dash2ndColor.a == 0.) {
- discard; // Discard full transparent bg color
- } else {
- gl_FragColor = dash2ndColor;
- }
- } else {
- gl_FragColor = vColor;
- }
- }
- """,
- attrib0='xPos')
-
- def __init__(self, xVboData=None, yVboData=None,
- colorVboData=None, distVboData=None,
- style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None,
- width=1, dashPeriod=10., drawMode=None,
- offset=(0., 0.)):
- if (xVboData is not None and
- not isinstance(xVboData, VertexBufferAttrib)):
- xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
- self.xVboData = xVboData
-
- if (yVboData is not None and
- not isinstance(yVboData, VertexBufferAttrib)):
- yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
- self.yVboData = yVboData
-
- # Compute distances if not given while providing numpy array coordinates
- if (isinstance(self.xVboData, numpy.ndarray) and
- isinstance(self.yVboData, numpy.ndarray) and
- distVboData is None):
- distVboData = distancesFromArrays(self.xVboData, self.yVboData)
-
- if (distVboData is not None and
- not isinstance(distVboData, VertexBufferAttrib)):
- distVboData = numpy.array(
- distVboData, copy=False, dtype=numpy.float32)
- self.distVboData = distVboData
-
- if colorVboData is not None:
- assert isinstance(colorVboData, VertexBufferAttrib)
- self.colorVboData = colorVboData
- self.useColorVboData = colorVboData is not None
-
- self.color = color
- self.dash2ndColor = dash2ndColor
- self.width = width
- self._style = None
- self.style = style
- self.dashPeriod = dashPeriod
- self.offset = offset
-
- self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP
-
- @property
- def style(self):
- """Line style (Union[str,None])"""
- return self._style
-
- @style.setter
- def style(self, style):
- if style in _MPL_NONES:
- self._style = None
- else:
- assert style in self.STYLES
- self._style = style
-
- @classmethod
- def init(cls):
- """OpenGL context initialization"""
- gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
-
- def render(self, context):
- """Perform rendering
-
- :param RenderContext context:
- """
- width = self.width / 72. * context.dpi
-
- style = self.style
- if style is None:
- return
-
- elif style == SOLID:
- program = self._SOLID_PROGRAM
- program.use()
-
- else: # DASHED, DASHDOT, DOTTED
- program = self._DASH_PROGRAM
- program.use()
-
- x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT)
- gl.glUniform2f(program.uniforms['halfViewportSize'],
- 0.5 * viewWidth, 0.5 * viewHeight)
-
- dashPeriod = self.dashPeriod * width
- if self.style == DOTTED:
- dash = (0.2 * dashPeriod,
- 0.5 * dashPeriod,
- 0.7 * dashPeriod,
- dashPeriod)
- elif self.style == DASHDOT:
- dash = (0.3 * dashPeriod,
- 0.5 * dashPeriod,
- 0.6 * dashPeriod,
- dashPeriod)
- else:
- dash = (0.5 * dashPeriod,
- dashPeriod,
- dashPeriod,
- dashPeriod)
-
- gl.glUniform4f(program.uniforms['dash'], *dash)
-
- if self.dash2ndColor is None:
- # Use fully transparent color which gets discarded in shader
- dash2ndColor = (0., 0., 0., 0.)
- else:
- dash2ndColor = self.dash2ndColor
- gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor)
-
- distAttrib = program.attributes['distance']
- gl.glEnableVertexAttribArray(distAttrib)
- if isinstance(self.distVboData, VertexBufferAttrib):
- self.distVboData.setVertexAttrib(distAttrib)
- else:
- gl.glVertexAttribPointer(distAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.distVboData)
-
- if width != 1:
- gl.glEnable(gl.GL_LINE_SMOOTH)
-
- matrix = numpy.dot(context.matrix,
- mat4Translate(*self.offset)).astype(numpy.float32)
- gl.glUniformMatrix4fv(program.uniforms['matrix'],
- 1, gl.GL_TRUE, matrix)
-
- colorAttrib = program.attributes['color']
- if self.useColorVboData and self.colorVboData is not None:
- gl.glEnableVertexAttribArray(colorAttrib)
- self.colorVboData.setVertexAttrib(colorAttrib)
- else:
- gl.glDisableVertexAttribArray(colorAttrib)
- gl.glVertexAttrib4f(colorAttrib, *self.color)
-
- xPosAttrib = program.attributes['xPos']
- gl.glEnableVertexAttribArray(xPosAttrib)
- if isinstance(self.xVboData, VertexBufferAttrib):
- self.xVboData.setVertexAttrib(xPosAttrib)
- else:
- gl.glVertexAttribPointer(xPosAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.xVboData)
-
- yPosAttrib = program.attributes['yPos']
- gl.glEnableVertexAttribArray(yPosAttrib)
- if isinstance(self.yVboData, VertexBufferAttrib):
- self.yVboData.setVertexAttrib(yPosAttrib)
- else:
- gl.glVertexAttribPointer(yPosAttrib,
- 1,
- gl.GL_FLOAT,
- False,
- 0,
- self.yVboData)
-
- gl.glLineWidth(width)
- gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
-
- gl.glDisable(gl.GL_LINE_SMOOTH)
-
-
-def distancesFromArrays(xData, yData):
- """Returns distances between each points
-
- :param numpy.ndarray xData: X coordinate of points
- :param numpy.ndarray yData: Y coordinate of points
- :rtype: numpy.ndarray
- """
- # Split array into sub-shapes at not finite points
- splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
- numpy.isfinite(xData), numpy.isfinite(yData))))[0]
- splits = numpy.concatenate(([-1], splits, [len(xData) - 1]))
-
- # Compute distance independently for each sub-shapes,
- # putting not finite points as last points of sub-shapes
- distances = []
- for begin, end in zip(splits[:-1] + 1, splits[1:] + 1):
- if begin == end: # Empty shape
- continue
- elif end - begin == 1: # Single element
- distances.append([0])
- else:
- deltas = numpy.dstack((
- numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)),
- numpy.ediff1d(yData[begin:end], to_begin=numpy.float32(0.))))[0]
- distances.append(
- numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))))
- return numpy.concatenate(distances)
-
-
-# points ######################################################################
-
-DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \
- 'd', 'o', 's', '+', 'x', '.', ',', '*'
-
-H_LINE, V_LINE, HEART = '_', '|', u'\u2665'
-
-TICK_LEFT = "tickleft"
-TICK_RIGHT = "tickright"
-TICK_UP = "tickup"
-TICK_DOWN = "tickdown"
-CARET_LEFT = "caretleft"
-CARET_RIGHT = "caretright"
-CARET_UP = "caretup"
-CARET_DOWN = "caretdown"
-
-
-class _Points2D(object):
- """Object rendering curve markers
-
- :param xVboData: X coordinates VBO
- :param yVboData: Y coordinates VBO
- :param colorVboData: VBO of colors
- :param str marker: Kind of symbol to use, see :attr:`MARKERS`.
- :param List[float] color: RGBA color as 4 float in [0, 1]
- :param float size: Marker size
- :param List[float] offset: Translation of coordinates (ox, oy)
- """
-
- MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK,
- H_LINE, V_LINE, HEART, TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN,
- CARET_LEFT, CARET_RIGHT, CARET_UP, CARET_DOWN)
- """List of supported markers"""
-
- _VERTEX_SHADER = """
- #version 120
-
- uniform mat4 matrix;
- uniform int transform;
- uniform float size;
- attribute float xPos;
- attribute float yPos;
- attribute vec4 color;
-
- varying vec4 vColor;
-
- void main(void) {
- gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
- vColor = color;
- gl_PointSize = size;
- }
- """
-
- _FRAGMENT_SHADER_SYMBOLS = {
- DIAMOND: """
- float alphaSymbol(vec2 coord, float size) {
- vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
- float f = centerCoord.x + centerCoord.y;
- return clamp(size * (0.5 - f), 0.0, 1.0);
- }
- """,
- CIRCLE: """
- float alphaSymbol(vec2 coord, float size) {
- float radius = 0.5;
- float r = distance(coord, vec2(0.5, 0.5));
- return clamp(size * (radius - r), 0.0, 1.0);
- }
- """,
- SQUARE: """
- float alphaSymbol(vec2 coord, float size) {
- return 1.0;
- }
- """,
- PLUS: """
- float alphaSymbol(vec2 coord, float size) {
- vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
- if (min(d.x, d.y) < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- X_MARKER: """
- float alphaSymbol(vec2 coord, float size) {
- vec2 pos = floor(size * coord) + 0.5;
- vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
- if (min(d_x.x, d_x.y) <= 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- ASTERISK: """
- float alphaSymbol(vec2 coord, float size) {
- /* Combining +, x and circle */
- vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
- vec2 pos = floor(size * coord) + 0.5;
- vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
- if (min(d_plus.x, d_plus.y) < 0.5) {
- return 1.0;
- } else if (min(d_x.x, d_x.y) <= 0.5) {
- float r = distance(coord, vec2(0.5, 0.5));
- return clamp(size * (0.5 - r), 0.0, 1.0);
- } else {
- return 0.0;
- }
- }
- """,
- H_LINE: """
- float alphaSymbol(vec2 coord, float size) {
- float dy = abs(size * (coord.y - 0.5));
- if (dy < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- V_LINE: """
- float alphaSymbol(vec2 coord, float size) {
- float dx = abs(size * (coord.x - 0.5));
- if (dx < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- HEART: """
- float alphaSymbol(vec2 coord, float size) {
- coord = (coord - 0.5) * 2.;
- coord *= 0.75;
- coord.y += 0.25;
- float a = atan(coord.x,-coord.y)/3.141593;
- float r = length(coord);
- float h = abs(a);
- float d = (13.0*h - 22.0*h*h + 10.0*h*h*h)/(6.0-5.0*h);
- float res = clamp(r-d, 0., 1.);
- // antialiasing
- res = smoothstep(0.1, 0.001, res);
- return res;
- }
- """,
- TICK_LEFT: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float dy = abs(coord.y);
- if (dy < 0.5 && coord.x < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- TICK_RIGHT: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float dy = abs(coord.y);
- if (dy < 0.5 && coord.x > -0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- TICK_UP: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float dx = abs(coord.x);
- if (dx < 0.5 && coord.y < 0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- TICK_DOWN: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float dx = abs(coord.x);
- if (dx < 0.5 && coord.y > -0.5) {
- return 1.0;
- } else {
- return 0.0;
- }
- }
- """,
- CARET_LEFT: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float d = abs(coord.x) - abs(coord.y);
- if (d >= -0.1 && coord.x > 0.5) {
- return smoothstep(-0.1, 0.1, d);
- } else {
- return 0.0;
- }
- }
- """,
- CARET_RIGHT: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float d = abs(coord.x) - abs(coord.y);
- if (d >= -0.1 && coord.x < 0.5) {
- return smoothstep(-0.1, 0.1, d);
- } else {
- return 0.0;
- }
- }
- """,
- CARET_UP: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float d = abs(coord.y) - abs(coord.x);
- if (d >= -0.1 && coord.y > 0.5) {
- return smoothstep(-0.1, 0.1, d);
- } else {
- return 0.0;
- }
- }
- """,
- CARET_DOWN: """
- float alphaSymbol(vec2 coord, float size) {
- coord = size * (coord - 0.5);
- float d = abs(coord.y) - abs(coord.x);
- if (d >= -0.1 && coord.y < 0.5) {
- return smoothstep(-0.1, 0.1, d);
- } else {
- return 0.0;
- }
- }
- """,
- }
-
- _FRAGMENT_SHADER_TEMPLATE = """
- #version 120
-
- uniform float size;
-
- varying vec4 vColor;
-
- %s
-
- void main(void) {
- float alpha = alphaSymbol(gl_PointCoord, size);
- if (alpha <= 0.0) {
- discard;
- } else {
- gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0));
- }
- }
- """
-
- _PROGRAMS = {}
-
- def __init__(self, xVboData=None, yVboData=None, colorVboData=None,
- marker=SQUARE, color=(0., 0., 0., 1.), size=7,
- offset=(0., 0.)):
- self.color = color
- self._marker = None
- self.marker = marker
- self.size = size
- self.offset = offset
-
- self.xVboData = xVboData
- self.yVboData = yVboData
- self.colorVboData = colorVboData
- self.useColorVboData = colorVboData is not None
-
- @property
- def marker(self):
- """Symbol used to display markers (str)"""
- return self._marker
-
- @marker.setter
- def marker(self, marker):
- if marker in _MPL_NONES:
- self._marker = None
- else:
- assert marker in self.MARKERS
- self._marker = marker
-
- @classmethod
- def _getProgram(cls, marker):
- """On-demand shader program creation."""
- if marker == PIXEL:
- marker = SQUARE
- elif marker == POINT:
- marker = CIRCLE
-
- if marker not in cls._PROGRAMS:
- cls._PROGRAMS[marker] = Program(
- vertexShader=cls._VERTEX_SHADER,
- fragmentShader=(cls._FRAGMENT_SHADER_TEMPLATE %
- cls._FRAGMENT_SHADER_SYMBOLS[marker]),
- attrib0='xPos')
-
- return cls._PROGRAMS[marker]
-
- @classmethod
- def init(cls):
- """OpenGL context initialization"""
- version = gl.glGetString(gl.GL_VERSION)
- majorVersion = int(version[0])
- assert majorVersion >= 2
- gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
- gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
- if majorVersion >= 3: # OpenGL 3
- gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
-
- def render(self, context):
- """Perform rendering
-
- :param RenderContext context:
- """
- if self.marker is None:
- return
-
- program = self._getProgram(self.marker)
- program.use()
-
- matrix = numpy.dot(context.matrix,
- mat4Translate(*self.offset)).astype(numpy.float32)
- gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
-
- if self.marker == PIXEL:
- size = 1
- elif self.marker == POINT:
- size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
- else:
- size = self.size
- size = size / 72. * context.dpi
-
- if self.marker in (PLUS, H_LINE, V_LINE,
- TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN):
- # Convert to nearest odd number
- size = size // 2 * 2 + 1.
-
- gl.glUniform1f(program.uniforms['size'], size)
- # gl.glPointSize(self.size)
-
- cAttrib = program.attributes['color']
- if self.useColorVboData and self.colorVboData is not None:
- gl.glEnableVertexAttribArray(cAttrib)
- self.colorVboData.setVertexAttrib(cAttrib)
- else:
- gl.glDisableVertexAttribArray(cAttrib)
- gl.glVertexAttrib4f(cAttrib, *self.color)
-
- xAttrib = program.attributes['xPos']
- gl.glEnableVertexAttribArray(xAttrib)
- self.xVboData.setVertexAttrib(xAttrib)
-
- yAttrib = program.attributes['yPos']
- gl.glEnableVertexAttribArray(yAttrib)
- self.yVboData.setVertexAttrib(yAttrib)
-
- gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size)
-
- gl.glUseProgram(0)
-
-
-# error bars ##################################################################
-
-class _ErrorBars(object):
- """Display errors bars.
-
- This is using its own VBO as opposed to fill/points/lines.
- There is no picking on error bars.
-
- It uses 2 vertices per error bars and uses :class:`GLLines2D` to
- render error bars and :class:`_Points2D` to render the ends.
-
- :param numpy.ndarray xData: X coordinates of the data.
- :param numpy.ndarray yData: Y coordinates of the data.
- :param xError: The absolute error on the X axis.
- :type xError: A float, or a numpy.ndarray of float32.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for negative errors,
- row 1 for positive errors.
- :param yError: The absolute error on the Y axis.
- :type yError: A float, or a numpy.ndarray of float32. See xError.
- :param float xMin: The min X value already computed by GLPlotCurve2D.
- :param float yMin: The min Y value already computed by GLPlotCurve2D.
- :param List[float] color: RGBA color as 4 float in [0, 1]
- :param List[float] offset: Translation of coordinates (ox, oy)
- """
-
- def __init__(self, xData, yData, xError, yError,
- xMin, yMin,
- color=(0., 0., 0., 1.),
- offset=(0., 0.)):
- self._attribs = None
- self._xMin, self._yMin = xMin, yMin
- self.offset = offset
-
- if xError is not None or yError is not None:
- self._xData = numpy.array(
- xData, order='C', dtype=numpy.float32, copy=False)
- self._yData = numpy.array(
- yData, order='C', dtype=numpy.float32, copy=False)
-
- # This also works if xError, yError is a float/int
- self._xError = numpy.array(
- xError, order='C', dtype=numpy.float32, copy=False)
- self._yError = numpy.array(
- yError, order='C', dtype=numpy.float32, copy=False)
- else:
- self._xData, self._yData = None, None
- self._xError, self._yError = None, None
-
- self._lines = GLLines2D(
- None, None, color=color, drawMode=gl.GL_LINES, offset=offset)
- self._xErrPoints = _Points2D(
- None, None, color=color, marker=V_LINE, offset=offset)
- self._yErrPoints = _Points2D(
- None, None, color=color, marker=H_LINE, offset=offset)
-
- def _buildVertices(self):
- """Generates error bars vertices"""
- nbLinesPerDataPts = (0 if self._xError is None else 2) + \
- (0 if self._yError is None else 2)
-
- nbDataPts = len(self._xData)
-
- # interleave coord+error, coord-error.
- # xError vertices first if any, then yError vertices if any.
- xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
- dtype=numpy.float32)
- yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
- dtype=numpy.float32)
-
- if self._xError is not None: # errors on the X axis
- if len(self._xError.shape) == 2:
- xErrorMinus, xErrorPlus = self._xError[0], self._xError[1]
- else:
- # numpy arrays of len 1 or len(xData)
- xErrorMinus, xErrorPlus = self._xError, self._xError
-
- # Interleave vertices for xError
- endXError = 4 * nbDataPts
- xCoords[0:endXError-3:4] = self._xData + xErrorPlus
- xCoords[1:endXError-2:4] = self._xData
- xCoords[2:endXError-1:4] = self._xData
- xCoords[3:endXError:4] = self._xData - xErrorMinus
-
- yCoords[0:endXError-3:4] = self._yData
- yCoords[1:endXError-2:4] = self._yData
- yCoords[2:endXError-1:4] = self._yData
- yCoords[3:endXError:4] = self._yData
-
- else:
- endXError = 0
-
- if self._yError is not None: # errors on the Y axis
- if len(self._yError.shape) == 2:
- yErrorMinus, yErrorPlus = self._yError[0], self._yError[1]
- else:
- # numpy arrays of len 1 or len(yData)
- yErrorMinus, yErrorPlus = self._yError, self._yError
-
- # Interleave vertices for yError
- xCoords[endXError::4] = self._xData
- xCoords[endXError+1::4] = self._xData
- xCoords[endXError+2::4] = self._xData
- xCoords[endXError+3::4] = self._xData
-
- yCoords[endXError::4] = self._yData + yErrorPlus
- yCoords[endXError+1::4] = self._yData
- yCoords[endXError+2::4] = self._yData
- yCoords[endXError+3::4] = self._yData - yErrorMinus
-
- return xCoords, yCoords
-
- def prepare(self):
- """Rendering preparation: build indices and bounding box vertices"""
- if self._xData is None:
- return
-
- if self._attribs is None:
- xCoords, yCoords = self._buildVertices()
-
- xAttrib, yAttrib = vertexBuffer((xCoords, yCoords))
- self._attribs = xAttrib, yAttrib
-
- self._lines.xVboData = xAttrib
- self._lines.yVboData = yAttrib
-
- # Set xError points using the same VBO as lines
- self._xErrPoints.xVboData = xAttrib.copy()
- self._xErrPoints.xVboData.size //= 2
- self._xErrPoints.yVboData = yAttrib.copy()
- self._xErrPoints.yVboData.size //= 2
-
- # Set yError points using the same VBO as lines
- self._yErrPoints.xVboData = xAttrib.copy()
- self._yErrPoints.xVboData.size //= 2
- self._yErrPoints.xVboData.offset += (xAttrib.itemsize *
- xAttrib.size // 2)
- self._yErrPoints.yVboData = yAttrib.copy()
- self._yErrPoints.yVboData.size //= 2
- self._yErrPoints.yVboData.offset += (yAttrib.itemsize *
- yAttrib.size // 2)
-
- def render(self, context):
- """Perform rendering
-
- :param RenderContext context:
- """
- self.prepare()
-
- if self._attribs is not None:
- self._lines.render(context)
- self._xErrPoints.render(context)
- self._yErrPoints.render(context)
-
- def discard(self):
- """Release VBOs"""
- if self.isInitialized():
- self._lines.xVboData, self._lines.yVboData = None, None
- self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None
- self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None
- self._attribs[0].vbo.discard()
- self._attribs = None
-
- def isInitialized(self):
- return self._attribs is not None
-
-
-# curves ######################################################################
-
-def _proxyProperty(*componentsAttributes):
- """Create a property to access an attribute of attribute(s).
- Useful for composition.
- Supports multiple components this way:
- getter returns the first found, setter sets all
- """
- def getter(self):
- for compName, attrName in componentsAttributes:
- try:
- component = getattr(self, compName)
- except AttributeError:
- pass
- else:
- return getattr(component, attrName)
-
- def setter(self, value):
- for compName, attrName in componentsAttributes:
- component = getattr(self, compName)
- setattr(component, attrName, value)
- return property(getter, setter)
-
-
-class GLPlotCurve2D(GLPlotItem):
- def __init__(self, xData, yData, colorData=None,
- xError=None, yError=None,
- lineStyle=SOLID,
- lineColor=(0., 0., 0., 1.),
- lineWidth=1,
- lineDashPeriod=20,
- marker=SQUARE,
- markerColor=(0., 0., 0., 1.),
- markerSize=7,
- fillColor=None,
- baseline=None,
- isYLog=False):
- super().__init__()
- self.colorData = colorData
-
- # Compute x bounds
- if xError is None:
- self.xMin, self.xMax = min_max(xData, min_positive=False)
- else:
- # Takes the error into account
- if hasattr(xError, 'shape') and len(xError.shape) == 2:
- xErrorMinus, xErrorPlus = xError[0], xError[1]
- else:
- xErrorMinus, xErrorPlus = xError, xError
- self.xMin = numpy.nanmin(xData - xErrorMinus)
- self.xMax = numpy.nanmax(xData + xErrorPlus)
-
- # Compute y bounds
- if yError is None:
- self.yMin, self.yMax = min_max(yData, min_positive=False)
- else:
- # Takes the error into account
- if hasattr(yError, 'shape') and len(yError.shape) == 2:
- yErrorMinus, yErrorPlus = yError[0], yError[1]
- else:
- yErrorMinus, yErrorPlus = yError, yError
- self.yMin = numpy.nanmin(yData - yErrorMinus)
- self.yMax = numpy.nanmax(yData + yErrorPlus)
-
- # Handle data offset
- if xData.itemsize > 4 or yData.itemsize > 4: # Use normalization
- # offset data, do not offset error as it is relative
- self.offset = self.xMin, self.yMin
- self.xData = (xData - self.offset[0]).astype(numpy.float32)
- self.yData = (yData - self.offset[1]).astype(numpy.float32)
-
- else: # float32
- self.offset = 0., 0.
- self.xData = xData
- self.yData = yData
- if fillColor is not None:
- def deduce_baseline(baseline):
- if baseline is None:
- _baseline = 0
- else:
- _baseline = baseline
- if not isinstance(_baseline, numpy.ndarray):
- _baseline = numpy.repeat(_baseline,
- len(self.xData))
- if isYLog is True:
- with numpy.errstate(divide='ignore', invalid='ignore'):
- log_val = numpy.log10(_baseline)
- _baseline = numpy.where(_baseline>0.0, log_val, -38)
- return _baseline
-
- _baseline = deduce_baseline(baseline)
-
- # Use different baseline depending of Y log scale
- self.fill = _Fill2D(self.xData, self.yData,
- baseline=_baseline,
- color=fillColor,
- offset=self.offset)
- else:
- self.fill = None
-
- self._errorBars = _ErrorBars(self.xData, self.yData,
- xError, yError,
- self.xMin, self.yMin,
- offset=self.offset)
-
- self.lines = GLLines2D()
- self.lines.style = lineStyle
- self.lines.color = lineColor
- self.lines.width = lineWidth
- self.lines.dashPeriod = lineDashPeriod
- self.lines.offset = self.offset
-
- self.points = _Points2D()
- self.points.marker = marker
- self.points.color = markerColor
- self.points.size = markerSize
- self.points.offset = self.offset
-
- xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData'))
-
- yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData'))
-
- colorVboData = _proxyProperty(('lines', 'colorVboData'),
- ('points', 'colorVboData'))
-
- useColorVboData = _proxyProperty(('lines', 'useColorVboData'),
- ('points', 'useColorVboData'))
-
- distVboData = _proxyProperty(('lines', 'distVboData'))
-
- lineStyle = _proxyProperty(('lines', 'style'))
-
- lineColor = _proxyProperty(('lines', 'color'))
-
- lineWidth = _proxyProperty(('lines', 'width'))
-
- lineDashPeriod = _proxyProperty(('lines', 'dashPeriod'))
-
- marker = _proxyProperty(('points', 'marker'))
-
- markerColor = _proxyProperty(('points', 'color'))
-
- markerSize = _proxyProperty(('points', 'size'))
-
- @classmethod
- def init(cls):
- """OpenGL context initialization"""
- GLLines2D.init()
- _Points2D.init()
-
- def prepare(self):
- """Rendering preparation: build indices and bounding box vertices"""
- if self.xVboData is None:
- xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
- if self.lineStyle in (DASHED, DASHDOT, DOTTED):
- dists = distancesFromArrays(self.xData, self.yData)
- if self.colorData is None:
- xAttrib, yAttrib, dAttrib = vertexBuffer(
- (self.xData, self.yData, dists))
- else:
- xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer(
- (self.xData, self.yData, self.colorData, dists))
- elif self.colorData is None:
- xAttrib, yAttrib = vertexBuffer((self.xData, self.yData))
- else:
- xAttrib, yAttrib, cAttrib = vertexBuffer(
- (self.xData, self.yData, self.colorData))
-
- self.xVboData = xAttrib
- self.yVboData = yAttrib
- self.distVboData = dAttrib
-
- if cAttrib is not None and self.colorData.dtype.kind == 'u':
- cAttrib.normalization = True # Normalize uint to [0, 1]
- self.colorVboData = cAttrib
- self.useColorVboData = cAttrib is not None
-
- def render(self, context):
- """Perform rendering
-
- :param RenderContext context: Rendering information
- """
- self.prepare()
- if self.fill is not None:
- self.fill.render(context)
- self._errorBars.render(context)
- self.lines.render(context)
- self.points.render(context)
-
- def discard(self):
- """Release VBOs"""
- if self.xVboData is not None:
- self.xVboData.vbo.discard()
-
- self.xVboData = None
- self.yVboData = None
- self.colorVboData = None
- self.distVboData = None
-
- self._errorBars.discard()
- if self.fill is not None:
- self.fill.discard()
-
- def isInitialized(self):
- return (self.xVboData is not None or
- self._errorBars.isInitialized() or
- (self.fill is not None and self.fill.isInitialized()))
-
- def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
- """Perform picking on the curve according to its rendering.
-
- The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax].
-
- In case a segment between 2 points with indices i, i+1 is picked,
- only its lower index end point (i.e., i) is added to the result.
- In case an end point with index i is picked it is added to the result,
- and the segment [i-1, i] is not tested for picking.
-
- :return: The indices of the picked data
- :rtype: Union[List[int],None]
- """
- if (self.marker is None and self.lineStyle is None) or \
- self.xMin > xPickMax or xPickMin > self.xMax or \
- self.yMin > yPickMax or yPickMin > self.yMax:
- return None
-
- # offset picking bounds
- xPickMin = xPickMin - self.offset[0]
- xPickMax = xPickMax - self.offset[0]
- yPickMin = yPickMin - self.offset[1]
- yPickMax = yPickMax - self.offset[1]
-
- if self.lineStyle is not None:
- # Using Cohen-Sutherland algorithm for line clipping
- with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
- codes = ((self.yData > yPickMax) << 3) | \
- ((self.yData < yPickMin) << 2) | \
- ((self.xData > xPickMax) << 1) | \
- (self.xData < xPickMin)
-
- notNaN = numpy.logical_not(numpy.logical_or(
- numpy.isnan(self.xData), numpy.isnan(self.yData)))
-
- # Add all points that are inside the picking area
- indices = numpy.nonzero(
- numpy.logical_and(codes == 0, notNaN))[0].tolist()
-
- # Segment that might cross the area with no end point inside it
- segToTestIdx = numpy.nonzero((codes[:-1] != 0) &
- (codes[1:] != 0) &
- ((codes[:-1] & codes[1:]) == 0))[0]
-
- TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0)
-
- for index in segToTestIdx:
- if index not in indices:
- x0, y0 = self.xData[index], self.yData[index]
- x1, y1 = self.xData[index + 1], self.yData[index + 1]
- code1 = codes[index + 1]
-
- # check for crossing with horizontal bounds
- # y0 == y1 is a never event:
- # => pt0 and pt1 in same vertical area are not in segToTest
- if code1 & TOP:
- x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0)
- elif code1 & BOTTOM:
- x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0)
- else:
- x = None # No horizontal bounds intersection test
-
- if x is not None and xPickMin <= x <= xPickMax:
- # Intersection
- indices.append(index)
-
- else:
- # check for crossing with vertical bounds
- # x0 == x1 is a never event (see remark for y)
- if code1 & RIGHT:
- y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0)
- elif code1 & LEFT:
- y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0)
- else:
- y = None # No vertical bounds intersection test
-
- if y is not None and yPickMin <= y <= yPickMax:
- # Intersection
- indices.append(index)
-
- indices.sort()
-
- else:
- with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
- indices = numpy.nonzero((self.xData >= xPickMin) &
- (self.xData <= xPickMax) &
- (self.yData >= yPickMin) &
- (self.yData <= yPickMax))[0].tolist()
-
- return tuple(indices) if len(indices) > 0 else None
diff --git a/silx/gui/plot/backends/glutils/GLPlotFrame.py b/silx/gui/plot/backends/glutils/GLPlotFrame.py
deleted file mode 100644
index c5ee75b..0000000
--- a/silx/gui/plot/backends/glutils/GLPlotFrame.py
+++ /dev/null
@@ -1,1219 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2014-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-This modules provides the rendering of plot titles, axes and grid.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/04/2017"
-
-
-# TODO
-# keep aspect ratio managed here?
-# smarter dirty flag handling?
-
-import datetime as dt
-import math
-import weakref
-import logging
-from collections import namedtuple
-
-import numpy
-
-from ...._glutils import gl, Program
-from ..._utils import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
-from .GLSupport import mat4Ortho
-from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270
-from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10
-from ..._utils.dtime_ticklayout import calcTicksAdaptive, bestFormatString
-from ..._utils.dtime_ticklayout import timestamp
-
-_logger = logging.getLogger(__name__)
-
-
-# PlotAxis ####################################################################
-
-class PlotAxis(object):
- """Represents a 1D axis of the plot.
- This class is intended to be used with :class:`GLPlotFrame`.
- """
-
- def __init__(self, plotFrame,
- tickLength=(0., 0.),
- foregroundColor=(0., 0., 0., 1.0),
- labelAlign=CENTER, labelVAlign=CENTER,
- titleAlign=CENTER, titleVAlign=CENTER,
- titleRotate=0, titleOffset=(0., 0.)):
- self._ticks = None
-
- self._plotFrameRef = weakref.ref(plotFrame)
-
- self._isDateTime = False
- self._timeZone = None
- self._isLog = False
- self._dataRange = 1., 100.
- self._displayCoords = (0., 0.), (1., 0.)
- self._title = ''
-
- self._tickLength = tickLength
- self._foregroundColor = foregroundColor
- self._labelAlign = labelAlign
- self._labelVAlign = labelVAlign
- self._titleAlign = titleAlign
- self._titleVAlign = titleVAlign
- self._titleRotate = titleRotate
- self._titleOffset = titleOffset
-
- @property
- def dataRange(self):
- """The range of the data represented on the axis as a tuple
- of 2 floats: (min, max)."""
- return self._dataRange
-
- @dataRange.setter
- def dataRange(self, dataRange):
- assert len(dataRange) == 2
- assert dataRange[0] <= dataRange[1]
- dataRange = float(dataRange[0]), float(dataRange[1])
-
- if dataRange != self._dataRange:
- self._dataRange = dataRange
- self._dirtyTicks()
-
- @property
- def isLog(self):
- """Whether the axis is using a log10 scale or not as a bool."""
- return self._isLog
-
- @isLog.setter
- def isLog(self, isLog):
- isLog = bool(isLog)
- if isLog != self._isLog:
- self._isLog = isLog
- self._dirtyTicks()
-
- @property
- def timeZone(self):
- """Returnss datetime.tzinfo that is used if this axis plots date times."""
- return self._timeZone
-
- @timeZone.setter
- def timeZone(self, tz):
- """Sets dateetime.tzinfo that is used if this axis plots date times."""
- self._timeZone = tz
- self._dirtyTicks()
-
- @property
- def isTimeSeries(self):
- """Whether the axis is showing floats as datetime objects"""
- return self._isDateTime
-
- @isTimeSeries.setter
- def isTimeSeries(self, isTimeSeries):
- isTimeSeries = bool(isTimeSeries)
- if isTimeSeries != self._isDateTime:
- self._isDateTime = isTimeSeries
- self._dirtyTicks()
-
- @property
- def displayCoords(self):
- """The coordinates of the start and end points of the axis
- in display space (i.e., in pixels) as a tuple of 2 tuples of
- 2 floats: ((x0, y0), (x1, y1)).
- """
- return self._displayCoords
-
- @displayCoords.setter
- def displayCoords(self, displayCoords):
- assert len(displayCoords) == 2
- assert len(displayCoords[0]) == 2
- assert len(displayCoords[1]) == 2
- displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1])
- if displayCoords != self._displayCoords:
- self._displayCoords = displayCoords
- self._dirtyTicks()
-
- @property
- def devicePixelRatio(self):
- """Returns the ratio between qt pixels and device pixels."""
- plotFrame = self._plotFrameRef()
- return plotFrame.devicePixelRatio if plotFrame is not None else 1.
-
- @property
- def title(self):
- """The text label associated with this axis as a str in latin-1."""
- return self._title
-
- @title.setter
- def title(self, title):
- if title != self._title:
- self._title = title
- self._dirtyPlotFrame()
-
- @property
- def titleOffset(self):
- """Title offset in pixels (x: int, y: int)"""
- return self._titleOffset
-
- @titleOffset.setter
- def titleOffset(self, offset):
- if offset != self._titleOffset:
- self._titleOffset = offset
- self._dirtyTicks()
-
- @property
- def foregroundColor(self):
- """Color used for frame and labels"""
- return self._foregroundColor
-
- @foregroundColor.setter
- def foregroundColor(self, color):
- """Color used for frame and labels"""
- assert len(color) == 4, \
- "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
- if self._foregroundColor != color:
- self._foregroundColor = color
- self._dirtyTicks()
-
- @property
- def ticks(self):
- """Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
- if self._ticks is None:
- self._ticks = tuple(self._ticksGenerator())
- return self._ticks
-
- def getVerticesAndLabels(self):
- """Create the list of vertices for axis and associated text labels.
-
- :returns: A tuple: List of 2D line vertices, List of Text2D labels.
- """
- vertices = list(self.displayCoords) # Add start and end points
- labels = []
- tickLabelsSize = [0., 0.]
-
- xTickLength, yTickLength = self._tickLength
- xTickLength *= self.devicePixelRatio
- yTickLength *= self.devicePixelRatio
- for (xPixel, yPixel), dataPos, text in self.ticks:
- if text is None:
- tickScale = 0.5
- else:
- tickScale = 1.
-
- label = Text2D(text=text,
- color=self._foregroundColor,
- x=xPixel - xTickLength,
- y=yPixel - yTickLength,
- align=self._labelAlign,
- valign=self._labelVAlign,
- devicePixelRatio=self.devicePixelRatio)
-
- width, height = label.size
- if width > tickLabelsSize[0]:
- tickLabelsSize[0] = width
- if height > tickLabelsSize[1]:
- tickLabelsSize[1] = height
-
- labels.append(label)
-
- vertices.append((xPixel, yPixel))
- vertices.append((xPixel + tickScale * xTickLength,
- yPixel + tickScale * yTickLength))
-
- (x0, y0), (x1, y1) = self.displayCoords
- xAxisCenter = 0.5 * (x0 + x1)
- yAxisCenter = 0.5 * (y0 + y1)
-
- xOffset, yOffset = self.titleOffset
-
- # Adaptative title positioning:
- # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2)
- # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm
- # xOffset -= 3 * xTickLength
- # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm
- # yOffset -= 3 * yTickLength
-
- axisTitle = Text2D(text=self.title,
- color=self._foregroundColor,
- x=xAxisCenter + xOffset,
- y=yAxisCenter + yOffset,
- align=self._titleAlign,
- valign=self._titleVAlign,
- rotate=self._titleRotate,
- devicePixelRatio=self.devicePixelRatio)
- labels.append(axisTitle)
-
- return vertices, labels
-
- def _dirtyPlotFrame(self):
- """Dirty parent GLPlotFrame"""
- plotFrame = self._plotFrameRef()
- if plotFrame is not None:
- plotFrame._dirty()
-
- def _dirtyTicks(self):
- """Mark ticks as dirty and notify listener (i.e., background)."""
- self._ticks = None
- self._dirtyPlotFrame()
-
- @staticmethod
- def _frange(start, stop, step):
- """range for float (including stop)."""
- while start <= stop:
- yield start
- start += step
-
- def _ticksGenerator(self):
- """Generator of ticks as tuples:
- ((x, y) in display, dataPos, textLabel).
- """
- dataMin, dataMax = self.dataRange
- if self.isLog and dataMin <= 0.:
- _logger.warning(
- 'Getting ticks while isLog=True and dataRange[0]<=0.')
- dataMin = 1.
- if dataMax < dataMin:
- dataMax = 1.
-
- if dataMin != dataMax: # data range is not null
- (x0, y0), (x1, y1) = self.displayCoords
-
- if self.isLog:
-
- if self.isTimeSeries:
- _logger.warning("Time series not implemented for log-scale")
-
- logMin, logMax = math.log10(dataMin), math.log10(dataMax)
- tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax)
-
- xScale = (x1 - x0) / (logMax - logMin)
- yScale = (y1 - y0) / (logMax - logMin)
-
- for logPos in self._frange(tickMin, tickMax, step):
- if logMin <= logPos <= logMax:
- dataPos = 10 ** logPos
- xPixel = x0 + (logPos - logMin) * xScale
- yPixel = y0 + (logPos - logMin) * yScale
- text = '1e%+03d' % logPos
- yield ((xPixel, yPixel), dataPos, text)
-
- if step == 1:
- ticks = list(self._frange(tickMin, tickMax, step))[:-1]
- for logPos in ticks:
- dataOrigPos = 10 ** logPos
- for index in range(2, 10):
- dataPos = dataOrigPos * index
- if dataMin <= dataPos <= dataMax:
- logSubPos = math.log10(dataPos)
- xPixel = x0 + (logSubPos - logMin) * xScale
- yPixel = y0 + (logSubPos - logMin) * yScale
- yield ((xPixel, yPixel), dataPos, None)
-
- else:
- xScale = (x1 - x0) / (dataMax - dataMin)
- yScale = (y1 - y0) / (dataMax - dataMin)
-
- nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
-
- # Density of 1.3 label per 92 pixels
- # i.e., 1.3 label per inch on a 92 dpi screen
- tickDensity = 1.3 / 92
-
- if not self.isTimeSeries:
- tickMin, tickMax, step, nbFrac = niceNumbersAdaptative(
- dataMin, dataMax, nbPixels, tickDensity)
-
- for dataPos in self._frange(tickMin, tickMax, step):
- if dataMin <= dataPos <= dataMax:
- xPixel = x0 + (dataPos - dataMin) * xScale
- yPixel = y0 + (dataPos - dataMin) * yScale
-
- if nbFrac == 0:
- text = '%g' % dataPos
- else:
- text = ('%.' + str(nbFrac) + 'f') % dataPos
- yield ((xPixel, yPixel), dataPos, text)
- else:
- # Time series
- dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone)
- dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone)
-
- tickDateTimes, spacing, unit = calcTicksAdaptive(
- dtMin, dtMax, nbPixels, tickDensity)
-
- for tickDateTime in tickDateTimes:
- if dtMin <= tickDateTime <= dtMax:
-
- dataPos = timestamp(tickDateTime)
- xPixel = x0 + (dataPos - dataMin) * xScale
- yPixel = y0 + (dataPos - dataMin) * yScale
-
- fmtStr = bestFormatString(spacing, unit)
- text = tickDateTime.strftime(fmtStr)
-
- yield ((xPixel, yPixel), dataPos, text)
-
-
-# GLPlotFrame #################################################################
-
-class GLPlotFrame(object):
- """Base class for rendering a 2D frame surrounded by axes."""
-
- _TICK_LENGTH_IN_PIXELS = 5
- _LINE_WIDTH = 1
-
- _SHADERS = {
- 'vertex': """
- attribute vec2 position;
- uniform mat4 matrix;
-
- void main(void) {
- gl_Position = matrix * vec4(position, 0.0, 1.0);
- }
- """,
- 'fragment': """
- uniform vec4 color;
- uniform float tickFactor; /* = 1./tickLength or 0. for solid line */
-
- void main(void) {
- if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) {
- gl_FragColor = color;
- } else {
- discard;
- }
- }
- """
- }
-
- _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom'))
-
- # Margins used when plot frame is not displayed
- _NoDisplayMargins = _Margins(0, 0, 0, 0)
-
- def __init__(self, marginRatios, foregroundColor, gridColor):
- """
- :param List[float] marginRatios:
- The ratios of margins around plot area for axis and labels.
- (left, top, right, bottom) as float in [0., 1.]
- :param foregroundColor: color used for the frame and labels.
- :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
- :param gridColor: color used for grid lines.
- :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
- """
- self._renderResources = None
-
- self.__marginRatios = marginRatios
- self.__marginsCache = None
-
- self._foregroundColor = foregroundColor
- self._gridColor = gridColor
-
- self.axes = [] # List of PlotAxis to be updated by subclasses
-
- self._grid = False
- self._size = 0., 0.
- self._title = ''
-
- self._devicePixelRatio = 1.
-
- @property
- def isDirty(self):
- """True if it need to refresh graphic rendering, False otherwise."""
- return self._renderResources is None
-
- GRID_NONE = 0
- GRID_MAIN_TICKS = 1
- GRID_SUB_TICKS = 2
- GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
-
- @property
- def foregroundColor(self):
- """Color used for frame and labels"""
- return self._foregroundColor
-
- @foregroundColor.setter
- def foregroundColor(self, color):
- """Color used for frame and labels"""
- assert len(color) == 4, \
- "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
- if self._foregroundColor != color:
- self._foregroundColor = color
- for axis in self.axes:
- axis.foregroundColor = color
- self._dirty()
-
- @property
- def gridColor(self):
- """Color used for frame and labels"""
- return self._gridColor
-
- @gridColor.setter
- def gridColor(self, color):
- """Color used for frame and labels"""
- assert len(color) == 4, \
- "gridColor must have length 4, got {}".format(len(self._gridColor))
- if self._gridColor != color:
- self._gridColor = color
- self._dirty()
-
- @property
- def marginRatios(self):
- """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1].
- """
- return self.__marginRatios
-
- @marginRatios.setter
- def marginRatios(self, ratios):
- ratios = tuple(float(v) for v in ratios)
- assert len(ratios) == 4
- for value in ratios:
- assert 0. <= value <= 1.
- assert ratios[0] + ratios[2] < 1.
- assert ratios[1] + ratios[3] < 1.
-
- if self.__marginRatios != ratios:
- self.__marginRatios = ratios
- self.__marginsCache = None # Clear cached margins
- self._dirty()
-
- @property
- def margins(self):
- """Margins in pixels around the plot."""
- if self.__marginsCache is None:
- width, height = self.size
- left, top, right, bottom = self.marginRatios
- self.__marginsCache = self._Margins(
- left=int(left*width),
- right=int(right*width),
- top=int(top*height),
- bottom=int(bottom*height))
- return self.__marginsCache
-
- @property
- def devicePixelRatio(self):
- return self._devicePixelRatio
-
- @devicePixelRatio.setter
- def devicePixelRatio(self, ratio):
- if ratio != self._devicePixelRatio:
- self._devicePixelRatio = ratio
- self._dirty()
-
- @property
- def grid(self):
- """Grid display mode:
- - 0: No grid.
- - 1: Grid on main ticks.
- - 2: Grid on sub-ticks for log scale axes.
- - 3: Grid on main and sub ticks."""
- return self._grid
-
- @grid.setter
- def grid(self, grid):
- assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS,
- self.GRID_SUB_TICKS, self.GRID_ALL_TICKS)
- if grid != self._grid:
- self._grid = grid
- self._dirty()
-
- @property
- def size(self):
- """Size in device pixels of the plot area including margins."""
- return self._size
-
- @size.setter
- def size(self, size):
- assert len(size) == 2
- size = tuple(size)
- if size != self._size:
- self._size = size
- self.__marginsCache = None # Clear cached margins
- self._dirty()
-
- @property
- def plotOrigin(self):
- """Plot area origin (left, top) in widget coordinates in pixels."""
- return self.margins.left, self.margins.top
-
- @property
- def plotSize(self):
- """Plot area size (width, height) in pixels."""
- w, h = self.size
- w -= self.margins.left + self.margins.right
- h -= self.margins.top + self.margins.bottom
- return w, h
-
- @property
- def title(self):
- """Main title as a str in latin-1."""
- return self._title
-
- @title.setter
- def title(self, title):
- if title != self._title:
- self._title = title
- self._dirty()
-
- # In-place update
- # if self._renderResources is not None:
- # self._renderResources[-1][-1].text = title
-
- def _dirty(self):
- # When Text2D require discard we need to handle it
- self._renderResources = None
-
- def _buildGridVertices(self):
- if self._grid == self.GRID_NONE:
- return []
-
- elif self._grid == self.GRID_MAIN_TICKS:
- def test(text):
- return text is not None
- elif self._grid == self.GRID_SUB_TICKS:
- def test(text):
- return text is None
- elif self._grid == self.GRID_ALL_TICKS:
- def test(_):
- return True
- else:
- logging.warning('Wrong grid mode: %d' % self._grid)
- return []
-
- return self._buildGridVerticesWithTest(test)
-
- def _buildGridVerticesWithTest(self, test):
- """Override in subclass to generate grid vertices"""
- return []
-
- def _buildVerticesAndLabels(self):
- # To fill with copy of axes lists
- vertices = []
- labels = []
-
- for axis in self.axes:
- axisVertices, axisLabels = axis.getVerticesAndLabels()
- vertices += axisVertices
- labels += axisLabels
-
- vertices = numpy.array(vertices, dtype=numpy.float32)
-
- # Add main title
- xTitle = (self.size[0] + self.margins.left -
- self.margins.right) // 2
- yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
- labels.append(Text2D(text=self.title,
- color=self._foregroundColor,
- x=xTitle,
- y=yTitle,
- align=CENTER,
- valign=BOTTOM,
- devicePixelRatio=self.devicePixelRatio))
-
- # grid
- gridVertices = numpy.array(self._buildGridVertices(),
- dtype=numpy.float32)
-
- self._renderResources = (vertices, gridVertices, labels)
-
- _program = Program(
- _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position')
-
- def render(self):
- if self.margins == self._NoDisplayMargins:
- return
-
- if self._renderResources is None:
- self._buildVerticesAndLabels()
- vertices, gridVertices, labels = self._renderResources
-
- width, height = self.size
- matProj = mat4Ortho(0, width, height, 0, 1, -1)
-
- gl.glViewport(0, 0, width, height)
-
- prog = self._program
- prog.use()
-
- gl.glLineWidth(self._LINE_WIDTH)
-
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor)
- gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
-
- gl.glEnableVertexAttribArray(prog.attributes['position'])
- gl.glVertexAttribPointer(prog.attributes['position'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, vertices)
-
- gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
-
- for label in labels:
- label.render(matProj)
-
- def renderGrid(self):
- if self._grid == self.GRID_NONE:
- return
-
- if self._renderResources is None:
- self._buildVerticesAndLabels()
- vertices, gridVertices, labels = self._renderResources
-
- width, height = self.size
- matProj = mat4Ortho(0, width, height, 0, 1, -1)
-
- gl.glViewport(0, 0, width, height)
-
- prog = self._program
- prog.use()
-
- gl.glLineWidth(self._LINE_WIDTH)
- gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
- matProj.astype(numpy.float32))
- gl.glUniform4f(prog.uniforms['color'], *self._gridColor)
- gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
-
- gl.glEnableVertexAttribArray(prog.attributes['position'])
- gl.glVertexAttribPointer(prog.attributes['position'],
- 2,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0, gridVertices)
-
- gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices))
-
-
-# GLPlotFrame2D ###############################################################
-
-class GLPlotFrame2D(GLPlotFrame):
- def __init__(self, marginRatios, foregroundColor, gridColor):
- """
- :param List[float] marginRatios:
- The ratios of margins around plot area for axis and labels.
- (left, top, right, bottom) as float in [0., 1.]
- :param foregroundColor: color used for the frame and labels.
- :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
- :param gridColor: color used for grid lines.
- :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
-
- """
- super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor)
- self.axes.append(PlotAxis(self,
- tickLength=(0., -5.),
- foregroundColor=self._foregroundColor,
- labelAlign=CENTER, labelVAlign=TOP,
- titleAlign=CENTER, titleVAlign=TOP,
- titleRotate=0))
-
- self._x2AxisCoords = ()
-
- self.axes.append(PlotAxis(self,
- tickLength=(5., 0.),
- foregroundColor=self._foregroundColor,
- labelAlign=RIGHT, labelVAlign=CENTER,
- titleAlign=CENTER, titleVAlign=BOTTOM,
- titleRotate=ROTATE_270))
-
- self._y2Axis = PlotAxis(self,
- tickLength=(-5., 0.),
- foregroundColor=self._foregroundColor,
- labelAlign=LEFT, labelVAlign=CENTER,
- titleAlign=CENTER, titleVAlign=TOP,
- titleRotate=ROTATE_270)
-
- self._isYAxisInverted = False
-
- self._dataRanges = {
- 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)}
-
- self._baseVectors = (1., 0.), (0., 1.)
-
- self._transformedDataRanges = None
- self._transformedDataProjMat = None
- self._transformedDataY2ProjMat = None
-
- def _dirty(self):
- super(GLPlotFrame2D, self)._dirty()
- self._transformedDataRanges = None
- self._transformedDataProjMat = None
- self._transformedDataY2ProjMat = None
-
- @property
- def isDirty(self):
- """True if it need to refresh graphic rendering, False otherwise."""
- return (super(GLPlotFrame2D, self).isDirty or
- self._transformedDataRanges is None or
- self._transformedDataProjMat is None or
- self._transformedDataY2ProjMat is None)
-
- @property
- def xAxis(self):
- return self.axes[0]
-
- @property
- def yAxis(self):
- return self.axes[1]
-
- @property
- def y2Axis(self):
- return self._y2Axis
-
- @property
- def isY2Axis(self):
- """Whether to display the left Y axis or not."""
- return len(self.axes) == 3
-
- @isY2Axis.setter
- def isY2Axis(self, isY2Axis):
- if isY2Axis != self.isY2Axis:
- if isY2Axis:
- self.axes.append(self._y2Axis)
- else:
- self.axes = self.axes[:2]
-
- self._dirty()
-
- @property
- def isYAxisInverted(self):
- """Whether Y axes are inverted or not as a bool."""
- return self._isYAxisInverted
-
- @isYAxisInverted.setter
- def isYAxisInverted(self, value):
- value = bool(value)
- if value != self._isYAxisInverted:
- self._isYAxisInverted = value
- self._dirty()
-
- DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.)
- """Values of baseVectors for orthogonal axes."""
-
- @property
- def baseVectors(self):
- """Coordinates of the X and Y axes in the orthogonal plot coords.
-
- Raises ValueError if corresponding matrix is singular.
-
- 2 tuples of 2 floats: (xx, xy), (yx, yy)
- """
- return self._baseVectors
-
- @baseVectors.setter
- def baseVectors(self, baseVectors):
- self._dirty()
-
- (xx, xy), (yx, yy) = baseVectors
- vectors = (float(xx), float(xy)), (float(yx), float(yy))
-
- det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1])
- if det == 0.:
- raise ValueError("Singular matrix for base vectors: " +
- str(vectors))
-
- if vectors != self._baseVectors:
- self._baseVectors = vectors
- self._dirty()
-
- def _updateTitleOffset(self):
- """Update axes title offset according to margins"""
- margins = self.margins
- self.xAxis.titleOffset = 0, margins.bottom // 2
- self.yAxis.titleOffset = -3 * margins.left // 4, 0
- self.y2Axis.titleOffset = 3 * margins.right // 4, 0
-
- # Override size and marginRatios setters to update titleOffsets
- @GLPlotFrame.size.setter
- def size(self, size):
- GLPlotFrame.size.fset(self, size)
- self._updateTitleOffset()
-
- @GLPlotFrame.marginRatios.setter
- def marginRatios(self, ratios):
- GLPlotFrame.marginRatios.fset(self, ratios)
- self._updateTitleOffset()
-
- @property
- def dataRanges(self):
- """Ranges of data visible in the plot on x, y and y2 axes.
-
- This is different to the axes range when axes are not orthogonal.
-
- Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max))
- """
- return self._DataRanges(self._dataRanges['x'],
- self._dataRanges['y'],
- self._dataRanges['y2'])
-
- @staticmethod
- def _clipToSafeRange(min_, max_, isLog):
- # Clip range if needed
- minLimit = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN
- min_ = numpy.clip(min_, minLimit, FLOAT32_SAFE_MAX)
- max_ = numpy.clip(max_, minLimit, FLOAT32_SAFE_MAX)
- assert min_ < max_
- return min_, max_
-
- def setDataRanges(self, x=None, y=None, y2=None):
- """Set data range over each axes.
-
- The provided ranges are clipped to possible values
- (i.e., 32 float range + positive range for log scale).
-
- :param x: (min, max) data range over X axis
- :param y: (min, max) data range over Y axis
- :param y2: (min, max) data range over Y2 axis
- """
- if x is not None:
- self._dataRanges['x'] = \
- self._clipToSafeRange(x[0], x[1], self.xAxis.isLog)
-
- if y is not None:
- self._dataRanges['y'] = \
- self._clipToSafeRange(y[0], y[1], self.yAxis.isLog)
-
- if y2 is not None:
- self._dataRanges['y2'] = \
- self._clipToSafeRange(y2[0], y2[1], self.y2Axis.isLog)
-
- self.xAxis.dataRange = self._dataRanges['x']
- self.yAxis.dataRange = self._dataRanges['y']
- self.y2Axis.dataRange = self._dataRanges['y2']
-
- _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2'))
-
- @property
- def transformedDataRanges(self):
- """Bounds of the displayed area in transformed data coordinates
- (i.e., log scale applied if any as well as skew)
-
- 3-tuple of 2-tuple (min, max) for each axis: x, y, y2.
- """
- if self._transformedDataRanges is None:
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges
-
- if self.xAxis.isLog:
- try:
- xMin = math.log10(xMin)
- except ValueError:
- _logger.info('xMin: warning log10(%f)', xMin)
- xMin = 0.
- try:
- xMax = math.log10(xMax)
- except ValueError:
- _logger.info('xMax: warning log10(%f)', xMax)
- xMax = 0.
-
- if self.yAxis.isLog:
- try:
- yMin = math.log10(yMin)
- except ValueError:
- _logger.info('yMin: warning log10(%f)', yMin)
- yMin = 0.
- try:
- yMax = math.log10(yMax)
- except ValueError:
- _logger.info('yMax: warning log10(%f)', yMax)
- yMax = 0.
-
- try:
- y2Min = math.log10(y2Min)
- except ValueError:
- _logger.info('yMin: warning log10(%f)', y2Min)
- y2Min = 0.
- try:
- y2Max = math.log10(y2Max)
- except ValueError:
- _logger.info('yMax: warning log10(%f)', y2Max)
- y2Max = 0.
-
- self._transformedDataRanges = self._DataRanges(
- (xMin, xMax), (yMin, yMax), (y2Min, y2Max))
-
- return self._transformedDataRanges
-
- @property
- def transformedDataProjMat(self):
- """Orthographic projection matrix for rendering transformed data
-
- :type: numpy.matrix
- """
- if self._transformedDataProjMat is None:
- xMin, xMax = self.transformedDataRanges.x
- yMin, yMax = self.transformedDataRanges.y
-
- if self.isYAxisInverted:
- mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
- else:
- mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
- self._transformedDataProjMat = mat
-
- return self._transformedDataProjMat
-
- @property
- def transformedDataY2ProjMat(self):
- """Orthographic projection matrix for rendering transformed data
- for the 2nd Y axis
-
- :type: numpy.matrix
- """
- if self._transformedDataY2ProjMat is None:
- xMin, xMax = self.transformedDataRanges.x
- y2Min, y2Max = self.transformedDataRanges.y2
-
- if self.isYAxisInverted:
- mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
- else:
- mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
- self._transformedDataY2ProjMat = mat
-
- return self._transformedDataY2ProjMat
-
- def dataToPixel(self, x, y, axis='left'):
- """Convert data coordinate to widget pixel coordinate.
- """
- assert axis in ('left', 'right')
-
- trBounds = self.transformedDataRanges
-
- if self.xAxis.isLog:
- if x < FLOAT32_MINPOS:
- return None
- xDataTr = math.log10(x)
- else:
- xDataTr = x
-
- if self.yAxis.isLog:
- if y < FLOAT32_MINPOS:
- return None
- yDataTr = math.log10(y)
- else:
- yDataTr = y
-
- # Non-orthogonal axes
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- skew_mat = numpy.array(((xx, yx), (xy, yy)))
-
- coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr)))
- xDataTr, yDataTr = coords
-
- plotWidth, plotHeight = self.plotSize
-
- xPixel = int(self.margins.left +
- plotWidth * (xDataTr - trBounds.x[0]) /
- (trBounds.x[1] - trBounds.x[0]))
-
- usedAxis = trBounds.y if axis == "left" else trBounds.y2
- yOffset = (plotHeight * (yDataTr - usedAxis[0]) /
- (usedAxis[1] - usedAxis[0]))
-
- if self.isYAxisInverted:
- yPixel = int(self.margins.top + yOffset)
- else:
- yPixel = int(self.size[1] - self.margins.bottom - yOffset)
-
- return xPixel, yPixel
-
- def pixelToData(self, x, y, axis="left"):
- """Convert pixel position to data coordinates.
-
- :param float x: X coord
- :param float y: Y coord
- :param str axis: Y axis to use in ('left', 'right')
- :return: (x, y) position in data coords
- """
- assert axis in ("left", "right")
-
- plotWidth, plotHeight = self.plotSize
-
- trBounds = self.transformedDataRanges
-
- xData = (x - self.margins.left + 0.5) / float(plotWidth)
- xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0])
-
- usedAxis = trBounds.y if axis == "left" else trBounds.y2
- if self.isYAxisInverted:
- yData = (y - self.margins.top + 0.5) / float(plotHeight)
- yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
- else:
- yData = self.size[1] - self.margins.bottom - y - 0.5
- yData /= float(plotHeight)
- yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
-
- # non-orthogonal axis
- if self.baseVectors != self.DEFAULT_BASE_VECTORS:
- (xx, xy), (yx, yy) = self.baseVectors
- skew_mat = numpy.array(((xx, yx), (xy, yy)))
- skew_mat = numpy.linalg.inv(skew_mat)
-
- coords = numpy.dot(skew_mat, numpy.array((xData, yData)))
- xData, yData = coords
-
- if self.xAxis.isLog:
- xData = pow(10, xData)
- if self.yAxis.isLog:
- yData = pow(10, yData)
-
- return xData, yData
-
- def _buildGridVerticesWithTest(self, test):
- vertices = []
-
- if self.baseVectors == self.DEFAULT_BASE_VECTORS:
- for axis in self.axes:
- for (xPixel, yPixel), data, text in axis.ticks:
- if test(text):
- vertices.append((xPixel, yPixel))
- if axis == self.xAxis:
- vertices.append((xPixel, self.margins.top))
- elif axis == self.yAxis:
- vertices.append((self.size[0] - self.margins.right,
- yPixel))
- else: # axis == self.y2Axis
- vertices.append((self.margins.left, yPixel))
-
- else:
- # Get plot corners in data coords
- plotLeft, plotTop = self.plotOrigin
- plotWidth, plotHeight = self.plotSize
-
- corners = [(plotLeft, plotTop),
- (plotLeft, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop + plotHeight),
- (plotLeft + plotWidth, plotTop)]
-
- for axis in self.axes:
- if axis == self.xAxis:
- cornersInData = numpy.array([
- self.pixelToData(x, y) for (x, y) in corners])
- borders = ((cornersInData[0], cornersInData[3]), # top
- (cornersInData[1], cornersInData[0]), # left
- (cornersInData[3], cornersInData[2])) # right
-
- for (xPixel, yPixel), data, text in axis.ticks:
- if test(text):
- for (x0, y0), (x1, y1) in borders:
- if min(x0, x1) <= data < max(x0, x1):
- yIntersect = (data - x0) * \
- (y1 - y0) / (x1 - x0) + y0
-
- pixelPos = self.dataToPixel(
- data, yIntersect)
- if pixelPos is not None:
- vertices.append((xPixel, yPixel))
- vertices.append(pixelPos)
- break # Stop at first intersection
-
- else: # y or y2 axes
- if axis == self.yAxis:
- axis_name = 'left'
- cornersInData = numpy.array([
- self.pixelToData(x, y) for (x, y) in corners])
- borders = (
- (cornersInData[3], cornersInData[2]), # right
- (cornersInData[0], cornersInData[3]), # top
- (cornersInData[2], cornersInData[1])) # bottom
-
- else: # axis == self.y2Axis
- axis_name = 'right'
- corners = numpy.array([self.pixelToData(
- x, y, axis='right') for (x, y) in corners])
- borders = (
- (cornersInData[1], cornersInData[0]), # left
- (cornersInData[0], cornersInData[3]), # top
- (cornersInData[2], cornersInData[1])) # bottom
-
- for (xPixel, yPixel), data, text in axis.ticks:
- if test(text):
- for (x0, y0), (x1, y1) in borders:
- if min(y0, y1) <= data < max(y0, y1):
- xIntersect = (data - y0) * \
- (x1 - x0) / (y1 - y0) + x0
-
- pixelPos = self.dataToPixel(
- xIntersect, data, axis=axis_name)
- if pixelPos is not None:
- vertices.append((xPixel, yPixel))
- vertices.append(pixelPos)
- break # Stop at first intersection
-
- return vertices
-
- def _buildVerticesAndLabels(self):
- width, height = self.size
-
- xCoords = (self.margins.left - 0.5,
- width - self.margins.right + 0.5)
- yCoords = (height - self.margins.bottom + 0.5,
- self.margins.top - 0.5)
-
- self.axes[0].displayCoords = ((xCoords[0], yCoords[0]),
- (xCoords[1], yCoords[0]))
-
- self._x2AxisCoords = ((xCoords[0], yCoords[1]),
- (xCoords[1], yCoords[1]))
-
- if self.isYAxisInverted:
- # Y axes are inverted, axes coordinates are inverted
- yCoords = yCoords[1], yCoords[0]
-
- self.axes[1].displayCoords = ((xCoords[0], yCoords[0]),
- (xCoords[0], yCoords[1]))
-
- self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]),
- (xCoords[1], yCoords[1]))
-
- super(GLPlotFrame2D, self)._buildVerticesAndLabels()
-
- vertices, gridVertices, labels = self._renderResources
-
- # Adds vertices for borders without axis
- extraVertices = []
- extraVertices += self._x2AxisCoords
- if not self.isY2Axis:
- extraVertices += self._y2Axis.displayCoords
-
- extraVertices = numpy.array(
- extraVertices, copy=False, dtype=numpy.float32)
- vertices = numpy.append(vertices, extraVertices, axis=0)
-
- self._renderResources = (vertices, gridVertices, labels)
-
- @property
- def foregroundColor(self):
- """Color used for frame and labels"""
- return self._foregroundColor
-
- @foregroundColor.setter
- def foregroundColor(self, color):
- """Color used for frame and labels"""
- assert len(color) == 4, \
- "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
- if self._foregroundColor != color:
- self._y2Axis.foregroundColor = color
- GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/silx/gui/plot/items/__init__.py b/silx/gui/plot/items/__init__.py
deleted file mode 100644
index 0484025..0000000
--- a/silx/gui/plot/items/__init__.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides classes that describes :class:`.PlotWidget` content.
-
-Instances of those classes are returned by :class:`.PlotWidget` methods that give
-access to its content such as :meth:`.PlotWidget.getCurve`, :meth:`.PlotWidget.getImage`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "22/06/2017"
-
-from .core import (Item, DataItem, # noqa
- LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
- SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
- AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa
- ComplexMixIn, ItemChangedType, PointsBase) # noqa
-from .complex import ImageComplexData # noqa
-from .curve import Curve, CurveStyle # noqa
-from .histogram import Histogram # noqa
-from .image import ImageBase, ImageData, ImageRgba, ImageStack, MaskImageData # noqa
-from .shape import Shape, BoundingRect, XAxisExtent, YAxisExtent # noqa
-from .scatter import Scatter # noqa
-from .marker import MarkerBase, Marker, XMarker, YMarker # noqa
-from .axis import Axis, XAxis, YAxis, YRightAxis
-
-DATA_ITEMS = (ImageComplexData, Curve, Histogram, ImageBase, Scatter,
- BoundingRect, XAxisExtent, YAxisExtent)
-"""Classes of items representing data and to consider to compute data bounds.
-"""
diff --git a/silx/gui/plot/items/axis.py b/silx/gui/plot/items/axis.py
deleted file mode 100644
index be85e6a..0000000
--- a/silx/gui/plot/items/axis.py
+++ /dev/null
@@ -1,569 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the class for axes of the :class:`PlotWidget`.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "22/11/2018"
-
-import datetime as dt
-import enum
-import logging
-
-import dateutil.tz
-
-from ... import qt
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TickMode(enum.Enum):
- """Determines if ticks are regular number or datetimes."""
- DEFAULT = 0 # Ticks are regular numbers
- TIME_SERIES = 1 # Ticks are datetime objects
-
-
-class Axis(qt.QObject):
- """This class describes and controls a plot axis.
-
- Note: This is an abstract class.
- """
- # States are half-stored on the backend of the plot, and half-stored on this
- # object.
- # TODO It would be good to store all the states of an axis in this object.
- # i.e. vmin and vmax
-
- LINEAR = "linear"
- """Constant defining a linear scale"""
-
- LOGARITHMIC = "log"
- """Constant defining a logarithmic scale"""
-
- _SCALES = set([LINEAR, LOGARITHMIC])
-
- sigInvertedChanged = qt.Signal(bool)
- """Signal emitted when axis orientation has changed"""
-
- sigScaleChanged = qt.Signal(str)
- """Signal emitted when axis scale has changed"""
-
- _sigLogarithmicChanged = qt.Signal(bool)
- """Signal emitted when axis scale has changed to or from logarithmic"""
-
- sigAutoScaleChanged = qt.Signal(bool)
- """Signal emitted when axis autoscale has changed"""
-
- sigLimitsChanged = qt.Signal(float, float)
- """Signal emitted when axis limits have changed"""
-
- def __init__(self, plot):
- """Constructor
-
- :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
- axis
- """
- qt.QObject.__init__(self, parent=plot)
- self._scale = self.LINEAR
- self._isAutoScale = True
- # Store default labels provided to setGraph[X|Y]Label
- self._defaultLabel = ''
- # Store currently displayed labels
- # Current label can differ from input one with active curve handling
- self._currentLabel = ''
-
- def _getPlot(self):
- """Returns the PlotWidget this Axis belongs to.
-
- :rtype: PlotWidget
- """
- plot = self.parent()
- if plot is None:
- raise RuntimeError("Axis no longer attached to a PlotWidget")
- return plot
-
- def _getBackend(self):
- """Returns the backend
-
- :rtype: BackendBase
- """
- return self._getPlot()._backend
-
- def getLimits(self):
- """Get the limits of this axis.
-
- :return: Minimum and maximum values of this axis as tuple
- """
- return self._internalGetLimits()
-
- def setLimits(self, vmin, vmax):
- """Set this axis limits.
-
- :param float vmin: minimum axis value
- :param float vmax: maximum axis value
- """
- vmin, vmax = self._checkLimits(vmin, vmax)
- if self.getLimits() == (vmin, vmax):
- return
-
- self._internalSetLimits(vmin, vmax)
- self._getPlot()._setDirtyPlot()
-
- self._emitLimitsChanged()
-
- def _emitLimitsChanged(self):
- """Emit axis sigLimitsChanged and PlotWidget limitsChanged event"""
- vmin, vmax = self.getLimits()
- self.sigLimitsChanged.emit(vmin, vmax)
- self._getPlot()._notifyLimitsChanged(emitSignal=False)
-
- def _checkLimits(self, vmin, vmax):
- """Makes sure axis range is not empty
-
- :param float vmin: Min axis value
- :param float vmax: Max axis value
- :return: (min, max) making sure min < max
- :rtype: 2-tuple of float
- """
- if vmax < vmin:
- _logger.debug('%s axis: max < min, inverting limits.', self._defaultLabel)
- vmin, vmax = vmax, vmin
- elif vmax == vmin:
- _logger.debug('%s axis: max == min, expanding limits.', self._defaultLabel)
- if vmin == 0.:
- vmin, vmax = -0.1, 0.1
- elif vmin < 0:
- vmin, vmax = vmin * 1.1, vmin * 0.9
- else: # xmin > 0
- vmin, vmax = vmin * 0.9, vmin * 1.1
-
- return vmin, vmax
-
- def isInverted(self):
- """Return True if the axis is inverted (top to bottom for the y-axis),
- False otherwise. It is always False for the X axis.
-
- :rtype: bool
- """
- return False
-
- def setInverted(self, isInverted):
- """Set the axis orientation.
-
- This is only available for the Y axis.
-
- :param bool flag: True for Y axis going from top to bottom,
- False for Y axis going from bottom to top
- """
- if isInverted == self.isInverted():
- return
- raise NotImplementedError()
-
- def getLabel(self):
- """Return the current displayed label of this axis.
-
- :param str axis: The Y axis for which to get the label (left or right)
- :rtype: str
- """
- return self._currentLabel
-
- def setLabel(self, label):
- """Set the label displayed on the plot for this axis.
-
- The provided label can be temporarily replaced by the label of the
- active curve if any.
-
- :param str label: The axis label
- """
- self._defaultLabel = label
- self._setCurrentLabel(label)
- self._getPlot()._setDirtyPlot()
-
- def _setCurrentLabel(self, label):
- """Define the label currently displayed.
-
- If the label is None or empty the default label is used.
-
- :param str label: Currently displayed label
- """
- if label is None or label == '':
- label = self._defaultLabel
- if label is None:
- label = ''
- self._currentLabel = label
- self._internalSetCurrentLabel(label)
-
- def getScale(self):
- """Return the name of the scale used by this axis.
-
- :rtype: str
- """
- return self._scale
-
- def setScale(self, scale):
- """Set the scale to be used by this axis.
-
- :param str scale: Name of the scale ("log", or "linear")
- """
- assert(scale in self._SCALES)
- if self._scale == scale:
- return
-
- # For the backward compatibility signal
- emitLog = self._scale == self.LOGARITHMIC or scale == self.LOGARITHMIC
-
- self._scale = scale
-
- # TODO hackish way of forcing update of curves and images
- plot = self._getPlot()
- for item in plot.getItems():
- item._updated()
- plot._invalidateDataRange()
-
- if scale == self.LOGARITHMIC:
- self._internalSetLogarithmic(True)
- elif scale == self.LINEAR:
- self._internalSetLogarithmic(False)
- else:
- raise ValueError("Scale %s unsupported" % scale)
-
- plot._forceResetZoom()
-
- self.sigScaleChanged.emit(self._scale)
- if emitLog:
- self._sigLogarithmicChanged.emit(self._scale == self.LOGARITHMIC)
-
- def _isLogarithmic(self):
- """Return True if this axis scale is logarithmic, False if linear.
-
- :rtype: bool
- """
- return self._scale == self.LOGARITHMIC
-
- def _setLogarithmic(self, flag):
- """Set the scale of this axes (either linear or logarithmic).
-
- :param bool flag: True to use a logarithmic scale, False for linear.
- """
- flag = bool(flag)
- self.setScale(self.LOGARITHMIC if flag else self.LINEAR)
-
- def getTimeZone(self):
- """Sets tzinfo that is used if this axis plots date times.
-
- None means the datetimes are interpreted as local time.
-
- :rtype: datetime.tzinfo of None.
- """
- raise NotImplementedError()
-
- def setTimeZone(self, tz):
- """Sets tzinfo that is used if this axis' tickMode is TIME_SERIES
-
- The tz must be a descendant of the datetime.tzinfo class, "UTC" or None.
- Use None to let the datetimes be interpreted as local time.
- Use the string "UTC" to let the date datetimes be in UTC time.
-
- :param tz: datetime.tzinfo, "UTC" or None.
- """
- raise NotImplementedError()
-
- def getTickMode(self):
- """Determines if axis ticks are number or datetimes.
-
- :rtype: TickMode enum.
- """
- raise NotImplementedError()
-
- def setTickMode(self, tickMode):
- """Determines if axis ticks are number or datetimes.
-
- :param TickMode tickMode: tick mode enum.
- """
- raise NotImplementedError()
-
- def isAutoScale(self):
- """Return True if axis is automatically adjusting its limits.
-
- :rtype: bool
- """
- return self._isAutoScale
-
- def setAutoScale(self, flag=True):
- """Set the axis limits adjusting behavior of :meth:`resetZoom`.
-
- :param bool flag: True to resize limits automatically,
- False to disable it.
- """
- self._isAutoScale = bool(flag)
- self.sigAutoScaleChanged.emit(self._isAutoScale)
-
- def _setLimitsConstraints(self, minPos=None, maxPos=None):
- raise NotImplementedError()
-
- def setLimitsConstraints(self, minPos=None, maxPos=None):
- """
- Set a constraint on the position of the axes.
-
- :param float minPos: Minimum allowed axis value.
- :param float maxPos: Maximum allowed axis value.
- :return: True if the constaints was updated
- :rtype: bool
- """
- updated = self._setLimitsConstraints(minPos, maxPos)
- if updated:
- plot = self._getPlot()
- xMin, xMax = plot.getXAxis().getLimits()
- yMin, yMax = plot.getYAxis().getLimits()
- y2Min, y2Max = plot.getYAxis('right').getLimits()
- plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
- return updated
-
- def _setRangeConstraints(self, minRange=None, maxRange=None):
- raise NotImplementedError()
-
- def setRangeConstraints(self, minRange=None, maxRange=None):
- """
- Set a constraint on the position of the axes.
-
- :param float minRange: Minimum allowed left-to-right span across the
- view
- :param float maxRange: Maximum allowed left-to-right span across the
- view
- :return: True if the constaints was updated
- :rtype: bool
- """
- updated = self._setRangeConstraints(minRange, maxRange)
- if updated:
- plot = self._getPlot()
- xMin, xMax = plot.getXAxis().getLimits()
- yMin, yMax = plot.getYAxis().getLimits()
- y2Min, y2Max = plot.getYAxis('right').getLimits()
- plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
- return updated
-
-
-class XAxis(Axis):
- """Axis class defining primitives for the X axis"""
-
- # TODO With some changes on the backend, it will be able to remove all this
- # specialised implementations (prefixel by '_internal')
-
- def getTimeZone(self):
- return self._getBackend().getXAxisTimeZone()
-
- def setTimeZone(self, tz):
- if isinstance(tz, str) and tz.upper() == "UTC":
- tz = dateutil.tz.tzutc()
- elif not(tz is None or isinstance(tz, dt.tzinfo)):
- raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.")
-
- self._getBackend().setXAxisTimeZone(tz)
- self._getPlot()._setDirtyPlot()
-
- def getTickMode(self):
- if self._getBackend().isXAxisTimeSeries():
- return TickMode.TIME_SERIES
- else:
- return TickMode.DEFAULT
-
- def setTickMode(self, tickMode):
- if tickMode == TickMode.DEFAULT:
- self._getBackend().setXAxisTimeSeries(False)
- elif tickMode == TickMode.TIME_SERIES:
- self._getBackend().setXAxisTimeSeries(True)
- else:
- raise ValueError("Unexpected TickMode: {}".format(tickMode))
-
- def _internalSetCurrentLabel(self, label):
- self._getBackend().setGraphXLabel(label)
-
- def _internalGetLimits(self):
- return self._getBackend().getGraphXLimits()
-
- def _internalSetLimits(self, xmin, xmax):
- self._getBackend().setGraphXLimits(xmin, xmax)
-
- def _internalSetLogarithmic(self, flag):
- self._getBackend().setXAxisLogarithmic(flag)
-
- def _setLimitsConstraints(self, minPos=None, maxPos=None):
- constrains = self._getPlot()._getViewConstraints()
- updated = constrains.update(xMin=minPos, xMax=maxPos)
- return updated
-
- def _setRangeConstraints(self, minRange=None, maxRange=None):
- constrains = self._getPlot()._getViewConstraints()
- updated = constrains.update(minXRange=minRange, maxXRange=maxRange)
- return updated
-
-
-class YAxis(Axis):
- """Axis class defining primitives for the Y axis"""
-
- # TODO With some changes on the backend, it will be able to remove all this
- # specialised implementations (prefixel by '_internal')
-
- def _internalSetCurrentLabel(self, label):
- self._getBackend().setGraphYLabel(label, axis='left')
-
- def _internalGetLimits(self):
- return self._getBackend().getGraphYLimits(axis='left')
-
- def _internalSetLimits(self, ymin, ymax):
- self._getBackend().setGraphYLimits(ymin, ymax, axis='left')
-
- def _internalSetLogarithmic(self, flag):
- self._getBackend().setYAxisLogarithmic(flag)
-
- def setInverted(self, flag=True):
- """Set the axis orientation.
-
- This is only available for the Y axis.
-
- :param bool flag: True for Y axis going from top to bottom,
- False for Y axis going from bottom to top
- """
- flag = bool(flag)
- if self.isInverted() == flag:
- return
- self._getBackend().setYAxisInverted(flag)
- self._getPlot()._setDirtyPlot()
- self.sigInvertedChanged.emit(flag)
-
- def isInverted(self):
- """Return True if the axis is inverted (top to bottom for the y-axis),
- False otherwise. It is always False for the X axis.
-
- :rtype: bool
- """
- return self._getBackend().isYAxisInverted()
-
- def _setLimitsConstraints(self, minPos=None, maxPos=None):
- constrains = self._getPlot()._getViewConstraints()
- updated = constrains.update(yMin=minPos, yMax=maxPos)
- return updated
-
- def _setRangeConstraints(self, minRange=None, maxRange=None):
- constrains = self._getPlot()._getViewConstraints()
- updated = constrains.update(minYRange=minRange, maxYRange=maxRange)
- return updated
-
-
-class YRightAxis(Axis):
- """Proxy axis for the secondary Y axes. It manages it own label and limit
- but share the some state like scale and direction with the main axis."""
-
- # TODO With some changes on the backend, it will be able to remove all this
- # specialised implementations (prefixel by '_internal')
-
- def __init__(self, plot, mainAxis):
- """Constructor
-
- :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
- axis
- :param Axis mainAxis: Axis which sharing state with this axis
- """
- Axis.__init__(self, plot)
- self.__mainAxis = mainAxis
-
- @property
- def sigInvertedChanged(self):
- """Signal emitted when axis orientation has changed"""
- return self.__mainAxis.sigInvertedChanged
-
- @property
- def sigScaleChanged(self):
- """Signal emitted when axis scale has changed"""
- return self.__mainAxis.sigScaleChanged
-
- @property
- def _sigLogarithmicChanged(self):
- """Signal emitted when axis scale has changed to or from logarithmic"""
- return self.__mainAxis._sigLogarithmicChanged
-
- @property
- def sigAutoScaleChanged(self):
- """Signal emitted when axis autoscale has changed"""
- return self.__mainAxis.sigAutoScaleChanged
-
- def _internalSetCurrentLabel(self, label):
- self._getBackend().setGraphYLabel(label, axis='right')
-
- def _internalGetLimits(self):
- return self._getBackend().getGraphYLimits(axis='right')
-
- def _internalSetLimits(self, ymin, ymax):
- self._getBackend().setGraphYLimits(ymin, ymax, axis='right')
-
- def setInverted(self, flag=True):
- """Set the Y axis orientation.
-
- :param bool flag: True for Y axis going from top to bottom,
- False for Y axis going from bottom to top
- """
- return self.__mainAxis.setInverted(flag)
-
- def isInverted(self):
- """Return True if Y axis goes from top to bottom, False otherwise."""
- return self.__mainAxis.isInverted()
-
- def getScale(self):
- """Return the name of the scale used by this axis.
-
- :rtype: str
- """
- return self.__mainAxis.getScale()
-
- def setScale(self, scale):
- """Set the scale to be used by this axis.
-
- :param str scale: Name of the scale ("log", or "linear")
- """
- self.__mainAxis.setScale(scale)
-
- def _isLogarithmic(self):
- """Return True if Y axis scale is logarithmic, False if linear."""
- return self.__mainAxis._isLogarithmic()
-
- def _setLogarithmic(self, flag):
- """Set the Y axes scale (either linear or logarithmic).
-
- :param bool flag: True to use a logarithmic scale, False for linear.
- """
- return self.__mainAxis._setLogarithmic(flag)
-
- def isAutoScale(self):
- """Return True if Y axes are automatically adjusting its limits."""
- return self.__mainAxis.isAutoScale()
-
- def setAutoScale(self, flag=True):
- """Set the Y axis limits adjusting behavior of :meth:`PlotWidget.resetZoom`.
-
- :param bool flag: True to resize limits automatically,
- False to disable it.
- """
- return self.__mainAxis.setAutoScale(flag)
diff --git a/silx/gui/plot/items/core.py b/silx/gui/plot/items/core.py
deleted file mode 100644
index 95a65ad..0000000
--- a/silx/gui/plot/items/core.py
+++ /dev/null
@@ -1,1734 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the base class for items of the :class:`Plot`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "08/12/2020"
-
-import collections
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
-from copy import deepcopy
-import logging
-import enum
-from typing import Optional, Tuple
-import warnings
-import weakref
-
-import numpy
-import six
-
-from ....utils.deprecation import deprecated
-from ....utils.proxy import docstring
-from ....utils.enum import Enum as _Enum
-from ....math.combo import min_max
-from ... import qt
-from ... import colors
-from ...colors import Colormap
-from ._pick import PickingResult
-
-from silx import config
-
-_logger = logging.getLogger(__name__)
-
-
-@enum.unique
-class ItemChangedType(enum.Enum):
- """Type of modification provided by :attr:`Item.sigItemChanged` signal."""
- # Private setters and setInfo are not emitting sigItemChanged signal.
- # Signals to consider:
- # COLORMAP_SET emitted when setColormap is called but not forward colormap object signal
- # CURRENT_COLOR_CHANGED emitted current color changed because highlight changed,
- # highlighted color changed or color changed depending on hightlight state.
-
- VISIBLE = 'visibleChanged'
- """Item's visibility changed flag."""
-
- ZVALUE = 'zValueChanged'
- """Item's Z value changed flag."""
-
- COLORMAP = 'colormapChanged' # Emitted when set + forward events from the colormap object
- """Item's colormap changed flag.
-
- This is emitted both when setting a new colormap and
- when the current colormap object is updated.
- """
-
- SYMBOL = 'symbolChanged'
- """Item's symbol changed flag."""
-
- SYMBOL_SIZE = 'symbolSizeChanged'
- """Item's symbol size changed flag."""
-
- LINE_WIDTH = 'lineWidthChanged'
- """Item's line width changed flag."""
-
- LINE_STYLE = 'lineStyleChanged'
- """Item's line style changed flag."""
-
- COLOR = 'colorChanged'
- """Item's color changed flag."""
-
- LINE_BG_COLOR = 'lineBgColorChanged'
- """Item's line background color changed flag."""
-
- YAXIS = 'yAxisChanged'
- """Item's Y axis binding changed flag."""
-
- FILL = 'fillChanged'
- """Item's fill changed flag."""
-
- ALPHA = 'alphaChanged'
- """Item's transparency alpha changed flag."""
-
- DATA = 'dataChanged'
- """Item's data changed flag"""
-
- MASK = 'maskChanged'
- """Item's mask changed flag"""
-
- HIGHLIGHTED = 'highlightedChanged'
- """Item's highlight state changed flag."""
-
- HIGHLIGHTED_COLOR = 'highlightedColorChanged'
- """Deprecated, use HIGHLIGHTED_STYLE instead."""
-
- HIGHLIGHTED_STYLE = 'highlightedStyleChanged'
- """Item's highlighted style changed flag."""
-
- SCALE = 'scaleChanged'
- """Item's scale changed flag."""
-
- TEXT = 'textChanged'
- """Item's text changed flag."""
-
- POSITION = 'positionChanged'
- """Item's position changed flag.
-
- This is emitted when a marker position changed and
- when an image origin changed.
- """
-
- OVERLAY = 'overlayChanged'
- """Item's overlay state changed flag."""
-
- VISUALIZATION_MODE = 'visualizationModeChanged'
- """Item's visualization mode changed flag."""
-
- COMPLEX_MODE = 'complexModeChanged'
- """Item's complex data visualization mode changed flag."""
-
- NAME = 'nameChanged'
- """Item's name changed flag."""
-
- EDITABLE = 'editableChanged'
- """Item's editable state changed flags."""
-
- SELECTABLE = 'selectableChanged'
- """Item's selectable state changed flags."""
-
-
-class Item(qt.QObject):
- """Description of an item of the plot"""
-
- _DEFAULT_Z_LAYER = 0
- """Default layer for overlay rendering"""
-
- _DEFAULT_SELECTABLE = False
- """Default selectable state of items"""
-
- sigItemChanged = qt.Signal(object)
- """Signal emitted when the item has changed.
-
- It provides a flag describing which property of the item has changed.
- See :class:`ItemChangedType` for flags description.
- """
-
- _sigVisibleBoundsChanged = qt.Signal()
- """Signal emitted when the visible extent of the item in the plot has changed.
-
- This signal is emitted only if visible extent tracking is enabled
- (see :meth:`_setVisibleBoundsTracking`).
- """
-
- def __init__(self):
- qt.QObject.__init__(self)
- self._dirty = True
- self._plotRef = None
- self._visible = True
- self._selectable = self._DEFAULT_SELECTABLE
- self._z = self._DEFAULT_Z_LAYER
- self._info = None
- self._xlabel = None
- self._ylabel = None
- self.__name = ''
-
- self.__visibleBoundsTracking = False
- self.__previousVisibleBounds = None
-
- self._backendRenderer = None
-
- def getPlot(self):
- """Returns the ~silx.gui.plot.PlotWidget this item belongs to.
-
- :rtype: Union[~silx.gui.plot.PlotWidget,None]
- """
- return None if self._plotRef is None else self._plotRef()
-
- def _setPlot(self, plot):
- """Set the plot this item belongs to.
-
- WARNING: This should only be called from the Plot.
-
- :param Union[~silx.gui.plot.PlotWidget,None] plot: The Plot instance.
- """
- if plot is not None and self._plotRef is not None:
- raise RuntimeError('Trying to add a node at two places.')
- self.__disconnectFromPlotWidget()
- self._plotRef = None if plot is None else weakref.ref(plot)
- self.__connectToPlotWidget()
- self._updated()
-
- def getBounds(self): # TODO return a Bounds object rather than a tuple
- """Returns the bounding box of this item in data coordinates
-
- :returns: (xmin, xmax, ymin, ymax) or None
- :rtype: 4-tuple of float or None
- """
- return self._getBounds()
-
- def _getBounds(self):
- """:meth:`getBounds` implementation to override by sub-class"""
- return None
-
- def isVisible(self):
- """True if item is visible, False otherwise
-
- :rtype: bool
- """
- return self._visible
-
- def setVisible(self, visible):
- """Set visibility of item.
-
- :param bool visible: True to display it, False otherwise
- """
- visible = bool(visible)
- if visible != self._visible:
- self._visible = visible
- # When visibility has changed, always mark as dirty
- self._updated(ItemChangedType.VISIBLE,
- checkVisibility=False)
-
- def isOverlay(self):
- """Return true if item is drawn as an overlay.
-
- :rtype: bool
- """
- return False
-
- def getName(self):
- """Returns the name of the item which is used as legend.
-
- :rtype: str
- """
- return self.__name
-
- def setName(self, name):
- """Set the name of the item which is used as legend.
-
- :param str name: New name of the item
- :raises RuntimeError: If item belongs to a PlotWidget.
- """
- name = str(name)
- if self.__name != name:
- if self.getPlot() is not None:
- raise RuntimeError(
- "Cannot change name while item is in a PlotWidget")
-
- self.__name = name
- self._updated(ItemChangedType.NAME)
-
- def getLegend(self): # Replaced by getName for API consistency
- return self.getName()
-
- @deprecated(replacement='setName', since_version='0.13')
- def _setLegend(self, legend):
- legend = str(legend) if legend is not None else ''
- self.setName(legend)
-
- def isSelectable(self):
- """Returns true if item is selectable (bool)"""
- return self._selectable
-
- def _setSelectable(self, selectable): # TODO support update
- """Set whether item is selectable or not.
-
- This is private for now as change is not handled.
-
- :param bool selectable: True to make item selectable
- """
- self._selectable = bool(selectable)
-
- def getZValue(self):
- """Returns the layer on which to draw this item (int)"""
- return self._z
-
- def setZValue(self, z):
- z = int(z) if z is not None else self._DEFAULT_Z_LAYER
- if z != self._z:
- self._z = z
- self._updated(ItemChangedType.ZVALUE)
-
- def getInfo(self, copy=True):
- """Returns the info associated to this item
-
- :param bool copy: True to get a deepcopy, False otherwise.
- """
- return deepcopy(self._info) if copy else self._info
-
- def setInfo(self, info, copy=True):
- if copy:
- info = deepcopy(info)
- self._info = info
-
- def getVisibleBounds(self) -> Optional[Tuple[float, float, float, float]]:
- """Returns visible bounds of the item bounding box in the plot area.
-
- :returns:
- (xmin, xmax, ymin, ymax) in data coordinates of the visible area or
- None if item is not visible in the plot area.
- :rtype: Union[List[float],None]
- """
- plot = self.getPlot()
- bounds = self.getBounds()
- if plot is None or bounds is None or not self.isVisible():
- return None
-
- xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits())
- ymin, ymax = numpy.clip(
- bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits())
-
- if xmin == xmax or ymin == ymax: # Outside the plot area
- return None
- else:
- return xmin, xmax, ymin, ymax
-
- def _isVisibleBoundsTracking(self) -> bool:
- """Returns True if visible bounds changes are tracked.
-
- When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes.
- :rtype: bool
- """
- return self.__visibleBoundsTracking
-
- def _setVisibleBoundsTracking(self, enable: bool) -> None:
- """Set whether or not to track visible bounds changes.
-
- :param bool enable:
- """
- if enable != self.__visibleBoundsTracking:
- self.__disconnectFromPlotWidget()
- self.__previousVisibleBounds = None
- self.__visibleBoundsTracking = enable
- self.__connectToPlotWidget()
-
- def __getYAxis(self) -> str:
- """Returns current Y axis ('left' or 'right')"""
- return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left'
-
- def __connectToPlotWidget(self) -> None:
- """Connect to PlotWidget signals and install event filter"""
- if not self._isVisibleBoundsTracking():
- return
-
- plot = self.getPlot()
- if plot is not None:
- for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
- axis.sigLimitsChanged.connect(self._visibleBoundsChanged)
-
- plot.installEventFilter(self)
-
- self._visibleBoundsChanged()
-
- def __disconnectFromPlotWidget(self) -> None:
- """Disconnect from PlotWidget signals and remove event filter"""
- if not self._isVisibleBoundsTracking():
- return
-
- plot = self.getPlot()
- if plot is not None:
- for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
- axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged)
-
- plot.removeEventFilter(self)
-
- def _visibleBoundsChanged(self, *args) -> None:
- """Check if visible extent actually changed and emit signal"""
- if not self._isVisibleBoundsTracking():
- return # No visible extent tracking
-
- plot = self.getPlot()
- if plot is None or not plot.isVisible():
- return # No plot or plot not visible
-
- extent = self.getVisibleBounds()
- if extent != self.__previousVisibleBounds:
- self.__previousVisibleBounds = extent
- self._sigVisibleBoundsChanged.emit()
-
- def eventFilter(self, watched, event):
- """Event filter to handle PlotWidget show events"""
- if watched is self.getPlot() and event.type() == qt.QEvent.Show:
- self._visibleBoundsChanged()
- return super().eventFilter(watched, event)
-
- def _updated(self, event=None, checkVisibility=True):
- """Mark the item as dirty (i.e., needing update).
-
- This also triggers Plot.replot.
-
- :param event: The event to send to :attr:`sigItemChanged` signal.
- :param bool checkVisibility: True to only mark as dirty if visible,
- False to always mark as dirty.
- """
- if not checkVisibility or self.isVisible():
- if not self._dirty:
- self._dirty = True
- # TODO: send event instead of explicit call
- plot = self.getPlot()
- if plot is not None:
- plot._itemRequiresUpdate(self)
- if event is not None:
- self.sigItemChanged.emit(event)
-
- def _update(self, backend):
- """Called by Plot to update the backend for this item.
-
- This is meant to be called asynchronously from _updated.
- This optimizes the number of call to _update.
-
- :param backend: The backend to update
- """
- if self._dirty:
- # Remove previous renderer from backend if any
- self._removeBackendRenderer(backend)
-
- # If not visible, do not add renderer to backend
- if self.isVisible():
- self._backendRenderer = self._addBackendRenderer(backend)
-
- self._dirty = False
-
- def _addBackendRenderer(self, backend):
- """Override in subclass to add specific backend renderer.
-
- :param BackendBase backend: The backend to update
- :return: The renderer handle to store or None if no renderer in backend
- """
- return None
-
- def _removeBackendRenderer(self, backend):
- """Override in subclass to remove specific backend renderer.
-
- :param BackendBase backend: The backend to update
- """
- if self._backendRenderer is not None:
- backend.remove(self._backendRenderer)
- self._backendRenderer = None
-
- def pick(self, x, y):
- """Run picking test on this item
-
- :param float x: The x pixel coord where to pick.
- :param float y: The y pixel coord where to pick.
- :return: None if not picked, else the picked position information
- :rtype: Union[None,PickingResult]
- """
- if not self.isVisible() or self._backendRenderer is None:
- return None
- plot = self.getPlot()
- if plot is None:
- return None
-
- indices = plot._backend.pickItem(x, y, self._backendRenderer)
- if indices is None:
- return None
- else:
- return PickingResult(self, indices)
-
-
-class DataItem(Item):
- """Item with a data extent in the plot"""
-
- def _boundsChanged(self, checkVisibility: bool=True) -> None:
- """Call this method in subclass when data bounds has changed.
-
- :param bool checkVisibility:
- """
- if not checkVisibility or self.isVisible():
- self._visibleBoundsChanged()
-
- # TODO hackish data range implementation
- plot = self.getPlot()
- if plot is not None:
- plot._invalidateDataRange()
-
- @docstring(Item)
- def setVisible(self, visible: bool):
- if visible != self.isVisible():
- self._boundsChanged(checkVisibility=False)
- super().setVisible(visible)
-
-# Mix-in classes ##############################################################
-
-
-class ItemMixInBase(object):
- """Base class for Item mix-in"""
-
- def _updated(self, event=None, checkVisibility=True):
- """This is implemented in :class:`Item`.
-
- Mark the item as dirty (i.e., needing update).
- This also triggers Plot.replot.
-
- :param event: The event to send to :attr:`sigItemChanged` signal.
- :param bool checkVisibility: True to only mark as dirty if visible,
- False to always mark as dirty.
- """
- raise RuntimeError(
- "Issue with Mix-In class inheritance order")
-
-
-class LabelsMixIn(ItemMixInBase):
- """Mix-in class for items with x and y labels
-
- Setters are private, otherwise it needs to check the plot
- current active curve and access the internal current labels.
- """
-
- def __init__(self):
- self._xlabel = None
- self._ylabel = None
-
- def getXLabel(self):
- """Return the X axis label associated to this curve
-
- :rtype: str or None
- """
- return self._xlabel
-
- def _setXLabel(self, label):
- """Set the X axis label associated with this curve
-
- :param str label: The X axis label
- """
- self._xlabel = str(label)
-
- def getYLabel(self):
- """Return the Y axis label associated to this curve
-
- :rtype: str or None
- """
- return self._ylabel
-
- def _setYLabel(self, label):
- """Set the Y axis label associated with this curve
-
- :param str label: The Y axis label
- """
- self._ylabel = str(label)
-
-
-class DraggableMixIn(ItemMixInBase):
- """Mix-in class for draggable items"""
-
- def __init__(self):
- self._draggable = False
-
- def isDraggable(self):
- """Returns true if image is draggable
-
- :rtype: bool
- """
- return self._draggable
-
- def _setDraggable(self, draggable): # TODO support update
- """Set if image is draggable or not.
-
- This is private for not as it does not support update.
-
- :param bool draggable:
- """
- self._draggable = bool(draggable)
-
- def drag(self, from_, to):
- """Perform a drag of the item.
-
- :param List[float] from_: (x, y) previous position in data coordinates
- :param List[float] to: (x, y) current position in data coordinates
- """
- raise NotImplementedError("Must be implemented in subclass")
-
-
-class ColormapMixIn(ItemMixInBase):
- """Mix-in class for items with colormap"""
-
- def __init__(self):
- self._colormap = Colormap()
- self._colormap.sigChanged.connect(self._colormapChanged)
- self.__data = None
- self.__cacheColormapRange = {} # Store {normalization: range}
-
- def getColormap(self):
- """Return the used colormap"""
- return self._colormap
-
- def setColormap(self, colormap):
- """Set the colormap of this item
-
- :param silx.gui.colors.Colormap colormap: colormap description
- """
- if self._colormap is colormap:
- return
- if isinstance(colormap, dict):
- colormap = Colormap._fromDict(colormap)
-
- if self._colormap is not None:
- self._colormap.sigChanged.disconnect(self._colormapChanged)
- self._colormap = colormap
- if self._colormap is not None:
- self._colormap.sigChanged.connect(self._colormapChanged)
- self._colormapChanged()
-
- def _colormapChanged(self):
- """Handle updates of the colormap"""
- self._updated(ItemChangedType.COLORMAP)
-
- def _setColormappedData(self, data, copy=True,
- min_=None, minPositive=None, max_=None):
- """Set the data used to compute the colormapped display.
-
- It also resets the cache of data ranges.
-
- This method MUST be called by inheriting classes when data is updated.
-
- :param Union[None,numpy.ndarray] data:
- :param Union[None,float] min_: Minimum value of the data
- :param Union[None,float] minPositive:
- Minimum of strictly positive values of the data
- :param Union[None,float] max_: Maximum value of the data
- """
- self.__data = None if data is None else numpy.array(data, copy=copy)
- self.__cacheColormapRange = {} # Reset cache
-
- # Fill-up colormap range cache if values are provided
- if max_ is not None and numpy.isfinite(max_):
- if min_ is not None and numpy.isfinite(min_):
- self.__cacheColormapRange[Colormap.LINEAR, Colormap.MINMAX] = min_, max_
- if minPositive is not None and numpy.isfinite(minPositive):
- self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = minPositive, max_
-
- colormap = self.getColormap()
- if None in (colormap.getVMin(), colormap.getVMax()):
- self._colormapChanged()
-
- def getColormappedData(self, copy=True):
- """Returns the data used to compute the displayed colors
-
- :param bool copy: True to get a copy,
- False to get internal data (do not modify!).
- :rtype: Union[None,numpy.ndarray]
- """
- if self.__data is None:
- return None
- else:
- return numpy.array(self.__data, copy=copy)
-
- def _getColormapAutoscaleRange(self, colormap=None):
- """Returns the autoscale range for current data and colormap.
-
- :param Union[None,~silx.gui.colors.Colormap] colormap:
- The colormap for which to compute the autoscale range.
- If None, the default, the colormap of the item is used
- :return: (vmin, vmax) range (vmin and /or vmax might be `None`)
- """
- if colormap is None:
- colormap = self.getColormap()
-
- data = self.getColormappedData(copy=False)
- if colormap is None or data is None:
- return None, None
-
- normalization = colormap.getNormalization()
- autoscaleMode = colormap.getAutoscaleMode()
- key = normalization, autoscaleMode
- vRange = self.__cacheColormapRange.get(key, None)
- if vRange is None:
- vRange = colormap._computeAutoscaleRange(data)
- self.__cacheColormapRange[key] = vRange
- return vRange
-
-
-class SymbolMixIn(ItemMixInBase):
- """Mix-in class for items with symbol type"""
-
- _DEFAULT_SYMBOL = None
- """Default marker of the item"""
-
- _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
- """Default marker size of the item"""
-
- _SUPPORTED_SYMBOLS = collections.OrderedDict((
- ('o', 'Circle'),
- ('d', 'Diamond'),
- ('s', 'Square'),
- ('+', 'Plus'),
- ('x', 'Cross'),
- ('.', 'Point'),
- (',', 'Pixel'),
- ('|', 'Vertical line'),
- ('_', 'Horizontal line'),
- ('tickleft', 'Tick left'),
- ('tickright', 'Tick right'),
- ('tickup', 'Tick up'),
- ('tickdown', 'Tick down'),
- ('caretleft', 'Caret left'),
- ('caretright', 'Caret right'),
- ('caretup', 'Caret up'),
- ('caretdown', 'Caret down'),
- (u'\u2665', 'Heart'),
- ('', 'None')))
- """Dict of supported symbols"""
-
- def __init__(self):
- if self._DEFAULT_SYMBOL is None: # Use default from config
- self._symbol = config.DEFAULT_PLOT_SYMBOL
- else:
- self._symbol = self._DEFAULT_SYMBOL
-
- if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config
- self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE
- else:
- self._symbol_size = self._DEFAULT_SYMBOL_SIZE
-
- @classmethod
- def getSupportedSymbols(cls):
- """Returns the list of supported symbol names.
-
- :rtype: tuple of str
- """
- return tuple(cls._SUPPORTED_SYMBOLS.keys())
-
- @classmethod
- def getSupportedSymbolNames(cls):
- """Returns the list of supported symbol human-readable names.
-
- :rtype: tuple of str
- """
- return tuple(cls._SUPPORTED_SYMBOLS.values())
-
- def getSymbolName(self, symbol=None):
- """Returns human-readable name for a symbol.
-
- :param str symbol: The symbol from which to get the name.
- Default: current symbol.
- :rtype: str
- :raise KeyError: if symbol is not in :meth:`getSupportedSymbols`.
- """
- if symbol is None:
- symbol = self.getSymbol()
- return self._SUPPORTED_SYMBOLS[symbol]
-
- def getSymbol(self):
- """Return the point marker type.
-
- Marker type::
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
-
- :rtype: str
- """
- return self._symbol
-
- def setSymbol(self, symbol):
- """Set the marker type
-
- See :meth:`getSymbol`.
-
- :param str symbol: Marker type or marker name
- """
- if symbol is None:
- symbol = self._DEFAULT_SYMBOL
-
- elif symbol not in self.getSupportedSymbols():
- for symbolCode, name in self._SUPPORTED_SYMBOLS.items():
- if name.lower() == symbol.lower():
- symbol = symbolCode
- break
- else:
- raise ValueError('Unsupported symbol %s' % str(symbol))
-
- if symbol != self._symbol:
- self._symbol = symbol
- self._updated(ItemChangedType.SYMBOL)
-
- def getSymbolSize(self):
- """Return the point marker size in points.
-
- :rtype: float
- """
- return self._symbol_size
-
- def setSymbolSize(self, size):
- """Set the point marker size in points.
-
- See :meth:`getSymbolSize`.
-
- :param str symbol: Marker type
- """
- if size is None:
- size = self._DEFAULT_SYMBOL_SIZE
- if size != self._symbol_size:
- self._symbol_size = size
- self._updated(ItemChangedType.SYMBOL_SIZE)
-
-
-class LineMixIn(ItemMixInBase):
- """Mix-in class for item with line"""
-
- _DEFAULT_LINEWIDTH = 1.
- """Default line width"""
-
- _DEFAULT_LINESTYLE = '-'
- """Default line style"""
-
- _SUPPORTED_LINESTYLE = '', ' ', '-', '--', '-.', ':', None
- """Supported line styles"""
-
- def __init__(self):
- self._linewidth = self._DEFAULT_LINEWIDTH
- self._linestyle = self._DEFAULT_LINESTYLE
-
- @classmethod
- def getSupportedLineStyles(cls):
- """Returns list of supported line styles.
-
- :rtype: List[str,None]
- """
- return cls._SUPPORTED_LINESTYLE
-
- def getLineWidth(self):
- """Return the curve line width in pixels
-
- :rtype: float
- """
- return self._linewidth
-
- def setLineWidth(self, width):
- """Set the width in pixel of the curve line
-
- See :meth:`getLineWidth`.
-
- :param float width: Width in pixels
- """
- width = float(width)
- if width != self._linewidth:
- self._linewidth = width
- self._updated(ItemChangedType.LINE_WIDTH)
-
- def getLineStyle(self):
- """Return the type of the line
-
- Type of line::
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
-
- :rtype: str
- """
- return self._linestyle
-
- def setLineStyle(self, style):
- """Set the style of the curve line.
-
- See :meth:`getLineStyle`.
-
- :param str style: Line style
- """
- style = str(style)
- assert style in self.getSupportedLineStyles()
- if style is None:
- style = self._DEFAULT_LINESTYLE
- if style != self._linestyle:
- self._linestyle = style
- self._updated(ItemChangedType.LINE_STYLE)
-
-
-class ColorMixIn(ItemMixInBase):
- """Mix-in class for item with color"""
-
- _DEFAULT_COLOR = (0., 0., 0., 1.)
- """Default color of the item"""
-
- def __init__(self):
- self._color = self._DEFAULT_COLOR
-
- def getColor(self):
- """Returns the RGBA color of the item
-
- :rtype: 4-tuple of float in [0, 1] or array of colors
- """
- return self._color
-
- def setColor(self, color, copy=True):
- """Set item color
-
- :param color: color(s) to be used
- :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- """
- if isinstance(color, six.string_types):
- color = colors.rgba(color)
- elif isinstance(color, qt.QColor):
- color = colors.rgba(color)
- else:
- color = numpy.array(color, copy=copy)
- # TODO more checks + improve color array support
- if color.ndim == 1: # Single RGBA color
- color = colors.rgba(color)
- else: # Array of colors
- assert color.ndim == 2
-
- self._color = color
- self._updated(ItemChangedType.COLOR)
-
-
-class YAxisMixIn(ItemMixInBase):
- """Mix-in class for item with yaxis"""
-
- _DEFAULT_YAXIS = 'left'
- """Default Y axis the item belongs to"""
-
- def __init__(self):
- self._yaxis = self._DEFAULT_YAXIS
-
- def getYAxis(self):
- """Returns the Y axis this curve belongs to.
-
- Either 'left' or 'right'.
-
- :rtype: str
- """
- return self._yaxis
-
- def setYAxis(self, yaxis):
- """Set the Y axis this curve belongs to.
-
- :param str yaxis: 'left' or 'right'
- """
- yaxis = str(yaxis)
- assert yaxis in ('left', 'right')
- if yaxis != self._yaxis:
- self._yaxis = yaxis
- # Handle data extent changed for DataItem
- if isinstance(self, DataItem):
- self._boundsChanged()
-
- # Handle visible extent changed
- if self._isVisibleBoundsTracking():
- # Switch Y axis signal connection
- plot = self.getPlot()
- if plot is not None:
- previousYAxis = 'left' if self.getXAxis() == 'right' else 'right'
- plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect(
- self._visibleBoundsChanged)
- plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect(
- self._visibleBoundsChanged)
- self._visibleBoundsChanged()
-
- self._updated(ItemChangedType.YAXIS)
-
-
-class FillMixIn(ItemMixInBase):
- """Mix-in class for item with fill"""
-
- def __init__(self):
- self._fill = False
-
- def isFill(self):
- """Returns whether the item is filled or not.
-
- :rtype: bool
- """
- return self._fill
-
- def setFill(self, fill):
- """Set whether to fill the item or not.
-
- :param bool fill:
- """
- fill = bool(fill)
- if fill != self._fill:
- self._fill = fill
- self._updated(ItemChangedType.FILL)
-
-
-class AlphaMixIn(ItemMixInBase):
- """Mix-in class for item with opacity"""
-
- def __init__(self):
- self._alpha = 1.
-
- def getAlpha(self):
- """Returns the opacity of the item
-
- :rtype: float in [0, 1.]
- """
- return self._alpha
-
- def setAlpha(self, alpha):
- """Set the opacity of the item
-
- .. note::
-
- If the colormap already has some transparency, this alpha
- adds additional transparency. The alpha channel of the colormap
- is multiplied by this value.
-
- :param alpha: Opacity of the item, between 0 (full transparency)
- and 1. (full opacity)
- :type alpha: float
- """
- alpha = float(alpha)
- alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range
- if alpha != self._alpha:
- self._alpha = alpha
- self._updated(ItemChangedType.ALPHA)
-
-
-class ComplexMixIn(ItemMixInBase):
- """Mix-in class for complex data mode"""
-
- _SUPPORTED_COMPLEX_MODES = None
- """Override to only support a subset of all ComplexMode"""
-
- class ComplexMode(_Enum):
- """Identify available display mode for complex"""
- NONE = 'none'
- ABSOLUTE = 'amplitude'
- PHASE = 'phase'
- REAL = 'real'
- IMAGINARY = 'imaginary'
- AMPLITUDE_PHASE = 'amplitude_phase'
- LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase'
- SQUARE_AMPLITUDE = 'square_amplitude'
-
- def __init__(self):
- self.__complex_mode = self.ComplexMode.ABSOLUTE
-
- def getComplexMode(self):
- """Returns the current complex visualization mode.
-
- :rtype: ComplexMode
- """
- return self.__complex_mode
-
- def setComplexMode(self, mode):
- """Set the complex visualization mode.
-
- :param ComplexMode mode: The visualization mode in:
- 'real', 'imaginary', 'phase', 'amplitude'
- :return: True if value was set, False if is was already set
- :rtype: bool
- """
- mode = self.ComplexMode.from_value(mode)
- assert mode in self.supportedComplexModes()
-
- if mode != self.__complex_mode:
- self.__complex_mode = mode
- self._updated(ItemChangedType.COMPLEX_MODE)
- return True
- else:
- return False
-
- def _convertComplexData(self, data, mode=None):
- """Convert complex data to the specific mode.
-
- :param Union[ComplexMode,None] mode:
- The kind of value to compute.
- If None (the default), the current complex mode is used.
- :return: The converted dataset
- :rtype: Union[numpy.ndarray[float],None]
- """
- if data is None:
- return None
-
- if mode is None:
- mode = self.getComplexMode()
-
- if mode is self.ComplexMode.REAL:
- return numpy.real(data)
- elif mode is self.ComplexMode.IMAGINARY:
- return numpy.imag(data)
- elif mode is self.ComplexMode.ABSOLUTE:
- return numpy.absolute(data)
- elif mode is self.ComplexMode.PHASE:
- return numpy.angle(data)
- elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
- return numpy.absolute(data) ** 2
- else:
- raise ValueError('Unsupported conversion mode: %s', str(mode))
-
- @classmethod
- def supportedComplexModes(cls):
- """Returns the list of supported complex visualization modes.
-
- See :class:`ComplexMode` and :meth:`setComplexMode`.
-
- :rtype: List[ComplexMode]
- """
- if cls._SUPPORTED_COMPLEX_MODES is None:
- return cls.ComplexMode.members()
- else:
- return cls._SUPPORTED_COMPLEX_MODES
-
-
-class ScatterVisualizationMixIn(ItemMixInBase):
- """Mix-in class for scatter plot visualization modes"""
-
- _SUPPORTED_SCATTER_VISUALIZATION = None
- """Allows to override supported Visualizations"""
-
- @enum.unique
- class Visualization(_Enum):
- """Different modes of scatter plot visualizations"""
-
- POINTS = 'points'
- """Display scatter plot as a point cloud"""
-
- LINES = 'lines'
- """Display scatter plot as a wireframe.
-
- This is based on Delaunay triangulation
- """
-
- SOLID = 'solid'
- """Display scatter plot as a set of filled triangles.
-
- This is based on Delaunay triangulation
- """
-
- REGULAR_GRID = 'regular_grid'
- """Display scatter plot as an image.
-
- It expects the points to be the intersection of a regular grid,
- and the order of points following that of an image.
- First line, then second one, and always in the same direction
- (either all lines from left to right or all from right to left).
- """
-
- IRREGULAR_GRID = 'irregular_grid'
- """Display scatter plot as contiguous quadrilaterals.
-
- It expects the points to be the intersection of an irregular grid,
- and the order of points following that of an image.
- First line, then second one, and always in the same direction
- (either all lines from left to right or all from right to left).
- """
-
- BINNED_STATISTIC = 'binned_statistic'
- """Display scatter plot as 2D binned statistic (i.e., generalized histogram).
- """
-
- @enum.unique
- class VisualizationParameter(_Enum):
- """Different parameter names for scatter plot visualizations"""
-
- GRID_MAJOR_ORDER = 'grid_major_order'
- """The major order of points in the regular grid.
-
- Either 'row' (row-major, fast X) or 'column' (column-major, fast Y).
- """
-
- GRID_BOUNDS = 'grid_bounds'
- """The expected range in data coordinates of the regular grid.
-
- A 2-tuple of 2-tuple: (begin (x, y), end (x, y)).
- This provides the data coordinates of the first point and the expected
- last on.
- As for `GRID_SHAPE`, this can be wider than the current data.
- """
-
- GRID_SHAPE = 'grid_shape'
- """The expected size of the regular grid (height, width).
-
- The given shape can be wider than the number of points,
- in which case the grid is not fully filled.
- """
-
- BINNED_STATISTIC_SHAPE = 'binned_statistic_shape'
- """The number of bins in each dimension (height, width).
- """
-
- BINNED_STATISTIC_FUNCTION = 'binned_statistic_function'
- """The reduction function to apply to each bin (str).
-
- Available reduction functions are: 'mean' (default), 'count', 'sum'.
- """
-
- DATA_BOUNDS_HINT = 'data_bounds_hint'
- """The expected bounds of the data in data coordinates.
-
- A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)).
- This provides a hint for the data ranges in both dimensions.
- It is eventually enlarged with actually data ranges.
-
- WARNING: dimension 0 i.e., Y first.
- """
-
- _SUPPORTED_VISUALIZATION_PARAMETER_VALUES = {
- VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'),
- VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'),
- }
- """Supported visualization parameter values.
-
- Defined for parameters with a set of acceptable values.
- """
-
- def __init__(self):
- self.__visualization = self.Visualization.POINTS
- self.__parameters = dict(# Init parameters to None
- (parameter, None) for parameter in self.VisualizationParameter)
- self.__parameters[self.VisualizationParameter.BINNED_STATISTIC_FUNCTION] = 'mean'
-
- @classmethod
- def supportedVisualizations(cls):
- """Returns the list of supported scatter visualization modes.
-
- See :meth:`setVisualization`
-
- :rtype: List[Visualization]
- """
- if cls._SUPPORTED_SCATTER_VISUALIZATION is None:
- return cls.Visualization.members()
- else:
- return cls._SUPPORTED_SCATTER_VISUALIZATION
-
- @classmethod
- def supportedVisualizationParameterValues(cls, parameter):
- """Returns the list of supported scatter visualization modes.
-
- See :meth:`VisualizationParameters`
-
- :param VisualizationParameter parameter:
- This parameter for which to retrieve the supported values.
- :returns: tuple of supported of values or None if not defined.
- """
- parameter = cls.VisualizationParameter(parameter)
- return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get(
- parameter, None)
-
- def setVisualization(self, mode):
- """Set the scatter plot visualization mode to use.
-
- See :class:`Visualization` for all possible values,
- and :meth:`supportedVisualizations` for supported ones.
-
- :param Union[str,Visualization] mode:
- The visualization mode to use.
- :return: True if value was set, False if is was already set
- :rtype: bool
- """
- mode = self.Visualization.from_value(mode)
- assert mode in self.supportedVisualizations()
-
- if mode != self.__visualization:
- self.__visualization = mode
-
- self._updated(ItemChangedType.VISUALIZATION_MODE)
- return True
- else:
- return False
-
- def getVisualization(self):
- """Returns the scatter plot visualization mode in use.
-
- :rtype: Visualization
- """
- return self.__visualization
-
- def setVisualizationParameter(self, parameter, value=None):
- """Set the given visualization parameter.
-
- :param Union[str,VisualizationParameter] parameter:
- The name of the parameter to set
- :param value: The value to use for this parameter
- Set to None to automatically set the parameter
- :raises ValueError: If parameter is not supported
- :return: True if parameter was set, False if is was already set
- :rtype: bool
- :raise ValueError: If value is not supported
- """
- parameter = self.VisualizationParameter.from_value(parameter)
-
- if self.__parameters[parameter] != value:
- validValues = self.supportedVisualizationParameterValues(parameter)
- if validValues is not None and value not in validValues:
- raise ValueError("Unsupported parameter value: %s" % str(value))
-
- self.__parameters[parameter] = value
- self._updated(ItemChangedType.VISUALIZATION_MODE)
- return True
- return False
-
- def getVisualizationParameter(self, parameter):
- """Returns the value of the given visualization parameter.
-
- This method returns the parameter as set by
- :meth:`setVisualizationParameter`.
-
- :param parameter: The name of the parameter to retrieve
- :returns: The value previously set or None if automatically set
- :raises ValueError: If parameter is not supported
- """
- if parameter not in self.VisualizationParameter:
- raise ValueError("parameter not supported: %s", parameter)
-
- return self.__parameters[parameter]
-
- def getCurrentVisualizationParameter(self, parameter):
- """Returns the current value of the given visualization parameter.
-
- If the parameter was set by :meth:`setVisualizationParameter` to
- a value that is not None, this value is returned;
- else the current value that is automatically computed is returned.
-
- :param parameter: The name of the parameter to retrieve
- :returns: The current value (either set or automatically computed)
- :raises ValueError: If parameter is not supported
- """
- # Override in subclass to provide automatically computed parameters
- return self.getVisualizationParameter(parameter)
-
-
-class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
- """Base class for :class:`Curve` and :class:`Scatter`"""
- # note: _logFilterData must be overloaded if you overload
- # getData to change its signature
-
- _DEFAULT_Z_LAYER = 1
- """Default overlay layer for points,
- on top of images."""
-
- def __init__(self):
- DataItem.__init__(self)
- SymbolMixIn.__init__(self)
- AlphaMixIn.__init__(self)
- self._x = ()
- self._y = ()
- self._xerror = None
- self._yerror = None
-
- # Store filtered data for x > 0 and/or y > 0
- self._filteredCache = {}
- self._clippedCache = {}
-
- # Store bounds depending on axes filtering >0:
- # key is (isXPositiveFilter, isYPositiveFilter)
- self._boundsCache = {}
-
- @staticmethod
- def _logFilterError(value, error):
- """Filter/convert error values if they go <= 0.
-
- Replace error leading to negative values by nan
-
- :param numpy.ndarray value: 1D array of values
- :param numpy.ndarray error:
- Array of errors: scalar, N, Nx1 or 2xN or None.
- :return: Filtered error so error bars are never negative
- """
- if error is not None:
- # Convert Nx1 to N
- if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1:
- error = numpy.ravel(error)
-
- # Supports error being scalar, N or 2xN array
- valueMinusError = value - numpy.atleast_2d(error)[0]
- errorClipped = numpy.isnan(valueMinusError)
- mask = numpy.logical_not(errorClipped)
- errorClipped[mask] = valueMinusError[mask] <= 0
-
- if numpy.any(errorClipped): # Need filtering
-
- # expand errorbars to 2xN
- if error.size == 1: # Scalar
- error = numpy.full(
- (2, len(value)), error, dtype=numpy.float64)
-
- elif error.ndim == 1: # N array
- newError = numpy.empty((2, len(value)),
- dtype=numpy.float64)
- newError[0,:] = error
- newError[1,:] = error
- error = newError
-
- elif error.size == 2 * len(value): # 2xN array
- error = numpy.array(
- error, copy=True, dtype=numpy.float64)
-
- else:
- _logger.error("Unhandled error array")
- return error
-
- error[0, errorClipped] = numpy.nan
-
- return error
-
- def _getClippingBoolArray(self, xPositive, yPositive):
- """Compute a boolean array to filter out points with negative
- coordinates on log axes.
-
- :param bool xPositive: True to filter arrays according to X coords.
- :param bool yPositive: True to filter arrays according to Y coords.
- :rtype: boolean numpy.ndarray
- """
- assert xPositive or yPositive
- if (xPositive, yPositive) not in self._clippedCache:
- xclipped, yclipped = False, False
-
- if xPositive:
- x = self.getXData(copy=False)
- with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
- xclipped = x <= 0
-
- if yPositive:
- y = self.getYData(copy=False)
- with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
- yclipped = y <= 0
-
- self._clippedCache[(xPositive, yPositive)] = \
- numpy.logical_or(xclipped, yclipped)
- return self._clippedCache[(xPositive, yPositive)]
-
- def _logFilterData(self, xPositive, yPositive):
- """Filter out values with x or y <= 0 on log axes
-
- :param bool xPositive: True to filter arrays according to X coords.
- :param bool yPositive: True to filter arrays according to Y coords.
- :return: The filter arrays or unchanged object if filtering not needed
- :rtype: (x, y, xerror, yerror)
- """
- x = self.getXData(copy=False)
- y = self.getYData(copy=False)
- xerror = self.getXErrorData(copy=False)
- yerror = self.getYErrorData(copy=False)
-
- if xPositive or yPositive:
- clipped = self._getClippingBoolArray(xPositive, yPositive)
-
- if numpy.any(clipped):
- # copy to keep original array and convert to float
- x = numpy.array(x, copy=True, dtype=numpy.float64)
- x[clipped] = numpy.nan
- y = numpy.array(y, copy=True, dtype=numpy.float64)
- y[clipped] = numpy.nan
-
- if xPositive and xerror is not None:
- xerror = self._logFilterError(x, xerror)
-
- if yPositive and yerror is not None:
- yerror = self._logFilterError(y, yerror)
-
- return x, y, xerror, yerror
-
- def _getBounds(self):
- if self.getXData(copy=False).size == 0: # Empty data
- return None
-
- plot = self.getPlot()
- if plot is not None:
- xPositive = plot.getXAxis()._isLogarithmic()
- yPositive = plot.getYAxis()._isLogarithmic()
- else:
- xPositive = False
- yPositive = False
-
- # TODO bounds do not take error bars into account
- if (xPositive, yPositive) not in self._boundsCache:
- # use the getData class method because instance method can be
- # overloaded to return additional arrays
- data = PointsBase.getData(self, copy=False, displayed=True)
- if len(data) == 5:
- # hack to avoid duplicating caching mechanism in Scatter
- # (happens when cached data is used, caching done using
- # Scatter._logFilterData)
- x, y, _xerror, _yerror = data[0], data[1], data[3], data[4]
- else:
- x, y, _xerror, _yerror = data
-
- xmin, xmax = min_max(x, finite=True)
- ymin, ymax = min_max(y, finite=True)
- self._boundsCache[(xPositive, yPositive)] = tuple([
- (bound if bound is not None else numpy.nan)
- for bound in (xmin, xmax, ymin, ymax)])
- return self._boundsCache[(xPositive, yPositive)]
-
- def _getCachedData(self):
- """Return cached filtered data if applicable,
- i.e. if any axis is in log scale.
- Return None if caching is not applicable."""
- plot = self.getPlot()
- if plot is not None:
- xPositive = plot.getXAxis()._isLogarithmic()
- yPositive = plot.getYAxis()._isLogarithmic()
- if xPositive or yPositive:
- # At least one axis has log scale, filter data
- if (xPositive, yPositive) not in self._filteredCache:
- self._filteredCache[(xPositive, yPositive)] = \
- self._logFilterData(xPositive, yPositive)
- return self._filteredCache[(xPositive, yPositive)]
- return None
-
- def getData(self, copy=True, displayed=False):
- """Returns the x, y values of the curve points and xerror, yerror
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :param bool displayed: True to only get curve points that are displayed
- in the plot. Default: False
- Note: If plot has log scale, negative points
- are not displayed.
- :returns: (x, y, xerror, yerror)
- :rtype: 4-tuple of numpy.ndarray
- """
- if displayed: # filter data according to plot state
- cached_data = self._getCachedData()
- if cached_data is not None:
- return cached_data
-
- return (self.getXData(copy),
- self.getYData(copy),
- self.getXErrorData(copy),
- self.getYErrorData(copy))
-
- def getXData(self, copy=True):
- """Returns the x coordinates of the data points
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray
- """
- return numpy.array(self._x, copy=copy)
-
- def getYData(self, copy=True):
- """Returns the y coordinates of the data points
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray
- """
- return numpy.array(self._y, copy=copy)
-
- def getXErrorData(self, copy=True):
- """Returns the x error of the points
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray, float or None
- """
- if isinstance(self._xerror, numpy.ndarray):
- return numpy.array(self._xerror, copy=copy)
- else:
- return self._xerror # float or None
-
- def getYErrorData(self, copy=True):
- """Returns the y error of the points
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray, float or None
- """
- if isinstance(self._yerror, numpy.ndarray):
- return numpy.array(self._yerror, copy=copy)
- else:
- return self._yerror # float or None
-
- def setData(self, x, y, xerror=None, yerror=None, copy=True):
- """Set the data of the curve.
-
- :param numpy.ndarray x: The data corresponding to the x coordinates.
- :param numpy.ndarray y: The data corresponding to the y coordinates.
- :param xerror: Values with the uncertainties on the x values
- :type xerror: A float, or a numpy.ndarray of float32.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for positive errors,
- row 1 for negative errors.
- :param yerror: Values with the uncertainties on the y values.
- :type yerror: A float, or a numpy.ndarray of float32. See xerror.
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- """
- x = numpy.array(x, copy=copy)
- y = numpy.array(y, copy=copy)
- assert len(x) == len(y)
- assert x.ndim == y.ndim == 1
-
- # Convert complex data
- if numpy.iscomplexobj(x):
- _logger.warning(
- 'Converting x data to absolute value to plot it.')
- x = numpy.absolute(x)
- if numpy.iscomplexobj(y):
- _logger.warning(
- 'Converting y data to absolute value to plot it.')
- y = numpy.absolute(y)
-
- if xerror is not None:
- if isinstance(xerror, abc.Iterable):
- xerror = numpy.array(xerror, copy=copy)
- if numpy.iscomplexobj(xerror):
- _logger.warning(
- 'Converting xerror data to absolute value to plot it.')
- xerror = numpy.absolute(xerror)
- else:
- xerror = float(xerror)
- if yerror is not None:
- if isinstance(yerror, abc.Iterable):
- yerror = numpy.array(yerror, copy=copy)
- if numpy.iscomplexobj(yerror):
- _logger.warning(
- 'Converting yerror data to absolute value to plot it.')
- yerror = numpy.absolute(yerror)
- else:
- yerror = float(yerror)
- # TODO checks on xerror, yerror
- self._x, self._y = x, y
- self._xerror, self._yerror = xerror, yerror
-
- self._boundsCache = {} # Reset cached bounds
- self._filteredCache = {} # Reset cached filtered data
- self._clippedCache = {} # Reset cached clipped bool array
-
- self._boundsChanged()
- self._updated(ItemChangedType.DATA)
-
-
-class BaselineMixIn(object):
- """Base class for Baseline mix-in"""
-
- def __init__(self, baseline=None):
- self._baseline = baseline
-
- def _setBaseline(self, baseline):
- """
- Set baseline value
-
- :param baseline: baseline value(s)
- :type: Union[None,float,numpy.ndarray]
- """
- if (isinstance(baseline, abc.Iterable)):
- baseline = numpy.array(baseline)
- self._baseline = baseline
-
- def getBaseline(self, copy=True):
- """
-
- :param bool copy:
- :return: histogram baseline
- :rtype: Union[None,float,numpy.ndarray]
- """
- if isinstance(self._baseline, numpy.ndarray):
- return numpy.array(self._baseline, copy=True)
- else:
- return self._baseline
-
-
-class _Style:
- """Object which store styles"""
-
-
-class HighlightedMixIn(ItemMixInBase):
-
- def __init__(self):
- self._highlightStyle = self._DEFAULT_HIGHLIGHT_STYLE
- self._highlighted = False
-
- def isHighlighted(self):
- """Returns True if curve is highlighted.
-
- :rtype: bool
- """
- return self._highlighted
-
- def setHighlighted(self, highlighted):
- """Set the highlight state of the curve
-
- :param bool highlighted:
- """
- highlighted = bool(highlighted)
- if highlighted != self._highlighted:
- self._highlighted = highlighted
- # TODO inefficient: better to use backend's setCurveColor
- self._updated(ItemChangedType.HIGHLIGHTED)
-
- def getHighlightedStyle(self):
- """Returns the highlighted style in use
-
- :rtype: CurveStyle
- """
- return self._highlightStyle
-
- def setHighlightedStyle(self, style):
- """Set the style to use for highlighting
-
- :param CurveStyle style: New style to use
- """
- previous = self.getHighlightedStyle()
- if style != previous:
- assert isinstance(style, _Style)
- self._highlightStyle = style
- self._updated(ItemChangedType.HIGHLIGHTED_STYLE)
-
- # Backward compatibility event
- if previous.getColor() != style.getColor():
- self._updated(ItemChangedType.HIGHLIGHTED_COLOR)
diff --git a/silx/gui/plot/items/curve.py b/silx/gui/plot/items/curve.py
deleted file mode 100644
index 75e7f01..0000000
--- a/silx/gui/plot/items/curve.py
+++ /dev/null
@@ -1,326 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the :class:`Curve` item of the :class:`Plot`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import logging
-
-import numpy
-import six
-
-from ....utils.deprecation import deprecated
-from ... import colors
-from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn,
- FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType,
- BaselineMixIn, HighlightedMixIn, _Style)
-
-
-_logger = logging.getLogger(__name__)
-
-
-class CurveStyle(_Style):
- """Object storing the style of a curve.
-
- Set a value to None to use the default
-
- :param color: Color
- :param Union[str,None] linestyle: Style of the line
- :param Union[float,None] linewidth: Width of the line
- :param Union[str,None] symbol: Symbol for markers
- :param Union[float,None] symbolsize: Size of the markers
- """
-
- def __init__(self, color=None, linestyle=None, linewidth=None,
- symbol=None, symbolsize=None):
- if color is None:
- self._color = None
- else:
- if isinstance(color, six.string_types):
- color = colors.rgba(color)
- else: # array-like expected
- color = numpy.array(color, copy=False)
- if color.ndim == 1: # Array is 1D, this is a single color
- color = colors.rgba(color)
- self._color = color
-
- if linestyle is not None:
- assert linestyle in LineMixIn.getSupportedLineStyles()
- self._linestyle = linestyle
-
- self._linewidth = None if linewidth is None else float(linewidth)
-
- if symbol is not None:
- assert symbol in SymbolMixIn.getSupportedSymbols()
- self._symbol = symbol
-
- self._symbolsize = None if symbolsize is None else float(symbolsize)
-
- def getColor(self, copy=True):
- """Returns the color or None if not set.
-
- :param bool copy: True to get a copy (default),
- False to get internal representation (do not modify!)
-
- :rtype: Union[List[float],None]
- """
- if isinstance(self._color, numpy.ndarray):
- return numpy.array(self._color, copy=copy)
- else:
- return self._color
-
- def getLineStyle(self):
- """Return the type of the line or None if not set.
-
- Type of line::
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
-
- :rtype: Union[str,None]
- """
- return self._linestyle
-
- def getLineWidth(self):
- """Return the curve line width in pixels or None if not set.
-
- :rtype: Union[float,None]
- """
- return self._linewidth
-
- def getSymbol(self):
- """Return the point marker type.
-
- Marker type::
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
-
- :rtype: Union[str,None]
- """
- return self._symbol
-
- def getSymbolSize(self):
- """Return the point marker size in points.
-
- :rtype: Union[float,None]
- """
- return self._symbolsize
-
- def __eq__(self, other):
- if isinstance(other, CurveStyle):
- return (numpy.array_equal(self.getColor(), other.getColor()) and
- self.getLineStyle() == other.getLineStyle() and
- self.getLineWidth() == other.getLineWidth() and
- self.getSymbol() == other.getSymbol() and
- self.getSymbolSize() == other.getSymbolSize())
- else:
- return False
-
-
-class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
- LineMixIn, BaselineMixIn, HighlightedMixIn):
- """Description of a curve"""
-
- _DEFAULT_Z_LAYER = 1
- """Default overlay layer for curves"""
-
- _DEFAULT_SELECTABLE = True
- """Default selectable state for curves"""
-
- _DEFAULT_LINEWIDTH = 1.
- """Default line width of the curve"""
-
- _DEFAULT_LINESTYLE = '-'
- """Default line style of the curve"""
-
- _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color='black')
- """Default highlight style of the item"""
-
- _DEFAULT_BASELINE = None
-
- def __init__(self):
- PointsBase.__init__(self)
- ColorMixIn.__init__(self)
- YAxisMixIn.__init__(self)
- FillMixIn.__init__(self)
- LabelsMixIn.__init__(self)
- LineMixIn.__init__(self)
- BaselineMixIn.__init__(self)
- HighlightedMixIn.__init__(self)
-
- self._setBaseline(Curve._DEFAULT_BASELINE)
-
- def _addBackendRenderer(self, backend):
- """Update backend renderer"""
- # Filter-out values <= 0
- xFiltered, yFiltered, xerror, yerror = self.getData(
- copy=False, displayed=True)
-
- if len(xFiltered) == 0 or not numpy.any(numpy.isfinite(xFiltered)):
- return None # No data to display, do not add renderer to backend
-
- style = self.getCurrentStyle()
-
- return backend.addCurve(xFiltered, yFiltered,
- color=style.getColor(),
- symbol=style.getSymbol(),
- linestyle=style.getLineStyle(),
- linewidth=style.getLineWidth(),
- yaxis=self.getYAxis(),
- xerror=xerror,
- yerror=yerror,
- fill=self.isFill(),
- alpha=self.getAlpha(),
- symbolsize=style.getSymbolSize(),
- baseline=self.getBaseline(copy=False))
-
- def __getitem__(self, item):
- """Compatibility with PyMca and silx <= 0.4.0"""
- if isinstance(item, slice):
- return [self[index] for index in range(*item.indices(5))]
- elif item == 0:
- return self.getXData(copy=False)
- elif item == 1:
- return self.getYData(copy=False)
- elif item == 2:
- return self.getName()
- elif item == 3:
- info = self.getInfo(copy=False)
- return {} if info is None else info
- elif item == 4:
- params = {
- 'info': self.getInfo(),
- 'color': self.getColor(),
- 'symbol': self.getSymbol(),
- 'linewidth': self.getLineWidth(),
- 'linestyle': self.getLineStyle(),
- 'xlabel': self.getXLabel(),
- 'ylabel': self.getYLabel(),
- 'yaxis': self.getYAxis(),
- 'xerror': self.getXErrorData(copy=False),
- 'yerror': self.getYErrorData(copy=False),
- 'z': self.getZValue(),
- 'selectable': self.isSelectable(),
- 'fill': self.isFill(),
- }
- return params
- else:
- raise IndexError("Index out of range: %s", str(item))
-
- @deprecated(replacement='Curve.getHighlightedStyle().getColor()',
- since_version='0.9.0')
- def getHighlightedColor(self):
- """Returns the RGBA highlight color of the item
-
- :rtype: 4-tuple of float in [0, 1]
- """
- return self.getHighlightedStyle().getColor()
-
- @deprecated(replacement='Curve.setHighlightedStyle()',
- since_version='0.9.0')
- def setHighlightedColor(self, color):
- """Set the color to use when highlighted
-
- :param color: color(s) to be used for highlight
- :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- """
- self.setHighlightedStyle(CurveStyle(color))
-
- def getCurrentStyle(self):
- """Returns the current curve style.
-
- Curve style depends on curve highlighting
-
- :rtype: CurveStyle
- """
- if self.isHighlighted():
- style = self.getHighlightedStyle()
- color = style.getColor()
- linestyle = style.getLineStyle()
- linewidth = style.getLineWidth()
- symbol = style.getSymbol()
- symbolsize = style.getSymbolSize()
-
- return CurveStyle(
- color=self.getColor() if color is None else color,
- linestyle=self.getLineStyle() if linestyle is None else linestyle,
- linewidth=self.getLineWidth() if linewidth is None else linewidth,
- symbol=self.getSymbol() if symbol is None else symbol,
- symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize)
-
- else:
- return CurveStyle(color=self.getColor(),
- linestyle=self.getLineStyle(),
- linewidth=self.getLineWidth(),
- symbol=self.getSymbol(),
- symbolsize=self.getSymbolSize())
-
- @deprecated(replacement='Curve.getCurrentStyle()',
- since_version='0.9.0')
- def getCurrentColor(self):
- """Returns the current color of the curve.
-
- This color is either the color of the curve or the highlighted color,
- depending on the highlight state.
-
- :rtype: 4-tuple of float in [0, 1]
- """
- return self.getCurrentStyle().getColor()
-
- def setData(self, x, y, xerror=None, yerror=None, baseline=None, copy=True):
- """Set the data of the curve.
-
- :param numpy.ndarray x: The data corresponding to the x coordinates.
- :param numpy.ndarray y: The data corresponding to the y coordinates.
- :param xerror: Values with the uncertainties on the x values
- :type xerror: A float, or a numpy.ndarray of float32.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for positive errors,
- row 1 for negative errors.
- :param yerror: Values with the uncertainties on the y values.
- :type yerror: A float, or a numpy.ndarray of float32. See xerror.
- :param baseline: curve baseline
- :type baseline: Union[None,float,numpy.ndarray]
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- """
- PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror,
- copy=copy)
- self._setBaseline(baseline=baseline)
diff --git a/silx/gui/plot/items/image.py b/silx/gui/plot/items/image.py
deleted file mode 100644
index 0d9c9a4..0000000
--- a/silx/gui/plot/items/image.py
+++ /dev/null
@@ -1,617 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the :class:`ImageData` and :class:`ImageRgba` items
-of the :class:`Plot`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "08/12/2020"
-
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
-import logging
-
-import numpy
-
-from ....utils.proxy import docstring
-from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn,
- AlphaMixIn, ItemChangedType)
-
-_logger = logging.getLogger(__name__)
-
-
-def _convertImageToRgba32(image, copy=True):
- """Convert an RGB or RGBA image to RGBA32.
-
- It converts from floats in [0, 1], bool, integer and uint in [0, 255]
-
- If the input image is already an RGBA32 image,
- the returned image shares the same data.
-
- :param image: Image to convert to
- :type image: numpy.ndarray with 3 dimensions: height, width, color channels
- :param bool copy: True (Default) to get a copy, False, avoid copy if possible
- :return: The image converted to RGBA32 with dimension: (height, width, 4)
- :rtype: numpy.ndarray of uint8
- """
- assert image.ndim == 3
- assert image.shape[-1] in (3, 4)
-
- # Convert type to uint8
- if image.dtype.name != 'uint8':
- if image.dtype.kind == 'f': # Float in [0, 1]
- image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8)
- elif image.dtype.kind == 'b': # boolean
- image = image.astype(numpy.uint8) * 255
- elif image.dtype.kind in ('i', 'u'): # int, uint
- image = numpy.clip(image, 0, 255).astype(numpy.uint8)
- else:
- raise ValueError('Unsupported image dtype: %s', image.dtype.name)
- copy = False # A copy as already been done, avoid next one
-
- # Convert RGB to RGBA
- if image.shape[-1] == 3:
- new_image = numpy.empty((image.shape[0], image.shape[1], 4),
- dtype=numpy.uint8)
- new_image[:,:,:3] = image
- new_image[:,:, 3] = 255
- return new_image # This is a copy anyway
- else:
- return numpy.array(image, copy=copy)
-
-
-class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
- """Description of an image
-
- :param numpy.ndarray data: Initial image data
- """
-
- def __init__(self, data=None, mask=None):
- DataItem.__init__(self)
- LabelsMixIn.__init__(self)
- DraggableMixIn.__init__(self)
- AlphaMixIn.__init__(self)
- if data is None:
- data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
- self._data = data
- self._mask = mask
- self.__valueDataCache = None # Store default data
- self._origin = (0., 0.)
- self._scale = (1., 1.)
-
- def __getitem__(self, item):
- """Compatibility with PyMca and silx <= 0.4.0"""
- if isinstance(item, slice):
- return [self[index] for index in range(*item.indices(5))]
- elif item == 0:
- return self.getData(copy=False)
- elif item == 1:
- return self.getName()
- elif item == 2:
- info = self.getInfo(copy=False)
- return {} if info is None else info
- elif item == 3:
- return None
- elif item == 4:
- params = {
- 'info': self.getInfo(),
- 'origin': self.getOrigin(),
- 'scale': self.getScale(),
- 'z': self.getZValue(),
- 'selectable': self.isSelectable(),
- 'draggable': self.isDraggable(),
- 'colormap': None,
- 'xlabel': self.getXLabel(),
- 'ylabel': self.getYLabel(),
- }
- return params
- else:
- raise IndexError("Index out of range: %s" % str(item))
-
- def _isPlotLinear(self, plot):
- """Return True if plot only uses linear scale for both of x and y
- axes."""
- linear = plot.getXAxis().LINEAR
- if plot.getXAxis().getScale() != linear:
- return False
- if plot.getYAxis().getScale() != linear:
- return False
- return True
-
- def _getBounds(self):
- if self.getData(copy=False).size == 0: # Empty data
- return None
-
- height, width = self.getData(copy=False).shape[:2]
- origin = self.getOrigin()
- scale = self.getScale()
- # Taking care of scale might be < 0
- xmin, xmax = origin[0], origin[0] + width * scale[0]
- if xmin > xmax:
- xmin, xmax = xmax, xmin
- # Taking care of scale might be < 0
- ymin, ymax = origin[1], origin[1] + height * scale[1]
- if ymin > ymax:
- ymin, ymax = ymax, ymin
-
- plot = self.getPlot()
- if plot is not None and not self._isPlotLinear(plot):
- return None
- else:
- return xmin, xmax, ymin, ymax
-
- @docstring(DraggableMixIn)
- def drag(self, from_, to):
- origin = self.getOrigin()
- self.setOrigin((origin[0] + to[0] - from_[0],
- origin[1] + to[1] - from_[1]))
-
- def getData(self, copy=True):
- """Returns the image data
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray
- """
- return numpy.array(self._data, copy=copy)
-
- def setData(self, data):
- """Set the image data
-
- :param numpy.ndarray data:
- """
- previousShape = self._data.shape
- self._data = data
- self._valueDataChanged()
- self._boundsChanged()
- self._updated(ItemChangedType.DATA)
-
- if (self.getMaskData(copy=False) is not None and
- previousShape != self._data.shape):
- # Data shape changed, so mask shape changes.
- # Send event, mask is lazily updated in getMaskData
- self._updated(ItemChangedType.MASK)
-
- def getMaskData(self, copy=True):
- """Returns the mask data
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: Union[None,numpy.ndarray]
- """
- if self._mask is None:
- return None
-
- # Update mask if it does not match data shape
- shape = self.getData(copy=False).shape[:2]
- if self._mask.shape != shape:
- # Clip/extend mask to match data
- newMask = numpy.zeros(shape, dtype=self._mask.dtype)
- newMask[:self._mask.shape[0], :self._mask.shape[1]] = self._mask[:shape[0], :shape[1]]
- self._mask = newMask
-
- return numpy.array(self._mask, copy=copy)
-
- def setMaskData(self, mask, copy=True):
- """Set the image data
-
- :param numpy.ndarray data:
- :param bool copy: True (Default) to make a copy,
- False to use as is (do not modify!)
- """
- if mask is not None:
- mask = numpy.array(mask, copy=copy)
-
- shape = self.getData(copy=False).shape[:2]
- if mask.shape != shape:
- _logger.warning("Inconsistent shape between mask and data %s, %s", mask.shape, shape)
- # Clip/extent is done lazily in getMaskData
- elif self._mask is None:
- return # No update
-
- self._mask = mask
- self._valueDataChanged()
- self._updated(ItemChangedType.MASK)
-
- def _valueDataChanged(self):
- """Clear cache of default data array"""
- self.__valueDataCache = None
-
- def _getValueData(self, copy=True):
- """Return data used by :meth:`getValueData`
-
- :param bool copy:
- :rtype: numpy.ndarray
- """
- return self.getData(copy=copy)
-
- def getValueData(self, copy=True):
- """Return data (converted to int or float) with mask applied.
-
- Masked values are set to Not-A-Number.
- It returns a 2D array of values (int or float).
-
- :param bool copy:
- :rtype: numpy.ndarray
- """
- if self.__valueDataCache is None:
- data = self._getValueData(copy=False)
- mask = self.getMaskData(copy=False)
- if mask is not None:
- if numpy.issubdtype(data.dtype, numpy.floating):
- dtype = data.dtype
- else:
- dtype = numpy.float64
- data = numpy.array(data, dtype=dtype, copy=True)
- data[mask != 0] = numpy.NaN
- self.__valueDataCache = data
- return numpy.array(self.__valueDataCache, copy=copy)
-
- def getRgbaImageData(self, copy=True):
- """Get the displayed RGB(A) image
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :returns: numpy.ndarray of uint8 of shape (height, width, 4)
- """
- raise NotImplementedError('This MUST be implemented in sub-class')
-
- def getOrigin(self):
- """Returns the offset from origin at which to display the image.
-
- :rtype: 2-tuple of float
- """
- return self._origin
-
- def setOrigin(self, origin):
- """Set the offset from origin at which to display the image.
-
- :param origin: (ox, oy) Offset from origin
- :type origin: float or 2-tuple of float
- """
- if isinstance(origin, abc.Sequence):
- origin = float(origin[0]), float(origin[1])
- else: # single value origin
- origin = float(origin), float(origin)
- if origin != self._origin:
- self._origin = origin
- self._boundsChanged()
- self._updated(ItemChangedType.POSITION)
-
- def getScale(self):
- """Returns the scale of the image in data coordinates.
-
- :rtype: 2-tuple of float
- """
- return self._scale
-
- def setScale(self, scale):
- """Set the scale of the image
-
- :param scale: (sx, sy) Scale of the image
- :type scale: float or 2-tuple of float
- """
- if isinstance(scale, abc.Sequence):
- scale = float(scale[0]), float(scale[1])
- else: # single value scale
- scale = float(scale), float(scale)
-
- if scale != self._scale:
- self._scale = scale
- self._boundsChanged()
- self._updated(ItemChangedType.SCALE)
-
-
-class ImageData(ImageBase, ColormapMixIn):
- """Description of a data image with a colormap"""
-
- def __init__(self):
- ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32))
- ColormapMixIn.__init__(self)
- self._alternativeImage = None
- self.__alpha = None
-
- def _addBackendRenderer(self, backend):
- """Update backend renderer"""
- plot = self.getPlot()
- assert plot is not None
- if not self._isPlotLinear(plot):
- # Do not render with non linear scales
- return None
-
- if (self.getAlternativeImageData(copy=False) is not None or
- self.getAlphaData(copy=False) is not None):
- dataToUse = self.getRgbaImageData(copy=False)
- else:
- dataToUse = self.getData(copy=False)
-
- if dataToUse.size == 0:
- return None # No data to display
-
- colormap = self.getColormap()
- if colormap.isAutoscale():
- # Avoid backend to compute autoscale: use item cache
- colormap = colormap.copy()
- colormap.setVRange(*colormap.getColormapRange(self))
-
- return backend.addImage(dataToUse,
- origin=self.getOrigin(),
- scale=self.getScale(),
- colormap=colormap,
- alpha=self.getAlpha())
-
- def __getitem__(self, item):
- """Compatibility with PyMca and silx <= 0.4.0"""
- if item == 3:
- return self.getAlternativeImageData(copy=False)
-
- params = ImageBase.__getitem__(self, item)
- if item == 4:
- params['colormap'] = self.getColormap()
-
- return params
-
- def getRgbaImageData(self, copy=True):
- """Get the displayed RGB(A) image
-
- :returns: Array of uint8 of shape (height, width, 4)
- :rtype: numpy.ndarray
- """
- alternative = self.getAlternativeImageData(copy=False)
- if alternative is not None:
- return _convertImageToRgba32(alternative, copy=copy)
- else:
- # Apply colormap, in this case an new array is always returned
- colormap = self.getColormap()
- image = colormap.applyToData(self)
- alphaImage = self.getAlphaData(copy=False)
- if alphaImage is not None:
- # Apply transparency
- image[:,:, 3] = image[:,:, 3] * alphaImage
- return image
-
- def getAlternativeImageData(self, copy=True):
- """Get the optional RGBA image that is displayed instead of the data
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: Union[None,numpy.ndarray]
- """
- if self._alternativeImage is None:
- return None
- else:
- return numpy.array(self._alternativeImage, copy=copy)
-
- def getAlphaData(self, copy=True):
- """Get the optional transparency image applied on the data
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: Union[None,numpy.ndarray]
- """
- if self.__alpha is None:
- return None
- else:
- return numpy.array(self.__alpha, copy=copy)
-
- def setData(self, data, alternative=None, alpha=None, copy=True):
- """"Set the image data and optionally an alternative RGB(A) representation
-
- :param numpy.ndarray data: Data array with 2 dimensions (h, w)
- :param alternative: RGB(A) image to display instead of data,
- shape: (h, w, 3 or 4)
- :type alternative: Union[None,numpy.ndarray]
- :param alpha: An array of transparency value in [0, 1] to use for
- display with shape: (h, w)
- :type alpha: Union[None,numpy.ndarray]
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- """
- data = numpy.array(data, copy=copy)
- assert data.ndim == 2
- if data.dtype.kind == 'b':
- _logger.warning(
- 'Converting boolean image to int8 to plot it.')
- data = numpy.array(data, copy=False, dtype=numpy.int8)
- elif numpy.iscomplexobj(data):
- _logger.warning(
- 'Converting complex image to absolute value to plot it.')
- data = numpy.absolute(data)
-
- if alternative is not None:
- alternative = numpy.array(alternative, copy=copy)
- assert alternative.ndim == 3
- assert alternative.shape[2] in (3, 4)
- assert alternative.shape[:2] == data.shape[:2]
- self._alternativeImage = alternative
-
- if alpha is not None:
- alpha = numpy.array(alpha, copy=copy)
- assert alpha.shape == data.shape
- if alpha.dtype.kind != 'f':
- alpha = alpha.astype(numpy.float32)
- if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
- alpha = numpy.clip(alpha, 0., 1.)
- self.__alpha = alpha
-
- super().setData(data)
-
- def _updated(self, event=None, checkVisibility=True):
- # Synchronizes colormapped data if changed
- if event in (ItemChangedType.DATA, ItemChangedType.MASK):
- self._setColormappedData(
- self.getValueData(copy=False),
- copy=False)
- super()._updated(event=event, checkVisibility=checkVisibility)
-
-
-class ImageRgba(ImageBase):
- """Description of an RGB(A) image"""
-
- def __init__(self):
- ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8))
-
- def _addBackendRenderer(self, backend):
- """Update backend renderer"""
- plot = self.getPlot()
- assert plot is not None
- if not self._isPlotLinear(plot):
- # Do not render with non linear scales
- return None
-
- data = self.getData(copy=False)
-
- if data.size == 0:
- return None # No data to display
-
- return backend.addImage(data,
- origin=self.getOrigin(),
- scale=self.getScale(),
- colormap=None,
- alpha=self.getAlpha())
-
- def getRgbaImageData(self, copy=True):
- """Get the displayed RGB(A) image
-
- :returns: numpy.ndarray of uint8 of shape (height, width, 4)
- """
- return _convertImageToRgba32(self.getData(copy=False), copy=copy)
-
- def setData(self, data, copy=True):
- """Set the image data
-
- :param data: RGB(A) image data to set
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- """
- data = numpy.array(data, copy=copy)
- assert data.ndim == 3
- assert data.shape[-1] in (3, 4)
- super().setData(data)
-
- def _getValueData(self, copy=True):
- """Compute the intensity of the RGBA image as default data.
-
- Conversion: https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
-
- :param bool copy:
- """
- rgba = self.getRgbaImageData(copy=False).astype(numpy.float32)
- intensity = (rgba[:, :, 0] * 0.299 +
- rgba[:, :, 1] * 0.587 +
- rgba[:, :, 2] * 0.114)
- intensity *= rgba[:, :, 3] / 255.
- return intensity
-
-
-class MaskImageData(ImageData):
- """Description of an image used as a mask.
-
- This class is used to flag mask items. This information is used to improve
- internal silx widgets.
- """
- pass
-
-
-class ImageStack(ImageData):
- """Item to store a stack of images and to show it in the plot as one
- of the images of the stack.
-
- The stack is a 3D array ordered this way: `frame id, y, x`.
- So the first image of the stack can be reached this way: `stack[0, :, :]`
- """
-
- def __init__(self):
- ImageData.__init__(self)
- self.__stack = None
- """A 3D numpy array (or a mimic one, see ListOfImages)"""
- self.__stackPosition = None
- """Displayed position in the cube"""
-
- def setStackData(self, stack, position=None, copy=True):
- """Set the stack data
-
- :param stack: A 3D numpy array like
- :param int position: The position of the displayed image in the stack
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- """
- if self.__stack is stack:
- return
- if copy:
- stack = numpy.array(stack)
- assert stack.ndim == 3
- self.__stack = stack
- if position is not None:
- self.__stackPosition = position
- if self.__stackPosition is None:
- self.__stackPosition = 0
- self.__updateDisplayedData()
-
- def getStackData(self, copy=True):
- """Get the stored stack array.
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: A 3D numpy array, or numpy array like
- """
- if copy:
- return numpy.array(self.__stack)
- else:
- return self.__stack
-
- def setStackPosition(self, pos):
- """Set the displayed position on the stack.
-
- This function will clamp the stack position according to
- the real size of the first axis of the stack.
-
- :param int pos: A position on the first axis of the stack.
- """
- if self.__stackPosition == pos:
- return
- self.__stackPosition = pos
- self.__updateDisplayedData()
-
- def getStackPosition(self):
- """Get the displayed position of the stack.
-
- :rtype: int
- """
- return self.__stackPosition
-
- def __updateDisplayedData(self):
- """Update the displayed frame whenever the stack or the stack
- position are updated."""
- if self.__stack is None or self.__stackPosition is None:
- empty = numpy.array([]).reshape(0, 0)
- self.setData(empty, copy=False)
- return
- size = len(self.__stack)
- self.__stackPosition = numpy.clip(self.__stackPosition, 0, size)
- self.setData(self.__stack[self.__stackPosition], copy=False)
diff --git a/silx/gui/plot/items/scatter.py b/silx/gui/plot/items/scatter.py
deleted file mode 100644
index 2d54223..0000000
--- a/silx/gui/plot/items/scatter.py
+++ /dev/null
@@ -1,973 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the :class:`Scatter` item of the :class:`Plot`.
-"""
-
-from __future__ import division
-
-
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "29/03/2017"
-
-
-from collections import namedtuple
-import logging
-import threading
-import numpy
-
-from collections import defaultdict
-from concurrent.futures import ThreadPoolExecutor, CancelledError
-
-from ....utils.proxy import docstring
-from ....math.combo import min_max
-from ....math.histogram import Histogramnd
-from ....utils.weakref import WeakList
-from .._utils.delaunay import delaunay
-from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
-from .axis import Axis
-from ._pick import PickingResult
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
- """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method.
- """
-
- def __init__(self, *args, **kwargs):
- super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs)
- self.__futures = defaultdict(WeakList)
- self.__lock = threading.RLock()
-
- def submit_greedy(self, queue, fn, *args, **kwargs):
- """Same as :meth:`submit` but cancel previous tasks in given queue.
-
- This means that when a new task is submitted for a given queue,
- all other pending tasks of that queue are cancelled.
-
- :param queue: Identifier of the queue. This must be hashable.
- :param callable fn: The callable to call with provided extra arguments
- :return: Future corresponding to this task
- :rtype: concurrent.futures.Future
- """
- with self.__lock:
- # Cancel previous tasks in given queue
- for future in self.__futures.pop(queue, []):
- if not future.done():
- future.cancel()
-
- future = super(_GreedyThreadPoolExecutor, self).submit(
- fn, *args, **kwargs)
- self.__futures[queue].append(future)
-
- return future
-
-
-# Functions to guess grid shape from coordinates
-
-def _get_z_line_length(array):
- """Return length of line if array is a Z-like 2D regular grid.
-
- :param numpy.ndarray array: The 1D array of coordinates to check
- :return: 0 if no line length could be found,
- else the number of element per line.
- :rtype: int
- """
- sign = numpy.sign(numpy.diff(array))
- if len(sign) == 0 or sign[0] == 0: # We don't handle that
- return 0
- # Check this way to account for 0 sign (i.e., diff == 0)
- beginnings = numpy.where(sign == - sign[0])[0] + 1
- if len(beginnings) == 0:
- return 0
- length = beginnings[0]
- if numpy.all(numpy.equal(numpy.diff(beginnings), length)):
- return length
- return 0
-
-
-def _guess_z_grid_shape(x, y):
- """Guess the shape of a grid from (x, y) coordinates.
-
- The grid might contain more elements than x and y,
- as the last line might be partly filled.
-
- :param numpy.ndarray x:
- :paran numpy.ndarray y:
- :returns: (order, (height, width)) of the regular grid,
- or None if could not guess one.
- 'order' is 'row' if X (i.e., column) is the fast dimension, else 'column'.
- :rtype: Union[List(str,int),None]
- """
- width = _get_z_line_length(x)
- if width != 0:
- return 'row', (int(numpy.ceil(len(x) / width)), width)
- else:
- height = _get_z_line_length(y)
- if height != 0:
- return 'column', (height, int(numpy.ceil(len(y) / height)))
- return None
-
-
-def is_monotonic(array):
- """Returns whether array is monotonic (increasing or decreasing).
-
- :param numpy.ndarray array: 1D array-like container.
- :returns: 1 if array is monotonically increasing,
- -1 if array is monotonically decreasing,
- 0 if array is not monotonic
- :rtype: int
- """
- diff = numpy.diff(numpy.ravel(array))
- with numpy.errstate(invalid='ignore'):
- if numpy.all(diff >= 0):
- return 1
- elif numpy.all(diff <= 0):
- return -1
- else:
- return 0
-
-
-def _guess_grid(x, y):
- """Guess a regular grid from the points.
-
- Result convention is (x, y)
-
- :param numpy.ndarray x: X coordinates of the points
- :param numpy.ndarray y: Y coordinates of the points
- :returns: (order, (height, width)
- order is 'row' or 'column'
- :rtype: Union[List[str,List[int]],None]
- """
- x, y = numpy.ravel(x), numpy.ravel(y)
-
- guess = _guess_z_grid_shape(x, y)
- if guess is not None:
- return guess
-
- else:
- # Cannot guess a regular grid
- # Let's assume it's a single line
- order = 'row' # or 'column' doesn't matter for a single line
- y_monotonic = is_monotonic(y)
- if is_monotonic(x) or y_monotonic: # we can guess a line
- x_min, x_max = min_max(x)
- y_min, y_max = min_max(y)
-
- if not y_monotonic or x_max - x_min >= y_max - y_min:
- # x only is monotonic or both are and X varies more
- # line along X
- shape = 1, len(x)
- else:
- # y only is monotonic or both are and Y varies more
- # line along Y
- shape = len(y), 1
-
- else: # Cannot guess a line from the points
- return None
-
- return order, shape
-
-
-def _quadrilateral_grid_coords(points):
- """Compute an irregular grid of quadrilaterals from a set of points
-
- The input points are expected to lie on a grid.
-
- :param numpy.ndarray points:
- 3D data set of 2D input coordinates (height, width, 2)
- height and width must be at least 2.
- :return: 3D dataset of 2D coordinates of the grid (height+1, width+1, 2)
- """
- assert points.ndim == 3
- assert points.shape[0] >= 2
- assert points.shape[1] >= 2
- assert points.shape[2] == 2
-
- dim0, dim1 = points.shape[:2]
- grid_points = numpy.zeros((dim0 + 1, dim1 + 1, 2), dtype=numpy.float64)
-
- # Compute inner points as mean of 4 neighbours
- neighbour_view = numpy.lib.stride_tricks.as_strided(
- points,
- shape=(dim0 - 1, dim1 - 1, 2, 2, points.shape[2]),
- strides=points.strides[:2] + points.strides[:2] + points.strides[-1:], writeable=False)
- inner_points = numpy.mean(neighbour_view, axis=(2, 3))
- grid_points[1:-1, 1:-1] = inner_points
-
- # Compute 'vertical' sides
- # Alternative: grid_points[1:-1, [0, -1]] = points[:-1, [0, -1]] + points[1:, [0, -1]] - inner_points[:, [0, -1]]
- grid_points[1:-1, [0, -1], 0] = points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0]
- grid_points[1:-1, [0, -1], 1] = inner_points[:, [0, -1], 1]
-
- # Compute 'horizontal' sides
- grid_points[[0, -1], 1:-1, 0] = inner_points[[0, -1], :, 0]
- grid_points[[0, -1], 1:-1, 1] = points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1]
-
- # Compute corners
- d0, d1 = [0, 0, -1, -1], [0, -1, -1, 0]
- grid_points[d0, d1] = 2 * points[d0, d1] - inner_points[d0, d1]
- return grid_points
-
-
-def _quadrilateral_grid_as_triangles(points):
- """Returns the points and indices to make a grid of quadirlaterals
-
- :param numpy.ndarray points:
- 3D array of points (height, width, 2)
- :return: triangle corners (4 * N, 2), triangle indices (2 * N, 3)
- With N = height * width, the number of input points
- """
- nbpoints = numpy.prod(points.shape[:2])
-
- grid = _quadrilateral_grid_coords(points)
- coords = numpy.empty((4 * nbpoints, 2), dtype=grid.dtype)
- coords[::4] = grid[:-1, :-1].reshape(-1, 2)
- coords[1::4] = grid[1:, :-1].reshape(-1, 2)
- coords[2::4] = grid[:-1, 1:].reshape(-1, 2)
- coords[3::4] = grid[1:, 1:].reshape(-1, 2)
-
- indices = numpy.empty((2 * nbpoints, 3), dtype=numpy.uint32)
- indices[::2, 0] = numpy.arange(0, 4 * nbpoints, 4)
- indices[::2, 1] = numpy.arange(1, 4 * nbpoints, 4)
- indices[::2, 2] = numpy.arange(2, 4 * nbpoints, 4)
- indices[1::2, 0] = indices[::2, 1]
- indices[1::2, 1] = indices[::2, 2]
- indices[1::2, 2] = numpy.arange(3, 4 * nbpoints, 4)
-
- return coords, indices
-
-
-_RegularGridInfo = namedtuple(
- '_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order'])
-
-
-_HistogramInfo = namedtuple(
- '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape'])
-
-
-class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
- """Description of a scatter"""
-
- _DEFAULT_SELECTABLE = True
- """Default selectable state for scatter plots"""
-
- _SUPPORTED_SCATTER_VISUALIZATION = (
- ScatterVisualizationMixIn.Visualization.POINTS,
- ScatterVisualizationMixIn.Visualization.SOLID,
- ScatterVisualizationMixIn.Visualization.REGULAR_GRID,
- ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID,
- ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC,
- )
- """Overrides supported Visualizations"""
-
- def __init__(self):
- PointsBase.__init__(self)
- ColormapMixIn.__init__(self)
- ScatterVisualizationMixIn.__init__(self)
- self._value = ()
- self.__alpha = None
- # Cache Delaunay triangulation future object
- self.__delaunayFuture = None
- # Cache interpolator future object
- self.__interpolatorFuture = None
- self.__executor = None
-
- # Cache triangles: x, y, indices
- self.__cacheTriangles = None, None, None
-
- # Cache regular grid and histogram info
- self.__cacheRegularGridInfo = None
- self.__cacheHistogramInfo = None
-
- def _updateColormappedData(self):
- """Update the colormapped data, to be called when changed"""
- if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
- histoInfo = self.__getHistogramInfo()
- if histoInfo is None:
- data = None
- else:
- data = getattr(
- histoInfo,
- self.getVisualizationParameter(
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
- else:
- data = self.getValueData(copy=False)
- self._setColormappedData(data, copy=False)
-
- @docstring(ScatterVisualizationMixIn)
- def setVisualization(self, mode):
- previous = self.getVisualization()
- if super().setVisualization(mode):
- if (bool(mode is self.Visualization.BINNED_STATISTIC) ^
- bool(previous is self.Visualization.BINNED_STATISTIC)):
- self._updateColormappedData()
- return True
- else:
- return False
-
- @docstring(ScatterVisualizationMixIn)
- def setVisualizationParameter(self, parameter, value):
- parameter = self.VisualizationParameter.from_value(parameter)
-
- if super(Scatter, self).setVisualizationParameter(parameter, value):
- if parameter in (self.VisualizationParameter.GRID_BOUNDS,
- self.VisualizationParameter.GRID_MAJOR_ORDER,
- self.VisualizationParameter.GRID_SHAPE):
- self.__cacheRegularGridInfo = None
-
- if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
- self.VisualizationParameter.DATA_BOUNDS_HINT):
- if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
- self.VisualizationParameter.DATA_BOUNDS_HINT):
- self.__cacheHistogramInfo = None # Clean-up cache
- if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
- self._updateColormappedData()
- return True
- else:
- return False
-
- @docstring(ScatterVisualizationMixIn)
- def getCurrentVisualizationParameter(self, parameter):
- value = self.getVisualizationParameter(parameter)
- if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or
- value is not None):
- return value # Value has been set, return it
-
- elif parameter is self.VisualizationParameter.GRID_BOUNDS:
- grid = self.__getRegularGridInfo()
- return None if grid is None else grid.bounds
-
- elif parameter is self.VisualizationParameter.GRID_MAJOR_ORDER:
- grid = self.__getRegularGridInfo()
- return None if grid is None else grid.order
-
- elif parameter is self.VisualizationParameter.GRID_SHAPE:
- grid = self.__getRegularGridInfo()
- return None if grid is None else grid.shape
-
- elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
- info = self.__getHistogramInfo()
- return None if info is None else info.shape
-
- else:
- raise NotImplementedError()
-
- def __getRegularGridInfo(self):
- """Get grid info"""
- if self.__cacheRegularGridInfo is None:
- shape = self.getVisualizationParameter(
- self.VisualizationParameter.GRID_SHAPE)
- order = self.getVisualizationParameter(
- self.VisualizationParameter.GRID_MAJOR_ORDER)
- if shape is None or order is None:
- guess = _guess_grid(self.getXData(copy=False),
- self.getYData(copy=False))
- if guess is None:
- _logger.warning(
- 'Cannot guess a grid: Cannot display as regular grid image')
- return None
- if shape is None:
- shape = guess[1]
- if order is None:
- order = guess[0]
-
- nbpoints = len(self.getXData(copy=False))
- if nbpoints > shape[0] * shape[1]:
- # More data points that provided grid shape: enlarge grid
- _logger.warning(
- "More data points than provided grid shape size: extends grid")
- dim0, dim1 = shape
- if order == 'row': # keep dim1, enlarge dim0
- dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0)
- else: # keep dim0, enlarge dim1
- dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0)
- shape = dim0, dim1
-
- bounds = self.getVisualizationParameter(
- self.VisualizationParameter.GRID_BOUNDS)
- if bounds is None:
- x, y = self.getXData(copy=False), self.getYData(copy=False)
- min_, max_ = min_max(x)
- xRange = (min_, max_) if (x[0] - min_) < (max_ - x[0]) else (max_, min_)
- min_, max_ = min_max(y)
- yRange = (min_, max_) if (y[0] - min_) < (max_ - y[0]) else (max_, min_)
- bounds = (xRange[0], yRange[0]), (xRange[1], yRange[1])
-
- begin, end = bounds
- scale = ((end[0] - begin[0]) / max(1, shape[1] - 1),
- (end[1] - begin[1]) / max(1, shape[0] - 1))
- if scale[0] == 0 and scale[1] == 0:
- scale = 1., 1.
- elif scale[0] == 0:
- scale = scale[1], scale[1]
- elif scale[1] == 0:
- scale = scale[0], scale[0]
-
- origin = begin[0] - 0.5 * scale[0], begin[1] - 0.5 * scale[1]
-
- self.__cacheRegularGridInfo = _RegularGridInfo(
- bounds=bounds, origin=origin, scale=scale, shape=shape, order=order)
-
- return self.__cacheRegularGridInfo
-
- def __getHistogramInfo(self):
- """Get histogram info"""
- if self.__cacheHistogramInfo is None:
- shape = self.getVisualizationParameter(
- self.VisualizationParameter.BINNED_STATISTIC_SHAPE)
- if shape is None:
- shape = 100, 100 # TODO compute auto shape
-
- x, y, values = self.getData(copy=False)[:3]
- if len(x) == 0: # No histogram
- return None
-
- if not numpy.issubdtype(x.dtype, numpy.floating):
- x = x.astype(numpy.float64)
- if not numpy.issubdtype(y.dtype, numpy.floating):
- y = y.astype(numpy.float64)
- if not numpy.issubdtype(values.dtype, numpy.floating):
- values = values.astype(numpy.float64)
-
- ranges = (tuple(min_max(y, finite=True)),
- tuple(min_max(x, finite=True)))
- rangesHint = self.getVisualizationParameter(
- self.VisualizationParameter.DATA_BOUNDS_HINT)
- if rangesHint is not None:
- ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax))
- for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint))
-
- points = numpy.transpose(numpy.array((y, x)))
- counts, sums, bin_edges = Histogramnd(
- points,
- histo_range=ranges,
- n_bins=shape,
- weights=values)
- yEdges, xEdges = bin_edges
- origin = xEdges[0], yEdges[0]
- scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
- (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1))
-
- with numpy.errstate(divide='ignore', invalid='ignore'):
- histo = sums / counts
-
- self.__cacheHistogramInfo = _HistogramInfo(
- mean=histo, count=counts, sum=sums,
- origin=origin, scale=scale, shape=shape)
-
- return self.__cacheHistogramInfo
-
- def _addBackendRenderer(self, backend):
- """Update backend renderer"""
- # Filter-out values <= 0
- xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
- copy=False, displayed=True)
-
- # Remove not finite numbers (this includes filtered out x, y <= 0)
- mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered))
- xFiltered = xFiltered[mask]
- yFiltered = yFiltered[mask]
-
- if len(xFiltered) == 0:
- return None # No data to display, do not add renderer to backend
-
- visualization = self.getVisualization()
-
- if visualization is self.Visualization.BINNED_STATISTIC:
- plot = self.getPlot()
- if (plot is None or
- plot.getXAxis().getScale() != Axis.LINEAR or
- plot.getYAxis().getScale() != Axis.LINEAR):
- # Those visualizations are not available with log scaled axes
- return None
-
- histoInfo = self.__getHistogramInfo()
- if histoInfo is None:
- return None
- data = getattr(histoInfo, self.getVisualizationParameter(
- self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
-
- return backend.addImage(
- data=data,
- origin=histoInfo.origin,
- scale=histoInfo.scale,
- colormap=self.getColormap(),
- alpha=self.getAlpha())
-
- # Compute colors
- cmap = self.getColormap()
- rgbacolors = cmap.applyToData(self)
-
- if self.__alpha is not None:
- rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8)
-
- visualization = self.getVisualization()
-
- if visualization is self.Visualization.POINTS:
- return backend.addCurve(xFiltered, yFiltered,
- color=rgbacolors[mask],
- symbol=self.getSymbol(),
- linewidth=0,
- linestyle="",
- yaxis='left',
- xerror=xerror,
- yerror=yerror,
- fill=False,
- alpha=self.getAlpha(),
- symbolsize=self.getSymbolSize(),
- baseline=None)
-
- else:
- plot = self.getPlot()
- if (plot is None or
- plot.getXAxis().getScale() != Axis.LINEAR or
- plot.getYAxis().getScale() != Axis.LINEAR):
- # Those visualizations are not available with log scaled axes
- return None
-
- if visualization is self.Visualization.SOLID:
- triangulation = self._getDelaunay().result()
- if triangulation is None:
- _logger.warning(
- 'Cannot get a triangulation: Cannot display as solid surface')
- return None
- else:
- triangles = triangulation.simplices.astype(numpy.int32)
- return backend.addTriangles(xFiltered,
- yFiltered,
- triangles,
- color=rgbacolors[mask],
- alpha=self.getAlpha())
-
- elif visualization is self.Visualization.REGULAR_GRID:
- gridInfo = self.__getRegularGridInfo()
- if gridInfo is None:
- return None
-
- dim0, dim1 = gridInfo.shape
- if gridInfo.order == 'column': # transposition needed
- dim0, dim1 = dim1, dim0
-
- if len(rgbacolors) == dim0 * dim1:
- image = rgbacolors.reshape(dim0, dim1, -1)
- else:
- # The points do not fill the whole image
- image = numpy.empty((dim0 * dim1, 4), dtype=rgbacolors.dtype)
- image[:len(rgbacolors)] = rgbacolors
- image[len(rgbacolors):] = 0, 0, 0, 0 # Transparent pixels
- image.shape = dim0, dim1, -1
-
- if gridInfo.order == 'column':
- image = numpy.transpose(image, axes=(1, 0, 2))
-
- return backend.addImage(
- data=image,
- origin=gridInfo.origin,
- scale=gridInfo.scale,
- colormap=None,
- alpha=self.getAlpha())
-
- elif visualization is self.Visualization.IRREGULAR_GRID:
- gridInfo = self.__getRegularGridInfo()
- if gridInfo is None:
- return None
-
- shape = gridInfo.shape
- if shape is None: # No shape, no display
- return None
-
- nbpoints = len(xFiltered)
- if nbpoints == 1:
- # single point, render as a square points
- return backend.addCurve(xFiltered, yFiltered,
- color=rgbacolors[mask],
- symbol='s',
- linewidth=0,
- linestyle="",
- yaxis='left',
- xerror=None,
- yerror=None,
- fill=False,
- alpha=self.getAlpha(),
- symbolsize=7,
- baseline=None)
-
- # Make shape include all points
- gridOrder = gridInfo.order
- if nbpoints != numpy.prod(shape):
- if gridOrder == 'row':
- shape = int(numpy.ceil(nbpoints / shape[1])), shape[1]
- else: # column-major order
- shape = shape[0], int(numpy.ceil(nbpoints / shape[0]))
-
- if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points
- points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64)
- # Use row/column major depending on shape, not on info value
- gridOrder = 'row' if shape[0] == 1 else 'column'
-
- if gridOrder == 'row':
- points[0, :, 0] = xFiltered
- points[0, :, 1] = yFiltered
- else: # column-major order
- points[0, :, 0] = yFiltered
- points[0, :, 1] = xFiltered
-
- # Add a second line that will be clipped in the end
- points[1, :-1] = points[0, :-1] + numpy.cross(
- points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2]
- points[1, -1] = points[0, -1] + numpy.cross(
- points[0, -1] - points[0, -2], (0., 0., 1.))[:2]
-
- points.shape = 2, nbpoints, 2 # Use same shape for both orders
- coords, indices = _quadrilateral_grid_as_triangles(points)
-
- elif gridOrder == 'row': # row-major order
- if nbpoints != numpy.prod(shape):
- points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
- points[:nbpoints, 0] = xFiltered
- points[:nbpoints, 1] = yFiltered
- # Index of last element of last fully filled row
- index = (nbpoints // shape[1]) * shape[1]
- points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index]
- points[nbpoints:, 1] = yFiltered[-1]
- else:
- points = numpy.transpose((xFiltered, yFiltered))
- points.shape = shape[0], shape[1], 2
-
- else: # column-major order
- if nbpoints != numpy.prod(shape):
- points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
- points[:nbpoints, 0] = yFiltered
- points[:nbpoints, 1] = xFiltered
- # Index of last element of last fully filled column
- index = (nbpoints // shape[0]) * shape[0]
- points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index]
- points[nbpoints:, 1] = xFiltered[-1]
- else:
- points = numpy.transpose((yFiltered, xFiltered))
- points.shape = shape[1], shape[0], 2
-
- coords, indices = _quadrilateral_grid_as_triangles(points)
-
- # Remove unused extra triangles
- coords = coords[:4*nbpoints]
- indices = indices[:2*nbpoints]
-
- if gridOrder == 'row':
- x, y = coords[:, 0], coords[:, 1]
- else: # column-major order
- y, x = coords[:, 0], coords[:, 1]
-
- rgbacolors = rgbacolors[mask] # Filter-out not finite points
- gridcolors = numpy.empty(
- (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype)
- for first in range(4):
- gridcolors[first::4] = rgbacolors[:nbpoints]
-
- return backend.addTriangles(x,
- y,
- indices,
- color=gridcolors,
- alpha=self.getAlpha())
-
- else:
- _logger.error("Unhandled visualization %s", visualization)
- return None
-
- @docstring(PointsBase)
- def pick(self, x, y):
- result = super(Scatter, self).pick(x, y)
-
- if result is not None:
- visualization = self.getVisualization()
-
- if visualization is self.Visualization.IRREGULAR_GRID:
- # Specific handling of picking for the irregular grid mode
- index = result.getIndices(copy=False)[0] // 4
- result = PickingResult(self, (index,))
-
- elif visualization is self.Visualization.REGULAR_GRID:
- # Specific handling of picking for the regular grid mode
- picked = result.getIndices(copy=False)
- if picked is None:
- return None
- row, column = picked[0][0], picked[1][0]
-
- gridInfo = self.__getRegularGridInfo()
- if gridInfo is None:
- return None
-
- if gridInfo.order == 'row':
- index = row * gridInfo.shape[1] + column
- else:
- index = row + column * gridInfo.shape[0]
- if index >= len(self.getXData(copy=False)): # OK as long as not log scale
- return None # Image can be larger than scatter
-
- result = PickingResult(self, (index,))
-
- elif visualization is self.Visualization.BINNED_STATISTIC:
- picked = result.getIndices(copy=False)
- if picked is None or len(picked) == 0 or len(picked[0]) == 0:
- return None
- row, col = picked[0][0], picked[1][0]
- histoInfo = self.__getHistogramInfo()
- if histoInfo is None:
- return None
- sx, sy = histoInfo.scale
- ox, oy = histoInfo.origin
- xdata = self.getXData(copy=False)
- ydata = self.getYData(copy=False)
- indices = numpy.nonzero(numpy.logical_and(
- numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)),
- numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0]
- result = None if len(indices) == 0 else PickingResult(self, indices)
-
- return result
-
- def __getExecutor(self):
- """Returns async greedy executor
-
- :rtype: _GreedyThreadPoolExecutor
- """
- if self.__executor is None:
- self.__executor = _GreedyThreadPoolExecutor(max_workers=2)
- return self.__executor
-
- def _getDelaunay(self):
- """Returns a :class:`Future` which result is the Delaunay object.
-
- :rtype: concurrent.futures.Future
- """
- if self.__delaunayFuture is None or self.__delaunayFuture.cancelled():
- # Need to init a new delaunay
- x, y = self.getData(copy=False)[:2]
- # Remove not finite points
- mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
-
- self.__delaunayFuture = self.__getExecutor().submit_greedy(
- 'delaunay', delaunay, x[mask], y[mask])
-
- return self.__delaunayFuture
-
- @staticmethod
- def __initInterpolator(delaunayFuture, values):
- """Returns an interpolator for the given data points
-
- :param concurrent.futures.Future delaunayFuture:
- Future object which result is a Delaunay object
- :param numpy.ndarray values: The data value of valid points.
- :rtype: Union[callable,None]
- """
- # Wait for Delaunay to complete
- try:
- triangulation = delaunayFuture.result()
- except CancelledError:
- triangulation = None
-
- if triangulation is None:
- interpolator = None # Error case
- else:
- # Lazy-loading of interpolator
- try:
- from scipy.interpolate import LinearNDInterpolator
- except ImportError:
- LinearNDInterpolator = None
-
- if LinearNDInterpolator is not None:
- interpolator = LinearNDInterpolator(triangulation, values)
-
- # First call takes a while, do it here
- interpolator([(0., 0.)])
-
- else:
- # Fallback using matplotlib interpolator
- import matplotlib.tri
-
- x, y = triangulation.points.T
- tri = matplotlib.tri.Triangulation(
- x, y, triangles=triangulation.simplices)
- mplInterpolator = matplotlib.tri.LinearTriInterpolator(
- tri, values)
-
- # Wrap interpolator to have same API as scipy's one
- def interpolator(points):
- return mplInterpolator(*points.T)
-
- return interpolator
-
- def _getInterpolator(self):
- """Returns a :class:`Future` which result is the interpolator.
-
- The interpolator is a callable taking an array Nx2 of points
- as a single argument.
- The :class:`Future` result is None in case the interpolator cannot
- be initialized.
-
- :rtype: concurrent.futures.Future
- """
- if (self.__interpolatorFuture is None or
- self.__interpolatorFuture.cancelled()):
- # Need to init a new interpolator
- x, y, values = self.getData(copy=False)[:3]
- # Remove not finite points
- mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
- x, y, values = x[mask], y[mask], values[mask]
-
- self.__interpolatorFuture = self.__getExecutor().submit_greedy(
- 'interpolator',
- self.__initInterpolator, self._getDelaunay(), values)
- return self.__interpolatorFuture
-
- def _logFilterData(self, xPositive, yPositive):
- """Filter out values with x or y <= 0 on log axes
-
- :param bool xPositive: True to filter arrays according to X coords.
- :param bool yPositive: True to filter arrays according to Y coords.
- :return: The filtered arrays or unchanged object if not filtering needed
- :rtype: (x, y, value, xerror, yerror)
- """
- # overloaded from PointsBase to filter also value.
- value = self.getValueData(copy=False)
-
- if xPositive or yPositive:
- clipped = self._getClippingBoolArray(xPositive, yPositive)
-
- if numpy.any(clipped):
- # copy to keep original array and convert to float
- value = numpy.array(value, copy=True, dtype=numpy.float64)
- value[clipped] = numpy.nan
-
- x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive)
-
- return x, y, value, xerror, yerror
-
- def getValueData(self, copy=True):
- """Returns the value assigned to the scatter data points.
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray
- """
- return numpy.array(self._value, copy=copy)
-
- def getAlphaData(self, copy=True):
- """Returns the alpha (transparency) assigned to the scatter data points.
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :rtype: numpy.ndarray
- """
- return numpy.array(self.__alpha, copy=copy)
-
- def getData(self, copy=True, displayed=False):
- """Returns the x, y coordinates and the value of the data points
-
- :param copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :param bool displayed: True to only get curve points that are displayed
- in the plot. Default: False.
- Note: If plot has log scale, negative points
- are not displayed.
- :returns: (x, y, value, xerror, yerror)
- :rtype: 5-tuple of numpy.ndarray
- """
- if displayed:
- data = self._getCachedData()
- if data is not None:
- assert len(data) == 5
- return data
-
- return (self.getXData(copy),
- self.getYData(copy),
- self.getValueData(copy),
- self.getXErrorData(copy),
- self.getYErrorData(copy))
-
- # reimplemented from PointsBase to handle `value`
- def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
- """Set the data of the scatter.
-
- :param numpy.ndarray x: The data corresponding to the x coordinates.
- :param numpy.ndarray y: The data corresponding to the y coordinates.
- :param numpy.ndarray value: The data corresponding to the value of
- the data points.
- :param xerror: Values with the uncertainties on the x values
- :type xerror: A float, or a numpy.ndarray of float32.
- If it is an array, it can either be a 1D array of
- same length as the data or a 2D array with 2 rows
- of same length as the data: row 0 for positive errors,
- row 1 for negative errors.
- :param yerror: Values with the uncertainties on the y values
- :type yerror: A float, or a numpy.ndarray of float32. See xerror.
- :param alpha: Values with the transparency (between 0 and 1)
- :type alpha: A float, or a numpy.ndarray of float32
- :param bool copy: True make a copy of the data (default),
- False to use provided arrays.
- """
- value = numpy.array(value, copy=copy)
- assert value.ndim == 1
- assert len(x) == len(value)
-
- # Convert complex data
- if numpy.iscomplexobj(value):
- _logger.warning(
- 'Converting value data to absolute value to plot it.')
- value = numpy.absolute(value)
-
- # Reset triangulation and interpolator
- if self.__delaunayFuture is not None:
- self.__delaunayFuture.cancel()
- self.__delaunayFuture = None
- if self.__interpolatorFuture is not None:
- self.__interpolatorFuture.cancel()
- self.__interpolatorFuture = None
-
- # Data changed, this needs update
- self.__cacheRegularGridInfo = None
- self.__cacheHistogramInfo = None
-
- self._value = value
- self._updateColormappedData()
-
- if alpha is not None:
- # Make sure alpha is an array of float in [0, 1]
- alpha = numpy.array(alpha, copy=copy)
- assert alpha.ndim == 1
- assert len(x) == len(alpha)
- if alpha.dtype.kind != 'f':
- alpha = alpha.astype(numpy.float32)
- if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
- alpha = numpy.clip(alpha, 0., 1.)
- self.__alpha = alpha
-
- # set x, y, xerror, yerror
-
- # call self._updated + plot._invalidateDataRange()
- PointsBase.setData(self, x, y, xerror, yerror, copy)
diff --git a/silx/gui/plot/items/shape.py b/silx/gui/plot/items/shape.py
deleted file mode 100644
index 955dfe3..0000000
--- a/silx/gui/plot/items/shape.py
+++ /dev/null
@@ -1,288 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the :class:`Shape` item of the :class:`Plot`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "21/12/2018"
-
-
-import logging
-
-import numpy
-import six
-
-from ... import colors
-from .core import (
- Item, DataItem,
- ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn)
-
-
-_logger = logging.getLogger(__name__)
-
-
-# TODO probably make one class for each kind of shape
-# TODO check fill:polygon/polyline + fill = duplicated
-class Shape(Item, ColorMixIn, FillMixIn, LineMixIn):
- """Description of a shape item
-
- :param str type_: The type of shape in:
- 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
- """
-
- def __init__(self, type_):
- Item.__init__(self)
- ColorMixIn.__init__(self)
- FillMixIn.__init__(self)
- LineMixIn.__init__(self)
- self._overlay = False
- assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines')
- self._type = type_
- self._points = ()
- self._lineBgColor = None
-
- self._handle = None
-
- def _addBackendRenderer(self, backend):
- """Update backend renderer"""
- points = self.getPoints(copy=False)
- x, y = points.T[0], points.T[1]
- return backend.addShape(x,
- y,
- shape=self.getType(),
- color=self.getColor(),
- fill=self.isFill(),
- overlay=self.isOverlay(),
- linestyle=self.getLineStyle(),
- linewidth=self.getLineWidth(),
- linebgcolor=self.getLineBgColor())
-
- def isOverlay(self):
- """Return true if shape is drawn as an overlay
-
- :rtype: bool
- """
- return self._overlay
-
- def setOverlay(self, overlay):
- """Set the overlay state of the shape
-
- :param bool overlay: True to make it an overlay
- """
- overlay = bool(overlay)
- if overlay != self._overlay:
- self._overlay = overlay
- self._updated(ItemChangedType.OVERLAY)
-
- def getType(self):
- """Returns the type of shape to draw.
-
- One of: 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
-
- :rtype: str
- """
- return self._type
-
- def getPoints(self, copy=True):
- """Get the control points of the shape.
-
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :return: Array of point coordinates
- :rtype: numpy.ndarray with 2 dimensions
- """
- return numpy.array(self._points, copy=copy)
-
- def setPoints(self, points, copy=True):
- """Set the point coordinates
-
- :param numpy.ndarray points: Array of point coordinates
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- :return:
- """
- self._points = numpy.array(points, copy=copy)
- self._updated(ItemChangedType.DATA)
-
- def getLineBgColor(self):
- """Returns the RGBA color of the item
- :rtype: 4-tuple of float in [0, 1] or array of colors
- """
- return self._lineBgColor
-
- def setLineBgColor(self, color, copy=True):
- """Set item color
- :param color: color(s) to be used
- :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
- one of the predefined color names defined in colors.py
- :param bool copy: True (Default) to get a copy,
- False to use internal representation (do not modify!)
- """
- if color is not None:
- if isinstance(color, six.string_types):
- color = colors.rgba(color)
- else:
- color = numpy.array(color, copy=copy)
- # TODO more checks + improve color array support
- if color.ndim == 1: # Single RGBA color
- color = colors.rgba(color)
- else: # Array of colors
- assert color.ndim == 2
-
- self._lineBgColor = color
- self._updated(ItemChangedType.LINE_BG_COLOR)
-
-
-class BoundingRect(DataItem, YAxisMixIn):
- """An invisible shape which enforce the plot view to display the defined
- space on autoscale.
-
- This item do not display anything. But if the visible property is true,
- this bounding box is used by the plot, if not, the bounding box is
- ignored. That's the default behaviour for plot items.
-
- It can be applied on the "left" or "right" axes. Not both at the same time.
- """
-
- def __init__(self):
- DataItem.__init__(self)
- YAxisMixIn.__init__(self)
- self.__bounds = None
-
- def setBounds(self, rect):
- """Set the bounding box of this item in data coordinates
-
- :param Union[None,List[float]] rect: (xmin, xmax, ymin, ymax) or None
- """
- if rect is not None:
- rect = float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3])
- assert rect[0] <= rect[1]
- assert rect[2] <= rect[3]
-
- if rect != self.__bounds:
- self.__bounds = rect
- self._boundsChanged()
- self._updated(ItemChangedType.DATA)
-
- def _getBounds(self):
- if self.__bounds is None:
- return None
- plot = self.getPlot()
- if plot is not None:
- xPositive = plot.getXAxis()._isLogarithmic()
- yPositive = plot.getYAxis()._isLogarithmic()
- if xPositive or yPositive:
- bounds = list(self.__bounds)
- if xPositive and bounds[1] <= 0:
- return None
- if xPositive and bounds[0] <= 0:
- bounds[0] = bounds[1]
- if yPositive and bounds[3] <= 0:
- return None
- if yPositive and bounds[2] <= 0:
- bounds[2] = bounds[3]
- return tuple(bounds)
-
- return self.__bounds
-
-
-class _BaseExtent(DataItem):
- """Base class for :class:`XAxisExtent` and :class:`YAxisExtent`.
-
- :param str axis: Either 'x' or 'y'.
- """
-
- def __init__(self, axis='x'):
- assert axis in ('x', 'y')
- DataItem.__init__(self)
- self.__axis = axis
- self.__range = 1., 100.
-
- def setRange(self, min_, max_):
- """Set the range of the extent of this item in data coordinates.
-
- :param float min_: Lower bound of the extent
- :param float max_: Upper bound of the extent
- :raises ValueError: If min > max or not finite bounds
- """
- range_ = float(min_), float(max_)
- if not numpy.all(numpy.isfinite(range_)):
- raise ValueError("min_ and max_ must be finite numbers.")
- if range_[0] > range_[1]:
- raise ValueError("min_ must be lesser or equal to max_")
-
- if range_ != self.__range:
- self.__range = range_
- self._boundsChanged()
- self._updated(ItemChangedType.DATA)
-
- def getRange(self):
- """Returns the range (min, max) of the extent in data coordinates.
-
- :rtype: List[float]
- """
- return self.__range
-
- def _getBounds(self):
- min_, max_ = self.getRange()
-
- plot = self.getPlot()
- if plot is not None:
- axis = plot.getXAxis() if self.__axis == 'x' else plot.getYAxis()
- if axis._isLogarithmic():
- if max_ <= 0:
- return None
- if min_ <= 0:
- min_ = max_
-
- if self.__axis == 'x':
- return min_, max_, float('nan'), float('nan')
- else:
- return float('nan'), float('nan'), min_, max_
-
-
-class XAxisExtent(_BaseExtent):
- """Invisible item with a settable horizontal data extent.
-
- This item do not display anything, but it behaves as a data
- item with a horizontal extent regarding plot data bounds, i.e.,
- :meth:`PlotWidget.resetZoom` will take this horizontal extent into account.
- """
- def __init__(self):
- _BaseExtent.__init__(self, axis='x')
-
-
-class YAxisExtent(_BaseExtent, YAxisMixIn):
- """Invisible item with a settable vertical data extent.
-
- This item do not display anything, but it behaves as a data
- item with a vertical extent regarding plot data bounds, i.e.,
- :meth:`PlotWidget.resetZoom` will take this vertical extent into account.
- """
-
- def __init__(self):
- _BaseExtent.__init__(self, axis='y')
- YAxisMixIn.__init__(self)
diff --git a/silx/gui/plot/test/__init__.py b/silx/gui/plot/test/__init__.py
deleted file mode 100644
index dfb7c2e..0000000
--- a/silx/gui/plot/test/__init__.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "23/07/2018"
-
-
-import unittest
-
-from .._utils import test
-from . import testColorBar
-from . import testCurvesROIWidget
-from . import testStats
-from . import testAlphaSlider
-from . import testInteraction
-from . import testLegendSelector
-from . import testMaskToolsWidget
-from . import testScatterMaskToolsWidget
-from . import testPlotInteraction
-from . import testPlotWidgetNoBackend
-from . import testPlotWidget
-from . import testPlotWindow
-from . import testStackView
-from . import testImageStack
-from . import testItem
-from . import testUtilsAxis
-from . import testLimitConstraints
-from . import testComplexImageView
-from . import testImageView
-from . import testSaveAction
-from . import testScatterView
-from . import testPixelIntensityHistoAction
-from . import testCompareImages
-from . import testRoiStatsWidget
-
-
-def suite():
- # Lazy-loading to avoid cyclic reference
- from ..tools import test as testTools
-
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- [test.suite(),
- testTools.suite(),
- testColorBar.suite(),
- testCurvesROIWidget.suite(),
- testStats.suite(),
- testAlphaSlider.suite(),
- testInteraction.suite(),
- testLegendSelector.suite(),
- testMaskToolsWidget.suite(),
- testScatterMaskToolsWidget.suite(),
- testPlotInteraction.suite(),
- testPlotWidgetNoBackend.suite(),
- testPlotWidget.suite(),
- testPlotWindow.suite(),
- testStackView.suite(),
- testImageStack.suite(),
- testItem.suite(),
- testUtilsAxis.suite(),
- testLimitConstraints.suite(),
- testComplexImageView.suite(),
- testImageView.suite(),
- testSaveAction.suite(),
- testScatterView.suite(),
- testPixelIntensityHistoAction.suite(),
- testCompareImages.suite(),
- testRoiStatsWidget.suite(),
- ])
- return test_suite
diff --git a/silx/gui/plot/test/testAlphaSlider.py b/silx/gui/plot/test/testAlphaSlider.py
deleted file mode 100644
index 01e6969..0000000
--- a/silx/gui/plot/test/testAlphaSlider.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for ImageAlphaSlider"""
-
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "28/03/2017"
-
-import numpy
-import unittest
-
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot import PlotWidget
-from silx.gui.plot import AlphaSlider
-
-
-class TestActiveImageAlphaSlider(TestCaseQt):
- def setUp(self):
- super(TestActiveImageAlphaSlider, self).setUp()
- self.plot = PlotWidget()
- self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot)
- self.aslider.setOrientation(qt.Qt.Horizontal)
-
- toolbar = qt.QToolBar("plot", self.plot)
- toolbar.addWidget(self.aslider)
- self.plot.addToolBar(toolbar)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.mouseMove(self.plot) # Move to center
- self.qapp.processEvents()
-
- def tearDown(self):
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- del self.aslider
-
- super(TestActiveImageAlphaSlider, self).tearDown()
-
- def testWidgetEnabled(self):
- # no active image initially, slider must be deactivate
- self.assertFalse(self.aslider.isEnabled())
-
- self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
- # now we have an active image
- self.assertTrue(self.aslider.isEnabled())
-
- self.plot.setActiveImage(None)
- self.assertFalse(self.aslider.isEnabled())
-
- def testGetImage(self):
- self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
- self.assertEqual(self.plot.getActiveImage(),
- self.aslider.getItem())
-
- self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
- self.plot.setActiveImage("2")
- self.assertEqual(self.plot.getImage("2"),
- self.aslider.getItem())
-
- def testGetAlpha(self):
- self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
- self.aslider.setValue(137)
- self.assertAlmostEqual(self.aslider.getAlpha(),
- 137. / 255)
-
-
-class TestNamedImageAlphaSlider(TestCaseQt):
- def setUp(self):
- super(TestNamedImageAlphaSlider, self).setUp()
- self.plot = PlotWidget()
- self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot)
- self.aslider.setOrientation(qt.Qt.Horizontal)
-
- toolbar = qt.QToolBar("plot", self.plot)
- toolbar.addWidget(self.aslider)
- self.plot.addToolBar(toolbar)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.mouseMove(self.plot) # Move to center
- self.qapp.processEvents()
-
- def tearDown(self):
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- del self.aslider
-
- super(TestNamedImageAlphaSlider, self).tearDown()
-
- def testWidgetEnabled(self):
- # no image set initially, slider must be deactivate
- self.assertFalse(self.aslider.isEnabled())
-
- self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
- self.aslider.setLegend("1")
- # now we have an image set
- self.assertTrue(self.aslider.isEnabled())
-
- def testGetImage(self):
- self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
- self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
- self.aslider.setLegend("1")
- self.assertEqual(self.plot.getImage("1"),
- self.aslider.getItem())
-
- self.aslider.setLegend("2")
- self.assertEqual(self.plot.getImage("2"),
- self.aslider.getItem())
-
- def testGetAlpha(self):
- self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
- self.aslider.setLegend("1")
- self.aslider.setValue(128)
- self.assertAlmostEqual(self.aslider.getAlpha(),
- 128. / 255)
-
-
-class TestNamedScatterAlphaSlider(TestCaseQt):
- def setUp(self):
- super(TestNamedScatterAlphaSlider, self).setUp()
- self.plot = PlotWidget()
- self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot)
- self.aslider.setOrientation(qt.Qt.Horizontal)
-
- toolbar = qt.QToolBar("plot", self.plot)
- toolbar.addWidget(self.aslider)
- self.plot.addToolBar(toolbar)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.mouseMove(self.plot) # Move to center
- self.qapp.processEvents()
-
- def tearDown(self):
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- del self.aslider
-
- super(TestNamedScatterAlphaSlider, self).tearDown()
-
- def testWidgetEnabled(self):
- # no Scatter set initially, slider must be deactivate
- self.assertFalse(self.aslider.isEnabled())
-
- self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
- legend="1")
- self.aslider.setLegend("1")
- # now we have an image set
- self.assertTrue(self.aslider.isEnabled())
-
- def testGetScatter(self):
- self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
- legend="1")
- self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
- legend="2")
- self.aslider.setLegend("1")
- self.assertEqual(self.plot.getScatter("1"),
- self.aslider.getItem())
-
- self.aslider.setLegend("2")
- self.assertEqual(self.plot.getScatter("2"),
- self.aslider.getItem())
-
- def testGetAlpha(self):
- self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
- legend="1")
- self.aslider.setLegend("1")
- self.aslider.setValue(128)
- self.assertAlmostEqual(self.aslider.getAlpha(),
- 128. / 255)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- # test_suite.addTest(positionInfoTestSuite)
- for testClass in (TestActiveImageAlphaSlider, TestNamedImageAlphaSlider,
- TestNamedScatterAlphaSlider):
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- testClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testColorBar.py b/silx/gui/plot/test/testColorBar.py
deleted file mode 100644
index a6f141c..0000000
--- a/silx/gui/plot/test/testColorBar.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for ColorBar featues and sub widgets of Colorbar module"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-import unittest
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot.ColorBar import _ColorScale
-from silx.gui.plot.ColorBar import ColorBarWidget
-from silx.gui.colors import Colormap
-from silx.gui import colors
-from silx.gui.plot import Plot2D
-from silx.gui import qt
-import numpy
-
-
-class TestColorScale(TestCaseQt):
- """Test that interaction with the colorScale is correct"""
- def setUp(self):
- super(TestColorScale, self).setUp()
- self.colorScaleWidget = _ColorScale(colormap=None, parent=None)
- self.colorScaleWidget.show()
- self.qWaitForWindowExposed(self.colorScaleWidget)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.colorScaleWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.colorScaleWidget.close()
- del self.colorScaleWidget
- super(TestColorScale, self).tearDown()
-
- def testNoColormap(self):
- """Test _ColorScale without a colormap"""
- colormap = self.colorScaleWidget.getColormap()
- self.assertIsNone(colormap)
-
- def testRelativePositionLinear(self):
- self.colorMapLin1 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=0.0,
- vmax=1.0)
- self.colorScaleWidget.setColormap(self.colorMapLin1)
-
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
-
- self.colorMapLin2 = Colormap(name='viridis',
- normalization=Colormap.LINEAR,
- vmin=-10,
- vmax=0)
- self.colorScaleWidget.setColormap(self.colorMapLin2)
-
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
- self.assertTrue(
- self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
-
- def testRelativePositionLog(self):
- self.colorMapLog1 = Colormap(name='temperature',
- normalization=Colormap.LOGARITHM,
- vmin=1.0,
- vmax=100.0)
-
- self.colorScaleWidget.setColormap(self.colorMapLog1)
-
- val = self.colorScaleWidget.getValueFromRelativePosition(1.0)
- self.assertAlmostEqual(val, 100.0)
-
- val = self.colorScaleWidget.getValueFromRelativePosition(0.5)
- self.assertAlmostEqual(val, 10.0)
-
- val = self.colorScaleWidget.getValueFromRelativePosition(0.0)
- self.assertTrue(val == 1.0)
-
-
-class TestNoAutoscale(TestCaseQt):
- """Test that ticks and color displayed are correct in the case of a colormap
- with no autoscale
- """
-
- def setUp(self):
- super(TestNoAutoscale, self).setUp()
- self.plot = Plot2D()
- self.colorBar = self.plot.getColorBarWidget()
- self.colorBar.setVisible(True) # Makes sure the colormap is visible
- self.tickBar = self.colorBar.getColorScaleBar().getTickBar()
- self.colorScale = self.colorBar.getColorScaleBar().getColorScale()
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.tickBar = None
- self.colorScale = None
- del self.colorBar
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestNoAutoscale, self).tearDown()
-
- def testLogNormNoAutoscale(self):
- colormapLog = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=1.0,
- vmax=100.0)
-
- data = numpy.linspace(10, 1e10, 9).reshape(3, 3)
- self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
- self.plot.setActiveImage('toto')
-
- # test Ticks
- self.tickBar.setTicksNumber(10)
- self.tickBar.computeTicks()
-
- ticksTh = numpy.linspace(1.0, 100.0, 10)
- ticksTh = 10**ticksTh
- numpy.array_equal(self.tickBar.ticks, ticksTh)
-
- # test ColorScale
- val = self.colorScale.getValueFromRelativePosition(1.0)
- self.assertAlmostEqual(val, 100.0)
-
- val = self.colorScale.getValueFromRelativePosition(0.0)
- self.assertTrue(val == 1.0)
-
- def testLinearNormNoAutoscale(self):
- colormapLog = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=-4,
- vmax=5)
-
- data = numpy.linspace(1, 9, 9).reshape(3, 3)
- self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
- self.plot.setActiveImage('toto')
-
- # test Ticks
- self.tickBar.setTicksNumber(10)
- self.tickBar.computeTicks()
-
- numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10))
-
- # test ColorScale
- val = self.colorScale.getValueFromRelativePosition(1.0)
- self.assertTrue(val == 5.0)
-
- val = self.colorScale.getValueFromRelativePosition(0.0)
- self.assertTrue(val == -4.0)
-
-
-class TestColorBarWidget(TestCaseQt):
- """Test interaction with the ColorBarWidget"""
-
- def setUp(self):
- super(TestColorBarWidget, self).setUp()
- self.plot = Plot2D()
- self.colorBar = self.plot.getColorBarWidget()
- self.colorBar.setVisible(True) # Makes sure the colormap is visible
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- self.qapp.processEvents()
- del self.colorBar
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestColorBarWidget, self).tearDown()
-
- def testEmptyColorBar(self):
- colorBar = ColorBarWidget(parent=None)
- colorBar.show()
- self.qWaitForWindowExposed(colorBar)
-
- def testNegativeColormaps(self):
- """test the behavior of the ColorBarWidget in the case of negative
- values
-
- Note : colorbar is modified by the Plot directly not ColorBarWidget
- """
- colormapLog = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
-
- data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30])
- data = data.reshape(3, 3)
- self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
- self.plot.setActiveImage('toto')
-
- # default behavior when with log and negative values: should set vmin
- # to 1 and vmax to 10
- self.assertTrue(self.colorBar.getColorScaleBar().minVal == 2)
- self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 30)
-
- # if data is positive
- data[data < 1] = data.max()
- self.plot.addImage(data=data,
- colormap=colormapLog,
- legend='toto',
- replace=True)
- self.plot.setActiveImage('toto')
-
- self.assertTrue(self.colorBar.getColorScaleBar().minVal == data.min())
- self.assertTrue(self.colorBar.getColorScaleBar().maxVal == data.max())
-
- def testPlotAssocation(self):
- """Make sure the ColorBarWidget is properly connected with the plot"""
- colormap = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None)
-
- # make sure that default settings are the same (but a copy of the
- self.colorBar.setPlot(self.plot)
- self.assertTrue(
- self.colorBar.getColormap() is self.plot.getDefaultColormap())
-
- data = numpy.linspace(0, 10, 100).reshape(10, 10)
- self.plot.addImage(data=data, colormap=colormap, legend='toto')
- self.plot.setActiveImage('toto')
-
- # make sure the modification of the colormap has been done
- self.assertFalse(
- self.colorBar.getColormap() is self.plot.getDefaultColormap())
- self.assertTrue(
- self.colorBar.getColormap() is colormap)
-
- # test that colorbar is updated when default plot colormap changes
- self.plot.clear()
- plotColormap = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
- self.plot.setDefaultColormap(plotColormap)
- self.assertTrue(self.colorBar.getColormap() is plotColormap)
-
- def testColormapWithoutRange(self):
- """Test with a colormap with vmin==vmax"""
- colormap = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=1.0,
- vmax=1.0)
- self.colorBar.setColormap(colormap)
-
-
-class TestColorBarUpdate(TestCaseQt):
- """Test that the ColorBar is correctly updated when the signal 'sigChanged'
- of the colormap is emitted
- """
-
- def setUp(self):
- super(TestColorBarUpdate, self).setUp()
- self.plot = Plot2D()
- self.colorBar = self.plot.getColorBarWidget()
- self.colorBar.setVisible(True) # Makes sure the colormap is visible
- self.colorBar.setPlot(self.plot)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
- self.data = numpy.random.rand(9).reshape(3, 3)
-
- def tearDown(self):
- self.qapp.processEvents()
- del self.colorBar
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestColorBarUpdate, self).tearDown()
-
- def testUpdateColorMap(self):
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=0,
- vmax=1)
-
- # check inital state
- self.plot.addImage(data=self.data, colormap=colormap, legend='toto')
- self.plot.setActiveImage('toto')
-
- self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0)
- self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 1)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmin == 0)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmax == 1)
- self.assertIsInstance(
- self.colorBar.getColorScaleBar().getTickBar()._normalizer,
- colors._LinearNormalization)
-
- # update colormap
- colormap.setVMin(0.5)
- self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0.5)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5)
-
- colormap.setVMax(0.8)
- self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 0.8)
- self.assertTrue(
- self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8)
-
- colormap.setNormalization('log')
- self.assertIsInstance(
- self.colorBar.getColorScaleBar().getTickBar()._normalizer,
- colors._LogarithmicNormalization)
-
- # TODO : should also check that if the colormap is changing then values (especially in log scale)
- # should be coherent if in autoscale
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for ui in (TestColorScale, TestNoAutoscale, TestColorBarWidget,
- TestColorBarUpdate):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(ui))
-
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testCompareImages.py b/silx/gui/plot/test/testCompareImages.py
deleted file mode 100644
index ed6942a..0000000
--- a/silx/gui/plot/test/testCompareImages.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for CompareImages widget"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "23/07/2018"
-
-import unittest
-import numpy
-import weakref
-
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot.CompareImages import CompareImages
-
-
-class TestCompareImages(TestCaseQt):
- """Test that CompareImages widget is working in some cases"""
-
- def setUp(self):
- super(TestCompareImages, self).setUp()
- self.widget = CompareImages()
-
- def tearDown(self):
- ref = weakref.ref(self.widget)
- self.widget = None
- self.qWaitForDestroy(ref)
- super(TestCompareImages, self).tearDown()
-
- def testIntensityImage(self):
- image1 = numpy.random.rand(10, 10)
- image2 = numpy.random.rand(10, 10)
- self.widget.setData(image1, image2)
-
- def testRgbImage(self):
- image1 = numpy.random.randint(0, 255, size=(10, 10, 3))
- image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
- self.widget.setData(image1, image2)
-
- def testRgbaImage(self):
- image1 = numpy.random.randint(0, 255, size=(10, 10, 4))
- image2 = numpy.random.randint(0, 255, size=(10, 10, 4))
- self.widget.setData(image1, image2)
-
- def testVizualisations(self):
- image1 = numpy.random.rand(10, 10)
- image2 = numpy.random.rand(10, 10)
- self.widget.setData(image1, image2)
- for mode in CompareImages.VisualizationMode:
- self.widget.setVisualizationMode(mode)
-
- def testAlignemnt(self):
- image1 = numpy.random.rand(10, 10)
- image2 = numpy.random.rand(5, 5)
- self.widget.setData(image1, image2)
- for mode in CompareImages.AlignmentMode:
- self.widget.setAlignmentMode(mode)
-
- def testGetPixel(self):
- image1 = numpy.random.rand(11, 11)
- image2 = numpy.random.rand(5, 5)
- image1[5, 5] = 111.111
- image2[2, 2] = 222.222
- self.widget.setData(image1, image2)
- expectedValue = {}
- expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222
- expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222
- expectedValue[CompareImages.AlignmentMode.ORIGIN] = None
- for mode in expectedValue.keys():
- self.widget.setAlignmentMode(mode)
- data = self.widget.getRawPixelData(11 / 2.0, 11 / 2.0)
- data1, data2 = data
- self.assertEqual(data1, 111.111)
- self.assertEqual(data2, expectedValue[mode])
-
- def testImageEmpty(self):
- self.widget.setData(image1=None, image2=None)
- self.assertTrue(self.widget.getRawPixelData(11 / 2.0, 11 / 2.0) == (None, None))
-
- def testSetImageSeparately(self):
- self.widget.setImage1(numpy.random.rand(10, 10))
- self.widget.setImage2(numpy.random.rand(10, 10))
- for mode in CompareImages.VisualizationMode:
- self.widget.setVisualizationMode(mode)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestCompareImages))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testComplexImageView.py b/silx/gui/plot/test/testComplexImageView.py
deleted file mode 100644
index 4ac3488..0000000
--- a/silx/gui/plot/test/testComplexImageView.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test suite for :class:`ComplexImageView`"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-import logging
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.plot import ComplexImageView
-
-from .utils import PlotWidgetTestCase
-
-
-logger = logging.getLogger(__name__)
-
-
-class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
- """Test suite of ComplexImageView widget"""
-
- def _createPlot(self):
- return ComplexImageView.ComplexImageView()
-
- def testPlot2DComplex(self):
- """Test API of ComplexImageView widget"""
- data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64)
- self.plot.setData(data)
- self.plot.setKeepDataAspectRatio(True)
- self.plot.getPlot().resetZoom()
- self.qWait(100)
-
- # Test colormap API
- colormap = self.plot.getColormap().copy()
- colormap.setName('magma')
- self.plot.setColormap(colormap)
- self.qWait(100)
-
- # Test all modes
- modes = self.plot.supportedComplexModes()
- for mode in modes:
- with self.subTest(mode=mode):
- self.plot.setComplexMode(mode)
- self.qWait(100)
-
- # Test origin and scale API
- self.plot.setScale((2, 1))
- self.qWait(100)
- self.plot.setOrigin((1, 1))
- self.qWait(100)
-
- # Test no data
- self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64))
- self.qWait(100)
-
- # Test float data
- self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10))
- self.qWait(100)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- TestComplexImageView))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testCurvesROIWidget.py b/silx/gui/plot/test/testCurvesROIWidget.py
deleted file mode 100644
index 6a0ab8c..0000000
--- a/silx/gui/plot/test/testCurvesROIWidget.py
+++ /dev/null
@@ -1,469 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for CurvesROIWidget"""
-
-__authors__ = ["T. Vincent", "P. Knobel", "H. Payno"]
-__license__ = "MIT"
-__date__ = "16/11/2017"
-
-
-import logging
-import os.path
-import unittest
-from collections import OrderedDict
-import numpy
-
-from silx.gui import qt
-from silx.gui.plot import items
-from silx.gui.plot import Plot1D
-from silx.test.utils import temp_dir
-from silx.gui.utils.testutils import TestCaseQt, SignalListener
-from silx.gui.plot import PlotWindow, CurvesROIWidget
-from silx.gui.plot.CurvesROIWidget import ROITable
-from silx.gui.utils.testutils import getQToolButtonFromAction
-from silx.gui.plot.PlotInteraction import ItemsInteraction
-
-_logger = logging.getLogger(__name__)
-
-
-class TestCurvesROIWidget(TestCaseQt):
- """Basic test for CurvesROIWidget"""
-
- def setUp(self):
- super(TestCurvesROIWidget, self).setUp()
- self.plot = PlotWindow()
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.widget = self.plot.getCurvesRoiDockWidget()
-
- self.widget.show()
- self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
-
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- del self.widget
-
- super(TestCurvesROIWidget, self).tearDown()
-
- def testDummyAPI(self):
- """Simple test of the getRois and setRois API"""
- roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
- todata=-10, type_='X')
- roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
- todata=20, type_='X')
-
- self.widget.roiWidget.setRois((roi_pos, roi_neg))
-
- rois_defs = self.widget.roiWidget.getRois()
- self.widget.roiWidget.setRois(rois=rois_defs)
-
- def testWithCurves(self):
- """Plot with curves: test all ROI widget buttons"""
- for offset in range(2):
- self.plot.addCurve(numpy.arange(1000),
- offset + numpy.random.random(1000),
- legend=str(offset))
-
- # Add two ROI
- self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
- self.qWait(200)
- self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
- self.qWait(200)
-
- # Change active curve
- self.plot.setActiveCurve(str(1))
-
- # Delete a ROI
- self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
- self.qWait(200)
-
- with temp_dir() as tmpDir:
- self.tmpFile = os.path.join(tmpDir, 'test.ini')
-
- # Save ROIs
- self.widget.roiWidget.save(self.tmpFile)
- self.assertTrue(os.path.isfile(self.tmpFile))
- self.assertEqual(len(self.widget.getRois()), 2)
-
- # Reset ROIs
- self.mouseClick(self.widget.roiWidget.resetButton,
- qt.Qt.LeftButton)
- self.qWait(200)
- rois = self.widget.getRois()
- self.assertEqual(len(rois), 1)
- roiID = list(rois.keys())[0]
- self.assertEqual(rois[roiID].getName(), 'ICR')
-
- # Load ROIs
- self.widget.roiWidget.load(self.tmpFile)
- self.assertEqual(len(self.widget.getRois()), 2)
-
- del self.tmpFile
-
- def testMiddleMarker(self):
- """Test with middle marker enabled"""
- self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True)
-
- # Add a ROI
- self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
-
- for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
- handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID]
- assert handler.getMarker('min')
- xleftMarker = handler.getMarker('min').getXPosition()
- xMiddleMarker = handler.getMarker('middle').getXPosition()
- xRightMarker = handler.getMarker('max').getXPosition()
- thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.
- self.assertAlmostEqual(xMiddleMarker, thValue)
-
- def testAreaCalculation(self):
- """Test result of area calculation"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
-
- # Add two curves
- self.plot.addCurve(x, y, legend="positive")
- self.plot.addCurve(-x, y, legend="negative")
-
- # Make sure there is an active curve and it is the positive one
- self.plot.setActiveCurve("positive")
-
- # Add two ROIs
- roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
- todata=-10, type_='X')
- roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
- todata=20, type_='X')
-
- self.widget.roiWidget.setRois((roi_pos, roi_neg))
-
- posCurve = self.plot.getCurve('positive')
- negCurve = self.plot.getCurve('negative')
-
- self.assertEqual(roi_pos.computeRawAndNetArea(posCurve),
- (numpy.trapz(y=[10, 20], x=[10, 20]),
- 0.0))
- self.assertEqual(roi_pos.computeRawAndNetArea(negCurve),
- (0.0, 0.0))
- self.assertEqual(roi_neg.computeRawAndNetArea(posCurve),
- ((0.0), 0.0))
- self.assertEqual(roi_neg.computeRawAndNetArea(negCurve),
- ((-150.0), 0.0))
-
- def testCountsCalculation(self):
- """Test result of count calculation"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
-
- # Add two curves
- self.plot.addCurve(x, y, legend="positive")
- self.plot.addCurve(-x, y, legend="negative")
-
- # Make sure there is an active curve and it is the positive one
- self.plot.setActiveCurve("positive")
-
- # Add two ROIs
- roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
- todata=-10, type_='X')
- roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
- todata=20, type_='X')
-
- self.widget.roiWidget.setRois((roi_pos, roi_neg))
-
- posCurve = self.plot.getCurve('positive')
- negCurve = self.plot.getCurve('negative')
-
- self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve),
- (y[10:21].sum(), 0.0))
- self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve),
- (0.0, 0.0))
- self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve),
- ((0.0), 0.0))
- self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve),
- (y[10:21].sum(), 0.0))
-
- def testDeferedInit(self):
- """Test behavior of the deferedInit"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
- self.plot.addCurve(x=x, y=y, legend="name", replace="True")
- roisDefs = OrderedDict([
- ["range1",
- OrderedDict([["from", 20], ["to", 200], ["type", "energy"]])],
- ["range2",
- OrderedDict([["from", 300], ["to", 500], ["type", "energy"]])]
- ])
-
- roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
- self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
- self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
- self.plot.getCurvesRoiDockWidget().setVisible(True)
- self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
-
- def testDictCompatibility(self):
- """Test that ROI api is valid with dict and not information is lost"""
- roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no',
- 'name': 'myROI', 'calibration': [1, 2, 3]}
- roi = CurvesROIWidget.ROI._fromDict(roiDict)
- self.assertEqual(roi.toDict(), roiDict)
-
- def testShowAllROI(self):
- """Test the show allROI action"""
- x = numpy.arange(100.)
- y = numpy.arange(100.)
- self.plot.addCurve(x=x, y=y, legend="name", replace="True")
-
- roisDefsDict = {
- "range1": {"from": 20, "to": 200,"type": "energy"},
- "range2": {"from": 300, "to": 500, "type": "energy"}
- }
-
- roisDefsObj = (
- CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200,
- type_='energy'),
- CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500,
- type_='energy')
- )
- self.widget.roiWidget.showAllMarkers(True)
- roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
- roiWidget.setRois(roisDefsDict)
- markers = [item for item in self.plot.getItems()
- if isinstance(item, items.MarkerBase)]
- self.assertEqual(len(markers), 2*3)
-
- markersHandler = self.widget.roiWidget.roiTable._markersHandler
- roiWidget.showAllMarkers(True)
- ICRROI = markersHandler.getVisibleRois()
- self.assertEqual(len(ICRROI), 2)
-
- roiWidget.showAllMarkers(False)
- ICRROI = markersHandler.getVisibleRois()
- self.assertEqual(len(ICRROI), 1)
-
- roiWidget.setRois(roisDefsObj)
- self.qapp.processEvents()
- markers = [item for item in self.plot.getItems()
- if isinstance(item, items.MarkerBase)]
- self.assertEqual(len(markers), 2*3)
-
- markersHandler = self.widget.roiWidget.roiTable._markersHandler
- roiWidget.showAllMarkers(True)
- ICRROI = markersHandler.getVisibleRois()
- self.assertEqual(len(ICRROI), 2)
-
- roiWidget.showAllMarkers(False)
- ICRROI = markersHandler.getVisibleRois()
- self.assertEqual(len(ICRROI), 1)
-
- def testRoiEdition(self):
- """Make sure if the ROI object is edited the ROITable will be updated
- """
- roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
- self.widget.roiWidget.setRois((roi, ))
-
- x = (0, 1, 1, 2, 2, 3)
- y = (1, 1, 2, 2, 1, 1)
- self.plot.addCurve(x=x, y=y, legend='linearCurve')
- self.plot.setActiveCurve(legend='linearCurve')
- self.widget.calculateROIs()
-
- roiTable = self.widget.roiWidget.roiTable
- indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
- itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts'])
- itemNetCounts = roiTable.item(0, indexesColumns['Net Counts'])
-
- self.assertTrue(itemRawCounts.text() == '8.0')
- self.assertTrue(itemNetCounts.text() == '2.0')
-
- itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
- itemNetArea = roiTable.item(0, indexesColumns['Net Area'])
-
- self.assertTrue(itemRawArea.text() == '4.0')
- self.assertTrue(itemNetArea.text() == '1.0')
-
- roi.setTo(2)
- itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
- self.assertTrue(itemRawArea.text() == '3.0')
- roi.setFrom(1)
- itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
- self.assertTrue(itemRawArea.text() == '2.0')
-
- def testRemoveActiveROI(self):
- """Test widget behavior when removing the active ROI"""
- roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
- self.widget.roiWidget.setRois((roi,))
-
- self.widget.roiWidget.roiTable.setActiveRoi(None)
- self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0)
- self.widget.roiWidget.setRois((roi,))
- self.plot.setActiveCurve(legend='linearCurve')
- self.widget.calculateROIs()
-
- def testEmitCurrentROI(self):
- """Test behavior of the CurvesROIWidget.sigROISignal"""
- roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
- self.widget.roiWidget.setRois((roi,))
- signalListener = SignalListener()
- self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
- self.widget.show()
- self.qapp.processEvents()
- self.assertEqual(signalListener.callCount(), 0)
- self.assertIs(self.widget.roiWidget.roiTable.activeRoi, roi)
- roi.setFrom(0.0)
- self.qapp.processEvents()
- self.assertEqual(signalListener.callCount(), 0)
- roi.setFrom(0.3)
- self.qapp.processEvents()
- self.assertEqual(signalListener.callCount(), 1)
-
-
-class TestRoiWidgetSignals(TestCaseQt):
- """Test Signals emitted by the RoiWidgetSignals"""
-
- def setUp(self):
- self.plot = Plot1D()
- x = range(20)
- y = range(20)
- self.plot.addCurve(x, y, legend='curve0')
- self.listener = SignalListener()
- self.curves_roi_widget = self.plot.getCurvesRoiWidget()
- self.curves_roi_widget.sigROISignal.connect(self.listener)
- assert self.curves_roi_widget.isVisible() is False
- assert self.listener.callCount() == 0
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- toolButton = getQToolButtonFromAction(self.plot.getRoiAction())
- self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton)
-
- self.curves_roi_widget.show()
- self.qWaitForWindowExposed(self.curves_roi_widget)
-
- def tearDown(self):
- self.plot = None
-
- def testSigROISignalAddRmRois(self):
- """Test SigROISignal when adding and removing ROIS"""
- self.assertEqual(self.listener.callCount(), 1)
- self.listener.clear()
-
- roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
- self.curves_roi_widget.roiTable.registerROI(roi1)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
- self.listener.clear()
-
- roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5)
- self.curves_roi_widget.roiTable.registerROI(roi2)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2')
- self.listener.clear()
-
- self.curves_roi_widget.roiTable.removeROI(roi2)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
- self.listener.clear()
-
- self.curves_roi_widget.roiTable.deleteActiveRoi()
- self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None)
- self.assertTrue(self.listener.arguments()[0][0]['current'] is None)
- self.listener.clear()
-
- self.curves_roi_widget.roiTable.registerROI(roi1)
- self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
- self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
- self.listener.clear()
- self.qapp.processEvents()
-
- self.curves_roi_widget.roiTable.removeROI(roi1)
- self.qapp.processEvents()
- self.assertEqual(self.listener.callCount(), 1)
- self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR')
- self.listener.clear()
-
- def testSigROISignalModifyROI(self):
- """Test SigROISignal when modifying it"""
- self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True)
- roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
- self.curves_roi_widget.roiTable.registerROI(roi1)
- self.curves_roi_widget.roiTable.setActiveRoi(roi1)
-
- # test modify the roi2 object
- self.listener.clear()
- roi1.setFrom(0.56)
- self.assertEqual(self.listener.callCount(), 1)
- self.listener.clear()
- roi1.setTo(2.56)
- self.assertEqual(self.listener.callCount(), 1)
- self.listener.clear()
- roi1.setName('linear2')
- self.assertEqual(self.listener.callCount(), 1)
- self.listener.clear()
- roi1.setType('new type')
- self.assertEqual(self.listener.callCount(), 1)
-
- # modify roi limits (from the gui)
- roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID())
- for marker_type in ('min', 'max', 'middle'):
- with self.subTest(marker_type=marker_type):
- self.listener.clear()
- marker = roi_marker_handler.getMarker(marker_type)
- self.qapp.processEvents()
- items_interaction = ItemsInteraction(plot=self.plot)
- x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), 1)
- items_interaction.beginDrag(x_pix, y_pix)
- self.qapp.processEvents()
- items_interaction.endDrag(x_pix+10, y_pix)
- self.qapp.processEvents()
- self.assertEqual(self.listener.callCount(), 1)
-
- def testSetActiveCurve(self):
- """Test sigRoiSignal when set an active curve"""
- roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
- self.curves_roi_widget.roiTable.registerROI(roi1)
- self.curves_roi_widget.roiTable.setActiveRoi(roi1)
- self.listener.clear()
- self.plot.setActiveCurve('curve0')
- self.assertEqual(self.listener.callCount(), 0)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestCurvesROIWidget,):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testImageStack.py b/silx/gui/plot/test/testImageStack.py
deleted file mode 100644
index 9c21469..0000000
--- a/silx/gui/plot/test/testImageStack.py
+++ /dev/null
@@ -1,197 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for ImageStack"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "15/01/2020"
-
-
-import unittest
-import tempfile
-import numpy
-import h5py
-
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-from silx.io.url import DataUrl
-from silx.gui.plot.ImageStack import ImageStack
-from silx.gui.utils.testutils import SignalListener
-from collections import OrderedDict
-import os
-import time
-import shutil
-
-
-class TestImageStack(TestCaseQt):
- """Simple test of the Image stack"""
-
- def setUp(self):
- TestCaseQt.setUp(self)
- self.urls = OrderedDict()
- self._raw_data = {}
- self._folder = tempfile.mkdtemp()
- self._n_urls = 10
- file_name = os.path.join(self._folder, 'test_inage_stack_file.h5')
- with h5py.File(file_name, 'w') as h5f:
- for i in range(self._n_urls):
- width = numpy.random.randint(10, 40)
- height = numpy.random.randint(10, 40)
- raw_data = numpy.random.random((width, height))
- self._raw_data[i] = raw_data
- h5f[str(i)] = raw_data
- self.urls[i] = DataUrl(file_path=file_name,
- data_path=str(i),
- scheme='silx')
- self.widget = ImageStack()
-
- self.urlLoadedListener = SignalListener()
- self.widget.sigLoaded.connect(self.urlLoadedListener)
-
- self.currentUrlChangedListener = SignalListener()
- self.widget.sigCurrentUrlChanged.connect(self.currentUrlChangedListener)
-
- def tearDown(self):
- shutil.rmtree(self._folder)
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
- self.widget.close()
- TestCaseQt.setUp(self)
-
- def testControls(self):
- """Test that selection using the url table and the slider are working
- """
- self.widget.show()
- self.assertEqual(self.widget.getCurrentUrl(), None)
- self.assertEqual(self.widget.getCurrentUrlIndex(), None)
- self.widget.setUrls(list(self.urls.values()))
-
- # wait for image to be loaded
- self._waitUntilUrlLoaded()
-
- self.assertEqual(self.widget.getCurrentUrl(), self.urls[0])
-
- # make sure all image are loaded
- self.assertEqual(self.urlLoadedListener.callCount(), self._n_urls)
- numpy.testing.assert_array_equal(
- self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
- self._raw_data[0])
- self.assertEqual(self.widget._slider.value(), 0)
-
- self.widget._urlsTable.setUrl(self.urls[4])
- numpy.testing.assert_array_equal(
- self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
- self._raw_data[4])
- self.assertEqual(self.widget._slider.value(), 4)
- self.assertEqual(self.widget.getCurrentUrl(), self.urls[4])
- self.assertEqual(self.widget.getCurrentUrlIndex(), 4)
-
- self.widget._slider.setUrlIndex(6)
- numpy.testing.assert_array_equal(
- self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
- self._raw_data[6])
- self.assertEqual(self.widget._urlsTable.currentItem().text(),
- self.urls[6].path())
-
- def testCurrentUrlSignals(self):
- """Test emission of 'currentUrlChangedListener'"""
- # check initialization
- self.assertEqual(self.currentUrlChangedListener.callCount(), 0)
- self.widget.setUrls(list(self.urls.values()))
- self.qapp.processEvents()
- time.sleep(0.5)
- self.qapp.processEvents()
- # once loaded the two signals should have been sended
- self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
- # if the slider is stuck to the same position no signal should be
- # emitted
- self.qapp.processEvents()
- time.sleep(0.5)
- self.qapp.processEvents()
- self.assertEqual(self.widget._slider.value(), 0)
- self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
- # if slider position is changed, one of each signal should have been
- # emitted
- self.widget._urlsTable.setUrl(self.urls[4])
- self.qapp.processEvents()
- time.sleep(1.5)
- self.qapp.processEvents()
- self.assertEqual(self.currentUrlChangedListener.callCount(), 2)
-
- def testUtils(self):
- """Test that some utils functions are working"""
- self.widget.show()
- self.widget.setUrls(list(self.urls.values()))
- self.assertEqual(len(self.widget.getUrls()), len(self.urls))
-
- # wait for image to be loaded
- self._waitUntilUrlLoaded()
-
- urls_values = list(self.urls.values())
- self.assertEqual(urls_values[0], self.urls[0])
- self.assertEqual(urls_values[7], self.urls[7])
-
- self.assertEqual(self.widget._getNextUrl(urls_values[2]).path(),
- urls_values[3].path())
- self.assertEqual(self.widget._getPreviousUrl(urls_values[0]), None)
- self.assertEqual(self.widget._getPreviousUrl(urls_values[6]).path(),
- urls_values[5].path())
-
- self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]),
- urls_values[1:3])
- self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]),
- urls_values[8:])
- self.assertEqual(self.widget._getNPreviousUrls(3, urls_values[2]),
- urls_values[:2])
- self.assertEqual(self.widget._getNPreviousUrls(5, urls_values[8]),
- urls_values[3:8])
-
- def _waitUntilUrlLoaded(self, timeout=2.0):
- """Wait until all image urls are loaded"""
- loop_duration = 0.2
- remaining_duration = timeout
- while(len(self.widget._loadingThreads) > 0 and remaining_duration > 0):
- remaining_duration -= loop_duration
- time.sleep(loop_duration)
- self.qapp.processEvents()
-
- if remaining_duration <= 0.0:
- remaining_urls = []
- for thread_ in self.widget._loadingThreads:
- remaining_urls.append(thread_.url.path())
- mess = 'All images are not loaded after the time out. ' \
- 'Remaining urls are: ' + str(remaining_urls)
- raise TimeoutError(mess)
- return True
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestImageStack))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testImageView.py b/silx/gui/plot/test/testImageView.py
deleted file mode 100644
index 3c8d84c..0000000
--- a/silx/gui/plot/test/testImageView.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWindow"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import unittest
-import numpy
-
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-
-from silx.gui.plot import ImageView
-from silx.gui.colors import Colormap
-
-
-class TestImageView(TestCaseQt):
- """Tests of ImageView widget."""
-
- def setUp(self):
- super(TestImageView, self).setUp()
- self.plot = ImageView()
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- self.qapp.processEvents()
- super(TestImageView, self).tearDown()
-
- def testSetImage(self):
- """Test setImage"""
- image = numpy.arange(100).reshape(10, 10)
-
- self.plot.setImage(image, reset=True)
- self.qWait(100)
- self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
- self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
-
- # With reset=False
- self.plot.setImage(image[::2, ::2], reset=False)
- self.qWait(100)
- self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
- self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
-
- self.plot.setImage(image, origin=(10, 20), scale=(2, 4), reset=False)
- self.qWait(100)
- self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
- self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
-
- # With reset=True
- self.plot.setImage(image, origin=(1, 2), scale=(1, 0.5), reset=True)
- self.qWait(100)
- self.assertEqual(self.plot.getXAxis().getLimits(), (1, 11))
- self.assertEqual(self.plot.getYAxis().getLimits(), (2, 7))
-
- self.plot.setImage(image[::2, ::2], reset=True)
- self.qWait(100)
- self.assertEqual(self.plot.getXAxis().getLimits(), (0, 5))
- self.assertEqual(self.plot.getYAxis().getLimits(), (0, 5))
-
- def testColormap(self):
- """Test get|setColormap"""
- image = numpy.arange(100).reshape(10, 10)
- self.plot.setImage(image)
-
- # Colormap as dict
- self.plot.setColormap({'name': 'viridis',
- 'normalization': 'log',
- 'autoscale': False,
- 'vmin': 0,
- 'vmax': 1})
- colormap = self.plot.getColormap()
- self.assertEqual(colormap.getName(), 'viridis')
- self.assertEqual(colormap.getNormalization(), 'log')
- self.assertEqual(colormap.getVMin(), 0)
- self.assertEqual(colormap.getVMax(), 1)
-
- # Colormap as keyword arguments
- self.plot.setColormap(colormap='magma',
- normalization='linear',
- autoscale=True,
- vmin=1,
- vmax=2)
- self.assertEqual(colormap.getName(), 'magma')
- self.assertEqual(colormap.getNormalization(), 'linear')
- self.assertEqual(colormap.getVMin(), None)
- self.assertEqual(colormap.getVMax(), None)
-
- # Update colormap with keyword argument
- self.plot.setColormap(normalization='log')
- self.assertEqual(colormap.getNormalization(), 'log')
-
- # Colormap as Colormap object
- cmap = Colormap()
- self.plot.setColormap(cmap)
- self.assertIs(self.plot.getColormap(), cmap)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestImageView))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testInteraction.py b/silx/gui/plot/test/testInteraction.py
deleted file mode 100644
index a47337e..0000000
--- a/silx/gui/plot/test/testInteraction.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests from interaction state machines"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "18/02/2016"
-
-
-import unittest
-
-from silx.gui.plot import Interaction
-
-
-class TestInteraction(unittest.TestCase):
- def testClickOrDrag(self):
- """Minimalistic test for click or drag state machine."""
- events = []
-
- class TestClickOrDrag(Interaction.ClickOrDrag):
- def click(self, x, y, btn):
- events.append(('click', x, y, btn))
-
- def beginDrag(self, x, y, btn):
- events.append(('beginDrag', x, y, btn))
-
- def drag(self, x, y, btn):
- events.append(('drag', x, y, btn))
-
- def endDrag(self, start, end, btn):
- events.append(('endDrag', start, end, btn))
-
- clickOrDrag = TestClickOrDrag()
-
- # click
- clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
- self.assertEqual(len(events), 0)
-
- clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN)
- self.assertEqual(len(events), 1)
- self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN))
-
- # drag
- events = []
- clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
- self.assertEqual(len(events), 0)
- clickOrDrag.handleEvent('move', 15, 10)
- self.assertEqual(len(events), 2) # Received beginDrag and drag
- self.assertEqual(events[0], ('beginDrag', 10, 10, Interaction.LEFT_BTN))
- self.assertEqual(events[1], ('drag', 15, 10, Interaction.LEFT_BTN))
- clickOrDrag.handleEvent('move', 20, 10)
- self.assertEqual(len(events), 3)
- self.assertEqual(events[-1], ('drag', 20, 10, Interaction.LEFT_BTN))
- clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN)
- self.assertEqual(len(events), 4)
- self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10), Interaction.LEFT_BTN))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestInteraction))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testItem.py b/silx/gui/plot/test/testItem.py
deleted file mode 100644
index 8dacdea..0000000
--- a/silx/gui/plot/test/testItem.py
+++ /dev/null
@@ -1,340 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for PlotWidget items."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/09/2017"
-
-
-import unittest
-
-import numpy
-
-from silx.gui.utils.testutils import SignalListener
-from silx.gui.plot.items import ItemChangedType
-from silx.gui.plot import items
-from .utils import PlotWidgetTestCase
-
-
-class TestSigItemChangedSignal(PlotWidgetTestCase):
- """Test item's sigItemChanged signal"""
-
- def testCurveChanged(self):
- """Test sigItemChanged for curve"""
- self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
- curve = self.plot.getCurve('test')
-
- listener = SignalListener()
- curve.sigItemChanged.connect(listener)
-
- # Test for signal in Item class
- curve.setVisible(False)
- curve.setVisible(True)
- curve.setZValue(100)
-
- # Test for signals in PointsBase class
- curve.setData(numpy.arange(100), numpy.arange(100))
-
- # SymbolMixIn
- curve.setSymbol('Circle')
- curve.setSymbol('d')
- curve.setSymbolSize(20)
-
- # AlphaMixIn
- curve.setAlpha(0.5)
-
- # Test for signals in Curve class
- # ColorMixIn
- curve.setColor('yellow')
- # YAxisMixIn
- curve.setYAxis('right')
- # FillMixIn
- curve.setFill(True)
- # LineMixIn
- curve.setLineStyle(':')
- curve.setLineStyle(':') # Not sending event
- curve.setLineWidth(2)
-
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.VISIBLE,
- ItemChangedType.VISIBLE,
- ItemChangedType.ZVALUE,
- ItemChangedType.DATA,
- ItemChangedType.SYMBOL,
- ItemChangedType.SYMBOL,
- ItemChangedType.SYMBOL_SIZE,
- ItemChangedType.ALPHA,
- ItemChangedType.COLOR,
- ItemChangedType.YAXIS,
- ItemChangedType.FILL,
- ItemChangedType.LINE_STYLE,
- ItemChangedType.LINE_WIDTH])
-
- def testHistogramChanged(self):
- """Test sigItemChanged for Histogram"""
- self.plot.addHistogram(
- numpy.arange(10), edges=numpy.arange(11), legend='test')
- histogram = self.plot.getHistogram('test')
- listener = SignalListener()
- histogram.sigItemChanged.connect(listener)
-
- # Test signals in Histogram class
- histogram.setData(numpy.zeros(10), numpy.arange(11))
-
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.DATA])
-
- def testImageDataChanged(self):
- """Test sigItemChanged for ImageData"""
- self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='test')
- image = self.plot.getImage('test')
-
- listener = SignalListener()
- image.sigItemChanged.connect(listener)
-
- # ColormapMixIn
- colormap = self.plot.getDefaultColormap().copy()
- image.setColormap(colormap)
- image.getColormap().setName('viridis')
-
- # Test of signals in ImageBase class
- image.setOrigin(10)
- image.setScale(2)
-
- # Test of signals in ImageData class
- image.setData(numpy.ones((10, 10)))
-
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.COLORMAP,
- ItemChangedType.COLORMAP,
- ItemChangedType.POSITION,
- ItemChangedType.SCALE,
- ItemChangedType.COLORMAP,
- ItemChangedType.DATA])
-
- def testImageRgbaChanged(self):
- """Test sigItemChanged for ImageRgba"""
- self.plot.addImage(numpy.ones((10, 10, 3)), legend='rgb')
- image = self.plot.getImage('rgb')
-
- listener = SignalListener()
- image.sigItemChanged.connect(listener)
-
- # Test of signals in ImageRgba class
- image.setData(numpy.zeros((10, 10, 3)))
-
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.DATA])
-
- def testMarkerChanged(self):
- """Test sigItemChanged for markers"""
- self.plot.addMarker(10, 20, legend='test')
- marker = self.plot._getMarker('test')
-
- listener = SignalListener()
- marker.sigItemChanged.connect(listener)
-
- # Test signals in _BaseMarker
- marker.setPosition(10, 10)
- marker.setPosition(10, 10) # Not sending event
- marker.setText('toto')
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.POSITION,
- ItemChangedType.TEXT])
-
- # XMarker
- self.plot.addXMarker(10, legend='x')
- marker = self.plot._getMarker('x')
-
- listener = SignalListener()
- marker.sigItemChanged.connect(listener)
- marker.setPosition(20, 20)
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.POSITION])
-
- # YMarker
- self.plot.addYMarker(10, legend='x')
- marker = self.plot._getMarker('x')
-
- listener = SignalListener()
- marker.sigItemChanged.connect(listener)
- marker.setPosition(20, 20)
- self.assertEqual(listener.arguments(argumentIndex=0),
- [ItemChangedType.POSITION])
-
- def testScatterChanged(self):
- """Test sigItemChanged for scatter"""
- data = numpy.arange(10)
- self.plot.addScatter(data, data, data, legend='test')
- scatter = self.plot.getScatter('test')
-
- listener = SignalListener()
- scatter.sigItemChanged.connect(listener)
-
- # ColormapMixIn
- scatter.getColormap().setName('viridis')
-
- # Test of signals in Scatter class
- scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2))
-
- # Visualization mode changed
- scatter.setVisualization(scatter.Visualization.SOLID)
-
- self.assertEqual(listener.arguments(),
- [(ItemChangedType.COLORMAP,),
- (ItemChangedType.COLORMAP,),
- (ItemChangedType.DATA,),
- (ItemChangedType.VISUALIZATION_MODE,)])
-
- def testShapeChanged(self):
- """Test sigItemChanged for shape"""
- data = numpy.array((1., 10.))
- self.plot.addShape(data, data, legend='test', shape='rectangle')
- shape = self.plot._getItem(kind='item', legend='test')
-
- listener = SignalListener()
- shape.sigItemChanged.connect(listener)
-
- shape.setOverlay(True)
- shape.setPoints(((2., 2.), (3., 3.)))
-
- self.assertEqual(listener.arguments(),
- [(ItemChangedType.OVERLAY,),
- (ItemChangedType.DATA,)])
-
-
-class TestSymbol(PlotWidgetTestCase):
- """Test item's symbol """
-
- def test(self):
- """Test sigItemChanged for curve"""
- self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
- curve = self.plot.getCurve('test')
-
- # SymbolMixIn
- curve.setSymbol('o')
- name = curve.getSymbolName()
- self.assertEqual('Circle', name)
-
- name = curve.getSymbolName('d')
- self.assertEqual('Diamond', name)
-
-
-class TestVisibleExtent(PlotWidgetTestCase):
- """Test item's visible extent feature"""
-
- def testGetVisibleBounds(self):
- """Test Item.getVisibleBounds"""
-
- # Create test items (with a bounding box of x: [1,3], y: [0,2])
- curve = items.Curve()
- curve.setData((1, 2, 3), (0, 1, 2))
-
- histogram = items.Histogram()
- histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3))
-
- image = items.ImageData()
- image.setOrigin((1, 0))
- image.setData(numpy.arange(4).reshape(2, 2))
-
- scatter = items.Scatter()
- scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3))
-
- bbox = items.BoundingRect()
- bbox.setBounds((1, 3, 0, 2))
-
- xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
- for item in (curve, histogram, image, scatter, bbox):
- with self.subTest(item=item):
- xaxis.setLimits(0, 100)
- yaxis.setLimits(0, 100)
- self.plot.addItem(item)
- self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.))
-
- xaxis.setLimits(0.5, 2.5)
- self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.))
-
- yaxis.setLimits(0.5, 1.5)
- self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5))
-
- item.setVisible(False)
- self.assertIsNone(item.getVisibleBounds())
-
- self.plot.clear()
-
- def testVisibleExtentTracking(self):
- """Test Item's visible extent tracking"""
- image = items.ImageData()
- image.setData(numpy.arange(6).reshape(2, 3))
-
- listener = SignalListener()
- image._sigVisibleBoundsChanged.connect(listener)
- image._setVisibleBoundsTracking(True)
- self.assertTrue(image._isVisibleBoundsTracking())
-
- self.plot.addItem(image)
- self.assertEqual(listener.callCount(), 1)
-
- self.plot.getXAxis().setLimits(0, 1)
- self.assertEqual(listener.callCount(), 2)
-
- self.plot.hide()
- self.qapp.processEvents()
- # No event here
- self.assertEqual(listener.callCount(), 2)
-
- self.plot.getXAxis().setLimits(1, 2)
- # No event since PlotWidget is hidden, delayed to PlotWidget show
- self.assertEqual(listener.callCount(), 2)
-
- self.plot.show()
- self.qapp.processEvents()
- # Receives delayed event now
- self.assertEqual(listener.callCount(), 3)
-
- image.setOrigin((-1, -1))
- self.assertEqual(listener.callCount(), 4)
-
- image.setVisible(False)
- image.setOrigin((0, 0))
- # No event since item is not visible
- self.assertEqual(listener.callCount(), 4)
-
- image.setVisible(True)
- # Receives delayed event now
- self.assertEqual(listener.callCount(), 5)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- for klass in (TestSigItemChangedSignal, TestSymbol, TestVisibleExtent):
- test_suite.addTest(loadTests(klass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testLegendSelector.py b/silx/gui/plot/test/testLegendSelector.py
deleted file mode 100644
index de5ffde..0000000
--- a/silx/gui/plot/test/testLegendSelector.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWidget"""
-
-__authors__ = ["T. Rueter", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "15/05/2017"
-
-
-import logging
-import unittest
-
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot import LegendSelector
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestLegendSelector(TestCaseQt):
- """Basic test for LegendSelector"""
-
- def testLegendSelector(self):
- """Test copied from __main__ of LegendSelector in PyMca"""
- class Notifier(qt.QObject):
- def __init__(self):
- qt.QObject.__init__(self)
- self.chk = True
-
- def signalReceived(self, **kw):
- obj = self.sender()
- _logger.info('NOTIFIER -- signal received\n\tsender: %s',
- str(obj))
-
- notifier = Notifier()
-
- legends = ['Legend0',
- 'Legend1',
- 'Long Legend 2',
- 'Foo Legend 3',
- 'Even Longer Legend 4',
- 'Short Leg 5',
- 'Dot symbol 6',
- 'Comma symbol 7']
- colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan,
- qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow]
- symbols = ['o', 't', '+', 'x', 's', 'd', '.', ',']
-
- win = LegendSelector.LegendListView()
- # win = LegendListContextMenu()
- # win = qt.QWidget()
- # layout = qt.QVBoxLayout()
- # layout.setContentsMargins(0,0,0,0)
- llist = []
-
- for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)):
- ddict = {
- 'color': qt.QColor(c),
- 'linewidth': 4,
- 'symbol': s,
- }
- legend = l
- llist.append((legend, ddict))
- # item = qt.QListWidgetItem(win)
- # legendWidget = LegendListItemWidget(l)
- # legendWidget.icon.setSymbol(s)
- # legendWidget.icon.setColor(qt.QColor(c))
- # layout.addWidget(legendWidget)
- # win.setItemWidget(item, legendWidget)
-
- # win = LegendListItemWidget('Some Legend 1')
- # print(llist)
- model = LegendSelector.LegendModel(legendList=llist)
- win.setModel(model)
- win.setSelectionModel(qt.QItemSelectionModel(model))
- win.setContextMenu()
- # print('Edit triggers: %d'%win.editTriggers())
-
- # win = LegendListWidget(None, legends)
- # win[0].updateItem(ddict)
- # win.setLayout(layout)
- win.sigLegendSignal.connect(notifier.signalReceived)
- win.show()
-
- win.clear()
- win.setLegendList(llist)
-
- self.qWaitForWindowExposed(win)
-
-
-class TestRenameCurveDialog(TestCaseQt):
- """Basic test for RenameCurveDialog"""
-
- def testDialog(self):
- """Create dialog, change name and press OK"""
- self.dialog = LegendSelector.RenameCurveDialog(
- None, 'curve1', ['curve1', 'curve2', 'curve3'])
- self.dialog.open()
- self.qWaitForWindowExposed(self.dialog)
- self.keyClicks(self.dialog.lineEdit, 'changed')
- self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton)
- self.qapp.processEvents()
- ret = self.dialog.result()
- self.assertEqual(ret, qt.QDialog.Accepted)
- newName = self.dialog.getText()
- self.assertEqual(newName, 'curve1changed')
- del self.dialog
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestLegendSelector, TestRenameCurveDialog):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testLimitConstraints.py b/silx/gui/plot/test/testLimitConstraints.py
deleted file mode 100644
index 5e7e0b1..0000000
--- a/silx/gui/plot/test/testLimitConstraints.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test setLimitConstaints on the PlotWidget"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "30/08/2017"
-
-
-import unittest
-from silx.gui.plot import PlotWidget
-
-
-class TestLimitConstaints(unittest.TestCase):
- """Tests setLimitConstaints class"""
-
- def setUp(self):
- self.plot = PlotWidget()
-
- def tearDown(self):
- self.plot = None
-
- def testApi(self):
- """Test availability of the API"""
- self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=10)
- self.plot.getXAxis().setRangeConstraints(minRange=1, maxRange=1)
- self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=10)
- self.plot.getYAxis().setRangeConstraints(minRange=1, maxRange=1)
-
- def testXMinMax(self):
- """Test limit constains on x-axis"""
- self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
- self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
- self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
- self.assertEqual(self.plot.getYAxis().getLimits(), (-1, 101))
-
- def testYMinMax(self):
- """Test limit constains on y-axis"""
- self.plot.getYAxis().setLimitsConstraints(minPos=0, maxPos=100)
- self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
- self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
- self.assertEqual(self.plot.getYAxis().getLimits(), (0, 100))
-
- def testMinXRange(self):
- """Test min range constains on x-axis"""
- self.plot.getXAxis().setRangeConstraints(minRange=100)
- self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
- limits = self.plot.getXAxis().getLimits()
- self.assertEqual(limits[1] - limits[0], 100)
- limits = self.plot.getYAxis().getLimits()
- self.assertNotEqual(limits[1] - limits[0], 100)
-
- def testMaxXRange(self):
- """Test max range constains on x-axis"""
- self.plot.getXAxis().setRangeConstraints(maxRange=100)
- self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
- limits = self.plot.getXAxis().getLimits()
- self.assertEqual(limits[1] - limits[0], 100)
- limits = self.plot.getYAxis().getLimits()
- self.assertNotEqual(limits[1] - limits[0], 100)
-
- def testMinYRange(self):
- """Test min range constains on y-axis"""
- self.plot.getYAxis().setRangeConstraints(minRange=100)
- self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
- limits = self.plot.getXAxis().getLimits()
- self.assertNotEqual(limits[1] - limits[0], 100)
- limits = self.plot.getYAxis().getLimits()
- self.assertEqual(limits[1] - limits[0], 100)
-
- def testMaxYRange(self):
- """Test max range constains on y-axis"""
- self.plot.getYAxis().setRangeConstraints(maxRange=100)
- self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
- limits = self.plot.getXAxis().getLimits()
- self.assertNotEqual(limits[1] - limits[0], 100)
- limits = self.plot.getYAxis().getLimits()
- self.assertEqual(limits[1] - limits[0], 100)
-
- def testChangeOfConstraints(self):
- """Test changing of the constraints"""
- self.plot.getXAxis().setRangeConstraints(minRange=10, maxRange=10)
- # There is no more constraints on the range
- self.plot.getXAxis().setRangeConstraints(minRange=None, maxRange=None)
- self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
- self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
-
- def testSettingConstraints(self):
- """Test setting a constaint (setLimits first then the constaint)"""
- self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
- self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
- self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestLimitConstaints))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testMaskToolsWidget.py b/silx/gui/plot/test/testMaskToolsWidget.py
deleted file mode 100644
index c22975f..0000000
--- a/silx/gui/plot/test/testMaskToolsWidget.py
+++ /dev/null
@@ -1,316 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for MaskToolsWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import logging
-import os.path
-import unittest
-
-import numpy
-
-from silx.gui import qt
-from silx.test.utils import temp_dir
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import getQToolButtonFromAction
-from silx.gui.plot import PlotWindow, MaskToolsWidget
-from .utils import PlotWidgetTestCase
-
-import fabio
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
- """Basic test for MaskToolsWidget"""
-
- def _createPlot(self):
- return PlotWindow()
-
- def setUp(self):
- super(TestMaskToolsWidget, self).setUp()
- self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST')
- self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
- self.maskWidget = self.widget.widget()
-
- def tearDown(self):
- del self.maskWidget
- del self.widget
- super(TestMaskToolsWidget, self).tearDown()
-
- def testEmptyPlot(self):
- """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
- self.maskWidget.setMultipleMasks('single')
- self.qapp.processEvents()
-
- self.maskWidget.setMultipleMasks('exclusive')
- self.qapp.processEvents()
-
- def _drag(self):
- """Drag from plot center to offset position"""
- plot = self.plot.getWidgetHandle()
- xCenter, yCenter = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- pos0 = xCenter, yCenter
- pos1 = xCenter + offset, yCenter + offset
-
- self.mouseMove(plot, pos=(0, 0))
- self.mouseMove(plot, pos=pos0)
- self.qapp.processEvents()
- self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
- self.qapp.processEvents()
- self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
- self.mouseMove(plot, pos=pos1)
- self.qapp.processEvents()
- self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
- self.qapp.processEvents()
- self.mouseMove(plot, pos=(0, 0))
-
- def _drawPolygon(self):
- """Draw a star polygon in the plot"""
- plot = self.plot.getWidgetHandle()
- x, y = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset),
- (x, y + offset)] # Close polygon
-
- self.mouseMove(plot, pos=(0, 0))
- for pos in star:
- self.mouseMove(plot, pos=pos)
- self.qapp.processEvents()
- self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
- self.qapp.processEvents()
- self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
- self.qapp.processEvents()
-
- def _drawPencil(self):
- """Draw a star polygon in the plot"""
- plot = self.plot.getWidgetHandle()
- x, y = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset)]
-
- self.mouseMove(plot, pos=(0, 0))
- self.mouseMove(plot, pos=star[0])
- self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
- for pos in star[1:]:
- self.mouseMove(plot, pos=pos)
- self.mouseRelease(
- plot, qt.Qt.LeftButton, pos=star[-1])
-
- def _isMaskItemSync(self):
- """Check if masks from item and tools are sync or not"""
- if self.maskWidget.isItemMaskUpdated():
- return numpy.all(numpy.equal(
- self.maskWidget.getSelectionMask(),
- self.plot.getActiveImage().getMaskData(copy=False)))
- else:
- return True
-
- def testWithAnImage(self):
- """Plot with an image: test MaskToolsWidget interactions"""
-
- # Add and remove a image (this should enable/disable GUI + change mask)
- self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024),
- legend='test')
- self.qapp.processEvents()
-
- self.plot.remove('test', kind='image')
- self.qapp.processEvents()
-
- tests = [((0, 0), (1, 1)),
- ((1000, 1000), (1, 1)),
- ((0, 0), (-1, -1)),
- ((1000, 1000), (-1, -1))]
-
- for itemMaskUpdated in (False, True):
- for origin, scale in tests:
- with self.subTest(origin=origin, scale=scale):
- self.maskWidget.setItemMaskUpdated(itemMaskUpdated)
- self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
- legend='test',
- origin=origin,
- scale=scale)
- self.qapp.processEvents()
-
- self.assertEqual(
- self.maskWidget.isItemMaskUpdated(), itemMaskUpdated)
-
- # Test draw rectangle #
- toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drag()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
- self.assertTrue(self._isMaskItemSync())
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drag()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
- self.assertTrue(self._isMaskItemSync())
-
- # Test draw polygon #
- toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drawPolygon()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
- self.assertTrue(self._isMaskItemSync())
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drawPolygon()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
- self.assertTrue(self._isMaskItemSync())
-
- # Test draw pencil #
- toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- self.maskWidget.pencilSpinBox.setValue(30)
- self.qapp.processEvents()
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drawPencil()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
- self.assertTrue(self._isMaskItemSync())
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drawPencil()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
- self.assertTrue(self._isMaskItemSync())
-
- # Test no draw tool #
- toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- self.plot.clear()
-
- def __loadSave(self, file_format):
- """Plot with an image: test MaskToolsWidget operations"""
- self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
- legend='test')
- self.qapp.processEvents()
-
- # Draw a polygon mask
- toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
- self._drawPolygon()
-
- ref_mask = self.maskWidget.getSelectionMask()
- self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
-
- with temp_dir() as tmp:
- mask_filename = os.path.join(tmp, 'mask.' + file_format)
- self.maskWidget.save(mask_filename, file_format)
-
- self.maskWidget.resetSelectionMask()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- self.maskWidget.load(mask_filename)
- self.assertTrue(numpy.all(numpy.equal(
- self.maskWidget.getSelectionMask(), ref_mask)))
-
- def testLoadSaveNpy(self):
- self.__loadSave("npy")
-
- def testLoadSaveFit2D(self):
- self.__loadSave("msk")
-
- def testSigMaskChangedEmitted(self):
- self.plot.addImage(numpy.arange(512**2).reshape(512, 512),
- legend='test')
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- l = []
-
- def slot():
- l.append(1)
-
- self.maskWidget.sigMaskChanged.connect(slot)
-
- # rectangle mask
- toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drag()
-
- self.assertGreater(len(l), 0)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestMaskToolsWidget,):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPixelIntensityHistoAction.py b/silx/gui/plot/test/testPixelIntensityHistoAction.py
deleted file mode 100644
index ac29952..0000000
--- a/silx/gui/plot/test/testPixelIntensityHistoAction.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PixelIntensitiesHistoAction"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/03/2018"
-
-
-import numpy
-import unittest
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
-from silx.gui import qt
-from silx.gui.plot import Plot2D
-
-
-class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
- """Tests for PixelIntensitiesHistoAction widget."""
-
- def setUp(self):
- super(TestPixelIntensitiesHisto, self).setUp()
- self.image = numpy.random.rand(10, 10)
- self.plotImage = Plot2D()
- self.plotImage.getIntensityHistogramAction().setVisible(True)
-
- def tearDown(self):
- del self.plotImage
- super(TestPixelIntensitiesHisto, self).tearDown()
-
- def testShowAndHide(self):
- """Simple test that the plot is showing and hiding when activating the
- action"""
- self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
- self.plotImage.show()
-
- histoAction = self.plotImage.getIntensityHistogramAction()
-
- # test the pixel intensity diagram is showing
- button = getQToolButtonFromAction(histoAction)
- self.assertIsNot(button, None)
- self.mouseMove(button)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qapp.processEvents()
- self.assertTrue(histoAction.getHistogramWidget().isVisible())
-
- # test the pixel intensity diagram is hiding
- self.qapp.setActiveWindow(self.plotImage)
- self.qapp.processEvents()
- self.mouseMove(button)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qapp.processEvents()
- self.assertFalse(histoAction.getHistogramWidget().isVisible())
-
- def testImageFormatInput(self):
- """Test multiple type as image input"""
- typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
- numpy.float32, numpy.float64]
- self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
- self.plotImage.show()
- button = getQToolButtonFromAction(
- self.plotImage.getIntensityHistogramAction())
- self.mouseMove(button)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qapp.processEvents()
- for typeToTest in typesToTest:
- with self.subTest(typeToTest=typeToTest):
- self.plotImage.addImage(self.image.astype(typeToTest),
- origin=(0, 0), legend='sino')
-
- def testScatter(self):
- """Test that an histogram from a scatter is displayed"""
- xx = numpy.arange(10)
- yy = numpy.arange(10)
- value = numpy.sin(xx)
- self.plotImage.addScatter(xx, yy, value)
- self.plotImage.show()
-
- histoAction = self.plotImage.getIntensityHistogramAction()
-
- # test the pixel intensity diagram is showing
- button = getQToolButtonFromAction(histoAction)
- self.assertIsNot(button, None)
- self.mouseMove(button)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qapp.processEvents()
-
- widget = histoAction.getHistogramWidget()
- self.assertTrue(widget.isVisible())
- items = widget.getPlotWidget().getItems()
- self.assertEqual(len(items), 1)
-
- def testChangeItem(self):
- """Test that histogram changes it the item changes"""
- xx = numpy.arange(10)
- yy = numpy.arange(10)
- value = numpy.sin(xx)
- self.plotImage.addScatter(xx, yy, value)
- self.plotImage.show()
-
- histoAction = self.plotImage.getIntensityHistogramAction()
-
- # test the pixel intensity diagram is showing
- button = getQToolButtonFromAction(histoAction)
- self.assertIsNot(button, None)
- self.mouseMove(button)
- self.mouseClick(button, qt.Qt.LeftButton)
- self.qapp.processEvents()
-
- # Reach histogram from the first item
- widget = histoAction.getHistogramWidget()
- self.assertTrue(widget.isVisible())
- items = widget.getPlotWidget().getItems()
- data1 = items[0].getValueData(copy=False)
-
- # Set another item to the plot
- self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
- self.qapp.processEvents()
- data2 = items[0].getValueData(copy=False)
-
- # Histogram is not the same
- self.assertFalse(numpy.array_equal(data1, data2))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestPixelIntensitiesHisto))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotInteraction.py b/silx/gui/plot/test/testPlotInteraction.py
deleted file mode 100644
index 7a30434..0000000
--- a/silx/gui/plot/test/testPlotInteraction.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016=2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests of plot interaction, through a PlotWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/09/2017"
-
-
-import unittest
-from silx.gui import qt
-from .utils import PlotWidgetTestCase
-
-
-class _SignalDump(object):
- """Callable object that store passed arguments in a list"""
-
- def __init__(self):
- self._received = []
-
- def __call__(self, *args):
- self._received.append(args)
-
- @property
- def received(self):
- """Return a shallow copy of the list of received arguments"""
- return list(self._received)
-
-
-class TestSelectPolygon(PlotWidgetTestCase):
- """Test polygon selection interaction"""
-
- def _interactionModeChanged(self, source):
- """Check that source received in event is the correct one"""
- self.assertEqual(source, self)
-
- def _draw(self, polygon):
- """Draw a polygon in the plot
-
- :param polygon: List of points (x, y) of the polygon (closed)
- """
- plot = self.plot.getWidgetHandle()
-
- dump = _SignalDump()
- self.plot.sigPlotSignal.connect(dump)
-
- for pos in polygon:
- self.mouseMove(plot, pos=pos)
- self.qapp.processEvents()
- self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
- self.qapp.processEvents()
- self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
- self.qapp.processEvents()
-
- self.plot.sigPlotSignal.disconnect(dump)
- return [args[0] for args in dump.received]
-
- def test(self):
- """Test draw polygons + events"""
- self.plot.sigInteractiveModeChanged.connect(
- self._interactionModeChanged)
-
- self.plot.setInteractiveMode(
- 'draw', shape='polygon', label='test', source=self)
- interaction = self.plot.getInteractiveMode()
-
- self.assertEqual(interaction['mode'], 'draw')
- self.assertEqual(interaction['shape'], 'polygon')
-
- self.plot.sigInteractiveModeChanged.disconnect(
- self._interactionModeChanged)
-
- plot = self.plot.getWidgetHandle()
- xCenter, yCenter = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- # Star polygon
- star = [(xCenter, yCenter + offset),
- (xCenter - offset, yCenter - offset),
- (xCenter + offset, yCenter),
- (xCenter - offset, yCenter),
- (xCenter + offset, yCenter - offset),
- (xCenter, yCenter + offset)] # Close polygon
-
- # Draw while dumping signals
- events = self._draw(star)
-
- # Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 6)
-
- # Large square
- largeSquare = [(xCenter - offset, yCenter - offset),
- (xCenter + offset, yCenter - offset),
- (xCenter + offset, yCenter + offset),
- (xCenter - offset, yCenter + offset),
- (xCenter - offset, yCenter - offset)] # Close polygon
-
- # Draw while dumping signals
- events = self._draw(largeSquare)
-
- # Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 5)
-
- # Rectangle too thin along X: Some points are ignored
- thinRectX = [(xCenter, yCenter - offset),
- (xCenter, yCenter + offset),
- (xCenter + 1, yCenter + offset),
- (xCenter + 1, yCenter - offset)] # Close polygon
-
- # Draw while dumping signals
- events = self._draw(thinRectX)
-
- # Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 3)
-
- # Rectangle too thin along Y: Some points are ignored
- thinRectY = [(xCenter - offset, yCenter),
- (xCenter + offset, yCenter),
- (xCenter + offset, yCenter + 1),
- (xCenter - offset, yCenter + 1)] # Close polygon
-
- # Draw while dumping signals
- events = self._draw(thinRectY)
-
- # Test last event
- drawEvents = [event for event in events
- if event['event'].startswith('drawing')]
- self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
- self.assertEqual(len(drawEvents[-1]['points']), 3)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestSelectPolygon,):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotWidget.py b/silx/gui/plot/test/testPlotWidget.py
deleted file mode 100755
index b55260e..0000000
--- a/silx/gui/plot/test/testPlotWidget.py
+++ /dev/null
@@ -1,2072 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/01/2019"
-
-
-import unittest
-import logging
-import numpy
-import sys
-
-from silx.utils.testutils import ParametricTestCase, parameterize
-from silx.gui.utils.testutils import SignalListener
-from silx.gui.utils.testutils import TestCaseQt
-
-from silx.test.utils import test_options
-
-from silx.gui import qt
-from silx.gui.plot import PlotWidget
-from silx.gui.plot.items.curve import CurveStyle
-from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis
-from silx.gui.colors import Colormap
-
-from .utils import PlotWidgetTestCase
-
-
-SIZE = 1024
-"""Size of the test image"""
-
-DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE)
-"""Image data set"""
-
-
-logger = logging.getLogger(__name__)
-
-
-class TestSpecialBackend(PlotWidgetTestCase, ParametricTestCase):
-
- def __init__(self, methodName='runTest', backend=None):
- TestCaseQt.__init__(self, methodName=methodName)
- self.__backend = backend
-
- def _createPlot(self):
- return PlotWidget(backend=self.__backend)
-
- def testPlot(self):
- self.assertIsNotNone(self.plot)
-
-
-class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
- """Basic tests for PlotWidget"""
-
- def testShow(self):
- """Most basic test"""
- pass
-
- def testSetTitleLabels(self):
- """Set title and axes labels"""
-
- title, xlabel, ylabel = 'the title', 'x label', 'y label'
- self.plot.setGraphTitle(title)
- self.plot.getXAxis().setLabel(xlabel)
- self.plot.getYAxis().setLabel(ylabel)
- self.qapp.processEvents()
-
- self.assertEqual(self.plot.getGraphTitle(), title)
- self.assertEqual(self.plot.getXAxis().getLabel(), xlabel)
- self.assertEqual(self.plot.getYAxis().getLabel(), ylabel)
-
- def _checkLimits(self,
- expectedXLim=None,
- expectedYLim=None,
- expectedRatio=None):
- """Assert that limits are as expected"""
- xlim = self.plot.getXAxis().getLimits()
- ylim = self.plot.getYAxis().getLimits()
- ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
-
- if expectedXLim is not None:
- self.assertEqual(expectedXLim, xlim)
-
- if expectedYLim is not None:
- self.assertEqual(expectedYLim, ylim)
-
- if expectedRatio is not None:
- self.assertTrue(
- numpy.allclose(expectedRatio, ratio, atol=0.01))
-
- def testChangeLimitsWithAspectRatio(self):
- self.plot.setKeepDataAspectRatio()
- self.qapp.processEvents()
- xlim = self.plot.getXAxis().getLimits()
- ylim = self.plot.getYAxis().getLimits()
- defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
-
- self.plot.getXAxis().setLimits(1., 10.)
- self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
- self.qapp.processEvents()
- self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
-
- self.plot.getYAxis().setLimits(1., 10.)
- self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
- self.qapp.processEvents()
- self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
-
- def testResizeWidget(self):
- """Test resizing the widget and receiving limitsChanged events"""
- self.plot.resize(200, 200)
- self.qapp.processEvents()
- self.qWait(100)
-
- xlim = self.plot.getXAxis().getLimits()
- ylim = self.plot.getYAxis().getLimits()
-
- listener = SignalListener()
- self.plot.getXAxis().sigLimitsChanged.connect(listener.partial('x'))
- self.plot.getYAxis().sigLimitsChanged.connect(listener.partial('y'))
-
- # Resize without aspect ratio
- self.plot.resize(200, 300)
- self.qapp.processEvents()
- self.qWait(100)
- self._checkLimits(expectedXLim=xlim, expectedYLim=ylim)
- self.assertEqual(listener.callCount(), 0)
-
- # Resize with aspect ratio
- self.plot.setKeepDataAspectRatio(True)
- self.qapp.processEvents()
- self.qWait(1000)
- listener.clear() # Clean-up received signal
-
- self.plot.resize(200, 200)
- self.qapp.processEvents()
- self.qWait(100)
- self.assertNotEqual(listener.callCount(), 0)
-
- def testAddRemoveItemSignals(self):
- """Test sigItemAdded and sigItemAboutToBeRemoved"""
- listener = SignalListener()
- self.plot.sigItemAdded.connect(listener.partial('add'))
- self.plot.sigItemAboutToBeRemoved.connect(listener.partial('remove'))
-
- self.plot.addCurve((1, 2, 3), (3, 2, 1), legend='curve')
- self.assertEqual(listener.callCount(), 1)
-
- curve = self.plot.getCurve('curve')
- self.plot.remove('curve')
- self.assertEqual(listener.callCount(), 2)
- self.assertEqual(listener.arguments(callIndex=0), ('add', curve))
- self.assertEqual(listener.arguments(callIndex=1), ('remove', curve))
-
- def testGetItems(self):
- """Test getItems method"""
- curve_x = 1, 2
- self.plot.addCurve(curve_x, (3, 4))
- image = (0, 1), (2, 3)
- self.plot.addImage(image)
- scatter_x = 10, 11
- self.plot.addScatter(scatter_x, (12, 13), (0, 1))
- marker_pos = 5, 5
- self.plot.addMarker(*marker_pos)
- marker_x = 6
- self.plot.addXMarker(marker_x)
- self.plot.addShape((0, 5), (2, 10), shape='rectangle')
-
- items = self.plot.getItems()
- self.assertEqual(len(items), 6)
- self.assertTrue(numpy.all(numpy.equal(items[0].getXData(), curve_x)))
- self.assertTrue(numpy.all(numpy.equal(items[1].getData(), image)))
- self.assertTrue(numpy.all(numpy.equal(items[2].getXData(), scatter_x)))
- self.assertTrue(numpy.all(numpy.equal(items[3].getPosition(), marker_pos)))
- self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
- self.assertEqual(items[5].getType(), 'rectangle')
-
- def testRemoveDiscardItem(self):
- """Test removeItem and discardItem"""
- self.plot.addCurve((1, 2, 3), (1, 2, 3))
- curve = self.plot.getItems()[0]
- self.plot.removeItem(curve)
- with self.assertRaises(ValueError):
- self.plot.removeItem(curve)
-
- self.plot.addCurve((1, 2, 3), (1, 2, 3))
- curve = self.plot.getItems()[0]
- result = self.plot.discardItem(curve)
- self.assertTrue(result)
- result = self.plot.discardItem(curve)
- self.assertFalse(result)
-
- def testBackGroundColors(self):
- self.plot.setVisible(True)
- self.qWaitForWindowExposed(self.plot)
- self.qapp.processEvents()
-
- # Custom the full background
- color = self.plot.getBackgroundColor()
- self.assertTrue(color.isValid())
- self.assertEqual(color, qt.QColor(255, 255, 255))
- self.plot.setBackgroundColor("red")
- color = self.plot.getBackgroundColor()
- self.assertTrue(color.isValid())
- self.qapp.processEvents()
-
- # Custom the data background
- color = self.plot.getDataBackgroundColor()
- self.assertFalse(color.isValid())
- self.plot.setDataBackgroundColor("red")
- color = self.plot.getDataBackgroundColor()
- self.assertTrue(color.isValid())
- self.qapp.processEvents()
-
- # Back to default
- self.plot.setBackgroundColor('white')
- self.plot.setDataBackgroundColor(None)
- color = self.plot.getBackgroundColor()
- self.assertTrue(color.isValid())
- self.assertEqual(color, qt.QColor(255, 255, 255))
- color = self.plot.getDataBackgroundColor()
- self.assertFalse(color.isValid())
- self.qapp.processEvents()
-
-
-class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
- """Basic tests for addImage"""
-
- def setUp(self):
- super(TestPlotImage, self).setUp()
-
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
-
- def testPlotColormapTemperature(self):
- self.plot.setGraphTitle('Temp. Linear')
-
- colormap = Colormap(name='temperature',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
-
- def testPlotColormapGray(self):
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('Gray Linear')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
-
- def testPlotColormapTemperatureLog(self):
- self.plot.setGraphTitle('Temp. Log')
-
- colormap = Colormap(name='temperature',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
-
- def testPlotRgbRgba(self):
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('RGB + RGBA')
-
- rgb = numpy.array(
- (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
- ((0, 128, 0), (0, 128, 128), (0, 128, 255))),
- dtype=numpy.uint8)
-
- self.plot.addImage(rgb, legend="rgb_uint8",
- origin=(0, 0), scale=(1, 1),
- resetzoom=False)
-
- rgb = numpy.array(
- (((0, 0, 0), (32768, 0, 0), (65535, 0, 0)),
- ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535))),
- dtype=numpy.uint16)
-
- self.plot.addImage(rgb, legend="rgb_uint16",
- origin=(3, 2), scale=(2, 2),
- resetzoom=False)
-
- rgba = numpy.array(
- (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
- ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
- dtype=numpy.float32)
-
- self.plot.addImage(rgba, legend="rgba_float32",
- origin=(9, 6), scale=(1, 1),
- resetzoom=False)
-
- self.plot.resetZoom()
-
- def testPlotColormapCustom(self):
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('Custom colormap')
-
- colormap = Colormap(name=None,
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None,
- colors=((0., 0., 0.), (1., 0., 0.),
- (0., 1., 0.), (0., 0., 1.)))
- self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap,
- resetzoom=False)
-
- colormap = Colormap(name=None,
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None,
- colors=numpy.array(
- ((0, 0, 0, 0), (0, 0, 0, 128),
- (128, 128, 128, 128), (255, 255, 255, 255)),
- dtype=numpy.uint8))
- self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap,
- origin=(DATA_2D.shape[0], 0),
- resetzoom=False)
- self.plot.resetZoom()
-
- def testPlotColormapNaNColor(self):
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setGraphTitle('Colormap with NaN color')
-
- colormap = Colormap()
- colormap.setNaNColor('red')
- self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
- data = DATA_2D.astype(numpy.float32)
- data[len(data)//2:] = numpy.nan
- self.plot.addImage(data, legend="image 1", colormap=colormap,
- resetzoom=False)
- self.plot.resetZoom()
-
- colormap.setNaNColor((0., 1., 0., 1.))
- self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0))
- self.qapp.processEvents()
-
- def testImageOriginScale(self):
- """Test of image with different origin and scale"""
- self.plot.setGraphTitle('origin and scale')
-
- tests = [ # (origin, scale)
- ((10, 20), (1, 1)),
- ((10, 20), (-1, -1)),
- ((-10, 20), (2, 1)),
- ((10, -20), (-1, -2)),
- (100, 2),
- (-100, (1, 1)),
- ((10, 20), 2),
- ]
-
- for origin, scale in tests:
- with self.subTest(origin=origin, scale=scale):
- self.plot.addImage(DATA_2D, origin=origin, scale=scale)
-
- try:
- ox, oy = origin
- except TypeError:
- ox, oy = origin, origin
- try:
- sx, sy = scale
- except TypeError:
- sx, sy = scale, scale
- xbounds = ox, ox + DATA_2D.shape[1] * sx
- ybounds = oy, oy + DATA_2D.shape[0] * sy
-
- # Check limits without aspect ratio
- xmin, xmax = self.plot.getXAxis().getLimits()
- ymin, ymax = self.plot.getYAxis().getLimits()
- self.assertEqual(xmin, min(xbounds))
- self.assertEqual(xmax, max(xbounds))
- self.assertEqual(ymin, min(ybounds))
- self.assertEqual(ymax, max(ybounds))
-
- # Check limits with aspect ratio
- self.plot.setKeepDataAspectRatio(True)
- xmin, xmax = self.plot.getXAxis().getLimits()
- ymin, ymax = self.plot.getYAxis().getLimits()
- self.assertTrue(round(xmin, 7) <= min(xbounds))
- self.assertTrue(round(xmax, 7) >= max(xbounds))
- self.assertTrue(round(ymin, 7) <= min(ybounds))
- self.assertTrue(round(ymax, 7) >= max(ybounds))
-
- self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio
- self.plot.clear()
- self.plot.resetZoom()
-
- def testPlotColormapDictAPI(self):
- """Test that the addImage API using a colormap dictionary is still
- working"""
- self.plot.setGraphTitle('Temp. Log')
-
- colormap = {
- 'name': 'temperature',
- 'normalization': 'log',
- 'vmin': None,
- 'vmax': None
- }
- self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
-
- def testPlotComplexImage(self):
- """Test that a complex image is displayed as its absolute value."""
- data = numpy.linspace(1, 1j, 100).reshape(10, 10)
- self.plot.addImage(data, legend='complex')
-
- image = self.plot.getActiveImage()
- retrievedData = image.getData(copy=False)
- self.assertTrue(
- numpy.all(numpy.equal(retrievedData, numpy.absolute(data))))
-
- def testPlotBooleanImage(self):
- """Test that a boolean image is displayed and converted to int8."""
- data = numpy.zeros((10, 10), dtype=bool)
- data[::2, ::2] = True
- self.plot.addImage(data, legend='boolean')
-
- image = self.plot.getActiveImage()
- retrievedData = image.getData(copy=False)
- self.assertTrue(numpy.all(numpy.equal(retrievedData, data)))
- self.assertIs(retrievedData.dtype.type, numpy.int8)
-
- def testPlotAlphaImage(self):
- """Test with an alpha image layer"""
- data = numpy.random.random((10, 10))
- alpha = numpy.linspace(0, 1, 100).reshape(10, 10)
- self.plot.addImage(data, legend='image')
- image = self.plot.getActiveImage()
- image.setData(data, alpha=alpha)
- self.qapp.processEvents()
- self.assertTrue(numpy.array_equal(alpha, image.getAlphaData()))
-
-
-class TestPlotCurve(PlotWidgetTestCase):
- """Basic tests for addCurve."""
-
- # Test data sets
- xData = numpy.arange(1000)
- yData = -500 + 100 * numpy.sin(xData)
- xData2 = xData + 1000
- yData2 = xData - 1000 + 200 * numpy.random.random(1000)
-
- def setUp(self):
- super(TestPlotCurve, self).setUp()
- self.plot.setGraphTitle('Curve')
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
-
- self.plot.setActiveCurveHandling(False)
-
- def testPlotCurveInfinite(self):
- """Test plot curves with not finite data"""
- tests = {
- 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
- 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
- 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]),
- 'y some inf': ([0, 1, 2], [0, numpy.inf, 2])
- }
- for name, args in tests.items():
- with self.subTest(name):
- self.plot.addCurve(*args)
- self.plot.resetZoom()
- self.qapp.processEvents()
- self.plot.clear()
-
- def testPlotCurveColorFloat(self):
- color = numpy.array(numpy.random.random(3 * 1000),
- dtype=numpy.float32).reshape(1000, 3)
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 1",
- replace=False, resetzoom=False,
- color=color,
- linestyle="", symbol="s")
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- self.plot.resetZoom()
-
- def testPlotCurveColorByte(self):
- color = numpy.array(255 * numpy.random.random(3 * 1000),
- dtype=numpy.uint8).reshape(1000, 3)
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 1",
- replace=False, resetzoom=False,
- color=color,
- linestyle="", symbol="s")
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- self.plot.resetZoom()
-
- def testPlotCurveColors(self):
- color = numpy.array(numpy.random.random(3 * 1000),
- dtype=numpy.float32).reshape(1000, 3)
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color=color, linestyle="-", symbol='o')
- self.plot.resetZoom()
-
- # Test updating color array
-
- # From array to array
- newColors = numpy.ones((len(self.xData), 3), dtype=numpy.float32)
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color=newColors, symbol='o')
-
- # Array to single color
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color='green', symbol='o')
-
- # single color to array
- self.plot.addCurve(self.xData, self.yData,
- legend="curve 2",
- replace=False, resetzoom=False,
- color=color, symbol='o')
-
- def testPlotBaselineNumpyArray(self):
- """simple test of the API with baseline as a numpy array"""
- x = numpy.arange(0, 10, step=0.1)
- my_sin = numpy.sin(x)
- y = numpy.arange(-4, 6, step=0.1) + my_sin
- baseline = y - 1.0
-
- self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
- baseline=baseline)
-
- def testPlotBaselineScalar(self):
- """simple test of the API with baseline as an int"""
- x = numpy.arange(0, 10, step=0.1)
- my_sin = numpy.sin(x)
- y = numpy.arange(-4, 6, step=0.1) + my_sin
-
- self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
- baseline=0)
-
- def testPlotBaselineList(self):
- """simple test of the API with baseline as an int"""
- x = numpy.arange(0, 10, step=0.1)
- my_sin = numpy.sin(x)
- y = numpy.arange(-4, 6, step=0.1) + my_sin
-
- self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
- baseline=list(range(0, 100, 1)))
-
- def testPlotCurveComplexData(self):
- """Test curve with complex data"""
- data = numpy.arange(100.) + 1j
- self.plot.addCurve(x=data, y=data, xerror=data, yerror=data)
-
-
-class TestPlotHistogram(PlotWidgetTestCase):
- """Basic tests for add Histogram"""
- def setUp(self):
- super(TestPlotHistogram, self).setUp()
- self.edges = numpy.arange(0, 10, step=1)
- self.histogram = numpy.random.random(len(self.edges))
-
- def testPlot(self):
- self.plot.addHistogram(histogram=self.histogram,
- edges=self.edges,
- legend='histogram1')
-
- def testPlotBaseline(self):
- self.plot.addHistogram(histogram=self.histogram,
- edges=self.edges,
- legend='histogram1',
- color='blue',
- baseline=-2,
- z=2,
- fill=True)
-
-
-class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
- """Basic tests for addScatter"""
-
- def testScatter(self):
- x = numpy.arange(100)
- y = numpy.arange(100)
- value = numpy.arange(100)
- self.plot.addScatter(x, y, value)
- self.plot.resetZoom()
-
- def testScatterComplexData(self):
- """Test scatter item with complex data"""
- data = numpy.arange(100.) + 1j
- self.plot.addScatter(
- x=data, y=data, value=data, xerror=data, yerror=data)
- self.plot.resetZoom()
-
- def testScatterVisualization(self):
- self.plot.addScatter((0, 1, 0, 1), (0, 0, 2, 2), (0, 1, 2, 3))
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- scatter = self.plot.getItems()[0]
-
- for visualization in ('solid',
- 'points',
- 'regular_grid',
- 'irregular_grid',
- 'binned_statistic',
- scatter.Visualization.SOLID,
- scatter.Visualization.POINTS,
- scatter.Visualization.REGULAR_GRID,
- scatter.Visualization.IRREGULAR_GRID,
- scatter.Visualization.BINNED_STATISTIC):
- with self.subTest(visualization=visualization):
- scatter.setVisualization(visualization)
- self.qapp.processEvents()
-
- def testGridVisualization(self):
- """Test regular and irregular grid mode with different points"""
- points = { # name: (x, y, order)
- 'single point': ((1.,), (1.,), 'row'),
- 'horizontal line': ((0, 1, 2), (0, 0, 0), 'row'),
- 'horizontal line backward': ((2, 1, 0), (0, 0, 0), 'row'),
- 'vertical line': ((0, 0, 0), (0, 1, 2), 'row'),
- 'vertical line backward': ((0, 0, 0), (2, 1, 0), 'row'),
- 'grid fast x, +x +y': ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), 'row'),
- 'grid fast x, +x -y': ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), 'row'),
- 'grid fast x, -x -y': ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), 'row'),
- 'grid fast x, -x +y': ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), 'row'),
- 'grid fast y, +x +y': ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), 'column'),
- 'grid fast y, +x -y': ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), 'column'),
- 'grid fast y, -x -y': ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), 'column'),
- 'grid fast y, -x +y': ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), 'column'),
- }
-
- self.plot.addScatter((), (), ())
- scatter = self.plot.getItems()[0]
-
- self.qapp.processEvents()
-
- for visualization in (scatter.Visualization.REGULAR_GRID,
- scatter.Visualization.IRREGULAR_GRID):
- scatter.setVisualization(visualization)
- self.assertIs(scatter.getVisualization(), visualization)
-
- for name, (x, y, ref_order) in points.items():
- with self.subTest(name=name, visualization=visualization.name):
- scatter.setData(x, y, numpy.arange(len(x)))
- self.plot.setGraphTitle(name)
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- order = scatter.getCurrentVisualizationParameter(
- scatter.VisualizationParameter.GRID_MAJOR_ORDER)
- self.assertEqual(ref_order, order)
-
- ref_bounds = (x[0], y[0]), (x[-1], y[-1])
- bounds = scatter.getCurrentVisualizationParameter(
- scatter.VisualizationParameter.GRID_BOUNDS)
- self.assertEqual(ref_bounds, bounds)
-
- shape = scatter.getCurrentVisualizationParameter(
- scatter.VisualizationParameter.GRID_SHAPE)
-
- self.plot.getXAxis().setLimits(numpy.min(x) - 1, numpy.max(x) + 1)
- self.plot.getYAxis().setLimits(numpy.min(y) - 1, numpy.max(y) + 1)
- self.qapp.processEvents()
-
- for index, position in enumerate(zip(x, y)):
- xpixel, ypixel = self.plot.dataToPixel(*position)
- result = scatter.pick(xpixel, ypixel)
- self.assertIsNotNone(result)
- self.assertIs(result.getItem(), scatter)
- self.assertEqual(result.getIndices(), (index,))
-
- def testBinnedStatisticVisualization(self):
- """Test binned display"""
- self.plot.addScatter((), (), ())
- scatter = self.plot.getItems()[0]
- scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC)
- self.assertIs(scatter.getVisualization(),
- scatter.Visualization.BINNED_STATISTIC)
- self.assertEqual(
- scatter.getVisualizationParameter(
- scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
- 'mean')
-
- self.qapp.processEvents()
-
- scatter.setData(*numpy.random.random(3000).reshape(3, -1))
-
- for reduction in ('count', 'sum', 'mean'):
- with self.subTest(reduction=reduction):
- scatter.setVisualizationParameter(
- scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
- reduction)
- self.assertEqual(
- scatter.getVisualizationParameter(
- scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
- reduction)
-
- self.qapp.processEvents()
-
-
-class TestPlotMarker(PlotWidgetTestCase):
- """Basic tests for add*Marker"""
-
- def setUp(self):
- super(TestPlotMarker, self).setUp()
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
-
- self.plot.getXAxis().setAutoScale(False)
- self.plot.getYAxis().setAutoScale(False)
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(0., 100., -100., 100.)
-
- def testPlotMarkerX(self):
- self.plot.setGraphTitle('Markers X')
-
- markers = [
- (10., 'blue', False, False),
- (20., 'red', False, False),
- (40., 'green', True, False),
- (60., 'gray', True, True),
- (80., 'black', False, True),
- ]
-
- for x, color, select, drag in markers:
- name = str(x)
- if select:
- name += " sel."
- if drag:
- name += " drag"
- self.plot.addXMarker(x, name, name, color, select, drag)
- self.plot.resetZoom()
-
- def testPlotMarkerY(self):
- self.plot.setGraphTitle('Markers Y')
-
- markers = [
- (-50., 'blue', False, False),
- (-30., 'red', False, False),
- (0., 'green', True, False),
- (10., 'gray', True, True),
- (80., 'black', False, True),
- ]
-
- for y, color, select, drag in markers:
- name = str(y)
- if select:
- name += " sel."
- if drag:
- name += " drag"
- self.plot.addYMarker(y, name, name, color, select, drag)
- self.plot.resetZoom()
-
- def testPlotMarkerPt(self):
- self.plot.setGraphTitle('Markers Pt')
-
- markers = [
- (10., -50., 'blue', False, False),
- (40., -30., 'red', False, False),
- (50., 0., 'green', True, False),
- (50., 20., 'gray', True, True),
- (70., 50., 'black', False, True),
- ]
- for x, y, color, select, drag in markers:
- name = "{0},{1}".format(x, y)
- if select:
- name += " sel."
- if drag:
- name += " drag"
- self.plot.addMarker(x, y, name, name, color, select, drag)
-
- self.plot.resetZoom()
-
- def testPlotMarkerWithoutLegend(self):
- self.plot.setGraphTitle('Markers without legend')
- self.plot.getYAxis().setInverted(True)
-
- # Markers without legend
- self.plot.addMarker(10, 10)
- self.plot.addMarker(10, 20)
- self.plot.addMarker(40, 50, text='test', symbol=None)
- self.plot.addMarker(40, 50, text='test', symbol='+')
- self.plot.addXMarker(25)
- self.plot.addXMarker(35)
- self.plot.addXMarker(45, text='test')
- self.plot.addYMarker(55)
- self.plot.addYMarker(65)
- self.plot.addYMarker(75, text='test')
-
- self.plot.resetZoom()
-
- def testPlotMarkerYAxis(self):
- # Check only the API
-
- legend = self.plot.addMarker(10, 10)
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "left")
-
- legend = self.plot.addMarker(10, 10, yaxis="right")
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "right")
-
- legend = self.plot.addMarker(10, 10, yaxis="left")
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "left")
-
- legend = self.plot.addXMarker(10, yaxis="right")
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "right")
-
- legend = self.plot.addXMarker(10, yaxis="left")
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "left")
-
- legend = self.plot.addYMarker(10, yaxis="right")
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "right")
-
- legend = self.plot.addYMarker(10, yaxis="left")
- item = self.plot._getMarker(legend)
- self.assertEqual(item.getYAxis(), "left")
-
- self.plot.resetZoom()
-
-
-# TestPlotItem ################################################################
-
-class TestPlotItem(PlotWidgetTestCase):
- """Basic tests for addItem."""
-
- # Polygon coordinates and color
- POLYGONS = [ # legend, x coords, y coords, color
- ('triangle', numpy.array((10, 30, 50)),
- numpy.array((55, 70, 55)), 'red'),
- ('square', numpy.array((10, 10, 50, 50)),
- numpy.array((10, 50, 50, 10)), 'green'),
- ('star', numpy.array((60, 70, 80, 60, 80)),
- numpy.array((25, 50, 25, 40, 40)), 'blue'),
- ('2 triangles-simple',
- numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)),
- numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)),
- 'pink'),
- ('2 triangles-extra NaN',
- numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)),
- numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)),
- 'black'),
- ]
-
- # Rectangle coordinantes and color
- RECTANGLES = [ # legend, x coords, y coords, color
- ('square 1', numpy.array((1., 10.)),
- numpy.array((1., 10.)), 'red'),
- ('square 2', numpy.array((10., 20.)),
- numpy.array((10., 20.)), 'green'),
- ('square 3', numpy.array((20., 30.)),
- numpy.array((20., 30.)), 'blue'),
- ('rect 1', numpy.array((1., 30.)),
- numpy.array((35., 40.)), 'black'),
- ('line h', numpy.array((1., 30.)),
- numpy.array((45., 45.)), 'darkRed'),
- ]
-
- SCALES = Axis.LINEAR, Axis.LOGARITHMIC
-
- def setUp(self):
- super(TestPlotItem, self).setUp()
-
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
- self.plot.getXAxis().setAutoScale(False)
- self.plot.getYAxis().setAutoScale(False)
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(0., 100., -100., 100.)
-
- def testPlotItemPolygonFill(self):
- for scale in self.SCALES:
- with self.subTest(scale=scale):
- self.plot.clear()
- self.plot.getXAxis().setScale(scale)
- self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Item Fill %s' % scale)
-
- for legend, xList, yList, color in self.POLYGONS:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False, linestyle='--',
- shape="polygon", fill=True, color=color)
- self.plot.resetZoom()
-
- def testPlotItemPolygonNoFill(self):
- for scale in self.SCALES:
- with self.subTest(scale=scale):
- self.plot.clear()
- self.plot.getXAxis().setScale(scale)
- self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Item No Fill %s' % scale)
-
- for legend, xList, yList, color in self.POLYGONS:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False, linestyle='--',
- shape="polygon", fill=False, color=color)
- self.plot.resetZoom()
-
- def testPlotItemRectangleFill(self):
- for scale in self.SCALES:
- with self.subTest(scale=scale):
- self.plot.clear()
- self.plot.getXAxis().setScale(scale)
- self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Rectangle Fill %s' % scale)
-
- for legend, xList, yList, color in self.RECTANGLES:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=True, color=color)
- self.plot.resetZoom()
-
- def testPlotItemRectangleNoFill(self):
- for scale in self.SCALES:
- with self.subTest(scale=scale):
- self.plot.clear()
- self.plot.getXAxis().setScale(scale)
- self.plot.getYAxis().setScale(scale)
- self.plot.setGraphTitle('Rectangle No Fill %s' % scale)
-
- for legend, xList, yList, color in self.RECTANGLES:
- self.plot.addShape(xList, yList, legend=legend,
- replace=False,
- shape="rectangle", fill=False, color=color)
- self.plot.resetZoom()
-
-
-class TestPlotActiveCurveImage(PlotWidgetTestCase):
- """Basic tests for active curve and image handling"""
- xData = numpy.arange(1000)
- yData = -500 + 100 * numpy.sin(xData)
- xData2 = xData + 1000
- yData2 = xData - 1000 + 200 * numpy.random.random(1000)
-
- def tearDown(self):
- self.plot.setActiveCurveHandling(False)
- super(TestPlotActiveCurveImage, self).tearDown()
-
- def testActiveCurveAndLabels(self):
- # Active curve handling off, no label change
- self.plot.setActiveCurveHandling(False)
- self.plot.getXAxis().setLabel('XLabel')
- self.plot.getYAxis().setLabel('YLabel')
- self.plot.addCurve((1, 2), (1, 2))
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.clear()
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- # Active curve handling on, label changes
- self.plot.setActiveCurveHandling(True)
- self.plot.getXAxis().setLabel('XLabel')
- self.plot.getYAxis().setLabel('YLabel')
-
- # labels changed as active curve
- self.plot.addCurve((1, 2), (1, 2), legend='1',
- xlabel='x1', ylabel='y1')
- self.plot.setActiveCurve('1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels not changed as not active curve
- self.plot.addCurve((1, 2), (2, 3), legend='2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels changed
- self.plot.setActiveCurve('2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.setActiveCurve('1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- self.plot.clear()
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- def testPlotActiveCurveSelectionMode(self):
- self.plot.clear()
- self.plot.setActiveCurveHandling(True)
- legend = "curve 1"
- self.plot.addCurve(self.xData, self.yData,
- legend=legend,
- color="green")
-
- # active curve should be None
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
-
- # active curve should be None when None is set as active curve
- self.plot.setActiveCurve(legend)
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, legend)
- self.plot.setActiveCurve(None)
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, None)
-
- # testing it automatically toggles if there is only one
- self.plot.setActiveCurveSelectionMode("legacy")
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, legend)
-
- # active curve should not change when None set as active curve
- self.assertEqual(self.plot.getActiveCurveSelectionMode(), "legacy")
- self.plot.setActiveCurve(None)
- current = self.plot.getActiveCurve(just_legend=True)
- self.assertEqual(current, legend)
-
- # situation where no curve is active
- self.plot.clear()
- self.plot.setActiveCurveHandling(True)
- self.assertEqual(self.plot.getActiveCurveSelectionMode(), "atmostone")
- self.plot.addCurve(self.xData, self.yData,
- legend=legend,
- color="green")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- color="red")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
- self.plot.setActiveCurveSelectionMode("legacy")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
-
- # the first curve added should be active
- self.plot.clear()
- self.plot.addCurve(self.xData, self.yData,
- legend=legend,
- color="green")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
- self.plot.addCurve(self.xData2, self.yData2,
- legend="curve 2",
- color="red")
- self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
-
- def testActiveCurveStyle(self):
- """Test change of active curve style"""
- self.plot.setActiveCurveHandling(True)
- self.plot.setActiveCurveStyle(color='black')
- style = self.plot.getActiveCurveStyle()
- self.assertEqual(style.getColor(), (0., 0., 0., 1.))
- self.assertIsNone(style.getLineStyle())
- self.assertIsNone(style.getLineWidth())
- self.assertIsNone(style.getSymbol())
- self.assertIsNone(style.getSymbolSize())
-
- self.plot.addCurve(x=self.xData, y=self.yData, legend="curve1")
- curve = self.plot.getCurve("curve1")
- curve.setColor('blue')
- curve.setLineStyle('-')
- curve.setLineWidth(1)
- curve.setSymbol('o')
- curve.setSymbolSize(5)
-
- # Check default current style
- defaultStyle = curve.getCurrentStyle()
- self.assertEqual(defaultStyle, CurveStyle(color='blue',
- linestyle='-',
- linewidth=1,
- symbol='o',
- symbolsize=5))
-
- # Activate curve with highlight color=black
- self.plot.setActiveCurve("curve1")
- style = curve.getCurrentStyle()
- self.assertEqual(style.getColor(), (0., 0., 0., 1.))
- self.assertEqual(style.getLineStyle(), '-')
- self.assertEqual(style.getLineWidth(), 1)
- self.assertEqual(style.getSymbol(), 'o')
- self.assertEqual(style.getSymbolSize(), 5)
-
- # Change highlight to linewidth=2
- self.plot.setActiveCurveStyle(linewidth=2)
- style = curve.getCurrentStyle()
- self.assertEqual(style.getColor(), (0., 0., 1., 1.))
- self.assertEqual(style.getLineStyle(), '-')
- self.assertEqual(style.getLineWidth(), 2)
- self.assertEqual(style.getSymbol(), 'o')
- self.assertEqual(style.getSymbolSize(), 5)
-
- self.plot.setActiveCurve(None)
- self.assertEqual(curve.getCurrentStyle(), defaultStyle)
-
- def testActiveImageAndLabels(self):
- # Active image handling always on, no API for toggling it
- self.plot.getXAxis().setLabel('XLabel')
- self.plot.getYAxis().setLabel('YLabel')
-
- # labels changed as active curve
- self.plot.addImage(numpy.arange(100).reshape(10, 10),
- legend='1', xlabel='x1', ylabel='y1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels not changed as not active curve
- self.plot.addImage(numpy.arange(100).reshape(10, 10),
- legend='2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- # labels changed
- self.plot.setActiveImage('2')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
- self.plot.setActiveImage('1')
- self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
-
- self.plot.clear()
- self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
- self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
-
-
-##############################################################################
-# Log
-##############################################################################
-
-class TestPlotEmptyLog(PlotWidgetTestCase):
- """Basic tests for log plot"""
- def testEmptyPlotTitleLabelsLog(self):
- self.plot.setGraphTitle('Empty Log Log')
- self.plot.getXAxis().setLabel('X')
- self.plot.getYAxis().setLabel('Y')
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
- self.plot.resetZoom()
-
-
-class TestPlotAxes(TestCaseQt, ParametricTestCase):
-
- # Test data
- xData = numpy.arange(1, 10)
- yData = xData ** 2
-
- def __init__(self, methodName='runTest', backend=None):
- unittest.TestCase.__init__(self, methodName)
- self.__backend = backend
-
- def setUp(self):
- super(TestPlotAxes, self).setUp()
- self.plot = PlotWidget(backend=self.__backend)
- # It is not needed to display the plot
- # It saves a lot of time
- # self.plot.show()
- # self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestPlotAxes, self).tearDown()
-
- def testDefaultAxes(self):
- axis = self.plot.getXAxis()
- self.assertEqual(axis.getScale(), axis.LINEAR)
- axis = self.plot.getYAxis()
- self.assertEqual(axis.getScale(), axis.LINEAR)
- axis = self.plot.getYAxis(axis="right")
- self.assertEqual(axis.getScale(), axis.LINEAR)
-
- def testOldPlotAxis_getterSetter(self):
- """Test silx API prior to silx 0.6"""
- x = self.plot.getXAxis()
- y = self.plot.getYAxis()
- p = self.plot
-
- tests = [
- # setters
- (p.setGraphXLimits, (10, 20), x.getLimits, (10, 20)),
- (p.setGraphYLimits, (10, 20), y.getLimits, (10, 20)),
- (p.setGraphXLabel, "foox", x.getLabel, "foox"),
- (p.setGraphYLabel, "fooy", y.getLabel, "fooy"),
- (p.setYAxisInverted, True, y.isInverted, True),
- (p.setXAxisLogarithmic, True, x.getScale, x.LOGARITHMIC),
- (p.setYAxisLogarithmic, True, y.getScale, y.LOGARITHMIC),
- (p.setXAxisAutoScale, False, x.isAutoScale, False),
- (p.setYAxisAutoScale, False, y.isAutoScale, False),
- # getters
- (x.setLimits, (11, 20), p.getGraphXLimits, (11, 20)),
- (y.setLimits, (11, 20), p.getGraphYLimits, (11, 20)),
- (x.setLabel, "fooxx", p.getGraphXLabel, "fooxx"),
- (y.setLabel, "fooyy", p.getGraphYLabel, "fooyy"),
- (y.setInverted, False, p.isYAxisInverted, False),
- (x.setScale, x.LINEAR, p.isXAxisLogarithmic, False),
- (y.setScale, y.LINEAR, p.isYAxisLogarithmic, False),
- (x.setAutoScale, True, p.isXAxisAutoScale, True),
- (y.setAutoScale, True, p.isYAxisAutoScale, True),
- ]
- for testCase in tests:
- setter, value, getter, expected = testCase
- with self.subTest():
- if setter is not None:
- if not isinstance(value, tuple):
- value = (value, )
- setter(*value)
- if getter is not None:
- self.assertEqual(getter(), expected)
-
- def testOldPlotAxis_Logarithmic(self):
- """Test silx API prior to silx 0.6"""
- x = self.plot.getXAxis()
- y = self.plot.getYAxis()
- yright = self.plot.getYAxis(axis="right")
-
- self.assertEqual(x.getScale(), x.LINEAR)
- self.assertEqual(y.getScale(), x.LINEAR)
- self.assertEqual(yright.getScale(), x.LINEAR)
-
- self.plot.setXAxisLogarithmic(True)
- self.assertEqual(x.getScale(), x.LOGARITHMIC)
- self.assertEqual(y.getScale(), x.LINEAR)
- self.assertEqual(yright.getScale(), x.LINEAR)
- self.assertEqual(self.plot.isXAxisLogarithmic(), True)
- self.assertEqual(self.plot.isYAxisLogarithmic(), False)
-
- self.plot.setYAxisLogarithmic(True)
- self.assertEqual(x.getScale(), x.LOGARITHMIC)
- self.assertEqual(y.getScale(), x.LOGARITHMIC)
- self.assertEqual(yright.getScale(), x.LOGARITHMIC)
- self.assertEqual(self.plot.isXAxisLogarithmic(), True)
- self.assertEqual(self.plot.isYAxisLogarithmic(), True)
-
- yright.setScale(yright.LINEAR)
- self.assertEqual(x.getScale(), x.LOGARITHMIC)
- self.assertEqual(y.getScale(), x.LINEAR)
- self.assertEqual(yright.getScale(), x.LINEAR)
- self.assertEqual(self.plot.isXAxisLogarithmic(), True)
- self.assertEqual(self.plot.isYAxisLogarithmic(), False)
-
- def testOldPlotAxis_AutoScale(self):
- """Test silx API prior to silx 0.6"""
- x = self.plot.getXAxis()
- y = self.plot.getYAxis()
- yright = self.plot.getYAxis(axis="right")
-
- self.assertEqual(x.isAutoScale(), True)
- self.assertEqual(y.isAutoScale(), True)
- self.assertEqual(yright.isAutoScale(), True)
-
- self.plot.setXAxisAutoScale(False)
- self.assertEqual(x.isAutoScale(), False)
- self.assertEqual(y.isAutoScale(), True)
- self.assertEqual(yright.isAutoScale(), True)
- self.assertEqual(self.plot.isXAxisAutoScale(), False)
- self.assertEqual(self.plot.isYAxisAutoScale(), True)
-
- self.plot.setYAxisAutoScale(False)
- self.assertEqual(x.isAutoScale(), False)
- self.assertEqual(y.isAutoScale(), False)
- self.assertEqual(yright.isAutoScale(), False)
- self.assertEqual(self.plot.isXAxisAutoScale(), False)
- self.assertEqual(self.plot.isYAxisAutoScale(), False)
-
- yright.setAutoScale(True)
- self.assertEqual(x.isAutoScale(), False)
- self.assertEqual(y.isAutoScale(), True)
- self.assertEqual(yright.isAutoScale(), True)
- self.assertEqual(self.plot.isXAxisAutoScale(), False)
- self.assertEqual(self.plot.isYAxisAutoScale(), True)
-
- def testOldPlotAxis_Inverted(self):
- """Test silx API prior to silx 0.6"""
- x = self.plot.getXAxis()
- y = self.plot.getYAxis()
- yright = self.plot.getYAxis(axis="right")
-
- self.assertEqual(x.isInverted(), False)
- self.assertEqual(y.isInverted(), False)
- self.assertEqual(yright.isInverted(), False)
-
- self.plot.setYAxisInverted(True)
- self.assertEqual(x.isInverted(), False)
- self.assertEqual(y.isInverted(), True)
- self.assertEqual(yright.isInverted(), True)
- self.assertEqual(self.plot.isYAxisInverted(), True)
-
- yright.setInverted(False)
- self.assertEqual(x.isInverted(), False)
- self.assertEqual(y.isInverted(), False)
- self.assertEqual(yright.isInverted(), False)
- self.assertEqual(self.plot.isYAxisInverted(), False)
-
- def testLogXWithData(self):
- self.plot.setGraphTitle('Curve X: Log Y: Linear')
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
- axis = self.plot.getXAxis()
- axis.setScale(axis.LOGARITHMIC)
-
- self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
-
- def testLogYWithData(self):
- self.plot.setGraphTitle('Curve X: Linear Y: Log')
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
- axis = self.plot.getYAxis()
- axis.setScale(axis.LOGARITHMIC)
-
- self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
- axis = self.plot.getYAxis(axis="right")
- self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
-
- def testLogYRightWithData(self):
- self.plot.setGraphTitle('Curve X: Linear Y: Log')
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
- axis = self.plot.getYAxis(axis="right")
- axis.setScale(axis.LOGARITHMIC)
-
- self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
- axis = self.plot.getYAxis()
- self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
-
- def testLimitsChanged_setLimits(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- listener = SignalListener()
- self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
- self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
- self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
- self.plot.setLimits(0, 1, 0, 1, 0, 1)
- # at least one event per axis
- self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
-
- def testLimitsChanged_resetZoom(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- listener = SignalListener()
- self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
- self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
- self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
- self.plot.resetZoom()
- # at least one event per axis
- self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
-
- def testLimitsChanged_setXLimit(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- listener = SignalListener()
- axis = self.plot.getXAxis()
- axis.sigLimitsChanged.connect(listener)
- axis.setLimits(20, 30)
- # at least one event per axis
- self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
- self.assertEqual(axis.getLimits(), (20.0, 30.0))
-
- def testLimitsChanged_setYLimit(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- listener = SignalListener()
- axis = self.plot.getYAxis()
- axis.sigLimitsChanged.connect(listener)
- axis.setLimits(20, 30)
- # at least one event per axis
- self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
- self.assertEqual(axis.getLimits(), (20.0, 30.0))
-
- def testLimitsChanged_setYRightLimit(self):
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=False,
- color='green', linestyle="-", symbol='o')
- listener = SignalListener()
- axis = self.plot.getYAxis(axis="right")
- axis.sigLimitsChanged.connect(listener)
- axis.setLimits(20, 30)
- # at least one event per axis
- self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
- self.assertEqual(axis.getLimits(), (20.0, 30.0))
-
- def testScaleProxy(self):
- listener = SignalListener()
- y = self.plot.getYAxis()
- yright = self.plot.getYAxis(axis="right")
- y.sigScaleChanged.connect(listener.partial("left"))
- yright.sigScaleChanged.connect(listener.partial("right"))
- yright.setScale(yright.LOGARITHMIC)
-
- self.assertEqual(y.getScale(), y.LOGARITHMIC)
- events = listener.arguments()
- self.assertEqual(len(events), 2)
- self.assertIn(("left", y.LOGARITHMIC), events)
- self.assertIn(("right", y.LOGARITHMIC), events)
-
- def testAutoScaleProxy(self):
- listener = SignalListener()
- y = self.plot.getYAxis()
- yright = self.plot.getYAxis(axis="right")
- y.sigAutoScaleChanged.connect(listener.partial("left"))
- yright.sigAutoScaleChanged.connect(listener.partial("right"))
- yright.setAutoScale(False)
-
- self.assertEqual(y.isAutoScale(), False)
- events = listener.arguments()
- self.assertEqual(len(events), 2)
- self.assertIn(("left", False), events)
- self.assertIn(("right", False), events)
-
- def testInvertedProxy(self):
- listener = SignalListener()
- y = self.plot.getYAxis()
- yright = self.plot.getYAxis(axis="right")
- y.sigInvertedChanged.connect(listener.partial("left"))
- yright.sigInvertedChanged.connect(listener.partial("right"))
- yright.setInverted(True)
-
- self.assertEqual(y.isInverted(), True)
- events = listener.arguments()
- self.assertEqual(len(events), 2)
- self.assertIn(("left", True), events)
- self.assertIn(("right", True), events)
-
- def testAxesDisplayedFalse(self):
- """Test coverage on setAxesDisplayed(False)"""
- self.plot.setAxesDisplayed(False)
-
- def testAxesDisplayedTrue(self):
- """Test coverage on setAxesDisplayed(True)"""
- self.plot.setAxesDisplayed(True)
-
- def testAxesMargins(self):
- """Test PlotWidget's getAxesMargins and setAxesMargins"""
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- margins = self.plot.getAxesMargins()
- self.assertEqual(margins, (.15, .1, .1, .15))
-
- for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)):
- with self.subTest(margins=margins):
- self.plot.setAxesMargins(*margins)
- self.qapp.processEvents()
- self.assertEqual(self.plot.getAxesMargins(), margins)
-
- def testBoundingRectItem(self):
- item = BoundingRect()
- item.setBounds((-1000, 1000, -2000, 2000))
- self.plot.addItem(item)
- self.plot.resetZoom()
- limits = numpy.array(self.plot.getXAxis().getLimits())
- numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
- limits = numpy.array(self.plot.getYAxis().getLimits())
- numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
-
- def testBoundingRectRightItem(self):
- item = BoundingRect()
- item.setYAxis("right")
- item.setBounds((-1000, 1000, -2000, 2000))
- self.plot.addItem(item)
- self.plot.resetZoom()
- limits = numpy.array(self.plot.getXAxis().getLimits())
- numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
- limits = numpy.array(self.plot.getYAxis("right").getLimits())
- numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
-
- def testBoundingRectArguments(self):
- item = BoundingRect()
- with self.assertRaises(Exception):
- item.setBounds((1000, -1000, -2000, 2000))
- with self.assertRaises(Exception):
- item.setBounds((-1000, 1000, 2000, -2000))
-
- def testBoundingRectWithLog(self):
- item = BoundingRect()
- self.plot.addItem(item)
-
- item.setBounds((-1000, 1000, -2000, 2000))
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(False)
- self.assertEqual(item.getBounds(), (1000, 1000, -2000, 2000))
-
- item.setBounds((-1000, 1000, -2000, 2000))
- self.plot.getXAxis()._setLogarithmic(False)
- self.plot.getYAxis()._setLogarithmic(True)
- self.assertEqual(item.getBounds(), (-1000, 1000, 2000, 2000))
-
- item.setBounds((-1000, 0, -2000, 2000))
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(False)
- self.assertIsNone(item.getBounds())
-
- def testAxisExtent(self):
- """Test XAxisExtent and yAxisExtent"""
- for cls, axis in ((XAxisExtent, self.plot.getXAxis()),
- (YAxisExtent, self.plot.getYAxis())):
- for range_, logRange in (((2, 3), (2, 3)),
- ((-2, -1), (1, 100)),
- ((-1, 3), (3. * 0.9, 3. * 1.1))):
- extent = cls()
- extent.setRange(*range_)
- self.plot.addItem(extent)
-
- for isLog, plotRange in ((False, range_), (True, logRange)):
- with self.subTest(
- cls=cls.__name__, range=range_, isLog=isLog):
- axis._setLogarithmic(isLog)
- self.plot.resetZoom()
- self.qapp.processEvents()
- self.assertEqual(axis.getLimits(), plotRange)
-
- axis._setLogarithmic(False)
- self.plot.clear()
-
-
-class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
- """Basic tests for addCurve with log scale axes"""
-
- # Test data
- xData = numpy.arange(1000) + 1
- yData = xData ** 2
-
- def _setLabels(self):
- self.plot.getXAxis().setLabel('X')
- self.plot.getYAxis().setLabel('X * X')
-
- def testPlotCurveLogX(self):
- self._setLabels()
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('Curve X: Log Y: Linear')
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
-
- def testPlotCurveLogY(self):
- self._setLabels()
- self.plot.getYAxis()._setLogarithmic(True)
-
- self.plot.setGraphTitle('Curve X: Linear Y: Log')
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
-
- def testPlotCurveLogXY(self):
- self._setLabels()
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
-
- self.plot.setGraphTitle('Curve X: Log Y: Log')
-
- self.plot.addCurve(self.xData, self.yData,
- legend="curve",
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
-
- def testPlotCurveErrorLogXY(self):
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
-
- # Every second error leads to negative number
- errors = numpy.ones_like(self.xData)
- errors[::2] = self.xData[::2] + 1
-
- tests = [ # name, xerror, yerror
- ('xerror=3', 3, None),
- ('xerror=N array', errors, None),
- ('xerror=Nx1 array', errors.reshape(len(errors), 1), None),
- ('xerror=2xN array', numpy.array((errors, errors)), None),
- ('yerror=6', None, 6),
- ('yerror=N array', None, errors ** 2),
- ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)),
- ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2),
- ]
-
- for name, xError, yError in tests:
- with self.subTest(name):
- self.plot.setGraphTitle(name)
- self.plot.addCurve(self.xData, self.yData,
- legend=name,
- xerror=xError, yerror=yError,
- replace=False, resetzoom=True,
- color='green', linestyle="-", symbol='o')
-
- self.qapp.processEvents()
-
- self.plot.clear()
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- def testPlotCurveToggleLog(self):
- """Add a curve with negative data and toggle log axis"""
- arange = numpy.arange(1000) + 1
- tests = [ # name, xData, yData
- ('x>0, some negative y', arange, arange - 500),
- ('x>0, y<0', arange, -arange),
- ('some negative x, y>0', arange - 500, arange),
- ('x<0, y>0', -arange, arange),
- ('some negative x and y', arange - 500, arange - 500),
- ('x<0, y<0', -arange, -arange),
- ]
-
- for name, xData, yData in tests:
- with self.subTest(name):
- self.plot.addCurve(xData, yData, resetzoom=True)
- self.qapp.processEvents()
-
- # no log axis
- xLim = self.plot.getXAxis().getLimits()
- self.assertEqual(xLim, (min(xData), max(xData)))
- yLim = self.plot.getYAxis().getLimits()
- self.assertEqual(yLim, (min(yData), max(yData)))
-
- # x axis log
- self.plot.getXAxis()._setLogarithmic(True)
- self.qapp.processEvents()
-
- xLim = self.plot.getXAxis().getLimits()
- yLim = self.plot.getYAxis().getLimits()
- positives = xData > 0
- if numpy.any(positives):
- self.assertTrue(numpy.allclose(
- xLim, (min(xData[positives]), max(xData[positives]))))
- self.assertEqual(
- yLim, (min(yData[positives]), max(yData[positives])))
- else: # No positive x in the curve
- self.assertEqual(xLim, (1., 100.))
- self.assertEqual(yLim, (1., 100.))
-
- # x axis and y axis log
- self.plot.getYAxis()._setLogarithmic(True)
- self.qapp.processEvents()
-
- xLim = self.plot.getXAxis().getLimits()
- yLim = self.plot.getYAxis().getLimits()
- positives = numpy.logical_and(xData > 0, yData > 0)
- if numpy.any(positives):
- self.assertTrue(numpy.allclose(
- xLim, (min(xData[positives]), max(xData[positives]))))
- self.assertTrue(numpy.allclose(
- yLim, (min(yData[positives]), max(yData[positives]))))
- else: # No positive x and y in the curve
- self.assertEqual(xLim, (1., 100.))
- self.assertEqual(yLim, (1., 100.))
-
- # y axis log
- self.plot.getXAxis()._setLogarithmic(False)
- self.qapp.processEvents()
-
- xLim = self.plot.getXAxis().getLimits()
- yLim = self.plot.getYAxis().getLimits()
- positives = yData > 0
- if numpy.any(positives):
- self.assertEqual(
- xLim, (min(xData[positives]), max(xData[positives])))
- self.assertTrue(numpy.allclose(
- yLim, (min(yData[positives]), max(yData[positives]))))
- else: # No positive y in the curve
- self.assertEqual(xLim, (1., 100.))
- self.assertEqual(yLim, (1., 100.))
-
- # no log axis
- self.plot.getYAxis()._setLogarithmic(False)
- self.qapp.processEvents()
-
- xLim = self.plot.getXAxis().getLimits()
- self.assertEqual(xLim, (min(xData), max(xData)))
- yLim = self.plot.getYAxis().getLimits()
- self.assertEqual(yLim, (min(yData), max(yData)))
-
- self.plot.clear()
- self.plot.resetZoom()
- self.qapp.processEvents()
-
-
-class TestPlotImageLog(PlotWidgetTestCase):
- """Basic tests for addImage with log scale axes."""
-
- def setUp(self):
- super(TestPlotImageLog, self).setUp()
-
- self.plot.getXAxis().setLabel('Columns')
- self.plot.getYAxis().setLabel('Rows')
-
- def testPlotColormapGrayLogX(self):
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('CMap X: Log Y: Linear')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1",
- origin=(1., 1.), scale=(1., 1.),
- resetzoom=False, colormap=colormap)
- self.plot.resetZoom()
-
- def testPlotColormapGrayLogY(self):
- self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('CMap X: Linear Y: Log')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1",
- origin=(1., 1.), scale=(1., 1.),
- resetzoom=False, colormap=colormap)
- self.plot.resetZoom()
-
- def testPlotColormapGrayLogXY(self):
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('CMap X: Log Y: Log')
-
- colormap = Colormap(name='gray',
- normalization='linear',
- vmin=None,
- vmax=None)
- self.plot.addImage(DATA_2D, legend="image 1",
- origin=(1., 1.), scale=(1., 1.),
- resetzoom=False, colormap=colormap)
- self.plot.resetZoom()
-
- def testPlotRgbRgbaLogXY(self):
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
- self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log')
-
- rgb = numpy.array(
- (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
- ((0, 128, 0), (0, 128, 128), (0, 128, 256))),
- dtype=numpy.uint8)
-
- self.plot.addImage(rgb, legend="rgb",
- origin=(1, 1), scale=(10, 10),
- resetzoom=False)
-
- rgba = numpy.array(
- (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
- ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
- dtype=numpy.float32)
-
- self.plot.addImage(rgba, legend="rgba",
- origin=(5., 5.), scale=(10., 10.),
- resetzoom=False)
- self.plot.resetZoom()
-
-
-class TestPlotMarkerLog(PlotWidgetTestCase):
- """Basic tests for markers on log scales"""
-
- # Test marker parameters
- markers = [ # x, y, color, selectable, draggable
- (10., 10., 'blue', False, False),
- (20., 20., 'red', False, False),
- (40., 100., 'green', True, False),
- (40., 500., 'gray', True, True),
- (60., 800., 'black', False, True),
- ]
-
- def setUp(self):
- super(TestPlotMarkerLog, self).setUp()
-
- self.plot.getYAxis().setLabel('Rows')
- self.plot.getXAxis().setLabel('Columns')
- self.plot.getXAxis().setAutoScale(False)
- self.plot.getYAxis().setAutoScale(False)
- self.plot.setKeepDataAspectRatio(False)
- self.plot.setLimits(1., 100., 1., 1000.)
- self.plot.getXAxis()._setLogarithmic(True)
- self.plot.getYAxis()._setLogarithmic(True)
-
- def testPlotMarkerXLog(self):
- self.plot.setGraphTitle('Markers X, Log axes')
-
- for x, _, color, select, drag in self.markers:
- name = str(x)
- if select:
- name += " sel."
- if drag:
- name += " drag"
- self.plot.addXMarker(x, name, name, color, select, drag)
- self.plot.resetZoom()
-
- def testPlotMarkerYLog(self):
- self.plot.setGraphTitle('Markers Y, Log axes')
-
- for _, y, color, select, drag in self.markers:
- name = str(y)
- if select:
- name += " sel."
- if drag:
- name += " drag"
- self.plot.addYMarker(y, name, name, color, select, drag)
- self.plot.resetZoom()
-
- def testPlotMarkerPtLog(self):
- self.plot.setGraphTitle('Markers Pt, Log axes')
-
- for x, y, color, select, drag in self.markers:
- name = "{0},{1}".format(x, y)
- if select:
- name += " sel."
- if drag:
- name += " drag"
- self.plot.addMarker(x, y, name, name, color, select, drag)
- self.plot.resetZoom()
-
-
-class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
- """Test [get|set]Backend to switch backend"""
-
- def testSwitchBackend(self):
- """Test switching a plot with a few items"""
- backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'}
- if test_options.WITH_GL_TEST:
- backends['gl'] = 'BackendOpenGL'
-
- self.plot.addImage(numpy.arange(100).reshape(10, 10))
- self.plot.addCurve((-3, -2, -1), (1, 2, 3))
- self.plot.resetZoom()
- xlimits = self.plot.getXAxis().getLimits()
- ylimits = self.plot.getYAxis().getLimits()
- items = self.plot.getItems()
- self.assertEqual(len(items), 2)
-
- for backend, className in backends.items():
- with self.subTest(backend=backend):
- self.plot.setBackend(backend)
- self.plot.replot()
-
- retrievedBackend = self.plot.getBackend()
- self.assertEqual(type(retrievedBackend).__name__, className)
- self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
- self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
- self.assertEqual(self.plot.getItems(), items)
-
-
-class TestPlotWidgetSelection(PlotWidgetTestCase):
- """Test PlotWidget.selection and active items handling"""
-
- def _checkSelection(self, selection, current=None, selected=()):
- """Check current item and selected items."""
- self.assertIs(selection.getCurrentItem(), current)
- self.assertEqual(selection.getSelectedItems(), selected)
-
- def testSyncWithActiveItems(self):
- """Test update of PlotWidgetSelection according to active items"""
- listener = SignalListener()
-
- selection = self.plot.selection()
- selection.sigCurrentItemChanged.connect(listener)
- self._checkSelection(selection)
-
- # Active item is current
- self.plot.addImage(((0, 1), (2, 3)), legend='image')
- image = self.plot.getActiveImage()
- self.assertEqual(listener.callCount(), 1)
- self._checkSelection(selection, image, (image,))
-
- # No active = no current
- self.plot.setActiveImage(None)
- self.assertEqual(listener.callCount(), 2)
- self._checkSelection(selection)
-
- # Active item is current
- self.plot.setActiveImage('image')
- self.assertEqual(listener.callCount(), 3)
- self._checkSelection(selection, image, (image,))
-
- # Mosted recently "actived" item is current
- self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
- scatter = self.plot.getActiveScatter()
- self.assertEqual(listener.callCount(), 4)
- self._checkSelection(selection, scatter, (scatter, image))
-
- # Previously mosted recently "actived" item is current
- self.plot.setActiveScatter(None)
- self.assertEqual(listener.callCount(), 5)
- self._checkSelection(selection, image, (image,))
-
- # Mosted recently "actived" item is current
- self.plot.setActiveScatter('scatter')
- self.assertEqual(listener.callCount(), 6)
- self._checkSelection(selection, scatter, (scatter, image))
-
- # No active = no current
- self.plot.setActiveImage(None)
- self.plot.setActiveScatter(None)
- self.assertEqual(listener.callCount(), 7)
- self._checkSelection(selection)
-
- # Mosted recently "actived" item is current
- self.plot.setActiveScatter('scatter')
- self.assertEqual(listener.callCount(), 8)
- self.plot.setActiveImage('image')
- self.assertEqual(listener.callCount(), 9)
- self._checkSelection(selection, image, (image, scatter))
-
- # Add a curve which is not active by default
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
- curve = self.plot.getCurve('curve')
- self.assertEqual(listener.callCount(), 9)
- self._checkSelection(selection, image, (image, scatter))
-
- # Mosted recently "actived" item is current
- self.plot.setActiveCurve('curve')
- self.assertEqual(listener.callCount(), 10)
- self._checkSelection(selection, curve, (curve, image, scatter))
-
- # Add a curve which is not active by default
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve2')
- curve2 = self.plot.getCurve('curve2')
- self.assertEqual(listener.callCount(), 10)
- self._checkSelection(selection, curve, (curve, image, scatter))
-
- # Mosted recently "actived" item is current, previous curve is removed
- self.plot.setActiveCurve('curve2')
- self.assertEqual(listener.callCount(), 11)
- self._checkSelection(selection, curve2, (curve2, image, scatter))
-
- # No items = no current
- self.plot.clear()
- self.assertEqual(listener.callCount(), 12)
- self._checkSelection(selection)
-
- def testPlotWidgetWithItems(self):
- """Test init of selection on a plot with items"""
- self.plot.addImage(((0, 1), (2, 3)), legend='image')
- self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
- self.plot.setActiveCurve('curve')
-
- selection = self.plot.selection()
- self.assertIsNotNone(selection.getCurrentItem())
- selected = selection.getSelectedItems()
- self.assertEqual(len(selected), 3)
- self.assertIn(self.plot.getActiveCurve(), selected)
- self.assertIn(self.plot.getActiveImage(), selected)
- self.assertIn(self.plot.getActiveScatter(), selected)
-
- def testSetCurrentItem(self):
- """Test setCurrentItem"""
- # Add items to the plot
- self.plot.addImage(((0, 1), (2, 3)), legend='image')
- image = self.plot.getActiveImage()
- self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
- scatter = self.plot.getActiveScatter()
- self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
- self.plot.setActiveCurve('curve')
- curve = self.plot.getActiveCurve()
-
- selection = self.plot.selection()
- self.assertIsNotNone(selection.getCurrentItem())
- self.assertEqual(len(selection.getSelectedItems()), 3)
-
- # Set current to None reset all active items
- selection.setCurrentItem(None)
- self._checkSelection(selection)
- self.assertIsNone(self.plot.getActiveCurve())
- self.assertIsNone(self.plot.getActiveImage())
- self.assertIsNone(self.plot.getActiveScatter())
-
- # Set current to an item makes it active
- selection.setCurrentItem(image)
- self._checkSelection(selection, image, (image,))
- self.assertIsNone(self.plot.getActiveCurve())
- self.assertIs(self.plot.getActiveImage(), image)
- self.assertIsNone(self.plot.getActiveScatter())
-
- # Set current to an item makes it active and keeps other active
- selection.setCurrentItem(curve)
- self._checkSelection(selection, curve, (curve, image))
- self.assertIs(self.plot.getActiveCurve(), curve)
- self.assertIs(self.plot.getActiveImage(), image)
- self.assertIsNone(self.plot.getActiveScatter())
-
- # Set current to an item makes it active and keeps other active
- selection.setCurrentItem(scatter)
- self._checkSelection(selection, scatter, (scatter, curve, image))
- self.assertIs(self.plot.getActiveCurve(), curve)
- self.assertIs(self.plot.getActiveImage(), image)
- self.assertIs(self.plot.getActiveScatter(), scatter)
-
-
-def suite():
- testClasses = (TestPlotWidget,
- TestPlotImage,
- TestPlotCurve,
- TestPlotHistogram,
- TestPlotScatter,
- TestPlotMarker,
- TestPlotItem,
- TestPlotAxes,
- TestPlotActiveCurveImage,
- TestPlotEmptyLog,
- TestPlotCurveLog,
- TestPlotImageLog,
- TestPlotMarkerLog,
- TestPlotWidgetSelection)
-
- test_suite = unittest.TestSuite()
-
- # Tests with matplotlib
- for testClass in testClasses:
- test_suite.addTest(parameterize(testClass, backend=None))
-
- test_suite.addTest(parameterize(TestSpecialBackend, backend=u"mpl"))
- if sys.version_info[0] == 2:
- test_suite.addTest(parameterize(TestSpecialBackend, backend=b"mpl"))
-
- if test_options.WITH_GL_TEST:
- # Tests with OpenGL backend
- for testClass in testClasses:
- test_suite.addTest(parameterize(testClass, backend='gl'))
-
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- TestPlotWidgetSwitchBackend))
-
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotWidgetNoBackend.py b/silx/gui/plot/test/testPlotWidgetNoBackend.py
deleted file mode 100644
index edd3cd7..0000000
--- a/silx/gui/plot/test/testPlotWidgetNoBackend.py
+++ /dev/null
@@ -1,631 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWidget with 'none' backend"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-from functools import reduce
-from silx.utils.testutils import ParametricTestCase
-
-import numpy
-
-from silx.gui.plot.PlotWidget import PlotWidget
-from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges
-
-
-class TestPlot(unittest.TestCase):
- """Basic tests of Plot without backend"""
-
- def testPlotTitleLabels(self):
- """Create a Plot and set the labels"""
-
- plot = PlotWidget(backend='none')
-
- title, xlabel, ylabel = 'the title', 'x label', 'y label'
- plot.setGraphTitle(title)
- plot.getXAxis().setLabel(xlabel)
- plot.getYAxis().setLabel(ylabel)
-
- self.assertEqual(plot.getGraphTitle(), title)
- self.assertEqual(plot.getXAxis().getLabel(), xlabel)
- self.assertEqual(plot.getYAxis().getLabel(), ylabel)
-
- def testAddNoRemove(self):
- """add objects to the Plot"""
-
- plot = PlotWidget(backend='none')
- plot.addCurve(x=(1, 2, 3), y=(3, 2, 1))
- plot.addImage(numpy.arange(100.).reshape(10, -1))
- plot.addShape(numpy.array((1., 10.)),
- numpy.array((10., 10.)),
- shape="rectangle")
- plot.addXMarker(10.)
-
-
-class TestPlotRanges(ParametricTestCase):
- """Basic tests of Plot data ranges without backend"""
-
- _getValidValues = {True: lambda ar: ar > 0,
- False: lambda ar: numpy.ones(shape=ar.shape,
- dtype=bool)}
-
- @staticmethod
- def _getRanges(arrays, are_logs):
- gen = (TestPlotRanges._getValidValues[is_log](ar)
- for (ar, is_log) in zip(arrays, are_logs))
- indices = numpy.where(reduce(numpy.logical_and, gen))[0]
- if len(indices) > 0:
- ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays]
- else:
- ranges = [None] * len(arrays)
-
- return ranges
-
- @staticmethod
- def _getRangesMinmax(ranges):
- # TODO : error if None in ranges.
- rangeMin = numpy.min([rng[0] for rng in ranges])
- rangeMax = numpy.max([rng[1] for rng in ranges])
- return rangeMin, rangeMax
-
- def testDataRangeNoPlot(self):
- """empty plot data range"""
-
- plot = PlotWidget(backend='none')
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- self.assertIsNone(dataRange.x)
- self.assertIsNone(dataRange.y)
- self.assertIsNone(dataRange.yright)
-
- def testDataRangeLeft(self):
- """left axis range"""
-
- plot = PlotWidget(backend='none')
-
- xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
- yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
-
- plot.addCurve(x=xData,
- y=yData,
- legend='plot_0',
- yaxis='left')
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRange, yRange = self._getRanges([xData, yData],
- [logX, logY])
- self.assertSequenceEqual(dataRange.x, xRange)
- self.assertSequenceEqual(dataRange.y, yRange)
- self.assertIsNone(dataRange.yright)
-
- def testDataRangeRight(self):
- """right axis range"""
-
- plot = PlotWidget(backend='none')
- xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
- yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
- plot.addCurve(x=xData,
- y=yData,
- legend='plot_0',
- yaxis='right')
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRange, yRange = self._getRanges([xData, yData],
- [logX, logY])
- self.assertSequenceEqual(dataRange.x, xRange)
- self.assertIsNone(dataRange.y)
- self.assertSequenceEqual(dataRange.yright, yRange)
-
- def testDataRangeImage(self):
- """image data range"""
-
- origin = (-10, 25)
- scale = (3., 8.)
- image = numpy.arange(100.).reshape(20, 5)
-
- plot = PlotWidget(backend='none')
- plot.addImage(image,
- origin=origin, scale=scale)
-
- xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
-
- ranges = {(False, False): (xRange, yRange),
- (True, False): (None, None),
- (True, True): (None, None),
- (False, True): (None, None)}
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRange, yRange = ranges[logX, logY]
- self.assertTrue(numpy.array_equal(dataRange.x, xRange),
- msg='{0} != {1}'.format(dataRange.x, xRange))
- self.assertTrue(numpy.array_equal(dataRange.y, yRange),
- msg='{0} != {1}'.format(dataRange.y, yRange))
- self.assertIsNone(dataRange.yright)
-
- def testDataRangeLeftRight(self):
- """right+left axis range"""
-
- plot = PlotWidget(backend='none')
-
- xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
- yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
- plot.addCurve(x=xData_l,
- y=yData_l,
- legend='plot_l',
- yaxis='left')
-
- xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
- yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
- plot.addCurve(x=xData_r,
- y=yData_r,
- legend='plot_r',
- yaxis='right')
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
- [logX, logY])
- xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
- [logX, logY])
- xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
- self.assertSequenceEqual(dataRange.x, xRangeLR)
- self.assertSequenceEqual(dataRange.y, yRangeL)
- self.assertSequenceEqual(dataRange.yright, yRangeR)
-
- def testDataRangeCurveImage(self):
- """right+left+image axis range"""
-
- # overlapping ranges :
- # image sets x min and y max
- # plot_left sets y min
- # plot_right sets x max (and yright)
- plot = PlotWidget(backend='none')
-
- origin = (-10, 5)
- scale = (3., 8.)
- image = numpy.arange(100.).reshape(20, 5)
-
- plot.addImage(image,
- origin=origin, scale=scale, legend='image')
-
- xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
- yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
- plot.addCurve(x=xData_l,
- y=yData_l,
- legend='plot_l',
- yaxis='left')
-
- xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1
- yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
- plot.addCurve(x=xData_r,
- y=yData_r,
- legend='plot_r',
- yaxis='right')
-
- imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
- [logX, logY])
- xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
- [logX, logY])
- if logX or logY:
- xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
- else:
- xRangeLR = self._getRangesMinmax([xRangeL,
- xRangeR,
- imgXRange])
- yRangeL = self._getRangesMinmax([yRangeL, imgYRange])
- self.assertSequenceEqual(dataRange.x, xRangeLR)
- self.assertSequenceEqual(dataRange.y, yRangeL)
- self.assertSequenceEqual(dataRange.yright, yRangeR)
-
- def testDataRangeImageNegativeScaleX(self):
- """image data range, negative scale"""
-
- origin = (-10, 25)
- scale = (-3., 8.)
- image = numpy.arange(100.).reshape(20, 5)
-
- plot = PlotWidget(backend='none')
- plot.addImage(image,
- origin=origin, scale=scale)
-
- xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- xRange.sort() # negative scale!
- yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
-
- ranges = {(False, False): (xRange, yRange),
- (True, False): (None, None),
- (True, True): (None, None),
- (False, True): (None, None)}
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRange, yRange = ranges[logX, logY]
- self.assertTrue(numpy.array_equal(dataRange.x, xRange),
- msg='{0} != {1}'.format(dataRange.x, xRange))
- self.assertTrue(numpy.array_equal(dataRange.y, yRange),
- msg='{0} != {1}'.format(dataRange.y, yRange))
- self.assertIsNone(dataRange.yright)
-
- def testDataRangeImageNegativeScaleY(self):
- """image data range, negative scale"""
-
- origin = (-10, 25)
- scale = (3., -8.)
- image = numpy.arange(100.).reshape(20, 5)
-
- plot = PlotWidget(backend='none')
- plot.addImage(image,
- origin=origin, scale=scale)
-
- xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
- yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
- yRange.sort() # negative scale!
-
- ranges = {(False, False): (xRange, yRange),
- (True, False): (None, None),
- (True, True): (None, None),
- (False, True): (None, None)}
-
- for logX, logY in ((False, False),
- (True, False),
- (True, True),
- (False, True),
- (False, False)):
- with self.subTest(logX=logX, logY=logY):
- plot.getXAxis()._setLogarithmic(logX)
- plot.getYAxis()._setLogarithmic(logY)
- dataRange = plot.getDataRange()
- xRange, yRange = ranges[logX, logY]
- self.assertTrue(numpy.array_equal(dataRange.x, xRange),
- msg='{0} != {1}'.format(dataRange.x, xRange))
- self.assertTrue(numpy.array_equal(dataRange.y, yRange),
- msg='{0} != {1}'.format(dataRange.y, yRange))
- self.assertIsNone(dataRange.yright)
-
- def testDataRangeHiddenCurve(self):
- """curves with a hidden curve"""
- plot = PlotWidget(backend='none')
- plot.addCurve((0, 1), (0, 1), legend='shown')
- plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden')
- range1 = plot.getDataRange()
- self.assertEqual(range1.x, (0, 2))
- self.assertEqual(range1.y, (0, 5))
- plot.hideCurve('hidden')
- range2 = plot.getDataRange()
- self.assertEqual(range2.x, (0, 1))
- self.assertEqual(range2.y, (0, 1))
-
-
-class TestPlotGetCurveImage(unittest.TestCase):
- """Test of plot getCurve and getImage methods"""
-
- def testGetCurve(self):
- """PlotWidget.getCurve and Plot.getActiveCurve tests"""
-
- plot = PlotWidget(backend='none')
-
- # No curve
- curve = plot.getCurve()
- self.assertIsNone(curve) # No curve
-
- plot.setActiveCurveHandling(True)
- plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0')
- plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1')
- plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2')
- plot.setActiveCurve('curve 0')
-
- # Active curve
- active = plot.getActiveCurve()
- self.assertEqual(active.getName(), 'curve 0')
- curve = plot.getCurve()
- self.assertEqual(curve.getName(), 'curve 0')
-
- # No active curve and curves
- plot.setActiveCurveHandling(False)
- active = plot.getActiveCurve()
- self.assertIsNone(active) # No active curve
- curve = plot.getCurve()
- self.assertEqual(curve.getName(), 'curve 2') # Last added curve
-
- # Last curve hidden
- plot.hideCurve('curve 2', True)
- curve = plot.getCurve()
- self.assertEqual(curve.getName(), 'curve 1') # Last added curve
-
- # All curves hidden
- plot.hideCurve('curve 1', True)
- plot.hideCurve('curve 0', True)
- curve = plot.getCurve()
- self.assertIsNone(curve)
-
- def testGetCurveOldApi(self):
- """old API PlotWidget.getCurve and Plot.getActiveCurve tests"""
-
- plot = PlotWidget(backend='none')
-
- # No curve
- curve = plot.getCurve()
- self.assertIsNone(curve) # No curve
-
- plot.setActiveCurveHandling(True)
- x = numpy.arange(10.).astype(numpy.float32)
- y = x * x
- plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"])
- plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything")
- plot.setActiveCurve('curve 0')
-
- # Active curve (4 elements)
- xOut, yOut, legend, info = plot.getActiveCurve()[:4]
- self.assertEqual(legend, 'curve 0')
- self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data')
- self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data')
-
- # Active curve (5 elements)
- xOut, yOut, legend, info, params = plot.getCurve("curve 1")
- self.assertEqual(legend, 'curve 1')
- self.assertEqual(info, 'anything')
- self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data')
- self.assertTrue(numpy.allclose(yOut, 2 * x), 'curve 1 wrong y data')
-
- def testGetImage(self):
- """PlotWidget.getImage and PlotWidget.getActiveImage tests"""
-
- plot = PlotWidget(backend='none')
-
- # No image
- image = plot.getImage()
- self.assertIsNone(image)
-
- plot.addImage(((0, 1), (2, 3)), legend='image 0')
- plot.addImage(((0, 1), (2, 3)), legend='image 1')
-
- # Active image
- active = plot.getActiveImage()
- self.assertEqual(active.getName(), 'image 0')
- image = plot.getImage()
- self.assertEqual(image.getName(), 'image 0')
-
- # No active image
- plot.addImage(((0, 1), (2, 3)), legend='image 2')
- plot.setActiveImage(None)
- active = plot.getActiveImage()
- self.assertIsNone(active)
- image = plot.getImage()
- self.assertEqual(image.getName(), 'image 2')
-
- # Active image
- plot.setActiveImage('image 1')
- active = plot.getActiveImage()
- self.assertEqual(active.getName(), 'image 1')
- image = plot.getImage()
- self.assertEqual(image.getName(), 'image 1')
-
- def testGetImageOldApi(self):
- """PlotWidget.getImage and PlotWidget.getActiveImage old API tests"""
-
- plot = PlotWidget(backend='none')
-
- # No image
- image = plot.getImage()
- self.assertIsNone(image)
-
- image = numpy.arange(10).astype(numpy.float32)
- image.shape = 5, 2
-
- plot.addImage(image, legend='image 0', info=["Hi!"])
-
- # Active image
- data, legend, info, something, params = plot.getActiveImage()
- self.assertEqual(legend, 'image 0')
- self.assertEqual(info, ["Hi!"])
- self.assertTrue(numpy.allclose(data, image), "image 0 data not correct")
-
- def testGetAllImages(self):
- """PlotWidget.getAllImages test"""
-
- plot = PlotWidget(backend='none')
-
- # No image
- images = plot.getAllImages()
- self.assertEqual(len(images), 0)
-
- # 2 images
- data = numpy.arange(100).reshape(10, 10)
- plot.addImage(data, legend='1')
- plot.addImage(data, origin=(10, 10), legend='2')
- images = plot.getAllImages(just_legend=True)
- self.assertEqual(list(images), ['1', '2'])
- images = plot.getAllImages(just_legend=False)
- self.assertEqual(len(images), 2)
- self.assertEqual(images[0].getName(), '1')
- self.assertEqual(images[1].getName(), '2')
-
-
-class TestPlotAddScatter(unittest.TestCase):
- """Test of plot addScatter"""
-
- def testAddGetScatter(self):
-
- plot = PlotWidget(backend='none')
-
- # No curve
- scatter = plot._getItem(kind="scatter")
- self.assertIsNone(scatter) # No curve
-
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
- plot._setActiveItem('scatter', 'scatter 0')
-
- # Active scatter
- active = plot._getActiveItem(kind='scatter')
- self.assertEqual(active.getName(), 'scatter 0')
-
- # check default values
- self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE)
- self.assertEqual(active.getSymbol(), "o")
- self.assertAlmostEqual(active.getAlpha(), 1.0)
-
- # modify parameters
- active.setSymbolSize(20.5)
- active.setSymbol("d")
- active.setAlpha(0.777)
-
- s0 = plot.getScatter("scatter 0")
-
- self.assertAlmostEqual(s0.getSymbolSize(), 20.5)
- self.assertEqual(s0.getSymbol(), "d")
- self.assertAlmostEqual(s0.getAlpha(), 0.777)
-
- scatter1 = plot._getItem(kind='scatter', legend='scatter 1')
- self.assertEqual(scatter1.getName(), 'scatter 1')
-
- def testGetAllScatters(self):
- """PlotWidget.getAllImages test"""
-
- plot = PlotWidget(backend='none')
-
- items = plot.getItems()
- self.assertEqual(len(items), 0)
-
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
- plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
-
- items = plot.getItems()
- self.assertEqual(len(items), 3)
- self.assertEqual(items[0].getName(), 'scatter 0')
- self.assertEqual(items[1].getName(), 'scatter 1')
- self.assertEqual(items[2].getName(), 'scatter 2')
-
-
-class TestPlotHistogram(unittest.TestCase):
- """Basic tests for histogram."""
-
- def testEdges(self):
- x = numpy.array([0, 1, 2])
- edgesRight = numpy.array([0, 1, 2, 3])
- edgesLeft = numpy.array([-1, 0, 1, 2])
- edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5])
-
- # testing x values for right
- edges = _computeEdges(x, 'right')
- numpy.testing.assert_array_equal(edges, edgesRight)
-
- edges = _computeEdges(x, 'center')
- numpy.testing.assert_array_equal(edges, edgesCenter)
-
- edges = _computeEdges(x, 'left')
- numpy.testing.assert_array_equal(edges, edgesLeft)
-
- def testHistogramCurve(self):
- y = numpy.array([3, 2, 5])
- edges = numpy.array([0, 1, 2, 3])
-
- xHisto, yHisto = _getHistogramCurve(y, edges)
- numpy.testing.assert_array_equal(
- yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
-
- y = numpy.array([-3, 2, 5, 0])
- edges = numpy.array([-2, -1, 0, 1, 2])
- xHisto, yHisto = _getHistogramCurve(y, edges)
- numpy.testing.assert_array_equal(
- yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0]))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestPlot, TestPlotRanges, TestPlotGetCurveImage,
- TestPlotHistogram, TestPlotAddScatter):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testPlotWindow.py b/silx/gui/plot/test/testPlotWindow.py
deleted file mode 100644
index e12b756..0000000
--- a/silx/gui/plot/test/testPlotWindow.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWindow"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "27/06/2017"
-
-
-import unittest
-import numpy
-
-from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
-from silx.test.utils import test_options
-
-from silx.gui import qt
-from silx.gui.plot import PlotWindow
-from silx.gui.colors import Colormap
-
-class TestPlotWindow(TestCaseQt):
- """Base class for tests of PlotWindow."""
-
- def setUp(self):
- super(TestPlotWindow, self).setUp()
- self.plot = PlotWindow()
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestPlotWindow, self).tearDown()
-
- def testActions(self):
- """Test the actions QToolButtons"""
- self.plot.setLimits(1, 100, 1, 100)
-
- checkList = [ # QAction, Plot state getter
- (self.plot.xAxisAutoScaleAction, self.plot.getXAxis().isAutoScale),
- (self.plot.yAxisAutoScaleAction, self.plot.getYAxis().isAutoScale),
- (self.plot.xAxisLogarithmicAction, self.plot.getXAxis()._isLogarithmic),
- (self.plot.yAxisLogarithmicAction, self.plot.getYAxis()._isLogarithmic),
- (self.plot.gridAction, self.plot.getGraphGrid),
- ]
-
- for action, getter in checkList:
- self.mouseMove(self.plot)
- initialState = getter()
- toolButton = getQToolButtonFromAction(action)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
- self.assertNotEqual(getter(), initialState,
- msg='"%s" state not changed' % action.text())
-
- self.mouseClick(toolButton, qt.Qt.LeftButton)
- self.assertEqual(getter(), initialState,
- msg='"%s" state not changed' % action.text())
-
- # Trigger a zoom reset
- self.mouseMove(self.plot)
- resetZoomAction = self.plot.resetZoomAction
- toolButton = getQToolButtonFromAction(resetZoomAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- def testDockWidgets(self):
- """Test add/remove dock widgets"""
- dock1 = qt.QDockWidget('Test 1')
- dock1.setWidget(qt.QLabel('Test 1'))
-
- self.plot.addTabbedDockWidget(dock1)
- self.qapp.processEvents()
-
- self.plot.removeDockWidget(dock1)
- self.qapp.processEvents()
-
- dock2 = qt.QDockWidget('Test 2')
- dock2.setWidget(qt.QLabel('Test 2'))
-
- self.plot.addTabbedDockWidget(dock2)
- self.qapp.processEvents()
-
- if qt.BINDING != 'PySide2':
- # Weird bug with PySide2 later upon gc.collect() when getting the layout
- self.assertNotEqual(self.plot.layout().indexOf(dock2),
- -1,
- "dock2 not properly displayed")
-
- def testToolAspectRatio(self):
- self.plot.toolBar()
- self.plot.keepDataAspectRatioButton.keepDataAspectRatio()
- self.assertTrue(self.plot.isKeepDataAspectRatio())
- self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio()
- self.assertFalse(self.plot.isKeepDataAspectRatio())
-
- def testToolYAxisOrigin(self):
- self.plot.toolBar()
- self.plot.yAxisInvertedButton.setYAxisUpward()
- self.assertFalse(self.plot.getYAxis().isInverted())
- self.plot.yAxisInvertedButton.setYAxisDownward()
- self.assertTrue(self.plot.getYAxis().isInverted())
-
- def testColormapAutoscaleCache(self):
- # Test that the min/max cache is not computed twice
-
- old = Colormap._computeAutoscaleRange
- self._count = 0
- def _computeAutoscaleRange(colormap, data):
- self._count = self._count + 1
- return 10, 20
- Colormap._computeAutoscaleRange = _computeAutoscaleRange
- try:
- colormap = Colormap(name='red')
- self.plot.setVisible(True)
-
- # Add an image
- data = numpy.arange(8**2).reshape(8, 8)
- self.plot.addImage(data, legend="foo", colormap=colormap)
- self.plot.setActiveImage("foo")
-
- # Use the colorbar
- self.plot.getColorBarWidget().setVisible(True)
- self.qWait(50)
-
- # Remove and add again the same item
- image = self.plot.getImage("foo")
- self.plot.removeImage("foo")
- self.plot.addItem(image)
- self.qWait(50)
- finally:
- Colormap._computeAutoscaleRange = old
- self.assertEqual(self._count, 1)
- del self._count
-
- @unittest.skipUnless(test_options.WITH_GL_TEST,
- test_options.WITH_QT_TEST_REASON)
- def testSwitchBackend(self):
- """Test switching an empty plot"""
- self.plot.resetZoom()
- xlimits = self.plot.getXAxis().getLimits()
- ylimits = self.plot.getYAxis().getLimits()
- isKeepAspectRatio = self.plot.isKeepDataAspectRatio()
-
- for backend in ('gl', 'mpl'):
- with self.subTest():
- self.plot.setBackend(backend)
- self.plot.replot()
- self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
- self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
- self.assertEqual(
- self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestPlotWindow))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testRoiStatsWidget.py b/silx/gui/plot/test/testRoiStatsWidget.py
deleted file mode 100644
index 378d499..0000000
--- a/silx/gui/plot/test/testRoiStatsWidget.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for ROIStatsWidget"""
-
-
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-from silx.gui.plot import PlotWindow
-from silx.gui.plot.stats.stats import Stats
-from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
-from silx.gui.plot.CurvesROIWidget import ROI
-from silx.gui.plot.items.roi import RectangleROI, PolygonROI
-from silx.gui.plot.StatsWidget import UpdateMode
-import unittest
-import numpy
-
-
-
-class _TestRoiStatsBase(TestCaseQt):
- """Base class for several unittest relative to ROIStatsWidget"""
- def setUp(self):
- TestCaseQt.setUp(self)
- # define plot
- self.plot = PlotWindow()
- self.plot.addImage(numpy.arange(10000).reshape(100, 100),
- legend='img1')
- self.img_item = self.plot.getImage('img1')
- self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56),
- legend='curve1')
- self.curve_item = self.plot.getCurve('curve1')
- self.plot.addHistogram(edges=numpy.linspace(0, 10, 56),
- histogram=numpy.arange(56), legend='histo1')
- self.histogram_item = self.plot.getHistogram(legend='histo1')
- self.plot.addScatter(x=numpy.linspace(0, 10, 56),
- y=numpy.linspace(0, 10, 56),
- value=numpy.arange(56),
- legend='scatter1')
- self.scatter_item = self.plot.getScatter(legend='scatter1')
-
- # stats widget
- self.statsWidget = ROIStatsWidget(plot=self.plot)
-
- # define stats
- stats = [
- ('sum', numpy.sum),
- ('mean', numpy.mean),
- ]
- self.statsWidget.setStats(stats=stats)
-
- # define rois
- self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy')
- self.rectangle_roi = RectangleROI()
- self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20))
- self.rectangle_roi.setName('Initial ROI')
- self.polygon_roi = PolygonROI()
- points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]])
- self.polygon_roi.setPoints(points)
-
- def statsTable(self):
- return self.statsWidget._statsROITable
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
- self.statsWidget.close()
- self.statsWidget = None
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
- self.plot.close()
- self.plot = None
- TestCaseQt.tearDown(self)
-
-
-class TestRoiStatsCouple(_TestRoiStatsBase):
- """
- Test different possible couple (roi, plotItem).
- Check that:
-
- * computation is correct if couple is valid
- * raise an error if couple is invalid
- """
- def testROICurve(self):
- """
- Test that the couple (ROI, curveItem) can be used for stats
- """
- item = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.curve_item)
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '253')
- self.assertEqual(tableItems['mean'].text(), '11.0')
-
- def testRectangleImage(self):
- """
- Test that the couple (RectangleROI, imageItem) can be used for stats
- """
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
- assert item is not None
- self.plot.addImage(numpy.ones(10000).reshape(100, 100),
- legend='img1')
- self.qapp.processEvents()
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), str(float(21*21)))
- self.assertEqual(tableItems['mean'].text(), '1.0')
-
- def testPolygonImage(self):
- """
- Test that the couple (PolygonROI, imageItem) can be used for stats
- """
- item = self.statsWidget.addItem(roi=self.polygon_roi,
- plotItem=self.img_item)
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '22750')
- self.assertEqual(tableItems['mean'].text(), '455.0')
-
- def testROIImage(self):
- """
- Test that the couple (ROI, imageItem) is raising an error
- """
- with self.assertRaises(TypeError):
- self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.img_item)
-
- def testRectangleCurve(self):
- """
- Test that the couple (rectangleROI, curveItem) is raising an error
- """
- with self.assertRaises(TypeError):
- self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.curve_item)
-
- def testROIHistogram(self):
- """
- Test that the couple (PolygonROI, imageItem) can be used for stats
- """
- item = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.histogram_item)
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '253')
- self.assertEqual(tableItems['mean'].text(), '11.0')
-
- def testROIScatter(self):
- """
- Test that the couple (PolygonROI, imageItem) can be used for stats
- """
- item = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.scatter_item)
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '253')
- self.assertEqual(tableItems['mean'].text(), '11.0')
-
-
-class TestRoiStatsAddRemoveItem(_TestRoiStatsBase):
- """Test adding and removing (roi, plotItem) items"""
- def testAddRemoveItems(self):
- item1 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.scatter_item)
- self.assertTrue(item1 is not None)
- self.assertEqual(self.statsTable().rowCount(), 1)
- item2 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.histogram_item)
- self.assertTrue(item2 is not None)
- self.assertEqual(self.statsTable().rowCount(), 2)
- # try to add twice the same item
- item3 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.histogram_item)
- self.assertTrue(item3 is None)
- self.assertEqual(self.statsTable().rowCount(), 2)
- item4 = self.statsWidget.addItem(roi=self.roi1D,
- plotItem=self.curve_item)
- self.assertTrue(item4 is not None)
- self.assertEqual(self.statsTable().rowCount(), 3)
-
- self.statsWidget.removeItem(plotItem=item4._plot_item,
- roi=item4._roi)
- self.assertEqual(self.statsTable().rowCount(), 2)
- # try to remove twice the same item
- self.statsWidget.removeItem(plotItem=item4._plot_item,
- roi=item4._roi)
- self.assertEqual(self.statsTable().rowCount(), 2)
- self.statsWidget.removeItem(plotItem=item2._plot_item,
- roi=item2._roi)
- self.statsWidget.removeItem(plotItem=item1._plot_item,
- roi=item1._roi)
- self.assertEqual(self.statsTable().rowCount(), 0)
-
-
-class TestRoiStatsRoiUpdate(_TestRoiStatsBase):
- """Test that the stats will be updated if the roi is updated"""
- def testChangeRoi(self):
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '445410')
- self.assertEqual(tableItems['mean'].text(), '1010.0')
-
- # update roi
- self.rectangle_roi.setOrigin(position=(10, 10))
- self.assertNotEqual(tableItems['sum'].text(), '445410')
- self.assertNotEqual(tableItems['mean'].text(), '1010.0')
-
- def testUpdateModeScenario(self):
- """Test update according to a simple scenario"""
- self.statsWidget._setUpdateMode(UpdateMode.AUTO)
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
-
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['sum'].text(), '445410')
- self.assertEqual(tableItems['mean'].text(), '1010.0')
- self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
- self.rectangle_roi.setOrigin(position=(10, 10))
- self.qapp.processEvents()
- self.assertNotEqual(tableItems['sum'].text(), '445410')
- self.assertNotEqual(tableItems['mean'].text(), '1010.0')
- self.statsWidget._updateAllStats(is_request=True)
- self.assertNotEqual(tableItems['sum'].text(), '445410')
- self.assertNotEqual(tableItems['mean'].text(), '1010.0')
-
-
-class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase):
- """Test that the stats will be updated if the plot item is updated"""
- def testChangeImage(self):
- self.statsWidget._setUpdateMode(UpdateMode.AUTO)
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
-
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['mean'].text(), '1010.0')
-
- # update plot
- self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
- legend='img1')
- self.assertNotEqual(tableItems['mean'].text(), '1059.5')
-
- def testUpdateModeScenario(self):
- """Test update according to a simple scenario"""
- self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
- item = self.statsWidget.addItem(roi=self.rectangle_roi,
- plotItem=self.img_item)
-
- assert item is not None
- tableItems = self.statsTable()._itemToTableItems(item)
- self.assertEqual(tableItems['mean'].text(), '1010.0')
- self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
- legend='img1')
- self.assertEqual(tableItems['mean'].text(), '1010.0')
- self.statsWidget._updateAllStats(is_request=True)
- self.assertEqual(tableItems['mean'].text(), '1110.0')
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestRoiStatsCouple, TestRoiStatsRoiUpdate,
- TestRoiStatsPlotItemUpdate):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testSaveAction.py b/silx/gui/plot/test/testSaveAction.py
deleted file mode 100644
index 0eb129d..0000000
--- a/silx/gui/plot/test/testSaveAction.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test the plot's save action (consistency of output)"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "28/11/2017"
-
-
-import unittest
-import tempfile
-import os
-
-from silx.gui.plot.test.utils import PlotWidgetTestCase
-
-from silx.gui.plot import PlotWidget
-from silx.gui.plot.actions.io import SaveAction
-
-
-class TestSaveActionSaveCurvesAsSpec(unittest.TestCase):
-
- def setUp(self):
- self.plot = PlotWidget(backend='none')
- self.saveAction = SaveAction(plot=self.plot)
-
- self.tempdir = tempfile.mkdtemp()
- self.out_fname = os.path.join(self.tempdir, "out.dat")
-
- def tearDown(self):
- os.unlink(self.out_fname)
- os.rmdir(self.tempdir)
-
- def testSaveMultipleCurvesAsSpec(self):
- """Test that labels are properly used."""
- self.plot.setGraphXLabel("graph x label")
- self.plot.setGraphYLabel("graph y label")
-
- self.plot.addCurve([0, 1], [1, 2], "curve with labels",
- xlabel="curve0 X", ylabel="curve0 Y")
- self.plot.addCurve([-1, 3], [-6, 2], "curve with X label",
- xlabel="curve1 X")
- self.plot.addCurve([-2, 0], [8, 12], "curve with Y label",
- ylabel="curve2 Y")
- self.plot.addCurve([3, 1], [7, 6], "curve with no labels")
-
- self.saveAction._saveCurves(self.plot,
- self.out_fname,
- SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)"
-
- with open(self.out_fname, "rb") as f:
- file_content = f.read()
- if hasattr(file_content, "decode"):
- file_content = file_content.decode()
-
- # case with all curve labels specified
- self.assertIn("#S 1 curve0 Y", file_content)
- self.assertIn("#L curve0 X curve0 Y", file_content)
-
- # graph X&Y labels are used when no curve label is specified
- self.assertIn("#S 2 graph y label", file_content)
- self.assertIn("#L curve1 X graph y label", file_content)
-
- self.assertIn("#S 3 curve2 Y", file_content)
- self.assertIn("#L graph x label curve2 Y", file_content)
-
- self.assertIn("#S 4 graph y label", file_content)
- self.assertIn("#L graph x label graph y label", file_content)
-
-
-class TestSaveActionExtension(PlotWidgetTestCase):
- """Test SaveAction file filter API"""
-
- def _dummySaveFunction(self, plot, filename, nameFilter):
- pass
-
- def testFileFilterAPI(self):
- """Test addition/update of a file filter"""
- saveAction = SaveAction(plot=self.plot, parent=self.plot)
-
- # Add a new file filter
- nameFilter = 'Dummy file (*.dummy)'
- saveAction.setFileFilter('all', nameFilter, self._dummySaveFunction)
- self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
- self.assertEqual(saveAction.getFileFilters('all')[nameFilter],
- self._dummySaveFunction)
-
- # Add a new file filter at a particular position
- nameFilter = 'Dummy file2 (*.dummy)'
- saveAction.setFileFilter('all', nameFilter,
- self._dummySaveFunction, index=3)
- self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
- filters = saveAction.getFileFilters('all')
- self.assertEqual(filters[nameFilter], self._dummySaveFunction)
- self.assertEqual(list(filters.keys()).index(nameFilter),3)
-
- # Update an existing file filter
- nameFilter = SaveAction.IMAGE_FILTER_EDF
- saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction)
- self.assertEqual(saveAction.getFileFilters('image')[nameFilter],
- self._dummySaveFunction)
-
- # Change the position of an existing file filter
- nameFilter = 'Dummy file2 (*.dummy)'
- oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter)
- newIndex = oldIndex - 1
- saveAction.setFileFilter('all', nameFilter,
- self._dummySaveFunction, index=newIndex)
- filters = saveAction.getFileFilters('all')
- self.assertEqual(filters[nameFilter], self._dummySaveFunction)
- self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
-
-def suite():
- test_suite = unittest.TestSuite()
- for cls in (TestSaveActionSaveCurvesAsSpec, TestSaveActionExtension):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(cls))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testScatterMaskToolsWidget.py b/silx/gui/plot/test/testScatterMaskToolsWidget.py
deleted file mode 100644
index 800f30e..0000000
--- a/silx/gui/plot/test/testScatterMaskToolsWidget.py
+++ /dev/null
@@ -1,318 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for MaskToolsWidget"""
-
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import logging
-import os.path
-import unittest
-
-import numpy
-
-from silx.gui import qt
-from silx.test.utils import temp_dir
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import getQToolButtonFromAction
-from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
-from .utils import PlotWidgetTestCase
-
-import fabio
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
- """Basic test for MaskToolsWidget"""
-
- def _createPlot(self):
- return PlotWindow()
-
- def setUp(self):
- super(TestScatterMaskToolsWidget, self).setUp()
- self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget(
- plot=self.plot, name='TEST')
- self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
-
- self.maskWidget = self.widget.widget()
-
- def tearDown(self):
- del self.maskWidget
- del self.widget
- super(TestScatterMaskToolsWidget, self).tearDown()
-
- def testEmptyPlot(self):
- """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
- self.maskWidget.setMultipleMasks('single')
- self.qapp.processEvents()
-
- self.maskWidget.setMultipleMasks('exclusive')
- self.qapp.processEvents()
-
- def _drag(self):
- """Drag from plot center to offset position"""
- plot = self.plot.getWidgetHandle()
- xCenter, yCenter = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- pos0 = xCenter, yCenter
- pos1 = xCenter + offset, yCenter + offset
-
- self.mouseMove(plot, pos=(0, 0))
- self.mouseMove(plot, pos=pos0)
- self.qapp.processEvents()
- self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
- self.qapp.processEvents()
-
- self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
- self.mouseMove(plot, pos=pos1)
- self.qapp.processEvents()
- self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
- self.qapp.processEvents()
- self.mouseMove(plot, pos=(0, 0))
-
- def _drawPolygon(self):
- """Draw a star polygon in the plot"""
- plot = self.plot.getWidgetHandle()
- x, y = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset),
- (x, y + offset)] # Close polygon
-
- self.mouseMove(plot, pos=[0, 0])
- for pos in star:
- self.mouseMove(plot, pos=pos)
- self.qapp.processEvents()
- self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
- self.qapp.processEvents()
- self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
- self.qapp.processEvents()
-
- def _drawPencil(self):
- """Draw a star polygon in the plot"""
- plot = self.plot.getWidgetHandle()
- x, y = plot.width() // 2, plot.height() // 2
- offset = min(plot.width(), plot.height()) // 10
-
- star = [(x, y + offset),
- (x - offset, y - offset),
- (x + offset, y),
- (x - offset, y),
- (x + offset, y - offset)]
-
- self.mouseMove(plot, pos=[0, 0])
- self.mouseMove(plot, pos=star[0])
- self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
- for pos in star[1:]:
- self.mouseMove(plot, pos=pos)
- self.mouseRelease(
- plot, qt.Qt.LeftButton, pos=star[-1])
-
- def testWithAScatter(self):
- """Plot with a Scatter: test MaskToolsWidget interactions"""
-
- # Add and remove a scatter (this should enable/disable GUI + change mask)
- self.plot.addScatter(
- x=numpy.arange(256),
- y=numpy.arange(256),
- value=numpy.random.random(256),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
- self.qapp.processEvents()
-
- self.plot.remove('test', kind='scatter')
- self.qapp.processEvents()
-
- self.plot.addScatter(
- x=numpy.arange(1000),
- y=1000 * (numpy.arange(1000) % 20),
- value=numpy.random.random(1000),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- # Test draw rectangle #
- toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drag()
-
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drag()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # Test draw polygon #
- toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drawPolygon()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drawPolygon()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # Test draw pencil #
- toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- self.maskWidget.pencilSpinBox.setValue(30)
- self.qapp.processEvents()
-
- # mask
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drawPencil()
- self.assertFalse(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # unmask same region
- self.maskWidget.maskStateGroup.button(0).click()
- self.qapp.processEvents()
- self._drawPencil()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- # Test no draw tool #
- toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
-
- self.plot.clear()
-
- def __loadSave(self, file_format):
- self.plot.addScatter(
- x=numpy.arange(256),
- y=25 * (numpy.arange(256) % 10),
- value=numpy.random.random(256),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- # Draw a polygon mask
- toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
- self._drawPolygon()
-
- ref_mask = self.maskWidget.getSelectionMask()
- self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
-
- with temp_dir() as tmp:
- mask_filename = os.path.join(tmp, 'mask.' + file_format)
- self.maskWidget.save(mask_filename, file_format)
-
- self.maskWidget.resetSelectionMask()
- self.assertTrue(
- numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
-
- self.maskWidget.load(mask_filename)
- self.assertTrue(numpy.all(numpy.equal(
- self.maskWidget.getSelectionMask(), ref_mask)))
-
- def testLoadSaveNpy(self):
- self.__loadSave("npy")
-
- def testLoadSaveCsv(self):
- self.__loadSave("csv")
-
- def testSigMaskChangedEmitted(self):
- self.qapp.processEvents()
- self.plot.addScatter(
- x=numpy.arange(1000),
- y=1000 * (numpy.arange(1000) % 20),
- value=numpy.ones((1000,)),
- legend='test')
- self.plot._setActiveItem(kind="scatter", legend="test")
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- self.plot.remove('test', kind='scatter')
- self.qapp.processEvents()
-
- self.plot.addScatter(
- x=numpy.arange(1000),
- y=1000 * (numpy.arange(1000) % 20),
- value=numpy.random.random(1000),
- legend='test')
-
- l = []
-
- def slot():
- l.append(1)
-
- self.maskWidget.sigMaskChanged.connect(slot)
-
- # rectangle mask
- toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
- self.assertIsNot(toolButton, None)
- self.mouseClick(toolButton, qt.Qt.LeftButton)
- self.maskWidget.maskStateGroup.button(1).click()
- self.qapp.processEvents()
- self._drag()
-
- self.assertGreater(len(l), 0)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestScatterMaskToolsWidget,):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testScatterView.py b/silx/gui/plot/test/testScatterView.py
deleted file mode 100644
index 583e3ed..0000000
--- a/silx/gui/plot/test/testScatterView.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for ScatterView"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "06/03/2018"
-
-
-import unittest
-
-import numpy
-
-from silx.gui.plot.items import Axis, Scatter
-from silx.gui.plot import ScatterView
-from silx.gui.plot.test.utils import PlotWidgetTestCase
-
-
-class TestScatterView(PlotWidgetTestCase):
- """Test of ScatterView widget"""
-
- def _createPlot(self):
- return ScatterView()
-
- def test(self):
- """Simple tests"""
- x = numpy.arange(100)
- y = numpy.arange(100)
- value = numpy.arange(100)
- self.plot.setData(x, y, value)
- self.qapp.processEvents()
-
- data = self.plot.getData()
- self.assertEqual(len(data), 5)
- self.assertTrue(numpy.all(numpy.equal(x, data[0])))
- self.assertTrue(numpy.all(numpy.equal(y, data[1])))
- self.assertTrue(numpy.all(numpy.equal(value, data[2])))
- self.assertIsNone(data[3]) # xerror
- self.assertIsNone(data[4]) # yerror
-
- # Test access to scatter item
- self.assertIsInstance(self.plot.getScatterItem(), Scatter)
-
- # Test toolbar actions
-
- action = self.plot.getScatterToolBar().getXAxisLogarithmicAction()
- action.trigger()
- self.qapp.processEvents()
-
- maskAction = self.plot.getScatterToolBar().actions()[-1]
- maskAction.trigger()
- self.qapp.processEvents()
-
- # Test proxy API
-
- self.plot.resetZoom()
- self.qapp.processEvents()
-
- scale = self.plot.getXAxis().getScale()
- self.assertEqual(scale, Axis.LOGARITHMIC)
-
- scale = self.plot.getYAxis().getScale()
- self.assertEqual(scale, Axis.LINEAR)
-
- title = 'Test ScatterView'
- self.plot.setGraphTitle(title)
- self.assertEqual(self.plot.getGraphTitle(), title)
-
- self.qapp.processEvents()
-
- # Reset scatter data
-
- self.plot.setData(None, None, None)
- self.qapp.processEvents()
-
- data = self.plot.getData()
- self.assertEqual(len(data), 5)
- self.assertEqual(len(data[0]), 0) # x
- self.assertEqual(len(data[1]), 0) # y
- self.assertEqual(len(data[2]), 0) # value
- self.assertIsNone(data[3]) # xerror
- self.assertIsNone(data[4]) # yerror
-
- def testAlpha(self):
- """Test alpha transparency in setData"""
- _pts = 100
- _levels = 100
- _fwhm = 50
- x = numpy.random.rand(_pts)*_levels
- y = numpy.random.rand(_pts)*_levels
- value = numpy.random.rand(_pts)*_levels
- x0 = x[int(_pts/2)]
- y0 = x[int(_pts/2)]
- #2D Gaussian kernel
- alpha = numpy.exp(-4*numpy.log(2) * ((x-x0)**2 + (y-y0)**2) / _fwhm**2)
-
- self.plot.setData(x, y, value, alpha=alpha)
- self.qapp.processEvents()
-
- alphaData = self.plot.getScatterItem().getAlphaData()
- self.assertTrue(numpy.all(numpy.equal(alpha, alphaData)))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestScatterView))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testStackView.py b/silx/gui/plot/test/testStackView.py
deleted file mode 100644
index 7605bbc..0000000
--- a/silx/gui/plot/test/testStackView.py
+++ /dev/null
@@ -1,261 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for StackView"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "20/03/2017"
-
-
-import unittest
-import numpy
-
-from silx.gui.utils.testutils import TestCaseQt, SignalListener
-
-from silx.gui import qt
-from silx.gui.plot import StackView
-from silx.gui.plot.StackView import StackViewMainWindow
-
-from silx.utils.array_like import ListOfImages
-
-
-class TestStackView(TestCaseQt):
- """Base class for tests of StackView."""
-
- def setUp(self):
- super(TestStackView, self).setUp()
- self.stackview = StackView()
- self.stackview.show()
- self.qWaitForWindowExposed(self.stackview)
- self.mystack = numpy.fromfunction(
- lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
- (10, 20, 30)
- )
-
- def tearDown(self):
- self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.stackview.close()
- del self.stackview
- super(TestStackView, self).tearDown()
-
- def testScaleColormapRangeToStack(self):
- """Test scaleColormapRangeToStack"""
- self.stackview.setStack(self.mystack)
- self.stackview.setColormap("viridis")
- colormap = self.stackview.getColormap()
-
- # Colormap autoscale to image
- self.assertEqual(colormap.getVRange(), (None, None))
- self.stackview.scaleColormapRangeToStack()
-
- # Colormap range set according to stack range
- self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max()))
-
- def testSetStack(self):
- self.stackview.setStack(self.mystack)
- self.stackview.setColormap("viridis", autoscale=True)
- my_trans_stack, params = self.stackview.getStack()
- self.assertEqual(my_trans_stack.shape, self.mystack.shape)
- self.assertTrue(numpy.array_equal(self.mystack,
- my_trans_stack))
- self.assertEqual(params["colormap"]["name"],
- "viridis")
-
- def testSetStackPerspective(self):
- self.stackview.setStack(self.mystack, perspective=1)
- # my_orig_stack, params = self.stackview.getStack()
- my_trans_stack, params = self.stackview.getCurrentView()
-
- # get stack returns the transposed data, depending on the perspective
- self.assertEqual(my_trans_stack.shape,
- (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
- self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
- my_trans_stack))
-
- def testSetStackListOfImages(self):
- loi = [self.mystack[i] for i in range(self.mystack.shape[0])]
-
- self.stackview.setStack(loi)
- my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True)
- my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True)
- self.assertEqual(my_trans_stack.shape, self.mystack.shape)
- self.assertTrue(numpy.array_equal(self.mystack,
- my_trans_stack))
- self.assertTrue(numpy.array_equal(self.mystack,
- my_orig_stack))
- self.assertIsInstance(my_trans_stack, numpy.ndarray)
-
- self.stackview.setStack(loi, perspective=2)
- my_orig_stack, params = self.stackview.getStack(copy=False)
- my_trans_stack, params = self.stackview.getCurrentView(copy=False)
- # getStack(copy=False) must return the object set in setStack
- self.assertIs(my_orig_stack, loi)
- # getCurrentView(copy=False) returns a ListOfImages whose .images
- # attr is the original data
- self.assertEqual(my_trans_stack.shape,
- (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]))
- self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack),
- numpy.transpose(self.mystack, axes=(2, 0, 1))))
- self.assertIsInstance(my_trans_stack,
- ListOfImages) # returnNumpyArray=False by default in getStack
- self.assertIs(my_trans_stack.images, loi)
-
- def testPerspective(self):
- self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)))
- self.assertEqual(self.stackview._perspective, 0,
- "Default perspective is not 0 (dim1-dim2).")
-
- self.stackview._StackView__planeSelection.setPerspective(1)
- self.assertEqual(self.stackview._perspective, 1,
- "Plane selection combobox not updating perspective")
-
- self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3)))
- self.assertEqual(self.stackview._perspective, 1,
- "Perspective not preserved when calling setStack "
- "without specifying the perspective parameter.")
-
- self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2)
- self.assertEqual(self.stackview._perspective, 2,
- "Perspective not set in setStack(..., perspective=2).")
-
- def testDefaultTitle(self):
- """Test that the plot title contains the proper Z information"""
- self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
- calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=0")
- self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=2")
-
- self.stackview._StackView__planeSelection.setPerspective(1)
- self.stackview.setFrameNumber(0)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=-10")
- self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=10")
-
- self.stackview._StackView__planeSelection.setPerspective(2)
- self.stackview.setFrameNumber(0)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=3.14")
- self.stackview.setFrameNumber(1)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Image z=6.28")
-
- def testCustomTitle(self):
- """Test setting the plot title with a user defined callback"""
- self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
- calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
-
- def title_callback(frame_idx):
- return "Cubed index title %d" % (frame_idx**3)
-
- self.stackview.setTitleCallback(title_callback)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 0")
- self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 8")
-
- # perspective should not matter, only frame index
- self.stackview._StackView__planeSelection.setPerspective(1)
- self.stackview.setFrameNumber(0)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 0")
- self.stackview.setFrameNumber(2)
- self.assertEqual(self.stackview._plot.getGraphTitle(),
- "Cubed index title 8")
-
- with self.assertRaises(TypeError):
- # setTitleCallback should not accept non-callable objects like strings
- self.stackview.setTitleCallback(
- "Là, vous faites sirop de vingt-et-un et vous dites : "
- "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot,"
- " passe-montagne, sirop au bon goût.")
-
- def testStackFrameNumber(self):
- self.stackview.setStack(self.mystack)
- self.assertEqual(self.stackview.getFrameNumber(), 0)
-
- listener = SignalListener()
- self.stackview.sigFrameChanged.connect(listener)
-
- self.stackview.setFrameNumber(1)
- self.assertEqual(self.stackview.getFrameNumber(), 1)
- self.assertEqual(listener.arguments(), [(1,)])
-
-
-class TestStackViewMainWindow(TestCaseQt):
- """Base class for tests of StackView."""
-
- def setUp(self):
- super(TestStackViewMainWindow, self).setUp()
- self.stackview = StackViewMainWindow()
- self.stackview.show()
- self.qWaitForWindowExposed(self.stackview)
- self.mystack = numpy.fromfunction(
- lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
- (10, 20, 30)
- )
-
- def tearDown(self):
- self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.stackview.close()
- del self.stackview
- super(TestStackViewMainWindow, self).tearDown()
-
- def testSetStack(self):
- self.stackview.setStack(self.mystack)
- self.stackview.setColormap("viridis", autoscale=True)
- my_trans_stack, params = self.stackview.getStack()
- self.assertEqual(my_trans_stack.shape, self.mystack.shape)
- self.assertTrue(numpy.array_equal(self.mystack,
- my_trans_stack))
- self.assertEqual(params["colormap"]["name"],
- "viridis")
-
- def testSetStackPerspective(self):
- self.stackview.setStack(self.mystack, perspective=1)
- my_trans_stack, params = self.stackview.getCurrentView()
- # get stack returns the transposed data, depending on the perspective
- self.assertEqual(my_trans_stack.shape,
- (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
- self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
- my_trans_stack))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestStackView))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestStackViewMainWindow))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py
deleted file mode 100644
index d5046ba..0000000
--- a/silx/gui/plot/test/testStats.py
+++ /dev/null
@@ -1,1058 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for CurvesROIWidget"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "07/03/2018"
-
-
-from silx.gui import qt
-from silx.gui.plot.stats import stats
-from silx.gui.plot import StatsWidget
-from silx.gui.plot.stats import statshandler
-from silx.gui.utils.testutils import TestCaseQt, SignalListener
-from silx.gui.plot import Plot1D, Plot2D
-from silx.gui.plot3d.SceneWidget import SceneWidget
-from silx.gui.plot.items.roi import RectangleROI, PolygonROI
-from silx.gui.plot.tools.roi import RegionOfInterestManager
-from silx.gui.plot.stats.stats import Stats
-from silx.gui.plot.CurvesROIWidget import ROI
-from silx.utils.testutils import ParametricTestCase
-import unittest
-import logging
-import numpy
-
-_logger = logging.getLogger(__name__)
-
-
-class TestStatsBase(object):
- """Base class for stats TestCase"""
- def setUp(self):
- self.createCurveContext()
- self.createImageContext()
- self.createScatterContext()
-
- def tearDown(self):
- self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot1d.close()
- del self.plot1d
- self.plot2d.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot2d.close()
- del self.plot2d
- self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.scatterPlot.close()
- del self.scatterPlot
-
- def createCurveContext(self):
- self.plot1d = Plot1D()
- x = range(20)
- y = range(20)
- self.plot1d.addCurve(x, y, legend='curve0')
-
- self.curveContext = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=False,
- roi=None)
-
- def createScatterContext(self):
- self.scatterPlot = Plot2D()
- lgd = 'scatter plot'
- self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
- self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
- self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
- self.scatterPlot.addScatter(self.xScatterData, self.yScatterData,
- self.valuesScatterData, legend=lgd)
- self.scatterContext = stats._ScatterContext(
- item=self.scatterPlot.getScatter(lgd),
- plot=self.scatterPlot,
- onlimits=False,
- roi=None
- )
-
- def createImageContext(self):
- self.plot2d = Plot2D()
- self._imgLgd = 'test image'
- self.imageData = numpy.arange(32*128).reshape(32, 128)
- self.plot2d.addImage(data=self.imageData,
- legend=self._imgLgd, replace=False)
- self.imageContext = stats._ImageContext(
- item=self.plot2d.getImage(self._imgLgd),
- plot=self.plot2d,
- onlimits=False,
- roi=None
- )
-
- def getBasicStats(self):
- return {
- 'min': stats.StatMin(),
- 'minCoords': stats.StatCoordMin(),
- 'max': stats.StatMax(),
- 'maxCoords': stats.StatCoordMax(),
- 'std': stats.Stat(name='std', fct=numpy.std),
- 'mean': stats.Stat(name='mean', fct=numpy.mean),
- 'com': stats.StatCOM()
- }
-
-
-class TestStats(TestStatsBase, TestCaseQt):
- """
- Test :class:`BaseClass` class and inheriting classes
- """
- def setUp(self):
- TestCaseQt.setUp(self)
- TestStatsBase.setUp(self)
-
- def tearDown(self):
- TestStatsBase.tearDown(self)
- TestCaseQt.tearDown(self)
-
- def testBasicStatsCurve(self):
- """Test result for simple stats on a curve"""
- _stats = self.getBasicStats()
- xData = yData = numpy.array(range(20))
- self.assertEqual(_stats['min'].calculate(self.curveContext), 0)
- self.assertEqual(_stats['max'].calculate(self.curveContext), 19)
- self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,))
- self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,))
- self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData))
- self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData))
- com = numpy.sum(xData * yData) / numpy.sum(yData)
- self.assertEqual(_stats['com'].calculate(self.curveContext), com)
-
- def testBasicStatsImage(self):
- """Test result for simple stats on an image"""
- _stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.imageContext), 0)
- self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1)
- self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0))
- self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31))
- self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData))
- self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData))
-
- yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
- xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
- dataXRange = range(self.imageData.shape[1])
- dataYRange = range(self.imageData.shape[0])
-
- ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
- xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
-
- self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
-
- def testStatsImageAdv(self):
- """Test that scale and origin are taking into account for images"""
-
- image2Data = numpy.arange(32 * 128).reshape(32, 128)
- self.plot2d.addImage(data=image2Data, legend=self._imgLgd,
- replace=True, origin=(100, 10), scale=(2, 0.5))
- image2Context = stats._ImageContext(
- item=self.plot2d.getImage(self._imgLgd),
- plot=self.plot2d,
- onlimits=False,
- roi=None,
- )
- _stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(image2Context), 0)
- self.assertEqual(
- _stats['max'].calculate(image2Context), 128 * 32 - 1)
- self.assertEqual(
- _stats['minCoords'].calculate(image2Context), (100, 10))
- self.assertEqual(
- _stats['maxCoords'].calculate(image2Context), (127*2. + 100,
- 31 * 0.5 + 10))
- self.assertEqual(_stats['std'].calculate(image2Context),
- numpy.std(self.imageData))
- self.assertEqual(_stats['mean'].calculate(image2Context),
- numpy.mean(self.imageData))
-
- yData = numpy.sum(self.imageData, axis=1)
- xData = numpy.sum(self.imageData, axis=0)
- dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64)
- dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64)
-
- ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
- ycom = (ycom * 0.5) + 10
- xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
- xcom = (xcom * 2.) + 100
- self.assertTrue(numpy.allclose(
- _stats['com'].calculate(image2Context), (xcom, ycom)))
-
- def testBasicStatsScatter(self):
- """Test result for simple stats on a scatter"""
- _stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.scatterContext), 5)
- self.assertEqual(_stats['max'].calculate(self.scatterContext), 90)
- self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2))
- self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69))
- self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData))
- self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData))
-
- data = self.valuesScatterData.astype(numpy.float64)
- comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
- comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
- self.assertEqual(_stats['com'].calculate(self.scatterContext),
- (comx, comy))
-
- def testKindNotManagedByStat(self):
- """Make sure an exception is raised if we try to execute calculate
- of the base class"""
- b = stats.StatBase(name='toto', compatibleKinds='curve')
- with self.assertRaises(NotImplementedError):
- b.calculate(self.imageContext)
-
- def testKindNotManagedByContext(self):
- """
- Make sure an error is raised if we try to calculate a statistic with
- a context not managed
- """
- myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve'))
- myStat.calculate(self.curveContext)
- with self.assertRaises(ValueError):
- myStat.calculate(self.scatterContext)
- with self.assertRaises(ValueError):
- myStat.calculate(self.imageContext)
-
- def testOnLimits(self):
- stat = stats.StatMin()
-
- self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
- curveContextOnLimits = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=True,
- roi=None)
- self.assertEqual(stat.calculate(curveContextOnLimits), 2)
-
- self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
- imageContextOnLimits = stats._ImageContext(
- item=self.plot2d.getImage('test image'),
- plot=self.plot2d,
- onlimits=True,
- roi=None)
- self.assertEqual(stat.calculate(imageContextOnLimits), 32)
-
- self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
- scatterContextOnLimits = stats._ScatterContext(
- item=self.scatterPlot.getScatter('scatter plot'),
- plot=self.scatterPlot,
- onlimits=True,
- roi=None)
- self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
-
-
-class TestStatsFormatter(TestCaseQt):
- """Simple test to check usage of the :class:`StatsFormatter`"""
- def setUp(self):
- TestCaseQt.setUp(self)
- self.plot1d = Plot1D()
- x = range(20)
- y = range(20)
- self.plot1d.addCurve(x, y, legend='curve0')
-
- self.curveContext = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=False,
- roi=None)
-
- self.stat = stats.StatMin()
-
- def tearDown(self):
- self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot1d.close()
- del self.plot1d
- TestCaseQt.tearDown(self)
-
- def testEmptyFormatter(self):
- """Make sure a formatter with no formatter definition will return a
- simple cast to str"""
- emptyFormatter = statshandler.StatFormatter()
- self.assertEqual(
- emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000')
-
- def testSettedFormatter(self):
- """Make sure a formatter with no formatter definition will return a
- simple cast to str"""
- formatter= statshandler.StatFormatter(formatter='{0:.3f}')
- self.assertEqual(
- formatter.format(self.stat.calculate(self.curveContext)), '0.000')
-
-
-class TestStatsHandler(TestCaseQt):
- """Make sure the StatHandler is correctly making the link between
- :class:`StatBase` and :class:`StatFormatter` and checking the API is valid
- """
- def setUp(self):
- TestCaseQt.setUp(self)
- self.plot1d = Plot1D()
- x = range(20)
- y = range(20)
- self.plot1d.addCurve(x, y, legend='curve0')
- self.curveItem = self.plot1d.getCurve('curve0')
-
- self.stat = stats.StatMin()
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot1d.close()
- self.plot1d = None
- TestCaseQt.tearDown(self)
-
- def testConstructor(self):
- """Make sure the constructor can deal will all possible arguments:
-
- * tuple of :class:`StatBase` derivated classes
- * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`)
- * tuple of tuples (str, pointer to function, kind)
- """
- handler0 = statshandler.StatsHandler(
- (stats.StatMin(), stats.StatMax())
- )
-
- res = handler0.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('min' in res)
- self.assertEqual(res['min'], '0')
- self.assertTrue('max' in res)
- self.assertEqual(res['max'], '19')
-
- handler1 = statshandler.StatsHandler(
- (
- (stats.StatMin(), statshandler.StatFormatter(formatter=None)),
- (stats.StatMax(), statshandler.StatFormatter())
- )
- )
-
- res = handler1.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('min' in res)
- self.assertEqual(res['min'], '0')
- self.assertTrue('max' in res)
- self.assertEqual(res['max'], '19.000')
-
- handler2 = statshandler.StatsHandler(
- (
- (stats.StatMin(), None),
- (stats.StatMax(), statshandler.StatFormatter())
- ))
-
- res = handler2.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('min' in res)
- self.assertEqual(res['min'], '0')
- self.assertTrue('max' in res)
- self.assertEqual(res['max'], '19.000')
-
- handler3 = statshandler.StatsHandler((
- (('amin', numpy.argmin), statshandler.StatFormatter()),
- ('amax', numpy.argmax)
- ))
-
- res = handler3.calculate(item=self.curveItem, plot=self.plot1d,
- onlimits=False)
- self.assertTrue('amin' in res)
- self.assertEqual(res['amin'], '0.000')
- self.assertTrue('amax' in res)
- self.assertEqual(res['amax'], '19')
-
- with self.assertRaises(ValueError):
- statshandler.StatsHandler(('name'))
-
-
-class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
- """Basic test for StatsWidget with curves"""
- def setUp(self):
- TestCaseQt.setUp(self)
- self.plot = Plot1D()
- self.plot.show()
- x = range(20)
- y = range(20)
- self.plot.addCurve(x, y, legend='curve0')
- y = range(12, 32)
- self.plot.addCurve(x, y, legend='curve1')
- y = range(-2, 18)
- self.plot.addCurve(x, y, legend='curve2')
- self.widget = StatsWidget.StatsWidget(plot=self.plot)
- self.statsTable = self.widget._statsTable
-
- mystats = statshandler.StatsHandler((
- stats.StatMin(),
- (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatMax(),
- (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatDelta(),
- ('std', numpy.std),
- ('mean', numpy.mean),
- stats.StatCOM()
- ))
-
- self.statsTable.setStats(mystats)
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.statsTable = None
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- self.widget = None
- self.plot = None
- TestCaseQt.tearDown(self)
-
- def testDisplayActiveItemsSyncOptions(self):
- """
- Test that the several option of the sync options are well
- synchronized between the different object"""
- widget = StatsWidget.StatsWidget(plot=self.plot)
- table = StatsWidget.StatsTable(plot=self.plot)
-
- def check_display_only_active_item(only_active):
- # check internal value
- self.assertIs(widget._statsTable._displayOnlyActItem, only_active)
- # self.assertTrue(table._displayOnlyActItem is only_active)
- # check gui display
- self.assertEqual(widget._options.isActiveItemMode(), only_active)
-
- for displayOnlyActiveItems in (True, False):
- with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems):
- widget.setDisplayOnlyActiveItem(displayOnlyActiveItems)
- # table.setDisplayOnlyActiveItem(displayOnlyActiveItems)
- check_display_only_active_item(displayOnlyActiveItems)
-
- check_display_only_active_item(only_active=False)
- widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- table.setAttribute(qt.Qt.WA_DeleteOnClose)
- widget.close()
- table.close()
-
- def testInit(self):
- """Make sure all the curves are registred on initialization"""
- self.assertEqual(self.statsTable.rowCount(), 3)
-
- def testRemoveCurve(self):
- """Make sure the Curves stats take into account the curve removal from
- plot"""
- self.plot.removeCurve('curve2')
- self.assertEqual(self.statsTable.rowCount(), 2)
- for iRow in range(2):
- self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1'))
-
- self.plot.removeCurve('curve0')
- self.assertEqual(self.statsTable.rowCount(), 1)
- self.plot.removeCurve('curve1')
- self.assertEqual(self.statsTable.rowCount(), 0)
-
- def testAddCurve(self):
- """Make sure the Curves stats take into account the add curve action"""
- self.plot.addCurve(legend='curve3', x=range(10), y=range(10))
- self.assertEqual(self.statsTable.rowCount(), 4)
-
- def testUpdateCurveFromAddCurve(self):
- """Make sure the stats of the cuve will be removed after updating a
- curve"""
- self.plot.addCurve(legend='curve0', x=range(10), y=range(10))
- self.qapp.processEvents()
- self.assertEqual(self.statsTable.rowCount(), 3)
- curve = self.plot._getItem(kind='curve', legend='curve0')
- tableItems = self.statsTable._itemToTableItems(curve)
- self.assertEqual(tableItems['max'].text(), '9')
-
- def testUpdateCurveFromCurveObj(self):
- self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
- self.qapp.processEvents()
- self.assertEqual(self.statsTable.rowCount(), 3)
- curve = self.plot._getItem(kind='curve', legend='curve0')
- tableItems = self.statsTable._itemToTableItems(curve)
- self.assertEqual(tableItems['max'].text(), '3')
-
- def testSetAnotherPlot(self):
- plot2 = Plot1D()
- plot2.addCurve(x=range(26), y=range(26), legend='new curve')
- self.statsTable.setPlot(plot2)
- self.assertEqual(self.statsTable.rowCount(), 1)
- self.qapp.processEvents()
- plot2.setAttribute(qt.Qt.WA_DeleteOnClose)
- plot2.close()
- plot2 = None
-
- def testUpdateMode(self):
- """Make sure the update modes are well take into account"""
- self.plot.setActiveCurve('curve0')
- for display_only_active in (True, False):
- with self.subTest(display_only_active=display_only_active):
- self.widget.setDisplayOnlyActiveItem(display_only_active)
- self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
- self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
- update_stats_action = self.widget._options.getUpdateStatsAction()
- # test from api
- self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
- self.widget.show()
- # check stats change in auto mode
- self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3))
- self.qapp.processEvents()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == -1.)
-
- self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5))
- self.qapp.processEvents()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == 1.)
-
- # check stats change in manual mode only if requested
- self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
- self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
-
- self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6))
- self.qapp.processEvents()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == 1.)
-
- update_stats_action.trigger()
- tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
- curve0_min = tableItems['min'].text()
- self.assertTrue(float(curve0_min) == 2.)
-
- def testItemHidden(self):
- """Test if an item is hide, then the associated stats item is also
- hide"""
- curve0 = self.plot.getCurve('curve0')
- curve1 = self.plot.getCurve('curve1')
- curve2 = self.plot.getCurve('curve2')
-
- self.plot.show()
- self.widget.show()
- self.qWaitForWindowExposed(self.widget)
- self.assertFalse(self.statsTable.isRowHidden(0))
- self.assertFalse(self.statsTable.isRowHidden(1))
- self.assertFalse(self.statsTable.isRowHidden(2))
-
- curve0.setVisible(False)
- self.qapp.processEvents()
- self.assertTrue(self.statsTable.isRowHidden(0))
- curve0.setVisible(True)
- self.qapp.processEvents()
- self.assertFalse(self.statsTable.isRowHidden(0))
- curve1.setVisible(False)
- self.qapp.processEvents()
- self.assertTrue(self.statsTable.isRowHidden(1))
- tableItems = self.statsTable._itemToTableItems(curve2)
- curve2_min = tableItems['min'].text()
- self.assertTrue(float(curve2_min) == -2.)
-
- curve0.setVisible(False)
- curve1.setVisible(False)
- curve2.setVisible(False)
- self.qapp.processEvents()
- self.assertTrue(self.statsTable.isRowHidden(0))
- self.assertTrue(self.statsTable.isRowHidden(1))
- self.assertTrue(self.statsTable.isRowHidden(2))
-
-
-class TestStatsWidgetWithImages(TestCaseQt):
- """Basic test for StatsWidget with images"""
-
- IMAGE_LEGEND = 'test image'
-
- def setUp(self):
- TestCaseQt.setUp(self)
- self.plot = Plot2D()
-
- self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128),
- legend=self.IMAGE_LEGEND, replace=False)
-
- self.widget = StatsWidget.StatsTable(plot=self.plot)
-
- mystats = statshandler.StatsHandler((
- (stats.StatMin(), statshandler.StatFormatter()),
- (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- (stats.StatMax(), statshandler.StatFormatter()),
- (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- (stats.StatDelta(), statshandler.StatFormatter()),
- ('std', numpy.std),
- ('mean', numpy.mean),
- (stats.StatCOM(), statshandler.StatFormatter(None))
- ))
-
- self.widget.setStats(mystats)
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- self.widget = None
- self.plot = None
- TestCaseQt.tearDown(self)
-
- def test(self):
- image = self.plot._getItem(
- kind='image', legend=self.IMAGE_LEGEND)
- tableItems = self.widget._itemToTableItems(image)
-
- maxText = '{0:.3f}'.format((128 * 128) - 1)
- self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND)
- self.assertEqual(tableItems['min'].text(), '0.000')
- self.assertEqual(tableItems['max'].text(), maxText)
- self.assertEqual(tableItems['delta'].text(), maxText)
- self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0')
- self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0')
-
- def testItemHidden(self):
- """Test if an item is hide, then the associated stats item is also
- hide"""
- self.widget.show()
- self.plot.show()
- self.qWaitForWindowExposed(self.widget)
- self.assertFalse(self.widget.isRowHidden(0))
- self.plot.getImage(self.IMAGE_LEGEND).setVisible(False)
- self.qapp.processEvents()
- self.assertTrue(self.widget.isRowHidden(0))
-
-
-class TestStatsWidgetWithScatters(TestCaseQt):
-
- SCATTER_LEGEND = 'scatter plot'
-
- def setUp(self):
- TestCaseQt.setUp(self)
- self.scatterPlot = Plot2D()
- self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60],
- [2, 3, 4, 26, 69, 6],
- [5, 6, 7, 10, 90, 20],
- legend=self.SCATTER_LEGEND)
- self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
-
- mystats = statshandler.StatsHandler((
- stats.StatMin(),
- (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatMax(),
- (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
- stats.StatDelta(),
- ('std', numpy.std),
- ('mean', numpy.mean),
- stats.StatCOM()
- ))
-
- self.widget.setStats(mystats)
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.scatterPlot.close()
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- self.widget = None
- self.scatterPlot = None
- TestCaseQt.tearDown(self)
-
- def testStats(self):
- scatter = self.scatterPlot._getItem(
- kind='scatter', legend=self.SCATTER_LEGEND)
- tableItems = self.widget._itemToTableItems(scatter)
- self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND)
- self.assertEqual(tableItems['min'].text(), '5')
- self.assertEqual(tableItems['coords min'].text(), '0, 2')
- self.assertEqual(tableItems['max'].text(), '90')
- self.assertEqual(tableItems['coords max'].text(), '50, 69')
- self.assertEqual(tableItems['delta'].text(), '85')
-
-
-class TestEmptyStatsWidget(TestCaseQt):
- def test(self):
- widget = StatsWidget.StatsWidget()
- widget.show()
- self.qWaitForWindowExposed(widget)
-
-
-# skip unit test for pyqt4 because there is some unrealised widget without
-# apparent reason
-@unittest.skipIf(qt.qVersion().split('.')[0] == '4', reason='PyQt4 not tested')
-class TestLineWidget(TestCaseQt):
- """Some test for the StatsLineWidget."""
- def setUp(self):
- TestCaseQt.setUp(self)
-
- mystats = statshandler.StatsHandler((
- (stats.StatMin(), statshandler.StatFormatter()),
- ))
-
- self.plot = Plot1D()
- self.plot.show()
- self.x = range(20)
- self.y0 = range(20)
- self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0')
- self.y1 = range(12, 32)
- self.plot.addCurve(self.x, self.y1, legend='curve1')
- self.y2 = range(-2, 18)
- self.plot.addCurve(self.x, self.y2, legend='curve2')
- self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot,
- kind='curve',
- stats=mystats)
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.widget.setPlot(None)
- self.widget._lineStatsWidget._statQlineEdit.clear()
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- self.widget = None
- self.plot = None
- TestCaseQt.tearDown(self)
-
- def testProcessing(self):
- self.widget._lineStatsWidget.setStatsOnVisibleData(False)
- self.qapp.processEvents()
- self.plot.setActiveCurve(legend='curve0')
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000')
- self.plot.setActiveCurve(legend='curve1')
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000')
- self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
- self.widget.setStatsOnVisibleData(True)
- self.qapp.processEvents()
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
- self.plot.setActiveCurve(None)
- self.assertIsNone(self.plot.getActiveCurve())
- self.widget.setStatsOnVisibleData(False)
- self.qapp.processEvents()
- self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
- self.widget.setKind('image')
- self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312)
- self.qapp.processEvents()
- self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312')
-
- def testUpdateMode(self):
- """Make sure the update modes are well take into account"""
- self.plot.setActiveCurve(self.curve0)
- _autoRB = self.widget._options._autoRB
- _manualRB = self.widget._options._manualRB
- # test from api
- self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
- self.assertTrue(_autoRB.isChecked())
- self.assertFalse(_manualRB.isChecked())
-
- # check stats change in auto mode
- curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text()
- new_y = numpy.array(self.y0) - 2.56
- self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
- curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
- self.assertTrue(curve0_min != curve0_min2)
-
- # check stats change in manual mode only if requested
- self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
- self.assertFalse(_autoRB.isChecked())
- self.assertTrue(_manualRB.isChecked())
-
- new_y = numpy.array(self.y0) - 1.2
- self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
- curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
- self.assertTrue(curve0_min3 == curve0_min2)
- self.widget._options._updateRequested()
- curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
- self.assertTrue(curve0_min3 != curve0_min2)
-
- # test from gui
- self.widget.showRadioButtons(True)
- self.widget._options._autoRB.toggle()
- self.assertTrue(_autoRB.isChecked())
- self.assertFalse(_manualRB.isChecked())
-
- self.widget._options._manualRB.toggle()
- self.assertFalse(_autoRB.isChecked())
- self.assertTrue(_manualRB.isChecked())
-
-
-class TestUpdateModeWidget(TestCaseQt):
- """Test UpdateModeWidget"""
- def setUp(self):
- TestCaseQt.setUp(self)
- self.widget = StatsWidget.UpdateModeWidget(parent=None)
-
- def tearDown(self):
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- self.widget = None
- TestCaseQt.tearDown(self)
-
- def testSignals(self):
- """Test the signal emission of the widget"""
- self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
- modeChangedListener = SignalListener()
- manualUpdateListener = SignalListener()
- self.widget.sigUpdateModeChanged.connect(modeChangedListener)
- self.widget.sigUpdateRequested.connect(manualUpdateListener)
- self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
- self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
- self.assertEqual(modeChangedListener.callCount(), 0)
- self.qapp.processEvents()
-
- self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
- self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
- self.qapp.processEvents()
- self.assertEqual(modeChangedListener.callCount(), 1)
- self.assertEqual(manualUpdateListener.callCount(), 0)
- self.widget._updatePB.click()
- self.widget._updatePB.click()
- self.assertEqual(manualUpdateListener.callCount(), 2)
-
- self.widget._autoRB.setChecked(True)
- self.assertEqual(modeChangedListener.callCount(), 2)
- self.widget._updatePB.click()
- self.assertEqual(manualUpdateListener.callCount(), 2)
-
-
-class TestStatsROI(TestStatsBase, TestCaseQt):
- """
- Test stats based on ROI
- """
- def setUp(self):
- TestCaseQt.setUp(self)
- self.createRois()
- TestStatsBase.setUp(self)
- self.createHistogramContext()
-
- self.roiManager = RegionOfInterestManager(self.plot2d)
- self.roiManager.addRoi(self._2Droi_rect)
- self.roiManager.addRoi(self._2Droi_poly)
-
- def tearDown(self):
- self.roiManager.clear()
- self.roiManager = None
- self._1Droi = None
- self._2Droi_rect = None
- self._2Droi_poly = None
- self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plotHisto.close()
- self.plotHisto = None
- TestStatsBase.tearDown(self)
- TestCaseQt.tearDown(self)
-
- def createRois(self):
- self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0)
- self._2Droi_rect = RectangleROI()
- self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
- self._2Droi_poly = PolygonROI()
- points = numpy.array(((0, 20), (0, 0), (10, 0)))
- self._2Droi_poly.setPoints(points=points)
-
- def createCurveContext(self):
- TestStatsBase.createCurveContext(self)
- self.curveContext = stats._CurveContext(
- item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=False,
- roi=self._1Droi)
-
- def createHistogramContext(self):
- self.plotHisto = Plot1D()
- x = range(20)
- y = range(20)
- self.plotHisto.addHistogram(x, y, legend='histo0')
-
- self.histoContext = stats._HistogramContext(
- item=self.plotHisto.getHistogram('histo0'),
- plot=self.plotHisto,
- onlimits=False,
- roi=self._1Droi)
-
- def createScatterContext(self):
- TestStatsBase.createScatterContext(self)
- self.scatterContext = stats._ScatterContext(
- item=self.scatterPlot.getScatter('scatter plot'),
- plot=self.scatterPlot,
- onlimits=False,
- roi=self._1Droi
- )
-
- def createImageContext(self):
- TestStatsBase.createImageContext(self)
-
- self.imageContext = stats._ImageContext(
- item=self.plot2d.getImage(self._imgLgd),
- plot=self.plot2d,
- onlimits=False,
- roi=self._2Droi_rect
- )
-
- self.imageContext_2 = stats._ImageContext(
- item=self.plot2d.getImage(self._imgLgd),
- plot=self.plot2d,
- onlimits=False,
- roi=self._2Droi_poly
- )
-
- def testErrors(self):
- # test if onlimits is True and give also a roi
- with self.assertRaises(ValueError):
- stats._CurveContext(item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=True,
- roi=self._1Droi)
-
- # test if is a curve context and give an invalid 2D roi
- with self.assertRaises(TypeError):
- stats._CurveContext(item=self.plot1d.getCurve('curve0'),
- plot=self.plot1d,
- onlimits=False,
- roi=self._2Droi_rect)
-
- def testBasicStatsCurve(self):
- """Test result for simple stats on a curve"""
- _stats = self.getBasicStats()
- xData = yData = numpy.array(range(0, 10))
- self.assertEqual(_stats['min'].calculate(self.curveContext), 2)
- self.assertEqual(_stats['max'].calculate(self.curveContext), 5)
- self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,))
- self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,))
- self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6]))
- self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6]))
- com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
- self.assertEqual(_stats['com'].calculate(self.curveContext), com)
-
- def testBasicStatsImageRectRoi(self):
- """Test result for simple stats on an image"""
- self.assertEqual(self.imageContext.values.compressed().size, 121)
- _stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.imageContext), 10)
- self.assertEqual(_stats['max'].calculate(self.imageContext), 1300)
- self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0))
- self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0))
- self.assertAlmostEqual(_stats['std'].calculate(self.imageContext),
- numpy.std(self.imageData[0:11, 10:21]))
- self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext),
- numpy.mean(self.imageData[0:11, 10:21]))
-
- compressed_values = self.imageContext.values.compressed()
- compressed_values = compressed_values.reshape(11, 11)
- yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1)
- xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0)
-
- dataYRange = range(11)
- dataXRange = range(10, 21)
-
- ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
- xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
- self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
-
- def testBasicStatsImagePolyRoi(self):
- """Test a simple rectangle ROI"""
- _stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0)
- self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432)
- self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0))
- # not 0.0, 19.0 because not fully in. Should all pixel have a weight,
- # on to manage them in stats. For now 0 if the center is not in, else 1
- self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0))
-
- def testBasicStatsScatter(self):
- self.assertEqual(self.scatterContext.values.compressed().size, 2)
- _stats = self.getBasicStats()
- self.assertEqual(_stats['min'].calculate(self.scatterContext), 6)
- self.assertEqual(_stats['max'].calculate(self.scatterContext), 7)
- self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3))
- self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4))
- self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7]))
- self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7]))
-
- def testBasicHistogram(self):
- _stats = self.getBasicStats()
- xData = yData = numpy.array(range(2, 6))
- self.assertEqual(_stats['min'].calculate(self.histoContext), 2)
- self.assertEqual(_stats['max'].calculate(self.histoContext), 5)
- self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,))
- self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,))
- self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData))
- self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData))
- com = numpy.sum(xData * yData) / numpy.sum(yData)
- self.assertEqual(_stats['com'].calculate(self.histoContext), com)
-
-
-class TestAdvancedROIImageContext(TestCaseQt):
- """Test stats result on an image context with different scale and
- origins"""
-
- def setUp(self):
- TestCaseQt.setUp(self)
- self.data_dims = (100, 100)
- self.data = numpy.random.rand(*self.data_dims)
- self.plot = Plot2D()
-
- def test(self):
- """Test stats result on an image context with different scale and
- origins"""
- roi_origins = [(0, 0), (2, 10), (14, 20)]
- img_origins = [(0, 0), (14, 20), (2, 10)]
- img_scales = [1.0, 0.5, 2.0]
- _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), }
- for roi_origin in roi_origins:
- for img_origin in img_origins:
- for img_scale in img_scales:
- with self.subTest(roi_origin=roi_origin,
- img_origin=img_origin,
- img_scale=img_scale):
- self.plot.addImage(self.data, legend='img',
- origin=img_origin,
- scale=img_scale)
- roi = RectangleROI()
- roi.setGeometry(origin=roi_origin, size=(20, 20))
- context = stats._ImageContext(
- item=self.plot.getImage('img'),
- plot=self.plot,
- onlimits=False,
- roi=roi)
- x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
- x_end = int(x_start + (20 / img_scale)) + 1
- y_start = int((roi_origin[1] - img_origin[1])/ img_scale)
- y_end = int(y_start + (20 / img_scale)) + 1
- x_start = max(x_start, 0)
- x_end = min(max(x_end, 0), self.data_dims[1])
- y_start = max(y_start, 0)
- y_end = min(max(y_end, 0), self.data_dims[0])
- th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
- self.assertAlmostEqual(_stats['sum'].calculate(context),
- th_sum)
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters,
- TestStatsWidgetWithImages, TestStatsWidgetWithCurves,
- TestStatsFormatter, TestEmptyStatsWidget, TestStatsROI,
- TestLineWidget, TestUpdateModeWidget, ):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/testUtilsAxis.py b/silx/gui/plot/test/testUtilsAxis.py
deleted file mode 100644
index 64373b8..0000000
--- a/silx/gui/plot/test/testUtilsAxis.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWidget"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "20/11/2018"
-
-
-import unittest
-from silx.gui.plot import PlotWidget
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot.utils.axis import SyncAxes
-
-
-class TestAxisSync(TestCaseQt):
- """Tests AxisSync class"""
-
- def setUp(self):
- TestCaseQt.setUp(self)
- self.plot1 = PlotWidget()
- self.plot2 = PlotWidget()
- self.plot3 = PlotWidget()
-
- def tearDown(self):
- self.plot1 = None
- self.plot2 = None
- self.plot3 = None
- TestCaseQt.tearDown(self)
-
- def testMoveFirstAxis(self):
- """Test synchronization after construction"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
-
- self.plot1.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testMoveSecondAxis(self):
- """Test synchronization after construction"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
-
- self.plot2.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testMoveTwoAxes(self):
- """Test synchronization after construction"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
-
- self.plot1.getXAxis().setLimits(1, 50)
- self.plot2.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testDestruction(self):
- """Test synchronization when sync object is destroyed"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- del sync
-
- self.plot1.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testAxisDestruction(self):
- """Test synchronization when an axis disappear"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
-
- # Destroy the plot is possible
- import weakref
- plot = weakref.ref(self.plot2)
- self.plot2 = None
- result = self.qWaitForDestroy(plot)
- if not result:
- # We can't test
- self.skipTest("Object not destroyed")
-
- self.plot1.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testStop(self):
- """Test synchronization after calling stop"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- sync.stop()
-
- self.plot1.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testStopMovingStart(self):
- """Test synchronization after calling stop, moving an axis, then start again"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- sync.stop()
- self.plot1.getXAxis().setLimits(10, 500)
- self.plot2.getXAxis().setLimits(1, 50)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- sync.start()
-
- # The first axis is the reference
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testDoubleStop(self):
- """Test double stop"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- sync.stop()
- self.assertRaises(RuntimeError, sync.stop)
-
- def testDoubleStart(self):
- """Test double stop"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- self.assertRaises(RuntimeError, sync.start)
-
- def testScale(self):
- """Test scale change"""
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- self.plot1.getXAxis().setScale(self.plot1.getXAxis().LOGARITHMIC)
- self.assertEqual(self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
- self.assertEqual(self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
- self.assertEqual(self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
-
- def testDirection(self):
- """Test direction change"""
- _sync = SyncAxes([self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()])
- self.plot1.getYAxis().setInverted(True)
- self.assertEqual(self.plot1.getYAxis().isInverted(), True)
- self.assertEqual(self.plot2.getYAxis().isInverted(), True)
- self.assertEqual(self.plot3.getYAxis().isInverted(), True)
-
- def testSyncCenter(self):
- """Test direction change"""
- # Not the same scale
- self.plot1.getXAxis().setLimits(0, 200)
- self.plot2.getXAxis().setLimits(0, 20)
- self.plot3.getXAxis().setLimits(0, 2)
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
- syncLimits=False, syncCenter=True)
-
- self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1))
-
- def testSyncCenterAndZoom(self):
- """Test direction change"""
- # Not the same scale
- self.plot1.getXAxis().setLimits(0, 200)
- self.plot2.getXAxis().setLimits(0, 20)
- self.plot3.getXAxis().setLimits(0, 2)
- _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
- syncLimits=False, syncCenter=True, syncZoom=True)
-
- # Supposing all the plots use the same size
- self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200))
-
- def testAddAxis(self):
- """Test synchronization after construction"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()])
- sync.addAxis(self.plot3.getXAxis())
-
- self.plot1.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
- def testRemoveAxis(self):
- """Test synchronization after construction"""
- sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
- sync.removeAxis(self.plot3.getXAxis())
-
- self.plot1.getXAxis().setLimits(10, 500)
- self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
- self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
- self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestAxisSync))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/test/utils.py b/silx/gui/plot/test/utils.py
deleted file mode 100644
index ed1917a..0000000
--- a/silx/gui/plot/test/utils.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for PlotWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/01/2018"
-
-
-import logging
-
-from silx.gui.utils.testutils import TestCaseQt
-
-from silx.gui import qt
-from silx.gui.plot import PlotWidget
-
-
-logger = logging.getLogger(__name__)
-
-
-class PlotWidgetTestCase(TestCaseQt):
- """Base class for tests of PlotWidget, not a TestCase in itself.
-
- plot attribute is the PlotWidget created for the test.
- """
-
- __screenshot_already_taken = False
-
- def __init__(self, methodName='runTest', backend=None):
- TestCaseQt.__init__(self, methodName=methodName)
- self.__backend = backend
-
- def _createPlot(self):
- return PlotWidget(backend=self.__backend)
-
- def setUp(self):
- super(PlotWidgetTestCase, self).setUp()
- self.plot = self._createPlot()
- self.plot.show()
- self.plotAlive = True
- self.qWaitForWindowExposed(self.plot)
- TestCaseQt.mouseClick(self, self.plot, button=qt.Qt.LeftButton, pos=(0, 0))
-
- def __onPlotDestroyed(self):
- self.plotAlive = False
-
- def _waitForPlotClosed(self):
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.destroyed.connect(self.__onPlotDestroyed)
- self.plot.close()
- del self.plot
- for _ in range(100):
- if not self.plotAlive:
- break
- self.qWait(10)
- else:
- logger.error("Plot is still alive")
-
- def tearDown(self):
- if not self._currentTestSucceeded():
- # MPL is the only widget which uses the real system mouse.
- # In case of a the windows is outside of the screen, minimzed,
- # overlapped by a system popup, the MPL widget will not receive the
- # mouse event.
- # Taking a screenshot help debuging this cases in the continuous
- # integration environement.
- if not PlotWidgetTestCase.__screenshot_already_taken:
- PlotWidgetTestCase.__screenshot_already_taken = True
- self.logScreenShot()
- self.qapp.processEvents()
- self._waitForPlotClosed()
- super(PlotWidgetTestCase, self).tearDown()
diff --git a/silx/gui/plot/tools/PositionInfo.py b/silx/gui/plot/tools/PositionInfo.py
deleted file mode 100644
index 81d312a..0000000
--- a/silx/gui/plot/tools/PositionInfo.py
+++ /dev/null
@@ -1,376 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a widget displaying mouse coordinates in a PlotWidget.
-
-It can be configured to provide more information.
-"""
-
-from __future__ import division
-
-__authors__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "16/10/2017"
-
-
-import logging
-import numbers
-import traceback
-import weakref
-
-import numpy
-
-from ....utils.deprecation import deprecated
-from ... import qt
-from .. import items
-from ...widgets.ElidedLabel import ElidedLabel
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _PositionInfoLabel(ElidedLabel):
- """QLabel with a default size larger than what is displayed."""
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
-
- def sizeHint(self):
- hint = super().sizeHint()
- width = self.fontMetrics().boundingRect('##############').width()
- return qt.QSize(max(hint.width(), width), hint.height())
-
-
-# PositionInfo ################################################################
-
-class PositionInfo(qt.QWidget):
- """QWidget displaying coords converted from data coords of the mouse.
-
- Provide this widget with a list of couple:
-
- - A name to display before the data
- - A function that takes (x, y) as arguments and returns something that
- gets converted to a string.
- If the result is a float it is converted with '%.7g' format.
-
- To run the following sample code, a QApplication must be initialized.
- First, create a PlotWindow and add a QToolBar where to place the
- PositionInfo widget.
-
- >>> from silx.gui.plot import PlotWindow
- >>> from silx.gui import qt
-
- >>> plot = PlotWindow() # Create a PlotWindow to add the widget to
- >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in
- >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot
-
- Then, create the PositionInfo widget and add it to the toolbar.
- The PositionInfo widget is created with a list of converters, here
- to display polar coordinates of the mouse position.
-
- >>> import numpy
- >>> from silx.gui.plot.tools import PositionInfo
-
- >>> position = PositionInfo(plot=plot, converters=[
- ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)),
- ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))])
- >>> toolBar.addWidget(position) # Add the widget to the toolbar
- <...>
- >>> plot.show() # To display the PlotWindow with the position widget
-
- :param plot: The PlotWidget this widget is displaying data coords from.
- :param converters:
- List of 2-tuple: name to display and conversion function from (x, y)
- in data coords to displayed value.
- If None, the default, it displays X and Y.
- :param parent: Parent widget
- """
-
- SNAP_THRESHOLD_DIST = 5
-
- def __init__(self, parent=None, plot=None, converters=None):
- assert plot is not None
- self._plotRef = weakref.ref(plot)
- self._snappingMode = self.SNAPPING_DISABLED
-
- super(PositionInfo, self).__init__(parent)
-
- if converters is None:
- converters = (('X', lambda x, y: x), ('Y', lambda x, y: y))
-
- self._fields = [] # To store (QLineEdit, name, function (x, y)->v)
-
- # Create a new layout with new widgets
- layout = qt.QHBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- # layout.setSpacing(0)
-
- # Create all QLabel and store them with the corresponding converter
- for name, func in converters:
- layout.addWidget(qt.QLabel('<b>' + name + ':</b>'))
-
- contentWidget = _PositionInfoLabel(self)
- contentWidget.setText('------')
- layout.addWidget(contentWidget)
- self._fields.append((contentWidget, name, func))
-
- layout.addStretch(1)
- self.setLayout(layout)
-
- # Connect to Plot events
- plot.sigPlotSignal.connect(self._plotEvent)
-
- def getPlotWidget(self):
- """Returns the PlotWidget this widget is attached to or None.
-
- :rtype: Union[~silx.gui.plot.PlotWidget,None]
- """
- return self._plotRef()
-
- @property
- @deprecated(replacement='getPlotWidget', since_version='0.8.0')
- def plot(self):
- return self.getPlotWidget()
-
- def getConverters(self):
- """Return the list of converters as 2-tuple (name, function)."""
- return [(name, func) for _label, name, func in self._fields]
-
- def _plotEvent(self, event):
- """Handle events from the Plot.
-
- :param dict event: Plot event
- """
- if event['event'] == 'mouseMoved':
- x, y = event['x'], event['y']
- xPixel, yPixel = event['xpixel'], event['ypixel']
- self._updateStatusBar(x, y, xPixel, yPixel)
-
- def updateInfo(self):
- """Update displayed information"""
- plot = self.getPlotWidget()
- if plot is None:
- _logger.error("Trying to update PositionInfo "
- "while PlotWidget no longer exists")
- return
-
- widget = plot.getWidgetHandle()
- position = widget.mapFromGlobal(qt.QCursor.pos())
- xPixel, yPixel = position.x(), position.y()
- dataPos = plot.pixelToData(xPixel, yPixel, check=True)
- if dataPos is not None: # Inside plot area
- x, y = dataPos
- self._updateStatusBar(x, y, xPixel, yPixel)
-
- def _updateStatusBar(self, x, y, xPixel, yPixel):
- """Update information from the status bar using the definitions.
-
- :param float x: Position-x in data
- :param float y: Position-y in data
- :param float xPixel: Position-x in pixels
- :param float yPixel: Position-y in pixels
- """
- plot = self.getPlotWidget()
- if plot is None:
- return
-
- styleSheet = "color: rgb(0, 0, 0);" # Default style
- xData, yData = x, y
-
- snappingMode = self.getSnappingMode()
-
- # Snapping when crosshair either not requested or active
- if (snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and
- (not (snappingMode & self.SNAPPING_CROSSHAIR) or
- plot.getGraphCursor())):
- styleSheet = "color: rgb(255, 0, 0);" # Style far from item
-
- if snappingMode & self.SNAPPING_ACTIVE_ONLY:
- selectedItems = []
-
- if snappingMode & self.SNAPPING_CURVE:
- activeCurve = plot.getActiveCurve()
- if activeCurve:
- selectedItems.append(activeCurve)
-
- if snappingMode & self.SNAPPING_SCATTER:
- activeScatter = plot._getActiveItem(kind='scatter')
- if activeScatter:
- selectedItems.append(activeScatter)
-
- else:
- kinds = []
- if snappingMode & self.SNAPPING_CURVE:
- kinds.append(items.Curve)
- kinds.append(items.Histogram)
- if snappingMode & self.SNAPPING_SCATTER:
- kinds.append(items.Scatter)
- selectedItems = [item for item in plot.getItems()
- if isinstance(item, tuple(kinds)) and item.isVisible()]
-
- # Compute distance threshold
- if qt.BINDING in ('PyQt5', 'PySide2'):
- window = plot.window()
- windowHandle = window.windowHandle()
- if windowHandle is not None:
- ratio = windowHandle.devicePixelRatio()
- else:
- ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio()
- else:
- ratio = 1.
-
- # Baseline squared distance threshold
- distInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2
-
- for item in selectedItems:
- if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
- not isinstance(item, items.SymbolMixIn) or
- not item.getSymbol())):
- # Only handled if item symbols are visible
- continue
-
- if isinstance(item, items.Histogram):
- result = item.pick(xPixel, yPixel)
- if result is not None: # Histogram picked
- index = result.getIndices()[0]
- edges = item.getBinEdgesData(copy=False)
-
- # Snap to bin center and value
- xData = 0.5 * (edges[index] + edges[index + 1])
- yData = item.getValueData(copy=False)[index]
-
- # Update label style sheet
- styleSheet = "color: rgb(0, 0, 0);"
- break
-
- else: # Curve, Scatter
- xArray = item.getXData(copy=False)
- yArray = item.getYData(copy=False)
- closestIndex = numpy.argmin(
- pow(xArray - x, 2) + pow(yArray - y, 2))
-
- xClosest = xArray[closestIndex]
- yClosest = yArray[closestIndex]
-
- if isinstance(item, items.YAxisMixIn):
- axis = item.getYAxis()
- else:
- axis = 'left'
-
- closestInPixels = plot.dataToPixel(
- xClosest, yClosest, axis=axis)
- if closestInPixels is not None:
- curveDistInPixels = (
- (closestInPixels[0] - xPixel)**2 +
- (closestInPixels[1] - yPixel)**2)
-
- if curveDistInPixels <= distInPixels:
- # Update label style sheet
- styleSheet = "color: rgb(0, 0, 0);"
-
- # if close enough, snap to data point coord
- xData, yData = xClosest, yClosest
- distInPixels = curveDistInPixels
-
- for label, name, func in self._fields:
- label.setStyleSheet(styleSheet)
-
- try:
- value = func(xData, yData)
- text = self.valueToString(value)
- label.setText(text)
- except:
- label.setText('Error')
- _logger.error(
- "Error while converting coordinates (%f, %f)"
- "with converter '%s'" % (xPixel, yPixel, name))
- _logger.error(traceback.format_exc())
-
- def valueToString(self, value):
- if isinstance(value, (tuple, list)):
- value = [self.valueToString(v) for v in value]
- return ", ".join(value)
- elif isinstance(value, numbers.Real):
- # Use this for floats and int
- return '%.7g' % value
- else:
- # Fallback for other types
- return str(value)
-
- # Snapping mode
-
- SNAPPING_DISABLED = 0
- """No snapping occurs"""
-
- SNAPPING_CROSSHAIR = 1 << 0
- """Snapping only enabled when crosshair cursor is enabled"""
-
- SNAPPING_ACTIVE_ONLY = 1 << 1
- """Snapping only enabled for active item"""
-
- SNAPPING_SYMBOLS_ONLY = 1 << 2
- """Snapping only when symbols are visible"""
-
- SNAPPING_CURVE = 1 << 3
- """Snapping on curves"""
-
- SNAPPING_SCATTER = 1 << 4
- """Snapping on scatter"""
-
- def setSnappingMode(self, mode):
- """Set the snapping mode.
-
- The mode is a mask.
-
- :param int mode: The mode to use
- """
- if mode != self._snappingMode:
- self._snappingMode = mode
- self.updateInfo()
-
- def getSnappingMode(self):
- """Returns the snapping mode as a mask
-
- :rtype: int
- """
- return self._snappingMode
-
- _SNAPPING_LEGACY = (SNAPPING_CROSSHAIR |
- SNAPPING_ACTIVE_ONLY |
- SNAPPING_SYMBOLS_ONLY |
- SNAPPING_CURVE |
- SNAPPING_SCATTER)
- """Legacy snapping mode"""
-
- @property
- @deprecated(replacement="getSnappingMode", since_version="0.8")
- def autoSnapToActiveCurve(self):
- return self.getSnappingMode() == self._SNAPPING_LEGACY
-
- @autoSnapToActiveCurve.setter
- @deprecated(replacement="setSnappingMode", since_version="0.8")
- def autoSnapToActiveCurve(self, flag):
- self.setSnappingMode(
- self._SNAPPING_LEGACY if flag else self.SNAPPING_DISABLED)
diff --git a/silx/gui/plot/tools/profile/manager.py b/silx/gui/plot/tools/profile/manager.py
deleted file mode 100644
index 68db9a6..0000000
--- a/silx/gui/plot/tools/profile/manager.py
+++ /dev/null
@@ -1,1076 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a manager to compute and display profiles.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "28/06/2018"
-
-import logging
-import weakref
-
-from silx.gui import qt
-from silx.gui import colors
-from silx.gui import utils
-
-from silx.utils.weakref import WeakMethodProxy
-from silx.gui import icons
-from silx.gui.plot import PlotWidget
-from silx.gui.plot.tools.roi import RegionOfInterestManager
-from silx.gui.plot.tools.roi import CreateRoiModeAction
-from silx.gui.plot import items
-from silx.gui.qt import silxGlobalThreadPool
-from silx.gui.qt import inspect
-from . import rois
-from . import core
-from . import editors
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _RunnableComputeProfile(qt.QRunnable):
- """Runner to process profiles
-
- :param qt.QThreadPool threadPool: The thread which will be used to
- execute this runner. It is used to update the used signals
- :param ~silx.gui.plot.items.Item item: Item in which the profile is
- computed
- :param ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn roi: ROI
- defining the profile shape and other characteristics
- """
-
- class _Signals(qt.QObject):
- """Signal holder"""
- resultReady = qt.Signal(object, object)
- runnerFinished = qt.Signal(object)
-
- def __init__(self, threadPool, item, roi):
- """Constructor
- """
- super(_RunnableComputeProfile, self).__init__()
- self._signals = self._Signals()
- self._signals.moveToThread(threadPool.thread())
- self._item = item
- self._roi = roi
- self._cancelled = False
-
- def _lazyCancel(self):
- """Cancel the runner if it is not yet started.
-
- The threadpool will still execute the runner, but this will process
- nothing.
-
- This is only used with Qt<5.9 where QThreadPool.tryTake is not available.
- """
- self._cancelled = True
-
- def autoDelete(self):
- return False
-
- def getRoi(self):
- """Returns the ROI in which the runner will compute a profile.
-
- :rtype: ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn
- """
- return self._roi
-
- @property
- def resultReady(self):
- """Signal emitted when the result of the computation is available.
-
- This signal provides 2 values: The ROI, and the computation result.
- """
- return self._signals.resultReady
-
- @property
- def runnerFinished(self):
- """Signal emitted when runner have finished.
-
- This signal provides a single value: the runner itself.
- """
- return self._signals.runnerFinished
-
- def run(self):
- """Process the profile computation.
- """
- if not self._cancelled:
- try:
- profileData = self._roi.computeProfile(self._item)
- except Exception:
- _logger.error("Error while computing profile", exc_info=True)
- else:
- self.resultReady.emit(self._roi, profileData)
- self.runnerFinished.emit(self)
-
-
-class ProfileWindow(qt.QMainWindow):
- """
- Display a computed profile.
-
- The content can be described using :meth:`setRoiProfile` if the source of
- the profile is a profile ROI, and :meth:`setProfile` for the data content.
- """
-
- sigClose = qt.Signal()
- """Emitted by :meth:`closeEvent` (e.g. when the window is closed
- through the window manager's close icon)."""
-
- def __init__(self, parent=None, backend=None):
- qt.QMainWindow.__init__(self, parent=parent, flags=qt.Qt.Dialog)
-
- self.setWindowTitle('Profile window')
- self._plot1D = None
- self._plot2D = None
- self._backend = backend
- self._data = None
-
- widget = qt.QWidget()
- self._layout = qt.QStackedLayout(widget)
- self._layout.setContentsMargins(0, 0, 0, 0)
- self.setCentralWidget(widget)
-
- def prepareWidget(self, roi):
- """Called before the show to prepare the window to use with
- a specific ROI."""
- if isinstance(roi, rois._DefaultImageStackProfileRoiMixIn):
- profileType = roi.getProfileType()
- else:
- profileType = "1D"
- if profileType == "1D":
- self.getPlot1D()
- elif profileType == "2D":
- self.getPlot2D()
-
- def createPlot1D(self, parent, backend):
- """Inherit this function to create your own plot to render 1D
- profiles. The default value is a `Plot1D`.
-
- :param parent: The parent of this widget or None.
- :param backend: The backend to use for the plot.
- See :class:`PlotWidget` for the list of supported backend.
- :rtype: PlotWidget
- """
- # import here to avoid circular import
- from ...PlotWindow import Plot1D
- plot = Plot1D(parent=parent, backend=backend)
- plot.setDataMargins(yMinMargin=0.1, yMaxMargin=0.1)
- plot.setGraphYLabel('Profile')
- plot.setGraphXLabel('')
- return plot
-
- def createPlot2D(self, parent, backend):
- """Inherit this function to create your own plot to render 2D
- profiles. The default value is a `Plot2D`.
-
- :param parent: The parent of this widget or None.
- :param backend: The backend to use for the plot.
- See :class:`PlotWidget` for the list of supported backend.
- :rtype: PlotWidget
- """
- # import here to avoid circular import
- from ...PlotWindow import Plot2D
- return Plot2D(parent=parent, backend=backend)
-
- def getPlot1D(self, init=True):
- """Return the current plot used to display curves and create it if it
- does not yet exists and `init` is True. Else returns None."""
- if not init:
- return self._plot1D
- if self._plot1D is None:
- self._plot1D = self.createPlot1D(self, self._backend)
- self._layout.addWidget(self._plot1D)
- return self._plot1D
-
- def _showPlot1D(self):
- plot = self.getPlot1D()
- self._layout.setCurrentWidget(plot)
-
- def getPlot2D(self, init=True):
- """Return the current plot used to display image and create it if it
- does not yet exists and `init` is True. Else returns None."""
- if not init:
- return self._plot2D
- if self._plot2D is None:
- self._plot2D = self.createPlot2D(parent=self, backend=self._backend)
- self._layout.addWidget(self._plot2D)
- return self._plot2D
-
- def _showPlot2D(self):
- plot = self.getPlot2D()
- self._layout.setCurrentWidget(plot)
-
- def getCurrentPlotWidget(self):
- return self._layout.currentWidget()
-
- def closeEvent(self, qCloseEvent):
- self.sigClose.emit()
- qCloseEvent.accept()
-
- def setRoiProfile(self, roi):
- """Set the profile ROI which it the source of the following data
- to display.
-
- :param ProfileRoiMixIn roi: The profile ROI data source
- """
- if roi is None:
- return
- self.__color = colors.rgba(roi.getColor())
-
- def _setImageProfile(self, data):
- """
- Setup the window to display a new profile data which is represented
- by an image.
-
- :param core.ImageProfileData data: Computed data profile
- """
- plot = self.getPlot2D()
-
- plot.clear()
- plot.setGraphTitle(data.title)
- plot.getXAxis().setLabel(data.xLabel)
-
-
- coords = data.coords
- colormap = data.colormap
- profileScale = (coords[-1] - coords[0]) / data.profile.shape[1], 1
- plot.addImage(data.profile,
- legend="profile",
- colormap=colormap,
- origin=(coords[0], 0),
- scale=profileScale)
- plot.getYAxis().setLabel("Frame index (depth)")
-
- self._showPlot2D()
-
- def _setCurveProfile(self, data):
- """
- Setup the window to display a new profile data which is represented
- by a curve.
-
- :param core.CurveProfileData data: Computed data profile
- """
- plot = self.getPlot1D()
-
- plot.clear()
- plot.setGraphTitle(data.title)
- plot.getXAxis().setLabel(data.xLabel)
- plot.getYAxis().setLabel(data.yLabel)
-
- plot.addCurve(data.coords,
- data.profile,
- legend="level",
- color=self.__color)
-
- self._showPlot1D()
-
- def _setRgbaProfile(self, data):
- """
- Setup the window to display a new profile data which is represented
- by a curve.
-
- :param core.RgbaProfileData data: Computed data profile
- """
- plot = self.getPlot1D()
-
- plot.clear()
- plot.setGraphTitle(data.title)
- plot.getXAxis().setLabel(data.xLabel)
- plot.getYAxis().setLabel(data.yLabel)
-
- self._showPlot1D()
-
- plot.addCurve(data.coords, data.profile,
- legend="level", color="black")
- plot.addCurve(data.coords, data.profile_r,
- legend="red", color="red")
- plot.addCurve(data.coords, data.profile_g,
- legend="green", color="green")
- plot.addCurve(data.coords, data.profile_b,
- legend="blue", color="blue")
- if data.profile_a is not None:
- plot.addCurve(data.coords, data.profile_a, legend="alpha", color="gray")
-
- def clear(self):
- """Clear the window profile"""
- plot = self.getPlot1D(init=False)
- if plot is not None:
- plot.clear()
- plot = self.getPlot2D(init=False)
- if plot is not None:
- plot.clear()
-
- def getProfile(self):
- """Returns the profile data which is displayed"""
- return self.__data
-
- def setProfile(self, data):
- """
- Setup the window to display a new profile data.
-
- This method dispatch the result to a specific method according to the
- data type.
-
- :param data: Computed data profile
- """
- self.__data = data
- if data is None:
- self.clear()
- elif isinstance(data, core.ImageProfileData):
- self._setImageProfile(data)
- elif isinstance(data, core.RgbaProfileData):
- self._setRgbaProfile(data)
- elif isinstance(data, core.CurveProfileData):
- self._setCurveProfile(data)
- else:
- raise TypeError("Unsupported type %s" % type(data))
-
-
-class _ClearAction(qt.QAction):
- """Action to clear the profile manager
-
- The action is only enabled if something can be cleaned up.
- """
-
- def __init__(self, parent, profileManager):
- super(_ClearAction, self).__init__(parent)
- self.__profileManager = weakref.ref(profileManager)
- icon = icons.getQIcon('profile-clear')
- self.setIcon(icon)
- self.setText('Clear profile')
- self.setToolTip('Clear the profiles')
- self.setCheckable(False)
- self.setEnabled(False)
- self.triggered.connect(profileManager.clearProfile)
- plot = profileManager.getPlotWidget()
- roiManager = profileManager.getRoiManager()
- plot.sigInteractiveModeChanged.connect(self.__modeUpdated)
- roiManager.sigRoiChanged.connect(self.__roiListUpdated)
-
- def getProfileManager(self):
- return self.__profileManager()
-
- def __roiListUpdated(self):
- self.__update()
-
- def __modeUpdated(self, source):
- self.__update()
-
- def __update(self):
- profileManager = self.getProfileManager()
- if profileManager is None:
- return
- roiManager = profileManager.getRoiManager()
- if roiManager is None:
- return
- enabled = roiManager.isStarted() or len(roiManager.getRois()) > 0
- self.setEnabled(enabled)
-
-
-class _StoreLastParamBehavior(qt.QObject):
- """This object allow to store and restore the properties of the ROI
- profiles"""
-
- def __init__(self, parent):
- assert isinstance(parent, ProfileManager)
- super(_StoreLastParamBehavior, self).__init__(parent=parent)
- self.__properties = {}
- self.__profileRoi = None
- self.__filter = utils.LockReentrant()
-
- def _roi(self):
- """Return the spied ROI"""
- if self.__profileRoi is None:
- return None
- roi = self.__profileRoi()
- if roi is None:
- self.__profileRoi = None
- return roi
-
- def setProfileRoi(self, roi):
- """Set a profile ROI to spy.
-
- :param ProfileRoiMixIn roi: A profile ROI
- """
- previousRoi = self._roi()
- if previousRoi is roi:
- return
- if previousRoi is not None:
- previousRoi.sigProfilePropertyChanged.disconnect(self._profilePropertyChanged)
- self.__profileRoi = None if roi is None else weakref.ref(roi)
- if roi is not None:
- roi.sigProfilePropertyChanged.connect(self._profilePropertyChanged)
-
- def _profilePropertyChanged(self):
- """Handle changes on the properties defining the profile ROI.
- """
- if self.__filter.locked():
- return
- roi = self.sender()
- self.storeProperties(roi)
-
- def storeProperties(self, roi):
- if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
- rois.ProfileImageStackCrossROI)):
- self.__properties["method"] = roi.getProfileMethod()
- self.__properties["line-width"] = roi.getProfileLineWidth()
- self.__properties["type"] = roi.getProfileType()
- elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
- rois.ProfileImageCrossROI)):
- self.__properties["method"] = roi.getProfileMethod()
- self.__properties["line-width"] = roi.getProfileLineWidth()
- elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
- rois.ProfileScatterCrossROI)):
- self.__properties["npoints"] = roi.getNPoints()
-
- def restoreProperties(self, roi):
- with self.__filter:
- if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
- rois.ProfileImageStackCrossROI)):
- value = self.__properties.get("method", None)
- if value is not None:
- roi.setProfileMethod(value)
- value = self.__properties.get("line-width", None)
- if value is not None:
- roi.setProfileLineWidth(value)
- value = self.__properties.get("type", None)
- if value is not None:
- roi.setProfileType(value)
- elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
- rois.ProfileImageCrossROI)):
- value = self.__properties.get("method", None)
- if value is not None:
- roi.setProfileMethod(value)
- value = self.__properties.get("line-width", None)
- if value is not None:
- roi.setProfileLineWidth(value)
- elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
- rois.ProfileScatterCrossROI)):
- value = self.__properties.get("npoints", None)
- if value is not None:
- roi.setNPoints(value)
-
-
-class ProfileManager(qt.QObject):
- """Base class for profile management tools
-
- :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
- :param plot: :class:`~silx.gui.plot.tools.roi.RegionOfInterestManager`
- on which to operate.
- """
- def __init__(self, parent=None, plot=None, roiManager=None):
- super(ProfileManager, self).__init__(parent)
-
- assert isinstance(plot, PlotWidget)
- self._plotRef = weakref.ref(
- plot, WeakMethodProxy(self.__plotDestroyed))
-
- # Set-up interaction manager
- if roiManager is None:
- roiManager = RegionOfInterestManager(plot)
-
- self._roiManagerRef = weakref.ref(roiManager)
- self._rois = []
- self._pendingRunners = []
- """List of ROIs which have to be updated"""
-
- self.__reentrantResults = {}
- """Store reentrant result to avoid to skip some of them
- cause the implementation uses a QEventLoop."""
-
- self._profileWindowClass = ProfileWindow
- """Class used to display the profile results"""
-
- self._computedProfiles = 0
- """Statistics for tests"""
-
- self.__itemTypes = []
- """Kind of items to use"""
-
- self.__tracking = False
- """Is the plot active items are tracked"""
-
- self.__useColorFromCursor = True
- """If true, force the ROI color with the colormap marker color"""
-
- self._item = None
- """The selected item"""
-
- self.__singleProfileAtATime = True
- """When it's true, only a single profile is displayed at a time."""
-
- self._previousWindowGeometry = []
-
- self._storeProperties = _StoreLastParamBehavior(self)
- """If defined the profile properties of the last ROI are reused to the
- new created ones"""
-
- # Listen to plot limits changed
- plot.getXAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
- plot.getYAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
-
- roiManager.sigInteractiveModeFinished.connect(self.__interactionFinished)
- roiManager.sigInteractiveRoiCreated.connect(self.__roiCreated)
- roiManager.sigRoiAdded.connect(self.__roiAdded)
- roiManager.sigRoiAboutToBeRemoved.connect(self.__roiRemoved)
-
- def setSingleProfile(self, enable):
- """
- Enable or disable the single profile mode.
-
- In single mode, the manager enforce a single ROI at the same
- time. A new one will remove the previous one.
-
- If this mode is not enabled, many ROIs can be created, and many
- profile windows will be displayed.
- """
- self.__singleProfileAtATime = enable
-
- def isSingleProfile(self):
- """
- Returns true if the manager is in a single profile mode.
-
- :rtype: bool
- """
- return self.__singleProfileAtATime
-
- def __interactionFinished(self):
- """Handle end of interactive mode"""
- pass
-
- def __roiAdded(self, roi):
- """Handle new ROI"""
- # Filter out non profile ROIs
- if not isinstance(roi, core.ProfileRoiMixIn):
- return
- self.__addProfile(roi)
-
- def __roiRemoved(self, roi):
- """Handle removed ROI"""
- # Filter out non profile ROIs
- if not isinstance(roi, core.ProfileRoiMixIn):
- return
- self.__removeProfile(roi)
-
- def createProfileAction(self, profileRoiClass, parent=None):
- """Create an action from a class of ProfileRoi
-
- :param core.ProfileRoiMixIn profileRoiClass: A class of a profile ROI
- :param qt.QObject parent: The parent of the created action.
- :rtype: qt.QAction
- """
- if not issubclass(profileRoiClass, core.ProfileRoiMixIn):
- raise TypeError("Type %s not expected" % type(profileRoiClass))
- roiManager = self.getRoiManager()
- action = CreateRoiModeAction(parent, roiManager, profileRoiClass)
- if hasattr(profileRoiClass, "ICON"):
- action.setIcon(icons.getQIcon(profileRoiClass.ICON))
- if hasattr(profileRoiClass, "NAME"):
- def articulify(word):
- """Add an an/a article in the front of the word"""
- first = word[1] if word[0] == 'h' else word[0]
- if first in "aeiou":
- return "an " + word
- return "a " + word
- action.setText('Define %s' % articulify(profileRoiClass.NAME))
- action.setToolTip('Enables %s selection mode' % profileRoiClass.NAME)
- action.setSingleShot(True)
- return action
-
- def createClearAction(self, parent):
- """Create an action to clean up the plot from the profile ROIs.
-
- :param qt.QObject parent: The parent of the created action.
- :rtype: qt.QAction
- """
- action = _ClearAction(parent, self)
- return action
-
- def createImageActions(self, parent):
- """Create actions designed for image items. This actions created
- new ROIs.
-
- :param qt.QObject parent: The parent of the created action.
- :rtype: List[qt.QAction]
- """
- profileClasses = [
- rois.ProfileImageHorizontalLineROI,
- rois.ProfileImageVerticalLineROI,
- rois.ProfileImageLineROI,
- rois.ProfileImageDirectedLineROI,
- rois.ProfileImageCrossROI,
- ]
- return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
-
- def createScatterActions(self, parent):
- """Create actions designed for scatter items. This actions created
- new ROIs.
-
- :param qt.QObject parent: The parent of the created action.
- :rtype: List[qt.QAction]
- """
- profileClasses = [
- rois.ProfileScatterHorizontalLineROI,
- rois.ProfileScatterVerticalLineROI,
- rois.ProfileScatterLineROI,
- rois.ProfileScatterCrossROI,
- ]
- return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
-
- def createScatterSliceActions(self, parent):
- """Create actions designed for regular scatter items. This actions
- created new ROIs.
-
- This ROIs was designed to use the input data without interpolation,
- like you could do with an image.
-
- :param qt.QObject parent: The parent of the created action.
- :rtype: List[qt.QAction]
- """
- profileClasses = [
- rois.ProfileScatterHorizontalSliceROI,
- rois.ProfileScatterVerticalSliceROI,
- rois.ProfileScatterCrossSliceROI,
- ]
- return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
-
- def createImageStackActions(self, parent):
- """Create actions designed for stack image items. This actions
- created new ROIs.
-
- This ROIs was designed to create both profile on the displayed image
- and profile on the full stack (2D result).
-
- :param qt.QObject parent: The parent of the created action.
- :rtype: List[qt.QAction]
- """
- profileClasses = [
- rois.ProfileImageStackHorizontalLineROI,
- rois.ProfileImageStackVerticalLineROI,
- rois.ProfileImageStackLineROI,
- rois.ProfileImageStackCrossROI,
- ]
- return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
-
- def createEditorAction(self, parent):
- """Create an action containing GUI to edit the selected profile ROI.
-
- :param qt.QObject parent: The parent of the created action.
- :rtype: qt.QAction
- """
- action = editors.ProfileRoiEditorAction(parent)
- action.setRoiManager(self.getRoiManager())
- return action
-
- def setItemType(self, image=False, scatter=False):
- """Set the item type to use and select the active one.
-
- :param bool image: Image item are allowed
- :param bool scatter: Scatter item are allowed
- """
- self.__itemTypes = []
- plot = self.getPlotWidget()
- item = None
- if image:
- self.__itemTypes.append("image")
- item = plot.getActiveImage()
- if scatter:
- self.__itemTypes.append("scatter")
- if item is None:
- item = plot.getActiveScatter()
- self.setPlotItem(item)
-
- def setProfileWindowClass(self, profileWindowClass):
- """Set the class which will be instantiated to display profile result.
- """
- self._profileWindowClass = profileWindowClass
-
- def setActiveItemTracking(self, tracking):
- """Enable/disable the tracking of the active item of the plot.
-
- :param bool tracking: Tracking mode
- """
- if self.__tracking == tracking:
- return
- plot = self.getPlotWidget()
- if self.__tracking:
- plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
- plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
- self.__tracking = tracking
- if self.__tracking:
- plot.sigActiveImageChanged.connect(self.__activeImageChanged)
- plot.sigActiveScatterChanged.connect(self.__activeScatterChanged)
-
- def setDefaultColorFromCursorColor(self, enabled):
- """Enabled/disable the use of the colormap cursor color to display the
- ROIs.
-
- If set, the manager will update the color of the profile ROIs using the
- current colormap cursor color from the selected item.
- """
- self.__useColorFromCursor = enabled
-
- def __activeImageChanged(self, previous, legend):
- """Handle plot item selection"""
- if "image" in self.__itemTypes:
- plot = self.getPlotWidget()
- item = plot.getImage(legend)
- self.setPlotItem(item)
-
- def __activeScatterChanged(self, previous, legend):
- """Handle plot item selection"""
- if "scatter" in self.__itemTypes:
- plot = self.getPlotWidget()
- item = plot.getScatter(legend)
- self.setPlotItem(item)
-
- def __roiCreated(self, roi):
- """Handle ROI creation"""
- # Filter out non profile ROIs
- if isinstance(roi, core.ProfileRoiMixIn):
- if self._storeProperties is not None:
- # Initialize the properties with the previous ones
- self._storeProperties.restoreProperties(roi)
-
- def __addProfile(self, profileRoi):
- """Add a new ROI to the manager."""
- if profileRoi.getFocusProxy() is None:
- if self._storeProperties is not None:
- # Follow changes on properties
- self._storeProperties.setProfileRoi(profileRoi)
- if self.__singleProfileAtATime:
- # FIXME: It would be good to reuse the windows to avoid blinking
- self.clearProfile()
-
- profileRoi._setProfileManager(self)
- self._updateRoiColor(profileRoi)
- self._rois.append(profileRoi)
- self.requestUpdateProfile(profileRoi)
-
- def __removeProfile(self, profileRoi):
- """Remove a ROI from the manager."""
- window = self._disconnectProfileWindow(profileRoi)
- if window is not None:
- geometry = window.geometry()
- if not geometry.isEmpty():
- self._previousWindowGeometry.append(geometry)
- self.clearProfileWindow(window)
- if profileRoi in self._rois:
- self._rois.remove(profileRoi)
-
- def _disconnectProfileWindow(self, profileRoi):
- """Handle profile window close."""
- window = profileRoi.getProfileWindow()
- profileRoi.setProfileWindow(None)
- return window
-
- def clearProfile(self):
- """Clear the associated ROI profile"""
- roiManager = self.getRoiManager()
- for roi in list(self._rois):
- if roi.getFocusProxy() is not None:
- # Skip sub ROIs, it will be removed by their parents
- continue
- roiManager.removeRoi(roi)
-
- if not roiManager.isDrawing():
- # Clean the selected mode
- roiManager.stop()
-
- def hasPendingOperations(self):
- """Returns true if a thread is still computing or displaying a profile.
-
- :rtype: bool
- """
- return len(self.__reentrantResults) > 0 or len(self._pendingRunners) > 0
-
- def requestUpdateAllProfile(self):
- """Request to update the profile of all the managed ROIs.
- """
- for roi in self._rois:
- self.requestUpdateProfile(roi)
-
- def requestUpdateProfile(self, profileRoi):
- """Request to update a specific profile ROI.
-
- :param ~core.ProfileRoiMixIn profileRoi:
- """
- if profileRoi.computeProfile is None:
- return
- threadPool = silxGlobalThreadPool()
-
- # Clean up deprecated runners
- for runner in list(self._pendingRunners):
- if not inspect.isValid(runner):
- self._pendingRunners.remove(runner)
- continue
- if runner.getRoi() is profileRoi:
- if hasattr(threadPool, "tryTake"):
- if threadPool.tryTake(runner):
- self._pendingRunners.remove(runner)
- else: # Support Qt<5.9
- runner._lazyCancel()
-
- item = self.getPlotItem()
- if item is None or not isinstance(item, profileRoi.ITEM_KIND):
- # This item is not compatible with this profile
- profileRoi._setPlotItem(None)
- profileWindow = profileRoi.getProfileWindow()
- if profileWindow is not None:
- profileWindow.setProfile(None)
- return
-
- profileRoi._setPlotItem(item)
- runner = _RunnableComputeProfile(threadPool, item, profileRoi)
- runner.runnerFinished.connect(self.__cleanUpRunner)
- runner.resultReady.connect(self.__displayResult)
- self._pendingRunners.append(runner)
- threadPool.start(runner)
-
- def __cleanUpRunner(self, runner):
- """Remove a thread pool runner from the list of hold tasks.
-
- Called at the termination of the runner.
- """
- if runner in self._pendingRunners:
- self._pendingRunners.remove(runner)
-
- def __displayResult(self, roi, profileData):
- """Display the result of a ROI.
-
- :param ~core.ProfileRoiMixIn profileRoi: A managed ROI
- :param ~core.CurveProfileData profileData: Computed data profile
- """
- if roi in self.__reentrantResults:
- # Store the data to process it in the main loop
- # And not a sub loop created by initProfileWindow
- # This also remove the duplicated requested
- self.__reentrantResults[roi] = profileData
- return
-
- self.__reentrantResults[roi] = profileData
- self._computedProfiles = self._computedProfiles + 1
- window = roi.getProfileWindow()
- if window is None:
- plot = self.getPlotWidget()
- window = self.createProfileWindow(plot, roi)
- # roi.profileWindow have to be set before initializing the window
- # Cause the initialization is using QEventLoop
- roi.setProfileWindow(window)
- self.initProfileWindow(window, roi)
- window.show()
-
- lastData = self.__reentrantResults.pop(roi)
- window.setProfile(lastData)
-
- def __plotDestroyed(self, ref):
- """Handle finalization of PlotWidget
-
- :param ref: weakref to the plot
- """
- self._plotRef = None
- self._roiManagerRef = None
- self._pendingRunners = []
-
- def setPlotItem(self, item):
- """Set the plot item focused by the profile manager.
-
- :param ~silx.gui.plot.items.Item item: A plot item
- """
- previous = self.getPlotItem()
- if previous is item:
- return
- if item is None:
- self._item = None
- else:
- item.sigItemChanged.connect(self.__itemChanged)
- self._item = weakref.ref(item)
- self._updateRoiColors()
- self.requestUpdateAllProfile()
-
- def getDefaultColor(self, item):
- """Returns the default ROI color to use according to the given item.
-
- :param ~silx.gui.plot.items.item.Item item: AN item
- :rtype: qt.QColor
- """
- color = 'pink'
- if isinstance(item, items.ColormapMixIn):
- colormap = item.getColormap()
- name = colormap.getName()
- if name is not None:
- color = colors.cursorColorForColormap(name)
- color = colors.asQColor(color)
- return color
-
- def _updateRoiColors(self):
- """Update ROI color according to the item selection"""
- if not self.__useColorFromCursor:
- return
- item = self.getPlotItem()
- color = self.getDefaultColor(item)
- for roi in self._rois:
- roi.setColor(color)
-
- def _updateRoiColor(self, roi):
- """Update a specific ROI according to the current selected item.
-
- :param RegionOfInterest roi: The ROI to update
- """
- if not self.__useColorFromCursor:
- return
- item = self.getPlotItem()
- color = self.getDefaultColor(item)
- roi.setColor(color)
-
- def __itemChanged(self, changeType):
- """Handle item changes.
- """
- if changeType in (items.ItemChangedType.DATA,
- items.ItemChangedType.MASK,
- items.ItemChangedType.POSITION,
- items.ItemChangedType.SCALE):
- self.requestUpdateAllProfile()
- elif changeType == (items.ItemChangedType.COLORMAP):
- self._updateRoiColors()
-
- def getPlotItem(self):
- """Returns the item focused by the profile manager.
-
- :rtype: ~silx.gui.plot.items.Item
- """
- if self._item is None:
- return None
- item = self._item()
- if item is None:
- self._item = None
- return item
-
- def getPlotWidget(self):
- """The plot associated to the profile manager.
-
- :rtype: ~silx.gui.plot.PlotWidget
- """
- if self._plotRef is None:
- return None
- plot = self._plotRef()
- if plot is None:
- self._plotRef = None
- return plot
-
- def getCurrentRoi(self):
- """Returns the currently selected ROI, else None.
-
- :rtype: core.ProfileRoiMixIn
- """
- roiManager = self.getRoiManager()
- if roiManager is None:
- return None
- roi = roiManager.getCurrentRoi()
- if not isinstance(roi, core.ProfileRoiMixIn):
- return None
- return roi
-
- def getRoiManager(self):
- """Returns the used ROI manager
-
- :rtype: RegionOfInterestManager
- """
- return self._roiManagerRef()
-
- def createProfileWindow(self, plot, roi):
- """Create a new profile window.
-
- :param ~core.ProfileRoiMixIn roi: The plot containing the raw data
- :param ~core.ProfileRoiMixIn roi: A managed ROI
- :rtype: ~ProfileWindow
- """
- return self._profileWindowClass(plot)
-
- def initProfileWindow(self, profileWindow, roi):
- """This function is called just after the profile window creation in
- order to initialize the window location.
-
- :param ~ProfileWindow profileWindow:
- The profile window to initialize.
- """
- # Enforce the use of one of the widgets
- # To have the correct window size
- profileWindow.prepareWidget(roi)
- profileWindow.adjustSize()
-
- # Trick to avoid blinking while retrieving the right window size
- # Display the window, hide it and wait for some event loops
- profileWindow.show()
- profileWindow.hide()
- eventLoop = qt.QEventLoop(self)
- for _ in range(10):
- if not eventLoop.processEvents():
- break
-
- profileWindow.show()
- if len(self._previousWindowGeometry) > 0:
- geometry = self._previousWindowGeometry.pop()
- profileWindow.setGeometry(geometry)
- return
-
- window = self.getPlotWidget().window()
- winGeom = window.frameGeometry()
- qapp = qt.QApplication.instance()
- desktop = qapp.desktop()
- screenGeom = desktop.availableGeometry(window)
- spaceOnLeftSide = winGeom.left()
- spaceOnRightSide = screenGeom.width() - winGeom.right()
-
- profileGeom = profileWindow.frameGeometry()
- profileWidth = profileGeom.width()
-
- # Align vertically to the center of the window
- top = winGeom.top() + (winGeom.height() - profileGeom.height()) // 2
-
- margin = 5
- if profileWidth < spaceOnRightSide:
- # Place profile on the right
- left = winGeom.right() + margin
- elif profileWidth < spaceOnLeftSide:
- # Place profile on the left
- left = max(0, winGeom.left() - profileWidth - margin)
- else:
- # Move it as much as possible where there is more space
- if spaceOnLeftSide > spaceOnRightSide:
- left = 0
- else:
- left = screenGeom.width() - profileGeom.width()
- profileWindow.move(left, top)
-
-
- def clearProfileWindow(self, profileWindow):
- """Called when a profile window is not anymore needed.
-
- By default the window will be closed. But it can be
- inherited to change this behavior.
- """
- profileWindow.deleteLater()
diff --git a/silx/gui/plot/tools/profile/rois.py b/silx/gui/plot/tools/profile/rois.py
deleted file mode 100644
index eb7e975..0000000
--- a/silx/gui/plot/tools/profile/rois.py
+++ /dev/null
@@ -1,1156 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module define ROIs for profile tools.
-
-.. inheritance-diagram::
- silx.gui.plot.tools.profile.rois
- :top-classes: silx.gui.plot.tools.profile.core.ProfileRoiMixIn, silx.gui.plot.items.roi.RegionOfInterest
- :parts: 1
- :private-bases:
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "01/12/2020"
-
-import numpy
-import weakref
-from concurrent.futures import CancelledError
-
-from silx.gui import colors
-
-from silx.gui.plot import items
-from silx.gui.plot.items import roi as roi_items
-from . import core
-from silx.gui import utils
-from .....utils.proxy import docstring
-
-
-def _relabelAxes(plot, text):
- """Relabel {xlabel} and {ylabel} from this text using the corresponding
- plot axis label. If the axis label is empty, label it with "X" and "Y".
-
- :rtype: str
- """
- xLabel = plot.getXAxis().getLabel()
- if not xLabel:
- xLabel = "X"
- yLabel = plot.getYAxis().getLabel()
- if not yLabel:
- yLabel = "Y"
- return text.format(xlabel=xLabel, ylabel=yLabel)
-
-
-def _lineProfileTitle(x0, y0, x1, y1):
- """Compute corresponding plot title
-
- This can be overridden to change title behavior.
-
- :param float x0: Profile start point X coord
- :param float y0: Profile start point Y coord
- :param float x1: Profile end point X coord
- :param float y1: Profile end point Y coord
- :return: Title to use
- :rtype: str
- """
- if x0 == x1:
- title = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1)
- elif y0 == y1:
- title = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1)
- else:
- m = (y1 - y0) / (x1 - x0)
- b = y0 - m * x0
- title = '{ylabel} = %g * {xlabel} %+g' % (m, b)
-
- return title
-
-
-class _ImageProfileArea(items.Shape):
- """This shape displays the location of pixels used to compute the
- profile."""
-
- def __init__(self, parentRoi):
- items.Shape.__init__(self, "polygon")
- color = colors.rgba(parentRoi.getColor())
- self.setColor(color)
- self.setFill(True)
- self.setOverlay(True)
- self.setPoints([[0, 0], [0, 0]]) # Else it segfault
-
- self.__parentRoi = weakref.ref(parentRoi)
- parentRoi.sigItemChanged.connect(self._updateAreaProperty)
- parentRoi.sigRegionChanged.connect(self._updateArea)
- parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
- parentRoi.sigPlotItemChanged.connect(self._updateArea)
-
- def getParentRoi(self):
- if self.__parentRoi is None:
- return None
- parentRoi = self.__parentRoi()
- if parentRoi is None:
- self.__parentRoi = None
- return parentRoi
-
- def _updateAreaProperty(self, event=None, checkVisibility=True):
- parentRoi = self.sender()
- if event == items.ItemChangedType.COLOR:
- parentRoi._updateItemProperty(event, parentRoi, self)
- elif event == items.ItemChangedType.VISIBLE:
- if self.getPlotItem() is not None:
- parentRoi._updateItemProperty(event, parentRoi, self)
-
- def _updateArea(self):
- roi = self.getParentRoi()
- item = roi.getPlotItem()
- if item is None:
- self.setVisible(False)
- return
- polygon = self._computePolygon(item)
- self.setVisible(True)
- polygon = numpy.array(polygon).T
- self.setLineStyle("--")
- self.setPoints(polygon, copy=False)
-
- def _computePolygon(self, item):
- if not isinstance(item, items.ImageBase):
- raise TypeError("Unexpected class %s" % type(item))
-
- currentData = item.getValueData(copy=False)
-
- roi = self.getParentRoi()
- origin = item.getOrigin()
- scale = item.getScale()
- _coords, _profile, area, _profileName, _xLabel = core.createProfile(
- roiInfo=roi._getRoiInfo(),
- currentData=currentData,
- origin=origin,
- scale=scale,
- lineWidth=roi.getProfileLineWidth(),
- method="none")
- return area
-
-
-class _SliceProfileArea(items.Shape):
- """This shape displays the location a profile in a scatter.
-
- Each point used to compute the slice are linked together.
- """
-
- def __init__(self, parentRoi):
- items.Shape.__init__(self, "polygon")
- color = colors.rgba(parentRoi.getColor())
- self.setColor(color)
- self.setFill(True)
- self.setOverlay(True)
- self.setPoints([[0, 0], [0, 0]]) # Else it segfault
-
- self.__parentRoi = weakref.ref(parentRoi)
- parentRoi.sigItemChanged.connect(self._updateAreaProperty)
- parentRoi.sigRegionChanged.connect(self._updateArea)
- parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
- parentRoi.sigPlotItemChanged.connect(self._updateArea)
-
- def getParentRoi(self):
- if self.__parentRoi is None:
- return None
- parentRoi = self.__parentRoi()
- if parentRoi is None:
- self.__parentRoi = None
- return parentRoi
-
- def _updateAreaProperty(self, event=None, checkVisibility=True):
- parentRoi = self.sender()
- if event == items.ItemChangedType.COLOR:
- parentRoi._updateItemProperty(event, parentRoi, self)
- elif event == items.ItemChangedType.VISIBLE:
- if self.getPlotItem() is not None:
- parentRoi._updateItemProperty(event, parentRoi, self)
-
- def _updateArea(self):
- roi = self.getParentRoi()
- item = roi.getPlotItem()
- if item is None:
- self.setVisible(False)
- return
- polylines = self._computePolylines(roi, item)
- if polylines is None:
- self.setVisible(False)
- return
- self.setVisible(True)
- self.setLineStyle("--")
- self.setPoints(polylines, copy=False)
-
- def _computePolylines(self, roi, item):
- slicing = roi._getSlice(item)
- if slicing is None:
- return None
- xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
- xx, yy = xx[slicing], yy[slicing]
- polylines = numpy.array((xx, yy)).T
- if len(polylines) == 0:
- return None
- return polylines
-
-
-class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
- """Provide common behavior for silx default image profile ROI.
- """
-
- ITEM_KIND = items.ImageBase
-
- def __init__(self, parent=None):
- core.ProfileRoiMixIn.__init__(self, parent=parent)
- self.__method = "mean"
- self.__width = 1
- self.sigRegionChanged.connect(self.__regionChanged)
- self.sigPlotItemChanged.connect(self.__updateArea)
- self.__area = _ImageProfileArea(self)
- self.addItem(self.__area)
-
- def __regionChanged(self):
- self.invalidateProfile()
- self.__updateArea()
-
- def setProfileMethod(self, method):
- """
- :param str method: method to compute the profile. Can be 'mean' or 'sum'
- """
- if self.__method == method:
- return
- self.__method = method
- self.invalidateProperties()
- self.invalidateProfile()
-
- def getProfileMethod(self):
- return self.__method
-
- def setProfileLineWidth(self, width):
- if self.__width == width:
- return
- self.__width = width
- self.__updateArea()
- self.invalidateProperties()
- self.invalidateProfile()
-
- def getProfileLineWidth(self):
- return self.__width
-
- def __updateArea(self):
- plotItem = self.getPlotItem()
- if plotItem is None:
- self.setLineStyle("-")
- else:
- self.setLineStyle("--")
-
- def _getRoiInfo(self):
- """Wrapper to allow to reuse the previous Profile code.
-
- It would be good to remove it at one point.
- """
- if isinstance(self, roi_items.HorizontalLineROI):
- lineProjectionMode = 'X'
- y = self.getPosition()
- roiStart = (0, y)
- roiEnd = (1, y)
- elif isinstance(self, roi_items.VerticalLineROI):
- lineProjectionMode = 'Y'
- x = self.getPosition()
- roiStart = (x, 0)
- roiEnd = (x, 1)
- elif isinstance(self, roi_items.LineROI):
- lineProjectionMode = 'D'
- roiStart, roiEnd = self.getEndPoints()
- else:
- assert False
-
- return roiStart, roiEnd, lineProjectionMode
-
- def computeProfile(self, item):
- if not isinstance(item, items.ImageBase):
- raise TypeError("Unexpected class %s" % type(item))
-
- origin = item.getOrigin()
- scale = item.getScale()
- method = self.getProfileMethod()
- lineWidth = self.getProfileLineWidth()
-
- def createProfile2(currentData):
- coords, profile, _area, profileName, xLabel = core.createProfile(
- roiInfo=self._getRoiInfo(),
- currentData=currentData,
- origin=origin,
- scale=scale,
- lineWidth=lineWidth,
- method=method)
- return coords, profile, profileName, xLabel
-
- currentData = item.getValueData(copy=False)
-
- yLabel = "%s" % str(method).capitalize()
- coords, profile, title, xLabel = createProfile2(currentData)
- title = title + "; width = %d" % lineWidth
-
- # Use the axis names from the original plot
- profileManager = self.getProfileManager()
- plot = profileManager.getPlotWidget()
- title = _relabelAxes(plot, title)
- xLabel = _relabelAxes(plot, xLabel)
-
- if isinstance(item, items.ImageRgba):
- rgba = item.getData(copy=False)
- _coords, r, _profileName, _xLabel = createProfile2(rgba[..., 0])
- _coords, g, _profileName, _xLabel = createProfile2(rgba[..., 1])
- _coords, b, _profileName, _xLabel = createProfile2(rgba[..., 2])
- if rgba.shape[-1] == 4:
- _coords, a, _profileName, _xLabel = createProfile2(rgba[..., 3])
- else:
- a = [None]
- data = core.RgbaProfileData(
- coords=coords,
- profile=profile[0],
- profile_r=r[0],
- profile_g=g[0],
- profile_b=b[0],
- profile_a=a[0],
- title=title,
- xLabel=xLabel,
- yLabel=yLabel,
- )
- else:
- data = core.CurveProfileData(
- coords=coords,
- profile=profile[0],
- title=title,
- xLabel=xLabel,
- yLabel=yLabel,
- )
- return data
-
-
-class ProfileImageHorizontalLineROI(roi_items.HorizontalLineROI,
- _DefaultImageProfileRoiMixIn):
- """ROI for an horizontal profile at a location of an image"""
-
- ICON = 'shape-horizontal'
- NAME = 'horizontal line profile'
-
- def __init__(self, parent=None):
- roi_items.HorizontalLineROI.__init__(self, parent=parent)
- _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileImageVerticalLineROI(roi_items.VerticalLineROI,
- _DefaultImageProfileRoiMixIn):
- """ROI for a vertical profile at a location of an image"""
-
- ICON = 'shape-vertical'
- NAME = 'vertical line profile'
-
- def __init__(self, parent=None):
- roi_items.VerticalLineROI.__init__(self, parent=parent)
- _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileImageLineROI(roi_items.LineROI,
- _DefaultImageProfileRoiMixIn):
- """ROI for an image profile between 2 points.
-
- The X profile of this ROI is the projecting into one of the x/y axes,
- using its scale and its orientation.
- """
-
- ICON = 'shape-diagonal'
- NAME = 'line profile'
-
- def __init__(self, parent=None):
- roi_items.LineROI.__init__(self, parent=parent)
- _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileImageDirectedLineROI(roi_items.LineROI,
- _DefaultImageProfileRoiMixIn):
- """ROI for an image profile between 2 points.
-
- The X profile of the line is displayed projected into the line itself,
- using its scale and its orientation. It's the distance from the origin.
- """
-
- ICON = 'shape-diagonal-directed'
- NAME = 'directed line profile'
-
- def __init__(self, parent=None):
- roi_items.LineROI.__init__(self, parent=parent)
- _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
- self._handleStart.setSymbol('o')
-
- def computeProfile(self, item):
- if not isinstance(item, items.ImageBase):
- raise TypeError("Unexpected class %s" % type(item))
-
- from silx.image.bilinear import BilinearImage
-
- origin = item.getOrigin()
- scale = item.getScale()
- method = self.getProfileMethod()
- lineWidth = self.getProfileLineWidth()
- currentData = item.getValueData(copy=False)
-
- roiInfo = self._getRoiInfo()
- roiStart, roiEnd, _lineProjectionMode = roiInfo
-
- startPt = ((roiStart[1] - origin[1]) / scale[1],
- (roiStart[0] - origin[0]) / scale[0])
- endPt = ((roiEnd[1] - origin[1]) / scale[1],
- (roiEnd[0] - origin[0]) / scale[0])
-
- if numpy.array_equal(startPt, endPt):
- return None
-
- bilinear = BilinearImage(currentData)
- profile = bilinear.profile_line(
- (startPt[0] - 0.5, startPt[1] - 0.5),
- (endPt[0] - 0.5, endPt[1] - 0.5),
- lineWidth,
- method=method)
-
- # Compute the line size
- lineSize = numpy.sqrt((roiEnd[1] - roiStart[1]) ** 2 +
- (roiEnd[0] - roiStart[0]) ** 2)
- coords = numpy.linspace(0, lineSize, len(profile),
- endpoint=True,
- dtype=numpy.float32)
-
- title = _lineProfileTitle(*roiStart, *roiEnd)
- title = title + "; width = %d" % lineWidth
- xLabel = "√({xlabel}²+{ylabel}²)"
- yLabel = str(method).capitalize()
-
- # Use the axis names from the original plot
- profileManager = self.getProfileManager()
- plot = profileManager.getPlotWidget()
- xLabel = _relabelAxes(plot, xLabel)
- title = _relabelAxes(plot, title)
-
- data = core.CurveProfileData(
- coords=coords,
- profile=profile,
- title=title,
- xLabel=xLabel,
- yLabel=yLabel,
- )
- return data
-
-
-class _ProfileCrossROI(roi_items.HandleBasedROI, core.ProfileRoiMixIn):
-
- """ROI to manage a cross of profiles
-
- It is managed using 2 sub ROIs for vertical and horizontal.
- """
-
- _kind = "Cross"
- """Label for this kind of ROI"""
-
- _plotShape = "point"
- """Plot shape which is used for the first interaction"""
-
- def __init__(self, parent=None):
- roi_items.HandleBasedROI.__init__(self, parent=parent)
- core.ProfileRoiMixIn.__init__(self, parent=parent)
- self.sigRegionChanged.connect(self.__regionChanged)
- self.sigAboutToBeRemoved.connect(self.__aboutToBeRemoved)
- self.__position = 0, 0
- self.__vline = None
- self.__hline = None
- self.__handle = self.addHandle()
- self.__handleLabel = self.addLabelHandle()
- self.__handleLabel.setText(self.getName())
- self.__inhibitReentance = utils.LockReentrant()
- self.computeProfile = None
- self.sigItemChanged.connect(self.__updateLineProperty)
-
- # Make sure the marker is over the ROIs
- self.__handle.setZValue(1)
- # Create the vline and the hline
- self._createSubRois()
-
- @docstring(roi_items.HandleBasedROI)
- def contains(self, position):
- roiPos = self.getPosition()
- return position[0] == roiPos[0] or position[1] == roiPos[1]
-
- def setFirstShapePoints(self, points):
- pos = points[0]
- self.setPosition(pos)
-
- def getPosition(self):
- """Returns the position of this ROI
-
- :rtype: numpy.ndarray
- """
- return self.__position
-
- def setPosition(self, pos):
- """Set the position of this ROI
-
- :param numpy.ndarray pos: 2d-coordinate of this point
- """
- self.__position = pos
- with utils.blockSignals(self.__handle):
- self.__handle.setPosition(*pos)
- with utils.blockSignals(self.__handleLabel):
- self.__handleLabel.setPosition(*pos)
- self.sigRegionChanged.emit()
-
- def handleDragUpdated(self, handle, origin, previous, current):
- if handle is self.__handle:
- self.setPosition(current)
-
- def __updateLineProperty(self, event=None, checkVisibility=True):
- if event == items.ItemChangedType.NAME:
- self.__handleLabel.setText(self.getName())
- elif event in [items.ItemChangedType.COLOR,
- items.ItemChangedType.VISIBLE]:
- lines = []
- if self.__vline:
- lines.append(self.__vline)
- if self.__hline:
- lines.append(self.__hline)
- self._updateItemProperty(event, self, lines)
-
- def _createLines(self, parent):
- """Inherit this function to return 2 ROI objects for respectivly
- the horizontal, and the vertical lines."""
- raise NotImplementedError()
-
- def _setProfileManager(self, profileManager):
- core.ProfileRoiMixIn._setProfileManager(self, profileManager)
- # Connecting the vline and the hline
- roiManager = profileManager.getRoiManager()
- roiManager.addRoi(self.__vline)
- roiManager.addRoi(self.__hline)
-
- def _createSubRois(self):
- hline, vline = self._createLines(parent=None)
- for i, line in enumerate([vline, hline]):
- line.setPosition(self.__position[i])
- line.setEditable(True)
- line.setSelectable(True)
- line.setFocusProxy(self)
- line.setName("")
- self.__vline = vline
- self.__hline = hline
- vline.sigAboutToBeRemoved.connect(self.__vlineRemoved)
- vline.sigRegionChanged.connect(self.__vlineRegionChanged)
- hline.sigAboutToBeRemoved.connect(self.__hlineRemoved)
- hline.sigRegionChanged.connect(self.__hlineRegionChanged)
-
- def _getLines(self):
- return self.__hline, self.__vline
-
- def __regionChanged(self):
- if self.__inhibitReentance.locked():
- return
- x, y = self.getPosition()
- hline, vline = self._getLines()
- if hline is None:
- return
- with self.__inhibitReentance:
- hline.setPosition(y)
- vline.setPosition(x)
-
- def __vlineRegionChanged(self):
- if self.__inhibitReentance.locked():
- return
- pos = self.getPosition()
- vline = self.__vline
- pos = vline.getPosition(), pos[1]
- with self.__inhibitReentance:
- self.setPosition(pos)
-
- def __hlineRegionChanged(self):
- if self.__inhibitReentance.locked():
- return
- pos = self.getPosition()
- hline = self.__hline
- pos = pos[0], hline.getPosition()
- with self.__inhibitReentance:
- self.setPosition(pos)
-
- def __aboutToBeRemoved(self):
- vline = self.__vline
- hline = self.__hline
- # Avoid side remove signals
- if hline is not None:
- hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
- hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
- if vline is not None:
- vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
- vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
- # Clean up the child
- profileManager = self.getProfileManager()
- roiManager = profileManager.getRoiManager()
- if hline is not None:
- roiManager.removeRoi(hline)
- self.__hline = None
- if vline is not None:
- roiManager.removeRoi(vline)
- self.__vline = None
-
- def __hlineRemoved(self):
- self.__lineRemoved(isHline=True)
-
- def __vlineRemoved(self):
- self.__lineRemoved(isHline=False)
-
- def __lineRemoved(self, isHline):
- """If any of the lines is removed: disconnect this objects, and let the
- other one persist"""
- hline, vline = self._getLines()
-
- hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
- vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
- hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
- vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
-
- self.__hline = None
- self.__vline = None
- profileManager = self.getProfileManager()
- roiManager = profileManager.getRoiManager()
- if isHline:
- self.__releaseLine(vline)
- else:
- self.__releaseLine(hline)
- roiManager.removeRoi(self)
-
- def __releaseLine(self, line):
- """Release the line in order to make it independent"""
- line.setFocusProxy(None)
- line.setName(self.getName())
- line.setEditable(self.isEditable())
- line.setSelectable(self.isSelectable())
-
-
-class ProfileImageCrossROI(_ProfileCrossROI):
- """ROI to manage a cross of profiles
-
- It is managed using 2 sub ROIs for vertical and horizontal.
- """
-
- ICON = 'shape-cross'
- NAME = 'cross profile'
- ITEM_KIND = items.ImageBase
-
- def _createLines(self, parent):
- vline = ProfileImageVerticalLineROI(parent=parent)
- hline = ProfileImageHorizontalLineROI(parent=parent)
- return hline, vline
-
- def setProfileMethod(self, method):
- """
- :param str method: method to compute the profile. Can be 'mean' or 'sum'
- """
- hline, vline = self._getLines()
- hline.setProfileMethod(method)
- vline.setProfileMethod(method)
- self.invalidateProperties()
-
- def getProfileMethod(self):
- hline, _vline = self._getLines()
- return hline.getProfileMethod()
-
- def setProfileLineWidth(self, width):
- hline, vline = self._getLines()
- hline.setProfileLineWidth(width)
- vline.setProfileLineWidth(width)
- self.invalidateProperties()
-
- def getProfileLineWidth(self):
- hline, _vline = self._getLines()
- return hline.getProfileLineWidth()
-
-
-class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
- """Provide common behavior for silx default scatter profile ROI.
- """
-
- ITEM_KIND = items.Scatter
-
- def __init__(self, parent=None):
- core.ProfileRoiMixIn.__init__(self, parent=parent)
- self.__nPoints = 1024
- self.sigRegionChanged.connect(self.__regionChanged)
-
- def __regionChanged(self):
- self.invalidateProfile()
-
- # Number of points
-
- def getNPoints(self):
- """Returns the number of points of the profiles
-
- :rtype: int
- """
- return self.__nPoints
-
- def setNPoints(self, npoints):
- """Set the number of points of the profiles
-
- :param int npoints:
- """
- npoints = int(npoints)
- if npoints < 1:
- raise ValueError("Unsupported number of points: %d" % npoints)
- elif npoints != self.__nPoints:
- self.__nPoints = npoints
- self.invalidateProperties()
- self.invalidateProfile()
-
- def _computeProfile(self, scatter, x0, y0, x1, y1):
- """Compute corresponding profile
-
- :param float x0: Profile start point X coord
- :param float y0: Profile start point Y coord
- :param float x1: Profile end point X coord
- :param float y1: Profile end point Y coord
- :return: (points, values) profile data or None
- """
- future = scatter._getInterpolator()
- try:
- interpolator = future.result()
- except CancelledError:
- return None
- if interpolator is None:
- return None # Cannot init an interpolator
-
- nPoints = self.getNPoints()
- points = numpy.transpose((
- numpy.linspace(x0, x1, nPoints, endpoint=True),
- numpy.linspace(y0, y1, nPoints, endpoint=True)))
-
- values = interpolator(points)
-
- if not numpy.any(numpy.isfinite(values)):
- return None # Profile outside convex hull
-
- return points, values
-
- def computeProfile(self, item):
- """Update profile according to current ROI"""
- if not isinstance(item, items.Scatter):
- raise TypeError("Unexpected class %s" % type(item))
-
- # Get end points
- if isinstance(self, roi_items.LineROI):
- points = self.getEndPoints()
- x0, y0 = points[0]
- x1, y1 = points[1]
- elif isinstance(self, (roi_items.VerticalLineROI, roi_items.HorizontalLineROI)):
- profileManager = self.getProfileManager()
- plot = profileManager.getPlotWidget()
-
- if isinstance(self, roi_items.HorizontalLineROI):
- x0, x1 = plot.getXAxis().getLimits()
- y0 = y1 = self.getPosition()
-
- elif isinstance(self, roi_items.VerticalLineROI):
- x0 = x1 = self.getPosition()
- y0, y1 = plot.getYAxis().getLimits()
- else:
- raise RuntimeError('Unsupported ROI for profile: {}'.format(self.__class__))
-
- if x1 < x0 or (x1 == x0 and y1 < y0):
- # Invert points
- x0, y0, x1, y1 = x1, y1, x0, y0
-
- profile = self._computeProfile(item, x0, y0, x1, y1)
- if profile is None:
- return None
-
- title = _lineProfileTitle(x0, y0, x1, y1)
- points = profile[0]
- values = profile[1]
-
- if (numpy.abs(points[-1, 0] - points[0, 0]) >
- numpy.abs(points[-1, 1] - points[0, 1])):
- xProfile = points[:, 0]
- xLabel = '{xlabel}'
- else:
- xProfile = points[:, 1]
- xLabel = '{ylabel}'
-
- # Use the axis names from the original
- profileManager = self.getProfileManager()
- plot = profileManager.getPlotWidget()
- title = _relabelAxes(plot, title)
- xLabel = _relabelAxes(plot, xLabel)
-
- data = core.CurveProfileData(
- coords=xProfile,
- profile=values,
- title=title,
- xLabel=xLabel,
- yLabel='Profile',
- )
- return data
-
-
-class ProfileScatterHorizontalLineROI(roi_items.HorizontalLineROI,
- _DefaultScatterProfileRoiMixIn):
- """ROI for an horizontal profile at a location of a scatter"""
-
- ICON = 'shape-horizontal'
- NAME = 'horizontal line profile'
-
- def __init__(self, parent=None):
- roi_items.HorizontalLineROI.__init__(self, parent=parent)
- _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileScatterVerticalLineROI(roi_items.VerticalLineROI,
- _DefaultScatterProfileRoiMixIn):
- """ROI for an horizontal profile at a location of a scatter"""
-
- ICON = 'shape-vertical'
- NAME = 'vertical line profile'
-
- def __init__(self, parent=None):
- roi_items.VerticalLineROI.__init__(self, parent=parent)
- _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileScatterLineROI(roi_items.LineROI,
- _DefaultScatterProfileRoiMixIn):
- """ROI for an horizontal profile at a location of a scatter"""
-
- ICON = 'shape-diagonal'
- NAME = 'line profile'
-
- def __init__(self, parent=None):
- roi_items.LineROI.__init__(self, parent=parent)
- _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileScatterCrossROI(_ProfileCrossROI):
- """ROI to manage a cross of profiles for scatters.
- """
-
- ICON = 'shape-cross'
- NAME = 'cross profile'
- ITEM_KIND = items.Scatter
-
- def _createLines(self, parent):
- vline = ProfileScatterVerticalLineROI(parent=parent)
- hline = ProfileScatterHorizontalLineROI(parent=parent)
- return hline, vline
-
- def getNPoints(self):
- """Returns the number of points of the profiles
-
- :rtype: int
- """
- hline, _vline = self._getLines()
- return hline.getNPoints()
-
- def setNPoints(self, npoints):
- """Set the number of points of the profiles
-
- :param int npoints:
- """
- hline, vline = self._getLines()
- hline.setNPoints(npoints)
- vline.setNPoints(npoints)
- self.invalidateProperties()
-
-
-class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
- """Default ROI to allow to slice in the scatter data."""
-
- ITEM_KIND = items.Scatter
-
- def __init__(self, parent=None):
- core.ProfileRoiMixIn.__init__(self, parent=parent)
- self.__area = _SliceProfileArea(self)
- self.addItem(self.__area)
- self.sigRegionChanged.connect(self._regionChanged)
- self.sigPlotItemChanged.connect(self._updateArea)
-
- def _regionChanged(self):
- self.invalidateProfile()
- self._updateArea()
-
- def _updateArea(self):
- plotItem = self.getPlotItem()
- if plotItem is None:
- self.setLineStyle("-")
- else:
- self.setLineStyle("--")
-
- def _getSlice(self, item):
- position = self.getPosition()
- bounds = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_BOUNDS)
- if isinstance(self, roi_items.HorizontalLineROI):
- axis = 1
- elif isinstance(self, roi_items.VerticalLineROI):
- axis = 0
- else:
- assert False
- if position < bounds[0][axis] or position > bounds[1][axis]:
- # ROI outside of the scatter bound
- return None
-
- major_order = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER)
- assert major_order == 'row'
- max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_SHAPE)
-
- xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
- if isinstance(self, roi_items.HorizontalLineROI):
- axis = yy
- max_grid_first = max_grid_yy
- max_grid_second = max_grid_xx
- major_axis = major_order == 'column'
- elif isinstance(self, roi_items.VerticalLineROI):
- axis = xx
- max_grid_first = max_grid_xx
- max_grid_second = max_grid_yy
- major_axis = major_order == 'row'
- else:
- assert False
-
- def argnearest(array, value):
- array = numpy.abs(array - value)
- return numpy.argmin(array)
-
- if major_axis:
- # slice in the middle of the scatter
- start = max_grid_second // 2 * max_grid_first
- vslice = axis[start:start + max_grid_second]
- index = argnearest(vslice, position)
- slicing = slice(index, None, max_grid_first)
- else:
- # slice in the middle of the scatter
- vslice = axis[max_grid_second // 2::max_grid_second]
- index = argnearest(vslice, position)
- start = index * max_grid_second
- slicing = slice(start, start + max_grid_second)
-
- return slicing
-
- def computeProfile(self, item):
- if not isinstance(item, items.Scatter):
- raise TypeError("Unsupported %s item" % type(item))
-
- slicing = self._getSlice(item)
- if slicing is None:
- # ROI out of bounds
- return None
-
- _xx, _yy, values, _xx_error, _yy_error = item.getData(copy=False)
- profile = values[slicing]
-
- if isinstance(self, roi_items.HorizontalLineROI):
- title = "Horizontal slice"
- xLabel = "{xlabel} index"
- elif isinstance(self, roi_items.VerticalLineROI):
- title = "Vertical slice"
- xLabel = "{ylabel} index"
- else:
- assert False
-
- # Use the axis names from the original plot
- profileManager = self.getProfileManager()
- plot = profileManager.getPlotWidget()
- xLabel = _relabelAxes(plot, xLabel)
-
- data = core.CurveProfileData(
- coords=numpy.arange(len(profile)),
- profile=profile,
- title=title,
- xLabel=xLabel,
- yLabel="Profile",
- )
- return data
-
-
-class ProfileScatterHorizontalSliceROI(roi_items.HorizontalLineROI,
- _DefaultScatterProfileSliceRoiMixIn):
- """ROI for an horizontal profile at a location of a scatter
- using data slicing.
- """
-
- ICON = 'slice-horizontal'
- NAME = 'horizontal data slice profile'
-
- def __init__(self, parent=None):
- roi_items.HorizontalLineROI.__init__(self, parent=parent)
- _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileScatterVerticalSliceROI(roi_items.VerticalLineROI,
- _DefaultScatterProfileSliceRoiMixIn):
- """ROI for a vertical profile at a location of a scatter
- using data slicing.
- """
-
- ICON = 'slice-vertical'
- NAME = 'vertical data slice profile'
-
- def __init__(self, parent=None):
- roi_items.VerticalLineROI.__init__(self, parent=parent)
- _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileScatterCrossSliceROI(_ProfileCrossROI):
- """ROI to manage a cross of slicing profiles on scatters.
- """
-
- ICON = 'slice-cross'
- NAME = 'cross data slice profile'
- ITEM_KIND = items.Scatter
-
- def _createLines(self, parent):
- vline = ProfileScatterVerticalSliceROI(parent=parent)
- hline = ProfileScatterHorizontalSliceROI(parent=parent)
- return hline, vline
-
-
-class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
-
- ITEM_KIND = items.ImageStack
-
- def __init__(self, parent=None):
- super(_DefaultImageStackProfileRoiMixIn, self).__init__(parent=parent)
- self.__profileType = "1D"
- """Kind of profile"""
-
- def getProfileType(self):
- return self.__profileType
-
- def setProfileType(self, kind):
- assert kind in ["1D", "2D"]
- if self.__profileType == kind:
- return
- self.__profileType = kind
- self.invalidateProperties()
- self.invalidateProfile()
-
- def computeProfile(self, item):
- if not isinstance(item, items.ImageStack):
- raise TypeError("Unexpected class %s" % type(item))
-
- kind = self.getProfileType()
- if kind == "1D":
- result = _DefaultImageProfileRoiMixIn.computeProfile(self, item)
- # z = item.getStackPosition()
- return result
-
- assert kind == "2D"
-
- def createProfile2(currentData):
- coords, profile, _area, profileName, xLabel = core.createProfile(
- roiInfo=self._getRoiInfo(),
- currentData=currentData,
- origin=origin,
- scale=scale,
- lineWidth=self.getProfileLineWidth(),
- method=method)
- return coords, profile, profileName, xLabel
-
- currentData = numpy.array(item.getStackData(copy=False))
- origin = item.getOrigin()
- scale = item.getScale()
- colormap = item.getColormap()
- method = self.getProfileMethod()
-
- coords, profile, profileName, xLabel = createProfile2(currentData)
-
- data = core.ImageProfileData(
- coords=coords,
- profile=profile,
- title=profileName,
- xLabel=xLabel,
- yLabel="Profile",
- colormap=colormap,
- )
- return data
-
-
-class ProfileImageStackHorizontalLineROI(roi_items.HorizontalLineROI,
- _DefaultImageStackProfileRoiMixIn):
- """ROI for an horizontal profile at a location of a stack of images"""
-
- ICON = 'shape-horizontal'
- NAME = 'horizontal line profile'
-
- def __init__(self, parent=None):
- roi_items.HorizontalLineROI.__init__(self, parent=parent)
- _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileImageStackVerticalLineROI(roi_items.VerticalLineROI,
- _DefaultImageStackProfileRoiMixIn):
- """ROI for an vertical profile at a location of a stack of images"""
-
- ICON = 'shape-vertical'
- NAME = 'vertical line profile'
-
- def __init__(self, parent=None):
- roi_items.VerticalLineROI.__init__(self, parent=parent)
- _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileImageStackLineROI(roi_items.LineROI,
- _DefaultImageStackProfileRoiMixIn):
- """ROI for an vertical profile at a location of a stack of images"""
-
- ICON = 'shape-diagonal'
- NAME = 'line profile'
-
- def __init__(self, parent=None):
- roi_items.LineROI.__init__(self, parent=parent)
- _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
-
-
-class ProfileImageStackCrossROI(ProfileImageCrossROI):
- """ROI for an vertical profile at a location of a stack of images"""
-
- ICON = 'shape-cross'
- NAME = 'cross profile'
- ITEM_KIND = items.ImageStack
-
- def _createLines(self, parent):
- vline = ProfileImageStackVerticalLineROI(parent=parent)
- hline = ProfileImageStackHorizontalLineROI(parent=parent)
- return hline, vline
-
- def getProfileType(self):
- hline, _vline = self._getLines()
- return hline.getProfileType()
-
- def setProfileType(self, kind):
- hline, vline = self._getLines()
- hline.setProfileType(kind)
- vline.setProfileType(kind)
- self.invalidateProperties()
diff --git a/silx/gui/plot/tools/roi.py b/silx/gui/plot/tools/roi.py
deleted file mode 100644
index 4e2d6db..0000000
--- a/silx/gui/plot/tools/roi.py
+++ /dev/null
@@ -1,1417 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides ROI interaction for :class:`~silx.gui.plot.PlotWidget`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "28/06/2018"
-
-
-import enum
-import logging
-import time
-import weakref
-import functools
-
-import numpy
-
-from ... import qt, icons
-from ...utils import blockSignals
-from ...utils import LockReentrant
-from .. import PlotWidget
-from ..items import roi as roi_items
-
-from ...colors import rgba
-
-
-logger = logging.getLogger(__name__)
-
-
-class CreateRoiModeAction(qt.QAction):
- """
- This action is a plot mode which allows to create new ROIs using a ROI
- manager.
-
- A ROI is created using a specific `roiClass`. `initRoi` and `finalizeRoi`
- can be inherited to custom the ROI initialization.
-
- :param class roiClass: The ROI class which will be created by this action.
- :param qt.QObject parent: The action parent
- :param RegionOfInterestManager roiManager: The ROI manager
- """
-
- def __init__(self, parent, roiManager, roiClass):
- assert roiManager is not None
- assert roiClass is not None
- qt.QAction.__init__(self, parent=parent)
- self._roiManager = weakref.ref(roiManager)
- self._roiClass = roiClass
- self._singleShot = False
- self._initAction()
- self.triggered[bool].connect(self._actionTriggered)
-
- def _initAction(self):
- """Default initialization of the action"""
- roiClass = self._roiClass
-
- name = None
- iconName = None
- if hasattr(roiClass, "NAME"):
- name = roiClass.NAME
- if hasattr(roiClass, "ICON"):
- iconName = roiClass.ICON
-
- if iconName is None:
- iconName = "add-shape-unknown"
- if name is None:
- name = roiClass.__name__
- text = 'Add %s' % name
- self.setIcon(icons.getQIcon(iconName))
- self.setText(text)
- self.setCheckable(True)
- self.setToolTip(text)
-
- def getRoiClass(self):
- """Return the ROI class used by this action to create ROIs"""
- return self._roiClass
-
- def getRoiManager(self):
- return self._roiManager()
-
- def setSingleShot(self, singleShot):
- """Set it to True to deactivate the action after the first creation
- of a ROI.
-
- :param bool singleShot: New single short state
- """
- self._singleShot = singleShot
-
- def getSingleShot(self):
- """If True, after the first creation of a ROI with this mode,
- the mode is deactivated.
-
- :rtype: bool
- """
- return self._singleShot
-
- def _actionTriggered(self, checked):
- """Handle mode actions being checked by the user
-
- :param bool checked:
- :param str kind: Corresponding shape kind
- """
- roiManager = self.getRoiManager()
- if roiManager is None:
- return
-
- if checked:
- roiManager.start(self._roiClass, self)
- self.__interactiveModeStarted(roiManager)
- else:
- source = roiManager.getInteractionSource()
- if source is self:
- roiManager.stop()
-
- def __interactiveModeStarted(self, roiManager):
- roiManager.sigInteractiveRoiCreated.connect(self.initRoi)
- roiManager.sigInteractiveRoiFinalized.connect(self.__finalizeRoi)
- roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished)
-
- def __interactiveModeFinished(self):
- roiManager = self.getRoiManager()
- if roiManager is not None:
- roiManager.sigInteractiveRoiCreated.disconnect(self.initRoi)
- roiManager.sigInteractiveRoiFinalized.disconnect(self.__finalizeRoi)
- roiManager.sigInteractiveModeFinished.disconnect(self.__interactiveModeFinished)
- self.setChecked(False)
-
- def initRoi(self, roi):
- """Inherit it to custom the new ROI at it's creation during the
- interaction."""
- pass
-
- def __finalizeRoi(self, roi):
- self.finalizeRoi(roi)
- if self._singleShot:
- roiManager = self.getRoiManager()
- if roiManager is not None:
- roiManager.stop()
-
- def finalizeRoi(self, roi):
- """Inherit it to custom the new ROI after it's creation when the
- interaction is finalized."""
- pass
-
-
-class RoiModeSelector(qt.QWidget):
- def __init__(self, parent=None):
- super(RoiModeSelector, self).__init__(parent=parent)
- self.__roi = None
- self.__reentrant = LockReentrant()
-
- layout = qt.QHBoxLayout(self)
- if isinstance(parent, qt.QMenu):
- margins = layout.contentsMargins()
- layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
- else:
- layout.setContentsMargins(0, 0, 0, 0)
-
- self._label = qt.QLabel(self)
- self._label.setText("Mode:")
- self._label.setToolTip("Select a specific interaction to edit the ROI")
- self._combo = qt.QComboBox(self)
- self._combo.currentIndexChanged.connect(self._modeSelected)
- layout.addWidget(self._label)
- layout.addWidget(self._combo)
- self._updateAvailableModes()
-
- def getRoi(self):
- """Returns the edited ROI.
-
- :rtype: roi_items.RegionOfInterest
- """
- return self.__roi
-
- def setRoi(self, roi):
- """Returns the edited ROI.
-
- :rtype: roi_items.RegionOfInterest
- """
- if self.__roi is roi:
- return
- if not isinstance(roi, roi_items.InteractionModeMixIn):
- self.__roi = None
- self._updateAvailableModes()
- return
-
- if self.__roi is not None:
- self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged)
- self.__roi = roi
- if self.__roi is not None:
- self.__roi.sigInteractionModeChanged.connect(self._modeChanged)
- self._updateAvailableModes()
-
- def isEmpty(self):
- return not self._label.isVisibleTo(self)
-
- def _updateAvailableModes(self):
- roi = self.getRoi()
- if isinstance(roi, roi_items.InteractionModeMixIn):
- modes = roi.availableInteractionModes()
- else:
- modes = []
- if len(modes) <= 1:
- self._label.setVisible(False)
- self._combo.setVisible(False)
- else:
- self._label.setVisible(True)
- self._combo.setVisible(True)
- with blockSignals(self._combo):
- self._combo.clear()
- for im, m in enumerate(modes):
- self._combo.addItem(m.label, m)
- self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole)
- mode = roi.getInteractionMode()
- self._modeChanged(mode)
- index = modes.index(mode)
- self._combo.setCurrentIndex(index)
-
- def _modeChanged(self, mode):
- """Triggered when the ROI interaction mode was changed externally"""
- if self.__reentrant.locked():
- # This event was initialised by the widget
- return
- roi = self.__roi
- modes = roi.availableInteractionModes()
- index = modes.index(mode)
- with blockSignals(self._combo):
- self._combo.setCurrentIndex(index)
-
- def _modeSelected(self):
- """Triggered when the ROI interaction mode was selected in the widget"""
- index = self._combo.currentIndex()
- if index == -1:
- return
- roi = self.getRoi()
- if roi is not None:
- mode = self._combo.itemData(index, qt.Qt.UserRole)
- with self.__reentrant:
- roi.setInteractionMode(mode)
-
-
-class RoiModeSelectorAction(qt.QWidgetAction):
- """Display the selected mode of a ROI and allow to change it"""
-
- def __init__(self, parent=None):
- super(RoiModeSelectorAction, self).__init__(parent)
- self.__roiManager = None
-
- def createWidget(self, parent):
- """Inherit the method to create a new widget"""
- widget = RoiModeSelector(parent)
- manager = self.__roiManager
- if manager is not None:
- roi = manager.getCurrentRoi()
- widget.setRoi(roi)
- self.setVisible(not widget.isEmpty())
- return widget
-
- def deleteWidget(self, widget):
- """Inherit the method to delete a widget"""
- widget.setRoi(None)
- return qt.QWidgetAction.deleteWidget(self, widget)
-
- def setRoiManager(self, roiManager):
- """
- Connect this action to a ROI manager.
-
- :param RegionOfInterestManager roiManager: A ROI manager
- """
- if self.__roiManager is roiManager:
- return
- if self.__roiManager is not None:
- self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
- self.__roiManager = roiManager
- if self.__roiManager is not None:
- self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
- self.__currentRoiChanged(roiManager.getCurrentRoi())
-
- def __currentRoiChanged(self, roi):
- """Handle changes of the selected ROI"""
- self.setRoi(roi)
-
- def setRoi(self, roi):
- """Set a profile ROI to edit.
-
- :param ProfileRoiMixIn roi: A profile ROI
- """
- widget = None
- for widget in self.createdWidgets():
- widget.setRoi(roi)
- if widget is not None:
- self.setVisible(not widget.isEmpty())
-
-
-class RegionOfInterestManager(qt.QObject):
- """Class handling ROI interaction on a PlotWidget.
-
- It supports the multiple ROIs: points, rectangles, polygons,
- lines, horizontal and vertical lines.
-
- See ``plotInteractiveImageROI.py`` sample code (:ref:`sample-code`).
-
- :param silx.gui.plot.PlotWidget parent:
- The plot widget in which to control the ROIs.
- """
-
- sigRoiAdded = qt.Signal(roi_items.RegionOfInterest)
- """Signal emitted when a new ROI has been added.
-
- It provides the newly add :class:`RegionOfInterest` object.
- """
-
- sigRoiAboutToBeRemoved = qt.Signal(roi_items.RegionOfInterest)
- """Signal emitted just before a ROI is removed.
-
- It provides the :class:`RegionOfInterest` object that is about to be removed.
- """
-
- sigRoiChanged = qt.Signal()
- """Signal emitted whenever the ROIs have changed."""
-
- sigCurrentRoiChanged = qt.Signal(object)
- """Signal emitted whenever a ROI is selected."""
-
- sigInteractiveModeStarted = qt.Signal(object)
- """Signal emitted when switching to ROI drawing interactive mode.
-
- It provides the class of the ROI which will be created by the interactive
- mode.
- """
-
- sigInteractiveRoiCreated = qt.Signal(object)
- """Signal emitted when a ROI is created during the interaction.
- The interaction is still incomplete and can be aborted.
-
- It provides the ROI object which was just been created.
- """
-
- sigInteractiveRoiFinalized = qt.Signal(object)
- """Signal emitted when a ROI creation is complet.
-
- It provides the ROI object which was just been created.
- """
-
- sigInteractiveModeFinished = qt.Signal()
- """Signal emitted when leaving interactive ROI drawing mode.
- """
-
- ROI_CLASSES = (
- roi_items.PointROI,
- roi_items.CrossROI,
- roi_items.RectangleROI,
- roi_items.CircleROI,
- roi_items.EllipseROI,
- roi_items.PolygonROI,
- roi_items.LineROI,
- roi_items.HorizontalLineROI,
- roi_items.VerticalLineROI,
- roi_items.ArcROI,
- roi_items.HorizontalRangeROI,
- )
-
- def __init__(self, parent):
- assert isinstance(parent, PlotWidget)
- super(RegionOfInterestManager, self).__init__(parent)
- self._rois = [] # List of ROIs
- self._drawnROI = None # New ROI being currently drawn
-
- self._roiClass = None
- self._source = None
- self._color = rgba('red')
-
- self._label = "__RegionOfInterestManager__%d" % id(self)
-
- self._currentRoi = None
- """Hold currently selected ROI"""
-
- self._eventLoop = None
-
- self._modeActions = {}
-
- parent.sigPlotSignal.connect(self._plotSignals)
-
- parent.sigInteractiveModeChanged.connect(
- self._plotInteractiveModeChanged)
-
- parent.sigItemRemoved.connect(self._itemRemoved)
-
- parent._sigDefaultContextMenu.connect(self._feedContextMenu)
-
- @classmethod
- def getSupportedRoiClasses(cls):
- """Returns the default available ROI classes
-
- :rtype: List[class]
- """
- return tuple(cls.ROI_CLASSES)
-
- # Associated QActions
-
- def getInteractionModeAction(self, roiClass):
- """Returns the QAction corresponding to a kind of ROI
-
- The QAction allows to enable the corresponding drawing
- interactive mode.
-
- :param class roiClass: The ROI class which will be created by this action.
- :rtype: QAction
- :raise ValueError: If kind is not supported
- """
- if not issubclass(roiClass, roi_items.RegionOfInterest):
- raise ValueError('Unsupported ROI class %s' % roiClass)
-
- action = self._modeActions.get(roiClass, None)
- if action is None: # Lazy-loading
- action = CreateRoiModeAction(self, self, roiClass)
- self._modeActions[roiClass] = action
- return action
-
- # PlotWidget eventFilter and listeners
-
- def _plotInteractiveModeChanged(self, source):
- """Handle change of interactive mode in the plot"""
- if source is not self:
- self.__roiInteractiveModeEnded()
-
- def _getRoiFromItem(self, item):
- """Returns the ROI which own this item, else None
- if this manager do not have knowledge of this ROI."""
- for roi in self._rois:
- if isinstance(roi, roi_items.RegionOfInterest):
- for child in roi.getItems():
- if child is item:
- return roi
- return None
-
- def _itemRemoved(self, item):
- """Called after an item was removed from the plot."""
- if not hasattr(item, "_roiGroup"):
- # Early break to avoid to use _getRoiFromItem
- # And to avoid reentrant signal when the ROI remove the item itself
- return
- roi = self._getRoiFromItem(item)
- if roi is not None:
- self.removeRoi(roi)
-
- # Handle ROI interaction
-
- def _handleInteraction(self, event):
- """Handle mouse interaction for ROI addition"""
- roiClass = self.getCurrentInteractionModeRoiClass()
- if roiClass is None:
- return # Should not happen
-
- kind = roiClass.getFirstInteractionShape()
- if kind == 'point':
- if event['event'] == 'mouseClicked' and event['button'] == 'left':
- points = numpy.array([(event['x'], event['y'])],
- dtype=numpy.float64)
- # Not an interactive creation
- roi = self._createInteractiveRoi(roiClass, points=points)
- roi.creationFinalized()
- self.sigInteractiveRoiFinalized.emit(roi)
- else: # other shapes
- if (event['event'] in ('drawingProgress', 'drawingFinished') and
- event['parameters']['label'] == self._label):
- points = numpy.array((event['xdata'], event['ydata']),
- dtype=numpy.float64).T
-
- if self._drawnROI is None: # Create new ROI
- # NOTE: Set something before createRoi, so isDrawing is True
- self._drawnROI = object()
- self._drawnROI = self._createInteractiveRoi(roiClass, points=points)
- else:
- self._drawnROI.setFirstShapePoints(points)
-
- if event['event'] == 'drawingFinished':
- if kind == 'polygon' and len(points) > 1:
- self._drawnROI.setFirstShapePoints(points[:-1])
- roi = self._drawnROI
- self._drawnROI = None # Stop drawing
- roi.creationFinalized()
- self.sigInteractiveRoiFinalized.emit(roi)
-
- # RegionOfInterest selection
-
- def __getRoiFromMarker(self, marker):
- """Returns a ROI from a marker, else None"""
- # This should be speed up
- for roi in self._rois:
- if isinstance(roi, roi_items.HandleBasedROI):
- for m in roi.getHandles():
- if m is marker:
- return roi
- else:
- for m in roi.getItems():
- if m is marker:
- return roi
- return None
-
- def setCurrentRoi(self, roi):
- """Set the currently selected ROI, and emit a signal.
-
- :param Union[RegionOfInterest,None] roi: The ROI to select
- """
- if self._currentRoi is roi:
- return
- if roi is not None:
- # Note: Fixed range to avoid infinite loops
- for _ in range(10):
- target = roi.getFocusProxy()
- if target is None:
- break
- roi = target
- else:
- raise RuntimeError("Max selection proxy depth (10) reached.")
-
- if self._currentRoi is not None:
- self._currentRoi.setHighlighted(False)
- self._currentRoi = roi
- if self._currentRoi is not None:
- self._currentRoi.setHighlighted(True)
- self.sigCurrentRoiChanged.emit(roi)
-
- def getCurrentRoi(self):
- """Returns the currently selected ROI, else None.
-
- :rtype: Union[RegionOfInterest,None]
- """
- return self._currentRoi
-
- def _plotSignals(self, event):
- """Handle mouse interaction for ROI addition"""
- clicked = False
- roi = None
- if event["event"] in ("markerClicked", "markerMoving"):
- plot = self.parent()
- legend = event["label"]
- marker = plot._getMarker(legend=legend)
- roi = self.__getRoiFromMarker(marker)
- elif event["event"] == "mouseClicked" and event["button"] == "left":
- # Marker click is only for dnd
- # This also can click on a marker
- clicked = True
- plot = self.parent()
- marker = plot._getMarkerAt(event["xpixel"], event["ypixel"])
- roi = self.__getRoiFromMarker(marker)
- else:
- return
-
- if roi not in self._rois:
- # The ROI is not own by this manager
- return
-
- if roi is not None:
- currentRoi = self.getCurrentRoi()
- if currentRoi is roi:
- if clicked:
- self.__updateMode(roi)
- elif roi.isSelectable():
- self.setCurrentRoi(roi)
- else:
- self.setCurrentRoi(None)
-
- def __updateMode(self, roi):
- if isinstance(roi, roi_items.InteractionModeMixIn):
- available = roi.availableInteractionModes()
- mode = roi.getInteractionMode()
- imode = available.index(mode)
- mode = available[(imode + 1) % len(available)]
- roi.setInteractionMode(mode)
-
- def _feedContextMenu(self, menu):
- """Called wen the default plot context menu is about to be displayed"""
- roi = self.getCurrentRoi()
- if roi is not None:
- if roi.isEditable():
- # Filter by data position
- # FIXME: It would be better to use GUI coords for it
- plot = self.parent()
- pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
- data = plot.pixelToData(pos.x(), pos.y())
- if roi.contains(data):
- if isinstance(roi, roi_items.InteractionModeMixIn):
- self._contextMenuForInteractionMode(menu, roi)
-
- removeAction = qt.QAction(menu)
- removeAction.setText("Remove %s" % roi.getName())
- callback = functools.partial(self.removeRoi, roi)
- removeAction.triggered.connect(callback)
- menu.addAction(removeAction)
-
- def _contextMenuForInteractionMode(self, menu, roi):
- availableModes = roi.availableInteractionModes()
- currentMode = roi.getInteractionMode()
- submenu = qt.QMenu(menu)
- modeGroup = qt.QActionGroup(menu)
- modeGroup.setExclusive(True)
- for mode in availableModes:
- action = qt.QAction(menu)
- action.setText(mode.label)
- action.setToolTip(mode.description)
- action.setCheckable(True)
- if mode is currentMode:
- action.setChecked(True)
- else:
- callback = functools.partial(roi.setInteractionMode, mode)
- action.triggered.connect(callback)
- modeGroup.addAction(action)
- submenu.addAction(action)
- action = qt.QAction(menu)
- action.setMenu(submenu)
- action.setText("%s interaction mode" % roi.getName())
- menu.addAction(action)
-
- # RegionOfInterest API
-
- def getRois(self):
- """Returns the list of ROIs.
-
- It returns an empty tuple if there is currently no ROI.
-
- :return: Tuple of arrays of objects describing the ROIs
- :rtype: List[RegionOfInterest]
- """
- return tuple(self._rois)
-
- def clear(self):
- """Reset current ROIs
-
- :return: True if ROIs were reset.
- :rtype: bool
- """
- if self.getRois(): # Something to reset
- for roi in self._rois:
- roi.sigRegionChanged.disconnect(
- self._regionOfInterestChanged)
- roi.setParent(None)
- self._rois = []
- self._roisUpdated()
- return True
-
- else:
- return False
-
- def _regionOfInterestChanged(self, event=None):
- """Handle ROI object changed"""
- self.sigRoiChanged.emit()
-
- def _createInteractiveRoi(self, roiClass, points, label=None, index=None):
- """Create a new ROI with interactive creation.
-
- :param class roiClass: The class of the ROI to create
- :param numpy.ndarray points: The first shape used to create the ROI
- :param str label: The label to display along with the ROI.
- :param int index: The position where to insert the ROI.
- By default it is appended to the end of the list.
- :return: The created ROI object
- :rtype: roi_items.RegionOfInterest
- :raise RuntimeError: When ROI cannot be added because the maximum
- number of ROIs has been reached.
- """
- roi = roiClass(parent=None)
- if label is not None:
- roi.setName(str(label))
- roi.creationStarted()
- roi.setFirstShapePoints(points)
-
- self.addRoi(roi, index)
- if roi.isSelectable():
- self.setCurrentRoi(roi)
- self.sigInteractiveRoiCreated.emit(roi)
- return roi
-
- def containsRoi(self, roi):
- """Returns true if the ROI is part of this manager.
-
- :param roi_items.RegionOfInterest roi: The ROI to add
- :rtype: bool
- """
- return roi in self._rois
-
- def addRoi(self, roi, index=None, useManagerColor=True):
- """Add the ROI to the list of ROIs.
-
- :param roi_items.RegionOfInterest roi: The ROI to add
- :param int index: The position where to insert the ROI,
- By default it is appended to the end of the list of ROIs
- :param bool useManagerColor:
- Whether to set the ROI color to the default one of the manager or not.
- (Default: True).
- :raise RuntimeError: When ROI cannot be added because the maximum
- number of ROIs has been reached.
- """
- plot = self.parent()
- if plot is None:
- raise RuntimeError(
- 'Cannot add ROI: PlotWidget no more available')
-
- roi.setParent(self)
-
- if useManagerColor:
- roi.setColor(self.getColor())
-
- roi.sigRegionChanged.connect(self._regionOfInterestChanged)
- roi.sigItemChanged.connect(self._regionOfInterestChanged)
-
- if index is None:
- self._rois.append(roi)
- else:
- self._rois.insert(index, roi)
- self.sigRoiAdded.emit(roi)
- self._roisUpdated()
-
- def removeRoi(self, roi):
- """Remove a ROI from the list of ROIs.
-
- :param roi_items.RegionOfInterest roi: The ROI to remove
- :raise ValueError: When ROI does not belong to this object
- """
- if not (isinstance(roi, roi_items.RegionOfInterest) and
- roi.parent() is self and
- roi in self._rois):
- raise ValueError(
- 'RegionOfInterest does not belong to this instance')
-
- roi.sigAboutToBeRemoved.emit()
- self.sigRoiAboutToBeRemoved.emit(roi)
-
- if roi is self._currentRoi:
- self.setCurrentRoi(None)
-
- mustRestart = False
- if roi is self._drawnROI:
- self._drawnROI = None
- mustRestart = True
- self._rois.remove(roi)
- roi.sigRegionChanged.disconnect(self._regionOfInterestChanged)
- roi.sigItemChanged.disconnect(self._regionOfInterestChanged)
- roi.setParent(None)
- self._roisUpdated()
-
- if mustRestart:
- self._restart()
-
- def _roisUpdated(self):
- """Handle update of the ROI list"""
- self.sigRoiChanged.emit()
-
- # RegionOfInterest parameters
-
- def getColor(self):
- """Return the default color of created ROIs
-
- :rtype: QColor
- """
- return qt.QColor.fromRgbF(*self._color)
-
- def setColor(self, color):
- """Set the default color to use when creating ROIs.
-
- Existing ROIs are not affected.
-
- :param color: The color to use for displaying ROIs as
- either a color name, a QColor, a list of uint8 or float in [0, 1].
- """
- self._color = rgba(color)
-
- # Control ROI
-
- def getCurrentInteractionModeRoiClass(self):
- """Returns the current ROI class used by the interactive drawing mode.
-
- Returns None if the ROI manager is not in an interactive mode.
-
- :rtype: Union[class,None]
- """
- return self._roiClass
-
- def getInteractionSource(self):
- """Returns the object which have requested the ROI creation.
-
- Returns None if the ROI manager is not in an interactive mode.
-
- :rtype: Union[object,None]
- """
- return self._source
-
- def isStarted(self):
- """Returns True if an interactive ROI drawing mode is active.
-
- :rtype: bool
- """
- return self._roiClass is not None
-
- def isDrawing(self):
- """Returns True if an interactive ROI is drawing.
-
- :rtype: bool
- """
- return self._drawnROI is not None
-
- def start(self, roiClass, source=None):
- """Start an interactive ROI drawing mode.
-
- :param class roiClass: The ROI class to create. It have to inherite from
- `roi_items.RegionOfInterest`.
- :param object source: SOurce of the ROI interaction.
- :return: True if interactive ROI drawing was started, False otherwise
- :rtype: bool
- :raise ValueError: If roiClass is not supported
- """
- self.stop()
-
- if not issubclass(roiClass, roi_items.RegionOfInterest):
- raise ValueError('Unsupported ROI class %s' % roiClass)
-
- plot = self.parent()
- if plot is None:
- return False
-
- self._roiClass = roiClass
- self._source = source
-
- self._restart()
-
- plot.sigPlotSignal.connect(self._handleInteraction)
-
- self.sigInteractiveModeStarted.emit(roiClass)
-
- return True
-
- def _restart(self):
- """Restart the plot interaction without changing the
- source or the ROI class.
- """
- roiClass = self._roiClass
- plot = self.parent()
- firstInteractionShapeKind = roiClass.getFirstInteractionShape()
-
- if firstInteractionShapeKind == 'point':
- plot.setInteractiveMode(mode='select', source=self)
- else:
- if roiClass.showFirstInteractionShape():
- color = rgba(self.getColor())
- else:
- color = None
- plot.setInteractiveMode(mode='select-draw',
- source=self,
- shape=firstInteractionShapeKind,
- color=color,
- label=self._label)
-
- def __roiInteractiveModeEnded(self):
- """Handle end of ROI draw interactive mode"""
- if self.isStarted():
- self._roiClass = None
- self._source = None
-
- if self._drawnROI is not None:
- # Cancel ROI create
- roi = self._drawnROI
- self._drawnROI = None
- self.removeRoi(roi)
-
- plot = self.parent()
- if plot is not None:
- plot.sigPlotSignal.disconnect(self._handleInteraction)
-
- self.sigInteractiveModeFinished.emit()
-
- def stop(self):
- """Stop interactive ROI drawing mode.
-
- :return: True if an interactive ROI drawing mode was actually stopped
- :rtype: bool
- """
- if not self.isStarted():
- return False
-
- plot = self.parent()
- if plot is not None:
- # This leads to call __roiInteractiveModeEnded through
- # interactive mode changed signal
- plot.resetInteractiveMode()
- else: # Fallback
- self.__roiInteractiveModeEnded()
-
- return True
-
- def exec_(self, roiClass):
- """Block until :meth:`quit` is called.
-
- :param class kind: The class of the ROI which have to be created.
- See `silx.gui.plot.items.roi`.
- :return: The list of ROIs
- :rtype: tuple
- """
- self.start(roiClass)
-
- plot = self.parent()
- plot.show()
- plot.raise_()
-
- self._eventLoop = qt.QEventLoop()
- self._eventLoop.exec_()
- self._eventLoop = None
-
- self.stop()
-
- rois = self.getRois()
- self.clear()
- return rois
-
- def quit(self):
- """Stop a blocking :meth:`exec_` and call :meth:`stop`"""
- if self._eventLoop is not None:
- self._eventLoop.quit()
- self._eventLoop = None
- self.stop()
-
-
-class InteractiveRegionOfInterestManager(RegionOfInterestManager):
- """RegionOfInterestManager with features for use from interpreter.
-
- It is meant to be used through the :meth:`exec_`.
- It provides some messages to display in a status bar and
- different modes to end blocking calls to :meth:`exec_`.
-
- :param parent: See QObject
- """
-
- sigMessageChanged = qt.Signal(str)
- """Signal emitted when a new message should be displayed to the user
-
- It provides the message as a str.
- """
-
- def __init__(self, parent):
- super(InteractiveRegionOfInterestManager, self).__init__(parent)
- self._maxROI = None
- self.__timeoutEndTime = None
- self.__message = ''
- self.__validationMode = self.ValidationMode.ENTER
- self.__execClass = None
-
- self.sigRoiAdded.connect(self.__added)
- self.sigRoiAboutToBeRemoved.connect(self.__aboutToBeRemoved)
- self.sigInteractiveModeStarted.connect(self.__started)
- self.sigInteractiveModeFinished.connect(self.__finished)
-
- # Max ROI
-
- def getMaxRois(self):
- """Returns the maximum number of ROIs or None if no limit.
-
- :rtype: Union[int,None]
- """
- return self._maxROI
-
- def setMaxRois(self, max_):
- """Set the maximum number of ROIs.
-
- :param Union[int,None] max_: The max limit or None for no limit.
- :raise ValueError: If there is more ROIs than max value
- """
- if max_ is not None:
- max_ = int(max_)
- if max_ <= 0:
- raise ValueError('Max limit must be strictly positive')
-
- if len(self.getRois()) > max_:
- raise ValueError(
- 'Cannot set max limit: Already too many ROIs')
-
- self._maxROI = max_
-
- def isMaxRois(self):
- """Returns True if the maximum number of ROIs is reached.
-
- :rtype: bool
- """
- max_ = self.getMaxRois()
- return max_ is not None and len(self.getRois()) >= max_
-
- # Validation mode
-
- @enum.unique
- class ValidationMode(enum.Enum):
- """Mode of validation to leave blocking :meth:`exec_`"""
-
- AUTO = 'auto'
- """Automatically ends the interactive mode once
- the user terminates the last ROI shape."""
-
- ENTER = 'enter'
- """Ends the interactive mode when the *Enter* key is pressed."""
-
- AUTO_ENTER = 'auto_enter'
- """Ends the interactive mode when reaching max ROIs or
- when the *Enter* key is pressed.
- """
-
- NONE = 'none'
- """Do not provide the user a way to end the interactive mode.
-
- The end of :meth:`exec_` is done through :meth:`quit` or timeout.
- """
-
- def getValidationMode(self):
- """Returns the interactive mode validation in use.
-
- :rtype: ValidationMode
- """
- return self.__validationMode
-
- def setValidationMode(self, mode):
- """Set the way to perform interactive mode validation.
-
- See :class:`ValidationMode` enumeration for the supported
- validation modes.
-
- :param ValidationMode mode: The interactive mode validation to use.
- """
- assert isinstance(mode, self.ValidationMode)
- if mode != self.__validationMode:
- self.__validationMode = mode
-
- if self.isExec():
- if (self.isMaxRois() and self.getValidationMode() in
- (self.ValidationMode.AUTO,
- self.ValidationMode.AUTO_ENTER)):
- self.quit()
-
- self.__updateMessage()
-
- def eventFilter(self, obj, event):
- if event.type() == qt.QEvent.Hide:
- self.quit()
-
- if event.type() == qt.QEvent.KeyPress:
- key = event.key()
- if (key in (qt.Qt.Key_Return, qt.Qt.Key_Enter) and
- self.getValidationMode() in (
- self.ValidationMode.ENTER,
- self.ValidationMode.AUTO_ENTER)):
- # Stop on return key pressed
- self.quit()
- return True # Stop further handling of this keys
-
- if (key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
- key == qt.Qt.Key_Z and
- event.modifiers() & qt.Qt.ControlModifier)):
- rois = self.getRois()
- if rois: # Something to undo
- self.removeRoi(rois[-1])
- # Stop further handling of keys if something was undone
- return True
-
- return super(InteractiveRegionOfInterestManager, self).eventFilter(obj, event)
-
- # Message API
-
- def getMessage(self):
- """Returns the current status message.
-
- This message is meant to be displayed in a status bar.
-
- :rtype: str
- """
- if self.__timeoutEndTime is None:
- return self.__message
- else:
- remaining = self.__timeoutEndTime - time.time()
- return self.__message + (' - %d seconds remaining' %
- max(1, int(remaining)))
-
- # Listen to ROI updates
-
- def __added(self, *args, **kwargs):
- """Handle new ROI added"""
- max_ = self.getMaxRois()
- if max_ is not None:
- # When reaching max number of ROIs, redo last one
- while len(self.getRois()) > max_:
- self.removeRoi(self.getRois()[-2])
-
- self.__updateMessage()
- if (self.isMaxRois() and
- self.getValidationMode() in (self.ValidationMode.AUTO,
- self.ValidationMode.AUTO_ENTER)):
- self.quit()
-
- def __aboutToBeRemoved(self, *args, **kwargs):
- """Handle removal of a ROI"""
- # RegionOfInterest not removed yet
- self.__updateMessage(nbrois=len(self.getRois()) - 1)
-
- def __started(self, roiKind):
- """Handle interactive mode started"""
- self.__updateMessage()
-
- def __finished(self):
- """Handle interactive mode finished"""
- self.__updateMessage()
-
- def __updateMessage(self, nbrois=None):
- """Update message"""
- if not self.isExec():
- message = 'Done'
-
- elif not self.isStarted():
- message = 'Use %s ROI edition mode' % self.__execClass
-
- else:
- if nbrois is None:
- nbrois = len(self.getRois())
-
- name = self.__execClass._getShortName()
-
- max_ = self.getMaxRois()
- if max_ is None:
- message = 'Select %ss (%d selected)' % (name, nbrois)
-
- elif max_ <= 1:
- message = 'Select a %s' % name
- else:
- message = 'Select %d/%d %ss' % (nbrois, max_, name)
-
- if (self.getValidationMode() == self.ValidationMode.ENTER and
- self.isMaxRois()):
- message += ' - Press Enter to confirm'
-
- if message != self.__message:
- self.__message = message
- # Use getMessage to add timeout message
- self.sigMessageChanged.emit(self.getMessage())
-
- # Handle blocking call
-
- def __timeoutUpdate(self):
- """Handle update of timeout"""
- if (self.__timeoutEndTime is not None and
- (self.__timeoutEndTime - time.time()) > 0):
- self.sigMessageChanged.emit(self.getMessage())
- else: # Stop interactive mode and message timer
- timer = self.sender()
- if timer is not None:
- timer.stop()
- self.__timeoutEndTime = None
- self.quit()
-
- def isExec(self):
- """Returns True if :meth:`exec_` is currently running.
-
- :rtype: bool"""
- return self.__execClass is not None
-
- def exec_(self, roiClass, timeout=0):
- """Block until ROI selection is done or timeout is elapsed.
-
- :meth:`quit` also ends this blocking call.
-
- :param class roiClass: The class of the ROI which have to be created.
- See `silx.gui.plot.items.roi`.
- :param int timeout: Maximum duration in seconds to block.
- Default: No timeout
- :return: The list of ROIs
- :rtype: List[RegionOfInterest]
- """
- plot = self.parent()
- if plot is None:
- return
-
- self.__execClass = roiClass
-
- plot.installEventFilter(self)
-
- if timeout > 0:
- self.__timeoutEndTime = time.time() + timeout
- timer = qt.QTimer(self)
- timer.timeout.connect(self.__timeoutUpdate)
- timer.start(1000)
-
- rois = super(InteractiveRegionOfInterestManager, self).exec_(roiClass)
-
- timer.stop()
- self.__timeoutEndTime = None
-
- else:
- rois = super(InteractiveRegionOfInterestManager, self).exec_(roiClass)
-
- plot.removeEventFilter(self)
-
- self.__execClass = None
- self.__updateMessage()
-
- return rois
-
-
-class _DeleteRegionOfInterestToolButton(qt.QToolButton):
- """Tool button deleting a ROI object
-
- :param parent: See QWidget
- :param RegionOfInterest roi: The ROI to delete
- """
-
- def __init__(self, parent, roi):
- super(_DeleteRegionOfInterestToolButton, self).__init__(parent)
- self.setIcon(icons.getQIcon('remove'))
- self.setToolTip("Remove this ROI")
- self.__roiRef = roi if roi is None else weakref.ref(roi)
- self.clicked.connect(self.__clicked)
-
- def __clicked(self, checked):
- """Handle button clicked"""
- roi = None if self.__roiRef is None else self.__roiRef()
- if roi is not None:
- manager = roi.parent()
- if manager is not None:
- manager.removeRoi(roi)
- self.__roiRef = None
-
-
-class RegionOfInterestTableWidget(qt.QTableWidget):
- """Widget displaying the ROIs of a :class:`RegionOfInterestManager`"""
-
- def __init__(self, parent=None):
- super(RegionOfInterestTableWidget, self).__init__(parent)
- self._roiManagerRef = None
-
- headers = ['Label', 'Edit', 'Kind', 'Coordinates', '']
- self.setColumnCount(len(headers))
- self.setHorizontalHeaderLabels(headers)
-
- horizontalHeader = self.horizontalHeader()
- horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft)
- if hasattr(horizontalHeader, 'setResizeMode'): # Qt 4
- setSectionResizeMode = horizontalHeader.setResizeMode
- else: # Qt5
- setSectionResizeMode = horizontalHeader.setSectionResizeMode
-
- setSectionResizeMode(0, qt.QHeaderView.Interactive)
- setSectionResizeMode(1, qt.QHeaderView.ResizeToContents)
- setSectionResizeMode(2, qt.QHeaderView.ResizeToContents)
- setSectionResizeMode(3, qt.QHeaderView.Stretch)
- setSectionResizeMode(4, qt.QHeaderView.ResizeToContents)
-
- verticalHeader = self.verticalHeader()
- verticalHeader.setVisible(False)
-
- self.setSelectionMode(qt.QAbstractItemView.NoSelection)
- self.setFocusPolicy(qt.Qt.NoFocus)
-
- self.itemChanged.connect(self.__itemChanged)
-
- def __itemChanged(self, item):
- """Handle item updates"""
- column = item.column()
- index = item.data(qt.Qt.UserRole)
-
- if index is not None:
- manager = self.getRegionOfInterestManager()
- roi = manager.getRois()[index]
- else:
- return
-
- if column == 0:
- # First collect information from item, then update ROI
- # Otherwise, this causes issues issues
- checked = item.checkState() == qt.Qt.Checked
- text= item.text()
- roi.setVisible(checked)
- roi.setName(text)
- elif column == 1:
- roi.setEditable(item.checkState() == qt.Qt.Checked)
- elif column in (2, 3, 4):
- pass # TODO
- else:
- logger.error('Unhandled column %d', column)
-
- def setRegionOfInterestManager(self, manager):
- """Set the :class:`RegionOfInterestManager` object to sync with
-
- :param RegionOfInterestManager manager:
- """
- assert manager is None or isinstance(manager, RegionOfInterestManager)
-
- previousManager = self.getRegionOfInterestManager()
-
- if previousManager is not None:
- previousManager.sigRoiChanged.disconnect(self._sync)
- self.setRowCount(0)
-
- self._roiManagerRef = weakref.ref(manager)
-
- self._sync()
-
- if manager is not None:
- manager.sigRoiChanged.connect(self._sync)
-
- def _getReadableRoiDescription(self, roi):
- """Returns modelisation of a ROI as a readable sequence of values.
-
- :rtype: str
- """
- text = str(roi)
- try:
- # Extract the params from syntax "CLASSNAME(PARAMS)"
- elements = text.split("(", 1)
- if len(elements) != 2:
- return text
- result = elements[1]
- result = result.strip()
- if not result.endswith(")"):
- return text
- result = result[0:-1]
- # Capitalize each words
- result = result.title()
- return result
- except Exception:
- logger.debug("Backtrace", exc_info=True)
- return text
-
- def _sync(self):
- """Update widget content according to ROI manger"""
- manager = self.getRegionOfInterestManager()
-
- if manager is None:
- self.setRowCount(0)
- return
-
- rois = manager.getRois()
-
- self.setRowCount(len(rois))
- for index, roi in enumerate(rois):
- baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
-
- # Label and visible
- label = roi.getName()
- item = qt.QTableWidgetItem(label)
- item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable)
- item.setData(qt.Qt.UserRole, index)
- item.setCheckState(
- qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
- self.setItem(index, 0, item)
-
- # Editable
- item = qt.QTableWidgetItem()
- item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
- item.setData(qt.Qt.UserRole, index)
- item.setCheckState(
- qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
- self.setItem(index, 1, item)
- item.setTextAlignment(qt.Qt.AlignCenter)
- item.setText(None)
-
- # Kind
- label = roi._getShortName()
- if label is None:
- # Default value if kind is not overrided
- label = roi.__class__.__name__
- item = qt.QTableWidgetItem(label.capitalize())
- item.setFlags(baseFlags)
- self.setItem(index, 2, item)
-
- item = qt.QTableWidgetItem()
- item.setFlags(baseFlags)
-
- # Coordinates
- text = self._getReadableRoiDescription(roi)
- item.setText(text)
- self.setItem(index, 3, item)
-
- # Delete
- delBtn = _DeleteRegionOfInterestToolButton(None, roi)
- widget = qt.QWidget(self)
- layout = qt.QHBoxLayout()
- layout.setContentsMargins(2, 2, 2, 2)
- layout.setSpacing(0)
- widget.setLayout(layout)
- layout.addStretch(1)
- layout.addWidget(delBtn)
- layout.addStretch(1)
- self.setCellWidget(index, 4, widget)
-
- def getRegionOfInterestManager(self):
- """Returns the :class:`RegionOfInterestManager` this widget supervise.
-
- It returns None if not sync with an :class:`RegionOfInterestManager`.
-
- :rtype: RegionOfInterestManager
- """
- return None if self._roiManagerRef is None else self._roiManagerRef()
diff --git a/silx/gui/plot/tools/test/__init__.py b/silx/gui/plot/tools/test/__init__.py
deleted file mode 100644
index 1429545..0000000
--- a/silx/gui/plot/tools/test/__init__.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/03/2018"
-
-
-import unittest
-
-from . import testROI
-from . import testTools
-from . import testScatterProfileToolBar
-from . import testCurveLegendsWidget
-from . import testProfile
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- [testROI.suite(),
- testTools.suite(),
- testScatterProfileToolBar.suite(),
- testCurveLegendsWidget.suite(),
- testProfile.suite(),
- ])
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/tools/test/testCurveLegendsWidget.py b/silx/gui/plot/tools/test/testCurveLegendsWidget.py
deleted file mode 100644
index 4824dd7..0000000
--- a/silx/gui/plot/tools/test/testCurveLegendsWidget.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/08/2018"
-
-
-import unittest
-
-from silx.gui import qt
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot import PlotWindow
-from silx.gui.plot.tools import CurveLegendsWidget
-
-
-class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
- """Tests for CurveLegendsWidget class"""
-
- def setUp(self):
- super(TestCurveLegendsWidget, self).setUp()
- self.plot = PlotWindow()
-
- self.legends = CurveLegendsWidget.CurveLegendsWidget()
- self.legends.setPlotWidget(self.plot)
-
- dock = qt.QDockWidget()
- dock.setWindowTitle('Curve Legends')
- dock.setWidget(self.legends)
- self.plot.addTabbedDockWidget(dock)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- del self.legends
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestCurveLegendsWidget, self).tearDown()
-
- def _assertNbLegends(self, count):
- """Check the number of legends in the CurveLegendsWidget"""
- children = self.legends.findChildren(CurveLegendsWidget._LegendWidget)
- self.assertEqual(len(children), count)
-
- def testAddRemoveCurves(self):
- """Test CurveLegendsWidget while adding/removing curves"""
- self.plot.addCurve((0, 1), (1, 2), legend='a')
- self._assertNbLegends(1)
- self.plot.addCurve((0, 1), (2, 3), legend='b')
- self._assertNbLegends(2)
-
- # Detached/attach
- self.legends.setPlotWidget(None)
- self._assertNbLegends(0)
-
- self.legends.setPlotWidget(self.plot)
- self._assertNbLegends(2)
-
- self.plot.clear()
- self._assertNbLegends(0)
-
- def testUpdateCurves(self):
- """Test CurveLegendsWidget while updating curves """
- self.plot.addCurve((0, 1), (1, 2), legend='a')
- self._assertNbLegends(1)
- self.plot.addCurve((0, 1), (2, 3), legend='b')
- self._assertNbLegends(2)
-
- # Activate curve
- self.plot.setActiveCurve('a')
- self.qapp.processEvents()
- self.plot.setActiveCurve('b')
- self.qapp.processEvents()
-
- # Change curve style
- curve = self.plot.getCurve('a')
- curve.setLineWidth(2)
- for linestyle in (':', '', '--', '-'):
- with self.subTest(linestyle=linestyle):
- curve.setLineStyle(linestyle)
- self.qapp.processEvents()
- self.qWait(1000)
-
- for symbol in ('o', 'd', '', 's'):
- with self.subTest(symbol=symbol):
- curve.setSymbol(symbol)
- self.qapp.processEvents()
- self.qWait(1000)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestCurveLegendsWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/tools/test/testProfile.py b/silx/gui/plot/tools/test/testProfile.py
deleted file mode 100644
index 444cfe0..0000000
--- a/silx/gui/plot/tools/test/testProfile.py
+++ /dev/null
@@ -1,673 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "28/06/2018"
-
-
-import unittest
-import contextlib
-import numpy
-import logging
-
-from silx.gui import qt
-from silx.utils import deprecation
-from silx.utils import testutils
-
-from silx.gui.utils.testutils import TestCaseQt
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile
-from silx.gui.plot.StackView import StackView
-from silx.gui.plot.tools.profile import rois
-from silx.gui.plot.tools.profile import editors
-from silx.gui.plot.items import roi as roi_items
-from silx.gui.plot.tools.profile import manager
-from silx.gui import plot as silx_plot
-
-_logger = logging.getLogger(__name__)
-
-
-class TestRois(TestCaseQt):
-
- def test_init(self):
- """Check that the constructor is not called twice"""
- roi = rois.ProfileImageVerticalLineROI()
- if qt.BINDING not in ["PySide", "PySide2"]:
- # the profile ROI + the shape
- self.assertEqual(roi.receivers(roi.sigRegionChanged), 2)
-
-
-class TestInteractions(TestCaseQt):
-
- @contextlib.contextmanager
- def defaultPlot(self):
- try:
- widget = silx_plot.PlotWidget()
- widget.show()
- self.qWaitForWindowExposed(widget)
- yield widget
- finally:
- widget.close()
- widget = None
- self.qWait()
-
- @contextlib.contextmanager
- def imagePlot(self):
- try:
- widget = silx_plot.Plot2D()
- image = numpy.arange(10 * 10).reshape(10, -1)
- widget.addImage(image)
- widget.show()
- self.qWaitForWindowExposed(widget)
- yield widget
- finally:
- widget.close()
- widget = None
- self.qWait()
-
- @contextlib.contextmanager
- def scatterPlot(self):
- try:
- widget = silx_plot.ScatterView()
-
- nbX, nbY = 7, 5
- yy = numpy.atleast_2d(numpy.ones(nbY)).T
- xx = numpy.atleast_2d(numpy.ones(nbX))
- positionX = numpy.linspace(10, 50, nbX) * yy
- positionX = positionX.reshape(nbX * nbY)
- positionY = numpy.atleast_2d(numpy.linspace(20, 60, nbY)).T * xx
- positionY = positionY.reshape(nbX * nbY)
- values = numpy.arange(nbX * nbY)
-
- widget.setData(positionX, positionY, values)
- widget.resetZoom()
- widget.show()
- self.qWaitForWindowExposed(widget)
- yield widget.getPlotWidget()
- finally:
- widget.close()
- widget = None
- self.qWait()
-
- @contextlib.contextmanager
- def stackPlot(self):
- try:
- widget = silx_plot.StackView()
- image = numpy.arange(10 * 10).reshape(10, -1)
- cube = numpy.array([image, image, image])
- widget.setStack(cube)
- widget.resetZoom()
- widget.show()
- self.qWaitForWindowExposed(widget)
- yield widget.getPlotWidget()
- finally:
- widget.close()
- widget = None
- self.qWait()
-
- def waitPendingOperations(self, proflie):
- for _ in range(10):
- if not proflie.hasPendingOperations():
- return
- self.qWait(100)
- _logger.error("The profile manager still have pending operations")
-
- def genericRoiTest(self, plot, roiClass):
- profileManager = manager.ProfileManager(plot, plot)
- profileManager.setItemType(image=True, scatter=True)
-
- try:
- action = profileManager.createProfileAction(roiClass, plot)
- action.triggered[bool].emit(True)
- widget = plot.getWidgetHandle()
-
- # Do the mouse interaction
- pos1 = widget.width() * 0.4, widget.height() * 0.4
- self.mouseMove(widget, pos=pos1)
- self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
-
- if issubclass(roiClass, roi_items.LineROI):
- pos2 = widget.width() * 0.6, widget.height() * 0.6
- self.mouseMove(widget, pos=pos2)
- self.mouseClick(widget, qt.Qt.LeftButton, pos=pos2)
-
- self.waitPendingOperations(profileManager)
-
- # Test that something was computed
- if issubclass(roiClass, rois._ProfileCrossROI):
- self.assertEqual(profileManager._computedProfiles, 2)
- elif issubclass(roiClass, roi_items.LineROI):
- self.assertGreaterEqual(profileManager._computedProfiles, 1)
- else:
- self.assertEqual(profileManager._computedProfiles, 1)
-
- # Test the created ROIs
- profileRois = profileManager.getRoiManager().getRois()
- if issubclass(roiClass, rois._ProfileCrossROI):
- self.assertEqual(len(profileRois), 3)
- else:
- self.assertEqual(len(profileRois), 1)
- # The first one should be the expected one
- roi = profileRois[0]
-
- # Test that something was displayed
- if issubclass(roiClass, rois._ProfileCrossROI):
- profiles = roi._getLines()
- window = profiles[0].getProfileWindow()
- self.assertIsNotNone(window)
- window = profiles[1].getProfileWindow()
- self.assertIsNotNone(window)
- else:
- window = roi.getProfileWindow()
- self.assertIsNotNone(window)
- finally:
- profileManager.clearProfile()
-
- def testImageActions(self):
- roiClasses = [
- rois.ProfileImageHorizontalLineROI,
- rois.ProfileImageVerticalLineROI,
- rois.ProfileImageLineROI,
- rois.ProfileImageCrossROI,
- ]
- with self.imagePlot() as plot:
- for roiClass in roiClasses:
- with self.subTest(roiClass=roiClass):
- self.genericRoiTest(plot, roiClass)
-
- def testScatterActions(self):
- roiClasses = [
- rois.ProfileScatterHorizontalLineROI,
- rois.ProfileScatterVerticalLineROI,
- rois.ProfileScatterLineROI,
- rois.ProfileScatterCrossROI,
- rois.ProfileScatterHorizontalSliceROI,
- rois.ProfileScatterVerticalSliceROI,
- rois.ProfileScatterCrossSliceROI,
- ]
- with self.scatterPlot() as plot:
- for roiClass in roiClasses:
- with self.subTest(roiClass=roiClass):
- self.genericRoiTest(plot, roiClass)
-
- def testStackActions(self):
- roiClasses = [
- rois.ProfileImageStackHorizontalLineROI,
- rois.ProfileImageStackVerticalLineROI,
- rois.ProfileImageStackLineROI,
- rois.ProfileImageStackCrossROI,
- ]
- with self.stackPlot() as plot:
- for roiClass in roiClasses:
- with self.subTest(roiClass=roiClass):
- self.genericRoiTest(plot, roiClass)
-
- def genericEditorTest(self, plot, roi, editor):
- if isinstance(editor, editors._NoProfileRoiEditor):
- pass
- elif isinstance(editor, editors._DefaultImageStackProfileRoiEditor):
- # GUI to ROI
- editor._lineWidth.setValue(2)
- self.assertEqual(roi.getProfileLineWidth(), 2)
- editor._methodsButton.setMethod("sum")
- self.assertEqual(roi.getProfileMethod(), "sum")
- editor._profileDim.setDimension(1)
- self.assertEqual(roi.getProfileType(), "1D")
- # ROI to GUI
- roi.setProfileLineWidth(3)
- self.assertEqual(editor._lineWidth.value(), 3)
- roi.setProfileMethod("mean")
- self.assertEqual(editor._methodsButton.getMethod(), "mean")
- roi.setProfileType("2D")
- self.assertEqual(editor._profileDim.getDimension(), 2)
- elif isinstance(editor, editors._DefaultImageProfileRoiEditor):
- # GUI to ROI
- editor._lineWidth.setValue(2)
- self.assertEqual(roi.getProfileLineWidth(), 2)
- editor._methodsButton.setMethod("sum")
- self.assertEqual(roi.getProfileMethod(), "sum")
- # ROI to GUI
- roi.setProfileLineWidth(3)
- self.assertEqual(editor._lineWidth.value(), 3)
- roi.setProfileMethod("mean")
- self.assertEqual(editor._methodsButton.getMethod(), "mean")
- elif isinstance(editor, editors._DefaultScatterProfileRoiEditor):
- # GUI to ROI
- editor._nPoints.setValue(100)
- self.assertEqual(roi.getNPoints(), 100)
- # ROI to GUI
- roi.setNPoints(200)
- self.assertEqual(editor._nPoints.value(), 200)
- else:
- assert False
-
- def testEditors(self):
- roiClasses = [
- (rois.ProfileImageHorizontalLineROI, editors._DefaultImageProfileRoiEditor),
- (rois.ProfileImageVerticalLineROI, editors._DefaultImageProfileRoiEditor),
- (rois.ProfileImageLineROI, editors._DefaultImageProfileRoiEditor),
- (rois.ProfileImageCrossROI, editors._DefaultImageProfileRoiEditor),
- (rois.ProfileScatterHorizontalLineROI, editors._DefaultScatterProfileRoiEditor),
- (rois.ProfileScatterVerticalLineROI, editors._DefaultScatterProfileRoiEditor),
- (rois.ProfileScatterLineROI, editors._DefaultScatterProfileRoiEditor),
- (rois.ProfileScatterCrossROI, editors._DefaultScatterProfileRoiEditor),
- (rois.ProfileScatterHorizontalSliceROI, editors._NoProfileRoiEditor),
- (rois.ProfileScatterVerticalSliceROI, editors._NoProfileRoiEditor),
- (rois.ProfileScatterCrossSliceROI, editors._NoProfileRoiEditor),
- (rois.ProfileImageStackHorizontalLineROI, editors._DefaultImageStackProfileRoiEditor),
- (rois.ProfileImageStackVerticalLineROI, editors._DefaultImageStackProfileRoiEditor),
- (rois.ProfileImageStackLineROI, editors._DefaultImageStackProfileRoiEditor),
- (rois.ProfileImageStackCrossROI, editors._DefaultImageStackProfileRoiEditor),
- ]
- with self.defaultPlot() as plot:
- profileManager = manager.ProfileManager(plot, plot)
- editorAction = profileManager.createEditorAction(parent=plot)
- for roiClass, editorClass in roiClasses:
- with self.subTest(roiClass=roiClass):
- roi = roiClass()
- roi._setProfileManager(profileManager)
- try:
- # Force widget creation
- menu = qt.QMenu(plot)
- menu.addAction(editorAction)
- widgets = editorAction.createdWidgets()
- self.assertGreater(len(widgets), 0)
-
- editorAction.setProfileRoi(roi)
- editorWidget = editorAction._getEditor(widgets[0])
- self.assertIsInstance(editorWidget, editorClass)
- self.genericEditorTest(plot, roi, editorWidget)
- finally:
- editorAction.setProfileRoi(None)
- menu.deleteLater()
- menu = None
- self.qapp.processEvents()
-
-
-class TestProfileToolBar(TestCaseQt, ParametricTestCase):
- """Tests for ProfileToolBar widget."""
-
- def setUp(self):
- super(TestProfileToolBar, self).setUp()
- self.plot = PlotWindow()
- self.toolBar = Profile.ProfileToolBar(plot=self.plot)
- self.plot.addToolBar(self.toolBar)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.mouseMove(self.plot) # Move to center
- self.qapp.processEvents()
- deprecation.FORCE = True
-
- def tearDown(self):
- deprecation.FORCE = False
- self.qapp.processEvents()
- profileManager = self.toolBar.getProfileManager()
- profileManager.clearProfile()
- profileManager = None
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- del self.toolBar
-
- super(TestProfileToolBar, self).tearDown()
-
- def testAlignedProfile(self):
- """Test horizontal and vertical profile, without and with image"""
- # Use Plot backend widget to submit mouse events
- widget = self.plot.getWidgetHandle()
- for method in ('sum', 'mean'):
- with self.subTest(method=method):
- # 2 positions to use for mouse events
- pos1 = widget.width() * 0.4, widget.height() * 0.4
- pos2 = widget.width() * 0.6, widget.height() * 0.6
-
- for action in (self.toolBar.hLineAction, self.toolBar.vLineAction):
- with self.subTest(mode=action.text()):
- # Trigger tool button for mode
- action.trigger()
- # Without image
- self.mouseMove(widget, pos=pos1)
- self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
-
- # with image
- self.plot.addImage(
- numpy.arange(100 * 100).reshape(100, -1))
- self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
- self.mouseMove(widget, pos=pos2)
- self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
-
- self.mouseMove(widget)
- self.mouseClick(widget, qt.Qt.LeftButton)
-
- manager = self.toolBar.getProfileManager()
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- @testutils.test_logging(deprecation.depreclog.name, warning=4)
- def testDiagonalProfile(self):
- """Test diagonal profile, without and with image"""
- # Use Plot backend widget to submit mouse events
- widget = self.plot.getWidgetHandle()
-
- for method in ('sum', 'mean'):
- for image in (False, True):
- with self.subTest(method=method, image=image):
- # 2 positions to use for mouse events
- pos1 = widget.width() * 0.4, widget.height() * 0.4
- pos2 = widget.width() * 0.6, widget.height() * 0.6
-
- if image:
- self.plot.addImage(
- numpy.arange(100 * 100).reshape(100, -1))
-
- # Trigger tool button for diagonal profile mode
- self.toolBar.lineAction.trigger()
-
- # draw profile line
- widget.setFocus(qt.Qt.OtherFocusReason)
- self.mouseMove(widget, pos=pos1)
- self.qWait(100)
- self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
- self.qWait(100)
- self.mouseMove(widget, pos=pos2)
- self.qWait(100)
- self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
- self.qWait(100)
-
- manager = self.toolBar.getProfileManager()
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- roi = manager.getCurrentRoi()
- self.assertIsNotNone(roi)
- roi.setProfileLineWidth(3)
- roi.setProfileMethod(method)
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- if image is True:
- curveItem = self.toolBar.getProfilePlot().getAllCurves()[0]
- if method == 'sum':
- self.assertTrue(curveItem.getData()[1].max() > 10000)
- elif method == 'mean':
- self.assertTrue(curveItem.getData()[1].max() < 10000)
-
- # Remove the ROI so the profile window is also removed
- roiManager = manager.getRoiManager()
- roiManager.removeRoi(roi)
- self.qWait(100)
-
-
-class TestDeprecatedProfileToolBar(TestCaseQt):
- """Tests old features of the ProfileToolBar widget."""
-
- def setUp(self):
- self.plot = None
- super(TestDeprecatedProfileToolBar, self).setUp()
-
- def tearDown(self):
- if self.plot is not None:
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.plot = None
- self.qWait()
-
- super(TestDeprecatedProfileToolBar, self).tearDown()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=2)
- def testCustomProfileWindow(self):
- from silx.gui.plot import ProfileMainWindow
-
- self.plot = PlotWindow()
- profileWindow = ProfileMainWindow.ProfileMainWindow(self.plot)
- toolBar = Profile.ProfileToolBar(parent=self.plot,
- plot=self.plot,
- profileWindow=profileWindow)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
- profileWindow.show()
- self.qWaitForWindowExposed(profileWindow)
- self.qapp.processEvents()
-
- self.plot.addImage(numpy.arange(10 * 10).reshape(10, -1))
- profile = rois.ProfileImageHorizontalLineROI()
- profile.setPosition(5)
- toolBar.getProfileManager().getRoiManager().addRoi(profile)
- toolBar.getProfileManager().getRoiManager().setCurrentRoi(profile)
-
- for _ in range(20):
- self.qWait(200)
- if not toolBar.getProfileManager().hasPendingOperations():
- break
-
- # There is a displayed profile
- self.assertIsNotNone(profileWindow.getProfile())
- self.assertIs(toolBar.getProfileMainWindow(), profileWindow)
-
- # There is nothing anymore but the window is still there
- toolBar.getProfileManager().clearProfile()
- self.qapp.processEvents()
- self.assertIsNone(profileWindow.getProfile())
-
-
-class TestProfile3DToolBar(TestCaseQt):
- """Tests for Profile3DToolBar widget.
- """
- def setUp(self):
- super(TestProfile3DToolBar, self).setUp()
- self.plot = StackView()
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.plot.setStack(numpy.array([
- [[0, 1, 2], [3, 4, 5]],
- [[6, 7, 8], [9, 10, 11]],
- [[12, 13, 14], [15, 16, 17]]
- ]))
- deprecation.FORCE = True
-
- def tearDown(self):
- deprecation.FORCE = False
- profileManager = self.plot.getProfileToolbar().getProfileManager()
- profileManager.clearProfile()
- profileManager = None
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.plot = None
-
- super(TestProfile3DToolBar, self).tearDown()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=2)
- def testMethodProfile2D(self):
- """Test that the profile can have a different method if we want to
- compute then in 1D or in 2D"""
-
- toolBar = self.plot.getProfileToolbar()
-
- toolBar.vLineAction.trigger()
- plot2D = self.plot.getPlotWidget().getWidgetHandle()
- pos1 = plot2D.width() * 0.5, plot2D.height() * 0.5
- self.mouseClick(plot2D, qt.Qt.LeftButton, pos=pos1)
-
- manager = toolBar.getProfileManager()
- roi = manager.getCurrentRoi()
- roi.setProfileMethod("mean")
- roi.setProfileType("2D")
- roi.setProfileLineWidth(3)
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- # check 2D 'mean' profile
- profilePlot = toolBar.getProfilePlot()
- data = profilePlot.getAllImages()[0].getData()
- expected = numpy.array([[1, 4], [7, 10], [13, 16]])
- numpy.testing.assert_almost_equal(data, expected)
-
- @testutils.test_logging(deprecation.depreclog.name, warning=2)
- def testMethodSumLine(self):
- """Simple interaction test to make sure the sum is correctly computed
- """
- toolBar = self.plot.getProfileToolbar()
-
- toolBar.lineAction.trigger()
- plot2D = self.plot.getPlotWidget().getWidgetHandle()
- pos1 = plot2D.width() * 0.5, plot2D.height() * 0.2
- pos2 = plot2D.width() * 0.5, plot2D.height() * 0.8
-
- self.mouseMove(plot2D, pos=pos1)
- self.mousePress(plot2D, qt.Qt.LeftButton, pos=pos1)
- self.mouseMove(plot2D, pos=pos2)
- self.mouseRelease(plot2D, qt.Qt.LeftButton, pos=pos2)
-
- manager = toolBar.getProfileManager()
- roi = manager.getCurrentRoi()
- roi.setProfileMethod("sum")
- roi.setProfileType("2D")
- roi.setProfileLineWidth(3)
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- # check 2D 'sum' profile
- profilePlot = toolBar.getProfilePlot()
- data = profilePlot.getAllImages()[0].getData()
- expected = numpy.array([[3, 12], [21, 30], [39, 48]])
- numpy.testing.assert_almost_equal(data, expected)
-
-
-class TestGetProfilePlot(TestCaseQt):
-
- def setUp(self):
- self.plot = None
- super(TestGetProfilePlot, self).setUp()
-
- def tearDown(self):
- if self.plot is not None:
- manager = self.plot.getProfileToolbar().getProfileManager()
- manager.clearProfile()
- manager = None
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- self.plot = None
-
- super(TestGetProfilePlot, self).tearDown()
-
- def testProfile1D(self):
- self.plot = Plot2D()
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
- self.plot.addImage([[0, 1], [2, 3]])
-
- toolBar = self.plot.getProfileToolbar()
-
- manager = toolBar.getProfileManager()
- roiManager = manager.getRoiManager()
-
- roi = rois.ProfileImageHorizontalLineROI()
- roi.setPosition(0.5)
- roiManager.addRoi(roi)
- roiManager.setCurrentRoi(roi)
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- profileWindow = roi.getProfileWindow()
- self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
- self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
-
- def testProfile2D(self):
- """Test that the profile plot associated to a stack view is either a
- Plot1D or a plot 2D instance."""
- self.plot = StackView()
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- self.plot.setStack(numpy.array([[[0, 1], [2, 3]],
- [[4, 5], [6, 7]]]))
-
- toolBar = self.plot.getProfileToolbar()
-
- manager = toolBar.getProfileManager()
- roiManager = manager.getRoiManager()
-
- roi = rois.ProfileImageStackHorizontalLineROI()
- roi.setPosition(0.5)
- roi.setProfileType("2D")
- roiManager.addRoi(roi)
- roiManager.setCurrentRoi(roi)
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- profileWindow = roi.getProfileWindow()
- self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
- self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot2D)
-
- roi.setProfileType("1D")
-
- for _ in range(20):
- self.qWait(200)
- if not manager.hasPendingOperations():
- break
-
- profileWindow = roi.getProfileWindow()
- self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
- self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestRois))
- test_suite.addTest(loadTests(TestInteractions))
- test_suite.addTest(loadTests(TestProfileToolBar))
- test_suite.addTest(loadTests(TestGetProfilePlot))
- test_suite.addTest(loadTests(TestProfile3DToolBar))
- test_suite.addTest(loadTests(TestDeprecatedProfileToolBar))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/tools/test/testROI.py b/silx/gui/plot/tools/test/testROI.py
deleted file mode 100644
index 8a00073..0000000
--- a/silx/gui/plot/tools/test/testROI.py
+++ /dev/null
@@ -1,694 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "28/06/2018"
-
-
-import unittest
-import numpy.testing
-
-from silx.gui import qt
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt, SignalListener
-from silx.gui.plot import PlotWindow
-import silx.gui.plot.items.roi as roi_items
-from silx.gui.plot.tools import roi
-
-
-class TestRoiItems(TestCaseQt):
-
- def testLine_geometry(self):
- item = roi_items.LineROI()
- startPoint = numpy.array([1, 2])
- endPoint = numpy.array([3, 4])
- item.setEndPoints(startPoint, endPoint)
- numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint)
- numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint)
-
- def testHLine_geometry(self):
- item = roi_items.HorizontalLineROI()
- item.setPosition(15)
- self.assertEqual(item.getPosition(), 15)
-
- def testVLine_geometry(self):
- item = roi_items.VerticalLineROI()
- item.setPosition(15)
- self.assertEqual(item.getPosition(), 15)
-
- def testPoint_geometry(self):
- point = numpy.array([1, 2])
- item = roi_items.PointROI()
- item.setPosition(point)
- numpy.testing.assert_allclose(item.getPosition(), point)
-
- def testRectangle_originGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- center = numpy.array([5, 10])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- numpy.testing.assert_allclose(item.getOrigin(), origin)
- numpy.testing.assert_allclose(item.getSize(), size)
- numpy.testing.assert_allclose(item.getCenter(), center)
-
- def testRectangle_centerGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- center = numpy.array([5, 10])
- item = roi_items.RectangleROI()
- item.setGeometry(center=center, size=size)
- numpy.testing.assert_allclose(item.getOrigin(), origin)
- numpy.testing.assert_allclose(item.getSize(), size)
- numpy.testing.assert_allclose(item.getCenter(), center)
-
- def testRectangle_setCenterGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- newCenter = numpy.array([0, 0])
- item.setCenter(newCenter)
- expectedOrigin = numpy.array([-5, -10])
- numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin)
- numpy.testing.assert_allclose(item.getCenter(), newCenter)
- numpy.testing.assert_allclose(item.getSize(), size)
-
- def testRectangle_setOriginGeometry(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- newOrigin = numpy.array([10, 10])
- item.setOrigin(newOrigin)
- expectedCenter = numpy.array([15, 20])
- numpy.testing.assert_allclose(item.getOrigin(), newOrigin)
- numpy.testing.assert_allclose(item.getCenter(), expectedCenter)
- numpy.testing.assert_allclose(item.getSize(), size)
-
- def testCircle_geometry(self):
- center = numpy.array([0, 0])
- radius = 10.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- numpy.testing.assert_allclose(item.getCenter(), center)
- numpy.testing.assert_allclose(item.getRadius(), radius)
-
- def testCircle_setCenter(self):
- center = numpy.array([0, 0])
- radius = 10.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- newCenter = numpy.array([-10, 0])
- item.setCenter(newCenter)
- numpy.testing.assert_allclose(item.getCenter(), newCenter)
- numpy.testing.assert_allclose(item.getRadius(), radius)
-
- def testCircle_setRadius(self):
- center = numpy.array([0, 0])
- radius = 10.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- newRadius = 5.1
- item.setRadius(newRadius)
- numpy.testing.assert_allclose(item.getCenter(), center)
- numpy.testing.assert_allclose(item.getRadius(), newRadius)
-
- def testCircle_contains(self):
- center = numpy.array([2, -1])
- radius = 1.
- item = roi_items.CircleROI()
- item.setGeometry(center=center, radius=radius)
- self.assertTrue(item.contains([1, -1]))
- self.assertFalse(item.contains([0, 0]))
- self.assertTrue(item.contains([2, 0]))
- self.assertFalse(item.contains([3.01, -1]))
-
- def testEllipse_contains(self):
- center = numpy.array([-2, 0])
- item = roi_items.EllipseROI()
- item.setCenter(center)
- item.setOrientation(numpy.pi / 4.0)
- item.setMajorRadius(2)
- item.setMinorRadius(1)
- print(item.getMinorRadius(), item.getMajorRadius())
- self.assertFalse(item.contains([0, 0]))
- self.assertTrue(item.contains([-1, 1]))
- self.assertTrue(item.contains([-3, 0]))
- self.assertTrue(item.contains([-2, 0]))
- self.assertTrue(item.contains([-2, 1]))
- self.assertFalse(item.contains([-4, 1]))
-
- def testRectangle_isIn(self):
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin, size=size)
- self.assertTrue(item.contains(position=(0, 0)))
- self.assertTrue(item.contains(position=(2, 14)))
- self.assertFalse(item.contains(position=(14, 12)))
-
- def testPolygon_emptyGeometry(self):
- points = numpy.empty((0, 2))
- item = roi_items.PolygonROI()
- item.setPoints(points)
- numpy.testing.assert_allclose(item.getPoints(), points)
-
- def testPolygon_geometry(self):
- points = numpy.array([[10, 10], [12, 10], [50, 1]])
- item = roi_items.PolygonROI()
- item.setPoints(points)
- numpy.testing.assert_allclose(item.getPoints(), points)
-
- def testPolygon_isIn(self):
- points = numpy.array([[0, 0], [0, 10], [5, 10]])
- item = roi_items.PolygonROI()
- item.setPoints(points)
- self.assertTrue(item.contains((0, 0)))
- self.assertFalse(item.contains((6, 2)))
- self.assertFalse(item.contains((-2, 5)))
- self.assertFalse(item.contains((2, -1)))
- self.assertFalse(item.contains((8, 1)))
- self.assertTrue(item.contains((1, 8)))
-
- def testArc_getToSetGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.ArcROI()
- item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
- item.setGeometry(*item.getGeometry())
-
- def testArc_degenerated_point(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
-
- def testArc_degenerated_line(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
-
- def testArc_special_circle(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
- self.assertTrue(item.isClosed())
-
- def testArc_special_donut(self):
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
- self.assertTrue(item.isClosed())
-
- def testArc_clockwiseGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), startAngle)
- self.assertAlmostEqual(item.getEndAngle(), endAngle)
- self.assertAlmostEqual(item.isClosed(), False)
-
- def testArc_anticlockwiseGeometry(self):
- """Test that we can use getGeometry as input to setGeometry"""
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, -numpy.pi * 0.5
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- numpy.testing.assert_allclose(item.getCenter(), center)
- self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
- self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
- self.assertAlmostEqual(item.getStartAngle(), startAngle)
- self.assertAlmostEqual(item.getEndAngle(), endAngle)
- self.assertAlmostEqual(item.isClosed(), False)
-
- def testHRange_geometry(self):
- item = roi_items.HorizontalRangeROI()
- vmin = 1
- vmax = 3
- item.setRange(vmin, vmax)
- self.assertAlmostEqual(item.getMin(), vmin)
- self.assertAlmostEqual(item.getMax(), vmax)
- self.assertAlmostEqual(item.getCenter(), 2)
-
-
-class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
- """Tests for RegionOfInterestManager class"""
-
- def setUp(self):
- super(TestRegionOfInterestManager, self).setUp()
- self.plot = PlotWindow()
-
- self.roiTableWidget = roi.RegionOfInterestTableWidget()
- dock = qt.QDockWidget()
- dock.setWidget(self.roiTableWidget)
- self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- del self.roiTableWidget
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestRegionOfInterestManager, self).tearDown()
-
- def test(self):
- """Test ROI of different shapes"""
- tests = ( # shape, points=[list of (x, y), list of (x, y)]
- (roi_items.PointROI, numpy.array(([(10., 15.)], [(20., 25.)]))),
- (roi_items.RectangleROI,
- numpy.array((((1., 10.), (11., 20.)),
- ((2., 3.), (12., 13.))))),
- (roi_items.PolygonROI,
- numpy.array((((0., 1.), (0., 10.), (10., 0.)),
- ((5., 6.), (5., 16.), (15., 6.))))),
- (roi_items.LineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- (roi_items.HorizontalLineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- (roi_items.VerticalLineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- (roi_items.HorizontalLineROI,
- numpy.array((((10., 20.), (10., 30.)),
- ((30., 40.), (30., 50.))))),
- )
-
- for roiClass, points in tests:
- with self.subTest(roiClass=roiClass):
- manager = roi.RegionOfInterestManager(self.plot)
- self.roiTableWidget.setRegionOfInterestManager(manager)
- manager.start(roiClass)
-
- self.assertEqual(manager.getRois(), ())
-
- finishListener = SignalListener()
- manager.sigInteractiveModeFinished.connect(finishListener)
-
- changedListener = SignalListener()
- manager.sigRoiChanged.connect(changedListener)
-
- # Add a point
- r = roiClass()
- r.setFirstShapePoints(points[0])
- manager.addRoi(r)
- self.qapp.processEvents()
- self.assertTrue(len(manager.getRois()), 1)
- self.assertEqual(changedListener.callCount(), 1)
-
- # Remove it
- manager.removeRoi(manager.getRois()[0])
- self.assertEqual(manager.getRois(), ())
- self.assertEqual(changedListener.callCount(), 2)
-
- # Add two point
- r = roiClass()
- r.setFirstShapePoints(points[0])
- manager.addRoi(r)
- self.qapp.processEvents()
- r = roiClass()
- r.setFirstShapePoints(points[1])
- manager.addRoi(r)
- self.qapp.processEvents()
- self.assertTrue(len(manager.getRois()), 2)
- self.assertEqual(changedListener.callCount(), 4)
-
- # Reset it
- result = manager.clear()
- self.assertTrue(result)
- self.assertEqual(manager.getRois(), ())
- self.assertEqual(changedListener.callCount(), 5)
-
- changedListener.clear()
-
- # Add two point
- r = roiClass()
- r.setFirstShapePoints(points[0])
- manager.addRoi(r)
- self.qapp.processEvents()
- r = roiClass()
- r.setFirstShapePoints(points[1])
- manager.addRoi(r)
- self.qapp.processEvents()
- self.assertTrue(len(manager.getRois()), 2)
- self.assertEqual(changedListener.callCount(), 2)
-
- # stop
- result = manager.stop()
- self.assertTrue(result)
- self.assertTrue(len(manager.getRois()), 1)
- self.qapp.processEvents()
- self.assertEqual(finishListener.callCount(), 1)
-
- manager.clear()
-
- def testRoiDisplay(self):
- rois = []
-
- # Line
- item = roi_items.LineROI()
- startPoint = numpy.array([1, 2])
- endPoint = numpy.array([3, 4])
- item.setEndPoints(startPoint, endPoint)
- rois.append(item)
- # Horizontal line
- item = roi_items.HorizontalLineROI()
- item.setPosition(15)
- rois.append(item)
- # Vertical line
- item = roi_items.VerticalLineROI()
- item.setPosition(15)
- rois.append(item)
- # Point
- item = roi_items.PointROI()
- point = numpy.array([1, 2])
- item.setPosition(point)
- rois.append(item)
- # Rectangle
- item = roi_items.RectangleROI()
- origin = numpy.array([0, 0])
- size = numpy.array([10, 20])
- item.setGeometry(origin=origin, size=size)
- rois.append(item)
- # Polygon
- item = roi_items.PolygonROI()
- points = numpy.array([[10, 10], [12, 10], [50, 1]])
- item.setPoints(points)
- rois.append(item)
- # Degenerated polygon: No points
- item = roi_items.PolygonROI()
- points = numpy.empty((0, 2))
- item.setPoints(points)
- rois.append(item)
- # Degenerated polygon: A single point
- item = roi_items.PolygonROI()
- points = numpy.array([[5, 10]])
- item.setPoints(points)
- rois.append(item)
- # Degenerated arc: it's a point
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- rois.append(item)
- # Degenerated arc: it's a line
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- rois.append(item)
- # Special arc: it's a donut
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- rois.append(item)
- # Arc
- item = roi_items.ArcROI()
- center = numpy.array([10, 20])
- innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
- item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
- rois.append(item)
- # Horizontal Range
- item = roi_items.HorizontalRangeROI()
- item.setRange(-1, 3)
- rois.append(item)
-
- manager = roi.RegionOfInterestManager(self.plot)
- self.roiTableWidget.setRegionOfInterestManager(manager)
- for item in rois:
- with self.subTest(roi=str(item)):
- manager.addRoi(item)
- self.qapp.processEvents()
- item.setEditable(True)
- self.qapp.processEvents()
- item.setEditable(False)
- self.qapp.processEvents()
- manager.removeRoi(item)
- self.qapp.processEvents()
-
- def testSelectionProxy(self):
- item1 = roi_items.PointROI()
- item1.setSelectable(True)
- item2 = roi_items.PointROI()
- item2.setSelectable(True)
- item1.setFocusProxy(item2)
- manager = roi.RegionOfInterestManager(self.plot)
- manager.setCurrentRoi(item1)
- self.assertIs(manager.getCurrentRoi(), item2)
-
- def testRemovedSelection(self):
- item1 = roi_items.PointROI()
- item1.setSelectable(True)
- manager = roi.RegionOfInterestManager(self.plot)
- manager.addRoi(item1)
- manager.setCurrentRoi(item1)
- manager.removeRoi(item1)
- self.assertIs(manager.getCurrentRoi(), None)
-
- def testMaxROI(self):
- """Test Max ROI"""
- origin1 = numpy.array([1., 10.])
- size1 = numpy.array([10., 10.])
- origin2 = numpy.array([2., 3.])
- size2 = numpy.array([10., 10.])
-
- manager = roi.InteractiveRegionOfInterestManager(self.plot)
- self.roiTableWidget.setRegionOfInterestManager(manager)
- self.assertEqual(manager.getRois(), ())
-
- changedListener = SignalListener()
- manager.sigRoiChanged.connect(changedListener)
-
- # Add two point
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin1, size=size1)
- manager.addRoi(item)
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin2, size=size2)
- manager.addRoi(item)
- self.qapp.processEvents()
- self.assertEqual(changedListener.callCount(), 2)
- self.assertEqual(len(manager.getRois()), 2)
-
- # Try to set max ROI to 1 while there is 2 ROIs
- with self.assertRaises(ValueError):
- manager.setMaxRois(1)
-
- manager.clear()
- self.assertEqual(len(manager.getRois()), 0)
- self.assertEqual(changedListener.callCount(), 3)
-
- # Set max limit to 1
- manager.setMaxRois(1)
-
- # Add a point
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin1, size=size1)
- manager.addRoi(item)
- self.qapp.processEvents()
- self.assertEqual(changedListener.callCount(), 4)
-
- # Add a 2nd point while max ROI is 1
- item = roi_items.RectangleROI()
- item.setGeometry(origin=origin1, size=size1)
- manager.addRoi(item)
- self.qapp.processEvents()
- self.assertEqual(changedListener.callCount(), 6)
- self.assertEqual(len(manager.getRois()), 1)
-
- def testChangeInteractionMode(self):
- """Test change of interaction mode"""
- manager = roi.RegionOfInterestManager(self.plot)
- self.roiTableWidget.setRegionOfInterestManager(manager)
- manager.start(roi_items.PointROI)
-
- interactiveModeToolBar = self.plot.getInteractiveModeToolBar()
- panAction = interactiveModeToolBar.getPanModeAction()
-
- for roiClass in manager.getSupportedRoiClasses():
- with self.subTest(roiClass=roiClass):
- # Change to pan mode
- panAction.trigger()
-
- # Change to interactive ROI mode
- action = manager.getInteractionModeAction(roiClass)
- action.trigger()
-
- self.assertEqual(roiClass, manager.getCurrentInteractionModeRoiClass())
-
- manager.clear()
-
- def testLineInteraction(self):
- """This test make sure that a ROI based on handles can be edited with
- the mouse."""
- xlimit = self.plot.getXAxis().getLimits()
- ylimit = self.plot.getYAxis().getLimits()
- points = numpy.array([xlimit, ylimit]).T
- center = numpy.mean(points, axis=0)
-
- # Create the line
- manager = roi.RegionOfInterestManager(self.plot)
- item = roi_items.LineROI()
- item.setEndPoints(points[0], points[1])
- item.setEditable(True)
- manager.addRoi(item)
- self.qapp.processEvents()
-
- # Drag the center
- widget = self.plot.getWidgetHandle()
- mx, my = self.plot.dataToPixel(*center)
- self.mouseMove(widget, pos=(mx, my))
- self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my))
- self.mouseMove(widget, pos=(mx, my+25))
- self.mouseMove(widget, pos=(mx, my+50))
- self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50))
-
- result = numpy.array(item.getEndPoints())
- # x location is still the same
- numpy.testing.assert_allclose(points[:, 0], result[:, 0], atol=0.5)
- # size is still the same
- numpy.testing.assert_allclose(points[1] - points[0],
- result[1] - result[0], atol=0.5)
- # But Y is not the same
- self.assertNotEqual(points[0, 1], result[0, 1])
- self.assertNotEqual(points[1, 1], result[1, 1])
- item = None
- manager.clear()
- self.qapp.processEvents()
-
- def testPlotWhenCleared(self):
- """PlotWidget.clear should clean up the available ROIs"""
- manager = roi.RegionOfInterestManager(self.plot)
- item = roi_items.LineROI()
- item.setEndPoints((0, 0), (1, 1))
- item.setEditable(True)
- manager.addRoi(item)
- self.qWait()
- try:
- # Make sure the test setup is fine
- self.assertNotEqual(len(manager.getRois()), 0)
- self.assertNotEqual(len(self.plot.getItems()), 0)
-
- # Call clear and test the expected state
- self.plot.clear()
- self.assertEqual(len(manager.getRois()), 0)
- self.assertEqual(len(self.plot.getItems()), 0)
- finally:
- # Clean up
- manager.clear()
-
- def testPlotWhenRoiRemoved(self):
- """Make sure there is no remaining items in the plot when a ROI is removed"""
- manager = roi.RegionOfInterestManager(self.plot)
- item = roi_items.LineROI()
- item.setEndPoints((0, 0), (1, 1))
- item.setEditable(True)
- manager.addRoi(item)
- self.qWait()
- try:
- # Make sure the test setup is fine
- self.assertNotEqual(len(manager.getRois()), 0)
- self.assertNotEqual(len(self.plot.getItems()), 0)
-
- # Call clear and test the expected state
- manager.removeRoi(item)
- self.assertEqual(len(manager.getRois()), 0)
- self.assertEqual(len(self.plot.getItems()), 0)
- finally:
- # Clean up
- manager.clear()
-
- def testArcRoiSwitchMode(self):
- """Make sure we can switch mode by clicking on the ROI"""
- xlimit = self.plot.getXAxis().getLimits()
- ylimit = self.plot.getYAxis().getLimits()
- points = numpy.array([xlimit, ylimit]).T
- center = numpy.mean(points, axis=0)
- size = numpy.abs(points[1] - points[0])
-
- # Create the line
- manager = roi.RegionOfInterestManager(self.plot)
- item = roi_items.ArcROI()
- item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3)
- item.setEditable(True)
- item.setSelectable(True)
- manager.addRoi(item)
- self.qapp.processEvents()
-
- # Initial state
- self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
- self.qWait(500)
-
- # Click on the center
- widget = self.plot.getWidgetHandle()
- mx, my = self.plot.dataToPixel(*center)
-
- # Select the ROI
- self.mouseMove(widget, pos=(mx, my))
- self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
- self.qWait(500)
- self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
-
- # Change the mode
- self.mouseMove(widget, pos=(mx, my))
- self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
- self.qWait(500)
- self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode)
-
- manager.clear()
- self.qapp.processEvents()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestRoiItems))
- test_suite.addTest(loadTests(TestRegionOfInterestManager))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/silx/gui/plot/tools/test/testScatterProfileToolBar.py
deleted file mode 100644
index b9f4885..0000000
--- a/silx/gui/plot/tools/test/testScatterProfileToolBar.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "28/06/2018"
-
-
-import unittest
-import numpy
-
-from silx.gui import qt
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot import PlotWindow
-from silx.gui.plot.tools.profile import manager
-from silx.gui.plot.tools.profile import core
-from silx.gui.plot.tools.profile import rois
-
-
-class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
- """Tests for ScatterProfileToolBar class"""
-
- def setUp(self):
- super(TestScatterProfileToolBar, self).setUp()
- self.plot = PlotWindow()
-
- self.manager = manager.ProfileManager(plot=self.plot)
- self.manager.setItemType(scatter=True)
- self.manager.setActiveItemTracking(True)
-
- self.plot.show()
- self.qWaitForWindowExposed(self.plot)
-
- def tearDown(self):
- del self.manager
- self.qapp.processEvents()
- self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.plot.close()
- del self.plot
- super(TestScatterProfileToolBar, self).tearDown()
-
- def testHorizontalProfile(self):
- """Test ScatterProfileToolBar horizontal profile"""
-
- roiManager = self.manager.getRoiManager()
-
- # Add a scatter plot
- self.plot.addScatter(
- x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
- self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
- self.qapp.processEvents()
-
- # Set a ROI profile
- roi = rois.ProfileScatterHorizontalLineROI()
- roi.setPosition(0.5)
- roi.setNPoints(8)
- roiManager.addRoi(roi)
-
- # Wait for async interpolator init
- for _ in range(20):
- self.qWait(200)
- if not self.manager.hasPendingOperations():
- break
- self.qapp.processEvents()
-
- window = roi.getProfileWindow()
- self.assertIsNotNone(window)
- data = window.getProfile()
- self.assertIsInstance(data, core.CurveProfileData)
- self.assertEqual(len(data.coords), 8)
-
- # Check that profile has same limits than Plot
- xLimits = self.plot.getXAxis().getLimits()
- self.assertEqual(data.coords[0], xLimits[0])
- self.assertEqual(data.coords[-1], xLimits[1])
-
- # Clear the profile
- self.manager.clearProfile()
- self.qapp.processEvents()
- self.assertIsNone(roi.getProfileWindow())
-
- def testVerticalProfile(self):
- """Test ScatterProfileToolBar vertical profile"""
-
- roiManager = self.manager.getRoiManager()
-
- # Add a scatter plot
- self.plot.addScatter(
- x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
- self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
- self.qapp.processEvents()
-
- # Set a ROI profile
- roi = rois.ProfileScatterVerticalLineROI()
- roi.setPosition(0.5)
- roi.setNPoints(8)
- roiManager.addRoi(roi)
-
- # Wait for async interpolator init
- for _ in range(10):
- self.qWait(200)
- if not self.manager.hasPendingOperations():
- break
-
- window = roi.getProfileWindow()
- self.assertIsNotNone(window)
- data = window.getProfile()
- self.assertIsInstance(data, core.CurveProfileData)
- self.assertEqual(len(data.coords), 8)
-
- # Check that profile has same limits than Plot
- yLimits = self.plot.getYAxis().getLimits()
- self.assertEqual(data.coords[0], yLimits[0])
- self.assertEqual(data.coords[-1], yLimits[1])
-
- # Check that profile limits are updated when changing limits
- self.plot.getYAxis().setLimits(yLimits[0] + 1, yLimits[1] + 10)
-
- # Wait for async interpolator init
- for _ in range(10):
- self.qWait(200)
- if not self.manager.hasPendingOperations():
- break
-
- yLimits = self.plot.getYAxis().getLimits()
- data = window.getProfile()
- self.assertEqual(data.coords[0], yLimits[0])
- self.assertEqual(data.coords[-1], yLimits[1])
-
- # Clear the profile
- self.manager.clearProfile()
- self.qapp.processEvents()
- self.assertIsNone(roi.getProfileWindow())
-
- def testLineProfile(self):
- """Test ScatterProfileToolBar line profile"""
-
- roiManager = self.manager.getRoiManager()
-
- # Add a scatter plot
- self.plot.addScatter(
- x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
- self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
- self.qapp.processEvents()
-
- # Set a ROI profile
- roi = rois.ProfileScatterLineROI()
- roi.setEndPoints(numpy.array([0., 0.]), numpy.array([1., 1.]))
- roi.setNPoints(8)
- roiManager.addRoi(roi)
-
- # Wait for async interpolator init
- for _ in range(10):
- self.qWait(200)
- if not self.manager.hasPendingOperations():
- break
-
- window = roi.getProfileWindow()
- self.assertIsNotNone(window)
- data = window.getProfile()
- self.assertIsInstance(data, core.CurveProfileData)
- self.assertEqual(len(data.coords), 8)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestScatterProfileToolBar))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/tools/test/testTools.py b/silx/gui/plot/tools/test/testTools.py
deleted file mode 100644
index 70c8105..0000000
--- a/silx/gui/plot/tools/test/testTools.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for silx.gui.plot.tools package"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/03/2018"
-
-
-import functools
-import unittest
-import numpy
-
-from silx.utils.testutils import TestLogging
-from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate
-from silx.gui import qt
-from silx.gui.plot import PlotWindow
-from silx.gui.plot import tools
-from silx.gui.plot.test.utils import PlotWidgetTestCase
-
-
-class TestPositionInfo(PlotWidgetTestCase):
- """Tests for PositionInfo widget."""
-
- def _createPlot(self):
- return PlotWindow()
-
- def setUp(self):
- super(TestPositionInfo, self).setUp()
- self.mouseMove(self.plot, pos=(0, 0))
- self.qapp.processEvents()
- self.qWait(100)
-
- def tearDown(self):
- super(TestPositionInfo, self).tearDown()
-
- def _test(self, positionWidget, converterNames, **kwargs):
- """General test of PositionInfo.
-
- - Add it to a toolbar and
- - Move mouse around the center of the PlotWindow.
- """
- toolBar = qt.QToolBar()
- self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar)
-
- toolBar.addWidget(positionWidget)
-
- converters = positionWidget.getConverters()
- self.assertEqual(len(converters), len(converterNames))
- for index, name in enumerate(converterNames):
- self.assertEqual(converters[index][0], name)
-
- with TestLogging(tools.__name__, **kwargs):
- # Move mouse to center
- center = self.plot.size() / 2
- self.mouseMove(self.plot, pos=(center.width(), center.height()))
- # Move out
- self.mouseMove(self.plot, pos=(1, 1))
-
- def testDefaultConverters(self):
- """Test PositionInfo with default converters"""
- positionWidget = tools.PositionInfo(plot=self.plot)
- self._test(positionWidget, ('X', 'Y'))
-
- def testCustomConverters(self):
- """Test PositionInfo with custom converters"""
- converters = [
- ('Coords', lambda x, y: (int(x), int(y))),
- ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)),
- ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))
- ]
- positionWidget = tools.PositionInfo(plot=self.plot,
- converters=converters)
- self._test(positionWidget, ('Coords', 'Radius', 'Angle'))
-
- def testFailingConverters(self):
- """Test PositionInfo with failing custom converters"""
- def raiseException(x, y):
- raise RuntimeError()
-
- positionWidget = tools.PositionInfo(
- plot=self.plot,
- converters=[('Exception', raiseException)])
- self._test(positionWidget, ['Exception'], error=2)
-
- def testUpdate(self):
- """Test :meth:`PositionInfo.updateInfo`"""
- calls = []
-
- def update(calls, x, y): # Get number of calls
- calls.append((x, y))
- return len(calls)
-
- positionWidget = tools.PositionInfo(
- plot=self.plot,
- converters=[('Call count', functools.partial(update, calls))])
-
- positionWidget.updateInfo()
- self.assertEqual(len(calls), 1)
-
-
-class TestPlotToolsToolbars(PlotWidgetTestCase):
- """Tests toolbars from silx.gui.plot.tools"""
-
- def test(self):
- """"Add all toolbars"""
- for tbClass in (tools.InteractiveModeToolBar,
- tools.ImageToolBar,
- tools.CurveToolBar,
- tools.OutputToolBar):
- tb = tbClass(parent=self.plot, plot=self.plot)
- self.plot.addToolBar(tb)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- # test_suite.addTest(positionInfoTestSuite)
- for testClass in (TestPositionInfo, TestPlotToolsToolbars):
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- testClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot/utils/axis.py b/silx/gui/plot/utils/axis.py
deleted file mode 100644
index 693e8eb..0000000
--- a/silx/gui/plot/utils/axis.py
+++ /dev/null
@@ -1,403 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module contains utils class for axes management.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "20/11/2018"
-
-import functools
-import logging
-from contextlib import contextmanager
-import weakref
-import silx.utils.weakref as silxWeakref
-from silx.gui.plot.items.axis import Axis, XAxis, YAxis
-
-try:
- from ...qt.inspect import isValid as _isQObjectValid
-except ImportError: # PySide(1) fallback
- def _isQObjectValid(obj):
- return True
-
-
-_logger = logging.getLogger(__name__)
-
-
-class SyncAxes(object):
- """Synchronize a set of plot axes together.
-
- It is created with the expected axes and starts to synchronize them.
-
- It can be customized to synchronize limits, scale, and direction of axes
- together. By default everything is synchronized.
-
- The API :meth:`start` and :meth:`stop` can be used to enable/disable the
- synchronization while this object is still alive.
-
- If this object is destroyed the synchronization stop.
-
- .. versionadded:: 0.6
- """
-
- def __init__(self, axes,
- syncLimits=True,
- syncScale=True,
- syncDirection=True,
- syncCenter=False,
- syncZoom=False,
- filterHiddenPlots=False
- ):
- """
- Constructor
-
- :param list(Axis) axes: A list of axes to synchronize together
- :param bool syncLimits: Synchronize axes limits
- :param bool syncScale: Synchronize axes scale
- :param bool syncDirection: Synchronize axes direction
- :param bool syncCenter: Synchronize the center of the axes in the center
- of the plots
- :param bool syncZoom: Synchronize the zoom of the plot
- :param bool filterHiddenPlots: True to avoid updating hidden plots.
- Default: False.
- """
- object.__init__(self)
-
- def implies(x, y): return bool(y ** x)
-
- assert(implies(syncZoom, not syncLimits))
- assert(implies(syncCenter, not syncLimits))
- assert(implies(syncLimits, not syncCenter))
- assert(implies(syncLimits, not syncZoom))
-
- self.__filterHiddenPlots = filterHiddenPlots
- self.__locked = False
- self.__axisRefs = []
- self.__syncLimits = syncLimits
- self.__syncScale = syncScale
- self.__syncDirection = syncDirection
- self.__syncCenter = syncCenter
- self.__syncZoom = syncZoom
- self.__callbacks = None
- self.__lastMainAxis = None
-
- for axis in axes:
- self.addAxis(axis)
-
- self.start()
-
- def start(self):
- """Start synchronizing axes together.
-
- The first axis is used as the reference for the first synchronization.
- After that, any changes to any axes will be used to synchronize other
- axes.
- """
- if self.isSynchronizing():
- raise RuntimeError("Axes already synchronized")
- self.__callbacks = {}
-
- axes = self.__getAxes()
-
- # register callback for further sync
- for axis in axes:
- self.__connectAxes(axis)
- self.synchronize()
-
- def isSynchronizing(self):
- """Returns true if events are connected to the axes to synchronize them
- all together
-
- :rtype: bool
- """
- return self.__callbacks is not None
-
- def __connectAxes(self, axis):
- refAxis = weakref.ref(axis)
- callbacks = []
- if self.__syncLimits:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigLimitsChanged
- sig.connect(callback)
- callbacks.append(("sigLimitsChanged", callback))
- elif self.__syncCenter and self.__syncZoom:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigLimitsChanged
- sig.connect(callback)
- callbacks.append(("sigLimitsChanged", callback))
- elif self.__syncZoom:
- raise NotImplementedError()
- elif self.__syncCenter:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigLimitsChanged
- sig.connect(callback)
- callbacks.append(("sigLimitsChanged", callback))
- if self.__syncScale:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigScaleChanged
- sig.connect(callback)
- callbacks.append(("sigScaleChanged", callback))
- if self.__syncDirection:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
- callback = functools.partial(callback, refAxis)
- sig = axis.sigInvertedChanged
- sig.connect(callback)
- callbacks.append(("sigInvertedChanged", callback))
-
- if self.__filterHiddenPlots:
- # the weakref is needed to be able ignore self references
- callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged)
- callback = functools.partial(callback, refAxis)
- plot = axis._getPlot()
- plot.sigVisibilityChanged.connect(callback)
- callbacks.append(("sigVisibilityChanged", callback))
-
- self.__callbacks[refAxis] = callbacks
-
- def __disconnectAxes(self, axis):
- if axis is not None and _isQObjectValid(axis):
- ref = weakref.ref(axis)
- callbacks = self.__callbacks.pop(ref)
- for sigName, callback in callbacks:
- if sigName == "sigVisibilityChanged":
- obj = axis._getPlot()
- else:
- obj = axis
- if obj is not None:
- sig = getattr(obj, sigName)
- sig.disconnect(callback)
-
- def addAxis(self, axis):
- """Add a new axes to synchronize.
-
- :param ~silx.gui.plot.items.Axis axis: The axis to synchronize
- """
- self.__axisRefs.append(weakref.ref(axis))
- if self.isSynchronizing():
- self.__connectAxes(axis)
- # This could be done faster as only this axis have to be fixed
- self.synchronize()
-
- def removeAxis(self, axis):
- """Remove an axis from the synchronized axes.
-
- :param ~silx.gui.plot.items.Axis axis: The axis to remove
- """
- ref = weakref.ref(axis)
- self.__axisRefs.remove(ref)
- if self.isSynchronizing():
- self.__disconnectAxes(axis)
-
- def synchronize(self, mainAxis=None):
- """Synchronize programatically all the axes.
-
- :param ~silx.gui.plot.items.Axis mainAxis:
- The axis to take as reference (Default: the first axis).
- """
- # sync the current state
- axes = self.__getAxes()
- if len(axes) == 0:
- return
-
- if mainAxis is None:
- mainAxis = axes[0]
-
- refMainAxis = weakref.ref(mainAxis)
- if self.__syncLimits:
- self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits())
- elif self.__syncCenter and self.__syncZoom:
- self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits())
- elif self.__syncCenter:
- self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits())
- if self.__syncScale:
- self.__axisScaleChanged(refMainAxis, mainAxis.getScale())
- if self.__syncDirection:
- self.__axisInvertedChanged(refMainAxis, mainAxis.isInverted())
-
- def stop(self):
- """Stop the synchronization of the axes"""
- if not self.isSynchronizing():
- raise RuntimeError("Axes not synchronized")
- for ref in list(self.__callbacks.keys()):
- axis = ref()
- self.__disconnectAxes(axis)
- self.__callbacks = None
-
- def __del__(self):
- """Destructor"""
- # clean up references
- if self.__callbacks is not None:
- self.stop()
-
- def __getAxes(self):
- """Returns list of existing axes.
-
- :rtype: List[Axis]
- """
- axes = [ref() for ref in self.__axisRefs]
- return [axis for axis in axes if axis is not None]
-
- @contextmanager
- def __inhibitSignals(self):
- self.__locked = True
- yield
- self.__locked = False
-
- def __axesToUpdate(self, changedAxis):
- for axis in self.__getAxes():
- if axis is changedAxis:
- continue
- if self.__filterHiddenPlots:
- plot = axis._getPlot()
- if not plot.isVisible():
- continue
- yield axis
-
- def __axisVisibilityChanged(self, changedAxis, isVisible):
- if not isVisible:
- return
- if self.__locked:
- return
- changedAxis = changedAxis()
- if self.__lastMainAxis is None:
- self.__lastMainAxis = self.__axisRefs[0]
- mainAxis = self.__lastMainAxis
- mainAxis = mainAxis()
- self.synchronize(mainAxis=mainAxis)
- # force back the main axis
- self.__lastMainAxis = weakref.ref(mainAxis)
-
- def __getAxesCenter(self, axis, vmin, vmax):
- """Returns the value displayed in the center of this axis range.
-
- :rtype: float
- """
- scale = axis.getScale()
- if scale == Axis.LINEAR:
- center = (vmin + vmax) * 0.5
- else:
- raise NotImplementedError("Log scale not implemented")
- return center
-
- def __getRangeInPixel(self, axis):
- """Returns the size of the axis in pixel"""
- bounds = axis._getPlot().getPlotBoundsInPixels()
- # bounds: left, top, width, height
- if isinstance(axis, XAxis):
- return bounds[2]
- elif isinstance(axis, YAxis):
- return bounds[3]
- else:
- assert(False)
-
- def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
- """Returns the limits to apply to this axis to move the `pos` into the
- center of this axis.
-
- :param Axis axis:
- :param float pos: Position in the center of the computed limits
- :param Union[None,float] pixelSize: Pixel size to apply to compute the
- limits. If `None` the current pixel size is applyed.
- """
- scale = axis.getScale()
- if scale == Axis.LINEAR:
- if pixelSize is None:
- # Use the current pixel size of the axis
- limits = axis.getLimits()
- valueRange = limits[0] - limits[1]
- a = pos - valueRange * 0.5
- b = pos + valueRange * 0.5
- else:
- pixelRange = self.__getRangeInPixel(axis)
- a = pos - pixelRange * 0.5 * pixelSize
- b = pos + pixelRange * 0.5 * pixelSize
-
- else:
- raise NotImplementedError("Log scale not implemented")
- if a > b:
- return b, a
- return a, b
-
- def __axisLimitsChanged(self, changedAxis, vmin, vmax):
- if self.__locked:
- return
- self.__lastMainAxis = changedAxis
- changedAxis = changedAxis()
- with self.__inhibitSignals():
- for axis in self.__axesToUpdate(changedAxis):
- axis.setLimits(vmin, vmax)
-
- def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax):
- if self.__locked:
- return
- self.__lastMainAxis = changedAxis
- changedAxis = changedAxis()
- with self.__inhibitSignals():
- center = self.__getAxesCenter(changedAxis, vmin, vmax)
- pixelRange = self.__getRangeInPixel(changedAxis)
- if pixelRange == 0:
- return
- pixelSize = (vmax - vmin) / pixelRange
- for axis in self.__axesToUpdate(changedAxis):
- vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize)
- axis.setLimits(vmin, vmax)
-
- def __axisCenterChanged(self, changedAxis, vmin, vmax):
- if self.__locked:
- return
- self.__lastMainAxis = changedAxis
- changedAxis = changedAxis()
- with self.__inhibitSignals():
- center = self.__getAxesCenter(changedAxis, vmin, vmax)
- for axis in self.__axesToUpdate(changedAxis):
- vmin, vmax = self.__getLimitsFromCenter(axis, center)
- axis.setLimits(vmin, vmax)
-
- def __axisScaleChanged(self, changedAxis, scale):
- if self.__locked:
- return
- self.__lastMainAxis = changedAxis
- changedAxis = changedAxis()
- with self.__inhibitSignals():
- for axis in self.__axesToUpdate(changedAxis):
- axis.setScale(scale)
-
- def __axisInvertedChanged(self, changedAxis, isInverted):
- if self.__locked:
- return
- self.__lastMainAxis = changedAxis
- changedAxis = changedAxis()
- with self.__inhibitSignals():
- for axis in self.__axesToUpdate(changedAxis):
- axis.setInverted(isInverted)
diff --git a/silx/gui/plot3d/ParamTreeView.py b/silx/gui/plot3d/ParamTreeView.py
deleted file mode 100644
index 8cf2b90..0000000
--- a/silx/gui/plot3d/ParamTreeView.py
+++ /dev/null
@@ -1,546 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module provides a :class:`QTreeView` dedicated to display plot3d models.
-
-This module contains:
-- :class:`ParamTreeView`: A QTreeView specific for plot3d parameters and scene.
-- :class:`ParameterTreeDelegate`: The delegate for :class:`ParamTreeView`.
-- A set of specific editors used by :class:`ParameterTreeDelegate`:
- :class:`FloatEditor`, :class:`Vector3DEditor`,
- :class:`Vector4DEditor`, :class:`IntSliderEditor`, :class:`BooleanEditor`
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "05/12/2017"
-
-
-import numbers
-import sys
-
-import six
-
-from .. import qt
-from ..widgets.FloatEdit import FloatEdit as _FloatEdit
-from ._model import visitQAbstractItemModel
-
-
-class FloatEditor(_FloatEdit):
- """Editor widget for float.
-
- :param parent: The widget's parent
- :param float value: The initial editor value
- """
-
- valueChanged = qt.Signal(float)
- """Signal emitted when the float value has changed"""
-
- def __init__(self, parent=None, value=None):
- super(FloatEditor, self).__init__(parent, value)
- self.setAlignment(qt.Qt.AlignLeft)
- self.editingFinished.connect(self._emit)
-
- def _emit(self):
- self.valueChanged.emit(self.value)
-
- value = qt.Property(float,
- fget=_FloatEdit.value,
- fset=_FloatEdit.setValue,
- user=True,
- notify=valueChanged)
- """Qt user property of the float value this widget edits"""
-
-
-class Vector3DEditor(qt.QWidget):
- """Editor widget for QVector3D.
-
- :param parent: The widget's parent
- :param flags: The widgets's flags
- """
-
- valueChanged = qt.Signal(qt.QVector3D)
- """Signal emitted when the QVector3D value has changed"""
-
- def __init__(self, parent=None, flags=qt.Qt.Widget):
- super(Vector3DEditor, self).__init__(parent, flags)
- layout = qt.QHBoxLayout(self)
- # layout.setSpacing(0)
- layout.setContentsMargins(0, 0, 0, 0)
- self.setLayout(layout)
- self._xEdit = _FloatEdit(parent=self, value=0.)
- self._xEdit.setAlignment(qt.Qt.AlignLeft)
- # self._xEdit.editingFinished.connect(self._emit)
- self._yEdit = _FloatEdit(parent=self, value=0.)
- self._yEdit.setAlignment(qt.Qt.AlignLeft)
- # self._yEdit.editingFinished.connect(self._emit)
- self._zEdit = _FloatEdit(parent=self, value=0.)
- self._zEdit.setAlignment(qt.Qt.AlignLeft)
- # self._zEdit.editingFinished.connect(self._emit)
- layout.addWidget(qt.QLabel('x:'))
- layout.addWidget(self._xEdit)
- layout.addWidget(qt.QLabel('y:'))
- layout.addWidget(self._yEdit)
- layout.addWidget(qt.QLabel('z:'))
- layout.addWidget(self._zEdit)
- layout.addStretch(1)
-
- def _emit(self):
- vector = self.value
- self.valueChanged.emit(vector)
-
- def getValue(self):
- """Returns the QVector3D value of this widget
-
- :rtype: QVector3D
- """
- return qt.QVector3D(
- self._xEdit.value(), self._yEdit.value(), self._zEdit.value())
-
- def setValue(self, value):
- """Set the QVector3D value
-
- :param QVector3D value: The new value
- """
- self._xEdit.setValue(value.x())
- self._yEdit.setValue(value.y())
- self._zEdit.setValue(value.z())
- self.valueChanged.emit(value)
-
- value = qt.Property(qt.QVector3D,
- fget=getValue,
- fset=setValue,
- user=True,
- notify=valueChanged)
- """Qt user property of the QVector3D value this widget edits"""
-
-
-class Vector4DEditor(qt.QWidget):
- """Editor widget for QVector4D.
-
- :param parent: The widget's parent
- :param flags: The widgets's flags
- """
-
- valueChanged = qt.Signal(qt.QVector4D)
- """Signal emitted when the QVector4D value has changed"""
-
- def __init__(self, parent=None, flags=qt.Qt.Widget):
- super(Vector4DEditor, self).__init__(parent, flags)
- layout = qt.QHBoxLayout(self)
- # layout.setSpacing(0)
- layout.setContentsMargins(0, 0, 0, 0)
- self.setLayout(layout)
- self._xEdit = _FloatEdit(parent=self, value=0.)
- self._xEdit.setAlignment(qt.Qt.AlignLeft)
- # self._xEdit.editingFinished.connect(self._emit)
- self._yEdit = _FloatEdit(parent=self, value=0.)
- self._yEdit.setAlignment(qt.Qt.AlignLeft)
- # self._yEdit.editingFinished.connect(self._emit)
- self._zEdit = _FloatEdit(parent=self, value=0.)
- self._zEdit.setAlignment(qt.Qt.AlignLeft)
- # self._zEdit.editingFinished.connect(self._emit)
- self._wEdit = _FloatEdit(parent=self, value=0.)
- self._wEdit.setAlignment(qt.Qt.AlignLeft)
- # self._wEdit.editingFinished.connect(self._emit)
- layout.addWidget(qt.QLabel('x:'))
- layout.addWidget(self._xEdit)
- layout.addWidget(qt.QLabel('y:'))
- layout.addWidget(self._yEdit)
- layout.addWidget(qt.QLabel('z:'))
- layout.addWidget(self._zEdit)
- layout.addWidget(qt.QLabel('w:'))
- layout.addWidget(self._wEdit)
- layout.addStretch(1)
-
- def _emit(self):
- vector = self.value
- self.valueChanged.emit(vector)
-
- def getValue(self):
- """Returns the QVector4D value of this widget
-
- :rtype: QVector4D
- """
- return qt.QVector4D(self._xEdit.value(), self._yEdit.value(),
- self._zEdit.value(), self._wEdit.value())
-
- def setValue(self, value):
- """Set the QVector4D value
-
- :param QVector4D value: The new value
- """
- self._xEdit.setValue(value.x())
- self._yEdit.setValue(value.y())
- self._zEdit.setValue(value.z())
- self._wEdit.setValue(value.w())
- self.valueChanged.emit(value)
-
- value = qt.Property(qt.QVector4D,
- fget=getValue,
- fset=setValue,
- user=True,
- notify=valueChanged)
- """Qt user property of the QVector4D value this widget edits"""
-
-
-class IntSliderEditor(qt.QSlider):
- """Slider editor widget for integer.
-
- Note: Tracking is disabled.
-
- :param parent: The widget's parent
- """
-
- def __init__(self, parent=None):
- super(IntSliderEditor, self).__init__(parent)
- self.setOrientation(qt.Qt.Horizontal)
- self.setSingleStep(1)
- self.setRange(0, 255)
- self.setValue(0)
-
-
-class BooleanEditor(qt.QCheckBox):
- """Checkbox editor for bool.
-
- This is a QCheckBox with white background.
-
- :param parent: The widget's parent
- """
-
- def __init__(self, parent=None):
- super(BooleanEditor, self).__init__(parent)
- self.setStyleSheet("background: white;")
-
-
-class ParameterTreeDelegate(qt.QStyledItemDelegate):
- """TreeView delegate specific to plot3d scene and object parameter tree.
-
- It provides additional editors.
-
- :param parent: Delegate's parent
- """
-
- EDITORS = {
- bool: BooleanEditor,
- float: FloatEditor,
- qt.QVector3D: Vector3DEditor,
- qt.QVector4D: Vector4DEditor,
- }
- """Specific editors for different type of data"""
-
- def __init__(self, parent=None):
- super(ParameterTreeDelegate, self).__init__(parent)
-
- def _fixVariant(self, data):
- """Fix PyQt4 zero vectors being stored as QPyNullVariant.
-
- :param data: Data retrieved from the model
- :return: Corresponding object
- """
- if qt.BINDING == 'PyQt4' and isinstance(data, qt.QPyNullVariant):
- typeName = data.typeName()
- if typeName == 'QVector3D':
- data = qt.QVector3D()
- elif typeName == 'QVector4D':
- data = qt.QVector4D()
- return data
-
- def paint(self, painter, option, index):
- """See :meth:`QStyledItemDelegate.paint`"""
- data = index.data(qt.Qt.DisplayRole)
- data = self._fixVariant(data)
-
- if isinstance(data, (qt.QVector3D, qt.QVector4D)):
- if isinstance(data, qt.QVector3D):
- text = '(x: %g; y: %g; z: %g)' % (data.x(), data.y(), data.z())
- elif isinstance(data, qt.QVector4D):
- text = '(%g; %g; %g; %g)' % (data.x(), data.y(), data.z(), data.w())
- else:
- text = ''
-
- painter.save()
- painter.setRenderHint(qt.QPainter.Antialiasing, True)
-
- # Select palette color group
- colorGroup = qt.QPalette.Inactive
- if option.state & qt.QStyle.State_Active:
- colorGroup = qt.QPalette.Active
- if not option.state & qt.QStyle.State_Enabled:
- colorGroup = qt.QPalette.Disabled
-
- # Draw background if selected
- if option.state & qt.QStyle.State_Selected:
- brush = option.palette.brush(colorGroup,
- qt.QPalette.Highlight)
- painter.fillRect(option.rect, brush)
-
- # Draw text
- if option.state & qt.QStyle.State_Selected:
- colorRole = qt.QPalette.HighlightedText
- else:
- colorRole = qt.QPalette.WindowText
- color = option.palette.color(colorGroup, colorRole)
- painter.setPen(qt.QPen(color))
- painter.drawText(option.rect, qt.Qt.AlignLeft, text)
-
- painter.restore()
-
- # The following commented code does the same as QPainter based code
- # but it does not work with PySide
- # self.initStyleOption(option, index)
- # option.text = text
- # widget = option.widget
- # style = qt.QApplication.style() if not widget else widget.style()
- # style.drawControl(qt.QStyle.CE_ItemViewItem, option, painter, widget)
-
- else:
- super(ParameterTreeDelegate, self).paint(painter, option, index)
-
- def _commit(self, *args):
- """Commit data to the model from editors"""
- sender = self.sender()
- self.commitData.emit(sender)
-
- def editorEvent(self, event, model, option, index):
- """See :meth:`QStyledItemDelegate.editorEvent`"""
- if (event.type() == qt.QEvent.MouseButtonPress and
- isinstance(index.data(qt.Qt.EditRole), qt.QColor)):
- initialColor = index.data(qt.Qt.EditRole)
-
- def callback(color):
- theModel = index.model()
- theModel.setData(index, color, qt.Qt.EditRole)
-
- dialog = qt.QColorDialog(self.parent())
- # dialog.setOption(qt.QColorDialog.ShowAlphaChannel, True)
- if sys.platform == 'darwin':
- # Use of native color dialog on macos might cause problems
- dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
- dialog.setCurrentColor(initialColor)
- dialog.currentColorChanged.connect(callback)
- if dialog.exec_() == qt.QDialog.Rejected:
- # Reset color
- dialog.setCurrentColor(initialColor)
-
- return True
- else:
- return super(ParameterTreeDelegate, self).editorEvent(
- event, model, option, index)
-
- def createEditor(self, parent, option, index):
- """See :meth:`QStyledItemDelegate.createEditor`"""
- data = index.data(qt.Qt.EditRole)
- data = self._fixVariant(data)
- editorHint = index.data(qt.Qt.UserRole)
-
- if callable(editorHint):
- editor = editorHint()
- assert isinstance(editor, qt.QWidget)
- editor.setParent(parent)
-
- elif isinstance(data, numbers.Number) and editorHint is not None:
- # Use a slider
- editor = IntSliderEditor(parent)
- range_ = editorHint
- editor.setRange(*range_)
- editor.sliderReleased.connect(self._commit)
-
- elif isinstance(data, six.string_types) and editorHint is not None:
- # Use a combo box
- editor = qt.QComboBox(parent)
- if data not in editorHint:
- editor.addItem(data)
- editor.addItems(editorHint)
-
- index = editor.findText(data)
- editor.setCurrentIndex(index)
-
- editor.currentIndexChanged.connect(self._commit)
-
- else:
- # Handle overridden editors from Python
- # Mimic Qt C++ implementation
- for type_, editorClass in self.EDITORS.items():
- if isinstance(data, type_):
- editor = editorClass(parent)
- metaObject = editor.metaObject()
- userProperty = metaObject.userProperty()
- if userProperty.isValid() and userProperty.hasNotifySignal():
- notifySignal = userProperty.notifySignal()
- if hasattr(notifySignal, 'signature'): # Qt4
- signature = notifySignal.signature()
- else:
- signature = notifySignal.methodSignature()
- if qt.BINDING == 'PySide2':
- signature = signature.data()
- else:
- signature = bytes(signature)
-
- if hasattr(signature, 'decode'): # For PySide with python3
- signature = signature.decode('ascii')
- signalName = signature.split('(')[0]
-
- signal = getattr(editor, signalName)
- signal.connect(self._commit)
- break
-
- else: # Default handling for default types
- return super(ParameterTreeDelegate, self).createEditor(
- parent, option, index)
-
- editor.setAutoFillBackground(True)
- return editor
-
- def setModelData(self, editor, model, index):
- """See :meth:`QStyledItemDelegate.setModelData`"""
- if isinstance(editor, tuple(self.EDITORS.values())):
- # Special handling of Python classes
- # Translation of QStyledItemDelegate::setModelData to Python
- # To make it work with Python QVariant wrapping/unwrapping
- name = editor.metaObject().userProperty().name()
- if not name:
- pass # TODO handle the case of missing user property
- if name:
- if hasattr(editor, name):
- value = getattr(editor, name)
- else:
- value = editor.property(name)
- model.setData(index, value, qt.Qt.EditRole)
-
- else:
- super(ParameterTreeDelegate, self).setModelData(editor, model, index)
-
-
-class ParamTreeView(qt.QTreeView):
- """QTreeView specific to handle plot3d scene and object parameters.
-
- It provides additional editors and specific creation of persistent editors.
-
- :param parent: The widget's parent.
- """
-
- def __init__(self, parent=None):
- super(ParamTreeView, self).__init__(parent)
-
- header = self.header()
- header.setMinimumSectionSize(128) # For colormap pixmaps
- if hasattr(header, 'setSectionResizeMode'): # Qt5
- header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
- else: # Qt4
- header.setResizeMode(qt.QHeaderView.ResizeToContents)
-
- delegate = ParameterTreeDelegate()
- self.setItemDelegate(delegate)
-
- self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
-
- self.expanded.connect(self._expanded)
-
- self.setEditTriggers(qt.QAbstractItemView.CurrentChanged |
- qt.QAbstractItemView.DoubleClicked)
-
- self.__persistentEditors = set()
-
- def _openEditorForIndex(self, index):
- """Check if it has to open a persistent editor for a specific cell.
-
- :param QModelIndex index: The cell index
- """
- if index.flags() & qt.Qt.ItemIsEditable:
- data = index.data(qt.Qt.EditRole)
- editorHint = index.data(qt.Qt.UserRole)
- if (isinstance(data, bool) or
- callable(editorHint) or
- (isinstance(data, numbers.Number) and editorHint)):
- self.openPersistentEditor(index)
- self.__persistentEditors.add(index)
-
- def _openEditors(self, parent=qt.QModelIndex()):
- """Open persistent editors in a subtree starting at parent.
-
- :param QModelIndex parent: The root of the subtree to process.
- """
- model = self.model()
- if model is not None:
- for index in visitQAbstractItemModel(model, parent):
- self._openEditorForIndex(index)
-
- def setModel(self, model):
- """Set the model this TreeView is displaying
-
- :param QAbstractItemModel model:
- """
- super(ParamTreeView, self).setModel(model)
- self._openEditors()
-
- def rowsInserted(self, parent, start, end):
- """See :meth:`QTreeView.rowsInserted`"""
- super(ParamTreeView, self).rowsInserted(parent, start, end)
- model = self.model()
- if model is not None:
- for row in range(start, end+1):
- self._openEditorForIndex(model.index(row, 1, parent))
- self._openEditors(model.index(row, 0, parent))
-
- def _expanded(self, index):
- """Handle QTreeView expanded signal"""
- name = index.data(qt.Qt.DisplayRole)
- if name == 'Transform':
- rotateIndex = self.model().index(1, 0, index)
- self.setExpanded(rotateIndex, True)
-
- def dataChanged(self, topLeft, bottomRight, roles=()):
- """Handle model dataChanged signal eventually closing editors"""
- if roles: # Qt 5
- super(ParamTreeView, self).dataChanged(topLeft, bottomRight, roles)
- else: # Qt4 compatibility
- super(ParamTreeView, self).dataChanged(topLeft, bottomRight)
- if not roles or qt.Qt.UserRole in roles: # Check editorHint update
- for row in range(topLeft.row(), bottomRight.row() + 1):
- for column in range(topLeft.column(), bottomRight.column() + 1):
- index = topLeft.sibling(row, column)
- if index.isValid():
- if self._isPersistentEditorOpen(index):
- self.closePersistentEditor(index)
- self._openEditorForIndex(index)
-
- def _isPersistentEditorOpen(self, index):
- """Returns True if a persistent editor is opened for index
-
- :param QModelIndex index:
- :rtype: bool
- """
- return index in self.__persistentEditors
-
- def selectionCommand(self, index, event=None):
- """Filter out selection of not selectable items"""
- if index.flags() & qt.Qt.ItemIsSelectable:
- return super(ParamTreeView, self).selectionCommand(index, event)
- else:
- return qt.QItemSelectionModel.NoUpdate
diff --git a/silx/gui/plot3d/Plot3DWidget.py b/silx/gui/plot3d/Plot3DWidget.py
deleted file mode 100644
index f512cd8..0000000
--- a/silx/gui/plot3d/Plot3DWidget.py
+++ /dev/null
@@ -1,460 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a Qt widget embedding an OpenGL scene."""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import enum
-import logging
-
-from silx.gui import qt
-from silx.gui.colors import rgba
-from . import actions
-
-from ...utils.enum import Enum as _Enum
-from ..utils.image import convertArrayToQImage
-
-from .. import _glutils as glu
-from .scene import interaction, primitives, transform
-from . import scene
-
-import numpy
-
-
-_logger = logging.getLogger(__name__)
-
-
-class _OverviewViewport(scene.Viewport):
- """A scene displaying the orientation of the data in another scene.
-
- :param Camera camera: The camera to track.
- """
-
- _SIZE = 100
- """Size in pixels of the overview square"""
-
- def __init__(self, camera=None):
- super(_OverviewViewport, self).__init__()
- self.size = self._SIZE, self._SIZE
- self.background = None # Disable clear
-
- self.scene.transforms = [transform.Scale(2.5, 2.5, 2.5)]
-
- # Add a point to draw the background (in a group with depth mask)
- backgroundPoint = primitives.ColorPoints(
- x=0., y=0., z=0.,
- color=(1., 1., 1., 0.5),
- size=self._SIZE)
- backgroundPoint.marker = 'o'
- noDepthGroup = primitives.GroupNoDepth(mask=True, notest=True)
- noDepthGroup.children.append(backgroundPoint)
- self.scene.children.append(noDepthGroup)
-
- axes = primitives.Axes()
- self.scene.children.append(axes)
-
- if camera is not None:
- camera.addListener(self._cameraChanged)
-
- def _cameraChanged(self, source):
- """Listen to camera in other scene for transformation updates.
-
- Sync the overview camera to point in the same direction
- but from a sphere centered on origin.
- """
- position = -12. * source.extrinsic.direction
- self.camera.extrinsic.position = position
-
- self.camera.extrinsic.setOrientation(
- source.extrinsic.direction, source.extrinsic.up)
-
-
-class Plot3DWidget(glu.OpenGLWidget):
- """OpenGL widget with a 3D viewport and an overview."""
-
- sigInteractiveModeChanged = qt.Signal()
- """Signal emitted when the interactive mode has changed
- """
-
- sigStyleChanged = qt.Signal(str)
- """Signal emitted when the style of the scene has changed
-
- It provides the updated property.
- """
-
- sigSceneClicked = qt.Signal(float, float)
- """Signal emitted when the scene is clicked with the left mouse button.
-
- It provides the (x, y) clicked mouse position
- """
-
- @enum.unique
- class FogMode(_Enum):
- """Different mode to render the scene with fog"""
-
- NONE = 'none'
- """No fog effect"""
-
- LINEAR = 'linear'
- """Linear fog through the whole scene"""
-
- def __init__(self, parent=None, f=qt.Qt.WindowFlags()):
- self._firstRender = True
-
- super(Plot3DWidget, self).__init__(
- parent,
- alphaBufferSize=8,
- depthBufferSize=0,
- stencilBufferSize=0,
- version=(2, 1),
- f=f)
-
- self.setAutoFillBackground(False)
- self.setMouseTracking(True)
-
- self.setFocusPolicy(qt.Qt.StrongFocus)
- self._copyAction = actions.io.CopyAction(parent=self, plot3d=self)
- self.addAction(self._copyAction)
-
- self._updating = False # True if an update is requested
-
- # Main viewport
- self.viewport = scene.Viewport()
-
- self._sceneScale = transform.Scale(1., 1., 1.)
- self.viewport.scene.transforms = [self._sceneScale,
- transform.Translate(0., 0., 0.)]
-
- # Overview area
- self.overview = _OverviewViewport(self.viewport.camera)
-
- self.setBackgroundColor((0.2, 0.2, 0.2, 1.))
-
- # Window describing on screen area to render
- self._window = scene.Window(mode='framebuffer')
- self._window.viewports = [self.viewport, self.overview]
- self._window.addListener(self._redraw)
-
- self.eventHandler = None
- self.setInteractiveMode('rotate')
-
- def __clickHandler(self, *args):
- """Handle interaction state machine click"""
- x, y = args[0][:2]
- self.sigSceneClicked.emit(x, y)
-
- def setInteractiveMode(self, mode):
- """Set the interactive mode.
-
- :param str mode: The interactive mode: 'rotate', 'pan' or None
- """
- if mode == self.getInteractiveMode():
- return
-
- if mode is None:
- self.eventHandler = None
-
- elif mode == 'rotate':
- self.eventHandler = interaction.RotateCameraControl(
- self.viewport,
- orbitAroundCenter=False,
- mode='position',
- scaleTransform=self._sceneScale,
- selectCB=self.__clickHandler)
-
- elif mode == 'pan':
- self.eventHandler = interaction.PanCameraControl(
- self.viewport,
- orbitAroundCenter=False,
- mode='position',
- scaleTransform=self._sceneScale,
- selectCB=self.__clickHandler)
-
- elif isinstance(mode, interaction.StateMachine):
- self.eventHandler = mode
-
- else:
- raise ValueError('Unsupported interactive mode %s', str(mode))
-
- if (self.eventHandler is not None and
- qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier):
- self.eventHandler.handleEvent('keyPress', qt.Qt.Key_Control)
-
- self.sigInteractiveModeChanged.emit()
-
- def getInteractiveMode(self):
- """Returns the interactive mode in use.
-
- :rtype: str
- """
- if self.eventHandler is None:
- return None
- if isinstance(self.eventHandler, interaction.RotateCameraControl):
- return 'rotate'
- elif isinstance(self.eventHandler, interaction.PanCameraControl):
- return 'pan'
- else:
- return None
-
- def setProjection(self, projection):
- """Change the projection in use.
-
- :param str projection: In 'perspective', 'orthographic'.
- """
- if projection == 'orthographic':
- projection = transform.Orthographic(size=self.viewport.size)
- elif projection == 'perspective':
- projection = transform.Perspective(fovy=30.,
- size=self.viewport.size)
- else:
- raise RuntimeError('Unsupported projection: %s' % projection)
-
- self.viewport.camera.intrinsic = projection
- self.viewport.resetCamera()
-
- def getProjection(self):
- """Return the current camera projection mode as a str.
-
- See :meth:`setProjection`
- """
- projection = self.viewport.camera.intrinsic
- if isinstance(projection, transform.Orthographic):
- return 'orthographic'
- elif isinstance(projection, transform.Perspective):
- return 'perspective'
- else:
- raise RuntimeError('Unknown projection in use')
-
- def setBackgroundColor(self, color):
- """Set the background color of the OpenGL view.
-
- :param color: RGB color of the isosurface: name, #RRGGBB or RGB values
- :type color:
- QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
- """
- color = rgba(color)
- if color != self.viewport.background:
- self.viewport.background = color
- self.sigStyleChanged.emit('backgroundColor')
-
- def getBackgroundColor(self):
- """Returns the RGBA background color (QColor)."""
- return qt.QColor.fromRgbF(*self.viewport.background)
-
- def setFogMode(self, mode):
- """Set the kind of fog to use for the whole scene.
-
- :param Union[str,FogMode] mode: The mode to use
- :raise ValueError: If mode is not supported
- """
- mode = self.FogMode.from_value(mode)
- if mode != self.getFogMode():
- self.viewport.fog.isOn = mode is self.FogMode.LINEAR
- self.sigStyleChanged.emit('fogMode')
-
- def getFogMode(self):
- """Returns the kind of fog in use
-
- :return: The kind of fog in use
- :rtype: FogMode
- """
- if self.viewport.fog.isOn:
- return self.FogMode.LINEAR
- else:
- return self.FogMode.NONE
-
- def isOrientationIndicatorVisible(self):
- """Returns True if the orientation indicator is displayed.
-
- :rtype: bool
- """
- return self.overview in self._window.viewports
-
- def setOrientationIndicatorVisible(self, visible):
- """Set the orientation indicator visibility.
-
- :param bool visible: True to show
- """
- visible = bool(visible)
- if visible != self.isOrientationIndicatorVisible():
- if visible:
- self._window.viewports = [self.viewport, self.overview]
- else:
- self._window.viewports = [self.viewport]
- self.sigStyleChanged.emit('orientationIndicatorVisible')
-
- def centerScene(self):
- """Position the center of the scene at the center of rotation."""
- self.viewport.resetCamera()
-
- def resetZoom(self, face='front'):
- """Reset the camera position to a default.
-
- :param str face: The direction the camera is looking at:
- side, front, back, top, bottom, right, left.
- Default: front.
- """
- self.viewport.camera.extrinsic.reset(face=face)
- self.centerScene()
-
- def _redraw(self, source=None):
- """Viewport listener to require repaint"""
- if not self._updating:
- self._updating = True # Mark that an update is requested
- self.update() # Queued repaint (i.e., asynchronous)
-
- def sizeHint(self):
- return qt.QSize(400, 300)
-
- def initializeGL(self):
- pass
-
- def paintGL(self):
- # In case paintGL is called by the system and not through _redraw,
- # Mark as updating.
- self._updating = True
-
- # Update near and far planes only if viewport needs refresh
- if self.viewport.dirty:
- self.viewport.adjustCameraDepthExtent()
-
- self._window.render(self.context(), self.getDevicePixelRatio())
-
- if self._firstRender: # TODO remove this ugly hack
- self._firstRender = False
- self.centerScene()
- self._updating = False
-
- def resizeGL(self, width, height):
- width *= self.getDevicePixelRatio()
- height *= self.getDevicePixelRatio()
- self._window.size = width, height
- self.viewport.size = self._window.size
- overviewWidth, overviewHeight = self.overview.size
- self.overview.origin = width - overviewWidth, height - overviewHeight
-
- def grabGL(self):
- """Renders the OpenGL scene into a numpy array
-
- :returns: OpenGL scene RGB rasterization
- :rtype: QImage
- """
- if not self.isValid():
- _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
- height, width = self._window.shape
- image = numpy.zeros((height, width, 3), dtype=numpy.uint8)
-
- else:
- self.makeCurrent()
- image = self._window.grab(self.context())
-
- return convertArrayToQImage(image)
-
- def wheelEvent(self, event):
- xpixel = event.x() * self.getDevicePixelRatio()
- ypixel = event.y() * self.getDevicePixelRatio()
- if hasattr(event, 'delta'): # Qt4
- angle = event.delta() / 8.
- else: # Qt5
- angle = event.angleDelta().y() / 8.
- event.accept()
-
- if self.eventHandler is not None and angle != 0 and self.isValid():
- self.makeCurrent()
- self.eventHandler.handleEvent('wheel', xpixel, ypixel, angle)
-
- def keyPressEvent(self, event):
- keyCode = event.key()
- # No need to accept QKeyEvent
-
- converter = {
- qt.Qt.Key_Left: 'left',
- qt.Qt.Key_Right: 'right',
- qt.Qt.Key_Up: 'up',
- qt.Qt.Key_Down: 'down'
- }
- direction = converter.get(keyCode, None)
- if direction is not None:
- if event.modifiers() == qt.Qt.ControlModifier:
- self.viewport.camera.rotate(direction)
- elif event.modifiers() == qt.Qt.ShiftModifier:
- self.viewport.moveCamera(direction)
- else:
- self.viewport.orbitCamera(direction)
-
- else:
- if (keyCode == qt.Qt.Key_Control and
- self.eventHandler is not None and
- self.isValid()):
- self.eventHandler.handleEvent('keyPress', keyCode)
-
- # Key not handled, call base class implementation
- super(Plot3DWidget, self).keyPressEvent(event)
-
- def keyReleaseEvent(self, event):
- """Catch Ctrl key release"""
- keyCode = event.key()
- if (keyCode == qt.Qt.Key_Control and
- self.eventHandler is not None and
- self.isValid()):
- self.eventHandler.handleEvent('keyRelease', keyCode)
- super(Plot3DWidget, self).keyReleaseEvent(event)
-
- # Mouse events #
- _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
-
- def mousePressEvent(self, event):
- xpixel = event.x() * self.getDevicePixelRatio()
- ypixel = event.y() * self.getDevicePixelRatio()
- btn = self._MOUSE_BTNS[event.button()]
- event.accept()
-
- if self.eventHandler is not None and self.isValid():
- self.makeCurrent()
- self.eventHandler.handleEvent('press', xpixel, ypixel, btn)
-
- def mouseMoveEvent(self, event):
- xpixel = event.x() * self.getDevicePixelRatio()
- ypixel = event.y() * self.getDevicePixelRatio()
- event.accept()
-
- if self.eventHandler is not None and self.isValid():
- self.makeCurrent()
- self.eventHandler.handleEvent('move', xpixel, ypixel)
-
- def mouseReleaseEvent(self, event):
- xpixel = event.x() * self.getDevicePixelRatio()
- ypixel = event.y() * self.getDevicePixelRatio()
- btn = self._MOUSE_BTNS[event.button()]
- event.accept()
-
- if self.eventHandler is not None and self.isValid():
- self.makeCurrent()
- self.eventHandler.handleEvent('release', xpixel, ypixel, btn)
diff --git a/silx/gui/plot3d/SFViewParamTree.py b/silx/gui/plot3d/SFViewParamTree.py
deleted file mode 100644
index 4e179fc..0000000
--- a/silx/gui/plot3d/SFViewParamTree.py
+++ /dev/null
@@ -1,1817 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module provides a tree widget to set/view parameters of a ScalarFieldView.
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["D. N."]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-import logging
-import sys
-import weakref
-
-import numpy
-
-from silx.gui import qt
-from silx.gui.icons import getQIcon
-from silx.gui.colors import Colormap
-from silx.gui.widgets.FloatEdit import FloatEdit
-
-from .ScalarFieldView import Isosurface
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ModelColumns(object):
- NameColumn, ValueColumn, ColumnMax = range(3)
- ColumnNames = ['Name', 'Value']
-
-
-class SubjectItem(qt.QStandardItem):
- """
- Base class for observers items.
-
- Subclassing:
- ------------
- The following method can/should be reimplemented:
- - _init
- - _pullData
- - _pushData
- - _setModelData
- - _subjectChanged
- - getEditor
- - getSignals
- - leftClicked
- - queryRemove
- - setEditorData
-
- Also the following attributes are available:
- - editable
- - persistent
-
- :param subject: object that this item will be observing.
- """
-
- editable = False
- """ boolean: set to True to make the item editable. """
-
- persistent = False
- """
- boolean: set to True to make the editor persistent.
- See : Qt.QAbstractItemView.openPersistentEditor
- """
-
- def __init__(self, subject, *args):
-
- super(SubjectItem, self).__init__(*args)
-
- self.setEditable(self.editable)
-
- self.__subject = None
- self.subject = subject
-
- def setData(self, value, role=qt.Qt.UserRole, pushData=True):
- """
- Overloaded method from QStandardItem. The pushData keyword tells
- the item to push data to the subject if the role is equal to EditRole.
- This is useful to let this method know if the setData method was called
- internally or from the view.
-
- :param value: the value ti set to data
- :param role: role in the item
- :param pushData: if True push value in the existing data.
- """
- if role == qt.Qt.EditRole and pushData:
- setValue = self._pushData(value, role)
- if setValue != value:
- value = setValue
- super(SubjectItem, self).setData(value, role)
-
- @property
- def subject(self):
- """The subject this item is observing"""
- return None if self.__subject is None else self.__subject()
-
- @subject.setter
- def subject(self, subject):
- if self.__subject is not None:
- raise ValueError('Subject already set '
- ' (subject change not supported).')
- if subject is None:
- self.__subject = None
- else:
- self.__subject = weakref.ref(subject)
- if subject is not None:
- self._init()
- self._connectSignals()
-
- def _connectSignals(self):
- """
- Connects the signals. Called when the subject is set.
- """
-
- def gen_slot(_sigIdx):
- def slotfn(*args, **kwargs):
- self._subjectChanged(signalIdx=_sigIdx,
- args=args,
- kwargs=kwargs)
- return slotfn
-
- if self.__subject is not None:
- self.__slots = slots = []
-
- signals = self.getSignals()
-
- if signals:
- if not isinstance(signals, (list, tuple)):
- signals = [signals]
- for sigIdx, signal in enumerate(signals):
- slot = gen_slot(sigIdx)
- signal.connect(slot)
- slots.append((signal, slot))
-
- def _disconnectSignals(self):
- """
- Disconnects all subject's signal
- """
- if self.__slots:
- for signal, slot in self.__slots:
- try:
- signal.disconnect(slot)
- except TypeError:
- pass
-
- def _enableRow(self, enable):
- """
- Set the enabled state for this cell, or for the whole row
- if this item has a parent.
-
- :param bool enable: True if we wan't to enable the cell
- """
- parent = self.parent()
- model = self.model()
- if model is None or parent is None:
- # no parent -> no siblings
- self.setEnabled(enable)
- return
-
- for col in range(model.columnCount()):
- sibling = parent.child(self.row(), col)
- sibling.setEnabled(enable)
-
- #################################################################
- # Overloadable methods
- #################################################################
-
- def getSignals(self):
- """
- Returns the list of this items subject's signals that
- this item will be listening to.
-
- :return: list.
- """
- return None
-
- def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
- """
- Called when one of the signals is triggered. Default implementation
- just calls _pullData, compares the result to the current value stored
- as Qt.EditRole, and stores the new value if it is different. It also
- stores its str representation as Qt.DisplayRole
-
- :param signalIdx: index of the triggered signal. The value passed
- is the same as the signal position in the list returned by
- SubjectItem.getSignals.
- :param args: arguments received from the signal
- :param kwargs: keyword arguments received from the signal
- """
- data = self._pullData()
- if data == self.data(qt.Qt.EditRole):
- return
- self.setData(data, role=qt.Qt.DisplayRole, pushData=False)
- self.setData(data, role=qt.Qt.EditRole, pushData=False)
-
- def _pullData(self):
- """
- Pulls data from the subject.
-
- :return: subject data
- """
- return None
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- """
- Pushes data to the subject and returns the actual value that was stored
-
- :return: the value that was stored
- """
- return value
-
- def _init(self):
- """
- Called when the subject is set.
- :return:
- """
- self._subjectChanged()
-
- def getEditor(self, parent, option, index):
- """
- Returns the editor widget used to edit this item's data. The arguments
- are the one passed to the QStyledItemDelegate.createEditor method.
-
- :param parent: the Qt parent of the editor
- :param option:
- :param index:
- :return:
- """
- return None
-
- def setEditorData(self, editor):
- """
- This is called by the View's delegate just before the editor is shown,
- its purpose it to setup the editors contents. Return False to use
- the delegate's default behaviour.
-
- :param editor:
- :return:
- """
- return True
-
- def _setModelData(self, editor):
- """
- This is called by the View's delegate just before the editor is closed,
- its allows this item to update itself with data from the editor.
-
- :param editor:
- :return:
- """
- return False
-
- def queryRemove(self, view=None):
- """
- This is called by the view to ask this items if it (the view) can
- remove it. Return True to let the view know that the item can be
- removed.
-
- :param view:
- :return:
- """
- return False
-
- def leftClicked(self):
- """
- This method is called by the view when the item's cell if left clicked.
-
- :return:
- """
- pass
-
-
-# View settings ###############################################################
-
-class ColorItem(SubjectItem):
- """color item."""
- editable = True
- persistent = True
-
- def getEditor(self, parent, option, index):
- editor = QColorEditor(parent)
- editor.color = self.getColor()
-
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.sigColorChanged.connect(
- lambda color: self._editorSlot(color))
- return editor
-
- def _editorSlot(self, color):
- self.setData(color, qt.Qt.EditRole)
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- self.setColor(value)
- return self.getColor()
-
- def _pullData(self):
- self.getColor()
-
- def setColor(self, color):
- """Override to implement actual color setter"""
- pass
-
-
-class BackgroundColorItem(ColorItem):
- itemName = 'Background'
-
- def setColor(self, color):
- self.subject.setBackgroundColor(color)
-
- def getColor(self):
- return self.subject.getBackgroundColor()
-
-
-class ForegroundColorItem(ColorItem):
- itemName = 'Foreground'
-
- def setColor(self, color):
- self.subject.setForegroundColor(color)
-
- def getColor(self):
- return self.subject.getForegroundColor()
-
-
-class HighlightColorItem(ColorItem):
- itemName = 'Highlight'
-
- def setColor(self, color):
- self.subject.setHighlightColor(color)
-
- def getColor(self):
- return self.subject.getHighlightColor()
-
-
-class _LightDirectionAngleBaseItem(SubjectItem):
- """Base class for directional light angle item."""
- editable = True
- persistent = True
-
- def _init(self):
- pass
-
- def getSignals(self):
- """Override to provide signals to listen"""
- raise NotImplementedError("MUST be implemented in subclass")
-
- def _pullData(self):
- """Override in subclass to get current angle"""
- raise NotImplementedError("MUST be implemented in subclass")
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- """Override in subclass to set the angle"""
- raise NotImplementedError("MUST be implemented in subclass")
-
- def getEditor(self, parent, option, index):
- editor = qt.QSlider(parent)
- editor.setOrientation(qt.Qt.Horizontal)
- editor.setMinimum(-90)
- editor.setMaximum(90)
- editor.setValue(int(self._pullData()))
-
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.valueChanged.connect(
- lambda value: self._pushData(value))
-
- return editor
-
- def setEditorData(self, editor):
- editor.setValue(int(self._pullData()))
- return True
-
- def _setModelData(self, editor):
- value = editor.value()
- self._pushData(value)
- return True
-
-
-class LightAzimuthAngleItem(_LightDirectionAngleBaseItem):
- """Light direction azimuth angle item."""
-
- def getSignals(self):
- return self.subject.sigAzimuthAngleChanged
-
- def _pullData(self):
- return self.subject.getAzimuthAngle()
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- self.subject.setAzimuthAngle(value)
-
-
-class LightAltitudeAngleItem(_LightDirectionAngleBaseItem):
- """Light direction altitude angle item."""
-
- def getSignals(self):
- return self.subject.sigAltitudeAngleChanged
-
- def _pullData(self):
- return self.subject.getAltitudeAngle()
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- self.subject.setAltitudeAngle(value)
-
-
-class _DirectionalLightProxy(qt.QObject):
- """Proxy to handle directional light with angles rather than vector.
- """
-
- sigAzimuthAngleChanged = qt.Signal()
- """Signal sent when the azimuth angle has changed."""
-
- sigAltitudeAngleChanged = qt.Signal()
- """Signal sent when altitude angle has changed."""
-
- def __init__(self, light):
- super(_DirectionalLightProxy, self).__init__()
- self._light = light
- light.addListener(self._directionUpdated)
- self._azimuth = 0.
- self._altitude = 0.
-
- def getAzimuthAngle(self):
- """Returns the signed angle in the horizontal plane.
-
- Unit: degrees.
- The 0 angle corresponds to the axis perpendicular to the screen.
-
- :rtype: float
- """
- return self._azimuth
-
- def getAltitudeAngle(self):
- """Returns the signed vertical angle from the horizontal plane.
-
- Unit: degrees.
- Range: [-90, +90]
-
- :rtype: float
- """
- return self._altitude
-
- def setAzimuthAngle(self, angle):
- """Set the horizontal angle.
-
- :param float angle: Angle from -z axis in zx plane in degrees.
- """
- if angle != self._azimuth:
- self._azimuth = angle
- self._updateLight()
- self.sigAzimuthAngleChanged.emit()
-
- def setAltitudeAngle(self, angle):
- """Set the horizontal angle.
-
- :param float angle: Angle from -z axis in zy plane in degrees.
- """
- if angle != self._altitude:
- self._altitude = angle
- self._updateLight()
- self.sigAltitudeAngleChanged.emit()
-
- def _directionUpdated(self, *args, **kwargs):
- """Handle light direction update in the scene"""
- # Invert direction to manipulate the 'source' pointing to
- # the center of the viewport
- x, y, z = - self._light.direction
-
- # Horizontal plane is plane xz
- azimuth = numpy.degrees(numpy.arctan2(x, z))
- altitude = numpy.degrees(numpy.pi/2. - numpy.arccos(y))
-
- if (abs(azimuth - self.getAzimuthAngle()) > 0.01 and
- abs(abs(altitude) - 90.) >= 0.001): # Do not update when at zenith
- self.setAzimuthAngle(azimuth)
-
- if abs(altitude - self.getAltitudeAngle()) > 0.01:
- self.setAltitudeAngle(altitude)
-
- def _updateLight(self):
- """Update light direction in the scene"""
- azimuth = numpy.radians(self._azimuth)
- delta = numpy.pi/2. - numpy.radians(self._altitude)
- z = - numpy.sin(delta) * numpy.cos(azimuth)
- x = - numpy.sin(delta) * numpy.sin(azimuth)
- y = - numpy.cos(delta)
- self._light.direction = x, y, z
-
-
-class DirectionalLightGroup(SubjectItem):
- """
- Root Item for the directional light
- """
-
- def __init__(self,subject, *args):
- self._light = _DirectionalLightProxy(
- subject.getPlot3DWidget().viewport.light)
-
- super(DirectionalLightGroup, self).__init__(subject, *args)
-
- def _init(self):
-
- nameItem = qt.QStandardItem('Azimuth')
- nameItem.setEditable(False)
- valueItem = LightAzimuthAngleItem(self._light)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Altitude')
- nameItem.setEditable(False)
- valueItem = LightAltitudeAngleItem(self._light)
- self.appendRow([nameItem, valueItem])
-
-
-class BoundingBoxItem(SubjectItem):
- """Bounding box, axes labels and grid visibility item.
-
- Item is checkable.
- """
- itemName = 'Bounding Box'
-
- def _init(self):
- visible = self.subject.isBoundingBoxVisible()
- self.setCheckable(True)
- self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
-
- def leftClicked(self):
- checked = (self.checkState() == qt.Qt.Checked)
- if checked != self.subject.isBoundingBoxVisible():
- self.subject.setBoundingBoxVisible(checked)
-
-
-class OrientationIndicatorItem(SubjectItem):
- """Orientation indicator visibility item.
-
- Item is checkable.
- """
- itemName = 'Axes indicator'
-
- def _init(self):
- plot3d = self.subject.getPlot3DWidget()
- visible = plot3d.isOrientationIndicatorVisible()
- self.setCheckable(True)
- self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
-
- def leftClicked(self):
- plot3d = self.subject.getPlot3DWidget()
- checked = (self.checkState() == qt.Qt.Checked)
- if checked != plot3d.isOrientationIndicatorVisible():
- plot3d.setOrientationIndicatorVisible(checked)
-
-
-class ViewSettingsItem(qt.QStandardItem):
- """Viewport settings"""
-
- def __init__(self, subject, *args):
-
- super(ViewSettingsItem, self).__init__(*args)
-
- self.setEditable(False)
-
- classes = (BackgroundColorItem,
- ForegroundColorItem,
- HighlightColorItem,
- BoundingBoxItem,
- OrientationIndicatorItem)
- for cls in classes:
- titleItem = qt.QStandardItem(cls.itemName)
- titleItem.setEditable(False)
- self.appendRow([titleItem, cls(subject)])
-
- nameItem = DirectionalLightGroup(subject, 'Light Direction')
- valueItem = qt.QStandardItem()
- self.appendRow([nameItem, valueItem])
-
-
-# Data information ############################################################
-
-class DataChangedItem(SubjectItem):
- """
- Base class for items listening to ScalarFieldView.sigDataChanged
- """
-
- def getSignals(self):
- subject = self.subject
- if subject:
- return subject.sigDataChanged, subject.sigTransformChanged
- return None
-
- def _init(self):
- self._subjectChanged()
-
-
-class DataTypeItem(DataChangedItem):
- itemName = 'dtype'
-
- def _pullData(self):
- data = self.subject.getData(copy=False)
- return ((data is not None) and str(data.dtype)) or 'N/A'
-
-
-class DataShapeItem(DataChangedItem):
- itemName = 'size'
-
- def _pullData(self):
- data = self.subject.getData(copy=False)
- if data is None:
- return 'N/A'
- else:
- return str(list(reversed(data.shape)))
-
-
-class OffsetItem(DataChangedItem):
- itemName = 'offset'
-
- def _pullData(self):
- offset = self.subject.getTranslation()
- return ((offset is not None) and str(offset)) or 'N/A'
-
-
-class ScaleItem(DataChangedItem):
- itemName = 'scale'
-
- def _pullData(self):
- scale = self.subject.getScale()
- return ((scale is not None) and str(scale)) or 'N/A'
-
-
-class MatrixItem(DataChangedItem):
-
- def __init__(self, subject, row, *args):
- self.__row = row
- super(MatrixItem, self).__init__(subject, *args)
-
- def _pullData(self):
- matrix = self.subject.getTransformMatrix()
- return str(matrix[self.__row])
-
-
-class DataSetItem(qt.QStandardItem):
-
- def __init__(self, subject, *args):
-
- super(DataSetItem, self).__init__(*args)
-
- self.setEditable(False)
-
- klasses = [DataTypeItem, DataShapeItem, OffsetItem]
- for klass in klasses:
- titleItem = qt.QStandardItem(klass.itemName)
- titleItem.setEditable(False)
- self.appendRow([titleItem, klass(subject)])
-
- matrixItem = qt.QStandardItem('matrix')
- matrixItem.setEditable(False)
- valueItem = qt.QStandardItem()
- self.appendRow([matrixItem, valueItem])
-
- for row in range(3):
- titleItem = qt.QStandardItem()
- titleItem.setEditable(False)
- valueItem = MatrixItem(subject, row)
- matrixItem.appendRow([titleItem, valueItem])
-
- titleItem = qt.QStandardItem(ScaleItem.itemName)
- titleItem.setEditable(False)
- self.appendRow([titleItem, ScaleItem(subject)])
-
-
-# Isosurface ##################################################################
-
-class IsoSurfaceRootItem(SubjectItem):
- """
- Root (i.e : column index 0) Isosurface item.
- """
-
- def __init__(self, subject, normalization, *args):
- self._isoLevelSliderNormalization = normalization
- super(IsoSurfaceRootItem, self).__init__(subject, *args)
-
- def getSignals(self):
- subject = self.subject
- return [subject.sigColorChanged,
- subject.sigVisibilityChanged]
-
- def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
- if signalIdx == 0:
- color = self.subject.getColor()
- self.setData(color, qt.Qt.DecorationRole)
- elif signalIdx == 1:
- visible = args[0]
- self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
-
- def _init(self):
- self.setCheckable(True)
-
- isosurface = self.subject
- color = isosurface.getColor()
- visible = isosurface.isVisible()
- self.setData(color, qt.Qt.DecorationRole)
- self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
-
- nameItem = qt.QStandardItem('Level')
- sliderItem = IsoSurfaceLevelSlider(self.subject,
- self._isoLevelSliderNormalization)
- self.appendRow([nameItem, sliderItem])
-
- nameItem = qt.QStandardItem('Color')
- nameItem.setEditable(False)
- valueItem = IsoSurfaceColorItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Opacity')
- nameItem.setTextAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
- nameItem.setEditable(False)
- valueItem = IsoSurfaceAlphaItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem()
- nameItem.setEditable(False)
- valueItem = IsoSurfaceAlphaLegendItem(self.subject)
- valueItem.setEditable(False)
- self.appendRow([nameItem, valueItem])
-
- def queryRemove(self, view=None):
- buttons = qt.QMessageBox.Ok | qt.QMessageBox.Cancel
- ans = qt.QMessageBox.question(view,
- 'Remove isosurface',
- 'Remove the selected iso-surface?',
- buttons=buttons)
- if ans == qt.QMessageBox.Ok:
- sfview = self.subject.parent()
- if sfview:
- sfview.removeIsosurface(self.subject)
- return False
- return False
-
- def leftClicked(self):
- checked = (self.checkState() == qt.Qt.Checked)
- visible = self.subject.isVisible()
- if checked != visible:
- self.subject.setVisible(checked)
-
-
-class IsoSurfaceLevelItem(SubjectItem):
- """
- Base class for the isosurface level items.
- """
- editable = True
-
- def getSignals(self):
- subject = self.subject
- return [subject.sigLevelChanged,
- subject.sigVisibilityChanged]
-
- def getEditor(self, parent, option, index):
- return FloatEdit(parent)
-
- def setEditorData(self, editor):
- editor.setValue(self._pullData())
- return False
-
- def _setModelData(self, editor):
- self._pushData(editor.value())
- return True
-
- def _pullData(self):
- return self.subject.getLevel()
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- self.subject.setLevel(value)
- return self.subject.getLevel()
-
-
-class _IsoLevelSlider(qt.QSlider):
- """QSlider used for iso-surface level with linear scale"""
-
- def __init__(self, parent, subject, normalization):
- super(_IsoLevelSlider, self).__init__(parent=parent)
- self.subject = subject
-
- if normalization == 'arcsinh':
- self.__norm = numpy.arcsinh
- self.__invNorm = numpy.sinh
- elif normalization == 'linear':
- self.__norm = lambda x: x
- self.__invNorm = lambda x: x
- else:
- raise ValueError(
- "Unsupported normalization %s", normalization)
-
- self.sliderReleased.connect(self.__sliderReleased)
-
- self.subject.sigLevelChanged.connect(self.setLevel)
- self.subject.parent().sigDataChanged.connect(self.__dataChanged)
-
- def setLevel(self, level):
- """Set slider from iso-surface level"""
- dataRange = self.subject.parent().getDataRange()
-
- if dataRange is not None:
- min_ = self.__norm(dataRange[0])
- max_ = self.__norm(dataRange[-1])
-
- width = max_ - min_
- if width > 0:
- sliderWidth = self.maximum() - self.minimum()
- sliderPosition = sliderWidth * (self.__norm(level) - min_) / width
- self.setValue(int(sliderPosition))
-
- def __dataChanged(self):
- """Handles data update to refresh slider range if needed"""
- self.setLevel(self.subject.getLevel())
-
- def __sliderReleased(self):
- value = self.value()
- dataRange = self.subject.parent().getDataRange()
- if dataRange is not None:
- min_ = self.__norm(dataRange[0])
- max_ = self.__norm(dataRange[-1])
- width = max_ - min_
- sliderWidth = self.maximum() - self.minimum()
- level = min_ + width * value / sliderWidth
- self.subject.setLevel(self.__invNorm(level))
-
-
-class IsoSurfaceLevelSlider(IsoSurfaceLevelItem):
- """
- Isosurface level item with a slider editor.
- """
- nTicks = 1000
- persistent = True
-
- def __init__(self, subject, normalization):
- self.normalization = normalization
- super(IsoSurfaceLevelSlider, self).__init__(subject)
-
- def getEditor(self, parent, option, index):
- editor = _IsoLevelSlider(parent, self.subject, self.normalization)
- editor.setOrientation(qt.Qt.Horizontal)
- editor.setMinimum(0)
- editor.setMaximum(self.nTicks)
-
- editor.setSingleStep(1)
-
- editor.setLevel(self.subject.getLevel())
- return editor
-
- def setEditorData(self, editor):
- return True
-
- def _setModelData(self, editor):
- return True
-
-
-class IsoSurfaceColorItem(SubjectItem):
- """
- Isosurface color item.
- """
- editable = True
- persistent = True
-
- def getSignals(self):
- return self.subject.sigColorChanged
-
- def getEditor(self, parent, option, index):
- editor = QColorEditor(parent)
- color = self.subject.getColor()
- color.setAlpha(255)
- editor.color = color
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.sigColorChanged.connect(
- lambda color: self.__editorChanged(color))
- return editor
-
- def __editorChanged(self, color):
- color.setAlpha(self.subject.getColor().alpha())
- self.subject.setColor(color)
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- self.subject.setColor(value)
- return self.subject.getColor()
-
-
-class QColorEditor(qt.QWidget):
- """
- QColor editor.
- """
- sigColorChanged = qt.Signal(object)
-
- color = property(lambda self: qt.QColor(self.__color))
-
- @color.setter
- def color(self, color):
- self._setColor(color)
- self.__previousColor = color
-
- def __init__(self, *args, **kwargs):
- super(QColorEditor, self).__init__(*args, **kwargs)
- layout = qt.QHBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- button = qt.QToolButton()
- icon = qt.QIcon(qt.QPixmap(32, 32))
- button.setIcon(icon)
- layout.addWidget(button)
- button.clicked.connect(self.__showColorDialog)
- layout.addStretch(1)
-
- self.__color = None
- self.__previousColor = None
-
- def sizeHint(self):
- return qt.QSize(0, 0)
-
- def _setColor(self, qColor):
- button = self.findChild(qt.QToolButton)
- pixmap = qt.QPixmap(32, 32)
- pixmap.fill(qColor)
- button.setIcon(qt.QIcon(pixmap))
- self.__color = qColor
-
- def __showColorDialog(self):
- dialog = qt.QColorDialog(parent=self)
- if sys.platform == 'darwin':
- # Use of native color dialog on macos might cause problems
- dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
-
- self.__previousColor = self.__color
- dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
- dialog.setModal(True)
- dialog.currentColorChanged.connect(self.__colorChanged)
- dialog.finished.connect(self.__dialogClosed)
- dialog.show()
-
- def __colorChanged(self, color):
- self.__color = color
- self._setColor(color)
- self.sigColorChanged.emit(color)
-
- def __dialogClosed(self, result):
- if result == qt.QDialog.Rejected:
- self.__colorChanged(self.__previousColor)
- self.__previousColor = None
-
-
-class IsoSurfaceAlphaItem(SubjectItem):
- """
- Isosurface alpha item.
- """
- editable = True
- persistent = True
-
- def _init(self):
- pass
-
- def getSignals(self):
- return self.subject.sigColorChanged
-
- def getEditor(self, parent, option, index):
- editor = qt.QSlider(parent)
- editor.setOrientation(qt.Qt.Horizontal)
- editor.setMinimum(0)
- editor.setMaximum(255)
-
- color = self.subject.getColor()
- editor.setValue(color.alpha())
-
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.valueChanged.connect(
- lambda value: self.__editorChanged(value))
-
- return editor
-
- def __editorChanged(self, value):
- color = self.subject.getColor()
- color.setAlpha(value)
- self.subject.setColor(color)
-
- def setEditorData(self, editor):
- return True
-
- def _setModelData(self, editor):
- return True
-
-
-class IsoSurfaceAlphaLegendItem(SubjectItem):
- """Legend to place under opacity slider"""
-
- editable = False
- persistent = True
-
- def getEditor(self, parent, option, index):
- layout = qt.QHBoxLayout()
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
- layout.addWidget(qt.QLabel('0'))
- layout.addStretch(1)
- layout.addWidget(qt.QLabel('1'))
-
- editor = qt.QWidget(parent)
- editor.setLayout(layout)
- return editor
-
-
-class IsoSurfaceCount(SubjectItem):
- """
- Item displaying the number of isosurfaces.
- """
-
- def getSignals(self):
- subject = self.subject
- return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
-
- def _pullData(self):
- return len(self.subject.getIsosurfaces())
-
-
-class IsoSurfaceAddRemoveWidget(qt.QWidget):
-
- sigViewTask = qt.Signal(str)
- """Signal for the tree view to perform some task"""
-
- def __init__(self, parent, item):
- super(IsoSurfaceAddRemoveWidget, self).__init__(parent)
- self._item = item
- layout = qt.QHBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
-
- addBtn = qt.QToolButton(self)
- addBtn.setText('+')
- addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
- layout.addWidget(addBtn)
- addBtn.clicked.connect(self.__addClicked)
-
- removeBtn = qt.QToolButton(self)
- removeBtn.setText('-')
- removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
- layout.addWidget(removeBtn)
- removeBtn.clicked.connect(self.__removeClicked)
-
- layout.addStretch(1)
-
- def __addClicked(self):
- sfview = self._item.subject
- if not sfview:
- return
- dataRange = sfview.getDataRange()
- if dataRange is None:
- dataRange = [0, 1]
-
- sfview.addIsosurface(
- numpy.mean((dataRange[0], dataRange[-1])), '#0000FF')
-
- def __removeClicked(self):
- self.sigViewTask.emit('remove_iso')
-
-
-class IsoSurfaceAddRemoveItem(SubjectItem):
- """
- Item displaying a simple QToolButton allowing to add an isosurface.
- """
- persistent = True
-
- def getEditor(self, parent, option, index):
- return IsoSurfaceAddRemoveWidget(parent, self)
-
-
-class IsoSurfaceGroup(SubjectItem):
- """
- Root item for the list of isosurface items.
- """
-
- def __init__(self, subject, normalization, *args):
- self._isoLevelSliderNormalization = normalization
- super(IsoSurfaceGroup, self).__init__(subject, *args)
-
- def getSignals(self):
- subject = self.subject
- return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
-
- def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
- if signalIdx == 0:
- if len(args) >= 1:
- isosurface = args[0]
- if not isinstance(isosurface, Isosurface):
- raise ValueError('Expected an isosurface instance.')
- self.__addIsosurface(isosurface)
- else:
- raise ValueError('Expected an isosurface instance.')
- elif signalIdx == 1:
- if len(args) >= 1:
- isosurface = args[0]
- if not isinstance(isosurface, Isosurface):
- raise ValueError('Expected an isosurface instance.')
- self.__removeIsosurface(isosurface)
- else:
- raise ValueError('Expected an isosurface instance.')
-
- def __addIsosurface(self, isosurface):
- valueItem = IsoSurfaceRootItem(
- subject=isosurface,
- normalization=self._isoLevelSliderNormalization)
- nameItem = IsoSurfaceLevelItem(subject=isosurface)
- self.insertRow(max(0, self.rowCount() - 1), [valueItem, nameItem])
-
- def __removeIsosurface(self, isosurface):
- for row in range(self.rowCount()):
- child = self.child(row)
- subject = getattr(child, 'subject', None)
- if subject == isosurface:
- self.takeRow(row)
- break
-
- def _init(self):
- nameItem = IsoSurfaceAddRemoveItem(self.subject)
- valueItem = qt.QStandardItem()
- valueItem.setEditable(False)
- self.appendRow([nameItem, valueItem])
-
- subject = self.subject
- isosurfaces = subject.getIsosurfaces()
- for isosurface in isosurfaces:
- self.__addIsosurface(isosurface)
-
-
-# Cutting Plane ###############################################################
-
-class ColormapBase(SubjectItem):
- """
- Mixin class for colormap items.
- """
-
- def getSignals(self):
- return [self.subject.getCutPlanes()[0].sigColormapChanged]
-
-
-class PlaneMinRangeItem(ColormapBase):
- """
- colormap minVal item.
- Editor is a QLineEdit with a QDoubleValidator
- """
- editable = True
-
- def _pullData(self):
- colormap = self.subject.getCutPlanes()[0].getColormap()
- auto = colormap.isAutoscale()
- if auto == self.isEnabled():
- self._enableRow(not auto)
- return colormap.getVMin()
-
- def _pushData(self, value, role=qt.Qt.UserRole):
- self._setVMin(value)
-
- def _setVMin(self, value):
- colormap = self.subject.getCutPlanes()[0].getColormap()
- vMin = value
- vMax = colormap.getVMax()
-
- if vMax is not None and value > vMax:
- vMin = vMax
- vMax = value
- colormap.setVRange(vMin, vMax)
-
- def getEditor(self, parent, option, index):
- return FloatEdit(parent)
-
- def setEditorData(self, editor):
- editor.setValue(self._pullData())
- return True
-
- def _setModelData(self, editor):
- value = editor.value()
- self._setVMin(value)
- return True
-
-
-class PlaneMaxRangeItem(ColormapBase):
- """
- colormap maxVal item.
- Editor is a QLineEdit with a QDoubleValidator
- """
- editable = True
-
- def _pullData(self):
- colormap = self.subject.getCutPlanes()[0].getColormap()
- auto = colormap.isAutoscale()
- if auto == self.isEnabled():
- self._enableRow(not auto)
- return self.subject.getCutPlanes()[0].getColormap().getVMax()
-
- def _setVMax(self, value):
- colormap = self.subject.getCutPlanes()[0].getColormap()
- vMin = colormap.getVMin()
- vMax = value
- if vMin is not None and value < vMin:
- vMax = vMin
- vMin = value
- colormap.setVRange(vMin, vMax)
-
- def getEditor(self, parent, option, index):
- return FloatEdit(parent)
-
- def setEditorData(self, editor):
- editor.setText(str(self._pullData()))
- return True
-
- def _setModelData(self, editor):
- value = editor.value()
- self._setVMax(value)
- return True
-
-
-class PlaneOrientationItem(SubjectItem):
- """
- Plane orientation item.
- Editor is a QComboBox.
- """
- editable = True
-
- _PLANE_ACTIONS = (
- ('3d-plane-normal-x', 'Plane 0',
- 'Set plane perpendicular to red axis', (1., 0., 0.)),
- ('3d-plane-normal-y', 'Plane 1',
- 'Set plane perpendicular to green axis', (0., 1., 0.)),
- ('3d-plane-normal-z', 'Plane 2',
- 'Set plane perpendicular to blue axis', (0., 0., 1.)),
- )
-
- def getSignals(self):
- return [self.subject.getCutPlanes()[0].sigPlaneChanged]
-
- def _pullData(self):
- currentNormal = self.subject.getCutPlanes()[0].getNormal(
- coordinates='scene')
- for _, text, _, normal in self._PLANE_ACTIONS:
- if numpy.allclose(normal, currentNormal):
- return text
- return ''
-
- def getEditor(self, parent, option, index):
- editor = qt.QComboBox(parent)
- for iconName, text, tooltip, normal in self._PLANE_ACTIONS:
- editor.addItem(getQIcon(iconName), text)
-
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.currentIndexChanged[int].connect(
- lambda index: self.__editorChanged(index))
- return editor
-
- def __editorChanged(self, index):
- normal = self._PLANE_ACTIONS[index][3]
- plane = self.subject.getCutPlanes()[0]
- plane.setNormal(normal, coordinates='scene')
- plane.moveToCenter()
-
- def setEditorData(self, editor):
- currentText = self._pullData()
- index = 0
- for normIdx, (_, text, _, _) in enumerate(self._PLANE_ACTIONS):
- if text == currentText:
- index = normIdx
- break
- editor.setCurrentIndex(index)
- return True
-
- def _setModelData(self, editor):
- return True
-
-
-class PlaneInterpolationItem(SubjectItem):
- """Toggle cut plane interpolation method: nearest or linear.
-
- Item is checkable
- """
-
- def _init(self):
- interpolation = self.subject.getCutPlanes()[0].getInterpolation()
- self.setCheckable(True)
- self.setCheckState(
- qt.Qt.Checked if interpolation == 'linear' else qt.Qt.Unchecked)
- self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
-
- def getSignals(self):
- return [self.subject.getCutPlanes()[0].sigInterpolationChanged]
-
- def leftClicked(self):
- checked = self.checkState() == qt.Qt.Checked
- self._setInterpolation('linear' if checked else 'nearest')
-
- def _pullData(self):
- interpolation = self.subject.getCutPlanes()[0].getInterpolation()
- self._setInterpolation(interpolation)
- return interpolation[0].upper() + interpolation[1:]
-
- def _setInterpolation(self, interpolation):
- self.subject.getCutPlanes()[0].setInterpolation(interpolation)
-
-
-class PlaneDisplayBelowMinItem(SubjectItem):
- """Toggle whether to display or not values <= colormap min of the cut plane
-
- Item is checkable
- """
-
- def _init(self):
- display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
- self.setCheckable(True)
- self.setCheckState(
- qt.Qt.Checked if display else qt.Qt.Unchecked)
- self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
-
- def getSignals(self):
- return [self.subject.getCutPlanes()[0].sigTransparencyChanged]
-
- def leftClicked(self):
- checked = self.checkState() == qt.Qt.Checked
- self._setDisplayValuesBelowMin(checked)
-
- def _pullData(self):
- display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
- self._setDisplayValuesBelowMin(display)
- return "Displayed" if display else "Hidden"
-
- def _setDisplayValuesBelowMin(self, display):
- self.subject.getCutPlanes()[0].setDisplayValuesBelowMin(display)
-
-
-class PlaneColormapItem(ColormapBase):
- """
- colormap name item.
- Editor is a QComboBox
- """
- editable = True
-
- listValues = ['gray', 'reversed gray',
- 'temperature', 'red',
- 'green', 'blue',
- 'viridis', 'magma', 'inferno', 'plasma']
-
- def getEditor(self, parent, option, index):
- editor = qt.QComboBox(parent)
- editor.addItems(self.listValues)
-
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.currentIndexChanged[int].connect(
- lambda index: self.__editorChanged(index))
-
- return editor
-
- def __editorChanged(self, index):
- colormapName = self.listValues[index]
- colormap = self.subject.getCutPlanes()[0].getColormap()
- colormap.setName(colormapName)
-
- def setEditorData(self, editor):
- colormapName = self.subject.getCutPlanes()[0].getColormap().getName()
- try:
- index = self.listValues.index(colormapName)
- except ValueError:
- _logger.error('Unsupported colormap: %s', colormapName)
- else:
- editor.setCurrentIndex(index)
- return True
-
- def _setModelData(self, editor):
- self.__editorChanged(editor.currentIndex())
- return True
-
- def _pullData(self):
- return self.subject.getCutPlanes()[0].getColormap().getName()
-
-
-class PlaneAutoScaleItem(ColormapBase):
- """
- colormap autoscale item.
- Item is checkable.
- """
-
- def _init(self):
- colorMap = self.subject.getCutPlanes()[0].getColormap()
- self.setCheckable(True)
- self.setCheckState((colorMap.isAutoscale() and qt.Qt.Checked)
- or qt.Qt.Unchecked)
- self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
-
- def leftClicked(self):
- checked = (self.checkState() == qt.Qt.Checked)
- self._setAutoScale(checked)
-
- def _setAutoScale(self, auto):
- view3d = self.subject
- colormap = view3d.getCutPlanes()[0].getColormap()
-
- if auto != colormap.isAutoscale():
- if auto:
- vMin = vMax = None
- else:
- dataRange = view3d.getDataRange()
- if dataRange is None:
- vMin = vMax = None
- else:
- vMin, vMax = dataRange[0], dataRange[-1]
- colormap.setVRange(vMin, vMax)
-
- def _pullData(self):
- auto = self.subject.getCutPlanes()[0].getColormap().isAutoscale()
- self._setAutoScale(auto)
- if auto:
- data = 'Auto'
- else:
- data = 'User'
- return data
-
-
-class NormalizationNode(ColormapBase):
- """
- colormap normalization item.
- Item is a QComboBox.
- """
- editable = True
- listValues = list(Colormap.NORMALIZATIONS)
-
- def getEditor(self, parent, option, index):
- editor = qt.QComboBox(parent)
- editor.addItems(self.listValues)
-
- # Wrapping call in lambda is a workaround for PySide with Python 3
- editor.currentIndexChanged[int].connect(
- lambda index: self.__editorChanged(index))
-
- return editor
-
- def __editorChanged(self, index):
- colorMap = self.subject.getCutPlanes()[0].getColormap()
- normalization = self.listValues[index]
- self.subject.getCutPlanes()[0].setColormap(name=colorMap.getName(),
- norm=normalization,
- vmin=colorMap.getVMin(),
- vmax=colorMap.getVMax())
-
- def setEditorData(self, editor):
- normalization = self.subject.getCutPlanes()[0].getColormap().getNormalization()
- index = self.listValues.index(normalization)
- editor.setCurrentIndex(index)
- return True
-
- def _setModelData(self, editor):
- self.__editorChanged(editor.currentIndex())
- return True
-
- def _pullData(self):
- return self.subject.getCutPlanes()[0].getColormap().getNormalization()
-
-
-class PlaneGroup(SubjectItem):
- """
- Root Item for the plane items.
- """
- def _init(self):
- valueItem = qt.QStandardItem()
- valueItem.setEditable(False)
- nameItem = PlaneVisibleItem(self.subject, 'Visible')
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Colormap')
- nameItem.setEditable(False)
- valueItem = PlaneColormapItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Normalization')
- nameItem.setEditable(False)
- valueItem = NormalizationNode(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Orientation')
- nameItem.setEditable(False)
- valueItem = PlaneOrientationItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Interpolation')
- nameItem.setEditable(False)
- valueItem = PlaneInterpolationItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Autoscale')
- nameItem.setEditable(False)
- valueItem = PlaneAutoScaleItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Min')
- nameItem.setEditable(False)
- valueItem = PlaneMinRangeItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Max')
- nameItem.setEditable(False)
- valueItem = PlaneMaxRangeItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
- nameItem = qt.QStandardItem('Values<=Min')
- nameItem.setEditable(False)
- valueItem = PlaneDisplayBelowMinItem(self.subject)
- self.appendRow([nameItem, valueItem])
-
-
-class PlaneVisibleItem(SubjectItem):
- """
- Plane visibility item.
- Item is checkable.
- """
- def _init(self):
- plane = self.subject.getCutPlanes()[0]
- self.setCheckable(True)
- self.setCheckState((plane.isVisible() and qt.Qt.Checked)
- or qt.Qt.Unchecked)
-
- def leftClicked(self):
- plane = self.subject.getCutPlanes()[0]
- checked = (self.checkState() == qt.Qt.Checked)
- if checked != plane.isVisible():
- plane.setVisible(checked)
- if plane.isVisible():
- plane.moveToCenter()
-
-
-# Tree ########################################################################
-
-class ItemDelegate(qt.QStyledItemDelegate):
- """
- Delegate for the QTreeView filled with SubjectItems.
- """
-
- sigDelegateEvent = qt.Signal(str)
-
- def __init__(self, parent=None):
- super(ItemDelegate, self).__init__(parent)
-
- def createEditor(self, parent, option, index):
- item = index.model().itemFromIndex(index)
- if item:
- if isinstance(item, SubjectItem):
- editor = item.getEditor(parent, option, index)
- if editor:
- editor.setAutoFillBackground(True)
- if hasattr(editor, 'sigViewTask'):
- editor.sigViewTask.connect(self.__viewTask)
- return editor
-
- editor = super(ItemDelegate, self).createEditor(parent,
- option,
- index)
- return editor
-
- def updateEditorGeometry(self, editor, option, index):
- editor.setGeometry(option.rect)
-
- def setEditorData(self, editor, index):
- item = index.model().itemFromIndex(index)
- if item:
- if isinstance(item, SubjectItem) and item.setEditorData(editor):
- return
- super(ItemDelegate, self).setEditorData(editor, index)
-
- def setModelData(self, editor, model, index):
- item = index.model().itemFromIndex(index)
- if isinstance(item, SubjectItem) and item._setModelData(editor):
- return
- super(ItemDelegate, self).setModelData(editor, model, index)
-
- def __viewTask(self, task):
- self.sigDelegateEvent.emit(task)
-
-
-class TreeView(qt.QTreeView):
- """
- TreeView displaying the SubjectItems for the ScalarFieldView.
- """
-
- def __init__(self, parent=None):
- super(TreeView, self).__init__(parent)
- self.__openedIndex = None
- self._isoLevelSliderNormalization = 'linear'
-
- self.setIconSize(qt.QSize(16, 16))
-
- header = self.header()
- if hasattr(header, 'setSectionResizeMode'): # Qt5
- header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
- else: # Qt4
- header.setResizeMode(qt.QHeaderView.ResizeToContents)
-
- delegate = ItemDelegate()
- self.setItemDelegate(delegate)
- delegate.sigDelegateEvent.connect(self.__delegateEvent)
- self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
-
- self.clicked.connect(self.__clicked)
-
- def setSfView(self, sfView):
- """
- Sets the ScalarFieldView this view is controlling.
-
- :param sfView: A `ScalarFieldView`
- """
- model = qt.QStandardItemModel()
- model.setColumnCount(ModelColumns.ColumnMax)
- model.setHorizontalHeaderLabels(['Name', 'Value'])
-
- item = qt.QStandardItem()
- item.setEditable(False)
- model.appendRow([ViewSettingsItem(sfView, 'Style'), item])
-
- item = qt.QStandardItem()
- item.setEditable(False)
- model.appendRow([DataSetItem(sfView, 'Data'), item])
-
- item = IsoSurfaceCount(sfView)
- item.setEditable(False)
- model.appendRow([IsoSurfaceGroup(sfView,
- self._isoLevelSliderNormalization,
- 'Isosurfaces'),
- item])
-
- item = qt.QStandardItem()
- item.setEditable(False)
- model.appendRow([PlaneGroup(sfView, 'Cutting Plane'), item])
-
- self.setModel(model)
-
- def setModel(self, model):
- """
- Reimplementation of the QTreeView.setModel method. It connects the
- rowsRemoved signal and opens the persistent editors.
-
- :param qt.QStandardItemModel model: the model
- """
-
- prevModel = self.model()
- if prevModel:
- self.__openPersistentEditors(qt.QModelIndex(), False)
- try:
- prevModel.rowsRemoved.disconnect(self.rowsRemoved)
- except TypeError:
- pass
-
- super(TreeView, self).setModel(model)
- model.rowsRemoved.connect(self.rowsRemoved)
- self.__openPersistentEditors(qt.QModelIndex())
-
- def __openPersistentEditors(self, parent=None, openEditor=True):
- """
- Opens or closes the items persistent editors.
-
- :param qt.QModelIndex parent: starting index, or None if the whole tree
- is to be considered.
- :param bool openEditor: True to open the editors, False to close them.
- """
- model = self.model()
-
- if not model:
- return
-
- if not parent or not parent.isValid():
- parent = self.model().invisibleRootItem().index()
-
- if openEditor:
- meth = self.openPersistentEditor
- else:
- meth = self.closePersistentEditor
-
- curParent = parent
- children = [model.index(row, 0, curParent)
- for row in range(model.rowCount(curParent))]
-
- columnCount = model.columnCount()
-
- while len(children) > 0:
- curParent = children.pop(-1)
-
- children.extend([model.index(row, 0, curParent)
- for row in range(model.rowCount(curParent))])
-
- for colIdx in range(columnCount):
- sibling = model.sibling(curParent.row(),
- colIdx,
- curParent)
- item = model.itemFromIndex(sibling)
- if isinstance(item, SubjectItem) and item.persistent:
- meth(sibling)
-
- def rowsAboutToBeRemoved(self, parent, start, end):
- """
- Reimplementation of the QTreeView.rowsAboutToBeRemoved. Closes all
- persistent editors under parent.
-
- :param qt.QModelIndex parent: Parent index
- :param int start: Start index from parent index (inclusive)
- :param int end: End index from parent index (inclusive)
- """
- self.__openPersistentEditors(parent, False)
- super(TreeView, self).rowsAboutToBeRemoved(parent, start, end)
-
- def rowsRemoved(self, parent, start, end):
- """
- Called when QTreeView.rowsRemoved is emitted. Opens all persistent
- editors under parent.
-
- :param qt.QModelIndex parent: Parent index
- :param int start: Start index from parent index (inclusive)
- :param int end: End index from parent index (inclusive)
- """
- super(TreeView, self).rowsRemoved(parent, start, end)
- self.__openPersistentEditors(parent, True)
-
- def rowsInserted(self, parent, start, end):
- """
- Reimplementation of the QTreeView.rowsInserted. Opens all persistent
- editors under parent.
-
- :param qt.QModelIndex parent: Parent index
- :param int start: Start index from parent index
- :param int end: End index from parent index
- """
- self.__openPersistentEditors(parent, False)
- super(TreeView, self).rowsInserted(parent, start, end)
- self.__openPersistentEditors(parent)
-
- def keyReleaseEvent(self, event):
- """
- Reimplementation of the QTreeView.keyReleaseEvent.
- At the moment only Key_Delete is handled. It calls the selected item's
- queryRemove method, and deleted the item if needed.
-
- :param qt.QKeyEvent event: A key event
- """
-
- # TODO : better filtering
- key = event.key()
- modifiers = event.modifiers()
-
- if key == qt.Qt.Key_Delete and modifiers == qt.Qt.NoModifier:
- self.__removeIsosurfaces()
-
- super(TreeView, self).keyReleaseEvent(event)
-
- def __removeIsosurfaces(self):
- model = self.model()
- selected = self.selectedIndexes()
- items = []
- # WARNING : the selection mode is set to single, so we re not
- # supposed to have more than one item here.
- # Multiple selection deletion has not been tested.
- # Watch out for index invalidation
- for index in selected:
- leftIndex = model.sibling(index.row(), 0, index)
- leftItem = model.itemFromIndex(leftIndex)
- if isinstance(leftItem, SubjectItem) and leftItem not in items:
- items.append(leftItem)
-
- isos = [item for item in items if isinstance(item, IsoSurfaceRootItem)]
- if isos:
- for iso in isos:
- if iso.queryRemove(self):
- parentItem = iso.parent()
- parentItem.removeRow(iso.row())
- else:
- qt.QMessageBox.information(
- self,
- 'Remove isosurface',
- 'Select an iso-surface to remove it')
-
- def __clicked(self, index):
- """
- Called when the QTreeView.clicked signal is emitted. Calls the item's
- leftClick method.
-
- :param qt.QIndex index: An index
- """
- item = self.model().itemFromIndex(index)
- if isinstance(item, SubjectItem):
- item.leftClicked()
-
- def __delegateEvent(self, task):
- if task == 'remove_iso':
- self.__removeIsosurfaces()
-
- def setIsoLevelSliderNormalization(self, normalization):
- """Set the normalization for iso level slider
-
- This MUST be called *before* :meth:`setSfView` to have an effect.
-
- :param str normalization: Either 'linear' or 'arcsinh'
- """
- assert normalization in ('linear', 'arcsinh')
- self._isoLevelSliderNormalization = normalization
diff --git a/silx/gui/plot3d/_model/items.py b/silx/gui/plot3d/_model/items.py
deleted file mode 100644
index be51663..0000000
--- a/silx/gui/plot3d/_model/items.py
+++ /dev/null
@@ -1,1760 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-This module provides base classes to implement models for 3D scene content
-"""
-
-from __future__ import absolute_import, division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-from collections import OrderedDict
-import functools
-import logging
-import weakref
-
-import numpy
-import six
-
-from ...utils.image import convertArrayToQImage
-from ...colors import preferredColormaps
-from ... import qt, icons
-from .. import items
-from ..items.volume import Isosurface, CutPlane, ComplexIsosurface
-from ..Plot3DWidget import Plot3DWidget
-
-
-from .core import AngleDegreeRow, BaseRow, ColorProxyRow, ProxyRow, StaticRow
-
-
-_logger = logging.getLogger(__name__)
-
-
-class ItemProxyRow(ProxyRow):
- """Provides a node to proxy a data accessible through functions.
-
- It listens on sigItemChanged to trigger the update.
-
- Warning: Only weak reference are kept on fget and fset.
-
- :param Item3D item: The item to
- :param str name: The name of this node
- :param callable fget: A callable returning the data
- :param callable fset:
- An optional callable setting the data with data as a single argument.
- :param events:
- An optional event kind or list of event kinds to react upon.
- :param callable toModelData:
- An optional callable to convert from fget
- callable to data returned by the model.
- :param callable fromModelData:
- An optional callable converting data provided to the model to
- data for fset.
- :param editorHint: Data to provide as UserRole for editor selection/setup
- """
-
- def __init__(self,
- item,
- name='',
- fget=None,
- fset=None,
- events=None,
- toModelData=None,
- fromModelData=None,
- editorHint=None):
- super(ItemProxyRow, self).__init__(
- name=name,
- fget=fget,
- fset=fset,
- notify=None,
- toModelData=toModelData,
- fromModelData=fromModelData,
- editorHint=editorHint)
-
- if isinstance(events, (items.ItemChangedType,
- items.Item3DChangedType)):
- events = (events,)
- self.__events = events
- item.sigItemChanged.connect(self.__itemChanged)
-
- def __itemChanged(self, event):
- """Handle item changed
-
- :param Union[ItemChangedType,Item3DChangedType] event:
- """
- if self.__events is None or event in self.__events:
- self._notified()
-
-
-class ItemColorProxyRow(ColorProxyRow, ItemProxyRow):
- """Combines :class:`ColorProxyRow` and :class:`ItemProxyRow`"""
-
- def __init__(self, *args, **kwargs):
- ItemProxyRow.__init__(self, *args, **kwargs)
-
-
-class ItemAngleDegreeRow(AngleDegreeRow, ItemProxyRow):
- """Combines :class:`AngleDegreeRow` and :class:`ItemProxyRow`"""
-
- def __init__(self, *args, **kwargs):
- ItemProxyRow.__init__(self, *args, **kwargs)
-
-
-class _DirectionalLightProxy(qt.QObject):
- """Proxy to handle directional light with angles rather than vector.
- """
-
- sigAzimuthAngleChanged = qt.Signal()
- """Signal sent when the azimuth angle has changed."""
-
- sigAltitudeAngleChanged = qt.Signal()
- """Signal sent when altitude angle has changed."""
-
- def __init__(self, light):
- super(_DirectionalLightProxy, self).__init__()
- self._light = light
- light.addListener(self._directionUpdated)
- self._azimuth = 0
- self._altitude = 0
-
- def getAzimuthAngle(self):
- """Returns the signed angle in the horizontal plane.
-
- Unit: degrees.
- The 0 angle corresponds to the axis perpendicular to the screen.
-
- :rtype: int
- """
- return self._azimuth
-
- def getAltitudeAngle(self):
- """Returns the signed vertical angle from the horizontal plane.
-
- Unit: degrees.
- Range: [-90, +90]
-
- :rtype: int
- """
- return self._altitude
-
- def setAzimuthAngle(self, angle):
- """Set the horizontal angle.
-
- :param int angle: Angle from -z axis in zx plane in degrees.
- """
- angle = int(round(angle))
- if angle != self._azimuth:
- self._azimuth = angle
- self._updateLight()
- self.sigAzimuthAngleChanged.emit()
-
- def setAltitudeAngle(self, angle):
- """Set the horizontal angle.
-
- :param int angle: Angle from -z axis in zy plane in degrees.
- """
- angle = int(round(angle))
- if angle != self._altitude:
- self._altitude = angle
- self._updateLight()
- self.sigAltitudeAngleChanged.emit()
-
- def _directionUpdated(self, *args, **kwargs):
- """Handle light direction update in the scene"""
- # Invert direction to manipulate the 'source' pointing to
- # the center of the viewport
- x, y, z = - self._light.direction
-
- # Horizontal plane is plane xz
- azimuth = int(round(numpy.degrees(numpy.arctan2(x, z))))
- altitude = int(round(numpy.degrees(numpy.pi/2. - numpy.arccos(y))))
-
- if azimuth != self.getAzimuthAngle():
- self.setAzimuthAngle(azimuth)
-
- if altitude != self.getAltitudeAngle():
- self.setAltitudeAngle(altitude)
-
- def _updateLight(self):
- """Update light direction in the scene"""
- azimuth = numpy.radians(self._azimuth)
- delta = numpy.pi/2. - numpy.radians(self._altitude)
- if delta == 0.: # Avoids zenith position
- delta = 0.0001
- z = - numpy.sin(delta) * numpy.cos(azimuth)
- x = - numpy.sin(delta) * numpy.sin(azimuth)
- y = - numpy.cos(delta)
- self._light.direction = x, y, z
-
-
-class Settings(StaticRow):
- """Subtree for :class:`SceneWidget` style parameters.
-
- :param SceneWidget sceneWidget: The widget to control
- """
-
- def __init__(self, sceneWidget):
- background = ColorProxyRow(
- name='Background',
- fget=sceneWidget.getBackgroundColor,
- fset=sceneWidget.setBackgroundColor,
- notify=sceneWidget.sigStyleChanged)
-
- foreground = ColorProxyRow(
- name='Foreground',
- fget=sceneWidget.getForegroundColor,
- fset=sceneWidget.setForegroundColor,
- notify=sceneWidget.sigStyleChanged)
-
- text = ColorProxyRow(
- name='Text',
- fget=sceneWidget.getTextColor,
- fset=sceneWidget.setTextColor,
- notify=sceneWidget.sigStyleChanged)
-
- highlight = ColorProxyRow(
- name='Highlight',
- fget=sceneWidget.getHighlightColor,
- fset=sceneWidget.setHighlightColor,
- notify=sceneWidget.sigStyleChanged)
-
- axesIndicator = ProxyRow(
- name='Axes Indicator',
- fget=sceneWidget.isOrientationIndicatorVisible,
- fset=sceneWidget.setOrientationIndicatorVisible,
- notify=sceneWidget.sigStyleChanged)
-
- # Light direction
-
- self._lightProxy = _DirectionalLightProxy(sceneWidget.viewport.light)
-
- azimuthNode = ProxyRow(
- name='Azimuth',
- fget=self._lightProxy.getAzimuthAngle,
- fset=self._lightProxy.setAzimuthAngle,
- notify=self._lightProxy.sigAzimuthAngleChanged,
- editorHint=(-90, 90))
-
- altitudeNode = ProxyRow(
- name='Altitude',
- fget=self._lightProxy.getAltitudeAngle,
- fset=self._lightProxy.setAltitudeAngle,
- notify=self._lightProxy.sigAltitudeAngleChanged,
- editorHint=(-90, 90))
-
- lightDirection = StaticRow(('Light Direction', None),
- children=(azimuthNode, altitudeNode))
-
- # Fog
- fog = ProxyRow(
- name='Fog',
- fget=sceneWidget.getFogMode,
- fset=sceneWidget.setFogMode,
- notify=sceneWidget.sigStyleChanged,
- toModelData=lambda mode: mode is Plot3DWidget.FogMode.LINEAR,
- fromModelData=lambda mode: Plot3DWidget.FogMode.LINEAR if mode else Plot3DWidget.FogMode.NONE)
-
- # Settings row
- children = (background, foreground, text, highlight,
- axesIndicator, lightDirection, fog)
- super(Settings, self).__init__(('Settings', None), children=children)
-
-
-class Item3DRow(BaseRow):
- """Represents an :class:`Item3D` with checkable visibility
-
- :param Item3D item: The scene item to represent.
- :param str name: The optional name of the item
- """
-
- _EVENTS = items.ItemChangedType.VISIBLE, items.Item3DChangedType.LABEL
- """Events for which to update the first column in the tree"""
-
- def __init__(self, item, name=None):
- self.__name = None if name is None else six.text_type(name)
- super(Item3DRow, self).__init__()
-
- self.setFlags(
- self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable,
- 0)
- self.setFlags(self.flags(1) | qt.Qt.ItemIsSelectable, 1)
-
- self._item = weakref.ref(item)
- item.sigItemChanged.connect(self._itemChanged)
-
- def _itemChanged(self, event):
- """Handle model update upon change"""
- if event in self._EVENTS:
- model = self.model()
- if model is not None:
- index = self.index(column=0)
- model.dataChanged.emit(index, index)
-
- def item(self):
- """Returns the :class:`Item3D` item or None"""
- return self._item()
-
- def data(self, column, role):
- if column == 0:
- if role == qt.Qt.CheckStateRole:
- item = self.item()
- if item is not None and item.isVisible():
- return qt.Qt.Checked
- else:
- return qt.Qt.Unchecked
-
- elif role == qt.Qt.DecorationRole:
- return icons.getQIcon('item-3dim')
-
- elif role == qt.Qt.DisplayRole:
- if self.__name is None:
- item = self.item()
- return '' if item is None else item.getLabel()
- else:
- return self.__name
-
- return super(Item3DRow, self).data(column, role)
-
- def setData(self, column, value, role):
- if column == 0 and role == qt.Qt.CheckStateRole:
- item = self.item()
- if item is not None:
- item.setVisible(value == qt.Qt.Checked)
- return True
- else:
- return False
- return super(Item3DRow, self).setData(column, value, role)
-
- def columnCount(self):
- return 2
-
-
-class DataItem3DBoundingBoxRow(ItemProxyRow):
- """Represents :class:`DataItem3D` bounding box visibility
-
- :param DataItem3D item: The item for which to display/control bounding box
- """
-
- def __init__(self, item):
- super(DataItem3DBoundingBoxRow, self).__init__(
- item=item,
- name='Bounding box',
- fget=item.isBoundingBoxVisible,
- fset=item.setBoundingBoxVisible,
- events=items.Item3DChangedType.BOUNDING_BOX_VISIBLE)
-
-
-class MatrixProxyRow(ItemProxyRow):
- """Proxy for a row of a DataItem3D 3x3 matrix transform
-
- :param DataItem3D item:
- :param int index: Matrix row index
- """
-
- def __init__(self, item, index):
- self._item = weakref.ref(item)
- self._index = index
-
- super(MatrixProxyRow, self).__init__(
- item=item,
- name='',
- fget=self._getMatrixRow,
- fset=self._setMatrixRow,
- events=items.Item3DChangedType.TRANSFORM)
-
- def _getMatrixRow(self):
- """Returns the matrix row.
-
- :rtype: QVector3D
- """
- item = self._item()
- if item is not None:
- matrix = item.getMatrix()
- return qt.QVector3D(*matrix[self._index, :])
- else:
- return None
-
- def _setMatrixRow(self, row):
- """Set the row of the matrix
-
- :param QVector3D row: Row values to set
- """
- item = self._item()
- if item is not None:
- matrix = item.getMatrix()
- matrix[self._index, :] = row.x(), row.y(), row.z()
- item.setMatrix(matrix)
-
- def data(self, column, role):
- data = super(MatrixProxyRow, self).data(column, role)
-
- if column == 1 and role == qt.Qt.DisplayRole:
- # Convert QVector3D to text
- data = "%g; %g; %g" % (data.x(), data.y(), data.z())
-
- return data
-
-
-class DataItem3DTransformRow(StaticRow):
- """Represents :class:`DataItem3D` transform parameters
-
- :param DataItem3D item: The item for which to display/control transform
- """
-
- _ROTATION_CENTER_OPTIONS = 'Origin', 'Lower', 'Center', 'Upper'
-
- def __init__(self, item):
- super(DataItem3DTransformRow, self).__init__(('Transform', None))
- self._item = weakref.ref(item)
-
- translation = ItemProxyRow(
- item=item,
- name='Translation',
- fget=item.getTranslation,
- fset=self._setTranslation,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: qt.QVector3D(*data))
- self.addRow(translation)
-
- # Here to keep a reference
- self._xSetCenter = functools.partial(self._setCenter, index=0)
- self._ySetCenter = functools.partial(self._setCenter, index=1)
- self._zSetCenter = functools.partial(self._setCenter, index=2)
-
- rotateCenter = StaticRow(
- ('Center', None),
- children=(
- ItemProxyRow(item=item,
- name='X axis',
- fget=item.getRotationCenter,
- fset=self._xSetCenter,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=functools.partial(
- self._centerToModelData, index=0),
- editorHint=self._ROTATION_CENTER_OPTIONS),
- ItemProxyRow(item=item,
- name='Y axis',
- fget=item.getRotationCenter,
- fset=self._ySetCenter,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=functools.partial(
- self._centerToModelData, index=1),
- editorHint=self._ROTATION_CENTER_OPTIONS),
- ItemProxyRow(item=item,
- name='Z axis',
- fget=item.getRotationCenter,
- fset=self._zSetCenter,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=functools.partial(
- self._centerToModelData, index=2),
- editorHint=self._ROTATION_CENTER_OPTIONS),
- ))
-
- rotate = StaticRow(
- ('Rotation', None),
- children=(
- ItemAngleDegreeRow(
- item=item,
- name='Angle',
- fget=item.getRotation,
- fset=self._setAngle,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: data[0]),
- ItemProxyRow(
- item=item,
- name='Axis',
- fget=item.getRotation,
- fset=self._setAxis,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: qt.QVector3D(*data[1])),
- rotateCenter
- ))
- self.addRow(rotate)
-
- scale = ItemProxyRow(
- item=item,
- name='Scale',
- fget=item.getScale,
- fset=self._setScale,
- events=items.Item3DChangedType.TRANSFORM,
- toModelData=lambda data: qt.QVector3D(*data))
- self.addRow(scale)
-
- matrix = StaticRow(
- ('Matrix', None),
- children=(MatrixProxyRow(item, 0),
- MatrixProxyRow(item, 1),
- MatrixProxyRow(item, 2)))
- self.addRow(matrix)
-
- def item(self):
- """Returns the :class:`Item3D` item or None"""
- return self._item()
-
- @staticmethod
- def _centerToModelData(center, index):
- """Convert rotation center information from scene to model.
-
- :param center: The center info from the scene
- :param int index: dimension to convert
- """
- value = center[index]
- if isinstance(value, six.string_types):
- return value.title()
- elif value == 0.:
- return 'Origin'
- else:
- return six.text_type(value)
-
- def _setCenter(self, value, index):
- """Set one dimension of the rotation center.
-
- :param value: Value received through the model.
- :param int index: dimension to set
- """
- item = self.item()
- if item is not None:
- if value == 'Origin':
- value = 0.
- elif value not in self._ROTATION_CENTER_OPTIONS:
- value = float(value)
- else:
- value = value.lower()
-
- center = list(item.getRotationCenter())
- center[index] = value
- item.setRotationCenter(*center)
-
- def _setAngle(self, angle):
- """Set rotation angle.
-
- :param float angle:
- """
- item = self.item()
- if item is not None:
- _, axis = item.getRotation()
- item.setRotation(angle, axis)
-
- def _setAxis(self, axis):
- """Set rotation axis.
-
- :param QVector3D axis:
- """
- item = self.item()
- if item is not None:
- angle, _ = item.getRotation()
- item.setRotation(angle, (axis.x(), axis.y(), axis.z()))
-
- def _setTranslation(self, translation):
- """Set translation transform.
-
- :param QVector3D translation:
- """
- item = self.item()
- if item is not None:
- item.setTranslation(translation.x(), translation.y(), translation.z())
-
- def _setScale(self, scale):
- """Set scale transform.
-
- :param QVector3D scale:
- """
- item = self.item()
- if item is not None:
- sx, sy, sz = scale.x(), scale.y(), scale.z()
- if sx == 0. or sy == 0. or sz == 0.:
- _logger.warning('Cannot set scale to 0: ignored')
- else:
- item.setScale(scale.x(), scale.y(), scale.z())
-
-
-class GroupItemRow(Item3DRow):
- """Represents a :class:`GroupItem` with transforms and children
-
- :param GroupItem item: The scene group to represent.
- :param str name: The optional name of the group
- """
-
- _CHILDREN_ROW_OFFSET = 2
- """Number of rows for group parameters. Children are added after"""
-
- def __init__(self, item, name=None):
- super(GroupItemRow, self).__init__(item, name)
- self.addRow(DataItem3DBoundingBoxRow(item))
- self.addRow(DataItem3DTransformRow(item))
-
- item.sigItemAdded.connect(self._itemAdded)
- item.sigItemRemoved.connect(self._itemRemoved)
-
- for child in item.getItems():
- self.addRow(nodeFromItem(child))
-
- def _itemAdded(self, item):
- """Handle item addition to the group and add it to the model.
-
- :param Item3D item: added item
- """
- group = self.item()
- if group is None:
- return
-
- row = group.getItems().index(item)
- self.addRow(nodeFromItem(item), row + self._CHILDREN_ROW_OFFSET)
-
- def _itemRemoved(self, item):
- """Handle item removal from the group and remove it from the model.
-
- :param Item3D item: removed item
- """
- group = self.item()
- if group is None:
- return
-
- # Find item
- for row in self.children():
- if isinstance(row, Item3DRow) and row.item() is item:
- self.removeRow(row)
- break # Got it
- else:
- raise RuntimeError("Model does not correspond to scene content")
-
-
-class InterpolationRow(ItemProxyRow):
- """Represents :class:`InterpolationMixIn` property.
-
- :param Item3D item: Scene item with interpolation property
- """
-
- def __init__(self, item):
- modes = [mode.title() for mode in item.INTERPOLATION_MODES]
- super(InterpolationRow, self).__init__(
- item=item,
- name='Interpolation',
- fget=item.getInterpolation,
- fset=item.setInterpolation,
- events=items.Item3DChangedType.INTERPOLATION,
- toModelData=lambda mode: mode.title(),
- fromModelData=lambda mode: mode.lower(),
- editorHint=modes)
-
-
-class _ColormapBaseProxyRow(ProxyRow):
- """Base class for colormap model row
-
- This class handle synchronization and signals from the item and the colormap
- """
-
- _sigColormapChanged = qt.Signal()
- """Signal used internally to notify colormap (or data) update"""
-
- def __init__(self, item, *args, **kwargs):
- self._item = weakref.ref(item)
- self._colormap = item.getColormap()
-
- ProxyRow.__init__(self, *args, **kwargs)
-
- self._colormap.sigChanged.connect(self._colormapChanged)
- item.sigItemChanged.connect(self._itemChanged)
- self._sigColormapChanged.connect(self._modelUpdated)
-
- def item(self):
- """Returns the :class:`ColormapMixIn` item or None"""
- return self._item()
-
- def _getColormapRange(self):
- """Returns the range of the colormap for the current data.
-
- :return: Colormap range (min, max)
- """
- item = self.item()
- if item is not None and self._colormap is not None:
- return self._colormap.getColormapRange(item)
- else:
- return 1, 100 # Fallback
-
- def _modelUpdated(self, *args, **kwargs):
- """Emit dataChanged in the model"""
- topLeft = self.index(column=0)
- bottomRight = self.index(column=1)
- model = self.model()
- if model is not None:
- model.dataChanged.emit(topLeft, bottomRight)
-
- def _colormapChanged(self):
- self._sigColormapChanged.emit()
-
- def _itemChanged(self, event):
- """Handle change of colormap or data in the item.
-
- :param ItemChangedType event:
- """
- if event == items.ItemChangedType.COLORMAP:
- self._sigColormapChanged.emit()
- if self._colormap is not None:
- self._colormap.sigChanged.disconnect(self._colormapChanged)
-
- item = self.item()
- if item is not None:
- self._colormap = item.getColormap()
- self._colormap.sigChanged.connect(self._colormapChanged)
- else:
- self._colormap = None
-
- elif event == items.ItemChangedType.DATA:
- self._sigColormapChanged.emit()
-
-
-class _ColormapBoundRow(_ColormapBaseProxyRow):
- """ProxyRow for colormap min or max
-
- :param ColormapMixIn item: The item to handle
- :param str name: Name of the raw
- :param int index: 0 for Min and 1 of Max
- """
-
- def __init__(self, item, name, index):
- self._index = index
- _ColormapBaseProxyRow.__init__(
- self,
- item,
- name=name,
- fget=self._getBound,
- fset=self._setBound)
-
- self.setToolTip('Colormap %s bound:\n'
- 'Check to set bound manually, '
- 'uncheck for autoscale' % name.lower())
-
- def _getRawBound(self):
- """Proxy to get raw colormap bound
-
- :rtype: float or None
- """
- if self._colormap is None:
- return None
- elif self._index == 0:
- return self._colormap.getVMin()
- else: # self._index == 1
- return self._colormap.getVMax()
-
- def _getBound(self):
- """Proxy to get colormap effective bound value
-
- :rtype: float
- """
- if self._colormap is not None:
- bound = self._getRawBound()
-
- if bound is None:
- bound = self._getColormapRange()[self._index]
- return bound
- else:
- return 1. # Fallback
-
- def _setBound(self, value):
- """Proxy to set colormap bound.
-
- :param float value:
- """
- if self._colormap is not None:
- if self._index == 0:
- min_ = value
- max_ = self._colormap.getVMax()
- else: # self._index == 1
- min_ = self._colormap.getVMin()
- max_ = value
-
- if max_ is not None and min_ is not None and min_ > max_:
- min_, max_ = max_, min_
- self._colormap.setVRange(min_, max_)
-
- def flags(self, column):
- if column == 0:
- return qt.Qt.ItemIsEnabled | qt.Qt.ItemIsUserCheckable
-
- elif column == 1:
- if self._getRawBound() is not None:
- flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
- else:
- flags = qt.Qt.NoItemFlags # Disabled if autoscale
- return flags
-
- else: # Never event
- return super(_ColormapBoundRow, self).flags(column)
-
- def data(self, column, role):
- if column == 0 and role == qt.Qt.CheckStateRole:
- if self._getRawBound() is None:
- return qt.Qt.Unchecked
- else:
- return qt.Qt.Checked
-
- else:
- return super(_ColormapBoundRow, self).data(column, role)
-
- def setData(self, column, value, role):
- if column == 0 and role == qt.Qt.CheckStateRole:
- if self._colormap is not None:
- bound = self._getBound() if value == qt.Qt.Checked else None
- self._setBound(bound)
- return True
- else:
- return False
-
- return super(_ColormapBoundRow, self).setData(column, value, role)
-
-
-class _ColormapGammaRow(_ColormapBaseProxyRow):
- """ProxyRow for colormap gamma normalization parameter
-
- :param ColormapMixIn item: The item to handle
- :param str name: Name of the raw
- """
-
- def __init__(self, item):
- _ColormapBaseProxyRow.__init__(
- self,
- item,
- name="Gamma",
- fget=self._getGammaNormalizationParameter,
- fset=self._setGammaNormalizationParameter)
-
- self.setToolTip('Colormap gamma correction parameter:\n'
- 'Only meaningful for gamma normalization.')
-
- def _getGammaNormalizationParameter(self):
- """Proxy for :meth:`Colormap.getGammaNormalizationParameter`"""
- if self._colormap is not None:
- return self._colormap.getGammaNormalizationParameter()
- else:
- return 0.0
-
- def _setGammaNormalizationParameter(self, gamma):
- """Proxy for :meth:`Colormap.setGammaNormalizationParameter`"""
- if self._colormap is not None:
- return self._colormap.setGammaNormalizationParameter(gamma)
-
- def _getNormalization(self):
- """Proxy for :meth:`Colormap.getNormalization`"""
- if self._colormap is not None:
- return self._colormap.getNormalization()
- else:
- return ''
-
- def flags(self, column):
- if column in (0, 1):
- if self._getNormalization() == 'gamma':
- flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
- else:
- flags = qt.Qt.NoItemFlags # Disabled if not gamma correction
- return flags
-
- else: # Never event
- return super(_ColormapGammaRow, self).flags(column)
-
-
-class ColormapRow(_ColormapBaseProxyRow):
- """Represents :class:`ColormapMixIn` property.
-
- :param Item3D item: Scene item with colormap property
- """
-
- def __init__(self, item):
- super(ColormapRow, self).__init__(
- item,
- name='Colormap',
- fget=self._get)
-
- self._colormapImage = None
-
- self._colormapsMapping = {}
- for cmap in preferredColormaps():
- self._colormapsMapping[cmap.title()] = cmap
-
- self.addRow(ProxyRow(
- name='Name',
- fget=self._getName,
- fset=self._setName,
- notify=self._sigColormapChanged,
- editorHint=list(self._colormapsMapping.keys())))
-
- norms = [norm.title() for norm in self._colormap.NORMALIZATIONS]
- self.addRow(ProxyRow(
- name='Normalization',
- fget=self._getNormalization,
- fset=self._setNormalization,
- notify=self._sigColormapChanged,
- editorHint=norms))
-
- self.addRow(_ColormapGammaRow(item))
-
- modes = [mode.title() for mode in self._colormap.AUTOSCALE_MODES]
- self.addRow(ProxyRow(
- name='Autoscale Mode',
- fget=self._getAutoscaleMode,
- fset=self._setAutoscaleMode,
- notify=self._sigColormapChanged,
- editorHint=modes))
-
- self.addRow(_ColormapBoundRow(item, name='Min.', index=0))
- self.addRow(_ColormapBoundRow(item, name='Max.', index=1))
-
- self._sigColormapChanged.connect(self._updateColormapImage)
-
- def getColormapImage(self):
- """Returns image representing the colormap or None
-
- :rtype: Union[QImage,None]
- """
- if self._colormapImage is None and self._colormap is not None:
- image = numpy.zeros((16, 130, 3), dtype=numpy.uint8)
- image[1:-1, 1:-1] = self._colormap.getNColors(image.shape[1] - 2)[:, :3]
- self._colormapImage = convertArrayToQImage(image)
- return self._colormapImage
-
- def _get(self):
- """Getter for ProxyRow subclass"""
- return None
-
- def _getName(self):
- """Proxy for :meth:`Colormap.getName`"""
- if self._colormap is not None and self._colormap.getName() is not None:
- return self._colormap.getName().title()
- else:
- return ''
-
- def _setName(self, name):
- """Proxy for :meth:`Colormap.setName`"""
- # Convert back from titled to name if possible
- if self._colormap is not None:
- name = self._colormapsMapping.get(name, name)
- self._colormap.setName(name)
-
- def _getNormalization(self):
- """Proxy for :meth:`Colormap.getNormalization`"""
- if self._colormap is not None:
- return self._colormap.getNormalization().title()
- else:
- return ''
-
- def _setNormalization(self, normalization):
- """Proxy for :meth:`Colormap.setNormalization`"""
- if self._colormap is not None:
- return self._colormap.setNormalization(normalization.lower())
-
- def _getAutoscaleMode(self):
- """Proxy for :meth:`Colormap.getAutoscaleMode`"""
- if self._colormap is not None:
- return self._colormap.getAutoscaleMode().title()
- else:
- return ''
-
- def _setAutoscaleMode(self, mode):
- """Proxy for :meth:`Colormap.setAutoscaleMode`"""
- if self._colormap is not None:
- return self._colormap.setAutoscaleMode(mode.lower())
-
- def _updateColormapImage(self, *args, **kwargs):
- """Notify colormap update to update the image in the tree"""
- if self._colormapImage is not None:
- self._colormapImage = None
- model = self.model()
- if model is not None:
- index = self.index(column=1)
- model.dataChanged.emit(index, index)
-
- def data(self, column, role):
- if column == 1 and role == qt.Qt.DecorationRole:
- return self.getColormapImage()
- else:
- return super(ColormapRow, self).data(column, role)
-
-
-class SymbolRow(ItemProxyRow):
- """Represents :class:`SymbolMixIn` symbol property.
-
- :param Item3D item: Scene item with symbol property
- """
-
- def __init__(self, item):
- names = [item.getSymbolName(s) for s in item.getSupportedSymbols()]
- super(SymbolRow, self).__init__(
- item=item,
- name='Marker',
- fget=item.getSymbolName,
- fset=item.setSymbol,
- events=items.ItemChangedType.SYMBOL,
- editorHint=names)
-
-
-class SymbolSizeRow(ItemProxyRow):
- """Represents :class:`SymbolMixIn` symbol size property.
-
- :param Item3D item: Scene item with symbol size property
- """
-
- def __init__(self, item):
- super(SymbolSizeRow, self).__init__(
- item=item,
- name='Marker size',
- fget=item.getSymbolSize,
- fset=item.setSymbolSize,
- events=items.ItemChangedType.SYMBOL_SIZE,
- editorHint=(1, 20)) # TODO link with OpenGL max point size
-
-
-class PlaneEquationRow(ItemProxyRow):
- """Represents :class:`PlaneMixIn` as plane equation.
-
- :param Item3D item: Scene item with plane equation property
- """
-
- def __init__(self, item):
- super(PlaneEquationRow, self).__init__(
- item=item,
- name='Equation',
- fget=item.getParameters,
- fset=item.setParameters,
- events=items.ItemChangedType.POSITION,
- toModelData=lambda data: qt.QVector4D(*data),
- fromModelData=lambda data: (data.x(), data.y(), data.z(), data.w()))
- self._item = weakref.ref(item)
-
- def data(self, column, role):
- if column == 1 and role == qt.Qt.DisplayRole:
- item = self._item()
- if item is not None:
- params = item.getParameters()
- return ('%gx %+gy %+gz %+g = 0' %
- (params[0], params[1], params[2], params[3]))
- return super(PlaneEquationRow, self).data(column, role)
-
-
-class PlaneRow(ItemProxyRow):
- """Represents :class:`PlaneMixIn` property.
-
- :param Item3D item: Scene item with plane equation property
- """
-
- _PLANES = OrderedDict((('Plane 0', (1., 0., 0.)),
- ('Plane 1', (0., 1., 0.)),
- ('Plane 2', (0., 0., 1.)),
- ('-', None)))
- """Mapping of plane names to normals"""
-
- _PLANE_ICONS = {'Plane 0': '3d-plane-normal-x',
- 'Plane 1': '3d-plane-normal-y',
- 'Plane 2': '3d-plane-normal-z',
- '-': '3d-plane'}
- """Mapping of plane names to normals"""
-
- def __init__(self, item):
- super(PlaneRow, self).__init__(
- item=item,
- name='Plane',
- fget=self.__getPlaneName,
- fset=self.__setPlaneName,
- events=items.ItemChangedType.POSITION,
- editorHint=tuple(self._PLANES.keys()))
- self._item = weakref.ref(item)
- self._lastName = None
-
- self.addRow(PlaneEquationRow(item))
-
- def _notified(self, *args, **kwargs):
- """Handle notification of modification
-
- Here only send if plane name actually changed
- """
- if self._lastName != self.__getPlaneName():
- super(PlaneRow, self)._notified()
-
- def __getPlaneName(self):
- """Returns name of plane // to axes or '-'
-
- :rtype: str
- """
- item = self._item()
- planeNormal = item.getNormal() if item is not None else None
-
- for name, normal in self._PLANES.items():
- if numpy.array_equal(planeNormal, normal):
- return name
- return '-'
-
- def __setPlaneName(self, data):
- """Set plane normal according to given plane name
-
- :param str data: Selected plane name
- """
- item = self._item()
- if item is not None:
- for name, normal in self._PLANES.items():
- if data == name and normal is not None:
- item.setNormal(normal)
-
- def data(self, column, role):
- if column == 1 and role == qt.Qt.DecorationRole:
- return icons.getQIcon(self._PLANE_ICONS[self.__getPlaneName()])
- data = super(PlaneRow, self).data(column, role)
- if column == 1 and role == qt.Qt.DisplayRole:
- self._lastName = data
- return data
-
-
-class ComplexModeRow(ItemProxyRow):
- """Represents :class:`items.ComplexMixIn` symbol property.
-
- :param Item3D item: Scene item with symbol property
- """
-
- def __init__(self, item, name='Mode'):
- names = [m.value.replace('_', ' ').title()
- for m in item.supportedComplexModes()]
- super(ComplexModeRow, self).__init__(
- item=item,
- name=name,
- fget=item.getComplexMode,
- fset=item.setComplexMode,
- events=items.ItemChangedType.COMPLEX_MODE,
- toModelData=lambda data: data.value.replace('_', ' ').title(),
- fromModelData=lambda data: data.lower().replace(' ', '_'),
- editorHint=names)
-
-
-class RemoveIsosurfaceRow(BaseRow):
- """Class for Isosurface Delete button
-
- :param Isosurface isosurface: The isosurface item to attach the button to.
- """
-
- def __init__(self, isosurface):
- super(RemoveIsosurfaceRow, self).__init__()
- self._isosurface = weakref.ref(isosurface)
-
- def createEditor(self):
- """Specific editor factory provided to the model"""
- editor = qt.QWidget()
- layout = qt.QHBoxLayout(editor)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
-
- removeBtn = qt.QToolButton()
- removeBtn.setText('Delete')
- removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
- layout.addWidget(removeBtn)
- removeBtn.clicked.connect(self._removeClicked)
-
- layout.addStretch(1)
- return editor
-
- def isosurface(self):
- """Returns the controlled isosurface
-
- :rtype: Isosurface
- """
- return self._isosurface()
-
- def data(self, column, role):
- if column == 0 and role == qt.Qt.UserRole: # editor hint
- return self.createEditor
-
- return super(RemoveIsosurfaceRow, self).data(column, role)
-
- def flags(self, column):
- flags = super(RemoveIsosurfaceRow, self).flags(column)
- if column == 0:
- flags |= qt.Qt.ItemIsEditable
- return flags
-
- def _removeClicked(self):
- """Handle Delete button clicked"""
- isosurface = self.isosurface()
- if isosurface is not None:
- volume = isosurface.parent()
- if volume is not None:
- volume.removeIsosurface(isosurface)
-
-
-class IsosurfaceRow(Item3DRow):
- """Represents an :class:`Isosurface` item.
-
- :param Isosurface item: Isosurface item
- """
-
- _LEVEL_SLIDER_RANGE = 0, 1000
- """Range given as editor hint"""
-
- _EVENTS = items.ItemChangedType.VISIBLE, items.ItemChangedType.COLOR
- """Events for which to update the first column in the tree"""
-
- def __init__(self, item):
- super(IsosurfaceRow, self).__init__(item, name=item.getLevel())
-
- self.setFlags(self.flags(1) | qt.Qt.ItemIsEditable, 1)
-
- item.sigItemChanged.connect(self._levelChanged)
-
- self.addRow(ItemProxyRow(
- item=item,
- name='Level',
- fget=self._getValueForLevelSlider,
- fset=self._setLevelFromSliderValue,
- events=items.Item3DChangedType.ISO_LEVEL,
- editorHint=self._LEVEL_SLIDER_RANGE))
-
- self.addRow(ItemColorProxyRow(
- item=item,
- name='Color',
- fget=self._rgbColor,
- fset=self._setRgbColor,
- events=items.ItemChangedType.COLOR))
-
- self.addRow(ItemProxyRow(
- item=item,
- name='Opacity',
- fget=self._opacity,
- fset=self._setOpacity,
- events=items.ItemChangedType.COLOR,
- editorHint=(0, 255)))
-
- self.addRow(RemoveIsosurfaceRow(item))
-
- def _getValueForLevelSlider(self):
- """Convert iso level to slider value.
-
- :rtype: int
- """
- item = self.item()
- if item is not None:
- volume = item.parent()
- if volume is not None:
- dataRange = volume.getDataRange()
- if dataRange is not None:
- dataMin, dataMax = dataRange[0], dataRange[-1]
- if dataMax != dataMin:
- offset = (item.getLevel() - dataMin) / (dataMax - dataMin)
- else:
- offset = 0.
-
- sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
- value = sliderMin + (sliderMax - sliderMin) * offset
- return value
- return 0
-
- def _setLevelFromSliderValue(self, value):
- """Convert slider value to isolevel.
-
- :param int value:
- """
- item = self.item()
- if item is not None:
- volume = item.parent()
- if volume is not None:
- dataRange = volume.getDataRange()
- if dataRange is not None:
- sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
- offset = (value - sliderMin) / (sliderMax - sliderMin)
-
- dataMin, dataMax = dataRange[0], dataRange[-1]
- level = dataMin + (dataMax - dataMin) * offset
- item.setLevel(level)
-
- def _rgbColor(self):
- """Proxy to get the isosurface's RGB color without transparency
-
- :rtype: QColor
- """
- item = self.item()
- if item is None:
- return None
- else:
- color = item.getColor()
- color.setAlpha(255)
- return color
-
- def _setRgbColor(self, color):
- """Proxy to set the isosurface's RGB color without transparency
-
- :param QColor color:
- """
- item = self.item()
- if item is not None:
- color.setAlpha(item.getColor().alpha())
- item.setColor(color)
-
- def _opacity(self):
- """Proxy to get the isosurface's transparency
-
- :rtype: int
- """
- item = self.item()
- return 255 if item is None else item.getColor().alpha()
-
- def _setOpacity(self, opacity):
- """Proxy to set the isosurface's transparency.
-
- :param int opacity:
- """
- item = self.item()
- if item is not None:
- color = item.getColor()
- color.setAlpha(opacity)
- item.setColor(color)
-
- def _levelChanged(self, event):
- """Handle isosurface level changed and notify model
-
- :param ItemChangedType event:
- """
- if event == items.Item3DChangedType.ISO_LEVEL:
- model = self.model()
- if model is not None:
- index = self.index(column=1)
- model.dataChanged.emit(index, index)
-
- def data(self, column, role):
- if column == 0: # Show color as decoration, not text
- if role == qt.Qt.DisplayRole:
- return None
- elif role == qt.Qt.DecorationRole:
- return self._rgbColor()
-
- elif column == 1 and role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
- item = self.item()
- return None if item is None else item.getLevel()
-
- return super(IsosurfaceRow, self).data(column, role)
-
- def setData(self, column, value, role):
- if column == 1 and role == qt.Qt.EditRole:
- item = self.item()
- if item is not None:
- item.setLevel(value)
- return True
-
- return super(IsosurfaceRow, self).setData(column, value, role)
-
-
-class ComplexIsosurfaceRow(IsosurfaceRow):
- """Represents an :class:`ComplexIsosurface` item.
-
- :param ComplexIsosurface item:
- """
-
- _EVENTS = (items.ItemChangedType.VISIBLE,
- items.ItemChangedType.COLOR,
- items.ItemChangedType.COMPLEX_MODE)
- """Events for which to update the first column in the tree"""
-
- def __init__(self, item):
- super(ComplexIsosurfaceRow, self).__init__(item)
-
- self.addRow(ComplexModeRow(item, "Color Complex Mode"), index=1)
- for row in self.children():
- if isinstance(row, ColorProxyRow):
- self._colorRow = row
- break
- else:
- raise RuntimeError("Cannot retrieve Color tree row")
- self._colormapRow = ColormapRow(item)
-
- self.__updateRowsForItem(item)
- item.sigItemChanged.connect(self.__itemChanged)
-
- def __itemChanged(self, event):
- """Update enabled/disabled rows"""
- if event == items.ItemChangedType.COMPLEX_MODE:
- item = self.sender()
- self.__updateRowsForItem(item)
-
- def __updateRowsForItem(self, item):
- """Update rows for item
-
- :param item:
- """
- if not isinstance(item, ComplexIsosurface):
- return
-
- if item.getComplexMode() == items.ComplexMixIn.ComplexMode.NONE:
- removed = self._colormapRow
- added = self._colorRow
- else:
- removed = self._colorRow
- added = self._colormapRow
-
- # Remove unwanted rows
- if removed in self.children():
- self.removeRow(removed)
-
- # Add required rows
- if added not in self.children():
- self.addRow(added, index=2)
-
- def data(self, column, role):
- if column == 0 and role == qt.Qt.DecorationRole:
- item = self.item()
- if (item is not None and
- item.getComplexMode() != items.ComplexMixIn.ComplexMode.NONE):
- return self._colormapRow.getColormapImage()
-
- return super(ComplexIsosurfaceRow, self).data(column, role)
-
-
-class AddIsosurfaceRow(BaseRow):
- """Class for Isosurface create button
-
- :param Union[ScalarField3D,ComplexField3D] volume:
- The volume item to attach the button to.
- """
-
- def __init__(self, volume):
- super(AddIsosurfaceRow, self).__init__()
- self._volume = weakref.ref(volume)
-
- def createEditor(self):
- """Specific editor factory provided to the model"""
- editor = qt.QWidget()
- layout = qt.QHBoxLayout(editor)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
-
- addBtn = qt.QToolButton()
- addBtn.setText('+')
- addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
- layout.addWidget(addBtn)
- addBtn.clicked.connect(self._addClicked)
-
- layout.addStretch(1)
- return editor
-
- def volume(self):
- """Returns the controlled volume item
-
- :rtype: Union[ScalarField3D,ComplexField3D]
- """
- return self._volume()
-
- def data(self, column, role):
- if column == 0 and role == qt.Qt.UserRole: # editor hint
- return self.createEditor
-
- return super(AddIsosurfaceRow, self).data(column, role)
-
- def flags(self, column):
- flags = super(AddIsosurfaceRow, self).flags(column)
- if column == 0:
- flags |= qt.Qt.ItemIsEditable
- return flags
-
- def _addClicked(self):
- """Handle Delete button clicked"""
- volume = self.volume()
- if volume is not None:
- dataRange = volume.getDataRange()
- if dataRange is None:
- dataRange = 0., 1.
-
- volume.addIsosurface(
- numpy.mean((dataRange[0], dataRange[-1])),
- '#0000FF')
-
-
-class VolumeIsoSurfacesRow(StaticRow):
- """Represents :class:`ScalarFieldView`'s isosurfaces
-
- :param Union[ScalarField3D,ComplexField3D] volume:
- Volume item to control
- """
-
- def __init__(self, volume):
- super(VolumeIsoSurfacesRow, self).__init__(
- ('Isosurfaces', None))
- self._volume = weakref.ref(volume)
-
- volume.sigIsosurfaceAdded.connect(self._isosurfaceAdded)
- volume.sigIsosurfaceRemoved.connect(self._isosurfaceRemoved)
-
- if isinstance(volume, items.ComplexMixIn):
- self.addRow(ComplexModeRow(volume, "Complex Mode"))
-
- for item in volume.getIsosurfaces():
- self.addRow(nodeFromItem(item))
-
- self.addRow(AddIsosurfaceRow(volume))
-
- def volume(self):
- """Returns the controlled volume item
-
- :rtype: Union[ScalarField3D,ComplexField3D]
- """
- return self._volume()
-
- def _isosurfaceAdded(self, item):
- """Handle isosurface addition
-
- :param Isosurface item: added isosurface
- """
- volume = self.volume()
- if volume is None:
- return
-
- row = volume.getIsosurfaces().index(item)
- if isinstance(volume, items.ComplexMixIn):
- row += 1 # Offset for the ComplexModeRow
- self.addRow(nodeFromItem(item), row)
-
- def _isosurfaceRemoved(self, item):
- """Handle isosurface removal
-
- :param Isosurface item: removed isosurface
- """
- volume = self.volume()
- if volume is None:
- return
-
- # Find item
- for row in self.children():
- if isinstance(row, IsosurfaceRow) and row.item() is item:
- self.removeRow(row)
- break # Got it
- else:
- raise RuntimeError("Model does not correspond to scene content")
-
-
-class Scatter2DPropertyMixInRow(object):
- """Mix-in class that enable/disable row according to Scatter2D mode.
-
- :param Scatter2D item:
- :param str propertyName: Name of the Scatter2D property of this row
- """
-
- def __init__(self, item, propertyName):
- assert propertyName in ('lineWidth', 'symbol', 'symbolSize')
- self.__propertyName = propertyName
-
- self.__isEnabled = item.isPropertyEnabled(propertyName)
- self.__updateFlags()
-
- item.sigItemChanged.connect(self.__itemChanged)
-
- def data(self, column, role):
- if column == 1 and not self.__isEnabled:
- # Discard data and editorHint if disabled
- return None
- else:
- return super(Scatter2DPropertyMixInRow, self).data(column, role)
-
- def __updateFlags(self):
- """Update model flags"""
- if self.__isEnabled:
- self.setFlags(qt.Qt.ItemIsEnabled, 0)
- self.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsEditable, 1)
- else:
- self.setFlags(qt.Qt.NoItemFlags)
-
- def __itemChanged(self, event):
- """Set flags to enable/disable the row"""
- if event == items.ItemChangedType.VISUALIZATION_MODE:
- item = self.sender()
- if item is not None: # This occurs with PySide/python2.7
- self.__isEnabled = item.isPropertyEnabled(self.__propertyName)
- self.__updateFlags()
-
- # Notify model
- model = self.model()
- if model is not None:
- begin = self.index(column=0)
- end = self.index(column=1)
- model.dataChanged.emit(begin, end)
-
-
-class Scatter2DSymbolRow(Scatter2DPropertyMixInRow, SymbolRow):
- """Specific class for Scatter2D symbol.
-
- It is enabled/disabled according to visualization mode.
-
- :param Scatter2D item:
- """
-
- def __init__(self, item):
- SymbolRow.__init__(self, item)
- Scatter2DPropertyMixInRow.__init__(self, item, 'symbol')
-
-
-class Scatter2DSymbolSizeRow(Scatter2DPropertyMixInRow, SymbolSizeRow):
- """Specific class for Scatter2D symbol size.
-
- It is enabled/disabled according to visualization mode.
-
- :param Scatter2D item:
- """
-
- def __init__(self, item):
- SymbolSizeRow.__init__(self, item)
- Scatter2DPropertyMixInRow.__init__(self, item, 'symbolSize')
-
-
-class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ItemProxyRow):
- """Specific class for Scatter2D symbol size.
-
- It is enabled/disabled according to visualization mode.
-
- :param Scatter2D item:
- """
-
- def __init__(self, item):
- # TODO link editorHint with OpenGL max line width
- ItemProxyRow.__init__(self,
- item=item,
- name='Line width',
- fget=item.getLineWidth,
- fset=item.setLineWidth,
- events=items.ItemChangedType.LINE_WIDTH,
- editorHint=(1, 10))
- Scatter2DPropertyMixInRow.__init__(self, item, 'lineWidth')
-
-
-def initScatter2DNode(node, item):
- """Specific node init for Scatter2D to set order of parameters
-
- :param Item3DRow node: The model node to setup
- :param Scatter2D item: The Scatter2D the node is representing
- """
- node.addRow(ItemProxyRow(
- item=item,
- name='Mode',
- fget=item.getVisualization,
- fset=item.setVisualization,
- events=items.ItemChangedType.VISUALIZATION_MODE,
- editorHint=[m.value.title() for m in item.supportedVisualizations()],
- toModelData=lambda data: data.value.title(),
- fromModelData=lambda data: data.lower()))
-
- node.addRow(ItemProxyRow(
- item=item,
- name='Height map',
- fget=item.isHeightMap,
- fset=item.setHeightMap,
- events=items.Item3DChangedType.HEIGHT_MAP))
-
- node.addRow(ColormapRow(item))
-
- node.addRow(Scatter2DSymbolRow(item))
- node.addRow(Scatter2DSymbolSizeRow(item))
-
- node.addRow(Scatter2DLineWidth(item))
-
-
-def initVolumeNode(node, item):
- """Specific node init for volume items
-
- :param Item3DRow node: The model node to setup
- :param Union[ScalarField3D,ComplexField3D] item:
- The volume item represented by the node
- """
- node.addRow(nodeFromItem(item.getCutPlanes()[0])) # Add cut plane
- node.addRow(VolumeIsoSurfacesRow(item))
-
-
-def initVolumeCutPlaneNode(node, item):
- """Specific node init for volume CutPlane
-
- :param Item3DRow node: The model node to setup
- :param CutPlane item: The CutPlane the node is representing
- """
- if isinstance(item, items.ComplexMixIn):
- node.addRow(ComplexModeRow(item))
-
- node.addRow(PlaneRow(item))
-
- node.addRow(ColormapRow(item))
-
- node.addRow(ItemProxyRow(
- item=item,
- name='Show <=Min',
- fget=item.getDisplayValuesBelowMin,
- fset=item.setDisplayValuesBelowMin,
- events=items.ItemChangedType.ALPHA))
-
- node.addRow(InterpolationRow(item))
-
-
-NODE_SPECIFIC_INIT = [ # class, init(node, item)
- (items.Scatter2D, initScatter2DNode),
- (items.ScalarField3D, initVolumeNode),
- (CutPlane, initVolumeCutPlaneNode),
-]
-"""List of specific node init for different item class"""
-
-
-def nodeFromItem(item):
- """Create :class:`Item3DRow` subclass corresponding to item
-
- :param Item3D item: The item fow which to create the node
- :rtype: Item3DRow
- """
- assert isinstance(item, items.Item3D)
-
- # Item with specific model row class
- if isinstance(item, (items.GroupItem, items.GroupWithAxesItem)):
- return GroupItemRow(item)
- elif isinstance(item, ComplexIsosurface):
- return ComplexIsosurfaceRow(item)
- elif isinstance(item, Isosurface):
- return IsosurfaceRow(item)
-
- # Create Item3DRow and populate it
- node = Item3DRow(item)
-
- if isinstance(item, items.DataItem3D):
- node.addRow(DataItem3DBoundingBoxRow(item))
- node.addRow(DataItem3DTransformRow(item))
-
- # Specific extra init
- for cls, specificInit in NODE_SPECIFIC_INIT:
- if isinstance(item, cls):
- specificInit(node, item)
- break
-
- else: # Generic case: handle mixins
- for cls in item.__class__.__mro__:
- if cls is items.ColormapMixIn:
- node.addRow(ColormapRow(item))
-
- elif cls is items.InterpolationMixIn:
- node.addRow(InterpolationRow(item))
-
- elif cls is items.SymbolMixIn:
- node.addRow(SymbolRow(item))
- node.addRow(SymbolSizeRow(item))
-
- elif cls is items.PlaneMixIn:
- node.addRow(PlaneRow(item))
-
- return node
diff --git a/silx/gui/plot3d/actions/io.py b/silx/gui/plot3d/actions/io.py
deleted file mode 100644
index 4020d6f..0000000
--- a/silx/gui/plot3d/actions/io.py
+++ /dev/null
@@ -1,336 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides Plot3DAction related to input/output.
-
-It provides QAction to copy, save (snapshot and video), print a Plot3DWidget.
-"""
-
-from __future__ import absolute_import, division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "06/09/2017"
-
-
-import logging
-import os
-
-import numpy
-
-from silx.gui import qt, printer
-from silx.gui.icons import getQIcon
-from .Plot3DAction import Plot3DAction
-from ..utils import mng
-from ...utils.image import convertQImageToArray
-
-
-_logger = logging.getLogger(__name__)
-
-
-class CopyAction(Plot3DAction):
- """QAction to provide copy of a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- def __init__(self, parent, plot3d=None):
- super(CopyAction, self).__init__(parent, plot3d)
-
- self.setIcon(getQIcon('edit-copy'))
- self.setText('Copy')
- self.setToolTip('Copy a snapshot of the 3D scene to the clipboard')
- self.setCheckable(False)
- self.setShortcut(qt.QKeySequence.Copy)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
- self.triggered[bool].connect(self._triggered)
-
- def _triggered(self, checked=False):
- plot3d = self.getPlot3DWidget()
- if plot3d is None:
- _logger.error('Cannot copy widget, no associated Plot3DWidget')
- else:
- image = plot3d.grabGL()
- qt.QApplication.clipboard().setImage(image)
-
-
-class SaveAction(Plot3DAction):
- """QAction to provide save snapshot of a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- def __init__(self, parent, plot3d=None):
- super(SaveAction, self).__init__(parent, plot3d)
-
- self.setIcon(getQIcon('document-save'))
- self.setText('Save...')
- self.setToolTip('Save a snapshot of the 3D scene')
- self.setCheckable(False)
- self.setShortcut(qt.QKeySequence.Save)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
- self.triggered[bool].connect(self._triggered)
-
- def _triggered(self, checked=False):
- plot3d = self.getPlot3DWidget()
- if plot3d is None:
- _logger.error('Cannot save widget, no associated Plot3DWidget')
- else:
- dialog = qt.QFileDialog(self.parent())
- dialog.setWindowTitle('Save snapshot as')
- dialog.setModal(True)
- dialog.setNameFilters(('Plot3D Snapshot PNG (*.png)',
- 'Plot3D Snapshot JPEG (*.jpg)'))
-
- dialog.setFileMode(qt.QFileDialog.AnyFile)
- dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
-
- if not dialog.exec_():
- return
-
- nameFilter = dialog.selectedNameFilter()
- filename = dialog.selectedFiles()[0]
- dialog.close()
-
- # Forces the filename extension to match the chosen filter
- extension = nameFilter.split()[-1][2:-1]
- if (len(filename) <= len(extension) or
- filename[-len(extension):].lower() != extension.lower()):
- filename += extension
-
- image = plot3d.grabGL()
- if not image.save(filename):
- _logger.error('Failed to save image as %s', filename)
- qt.QMessageBox.critical(
- self.parent(),
- 'Save snapshot as',
- 'Failed to save snapshot')
-
-
-class PrintAction(Plot3DAction):
- """QAction to provide printing of a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- def __init__(self, parent, plot3d=None):
- super(PrintAction, self).__init__(parent, plot3d)
-
- self.setIcon(getQIcon('document-print'))
- self.setText('Print...')
- self.setToolTip('Print a snapshot of the 3D scene')
- self.setCheckable(False)
- self.setShortcut(qt.QKeySequence.Print)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
- self.triggered[bool].connect(self._triggered)
-
- def getPrinter(self):
- """Return the QPrinter instance used for printing.
-
- :rtype: QPrinter
- """
- return printer.getDefaultPrinter()
-
- def _triggered(self, checked=False):
- plot3d = self.getPlot3DWidget()
- if plot3d is None:
- _logger.error('Cannot print widget, no associated Plot3DWidget')
- else:
- printer = self.getPrinter()
- dialog = qt.QPrintDialog(printer, plot3d)
- dialog.setWindowTitle('Print Plot3D snapshot')
- if not dialog.exec_():
- return
-
- image = plot3d.grabGL()
-
- # Draw pixmap with painter
- painter = qt.QPainter()
- if not painter.begin(printer):
- return
-
- if (printer.pageRect().width() < image.width() or
- printer.pageRect().height() < image.height()):
- # Downscale to page
- xScale = printer.pageRect().width() / image.width()
- yScale = printer.pageRect().height() / image.height()
- scale = min(xScale, yScale)
- else:
- scale = 1.
-
- rect = qt.QRectF(0,
- 0,
- scale * image.width(),
- scale * image.height())
- painter.drawImage(rect, image)
- painter.end()
-
-
-class VideoAction(Plot3DAction):
- """This action triggers the recording of a video of the scene.
-
- The scene is rotated 360 degrees around a vertical axis.
-
- :param parent: Action parent see :class:`QAction`.
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- PNG_SERIE_FILTER = 'Serie of PNG files (*.png)'
- MNG_FILTER = 'Multiple-image Network Graphics file (*.mng)'
-
- def __init__(self, parent, plot3d=None):
- super(VideoAction, self).__init__(parent, plot3d)
- self.setText('Record video..')
- self.setIcon(getQIcon('camera'))
- self.setToolTip(
- 'Record a video of a 360 degrees rotation of the 3D scene.')
- self.setCheckable(False)
- self.triggered[bool].connect(self._triggered)
-
- def _triggered(self, checked=False):
- """Action triggered callback"""
- plot3d = self.getPlot3DWidget()
- if plot3d is None:
- _logger.warning(
- 'Ignoring action triggered without Plot3DWidget set')
- return
-
- dialog = qt.QFileDialog(parent=plot3d)
- dialog.setWindowTitle('Save video as...')
- dialog.setModal(True)
- dialog.setNameFilters([self.PNG_SERIE_FILTER,
- self.MNG_FILTER])
- dialog.setFileMode(dialog.AnyFile)
- dialog.setAcceptMode(dialog.AcceptSave)
-
- if not dialog.exec_():
- return
-
- nameFilter = dialog.selectedNameFilter()
- filename = dialog.selectedFiles()[0]
-
- # Forces the filename extension to match the chosen filter
- extension = nameFilter.split()[-1][2:-1]
- if (len(filename) <= len(extension) or
- filename[-len(extension):].lower() != extension.lower()):
- filename += extension
-
- nbFrames = int(4. * 25) # 4 seconds, 25 fps
-
- if nameFilter == self.PNG_SERIE_FILTER:
- self._saveAsPNGSerie(filename, nbFrames)
- elif nameFilter == self.MNG_FILTER:
- self._saveAsMNG(filename, nbFrames)
- else:
- _logger.error('Unsupported file filter: %s', nameFilter)
-
- def _saveAsPNGSerie(self, filename, nbFrames):
- """Save video as serie of PNG files.
-
- It adds a counter to the provided filename before the extension.
-
- :param str filename: filename to use as template
- :param int nbFrames: Number of frames to generate
- """
- plot3d = self.getPlot3DWidget()
- assert plot3d is not None
-
- # Define filename template
- nbDigits = int(numpy.log10(nbFrames)) + 1
- indexFormat = '%%0%dd' % nbDigits
- extensionIndex = filename.rfind('.')
- filenameFormat = \
- filename[:extensionIndex] + indexFormat + filename[extensionIndex:]
-
- try:
- for index, image in enumerate(self._video360(nbFrames)):
- image.save(filenameFormat % index)
- except GeneratorExit:
- pass
-
- def _saveAsMNG(self, filename, nbFrames):
- """Save video as MNG file.
-
- :param str filename: filename to use
- :param int nbFrames: Number of frames to generate
- """
- plot3d = self.getPlot3DWidget()
- assert plot3d is not None
-
- frames = (convertQImageToArray(im) for im in self._video360(nbFrames))
- try:
- with open(filename, 'wb') as file_:
- for chunk in mng.convert(frames, nb_images=nbFrames):
- file_.write(chunk)
- except GeneratorExit:
- os.remove(filename) # Saving aborted, delete file
-
- def _video360(self, nbFrames):
- """Run the video and provides the images
-
- :param int nbFrames: The number of frames to generate for
- :return: Iterator of QImage of the video sequence
- """
- plot3d = self.getPlot3DWidget()
- assert plot3d is not None
-
- angleStep = 360. / nbFrames
-
- # Create progress bar dialog
- dialog = qt.QDialog(plot3d)
- dialog.setWindowTitle('Record Video')
- layout = qt.QVBoxLayout(dialog)
- progress = qt.QProgressBar()
- progress.setRange(0, nbFrames)
- layout.addWidget(progress)
-
- btnBox = qt.QDialogButtonBox(qt.QDialogButtonBox.Abort)
- btnBox.rejected.connect(dialog.reject)
- layout.addWidget(btnBox)
-
- dialog.setModal(True)
- dialog.show()
-
- qapp = qt.QApplication.instance()
-
- for frame in range(nbFrames):
- progress.setValue(frame)
- image = plot3d.grabGL()
- yield image
- plot3d.viewport.orbitCamera('left', angleStep)
- qapp.processEvents()
- if not dialog.isVisible():
- break # It as been rejected by the abort button
- else:
- dialog.accept()
-
- if dialog.result() == qt.QDialog.Rejected:
- raise GeneratorExit('Aborted')
diff --git a/silx/gui/plot3d/actions/mode.py b/silx/gui/plot3d/actions/mode.py
deleted file mode 100644
index ce09b4c..0000000
--- a/silx/gui/plot3d/actions/mode.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides Plot3DAction related to interaction modes.
-
-It provides QAction to rotate or pan a Plot3DWidget
-as well as toggle a picking mode.
-"""
-
-from __future__ import absolute_import, division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "06/09/2017"
-
-
-import logging
-
-from ....utils.proxy import docstring
-from ... import qt
-from ...icons import getQIcon
-from .Plot3DAction import Plot3DAction
-
-
-_logger = logging.getLogger(__name__)
-
-
-class InteractiveModeAction(Plot3DAction):
- """Base class for QAction changing interactive mode of a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param str interaction: The interactive mode this action controls
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- def __init__(self, parent, interaction, plot3d=None):
- self._interaction = interaction
-
- super(InteractiveModeAction, self).__init__(parent, plot3d)
- self.setCheckable(True)
- self.triggered[bool].connect(self._triggered)
-
- def _triggered(self, checked=False):
- plot3d = self.getPlot3DWidget()
- if plot3d is None:
- _logger.error(
- 'Cannot set %s interaction, no associated Plot3DWidget' %
- self._interaction)
- else:
- plot3d.setInteractiveMode(self._interaction)
- self.setChecked(True)
-
- @docstring(Plot3DAction)
- def setPlot3DWidget(self, widget):
- # Disconnect from previous Plot3DWidget
- plot3d = self.getPlot3DWidget()
- if plot3d is not None:
- plot3d.sigInteractiveModeChanged.disconnect(
- self._interactiveModeChanged)
-
- super(InteractiveModeAction, self).setPlot3DWidget(widget)
-
- # Connect to new Plot3DWidget
- if widget is None:
- self.setChecked(False)
- else:
- self.setChecked(widget.getInteractiveMode() == self._interaction)
- widget.sigInteractiveModeChanged.connect(
- self._interactiveModeChanged)
-
- def _interactiveModeChanged(self):
- plot3d = self.getPlot3DWidget()
- if plot3d is None:
- _logger.error('Received a signal while there is no widget')
- else:
- self.setChecked(plot3d.getInteractiveMode() == self._interaction)
-
-
-class RotateArcballAction(InteractiveModeAction):
- """QAction to set arcball rotation interaction on a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- def __init__(self, parent, plot3d=None):
- super(RotateArcballAction, self).__init__(parent, 'rotate', plot3d)
-
- self.setIcon(getQIcon('rotate-3d'))
- self.setText('Rotate')
- self.setToolTip('Rotate the view. Press <b>Ctrl</b> to pan.')
-
-
-class PanAction(InteractiveModeAction):
- """QAction to set pan interaction on a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- def __init__(self, parent, plot3d=None):
- super(PanAction, self).__init__(parent, 'pan', plot3d)
-
- self.setIcon(getQIcon('pan'))
- self.setText('Pan')
- self.setToolTip('Pan the view. Press <b>Ctrl</b> to rotate.')
-
-
-class PickingModeAction(Plot3DAction):
- """QAction to toggle picking moe on a Plot3DWidget
-
- :param parent: See :class:`QAction`
- :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
- Plot3DWidget the action is associated with
- """
-
- sigSceneClicked = qt.Signal(float, float)
- """Signal emitted when the scene is clicked with the left mouse button.
-
- This signal is only emitted when the action is checked.
-
- It provides the (x, y) clicked mouse position
- """
-
- def __init__(self, parent, plot3d=None):
- super(PickingModeAction, self).__init__(parent, plot3d)
- self.setIcon(getQIcon('pointing-hand'))
- self.setText('Picking')
- self.setToolTip('Toggle picking with left button click')
- self.setCheckable(True)
- self.triggered[bool].connect(self._triggered)
-
- def _triggered(self, checked=False):
- plot3d = self.getPlot3DWidget()
- if plot3d is not None:
- if checked:
- plot3d.sigSceneClicked.connect(self.sigSceneClicked)
- else:
- plot3d.sigSceneClicked.disconnect(self.sigSceneClicked)
-
- @docstring(Plot3DAction)
- def setPlot3DWidget(self, widget):
- # Disconnect from previous Plot3DWidget
- plot3d = self.getPlot3DWidget()
- if plot3d is not None and self.isChecked():
- plot3d.sigSceneClicked.disconnect(self.sigSceneClicked)
-
- super(PickingModeAction, self).setPlot3DWidget(widget)
-
- # Connect to new Plot3DWidget
- if widget is None:
- self.setChecked(False)
- elif self.isChecked():
- widget.sigSceneClicked.connect(self.sigSceneClicked)
diff --git a/silx/gui/plot3d/items/core.py b/silx/gui/plot3d/items/core.py
deleted file mode 100644
index ab2ceb6..0000000
--- a/silx/gui/plot3d/items/core.py
+++ /dev/null
@@ -1,779 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the base class for items of the :class:`.SceneWidget`.
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "15/11/2017"
-
-from collections import defaultdict
-import enum
-
-import numpy
-import six
-
-from ... import qt
-from ...plot.items import ItemChangedType
-from .. import scene
-from ..scene import axes, primitives, transform
-from ._pick import PickContext
-
-
-@enum.unique
-class Item3DChangedType(enum.Enum):
- """Type of modification provided by :attr:`Item3D.sigItemChanged` signal."""
-
- INTERPOLATION = 'interpolationChanged'
- """Item3D image interpolation changed flag."""
-
- TRANSFORM = 'transformChanged'
- """Item3D transform changed flag."""
-
- HEIGHT_MAP = 'heightMapChanged'
- """Item3D height map changed flag."""
-
- ISO_LEVEL = 'isoLevelChanged'
- """Isosurface level changed flag."""
-
- LABEL = 'labelChanged'
- """Item's label changed flag."""
-
- BOUNDING_BOX_VISIBLE = 'boundingBoxVisibleChanged'
- """Item's bounding box visibility changed"""
-
- ROOT_ITEM = 'rootItemChanged'
- """Item's root changed flag."""
-
-
-class Item3D(qt.QObject):
- """Base class representing an item in the scene.
-
- :param parent: The View widget this item belongs to.
- :param primitive: An optional primitive to use as scene primitive
- """
-
- _LABEL_INDICES = defaultdict(int)
- """Store per class label indices"""
-
- sigItemChanged = qt.Signal(object)
- """Signal emitted when an item's property has changed.
-
- It provides a flag describing which property of the item has changed.
- See :class:`ItemChangedType` and :class:`Item3DChangedType`
- for flags description.
- """
-
- def __init__(self, parent, primitive=None):
- qt.QObject.__init__(self, parent)
-
- if primitive is None:
- primitive = scene.Group()
-
- self._primitive = primitive
-
- self.__syncForegroundColor()
-
- labelIndex = self._LABEL_INDICES[self.__class__]
- self._label = six.text_type(self.__class__.__name__)
- if labelIndex != 0:
- self._label += u' %d' % labelIndex
- self._LABEL_INDICES[self.__class__] += 1
-
- if isinstance(parent, Item3D):
- parent.sigItemChanged.connect(self.__parentItemChanged)
-
- def setParent(self, parent):
- """Override set parent to handle root item change"""
- previousParent = self.parent()
- if isinstance(previousParent, Item3D):
- previousParent.sigItemChanged.disconnect(self.__parentItemChanged)
-
- super(Item3D, self).setParent(parent)
-
- if isinstance(parent, Item3D):
- parent.sigItemChanged.connect(self.__parentItemChanged)
-
- self._updated(Item3DChangedType.ROOT_ITEM)
-
- def __parentItemChanged(self, event):
- """Handle updates of the parent if it is an Item3D
-
- :param Item3DChangedType event:
- """
- if event == Item3DChangedType.ROOT_ITEM:
- self._updated(Item3DChangedType.ROOT_ITEM)
-
- def root(self):
- """Returns the root of the scene this item belongs to.
-
- The root is the up-most Item3D in the scene tree hierarchy.
-
- :rtype: Union[Item3D, None]
- """
- root = None
- ancestor = self.parent()
- while isinstance(ancestor, Item3D):
- root = ancestor
- ancestor = ancestor.parent()
-
- return root
-
- def _getScenePrimitive(self):
- """Return the group containing the item rendering"""
- return self._primitive
-
- def _updated(self, event=None):
- """Handle MixIn class updates.
-
- :param event: The event to send to :attr:`sigItemChanged` signal.
- """
- if event == Item3DChangedType.ROOT_ITEM:
- self.__syncForegroundColor()
-
- if event is not None:
- self.sigItemChanged.emit(event)
-
- # Label
-
- def getLabel(self):
- """Returns the label associated to this item.
-
- :rtype: str
- """
- return self._label
-
- def setLabel(self, label):
- """Set the label associated to this item.
-
- :param str label:
- """
- label = six.text_type(label)
- if label != self._label:
- self._label = label
- self._updated(Item3DChangedType.LABEL)
-
- # Visibility
-
- def isVisible(self):
- """Returns True if item is visible, else False
-
- :rtype: bool
- """
- return self._getScenePrimitive().visible
-
- def setVisible(self, visible=True):
- """Set the visibility of the item in the scene.
-
- :param bool visible: True (default) to show the item, False to hide
- """
- visible = bool(visible)
- primitive = self._getScenePrimitive()
- if visible != primitive.visible:
- primitive.visible = visible
- self._updated(ItemChangedType.VISIBLE)
-
- # Foreground color
-
- def _setForegroundColor(self, color):
- """Set the foreground color of the item.
-
- The default implementation does nothing, override it in subclass.
-
- :param color: RGBA color
- :type color: tuple of 4 float in [0., 1.]
- """
- if hasattr(super(Item3D, self), '_setForegroundColor'):
- super(Item3D, self)._setForegroundColor(color)
-
- def __syncForegroundColor(self):
- """Retrieve foreground color from parent and update this item"""
- # Look-up for SceneWidget to get its foreground color
- root = self.root()
- if root is not None:
- widget = root.parent()
- if isinstance(widget, qt.QWidget):
- self._setForegroundColor(
- widget.getForegroundColor().getRgbF())
-
- # picking
-
- def _pick(self, context):
- """Implement picking on this item.
-
- :param PickContext context: Current picking context
- :return: Data indices at picked position or None
- :rtype: Union[None,PickingResult]
- """
- if (self.isVisible() and
- context.isEnabled() and
- context.isItemPickable(self) and
- self._pickFastCheck(context)):
- return self._pickFull(context)
- return None
-
- def _pickFastCheck(self, context):
- """Approximate item pick test (e.g., bounding box-based picking).
-
- :param PickContext context: Current picking context
- :return: True if item might be picked
- :rtype: bool
- """
- primitive = self._getScenePrimitive()
-
- positionNdc = context.getNDCPosition()
- if positionNdc is None: # No picking outside viewport
- return False
-
- bounds = primitive.bounds(transformed=False, dataBounds=False)
- if bounds is None: # primitive has no bounds
- return False
-
- bounds = primitive.objectToNDCTransform.transformBounds(bounds)
-
- return (bounds[0, 0] <= positionNdc[0] <= bounds[1, 0] and
- bounds[0, 1] <= positionNdc[1] <= bounds[1, 1])
-
- def _pickFull(self, context):
- """Perform precise picking in this item at given widget position.
-
- :param PickContext context: Current picking context
- :return: Object holding the results or None
- :rtype: Union[None,PickingResult]
- """
- return None
-
-
-class DataItem3D(Item3D):
- """Base class representing a data item with transform in the scene.
-
- :param parent: The View widget this item belongs to.
- :param Union[GroupBBox, None] group:
- The scene group to use for rendering
- """
-
- def __init__(self, parent, group=None):
- if group is None:
- group = primitives.GroupBBox()
-
- # Set-up bounding box
- group.boxVisible = False
- group.axesVisible = False
- else:
- assert isinstance(group, primitives.GroupBBox)
-
- Item3D.__init__(self, parent=parent, primitive=group)
-
- # Transformations
- self._translate = transform.Translate()
- self._rotateForwardTranslation = transform.Translate()
- self._rotate = transform.Rotate()
- self._rotateBackwardTranslation = transform.Translate()
- self._translateFromRotationCenter = transform.Translate()
- self._matrix = transform.Matrix()
- self._scale = transform.Scale()
- # Group transforms to do to data before rotation
- # This is useful to handle rotation center relative to bbox
- self._transformObjectToRotate = transform.TransformList(
- [self._matrix, self._scale])
- self._transformObjectToRotate.addListener(self._updateRotationCenter)
-
- self._rotationCenter = 0., 0., 0.
-
- self.__transforms = transform.TransformList([
- self._translate,
- self._rotateForwardTranslation,
- self._rotate,
- self._rotateBackwardTranslation,
- self._transformObjectToRotate])
-
- self._getScenePrimitive().transforms = self.__transforms
-
- def _updated(self, event=None):
- """Handle MixIn class updates.
-
- :param event: The event to send to :attr:`sigItemChanged` signal.
- """
- if event == ItemChangedType.DATA:
- self._updateRotationCenter()
- super(DataItem3D, self)._updated(event)
-
- # Transformations
-
- def _getSceneTransforms(self):
- """Return TransformList corresponding to current transforms
-
- :rtype: TransformList
- """
- return self.__transforms
-
- def setScale(self, sx=1., sy=1., sz=1.):
- """Set the scale of the item in the scene.
-
- :param float sx: Scale factor along the X axis
- :param float sy: Scale factor along the Y axis
- :param float sz: Scale factor along the Z axis
- """
- scale = numpy.array((sx, sy, sz), dtype=numpy.float32)
- if not numpy.all(numpy.equal(scale, self.getScale())):
- self._scale.scale = scale
- self._updated(Item3DChangedType.TRANSFORM)
-
- def getScale(self):
- """Returns the scales provided by :meth:`setScale`.
-
- :rtype: numpy.ndarray
- """
- return self._scale.scale
-
- def setTranslation(self, x=0., y=0., z=0.):
- """Set the translation of the origin of the item in the scene.
-
- :param float x: Offset of the data origin on the X axis
- :param float y: Offset of the data origin on the Y axis
- :param float z: Offset of the data origin on the Z axis
- """
- translation = numpy.array((x, y, z), dtype=numpy.float32)
- if not numpy.all(numpy.equal(translation, self.getTranslation())):
- self._translate.translation = translation
- self._updated(Item3DChangedType.TRANSFORM)
-
- def getTranslation(self):
- """Returns the offset set by :meth:`setTranslation`.
-
- :rtype: numpy.ndarray
- """
- return self._translate.translation
-
- _ROTATION_CENTER_TAGS = 'lower', 'center', 'upper'
-
- def _updateRotationCenter(self, *args, **kwargs):
- """Update rotation center relative to bounding box"""
- center = []
- for index, position in enumerate(self.getRotationCenter()):
- # Patch position relative to bounding box
- if position in self._ROTATION_CENTER_TAGS:
- bounds = self._getScenePrimitive().bounds(
- transformed=False, dataBounds=True)
- bounds = self._transformObjectToRotate.transformBounds(bounds)
-
- if bounds is None:
- position = 0.
- elif position == 'lower':
- position = bounds[0, index]
- elif position == 'center':
- position = 0.5 * (bounds[0, index] + bounds[1, index])
- elif position == 'upper':
- position = bounds[1, index]
-
- center.append(position)
-
- if not numpy.all(numpy.equal(
- center, self._rotateForwardTranslation.translation)):
- self._rotateForwardTranslation.translation = center
- self._rotateBackwardTranslation.translation = \
- - self._rotateForwardTranslation.translation
- self._updated(Item3DChangedType.TRANSFORM)
-
- def setRotationCenter(self, x=0., y=0., z=0.):
- """Set the center of rotation of the item.
-
- Position of the rotation center is either a float
- for an absolute position or one of the following
- string to define a position relative to the item's bounding box:
- 'lower', 'center', 'upper'
-
- :param x: rotation center position on the X axis
- :rtype: float or str
- :param y: rotation center position on the Y axis
- :rtype: float or str
- :param z: rotation center position on the Z axis
- :rtype: float or str
- """
- center = []
- for position in (x, y, z):
- if isinstance(position, six.string_types):
- assert position in self._ROTATION_CENTER_TAGS
- else:
- position = float(position)
- center.append(position)
- center = tuple(center)
-
- if center != self._rotationCenter:
- self._rotationCenter = center
- self._updateRotationCenter()
-
- def getRotationCenter(self):
- """Returns the rotation center set by :meth:`setRotationCenter`.
-
- :rtype: 3-tuple of float or str
- """
- return self._rotationCenter
-
- def setRotation(self, angle=0., axis=(0., 0., 1.)):
- """Set the rotation of the item in the scene
-
- :param float angle: The rotation angle in degrees.
- :param axis: The (x, y, z) coordinates of the rotation axis.
- """
- axis = numpy.array(axis, dtype=numpy.float32)
- assert axis.ndim == 1
- assert axis.size == 3
- if (self._rotate.angle != angle or
- not numpy.all(numpy.equal(axis, self._rotate.axis))):
- self._rotate.setAngleAxis(angle, axis)
- self._updated(Item3DChangedType.TRANSFORM)
-
- def getRotation(self):
- """Returns the rotation set by :meth:`setRotation`.
-
- :return: (angle, axis)
- :rtype: 2-tuple (float, numpy.ndarray)
- """
- return self._rotate.angle, self._rotate.axis
-
- def setMatrix(self, matrix=None):
- """Set the transform matrix
-
- :param numpy.ndarray matrix: 3x3 transform matrix
- """
- matrix4x4 = numpy.identity(4, dtype=numpy.float32)
-
- if matrix is not None:
- matrix = numpy.array(matrix, dtype=numpy.float32)
- assert matrix.shape == (3, 3)
- matrix4x4[:3, :3] = matrix
-
- if not numpy.all(numpy.equal(matrix4x4, self._matrix.getMatrix())):
- self._matrix.setMatrix(matrix4x4)
- self._updated(Item3DChangedType.TRANSFORM)
-
- def getMatrix(self):
- """Returns the matrix set by :meth:`setMatrix`
-
- :return: 3x3 matrix
- :rtype: numpy.ndarray"""
- return self._matrix.getMatrix(copy=True)[:3, :3]
-
- # Bounding box
-
- def _setForegroundColor(self, color):
- """Set the color of the bounding box
-
- :param color: RGBA color as 4 floats in [0, 1]
- """
- self._getScenePrimitive().color = color
- super(DataItem3D, self)._setForegroundColor(color)
-
- def isBoundingBoxVisible(self):
- """Returns item's bounding box visibility.
-
- :rtype: bool
- """
- return self._getScenePrimitive().boxVisible
-
- def setBoundingBoxVisible(self, visible):
- """Set item's bounding box visibility.
-
- :param bool visible:
- True to show the bounding box, False (default) to hide it
- """
- visible = bool(visible)
- primitive = self._getScenePrimitive()
- if visible != primitive.boxVisible:
- primitive.boxVisible = visible
- self._updated(Item3DChangedType.BOUNDING_BOX_VISIBLE)
-
-
-class BaseNodeItem(DataItem3D):
- """Base class for data item having children (e.g., group, 3d volume)."""
-
- def __init__(self, parent=None, group=None):
- """Base class representing a group of items in the scene.
-
- :param parent: The View widget this item belongs to.
- :param Union[GroupBBox, None] group:
- The scene group to use for rendering
- """
- DataItem3D.__init__(self, parent=parent, group=group)
-
- def getItems(self):
- """Returns the list of items currently present in the group.
-
- :rtype: tuple
- """
- raise NotImplementedError('getItems must be implemented in subclass')
-
- def visit(self, included=True):
- """Generator visiting the group content.
-
- It traverses the group sub-tree in a top-down left-to-right way.
-
- :param bool included: True (default) to include self in visit
- """
- if included:
- yield self
- for child in self.getItems():
- yield child
- if hasattr(child, 'visit'):
- for item in child.visit(included=False):
- yield item
-
- def pickItems(self, x, y, condition=None):
- """Iterator over picked items in the group at given position.
-
- Each picked item yield a :class:`PickingResult` object
- holding the picking information.
-
- It traverses the group sub-tree in a left-to-right top-down way.
-
- :param int x: X widget device pixel coordinate
- :param int y: Y widget device pixel coordinate
- :param callable condition: Optional test called for each item
- checking whether to process it or not.
- """
- viewport = self._getScenePrimitive().viewport
- if viewport is None:
- raise RuntimeError(
- 'Cannot perform picking: Item not attached to a widget')
-
- context = PickContext(x, y, viewport, condition)
- for result in self._pickItems(context):
- yield result
-
- def _pickItems(self, context):
- """Implement :meth:`pickItems`
-
- :param PickContext context: Current picking context
- """
- if not self.isVisible() or not context.isEnabled():
- return # empty iterator
-
- # Use a copy to discard context changes once this returns
- context = context.copy()
-
- if not self._pickFastCheck(context):
- return # empty iterator
-
- result = self._pick(context)
- if result is not None:
- yield result
-
- for child in self.getItems():
- if isinstance(child, BaseNodeItem):
- for result in child._pickItems(context):
- yield result # Flatten result
-
- else:
- result = child._pick(context)
- if result is not None:
- yield result
-
-
-class _BaseGroupItem(BaseNodeItem):
- """Base class for group of items sharing a common transform."""
-
- sigItemAdded = qt.Signal(object)
- """Signal emitted when a new item is added to the group.
-
- The newly added item is provided by this signal
- """
-
- sigItemRemoved = qt.Signal(object)
- """Signal emitted when an item is removed from the group.
-
- The removed item is provided by this signal.
- """
-
- def __init__(self, parent=None, group=None):
- """Base class representing a group of items in the scene.
-
- :param parent: The View widget this item belongs to.
- :param Union[GroupBBox, None] group:
- The scene group to use for rendering
- """
- BaseNodeItem.__init__(self, parent=parent, group=group)
- self._items = []
-
- def _getGroupPrimitive(self):
- """Returns the group for which to handle children.
-
- This allows this group to be different from the primitive.
- """
- return self._getScenePrimitive()
-
- def addItem(self, item, index=None):
- """Add an item to the group
-
- :param Item3D item: The item to add
- :param int index: The index at which to place the item.
- By default it is appended to the end of the list.
- :raise ValueError: If the item is already in the group.
- """
- assert isinstance(item, Item3D)
- assert item.parent() in (None, self)
-
- if item in self.getItems():
- raise ValueError("Item3D already in group: %s" % item)
-
- item.setParent(self)
- if index is None:
- self._getGroupPrimitive().children.append(
- item._getScenePrimitive())
- self._items.append(item)
- else:
- self._getGroupPrimitive().children.insert(
- index, item._getScenePrimitive())
- self._items.insert(index, item)
- self.sigItemAdded.emit(item)
-
- def getItems(self):
- """Returns the list of items currently present in the group.
-
- :rtype: tuple
- """
- return tuple(self._items)
-
- def removeItem(self, item):
- """Remove an item from the scene.
-
- :param Item3D item: The item to remove from the scene
- :raises ValueError: If the item does not belong to the group
- """
- if item not in self.getItems():
- raise ValueError("Item3D not in group: %s" % str(item))
-
- self._getGroupPrimitive().children.remove(item._getScenePrimitive())
- self._items.remove(item)
- item.setParent(None)
- self.sigItemRemoved.emit(item)
-
- def clearItems(self):
- """Remove all item from the group."""
- for item in self.getItems():
- self.removeItem(item)
-
-
-class GroupItem(_BaseGroupItem):
- """Group of items sharing a common transform."""
-
- def __init__(self, parent=None):
- super(GroupItem, self).__init__(parent=parent)
-
-
-class GroupWithAxesItem(_BaseGroupItem):
- """
- Group of items sharing a common transform surrounded with labelled axes.
- """
-
- def __init__(self, parent=None):
- """Class representing a group of items in the scene with labelled axes.
-
- :param parent: The View widget this item belongs to.
- """
- super(GroupWithAxesItem, self).__init__(parent=parent,
- group=axes.LabelledAxes())
-
- # Axes labels
-
- def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None):
- """Set the text labels of the axes.
-
- :param str xlabel: Label of the X axis, None to leave unchanged.
- :param str ylabel: Label of the Y axis, None to leave unchanged.
- :param str zlabel: Label of the Z axis, None to leave unchanged.
- """
- labelledAxes = self._getScenePrimitive()
- if xlabel is not None:
- labelledAxes.xlabel = xlabel
-
- if ylabel is not None:
- labelledAxes.ylabel = ylabel
-
- if zlabel is not None:
- labelledAxes.zlabel = zlabel
-
- class _Labels(tuple):
- """Return type of :meth:`getAxesLabels`"""
-
- def getXLabel(self):
- """Label of the X axis (str)"""
- return self[0]
-
- def getYLabel(self):
- """Label of the Y axis (str)"""
- return self[1]
-
- def getZLabel(self):
- """Label of the Z axis (str)"""
- return self[2]
-
- def getAxesLabels(self):
- """Returns the text labels of the axes
-
- >>> group = GroupWithAxesItem()
- >>> group.setAxesLabels(xlabel='X')
-
- You can get the labels either as a 3-tuple:
-
- >>> xlabel, ylabel, zlabel = group.getAxesLabels()
-
- Or as an object with methods getXLabel, getYLabel and getZLabel:
-
- >>> labels = group.getAxesLabels()
- >>> labels.getXLabel()
- ... 'X'
-
- :return: object describing the labels
- """
- labelledAxes = self._getScenePrimitive()
- return self._Labels((labelledAxes.xlabel,
- labelledAxes.ylabel,
- labelledAxes.zlabel))
-
-
-class RootGroupWithAxesItem(GroupWithAxesItem):
- """Special group with axes item for root of the scene.
-
- Uses 2 groups so that axes take transforms into account.
- """
-
- def __init__(self, parent=None):
- super(RootGroupWithAxesItem, self).__init__(parent)
- self.__group = scene.Group()
- self.__group.transforms = self._getSceneTransforms()
-
- groupWithAxes = self._getScenePrimitive()
- groupWithAxes.transforms = [] # Do not apply transforms here
- groupWithAxes.children.append(self.__group)
-
- def _getGroupPrimitive(self):
- """Returns the group for which to handle children.
-
- This allows this group to be different from the primitive.
- """
- return self.__group
diff --git a/silx/gui/plot3d/items/image.py b/silx/gui/plot3d/items/image.py
deleted file mode 100644
index 4e2b396..0000000
--- a/silx/gui/plot3d/items/image.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides 2D data and RGB(A) image item class.
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "15/11/2017"
-
-import numpy
-
-from ..scene import primitives, utils
-from .core import DataItem3D, ItemChangedType
-from .mixins import ColormapMixIn, InterpolationMixIn
-from ._pick import PickingResult
-
-
-class _Image(DataItem3D, InterpolationMixIn):
- """Base class for images
-
- :param parent: The View widget this item belongs to.
- """
-
- def __init__(self, parent=None):
- DataItem3D.__init__(self, parent=parent)
- InterpolationMixIn.__init__(self)
-
- def _setPrimitive(self, primitive):
- InterpolationMixIn._setPrimitive(self, primitive)
-
- def getData(self, copy=True):
- raise NotImplementedError()
-
- def _pickFull(self, context):
- """Perform picking in this item at given widget position.
-
- :param PickContext context: Current picking context
- :return: Object holding the results or None
- :rtype: Union[None,PickingResult]
- """
- rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
- if rayObject is None:
- return None
-
- points = utils.segmentPlaneIntersect(
- rayObject[0, :3],
- rayObject[1, :3],
- planeNorm=numpy.array((0., 0., 1.), dtype=numpy.float64),
- planePt=numpy.array((0., 0., 0.), dtype=numpy.float64))
-
- if len(points) == 1: # Single intersection
- if points[0][0] < 0. or points[0][1] < 0.:
- return None # Outside image
- row, column = int(points[0][1]), int(points[0][0])
- data = self.getData(copy=False)
- height, width = data.shape[:2]
- if row < height and column < width:
- return PickingResult(
- self,
- positions=[(points[0][0], points[0][1], 0.)],
- indices=([row], [column]))
- else:
- return None # Outside image
- else: # Either no intersection or segment and image are coplanar
- return None
-
-
-class ImageData(_Image, ColormapMixIn):
- """Description of a 2D image data.
-
- :param parent: The View widget this item belongs to.
- """
-
- def __init__(self, parent=None):
- _Image.__init__(self, parent=parent)
- ColormapMixIn.__init__(self)
-
- self._data = numpy.zeros((0, 0), dtype=numpy.float32)
-
- self._image = primitives.ImageData(self._data)
- self._getScenePrimitive().children.append(self._image)
-
- # Connect scene primitive to mix-in class
- ColormapMixIn._setSceneColormap(self, self._image.colormap)
- _Image._setPrimitive(self, self._image)
-
- def setData(self, data, copy=True):
- """Set the image data to display.
-
- The data will be casted to float32.
-
- :param numpy.ndarray data: The image data
- :param bool copy: True (default) to copy the data,
- False to use as is (do not modify!).
- """
- self._image.setData(data, copy=copy)
- self._setColormappedData(self.getData(copy=False), copy=False)
- self._updated(ItemChangedType.DATA)
-
- def getData(self, copy=True):
- """Get the image data.
-
- :param bool copy:
- True (default) to get a copy,
- False to get internal representation (do not modify!).
- :rtype: numpy.ndarray
- :return: The image data
- """
- return self._image.getData(copy=copy)
-
-
-class ImageRgba(_Image, InterpolationMixIn):
- """Description of a 2D data RGB(A) image.
-
- :param parent: The View widget this item belongs to.
- """
-
- def __init__(self, parent=None):
- _Image.__init__(self, parent=parent)
- InterpolationMixIn.__init__(self)
-
- self._data = numpy.zeros((0, 0, 3), dtype=numpy.float32)
-
- self._image = primitives.ImageRgba(self._data)
- self._getScenePrimitive().children.append(self._image)
-
- # Connect scene primitive to mix-in class
- _Image._setPrimitive(self, self._image)
-
- def setData(self, data, copy=True):
- """Set the RGB(A) image data to display.
-
- Supported array format: float32 in [0, 1], uint8.
-
- :param numpy.ndarray data:
- The RGBA image data as an array of shape (H, W, Channels)
- :param bool copy: True (default) to copy the data,
- False to use as is (do not modify!).
- """
- self._image.setData(data, copy=copy)
- self._updated(ItemChangedType.DATA)
-
- def getData(self, copy=True):
- """Get the image data.
-
- :param bool copy:
- True (default) to get a copy,
- False to get internal representation (do not modify!).
- :rtype: numpy.ndarray
- :return: The image data
- """
- return self._image.getData(copy=copy)
-
-
-class _HeightMap(DataItem3D):
- """Base class for 2D data array displayed as a height field.
-
- :param parent: The View widget this item belongs to.
- """
-
- def __init__(self, parent=None):
- DataItem3D.__init__(self, parent=parent)
- self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
-
- def _pickFull(self, context, threshold=0., sort='depth'):
- """Perform picking in this item at given widget position.
-
- :param PickContext context: Current picking context
- :param float threshold: Picking threshold in pixel.
- Perform picking in a square of size threshold x threshold.
- :param str sort: How returned indices are sorted:
-
- - 'index' (default): sort by the value of the indices
- - 'depth': Sort by the depth of the points from the current
- camera point of view.
- :return: Object holding the results or None
- :rtype: Union[None,PickingResult]
- """
- assert sort in ('index', 'depth')
-
- rayNdc = context.getPickingSegment(frame='ndc')
- if rayNdc is None: # No picking outside viewport
- return None
-
- # TODO no colormapped or color data
- # Project data to NDC
- heightData = self.getData(copy=False)
- if heightData.size == 0:
- return # Nothing displayed
-
- height, width = heightData.shape
- z = numpy.ravel(heightData)
- y, x = numpy.mgrid[0:height, 0:width]
- dataPoints = numpy.transpose((numpy.ravel(x),
- numpy.ravel(y),
- z,
- numpy.ones_like(z)))
-
- primitive = self._getScenePrimitive()
-
- pointsNdc = primitive.objectToNDCTransform.transformPoints(
- dataPoints, perspectiveDivide=True)
-
- # Perform picking
- distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
- # TODO issue with symbol size: using pixel instead of points
- threshold += 1. # symbol size
- thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
- picked = numpy.where(numpy.logical_and(
- numpy.all(distancesNdc < thresholdNdc, axis=1),
- numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
- pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
-
- if sort == 'depth':
- # Sort picked points from front to back
- picked = picked[numpy.argsort(pointsNdc[picked, 2])]
-
- if picked.size > 0:
- # Convert indices from 1D to 2D
- return PickingResult(self,
- positions=dataPoints[picked, :3],
- indices=(picked // width, picked % width),
- fetchdata=self.getData)
- else:
- return None
-
- def setData(self, data, copy: bool=True):
- """Set the height field data.
-
- :param data:
- :param copy: True (default) to copy the data,
- False to use as is (do not modify!).
- """
- data = numpy.array(data, copy=copy)
- assert data.ndim == 2
-
- self.__data = data
- self._updated(ItemChangedType.DATA)
-
- def getData(self, copy: bool=True) -> numpy.ndarray:
- """Get the height field 2D data.
-
- :param bool copy:
- True (default) to get a copy,
- False to get internal representation (do not modify!).
- """
- return numpy.array(self.__data, copy=copy)
-
-
-class HeightMapData(_HeightMap, ColormapMixIn):
- """Description of a 2D height field associated to a colormapped dataset.
-
- :param parent: The View widget this item belongs to.
- """
-
- def __init__(self, parent=None):
- _HeightMap.__init__(self, parent=parent)
- ColormapMixIn.__init__(self)
-
- self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
-
- def _updated(self, event=None):
- if event == ItemChangedType.DATA:
- self.__updateScene()
- super()._updated(event=event)
-
- def __updateScene(self):
- """Update display primitive to use"""
- self._getScenePrimitive().children = [] # Remove previous primitives
- ColormapMixIn._setSceneColormap(self, None)
-
- if not self.isVisible():
- return # Update when visible
-
- data = self.getColormappedData(copy=False)
- heightData = self.getData(copy=False)
-
- if data.size == 0 or heightData.size == 0:
- return # Nothing to display
-
- # Display as a set of points
- height, width = heightData.shape
- # Generates coordinates
- y, x = numpy.mgrid[0:height, 0:width]
-
- if data.shape != heightData.shape: # data and height size miss-match
- # Colormapped data is interpolated (nearest-neighbour) to match the height field
- data = data[numpy.floor(y * data.shape[0] / height).astype(numpy.int),
- numpy.floor(x * data.shape[1] / height).astype(numpy.int)]
-
- x = numpy.ravel(x)
- y = numpy.ravel(y)
-
- primitive = primitives.Points(
- x=x,
- y=y,
- z=numpy.ravel(heightData),
- value=numpy.ravel(data),
- size=1)
- primitive.marker = 's'
- ColormapMixIn._setSceneColormap(self, primitive.colormap)
- self._getScenePrimitive().children = [primitive]
-
- def setColormappedData(self, data, copy: bool=True):
- """Set the 2D data used to compute colors.
-
- :param data: 2D array of data
- :param copy: True (default) to copy the data,
- False to use as is (do not modify!).
- """
- data = numpy.array(data, copy=copy)
- assert data.ndim == 2
-
- self.__data = data
- self._updated(ItemChangedType.DATA)
-
- def getColormappedData(self, copy: bool=True) -> numpy.ndarray:
- """Returns the 2D data used to compute colors.
-
- :param copy:
- True (default) to get a copy,
- False to get internal representation (do not modify!).
- """
- return numpy.array(self.__data, copy=copy)
-
-
-class HeightMapRGBA(_HeightMap):
- """Description of a 2D height field associated to a RGB(A) image.
-
- :param parent: The View widget this item belongs to.
- """
-
- def __init__(self, parent=None):
- _HeightMap.__init__(self, parent=parent)
-
- self.__rgba = numpy.zeros((0, 0, 3), dtype=numpy.float32)
-
- def _updated(self, event=None):
- if event == ItemChangedType.DATA:
- self.__updateScene()
- super()._updated(event=event)
-
- def __updateScene(self):
- """Update display primitive to use"""
- self._getScenePrimitive().children = [] # Remove previous primitives
-
- if not self.isVisible():
- return # Update when visible
-
- rgba = self.getColorData(copy=False)
- heightData = self.getData(copy=False)
- if rgba.size == 0 or heightData.size == 0:
- return # Nothing to display
-
- # Display as a set of points
- height, width = heightData.shape
- # Generates coordinates
- y, x = numpy.mgrid[0:height, 0:width]
-
- if rgba.shape[:2] != heightData.shape: # image and height size miss-match
- # RGBA data is interpolated (nearest-neighbour) to match the height field
- rgba = rgba[numpy.floor(y * rgba.shape[0] / height).astype(numpy.int),
- numpy.floor(x * rgba.shape[1] / height).astype(numpy.int)]
-
- x = numpy.ravel(x)
- y = numpy.ravel(y)
-
- primitive = primitives.ColorPoints(
- x=x,
- y=y,
- z=numpy.ravel(heightData),
- color=rgba.reshape(-1, rgba.shape[-1]),
- size=1)
- primitive.marker = 's'
- self._getScenePrimitive().children = [primitive]
-
- def setColorData(self, data, copy: bool=True):
- """Set the RGB(A) image to use.
-
- Supported array format: float32 in [0, 1], uint8.
-
- :param data:
- The RGBA image data as an array of shape (H, W, Channels)
- :param copy: True (default) to copy the data,
- False to use as is (do not modify!).
- """
- data = numpy.array(data, copy=copy)
- assert data.ndim == 3
- assert data.shape[-1] in (3, 4)
- # TODO check type
-
- self.__rgba = data
- self._updated(ItemChangedType.DATA)
-
- def getColorData(self, copy: bool=True) -> numpy.ndarray:
- """Get the RGB(A) image data.
-
- :param copy: True (default) to get a copy,
- False to get internal representation (do not modify!).
- """
- return numpy.array(self.__rgba, copy=copy)
diff --git a/silx/gui/plot3d/scene/test/__init__.py b/silx/gui/plot3d/scene/test/__init__.py
deleted file mode 100644
index fc4621e..0000000
--- a/silx/gui/plot3d/scene/test/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "25/07/2016"
-
-
-import unittest
-
-from .test_transform import suite as test_transform_suite
-from .test_utils import suite as test_utils_suite
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(test_transform_suite())
- testsuite.addTest(test_utils_suite())
- return testsuite
diff --git a/silx/gui/plot3d/scene/test/test_transform.py b/silx/gui/plot3d/scene/test/test_transform.py
deleted file mode 100644
index 9ea0af1..0000000
--- a/silx/gui/plot3d/scene/test/test_transform.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "05/01/2017"
-
-
-import numpy
-import unittest
-
-from silx.gui.plot3d.scene import transform
-
-
-class TestTransformList(unittest.TestCase):
-
- def assertSameArrays(self, a, b):
- return self.assertTrue(numpy.allclose(a, b, atol=1e-06))
-
- def testTransformList(self):
- """Minimalistic test of TransformList"""
- transforms = transform.TransformList()
- refmatrix = numpy.identity(4, dtype=numpy.float32)
- self.assertSameArrays(refmatrix, transforms.matrix)
-
- # Append translate
- transforms.append(transform.Translate(1., 1., 1.))
- refmatrix = numpy.array(((1., 0., 0., 1.),
- (0., 1., 0., 1.),
- (0., 0., 1., 1.),
- (0., 0., 0., 1.)), dtype=numpy.float32)
- self.assertSameArrays(refmatrix, transforms.matrix)
-
- # Extend scale
- transforms.extend([transform.Scale(0.1, 2., 1.)])
- refmatrix = numpy.dot(refmatrix,
- numpy.array(((0.1, 0., 0., 0.),
- (0., 2., 0., 0.),
- (0., 0., 1., 0.),
- (0., 0., 0., 1.)),
- dtype=numpy.float32))
- self.assertSameArrays(refmatrix, transforms.matrix)
-
- # Insert rotate
- transforms.insert(0, transform.Rotate(360.))
- self.assertSameArrays(refmatrix, transforms.matrix)
-
- # Update translate and check for listener called
- self._callCount = 0
-
- def listener(source):
- self._callCount += 1
- transforms.addListener(listener)
-
- transforms[1].tx += 1
- self.assertEqual(self._callCount, 1)
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestTransformList))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/scene/test/test_utils.py b/silx/gui/plot3d/scene/test/test_utils.py
deleted file mode 100644
index 4a2d515..0000000
--- a/silx/gui/plot3d/scene/test/test_utils.py
+++ /dev/null
@@ -1,275 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-from silx.utils.testutils import ParametricTestCase
-
-import numpy
-
-from silx.gui.plot3d.scene import utils
-
-
-# angleBetweenVectors #########################################################
-
-class TestAngleBetweenVectors(ParametricTestCase):
-
- TESTS = { # name: (refvector, vectors, norm, refangles)
- 'single vector':
- ((1., 0., 0.), (1., 0., 0.), (0., 0., 1.), 0.),
- 'single vector, no norm':
- ((1., 0., 0.), (1., 0., 0.), None, 0.),
-
- 'with orthogonal norm':
- ((1., 0., 0.),
- ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
- (0., 0., 1.),
- (0., 90., 180., 270.)),
-
- 'with coplanar norm': # = similar to no norm
- ((1., 0., 0.),
- ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
- (1., 0., 0.),
- (0., 90., 180., 90.)),
-
- 'without norm':
- ((1., 0., 0.),
- ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
- None,
- (0., 90., 180., 90.)),
-
- 'not unit vectors':
- ((2., 2., 0.), ((1., 1., 0.), (1., -1., 0.)), None, (0., 90.)),
- }
-
- def testAngleBetweenVectorsFunction(self):
- for name, params in self.TESTS.items():
- refvector, vectors, norm, refangles = params
- with self.subTest(name):
- refangles = numpy.radians(refangles)
-
- refvector = numpy.array(refvector)
- vectors = numpy.array(vectors)
- if norm is not None:
- norm = numpy.array(norm)
-
- testangles = utils.angleBetweenVectors(
- refvector, vectors, norm)
-
- self.assertTrue(
- numpy.allclose(testangles, refangles, atol=1e-5))
-
-
-# Plane #######################################################################
-
-class AssertNotificationContext(object):
- """Context that checks if an event.Notifier is sending events."""
-
- def __init__(self, notifier, count=1):
- """Initializer.
-
- :param event.Notifier notifier: The notifier to test.
- :param int count: The expected number of calls.
- """
- self._notifier = notifier
- self._callCount = None
- self._count = count
-
- def __enter__(self):
- self._callCount = 0
- self._notifier.addListener(self._callback)
-
- def __exit__(self, exc_type, exc_value, traceback):
- # Do not return True so exceptions are propagated
- self._notifier.removeListener(self._callback)
- assert self._callCount == self._count
- self._callCount = None
-
- def _callback(self, *args, **kwargs):
- self._callCount += 1
-
-
-class TestPlaneParameters(ParametricTestCase):
- """Test Plane.parameters read/write and notifications."""
-
- PARAMETERS = {
- 'unit normal': (1., 0., 0., 1.),
- 'not unit normal': (1., 1., 0., 1.),
- 'd = 0': (1., 0., 0., 0.)
- }
-
- def testParameters(self):
- """Check parameters read/write and notification."""
- plane = utils.Plane()
-
- for name, parameters in self.PARAMETERS.items():
- with self.subTest(name, parameters=parameters):
- with AssertNotificationContext(plane):
- plane.parameters = parameters
-
- # Plane parameters are converted to have a unit normal
- normparams = parameters / numpy.linalg.norm(parameters[:3])
- self.assertTrue(numpy.allclose(plane.parameters, normparams))
-
- ZEROS_PARAMETERS = (
- (0., 0., 0., 0.),
- (0., 0., 0., 1.)
- )
-
- ZEROS = 0., 0., 0., 0.
-
- def testParametersNoPlane(self):
- """Test Plane.parameters with ||normal|| == 0 ."""
- plane = utils.Plane()
- plane.parameters = self.ZEROS
-
- for parameters in self.ZEROS_PARAMETERS:
- with self.subTest(parameters=parameters):
- with AssertNotificationContext(plane, count=0):
- plane.parameters = parameters
- self.assertTrue(
- numpy.allclose(plane.parameters, self.ZEROS, 0., 0.))
-
-
-# unindexArrays ###############################################################
-
-class TestUnindexArrays(ParametricTestCase):
- """Test unindexArrays function."""
-
- def testBasicModes(self):
- """Test for modes: points, lines and triangles"""
- indices = numpy.array((1, 2, 0))
- arrays = (numpy.array((0., 1., 2.)),
- numpy.array(((0, 0), (1, 1), (2, 2))))
- refresults = (numpy.array((1., 2., 0.)),
- numpy.array(((1, 1), (2, 2), (0, 0))))
-
- for mode in ('points', 'lines', 'triangles'):
- with self.subTest(mode=mode):
- testresults = utils.unindexArrays(mode, indices, *arrays)
- for ref, test in zip(refresults, testresults):
- self.assertTrue(numpy.equal(ref, test).all())
-
- def testPackedLines(self):
- """Test for modes: line_strip, loop"""
- indices = numpy.array((1, 2, 0))
- arrays = (numpy.array((0., 1., 2.)),
- numpy.array(((0, 0), (1, 1), (2, 2))))
- results = {
- 'line_strip': (
- numpy.array((1., 2., 2., 0.)),
- numpy.array(((1, 1), (2, 2), (2, 2), (0, 0)))),
- 'loop': (
- numpy.array((1., 2., 2., 0., 0., 1.)),
- numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1)))),
- }
-
- for mode, refresults in results.items():
- with self.subTest(mode=mode):
- testresults = utils.unindexArrays(mode, indices, *arrays)
- for ref, test in zip(refresults, testresults):
- self.assertTrue(numpy.equal(ref, test).all())
-
- def testPackedTriangles(self):
- """Test for modes: triangle_strip, fan"""
- indices = numpy.array((1, 2, 0, 3))
- arrays = (numpy.array((0., 1., 2., 3.)),
- numpy.array(((0, 0), (1, 1), (2, 2), (3, 3))))
- results = {
- 'triangle_strip': (
- numpy.array((1., 2., 0., 2., 0., 3.)),
- numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3)))),
- 'fan': (
- numpy.array((1., 2., 0., 1., 0., 3.)),
- numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3)))),
- }
-
- for mode, refresults in results.items():
- with self.subTest(mode=mode):
- testresults = utils.unindexArrays(mode, indices, *arrays)
- for ref, test in zip(refresults, testresults):
- self.assertTrue(numpy.equal(ref, test).all())
-
- def testBadIndices(self):
- """Test with negative indices and indices higher than array length"""
- arrays = numpy.array((0, 1)), numpy.array((0, 1, 2))
-
- # negative indices
- with self.assertRaises(AssertionError):
- utils.unindexArrays('points', (-1, 0), *arrays)
-
- # Too high indices
- with self.assertRaises(AssertionError):
- utils.unindexArrays('points', (0, 10), *arrays)
-
-
-# triangleNormals #############################################################
-
-class TestTriangleNormals(ParametricTestCase):
- """Test triangleNormals function."""
-
- def test(self):
- """Test for modes: points, lines and triangles"""
- positions = numpy.array(
- ((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), # normal = Z
- (1., 1., 1.), (1., 2., 3.), (4., 5., 6.), # Random triangle
- # Degenerated triangles:
- (0., 0., 0.), (1., 0., 0.), (2., 0., 0.), # Colinear points
- (1., 1., 1.), (1., 1., 1.), (1., 1., 1.), # All same point
- ),
- dtype='float32')
-
- normals = numpy.array(
- ((0., 0., 1.),
- (-0.40824829, 0.81649658, -0.40824829),
- (0., 0., 0.),
- (0., 0., 0.)),
- dtype='float32')
-
- testnormals = utils.trianglesNormal(positions)
- self.assertTrue(numpy.allclose(testnormals, normals))
-
-
-# suite #######################################################################
-
-def suite():
- testsuite = unittest.TestSuite()
- for test in (TestAngleBetweenVectors,
- TestPlaneParameters,
- TestUnindexArrays,
- TestTriangleNormals):
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(test))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/scene/window.py b/silx/gui/plot3d/scene/window.py
deleted file mode 100644
index baa76a2..0000000
--- a/silx/gui/plot3d/scene/window.py
+++ /dev/null
@@ -1,430 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a class for Viewports rendering on the screen.
-
-The :class:`Window` renders a list of Viewports in the current framebuffer.
-The rendering can be performed in an off-screen framebuffer that is only
-updated when the scene has changed and not each time Qt is requiring a repaint.
-
-The :class:`Context` and :class:`ContextGL2` represent the operating system
-OpenGL context and handle OpenGL resources.
-"""
-
-from __future__ import absolute_import, division, unicode_literals
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "10/01/2017"
-
-
-import weakref
-import numpy
-
-from ..._glutils import gl
-from ... import _glutils
-
-from . import event
-
-
-class Context(object):
- """Correspond to an operating system OpenGL context.
-
- User should NEVER use an instance of this class beyond the method
- it is passed to as an argument (i.e., do not keep a reference to it).
-
- :param glContextHandle: System specific OpenGL context handle.
- """
-
- def __init__(self, glContextHandle):
- self._context = glContextHandle
- self._isCurrent = False
- self._devicePixelRatio = 1.0
-
- @property
- def isCurrent(self):
- """Whether this OpenGL context is the current one or not."""
- return self._isCurrent
-
- def setCurrent(self, isCurrent=True):
- """Set the state of the OpenGL context to reflect OpenGL state.
-
- This should not be called from the scene graph, only in the
- wrapper that handle the OpenGL context to reflect its state.
-
- :param bool isCurrent: The state of the system OpenGL context.
- """
- self._isCurrent = bool(isCurrent)
-
- @property
- def devicePixelRatio(self):
- """Ratio between device and device independent pixels (float)
-
- This is useful for font rendering.
- """
- return self._devicePixelRatio
-
- @devicePixelRatio.setter
- def devicePixelRatio(self, ratio):
- assert ratio > 0
- self._devicePixelRatio = float(ratio)
-
- def __enter__(self):
- self.setCurrent(True)
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.setCurrent(False)
-
- @property
- def glContext(self):
- """The handle to the OpenGL context provided by the system."""
- return self._context
-
- def cleanGLGarbage(self):
- """This is releasing OpenGL resource that are no longer used."""
- pass
-
-
-class ContextGL2(Context):
- """Handle a system GL2 context.
-
- User should NEVER use an instance of this class beyond the method
- it is passed to as an argument (i.e., do not keep a reference to it).
-
- :param glContextHandle: System specific OpenGL context handle.
- """
- def __init__(self, glContextHandle):
- super(ContextGL2, self).__init__(glContextHandle)
-
- self._programs = {} # GL programs already compiled
- self._vbos = {} # GL Vbos already set
- self._vboGarbage = [] # Vbos waiting to be discarded
-
- # programs
-
- def prog(self, vertexShaderSrc, fragmentShaderSrc, attrib0='position'):
- """Cache program within context.
-
- WARNING: No clean-up.
-
- :param str vertexShaderSrc: Vertex shader source code
- :param str fragmentShaderSrc: Fragment shader source code
- :param str attrib0:
- Attribute's name to bind to position 0 (default: 'position').
- On some platform, this attribute MUST be active and with an
- array attached to it in order for the rendering to occur....
- """
- assert self.isCurrent
- key = vertexShaderSrc, fragmentShaderSrc, attrib0
- program = self._programs.get(key, None)
- if program is None:
- program = _glutils.Program(
- vertexShaderSrc, fragmentShaderSrc, attrib0=attrib0)
- self._programs[key] = program
- return program
-
- # VBOs
-
- def makeVbo(self, data=None, sizeInBytes=None,
- usage=None, target=None):
- """Create a VBO in this context with the data.
-
- Current limitations:
-
- - One array per VBO
- - Do not support sharing VertexBuffer across VboAttrib
-
- Automatically discards the VBO when the returned
- :class:`VertexBuffer` istance is deleted.
-
- :param numpy.ndarray data: 2D array of data to store in VBO or None.
- :param int sizeInBytes: Size of the VBO or None.
- It should be <= data.nbytes if both are given.
- :param usage: OpenGL usage define in VertexBuffer._USAGES.
- :param target: OpenGL target in VertexBuffer._TARGETS.
- :return: The VertexBuffer created in this context.
- """
- assert self.isCurrent
- vbo = _glutils.VertexBuffer(data, sizeInBytes, usage, target)
- vboref = weakref.ref(vbo, self._deadVbo)
- # weakref is hashable as far as target is
- self._vbos[vboref] = vbo.name
- return vbo
-
- def makeVboAttrib(self, data, usage=None, target=None):
- """Create a VBO from data and returns the associated VBOAttrib.
-
- Automatically discards the VBO when the returned
- :class:`VBOAttrib` istance is deleted.
-
- :param numpy.ndarray data: 2D array of data to store in VBO or None.
- :param usage: OpenGL usage define in VertexBuffer._USAGES.
- :param target: OpenGL target in VertexBuffer._TARGETS.
- :returns: A VBOAttrib instance created in this context.
- """
- assert self.isCurrent
- vbo = self.makeVbo(data, usage=usage, target=target)
-
- assert len(data.shape) <= 2
- dimension = 1 if len(data.shape) == 1 else data.shape[1]
-
- return _glutils.VertexBufferAttrib(
- vbo,
- type_=_glutils.numpyToGLType(data.dtype),
- size=data.shape[0],
- dimension=dimension,
- offset=0,
- stride=0)
-
- def _deadVbo(self, vboRef):
- """Callback handling dead VBOAttribs."""
- vboid = self._vbos.pop(vboRef)
- if self.isCurrent:
- # Direct delete if context is active
- gl.glDeleteBuffers(vboid)
- else:
- # Deferred VBO delete if context is not active
- self._vboGarbage.append(vboid)
-
- def cleanGLGarbage(self):
- """Delete OpenGL resources that are pending for destruction.
-
- This requires the associated OpenGL context to be active.
- This is meant to be called before rendering.
- """
- assert self.isCurrent
- if self._vboGarbage:
- vboids = self._vboGarbage
- gl.glDeleteBuffers(vboids)
- self._vboGarbage = []
-
-
-class Window(event.Notifier):
- """OpenGL Framebuffer where to render viewports
-
- :param str mode: Rendering mode to use:
-
- - 'direct' to render everything for each render call
- - 'framebuffer' to cache viewport rendering in a texture and
- update the texture only when needed.
- """
-
- _position = numpy.array(((-1., -1., 0., 0.),
- (1., -1., 1., 0.),
- (-1., 1., 0., 1.),
- (1., 1., 1., 1.)),
- dtype=numpy.float32)
-
- _shaders = ("""
- attribute vec4 position;
- varying vec2 textureCoord;
-
- void main(void) {
- gl_Position = vec4(position.x, position.y, 0., 1.);
- textureCoord = position.zw;
- }
- """,
- """
- uniform sampler2D texture;
- varying vec2 textureCoord;
-
- void main(void) {
- gl_FragColor = texture2D(texture, textureCoord);
- gl_FragColor.a = 1.0;
- }
- """)
-
- def __init__(self, mode='framebuffer'):
- super(Window, self).__init__()
- self._dirty = True
- self._size = 0, 0
- self._contexts = {} # To map system GL context id to Context objects
- self._viewports = event.NotifierList()
- self._viewports.addListener(self._updated)
- self._framebufferid = 0
- self._framebuffers = {} # Cache of framebuffers
-
- assert mode in ('direct', 'framebuffer')
- self._isframebuffer = mode == 'framebuffer'
-
- @property
- def dirty(self):
- """True if this object or any attached viewports is dirty."""
- for viewport in self._viewports:
- if viewport.dirty:
- return True
- return self._dirty
-
- @property
- def size(self):
- """Size (width, height) of the window in pixels"""
- return self._size
-
- @size.setter
- def size(self, size):
- w, h = size
- size = int(w), int(h)
- if size != self._size:
- self._size = size
- self._dirty = True
- self.notify()
-
- @property
- def shape(self):
- """Shape (height, width) of the window in pixels.
-
- This is a convenient wrapper to the reverse of size.
- """
- return self._size[1], self._size[0]
-
- @shape.setter
- def shape(self, shape):
- self.size = shape[1], shape[0]
-
- @property
- def viewports(self):
- """List of viewports to render in the corresponding framebuffer"""
- return self._viewports
-
- @viewports.setter
- def viewports(self, iterable):
- self._viewports.removeListener(self._updated)
- self._viewports = event.NotifierList(iterable)
- self._viewports.addListener(self._updated)
- self._updated(self)
-
- def _updated(self, source, *args, **kwargs):
- self._dirty = True
- self.notify(*args, **kwargs)
-
- framebufferid = property(lambda self: self._framebufferid,
- doc="Framebuffer ID used to perform rendering")
-
- def grab(self, glcontext):
- """Returns the raster of the scene as an RGB numpy array
-
- :returns: OpenGL scene RGB bitmap
- as an array of dimension (height, width, 3)
- :rtype: numpy.ndarray of uint8
- """
- height, width = self.shape
- image = numpy.empty((height, width, 3), dtype=numpy.uint8)
-
- previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
- gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebufferid)
- gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
- gl.glReadPixels(
- 0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image)
- gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
-
- # glReadPixels gives bottom to top,
- # while images are stored as top to bottom
- image = numpy.flipud(image)
-
- return numpy.array(image, copy=False, order='C')
-
- def render(self, glcontext, devicePixelRatio):
- """Perform the rendering of attached viewports
-
- :param glcontext: System identifier of the OpenGL context
- :param float devicePixelRatio:
- Ratio between device and device-independent pixels
- """
- if glcontext not in self._contexts:
- self._contexts[glcontext] = ContextGL2(glcontext) # New context
-
- with self._contexts[glcontext] as context:
- context.devicePixelRatio = devicePixelRatio
- if self._isframebuffer:
- self._renderWithOffscreenFramebuffer(context)
- else:
- self._renderDirect(context)
-
- self._dirty = False
-
- def _renderDirect(self, context):
- """Perform the direct rendering of attached viewports
-
- :param Context context: Object wrapping OpenGL context
- """
- for viewport in self._viewports:
- viewport.framebuffer = self.framebufferid
- viewport.render(context)
- viewport.resetDirty()
-
- def _renderWithOffscreenFramebuffer(self, context):
- """Renders viewports in a texture and render this texture on screen.
-
- The texture is updated only if viewport or size has changed.
-
- :param ContextGL2 context: Object wrappign OpenGL context
- """
- if self.dirty or context not in self._framebuffers:
- # Need to redraw framebuffer content
-
- if (context not in self._framebuffers or
- self._framebuffers[context].shape != self.shape):
- # Need to rebuild framebuffer
-
- if context in self._framebuffers:
- self._framebuffers[context].discard()
-
- fbo = _glutils.FramebufferTexture(gl.GL_RGBA,
- shape=self.shape,
- minFilter=gl.GL_NEAREST,
- magFilter=gl.GL_NEAREST,
- wrap=gl.GL_CLAMP_TO_EDGE)
- self._framebuffers[context] = fbo
- self._framebufferid = fbo.name
-
- # Render in framebuffer
- with self._framebuffers[context]:
- self._renderDirect(context)
-
- # Render framebuffer texture to screen
- fbo = self._framebuffers[context]
- height, width = fbo.shape
-
- program = context.prog(*self._shaders)
- program.use()
-
- gl.glViewport(0, 0, width, height)
- gl.glDisable(gl.GL_BLEND)
- gl.glDisable(gl.GL_DEPTH_TEST)
- gl.glDisable(gl.GL_SCISSOR_TEST)
- # gl.glScissor(0, 0, width, height)
- gl.glClearColor(0., 0., 0., 0.)
- gl.glClear(gl.GL_COLOR_BUFFER_BIT)
- gl.glUniform1i(program.uniforms['texture'], fbo.texture.texUnit)
- gl.glEnableVertexAttribArray(program.attributes['position'])
- gl.glVertexAttribPointer(program.attributes['position'],
- 4,
- gl.GL_FLOAT,
- gl.GL_FALSE,
- 0,
- self._position)
- fbo.texture.bind()
- gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._position))
- gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
diff --git a/silx/gui/plot3d/test/__init__.py b/silx/gui/plot3d/test/__init__.py
deleted file mode 100644
index 77172d1..0000000
--- a/silx/gui/plot3d/test/__init__.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""plot3d test suite."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "09/11/2017"
-
-
-import logging
-import unittest
-from silx.test.utils import test_options
-
-
-_logger = logging.getLogger(__name__)
-
-
-def suite():
- testsuite = unittest.TestSuite()
-
- if not test_options.WITH_GL_TEST:
- # Explicitly disabled tests
- msg = "silx.gui.plot3d tests disabled: %s" % test_options.WITH_GL_TEST_REASON
- _logger.warning(msg)
-
- class SkipPlot3DTest(unittest.TestCase):
- def runTest(self):
- self.skipTest(test_options.WITH_GL_TEST_REASON)
-
- testsuite.addTest(SkipPlot3DTest())
- return testsuite
-
- # Import here to avoid loading modules if tests are disabled
-
- from ..scene.test import suite as sceneTestSuite
- from ..tools.test import suite as toolsTestSuite
- from .testGL import suite as testGLSuite
- from .testScalarFieldView import suite as testScalarFieldViewSuite
- from .testSceneWidget import suite as testSceneWidgetSuite
- from .testSceneWidgetPicking import suite as testSceneWidgetPickingSuite
- from .testSceneWindow import suite as testSceneWindowSuite
- from .testStatsWidget import suite as testStatsWidgetSuite
-
- testsuite = unittest.TestSuite()
- testsuite.addTest(testGLSuite())
- testsuite.addTest(sceneTestSuite())
- testsuite.addTest(testScalarFieldViewSuite())
- testsuite.addTest(testSceneWidgetSuite())
- testsuite.addTest(testSceneWidgetPickingSuite())
- testsuite.addTest(testSceneWindowSuite())
- testsuite.addTest(toolsTestSuite())
- testsuite.addTest(testStatsWidgetSuite())
- return testsuite
diff --git a/silx/gui/plot3d/test/testGL.py b/silx/gui/plot3d/test/testGL.py
deleted file mode 100644
index ae167ab..0000000
--- a/silx/gui/plot3d/test/testGL.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test OpenGL"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "10/08/2017"
-
-
-import logging
-import unittest
-
-from silx.gui._glutils import gl, OpenGLWidget
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestOpenGL(TestCaseQt):
- """Tests of OpenGL widget."""
-
- class OpenGLWidgetLogger(OpenGLWidget):
- """Widget logging information of available OpenGL version"""
-
- def __init__(self):
- self._dump = False
- super(TestOpenGL.OpenGLWidgetLogger, self).__init__(version=(1, 0))
-
- def paintOpenGL(self):
- """Perform the rendering and logging"""
- if not self._dump:
- self._dump = True
- _logger.info('OpenGL info:')
- _logger.info('\tQt OpenGL context version: %d.%d', *self.getOpenGLVersion())
- _logger.info('\tGL_VERSION: %s' % gl.glGetString(gl.GL_VERSION))
- _logger.info('\tGL_SHADING_LANGUAGE_VERSION: %s' %
- gl.glGetString(gl.GL_SHADING_LANGUAGE_VERSION))
- _logger.debug('\tGL_EXTENSIONS: %s' % gl.glGetString(gl.GL_EXTENSIONS))
-
- gl.glClearColor(1., 1., 1., 1.)
- gl.glClear(gl.GL_COLOR_BUFFER_BIT)
-
- def testOpenGL(self):
- """Log OpenGL version using an OpenGLWidget"""
- super(TestOpenGL, self).setUp()
- widget = self.OpenGLWidgetLogger()
- widget.show()
- widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.qWaitForWindowExposed(widget)
- widget.close()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestOpenGL))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/test/testScalarFieldView.py b/silx/gui/plot3d/test/testScalarFieldView.py
deleted file mode 100644
index d9c743b..0000000
--- a/silx/gui/plot3d/test/testScalarFieldView.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test ScalarFieldView widget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import logging
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
-from silx.gui.plot3d.SFViewParamTree import TreeView
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestScalarFieldView(TestCaseQt, ParametricTestCase):
- """Tests of ScalarFieldView widget."""
-
- def setUp(self):
- super(TestScalarFieldView, self).setUp()
- self.widget = ScalarFieldView()
- self.widget.show()
-
- paramTreeWidget = TreeView()
- paramTreeWidget.setSfView(self.widget)
-
- dock = qt.QDockWidget()
- dock.setWidget(paramTreeWidget)
- self.widget.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
-
- # Commented as it slows down the tests
- # self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- del self.widget
- super(TestScalarFieldView, self).tearDown()
-
- @staticmethod
- def _buildData(size):
- """Make a 3D dataset"""
- coords = numpy.linspace(-10, 10, size)
- z = coords.reshape(-1, 1, 1)
- y = coords.reshape(1, -1, 1)
- x = coords.reshape(1, 1, -1)
- return numpy.sin(x * y * z) / (x * y * z)
-
- def testSimple(self):
- """Set the data and an isosurface"""
- data = self._buildData(size=32)
-
- self.widget.setData(data)
- self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
- self.widget.addIsosurface(0.7, qt.QColor('green'))
- self.qapp.processEvents()
-
- def testNotFinite(self):
- """Test with NaN and inf in data set"""
-
- # Some NaNs and inf
- data = self._buildData(size=32)
- data[8, :, :] = numpy.nan
- data[16, :, :] = numpy.inf
- data[24, :, :] = - numpy.inf
-
- self.widget.addIsosurface(0.5, 'red')
- self.widget.setData(data, copy=True)
- self.qapp.processEvents()
- self.widget.setData(None)
-
- # All NaNs or inf
- data = numpy.empty((4, 4, 4), dtype=numpy.float32)
- for value in (numpy.nan, numpy.inf):
- with self.subTest(value=str(value)):
- data[:] = value
- self.widget.setData(data, copy=True)
- self.qapp.processEvents()
-
- def testIsoSliderNormalization(self):
- """Test set TreeView with a different isoslider normalization"""
- data = self._buildData(size=32)
-
- self.widget.setData(data)
- self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
- self.widget.addIsosurface(0.7, qt.QColor('green'))
- self.qapp.processEvents()
-
- # Add a second TreeView
- paramTreeWidget = TreeView(self.widget)
- paramTreeWidget.setIsoLevelSliderNormalization('arcsinh')
- paramTreeWidget.setSfView(self.widget)
-
- dock = qt.QDockWidget()
- dock.setWidget(paramTreeWidget)
- self.widget.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestScalarFieldView))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/test/testSceneWidget.py b/silx/gui/plot3d/test/testSceneWidget.py
deleted file mode 100644
index 13ddd37..0000000
--- a/silx/gui/plot3d/test/testSceneWidget.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test SceneWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "06/03/2019"
-
-
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-from silx.gui.plot3d.SceneWidget import SceneWidget
-
-
-class TestSceneWidget(TestCaseQt, ParametricTestCase):
- """Tests SceneWidget picking feature"""
-
- def setUp(self):
- super(TestSceneWidget, self).setUp()
- self.widget = SceneWidget()
- self.widget.show()
- self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- del self.widget
- super(TestSceneWidget, self).tearDown()
-
- def testFogEffect(self):
- """Test fog effect on scene primitive"""
- image = self.widget.addImage(numpy.arange(100).reshape(10, 10))
- scatter = self.widget.add3DScatter(*numpy.random.random(4000).reshape(4, -1))
- scatter.setTranslation(10, 10)
- scatter.setScale(10, 10, 10)
-
- self.widget.resetZoom('front')
- self.qapp.processEvents()
-
- self.widget.setFogMode(self.widget.FogMode.LINEAR)
- self.qapp.processEvents()
-
- self.widget.setFogMode(self.widget.FogMode.NONE)
- self.qapp.processEvents()
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestSceneWidget))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/test/testSceneWidgetPicking.py b/silx/gui/plot3d/test/testSceneWidgetPicking.py
deleted file mode 100644
index aea30f6..0000000
--- a/silx/gui/plot3d/test/testSceneWidgetPicking.py
+++ /dev/null
@@ -1,326 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test SceneWidget picking feature"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/10/2018"
-
-
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-from silx.gui.plot3d.SceneWidget import SceneWidget, items
-
-
-class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
- """Tests SceneWidget picking feature"""
-
- def setUp(self):
- super(TestSceneWidgetPicking, self).setUp()
- self.widget = SceneWidget()
- self.widget.resize(300, 300)
- self.widget.show()
- # self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- del self.widget
- super(TestSceneWidgetPicking, self).tearDown()
-
- def _widgetCenter(self):
- """Returns widget center"""
- size = self.widget.size()
- return size.width() // 2, size.height() // 2
-
- def testPickImage(self):
- """Test picking of ImageData and ImageRgba items"""
- imageData = items.ImageData()
- imageData.setData(numpy.arange(100).reshape(10, 10))
-
- imageRgba = items.ImageRgba()
- imageRgba.setData(
- numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
-
- for item in (imageData, imageRgba):
- with self.subTest(item=item.__class__.__name__):
- # Add item
- self.widget.clearItems()
- self.widget.addItem(item)
- self.widget.resetZoom('front')
- self.qapp.processEvents()
-
- # Picking on data (at widget center)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
-
- self.assertEqual(len(picking), 1)
- self.assertIs(picking[0].getItem(), item)
- self.assertEqual(picking[0].getPositions('ndc').shape, (1, 3))
- data = picking[0].getData()
- self.assertEqual(len(data), 1)
- self.assertTrue(numpy.array_equal(
- data,
- item.getData()[picking[0].getIndices()]))
-
- # Picking outside data
- picking = list(self.widget.pickItems(1, 1))
- self.assertEqual(len(picking), 0)
-
- def testPickScatter(self):
- """Test picking of Scatter2D and Scatter3D items"""
- data = numpy.arange(100)
-
- scatter2d = items.Scatter2D()
- scatter2d.setData(x=data, y=data, value=data)
-
- scatter3d = items.Scatter3D()
- scatter3d.setData(x=data, y=data, z=data, value=data)
-
- for item in (scatter2d, scatter3d):
- with self.subTest(item=item.__class__.__name__):
- # Add item
- self.widget.clearItems()
- self.widget.addItem(item)
- self.widget.resetZoom('front')
- self.qapp.processEvents()
-
- # Picking on data (at widget center)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
-
- self.assertEqual(len(picking), 1)
- self.assertIs(picking[0].getItem(), item)
- nbPos = len(picking[0].getPositions('ndc'))
- data = picking[0].getData()
- self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getValueData()[picking[0].getIndices()]))
-
- # Picking outside data
- picking = list(self.widget.pickItems(1, 1))
- self.assertEqual(len(picking), 0)
-
- def testPickVolume(self):
- """Test picking of volume CutPlane and Isosurface items"""
- for dtype in (numpy.float32, numpy.complex64):
- with self.subTest(dtype=dtype):
- refData = numpy.arange(10**3, dtype=dtype).reshape(10, 10, 10)
- volume = self.widget.addVolume(refData)
- if dtype == numpy.complex64:
- volume.setComplexMode(volume.ComplexMode.REAL)
- refData = numpy.real(refData)
- self.widget.resetZoom('front')
-
- cutplane = volume.getCutPlanes()[0]
- if dtype == numpy.complex64:
- cutplane.setComplexMode(volume.ComplexMode.REAL)
- cutplane.getColormap().setVRange(0, 100)
- cutplane.setNormal((0, 0, 1))
-
- # Picking on data without anything displayed
- cutplane.setVisible(False)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
- self.assertEqual(len(picking), 0)
-
- # Picking on data with the cut plane
- cutplane.setVisible(True)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
-
- self.assertEqual(len(picking), 1)
- self.assertIs(picking[0].getItem(), cutplane)
- data = picking[0].getData()
- self.assertEqual(len(data), 1)
- self.assertEqual(picking[0].getPositions().shape, (1, 3))
- self.assertTrue(numpy.array_equal(
- data,
- refData[picking[0].getIndices()]))
-
- # Picking on data with an isosurface
- isosurface = volume.addIsosurface(
- level=500, color=(1., 0., 0., .5))
- picking = list(self.widget.pickItems(*self._widgetCenter()))
- self.assertEqual(len(picking), 2)
- self.assertIs(picking[0].getItem(), cutplane)
- self.assertIs(picking[1].getItem(), isosurface)
- self.assertEqual(picking[1].getPositions().shape, (1, 3))
- data = picking[1].getData()
- self.assertEqual(len(data), 1)
- self.assertTrue(numpy.array_equal(
- data,
- refData[picking[1].getIndices()]))
-
- # Picking outside data
- picking = list(self.widget.pickItems(1, 1))
- self.assertEqual(len(picking), 0)
-
- self.widget.clearItems()
-
- def testPickMesh(self):
- """Test picking of Mesh items"""
-
- triangles = items.Mesh()
- triangles.setData(
- position=((0, 0, 0), (1, 0, 0), (1, 1, 0),
- (0, 0, 0), (1, 1, 0), (0, 1, 0)),
- color=(1, 0, 0, 1),
- mode='triangles')
- triangleStrip = items.Mesh()
- triangleStrip.setData(
- position=(((1, 0, 0), (0, 0, 0), (1, 1, 0), (0, 1, 0))),
- color=(0, 1, 0, 1),
- mode='triangle_strip')
- triangleFan = items.Mesh()
- triangleFan.setData(
- position=((0, 0, 0), (1, 0, 0), (1, 1, 0), (0, 1, 0)),
- color=(0, 0, 1, 1),
- mode='fan')
-
- for item in (triangles, triangleStrip, triangleFan):
- with self.subTest(mode=item.getDrawMode()):
- # Add item
- self.widget.clearItems()
- self.widget.addItem(item)
- self.widget.resetZoom('front')
- self.qapp.processEvents()
-
- # Picking on data (at widget center)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
-
- self.assertEqual(len(picking), 1)
- self.assertIs(picking[0].getItem(), item)
- nbPos = len(picking[0].getPositions())
- data = picking[0].getData()
- self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getPositionData()[picking[0].getIndices()]))
-
- # Picking outside data
- picking = list(self.widget.pickItems(1, 1))
- self.assertEqual(len(picking), 0)
-
- def testPickMeshWithIndices(self):
- """Test picking of Mesh items defined by indices"""
-
- triangles = items.Mesh()
- triangles.setData(
- position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
- color=(1, 0, 0, 1),
- indices=numpy.array( # dummy triangles and square
- (0, 0, 1, 0, 1, 2, 1, 2, 3), dtype=numpy.uint8),
- mode='triangles')
- triangleStrip = items.Mesh()
- triangleStrip.setData(
- position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
- color=(0, 1, 0, 1),
- indices=numpy.array( # dummy triangles and square
- (1, 0, 0, 1, 2, 3), dtype=numpy.uint8),
- mode='triangle_strip')
- triangleFan = items.Mesh()
- triangleFan.setData(
- position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
- color=(0, 0, 1, 1),
- indices=numpy.array( # dummy triangle, square, dummy
- (1, 1, 0, 2, 3, 3), dtype=numpy.uint8),
- mode='fan')
-
- for item in (triangles, triangleStrip, triangleFan):
- with self.subTest(mode=item.getDrawMode()):
- # Add item
- self.widget.clearItems()
- self.widget.addItem(item)
- self.widget.resetZoom('front')
- self.qapp.processEvents()
-
- # Picking on data (at widget center)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
-
- self.assertEqual(len(picking), 1)
- self.assertIs(picking[0].getItem(), item)
- nbPos = len(picking[0].getPositions())
- data = picking[0].getData()
- self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getPositionData()[picking[0].getIndices()]))
-
- # Picking outside data
- picking = list(self.widget.pickItems(1, 1))
- self.assertEqual(len(picking), 0)
-
- def testPickCylindricalMesh(self):
- """Test picking of Box, Cylinder and Hexagon items"""
-
- positions = numpy.array(((0., 0., 0.), (1., 1., 0.), (2., 2., 0.)))
- box = items.Box()
- box.setData(position=positions)
- cylinder = items.Cylinder()
- cylinder.setData(position=positions)
- hexagon = items.Hexagon()
- hexagon.setData(position=positions)
-
- for item in (box, cylinder, hexagon):
- with self.subTest(item=item.__class__.__name__):
- # Add item
- self.widget.clearItems()
- self.widget.addItem(item)
- self.widget.resetZoom('front')
- self.qapp.processEvents()
-
- # Picking on data (at widget center)
- picking = list(self.widget.pickItems(*self._widgetCenter()))
-
- self.assertEqual(len(picking), 1)
- self.assertIs(picking[0].getItem(), item)
- nbPos = len(picking[0].getPositions())
- data = picking[0].getData()
- print(item.__class__.__name__, [positions[1]], data)
- self.assertTrue(numpy.all(numpy.equal(positions[1], data)))
- self.assertEqual(nbPos, len(data))
- self.assertTrue(numpy.array_equal(
- data,
- item.getPosition()[picking[0].getIndices()]))
-
- # Picking outside data
- picking = list(self.widget.pickItems(1, 1))
- self.assertEqual(len(picking), 0)
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestSceneWidgetPicking))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/test/testSceneWindow.py b/silx/gui/plot3d/test/testSceneWindow.py
deleted file mode 100644
index 8cf6b81..0000000
--- a/silx/gui/plot3d/test/testSceneWindow.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test SceneWindow"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "22/03/2019"
-
-
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-from silx.gui.plot3d.SceneWindow import SceneWindow
-from silx.gui.plot3d.items import HeightMapData, HeightMapRGBA
-
-class TestSceneWindow(TestCaseQt, ParametricTestCase):
- """Tests SceneWidget picking feature"""
-
- def setUp(self):
- super(TestSceneWindow, self).setUp()
- self.window = SceneWindow()
- self.window.show()
- self.qWaitForWindowExposed(self.window)
-
- def tearDown(self):
- self.qapp.processEvents()
- self.window.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.window.close()
- del self.window
- super(TestSceneWindow, self).tearDown()
-
- def testAdd(self):
- """Test add basic scene primitive"""
- sceneWidget = self.window.getSceneWidget()
- items = []
-
- # RGB image
- image = sceneWidget.addImage(numpy.random.random(
- 10*10*3).astype(numpy.float32).reshape(10, 10, 3))
- image.setLabel('RGB image')
- items.append(image)
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- # Data image
- image = sceneWidget.addImage(
- numpy.arange(100, dtype=numpy.float32).reshape(10, 10))
- image.setTranslation(10.)
- items.append(image)
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- # 2D scatter
- scatter = sceneWidget.add2DScatter(
- *numpy.random.random(3000).astype(numpy.float32).reshape(3, -1),
- index=0)
- scatter.setTranslation(0, 10)
- scatter.setScale(10, 10, 10)
- items.insert(0, scatter)
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- # 3D scatter
- scatter = sceneWidget.add3DScatter(
- *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1))
- scatter.setTranslation(10, 10)
- scatter.setScale(10, 10, 10)
- items.append(scatter)
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- # 3D array of float
- volume = sceneWidget.addVolume(
- numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10))
- volume.setTranslation(0, 0, 10)
- volume.setRotation(45, (0, 0, 1))
- volume.addIsosurface(500, 'red')
- volume.getCutPlanes()[0].getColormap().setName('viridis')
- items.append(volume)
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- # 3D array of complex
- volume = sceneWidget.addVolume(
- numpy.arange(10**3).reshape(10, 10, 10).astype(numpy.complex64))
- volume.setTranslation(10, 0, 10)
- volume.setRotation(45, (0, 0, 1))
- volume.setComplexMode(volume.ComplexMode.REAL)
- volume.addIsosurface(500, (1., 0., 0., .5))
- items.append(volume)
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- sceneWidget.resetZoom('front')
- self.qapp.processEvents()
-
- def testHeightMap(self):
- """Test height map items"""
- sceneWidget = self.window.getSceneWidget()
-
- height = numpy.arange(10000).reshape(100, 100) /100.
-
- for shape in ((100, 100), (4, 5), (150, 20), (110, 110)):
- with self.subTest(shape=shape):
- items = []
-
- # Colormapped data height map
- data = numpy.arange(numpy.prod(shape)).astype(numpy.float32).reshape(shape)
-
- heightmap = HeightMapData()
- heightmap.setData(height)
- heightmap.setColormappedData(data)
- heightmap.getColormap().setName('viridis')
- items.append(heightmap)
- sceneWidget.addItem(heightmap)
-
- # RGBA height map
- colors = numpy.zeros(shape + (3,), dtype=numpy.float32)
- colors[:, :, 1] = numpy.random.random(shape)
-
- heightmap = HeightMapRGBA()
- heightmap.setData(height)
- heightmap.setColorData(colors)
- heightmap.setTranslation(100., 0., 0.)
- items.append(heightmap)
- sceneWidget.addItem(heightmap)
-
- self.assertEqual(sceneWidget.getItems(), tuple(items))
- sceneWidget.resetZoom('front')
- self.qapp.processEvents()
- sceneWidget.clearItems()
-
- def testChangeContent(self):
- """Test add/remove/clear items"""
- sceneWidget = self.window.getSceneWidget()
- items = []
-
- # Add 2 images
- image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
- items.append(sceneWidget.addImage(image))
- items.append(sceneWidget.addImage(image))
- self.qapp.processEvents()
- self.assertEqual(sceneWidget.getItems(), tuple(items))
-
- # Clear
- sceneWidget.clearItems()
- self.qapp.processEvents()
- self.assertEqual(sceneWidget.getItems(), ())
-
- # Add 2 images and remove first one
- image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
- sceneWidget.addImage(image)
- items = (sceneWidget.addImage(image),)
- self.qapp.processEvents()
-
- sceneWidget.removeItem(sceneWidget.getItems()[0])
- self.qapp.processEvents()
- self.assertEqual(sceneWidget.getItems(), items)
-
- def testColors(self):
- """Test setting scene colors"""
- sceneWidget = self.window.getSceneWidget()
-
- color = qt.QColor(128, 128, 128)
- sceneWidget.setBackgroundColor(color)
- self.assertEqual(sceneWidget.getBackgroundColor(), color)
-
- color = qt.QColor(0, 0, 0)
- sceneWidget.setForegroundColor(color)
- self.assertEqual(sceneWidget.getForegroundColor(), color)
-
- color = qt.QColor(255, 0, 0)
- sceneWidget.setTextColor(color)
- self.assertEqual(sceneWidget.getTextColor(), color)
-
- color = qt.QColor(0, 255, 0)
- sceneWidget.setHighlightColor(color)
- self.assertEqual(sceneWidget.getHighlightColor(), color)
-
- self.qapp.processEvents()
-
- def testInteractiveMode(self):
- """Test changing interactive mode"""
- sceneWidget = self.window.getSceneWidget()
- center = numpy.array((sceneWidget.width() //2, sceneWidget.height() // 2))
-
- self.mouseMove(sceneWidget, pos=center)
- self.mouseClick(sceneWidget, qt.Qt.LeftButton, pos=center)
-
- volume = sceneWidget.addVolume(
- numpy.arange(10**3).astype(numpy.float32).reshape(10, 10, 10))
- sceneWidget.selection().setCurrentItem( volume.getCutPlanes()[0])
- sceneWidget.resetZoom('side')
-
- for mode in (None, 'rotate', 'pan', 'panSelectedPlane'):
- with self.subTest(mode=mode):
- sceneWidget.setInteractiveMode(mode)
- self.qapp.processEvents()
- self.assertEqual(sceneWidget.getInteractiveMode(), mode)
-
- self.mouseMove(sceneWidget, pos=center)
- self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
- self.mouseMove(sceneWidget, pos=center-10)
- self.mouseMove(sceneWidget, pos=center-20)
- self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
-
- self.keyPress(sceneWidget, qt.Qt.Key_Control)
- self.mouseMove(sceneWidget, pos=center)
- self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
- self.mouseMove(sceneWidget, pos=center-10)
- self.mouseMove(sceneWidget, pos=center-20)
- self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
- self.keyRelease(sceneWidget, qt.Qt.Key_Control)
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestSceneWindow))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/test/testStatsWidget.py b/silx/gui/plot3d/test/testStatsWidget.py
deleted file mode 100644
index bcab1a4..0000000
--- a/silx/gui/plot3d/test/testStatsWidget.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test silx.gui.plot.StatsWidget with SceneWidget and ScalarFieldView"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "25/01/2019"
-
-
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.plot.stats.stats import Stats
-from silx.gui import qt
-
-from silx.gui.plot.StatsWidget import BasicStatsWidget
-
-from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
-from silx.gui.plot3d.SceneWidget import SceneWidget, items
-
-
-class TestSceneWidget(TestCaseQt, ParametricTestCase):
- """Tests StatsWidget combined with SceneWidget"""
-
- def setUp(self):
- super(TestSceneWidget, self).setUp()
- self.sceneWidget = SceneWidget()
- self.sceneWidget.resize(300, 300)
- self.sceneWidget.show()
- self.statsWidget = BasicStatsWidget()
- self.statsWidget.setPlot(self.sceneWidget)
- # self.qWaitForWindowExposed(self.sceneWidget)
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.qapp.processEvents()
- self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.sceneWidget.close()
- del self.sceneWidget
- self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.statsWidget.close()
- del self.statsWidget
- super(TestSceneWidget, self).tearDown()
-
- def test(self):
- """Test StatsWidget with SceneWidget"""
- # Prepare scene
-
- # Data image
- image = self.sceneWidget.addImage(numpy.arange(100).reshape(10, 10))
- image.setLabel('Image')
- # RGB image
- imageRGB = self.sceneWidget.addImage(
- numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
- imageRGB.setLabel('RGB Image')
- # 2D scatter
- data = numpy.arange(100)
- scatter2D = self.sceneWidget.add2DScatter(x=data, y=data, value=data)
- scatter2D.setLabel('2D Scatter')
- # 3D scatter
- scatter3D = self.sceneWidget.add3DScatter(x=data, y=data, z=data, value=data)
- scatter3D.setLabel('3D Scatter')
- # Add a group
- group = items.GroupItem()
- self.sceneWidget.addItem(group)
- # 3D scalar field
- data = numpy.arange(64**3).reshape(64, 64, 64)
- scalarField = items.ScalarField3D()
- scalarField.setData(data, copy=False)
- scalarField.setLabel('3D Scalar field')
- group.addItem(scalarField)
-
- statsTable = self.statsWidget._getStatsTable()
-
- # Test selection only
- self.statsWidget.setDisplayOnlyActiveItem(True)
- self.assertEqual(statsTable.rowCount(), 0)
-
- self.sceneWidget.selection().setCurrentItem(group)
- self.assertEqual(statsTable.rowCount(), 0)
-
- for item in (image, scatter2D, scatter3D, scalarField):
- with self.subTest('selection only', item=item.getLabel()):
- self.sceneWidget.selection().setCurrentItem(item)
- self.assertEqual(statsTable.rowCount(), 1)
- self._checkItem(item)
-
- # Test all data
- self.statsWidget.setDisplayOnlyActiveItem(False)
- self.assertEqual(statsTable.rowCount(), 4)
-
- for item in (image, scatter2D, scatter3D, scalarField):
- with self.subTest('all items', item=item.getLabel()):
- self._checkItem(item)
-
- def _checkItem(self, item):
- """Check that item is in StatsTable and that stats are OK
-
- :param silx.gui.plot3d.items.Item3D item:
- """
- if isinstance(item, (items.Scatter2D, items.Scatter3D)):
- data = item.getValueData(copy=False)
- else:
- data = item.getData(copy=False)
-
- statsTable = self.statsWidget._getStatsTable()
- tableItems = statsTable._itemToTableItems(item)
- self.assertTrue(len(tableItems) > 0)
- self.assertEqual(tableItems['legend'].text(), item.getLabel())
- self.assertEqual(float(tableItems['min'].text()), numpy.min(data))
- self.assertEqual(float(tableItems['max'].text()), numpy.max(data))
- # TODO
-
-
-class TestScalarFieldView(TestCaseQt):
- """Tests StatsWidget combined with ScalarFieldView"""
-
- def setUp(self):
- super(TestScalarFieldView, self).setUp()
- self.scalarFieldView = ScalarFieldView()
- self.scalarFieldView.resize(300, 300)
- self.scalarFieldView.show()
- self.statsWidget = BasicStatsWidget()
- self.statsWidget.setPlot(self.scalarFieldView)
- # self.qWaitForWindowExposed(self.sceneWidget)
-
- def tearDown(self):
- Stats._getContext.cache_clear()
- self.qapp.processEvents()
- self.scalarFieldView.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.scalarFieldView.close()
- del self.scalarFieldView
- self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.statsWidget.close()
- del self.statsWidget
- super(TestScalarFieldView, self).tearDown()
-
- def _getTextFor(self, row, name):
- """Returns text in table at given row for column name
-
- :param int row: Row number in the table
- :param str name: Column id
- :rtype: Union[str,None]
- """
- statsTable = self.statsWidget._getStatsTable()
-
- for column in range(statsTable.columnCount()):
- headerItem = statsTable.horizontalHeaderItem(column)
- if headerItem.data(qt.Qt.UserRole) == name:
- tableItem = statsTable.item(row, column)
- return tableItem.text()
-
- return None
-
- def test(self):
- """Test StatsWidget with ScalarFieldView"""
- data = numpy.arange(64**3, dtype=numpy.float64).reshape(64, 64, 64)
- self.scalarFieldView.setData(data)
-
- statsTable = self.statsWidget._getStatsTable()
-
- # Test selection only
- self.statsWidget.setDisplayOnlyActiveItem(True)
- self.assertEqual(statsTable.rowCount(), 1)
-
- # Test all data
- self.statsWidget.setDisplayOnlyActiveItem(False)
- self.assertEqual(statsTable.rowCount(), 1)
-
- for column in range(statsTable.columnCount()):
- self.assertEqual(float(self._getTextFor(0, 'min')), numpy.min(data))
- self.assertEqual(float(self._getTextFor(0, 'max')), numpy.max(data))
- sum_ = numpy.sum(data)
- comz = numpy.sum(numpy.arange(data.shape[0]) * numpy.sum(data, axis=(1, 2))) / sum_
- comy = numpy.sum(numpy.arange(data.shape[1]) * numpy.sum(data, axis=(0, 2))) / sum_
- comx = numpy.sum(numpy.arange(data.shape[2]) * numpy.sum(data, axis=(0, 1))) / sum_
- self.assertEqual(self._getTextFor(0, 'COM'), str((comx, comy, comz)))
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestSceneWidget))
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestScalarFieldView))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/plot3d/tools/GroupPropertiesWidget.py b/silx/gui/plot3d/tools/GroupPropertiesWidget.py
deleted file mode 100644
index ec995a3..0000000
--- a/silx/gui/plot3d/tools/GroupPropertiesWidget.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-""":class:`GroupPropertiesWidget` allows to reset properties in a GroupItem."""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-from ....gui import qt
-from ....gui.colors import Colormap
-from ....gui.dialog.ColormapDialog import ColormapDialog
-
-from ..items import SymbolMixIn, ColormapMixIn
-
-
-class GroupPropertiesWidget(qt.QWidget):
- """Set properties of all items in a :class:`GroupItem`
-
- :param QWidget parent:
- """
-
- MAX_MARKER_SIZE = 20
- """Maximum value for marker size"""
-
- MAX_LINE_WIDTH = 10
- """Maximum value for line width"""
-
- def __init__(self, parent=None):
- super(GroupPropertiesWidget, self).__init__(parent)
- self._group = None
- self.setEnabled(False)
-
- # Set widgets
- layout = qt.QFormLayout(self)
- self.setLayout(layout)
-
- # Colormap
- colormapButton = qt.QPushButton('Set...')
- colormapButton.setToolTip("Set colormap for all items")
- colormapButton.clicked.connect(self._colormapButtonClicked)
- layout.addRow('Colormap', colormapButton)
-
- self._markerComboBox = qt.QComboBox(self)
- self._markerComboBox.addItems(SymbolMixIn.getSupportedSymbolNames())
-
- # Marker
- markerButton = qt.QPushButton('Set')
- markerButton.setToolTip("Set marker for all items")
- markerButton.clicked.connect(self._markerButtonClicked)
-
- markerLayout = qt.QHBoxLayout()
- markerLayout.setContentsMargins(0, 0, 0, 0)
- markerLayout.addWidget(self._markerComboBox, 1)
- markerLayout.addWidget(markerButton, 0)
-
- layout.addRow('Marker', markerLayout)
-
- # Marker size
- self._markerSizeSlider = qt.QSlider()
- self._markerSizeSlider.setOrientation(qt.Qt.Horizontal)
- self._markerSizeSlider.setSingleStep(1)
- self._markerSizeSlider.setRange(1, self.MAX_MARKER_SIZE)
- self._markerSizeSlider.setValue(1)
-
- markerSizeButton = qt.QPushButton('Set')
- markerSizeButton.setToolTip("Set marker size for all items")
- markerSizeButton.clicked.connect(self._markerSizeButtonClicked)
-
- markerSizeLayout = qt.QHBoxLayout()
- markerSizeLayout.setContentsMargins(0, 0, 0, 0)
- markerSizeLayout.addWidget(qt.QLabel('1'))
- markerSizeLayout.addWidget(self._markerSizeSlider, 1)
- markerSizeLayout.addWidget(qt.QLabel(str(self.MAX_MARKER_SIZE)))
- markerSizeLayout.addWidget(markerSizeButton, 0)
-
- layout.addRow('Marker Size', markerSizeLayout)
-
- # Line width
- self._lineWidthSlider = qt.QSlider()
- self._lineWidthSlider.setOrientation(qt.Qt.Horizontal)
- self._lineWidthSlider.setSingleStep(1)
- self._lineWidthSlider.setRange(1, self.MAX_LINE_WIDTH)
- self._lineWidthSlider.setValue(1)
-
- lineWidthButton = qt.QPushButton('Set')
- lineWidthButton.setToolTip("Set line width for all items")
- lineWidthButton.clicked.connect(self._lineWidthButtonClicked)
-
- lineWidthLayout = qt.QHBoxLayout()
- lineWidthLayout.setContentsMargins(0, 0, 0, 0)
- lineWidthLayout.addWidget(qt.QLabel('1'))
- lineWidthLayout.addWidget(self._lineWidthSlider, 1)
- lineWidthLayout.addWidget(qt.QLabel(str(self.MAX_LINE_WIDTH)))
- lineWidthLayout.addWidget(lineWidthButton, 0)
-
- layout.addRow('Line Width', lineWidthLayout)
-
- self._colormapDialog = None # To store dialog
- self._colormap = Colormap()
-
- def getGroup(self):
- """Returns the :class:`GroupItem` this widget is attached to.
-
- :rtype: Union[GroupItem, None]
- """
- return self._group
-
- def setGroup(self, group):
- """Set the :class:`GroupItem` this widget is attached to.
-
- :param GroupItem group: GroupItem to control (or None)
- """
- self._group = group
- if group is not None:
- self.setEnabled(True)
-
- def _colormapButtonClicked(self, checked=False):
- """Handle colormap button clicked"""
- group = self.getGroup()
- if group is None:
- return
-
- if self._colormapDialog is None:
- self._colormapDialog = ColormapDialog(self)
- self._colormapDialog.setColormap(self._colormap)
-
- previousColormap = self._colormapDialog.getColormap()
- if self._colormapDialog.exec_():
- colormap = self._colormapDialog.getColormap()
-
- for item in group.visit():
- if isinstance(item, ColormapMixIn):
- itemCmap = item.getColormap()
- cmapName = colormap.getName()
- if cmapName is not None:
- itemCmap.setName(colormap.getName())
- else:
- itemCmap.setColormapLUT(colormap.getColormapLUT())
- itemCmap.setNormalization(colormap.getNormalization())
- itemCmap.setGammaNormalizationParameter(
- colormap.getGammaNormalizationParameter())
- itemCmap.setVRange(colormap.getVMin(), colormap.getVMax())
- else:
- # Reset colormap
- self._colormapDialog.setColormap(previousColormap)
-
- def _markerButtonClicked(self, checked=False):
- """Handle marker set button clicked"""
- group = self.getGroup()
- if group is None:
- return
-
- marker = self._markerComboBox.currentText()
- for item in group.visit():
- if isinstance(item, SymbolMixIn):
- item.setSymbol(marker)
-
- def _markerSizeButtonClicked(self, checked=False):
- """Handle marker size set button clicked"""
- group = self.getGroup()
- if group is None:
- return
-
- markerSize = self._markerSizeSlider.value()
- for item in group.visit():
- if isinstance(item, SymbolMixIn):
- item.setSymbolSize(markerSize)
-
- def _lineWidthButtonClicked(self, checked=False):
- """Handle line width set button clicked"""
- group = self.getGroup()
- if group is None:
- return
-
- lineWidth = self._lineWidthSlider.value()
- for item in group.visit():
- if hasattr(item, 'setLineWidth'):
- item.setLineWidth(lineWidth)
diff --git a/silx/gui/plot3d/tools/PositionInfoWidget.py b/silx/gui/plot3d/tools/PositionInfoWidget.py
deleted file mode 100644
index 78f2959..0000000
--- a/silx/gui/plot3d/tools/PositionInfoWidget.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a widget that displays data values of a SceneWidget.
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/10/2018"
-
-
-import logging
-import weakref
-
-from ... import qt
-from .. import actions
-from .. import items
-from ..items import volume
-from ..SceneWidget import SceneWidget
-
-
-_logger = logging.getLogger(__name__)
-
-
-class PositionInfoWidget(qt.QWidget):
- """Widget displaying information about picked position
-
- :param QWidget parent: See :class:`QWidget`
- """
-
- def __init__(self, parent=None):
- super(PositionInfoWidget, self).__init__(parent)
- self._sceneWidgetRef = None
-
- self.setToolTip("Double-click on a data point to show its value")
- layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight, self)
-
- self._xLabel = self._addInfoField('X')
- self._yLabel = self._addInfoField('Y')
- self._zLabel = self._addInfoField('Z')
- self._dataLabel = self._addInfoField('Data')
- self._itemLabel = self._addInfoField('Item')
-
- layout.addStretch(1)
-
- self._action = actions.mode.PickingModeAction(parent=self)
- self._action.setText('Selection')
- self._action.setToolTip(
- 'Toggle selection information update with left button click')
- self._action.sigSceneClicked.connect(self.pick)
- self._action.changed.connect(self.__actionChanged)
- self._action.setChecked(False) # Disabled by default
- self.__actionChanged() # Sync action/widget
-
- def __actionChanged(self):
- """Handle toggle action change signal"""
- if self.toggleAction().isChecked() != self.isEnabled():
- self.setEnabled(self.toggleAction().isChecked())
-
- def toggleAction(self):
- """The action to toggle the picking mode.
-
- :rtype: QAction
- """
- return self._action
-
- def _addInfoField(self, label):
- """Add a description: info widget to this widget
-
- :param str label: Description label
- :return: The QLabel used to display the info
- :rtype: QLabel
- """
- subLayout = qt.QHBoxLayout()
- subLayout.setContentsMargins(0, 0, 0, 0)
-
- subLayout.addWidget(qt.QLabel(label + ':'))
-
- widget = qt.QLabel('-')
- widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
- widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
- widget.setMinimumWidth(widget.fontMetrics().width('#######'))
- subLayout.addWidget(widget)
-
- subLayout.addStretch(1)
-
- layout = self.layout()
- layout.addLayout(subLayout)
- return widget
-
- def getSceneWidget(self):
- """Returns the associated :class:`SceneWidget` or None.
-
- :rtype: Union[None,~silx.gui.plot3d.SceneWidget.SceneWidget]
- """
- if self._sceneWidgetRef is None:
- return None
- else:
- return self._sceneWidgetRef()
-
- def setSceneWidget(self, widget):
- """Set the associated :class:`SceneWidget`
-
- :param ~silx.gui.plot3d.SceneWidget.SceneWidget widget:
- 3D scene for which to display information
- """
- if widget is not None and not isinstance(widget, SceneWidget):
- raise ValueError("widget must be a SceneWidget or None")
-
- self._sceneWidgetRef = None if widget is None else weakref.ref(widget)
-
- self.toggleAction().setPlot3DWidget(widget)
-
- def clear(self):
- """Clean-up displayed values"""
- for widget in (self._xLabel, self._yLabel, self._zLabel,
- self._dataLabel, self._itemLabel):
- widget.setText('-')
-
- _SUPPORTED_ITEMS = (items.Scatter3D,
- items.Scatter2D,
- items.ImageData,
- items.ImageRgba,
- items.HeightMapData,
- items.HeightMapRGBA,
- items.Mesh,
- items.Box,
- items.Cylinder,
- items.Hexagon,
- volume.CutPlane,
- volume.Isosurface)
- """Type of items that are picked"""
-
- def _isSupportedItem(self, item):
- """Returns True if item is of supported type
-
- :param Item3D item: The Item3D to check
- :rtype: bool
- """
- return isinstance(item, self._SUPPORTED_ITEMS)
-
- def pick(self, x, y):
- """Pick items in the associated SceneWidget and display result
-
- Only the closest point is displayed.
-
- :param int x: X coordinate in pixel in the SceneWidget
- :param int y: Y coordinate in pixel in the SceneWidget
- """
- self.clear()
-
- sceneWidget = self.getSceneWidget()
- if sceneWidget is None: # No associated widget
- _logger.info('Picking without associated SceneWidget')
- return
-
- # Find closest (and latest in the tree) supported item
- closestNdcZ = float('inf')
- picking = None
- for result in sceneWidget.pickItems(x, y,
- condition=self._isSupportedItem):
- ndcZ = result.getPositions('ndc', copy=False)[0, 2]
- if ndcZ <= closestNdcZ:
- closestNdcZ = ndcZ
- picking = result
-
- if picking is None:
- return # No picked item
-
- item = picking.getItem()
- self._itemLabel.setText(item.getLabel())
- positions = picking.getPositions('scene', copy=False)
- x, y, z = positions[0]
- self._xLabel.setText("%g" % x)
- self._yLabel.setText("%g" % y)
- self._zLabel.setText("%g" % z)
-
- data = picking.getData(copy=False)
- if data is not None:
- data = data[0]
- if hasattr(data, '__len__'):
- text = ' '.join(["%.3g"] * len(data)) % tuple(data)
- else:
- text = "%g" % data
- self._dataLabel.setText(text)
-
- def updateInfo(self):
- """Update information according to cursor position"""
- widget = self.getSceneWidget()
- if widget is None:
- _logger.info('Update without associated SceneWidget')
- self.clear()
- return
-
- position = widget.mapFromGlobal(qt.QCursor.pos())
- self.pick(position.x(), position.y())
diff --git a/silx/gui/plot3d/tools/test/__init__.py b/silx/gui/plot3d/tools/test/__init__.py
deleted file mode 100644
index 2dbc0ab..0000000
--- a/silx/gui/plot3d/tools/test/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""plot3d tools test suite."""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/10/2018"
-
-
-import unittest
-from .testPositionInfoWidget import suite as testPositionInfoWidgetSuite
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(testPositionInfoWidgetSuite())
- return testsuite
diff --git a/silx/gui/plot3d/tools/test/testPositionInfoWidget.py b/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
deleted file mode 100644
index 4520a2a..0000000
--- a/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-# ###########################################################################*/
-"""Test PositionInfoWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/10/2018"
-
-
-import unittest
-
-import numpy
-
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-from silx.gui.plot3d.SceneWidget import SceneWidget
-from silx.gui.plot3d.tools.PositionInfoWidget import PositionInfoWidget
-
-
-class TestPositionInfoWidget(TestCaseQt):
- """Tests PositionInfoWidget"""
-
- def setUp(self):
- super(TestPositionInfoWidget, self).setUp()
- self.sceneWidget = SceneWidget()
- self.sceneWidget.resize(300, 300)
- self.sceneWidget.show()
-
- self.positionInfoWidget = PositionInfoWidget()
- self.positionInfoWidget.setSceneWidget(self.sceneWidget)
- self.positionInfoWidget.show()
- self.qWaitForWindowExposed(self.positionInfoWidget)
-
- # self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- self.qapp.processEvents()
-
- self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.sceneWidget.close()
- del self.sceneWidget
-
- self.positionInfoWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.positionInfoWidget.close()
- del self.positionInfoWidget
- super(TestPositionInfoWidget, self).tearDown()
-
- def test(self):
- """Test PositionInfoWidget"""
- self.assertIs(self.positionInfoWidget.getSceneWidget(),
- self.sceneWidget)
-
- data = numpy.arange(100)
- self.sceneWidget.add2DScatter(x=data, y=data, value=data)
- self.sceneWidget.resetZoom('front')
-
- # Double click at the center
- self.mouseDClick(self.sceneWidget, button=qt.Qt.LeftButton)
-
- # Clear displayed value
- self.positionInfoWidget.clear()
-
- # Update info from API
- self.positionInfoWidget.pick(x=10, y=10)
-
- # Remove SceneWidget
- self.positionInfoWidget.setSceneWidget(None)
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(
- TestPositionInfoWidget))
- return testsuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/qt/__init__.py b/silx/gui/qt/__init__.py
deleted file mode 100644
index ace2841..0000000
--- a/silx/gui/qt/__init__.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Common wrapper over Python Qt bindings:
-
-- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_
-- `PySide2 <https://wiki.qt.io/Qt_for_Python>`_
-- `PyQt4 <http://pyqt.sourceforge.net/Docs/PyQt4/>`_
-
-If a Qt binding is already loaded, it will use it, otherwise the different
-Qt bindings are tried in this order: PyQt5, PyQt4, PySide2.
-
-The name of the loaded Qt binding is stored in the BINDING variable.
-
-This module provides a flat namespace over Qt bindings by importing
-all symbols from **QtCore** and **QtGui** packages and if available
-from **QtOpenGL** and **QtSvg** packages.
-For **PyQt5**, it also imports all symbols from **QtWidgets** and
-**QtPrintSupport** packages.
-
-Example of using :mod:`silx.gui.qt` module:
-
->>> from silx.gui import qt
->>> app = qt.QApplication([])
->>> widget = qt.QWidget()
-
-For an alternative solution providing a structured namespace,
-see `qtpy <https://pypi.org/project/QtPy/>`_ which
-provides the namespace of PyQt5 over PyQt4, PySide and PySide2.
-"""
-
-from ._qt import * # noqa
-from ._utils import * # noqa
-
-
-if sys.platform == "darwin":
- if BINDING in ["PySide", "PyQt4"]:
- from . import _macosx
- _macosx.patch_QUrl_toLocalFile()
diff --git a/silx/gui/qt/_macosx.py b/silx/gui/qt/_macosx.py
deleted file mode 100644
index 07f3143..0000000
--- a/silx/gui/qt/_macosx.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-Patches for Mac OS X
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "30/11/2016"
-
-
-def patch_QUrl_toLocalFile():
- """Apply a monkey-patch on qt.QUrl to allow to reach filename when the URL
- come from a MIME data from a file drop. Without, `QUrl.toLocalName` with
- some version of Mac OS X returns a path which looks like
- `/.file/id=180.112`.
-
- Qt5 is or will be patch, but Qt4 and PySide are not.
-
- This fix uses the file URL and use an subprocess with an
- AppleScript. The script convert the URI into a posix path.
- The interpreter (osascript) is available on default OS X installs.
-
- See https://bugreports.qt.io/browse/QTBUG-40449
- """
- from ._qt import QUrl
- import subprocess
-
- def QUrl_toLocalFile(self):
- path = QUrl._oldToLocalFile(self)
- if not path.startswith("/.file/id="):
- return path
-
- url = self.toString()
- script = 'get posix path of my posix file \"%s\" -- kthxbai' % url
- try:
- p = subprocess.Popen(["osascript", "-e", script], stdout=subprocess.PIPE)
- out, _err = p.communicate()
- if p.returncode == 0:
- return out.strip()
- except OSError:
- pass
- return path
-
- QUrl._oldToLocalFile = QUrl.toLocalFile
- QUrl.toLocalFile = QUrl_toLocalFile
diff --git a/silx/gui/qt/_pyside_dynamic.py b/silx/gui/qt/_pyside_dynamic.py
deleted file mode 100644
index 6013416..0000000
--- a/silx/gui/qt/_pyside_dynamic.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# -*- coding: utf-8 -*-
-
-# Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8
-# Plus: https://github.com/spyder-ide/qtpy/commit/001a862c401d757feb63025f88dbb4601d353c84
-
-# Copyright (c) 2011 Sebastian Wiesner <lunaryorn@gmail.com>
-# Modifications by Charl Botha <cpbotha@vxlabs.com>
-# * customWidgets support (registerCustomWidget() causes segfault in
-# pyside 1.1.2 on Ubuntu 12.04 x86_64)
-# * workingDirectory support in loadUi
-
-# found this here:
-# https://github.com/lunaryorn/snippets/blob/master/qt4/designer/pyside_dynamic.py
-
-# Permission is hereby granted, free of charge, to any person obtaining a
-# copy of this software and associated documentation files (the "Software"),
-# to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense,
-# and/or sell copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-"""
- How to load a user interface dynamically with PySide.
-
- .. moduleauthor:: Sebastian Wiesner <lunaryorn@gmail.com>
-"""
-
-from __future__ import (print_function, division, unicode_literals,
- absolute_import)
-
-import logging
-import sys
-
-if "PySide.QtCore" in sys.modules:
- from PySide.QtCore import QMetaObject
- from PySide.QtUiTools import QUiLoader
-else: # PySide2
- from PySide2.QtCore import QMetaObject, Property, Qt
- from PySide2.QtWidgets import QFrame
- from PySide2.QtUiTools import QUiLoader
-
-_logger = logging.getLogger(__name__)
-
-
-class UiLoader(QUiLoader):
- """
- Subclass :class:`~PySide.QtUiTools.QUiLoader` to create the user interface
- in a base instance.
-
- Unlike :class:`~PySide.QtUiTools.QUiLoader` itself this class does not
- create a new instance of the top-level widget, but creates the user
- interface in an existing instance of the top-level class.
-
- This mimics the behaviour of :func:`PyQt*.uic.loadUi`.
- """
-
- def __init__(self, baseinstance, customWidgets=None):
- """
- Create a loader for the given ``baseinstance``.
-
- The user interface is created in ``baseinstance``, which must be an
- instance of the top-level class in the user interface to load, or a
- subclass thereof.
-
- ``customWidgets`` is a dictionary mapping from class name to class
- object for widgets that you've promoted in the Qt Designer
- interface. Usually, this should be done by calling
- registerCustomWidget on the QUiLoader, but
- with PySide 1.1.2 on Ubuntu 12.04 x86_64 this causes a segfault.
-
- ``parent`` is the parent object of this loader.
- """
-
- QUiLoader.__init__(self, baseinstance)
- self.baseinstance = baseinstance
- self.customWidgets = {}
- self.uifile = None
- self.customWidgets.update(customWidgets)
-
- def createWidget(self, class_name, parent=None, name=''):
- """
- Function that is called for each widget defined in ui file,
- overridden here to populate baseinstance instead.
- """
-
- if parent is None and self.baseinstance:
- # supposed to create the top-level widget, return the base instance
- # instead
- return self.baseinstance
-
- else:
- if class_name in self.availableWidgets():
- # create a new widget for child widgets
- widget = QUiLoader.createWidget(self, class_name, parent, name)
-
- else:
- # if not in the list of availableWidgets,
- # must be a custom widget
- # this will raise KeyError if the user has not supplied the
- # relevant class_name in the dictionary, or TypeError, if
- # customWidgets is None
- if class_name not in self.customWidgets:
- raise Exception('No custom widget ' + class_name +
- ' found in customWidgets param of' +
- 'UiFile %s.' % self.uifile)
- try:
- widget = self.customWidgets[class_name](parent)
- except Exception:
- _logger.error("Fail to instanciate widget %s from file %s", class_name, self.uifile)
- raise
-
- if self.baseinstance:
- # set an attribute for the new child widget on the base
- # instance, just like PyQt*.uic.loadUi does.
- setattr(self.baseinstance, name, widget)
-
- # this outputs the various widget names, e.g.
- # sampleGraphicsView, dockWidget, samplesTableView etc.
- # print(name)
-
- return widget
-
- def _parse_custom_widgets(self, ui_file):
- """
- This function is used to parse a ui file and look for the <customwidgets>
- section, then automatically load all the custom widget classes.
- """
- import importlib
- from xml.etree.ElementTree import ElementTree
-
- # Parse the UI file
- etree = ElementTree()
- ui = etree.parse(ui_file)
-
- # Get the customwidgets section
- custom_widgets = ui.find('customwidgets')
-
- if custom_widgets is None:
- return
-
- custom_widget_classes = {}
-
- for custom_widget in custom_widgets.getchildren():
-
- cw_class = custom_widget.find('class').text
- cw_header = custom_widget.find('header').text
-
- module = importlib.import_module(cw_header)
-
- custom_widget_classes[cw_class] = getattr(module, cw_class)
-
- self.customWidgets.update(custom_widget_classes)
-
- def load(self, uifile):
- self._parse_custom_widgets(uifile)
- self.uifile = uifile
- return QUiLoader.load(self, uifile)
-
-
-if "PySide2.QtCore" in sys.modules:
-
- class _Line(QFrame):
- """Widget to use as 'Line' Qt designer"""
- def __init__(self, parent=None):
- super(_Line, self).__init__(parent)
- self.setFrameShape(QFrame.HLine)
- self.setFrameShadow(QFrame.Sunken)
-
- def getOrientation(self):
- shape = self.frameShape()
- if shape == QFrame.HLine:
- return Qt.Horizontal
- elif shape == QFrame.VLine:
- return Qt.Vertical
- else:
- raise RuntimeError("Wrong shape: %d", shape)
-
- def setOrientation(self, orientation):
- if orientation == Qt.Horizontal:
- self.setFrameShape(QFrame.HLine)
- elif orientation == Qt.Vertical:
- self.setFrameShape(QFrame.VLine)
- else:
- raise ValueError("Unsupported orientation %s" % str(orientation))
-
- orientation = Property("Qt::Orientation", getOrientation, setOrientation)
-
- CUSTOM_WIDGETS = {"Line": _Line}
- """Default custom widgets for `loadUi`"""
-
-else: # PySide support
- CUSTOM_WIDGETS = {}
- """Default custom widgets for `loadUi`"""
-
-
-def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None):
- """
- Dynamically load a user interface from the given ``uifile``.
-
- ``uifile`` is a string containing a file name of the UI file to load.
-
- If ``baseinstance`` is ``None``, the a new instance of the top-level widget
- will be created. Otherwise, the user interface is created within the given
- ``baseinstance``. In this case ``baseinstance`` must be an instance of the
- top-level widget class in the UI file to load, or a subclass thereof. In
- other words, if you've created a ``QMainWindow`` interface in the designer,
- ``baseinstance`` must be a ``QMainWindow`` or a subclass thereof, too. You
- cannot load a ``QMainWindow`` UI file with a plain
- :class:`~PySide.QtGui.QWidget` as ``baseinstance``.
-
- :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on the
- created user interface, so you can implemented your slots according to its
- conventions in your widget class.
-
- Return ``baseinstance``, if ``baseinstance`` is not ``None``. Otherwise
- return the newly created instance of the user interface.
- """
- if package is not None:
- _logger.warning(
- "loadUi package parameter not implemented with PySide")
- if resource_suffix is not None:
- _logger.warning(
- "loadUi resource_suffix parameter not implemented with PySide")
-
- loader = UiLoader(baseinstance, customWidgets=CUSTOM_WIDGETS)
- widget = loader.load(uifile)
- QMetaObject.connectSlotsByName(widget)
- return widget
diff --git a/silx/gui/qt/_pyside_missing.py b/silx/gui/qt/_pyside_missing.py
deleted file mode 100644
index a7e2781..0000000
--- a/silx/gui/qt/_pyside_missing.py
+++ /dev/null
@@ -1,274 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""
-Python implementation of classes which are not provided by default by PySide.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/01/2017"
-
-
-from PySide.QtGui import QAbstractProxyModel
-from PySide.QtCore import QModelIndex
-from PySide.QtCore import Qt
-from PySide.QtGui import QItemSelection
-from PySide.QtGui import QItemSelectionRange
-
-
-class QIdentityProxyModel(QAbstractProxyModel):
- """Python translation of the source code of Qt c++ file"""
-
- def __init__(self, parent=None):
- super(QIdentityProxyModel, self).__init__(parent)
- self.__ignoreNextLayoutAboutToBeChanged = False
- self.__ignoreNextLayoutChanged = False
- self.__persistentIndexes = []
-
- def columnCount(self, parent):
- parent = self.mapToSource(parent)
- return self.sourceModel().columnCount(parent)
-
- def dropMimeData(self, data, action, row, column, parent):
- parent = self.mapToSource(parent)
- return self.sourceModel().dropMimeData(data, action, row, column, parent)
-
- def index(self, row, column, parent=QModelIndex()):
- parent = self.mapToSource(parent)
- i = self.sourceModel().index(row, column, parent)
- return self.mapFromSource(i)
-
- def insertColumns(self, column, count, parent=QModelIndex()):
- parent = self.mapToSource(parent)
- return self.sourceModel().insertColumns(column, count, parent)
-
- def insertRows(self, row, count, parent=QModelIndex()):
- parent = self.mapToSource(parent)
- return self.sourceModel().insertRows(row, count, parent)
-
- def mapFromSource(self, sourceIndex):
- if self.sourceModel() is None or not sourceIndex.isValid():
- return QModelIndex()
- index = self.createIndex(sourceIndex.row(), sourceIndex.column(), sourceIndex.internalPointer())
- return index
-
- def mapSelectionFromSource(self, sourceSelection):
- proxySelection = QItemSelection()
- if self.sourceModel() is None:
- return proxySelection
-
- cursor = sourceSelection.constBegin()
- end = sourceSelection.constEnd()
- while cursor != end:
- topLeft = self.mapFromSource(cursor.topLeft())
- bottomRight = self.mapFromSource(cursor.bottomRight())
- proxyRange = QItemSelectionRange(topLeft, bottomRight)
- proxySelection.append(proxyRange)
- cursor += 1
- return proxySelection
-
- def mapSelectionToSource(self, proxySelection):
- sourceSelection = QItemSelection()
- if self.sourceModel() is None:
- return sourceSelection
-
- cursor = proxySelection.constBegin()
- end = proxySelection.constEnd()
- while cursor != end:
- topLeft = self.mapToSource(cursor.topLeft())
- bottomRight = self.mapToSource(cursor.bottomRight())
- sourceRange = QItemSelectionRange(topLeft, bottomRight)
- sourceSelection.append(sourceRange)
- cursor += 1
- return sourceSelection
-
- def mapToSource(self, proxyIndex):
- if self.sourceModel() is None or not proxyIndex.isValid():
- return QModelIndex()
- return self.sourceModel().createIndex(proxyIndex.row(), proxyIndex.column(), proxyIndex.internalPointer())
-
- def match(self, start, role, value, hits=1, flags=Qt.MatchFlags(Qt.MatchStartsWith | Qt.MatchWrap)):
- if self.sourceModel() is None:
- return []
-
- start = self.mapToSource(start)
- sourceList = self.sourceModel().match(start, role, value, hits, flags)
- proxyList = []
- for cursor in sourceList:
- proxyList.append(self.mapFromSource(cursor))
- return proxyList
-
- def parent(self, child):
- sourceIndex = self.mapToSource(child)
- sourceParent = sourceIndex.parent()
- index = self.mapFromSource(sourceParent)
- return index
-
- def removeColumns(self, column, count, parent=QModelIndex()):
- parent = self.mapToSource(parent)
- return self.sourceModel().removeColumns(column, count, parent)
-
- def removeRows(self, row, count, parent=QModelIndex()):
- parent = self.mapToSource(parent)
- return self.sourceModel().removeRows(row, count, parent)
-
- def rowCount(self, parent=QModelIndex()):
- parent = self.mapToSource(parent)
- return self.sourceModel().rowCount(parent)
-
- def setSourceModel(self, newSourceModel):
- """Bind and unbind the source model events"""
- self.beginResetModel()
-
- sourceModel = self.sourceModel()
- if sourceModel is not None:
- sourceModel.rowsAboutToBeInserted.disconnect(self.__rowsAboutToBeInserted)
- sourceModel.rowsInserted.disconnect(self.__rowsInserted)
- sourceModel.rowsAboutToBeRemoved.disconnect(self.__rowsAboutToBeRemoved)
- sourceModel.rowsRemoved.disconnect(self.__rowsRemoved)
- sourceModel.rowsAboutToBeMoved.disconnect(self.__rowsAboutToBeMoved)
- sourceModel.rowsMoved.disconnect(self.__rowsMoved)
- sourceModel.columnsAboutToBeInserted.disconnect(self.__columnsAboutToBeInserted)
- sourceModel.columnsInserted.disconnect(self.__columnsInserted)
- sourceModel.columnsAboutToBeRemoved.disconnect(self.__columnsAboutToBeRemoved)
- sourceModel.columnsRemoved.disconnect(self.__columnsRemoved)
- sourceModel.columnsAboutToBeMoved.disconnect(self.__columnsAboutToBeMoved)
- sourceModel.columnsMoved.disconnect(self.__columnsMoved)
- sourceModel.modelAboutToBeReset.disconnect(self.__modelAboutToBeReset)
- sourceModel.modelReset.disconnect(self.__modelReset)
- sourceModel.dataChanged.disconnect(self.__dataChanged)
- sourceModel.headerDataChanged.disconnect(self.__headerDataChanged)
- sourceModel.layoutAboutToBeChanged.disconnect(self.__layoutAboutToBeChanged)
- sourceModel.layoutChanged.disconnect(self.__layoutChanged)
-
- super(QIdentityProxyModel, self).setSourceModel(newSourceModel)
-
- sourceModel = self.sourceModel()
- if sourceModel is not None:
- sourceModel.rowsAboutToBeInserted.connect(self.__rowsAboutToBeInserted)
- sourceModel.rowsInserted.connect(self.__rowsInserted)
- sourceModel.rowsAboutToBeRemoved.connect(self.__rowsAboutToBeRemoved)
- sourceModel.rowsRemoved.connect(self.__rowsRemoved)
- sourceModel.rowsAboutToBeMoved.connect(self.__rowsAboutToBeMoved)
- sourceModel.rowsMoved.connect(self.__rowsMoved)
- sourceModel.columnsAboutToBeInserted.connect(self.__columnsAboutToBeInserted)
- sourceModel.columnsInserted.connect(self.__columnsInserted)
- sourceModel.columnsAboutToBeRemoved.connect(self.__columnsAboutToBeRemoved)
- sourceModel.columnsRemoved.connect(self.__columnsRemoved)
- sourceModel.columnsAboutToBeMoved.connect(self.__columnsAboutToBeMoved)
- sourceModel.columnsMoved.connect(self.__columnsMoved)
- sourceModel.modelAboutToBeReset.connect(self.__modelAboutToBeReset)
- sourceModel.modelReset.connect(self.__modelReset)
- sourceModel.dataChanged.connect(self.__dataChanged)
- sourceModel.headerDataChanged.connect(self.__headerDataChanged)
- sourceModel.layoutAboutToBeChanged.connect(self.__layoutAboutToBeChanged)
- sourceModel.layoutChanged.connect(self.__layoutChanged)
-
- self.endResetModel()
-
- def __columnsAboutToBeInserted(self, parent, start, end):
- parent = self.mapFromSource(parent)
- self.beginInsertColumns(parent, start, end)
-
- def __columnsAboutToBeMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
- sourceParent = self.mapFromSource(sourceParent)
- destParent = self.mapFromSource(destParent)
- self.beginMoveColumns(sourceParent, sourceStart, sourceEnd, destParent, dest)
-
- def __columnsAboutToBeRemoved(self, parent, start, end):
- parent = self.mapFromSource(parent)
- self.beginRemoveColumns(parent, start, end)
-
- def __columnsInserted(self, parent, start, end):
- self.endInsertColumns()
-
- def __columnsMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
- self.endMoveColumns()
-
- def __columnsRemoved(self, parent, start, end):
- self.endRemoveColumns()
-
- def __dataChanged(self, topLeft, bottomRight):
- topLeft = self.mapFromSource(topLeft)
- bottomRight = self.mapFromSource(bottomRight)
- self.dataChanged(topLeft, bottomRight)
-
- def __headerDataChanged(self, orientation, first, last):
- self.headerDataChanged(orientation, first, last)
-
- def __layoutAboutToBeChanged(self):
- """Store persistent indexes"""
- if self.__ignoreNextLayoutAboutToBeChanged:
- return
-
- for proxyPersistentIndex in self.persistentIndexList():
- self.__proxyIndexes.append()
- sourcePersistentIndex = self.mapToSource(proxyPersistentIndex)
- mapping = proxyPersistentIndex, sourcePersistentIndex
- self.__persistentIndexes.append(mapping)
-
- self.layoutAboutToBeChanged()
-
- def __layoutChanged(self):
- """Restore persistent indexes"""
- if self.__ignoreNextLayoutChanged:
- return
-
- for mapping in self.__persistentIndexes:
- proxyIndex, sourcePersistentIndex = mapping
- sourcePersistentIndex = self.mapFromSource(sourcePersistentIndex)
- self.changePersistentIndex(proxyIndex, sourcePersistentIndex)
-
- self.__persistentIndexes = []
-
- self.layoutChanged()
-
- def __modelAboutToBeReset(self):
- self.beginResetModel()
-
- def __modelReset(self):
- self.endResetModel()
-
- def __rowsAboutToBeInserted(self, parent, start, end):
- parent = self.mapFromSource(parent)
- self.beginInsertRows(parent, start, end)
-
- def __rowsAboutToBeMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
- sourceParent = self.mapFromSource(sourceParent)
- destParent = self.mapFromSource(destParent)
- self.beginMoveRows(sourceParent, sourceStart, sourceEnd, destParent, dest)
-
- def __rowsAboutToBeRemoved(self, parent, start, end):
- parent = self.mapFromSource(parent)
- self.beginRemoveRows(parent, start, end)
-
- def __rowsInserted(self, parent, start, end):
- self.endInsertRows()
-
- def __rowsMoved(self, sourceParent, sourceStart, sourceEnd, destParent, dest):
- self.endMoveRows()
-
- def __rowsRemoved(self, parent, start, end):
- self.endRemoveRows()
diff --git a/silx/gui/qt/_qt.py b/silx/gui/qt/_qt.py
deleted file mode 100644
index 29a6354..0000000
--- a/silx/gui/qt/_qt.py
+++ /dev/null
@@ -1,289 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Load Qt binding"""
-
-__authors__ = ["V.A. Sole"]
-__license__ = "MIT"
-__date__ = "23/05/2018"
-
-
-import logging
-import sys
-import traceback
-
-from ...utils.deprecation import deprecated_warning
-
-
-_logger = logging.getLogger(__name__)
-
-
-BINDING = None
-"""The name of the Qt binding in use: PyQt5, PyQt4 or PySide2."""
-
-QtBinding = None # noqa
-"""The Qt binding module in use: PyQt5, PyQt4 or PySide2."""
-
-HAS_SVG = False
-"""True if Qt provides support for Scalable Vector Graphics (QtSVG)."""
-
-HAS_OPENGL = False
-"""True if Qt provides support for OpenGL (QtOpenGL)."""
-
-# First check for an already loaded wrapper
-if 'PySide2.QtCore' in sys.modules:
- BINDING = 'PySide2'
-
-elif 'PySide.QtCore' in sys.modules:
- BINDING = 'PySide'
-
-elif 'PyQt5.QtCore' in sys.modules:
- BINDING = 'PyQt5'
-
-elif 'PyQt4.QtCore' in sys.modules:
- BINDING = 'PyQt4'
-
-else: # Then try Qt bindings
- try:
- import PyQt5.QtCore # noqa
- except ImportError:
- if 'PyQt5' in sys.modules:
- del sys.modules["PyQt5"]
- try:
- import sip
- sip.setapi("QString", 2)
- sip.setapi("QVariant", 2)
- sip.setapi('QDate', 2)
- sip.setapi('QDateTime', 2)
- sip.setapi('QTextStream', 2)
- sip.setapi('QTime', 2)
- sip.setapi('QUrl', 2)
- import PyQt4.QtCore # noqa
- except ImportError:
- if 'PyQt4' in sys.modules:
- del sys.modules["sip"]
- del sys.modules["PyQt4"]
- try:
- import PySide2.QtCore # noqa
- except ImportError:
- if 'PySide2' in sys.modules:
- del sys.modules["PySide2"]
- try:
- import PySide.QtCore # noqa
- except ImportError:
- if 'PySide' in sys.modules:
- del sys.modules["PySide"]
- raise ImportError(
- 'No Qt wrapper found. Install PyQt5, PyQt4 or PySide2.')
- else:
- BINDING = 'PySide'
- else:
- BINDING = 'PySide2'
- else:
- BINDING = 'PyQt4'
- else:
- BINDING = 'PyQt5'
-
-
-if BINDING == 'PyQt4':
- _logger.debug('Using PyQt4 bindings')
- deprecated_warning("Qt Binding", "PyQt4",
- replacement='PyQt5',
- since_version='0.9.0')
-
- if sys.version_info < (3, ):
- try:
- import sip
- sip.setapi("QString", 2)
- sip.setapi("QVariant", 2)
- sip.setapi('QDate', 2)
- sip.setapi('QDateTime', 2)
- sip.setapi('QTextStream', 2)
- sip.setapi('QTime', 2)
- sip.setapi('QUrl', 2)
- except:
- _logger.warning("Cannot set sip API")
-
- import PyQt4 as QtBinding # noqa
-
- from PyQt4.QtCore import * # noqa
- from PyQt4.QtGui import * # noqa
-
- try:
- from PyQt4.QtOpenGL import * # noqa
- except ImportError:
- _logger.info("PyQt4.QtOpenGL not available")
- HAS_OPENGL = False
- else:
- HAS_OPENGL = True
-
- try:
- from PyQt4.QtSvg import * # noqa
- except ImportError:
- _logger.info("PyQt4.QtSvg not available")
- HAS_SVG = False
- else:
- HAS_SVG = True
-
- from PyQt4.uic import loadUi # noqa
-
- Signal = pyqtSignal
-
- Property = pyqtProperty
-
- Slot = pyqtSlot
-
-elif BINDING == 'PySide':
- _logger.debug('Using PySide bindings')
- deprecated_warning("Qt Binding", "PySide",
- replacement='PySide2',
- since_version='0.9.0')
-
- import PySide as QtBinding # noqa
-
- from PySide.QtCore import * # noqa
- from PySide.QtGui import * # noqa
-
- try:
- from PySide.QtOpenGL import * # noqa
- except ImportError:
- _logger.info("PySide.QtOpenGL not available")
- HAS_OPENGL = False
- else:
- HAS_OPENGL = True
-
- try:
- from PySide.QtSvg import * # noqa
- except ImportError:
- _logger.info("PySide.QtSvg not available")
- HAS_SVG = False
- else:
- HAS_SVG = True
-
- pyqtSignal = Signal
-
- # Import loadUi wrapper for PySide
- from ._pyside_dynamic import loadUi # noqa
-
- # Import missing classes
- if not hasattr(locals(), "QIdentityProxyModel"):
- from ._pyside_missing import QIdentityProxyModel # noqa
-
-elif BINDING == 'PyQt5':
- _logger.debug('Using PyQt5 bindings')
-
- import PyQt5 as QtBinding # noqa
-
- from PyQt5.QtCore import * # noqa
- from PyQt5.QtGui import * # noqa
- from PyQt5.QtWidgets import * # noqa
- from PyQt5.QtPrintSupport import * # noqa
-
- try:
- from PyQt5.QtOpenGL import * # noqa
- except ImportError:
- _logger.info("PySide.QtOpenGL not available")
- HAS_OPENGL = False
- else:
- HAS_OPENGL = True
-
- try:
- from PyQt5.QtSvg import * # noqa
- except ImportError:
- _logger.info("PyQt5.QtSvg not available")
- HAS_SVG = False
- else:
- HAS_SVG = True
-
- from PyQt5.uic import loadUi # noqa
-
- Signal = pyqtSignal
-
- Property = pyqtProperty
-
- Slot = pyqtSlot
-
- # Disable PyQt5's cooperative multi-inheritance since other bindings do not provide it.
- # See https://www.riverbankcomputing.com/static/Docs/PyQt5/multiinheritance.html?highlight=inheritance
- class _Foo(object): pass
- class QObject(QObject, _Foo): pass
-
-
-elif BINDING == 'PySide2':
- _logger.debug('Using PySide2 bindings')
-
- import PySide2 as QtBinding # noqa
-
- from PySide2.QtCore import * # noqa
- from PySide2.QtGui import * # noqa
- from PySide2.QtWidgets import * # noqa
- from PySide2.QtPrintSupport import * # noqa
-
- try:
- from PySide2.QtOpenGL import * # noqa
- except ImportError:
- _logger.info("PySide2.QtOpenGL not available")
- HAS_OPENGL = False
- else:
- HAS_OPENGL = True
-
- try:
- from PySide2.QtSvg import * # noqa
- except ImportError:
- _logger.info("PySide2.QtSvg not available")
- HAS_SVG = False
- else:
- HAS_SVG = True
-
- # Import loadUi wrapper for PySide2
- from ._pyside_dynamic import loadUi # noqa
-
- pyqtSignal = Signal
-
-else:
- raise ImportError('No Qt wrapper found. Install PyQt4, PyQt5, PySide2')
-
-
-# provide a exception handler but not implement it by default
-def exceptionHandler(type_, value, trace):
- """
- This exception handler prevents quitting to the command line when there is
- an unhandled exception while processing a Qt signal.
-
- The script/application willing to use it should implement code similar to:
-
- .. code-block:: python
-
- if __name__ == "__main__":
- sys.excepthook = qt.exceptionHandler
-
- """
- _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace)))
- msg = QMessageBox()
- msg.setWindowTitle("Unhandled exception")
- msg.setIcon(QMessageBox.Critical)
- msg.setInformativeText("%s %s\nPlease report details" % (type_, value))
- msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace)))
- msg.raise_()
- msg.exec_()
diff --git a/silx/gui/qt/_utils.py b/silx/gui/qt/_utils.py
deleted file mode 100644
index 4a7a1c0..0000000
--- a/silx/gui/qt/_utils.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides convenient functions related to Qt.
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "30/11/2016"
-
-
-import sys as _sys
-from . import _qt
-
-
-def supportedImageFormats():
- """Return a set of string of file format extensions supported by the
- Qt runtime."""
- if _sys.version_info[0] < 3 or _qt.BINDING == 'PySide':
- convert = str
- elif _qt.BINDING == 'PySide2':
- def convert(data):
- return str(data.data(), 'ascii')
- else:
- convert = lambda data: str(data, 'ascii')
- formats = _qt.QImageReader.supportedImageFormats()
- return set([convert(data) for data in formats])
-
-
-__globalThreadPoolInstance = None
-"""Store the own silx global thread pool"""
-
-
-def silxGlobalThreadPool():
- """"Manage an own QThreadPool to avoid issue on Qt5 Windows with the
- default Qt global thread pool.
-
- A thread pool is create in lazy loading. With a maximum of 4 threads.
- Else `qt.Thread.idealThreadCount()` is used.
-
- :rtype: qt.QThreadPool
- """
- global __globalThreadPoolInstance
- if __globalThreadPoolInstance is None:
- tp = _qt.QThreadPool()
- # Setting maxThreadCount fixes a segfault with PyQt 5.9.1 on Windows
- maxThreadCount = min(4, tp.maxThreadCount())
- tp.setMaxThreadCount(maxThreadCount)
- __globalThreadPoolInstance = tp
- return __globalThreadPoolInstance
diff --git a/silx/gui/qt/inspect.py b/silx/gui/qt/inspect.py
deleted file mode 100644
index 3c08835..0000000
--- a/silx/gui/qt/inspect.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides functions to access Qt C++ object state:
-
-- :func:`isValid` to check whether a QObject C++ pointer is valid.
-- :func:`createdByPython` to check if a QObject was created from Python.
-- :func:`ownedByPython` to check if a QObject is currently owned by Python.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "08/10/2018"
-
-
-from . import _qt as qt
-
-
-if qt.BINDING in ('PyQt4', 'PyQt5'):
- if qt.BINDING == 'PyQt5':
- try:
- from PyQt5.sip import isdeleted as _isdeleted # noqa
- from PyQt5.sip import ispycreated as createdByPython # noqa
- from PyQt5.sip import ispyowned as ownedByPython # noqa
- except ImportError:
- from sip import isdeleted as _isdeleted # noqa
- from sip import ispycreated as createdByPython # noqa
- from sip import ispyowned as ownedByPython # noqa
-
- else: # PyQt4
- from sip import isdeleted as _isdeleted # noqa
- from sip import ispycreated as createdByPython # noqa
- from sip import ispyowned as ownedByPython # noqa
-
- def isValid(obj):
- """Returns True if underlying C++ object is valid.
-
- :param QObject obj:
- :rtype: bool
- """
- return not _isdeleted(obj)
-
-elif qt.BINDING == 'PySide2':
- try:
- from PySide2.shiboken2 import isValid # noqa
- from PySide2.shiboken2 import createdByPython # noqa
- from PySide2.shiboken2 import ownedByPython # noqa
- except ImportError:
- from shiboken2 import isValid # noqa
- from shiboken2 import createdByPython # noqa
- from shiboken2 import ownedByPython # noqa
-
-elif qt.BINDING == 'PySide':
- try: # Available through PySide
- from PySide.shiboken import isValid # noqa
- from PySide.shiboken import createdByPython # noqa
- from PySide.shiboken import ownedByPython # noqa
- except ImportError: # Available through standalone shiboken package
- from Shiboken.shiboken import isValid # noqa
- from Shiboken.shiboken import createdByPython # noqa
- from Shiboken.shiboken import ownedByPython # noqa
-
-else:
- raise ImportError("Unsupported Qt binding %s" % qt.BINDING)
-
-__all__ = ['isValid', 'createdByPython', 'ownedByPython']
diff --git a/silx/gui/test/__init__.py b/silx/gui/test/__init__.py
deleted file mode 100644
index 2e7901d..0000000
--- a/silx/gui/test/__init__.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import logging
-import os
-import sys
-import unittest
-
-from silx.test.utils import test_options
-
-_logger = logging.getLogger(__name__)
-
-
-def suite():
-
- test_suite = unittest.TestSuite()
-
- if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
- # On Linux and no DISPLAY available (e.g., ssh without -X)
- _logger.warning('silx.gui tests disabled (DISPLAY env. variable not set)')
-
- class SkipGUITest(unittest.TestCase):
- def runTest(self):
- self.skipTest(
- 'silx.gui tests disabled (DISPLAY env. variable not set)')
-
- test_suite.addTest(SkipGUITest())
- return test_suite
-
- elif not test_options.WITH_QT_TEST:
- # Explicitly disabled tests
- msg = "silx.gui tests disabled: %s" % test_options.WITH_QT_TEST_REASON
- _logger.warning(msg)
-
- class SkipGUITest(unittest.TestCase):
- def runTest(self):
- self.skipTest(test_options.WITH_QT_TEST_REASON)
-
- test_suite.addTest(SkipGUITest())
- return test_suite
-
- # Import here to avoid loading QT if tests are disabled
-
- from ..plot import test as test_plot
- from ..fit import test as test_fit
- from ..hdf5 import test as test_hdf5
- from ..widgets import test as test_widgets
- from ..data import test as test_data
- from ..dialog import test as test_dialog
- from ..utils import test as test_utils
-
- from . import test_qt
- # Console tests disabled due to corruption of python environment
- # (see issue #538 on github)
- # from . import test_console
- from . import test_icons
- from . import test_colors
-
- try:
- from ..plot3d.test import suite as test_plot3d_suite
-
- except ImportError:
- _logger.warning(
- 'silx.gui.plot3d tests disabled '
- '(PyOpenGL or QtOpenGL not installed)')
-
- class SkipPlot3DTest(unittest.TestCase):
- def runTest(self):
- self.skipTest('silx.gui.plot3d tests disabled '
- '(PyOpenGL or QtOpenGL not installed)')
-
- test_plot3d_suite = SkipPlot3DTest
-
- test_suite.addTest(test_qt.suite())
- test_suite.addTest(test_plot.suite())
- test_suite.addTest(test_fit.suite())
- test_suite.addTest(test_hdf5.suite())
- test_suite.addTest(test_widgets.suite())
- # test_suite.addTest(test_console.suite()) # see issue #538 on github
- test_suite.addTest(test_icons.suite())
- test_suite.addTest(test_colors.suite())
- test_suite.addTest(test_data.suite())
- test_suite.addTest(test_plot3d_suite())
- test_suite.addTest(test_dialog.suite())
- # Run test_utils last: it interferes with OpenGLWidget through isOpenGLAvailable
- test_suite.addTest(test_utils.suite())
- return test_suite
diff --git a/silx/gui/test/test_colors.py b/silx/gui/test/test_colors.py
deleted file mode 100755
index 9e23a93..0000000
--- a/silx/gui/test/test_colors.py
+++ /dev/null
@@ -1,619 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the Colormap object
-"""
-
-from __future__ import absolute_import
-
-__authors__ = ["H.Payno"]
-__license__ = "MIT"
-__date__ = "09/11/2018"
-
-import unittest
-import numpy
-from silx.utils.testutils import ParametricTestCase
-from silx.gui import qt
-from silx.gui import colors
-from silx.gui.colors import Colormap
-from silx.gui.plot import items
-from silx.utils.exceptions import NotEditableError
-
-
-class TestColor(ParametricTestCase):
- """Basic tests of rgba function"""
-
- TEST_COLORS = { # name: (colors, expected values)
- 'blue': ('blue', (0., 0., 1., 1.)),
- '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)),
- '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)),
- '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8),
- (1 / 255., 1., 0., 1.)),
- '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8),
- (1 / 255., 1., 0., 1 / 255.)),
- '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)),
- }
-
- def testRGBA(self):
- """"Test rgba function with accepted values"""
- for name, test in self.TEST_COLORS.items():
- color, expected = test
- with self.subTest(msg=name):
- result = colors.rgba(color)
- self.assertEqual(result, expected)
-
- def testQColor(self):
- """"Test getQColor function with accepted values"""
- for name, test in self.TEST_COLORS.items():
- color, expected = test
- with self.subTest(msg=name):
- result = colors.asQColor(color)
- self.assertAlmostEqual(result.redF(), expected[0], places=4)
- self.assertAlmostEqual(result.greenF(), expected[1], places=4)
- self.assertAlmostEqual(result.blueF(), expected[2], places=4)
- self.assertAlmostEqual(result.alphaF(), expected[3], places=4)
-
-
-class TestApplyColormapToData(ParametricTestCase):
- """Tests of applyColormapToData function"""
-
- def testApplyColormapToData(self):
- """Simple test of applyColormapToData function"""
- colormap = Colormap(name='gray', normalization='linear',
- vmin=0, vmax=255)
-
- size = 10
- expected = numpy.empty((size, 4), dtype='uint8')
- expected[:, 0] = numpy.arange(size, dtype='uint8')
- expected[:, 1] = expected[:, 0]
- expected[:, 2] = expected[:, 0]
- expected[:, 3] = 255
-
- for dtype in ('uint8', 'int32', 'float32', 'float64'):
- with self.subTest(dtype=dtype):
- array = numpy.arange(size, dtype=dtype)
- result = colormap.applyToData(data=array)
- self.assertTrue(numpy.all(numpy.equal(result, expected)))
-
- def testAutoscaleFromDataReference(self):
- colormap = Colormap(name='gray', normalization='linear')
- data = numpy.array([50])
- reference = numpy.array([0, 100])
- value = colormap.applyToData(data, reference)
- self.assertEqual(len(value), 1)
- self.assertEqual(value[0, 0], 128)
-
- def testAutoscaleFromItemReference(self):
- colormap = Colormap(name='gray', normalization='linear')
- data = numpy.array([50])
- image = items.ImageData()
- image.setData(numpy.array([[0, 100]]))
- value = colormap.applyToData(data, reference=image)
- self.assertEqual(len(value), 1)
- self.assertEqual(value[0, 0], 128)
-
- def testNaNColor(self):
- """Test Colormap.applyToData with NaN values"""
- colormap = Colormap(name='gray', normalization='linear')
- colormap.setNaNColor('red')
- self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
-
- data = numpy.array([50., numpy.nan])
- image = items.ImageData()
- image.setData(numpy.array([[0, 100]]))
- value = colormap.applyToData(data, reference=image)
- self.assertEqual(len(value), 2)
- self.assertTrue(numpy.array_equal(value[0], (128, 128, 128, 255)))
- self.assertTrue(numpy.array_equal(value[1], (255, 0, 0, 255)))
-
-
-class TestDictAPI(unittest.TestCase):
- """Make sure the old dictionary API is working
- """
-
- def setUp(self):
- self.vmin = -1.0
- self.vmax = 12
-
- def testGetItem(self):
- """test the item getter API ([xxx])"""
- colormap = Colormap(name='viridis',
- normalization=Colormap.LINEAR,
- vmin=self.vmin,
- vmax=self.vmax)
- self.assertTrue(colormap['name'] == 'viridis')
- self.assertTrue(colormap['normalization'] == Colormap.LINEAR)
- self.assertTrue(colormap['vmin'] == self.vmin)
- self.assertTrue(colormap['vmax'] == self.vmax)
- with self.assertRaises(KeyError):
- colormap['toto']
-
- def testGetDict(self):
- """Test the getDict function API"""
- clmObject = Colormap(name='viridis',
- normalization=Colormap.LINEAR,
- vmin=self.vmin,
- vmax=self.vmax)
- clmDict = clmObject._toDict()
- self.assertTrue(clmDict['name'] == 'viridis')
- self.assertTrue(clmDict['autoscale'] is False)
- self.assertTrue(clmDict['vmin'] == self.vmin)
- self.assertTrue(clmDict['vmax'] == self.vmax)
- self.assertTrue(clmDict['normalization'] == Colormap.LINEAR)
-
- clmObject.setVRange(None, None)
- self.assertTrue(clmObject._toDict()['autoscale'] is True)
-
- def testSetValidDict(self):
- """Test that if a colormap is created from a dict then it is correctly
- created and the values are copied (so if some values from the dict
- is changing, this won't affect the Colormap object"""
- clm_dict = {
- 'name': 'temperature',
- 'vmin': 1.0,
- 'vmax': 2.0,
- 'normalization': 'linear',
- 'colors': None,
- 'autoscale': False
- }
-
- # Test that the colormap is correctly created
- colormapObject = Colormap._fromDict(clm_dict)
- self.assertTrue(colormapObject.getName() == clm_dict['name'])
- self.assertTrue(colormapObject.getColormapLUT() == clm_dict['colors'])
- self.assertTrue(colormapObject.getVMin() == clm_dict['vmin'])
- self.assertTrue(colormapObject.getVMax() == clm_dict['vmax'])
- self.assertTrue(colormapObject.isAutoscale() == clm_dict['autoscale'])
-
- # Check that the colormap has copied the values
- clm_dict['vmin'] = None
- clm_dict['vmax'] = None
- clm_dict['colors'] = [1.0, 2.0]
- clm_dict['autoscale'] = True
- clm_dict['normalization'] = Colormap.LOGARITHM
- clm_dict['name'] = 'viridis'
-
- self.assertFalse(colormapObject.getName() == clm_dict['name'])
- self.assertFalse(colormapObject.getColormapLUT() == clm_dict['colors'])
- self.assertFalse(colormapObject.getVMin() == clm_dict['vmin'])
- self.assertFalse(colormapObject.getVMax() == clm_dict['vmax'])
- self.assertFalse(colormapObject.isAutoscale() == clm_dict['autoscale'])
-
- def testMissingKeysFromDict(self):
- """Make sure we can create a Colormap object from a dictionary even if
- there is missing keys except if those keys are 'colors' or 'name'
- """
- colormap = Colormap._fromDict({'name': 'blue'})
- self.assertTrue(colormap.getVMin() is None)
- colormap = Colormap._fromDict({'colors': numpy.zeros((5, 3))})
- self.assertTrue(colormap.getName() is None)
-
- with self.assertRaises(ValueError):
- Colormap._fromDict({})
-
- def testUnknowNorm(self):
- """Make sure an error is raised if the given normalization is not
- knowed
- """
- clm_dict = {
- 'name': 'temperature',
- 'vmin': 1.0,
- 'vmax': 2.0,
- 'normalization': 'toto',
- 'colors': None,
- 'autoscale': False
- }
- with self.assertRaises(ValueError):
- Colormap._fromDict(clm_dict)
-
- def testNumericalColors(self):
- """Make sure the old API using colors=int was supported"""
- clm_dict = {
- 'name': 'temperature',
- 'vmin': 1.0,
- 'vmax': 2.0,
- 'colors': 256,
- 'autoscale': False
- }
- Colormap._fromDict(clm_dict)
-
-
-class TestObjectAPI(ParametricTestCase):
- """Test the new Object API of the colormap"""
- def testVMinVMax(self):
- """Test getter and setter associated to vmin and vmax values"""
- vmin = 1.0
- vmax = 2.0
-
- colormapObject = Colormap(name='viridis',
- vmin=vmin,
- vmax=vmax,
- normalization=Colormap.LINEAR)
-
- with self.assertRaises(ValueError):
- colormapObject.setVMin(3)
-
- with self.assertRaises(ValueError):
- colormapObject.setVMax(-2)
-
- with self.assertRaises(ValueError):
- colormapObject.setVRange(3, -2)
-
- self.assertTrue(colormapObject.getColormapRange() == (1.0, 2.0))
- self.assertTrue(colormapObject.isAutoscale() is False)
- colormapObject.setVRange(None, None)
- self.assertTrue(colormapObject.getVMin() is None)
- self.assertTrue(colormapObject.getVMax() is None)
- self.assertTrue(colormapObject.isAutoscale() is True)
-
- def testCopy(self):
- """Make sure the copy function is correctly processing
- """
- colormapObject = Colormap(name=None,
- colors=numpy.array([[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]]),
- vmin=None,
- vmax=None,
- normalization=Colormap.LOGARITHM)
-
- colormapObject2 = colormapObject.copy()
- self.assertTrue(colormapObject == colormapObject2)
- colormapObject.setColormapLUT([[0, 0, 0], [255, 255, 255]])
- self.assertFalse(colormapObject == colormapObject2)
-
- colormapObject2 = colormapObject.copy()
- self.assertTrue(colormapObject == colormapObject2)
- colormapObject.setNormalization(Colormap.LINEAR)
- self.assertFalse(colormapObject == colormapObject2)
-
- def testGetColorMapRange(self):
- """Make sure the getColormapRange function of colormap is correctly
- applying
- """
- # test linear scale
- data = numpy.array([-1, 1, 2, 3, float('nan')])
- cl1 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=0,
- vmax=2)
- cl2 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=2)
- cl3 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=0,
- vmax=None)
- cl4 = Colormap(name='gray',
- normalization=Colormap.LINEAR,
- vmin=None,
- vmax=None)
-
- self.assertTrue(cl1.getColormapRange(data) == (0, 2))
- self.assertTrue(cl2.getColormapRange(data) == (-1, 2))
- self.assertTrue(cl3.getColormapRange(data) == (0, 3))
- self.assertTrue(cl4.getColormapRange(data) == (-1, 3))
-
- # test linear with annoying cases
- self.assertEqual(cl3.getColormapRange((-1, -2)), (0, 0))
- self.assertEqual(cl4.getColormapRange(()), (0., 1.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'))), (0., 1.))
-
- # test log scale
- data = numpy.array([float('nan'), -1, 1, 10, 100, 1000])
- cl1 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=1,
- vmax=100)
- cl2 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=100)
- cl3 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=1,
- vmax=None)
- cl4 = Colormap(name='gray',
- normalization=Colormap.LOGARITHM,
- vmin=None,
- vmax=None)
-
- self.assertTrue(cl1.getColormapRange(data) == (1, 100))
- self.assertTrue(cl2.getColormapRange(data) == (1, 100))
- self.assertTrue(cl3.getColormapRange(data) == (1, 1000))
- self.assertTrue(cl4.getColormapRange(data) == (1, 1000))
-
- # test log with annoying cases
- self.assertEqual(cl3.getColormapRange((0.1, 0.2)), (1, 1))
- self.assertEqual(cl4.getColormapRange((-2., -1.)), (1., 1.))
- self.assertEqual(cl4.getColormapRange(()), (1., 10.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
- self.assertEqual(cl4.getColormapRange(
- (float('nan'), float('inf'))), (1., 10.))
-
- def testApplyToData(self):
- """Test applyToData on different datasets"""
- datasets = [
- numpy.zeros((0, 0)), # Empty array
- numpy.array((numpy.nan, numpy.inf)), # All non-finite
- numpy.array((-numpy.inf, numpy.inf, 1.0, 2.0)), # Some infinite
- ]
-
- for normalization in ('linear', 'log'):
- colormap = Colormap(name='gray',
- normalization=normalization,
- vmin=None,
- vmax=None)
-
- for data in datasets:
- with self.subTest(data=data):
- image = colormap.applyToData(data)
- self.assertEqual(image.dtype, numpy.uint8)
- self.assertEqual(image.shape[-1], 4)
- self.assertEqual(image.shape[:-1], data.shape)
-
- def testGetNColors(self):
- """Test getNColors method"""
- # specific LUT
- colormap = Colormap(name=None,
- colors=((0., 0., 0.), (1., 1., 1.)),
- vmin=1000,
- vmax=2000)
- colors = colormap.getNColors()
- self.assertTrue(numpy.all(numpy.equal(
- colors,
- ((0, 0, 0, 255), (255, 255, 255, 255)))))
-
- def testEditableMode(self):
- """Make sure the colormap will raise NotEditableError when try to
- change a colormap not editable"""
- colormap = Colormap()
- colormap.setEditable(False)
- with self.assertRaises(NotEditableError):
- colormap.setVRange(0., 1.)
- with self.assertRaises(NotEditableError):
- colormap.setVMin(1.)
- with self.assertRaises(NotEditableError):
- colormap.setVMax(1.)
- with self.assertRaises(NotEditableError):
- colormap.setNormalization(Colormap.LOGARITHM)
- with self.assertRaises(NotEditableError):
- colormap.setName('magma')
- with self.assertRaises(NotEditableError):
- colormap.setColormapLUT([[0., 0., 0.], [1., 1., 1.]])
- with self.assertRaises(NotEditableError):
- colormap._setFromDict(colormap._toDict())
- state = colormap.saveState()
- with self.assertRaises(NotEditableError):
- colormap.restoreState(state)
-
- def testBadColorsType(self):
- """Make sure colors can't be something else than an array"""
- with self.assertRaises(TypeError):
- Colormap(colors=256)
-
- def testEqual(self):
- colormap1 = Colormap()
- colormap2 = Colormap()
- self.assertEqual(colormap1, colormap2)
-
- def testCompareString(self):
- colormap = Colormap()
- self.assertNotEqual(colormap, "a")
-
- def testCompareNone(self):
- colormap = Colormap()
- self.assertNotEqual(colormap, None)
-
- def testSet(self):
- colormap = Colormap()
- other = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
- self.assertNotEqual(colormap, other)
- colormap.setFromColormap(other)
- self.assertIsNot(colormap, other)
- self.assertEqual(colormap, other)
-
- def testAutoscaleMode(self):
- colormap = Colormap(autoscaleMode=Colormap.STDDEV3)
- self.assertEqual(colormap.getAutoscaleMode(), Colormap.STDDEV3)
- colormap.setAutoscaleMode(Colormap.MINMAX)
- self.assertEqual(colormap.getAutoscaleMode(), Colormap.MINMAX)
-
- def testStoreRestore(self):
- colormaps = [
- Colormap(name="viridis"),
- Colormap(normalization=Colormap.SQRT)
- ]
- cmap = Colormap(normalization=Colormap.GAMMA)
- cmap.setGammaNormalizationParameter(1.2)
- cmap.setNaNColor('red')
- colormaps.append(cmap)
- for expected in colormaps:
- with self.subTest(colormap=expected):
- state = expected.saveState()
- result = Colormap()
- result.restoreState(state)
- self.assertEqual(expected, result)
-
- def testStorageV1(self):
- state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00\x00'\
- b'\x00\x01\x00\x00\x00\x0E\x00v\x00i\x00r\x00i\x00d\x00i\x00s\x00'\
- b'\x00\x00\x00\x06\x00?\xF0\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00'\
- b'l\x00o\x00g'
- state = qt.QByteArray(state)
- colormap = Colormap()
- colormap.restoreState(state)
-
- expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
- self.assertEqual(colormap, expected)
-
- def testStorageV2(self):
- state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00'\
- b'\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00'\
- b's\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00'\
- b'\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06'\
- b'\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x'
- state = qt.QByteArray(state)
- colormap = Colormap()
- colormap.restoreState(state)
-
- expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
- expected.setGammaNormalizationParameter(1.5)
- self.assertEqual(colormap, expected)
-
-
-class TestPreferredColormaps(unittest.TestCase):
- """Test get|setPreferredColormaps functions"""
-
- def setUp(self):
- # Save preferred colormaps
- self._colormaps = colors.preferredColormaps()
-
- def tearDown(self):
- # Restore saved preferred colormaps
- colors.setPreferredColormaps(self._colormaps)
-
- def test(self):
- colormaps = 'viridis', 'magma'
-
- colors.setPreferredColormaps(colormaps)
- self.assertEqual(colors.preferredColormaps(), colormaps)
-
- with self.assertRaises(ValueError):
- colors.setPreferredColormaps(())
-
- with self.assertRaises(ValueError):
- colors.setPreferredColormaps(('This is not a colormap',))
-
- colormaps = 'red', 'green'
- colors.setPreferredColormaps(('This is not a colormap',) + colormaps)
- self.assertEqual(colors.preferredColormaps(), colormaps)
-
-
-class TestRegisteredLut(unittest.TestCase):
- """Test get|setPreferredColormaps functions"""
-
- def setUp(self):
- # Save preferred colormaps
- lut = numpy.arange(8 * 3)
- lut.shape = -1, 3
- lut = lut / (8.0 * 3)
- colors.registerLUT("test_8", colors=lut, cursor_color='blue')
-
- def testColormap(self):
- colormap = Colormap("test_8")
- self.assertIsNotNone(colormap)
-
- def testCursor(self):
- color = colors.cursorColorForColormap("test_8")
- self.assertEqual(color, 'blue')
-
- def testLut(self):
- colormap = Colormap("test_8")
- colors = colormap.getNColors(8)
- self.assertEqual(len(colors), 8)
-
- def testUint8(self):
- lut = numpy.array([[255, 0, 0], [200, 0, 0], [150, 0, 0]], dtype="uint")
- colors.registerLUT("test_type", lut)
- colormap = colors.Colormap(name="test_type")
- lut = colormap.getNColors(3)
- self.assertEqual(lut.shape, (3, 4))
- self.assertEqual(lut[0, 0], 255)
-
- def testFloatRGB(self):
- lut = numpy.array([[1.0, 0, 0], [0.5, 0, 0], [0, 0, 0]], dtype="float")
- colors.registerLUT("test_type", lut)
- colormap = colors.Colormap(name="test_type")
- lut = colormap.getNColors(3)
- self.assertEqual(lut.shape, (3, 4))
- self.assertEqual(lut[0, 0], 255)
-
- def testFloatRGBA(self):
- lut = numpy.array([[1.0, 0, 0, 128 / 256.0], [0.5, 0, 0, 1.0], [0.0, 0, 0, 1.0]], dtype="float")
- colors.registerLUT("test_type", lut)
- colormap = colors.Colormap(name="test_type")
- lut = colormap.getNColors(3)
- self.assertEqual(lut.shape, (3, 4))
- self.assertEqual(lut[0, 0], 255)
- self.assertEqual(lut[0, 3], 128)
-
-
-class TestAutoscaleRange(ParametricTestCase):
-
- def testAutoscaleRange(self):
- nan = numpy.nan
- data_std_inside = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])
- data_std_inside_nan = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan])
- data = [
- # Positive values
- (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50]), (10, 50)),
- (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100]), (10, 100)),
- (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside, (0.026671473215424735, 1.9733285267845753)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside, (1, 1.6733506885453602)),
- (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
-
- # With nan
- (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50, nan]), (10, 50)),
- (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, nan]), (10, 100)),
- (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside_nan, (0.026671473215424735, 1.9733285267845753)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside_nan, (1, 1.6733506885453602)),
- # With negative
- (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, -50]), (10, 100)),
- (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (10, 100)),
- ]
- for norm, mode, array, expectedRange in data:
- with self.subTest(norm=norm, mode=mode, array=array):
- colormap = Colormap()
- colormap.setNormalization(norm)
- colormap.setAutoscaleMode(mode)
- vRange = colormap._computeAutoscaleRange(array)
- if vRange is None:
- self.assertIsNone(expectedRange)
- else:
- self.assertAlmostEqual(vRange[0], expectedRange[0])
- self.assertAlmostEqual(vRange[1], expectedRange[1])
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestApplyColormapToData))
- test_suite.addTest(loadTests(TestColor))
- test_suite.addTest(loadTests(TestDictAPI))
- test_suite.addTest(loadTests(TestObjectAPI))
- test_suite.addTest(loadTests(TestPreferredColormaps))
- test_suite.addTest(loadTests(TestRegisteredLut))
- test_suite.addTest(loadTests(TestAutoscaleRange))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/test_console.py b/silx/gui/test/test_console.py
deleted file mode 100644
index 7db5f12..0000000
--- a/silx/gui/test/test_console.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for IPython console widget"""
-
-from __future__ import print_function
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-
-import unittest
-
-from silx.gui.utils.testutils import TestCaseQt
-
-from silx.gui import qt
-try:
- from silx.gui.console import IPythonDockWidget
-except ImportError:
- console_missing = True
-else:
- console_missing = False
-
-
-# dummy objects to test pushing variables to the interactive namespace
-_a = 1
-
-
-def _f():
- print("Hello World!")
-
-
-@unittest.skipIf(console_missing, "Could not import Ipython and/or qtconsole")
-class TestConsole(TestCaseQt):
- """Basic test for ``module.IPythonDockWidget``"""
-
- def setUp(self):
- super(TestConsole, self).setUp()
- self.console = IPythonDockWidget(
- available_vars={"a": _a, "f": _f},
- custom_banner="Welcome!\n")
- self.console.show()
- self.qWaitForWindowExposed(self.console)
-
- def tearDown(self):
- self.console.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.console.close()
- del self.console
- super(TestConsole, self).tearDown()
-
- def testShow(self):
- pass
-
- def testInteract(self):
- self.mouseClick(self.console, qt.Qt.LeftButton)
- self.keyClicks(self.console, 'import silx')
- self.keyClick(self.console, qt.Qt.Key_Enter)
- self.qapp.processEvents()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestConsole))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/test_icons.py b/silx/gui/test/test_icons.py
deleted file mode 100644
index 1757f30..0000000
--- a/silx/gui/test/test_icons.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic test of Qt icons module."""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "06/09/2017"
-
-
-import gc
-import unittest
-import weakref
-import tempfile
-import shutil
-import os
-
-import silx.resources
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import icons
-
-
-class TestIcons(TestCaseQt):
- """Test to check that icons module."""
-
- @classmethod
- def setUpClass(cls):
- super(TestIcons, cls).setUpClass()
-
- cls.tmpDirectory = tempfile.mkdtemp(prefix="resource_")
- os.mkdir(os.path.join(cls.tmpDirectory, "gui"))
- destination = os.path.join(cls.tmpDirectory, "gui", "icons")
- os.mkdir(destination)
- shutil.copy(silx.resources.resource_filename("gui/icons/zoom-in.png"), destination)
- shutil.copy(silx.resources.resource_filename("gui/icons/zoom-out.svg"), destination)
-
- @classmethod
- def tearDownClass(cls):
- super(TestIcons, cls).tearDownClass()
- shutil.rmtree(cls.tmpDirectory)
-
- def setUp(self):
- # Store the original configuration
- self._oldResources = dict(silx.resources._RESOURCE_DIRECTORIES)
- silx.resources.register_resource_directory("test", "foo.bar", forced_path=self.tmpDirectory)
- unittest.TestCase.setUp(self)
-
- def tearDown(self):
- unittest.TestCase.tearDown(self)
- # Restiture the original configuration
- silx.resources._RESOURCE_DIRECTORIES = self._oldResources
-
- def testIcon(self):
- icon = icons.getQIcon("silx:gui/icons/zoom-out")
- self.assertIsNotNone(icon)
-
- def testPrefix(self):
- icon = icons.getQIcon("silx:gui/icons/zoom-out")
- self.assertIsNotNone(icon)
-
- def testSvgIcon(self):
- if "svg" not in qt.supportedImageFormats():
- self.skipTest("SVG not supported")
- icon = icons.getQIcon("test:gui/icons/zoom-out")
- self.assertIsNotNone(icon)
-
- def testPngIcon(self):
- icon = icons.getQIcon("test:gui/icons/zoom-in")
- self.assertIsNotNone(icon)
-
- def testUnexistingIcon(self):
- self.assertRaises(ValueError, icons.getQIcon, "not-exists")
-
- def testExistingQPixmap(self):
- icon = icons.getQPixmap("crop")
- self.assertIsNotNone(icon)
-
- def testUnexistingQPixmap(self):
- self.assertRaises(ValueError, icons.getQPixmap, "not-exists")
-
- def testCache(self):
- icon1 = icons.getQIcon("crop")
- icon2 = icons.getQIcon("crop")
- self.assertIs(icon1, icon2)
-
- def testCacheReleased(self):
- icon = icons.getQIcon("crop")
- icon_ref = weakref.ref(icon)
- del icon
- gc.collect()
- self.assertIsNone(icon_ref())
-
-
-class TestAnimatedIcons(TestCaseQt):
- """Test to check that icons module."""
-
- def testProcessWorking(self):
- icon = icons.getWaitIcon()
- self.assertIsNotNone(icon)
-
- def testProcessWorkingCache(self):
- icon1 = icons.getWaitIcon()
- icon2 = icons.getWaitIcon()
- self.assertIs(icon1, icon2)
-
- def testMovieIconExists(self):
- if "mng" not in qt.supportedImageFormats():
- self.skipTest("MNG not supported")
- icon = icons.MovieAnimatedIcon("process-working")
- self.assertIsNotNone(icon)
-
- def testMovieIconNotExists(self):
- self.assertRaises(ValueError, icons.MovieAnimatedIcon, "not-exists")
-
- def testMultiImageIconExists(self):
- icon = icons.MultiImageAnimatedIcon("process-working")
- self.assertIsNotNone(icon)
-
- def testPrefixedResourceExists(self):
- icon = icons.MultiImageAnimatedIcon("silx:gui/icons/process-working")
- self.assertIsNotNone(icon)
-
- def testMultiImageIconNotExists(self):
- self.assertRaises(ValueError, icons.MultiImageAnimatedIcon, "not-exists")
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestIcons))
- test_suite.addTest(loadTests(TestAnimatedIcons))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/test/test_qt.py b/silx/gui/test/test_qt.py
deleted file mode 100644
index 0d10620..0000000
--- a/silx/gui/test/test_qt.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic test of Qt bindings wrapper."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-
-import os.path
-import unittest
-
-from silx.test.utils import temp_dir
-from silx.gui.utils.testutils import TestCaseQt
-
-from silx.gui import qt
-try:
- from silx.gui.qt import inspect as qt_inspect
-except ImportError:
- qt_inspect = None
-
-
-class TestQtWrapper(unittest.TestCase):
- """Minimalistic test to check that Qt has been loaded."""
-
- def testQObject(self):
- """Test that QObject is there."""
- obj = qt.QObject()
- self.assertTrue(obj is not None)
-
-
-class TestLoadUi(TestCaseQt):
- """Test loadUi function"""
-
- TEST_UI = """<?xml version="1.0" encoding="UTF-8"?>
- <ui version="4.0">
- <class>MainWindow</class>
- <widget class="QMainWindow" name="MainWindow">
- <property name="geometry">
- <rect>
- <x>0</x>
- <y>0</y>
- <width>293</width>
- <height>296</height>
- </rect>
- </property>
- <property name="windowTitle">
- <string>Test loadUi</string>
- </property>
- <widget class="QWidget" name="centralwidget">
- <widget class="QPushButton" name="pushButton">
- <property name="geometry">
- <rect>
- <x>10</x>
- <y>10</y>
- <width>89</width>
- <height>27</height>
- </rect>
- </property>
- <property name="text">
- <string>Button 1</string>
- </property>
- </widget>
- <widget class="QPushButton" name="pushButton_2">
- <property name="geometry">
- <rect>
- <x>10</x>
- <y>50</y>
- <width>89</width>
- <height>27</height>
- </rect>
- </property>
- <property name="text">
- <string>Button 2</string>
- </property>
- </widget>
- <widget class="Line" name="line">
- <property name="geometry">
- <rect>
- <x>10</x>
- <y>90</y>
- <width>118</width>
- <height>3</height>
- </rect>
- </property>
- <property name="orientation">
- <enum>Qt::Horizontal</enum>
- </property>
- </widget>
- <widget class="Line" name="line_2">
- <property name="geometry">
- <rect>
- <x>150</x>
- <y>20</y>
- <width>3</width>
- <height>61</height>
- </rect>
- </property>
- <property name="orientation">
- <enum>Qt::Vertical</enum>
- </property>
- </widget>
- </widget>
- <widget class="QMenuBar" name="menubar">
- <property name="geometry">
- <rect>
- <x>0</x>
- <y>0</y>
- <width>293</width>
- <height>25</height>
- </rect>
- </property>
- </widget>
- <widget class="QStatusBar" name="statusbar"/>
- </widget>
- <resources/>
- <connections/>
- </ui>
- """
-
- @unittest.skipIf(qt.BINDING == "PySide", "Not fully working with PySide")
- def testLoadUi(self):
- """Create a QMainWindow from an ui file"""
- with temp_dir() as tmp:
- uifile = os.path.join(tmp, "test.ui")
-
- # write file
- with open(uifile, mode='w') as f:
- f.write(self.TEST_UI)
-
- class TestMainWindow(qt.QMainWindow):
- def __init__(self, parent=None):
- super(TestMainWindow, self).__init__(parent)
- qt.loadUi(uifile, self)
-
- testMainWindow = TestMainWindow()
- testMainWindow.show()
- self.qWaitForWindowExposed(testMainWindow)
-
- testMainWindow.setAttribute(qt.Qt.WA_DeleteOnClose)
- testMainWindow.close()
-
-
-class TestQtInspect(unittest.TestCase):
- """Test functions of silx.gui.qt.inspect module"""
-
- # shiboken module is not always available
- @unittest.skipIf(qt.BINDING == 'PySide' and qt_inspect is None,
- reason="shiboken module not available")
- def test(self):
- """Test functions of silx.gui.qt.inspect module"""
- self.assertIsNotNone(qt_inspect)
-
- parent = qt.QObject()
-
- self.assertTrue(qt_inspect.isValid(parent))
- self.assertTrue(qt_inspect.createdByPython(parent))
- self.assertTrue(qt_inspect.ownedByPython(parent))
-
- obj = qt.QObject(parent)
-
- self.assertTrue(qt_inspect.isValid(obj))
- self.assertTrue(qt_inspect.createdByPython(obj))
- self.assertFalse(qt_inspect.ownedByPython(obj))
-
- del parent
- self.assertFalse(qt_inspect.isValid(obj))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestCaseCls in (TestQtWrapper, TestLoadUi, TestQtInspect):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestCaseCls))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/utils/glutils/__init__.py b/silx/gui/utils/glutils/__init__.py
deleted file mode 100644
index c90f029..0000000
--- a/silx/gui/utils/glutils/__init__.py
+++ /dev/null
@@ -1,199 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides the :func:`isOpenGLAvailable` utility function.
-"""
-
-import os
-import sys
-import subprocess
-from silx.gui import qt
-
-
-class _isOpenGLAvailableResult:
- """Store result of checking OpenGL availability.
-
- It provides a `status` boolean attribute storing the result of the check and
- an `error` string attribute storting the possible error message.
- """
-
- def __init__(self, status=True, error=''):
- self.__status = bool(status)
- self.__error = str(error)
-
- status = property(lambda self: self.__status, doc="True if OpenGL is working")
- error = property(lambda self: self.__error, doc="Error message")
-
- def __bool__(self):
- return self.status
-
- def __repr__(self):
- return '<_isOpenGLAvailableResult: %s, "%s">' % (self.status, self.error)
-
-
-def _runtimeOpenGLCheck(version):
- """Run OpenGL check in a subprocess.
-
- This is done by starting a subprocess that displays a Qt OpenGL widget.
-
- :param List[int] version:
- The minimal required OpenGL version as a 2-tuple (major, minor).
- Default: (2, 1)
- :return: An error string that is empty if no error occured
- :rtype: str
- """
- major, minor = str(version[0]), str(version[1])
- env = os.environ.copy()
- env['PYTHONPATH'] = os.pathsep.join(
- [os.path.abspath(p) for p in sys.path])
-
- try:
- error = subprocess.check_output(
- [sys.executable, '-s', '-S', __file__, major, minor],
- env=env,
- timeout=2)
- except subprocess.TimeoutExpired:
- status = False
- error = "Qt OpenGL widget hang"
- if sys.platform.startswith('linux'):
- error += ':\nIf connected remotely, GLX forwarding might be disabled.'
- except subprocess.CalledProcessError as e:
- status = False
- error = "Qt OpenGL widget error: retcode=%d, error=%s" % (e.returncode, e.output)
- else:
- status = True
- error = error.decode()
- return _isOpenGLAvailableResult(status, error)
-
-
-_runtimeCheckCache = {} # Cache runtime check results: {version: result}
-
-
-def isOpenGLAvailable(version=(2, 1), runtimeCheck=True):
- """Check if OpenGL is available through Qt and actually working.
-
- After some basic tests, this is done by starting a subprocess that
- displays a Qt OpenGL widget.
-
- :param List[int] version:
- The minimal required OpenGL version as a 2-tuple (major, minor).
- Default: (2, 1)
- :param bool runtimeCheck:
- True (default) to run the test creating a Qt OpenGL widgt in a subprocess,
- False to avoid this check.
- :return: A result object that evaluates to True if successful and
- which has a `status` boolean attribute (True if successful) and
- an `error` string attribute that is not empty if `status` is False.
- """
- error = ''
-
- if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
- # On Linux and no DISPLAY available (e.g., ssh without -X)
- error = 'DISPLAY environment variable not set'
-
- else:
- # Check pyopengl availability
- try:
- import silx.gui._glutils.gl # noqa
- except ImportError:
- error = "Cannot import OpenGL wrapper: pyopengl is not installed"
- else:
- # Pre checks for Qt < 5.4
- if not hasattr(qt, 'QOpenGLWidget'):
- if not qt.HAS_OPENGL:
- error = '%s.QtOpenGL not available' % qt.BINDING
-
- elif qt.QApplication.instance() and not qt.QGLFormat.hasOpenGL():
- # qt.QGLFormat.hasOpenGL MUST be called with a QApplication created
- # so this is only checked if the QApplication is already created
- error = 'Qt reports OpenGL not available'
-
- result = _isOpenGLAvailableResult(error == '', error)
-
- if result: # No error so far, runtime check
- if version in _runtimeCheckCache: # Use cache
- result = _runtimeCheckCache[version]
- elif runtimeCheck: # Run test in subprocess
- result = _runtimeOpenGLCheck(version)
- _runtimeCheckCache[version] = result
-
- return result
-
-
-if __name__ == "__main__":
- from silx.gui._glutils import OpenGLWidget
- from silx.gui._glutils import gl
- import argparse
-
- class _TestOpenGLWidget(OpenGLWidget):
- """Widget checking that OpenGL is indeed available
-
- :param List[int] version: (major, minor) minimum OpenGL version
- """
-
- def __init__(self, version):
- super(_TestOpenGLWidget, self).__init__(
- alphaBufferSize=0,
- depthBufferSize=0,
- stencilBufferSize=0,
- version=version)
-
- def paintEvent(self, event):
- super(_TestOpenGLWidget, self).paintEvent(event)
-
- # Check once paint has been done
- app = qt.QApplication.instance()
- if not self.isValid():
- print("OpenGL widget is not valid")
- app.exit(1)
- else:
- qt.QTimer.singleShot(100, app.quit)
-
- def paintGL(self):
- gl.glClearColor(1., 0., 0., 0.)
- gl.glClear(gl.GL_COLOR_BUFFER_BIT)
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument('major')
- parser.add_argument('minor')
-
- args = parser.parse_args(args=sys.argv[1:])
-
- app = qt.QApplication([])
- window = qt.QMainWindow(flags=
- qt.Qt.Popup |
- qt.Qt.FramelessWindowHint |
- qt.Qt.NoDropShadowWindowHint |
- qt.Qt.WindowStaysOnTopHint)
- window.setAttribute(qt.Qt.WA_ShowWithoutActivating)
- window.move(0, 0)
- window.resize(3, 3)
- widget = _TestOpenGLWidget(version=(args.major, args.minor))
- window.setCentralWidget(widget)
- window.setWindowOpacity(0.04)
- window.show()
-
- qt.QTimer.singleShot(1000, app.quit)
- sys.exit(app.exec_())
diff --git a/silx/gui/utils/image.py b/silx/gui/utils/image.py
deleted file mode 100644
index 3ac737f..0000000
--- a/silx/gui/utils/image.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides conversions between numpy.ndarray and QImage
-
-- :func:`convertArrayToQImage`
-- :func:`convertQImageToArray`
-"""
-
-from __future__ import division
-
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "04/09/2018"
-
-
-import sys
-import numpy
-from numpy.lib.stride_tricks import as_strided as _as_strided
-
-from .. import qt
-
-
-def convertArrayToQImage(array):
- """Convert an array-like image to a QImage.
-
- The created QImage is using a copy of the array data.
-
- Limitation: Only RGB or RGBA images with 8 bits per channel are supported.
-
- :param array: Array-like image data of shape (height, width, channels)
- Channels are expected to be either RGB or RGBA.
- :type array: numpy.ndarray of uint8
- :return: Corresponding Qt image with RGB888 or ARGB32 format.
- :rtype: QImage
- """
- array = numpy.array(array, copy=False, order='C', dtype=numpy.uint8)
-
- if array.ndim != 3 or array.shape[2] not in (3, 4):
- raise ValueError(
- 'Image must be a 3D array with 3 or 4 channels per pixel')
-
- if array.shape[2] == 4:
- format_ = qt.QImage.Format_ARGB32
- # RGBA -> ARGB + take care of endianness
- if sys.byteorder == 'little': # RGBA -> BGRA
- array = array[:, :, (2, 1, 0, 3)]
- else: # big endian: RGBA -> ARGB
- array = array[:, :, (3, 0, 1, 2)]
-
- array = numpy.array(array, order='C') # Make a contiguous array
-
- else: # array.shape[2] == 3
- format_ = qt.QImage.Format_RGB888
-
- height, width, depth = array.shape
- qimage = qt.QImage(
- array.data,
- width,
- height,
- array.strides[0], # bytesPerLine
- format_)
-
- return qimage.copy() # Making a copy of the image and its data
-
-
-def convertQImageToArray(image):
- """Convert a QImage to a numpy array.
-
- If QImage format is not Format_RGB888, Format_RGBA8888 or Format_ARGB32,
- it is first converted to one of this format depending on
- the presence of an alpha channel.
-
- The created numpy array is using a copy of the QImage data.
-
- :param QImage image: The QImage to convert.
- :return: The image array of RGB or RGBA channels of shape
- (height, width, channels (3 or 4))
- :rtype: numpy.ndarray of uint8
- """
- rgba8888 = getattr(qt.QImage, 'Format_RGBA8888', None) # Only in Qt5
-
- # Convert to supported format if needed
- if image.format() not in (qt.QImage.Format_ARGB32,
- qt.QImage.Format_RGB888,
- rgba8888):
- if image.hasAlphaChannel():
- image = image.convertToFormat(
- rgba8888 if rgba8888 is not None else qt.QImage.Format_ARGB32)
- else:
- image = image.convertToFormat(qt.QImage.Format_RGB888)
-
- format_ = image.format()
- channels = 3 if format_ == qt.QImage.Format_RGB888 else 4
-
- ptr = image.bits()
- if qt.BINDING not in ('PySide', 'PySide2'):
- ptr.setsize(image.byteCount())
- if qt.BINDING == 'PyQt4' and sys.version_info[0] == 2:
- ptr = ptr.asstring()
- elif sys.version_info[0] == 3: # PySide with Python3
- ptr = ptr.tobytes()
-
- # Create an array view on QImage internal data
- view = _as_strided(
- numpy.frombuffer(ptr, dtype=numpy.uint8),
- shape=(image.height(), image.width(), channels),
- strides=(image.bytesPerLine(), channels, 1))
-
- if format_ == qt.QImage.Format_ARGB32:
- # Convert from ARGB to RGBA
- # Not a byte-ordered format: do care about endianness
- if sys.byteorder == 'little': # BGRA -> RGBA
- view = view[:, :, (2, 1, 0, 3)]
- else: # big endian: ARGB -> RGBA
- view = view[:, :, (1, 2, 3, 0)]
-
- # Format_RGB888 and Format_RGBA8888 do not need reshuffling channels:
- # They are byte-ordered and already in the right order
-
- return numpy.array(view, copy=True, order='C')
diff --git a/silx/gui/utils/matplotlib.py b/silx/gui/utils/matplotlib.py
deleted file mode 100644
index 484e01a..0000000
--- a/silx/gui/utils/matplotlib.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-from __future__ import absolute_import
-
-"""This module initializes matplotlib and sets-up the backend to use.
-
-It MUST be imported prior to any other import of matplotlib.
-
-It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
-to the used backend.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/05/2018"
-
-
-from pkg_resources import parse_version
-import matplotlib
-
-from .. import qt
-
-
-def _matplotlib_use(backend, force):
- """Wrapper of `matplotlib.use` to set-up backend.
-
- It adds extra initialization for PySide and PySide2 with matplotlib < 2.2.
- """
- # This is kept for compatibility with matplotlib < 2.2
- if parse_version(matplotlib.__version__) < parse_version('2.2'):
- if qt.BINDING == 'PySide':
- matplotlib.rcParams['backend.qt4'] = 'PySide'
- if qt.BINDING == 'PySide2':
- matplotlib.rcParams['backend.qt5'] = 'PySide2'
-
- matplotlib.use(backend, force=force)
-
-
-if qt.BINDING in ('PyQt4', 'PySide'):
- _matplotlib_use('Qt4Agg', force=False)
- from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg # noqa
-
-elif qt.BINDING in ('PyQt5', 'PySide2'):
- _matplotlib_use('Qt5Agg', force=False)
- from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
-
-else:
- raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
diff --git a/silx/gui/utils/test/__init__.py b/silx/gui/utils/test/__init__.py
deleted file mode 100755
index 41e0d6a..0000000
--- a/silx/gui/utils/test/__init__.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""silx.gui.utils tests"""
-
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "24/04/2018"
-
-
-import unittest
-
-from . import test_async
-from . import test_glutils
-from . import test_image
-from . import test_qtutils
-from . import test_testutils
-from . import test
-
-
-def suite():
- """Test suite for module silx.image.test"""
- test_suite = unittest.TestSuite()
- test_suite.addTest(test.suite())
- test_suite.addTest(test_async.suite())
- test_suite.addTest(test_glutils.suite())
- test_suite.addTest(test_image.suite())
- test_suite.addTest(test_qtutils.suite())
- test_suite.addTest(test_testutils.suite())
- return test_suite
-
-
-if __name__ == "__main__":
- unittest.main(defaultTest="suite")
diff --git a/silx/gui/utils/test/test.py b/silx/gui/utils/test/test.py
deleted file mode 100644
index 8bba852..0000000
--- a/silx/gui/utils/test/test.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of functions available in silx.gui.utils module."""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/08/2019"
-
-
-import unittest
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt, SignalListener
-
-from silx.gui.utils import blockSignals
-
-
-class TestBlockSignals(TestCaseQt):
- """Test blockSignals context manager"""
-
- def _test(self, *objs):
- """Test for provided objects"""
- listener = SignalListener()
- for obj in objs:
- obj.objectNameChanged.connect(listener)
- obj.setObjectName("received")
-
- with blockSignals(*objs):
- for obj in objs:
- obj.setObjectName("silent")
-
- self.assertEqual(listener.arguments(), [("received",)] * len(objs))
-
- @unittest.skipUnless(qt.BINDING in ('PyQt5', 'PySide2'), 'Qt5 only test')
- def testManyObjects(self):
- """Test blockSignals with 2 QObjects"""
- self._test(qt.QObject(), qt.QObject())
-
- @unittest.skipUnless(qt.BINDING in ('PyQt5', 'PySide2'), 'Qt5 only test')
- def testOneObject(self):
- """Test blockSignals context manager with a single QObject"""
- self._test(qt.QObject())
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- TestBlockSignals))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/utils/test/test_async.py b/silx/gui/utils/test/test_async.py
deleted file mode 100644
index dcfde1d..0000000
--- a/silx/gui/utils/test/test_async.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of async module."""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "09/03/2018"
-
-
-import threading
-import unittest
-
-
-from concurrent.futures import wait
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-
-from silx.gui.utils import concurrent
-
-
-class TestSubmitToQtThread(TestCaseQt):
- """Test submission of tasks to Qt main thread"""
-
- def setUp(self):
- # Reset executor to test lazy-loading in different conditions
- concurrent._executor = None
- super(TestSubmitToQtThread, self).setUp()
-
- def _task(self, value1, value2):
- return value1, value2
-
- def _taskWithException(self, *args, **kwargs):
- raise RuntimeError('task exception')
-
- def testFromMainThread(self):
- """Call submitToQtMainThread from the main thread"""
- value1, value2 = 0, 1
- future = concurrent.submitToQtMainThread(self._task, value1, value2=value2)
- self.assertTrue(future.done())
- self.assertEqual(future.result(1), (value1, value2))
- self.assertIsNone(future.exception(1))
-
- future = concurrent.submitToQtMainThread(self._taskWithException)
- self.assertTrue(future.done())
- with self.assertRaises(RuntimeError):
- future.result(1)
- self.assertIsInstance(future.exception(1), RuntimeError)
-
- def _threadedTest(self):
- """Function run in a thread for the tests"""
- value1, value2 = 0, 1
- future = concurrent.submitToQtMainThread(self._task, value1, value2=value2)
-
- wait([future], 3)
-
- self.assertTrue(future.done())
- self.assertEqual(future.result(1), (value1, value2))
- self.assertIsNone(future.exception(1))
-
- future = concurrent.submitToQtMainThread(self._taskWithException)
-
- wait([future], 3)
-
- self.assertTrue(future.done())
- with self.assertRaises(RuntimeError):
- future.result(1)
- self.assertIsInstance(future.exception(1), RuntimeError)
-
- def testFromPythonThread(self):
- """Call submitToQtMainThread from a Python thread"""
- thread = threading.Thread(target=self._threadedTest)
- thread.start()
- for i in range(100): # Loop over for 10 seconds
- self.qapp.processEvents()
- thread.join(0.1)
- if not thread.is_alive():
- break
- else:
- self.fail(('Thread task still running'))
-
- def testFromQtThread(self):
- """Call submitToQtMainThread from a Qt thread pool"""
- class Runner(qt.QRunnable):
- def __init__(self, fn):
- super(Runner, self).__init__()
- self._fn = fn
-
- def run(self):
- self._fn()
-
- def autoDelete(self):
- return True
-
- threadPool = qt.silxGlobalThreadPool()
- runner = Runner(self._threadedTest)
- threadPool.start(runner)
- for i in range(100): # Loop over for 10 seconds
- self.qapp.processEvents()
- done = threadPool.waitForDone(100)
- if done:
- break
- else:
- self.fail('Thread pool task still running')
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- TestSubmitToQtThread))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/utils/test/test_glutils.py b/silx/gui/utils/test/test_glutils.py
deleted file mode 100644
index 66df8cf..0000000
--- a/silx/gui/utils/test/test_glutils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for the silx.gui.utils.glutils module."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "15/01/2020"
-
-
-import logging
-import unittest
-from silx.gui.utils.glutils import isOpenGLAvailable
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestIsOpenGLAvailable(unittest.TestCase):
- """Test isOpenGLAvailable"""
-
- def test(self):
- for version in ((2, 1), (2, 1), (1000, 1)):
- with self.subTest(version=version):
- result = isOpenGLAvailable(version=version)
- _logger.info("isOpenGLAvailable returned: %s", str(result))
- if version[0] == 1000:
- self.assertFalse(result)
- if not result:
- self.assertFalse(result.status)
- self.assertTrue(len(result.error) > 0)
- else:
- self.assertTrue(result.status)
- self.assertTrue(len(result.error) == 0)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- TestIsOpenGLAvailable))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/utils/test/test_image.py b/silx/gui/utils/test/test_image.py
deleted file mode 100644
index cda7d95..0000000
--- a/silx/gui/utils/test/test_image.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of utils module."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "16/01/2017"
-
-import numpy
-import unittest
-
-from silx.gui import qt
-from silx.utils.testutils import ParametricTestCase
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.utils.image import convertArrayToQImage, convertQImageToArray
-
-
-class TestQImageConversion(TestCaseQt, ParametricTestCase):
- """Tests conversion of QImage to/from numpy array."""
-
- def testConvertArrayToQImage(self):
- """Test conversion of numpy array to QImage"""
- for format_, channels in [('Format_RGB888', 3),
- ('Format_ARGB32', 4)]:
- with self.subTest(format_):
- image = numpy.arange(
- 3*3*channels, dtype=numpy.uint8).reshape(3, 3, channels)
- qimage = convertArrayToQImage(image)
-
- self.assertEqual(qimage.height(), image.shape[0])
- self.assertEqual(qimage.width(), image.shape[1])
- self.assertEqual(qimage.format(), getattr(qt.QImage, format_))
-
- for row in range(3):
- for col in range(3):
- # Qrgb has no alpha channel, not compared
- # Qt uses x,y while array is row,col...
- self.assertEqual(qt.QColor(qimage.pixel(col, row)),
- qt.QColor(*image[row, col, :3]))
-
-
- def testConvertQImageToArray(self):
- """Test conversion of QImage to numpy array"""
- for format_, channels in [
- ('Format_RGB888', 3), # Native support
- ('Format_ARGB32', 4), # Native support
- ('Format_RGB32', 3)]: # Conversion to RGB
- with self.subTest(format_):
- color = numpy.arange(channels) # RGB(A) values
- qimage = qt.QImage(3, 3, getattr(qt.QImage, format_))
- qimage.fill(qt.QColor(*color))
- image = convertQImageToArray(qimage)
-
- self.assertEqual(qimage.height(), image.shape[0])
- self.assertEqual(qimage.width(), image.shape[1])
- self.assertEqual(image.shape[2], len(color))
- self.assertTrue(numpy.all(numpy.equal(image, color)))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(
- TestQImageConversion))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/utils/test/test_qtutils.py b/silx/gui/utils/test/test_qtutils.py
deleted file mode 100755
index 043a0a6..0000000
--- a/silx/gui/utils/test/test_qtutils.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of functions available in silx.gui.utils module."""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/08/2019"
-
-
-import unittest
-from silx.gui import qt
-from silx.gui import utils
-from silx.gui.utils.testutils import TestCaseQt
-
-
-class TestQEventName(TestCaseQt):
- """Test QEvent names"""
-
- def testNoneType(self):
- result = utils.getQEventName(0)
- self.assertEqual(result, "None")
-
- def testNoneEvent(self):
- event = qt.QEvent(qt.QEvent.Type(0))
- result = utils.getQEventName(event)
- self.assertEqual(result, "None")
-
- def testUserType(self):
- result = utils.getQEventName(1050)
- self.assertIn("User", result)
- self.assertIn("1050", result)
-
- def testQtUndefinedType(self):
- result = utils.getQEventName(900)
- self.assertIn("Unknown", result)
- self.assertIn("900", result)
-
- def testUndefinedType(self):
- result = utils.getQEventName(70000)
- self.assertIn("Unknown", result)
- self.assertIn("70000", result)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestQEventName))
- return test_suite
-
-
-if __name__ == "__main__":
- unittest.main(defaultTest="suite")
diff --git a/silx/gui/utils/test/test_testutils.py b/silx/gui/utils/test/test_testutils.py
deleted file mode 100644
index 8a58e6e..0000000
--- a/silx/gui/utils/test/test_testutils.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of testutils module."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "16/01/2017"
-
-import unittest
-import sys
-
-from silx.gui import qt
-from ..testutils import TestCaseQt
-
-
-class TestOutcome(unittest.TestCase):
- """Tests conversion of QImage to/from numpy array."""
-
- @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
- def testNoneOutcome(self):
- test = TestCaseQt()
- test._currentTestSucceeded()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loader(TestOutcome))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/utils/testutils.py b/silx/gui/utils/testutils.py
deleted file mode 100644
index 30b9e34..0000000
--- a/silx/gui/utils/testutils.py
+++ /dev/null
@@ -1,518 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Helper class to write Qt widget unittests."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "05/10/2018"
-
-
-import gc
-import logging
-import unittest
-import time
-import functools
-import sys
-import os
-
-_logger = logging.getLogger(__name__)
-
-from silx.gui import qt
-from silx.gui.qt import inspect as _inspect
-
-
-if qt.BINDING == 'PySide':
- from PySide.QtTest import QTest
-elif qt.BINDING == 'PySide2':
- from PySide2.QtTest import QTest
-elif qt.BINDING == 'PyQt5':
- from PyQt5.QtTest import QTest
-elif qt.BINDING == 'PyQt4':
- from PyQt4.QtTest import QTest
-else:
- raise ImportError('Unsupported Qt bindings')
-
-# Qt4/Qt5 compatibility wrapper
-if qt.BINDING in ('PySide', 'PyQt4'):
- _logger.info("QTest.qWaitForWindowExposed not available," +
- "using QTest.qWaitForWindowShown instead.")
-
- def qWaitForWindowExposed(window, timeout=None):
- """Mimic QTest.qWaitForWindowExposed for Qt4."""
- QTest.qWaitForWindowShown(window)
- return True
-else:
- qWaitForWindowExposed = QTest.qWaitForWindowExposed
-
-
-def qWaitForWindowExposedAndActivate(window, timeout=None):
- """Waits until the window is shown in the screen.
-
- It also activates the window and raises it.
-
- See QTest.qWaitForWindowExposed for details.
- """
- if timeout is None:
- result = qWaitForWindowExposed(window)
- else:
- result = qWaitForWindowExposed(window, timeout)
-
- if result:
- # Makes sure window is active and on top
- window.activateWindow()
- window.raise_()
-
- return result
-
-
-class TestCaseQt(unittest.TestCase):
- """Base class to write test for Qt stuff.
-
- It creates a QApplication before running the tests.
- WARNING: The QApplication is shared by all tests, which might have side
- effects.
-
- After each test, this class is checking for widgets remaining alive.
- To allow some widgets to remain alive at the end of a test, set the
- allowedLeakingWidgets attribute to the number of widgets that can remain
- alive at the end of the test.
- With PySide, this test is not run for now as it seems PySide
- is leaking widgets internally.
-
- All keyboard and mouse event simulation methods call qWait(20) after
- simulating the event (as QTest does on Mac OSX).
- This was introduced to fix issues with continuous integration tests
- running with Xvfb on Linux.
- """
-
- DEFAULT_TIMEOUT_WAIT = 100
- """Default timeout for qWait"""
-
- TIMEOUT_WAIT = 0
- """Extra timeout in millisecond to add to qSleep, qWait and
- qWaitForWindowExposed.
-
- Intended purpose is for debugging, to add extra time to waits in order to
- allow to view the tested widgets.
- """
-
- _qapp = None
- """Placeholder for QApplication"""
-
- @classmethod
- def exceptionHandler(cls, exceptionClass, exception, stack):
- import traceback
- message = (''.join(traceback.format_tb(stack)))
- template = 'Traceback (most recent call last):\n{2}{0}: {1}'
- message = template.format(exceptionClass.__name__, exception, message)
- cls._exceptions.append(message)
-
- @classmethod
- def setUpClass(cls):
- """Makes sure Qt is inited"""
- cls._oldExceptionHook = sys.excepthook
- sys.excepthook = cls.exceptionHandler
-
- # Makes sure a QApplication exists and do it once for all
- if not qt.QApplication.instance():
- cls._qapp = qt.QApplication([])
-
- @classmethod
- def tearDownClass(cls):
- sys.excepthook = cls._oldExceptionHook
-
- def setUp(self):
- """Get the list of existing widgets."""
- self.allowedLeakingWidgets = 0
- self.__previousWidgets = self.qapp.allWidgets()
- self.__class__._exceptions = []
-
- def _currentTestSucceeded(self):
- if hasattr(self, '_outcome'):
- # For Python >= 3.4
- result = self.defaultTestResult() # these 2 methods have no side effects
- if hasattr(self._outcome, 'errors'):
- self._feedErrorsToResult(result, self._outcome.errors)
- else:
- # For Python < 3.4
- result = getattr(self, '_outcomeForDoCleanups', self._resultForDoCleanups)
-
- skipped = self.id() in [case.id() for case, _ in result.skipped]
- error = self.id() in [case.id() for case, _ in result.errors]
- failure = self.id() in [case.id() for case, _ in result.failures]
- return not error and not failure and not skipped
-
- def _checkForUnreleasedWidgets(self):
- """Test fixture checking that no more widgets exists."""
- gc.collect()
-
- widgets = [widget for widget in self.qapp.allWidgets()
- if (widget not in self.__previousWidgets and
- _inspect.createdByPython(widget))]
- del self.__previousWidgets
-
- if qt.BINDING in ('PySide', 'PySide2'):
- return # Do not test for leaking widgets with PySide
-
- allowedLeakingWidgets = self.allowedLeakingWidgets
- self.allowedLeakingWidgets = 0
-
- if widgets and len(widgets) <= allowedLeakingWidgets:
- _logger.info(
- '%s: %d remaining widgets after test' % (self.id(),
- len(widgets)))
-
- if len(widgets) > allowedLeakingWidgets:
- raise RuntimeError(
- "Test ended with widgets alive: %s" % str(widgets))
-
- def tearDown(self):
- if len(self.__class__._exceptions) > 0:
- messages = "\n".join(self.__class__._exceptions)
- raise AssertionError("Exception occured in Qt thread:\n" + messages)
-
- if self._currentTestSucceeded():
- self._checkForUnreleasedWidgets()
-
- @property
- def qapp(self):
- """The QApplication currently running."""
- return qt.QApplication.instance()
-
- # Proxy to QTest
-
- Press = QTest.Press
- """Key press action code"""
-
- Release = QTest.Release
- """Key release action code"""
-
- Click = QTest.Click
- """Key click action code"""
-
- QTest = property(lambda self: QTest,
- doc="""The Qt QTest class from the used Qt binding.""")
-
- def keyClick(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
- """Simulate clicking a key.
-
- See QTest.keyClick for details.
- """
- QTest.keyClick(widget, key, modifier, delay)
- self.qWait(20)
-
- def keyClicks(self, widget, sequence, modifier=qt.Qt.NoModifier, delay=-1):
- """Simulate clicking a sequence of keys.
-
- See QTest.keyClick for details.
- """
- QTest.keyClicks(widget, sequence, modifier, delay)
- self.qWait(20)
-
- def keyEvent(self, action, widget, key,
- modifier=qt.Qt.NoModifier, delay=-1):
- """Sends a Qt key event.
-
- See QTest.keyEvent for details.
- """
- QTest.keyEvent(action, widget, key, modifier, delay)
- self.qWait(20)
-
- def keyPress(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
- """Sends a Qt key press event.
-
- See QTest.keyPress for details.
- """
- QTest.keyPress(widget, key, modifier, delay)
- self.qWait(20)
-
- def keyRelease(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
- """Sends a Qt key release event.
-
- See QTest.keyRelease for details.
- """
- QTest.keyRelease(widget, key, modifier, delay)
- self.qWait(20)
-
- def mouseClick(self, widget, button, modifier=None, pos=None, delay=-1):
- """Simulate clicking a mouse button.
-
- See QTest.mouseClick for details.
- """
- if modifier is None:
- modifier = qt.Qt.KeyboardModifiers()
- pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
- QTest.mouseClick(widget, button, modifier, pos, delay)
- self.qWait(20)
-
- def mouseDClick(self, widget, button, modifier=None, pos=None, delay=-1):
- """Simulate double clicking a mouse button.
-
- See QTest.mouseDClick for details.
- """
- if modifier is None:
- modifier = qt.Qt.KeyboardModifiers()
- pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
- QTest.mouseDClick(widget, button, modifier, pos, delay)
- self.qWait(20)
-
- def mouseMove(self, widget, pos=None, delay=-1):
- """Simulate moving the mouse.
-
- See QTest.mouseMove for details.
- """
- pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
- QTest.mouseMove(widget, pos, delay)
- self.qWait(20)
-
- def mousePress(self, widget, button, modifier=None, pos=None, delay=-1):
- """Simulate pressing a mouse button.
-
- See QTest.mousePress for details.
- """
- if modifier is None:
- modifier = qt.Qt.KeyboardModifiers()
- pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
- QTest.mousePress(widget, button, modifier, pos, delay)
- self.qWait(20)
-
- def mouseRelease(self, widget, button, modifier=None, pos=None, delay=-1):
- """Simulate releasing a mouse button.
-
- See QTest.mouseRelease for details.
- """
- if modifier is None:
- modifier = qt.Qt.KeyboardModifiers()
- pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
- QTest.mouseRelease(widget, button, modifier, pos, delay)
- self.qWait(20)
-
- def qSleep(self, ms):
- """Sleep for ms milliseconds, blocking the execution of the test.
-
- See QTest.qSleep for details.
- """
- QTest.qSleep(int(ms) + self.TIMEOUT_WAIT)
-
- @classmethod
- def qWait(cls, ms=None):
- """Waits for ms milliseconds, events will be processed.
-
- See QTest.qWait for details.
- """
- if ms is None:
- ms = cls.DEFAULT_TIMEOUT_WAIT
-
- if qt.BINDING in ('PySide', 'PySide2'):
- # PySide has no qWait, provide a replacement
- timeout = int(ms)
- endTimeMS = int(time.time() * 1000) + timeout
- qapp = qt.QApplication.instance()
- while timeout > 0:
- qapp.processEvents(qt.QEventLoop.AllEvents,
- maxtime=timeout)
- timeout = endTimeMS - int(time.time() * 1000)
- else:
- QTest.qWait(int(ms) + cls.TIMEOUT_WAIT)
-
- def qWaitForWindowExposed(self, window, timeout=None):
- """Waits until the window is shown in the screen.
-
- See QTest.qWaitForWindowExposed for details.
- """
- result = qWaitForWindowExposedAndActivate(window, timeout)
-
- if self.TIMEOUT_WAIT:
- QTest.qWait(self.TIMEOUT_WAIT)
-
- return result
-
- _qobject_destroyed = False
-
- @classmethod
- def _aboutToDestroy(cls):
- cls._qobject_destroyed = True
-
- @classmethod
- def qWaitForDestroy(cls, ref):
- """
- Wait for Qt object destruction.
-
- Use a weakref as parameter to avoid any strong references to the
- object.
-
- It have to be used as following. Removing the reference to the object
- before calling the function looks to be expected, else
- :meth:`deleteLater` will not work.
-
- .. code-block:: python
-
- ref = weakref.ref(self.obj)
- self.obj = None
- self.qWaitForDestroy(ref)
-
- :param weakref ref: A weakref to an object to avoid any reference
- :return: True if the object was destroyed
- :rtype: bool
- """
- cls._qobject_destroyed = False
- if qt.BINDING == 'PyQt4':
- # Without this, QWidget will be still alive on PyQt4
- # (at least on Windows Python 2.7)
- # If it is not skipped on PySide, silx.gui.dialog tests will
- # segfault (at least on Windows Python 2.7)
- import gc
- gc.collect()
- qobject = ref()
- if qobject is None:
- return True
- qobject.destroyed.connect(cls._aboutToDestroy)
- qobject.deleteLater()
- qobject = None
- for _ in range(10):
- if cls._qobject_destroyed:
- break
- cls.qWait(10)
- else:
- _logger.debug("Object was not destroyed")
-
- return ref() is None
-
- def logScreenShot(self, level=logging.ERROR):
- """Take a screenshot and log it into the logging system if the
- logger is enabled for the expected level.
-
- The screenshot is stored in the directory "./build/test-debug", and
- the logging system only log the path to this file.
-
- :param level: Logging level
- """
- if not _logger.isEnabledFor(level):
- return
- basedir = os.path.abspath(os.path.join("build", "test-debug"))
- if not os.path.exists(basedir):
- os.makedirs(basedir)
- filename = "Screenshot_%s.png" % self.id()
- filename = os.path.join(basedir, filename)
-
- if not hasattr(self.qapp, "primaryScreen"):
- # Qt4
- winId = qt.QApplication.desktop().winId()
- pixmap = qt.QPixmap.grabWindow(winId)
- else:
- # Qt5
- screen = self.qapp.primaryScreen()
- pixmap = screen.grabWindow(0)
- pixmap.save(filename)
- _logger.log(level, "Screenshot saved at %s", filename)
-
-
-class SignalListener(object):
- """Util to listen a Qt event and store parameters
- """
-
- def __init__(self):
- self.__calls = []
-
- def __call__(self, *args, **kargs):
- self.__calls.append((args, kargs))
-
- def clear(self):
- """Clear stored data"""
- self.__calls = []
-
- def callCount(self):
- """
- Returns how many times the listener was called.
-
- :rtype: int
- """
- return len(self.__calls)
-
- def arguments(self, callIndex=None, argumentIndex=None):
- """Returns positional arguments optionally filtered by call count id
- or argument index.
-
- :param int callIndex: Index of the called data
- :param int argumentIndex: Index of the positional argument.
- """
- if callIndex is not None:
- result = self.__calls[callIndex][0]
- if argumentIndex is not None:
- result = result[argumentIndex]
- else:
- result = [x[0] for x in self.__calls]
- if argumentIndex is not None:
- result = [x[argumentIndex] for x in result]
- return result
-
- def karguments(self, callIndex=None, argumentName=None):
- """Returns positional arguments optionally filtered by call count id
- or name of the keyword argument.
-
- :param int callIndex: Index of the called data
- :param int argumentName: Name of the keyword argument.
- """
- if callIndex is not None:
- result = self.__calls[callIndex][1]
- if argumentName is not None:
- result = result[argumentName]
- else:
- result = [x[1] for x in self.__calls]
- if argumentName is not None:
- result = [x[argumentName] for x in result]
- return result
-
- def partial(self, *args, **kargs):
- """Returns a new partial object which when called will behave like this
- listener called with the positional arguments args and keyword
- arguments keywords. If more arguments are supplied to the call, they
- are appended to args. If additional keyword arguments are supplied,
- they extend and override keywords.
- """
- return functools.partial(self, *args, **kargs)
-
-
-def getQToolButtonFromAction(action):
- """Return a QToolButton corresponding to a QAction.
-
- :param QAction action: The QAction from which to get QToolButton.
- :return: A QToolButton associated to action or None.
- """
- for widget in action.associatedWidgets():
- if isinstance(widget, qt.QToolButton):
- return widget
- return None
-
-
-def findChildren(parent, kind, name=None):
- if qt.BINDING in ("PySide", "PySide2") and name is not None:
- result = []
- for obj in parent.findChildren(kind):
- if obj.objectName() == name:
- result.append(obj)
- return result
- else:
- return parent.findChildren(kind, name=name)
diff --git a/silx/gui/widgets/ElidedLabel.py b/silx/gui/widgets/ElidedLabel.py
deleted file mode 100644
index fe53bb9..0000000
--- a/silx/gui/widgets/ElidedLabel.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module contains an elidable label
-"""
-
-__license__ = "MIT"
-__date__ = "07/12/2018"
-
-from silx.gui import qt
-
-
-class ElidedLabel(qt.QLabel):
- """QLabel with an edile property.
-
- By default if the text is too big, it is elided on the right.
-
- This mode can be changed with :func:`setElideMode`.
-
- In case the text is elided, the full content is displayed as part of the
- tool tip. This behavior can be disabled with :func:`setTextAsToolTip`.
- """
-
- def __init__(self, parent=None):
- super(ElidedLabel, self).__init__(parent)
- self.__text = ""
- self.__toolTip = ""
- self.__textAsToolTip = True
- self.__textIsElided = False
- self.__elideMode = qt.Qt.ElideRight
- self.__updateMinimumSize()
-
- def resizeEvent(self, event):
- self.__updateText()
- return qt.QLabel.resizeEvent(self, event)
-
- def setFont(self, font):
- qt.QLabel.setFont(self, font)
- self.__updateMinimumSize()
- self.__updateText()
-
- def __updateMinimumSize(self):
- metrics = self.fontMetrics()
- width = metrics.width("...")
- self.setMinimumWidth(width)
-
- def __updateText(self):
- metrics = self.fontMetrics()
- elidedText = metrics.elidedText(self.__text, self.__elideMode, self.width())
- qt.QLabel.setText(self, elidedText)
- wasElided = self.__textIsElided
- self.__textIsElided = elidedText != self.__text
- if self.__textIsElided or wasElided != self.__textIsElided:
- self.__updateToolTip()
-
- def __updateToolTip(self):
- if self.__textIsElided and self.__textAsToolTip:
- qt.QLabel.setToolTip(self, self.__text + "<br/>" + self.__toolTip)
- else:
- qt.QLabel.setToolTip(self, self.__toolTip)
-
- # Properties
-
- def setText(self, text):
- self.__text = text
- self.__updateText()
-
- def getText(self):
- return self.__text
-
- text = qt.Property(str, getText, setText)
-
- def setToolTip(self, toolTip):
- self.__toolTip = toolTip
- self.__updateToolTip()
-
- def getToolTip(self):
- return self.__toolTip
-
- toolTip = qt.Property(str, getToolTip, setToolTip)
-
- def setElideMode(self, elideMode):
- """Set the elide mode.
-
- :param qt.Qt.TextElideMode elidMode: Elide mode to use
- """
- self.__elideMode = elideMode
- self.__updateText()
-
- def getElideMode(self):
- """Returns the used elide mode.
-
- :rtype: qt.Qt.TextElideMode
- """
- return self.__elideMode
-
- elideMode = qt.Property(qt.Qt.TextElideMode, getToolTip, setToolTip)
-
- def setTextAsToolTip(self, enabled):
- """Enable displaying text as part of the tooltip if it is elided.
-
- :param bool enabled: Enable the behavior
- """
- if self.__textAsToolTip == enabled:
- return
- self.__textAsToolTip = enabled
- self.__updateToolTip()
-
- def getTextAsToolTip(self):
- """True if an elided text is displayed as part of the tooltip.
-
- :rtype: bool
- """
- return self.__textAsToolTip
-
- textAsToolTip = qt.Property(bool, getTextAsToolTip, setTextAsToolTip)
diff --git a/silx/gui/widgets/FloatEdit.py b/silx/gui/widgets/FloatEdit.py
deleted file mode 100644
index 36a39a7..0000000
--- a/silx/gui/widgets/FloatEdit.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module contains a float editor
-"""
-
-from __future__ import division
-
-__authors__ = ["V.A. Sole", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/10/2017"
-
-from .. import qt
-
-
-class FloatEdit(qt.QLineEdit):
- """Field to edit a float value.
-
- :param parent: See :class:`QLineEdit`
- :param float value: The value to set the QLineEdit to.
- """
- def __init__(self, parent=None, value=None):
- qt.QLineEdit.__init__(self, parent)
- validator = qt.QDoubleValidator(self)
- self.setValidator(validator)
- self.setAlignment(qt.Qt.AlignRight)
- if value is not None:
- self.setValue(value)
-
- def value(self):
- """Return the QLineEdit current value as a float."""
- text = self.text()
- value, validated = self.validator().locale().toDouble(text)
- if not validated:
- self.setValue(value)
- return value
-
- def setValue(self, value):
- """Set the current value of the LineEdit
-
- :param float value: The value to set the QLineEdit to.
- """
- text = self.validator().locale().toString(float(value))
- self.setText(text)
diff --git a/silx/gui/widgets/PeriodicTable.py b/silx/gui/widgets/PeriodicTable.py
deleted file mode 100644
index 0233e8c..0000000
--- a/silx/gui/widgets/PeriodicTable.py
+++ /dev/null
@@ -1,831 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Periodic table widgets
-
-Classes
--------
-
-Widgets:
-
- - :class:`PeriodicTable`
- - :class:`PeriodicList`
- - :class:`PeriodicCombo`
-
-Data model:
-
- - :class:`PeriodicTableItem`
- - :class:`ColoredPeriodicTableItem`
-
-
-Example of usage
-----------------
-
-This example uses the widgets with the standard builtin elements list.
-
-.. code-block:: python
-
- from silx.gui import qt
- from silx.gui.widgets.PeriodicTable import PeriodicTable, \
- PeriodicCombo, PeriodicList
-
- a = qt.QApplication([])
-
- w = qt.QTabWidget()
-
- ptable = PeriodicTable(w, selectable=True)
- pcombo = PeriodicCombo(w)
- plist = PeriodicList(w)
-
- w.addTab(ptable, "PeriodicTable")
- w.addTab(plist, "PeriodicList")
- w.addTab(pcombo, "PeriodicCombo")
-
- ptable.setSelection(['H', 'Fe', 'Si'])
- plist.setSelectedElements(['H', 'Be', 'F'])
- pcombo.setSelection("Li")
-
- def change_list(items):
- print("New list selection:", [item.symbol for item in items])
-
- def change_combo(item):
- print("New combo selection:", item.symbol)
-
- def click_table(item):
- print("New table click:", item.symbol)
-
- def change_table(items):
- print("New table selection:", [item.symbol for item in items])
-
- ptable.sigElementClicked.connect(click_table)
- ptable.sigSelectionChanged.connect(change_table)
- plist.sigSelectionChanged.connect(change_list)
- pcombo.sigSelectionChanged.connect(change_combo)
-
- w.show()
- a.exec_()
-
-
-The second example explains how to define custom elements.
-
-.. code-block:: python
-
- from silx.gui import qt
- from silx.gui.widgets.PeriodicTable import PeriodicTable, \
- PeriodicCombo, PeriodicList
- from silx.gui.widgets.PeriodicTable import PeriodicTableItem
-
- # subclass PeriodicTableItem
- class MyPeriodicTableItem(PeriodicTableItem):
- "New item with added mass number and number of protons"
- def __init__(self, symbol, Z, A, col, row, name, mass,
- subcategory=""):
- PeriodicTableItem.__init__(
- self, symbol, Z, col, row, name, mass,
- subcategory)
-
- self.A = A
- "Mass number (neutrons + protons)"
-
- self.num_neutrons = A - Z
- "Number of neutrons"
-
- # build your list of elements
- my_elements = [MyPeriodicTableItem("H", 1, 1, 1, 1, "hydrogen",
- 1.00800, "diatomic nonmetal"),
- MyPeriodicTableItem("He", 2, 4, 18, 1, "helium",
- 4.0030, "noble gas"),
- # etc ...
- ]
-
- app = qt.QApplication([])
-
- ptable = PeriodicTable(elements=my_elements, selectable=True)
- ptable.show()
-
- def click_table(item):
- "Callback function printing the mass number of clicked element"
- print("New table click, mass number:", item.A)
-
- ptable.sigElementClicked.connect(click_table)
- app.exec_()
-
-"""
-
-__authors__ = ["E. Papillon", "V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "26/01/2017"
-
-from collections import OrderedDict
-import logging
-from silx.gui import qt
-
-_logger = logging.getLogger(__name__)
-
-# Symbol Atomic Number col row name mass subcategory
-_elements = [("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"),
- ("He", 2, 18, 1, "helium", 4.0030, "noble gas"),
- ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"),
- ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"),
- ("B", 5, 13, 2, "boron", 10.8110, "metalloid"),
- ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"),
- ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"),
- ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"),
- ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"),
- ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"),
- ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"),
- ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"),
- ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"),
- ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"),
- ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"),
- ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"),
- ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"),
- ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"),
- ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"),
- ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"),
- ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"),
- ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"),
- ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"),
- ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"),
- ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"),
- ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"),
- ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"),
- ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"),
- ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"),
- ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"),
- ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"),
- ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"),
- ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"),
- ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"),
- ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"),
- ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"),
- ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"),
- ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"),
- ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"),
- ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"),
- ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"),
- ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"),
- ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"),
- ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"),
- ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"),
- ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"),
- ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"),
- ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"),
- ("In", 49, 13, 5, "indium", 114.820, "post transition metal"),
- ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"),
- ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"),
- ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"),
- ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"),
- ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"),
- ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"),
- ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"),
- ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"),
- ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"),
- ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"),
- ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"),
- ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"),
- ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"),
- ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"),
- ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"),
- ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"),
- ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"),
- ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"),
- ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"),
- ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"),
- ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"),
- ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"),
- ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"),
- ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"),
- ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"),
- ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"),
- ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"),
- ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"),
- ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"),
- ("Au", 79, 11, 6, "gold", 197.200, "transition metal"),
- ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"),
- ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"),
- ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"),
- ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"),
- ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"),
- ("At", 85, 17, 6, "astatine", 210.000, "metalloid"),
- ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"),
- ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"),
- ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"),
- ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"),
- ("Th", 90, 4, 10, "thorium", 232.000, "actinide"),
- ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"),
- ("U", 92, 6, 10, "uranium", 238.070, "actinide"),
- ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"),
- ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"),
- ("Am", 95, 9, 10, "americium", 243, "actinide"),
- ("Cm", 96, 10, 10, "curium", 247, "actinide"),
- ("Bk", 97, 11, 10, "berkelium", 247, "actinide"),
- ("Cf", 98, 12, 10, "californium", 251, "actinide"),
- ("Es", 99, 13, 10, "einsteinium", 252, "actinide"),
- ("Fm", 100, 14, 10, "fermium", 257, "actinide"),
- ("Md", 101, 15, 10, "mendelevium", 258, "actinide"),
- ("No", 102, 16, 10, "nobelium", 259, "actinide"),
- ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"),
- ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"),
- ("Db", 105, 5, 7, "dubnium", 262, "transition metal"),
- ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"),
- ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"),
- ("Hs", 108, 8, 7, "hassium", 269, "transition metal"),
- ("Mt", 109, 9, 7, "meitnerium", 268)]
-
-
-class PeriodicTableItem(object):
- """Periodic table item, used as generic item in :class:`PeriodicTable`,
- :class:`PeriodicCombo` and :class:`PeriodicList`.
-
- This implementation stores the minimal amount of information needed by the
- widgets:
-
- - atomic symbol
- - atomic number
- - element name
- - atomic mass
- - column of element in periodic table
- - row of element in periodic table
-
- You can subclass this class to add additional information.
-
- :param str symbol: Atomic symbol (e.g. H, He, Li...)
- :param int Z: Proton number
- :param int col: 1-based column index of element in periodic table
- :param int row: 1-based row index of element in periodic table
- :param str name: PeriodicTableItem name ("hydrogen", ...)
- :param float mass: Atomic mass (gram per mol)
- :param str subcategory: Subcategory, based on physical properties
- (e.g. "alkali metal", "noble gas"...)
- """
- def __init__(self, symbol, Z, col, row, name, mass,
- subcategory=""):
- self.symbol = symbol
- """Atomic symbol (e.g. H, He, Li...)"""
- self.Z = Z
- """Atomic number (Proton number)"""
- self.col = col
- """1-based column index of element in periodic table"""
- self.row = row
- """1-based row index of element in periodic table"""
- self.name = name
- """PeriodicTableItem name ("hydrogen", ...)"""
- self.mass = mass
- """Atomic mass (gram per mol)"""
- self.subcategory = subcategory
- """Subcategory, based on physical properties
- (e.g. "alkali metal", "noble gas"...)"""
-
- # pymca compatibility (elements used to be stored as a list of lists)
- def __getitem__(self, idx):
- if idx == 6:
- _logger.warning("density not implemented in silx, returning 0.")
-
- ret = [self.symbol, self.Z,
- self.col, self.row,
- self.name, self.mass,
- 0.]
- return ret[idx]
-
- def __len__(self):
- return 6
-
-
-class ColoredPeriodicTableItem(PeriodicTableItem):
- """:class:`PeriodicTableItem` with an added :attr:`bgcolor`.
- The background color can be passed as a parameter to the constructor.
- If it is not specified, it will be defined based on
- :attr:`subcategory`.
-
- :param str bgcolor: Custom background color for element in
- periodic table, as a RGB string *#RRGGBB*"""
- COLORS = {
- "diatomic nonmetal": "#7FFF00", # chartreuse
- "noble gas": "#00FFFF", # cyan
- "alkali metal": "#FFE4B5", # Moccasin
- "alkaline earth metal": "#FFA500", # orange
- "polyatomic nonmetal": "#7FFFD4", # aquamarine
- "transition metal": "#FFA07A", # light salmon
- "metalloid": "#8FBC8F", # Dark Sea Green
- "post transition metal": "#D3D3D3", # light gray
- "lanthanide": "#FFB6C1", # light pink
- "actinide": "#F08080", # Light Coral
- "": "#FFFFFF" # white
- }
- """Dictionary defining RGB colors for each subcategory."""
-
- def __init__(self, symbol, Z, col, row, name, mass,
- subcategory="", bgcolor=None):
- PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass,
- subcategory)
-
- self.bgcolor = self.COLORS.get(subcategory, "#FFFFFF")
- """Background color of element in the periodic table,
- based on its subcategory. This should be a string of a hexadecimal
- RGB code, with the format *#RRGGBB*.
- If the subcategory is unknown, use white (*#FFFFFF*)
- """
-
- # possible custom color
- if bgcolor is not None:
- self.bgcolor = bgcolor
-
-
-_defaultTableItems = [ColoredPeriodicTableItem(*info) for info in _elements]
-
-
-class _ElementButton(qt.QPushButton):
- """Atomic element button, used as a cell in the periodic table
- """
- sigElementEnter = qt.pyqtSignal(object)
- """Signal emitted as the cursor enters the widget"""
- sigElementLeave = qt.pyqtSignal(object)
- """Signal emitted as the cursor leaves the widget"""
- sigElementClicked = qt.pyqtSignal(object)
- """Signal emitted when the widget is clicked"""
-
- def __init__(self, item, parent=None):
- """
-
- :param parent: Parent widget
- :param PeriodicTableItem item: :class:`PeriodicTableItem` object
- """
- qt.QPushButton.__init__(self, parent)
-
- self.item = item
- """:class:`PeriodicTableItem` object represented by this button"""
-
- self.setText(item.symbol)
- self.setFlat(1)
- self.setCheckable(0)
-
- self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
- qt.QSizePolicy.Expanding))
-
- self.selected = False
- self.current = False
-
- # selection colors
- self.selected_color = qt.QColor(qt.Qt.yellow)
- self.current_color = qt.QColor(qt.Qt.gray)
- self.selected_current_color = qt.QColor(qt.Qt.darkYellow)
-
- # element colors
-
- if hasattr(item, "bgcolor"):
- self.bgcolor = qt.QColor(item.bgcolor)
- else:
- self.bgcolor = qt.QColor("#FFFFFF")
-
- self.brush = qt.QBrush()
- self.__setBrush()
-
- self.clicked.connect(self.clickedSlot)
-
- def sizeHint(self):
- return qt.QSize(40, 40)
-
- def setCurrent(self, b):
- """Set this element button as current.
- Multiple buttons can be selected.
-
- :param b: boolean
- """
- self.current = b
- self.__setBrush()
-
- def isCurrent(self):
- """
- :return: True if element button is current
- """
- return self.current
-
- def isSelected(self):
- """
- :return: True if element button is selected
- """
- return self.selected
-
- def setSelected(self, b):
- """Set this element button as selected.
- Only a single button can be selected.
-
- :param b: boolean
- """
- self.selected = b
- self.__setBrush()
-
- def __setBrush(self):
- """Selected cells are yellow when not current.
- The current cell is dark yellow when selected or grey when not
- selected.
- Other cells have no bg color by default, unless specified at
- instantiation (:attr:`bgcolor`)"""
- palette = self.palette()
- # if self.current and self.selected:
- # self.brush = qt.QBrush(self.selected_current_color)
- # el
- if self.selected:
- self.brush = qt.QBrush(self.selected_color)
- # elif self.current:
- # self.brush = qt.QBrush(self.current_color)
- elif self.bgcolor is not None:
- self.brush = qt.QBrush(self.bgcolor)
- else:
- self.brush = qt.QBrush()
- palette.setBrush(self.backgroundRole(),
- self.brush)
- self.setPalette(palette)
- self.update()
-
- def paintEvent(self, pEvent):
- # get button geometry
- widgGeom = self.rect()
- paintGeom = qt.QRect(widgGeom.left() + 1,
- widgGeom.top() + 1,
- widgGeom.width() - 2,
- widgGeom.height() - 2)
-
- # paint background color
- painter = qt.QPainter(self)
- if self.brush is not None:
- painter.fillRect(paintGeom, self.brush)
- # paint frame
- pen = qt.QPen(qt.Qt.black)
- pen.setWidth(1 if not self.isCurrent() else 5)
- painter.setPen(pen)
- painter.drawRect(paintGeom)
- painter.end()
- qt.QPushButton.paintEvent(self, pEvent)
-
- def enterEvent(self, e):
- """Emit a :attr:`sigElementEnter` signal and send a
- :class:`PeriodicTableItem` object"""
- self.sigElementEnter.emit(self.item)
-
- def leaveEvent(self, e):
- """Emit a :attr:`sigElementLeave` signal and send a
- :class:`PeriodicTableItem` object"""
- self.sigElementLeave.emit(self.item)
-
- def clickedSlot(self):
- """Emit a :attr:`sigElementClicked` signal and send a
- :class:`PeriodicTableItem` object"""
- self.sigElementClicked.emit(self.item)
-
-
-class PeriodicTable(qt.QWidget):
- """Periodic Table widget
-
- .. image:: img/PeriodicTable.png
-
- The following example shows how to connect clicking to selection::
-
- from silx.gui import qt
- from silx.gui.widgets.PeriodicTable import PeriodicTable
- app = qt.QApplication([])
- pt = PeriodicTable()
- pt.sigElementClicked.connect(pt.elementToggle)
- pt.show()
- app.exec_()
-
- To print all selected elements each time a new element is selected::
-
- def my_slot(item):
- pt.elementToggle(item)
- selected_elements = pt.getSelection()
- for e in selected_elements:
- print(e.symbol)
-
- pt.sigElementClicked.connect(my_slot)
-
- """
- sigElementClicked = qt.pyqtSignal(object)
- """When any element is clicked in the table, the widget emits
- this signal and sends a :class:`PeriodicTableItem` object.
- """
-
- sigSelectionChanged = qt.pyqtSignal(object)
- """When any element is selected/unselected in the table, the widget emits
- this signal and sends a list of :class:`PeriodicTableItem` objects.
-
- .. note::
-
- To enable selection of elements, you must set *selectable=True*
- when you instantiate the widget. Alternatively, you can also connect
- :attr:`sigElementClicked` to :meth:`elementToggle` manually::
-
- pt = PeriodicTable()
- pt.sigElementClicked.connect(pt.elementToggle)
-
-
- :param parent: parent QWidget
- :param str name: Widget window title
- :param elements: List of items (:class:`PeriodicTableItem` objects) to
- be represented in the table. By default, take elements from
- a predefined list with minimal information (symbol, atomic number,
- name, mass).
- :param bool selectable: If *True*, multiple elements can be
- selected by clicking with the mouse. If *False* (default),
- selection is only possible with method :meth:`setSelection`.
- """
-
- def __init__(self, parent=None, name="PeriodicTable", elements=None,
- selectable=False):
- self.selectable = selectable
- qt.QWidget.__init__(self, parent)
- self.setWindowTitle(name)
- self.gridLayout = qt.QGridLayout(self)
- self.gridLayout.setContentsMargins(0, 0, 0, 0)
- self.gridLayout.addItem(qt.QSpacerItem(0, 5), 7, 0)
-
- for idx in range(10):
- self.gridLayout.setRowStretch(idx, 3)
- # row 8 (above lanthanoids is empty)
- self.gridLayout.setRowStretch(7, 2)
-
- # Element information displayed when cursor enters a cell
- self.eltLabel = qt.QLabel(self)
- f = self.eltLabel.font()
- f.setBold(1)
- self.eltLabel.setFont(f)
- self.eltLabel.setAlignment(qt.Qt.AlignHCenter)
- self.gridLayout.addWidget(self.eltLabel, 1, 1, 3, 10)
-
- self._eltCurrent = None
- """Current :class:`_ElementButton` (last clicked)"""
-
- self._eltButtons = OrderedDict()
- """Dictionary of all :class:`_ElementButton`. Keys are the symbols
- ("H", "He", "Li"...)"""
-
- if elements is None:
- elements = _defaultTableItems
- # fill cells with elements
- for elmt in elements:
- self.__addElement(elmt)
-
- def __addElement(self, elmt):
- """Add one :class:`_ElementButton` widget into the grid,
- connect its signals to interact with the cursor"""
- b = _ElementButton(elmt, self)
- b.setAutoDefault(False)
-
- self._eltButtons[elmt.symbol] = b
- self.gridLayout.addWidget(b, elmt.row, elmt.col)
-
- b.sigElementEnter.connect(self.elementEnter)
- b.sigElementLeave.connect(self._elementLeave)
- b.sigElementClicked.connect(self._elementClicked)
-
- def elementEnter(self, item):
- """Update label with element info (e.g. "Nb(41) - niobium")
- when mouse cursor hovers an element.
-
- :param PeriodicTableItem item: Element entered by cursor
- """
- self.eltLabel.setText("%s(%d) - %s" % (item.symbol, item.Z, item.name))
-
- def _elementLeave(self, item):
- """Clear label when the cursor leaves the cell
-
- :param PeriodicTableItem item: Element left
- """
- self.eltLabel.setText("")
-
- def _elementClicked(self, item):
- """Emit :attr:`sigElementClicked`,
- toggle selected state of element
-
- :param PeriodicTableItem item: Element clicked
- """
- if self._eltCurrent is not None:
- self._eltCurrent.setCurrent(False)
- self._eltButtons[item.symbol].setCurrent(True)
- self._eltCurrent = self._eltButtons[item.symbol]
- if self.selectable:
- self.elementToggle(item)
- self.sigElementClicked.emit(item)
-
- def getSelection(self):
- """Return a list of selected elements, as a list of :class:`PeriodicTableItem`
- objects.
-
- :return: Selected items
- :rtype: List[PeriodicTableItem]
- """
- return [b.item for b in self._eltButtons.values() if b.isSelected()]
-
- def setSelection(self, symbols):
- """Set selected elements.
-
- This causes the sigSelectionChanged signal
- to be emitted, even if the selection didn't actually change.
-
- :param List[str] symbols: List of symbols of elements to be selected
- (e.g. *["Fe", "Hg", "Li"]*)
- """
- # accept list of PeriodicTableItems as input, because getSelection
- # returns these objects and it makes sense to have getter and setter
- # use same type of data
- if isinstance(symbols[0], PeriodicTableItem):
- symbols = [elmt.symbol for elmt in symbols]
-
- for (e, b) in self._eltButtons.items():
- b.setSelected(e in symbols)
- self.sigSelectionChanged.emit(self.getSelection())
-
- def setElementSelected(self, symbol, state):
- """Modify *selected* status of a single element (select or unselect)
-
- :param str symbol: PeriodicTableItem symbol to be selected
- :param bool state: *True* to select, *False* to unselect
- """
- self._eltButtons[symbol].setSelected(state)
- self.sigSelectionChanged.emit(self.getSelection())
-
- def isElementSelected(self, symbol):
- """Return *True* if element is selected, else *False*
-
- :param str symbol: PeriodicTableItem symbol
- :return: *True* if element is selected, else *False*
- """
- return self._eltButtons[symbol].isSelected()
-
- def elementToggle(self, item):
- """Toggle selected/unselected state for element
-
- :param item: PeriodicTableItem object
- """
- b = self._eltButtons[item.symbol]
- b.setSelected(not b.isSelected())
- self.sigSelectionChanged.emit(self.getSelection())
-
-
-class PeriodicCombo(qt.QComboBox):
- """
- Combo list with all atomic elements of the periodic table
-
- .. image:: img/PeriodicCombo.png
-
- :param bool detailed: True (default) display element symbol, Z and name.
- False display only element symbol and Z.
- :param elements: List of items (:class:`PeriodicTableItem` objects) to
- be represented in the table. By default, take elements from
- a predefined list with minimal information (symbol, atomic number,
- name, mass).
- """
- sigSelectionChanged = qt.pyqtSignal(object)
- """Signal emitted when the selection changes. Send
- :class:`PeriodicTableItem` object representing selected
- element
- """
-
- def __init__(self, parent=None, detailed=True, elements=None):
- qt.QComboBox.__init__(self, parent)
-
- # add all elements from global list
- if elements is None:
- elements = _defaultTableItems
- for i, elmt in enumerate(elements):
- if detailed:
- txt = "%2s (%d) - %s" % (elmt.symbol, elmt.Z, elmt.name)
- else:
- txt = "%2s (%d)" % (elmt.symbol, elmt.Z)
- self.insertItem(i, txt)
-
- self.currentIndexChanged[int].connect(self.__selectionChanged)
-
- def __selectionChanged(self, idx):
- """Emit :attr:`sigSelectionChanged`"""
- self.sigSelectionChanged.emit(_defaultTableItems[idx])
-
- def getSelection(self):
- """Get selected element
-
- :return: Selected element
- :rtype: PeriodicTableItem
- """
- return _defaultTableItems[self.currentIndex()]
-
- def setSelection(self, symbol):
- """Set selected item in combobox by giving the atomic symbol
-
- :param symbol: Symbol of element to be selected
- """
- # accept PeriodicTableItem for getter/setter consistency
- if isinstance(symbol, PeriodicTableItem):
- symbol = symbol.symbol
- symblist = [elmt.symbol for elmt in _defaultTableItems]
- self.setCurrentIndex(symblist.index(symbol))
-
-
-class PeriodicList(qt.QTreeWidget):
- """List of atomic elements in a :class:`QTreeView`
-
- .. image:: img/PeriodicList.png
-
- :param QWidget parent: Parent widget
- :param bool detailed: True (default) display element symbol, Z and name.
- False display only element symbol and Z.
- :param single: *True* for single element selection with mouse click,
- *False* for multiple element selection mode.
- """
- sigSelectionChanged = qt.pyqtSignal(object)
- """When any element is selected/unselected in the widget, it emits
- this signal and sends a list of currently selected
- :class:`PeriodicTableItem` objects.
- """
-
- def __init__(self, parent=None, detailed=True, single=False, elements=None):
- qt.QTreeWidget.__init__(self, parent)
-
- self.detailed = detailed
-
- headers = ["Z", "Symbol"]
- if detailed:
- headers.append("Name")
- self.setColumnCount(3)
- else:
- self.setColumnCount(2)
- self.setHeaderLabels(headers)
- self.header().setStretchLastSection(False)
-
- self.setRootIsDecorated(0)
- self.itemClicked.connect(self.__selectionChanged)
- self.setSelectionMode(qt.QAbstractItemView.SingleSelection if single
- else qt.QAbstractItemView.ExtendedSelection)
- self.__fill_widget(elements)
- self.resizeColumnToContents(0)
- self.resizeColumnToContents(1)
- if detailed:
- self.resizeColumnToContents(2)
-
- def __fill_widget(self, elements):
- """Fill tree widget with elements """
- if elements is None:
- elements = _defaultTableItems
-
- self.tree_items = []
-
- previous_item = None
- for elmt in elements:
- if previous_item is None:
- item = qt.QTreeWidgetItem(self)
- else:
- item = qt.QTreeWidgetItem(self, previous_item)
- item.setText(0, str(elmt.Z))
- item.setText(1, elmt.symbol)
- if self.detailed:
- item.setText(2, elmt.name)
- self.tree_items.append(item)
- previous_item = item
-
- def __selectionChanged(self, treeItem, column):
- """Emit a :attr:`sigSelectionChanged` and send a list of
- :class:`PeriodicTableItem` objects."""
- self.sigSelectionChanged.emit(self.getSelection())
-
- def getSelection(self):
- """Get a list of selected elements, as a list of :class:`PeriodicTableItem`
- objects.
-
- :return: Selected elements
- :rtype: List[PeriodicTableItem]"""
- return [_defaultTableItems[idx] for idx in range(len(self.tree_items))
- if self.tree_items[idx].isSelected()]
-
- # setSelection is a bad name (name of a QTreeWidget method)
- def setSelectedElements(self, symbolList):
- """
-
- :param symbolList: List of atomic symbols ["H", "He", "Li"...]
- to be selected in the widget
- """
- # accept PeriodicTableItem for getter/setter consistency
- if isinstance(symbolList[0], PeriodicTableItem):
- symbolList = [elmt.symbol for elmt in symbolList]
- for idx in range(len(self.tree_items)):
- self.tree_items[idx].setSelected(_defaultTableItems[idx].symbol in symbolList)
diff --git a/silx/gui/widgets/PrintGeometryDialog.py b/silx/gui/widgets/PrintGeometryDialog.py
deleted file mode 100644
index db0f3b3..0000000
--- a/silx/gui/widgets/PrintGeometryDialog.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-
-from silx.gui import qt
-from silx.gui.widgets.FloatEdit import FloatEdit
-
-
-class PrintGeometryWidget(qt.QWidget):
- """Widget to specify the size and aspect ratio of an item
- before sending it to the print preview dialog.
-
- Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
- to interact with the widget.
- """
- def __init__(self, parent=None):
- super(PrintGeometryWidget, self).__init__(parent)
- self.mainLayout = qt.QGridLayout(self)
- self.mainLayout.setContentsMargins(0, 0, 0, 0)
- self.mainLayout.setSpacing(2)
- hbox = qt.QWidget(self)
- hboxLayout = qt.QHBoxLayout(hbox)
- hboxLayout.setContentsMargins(0, 0, 0, 0)
- hboxLayout.setSpacing(2)
- label = qt.QLabel(self)
- label.setText("Units")
- label.setAlignment(qt.Qt.AlignCenter)
- self._pageButton = qt.QRadioButton()
- self._pageButton.setText("Page")
- self._inchButton = qt.QRadioButton()
- self._inchButton.setText("Inches")
- self._cmButton = qt.QRadioButton()
- self._cmButton.setText("Centimeters")
- self._buttonGroup = qt.QButtonGroup(self)
- self._buttonGroup.addButton(self._pageButton)
- self._buttonGroup.addButton(self._inchButton)
- self._buttonGroup.addButton(self._cmButton)
- self._buttonGroup.setExclusive(True)
-
- # units
- self.mainLayout.addWidget(label, 0, 0, 1, 4)
- hboxLayout.addWidget(self._pageButton)
- hboxLayout.addWidget(self._inchButton)
- hboxLayout.addWidget(self._cmButton)
- self.mainLayout.addWidget(hbox, 1, 0, 1, 4)
- self._pageButton.setChecked(True)
-
- # xOffset
- label = qt.QLabel(self)
- label.setText("X Offset:")
- self.mainLayout.addWidget(label, 2, 0)
- self._xOffset = FloatEdit(self, 0.1)
- self.mainLayout.addWidget(self._xOffset, 2, 1)
-
- # yOffset
- label = qt.QLabel(self)
- label.setText("Y Offset:")
- self.mainLayout.addWidget(label, 2, 2)
- self._yOffset = FloatEdit(self, 0.1)
- self.mainLayout.addWidget(self._yOffset, 2, 3)
-
- # width
- label = qt.QLabel(self)
- label.setText("Width:")
- self.mainLayout.addWidget(label, 3, 0)
- self._width = FloatEdit(self, 0.9)
- self.mainLayout.addWidget(self._width, 3, 1)
-
- # height
- label = qt.QLabel(self)
- label.setText("Height:")
- self.mainLayout.addWidget(label, 3, 2)
- self._height = FloatEdit(self, 0.9)
- self.mainLayout.addWidget(self._height, 3, 3)
-
- # aspect ratio
- self._aspect = qt.QCheckBox(self)
- self._aspect.setText("Keep screen aspect ratio")
- self._aspect.setChecked(True)
- self.mainLayout.addWidget(self._aspect, 4, 1, 1, 2)
-
- def getPrintGeometry(self):
- """Return the print geometry dictionary.
-
- See :meth:`setPrintGeometry` for documentation about the
- print geometry dictionary."""
- ddict = {}
- if self._inchButton.isChecked():
- ddict['units'] = "inches"
- elif self._cmButton.isChecked():
- ddict['units'] = "centimeters"
- else:
- ddict['units'] = "page"
-
- ddict['xOffset'] = self._xOffset.value()
- ddict['yOffset'] = self._yOffset.value()
- ddict['width'] = self._width.value()
- ddict['height'] = self._height.value()
-
- if self._aspect.isChecked():
- ddict['keepAspectRatio'] = True
- else:
- ddict['keepAspectRatio'] = False
- return ddict
-
- def setPrintGeometry(self, geometry=None):
- """Set the print geometry.
-
- The geometry parameters must be provided as a dictionary with
- the following keys:
-
- - *"xOffset"* (float)
- - *"yOffset"* (float)
- - *"width"* (float)
- - *"height"* (float)
- - *"units"*: possible values *"page", "inch", "cm"*
- - *"keepAspectRatio"*: *True* or *False*
-
- If *units* is *"page"*, the values should be floats in [0, 1.]
- and are interpreted as a fraction of the page width or height.
-
- :param dict geometry: Geometry parameters, as a dictionary."""
- if geometry is None:
- geometry = {}
- oldDict = self.getPrintGeometry()
- for key in ["units", "xOffset", "yOffset",
- "width", "height", "keepAspectRatio"]:
- geometry[key] = geometry.get(key, oldDict[key])
-
- if geometry['units'].lower().startswith("inc"):
- self._inchButton.setChecked(True)
- elif geometry['units'].lower().startswith("c"):
- self._cmButton.setChecked(True)
- else:
- self._pageButton.setChecked(True)
-
- self._xOffset.setText("%s" % float(geometry['xOffset']))
- self._yOffset.setText("%s" % float(geometry['yOffset']))
- self._width.setText("%s" % float(geometry['width']))
- self._height.setText("%s" % float(geometry['height']))
- if geometry['keepAspectRatio']:
- self._aspect.setChecked(True)
- else:
- self._aspect.setChecked(False)
-
-
-class PrintGeometryDialog(qt.QDialog):
- """Dialog embedding a :class:`PrintGeometryWidget`.
-
- Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
- to interact with the widget.
-
- Execute method :meth:`exec_` to run the dialog.
- The return value of that method is *True* if the geometry was set
- (*Ok* button clicked) or *False* if the user clicked the *Cancel*
- button.
- """
-
- def __init__(self, parent=None):
- qt.QDialog.__init__(self, parent)
- self.setWindowTitle("Set print size preferences")
- layout = qt.QVBoxLayout(self)
- layout.setContentsMargins(0, 0, 0, 0)
- layout.setSpacing(0)
- self.configurationWidget = PrintGeometryWidget(self)
- hbox = qt.QWidget(self)
- hboxLayout = qt.QHBoxLayout(hbox)
- self.okButton = qt.QPushButton(hbox)
- self.okButton.setText("Accept")
- self.okButton.setAutoDefault(False)
- self.rejectButton = qt.QPushButton(hbox)
- self.rejectButton.setText("Dismiss")
- self.rejectButton.setAutoDefault(False)
- self.okButton.clicked.connect(self.accept)
- self.rejectButton.clicked.connect(self.reject)
- hboxLayout.setContentsMargins(0, 0, 0, 0)
- hboxLayout.setSpacing(2)
- # hboxLayout.addWidget(qt.HorizontalSpacer(hbox))
- hboxLayout.addWidget(self.okButton)
- hboxLayout.addWidget(self.rejectButton)
- # hboxLayout.addWidget(qt.HorizontalSpacer(hbox))
- layout.addWidget(self.configurationWidget)
- layout.addWidget(hbox)
-
- def setPrintGeometry(self, geometry):
- """Return the print geometry dictionary.
-
- See :meth:`PrintGeometryWidget.setPrintGeometry` for documentation on
- print geometry dictionary.
-
- :param dict geometry: Print geometry parameters dictionary.
- """
- self.configurationWidget.setPrintGeometry(geometry)
-
- def getPrintGeometry(self):
- """Return the print geometry dictionary.
-
- See :meth:`PrintGeometryWidget.setPrintGeometry` for documentation on
- print geometry dictionary."""
- return self.configurationWidget.getPrintGeometry()
diff --git a/silx/gui/widgets/PrintPreview.py b/silx/gui/widgets/PrintPreview.py
deleted file mode 100644
index 96af34b..0000000
--- a/silx/gui/widgets/PrintPreview.py
+++ /dev/null
@@ -1,728 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module implements a print preview dialog.
-
-The dialog provides methods to send images, pixmaps and SVG
-items to the page to be printed.
-
-The user can interactively move and resize the items.
-"""
-import sys
-import logging
-from silx.gui import qt, printer
-
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "11/07/2017"
-
-
-_logger = logging.getLogger(__name__)
-
-
-class PrintPreviewDialog(qt.QDialog):
- """Print preview dialog widget.
- """
- def __init__(self, parent=None, printer=None):
-
- qt.QDialog.__init__(self, parent)
- self.setWindowTitle("Print Preview")
- self.setModal(False)
- self.resize(400, 500)
-
- self.mainLayout = qt.QVBoxLayout(self)
- self.mainLayout.setContentsMargins(0, 0, 0, 0)
- self.mainLayout.setSpacing(0)
-
- self._buildToolbar()
-
- self.printer = printer
- # :class:`QPrinter` (paint device that paints on a printer).
- # :meth:`showEvent` has been reimplemented to enforce printer
- # setup.
-
- self.printDialog = None
- # :class:`QPrintDialog` (dialog for specifying the printer's
- # configuration)
-
- self.scene = None
- # :class:`QGraphicsScene` (surface for managing
- # 2D graphical items)
-
- self.page = None
- # :class:`QGraphicsRectItem` used as white background page on which
- # to display the print preview.
-
- self.view = None
- # :class:`QGraphicsView` widget for displaying :attr:`scene`
-
- self._svgItems = []
- # List storing :class:`QSvgRenderer` items to be printed, added in
- # :meth:`addSvgItem`, cleared in :meth:`_clearAll`.
- # This ensures that there is a reference pointing to the items,
- # which ensures they are not destroyed before being printed.
-
- self._viewScale = 1.0
- # Zoom level (1.0 is 100%)
-
- self._toBeCleared = False
- # Flag indicating that all items must be removed from :attr:`scene`
- # and from :attr:`_svgItems`.
- # Set to True after a successful printing. The widget is then hidden,
- # and it will be cleared the next time it is shown.
- # Reset to False after :meth:`_clearAll` has done its job.
-
- def _buildToolbar(self):
- toolBar = qt.QWidget(self)
- # a layout for the toolbar
- toolsLayout = qt.QHBoxLayout(toolBar)
- toolsLayout.setContentsMargins(0, 0, 0, 0)
- toolsLayout.setSpacing(0)
-
- hideBut = qt.QPushButton("Hide", toolBar)
- hideBut.setToolTip("Hide print preview dialog")
- hideBut.clicked.connect(self.hide)
-
- cancelBut = qt.QPushButton("Clear All", toolBar)
- cancelBut.setToolTip("Remove all items")
- cancelBut.clicked.connect(self._clearAll)
-
- removeBut = qt.QPushButton("Remove",
- toolBar)
- removeBut.setToolTip("Remove selected item (use left click to select)")
- removeBut.clicked.connect(self._remove)
-
- setupBut = qt.QPushButton("Setup", toolBar)
- setupBut.setToolTip("Select and configure a printer")
- setupBut.clicked.connect(self.setup)
-
- printBut = qt.QPushButton("Print", toolBar)
- printBut.setToolTip("Print page and close print preview")
- printBut.clicked.connect(self._print)
-
- zoomPlusBut = qt.QPushButton("Zoom +", toolBar)
- zoomPlusBut.clicked.connect(self._zoomPlus)
-
- zoomMinusBut = qt.QPushButton("Zoom -", toolBar)
- zoomMinusBut.clicked.connect(self._zoomMinus)
-
- toolsLayout.addWidget(hideBut)
- toolsLayout.addWidget(printBut)
- toolsLayout.addWidget(cancelBut)
- toolsLayout.addWidget(removeBut)
- toolsLayout.addWidget(setupBut)
- # toolsLayout.addStretch()
- # toolsLayout.addWidget(marginLabel)
- # toolsLayout.addWidget(self.marginSpin)
- toolsLayout.addStretch()
- # toolsLayout.addWidget(scaleLabel)
- # toolsLayout.addWidget(self.scaleCombo)
- toolsLayout.addWidget(zoomPlusBut)
- toolsLayout.addWidget(zoomMinusBut)
- # toolsLayout.addStretch()
- self.toolBar = toolBar
- self.mainLayout.addWidget(self.toolBar)
-
- def _buildStatusBar(self):
- """Create the status bar used to display the printer name
- or output file name."""
- # status bar
- statusBar = qt.QStatusBar(self)
- self.targetLabel = qt.QLabel(statusBar)
- self._updateTargetLabel()
- statusBar.addWidget(self.targetLabel)
- self.mainLayout.addWidget(statusBar)
-
- def _updateTargetLabel(self):
- """Update printer name or file name shown in the status bar."""
- if self.printer is None:
- self.targetLabel.setText("Undefined printer")
- return
- if self.printer.outputFileName():
- self.targetLabel.setText("File:" +
- self.printer.outputFileName())
- else:
- self.targetLabel.setText("Printer:" +
- self.printer.printerName())
-
- def _updatePrinter(self):
- """Resize :attr:`page`, :attr:`scene` and :attr:`view` to :attr:`printer`
- width and height."""
- printer = self.printer
- assert printer is not None, \
- "_updatePrinter should not be called unless a printer is defined"
- if self.scene is None:
- self.scene = qt.QGraphicsScene()
- self.scene.setBackgroundBrush(qt.QColor(qt.Qt.lightGray))
- self.scene.setSceneRect(qt.QRectF(0, 0, printer.width(), printer.height()))
-
- if self.page is None:
- self.page = qt.QGraphicsRectItem(0, 0, printer.width(), printer.height())
- self.page.setBrush(qt.QColor(qt.Qt.white))
- self.scene.addItem(self.page)
-
- self.scene.setSceneRect(qt.QRectF(0, 0, printer.width(), printer.height()))
- self.page.setPos(qt.QPointF(0.0, 0.0))
- self.page.setRect(qt.QRectF(0, 0, printer.width(), printer.height()))
-
- if self.view is None:
- self.view = qt.QGraphicsView(self.scene)
- self.mainLayout.addWidget(self.view)
- self._buildStatusBar()
- # self.view.scale(1./self._viewScale, 1./self._viewScale)
- self.view.fitInView(self.page.rect(), qt.Qt.KeepAspectRatio)
- self._viewScale = 1.00
- self._updateTargetLabel()
-
- # Public methods
- def addImage(self, image, title=None, comment=None, commentPosition=None):
- """Add an image to the print preview scene.
-
- :param QImage image: Image to be added to the scene
- :param str title: Title shown above (centered) the image
- :param str comment: Comment displayed below the image
- :param commentPosition: "CENTER" or "LEFT"
- """
- self.addPixmap(qt.QPixmap.fromImage(image),
- title=title, comment=comment,
- commentPosition=commentPosition)
-
- def addPixmap(self, pixmap, title=None, comment=None, commentPosition=None):
- """Add a pixmap to the print preview scene
-
- :param QPixmap pixmap: Pixmap to be added to the scene
- :param str title: Title shown above (centered) the pixmap
- :param str comment: Comment displayed below the pixmap
- :param commentPosition: "CENTER" or "LEFT"
- """
- if self._toBeCleared:
- self._clearAll()
- self.ensurePrinterIsSet()
- if self.printer is None:
- _logger.error("printer is not set, cannot add pixmap to page")
- return
- if title is None:
- title = ' ' * 88
- if comment is None:
- comment = ' ' * 88
- if commentPosition is None:
- commentPosition = "CENTER"
- if qt.qVersion() < "5.0":
- rectItem = qt.QGraphicsRectItem(self.page, self.scene)
- else:
- rectItem = qt.QGraphicsRectItem(self.page)
-
- rectItem.setRect(qt.QRectF(1, 1,
- pixmap.width(), pixmap.height()))
-
- pen = rectItem.pen()
- color = qt.QColor(qt.Qt.red)
- color.setAlpha(1)
- pen.setColor(color)
- rectItem.setPen(pen)
- rectItem.setZValue(1)
- rectItem.setFlag(qt.QGraphicsItem.ItemIsSelectable, True)
- rectItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
- rectItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
-
- rectItemResizeRect = _GraphicsResizeRectItem(rectItem, self.scene)
- rectItemResizeRect.setZValue(2)
-
- if qt.qVersion() < "5.0":
- pixmapItem = qt.QGraphicsPixmapItem(rectItem, self.scene)
- else:
- pixmapItem = qt.QGraphicsPixmapItem(rectItem)
- pixmapItem.setPixmap(pixmap)
- pixmapItem.setZValue(0)
-
- # I add the title
- if qt.qVersion() < "5.0":
- textItem = qt.QGraphicsTextItem(title, rectItem, self.scene)
- else:
- textItem = qt.QGraphicsTextItem(title, rectItem)
- textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
- offset = 0.5 * textItem.boundingRect().width()
- textItem.moveBy(0.5 * pixmap.width() - offset, -20)
- textItem.setZValue(2)
-
- # I add the comment
- if qt.qVersion() < "5.0":
- commentItem = qt.QGraphicsTextItem(comment, rectItem, self.scene)
- else:
- commentItem = qt.QGraphicsTextItem(comment, rectItem)
- commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
- offset = 0.5 * commentItem.boundingRect().width()
- if commentPosition.upper() == "LEFT":
- x = 1
- else:
- x = 0.5 * pixmap.width() - offset
- commentItem.moveBy(x, pixmap.height() + 20)
- commentItem.setZValue(2)
-
- rectItem.moveBy(20, 40)
-
- def addSvgItem(self, item, title=None,
- comment=None, commentPosition=None,
- viewBox=None, keepRatio=True):
- """Add a SVG item to the scene.
-
- :param QSvgRenderer item: SVG item to be added to the scene.
- :param str title: Title shown above (centered) the SVG item.
- :param str comment: Comment displayed below the SVG item.
- :param str commentPosition: "CENTER" or "LEFT"
- :param QRectF viewBox: Bounding box for the item on the print page
- (xOffset, yOffset, width, height). If None, use original
- item size.
- :param bool keepRatio: If True, resizing the item will preserve its
- original aspect ratio.
- """
- if not qt.HAS_SVG:
- raise RuntimeError("Missing QtSvg library.")
- if not isinstance(item, qt.QSvgRenderer):
- raise TypeError("addSvgItem: QSvgRenderer expected")
- if self._toBeCleared:
- self._clearAll()
- self.ensurePrinterIsSet()
- if self.printer is None:
- _logger.error("printer is not set, cannot add SvgItem to page")
- return
-
- if title is None:
- title = 50 * ' '
- if comment is None:
- comment = 80 * ' '
- if commentPosition is None:
- commentPosition = "CENTER"
-
- if viewBox is None:
- if hasattr(item, "_viewBox"):
- # PyMca compatibility: viewbox attached to item
- viewBox = item._viewBox
- else:
- # try the original item viewbox
- viewBox = item.viewBoxF()
-
- svgItem = _GraphicsSvgRectItem(viewBox, self.page)
- svgItem.setSvgRenderer(item)
-
- svgItem.setCacheMode(qt.QGraphicsItem.NoCache)
- svgItem.setZValue(0)
- svgItem.setFlag(qt.QGraphicsItem.ItemIsSelectable, True)
- svgItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
- svgItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
-
- rectItemResizeRect = _GraphicsResizeRectItem(svgItem, self.scene,
- keepratio=keepRatio)
- rectItemResizeRect.setZValue(2)
-
- self._svgItems.append(item)
-
- # Comment / legend
- dummyComment = 80 * "1"
- if qt.qVersion() < '5.0':
- commentItem = qt.QGraphicsTextItem(dummyComment, svgItem, self.scene)
- else:
- commentItem = qt.QGraphicsTextItem(dummyComment, svgItem)
- commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
- # we scale the text to have the legend box have the same width as the graph
- scaleCalculationRect = qt.QRectF(commentItem.boundingRect())
- scale = svgItem.boundingRect().width() / scaleCalculationRect.width()
-
- commentItem.setPlainText(comment)
- commentItem.setZValue(1)
-
- commentItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
- if qt.qVersion() < "5.0":
- commentItem.scale(scale, scale)
- else:
- commentItem.setScale(scale)
-
- # align
- if commentPosition.upper() == "CENTER":
- alignment = qt.Qt.AlignCenter
- elif commentPosition.upper() == "RIGHT":
- alignment = qt.Qt.AlignRight
- else:
- alignment = qt.Qt.AlignLeft
- commentItem.setTextWidth(commentItem.boundingRect().width())
- center_format = qt.QTextBlockFormat()
- center_format.setAlignment(alignment)
- cursor = commentItem.textCursor()
- cursor.select(qt.QTextCursor.Document)
- cursor.mergeBlockFormat(center_format)
- cursor.clearSelection()
- commentItem.setTextCursor(cursor)
- if alignment == qt.Qt.AlignLeft:
- deltax = 0
- else:
- deltax = (svgItem.boundingRect().width() - commentItem.boundingRect().width()) / 2.
- commentItem.moveBy(svgItem.boundingRect().x() + deltax,
- svgItem.boundingRect().y() + svgItem.boundingRect().height())
-
- # Title
- if qt.qVersion() < '5.0':
- textItem = qt.QGraphicsTextItem(title, svgItem, self.scene)
- else:
- textItem = qt.QGraphicsTextItem(title, svgItem)
- textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
- textItem.setZValue(1)
- textItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
-
- title_offset = 0.5 * textItem.boundingRect().width()
- textItem.moveBy(svgItem.boundingRect().x() +
- 0.5 * svgItem.boundingRect().width() - title_offset * scale,
- svgItem.boundingRect().y())
- if qt.qVersion() < "5.0":
- textItem.scale(scale, scale)
- else:
- textItem.setScale(scale)
-
- def setup(self):
- """Open a print dialog to ensure the :attr:`printer` is set.
-
- If the setting fails or is cancelled, :attr:`printer` is reset to
- *None*.
- """
- if self.printer is None:
- self.printer = printer.getDefaultPrinter()
- if self.printDialog is None:
- self.printDialog = qt.QPrintDialog(self.printer, self)
- if self.printDialog.exec_():
- if self.printer.width() <= 0 or self.printer.height() <= 0:
- self.message = qt.QMessageBox(self)
- self.message.setIcon(qt.QMessageBox.Critical)
- self.message.setText("Unknown library error \non printer initialization")
- self.message.setWindowTitle("Library Error")
- self.message.setModal(0)
- self.printer = None
- return
- self.printer.setFullPage(True)
- self._updatePrinter()
- else:
- # printer setup cancelled, check for a possible previous configuration
- if self.page is None:
- # not initialized
- self.printer = None
-
- def ensurePrinterIsSet(self):
- """If the printer is not already set, try to interactively
- setup the printer using a QPrintDialog.
- In case of failure, hide widget and log a warning.
-
- :return: True if printer was set. False if it failed or if the
- selection dialog was canceled.
- """
- if self.printer is None:
- self.setup()
- if self.printer is None:
- self.hide()
- _logger.warning("Printer setup failed or was cancelled, " +
- "but printer is required.")
- return self.printer is not None
-
- def setOutputFileName(self, name):
- """Set output filename.
-
- Setting a non-empty name enables printing to file.
-
- :param str name: File name (path)"""
- self.printer.setOutputFileName(name)
-
- # overloaded methods
- def exec_(self):
- if self._toBeCleared:
- self._clearAll()
- return qt.QDialog.exec_(self)
-
- def raise_(self):
- if self._toBeCleared:
- self._clearAll()
- return qt.QDialog.raise_(self)
-
- def showEvent(self, event):
- """Reimplemented to force printer setup.
- In case of failure, hide the widget."""
- if self._toBeCleared:
- self._clearAll()
- self.ensurePrinterIsSet()
-
- return super(PrintPreviewDialog, self).showEvent(event)
-
- # button callbacks
- def _print(self):
- """Do the printing, hide the print preview dialog,
- set :attr:`_toBeCleared` flag to True to trigger clearing the
- next time the dialog is shown.
-
- If the printer is not setup, do it first."""
- printer = self.printer
-
- painter = qt.QPainter()
- if not painter.begin(printer) or printer is None:
- _logger.error("Cannot initialize printer")
- return
- try:
- self.scene.render(painter, qt.QRectF(0, 0, printer.width(), printer.height()),
- qt.QRectF(self.page.rect().x(), self.page.rect().y(),
- self.page.rect().width(), self.page.rect().height()),
- qt.Qt.KeepAspectRatio)
- painter.end()
- self.hide()
- self.accept()
- self._toBeCleared = True
- except: # FIXME
- painter.end()
- qt.QMessageBox.critical(self, "ERROR",
- 'Printing problem:\n %s' % sys.exc_info()[1])
- _logger.error('printing problem:\n %s' % sys.exc_info()[1])
- return
-
- def _zoomPlus(self):
- self._viewScale *= 1.20
- self.view.scale(1.20, 1.20)
-
- def _zoomMinus(self):
- self._viewScale *= 0.80
- self.view.scale(0.80, 0.80)
-
- def _clearAll(self):
- """
- Clear the print preview window, remove all items
- but keep the page.
- """
- itemlist = self.scene.items()
- keep = self.page
- while len(itemlist) != 1:
- if itemlist.index(keep) == 0:
- self.scene.removeItem(itemlist[1])
- else:
- self.scene.removeItem(itemlist[0])
- itemlist = self.scene.items()
- self._svgItems = []
- self._toBeCleared = False
-
- def _remove(self):
- """Remove selected item in :attr:`scene`.
- """
- itemlist = self.scene.items()
-
- # this loop is not efficient if there are many items ...
- for item in itemlist:
- if item.isSelected():
- self.scene.removeItem(item)
-
-
-class SingletonPrintPreviewDialog(PrintPreviewDialog):
- """Singleton print preview dialog.
-
- All widgets in a program that instantiate this class will share
- a single print preview dialog. This enables sending
- multiple images to a single page to be printed.
- """
- _instance = None
-
- def __new__(self, *var, **kw):
- if self._instance is None:
- self._instance = PrintPreviewDialog(*var, **kw)
- return self._instance
-
-
-class _GraphicsSvgRectItem(qt.QGraphicsRectItem):
- """:class:`qt.QGraphicsRectItem` with an attached
- :class:`qt.QSvgRenderer`, and with a painter redefined to render
- the SVG item."""
- def setSvgRenderer(self, renderer):
- """
-
- :param QSvgRenderer renderer: svg renderer
- """
- self._renderer = renderer
-
- def paint(self, painter, *var, **kw):
- self._renderer.render(painter, self.boundingRect())
-
-
-class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
- """Resizable QGraphicsRectItem."""
- def __init__(self, parent=None, scene=None, keepratio=True):
- if qt.qVersion() < '5.0':
- qt.QGraphicsRectItem.__init__(self, parent, scene)
- else:
- qt.QGraphicsRectItem.__init__(self, parent)
- rect = parent.boundingRect()
- x = rect.x()
- y = rect.y()
- w = rect.width()
- h = rect.height()
- self._newRect = None
- self.keepRatio = keepratio
- self.setRect(qt.QRectF(x + w - 40, y + h - 40, 40, 40))
- self.setAcceptHoverEvents(True)
- pen = qt.QPen()
- color = qt.QColor(qt.Qt.white)
- color.setAlpha(0)
- pen.setColor(color)
- pen.setStyle(qt.Qt.NoPen)
- self.setPen(pen)
- self.setBrush(color)
- self.setFlag(self.ItemIsMovable, True)
- self.show()
-
- def hoverEnterEvent(self, event):
- if self.parentItem().isSelected():
- self.parentItem().setSelected(False)
- if self.keepRatio:
- self.setCursor(qt.QCursor(qt.Qt.SizeFDiagCursor))
- else:
- self.setCursor(qt.QCursor(qt.Qt.SizeAllCursor))
- self.setBrush(qt.QBrush(qt.Qt.yellow, qt.Qt.SolidPattern))
- return qt.QGraphicsRectItem.hoverEnterEvent(self, event)
-
- def hoverLeaveEvent(self, event):
- self.setCursor(qt.QCursor(qt.Qt.ArrowCursor))
- pen = qt.QPen()
- color = qt.QColor(qt.Qt.white)
- color.setAlpha(0)
- pen.setColor(color)
- pen.setStyle(qt.Qt.NoPen)
- self.setPen(pen)
- self.setBrush(color)
- return qt.QGraphicsRectItem.hoverLeaveEvent(self, event)
-
- def mousePressEvent(self, event):
- if self._newRect is not None:
- self._newRect = None
- self._point0 = self.pos()
- parent = self.parentItem()
- scene = self.scene()
- # following line prevents dragging along the previously selected
- # item when resizing another one
- scene.clearSelection()
-
- rect = parent.boundingRect()
- self._x = rect.x()
- self._y = rect.y()
- self._w = rect.width()
- self._h = rect.height()
- self._ratio = self._w / self._h
- if qt.qVersion() < "5.0":
- self._newRect = qt.QGraphicsRectItem(parent, scene)
- else:
- self._newRect = qt.QGraphicsRectItem(parent)
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- self._w,
- self._h))
- qt.QGraphicsRectItem.mousePressEvent(self, event)
-
- def mouseMoveEvent(self, event):
- point1 = self.pos()
- deltax = point1.x() - self._point0.x()
- deltay = point1.y() - self._point0.y()
- if self.keepRatio:
- r1 = (self._w + deltax) / self._w
- r2 = (self._h + deltay) / self._h
- if r1 < r2:
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- self._w + deltax,
- (self._w + deltax) / self._ratio))
- else:
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- (self._h + deltay) * self._ratio,
- self._h + deltay))
- else:
- self._newRect.setRect(qt.QRectF(self._x,
- self._y,
- self._w + deltax,
- self._h + deltay))
- qt.QGraphicsRectItem.mouseMoveEvent(self, event)
-
- def mouseReleaseEvent(self, event):
- point1 = self.pos()
- deltax = point1.x() - self._point0.x()
- deltay = point1.y() - self._point0.y()
- self.moveBy(-deltax, -deltay)
- parent = self.parentItem()
-
- # deduce scale from rectangle
- if (qt.qVersion() < "5.0") or self.keepRatio:
- scalex = self._newRect.rect().width() / self._w
- scaley = scalex
- else:
- scalex = self._newRect.rect().width() / self._w
- scaley = self._newRect.rect().height() / self._h
-
- if qt.qVersion() < "5.0":
- parent.scale(scalex, scaley)
- else:
- # apply the scale to the previous transformation matrix
- previousTransform = parent.transform()
- parent.setTransform(
- previousTransform.scale(scalex, scaley))
-
- self.scene().removeItem(self._newRect)
- self._newRect = None
- qt.QGraphicsRectItem.mouseReleaseEvent(self, event)
-
-
-def main():
- """
- """
- if len(sys.argv) < 2:
- print("give an image file as parameter please.")
- sys.exit(1)
-
- if len(sys.argv) > 2:
- print("only one parameter please.")
- sys.exit(1)
-
- filename = sys.argv[1]
- w = PrintPreviewDialog()
- w.resize(400, 500)
-
- comment = ""
- for i in range(20):
- comment += "Line number %d: En un lugar de La Mancha de cuyo nombre ...\n" % i
-
- if filename[-3:] == "svg":
- item = qt.QSvgRenderer(filename, w.page)
- w.addSvgItem(item, title=filename,
- comment=comment, commentPosition="CENTER")
- else:
- w.addPixmap(qt.QPixmap.fromImage(qt.QImage(filename)),
- title=filename,
- comment=comment,
- commentPosition="CENTER")
- w.addImage(qt.QImage(filename), comment=comment, commentPosition="LEFT")
-
- sys.exit(w.exec_())
-
-
-if __name__ == '__main__':
- a = qt.QApplication(sys.argv)
- main()
- a.exec_()
diff --git a/silx/gui/widgets/RangeSlider.py b/silx/gui/widgets/RangeSlider.py
deleted file mode 100644
index 31dbd4e..0000000
--- a/silx/gui/widgets/RangeSlider.py
+++ /dev/null
@@ -1,765 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides a :class:`RangeSlider` widget.
-
-.. image:: img/RangeSlider.png
- :align: center
-"""
-from __future__ import absolute_import, division
-
-__authors__ = ["D. Naudet", "T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/11/2018"
-
-
-import numpy as numpy
-
-from silx.gui import qt, icons, colors
-from silx.gui.utils.image import convertArrayToQImage
-
-
-class StyleOptionRangeSlider(qt.QStyleOption):
- def __init__(self):
- super(StyleOptionRangeSlider, self).__init__()
- self.minimum = None
- self.maximum = None
- self.sliderPosition1 = None
- self.sliderPosition2 = None
- self.handlerRect1 = None
- self.handlerRect2 = None
-
-
-class RangeSlider(qt.QWidget):
- """Range slider with 2 thumbs and an optional colored groove.
-
- The position of the slider thumbs can be retrieved either as values
- in the slider range or as a number of steps or pixels.
-
- :param QWidget parent: See QWidget
- """
-
- _SLIDER_WIDTH = 10
- """Width of the slider rectangle"""
-
- _PIXMAP_VOFFSET = 7
- """Vertical groove pixmap offset"""
-
- sigRangeChanged = qt.Signal(float, float)
- """Signal emitted when the value range has changed.
-
- It provides the new range (min, max).
- """
-
- sigValueChanged = qt.Signal(float, float)
- """Signal emitted when the value of the sliders has changed.
-
- It provides the slider values (first, second).
- """
-
- sigPositionCountChanged = qt.Signal(object)
- """This signal is emitted when the number of steps has changed.
-
- It provides the new position count.
- """
-
- sigPositionChanged = qt.Signal(int, int)
- """Signal emitted when the position of the sliders has changed.
-
- It provides the slider positions in steps or pixels (first, second).
- """
-
- def __init__(self, parent=None):
- self.__pixmap = None
- self.__positionCount = None
- self.__firstValue = 0.
- self.__secondValue = 1.
- self.__minValue = 0.
- self.__maxValue = 1.
- self.__hoverRect = qt.QRect()
- self.__hoverControl = None
-
- self.__focus = None
- self.__moving = None
-
- self.__icons = {
- 'first': icons.getQIcon('previous'),
- 'second': icons.getQIcon('next')
- }
-
- # call the super constructor AFTER defining all members that
- # are used in the "paint" method
- super(RangeSlider, self).__init__(parent)
-
- self.setFocusPolicy(qt.Qt.ClickFocus)
- self.setAttribute(qt.Qt.WA_Hover)
-
- self.setMinimumSize(qt.QSize(50, 20))
- self.setMaximumHeight(20)
-
- # Broadcast value changed signal
- self.sigValueChanged.connect(self.__emitPositionChanged)
-
- def event(self, event):
- t = event.type()
- if t == qt.QEvent.HoverEnter or t == qt.QEvent.HoverLeave or t == qt.QEvent.HoverMove:
- return self.__updateHoverControl(event.pos())
- else:
- return super(RangeSlider, self).event(event)
-
- def __updateHoverControl(self, pos):
- hoverControl, hoverRect = self.__findHoverControl(pos)
- if hoverControl != self.__hoverControl:
- self.update(self.__hoverRect)
- self.update(hoverRect)
- self.__hoverControl = hoverControl
- self.__hoverRect = hoverRect
- return True
- return hoverControl is not None
-
- def __findHoverControl(self, pos):
- """Returns the control at the position and it's rect location"""
- for name in ["first", "second"]:
- rect = self.__sliderRect(name)
- if rect.contains(pos):
- return name, rect
- rect = self.__drawArea()
- if rect.contains(pos):
- return "groove", rect
- return None, qt.QRect()
-
- # Position <-> Value conversion
-
- def __positionToValue(self, position):
- """Returns value corresponding to position
-
- :param int position:
- :rtype: float
- """
- min_, max_ = self.getMinimum(), self.getMaximum()
- maxPos = self.__getCurrentPositionCount() - 1
- return min_ + (max_ - min_) * int(position) / maxPos
-
- def __valueToPosition(self, value):
- """Returns closest position corresponding to value
-
- :param float value:
- :rtype: int
- """
- min_, max_ = self.getMinimum(), self.getMaximum()
- maxPos = self.__getCurrentPositionCount() - 1
- return int(0.5 + maxPos * (float(value) - min_) / (max_ - min_))
-
- # Position (int) API
-
- def __getCurrentPositionCount(self):
- """Return current count (either position count or widget width
-
- :rtype: int
- """
- count = self.getPositionCount()
- if count is not None:
- return count
- else:
- return max(2, self.width() - self._SLIDER_WIDTH)
-
- def getPositionCount(self):
- """Returns the number of positions.
-
- :rtype: Union[int,None]"""
- return self.__positionCount
-
- def setPositionCount(self, count):
- """Set the number of positions.
-
- Slider values are eventually adjusted.
-
- :param Union[int,None] count:
- Either the number of possible positions or
- None to allow any values.
- :raise ValueError: If count <= 1
- """
- count = None if count is None else int(count)
- if count != self.getPositionCount():
- if count is not None and count <= 1:
- raise ValueError("Position count must be higher than 1")
- self.__positionCount = count
- emit = self.__setValues(*self.getValues())
- self.sigPositionCountChanged.emit(count)
- if emit:
- self.sigValueChanged.emit(*self.getValues())
-
- def getFirstPosition(self):
- """Returns first slider position
-
- :rtype: int
- """
- return self.__valueToPosition(self.getFirstValue())
-
- def setFirstPosition(self, position):
- """Set the position of the first slider
-
- The position is adjusted to valid values
-
- :param int position:
- """
- self.setFirstValue(self.__positionToValue(position))
-
- def getSecondPosition(self):
- """Returns second slider position
-
- :rtype: int
- """
- return self.__valueToPosition(self.getSecondValue())
-
- def setSecondPosition(self, position):
- """Set the position of the second slider
-
- The position is adjusted to valid values
-
- :param int position:
- """
- self.setSecondValue(self.__positionToValue(position))
-
- def getPositions(self):
- """Returns slider positions (first, second)
-
- :rtype: List[int]
- """
- return self.getFirstPosition(), self.getSecondPosition()
-
- def setPositions(self, first, second):
- """Set the position of both sliders at once
-
- First is clipped to the slider range: [0, max].
- Second is clipped to valid values: [first, max]
-
- :param int first:
- :param int second:
- """
- self.setValues(self.__positionToValue(first),
- self.__positionToValue(second))
-
- # Value (float) API
-
- def __emitPositionChanged(self, *args, **kwargs):
- self.sigPositionChanged.emit(*self.getPositions())
-
- def __rangeChanged(self):
- """Handle change of value range"""
- emit = self.__setValues(*self.getValues())
- self.sigRangeChanged.emit(*self.getRange())
- if emit:
- self.sigValueChanged.emit(*self.getValues())
-
- def getMinimum(self):
- """Returns the minimum value of the slider range
-
- :rtype: float
- """
- return self.__minValue
-
- def setMinimum(self, minimum):
- """Set the minimum value of the slider range.
-
- It eventually adjusts maximum.
- Slider positions remains unchanged and slider values are modified.
-
- :param float minimum:
- """
- minimum = float(minimum)
- if minimum != self.getMinimum():
- if minimum > self.getMaximum():
- self.__maxValue = minimum
- self.__minValue = minimum
- self.__rangeChanged()
-
- def getMaximum(self):
- """Returns the maximum value of the slider range
-
- :rtype: float
- """
- return self.__maxValue
-
- def setMaximum(self, maximum):
- """Set the maximum value of the slider range
-
- It eventually adjusts minimum.
- Slider positions remains unchanged and slider values are modified.
-
- :param float maximum:
- """
- maximum = float(maximum)
- if maximum != self.getMaximum():
- if maximum < self.getMinimum():
- self.__minValue = maximum
- self.__maxValue = maximum
- self.__rangeChanged()
-
- def getRange(self):
- """Returns the range of values (min, max)
-
- :rtype: List[float]
- """
- return self.getMinimum(), self.getMaximum()
-
- def setRange(self, minimum, maximum):
- """Set the range of values.
-
- If maximum is lower than minimum, minimum is the only valid value.
- Slider positions remains unchanged and slider values are modified.
-
- :param float minimum:
- :param float maximum:
- """
- minimum, maximum = float(minimum), float(maximum)
- if minimum != self.getMinimum() or maximum != self.getMaximum():
- self.__minValue = minimum
- self.__maxValue = max(maximum, minimum)
- self.__rangeChanged()
-
- def getFirstValue(self):
- """Returns the value of the first slider
-
- :rtype: float
- """
- return self.__firstValue
-
- def __clipFirstValue(self, value, max_=None):
- """Clip first value to range and steps
-
- :param float value:
- :param float max_: Alternative maximum to use
- """
- if max_ is None:
- max_ = self.getSecondValue()
- value = min(max(self.getMinimum(), float(value)), max_)
- if self.getPositionCount() is not None: # Clip to steps
- value = self.__positionToValue(self.__valueToPosition(value))
- return value
-
- def setFirstValue(self, value):
- """Set the value of the first slider
-
- Value is clipped to valid values.
-
- :param float value:
- """
- value = self.__clipFirstValue(value)
- if value != self.getFirstValue():
- self.__firstValue = value
- self.update()
- self.sigValueChanged.emit(*self.getValues())
-
- def getSecondValue(self):
- """Returns the value of the second slider
-
- :rtype: float
- """
- return self.__secondValue
-
- def __clipSecondValue(self, value):
- """Clip second value to range and steps
-
- :param float value:
- """
- value = min(max(self.getFirstValue(), float(value)), self.getMaximum())
- if self.getPositionCount() is not None: # Clip to steps
- value = self.__positionToValue(self.__valueToPosition(value))
- return value
-
- def setSecondValue(self, value):
- """Set the value of the second slider
-
- Value is clipped to valid values.
-
- :param float value:
- """
- value = self.__clipSecondValue(value)
- if value != self.getSecondValue():
- self.__secondValue = value
- self.update()
- self.sigValueChanged.emit(*self.getValues())
-
- def getValues(self):
- """Returns value of both sliders at once
-
- :return: (first value, second value)
- :rtype: List[float]
- """
- return self.getFirstValue(), self.getSecondValue()
-
- def setValues(self, first, second):
- """Set values for both sliders at once
-
- First is clipped to the slider range: [minimum, maximum].
- Second is clipped to valid values: [first, maximum]
-
- :param float first:
- :param float second:
- """
- if self.__setValues(first, second):
- self.sigValueChanged.emit(*self.getValues())
-
- def __setValues(self, first, second):
- """Set values for both sliders at once
-
- First is clipped to the slider range: [minimum, maximum].
- Second is clipped to valid values: [first, maximum]
-
- :param float first:
- :param float second:
- :return: True if values has changed, False otherwise
- :rtype: bool
- """
- first = self.__clipFirstValue(first, self.getMaximum())
- second = self.__clipSecondValue(second)
- values = first, second
-
- if self.getValues() != values:
- self.__firstValue, self.__secondValue = values
- self.update()
- return True
- return False
-
- # Groove API
-
- def getGroovePixmap(self):
- """Returns the pixmap displayed in the slider groove if any.
-
- :rtype: Union[QPixmap,None]
- """
- return self.__pixmap
-
- def setGroovePixmap(self, pixmap):
- """Set the pixmap displayed in the slider groove.
-
- :param Union[QPixmap,None] pixmap: The QPixmap to use or None to unset.
- """
- assert pixmap is None or isinstance(pixmap, qt.QPixmap)
- self.__pixmap = pixmap
- self.update()
-
- def setGroovePixmapFromProfile(self, profile, colormap=None):
- """Set the pixmap displayed in the slider groove from histogram values.
-
- :param Union[numpy.ndarray,None] profile:
- 1D array of values to display
- :param Union[~silx.gui.colors.Colormap,str] colormap:
- The colormap name or object to convert profile values to colors
- """
- if profile is None:
- self.setSliderPixmap(None)
- return
-
- profile = numpy.array(profile, copy=False)
-
- if profile.size == 0:
- self.setSliderPixmap(None)
- return
-
- if colormap is None:
- colormap = colors.Colormap()
- elif isinstance(colormap, str):
- colormap = colors.Colormap(name=colormap)
- assert isinstance(colormap, colors.Colormap)
-
- rgbImage = colormap.applyToData(profile.reshape(1, -1))[:, :, :3]
- qimage = convertArrayToQImage(rgbImage)
- qpixmap = qt.QPixmap.fromImage(qimage)
- self.setGroovePixmap(qpixmap)
-
- # Handle interaction
-
- def mousePressEvent(self, event):
- super(RangeSlider, self).mousePressEvent(event)
-
- if event.buttons() == qt.Qt.LeftButton:
- picked = None
- for name in ('first', 'second'):
- area = self.__sliderRect(name)
- if area.contains(event.pos()):
- picked = name
- break
-
- self.__moving = picked
- self.__focus = picked
- self.update()
-
- def mouseMoveEvent(self, event):
- super(RangeSlider, self).mouseMoveEvent(event)
-
- if self.__moving is not None:
- delta = self._SLIDER_WIDTH // 2
- if self.__moving == 'first':
- position = self.__xPixelToPosition(event.pos().x() + delta)
- self.setFirstPosition(position)
- else:
- position = self.__xPixelToPosition(event.pos().x() - delta)
- self.setSecondPosition(position)
-
- def mouseReleaseEvent(self, event):
- super(RangeSlider, self).mouseReleaseEvent(event)
-
- if event.button() == qt.Qt.LeftButton and self.__moving is not None:
- self.__moving = None
- self.update()
-
- def focusOutEvent(self, event):
- if self.__focus is not None:
- self.__focus = None
- self.update()
- super(RangeSlider, self).focusOutEvent(event)
-
- def keyPressEvent(self, event):
- key = event.key()
- if event.modifiers() == qt.Qt.NoModifier and self.__focus is not None:
- if key in (qt.Qt.Key_Left, qt.Qt.Key_Down):
- if self.__focus == 'first':
- self.setFirstPosition(self.getFirstPosition() - 1)
- else:
- self.setSecondPosition(self.getSecondPosition() - 1)
- return # accept event
- elif key in (qt.Qt.Key_Right, qt.Qt.Key_Up):
- if self.__focus == 'first':
- self.setFirstPosition(self.getFirstPosition() + 1)
- else:
- self.setSecondPosition(self.getSecondPosition() + 1)
- return # accept event
-
- super(RangeSlider, self).keyPressEvent(event)
-
- # Handle resize
-
- def resizeEvent(self, event):
- super(RangeSlider, self).resizeEvent(event)
-
- # If no step, signal position update when width change
- if (self.getPositionCount() is None and
- event.size().width() != event.oldSize().width()):
- self.sigPositionChanged.emit(*self.getPositions())
-
- # Handle repaint
-
- def __xPixelToPosition(self, x):
- """Convert position in pixel to slider position
-
- :param int x: X in pixel coordinates
- :rtype: int
- """
- sliderArea = self.__sliderAreaRect()
- maxPos = self.__getCurrentPositionCount() - 1
- position = maxPos * (x - sliderArea.left()) / (sliderArea.width() - 1)
- return int(position + 0.5)
-
- def __sliderRect(self, name):
- """Returns rectangle corresponding to slider in pixels
-
- :param str name: 'first' or 'second'
- :rtype: QRect
- :raise ValueError: If wrong name
- """
- assert name in ('first', 'second')
- if name == 'first':
- offset = - self._SLIDER_WIDTH
- position = self.getFirstPosition()
- elif name == 'second':
- offset = 0
- position = self.getSecondPosition()
- else:
- raise ValueError('Unknown name')
-
- sliderArea = self.__sliderAreaRect()
-
- maxPos = self.__getCurrentPositionCount() - 1
- xOffset = int((sliderArea.width() - 1) * position / maxPos)
- xPos = sliderArea.left() + xOffset + offset
-
- return qt.QRect(xPos,
- sliderArea.top(),
- self._SLIDER_WIDTH,
- sliderArea.height())
-
- def __drawArea(self):
- return self.rect().adjusted(self._SLIDER_WIDTH, 0,
- -self._SLIDER_WIDTH, 0)
-
- def __sliderAreaRect(self):
- return self.__drawArea().adjusted(self._SLIDER_WIDTH // 2,
- 0,
- -self._SLIDER_WIDTH // 2 + 1,
- 0)
-
- def __pixMapRect(self):
- return self.__sliderAreaRect().adjusted(0,
- self._PIXMAP_VOFFSET,
- -1,
- -self._PIXMAP_VOFFSET)
-
- def paintEvent(self, event):
- painter = qt.QPainter(self)
-
- style = qt.QApplication.style()
-
- area = self.__drawArea()
- if self.__pixmap is not None:
- pixmapRect = self.__pixMapRect()
-
- option = qt.QStyleOptionProgressBar()
- option.initFrom(self)
- option.rect = area
- option.state = (qt.QStyle.State_Enabled if self.isEnabled()
- else qt.QStyle.State_None)
- style.drawControl(qt.QStyle.CE_ProgressBarGroove,
- option,
- painter,
- self)
-
- painter.save()
- pen = painter.pen()
- pen.setWidth(1)
- pen.setColor(qt.Qt.black if self.isEnabled() else qt.Qt.gray)
- painter.setPen(pen)
- painter.drawRect(pixmapRect.adjusted(-1, -1, 0, 1))
- painter.restore()
-
- if self.isEnabled():
- rect = area.adjusted(self._SLIDER_WIDTH // 2,
- self._PIXMAP_VOFFSET,
- -self._SLIDER_WIDTH // 2,
- -self._PIXMAP_VOFFSET + 1)
- painter.drawPixmap(rect,
- self.__pixmap,
- self.__pixmap.rect())
- else:
- option = StyleOptionRangeSlider()
- option.initFrom(self)
- option.rect = area
- option.sliderPosition1 = self.__firstValue
- option.sliderPosition2 = self.__secondValue
- option.handlerRect1 = self.__sliderRect("first")
- option.handlerRect2 = self.__sliderRect("second")
- option.minimum = self.__minValue
- option.maximum = self.__maxValue
- option.state = (qt.QStyle.State_Enabled if self.isEnabled()
- else qt.QStyle.State_None)
- if self.__hoverControl == "groove":
- option.state |= qt.QStyle.State_MouseOver
- elif option.state & qt.QStyle.State_MouseOver:
- option.state ^= qt.QStyle.State_MouseOver
- self.drawRangeSliderBackground(painter, option, self)
-
- # Avoid glitch when moving handles
- hoverControl = self.__moving or self.__hoverControl
-
- for name in ('first', 'second'):
- rect = self.__sliderRect(name)
- option = qt.QStyleOptionButton()
- option.initFrom(self)
- option.icon = self.__icons[name]
- option.iconSize = rect.size() * 0.7
- if hoverControl == name:
- option.state |= qt.QStyle.State_MouseOver
- elif option.state & qt.QStyle.State_MouseOver:
- option.state ^= qt.QStyle.State_MouseOver
- if self.__focus == name:
- option.state |= qt.QStyle.State_HasFocus
- elif option.state & qt.QStyle.State_HasFocus:
- option.state ^= qt.QStyle.State_HasFocus
- option.rect = rect
- style.drawControl(
- qt.QStyle.CE_PushButton, option, painter, self)
-
- def sizeHint(self):
- return qt.QSize(200, self.minimumHeight())
-
- @classmethod
- def drawRangeSliderBackground(cls, painter, option, widget):
- """Draw the background of the RangeSlider widget into the painter.
-
- :param qt.QPainter painter: A painter
- :param StyleOptionRangeSlider option: Options to draw the widget
- :param qt.QWidget: The widget which have to be drawn
- """
- painter.save()
- painter.translate(0.5, 0.5)
-
- backgroundRect = qt.QRect(option.rect)
- if backgroundRect.height() > 8:
- center = backgroundRect.center()
- backgroundRect.setHeight(8)
- backgroundRect.moveCenter(center)
-
- selectedRangeRect = qt.QRect(backgroundRect)
- selectedRangeRect.setLeft(option.handlerRect1.center().x())
- selectedRangeRect.setRight(option.handlerRect2.center().x())
-
- highlight = option.palette.color(qt.QPalette.Highlight)
- activeHighlight = highlight
- selectedOutline = option.palette.color(qt.QPalette.Highlight)
-
- buttonColor = option.palette.button().color()
- val = qt.qGray(buttonColor.rgb())
- buttonColor = buttonColor.lighter(100 + max(1, (180 - val) // 6))
- buttonColor.setHsv(buttonColor.hue(), (buttonColor.saturation() * 3) // 4, buttonColor.value())
-
- grooveColor = qt.QColor()
- grooveColor.setHsv(buttonColor.hue(),
- min(255, (int)(buttonColor.saturation())),
- min(255, (int)(buttonColor.value() * 0.9)))
-
- selectedInnerContrastLine = qt.QColor(255, 255, 255, 30)
-
- outline = option.palette.color(qt.QPalette.Background).darker(140)
- if (option.state & qt.QStyle.State_HasFocus and option.state & qt.QStyle.State_KeyboardFocusChange):
- outline = highlight.darker(125)
- if outline.value() > 160:
- outline.setHsl(highlight.hue(), highlight.saturation(), 160)
-
- # Draw background groove
- painter.setRenderHint(qt.QPainter.Antialiasing, True)
- gradient = qt.QLinearGradient()
- gradient.setStart(backgroundRect.center().x(), backgroundRect.top())
- gradient.setFinalStop(backgroundRect.center().x(), backgroundRect.bottom())
- painter.setPen(qt.QPen(outline))
- gradient.setColorAt(0, grooveColor.darker(110))
- gradient.setColorAt(1, grooveColor.lighter(110))
- painter.setBrush(gradient)
- painter.drawRoundedRect(backgroundRect.adjusted(1, 1, -2, -2), 1, 1)
-
- # Draw slider background for the value
- gradient = qt.QLinearGradient()
- gradient.setStart(selectedRangeRect.center().x(), selectedRangeRect.top())
- gradient.setFinalStop(selectedRangeRect.center().x(), selectedRangeRect.bottom())
- painter.setRenderHint(qt.QPainter.Antialiasing, True)
- painter.setPen(qt.QPen(selectedOutline))
- gradient.setColorAt(0, activeHighlight)
- gradient.setColorAt(1, activeHighlight.lighter(130))
- painter.setBrush(gradient)
- painter.drawRoundedRect(selectedRangeRect.adjusted(1, 1, -2, -2), 1, 1)
- painter.setPen(selectedInnerContrastLine)
- painter.setBrush(qt.Qt.NoBrush)
- painter.drawRoundedRect(selectedRangeRect.adjusted(2, 2, -3, -3), 1, 1)
-
- painter.restore()
diff --git a/silx/gui/widgets/TableWidget.py b/silx/gui/widgets/TableWidget.py
deleted file mode 100644
index 8167fec..0000000
--- a/silx/gui/widgets/TableWidget.py
+++ /dev/null
@@ -1,626 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module provides table widgets handling cut, copy and paste for
-multiple cell selections. These actions can be triggered using keyboard
-shortcuts or through a context menu (right-click).
-
-:class:`TableView` is a subclass of :class:`QTableView`. The added features
-are made available to users after a model is added to the widget, using
-:meth:`TableView.setModel`.
-
-:class:`TableWidget` is a subclass of :class:`qt.QTableWidget`, a table view
-with a built-in standard data model. The added features are available as soon as
-the widget is initialized.
-
-The cut, copy and paste actions are implemented as QActions:
-
- - :class:`CopySelectedCellsAction` (*Ctrl+C*)
- - :class:`CopyAllCellsAction`
- - :class:`CutSelectedCellsAction` (*Ctrl+X*)
- - :class:`CutAllCellsAction`
- - :class:`PasteCellsAction` (*Ctrl+V*)
-
-The copy actions are enabled by default. The cut and paste actions must be
-explicitly enabled, by passing parameters ``cut=True, paste=True`` when
-creating the widgets, or later by calling their :meth:`enableCut` and
-:meth:`enablePaste` methods.
-"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "03/07/2017"
-
-
-import sys
-from .. import qt
-
-
-if sys.platform.startswith("win"):
- row_separator = "\r\n"
-else:
- row_separator = "\n"
-
-col_separator = "\t"
-
-
-class CopySelectedCellsAction(qt.QAction):
- """QAction to copy text from selected cells in a :class:`QTableWidget`
- into the clipboard.
-
- If multiple cells are selected, the copied text will be a concatenation
- of the texts in all selected cells, tabulated with tabulation and
- newline characters.
-
- If the cells are sparsely selected, the structure is preserved by
- representing the unselected cells as empty strings in between two
- tabulation characters.
- Beware of pasting this data in another table widget, because depending
- on how the paste is implemented, the empty cells may cause data in the
- target table to be deleted, even though you didn't necessarily select the
- corresponding cell in the origin table.
-
- :param table: :class:`QTableView` to which this action belongs.
- """
- def __init__(self, table):
- if not isinstance(table, qt.QTableView):
- raise ValueError('CopySelectedCellsAction must be initialised ' +
- 'with a QTableWidget.')
- super(CopySelectedCellsAction, self).__init__(table)
- self.setText("Copy selection")
- self.setToolTip("Copy selected cells into the clipboard.")
- self.setShortcut(qt.QKeySequence.Copy)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
- self.triggered.connect(self.copyCellsToClipboard)
- self.table = table
- self.cut = False
- """:attr:`cut` can be set to True by classes inheriting this action,
- to do a cut action."""
-
- def copyCellsToClipboard(self):
- """Concatenate the text content of all selected cells into a string
- using tabulations and newlines to keep the table structure.
- Put this text into the clipboard.
- """
- selected_idx = self.table.selectedIndexes()
- if not selected_idx:
- return
- selected_idx_tuples = [(idx.row(), idx.column()) for idx in selected_idx]
-
- selected_rows = [idx[0] for idx in selected_idx_tuples]
- selected_columns = [idx[1] for idx in selected_idx_tuples]
-
- data_model = self.table.model()
-
- copied_text = ""
- for row in range(min(selected_rows), max(selected_rows) + 1):
- for col in range(min(selected_columns), max(selected_columns) + 1):
- index = data_model.index(row, col)
- cell_text = data_model.data(index)
- flags = data_model.flags(index)
-
- if (row, col) in selected_idx_tuples and cell_text is not None:
- copied_text += cell_text
- if self.cut and (flags & qt.Qt.ItemIsEditable):
- data_model.setData(index, "")
- copied_text += col_separator
- # remove the right-most tabulation
- copied_text = copied_text[:-len(col_separator)]
- # add a newline
- copied_text += row_separator
- # remove final newline
- copied_text = copied_text[:-len(row_separator)]
-
- # put this text into clipboard
- qapp = qt.QApplication.instance()
- qapp.clipboard().setText(copied_text)
-
-
-class CopyAllCellsAction(qt.QAction):
- """QAction to copy text from all cells in a :class:`QTableWidget`
- into the clipboard.
-
- The copied text will be a concatenation
- of the texts in all cells, tabulated with tabulation and
- newline characters.
-
- :param table: :class:`QTableView` to which this action belongs.
- """
- def __init__(self, table):
- if not isinstance(table, qt.QTableView):
- raise ValueError('CopyAllCellsAction must be initialised ' +
- 'with a QTableWidget.')
- super(CopyAllCellsAction, self).__init__(table)
- self.setText("Copy all")
- self.setToolTip("Copy all cells into the clipboard.")
- self.triggered.connect(self.copyCellsToClipboard)
- self.table = table
- self.cut = False
-
- def copyCellsToClipboard(self):
- """Concatenate the text content of all cells into a string
- using tabulations and newlines to keep the table structure.
- Put this text into the clipboard.
- """
- data_model = self.table.model()
- copied_text = ""
- for row in range(data_model.rowCount()):
- for col in range(data_model.columnCount()):
- index = data_model.index(row, col)
- cell_text = data_model.data(index)
- flags = data_model.flags(index)
- if cell_text is not None:
- copied_text += cell_text
- if self.cut and (flags & qt.Qt.ItemIsEditable):
- data_model.setData(index, "")
- copied_text += col_separator
- # remove the right-most tabulation
- copied_text = copied_text[:-len(col_separator)]
- # add a newline
- copied_text += row_separator
- # remove final newline
- copied_text = copied_text[:-len(row_separator)]
-
- # put this text into clipboard
- qapp = qt.QApplication.instance()
- qapp.clipboard().setText(copied_text)
-
-
-class CutSelectedCellsAction(CopySelectedCellsAction):
- """QAction to cut text from selected cells in a :class:`QTableWidget`
- into the clipboard.
-
- The text is deleted from the original table widget
- (use :class:`CopySelectedCellsAction` to preserve the original data).
-
- If multiple cells are selected, the cut text will be a concatenation
- of the texts in all selected cells, tabulated with tabulation and
- newline characters.
-
- If the cells are sparsely selected, the structure is preserved by
- representing the unselected cells as empty strings in between two
- tabulation characters.
- Beware of pasting this data in another table widget, because depending
- on how the paste is implemented, the empty cells may cause data in the
- target table to be deleted, even though you didn't necessarily select the
- corresponding cell in the origin table.
-
- :param table: :class:`QTableView` to which this action belongs."""
- def __init__(self, table):
- super(CutSelectedCellsAction, self).__init__(table)
- self.setText("Cut selection")
- self.setShortcut(qt.QKeySequence.Cut)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
- # cutting is already implemented in CopySelectedCellsAction (but
- # it is disabled), we just need to enable it
- self.cut = True
-
-
-class CutAllCellsAction(CopyAllCellsAction):
- """QAction to cut text from all cells in a :class:`QTableWidget`
- into the clipboard.
-
- The text is deleted from the original table widget
- (use :class:`CopyAllCellsAction` to preserve the original data).
-
- The cut text will be a concatenation
- of the texts in all cells, tabulated with tabulation and
- newline characters.
-
- :param table: :class:`QTableView` to which this action belongs."""
- def __init__(self, table):
- super(CutAllCellsAction, self).__init__(table)
- self.setText("Cut all")
- self.setToolTip("Cut all cells into the clipboard.")
- self.cut = True
-
-
-def _parseTextAsTable(text, row_separator=row_separator, col_separator=col_separator):
- """Parse text into list of lists (2D sequence).
-
- The input text must be tabulated using tabulation characters and
- newlines to separate columns and rows.
-
- :param text: text to be parsed
- :param record_separator: String, or single character, to be interpreted
- as a record/row separator.
- :param field_separator: String, or single character, to be interpreted
- as a field/column separator.
- :return: 2D sequence of strings
- """
- rows = text.split(row_separator)
- table_data = [row.split(col_separator) for row in rows]
- return table_data
-
-
-class PasteCellsAction(qt.QAction):
- """QAction to paste text from the clipboard into the table.
-
- If the text contains tabulations and
- newlines, they are interpreted as column and row separators.
- In such a case, the text is split into multiple texts to be pasted
- into multiple cells.
-
- If a cell content is an empty string in the original text, it is
- ignored: the destination cell's text will not be deleted.
-
- :param table: :class:`QTableView` to which this action belongs.
- """
- def __init__(self, table):
- if not isinstance(table, qt.QTableView):
- raise ValueError('PasteCellsAction must be initialised ' +
- 'with a QTableWidget.')
- super(PasteCellsAction, self).__init__(table)
- self.table = table
- self.setText("Paste")
- self.setShortcut(qt.QKeySequence.Paste)
- self.setShortcutContext(qt.Qt.WidgetShortcut)
- self.setToolTip("Paste data. The selected cell is the top-left" +
- "corner of the paste area.")
- self.triggered.connect(self.pasteCellFromClipboard)
-
- def pasteCellFromClipboard(self):
- """Paste text from clipboard into the table.
-
- :return: *True* in case of success, *False* if pasting data failed.
- """
- selected_idx = self.table.selectedIndexes()
- if len(selected_idx) != 1:
- msgBox = qt.QMessageBox(parent=self.table)
- msgBox.setText("A single cell must be selected to paste data")
- msgBox.exec_()
- return False
-
- data_model = self.table.model()
-
- selected_row = selected_idx[0].row()
- selected_col = selected_idx[0].column()
-
- qapp = qt.QApplication.instance()
- clipboard_text = qapp.clipboard().text()
- table_data = _parseTextAsTable(clipboard_text)
-
- protected_cells = 0
- out_of_range_cells = 0
-
- # paste table data into cells, using selected cell as origin
- for row_offset in range(len(table_data)):
- for col_offset in range(len(table_data[row_offset])):
- target_row = selected_row + row_offset
- target_col = selected_col + col_offset
-
- if target_row >= data_model.rowCount() or\
- target_col >= data_model.columnCount():
- out_of_range_cells += 1
- continue
-
- index = data_model.index(target_row, target_col)
- flags = data_model.flags(index)
-
- # ignore empty strings
- if table_data[row_offset][col_offset] != "":
- if not flags & qt.Qt.ItemIsEditable:
- protected_cells += 1
- continue
- data_model.setData(index, table_data[row_offset][col_offset])
- # item.setText(table_data[row_offset][col_offset])
-
- if protected_cells or out_of_range_cells:
- msgBox = qt.QMessageBox(parent=self.table)
- msg = "Some data could not be inserted, "
- msg += "due to out-of-range or write-protected cells."
- msgBox.setText(msg)
- msgBox.exec_()
- return False
- return True
-
-
-class CopySingleCellAction(qt.QAction):
- """QAction to copy text from a single cell in a modified
- :class:`QTableWidget`.
-
- This action relies on the fact that the text in the last clicked cell
- are stored in :attr:`_last_cell_clicked` of the modified widget.
-
- In most cases, :class:`CopySelectedCellsAction` handles single cells,
- but if the selection mode of the widget has been set to NoSelection
- it is necessary to use this class instead.
-
- :param table: :class:`QTableView` to which this action belongs.
- """
- def __init__(self, table):
- if not isinstance(table, qt.QTableView):
- raise ValueError('CopySingleCellAction must be initialised ' +
- 'with a QTableWidget.')
- super(CopySingleCellAction, self).__init__(table)
- self.setText("Copy cell")
- self.setToolTip("Copy cell content into the clipboard.")
- self.triggered.connect(self.copyCellToClipboard)
- self.table = table
-
- def copyCellToClipboard(self):
- """
- """
- cell_text = self.table._text_last_cell_clicked
- if cell_text is None:
- return
-
- # put this text into clipboard
- qapp = qt.QApplication.instance()
- qapp.clipboard().setText(cell_text)
-
-
-class TableWidget(qt.QTableWidget):
- """:class:`QTableWidget` with a context menu displaying up to 5 actions:
-
- - :class:`CopySelectedCellsAction`
- - :class:`CopyAllCellsAction`
- - :class:`CutSelectedCellsAction`
- - :class:`CutAllCellsAction`
- - :class:`PasteCellsAction`
-
- These actions interact with the clipboard and can be used to copy data
- to or from an external application, or another widget.
-
- The cut and paste actions are disabled by default, due to the risk of
- overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
- and :meth:`enableCut` to activate them.
-
- .. image:: img/TableWidget.png
-
- :param parent: Parent QWidget
- :param bool cut: Enable cut action
- :param bool paste: Enable paste action
- """
- def __init__(self, parent=None, cut=False, paste=False):
- super(TableWidget, self).__init__(parent)
- self._text_last_cell_clicked = None
-
- self.copySelectedCellsAction = CopySelectedCellsAction(self)
- self.copyAllCellsAction = CopyAllCellsAction(self)
- self.copySingleCellAction = None
- self.pasteCellsAction = None
- self.cutSelectedCellsAction = None
- self.cutAllCellsAction = None
-
- self.addAction(self.copySelectedCellsAction)
- self.addAction(self.copyAllCellsAction)
- if cut:
- self.enableCut()
- if paste:
- self.enablePaste()
-
- self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
-
- def mousePressEvent(self, event):
- item = self.itemAt(event.pos())
- if item is not None:
- self._text_last_cell_clicked = item.text()
- super(TableWidget, self).mousePressEvent(event)
-
- def enablePaste(self):
- """Enable paste action, to paste data from the clipboard into the
- table.
-
- .. warning::
-
- This action can cause data to be overwritten.
- There is currently no *Undo* action to retrieve lost data.
- """
- self.pasteCellsAction = PasteCellsAction(self)
- self.addAction(self.pasteCellsAction)
-
- def enableCut(self):
- """Enable cut action.
-
- .. warning::
-
- This action can cause data to be deleted.
- There is currently no *Undo* action to retrieve lost data."""
- self.cutSelectedCellsAction = CutSelectedCellsAction(self)
- self.cutAllCellsAction = CutAllCellsAction(self)
- self.addAction(self.cutSelectedCellsAction)
- self.addAction(self.cutAllCellsAction)
-
- def setSelectionMode(self, mode):
- """Overloaded from QTableWidget to disable cut/copy selection
- actions in case mode is NoSelection
-
- :param mode:
- :return:
- """
- if mode == qt.QTableView.NoSelection:
- self.copySelectedCellsAction.setVisible(False)
- self.copySelectedCellsAction.setEnabled(False)
- if self.cutSelectedCellsAction is not None:
- self.cutSelectedCellsAction.setVisible(False)
- self.cutSelectedCellsAction.setEnabled(False)
- if self.copySingleCellAction is None:
- self.copySingleCellAction = CopySingleCellAction(self)
- self.insertAction(self.copySelectedCellsAction, # before first action
- self.copySingleCellAction)
- self.copySingleCellAction.setVisible(True)
- self.copySingleCellAction.setEnabled(True)
- else:
- self.copySelectedCellsAction.setVisible(True)
- self.copySelectedCellsAction.setEnabled(True)
- if self.cutSelectedCellsAction is not None:
- self.cutSelectedCellsAction.setVisible(True)
- self.cutSelectedCellsAction.setEnabled(True)
- if self.copySingleCellAction is not None:
- self.copySingleCellAction.setVisible(False)
- self.copySingleCellAction.setEnabled(False)
- super(TableWidget, self).setSelectionMode(mode)
-
-
-class TableView(qt.QTableView):
- """:class:`QTableView` with a context menu displaying up to 5 actions:
-
- - :class:`CopySelectedCellsAction`
- - :class:`CopyAllCellsAction`
- - :class:`CutSelectedCellsAction`
- - :class:`CutAllCellsAction`
- - :class:`PasteCellsAction`
-
- These actions interact with the clipboard and can be used to copy data
- to or from an external application, or another widget.
-
- The cut and paste actions are disabled by default, due to the risk of
- overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
- and :meth:`enableCut` to activate them.
-
- .. note::
-
- These actions will be available only after a model is associated
- with this view, using :meth:`setModel`.
-
- :param parent: Parent QWidget
- :param bool cut: Enable cut action
- :param bool paste: Enable paste action
- """
- def __init__(self, parent=None, cut=False, paste=False):
- super(TableView, self).__init__(parent)
- self._text_last_cell_clicked = None
-
- self.cut = cut
- self.paste = paste
-
- self.copySelectedCellsAction = None
- self.copyAllCellsAction = None
- self.copySingleCellAction = None
- self.pasteCellsAction = None
- self.cutSelectedCellsAction = None
- self.cutAllCellsAction = None
-
- def mousePressEvent(self, event):
- qindex = self.indexAt(event.pos())
- if self.copyAllCellsAction is not None: # model was set
- self._text_last_cell_clicked = self.model().data(qindex)
- super(TableView, self).mousePressEvent(event)
-
- def setModel(self, model):
- """Set the data model for the table view, activate the actions
- and the context menu.
-
- :param model: :class:`qt.QAbstractItemModel` object
- """
- super(TableView, self).setModel(model)
-
- self.copySelectedCellsAction = CopySelectedCellsAction(self)
- self.copyAllCellsAction = CopyAllCellsAction(self)
- self.addAction(self.copySelectedCellsAction)
- self.addAction(self.copyAllCellsAction)
- if self.cut:
- self.enableCut()
- if self.paste:
- self.enablePaste()
-
- self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
-
- def enablePaste(self):
- """Enable paste action, to paste data from the clipboard into the
- table.
-
- .. warning::
-
- This action can cause data to be overwritten.
- There is currently no *Undo* action to retrieve lost data.
- """
- self.pasteCellsAction = PasteCellsAction(self)
- self.addAction(self.pasteCellsAction)
-
- def enableCut(self):
- """Enable cut action.
-
- .. warning::
-
- This action can cause data to be deleted.
- There is currently no *Undo* action to retrieve lost data.
- """
- self.cutSelectedCellsAction = CutSelectedCellsAction(self)
- self.cutAllCellsAction = CutAllCellsAction(self)
- self.addAction(self.cutSelectedCellsAction)
- self.addAction(self.cutAllCellsAction)
-
- def addAction(self, action):
- # ensure the actions are not added multiple times:
- # compare action type and parent widget with those of existing actions
- for existing_action in self.actions():
- if type(action) == type(existing_action):
- if hasattr(action, "table") and\
- action.table is existing_action.table:
- return None
- super(TableView, self).addAction(action)
-
- def setSelectionMode(self, mode):
- """Overloaded from QTableView to disable cut/copy selection
- actions in case mode is NoSelection
-
- :param mode:
- :return:
- """
- if mode == qt.QTableView.NoSelection:
- self.copySelectedCellsAction.setVisible(False)
- self.copySelectedCellsAction.setEnabled(False)
- if self.cutSelectedCellsAction is not None:
- self.cutSelectedCellsAction.setVisible(False)
- self.cutSelectedCellsAction.setEnabled(False)
- if self.copySingleCellAction is None:
- self.copySingleCellAction = CopySingleCellAction(self)
- self.insertAction(self.copySelectedCellsAction, # before first action
- self.copySingleCellAction)
- self.copySingleCellAction.setVisible(True)
- self.copySingleCellAction.setEnabled(True)
- else:
- self.copySelectedCellsAction.setVisible(True)
- self.copySelectedCellsAction.setEnabled(True)
- if self.cutSelectedCellsAction is not None:
- self.cutSelectedCellsAction.setVisible(True)
- self.cutSelectedCellsAction.setEnabled(True)
- if self.copySingleCellAction is not None:
- self.copySingleCellAction.setVisible(False)
- self.copySingleCellAction.setEnabled(False)
- super(TableView, self).setSelectionMode(mode)
-
-
-if __name__ == "__main__":
- app = qt.QApplication([])
-
- tablewidget = TableWidget()
- tablewidget.setWindowTitle("TableWidget")
- tablewidget.setColumnCount(10)
- tablewidget.setRowCount(7)
- tablewidget.enableCut()
- tablewidget.enablePaste()
- tablewidget.show()
-
- tableview = TableView(cut=True, paste=True)
- tableview.setWindowTitle("TableView")
- model = qt.QStandardItemModel()
- model.setColumnCount(10)
- model.setRowCount(7)
- tableview.setModel(model)
- tableview.show()
-
- app.exec_()
diff --git a/silx/gui/widgets/UrlSelectionTable.py b/silx/gui/widgets/UrlSelectionTable.py
deleted file mode 100644
index fb15edd..0000000
--- a/silx/gui/widgets/UrlSelectionTable.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# /*##########################################################################
-# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
-#
-# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
-# the ESRF by the Software group.
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-#############################################################################*/
-"""Some widget construction to check if a sample moved"""
-
-__author__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "19/03/2018"
-
-from silx.gui import qt
-from collections import OrderedDict
-from silx.gui.widgets.TableWidget import TableWidget
-from silx.io.url import DataUrl
-import functools
-import logging
-import os
-
-logger = logging.getLogger(__name__)
-
-
-class UrlSelectionTable(TableWidget):
- """Table used to select the color channel to be displayed for each"""
-
- COLUMS_INDEX = OrderedDict([
- ('url', 0),
- ('img A', 1),
- ('img B', 2),
- ])
-
- sigImageAChanged = qt.Signal(str)
- """Signal emitted when the image A change. Param is the image url path"""
-
- sigImageBChanged = qt.Signal(str)
- """Signal emitted when the image B change. Param is the image url path"""
-
- def __init__(self, parent=None):
- TableWidget.__init__(self, parent)
- self.clear()
-
- def clear(self):
- qt.QTableWidget.clear(self)
- self.setRowCount(0)
- self.setColumnCount(len(self.COLUMS_INDEX))
- self.setHorizontalHeaderLabels(list(self.COLUMS_INDEX.keys()))
- self.verticalHeader().hide()
- if hasattr(self.horizontalHeader(), 'setSectionResizeMode'): # Qt5
- self.horizontalHeader().setSectionResizeMode(0,
- qt.QHeaderView.Stretch)
- else: # Qt4
- self.horizontalHeader().setResizeMode(0, qt.QHeaderView.Stretch)
-
- self.setSortingEnabled(True)
- self._checkBoxes = {}
-
- def setUrls(self, urls: list) -> None:
- """
-
- :param urls: urls to be displayed
- """
- for url in urls:
- self.addUrl(url=url)
-
- def addUrl(self, url, **kwargs):
- """
-
- :param url:
- :param args:
- :return: index of the created items row
- :rtype int
- """
- assert isinstance(url, DataUrl)
- row = self.rowCount()
- self.setRowCount(row + 1)
-
- _item = qt.QTableWidgetItem()
- _item.setText(os.path.basename(url.path()))
- _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
- self.setItem(row, self.COLUMS_INDEX['url'], _item)
-
- widgetImgA = qt.QRadioButton(parent=self)
- widgetImgA.setAutoExclusive(False)
- self.setCellWidget(row, self.COLUMS_INDEX['img A'], widgetImgA)
- callbackImgA = functools.partial(self._activeImgAChanged, url.path())
- widgetImgA.toggled.connect(callbackImgA)
-
- widgetImgB = qt.QRadioButton(parent=self)
- widgetImgA.setAutoExclusive(False)
- self.setCellWidget(row, self.COLUMS_INDEX['img B'], widgetImgB)
- callbackImgB = functools.partial(self._activeImgBChanged, url.path())
- widgetImgB.toggled.connect(callbackImgB)
-
- self._checkBoxes[url.path()] = {'img A': widgetImgA,
- 'img B': widgetImgB}
- self.resizeColumnsToContents()
- return row
-
- def _activeImgAChanged(self, name):
- self._updatecheckBoxes('img A', name)
- self.sigImageAChanged.emit(name)
-
- def _activeImgBChanged(self, name):
- self._updatecheckBoxes('img B', name)
- self.sigImageBChanged.emit(name)
-
- def _updatecheckBoxes(self, whichImg, name):
- assert name in self._checkBoxes
- assert whichImg in self._checkBoxes[name]
- if self._checkBoxes[name][whichImg].isChecked():
- for radioUrl in self._checkBoxes:
- if radioUrl != name:
- self._checkBoxes[radioUrl][whichImg].blockSignals(True)
- self._checkBoxes[radioUrl][whichImg].setChecked(False)
- self._checkBoxes[radioUrl][whichImg].blockSignals(False)
-
- def getSelection(self):
- """
-
- :return: url selected for img A and img B.
- """
- imgA = imgB = None
- for radioUrl in self._checkBoxes:
- if self._checkBoxes[radioUrl]['img A'].isChecked():
- imgA = radioUrl
- if self._checkBoxes[radioUrl]['img B'].isChecked():
- imgB = radioUrl
- return imgA, imgB
-
- def setSelection(self, url_img_a, url_img_b):
- """
-
- :param ddict: key: image url, values: list of active channels
- """
- for radioUrl in self._checkBoxes:
- for img in ('img A', 'img B'):
- self._checkBoxes[radioUrl][img].blockSignals(True)
- self._checkBoxes[radioUrl][img].setChecked(False)
- self._checkBoxes[radioUrl][img].blockSignals(False)
-
- self._checkBoxes[radioUrl][img].blockSignals(True)
- self._checkBoxes[url_img_a]['img A'].setChecked(True)
- self._checkBoxes[radioUrl][img].blockSignals(False)
-
- self._checkBoxes[radioUrl][img].blockSignals(True)
- self._checkBoxes[url_img_b]['img B'].setChecked(True)
- self._checkBoxes[radioUrl][img].blockSignals(False)
- self.sigImageAChanged.emit(url_img_a)
- self.sigImageBChanged.emit(url_img_b)
-
- def removeUrl(self, url):
- raise NotImplementedError("")
diff --git a/silx/gui/widgets/WaitingPushButton.py b/silx/gui/widgets/WaitingPushButton.py
deleted file mode 100644
index 499de1a..0000000
--- a/silx/gui/widgets/WaitingPushButton.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2004-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""WaitingPushButton module
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "26/04/2017"
-
-from .. import qt
-from .. import icons
-
-
-class WaitingPushButton(qt.QPushButton):
- """Button which allows to display a waiting status when, for example,
- something is still computing.
-
- The component is graphically disabled when it is in waiting. Then we
- overwrite the enabled method to dissociate the 2 concepts:
- graphically enabled/disabled, and enabled/disabled
-
- .. image:: img/WaitingPushButton.png
- """
-
- def __init__(self, parent=None, text=None, icon=None):
- """Constructor
-
- :param str text: Text displayed on the button
- :param qt.QIcon icon: Icon displayed on the button
- :param qt.QWidget parent: Parent of the widget
- """
- if icon is not None:
- qt.QPushButton.__init__(self, icon, text, parent)
- elif text is not None:
- qt.QPushButton.__init__(self, text, parent)
- else:
- qt.QPushButton.__init__(self, parent)
-
- self.__waiting = False
- self.__enabled = True
- self.__icon = icon
- self.__disabled_when_waiting = True
- self.__waitingIcon = icons.getWaitIcon()
-
- def sizeHint(self):
- """Returns the recommended size for the widget.
-
- This implementation of the recommended size always consider there is an
- icon. In this way it avoid to update the layout when the waiting icon
- is displayed.
- """
- self.ensurePolished()
-
- w = 0
- h = 0
-
- opt = qt.QStyleOptionButton()
- self.initStyleOption(opt)
-
- # Content with icon
- # no condition, assume that there is an icon to avoid blinking
- # when the widget switch to waiting state
- ih = opt.iconSize.height()
- iw = opt.iconSize.width() + 4
- w += iw
- h = max(h, ih)
-
- # Content with text
- text = self.text()
- isEmpty = text == ""
- if isEmpty:
- text = "XXXX"
- fm = self.fontMetrics()
- textSize = fm.size(qt.Qt.TextShowMnemonic, text)
- if not isEmpty or w == 0:
- w += textSize.width()
- if not isEmpty or h == 0:
- h = max(h, textSize.height())
-
- # Content with menu indicator
- opt.rect.setSize(qt.QSize(w, h)) # PM_MenuButtonIndicator depends on the height
- if self.menu() is not None:
- w += self.style().pixelMetric(qt.QStyle.PM_MenuButtonIndicator, opt, self)
-
- contentSize = qt.QSize(w, h)
- if qt.qVersion().startswith("4.8."):
- # On PyQt4/PySide the method QCommonStyle sizeFromContents returns
- # different size when the widget provides an icon or not.
- # In Qt5 there is not this problem.
- opt.icon = qt.QIcon()
- sizeHint = self.style().sizeFromContents(qt.QStyle.CT_PushButton, opt, contentSize, self)
- sizeHint = sizeHint.expandedTo(qt.QApplication.globalStrut())
- return sizeHint
-
- def setDisabledWhenWaiting(self, isDisabled):
- """Enable or disable the auto disable behaviour when the button is waiting.
-
- :param bool isDisabled: Enable the auto-disable behaviour
- """
- if self.__disabled_when_waiting == isDisabled:
- return
- self.__disabled_when_waiting = isDisabled
- self.__updateVisibleEnabled()
-
- def isDisabledWhenWaiting(self):
- """Returns true if the button is auto disabled when it is waiting.
-
- :rtype: bool
- """
- return self.__disabled_when_waiting
-
- disabledWhenWaiting = qt.Property(bool, isDisabledWhenWaiting, setDisabledWhenWaiting)
- """Property to enable/disable the auto disabled state when the button is waiting."""
-
- def __setWaitingIcon(self, icon):
- """Called when the waiting icon is updated. It is called every frames
- of the animation.
-
- :param qt.QIcon icon: The new waiting icon
- """
- qt.QPushButton.setIcon(self, icon)
-
- def setIcon(self, icon):
- """Set the button icon. If the button is waiting, the icon is not
- visible directly, but will be visible when the waiting state will be
- removed.
-
- :param qt.QIcon icon: An icon
- """
- self.__icon = icon
- self.__updateVisibleIcon()
-
- def getIcon(self):
- """Returns the icon set to the button. If the widget is waiting
- it is not returning the visible icon, but the one requested by
- the application (the one displayed when the widget is not in
- waiting state).
-
- :rtype: qt.QIcon
- """
- return self.__icon
-
- icon = qt.Property(qt.QIcon, getIcon, setIcon)
- """Property providing access to the icon."""
-
- def __updateVisibleIcon(self):
- """Update the visible icon according to the state of the widget."""
- if not self.isWaiting():
- icon = self.__icon
- else:
- icon = self.__waitingIcon.currentIcon()
- if icon is None:
- icon = qt.QIcon()
- qt.QPushButton.setIcon(self, icon)
-
- def setEnabled(self, enabled):
- """Set the enabled state of the widget.
-
- :param bool enabled: The enabled state
- """
- if self.__enabled == enabled:
- return
- self.__enabled = enabled
- self.__updateVisibleEnabled()
-
- def isEnabled(self):
- """Returns the enabled state of the widget.
-
- :rtype: bool
- """
- return self.__enabled
-
- enabled = qt.Property(bool, isEnabled, setEnabled)
- """Property providing access to the enabled state of the widget"""
-
- def __updateVisibleEnabled(self):
- """Update the visible enabled state according to the state of the
- widget."""
- if self.__disabled_when_waiting:
- enabled = not self.isWaiting() and self.__enabled
- else:
- enabled = self.__enabled
- qt.QPushButton.setEnabled(self, enabled)
-
- def setWaiting(self, waiting):
- """Set the waiting state of the widget.
-
- :param bool waiting: Requested state"""
- if self.__waiting == waiting:
- return
- self.__waiting = waiting
-
- if self.__waiting:
- self.__waitingIcon.register(self)
- self.__waitingIcon.iconChanged.connect(self.__setWaitingIcon)
- else:
- # unregister only if the object is registred
- self.__waitingIcon.unregister(self)
- self.__waitingIcon.iconChanged.disconnect(self.__setWaitingIcon)
-
- self.__updateVisibleEnabled()
- self.__updateVisibleIcon()
-
- def isWaiting(self):
- """Returns true if the widget is in waiting state.
-
- :rtype: bool"""
- return self.__waiting
-
- @qt.Slot()
- def wait(self):
- """Enable the waiting state."""
- self.setWaiting(True)
-
- @qt.Slot()
- def stopWaiting(self):
- """Disable the waiting state."""
- self.setWaiting(False)
-
- @qt.Slot()
- def swapWaiting(self):
- """Swap the waiting state."""
- self.setWaiting(not self.isWaiting())
diff --git a/silx/gui/widgets/test/__init__.py b/silx/gui/widgets/test/__init__.py
deleted file mode 100644
index 9aaec76..0000000
--- a/silx/gui/widgets/test/__init__.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-import unittest
-
-from . import test_periodictable
-from . import test_tablewidget
-from . import test_threadpoolpushbutton
-from . import test_hierarchicaltableview
-from . import test_printpreview
-from . import test_framebrowser
-from . import test_boxlayoutdockwidget
-from . import test_rangeslider
-from . import test_flowlayout
-from . import test_elidedlabel
-from . import test_legendiconwidget
-
-__authors__ = ["V. Valls", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "19/07/2017"
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- [test_threadpoolpushbutton.suite(),
- test_tablewidget.suite(),
- test_periodictable.suite(),
- test_printpreview.suite(),
- test_hierarchicaltableview.suite(),
- test_framebrowser.suite(),
- test_boxlayoutdockwidget.suite(),
- test_rangeslider.suite(),
- test_flowlayout.suite(),
- test_elidedlabel.suite(),
- test_legendiconwidget.suite(),
- ])
- return test_suite
diff --git a/silx/gui/widgets/test/test_boxlayoutdockwidget.py b/silx/gui/widgets/test/test_boxlayoutdockwidget.py
deleted file mode 100644
index 9a93ca1..0000000
--- a/silx/gui/widgets/test/test_boxlayoutdockwidget.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for BoxLayoutDockWidget"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "06/03/2018"
-
-import unittest
-
-from silx.gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-
-
-class TestBoxLayoutDockWidget(TestCaseQt):
- """Tests for BoxLayoutDockWidget"""
-
- def setUp(self):
- """Create and show a main window"""
- self.window = qt.QMainWindow()
- self.qWaitForWindowExposed(self.window)
-
- def tearDown(self):
- """Delete main window"""
- self.window.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.window.close()
- del self.window
- self.qapp.processEvents()
-
- def test(self):
- """Test update of layout direction according to dock area"""
- # Create a widget with a QBoxLayout
- layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
- layout.addWidget(qt.QLabel('First'))
- layout.addWidget(qt.QLabel('Second'))
- widget = qt.QWidget()
- widget.setLayout(layout)
-
- # Add it to a BoxLayoutDockWidget
- dock = BoxLayoutDockWidget()
- dock.setWidget(widget)
-
- self.window.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
- self.qapp.processEvents()
- self.assertEqual(layout.direction(), qt.QBoxLayout.LeftToRight)
-
- self.window.addDockWidget(qt.Qt.LeftDockWidgetArea, dock)
- self.qapp.processEvents()
- self.assertEqual(layout.direction(), qt.QBoxLayout.TopToBottom)
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestBoxLayoutDockWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_elidedlabel.py b/silx/gui/widgets/test/test_elidedlabel.py
deleted file mode 100644
index 2856733..0000000
--- a/silx/gui/widgets/test/test_elidedlabel.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for ElidedLabel"""
-
-__license__ = "MIT"
-__date__ = "08/06/2020"
-
-import unittest
-
-from silx.gui import qt
-from silx.gui.widgets.ElidedLabel import ElidedLabel
-from silx.gui.utils import testutils
-
-
-class TestElidedLabel(testutils.TestCaseQt):
-
- def setUp(self):
- self.label = ElidedLabel()
- self.label.show()
- self.qWaitForWindowExposed(self.label)
-
- def tearDown(self):
- self.label.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.label.close()
- del self.label
- self.qapp.processEvents()
-
- def testElidedValue(self):
- """Test elided text"""
- raw = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
- self.label.setText(raw)
- self.label.setFixedWidth(30)
- displayedText = qt.QLabel.text(self.label)
- self.assertNotEqual(raw, displayedText)
- self.assertIn("…", displayedText)
- self.assertIn("m", displayedText)
-
- def testNotElidedValue(self):
- """Test elided text"""
- raw = "mmmmmmm"
- self.label.setText(raw)
- self.label.setFixedWidth(200)
- displayedText = qt.QLabel.text(self.label)
- self.assertNotIn("…", displayedText)
- self.assertEqual(raw, displayedText)
-
- def testUpdateFromElidedToNotElided(self):
- """Test tooltip when not elided"""
- raw1 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
- raw2 = "nn"
- self.label.setText(raw1)
- self.label.setFixedWidth(30)
- self.label.setText(raw2)
- displayedTooltip = qt.QLabel.toolTip(self.label)
- self.assertNotIn(raw1, displayedTooltip)
- self.assertNotIn(raw2, displayedTooltip)
-
- def testUpdateFromNotElidedToElided(self):
- """Test tooltip when elided"""
- raw1 = "nn"
- raw2 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
- self.label.setText(raw1)
- self.label.setFixedWidth(30)
- self.label.setText(raw2)
- displayedTooltip = qt.QLabel.toolTip(self.label)
- self.assertNotIn(raw1, displayedTooltip)
- self.assertIn(raw2, displayedTooltip)
-
- def testUpdateFromElidedToElided(self):
- """Test tooltip when elided"""
- raw1 = "nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn"
- raw2 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
- self.label.setText(raw1)
- self.label.setFixedWidth(30)
- self.label.setText(raw2)
- displayedTooltip = qt.QLabel.toolTip(self.label)
- self.assertNotIn(raw1, displayedTooltip)
- self.assertIn(raw2, displayedTooltip)
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestElidedLabel))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_flowlayout.py b/silx/gui/widgets/test/test_flowlayout.py
deleted file mode 100644
index 1497945..0000000
--- a/silx/gui/widgets/test/test_flowlayout.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for FlowLayout"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "02/08/2018"
-
-import unittest
-
-from silx.gui.widgets.FlowLayout import FlowLayout
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-
-
-class TestFlowLayout(TestCaseQt):
- """Tests for FlowLayout"""
-
- def setUp(self):
- """Create and show a widget"""
- self.widget = qt.QWidget()
- self.widget.show()
- self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- """Delete widget"""
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- del self.widget
- self.qapp.processEvents()
-
- def test(self):
- """Basic tests"""
- layout = FlowLayout()
- self.widget.setLayout(layout)
-
- layout.addWidget(qt.QLabel('first'))
- layout.addWidget(qt.QLabel('second'))
- self.assertEqual(layout.count(), 2)
-
- layout.setHorizontalSpacing(10)
- self.assertEqual(layout.horizontalSpacing(), 10)
- layout.setVerticalSpacing(5)
- self.assertEqual(layout.verticalSpacing(), 5)
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestFlowLayout))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_framebrowser.py b/silx/gui/widgets/test/test_framebrowser.py
deleted file mode 100644
index 2dfd302..0000000
--- a/silx/gui/widgets/test/test_framebrowser.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "23/03/2018"
-
-
-import unittest
-
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.widgets.FrameBrowser import FrameBrowser
-
-
-class TestFrameBrowser(TestCaseQt):
- """Test for FrameBrowser"""
-
- def test(self):
- """Test FrameBrowser"""
- widget = FrameBrowser()
- widget.show()
- self.qWaitForWindowExposed(widget)
-
- nFrames = 20
- widget.setNFrames(nFrames)
- self.assertEqual(widget.getRange(), (0, nFrames - 1))
- self.assertEqual(widget.getValue(), 0)
-
- range_ = -100, 100
- widget.setRange(*range_)
- self.assertEqual(widget.getRange(), range_)
- self.assertEqual(widget.getValue(), range_[0])
-
- widget.setValue(0)
- self.assertEqual(widget.getValue(), 0)
-
- widget.setValue(range_[1] + 100)
- self.assertEqual(widget.getValue(), range_[1])
-
- widget.setValue(range_[0] - 100)
- self.assertEqual(widget.getValue(), range_[0])
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestFrameBrowser))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_hierarchicaltableview.py b/silx/gui/widgets/test/test_hierarchicaltableview.py
deleted file mode 100644
index 9fad54d..0000000
--- a/silx/gui/widgets/test/test_hierarchicaltableview.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "07/04/2017"
-
-import unittest
-
-from .. import HierarchicalTableView
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-
-class TableModel(HierarchicalTableView.HierarchicalTableModel):
-
- def __init__(self, parent):
- HierarchicalTableView.HierarchicalTableModel.__init__(self, parent)
- self.__content = {}
-
- def rowCount(self, parent=qt.QModelIndex()):
- return 3
-
- def columnCount(self, parent=qt.QModelIndex()):
- return 3
-
- def setData1(self):
- if qt.qVersion() > "4.6":
- self.beginResetModel()
- else:
- self.reset()
-
- content = {}
- content[0, 0] = ("title", True, (1, 3))
- content[0, 1] = ("a", True, (2, 1))
- content[1, 1] = ("b", False, (1, 2))
- content[1, 2] = ("c", False, (1, 1))
- content[2, 2] = ("d", False, (1, 1))
- self.__content = content
- if qt.qVersion() > "4.6":
- self.endResetModel()
-
- def data(self, index, role=qt.Qt.DisplayRole):
- if not index.isValid():
- return None
- cell = self.__content.get((index.column(), index.row()), None)
- if cell is None:
- return None
-
- if role == self.SpanRole:
- return cell[2]
- elif role == self.IsHeaderRole:
- return cell[1]
- elif role == qt.Qt.DisplayRole:
- return cell[0]
- return None
-
-
-class TestHierarchicalTableView(TestCaseQt):
- """Test for HierarchicalTableView"""
-
- def testEmpty(self):
- widget = HierarchicalTableView.HierarchicalTableView()
- widget.show()
- self.qWaitForWindowExposed(widget)
-
- def testModel(self):
- widget = HierarchicalTableView.HierarchicalTableView()
- model = TableModel(widget)
- # set the data before using the model into the widget
- model.setData1()
- widget.setModel(model)
- span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
- self.assertEqual(span, (1, 3))
- widget.show()
- self.qWaitForWindowExposed(widget)
-
- def testModelUpdate(self):
- widget = HierarchicalTableView.HierarchicalTableView()
- model = TableModel(widget)
- widget.setModel(model)
- # set the data after using the model into the widget
- model.setData1()
- span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
- self.assertEqual(span, (1, 3))
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestHierarchicalTableView))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_legendiconwidget.py b/silx/gui/widgets/test/test_legendiconwidget.py
deleted file mode 100644
index f845f75..0000000
--- a/silx/gui/widgets/test/test_legendiconwidget.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for LegendIconWidget"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "23/10/2020"
-
-import unittest
-
-from silx.gui import qt
-from silx.gui.widgets.LegendIconWidget import LegendIconWidget
-from silx.gui.utils.testutils import TestCaseQt
-from silx.utils.testutils import ParametricTestCase
-
-
-class TestLegendIconWidget(TestCaseQt, ParametricTestCase):
- """Tests for TestRangeSlider"""
-
- def setUp(self):
- self.widget = LegendIconWidget()
- self.widget.show()
- self.qWaitForWindowExposed(self.widget)
-
- def tearDown(self):
- self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.widget.close()
- del self.widget
- self.qapp.processEvents()
-
- def testCreate(self):
- self.qapp.processEvents()
-
- def testColormap(self):
- self.widget.setColormap("viridis")
- self.qapp.processEvents()
-
- def testSymbol(self):
- self.widget.setSymbol("o")
- self.widget.setSymbolColormap("viridis")
- self.qapp.processEvents()
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestLegendIconWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_periodictable.py b/silx/gui/widgets/test/test_periodictable.py
deleted file mode 100644
index 3e7eb16..0000000
--- a/silx/gui/widgets/test/test_periodictable.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-import unittest
-
-from .. import PeriodicTable
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui import qt
-
-
-class TestPeriodicTable(TestCaseQt):
- """Basic test for ArrayTableWidget with a numpy array"""
-
- def testShow(self):
- """basic test (instantiation done in setUp)"""
- pt = PeriodicTable.PeriodicTable()
- pt.show()
- self.qWaitForWindowExposed(pt)
-
- def testSelectable(self):
- """basic test (instantiation done in setUp)"""
- pt = PeriodicTable.PeriodicTable(selectable=True)
- self.assertTrue(pt.selectable)
-
- def testCustomElements(self):
- PTI = PeriodicTable.ColoredPeriodicTableItem
- my_items = [
- PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2,
- bgcolor="#FF0000"),
- PTI("Yy", 25, 22, 44, "yoyotrium", 8.8)
- ]
-
- pt = PeriodicTable.PeriodicTable(elements=my_items)
-
- pt.setSelection(["He", "Xx"])
- selection = pt.getSelection()
- self.assertEqual(len(selection), 1) # "He" not found
- self.assertEqual(selection[0].symbol, "Xx")
- self.assertEqual(selection[0].Z, 42)
- self.assertEqual(selection[0].col, 43)
- self.assertAlmostEqual(selection[0].mass, 1002.2)
- self.assertEqual(qt.QColor(selection[0].bgcolor),
- qt.QColor(qt.Qt.red))
-
- self.assertTrue(pt.isElementSelected("Xx"))
- self.assertFalse(pt.isElementSelected("Yy"))
- self.assertRaises(KeyError, pt.isElementSelected, "Yx")
-
- def testVeryCustomElements(self):
- class MyPTI(PeriodicTable.PeriodicTableItem):
- def __init__(self, *args):
- PeriodicTable.PeriodicTableItem.__init__(self, *args[:6])
- self.my_feature = args[6]
-
- my_items = [
- MyPTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, "spam"),
- MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs")
- ]
-
- pt = PeriodicTable.PeriodicTable(elements=my_items)
-
- pt.setSelection(["Xx", "Yy"])
- selection = pt.getSelection()
- self.assertEqual(len(selection), 2)
- self.assertEqual(selection[1].symbol, "Yy")
- self.assertEqual(selection[1].Z, 25)
- self.assertEqual(selection[1].col, 22)
- self.assertEqual(selection[1].row, 44)
- self.assertAlmostEqual(selection[0].mass, 1002.2)
- self.assertAlmostEqual(selection[0].my_feature, "spam")
-
-
-class TestPeriodicCombo(TestCaseQt):
- """Basic test for ArrayTableWidget with a numpy array"""
- def setUp(self):
- super(TestPeriodicCombo, self).setUp()
- self.pc = PeriodicTable.PeriodicCombo()
-
- def tearDown(self):
- del self.pc
- super(TestPeriodicCombo, self).tearDown()
-
- def testShow(self):
- """basic test (instantiation done in setUp)"""
- self.pc.show()
- self.qWaitForWindowExposed(self.pc)
-
- def testSelect(self):
- self.pc.setSelection("Sb")
- selection = self.pc.getSelection()
- self.assertIsInstance(selection,
- PeriodicTable.PeriodicTableItem)
- self.assertEqual(selection.symbol, "Sb")
- self.assertEqual(selection.Z, 51)
- self.assertEqual(selection.name, "antimony")
-
-
-class TestPeriodicList(TestCaseQt):
- """Basic test for ArrayTableWidget with a numpy array"""
- def setUp(self):
- super(TestPeriodicList, self).setUp()
- self.pl = PeriodicTable.PeriodicList()
-
- def tearDown(self):
- del self.pl
- super(TestPeriodicList, self).tearDown()
-
- def testShow(self):
- """basic test (instantiation done in setUp)"""
- self.pl.show()
- self.qWaitForWindowExposed(self.pl)
-
- def testSelect(self):
- self.pl.setSelectedElements(["Li", "He", "Au"])
- sel_elmts = self.pl.getSelection()
-
- self.assertEqual(len(sel_elmts), 3,
- "Wrong number of elements selected")
- for e in sel_elmts:
- self.assertIsInstance(e, PeriodicTable.PeriodicTableItem)
- self.assertIn(e.symbol, ["Li", "He", "Au"])
- self.assertIn(e.Z, [2, 3, 79])
- self.assertIn(e.name, ["lithium", "helium", "gold"])
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicTable))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicList))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestPeriodicCombo))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_printpreview.py b/silx/gui/widgets/test/test_printpreview.py
deleted file mode 100644
index 3c29171..0000000
--- a/silx/gui/widgets/test/test_printpreview.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test PrintPreview"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "19/07/2017"
-
-
-import unittest
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.widgets.PrintPreview import PrintPreviewDialog
-from silx.gui import qt
-
-from silx.resources import resource_filename
-
-
-class TestPrintPreview(TestCaseQt):
- def testShow(self):
- p = qt.QPrinter()
- d = PrintPreviewDialog(printer=p)
- d.show()
- self.qapp.processEvents()
-
- def testAddImage(self):
- p = qt.QPrinter()
- d = PrintPreviewDialog(printer=p)
- d.addImage(qt.QImage(resource_filename("gui/icons/clipboard.png")))
- self.qapp.processEvents()
-
- def testAddSvg(self):
- p = qt.QPrinter()
- d = PrintPreviewDialog(printer=p)
- d.addSvgItem(qt.QSvgRenderer(resource_filename("gui/icons/clipboard.svg"), d.page))
- self.qapp.processEvents()
-
- def testAddPixmap(self):
- p = qt.QPrinter()
- d = PrintPreviewDialog(printer=p)
- d.addPixmap(qt.QPixmap.fromImage(qt.QImage(resource_filename("gui/icons/clipboard.png"))))
- self.qapp.processEvents()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestPrintPreview))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_rangeslider.py b/silx/gui/widgets/test/test_rangeslider.py
deleted file mode 100644
index 2829050..0000000
--- a/silx/gui/widgets/test/test_rangeslider.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for RangeSlider"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "01/08/2018"
-
-import unittest
-
-from silx.gui import qt, colors
-from silx.gui.widgets.RangeSlider import RangeSlider
-from silx.gui.utils.testutils import TestCaseQt
-from silx.utils.testutils import ParametricTestCase
-
-
-class TestRangeSlider(TestCaseQt, ParametricTestCase):
- """Tests for TestRangeSlider"""
-
- def setUp(self):
- self.slider = RangeSlider()
- self.slider.show()
- self.qWaitForWindowExposed(self.slider)
-
- def tearDown(self):
- self.slider.setAttribute(qt.Qt.WA_DeleteOnClose)
- self.slider.close()
- del self.slider
- self.qapp.processEvents()
-
- def testRangeValue(self):
- """Test slider range and values"""
-
- # Play with range
- self.slider.setRange(1, 2)
- self.assertEqual(self.slider.getRange(), (1., 2.))
- self.assertEqual(self.slider.getValues(), (1., 1.))
-
- self.slider.setMinimum(-1)
- self.assertEqual(self.slider.getRange(), (-1., 2.))
- self.assertEqual(self.slider.getValues(), (1., 1.))
-
- self.slider.setMaximum(0)
- self.assertEqual(self.slider.getRange(), (-1., 0.))
- self.assertEqual(self.slider.getValues(), (0., 0.))
-
- # Play with values
- self.slider.setFirstValue(-2.)
- self.assertEqual(self.slider.getValues(), (-1., 0.))
-
- self.slider.setFirstValue(-0.5)
- self.assertEqual(self.slider.getValues(), (-0.5, 0.))
-
- self.slider.setSecondValue(2.)
- self.assertEqual(self.slider.getValues(), (-0.5, 0.))
-
- self.slider.setSecondValue(-0.1)
- self.assertEqual(self.slider.getValues(), (-0.5, -0.1))
-
- def testStepCount(self):
- """Test related to step count"""
- self.slider.setPositionCount(11)
- self.assertEqual(self.slider.getPositionCount(), 11)
- self.slider.setFirstValue(0.32)
- self.assertEqual(self.slider.getFirstValue(), 0.3)
- self.assertEqual(self.slider.getFirstPosition(), 3)
-
- self.slider.setPositionCount(3) # Value is adjusted
- self.assertEqual(self.slider.getValues(), (0.5, 1.))
- self.assertEqual(self.slider.getPositions(), (1, 2))
-
- def testGroove(self):
- """Test Groove pixmap"""
- profile = list(range(100))
-
- for cmap in ('jet', colors.Colormap('viridis')):
- with self.subTest(str(cmap)):
- self.slider.setGroovePixmapFromProfile(profile, cmap)
- pixmap = self.slider.getGroovePixmap()
- self.assertIsInstance(pixmap, qt.QPixmap)
- self.assertEqual(pixmap.width(), len(profile))
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestRangeSlider))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_tablewidget.py b/silx/gui/widgets/test/test_tablewidget.py
deleted file mode 100644
index 6822aef..0000000
--- a/silx/gui/widgets/test/test_tablewidget.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test TableWidget"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "05/12/2016"
-
-
-import unittest
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.widgets.TableWidget import TableWidget
-
-
-class TestTableWidget(TestCaseQt):
- def setUp(self):
- super(TestTableWidget, self).setUp()
- self._result = []
-
- def testShow(self):
- table = TableWidget()
- table.setColumnCount(10)
- table.setRowCount(7)
- table.enableCut()
- table.enablePaste()
- table.show()
- table.hide()
- self.qapp.processEvents()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestTableWidget))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/gui/widgets/test/test_threadpoolpushbutton.py b/silx/gui/widgets/test/test_threadpoolpushbutton.py
deleted file mode 100644
index e92eb02..0000000
--- a/silx/gui/widgets/test/test_threadpoolpushbutton.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for silx.gui.hdf5 module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-import time
-from silx.gui import qt
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.utils.testutils import SignalListener
-from silx.gui.widgets.ThreadPoolPushButton import ThreadPoolPushButton
-from silx.utils.testutils import TestLogging
-
-
-class TestThreadPoolPushButton(TestCaseQt):
-
- def setUp(self):
- super(TestThreadPoolPushButton, self).setUp()
- self._result = []
-
- def waitForPendingOperations(self, object):
- for i in range(50):
- if not object.hasPendingOperations():
- break
- self.qWait(10)
- else:
- raise RuntimeError("Still waiting for a pending operation")
-
- def _trace(self, name, delay=0):
- self._result.append(name)
- if delay != 0:
- time.sleep(delay / 1000.0)
-
- def _compute(self):
- return "result"
-
- def _computeFail(self):
- raise Exception("exception")
-
- def testExecute(self):
- button = ThreadPoolPushButton()
- button.setCallable(self._trace, "a", 0)
- button.executeCallable()
- time.sleep(0.1)
- self.assertListEqual(self._result, ["a"])
- self.waitForPendingOperations(button)
-
- def testMultiExecution(self):
- button = ThreadPoolPushButton()
- button.setCallable(self._trace, "a", 0)
- number = qt.silxGlobalThreadPool().maxThreadCount()
- for _ in range(number):
- button.executeCallable()
- self.waitForPendingOperations(button)
- self.assertListEqual(self._result, ["a"] * number)
-
- def testSaturateThreadPool(self):
- button = ThreadPoolPushButton()
- button.setCallable(self._trace, "a", 100)
- number = qt.silxGlobalThreadPool().maxThreadCount() * 2
- for _ in range(number):
- button.executeCallable()
- self.waitForPendingOperations(button)
- self.assertListEqual(self._result, ["a"] * number)
-
- def testSuccess(self):
- listener = SignalListener()
- button = ThreadPoolPushButton()
- button.setCallable(self._compute)
- button.beforeExecuting.connect(listener.partial(test="be"))
- button.started.connect(listener.partial(test="s"))
- button.succeeded.connect(listener.partial(test="result"))
- button.failed.connect(listener.partial(test="Unexpected exception"))
- button.finished.connect(listener.partial(test="f"))
- button.executeCallable()
- self.qapp.processEvents()
- time.sleep(0.1)
- self.qapp.processEvents()
- result = listener.karguments(argumentName="test")
- self.assertListEqual(result, ["be", "s", "result", "f"])
-
- def testFail(self):
- listener = SignalListener()
- button = ThreadPoolPushButton()
- button.setCallable(self._computeFail)
- button.beforeExecuting.connect(listener.partial(test="be"))
- button.started.connect(listener.partial(test="s"))
- button.succeeded.connect(listener.partial(test="Unexpected success"))
- button.failed.connect(listener.partial(test="exception"))
- button.finished.connect(listener.partial(test="f"))
- with TestLogging('silx.gui.widgets.ThreadPoolPushButton', error=1):
- button.executeCallable()
- self.qapp.processEvents()
- time.sleep(0.1)
- self.qapp.processEvents()
- result = listener.karguments(argumentName="test")
- self.assertListEqual(result, ["be", "s", "exception", "f"])
- listener.clear()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestThreadPoolPushButton))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/image/marchingsquares/setup.py b/silx/image/marchingsquares/setup.py
deleted file mode 100644
index 902f297..0000000
--- a/silx/image/marchingsquares/setup.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "23/04/2018"
-
-import os
-import numpy
-from numpy.distutils.misc_util import Configuration
-
-
-def configuration(parent_package='', top_path=None):
- config = Configuration('marchingsquares', parent_package, top_path)
- config.add_subpackage('test')
-
- silx_include = os.path.join(top_path, "silx", "utils", "include")
- config.add_extension('_mergeimpl',
- sources=['_mergeimpl.pyx'],
- include_dirs=[numpy.get_include(), silx_include],
- language='c++',
- extra_link_args=['-fopenmp'],
- extra_compile_args=['-fopenmp'])
-
- return config
-
-
-if __name__ == "__main__":
- from numpy.distutils.core import setup
- setup(configuration=configuration)
diff --git a/silx/image/marchingsquares/test/__init__.py b/silx/image/marchingsquares/test/__init__.py
deleted file mode 100644
index 5351a28..0000000
--- a/silx/image/marchingsquares/test/__init__.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/04/2018"
-
-import unittest
-from . import test_funcapi
-from . import test_mergeimpl
-
-
-def suite():
- """Test suite for module silx.image.test"""
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_funcapi.suite())
- test_suite.addTest(test_mergeimpl.suite())
- return test_suite
diff --git a/silx/image/marchingsquares/test/test_funcapi.py b/silx/image/marchingsquares/test/test_funcapi.py
deleted file mode 100644
index a84a493..0000000
--- a/silx/image/marchingsquares/test/test_funcapi.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/04/2018"
-
-import unittest
-import numpy
-import silx.image.marchingsquares
-
-
-class MockMarchingSquares(object):
-
- last = None
-
- def __init__(self, image, mask=None):
- MockMarchingSquares.last = self
- self.events = []
- self.events.append(("image", image))
- self.events.append(("mask", mask))
-
- def find_pixels(self, level):
- self.events.append(("find_pixels", level))
- return None
-
- def find_contours(self, level):
- self.events.append(("find_contours", level))
- return None
-
-
-class TestFunctionalApi(unittest.TestCase):
- """Test that the default functional API is called using the right
- parameters to the right location."""
-
- def setUp(self):
- self.old_impl = silx.image.marchingsquares.MarchingSquaresMergeImpl
- silx.image.marchingsquares.MarchingSquaresMergeImpl = MockMarchingSquares
-
- def tearDown(self):
- silx.image.marchingsquares.MarchingSquaresMergeImpl = self.old_impl
- del self.old_impl
-
- def test_default_find_contours(self):
- image = numpy.ones((2, 2), dtype=numpy.float32)
- mask = numpy.zeros((2, 2), dtype=numpy.int32)
- level = 2.5
- silx.image.marchingsquares.find_contours(image=image, level=level, mask=mask)
- events = MockMarchingSquares.last.events
- self.assertEqual(len(events), 3)
- self.assertEqual(events[0][0], "image")
- self.assertEqual(events[0][1][0, 0], 1)
- self.assertEqual(events[1][0], "mask")
- self.assertEqual(events[1][1][0, 0], 0)
- self.assertEqual(events[2][0], "find_contours")
- self.assertEqual(events[2][1], level)
-
- def test_default_find_pixels(self):
- image = numpy.ones((2, 2), dtype=numpy.float32)
- mask = numpy.zeros((2, 2), dtype=numpy.int32)
- level = 3.5
- silx.image.marchingsquares.find_pixels(image=image, level=level, mask=mask)
- events = MockMarchingSquares.last.events
- self.assertEqual(len(events), 3)
- self.assertEqual(events[0][0], "image")
- self.assertEqual(events[0][1][0, 0], 1)
- self.assertEqual(events[1][0], "mask")
- self.assertEqual(events[1][1][0, 0], 0)
- self.assertEqual(events[2][0], "find_pixels")
- self.assertEqual(events[2][1], level)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestFunctionalApi))
- return test_suite
diff --git a/silx/image/marchingsquares/test/test_mergeimpl.py b/silx/image/marchingsquares/test/test_mergeimpl.py
deleted file mode 100644
index 1c14f33..0000000
--- a/silx/image/marchingsquares/test/test_mergeimpl.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "18/04/2018"
-
-import unittest
-import numpy
-from .._mergeimpl import MarchingSquaresMergeImpl
-
-
-class TestMergeImplApi(unittest.TestCase):
-
- def test_image_not_an_array(self):
- bad_image = 1
- self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
-
- def test_image_bad_dim(self):
- bad_image = numpy.array([[[1.0]]])
- self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
-
- def test_image_not_big_enough(self):
- bad_image = numpy.array([[1.0, 1.0, 1.0, 1.0]])
- self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
-
- def test_mask_not_an_array(self):
- image = numpy.array([[1.0, 1.0], [1.0, 1.0]])
- bad_mask = 1
- self.assertRaises(ValueError, MarchingSquaresMergeImpl, image, bad_mask)
-
- def test_mask_not_match(self):
- image = numpy.array([[1.0, 1.0], [1.0, 1.0]])
- bad_mask = numpy.array([[1.0, 1.0]])
- self.assertRaises(ValueError, MarchingSquaresMergeImpl, image, bad_mask)
-
- def test_ok_anyway_bad_type(self):
- image = numpy.array([[1.0, 1.0], [1.0, 1.0]], dtype=numpy.int32)
- mask = numpy.array([[1.0, 1.0], [1.0, 1.0]], dtype=numpy.float32)
- MarchingSquaresMergeImpl(image, mask)
-
- def test_find_contours_result(self):
- image = numpy.zeros((2, 2))
- image[0, 0] = 1
- ms = MarchingSquaresMergeImpl(image)
- polygons = ms.find_contours(0.5)
- self.assertIsInstance(polygons, list)
- self.assertTrue(len(polygons), 1)
- self.assertIsInstance(polygons[0], numpy.ndarray)
- self.assertEqual(polygons[0].shape[1], 2)
- self.assertEqual(polygons[0].dtype.kind, "f")
-
- def test_find_pixels_result(self):
- image = numpy.zeros((2, 2))
- image[0, 0] = 1
- ms = MarchingSquaresMergeImpl(image)
- pixels = ms.find_pixels(0.5)
- self.assertIsInstance(pixels, numpy.ndarray)
- self.assertEqual(pixels.shape[1], 2)
- self.assertEqual(pixels.dtype.kind, "i")
-
- def test_find_contours_empty_result(self):
- image = numpy.zeros((2, 2))
- ms = MarchingSquaresMergeImpl(image)
- polygons = ms.find_contours(0.5)
- self.assertIsInstance(polygons, list)
- self.assertEqual(len(polygons), 0)
-
- def test_find_pixels_empty_result(self):
- image = numpy.zeros((2, 2))
- ms = MarchingSquaresMergeImpl(image)
- pixels = ms.find_pixels(0.5)
- self.assertIsInstance(pixels, numpy.ndarray)
- self.assertEqual(pixels.shape[1], 2)
- self.assertEqual(pixels.shape[0], 0)
- self.assertEqual(pixels.dtype.kind, "i")
-
- def test_find_contours_yx_result(self):
- image = numpy.zeros((2, 2))
- image[1, 0] = 1
- ms = MarchingSquaresMergeImpl(image)
- polygons = ms.find_contours(0.5)
- polygon = polygons[0]
- self.assertTrue((polygon == (0.5, 0)).any())
- self.assertTrue((polygon == (1, 0.5)).any())
-
- def test_find_pixels_yx_result(self):
- image = numpy.zeros((2, 2))
- image[1, 0] = 1
- ms = MarchingSquaresMergeImpl(image)
- pixels = ms.find_pixels(0.5)
- self.assertTrue((pixels == (1, 0)).any())
-
-
-class TestMergeImplContours(unittest.TestCase):
-
- def test_merge_segments(self):
- image = numpy.zeros((4, 4))
- image[(2, 3), :] = 1
- ms = MarchingSquaresMergeImpl(image)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 1)
-
- def test_merge_segments_2(self):
- image = numpy.zeros((4, 4))
- image[(2, 3), :] = 1
- image[2, 2] = 0
- ms = MarchingSquaresMergeImpl(image)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 1)
-
- def test_merge_tiles(self):
- image = numpy.zeros((4, 4))
- image[(2, 3), :] = 1
- ms = MarchingSquaresMergeImpl(image, group_size=2)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 1)
-
- def test_fully_masked(self):
- image = numpy.zeros((5, 5))
- image[(2, 3), :] = 1
- mask = numpy.ones((5, 5))
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 0)
-
- def test_fully_masked_minmax(self):
- """This invalidates all the tiles. The route is not the same."""
- image = numpy.zeros((5, 5))
- image[(2, 3), :] = 1
- mask = numpy.ones((5, 5))
- ms = MarchingSquaresMergeImpl(image, mask, group_size=2, use_minmax_cache=True)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 0)
-
- def test_masked_segments(self):
- image = numpy.zeros((5, 5))
- image[(2, 3, 4), :] = 1
- mask = numpy.zeros((5, 5))
- mask[:, 2] = 1
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 2)
-
- def test_closed_polygon(self):
- image = numpy.zeros((5, 5))
- image[2, 2] = 1
- image[1, 2] = 1
- image[3, 2] = 1
- image[2, 1] = 1
- image[2, 3] = 1
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.9)
- self.assertEqual(len(polygons), 1)
- self.assertEqual(list(polygons[0][0]), list(polygons[0][-1]))
-
- def test_closed_polygon_between_tiles(self):
- image = numpy.zeros((5, 5))
- image[2, 2] = 1
- image[1, 2] = 1
- image[3, 2] = 1
- image[2, 1] = 1
- image[2, 3] = 1
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask, group_size=2)
- polygons = ms.find_contours(0.9)
- self.assertEqual(len(polygons), 1)
- self.assertEqual(list(polygons[0][0]), list(polygons[0][-1]))
-
- def test_open_polygon(self):
- image = numpy.zeros((5, 5))
- image[2, 2] = 1
- image[1, 2] = 1
- image[3, 2] = 1
- image[2, 1] = 1
- image[2, 3] = 1
- mask = numpy.zeros((5, 5))
- mask[1, 1] = 1
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.9)
- self.assertEqual(len(polygons), 1)
- self.assertNotEqual(list(polygons[0][0]), list(polygons[0][-1]))
-
- def test_ambiguous_pattern(self):
- image = numpy.zeros((6, 8))
- image[(3, 4), :] = 1
- image[:, (0, -1)] = 0
- image[3, 3] = -0.001
- image[4, 4] = 0.0
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 2)
-
- def test_ambiguous_pattern_2(self):
- image = numpy.zeros((6, 8))
- image[(3, 4), :] = 1
- image[:, (0, -1)] = 0
- image[3, 3] = +0.001
- image[4, 4] = 0.0
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 1)
-
- def count_closed_polygons(self, polygons):
- closed = 0
- for polygon in polygons:
- if list(polygon[0]) == list(polygon[-1]):
- closed += 1
- return closed
-
- def test_image(self):
- # example from skimage
- x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
- image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 11)
- self.assertEqual(self.count_closed_polygons(polygons), 3)
-
- def test_image_tiled(self):
- # example from skimage
- x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
- image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask, group_size=50)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 11)
- self.assertEqual(self.count_closed_polygons(polygons), 3)
-
- def test_image_tiled_minmax(self):
- # example from skimage
- x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
- image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
- mask = None
- ms = MarchingSquaresMergeImpl(image, mask, group_size=50, use_minmax_cache=True)
- polygons = ms.find_contours(0.5)
- self.assertEqual(len(polygons), 11)
- self.assertEqual(self.count_closed_polygons(polygons), 3)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestMergeImplApi))
- test_suite.addTest(loadTests(TestMergeImplContours))
- return test_suite
diff --git a/silx/image/test/__init__.py b/silx/image/test/__init__.py
deleted file mode 100644
index f469edc..0000000
--- a/silx/image/test/__init__.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2018 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["J. Kieffer"]
-__license__ = "MIT"
-__date__ = "17/04/2018"
-
-import unittest
-from . import test_bilinear
-from . import test_shapes
-from . import test_medianfilter
-from . import test_tomography
-from . import test_bb
-from ..marchingsquares.test import suite as marchingsquares_suite
-
-
-def suite():
- """Test suite for module silx.image.test"""
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_bilinear.suite())
- test_suite.addTest(test_medianfilter.suite())
- test_suite.addTest(test_shapes.suite())
- test_suite.addTest(test_tomography.suite())
- test_suite.addTest(marchingsquares_suite())
- test_suite.addTest(test_bb.suite())
- return test_suite
diff --git a/silx/image/test/test_bb.py b/silx/image/test/test_bb.py
deleted file mode 100644
index 3f33e80..0000000
--- a/silx/image/test/test_bb.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic tests for Bounding box"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "27/09/2019"
-
-
-import unittest
-import numpy
-from silx.image._boundingbox import _BoundingBox
-
-
-class TestBB(unittest.TestCase):
- """Some simple test on the bounding box class"""
- def test_creation(self):
- """test some constructors"""
- pts = numpy.array([(0, 0), (10, 20), (20, 0)])
- bb = _BoundingBox.from_points(pts)
- self.assertTrue(bb.bottom_left == (0, 0))
- self.assertTrue(bb.top_right == (20, 20))
- pts = numpy.array([(0, 10), (10, 20), (45, 30), (35, 0)])
- bb = _BoundingBox.from_points(pts)
- self.assertTrue(bb.bottom_left == (0, 0))
- print(bb.top_right)
- self.assertTrue(bb.top_right == (45, 30))
-
- def test_isIn_pt(self):
- """test the isIn function with points"""
- bb = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb.contains((10, 4)))
- self.assertTrue(bb.contains((6, 2)))
- self.assertTrue(bb.contains((12, 2)))
- self.assertFalse(bb.contains((0, 0)))
- self.assertFalse(bb.contains((20, 0)))
- self.assertFalse(bb.contains((10, 0)))
-
- def test_collide(self):
- """test the collide function"""
- bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb1.collide(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6))))
- bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertFalse(bb1.collide(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2))))
-
- def test_isIn_bb(self):
- """test the isIn function with other bounding box"""
- bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb1.contains(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6))))
- bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
- self.assertTrue(bb1.contains(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2))))
- self.assertFalse(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2)).contains(bb1))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for TestClass in (TestBB,):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/image/test/test_bilinear.py b/silx/image/test/test_bilinear.py
deleted file mode 100644
index 55eaccb..0000000
--- a/silx/image/test/test_bilinear.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx (originally pyFAI)
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2017 European Synchrotron Radiation Facility, Grenoble, France
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["J. Kieffer"]
-__license__ = "MIT"
-__date__ = "25/11/2020"
-
-import unittest
-import numpy
-import logging
-logger = logging.getLogger(__name__)
-from ..bilinear import BilinearImage
-
-
-class TestBilinear(unittest.TestCase):
- """basic maximum search test"""
- N = 1000
-
- def test_max_search_round(self):
- """test maximum search using random points: maximum is at the pixel center"""
- a = numpy.arange(100) - 40.
- b = numpy.arange(100) - 60.
- ga = numpy.exp(-a * a / 4000)
- gb = numpy.exp(-b * b / 6000)
- gg = numpy.outer(ga, gb)
- b = BilinearImage(gg)
-
- self.assertAlmostEqual(b.maxi, 1, 2, "maxi is almost 1")
- self.assertLess(b.mini, 0.3, "mini should be around 0.23")
-
- ok = 0
- for s in range(self.N):
- i, j = numpy.random.randint(100), numpy.random.randint(100)
- k, l = b.local_maxi((i, j))
- if abs(k - 40) > 1e-4 or abs(l - 60) > 1e-4:
- logger.warning("Wrong guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
- else:
- logger.debug("Good guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
- ok += 1
- logger.debug("Success rate: %.1f", 100. * ok / self.N)
- self.assertEqual(ok, self.N, "Maximum is always found")
-
- def test_max_search_half(self):
- """test maximum search using random points: maximum is at a pixel edge"""
- a = numpy.arange(100) - 40.5
- b = numpy.arange(100) - 60.5
- ga = numpy.exp(-a * a / 4000)
- gb = numpy.exp(-b * b / 6000)
- gg = numpy.outer(ga, gb)
- b = BilinearImage(gg)
- ok = 0
- for s in range(self.N):
- i, j = numpy.random.randint(100), numpy.random.randint(100)
- k, l = b.local_maxi((i, j))
- if abs(k - 40.5) > 0.5 or abs(l - 60.5) > 0.5:
- logger.warning("Wrong guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
- else:
- logger.debug("Good guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
- ok += 1
- logger.debug("Success rate: %.1f", 100. * ok / self.N)
- self.assertEqual(ok, self.N, "Maximum is always found")
-
- def test_map(self):
- N = 6
- y, x = numpy.ogrid[:N,:N + 10]
- img = x + y
- b = BilinearImage(img)
- x2d = numpy.zeros_like(y) + x
- y2d = numpy.zeros_like(x) + y
- res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(abs(res1 - img).max(), 0, "images are the same (corners)")
-
- x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + y
- res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(abs(res1 - img[:,:-1] - 0.5).max(), 0, "images are the same (middle)")
-
- x2d = numpy.zeros_like(y[:-1,:]) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1,:] + 0.5)
- res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(abs(res1 - img[:-1, 1:]).max(), 0, "images are the same (center)")
-
- def test_mask_grad(self):
- N = 100
- img = numpy.arange(N * N).reshape(N, N)
- # No mask on the boundaries, makes the test complicated, pixel always separated
- masked = 2 * numpy.random.randint(0, int((N - 1) / 2), size=(2, N)) + 1
- mask = numpy.zeros((N, N), dtype=numpy.uint8)
- mask[(masked[0], masked[1])] = 1
- self.assertLessEqual(mask.sum(), N, "At most N pixels are masked")
-
- b = BilinearImage(img, mask=mask)
- self.assertEqual(b.has_mask, True, "interpolator has mask")
- self.assertEqual(b.maxi, N * N - 1, "maxi is N²-1")
- self.assertEqual(b.mini, 0, "mini is 0")
-
- y, x = numpy.ogrid[:N,:N]
- x2d = numpy.zeros_like(y) + x
- y2d = numpy.zeros_like(x) + y
- res1 = b.map_coordinates((y2d, x2d))
- self.assertEqual(numpy.nanmax(abs(res1 - img)), 0, "images are the same (corners), or Nan ")
-
- x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + y
- res1 = b.map_coordinates((y2d, x2d))
- self.assertLessEqual(numpy.max(abs(res1 - img[:, 1:] + 1 / 2.)), 0.5, "images are the same (middle) +/- 0.5")
-
- x2d = numpy.zeros_like(y[:-1]) + (x[:,:-1] + 0.5)
- y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1] + 0.5)
- res1 = b.map_coordinates((y2d, x2d))
- exp = 0.25 * (img[:-1,:-1] + img[:-1, 1:] + img[1:,:-1] + img[1:, 1:])
- self.assertLessEqual(abs(res1 - exp).max(), N / 4, "images are almost the same (center)")
-
- def test_profile_grad(self):
- N = 100
- img = numpy.arange(N * N).reshape(N, N)
- b = BilinearImage(img)
- res1 = b.profile_line((0, 0), (N - 1, N - 1))
- l = numpy.ceil(numpy.sqrt(2) * N)
- self.assertEqual(len(res1), l, "Profile has correct length")
- self.assertLess((res1[:-2] - res1[1:-1]).std(), 1e-3, "profile is linear (excluding last point)")
-
- def test_profile_gaus(self):
- N = 100
- x = numpy.arange(N) - N // 2.0
- g = numpy.exp(-x * x / (N * N))
- img = numpy.outer(g, g)
- b = BilinearImage(img)
- res_hor = b.profile_line((N // 2, 0), (N // 2, N - 1))
- res_ver = b.profile_line((0, N // 2), (N - 1, N // 2))
- self.assertEqual(len(res_hor), N, "Profile has correct length")
- self.assertEqual(len(res_ver), N, "Profile has correct length")
- self.assertLess(abs(res_hor - g).max(), 1e-5, "correct horizontal profile")
- self.assertLess(abs(res_ver - g).max(), 1e-5, "correct vertical profile")
-
- # Profile with linewidth=3
- expected_profile = img[:, N // 2 - 1:N // 2 + 2].mean(axis=1)
- res_hor = b.profile_line((N // 2, 0), (N // 2, N - 1), linewidth=3)
- res_ver = b.profile_line((0, N // 2), (N - 1, N // 2), linewidth=3)
-
- self.assertEqual(len(res_hor), N, "Profile has correct length")
- self.assertEqual(len(res_ver), N, "Profile has correct length")
- self.assertLess(abs(res_hor - expected_profile).max(), 1e-5,
- "correct horizontal profile")
- self.assertLess(abs(res_ver - expected_profile).max(), 1e-5,
- "correct vertical profile")
-
-
-def suite():
- testsuite = unittest.TestSuite()
- testsuite.addTest(TestBilinear("test_max_search_round"))
- testsuite.addTest(TestBilinear("test_max_search_half"))
- testsuite.addTest(TestBilinear("test_map"))
- testsuite.addTest(TestBilinear("test_profile_grad"))
- testsuite.addTest(TestBilinear("test_profile_gaus"))
- testsuite.addTest(TestBilinear("test_mask_grad"))
- return testsuite
diff --git a/silx/image/test/test_medianfilter.py b/silx/image/test/test_medianfilter.py
deleted file mode 100644
index 5b062d9..0000000
--- a/silx/image/test/test_medianfilter.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests that the different implementation of opencl (cpp, opencl) are
- accessible
-"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "11/05/2017"
-
-import unittest
-from silx.image import medianfilter
-import numpy
-
-from silx.opencl.common import ocl
-
-
-class TestMedianFilterEngines(unittest.TestCase):
- """Make sure we have access to all the different implementation of
- median filter from image medfilt"""
-
-
- IMG = numpy.arange(10000.).reshape(100, 100)
-
- KERNEL = (1, 1)
-
- def testCppMedFilt2d(self):
- """test cpp engine for medfilt2d"""
- res = medianfilter.medfilt2d(
- image=TestMedianFilterEngines.IMG,
- kernel_size=TestMedianFilterEngines.KERNEL,
- engine='cpp')
- self.assertTrue(numpy.array_equal(res, TestMedianFilterEngines.IMG))
-
- @unittest.skipUnless(ocl, "PyOpenCl is missing")
- def testOpenCLMedFilt2d(self):
- """test cpp engine for medfilt2d"""
- res = medianfilter.medfilt2d(
- image=TestMedianFilterEngines.IMG,
- kernel_size=TestMedianFilterEngines.KERNEL,
- engine='opencl')
- self.assertTrue(numpy.array_equal(res, TestMedianFilterEngines.IMG))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for testClass in (TestMedianFilterEngines, ):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(testClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/image/test/test_shapes.py b/silx/image/test/test_shapes.py
deleted file mode 100644
index 6539bba..0000000
--- a/silx/image/test/test_shapes.py
+++ /dev/null
@@ -1,366 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for polygon functions
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "15/02/2019"
-
-
-import logging
-import unittest
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.image import shapes
-
-_logger = logging.getLogger(__name__)
-
-
-class TestPolygonFill(ParametricTestCase):
- """basic poylgon test"""
-
- def test_squares(self):
- """Test polygon fill for a square polygons"""
- mask_shape = 4, 4
- tests = {
- # test name: [(row min, row max), (col min, col max)]
- 'square in': [(1, 3), (1, 3)],
- 'square out': [(1, 3), (1, 10)],
- 'square around': [(-1, 5), (-1, 5)],
- }
-
- for test_name, (rows, cols) in tests.items():
- with self.subTest(msg=test_name, rows=rows, cols=cols,
- mask_shape=mask_shape):
- ref_mask = numpy.zeros(mask_shape, dtype=numpy.uint8)
- ref_mask[max(0, rows[0]):rows[1],
- max(0, cols[0]):cols[1]] = True
-
- vertices = [(rows[0], cols[0]), (rows[1], cols[0]),
- (rows[1], cols[1]), (rows[0], cols[1])]
- mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
- is_equal = numpy.all(numpy.equal(ref_mask, mask))
- if not is_equal:
- _logger.debug('%s failed with mask != ref_mask:',
- test_name)
- _logger.debug('result:\n%s', str(mask))
- _logger.debug('ref:\n%s', str(ref_mask))
- self.assertTrue(is_equal)
-
- def test_eight(self):
- """Tests with eight shape with different rotation and direction"""
- ref_mask = numpy.array((
- (1, 1, 1, 1, 1, 0),
- (0, 1, 1, 1, 0, 0),
- (0, 0, 1, 0, 0, 0),
- (0, 0, 1, 0, 0, 0),
- (0, 1, 1, 1, 0, 0),
- (0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)
- ref_mask_rot = numpy.asarray(numpy.logical_not(ref_mask),
- dtype=numpy.uint8)
- ref_mask_rot[:, -1] = 0
- ref_mask_rot[-1, :] = 0
-
- tests = {
- 'dir 1': ([(0, 0), (5, 5), (5, 0), (0, 5)], ref_mask),
- 'dir 1, rot 90': ([(5, 0), (0, 5), (5, 5), (0, 0)], ref_mask_rot),
- 'dir 1, rot 180': ([(5, 5), (0, 0), (0, 5), (5, 0)], ref_mask),
- 'dir 1, rot -90': ([(0, 5), (5, 0), (0, 0), (5, 5)], ref_mask_rot),
- 'dir 2': ([(0, 0), (0, 5), (5, 0), (5, 5)], ref_mask),
- 'dir 2, rot 90': ([(5, 0), (0, 0), (5, 5), (0, 5)], ref_mask_rot),
- 'dir 2, rot 180': ([(5, 5), (5, 0), (0, 5), (0, 0)], ref_mask),
- 'dir 2, rot -90': ([(0, 5), (5, 5), (0, 0), (5, 0)], ref_mask_rot),
- }
-
- for test_name, (vertices, ref_mask) in tests.items():
- with self.subTest(msg=test_name):
- mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
- is_equal = numpy.all(numpy.equal(ref_mask, mask))
- if not is_equal:
- _logger.debug('%s failed with mask != ref_mask:',
- test_name)
- _logger.debug('result:\n%s', str(mask))
- _logger.debug('ref:\n%s', str(ref_mask))
- self.assertTrue(is_equal)
-
- def test_shapes(self):
- """Tests with shapes and reference mask"""
- tests = {
- # name: (
- # polygon corners as a list of (row, col),
- # ref_mask)
- 'concave polygon': (
- [(1, 1), (4, 3), (1, 5), (2, 3)],
- numpy.array((
- (0, 0, 0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0, 0, 0),
- (0, 0, 1, 1, 1, 0, 0, 0),
- (0, 0, 0, 1, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)),
- 'concave polygon partly outside mask': (
- [(-1, -1), (4, 3), (1, 5), (2, 3)],
- numpy.array((
- (1, 0, 0, 0, 0, 0),
- (0, 1, 0, 0, 0, 0),
- (0, 0, 1, 1, 1, 0),
- (0, 0, 0, 1, 0, 0),
- (0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0),
- (0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)),
- 'polygon surrounding mask': (
- [(-1, -1), (-1, 7), (7, 7), (7, -1), (0, -1),
- (8, -2), (8, 8), (-2, 8)],
- numpy.zeros((6, 6), dtype=numpy.uint8))
- }
-
- for test_name, (vertices, ref_mask) in tests.items():
- with self.subTest(msg=test_name):
- mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
- is_equal = numpy.all(numpy.equal(ref_mask, mask))
- if not is_equal:
- _logger.debug('%s failed with mask != ref_mask:',
- test_name)
- _logger.debug('result:\n%s', str(mask))
- _logger.debug('ref:\n%s', str(ref_mask))
- self.assertTrue(is_equal)
-
-
-class TestDrawLine(ParametricTestCase):
- """basic draw line test"""
-
- def test_aligned_lines(self):
- """Test drawing horizontal, vertical and diagonal lines"""
-
- lines = { # test_name: (drow, dcol)
- 'Horizontal line, col0 < col1': (0, 10),
- 'Horizontal line, col0 > col1': (0, -10),
- 'Vertical line, row0 < row1': (10, 0),
- 'Vertical line, row0 > row1': (-10, 0),
- 'Diagonal col0 < col1 and row0 < row1': (10, 10),
- 'Diagonal col0 < col1 and row0 > row1': (-10, 10),
- 'Diagonal col0 > col1 and row0 < row1': (10, -10),
- 'Diagonal col0 > col1 and row0 > row1': (-10, -10),
- }
- row0, col0 = 1, 2 # Start point
-
- for test_name, (drow, dcol) in lines.items():
- row1 = row0 + drow
- col1 = col0 + dcol
- with self.subTest(msg=test_name, drow=drow, dcol=dcol):
- # Build reference coordinates from drow and dcol
- if drow == 0:
- rows = row0 * numpy.ones(abs(dcol) + 1)
- else:
- step = 1 if drow > 0 else -1
- rows = row0 + numpy.arange(0, drow + step, step)
-
- if dcol == 0:
- cols = col0 * numpy.ones(abs(drow) + 1)
- else:
- step = 1 if dcol > 0 else -1
- cols = col0 + numpy.arange(0, dcol + step, step)
- ref_coords = rows, cols
-
- result = shapes.draw_line(row0, col0, row1, col1)
- self.assertTrue(self.isEqual(test_name, result, ref_coords))
-
- def test_noline(self):
- """Test pt0 == pt1"""
- for width in range(4):
- with self.subTest(width=width):
- result = shapes.draw_line(1, 2, 1, 2, width)
- self.assertTrue(numpy.all(numpy.equal(result, [(1,), (2,)])))
-
- def test_lines(self):
- """Test lines not aligned with axes for 8 slopes and directions"""
- row0, col0 = 1, 1
-
- dy, dx = 3, 5
- ref_coords = numpy.array(
- [(0, 0), (1, 1), (1, 2), (2, 3), (2, 4), (3, 5)])
-
- # Build lines for the 8 octants from this coordinantes
- lines = { # name: (drow, dcol, ref_coords)
- '1st octant': (dy, dx, ref_coords),
- '2nd octant': (dx, dy, ref_coords[:, (1, 0)]), # invert x and y
- '3rd octant': (dx, -dy, ref_coords[:, (1, 0)] * (1, -1)),
- '4th octant': (dy, -dx, ref_coords * (1, -1)),
- '5th octant': (-dy, -dx, ref_coords * (-1, -1)),
- '6th octant': (-dx, -dy, ref_coords[:, (1, 0)] * (-1, -1)),
- '7th octant': (-dx, dy, ref_coords[:, (1, 0)] * (-1, 1)),
- '8th octant': (-dy, dx, ref_coords * (-1, 1))
- }
-
- # Test with different starting points with positive and negative coords
- for row0, col0 in ((0, 0), (2, 3), (-4, 1), (-5, -6), (8, -7)):
- for name, (drow, dcol, ref_coords) in lines.items():
- row1 = row0 + drow
- col1 = col0 + dcol
- # Transpose from ((row0, col0), ...) to (rows, cols)
- ref_coords = numpy.transpose(ref_coords + (row0, col0))
-
- with self.subTest(msg=name,
- pt0=(row0, col0), pt1=(row1, col1)):
- result = shapes.draw_line(row0, col0, row1, col1)
- self.assertTrue(self.isEqual(name, result, ref_coords))
-
- def test_width(self):
- """Test of line width"""
-
- lines = { # test_name: row0, col0, row1, col1, width, ref
- 'horizontal w=2':
- (0, 0, 0, 1, 2, ((0, 1, 0, 1),
- (0, 0, 1, 1))),
- 'horizontal w=3':
- (0, 0, 0, 1, 3, ((-1, 0, 1, -1, 0, 1),
- (0, 0, 0, 1, 1, 1))),
- 'vertical w=2':
- (0, 0, 1, 0, 2, ((0, 0, 1, 1),
- (0, 1, 0, 1))),
- 'vertical w=3':
- (0, 0, 1, 0, 3, ((0, 0, 0, 1, 1, 1),
- (-1, 0, 1, -1, 0, 1))),
- 'diagonal w=3':
- (0, 0, 1, 1, 3, ((-1, 0, 1, 0, 1, 2),
- (0, 0, 0, 1, 1, 1))),
- '1st octant w=3':
- (0, 0, 1, 2, 3,
- numpy.array(((-1, 0), (0, 0), (1, 0),
- (0, 1), (1, 1), (2, 1),
- (0, 2), (1, 2), (2, 2))).T),
- '2nd octant w=3':
- (0, 0, 2, 1, 3,
- numpy.array(((0, -1), (0, 0), (0, 1),
- (1, 0), (1, 1), (1, 2),
- (2, 0), (2, 1), (2, 2))).T),
- }
-
- for test_name, (row0, col0, row1, col1, width, ref) in lines.items():
- with self.subTest(msg=test_name,
- pt0=(row0, col0), pt1=(row1, col1), width=width):
- result = shapes.draw_line(row0, col0, row1, col1, width)
- self.assertTrue(self.isEqual(test_name, result, ref))
-
- def isEqual(self, test_name, result, ref):
- """Test equality of two numpy arrays and log them if different"""
- is_equal = numpy.all(numpy.equal(result, ref))
- if not is_equal:
- _logger.debug('%s failed with result != ref:',
- test_name)
- _logger.debug('result:\n%s', str(result))
- _logger.debug('ref:\n%s', str(ref))
- return is_equal
-
-
-class TestCircleFill(ParametricTestCase):
- """Tests for circle filling"""
-
- def testCircle(self):
- """Test circle_fill with different input parameters"""
-
- square3x3 = numpy.array(((-1, -1, -1, 0, 0, 0, 1, 1, 1),
- (-1, 0, 1, -1, 0, 1, -1, 0, 1)))
-
- tests = [
- # crow, ccol, radius, ref_coords = (ref_rows, ref_cols)
- (0, 0, 1, ((0,), (0,))),
- (10, 15, 1, ((10,), (15,))),
- (0, 0, 1.5, square3x3),
- (5, 10, 2, (5 + square3x3[0], 10 + square3x3[1])),
- (10, 20, 3.5, (
- 10 + numpy.array((-3, -3, -3,
- -2, -2, -2, -2, -2,
- -1, -1, -1, -1, -1, -1, -1,
- 0, 0, 0, 0, 0, 0, 0,
- 1, 1, 1, 1, 1, 1, 1,
- 2, 2, 2, 2, 2,
- 3, 3, 3)),
- 20 + numpy.array((-1, 0, 1,
- -2, -1, 0, 1, 2,
- -3, -2, -1, 0, 1, 2, 3,
- -3, -2, -1, 0, 1, 2, 3,
- -3, -2, -1, 0, 1, 2, 3,
- -2, -1, 0, 1, 2,
- -1, 0, 1)))),
- ]
-
- for crow, ccol, radius, ref_coords in tests:
- with self.subTest(crow=crow, ccol=ccol, radius=radius):
- coords = shapes.circle_fill(crow, ccol, radius)
- is_equal = numpy.all(numpy.equal(coords, ref_coords))
- if not is_equal:
- _logger.debug('result:\n%s', str(coords))
- _logger.debug('ref:\n%s', str(ref_coords))
- self.assertTrue(is_equal)
-
-
-class TestEllipseFill(unittest.TestCase):
- """Tests for ellipse filling"""
-
- def testPoint(self):
- args = [1, 1, 1, 1]
- result = shapes.ellipse_fill(*args)
- expected = numpy.array(([1], [1]))
- numpy.testing.assert_array_equal(result, expected)
-
- def testTranslatedPoint(self):
- args = [10, 10, 1, 1]
- result = shapes.ellipse_fill(*args)
- expected = numpy.array(([10], [10]))
- numpy.testing.assert_array_equal(result, expected)
-
- def testEllipse(self):
- args = [0, 0, 20, 10]
- rows, cols = shapes.ellipse_fill(*args)
- self.assertEqual(len(rows), 617)
- self.assertEqual(rows.mean(), 0)
- self.assertAlmostEqual(rows.std(), 10.025575, places=3)
- self.assertEqual(len(cols), 617)
- self.assertEqual(cols.mean(), 0)
- self.assertAlmostEqual(cols.std(), 4.897325, places=3)
-
- def testTranslatedEllipse(self):
- args = [0, 0, 20, 10]
- expected_rows, expected_cols = shapes.ellipse_fill(*args)
- args = [10, 50, 20, 10]
- rows, cols = shapes.ellipse_fill(*args)
- numpy.testing.assert_allclose(rows, expected_rows + 10)
- numpy.testing.assert_allclose(cols, expected_cols + 50)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for testClass in (TestPolygonFill, TestDrawLine, TestCircleFill, TestEllipseFill):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(testClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/image/test/test_tomography.py b/silx/image/test/test_tomography.py
deleted file mode 100644
index 2a6a33c..0000000
--- a/silx/image/test/test_tomography.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Tests that the functions of tomography are valid
-"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "12/09/2017"
-
-import unittest
-import numpy
-from silx.test.utils import utilstest
-from silx.image import tomography
-
-class TestTomography(unittest.TestCase):
- """
-
- """
-
- def setUp(self):
- self.sinoTrueData = numpy.load(utilstest.getfile("sino500.npz"))["data"]
-
- def testCalcCenterCentroid(self):
- centerTD = tomography.calc_center_centroid(self.sinoTrueData)
- self.assertTrue(numpy.isclose(centerTD, 256, rtol=0.01))
-
- def testCalcCenterCorr(self):
- centerTrueData = tomography.calc_center_corr(self.sinoTrueData,
- fullrot=False,
- props=1)
- self.assertTrue(numpy.isclose(centerTrueData, 256, rtol=0.01))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for testClass in (TestTomography, ):
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(testClass))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/io/commonh5.py b/silx/io/commonh5.py
deleted file mode 100644
index 57232d8..0000000
--- a/silx/io/commonh5.py
+++ /dev/null
@@ -1,1083 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-This module contains generic objects, emulating *h5py* groups, datasets and
-files. They are used in :mod:`spech5` and :mod:`fabioh5`.
-
-.. note:: This module has a dependency on the `h5py <http://www.h5py.org/>`_
- library, which is not a mandatory dependency for `silx`.
-"""
-import collections
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
-import weakref
-
-import h5py
-import numpy
-import six
-
-from . import utils
-
-__authors__ = ["V. Valls", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "02/07/2018"
-
-
-class _MappingProxyType(abc.MutableMapping):
- """Read-only dictionary
-
- This class is available since Python 3.3, but not on earlyer Python
- versions.
- """
-
- def __init__(self, data):
- self._data = data
-
- def __getitem__(self, key):
- return self._data[key]
-
- def __len__(self):
- return len(self._data)
-
- def __iter__(self):
- return iter(self._data)
-
- def get(self, key, default=None):
- return self._data.get(key, default)
-
- def __setitem__(self, key, value):
- raise RuntimeError("Cannot modify read-only dictionary")
-
- def __delitem__(self, key):
- raise RuntimeError("Cannot modify read-only dictionary")
-
- def pop(self, key):
- raise RuntimeError("Cannot modify read-only dictionary")
-
- def clear(self):
- raise RuntimeError("Cannot modify read-only dictionary")
-
- def update(self, key, value):
- raise RuntimeError("Cannot modify read-only dictionary")
-
- def setdefault(self, key):
- raise RuntimeError("Cannot modify read-only dictionary")
-
-
-class Node(object):
- """This is the base class for all :mod:`spech5` and :mod:`fabioh5`
- classes. It represents a tree node, and knows its parent node
- (:attr:`parent`).
- The API mimics a *h5py* node, with following attributes: :attr:`file`,
- :attr:`attrs`, :attr:`name`, and :attr:`basename`.
- """
-
- def __init__(self, name, parent=None, attrs=None):
- self._set_parent(parent)
- self.__basename = name
- self.__attrs = {}
- if attrs is not None:
- self.__attrs.update(attrs)
-
- def _set_basename(self, name):
- self.__basename = name
-
- @property
- def h5_class(self):
- """Returns the HDF5 class which is mimicked by this class.
-
- :rtype: H5Type
- """
- raise NotImplementedError()
-
- @property
- def h5py_class(self):
- """Returns the h5py classes which is mimicked by this class. It can be
- one of `h5py.File, h5py.Group` or `h5py.Dataset`
-
- This should not be used anymore. Prefer using `h5_class`
-
- :rtype: Class
- """
- h5_class = self.h5_class
- if h5_class == utils.H5Type.FILE:
- return h5py.File
- elif h5_class == utils.H5Type.GROUP:
- return h5py.Group
- elif h5_class == utils.H5Type.DATASET:
- return h5py.Dataset
- elif h5_class == utils.H5Type.SOFT_LINK:
- return h5py.SoftLink
- raise NotImplementedError()
-
- @property
- def parent(self):
- """Returns the parent of the node.
-
- :rtype: Node
- """
- if self.__parent is None:
- parent = None
- else:
- parent = self.__parent()
- if parent is None:
- self.__parent = None
- return parent
-
- def _set_parent(self, parent):
- """Set the parent of this node.
-
- It do not update the parent object.
-
- :param Node parent: New parent for this node
- """
- if parent is not None:
- self.__parent = weakref.ref(parent)
- else:
- self.__parent = None
-
- @property
- def file(self):
- """Returns the file node of this node.
-
- :rtype: Node
- """
- node = self
- while node.parent is not None:
- node = node.parent
- if isinstance(node, File):
- return node
- else:
- return None
-
- @property
- def attrs(self):
- """Returns HDF5 attributes of this node.
-
- :rtype: dict
- """
- if self._is_editable():
- return self.__attrs
- else:
- return _MappingProxyType(self.__attrs)
-
- @property
- def name(self):
- """Returns the HDF5 name of this node.
- """
- parent = self.parent
- if parent is None:
- return "/"
- if parent.name == "/":
- return "/" + self.basename
- return parent.name + "/" + self.basename
-
- @property
- def basename(self):
- """Returns the HDF5 basename of this node.
- """
- return self.__basename
-
- def _is_editable(self):
- """Returns true if the file is editable or if the node is not linked
- to a tree.
-
- :rtype: bool
- """
- f = self.file
- return f is None or f.mode == "w"
-
-
-class Dataset(Node):
- """This class handles a numpy data object, as a mimicry of a
- *h5py.Dataset*.
- """
-
- def __init__(self, name, data, parent=None, attrs=None):
- Node.__init__(self, name, parent, attrs=attrs)
- if data is not None:
- self._check_data(data)
- self.__data = data
-
- def _check_data(self, data):
- """Check that the data provided by the dataset is valid.
-
- It is valid when it can be stored in a HDF5 using h5py.
-
- :param numpy.ndarray data: Data associated to the dataset
- :raises TypeError: In the case the data is not valid.
- """
- if isinstance(data, (six.text_type, six.binary_type)):
- return
-
- chartype = data.dtype.char
- if chartype == "U":
- pass
- elif chartype == "O":
- d = h5py.special_dtype(vlen=data.dtype)
- if d is not None:
- return
- d = h5py.special_dtype(ref=data.dtype)
- if d is not None:
- return
- else:
- return
-
- msg = "Type of the dataset '%s' is not supported. Found '%s'."
- raise TypeError(msg % (self.name, data.dtype))
-
- def _set_data(self, data):
- """Set the data exposed by the dataset.
-
- It have to be called only one time before the data is used. It should
- not be edited after use.
-
- :param numpy.ndarray data: Data associated to the dataset
- """
- self._check_data(data)
- self.__data = data
-
- def _get_data(self):
- """Returns the exposed data
-
- :rtype: numpy.ndarray
- """
- return self.__data
-
- @property
- def h5_class(self):
- """Returns the HDF5 class which is mimicked by this class.
-
- :rtype: H5Type
- """
- return utils.H5Type.DATASET
-
- @property
- def dtype(self):
- """Returns the numpy datatype exposed by this dataset.
-
- :rtype: numpy.dtype
- """
- return self._get_data().dtype
-
- @property
- def shape(self):
- """Returns the shape of the data exposed by this dataset.
-
- :rtype: tuple
- """
- if isinstance(self._get_data(), numpy.ndarray):
- return self._get_data().shape
- else:
- return tuple()
-
- @property
- def size(self):
- """Returns the size of the data exposed by this dataset.
-
- :rtype: int
- """
- if isinstance(self._get_data(), numpy.ndarray):
- return self._get_data().size
- else:
- # It is returned as float64 1.0 by h5py
- return numpy.float64(1.0)
-
- def __len__(self):
- """Returns the size of the data exposed by this dataset.
-
- :rtype: int
- """
- if isinstance(self._get_data(), numpy.ndarray):
- return len(self._get_data())
- else:
- # It is returned as float64 1.0 by h5py
- raise TypeError("Attempt to take len() of scalar dataset")
-
- def __getitem__(self, item):
- """Returns the slice of the data exposed by this dataset.
-
- :rtype: numpy.ndarray
- """
- if not isinstance(self._get_data(), numpy.ndarray):
- if item == Ellipsis:
- return numpy.array(self._get_data())
- elif item == tuple():
- return self._get_data()
- else:
- raise ValueError("Scalar can only be reached with an ellipsis or an empty tuple")
- return self._get_data().__getitem__(item)
-
- def __str__(self):
- basename = self.name.split("/")[-1]
- return '<HDF5-like dataset "%s": shape %s, type "%s">' % \
- (basename, self.shape, self.dtype.str)
-
- def __getslice__(self, i, j):
- """Returns the slice of the data exposed by this dataset.
-
- Deprecated but still in use for python 2.7
-
- :rtype: numpy.ndarray
- """
- return self.__getitem__(slice(i, j, None))
-
- @property
- def value(self):
- """Returns the data exposed by this dataset.
-
- Deprecated by h5py. It is prefered to use indexing `[()]`.
-
- :rtype: numpy.ndarray
- """
- return self._get_data()
-
- @property
- def compression(self):
- """Returns compression as provided by `h5py.Dataset`.
-
- There is no compression."""
- return None
-
- @property
- def compression_opts(self):
- """Returns compression options as provided by `h5py.Dataset`.
-
- There is no compression."""
- return None
-
- @property
- def chunks(self):
- """Returns chunks as provided by `h5py.Dataset`.
-
- There is no chunks."""
- return None
-
- @property
- def is_virtual(self):
- """Checks virtual data as provided by `h5py.Dataset`"""
- return False
-
- def virtual_sources(self):
- """Returns virtual dataset sources as provided by `h5py.Dataset`.
-
- :rtype: list"""
- raise RuntimeError("Not a virtual dataset")
-
- @property
- def external(self):
- """Returns external sources as provided by `h5py.Dataset`.
-
- :rtype: list or None"""
- return None
-
- def __array__(self, dtype=None):
- # Special case for (0,)*-shape datasets
- if numpy.product(self.shape) == 0:
- return self[()]
- else:
- return numpy.array(self[...], dtype=self.dtype if dtype is None else dtype)
-
- def __iter__(self):
- """Iterate over the first axis. TypeError if scalar."""
- if len(self.shape) == 0:
- raise TypeError("Can't iterate over a scalar dataset")
- return self._get_data().__iter__()
-
- # make comparisons and operations on the data
- def __eq__(self, other):
- """When comparing datasets, compare the actual data."""
- if utils.is_dataset(other):
- return self[()] == other[()]
- return self[()] == other
-
- def __add__(self, other):
- return self[()] + other
-
- def __radd__(self, other):
- return other + self[()]
-
- def __sub__(self, other):
- return self[()] - other
-
- def __rsub__(self, other):
- return other - self[()]
-
- def __mul__(self, other):
- return self[()] * other
-
- def __rmul__(self, other):
- return other * self[()]
-
- def __truediv__(self, other):
- return self[()] / other
-
- def __rtruediv__(self, other):
- return other / self[()]
-
- def __floordiv__(self, other):
- return self[()] // other
-
- def __rfloordiv__(self, other):
- return other // self[()]
-
- def __neg__(self):
- return -self[()]
-
- def __abs__(self):
- return abs(self[()])
-
- def __float__(self):
- return float(self[()])
-
- def __int__(self):
- return int(self[()])
-
- def __bool__(self):
- if self[()]:
- return True
- return False
-
- def __nonzero__(self):
- # python 2
- return self.__bool__()
-
- def __ne__(self, other):
- if utils.is_dataset(other):
- return self[()] != other[()]
- else:
- return self[()] != other
-
- def __lt__(self, other):
- if utils.is_dataset(other):
- return self[()] < other[()]
- else:
- return self[()] < other
-
- def __le__(self, other):
- if utils.is_dataset(other):
- return self[()] <= other[()]
- else:
- return self[()] <= other
-
- def __gt__(self, other):
- if utils.is_dataset(other):
- return self[()] > other[()]
- else:
- return self[()] > other
-
- def __ge__(self, other):
- if utils.is_dataset(other):
- return self[()] >= other[()]
- else:
- return self[()] >= other
-
- def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
- data = self._get_data()
- if hasattr(data, item):
- return getattr(data, item)
-
- raise AttributeError("Dataset has no attribute %s" % item)
-
-
-class DatasetProxy(Dataset):
- """Virtual dataset providing content of another dataset"""
-
- def __init__(self, name, target, parent=None):
- Dataset.__init__(self, name, data=None, parent=parent)
- if not utils.is_dataset(target):
- raise TypeError("A Dataset is expected but %s found", target.__class__)
- self.__target = target
-
- @property
- def shape(self):
- return self.__target.shape
-
- @property
- def size(self):
- return self.__target.size
-
- @property
- def dtype(self):
- return self.__target.dtype
-
- def _get_data(self):
- return self.__target[...]
-
- @property
- def attrs(self):
- return self.__target.attrs
-
-
-class _LinkToDataset(Dataset):
- """Virtual dataset providing link to another dataset"""
-
- def __init__(self, name, target, parent=None):
- Dataset.__init__(self, name, data=None, parent=parent)
- self.__target = target
-
- def _get_data(self):
- return self.__target._get_data()
-
- @property
- def attrs(self):
- return self.__target.attrs
-
-
-class LazyLoadableDataset(Dataset):
- """Abstract dataset which provides a lazy loading of the data.
-
- The class has to be inherited and the :meth:`_create_data` method has to be
- implemented to return the numpy data exposed by the dataset. This factory
- method is only called once, when the data is needed.
- """
-
- def __init__(self, name, parent=None, attrs=None):
- super(LazyLoadableDataset, self).__init__(name, None, parent, attrs=attrs)
- self._is_initialized = False
-
- def _create_data(self):
- """
- Factory to create the data exposed by the dataset when it is needed.
-
- It has to be implemented for the class to work.
-
- :rtype: numpy.ndarray
- """
- raise NotImplementedError()
-
- def _get_data(self):
- """Returns the data exposed by the dataset.
-
- Overwrite Dataset method :meth:`_get_data` to implement the lazy
- loading feature.
-
- :rtype: numpy.ndarray
- """
- if not self._is_initialized:
- data = self._create_data()
- # is_initialized before set_data to avoid infinit initialization
- # is case of wrong check of the data
- self._is_initialized = True
- self._set_data(data)
- return super(LazyLoadableDataset, self)._get_data()
-
-
-class SoftLink(Node):
- """This class is a tree node that mimics a *h5py.Softlink*.
-
- In this implementation, the path to the target must be absolute.
- """
- def __init__(self, name, path, parent=None):
- assert str(path).startswith("/") # TODO: h5py also allows a relative path
-
- Node.__init__(self, name, parent)
-
- # attr target defined for spech5 backward compatibility
- self.target = str(path)
-
- @property
- def h5_class(self):
- """Returns the HDF5 class which is mimicked by this class.
-
- :rtype: H5Type
- """
- return utils.H5Type.SOFT_LINK
-
- @property
- def path(self):
- """Soft link value. Not guaranteed to be a valid path."""
- return self.target
-
-
-class Group(Node):
- """This class mimics a `h5py.Group`."""
-
- def __init__(self, name, parent=None, attrs=None):
- Node.__init__(self, name, parent, attrs=attrs)
- self.__items = collections.OrderedDict()
-
- def _get_items(self):
- """Returns the child items as a name-node dictionary.
-
- :rtype: dict
- """
- return self.__items
-
- def add_node(self, node):
- """Add a child to this group.
-
- :param Node node: Child to add to this group
- """
- self._get_items()[node.basename] = node
- node._set_parent(self)
-
- @property
- def h5_class(self):
- """Returns the HDF5 class which is mimicked by this class.
-
- :rtype: H5Type
- """
- return utils.H5Type.GROUP
-
- def _get(self, name, getlink):
- """If getlink is True and name points to an existing SoftLink, this
- SoftLink is returned. In all other situations, we try to return a
- Group or Dataset, or we raise a KeyError if we fail."""
- if "/" not in name:
- result = self._get_items()[name]
- elif name.startswith("/"):
- root = self.file
- if name == "/":
- return root
- result = root._get(name[1:], getlink)
- else:
- path = name.split("/")
- result = self
- for item_name in path:
- if isinstance(result, SoftLink):
- # traverse links
- l_name, l_target = result.name, result.path
- result = result.file.get(l_target)
- if result is None:
- raise KeyError(
- "Unable to open object (broken SoftLink %s -> %s)" %
- (l_name, l_target))
- if not item_name:
- # trailing "/" in name (legal for accessing Groups only)
- if isinstance(result, Group):
- continue
- if not isinstance(result, Group):
- raise KeyError("Unable to open object (Component not found)")
- result = result._get_items()[item_name]
-
- if isinstance(result, SoftLink) and not getlink:
- link = result
- target = result.file.get(link.path)
- if result is None:
- msg = "Unable to open object (broken SoftLink %s -> %s)"
- raise KeyError(msg % (link.name, link.path))
- # Convert SoftLink into typed group/dataset
- if isinstance(target, Group):
- result = _LinkToGroup(name=link.basename, target=target, parent=link.parent)
- elif isinstance(target, Dataset):
- result = _LinkToDataset(name=link.basename, target=target, parent=link.parent)
- else:
- raise TypeError("Unexpected target type %s" % type(target))
-
- return result
-
- def get(self, name, default=None, getclass=False, getlink=False):
- """Retrieve an item or other information.
-
- If getlink only is true, the returned value is always `h5py.HardLink`,
- because this implementation do not use links. Like the original
- implementation.
-
- :param str name: name of the item
- :param object default: default value returned if the name is not found
- :param bool getclass: if true, the returned object is the class of the object found
- :param bool getlink: if true, links object are returned instead of the target
- :return: An object, else None
- :rtype: object
- """
- if name not in self:
- return default
-
- node = self._get(name, getlink=True)
- if isinstance(node, SoftLink) and not getlink:
- # get target
- try:
- node = self._get(name, getlink=False)
- except KeyError:
- return default
- elif not isinstance(node, SoftLink) and getlink:
- # ExternalLink objects don't exist in silx, so it must be a HardLink
- node = h5py.HardLink()
-
- if getclass:
- obj = utils.get_h5py_class(node)
- if obj is None:
- obj = node.__class__
- else:
- obj = node
- return obj
-
- def __setitem__(self, name, obj):
- """Add an object to the group.
-
- :param str name: Location on the group to store the object.
- This path name must not exists.
- :param object obj: Object to store on the file. According to the type,
- the behaviour will not be the same.
-
- - `commonh5.SoftLink`: Create the corresponding link.
- - `numpy.ndarray`: The array is converted to a dataset object.
- - `commonh5.Node`: A hard link should be created pointing to the
- given object. This implementation uses a soft link.
- If the node do not have parent it is connected to the tree
- without using a link (that's a hard link behaviour).
- - other object: Convert first the object with ndarray and then
- store it. ValueError if the resulting array dtype is not
- supported.
- """
- if name in self:
- # From the h5py API
- raise RuntimeError("Unable to create link (name already exists)")
-
- elements = name.rsplit("/", 1)
- if len(elements) == 1:
- parent = self
- basename = elements[0]
- else:
- group_path, basename = elements
- if group_path in self:
- parent = self[group_path]
- else:
- parent = self.create_group(group_path)
-
- if isinstance(obj, SoftLink):
- obj._set_basename(basename)
- node = obj
- elif isinstance(obj, Node):
- if obj.parent is None:
- obj._set_basename(basename)
- node = obj
- else:
- node = SoftLink(basename, obj.name)
- elif isinstance(obj, numpy.dtype):
- node = Dataset(basename, data=obj)
- elif isinstance(obj, numpy.ndarray):
- node = Dataset(basename, data=obj)
- else:
- data = numpy.array(obj)
- try:
- node = Dataset(basename, data=data)
- except TypeError as e:
- raise ValueError(e.args[0])
-
- parent.add_node(node)
-
- def __getitem__(self, name):
- """Return a child from his name.
-
- :param str name: name of a member or a path throug members using '/'
- separator. A '/' as a prefix access to the root item of the tree.
- :rtype: Node
- """
- if name is None or name == "":
- raise ValueError("No name")
- return self._get(name, getlink=False)
-
- def __contains__(self, name):
- """Returns true if name is an existing child of this group.
-
- :rtype: bool
- """
- if "/" not in name:
- return name in self._get_items()
-
- if name.startswith("/"):
- # h5py allows to access any valid full path from any group
- node = self.file
- else:
- node = self
-
- name = name.lstrip("/")
- basenames = name.split("/")
- for basename in basenames:
- if basename.strip() == "":
- # presence of a trailing "/" in name
- # (OK for groups, not for datasets)
- if isinstance(node, SoftLink):
- # traverse links
- node = node.file.get(node.path, getlink=False)
- if node is None:
- # broken link
- return False
- if utils.is_dataset(node):
- return False
- continue
- if basename not in node._get_items():
- return False
- node = node[basename]
-
- return True
-
- def __len__(self):
- """Returns the number of children contained in this group.
-
- :rtype: int
- """
- return len(self._get_items())
-
- def __iter__(self):
- """Iterate over member names"""
- for x in self._get_items().__iter__():
- yield x
-
- if six.PY2:
- def keys(self):
- """Returns a list of the children's names."""
- return self._get_items().keys()
-
- def values(self):
- """Returns a list of the children nodes (groups and datasets).
-
- .. versionadded:: 0.6
- """
- return self._get_items().values()
-
- def items(self):
- """Returns a list of tuples containing (name, node) pairs.
- """
- return self._get_items().items()
-
- else:
- def keys(self):
- """Returns an iterator over the children's names in a group."""
- return self._get_items().keys()
-
- def values(self):
- """Returns an iterator over the children nodes (groups and datasets)
- in a group.
-
- .. versionadded:: 0.6
- """
- return self._get_items().values()
-
- def items(self):
- """Returns items iterator containing name-node mapping.
-
- :rtype: iterator
- """
- return self._get_items().items()
-
- def visit(self, func, visit_links=False):
- """Recursively visit all names in this group and subgroups.
- See the documentation for `h5py.Group.visit` for more help.
-
- :param func: Callable (function, method or callable object)
- :type func: callable
- """
- origin_name = self.name
- return self._visit(func, origin_name, visit_links)
-
- def visititems(self, func, visit_links=False):
- """Recursively visit names and objects in this group.
- See the documentation for `h5py.Group.visititems` for more help.
-
- :param func: Callable (function, method or callable object)
- :type func: callable
- :param bool visit_links: If *False*, ignore links. If *True*,
- call `func(name)` for links and recurse into target groups.
- """
- origin_name = self.name
- return self._visit(func, origin_name, visit_links,
- visititems=True)
-
- def _visit(self, func, origin_name,
- visit_links=False, visititems=False):
- """
-
- :param origin_name: name of first group that initiated the recursion
- This is used to compute the relative path from each item's
- absolute path.
- """
- for member in self.values():
- ret = None
- if not isinstance(member, SoftLink) or visit_links:
- relative_name = member.name[len(origin_name):]
- # remove leading slash and unnecessary trailing slash
- relative_name = relative_name.strip("/")
- if visititems:
- ret = func(relative_name, member)
- else:
- ret = func(relative_name)
- if ret is not None:
- return ret
- if isinstance(member, Group):
- member._visit(func, origin_name, visit_links, visititems)
-
- def create_group(self, name):
- """Create and return a new subgroup.
-
- Name may be absolute or relative. Fails if the target name already
- exists.
-
- :param str name: Name of the new group
- """
- if not self._is_editable():
- raise RuntimeError("File is not editable")
- if name in self:
- raise ValueError("Unable to create group (name already exists)")
-
- if name.startswith("/"):
- name = name[1:]
- return self.file.create_group(name)
-
- elements = name.split('/')
- group = self
- for basename in elements:
- if basename in group:
- group = group[basename]
- if not isinstance(group, Group):
- raise RuntimeError("Unable to create group (group parent is missing")
- else:
- node = Group(basename)
- group.add_node(node)
- group = node
- return group
-
- def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
- """Create and return a sub dataset.
-
- :param str name: Name of the dataset.
- :param shape: Dataset shape. Use "()" for scalar datasets.
- Required if "data" isn't provided.
- :param dtype: Numpy dtype or string.
- If omitted, dtype('f') will be used.
- Required if "data" isn't provided; otherwise, overrides data
- array's dtype.
- :param numpy.ndarray data: Provide data to initialize the dataset.
- If used, you can omit shape and dtype arguments.
- :param kwds: Extra arguments. Nothing yet supported.
- """
- if not self._is_editable():
- raise RuntimeError("File is not editable")
- if len(kwds) > 0:
- raise TypeError("Extra args provided, but nothing supported")
- if "/" in name:
- raise TypeError("Path are not supported")
- if data is None:
- if dtype is None:
- dtype = numpy.float64
- data = numpy.empty(shape=shape, dtype=dtype)
- elif dtype is not None:
- data = data.astype(dtype)
- dataset = Dataset(name, data)
- self.add_node(dataset)
- return dataset
-
-
-class _LinkToGroup(Group):
- """Virtual group providing link to another group"""
-
- def __init__(self, name, target, parent=None):
- Group.__init__(self, name, parent=parent)
- self.__target = target
-
- def _get_items(self):
- return self.__target._get_items()
-
- @property
- def attrs(self):
- return self.__target.attrs
-
-
-class LazyLoadableGroup(Group):
- """Abstract group which provides a lazy loading of the child.
-
- The class has to be inherited and the :meth:`_create_child` method has
- to be implemented to add (:meth:`_add_node`) all children. This factory
- is only called once, when children are needed.
- """
-
- def __init__(self, name, parent=None, attrs=None):
- Group.__init__(self, name, parent, attrs)
- self.__is_initialized = False
-
- def _get_items(self):
- """Returns the internal structure which contains the children.
-
- It overwrite method :meth:`_get_items` to implement the lazy
- loading feature.
-
- :rtype: dict
- """
- if not self.__is_initialized:
- self.__is_initialized = True
- self._create_child()
- return Group._get_items(self)
-
- def _create_child(self):
- """
- Factory to create the child contained by the group when it is needed.
-
- It has to be implemented to work.
- """
- raise NotImplementedError()
-
-
-class File(Group):
- """This class is the special :class:`Group` that is the root node
- of the tree structure. It mimics `h5py.File`."""
-
- def __init__(self, name=None, mode=None, attrs=None):
- """
- Constructor
-
- :param str name: File name if it exists
- :param str mode: Access mode
- - "r": Read-only. Methods :meth:`create_dataset` and
- :meth:`create_group` are locked.
- - "w": File is editable. Methods :meth:`create_dataset` and
- :meth:`create_group` are available.
- :param dict attrs: Default attributes
- """
- Group.__init__(self, name="", parent=None, attrs=attrs)
- self._file_name = name
- if mode is None:
- mode = "r"
- assert(mode in ["r", "w"])
- self._mode = mode
-
- @property
- def filename(self):
- return self._file_name
-
- @property
- def mode(self):
- return self._mode
-
- @property
- def h5_class(self):
- """Returns the :class:`h5py.File` class"""
- return utils.H5Type.FILE
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
- def close(self):
- """Close the object, and free up associated resources.
- """
- # should be implemented in subclass
- pass
diff --git a/silx/io/convert.py b/silx/io/convert.py
deleted file mode 100644
index 5b809ba..0000000
--- a/silx/io/convert.py
+++ /dev/null
@@ -1,343 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""This module provides classes and function to convert file formats supported
-by *silx* into HDF5 file. Currently, SPEC file and fabio images are the
-supported formats.
-
-Read the documentation of :mod:`silx.io.spech5` and :mod:`silx.io.fabioh5` for
-information on the structure of the output HDF5 files.
-
-Text strings are written to the HDF5 datasets as variable-length utf-8.
-
-.. warning::
-
- The output format for text strings changed in silx version 0.7.0.
- Prior to that, text was output as fixed-length ASCII.
-
- To be on the safe side, when reading back a HDF5 file written with an
- older version of silx, you can test for the presence of a *decode*
- attribute. To ensure that you always work with unicode text::
-
- >>> import h5py
- >>> h5f = h5py.File("my_scans.h5", "r")
- >>> title = h5f["/68.1/title"]
- >>> if hasattr(title, "decode"):
- ... title = title.decode()
-
-
-.. note:: This module has a dependency on the `h5py <http://www.h5py.org/>`_
- library, which is not a mandatory dependency for `silx`. You might need
- to install it if you don't already have it.
-"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/07/2018"
-
-
-import logging
-
-import h5py
-import numpy
-import six
-
-import silx.io
-from silx.io import is_dataset, is_group, is_softlink
-from silx.io import fabioh5
-
-
-_logger = logging.getLogger(__name__)
-
-
-def _create_link(h5f, link_name, target_name,
- link_type="soft", overwrite_data=False):
- """Create a link in a HDF5 file
-
- If member with name ``link_name`` already exists, delete it first or
- ignore link depending on global param ``overwrite_data``.
-
- :param h5f: :class:`h5py.File` object
- :param link_name: Link path
- :param target_name: Handle for target group or dataset
- :param str link_type: "soft" or "hard"
- :param bool overwrite_data: If True, delete existing member (group,
- dataset or link) with the same name. Default is False.
- """
- if link_name not in h5f:
- _logger.debug("Creating link " + link_name + " -> " + target_name)
- elif overwrite_data:
- _logger.warning("Overwriting " + link_name + " with link to " +
- target_name)
- del h5f[link_name]
- else:
- _logger.warning(link_name + " already exist. Cannot create link to " +
- target_name)
- return None
-
- if link_type == "hard":
- h5f[link_name] = h5f[target_name]
- elif link_type == "soft":
- h5f[link_name] = h5py.SoftLink(target_name)
- else:
- raise ValueError("link_type must be 'hard' or 'soft'")
-
-
-def _attr_utf8(attr_value):
- """If attr_value is bytes, make sure we output utf-8
-
- :param attr_value: String (possibly bytes if PY2)
- :return: Attr ready to be written by h5py as utf8
- """
- if isinstance(attr_value, six.binary_type) or \
- isinstance(attr_value, six.text_type):
- out_attr_value = numpy.array(
- attr_value,
- dtype=h5py.special_dtype(vlen=six.text_type))
- else:
- out_attr_value = attr_value
-
- return out_attr_value
-
-
-class Hdf5Writer(object):
- """Converter class to write the content of a data file to a HDF5 file.
- """
- def __init__(self,
- h5path='/',
- overwrite_data=False,
- link_type="soft",
- create_dataset_args=None,
- min_size=500):
- """
-
- :param h5path: Target path where the scan groups will be written
- in the output HDF5 file.
- :param bool overwrite_data:
- See documentation of :func:`write_to_h5`
- :param str link_type: ``"hard"`` or ``"soft"`` (default)
- :param dict create_dataset_args: Dictionary of args you want to pass to
- ``h5py.File.create_dataset``.
- See documentation of :func:`write_to_h5`
- :param int min_size:
- See documentation of :func:`write_to_h5`
- """
- self.h5path = h5path
- if not h5path.startswith("/"):
- # target path must be absolute
- self.h5path = "/" + h5path
- if not self.h5path.endswith("/"):
- self.h5path += "/"
-
- self._h5f = None
- """h5py.File object, assigned in :meth:`write`"""
-
- if create_dataset_args is None:
- create_dataset_args = {}
- self.create_dataset_args = create_dataset_args
-
- self.min_size = min_size
-
- self.overwrite_data = overwrite_data # boolean
-
- self.link_type = link_type
- """'soft' or 'hard' """
-
- self._links = []
- """List of *(link_path, target_path)* tuples."""
-
- def write(self, infile, h5f):
- """Do the conversion from :attr:`sfh5` (Spec file) to *h5f* (HDF5)
-
- All the parameters needed for the conversion have been initialized
- in the constructor.
-
- :param infile: :class:`SpecH5` object
- :param h5f: :class:`h5py.File` instance
- """
- # Recurse through all groups and datasets to add them to the HDF5
- self._h5f = h5f
- infile.visititems(self.append_member_to_h5, visit_links=True)
-
- # Handle the attributes of the root group
- root_grp = h5f[self.h5path]
- for key in infile.attrs:
- if self.overwrite_data or key not in root_grp.attrs:
- root_grp.attrs.create(key,
- _attr_utf8(infile.attrs[key]))
-
- # Handle links at the end, when their targets are created
- for link_name, target_name in self._links:
- _create_link(self._h5f, link_name, target_name,
- link_type=self.link_type,
- overwrite_data=self.overwrite_data)
- self._links = []
-
- def append_member_to_h5(self, h5like_name, obj):
- """Add one group or one dataset to :attr:`h5f`"""
- h5_name = self.h5path + h5like_name.lstrip("/")
- if is_softlink(obj):
- # links to be created after all groups and datasets
- h5_target = self.h5path + obj.path.lstrip("/")
- self._links.append((h5_name, h5_target))
-
- elif is_dataset(obj):
- _logger.debug("Saving dataset: " + h5_name)
-
- member_initially_exists = h5_name in self._h5f
-
- if self.overwrite_data and member_initially_exists:
- _logger.warning("Overwriting dataset: " + h5_name)
- del self._h5f[h5_name]
-
- if self.overwrite_data or not member_initially_exists:
- if isinstance(obj, fabioh5.FrameData) and len(obj.shape) > 2:
- # special case of multiframe data
- # write frame by frame to save memory usage low
- ds = self._h5f.create_dataset(h5_name,
- shape=obj.shape,
- dtype=obj.dtype,
- **self.create_dataset_args)
- for i, frame in enumerate(obj):
- ds[i] = frame
- else:
- # fancy arguments don't apply to small dataset
- if obj.size < self.min_size:
- ds = self._h5f.create_dataset(h5_name, data=obj.value)
- else:
- ds = self._h5f.create_dataset(h5_name, data=obj.value,
- **self.create_dataset_args)
- else:
- ds = self._h5f[h5_name]
-
- # add HDF5 attributes
- for key in obj.attrs:
- if self.overwrite_data or key not in ds.attrs:
- ds.attrs.create(key,
- _attr_utf8(obj.attrs[key]))
-
- if not self.overwrite_data and member_initially_exists:
- _logger.warning("Not overwriting existing dataset: " + h5_name)
-
- elif is_group(obj):
- if h5_name not in self._h5f:
- _logger.debug("Creating group: " + h5_name)
- grp = self._h5f.create_group(h5_name)
- else:
- grp = self._h5f[h5_name]
-
- # add HDF5 attributes
- for key in obj.attrs:
- if self.overwrite_data or key not in grp.attrs:
- grp.attrs.create(key,
- _attr_utf8(obj.attrs[key]))
-
-
-def _is_commonh5_group(grp):
- """Return True if grp is a commonh5 group.
- (h5py.Group objects are not commonh5 groups)"""
- return is_group(grp) and not isinstance(grp, h5py.Group)
-
-
-def write_to_h5(infile, h5file, h5path='/', mode="a",
- overwrite_data=False, link_type="soft",
- create_dataset_args=None, min_size=500):
- """Write content of a h5py-like object into a HDF5 file.
-
- :param infile: Path of input file, or :class:`commonh5.File` object
- or :class:`commonh5.Group` object.
- :param h5file: Path of output HDF5 file or HDF5 file handle
- (`h5py.File` object)
- :param str h5path: Target path in HDF5 file in which scan groups are created.
- Default is root (``"/"``)
- :param str mode: Can be ``"r+"`` (read/write, file must exist),
- ``"w"`` (write, existing file is lost), ``"w-"`` (write, fail
- if exists) or ``"a"`` (read/write if exists, create otherwise).
- This parameter is ignored if ``h5file`` is a file handle.
- :param bool overwrite_data: If ``True``, existing groups and datasets can be
- overwritten, if ``False`` they are skipped. This parameter is only
- relevant if ``file_mode`` is ``"r+"`` or ``"a"``.
- :param str link_type: *"soft"* (default) or *"hard"*
- :param dict create_dataset_args: Dictionary of args you want to pass to
- ``h5py.File.create_dataset``. This allows you to specify filters and
- compression parameters. Don't specify ``name`` and ``data``.
- These arguments are only applied to datasets larger than 1MB.
- :param int min_size: Minimum number of elements in a dataset to apply
- chunking and compression. Default is 500.
-
- The structure of the spec data in an HDF5 file is described in the
- documentation of :mod:`silx.io.spech5`.
- """
- writer = Hdf5Writer(h5path=h5path,
- overwrite_data=overwrite_data,
- link_type=link_type,
- create_dataset_args=create_dataset_args,
- min_size=min_size)
-
- # both infile and h5file can be either file handle or a file name: 4 cases
- if not isinstance(h5file, h5py.File) and not is_group(infile):
- with silx.io.open(infile) as h5pylike:
- if not _is_commonh5_group(h5pylike):
- raise IOError("Cannot convert HDF5 file %s to HDF5" % infile)
- with h5py.File(h5file, mode) as h5f:
- writer.write(h5pylike, h5f)
- elif isinstance(h5file, h5py.File) and not is_group(infile):
- with silx.io.open(infile) as h5pylike:
- if not _is_commonh5_group(h5pylike):
- raise IOError("Cannot convert HDF5 file %s to HDF5" % infile)
- writer.write(h5pylike, h5file)
- elif is_group(infile) and not isinstance(h5file, h5py.File):
- if not _is_commonh5_group(infile):
- raise IOError("Cannot convert HDF5 file %s to HDF5" % infile.file.name)
- with h5py.File(h5file, mode) as h5f:
- writer.write(infile, h5f)
- else:
- if not _is_commonh5_group(infile):
- raise IOError("Cannot convert HDF5 file %s to HDF5" % infile.file.name)
- writer.write(infile, h5file)
-
-
-def convert(infile, h5file, mode="w-", create_dataset_args=None):
- """Convert a supported file into an HDF5 file, write scans into the
- root group (``/``).
-
- This is a convenience shortcut to call::
-
- write_to_h5(h5like, h5file, h5path='/',
- mode="w-", link_type="soft")
-
- :param infile: Path of input file or :class:`commonh5.File` object
- or :class:`commonh5.Group` object
- :param h5file: Path of output HDF5 file, or h5py.File object
- :param mode: Can be ``"w"`` (write, existing file is
- lost), ``"w-"`` (write, fail if exists). This is ignored
- if ``h5file`` is a file handle.
- :param create_dataset_args: Dictionary of args you want to pass to
- ``h5py.File.create_dataset``. This allows you to specify filters and
- compression parameters. Don't specify ``name`` and ``data``.
- """
- if mode not in ["w", "w-"]:
- raise IOError("File mode must be 'w' or 'w-'. Use write_to_h5" +
- " to append data to an existing HDF5 file.")
- write_to_h5(infile, h5file, h5path='/', mode=mode,
- create_dataset_args=create_dataset_args)
diff --git a/silx/io/dictdump.py b/silx/io/dictdump.py
deleted file mode 100644
index e907668..0000000
--- a/silx/io/dictdump.py
+++ /dev/null
@@ -1,842 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""This module offers a set of functions to dump a python dictionary indexed
-by text strings to following file formats: `HDF5, INI, JSON`
-"""
-
-from collections import OrderedDict
-from collections.abc import Mapping
-import json
-import logging
-import numpy
-import os.path
-import sys
-import h5py
-
-from .configdict import ConfigDict
-from .utils import is_group
-from .utils import is_dataset
-from .utils import is_link
-from .utils import is_softlink
-from .utils import is_externallink
-from .utils import is_file as is_h5_file_like
-from .utils import open as h5open
-from .utils import h5py_read_dataset
-from .utils import H5pyAttributesReadWrapper
-from silx.utils.deprecation import deprecated_warning
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/07/2018"
-
-logger = logging.getLogger(__name__)
-
-vlen_utf8 = h5py.special_dtype(vlen=str)
-vlen_bytes = h5py.special_dtype(vlen=bytes)
-
-
-def _prepare_hdf5_write_value(array_like):
- """Cast a python object into a numpy array in a HDF5 friendly format.
-
- :param array_like: Input dataset in a type that can be digested by
- ``numpy.array()`` (`str`, `list`, `numpy.ndarray`…)
- :return: ``numpy.ndarray`` ready to be written as an HDF5 dataset
- """
- array = numpy.asarray(array_like)
- if numpy.issubdtype(array.dtype, numpy.bytes_):
- return numpy.array(array_like, dtype=vlen_bytes)
- elif numpy.issubdtype(array.dtype, numpy.str_):
- return numpy.array(array_like, dtype=vlen_utf8)
- else:
- return array
-
-
-class _SafeH5FileWrite:
- """Context manager returning a :class:`h5py.File` object.
-
- If this object is initialized with a file path, we open the file
- and then we close it on exiting.
-
- If a :class:`h5py.File` instance is provided to :meth:`__init__` rather
- than a path, we assume that the user is responsible for closing the
- file.
-
- This behavior is well suited for handling h5py file in a recursive
- function. The object is created in the initial call if a path is provided,
- and it is closed only at the end when all the processing is finished.
- """
- def __init__(self, h5file, mode="w"):
- """
- :param h5file: HDF5 file path or :class:`h5py.File` instance
- :param str mode: Can be ``"r+"`` (read/write, file must exist),
- ``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
- exists) or ``"a"`` (read/write if exists, create otherwise).
- This parameter is ignored if ``h5file`` is a file handle.
- """
- self.raw_h5file = h5file
- self.mode = mode
-
- def __enter__(self):
- if not isinstance(self.raw_h5file, h5py.File):
- self.h5file = h5py.File(self.raw_h5file, self.mode)
- self.close_when_finished = True
- else:
- self.h5file = self.raw_h5file
- self.close_when_finished = False
- return self.h5file
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.close_when_finished:
- self.h5file.close()
-
-
-class _SafeH5FileRead:
- """Context manager returning a :class:`h5py.File` or a
- :class:`silx.io.spech5.SpecH5` or a :class:`silx.io.fabioh5.File` object.
-
- The general behavior is the same as :class:`_SafeH5FileWrite` except
- that SPEC files and all formats supported by fabio can also be opened,
- but in read-only mode.
- """
- def __init__(self, h5file):
- """
-
- :param h5file: HDF5 file path or h5py.File-like object
- """
- self.raw_h5file = h5file
-
- def __enter__(self):
- if not is_h5_file_like(self.raw_h5file):
- self.h5file = h5open(self.raw_h5file)
- self.close_when_finished = True
- else:
- self.h5file = self.raw_h5file
- self.close_when_finished = False
-
- return self.h5file
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.close_when_finished:
- self.h5file.close()
-
-
-def _normalize_h5_path(h5root, h5path):
- """
- :param h5root: File name or h5py-like File, Group or Dataset
- :param str h5path: relative to ``h5root``
- :returns 2-tuple: (File or file object, h5path)
- """
- if is_group(h5root):
- group_name = h5root.name
- if group_name == "/":
- pass
- elif h5path:
- h5path = group_name + "/" + h5path
- else:
- h5path = group_name
- h5file = h5root.file
- elif is_dataset(h5root):
- h5path = h5root.name
- h5file = h5root.file
- else:
- h5file = h5root
- if not h5path:
- h5path = "/"
- elif not h5path.endswith("/"):
- h5path += "/"
- return h5file, h5path
-
-
-def dicttoh5(treedict, h5file, h5path='/',
- mode="w", overwrite_data=None,
- create_dataset_args=None, update_mode=None):
- """Write a nested dictionary to a HDF5 file, using keys as member names.
-
- If a dictionary value is a sub-dictionary, a group is created. If it is
- any other data type, it is cast into a numpy array and written as a
- :mod:`h5py` dataset. Dictionary keys must be strings and cannot contain
- the ``/`` character.
-
- If dictionary keys are tuples they are interpreted to set h5 attributes.
- The tuples should have the format (dataset_name, attr_name).
-
- Existing HDF5 items can be deleted by providing the dictionary value
- ``None``, provided that ``update_mode in ["modify", "replace"]``.
-
- .. note::
-
- This function requires `h5py <http://www.h5py.org/>`_ to be installed.
-
- :param treedict: Nested dictionary/tree structure with strings or tuples as
- keys and array-like objects as leafs. The ``"/"`` character can be used
- to define sub trees. If tuples are used as keys they should have the
- format (dataset_name,attr_name) and will add a 5h attribute with the
- corresponding value.
- :param h5file: File name or h5py-like File, Group or Dataset
- :param h5path: Target path in the HDF5 file relative to ``h5file``.
- Default is root (``"/"``)
- :param mode: Can be ``"r+"`` (read/write, file must exist),
- ``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
- exists) or ``"a"`` (read/write if exists, create otherwise).
- This parameter is ignored if ``h5file`` is a file handle.
- :param overwrite_data: Deprecated. ``True`` is approximately equivalent
- to ``update_mode="modify"`` and ``False`` is equivalent to
- ``update_mode="add"``.
- :param create_dataset_args: Dictionary of args you want to pass to
- ``h5f.create_dataset``. This allows you to specify filters and
- compression parameters. Don't specify ``name`` and ``data``.
- :param update_mode: Can be ``add`` (default), ``modify`` or ``replace``.
-
- * ``add``: Extend the existing HDF5 tree when possible. Existing HDF5
- items (groups, datasets and attributes) remain untouched.
- * ``modify``: Extend the existing HDF5 tree when possible, modify
- existing attributes, modify same-sized dataset values and delete
- HDF5 items with a ``None`` value in the dict tree.
- * ``replace``: Replace the existing HDF5 tree. Items from the root of
- the HDF5 tree that are not present in the root of the dict tree
- will remain untouched.
-
- Example::
-
- from silx.io.dictdump import dicttoh5
-
- city_area = {
- "Europe": {
- "France": {
- "Isère": {
- "Grenoble": 18.44,
- ("Grenoble","unit"): "km2"
- },
- "Nord": {
- "Tourcoing": 15.19,
- ("Tourcoing","unit"): "km2"
- },
- },
- },
- }
-
- create_ds_args = {'compression': "gzip",
- 'shuffle': True,
- 'fletcher32': True}
-
- dicttoh5(city_area, "cities.h5", h5path="/area",
- create_dataset_args=create_ds_args)
- """
-
- if overwrite_data is not None:
- reason = (
- "`overwrite_data=True` becomes `update_mode='modify'` and "
- "`overwrite_data=False` becomes `update_mode='add'`"
- )
- deprecated_warning(
- type_="argument",
- name="overwrite_data",
- reason=reason,
- replacement="update_mode",
- since_version="0.15",
- )
-
- if update_mode is None:
- if overwrite_data:
- update_mode = "modify"
- else:
- update_mode = "add"
- else:
- valid_existing_values = ("add", "replace", "modify")
- if update_mode not in valid_existing_values:
- raise ValueError((
- "Argument 'update_mode' can only have values: {}"
- "".format(valid_existing_values)
- ))
- if overwrite_data is not None:
- logger.warning("The argument `overwrite_data` is ignored")
-
- if not isinstance(treedict, Mapping):
- raise TypeError("'treedict' must be a dictionary")
-
- h5file, h5path = _normalize_h5_path(h5file, h5path)
-
- def _iter_treedict(attributes=False):
- nonlocal treedict
- for key, value in treedict.items():
- if isinstance(key, tuple) == attributes:
- yield key, value
-
- change_allowed = update_mode in ("replace", "modify")
-
- with _SafeH5FileWrite(h5file, mode=mode) as h5f:
- # Create the root of the tree
- if h5path in h5f:
- if not is_group(h5f[h5path]):
- if update_mode == "replace":
- del h5f[h5path]
- h5f.create_group(h5path)
- else:
- return
- else:
- h5f.create_group(h5path)
-
- # Loop over all groups, links and datasets
- for key, value in _iter_treedict(attributes=False):
- h5name = h5path + key
- exists = h5name in h5f
-
- if value is None:
- # Delete HDF5 item
- if exists and change_allowed:
- del h5f[h5name]
- exists = False
- elif isinstance(value, Mapping):
- # HDF5 group
- if exists and update_mode == "replace":
- del h5f[h5name]
- exists = False
- if value:
- dicttoh5(value, h5f, h5name,
- update_mode=update_mode,
- create_dataset_args=create_dataset_args)
- elif not exists:
- h5f.create_group(h5name)
- elif is_link(value):
- # HDF5 link
- if exists and update_mode == "replace":
- del h5f[h5name]
- exists = False
- if not exists:
- # Create link from h5py link object
- h5f[h5name] = value
- else:
- # HDF5 dataset
- if exists and not change_allowed:
- continue
- data = _prepare_hdf5_write_value(value)
-
- # Edit the existing dataset
- attrs_backup = None
- if exists:
- try:
- h5f[h5name][()] = data
- continue
- except Exception:
- # Delete the existing dataset
- if update_mode != "replace":
- if not is_dataset(h5f[h5name]):
- continue
- attrs_backup = dict(h5f[h5name].attrs)
- del h5f[h5name]
-
- # Create dataset
- # can't apply filters on scalars (datasets with shape == ())
- if data.shape == () or create_dataset_args is None:
- h5f.create_dataset(h5name,
- data=data)
- else:
- h5f.create_dataset(h5name,
- data=data,
- **create_dataset_args)
- if attrs_backup:
- h5f[h5name].attrs.update(attrs_backup)
-
- # Loop over all attributes
- for key, value in _iter_treedict(attributes=True):
- if len(key) != 2:
- raise ValueError("HDF5 attribute must be described by 2 values")
- h5name = h5path + key[0]
- attr_name = key[1]
-
- if h5name not in h5f:
- # Create an empty group to store the attribute
- h5f.create_group(h5name)
-
- h5a = h5f[h5name].attrs
- exists = attr_name in h5a
-
- if value is None:
- # Delete HDF5 attribute
- if exists and change_allowed:
- del h5a[attr_name]
- exists = False
- else:
- # Add/modify HDF5 attribute
- if exists and not change_allowed:
- continue
- data = _prepare_hdf5_write_value(value)
- h5a[attr_name] = data
-
-
-def _has_nx_class(treedict, key=""):
- return key + "@NX_class" in treedict or \
- (key, "NX_class") in treedict
-
-
-def _ensure_nx_class(treedict, parents=tuple()):
- """Each group needs an "NX_class" attribute.
- """
- if _has_nx_class(treedict):
- return
- nparents = len(parents)
- if nparents == 0:
- treedict[("", "NX_class")] = "NXroot"
- elif nparents == 1:
- treedict[("", "NX_class")] = "NXentry"
- else:
- treedict[("", "NX_class")] = "NXcollection"
-
-
-def nexus_to_h5_dict(
- treedict, parents=tuple(), add_nx_class=True, has_nx_class=False
-):
- """The following conversions are applied:
- * key with "{name}@{attr_name}" notation: key converted to 2-tuple
- * key with ">{url}" notation: strip ">" and convert value to
- h5py.SoftLink or h5py.ExternalLink
-
- :param treedict: Nested dictionary/tree structure with strings as keys
- and array-like objects as leafs. The ``"/"`` character can be used
- to define sub tree. The ``"@"`` character is used to write attributes.
- The ``">"`` prefix is used to define links.
- :param parents: Needed to resolve up-links (tuple of HDF5 group names)
- :param add_nx_class: Add "NX_class" attribute when missing
- :param has_nx_class: The "NX_class" attribute is defined in the parent
-
- :rtype dict:
- """
- if not isinstance(treedict, Mapping):
- raise TypeError("'treedict' must be a dictionary")
- copy = dict()
- for key, value in treedict.items():
- if "@" in key:
- # HDF5 attribute
- key = tuple(key.rsplit("@", 1))
- elif key.startswith(">"):
- # HDF5 link
- if isinstance(value, str):
- key = key[1:]
- first, sep, second = value.partition("::")
- if sep:
- value = h5py.ExternalLink(first, second)
- else:
- if ".." in first:
- # Up-links not supported: make absolute
- parts = []
- for p in list(parents) + first.split("/"):
- if not p or p == ".":
- continue
- elif p == "..":
- parts.pop(-1)
- else:
- parts.append(p)
- first = "/" + "/".join(parts)
- value = h5py.SoftLink(first)
- elif is_link(value):
- key = key[1:]
- if isinstance(value, Mapping):
- # HDF5 group
- key_has_nx_class = add_nx_class and _has_nx_class(treedict, key)
- copy[key] = nexus_to_h5_dict(
- value,
- parents=parents+(key,),
- add_nx_class=add_nx_class,
- has_nx_class=key_has_nx_class)
- else:
- # HDF5 dataset or link
- copy[key] = value
- if add_nx_class and not has_nx_class:
- _ensure_nx_class(copy, parents)
- return copy
-
-
-def h5_to_nexus_dict(treedict):
- """The following conversions are applied:
- * 2-tuple key: converted to string ("@" notation)
- * h5py.Softlink value: converted to string (">" key prefix)
- * h5py.ExternalLink value: converted to string (">" key prefix)
-
- :param treedict: Nested dictionary/tree structure with strings as keys
- and array-like objects as leafs. The ``"/"`` character can be used
- to define sub tree.
-
- :rtype dict:
- """
- copy = dict()
- for key, value in treedict.items():
- if isinstance(key, tuple):
- if len(key) != 2:
- raise ValueError("HDF5 attribute must be described by 2 values")
- key = "%s@%s" % (key[0], key[1])
- elif is_softlink(value):
- key = ">" + key
- value = value.path
- elif is_externallink(value):
- key = ">" + key
- value = value.filename + "::" + value.path
- if isinstance(value, Mapping):
- copy[key] = h5_to_nexus_dict(value)
- else:
- copy[key] = value
- return copy
-
-
-def _name_contains_string_in_list(name, strlist):
- if strlist is None:
- return False
- for filter_str in strlist:
- if filter_str in name:
- return True
- return False
-
-
-def _handle_error(mode: str, exception, msg: str, *args) -> None:
- """Handle errors.
-
- :param str mode: 'raise', 'log', 'ignore'
- :param type exception: Exception class to use in 'raise' mode
- :param str msg: Error message template
- :param List[str] args: Arguments for error message template
- """
- if mode == 'ignore':
- return # no-op
- elif mode == 'log':
- logger.error(msg, *args)
- elif mode == 'raise':
- raise exception(msg % args)
- else:
- raise ValueError("Unsupported error handling: %s" % mode)
-
-
-def h5todict(h5file,
- path="/",
- exclude_names=None,
- asarray=True,
- dereference_links=True,
- include_attributes=False,
- errors='raise'):
- """Read a HDF5 file and return a nested dictionary with the complete file
- structure and all data.
-
- Example of usage::
-
- from silx.io.dictdump import h5todict
-
- # initialize dict with file header and scan header
- header94 = h5todict("oleg.dat",
- "/94.1/instrument/specfile")
- # add positioners subdict
- header94["positioners"] = h5todict("oleg.dat",
- "/94.1/instrument/positioners")
- # add scan data without mca data
- header94["detector data"] = h5todict("oleg.dat",
- "/94.1/measurement",
- exclude_names="mca_")
-
-
- .. note:: This function requires `h5py <http://www.h5py.org/>`_ to be
- installed.
-
- .. note:: If you write a dictionary to a HDF5 file with
- :func:`dicttoh5` and then read it back with :func:`h5todict`, data
- types are not preserved. All values are cast to numpy arrays before
- being written to file, and they are read back as numpy arrays (or
- scalars). In some cases, you may find that a list of heterogeneous
- data types is converted to a numpy array of strings.
-
- :param h5file: File name or h5py-like File, Group or Dataset
- :param str path: Target path in the HDF5 file relative to ``h5file``
- :param List[str] exclude_names: Groups and datasets whose name contains
- a string in this list will be ignored. Default is None (ignore nothing)
- :param bool asarray: True (default) to read scalar as arrays, False to
- read them as scalar
- :param bool dereference_links: True (default) to dereference links, False
- to preserve the link itself
- :param bool include_attributes: False (default)
- :param str errors: Handling of errors (HDF5 access issue, broken link,...):
- - 'raise' (default): Raise an exception
- - 'log': Log as errors
- - 'ignore': Ignore errors
- :return: Nested dictionary
- """
- h5file, path = _normalize_h5_path(h5file, path)
- with _SafeH5FileRead(h5file) as h5f:
- ddict = {}
- if path not in h5f:
- _handle_error(
- errors, KeyError, 'Path "%s" does not exist in file.', path)
- return ddict
-
- try:
- root = h5f[path]
- except KeyError as e:
- if not isinstance(h5f.get(path, getlink=True), h5py.HardLink):
- _handle_error(errors,
- KeyError,
- 'Cannot retrieve path "%s" (broken link)',
- path)
- else:
- _handle_error(errors, KeyError, ', '.join(e.args))
- return ddict
-
- # Read the attributes of the group
- if include_attributes:
- attrs = H5pyAttributesReadWrapper(root.attrs)
- for aname, avalue in attrs.items():
- ddict[("", aname)] = avalue
- # Read the children of the group
- for key in root:
- if _name_contains_string_in_list(key, exclude_names):
- continue
- h5name = path + "/" + key
- # Preserve HDF5 link when requested
- if not dereference_links:
- lnk = h5f.get(h5name, getlink=True)
- if is_link(lnk):
- ddict[key] = lnk
- continue
-
- try:
- h5obj = h5f[h5name]
- except KeyError as e:
- if not isinstance(h5f.get(h5name, getlink=True), h5py.HardLink):
- _handle_error(errors,
- KeyError,
- 'Cannot retrieve path "%s" (broken link)',
- h5name)
- else:
- _handle_error(errors, KeyError, ', '.join(e.args))
- continue
-
- if is_group(h5obj):
- # Child is an HDF5 group
- ddict[key] = h5todict(h5f,
- h5name,
- exclude_names=exclude_names,
- asarray=asarray,
- dereference_links=dereference_links,
- include_attributes=include_attributes)
- else:
- # Child is an HDF5 dataset
- try:
- data = h5py_read_dataset(h5obj)
- except OSError:
- _handle_error(errors,
- OSError,
- 'Cannot retrieve dataset "%s"',
- h5name)
- else:
- if asarray: # Convert HDF5 dataset to numpy array
- data = numpy.array(data, copy=False)
- ddict[key] = data
- # Read the attributes of the child
- if include_attributes:
- attrs = H5pyAttributesReadWrapper(h5obj.attrs)
- for aname, avalue in attrs.items():
- ddict[(key, aname)] = avalue
- return ddict
-
-
-def dicttonx(treedict, h5file, h5path="/", add_nx_class=None, **kw):
- """
- Write a nested dictionary to a HDF5 file, using string keys as member names.
- The NeXus convention is used to identify attributes with ``"@"`` character,
- therefore the dataset_names should not contain ``"@"``.
-
- Similarly, links are identified by keys starting with the ``">"`` character.
- The corresponding value can be a soft or external link.
-
- :param treedict: Nested dictionary/tree structure with strings as keys
- and array-like objects as leafs. The ``"/"`` character can be used
- to define sub tree. The ``"@"`` character is used to write attributes.
- The ``">"`` prefix is used to define links.
- :param add_nx_class: Add "NX_class" attribute when missing. By default it
- is ``True`` when ``update_mode`` is ``"add"`` or ``None``.
-
- The named parameters are passed to dicttoh5.
-
- Example::
-
- import numpy
- from silx.io.dictdump import dicttonx
-
- gauss = {
- "entry":{
- "title":u"A plot of a gaussian",
- "instrument": {
- "@NX_class": u"NXinstrument",
- "positioners": {
- "@NX_class": u"NXCollection",
- "x": numpy.arange(0,1.1,.1)
- }
- }
- "plot": {
- "y": numpy.array([0.08, 0.19, 0.39, 0.66, 0.9, 1.,
- 0.9, 0.66, 0.39, 0.19, 0.08]),
- ">x": "../instrument/positioners/x",
- "@signal": "y",
- "@axes": "x",
- "@NX_class":u"NXdata",
- "title:u"Gauss Plot",
- },
- "@NX_class": u"NXentry",
- "default":"plot",
- }
- "@NX_class": u"NXroot",
- "@default": "entry",
- }
-
- dicttonx(gauss,"test.h5")
- """
- h5file, h5path = _normalize_h5_path(h5file, h5path)
- parents = tuple(p for p in h5path.split("/") if p)
- if add_nx_class is None:
- add_nx_class = kw.get("update_mode", None) in (None, "add")
- nxtreedict = nexus_to_h5_dict(
- treedict, parents=parents, add_nx_class=add_nx_class
- )
- dicttoh5(nxtreedict, h5file, h5path=h5path, **kw)
-
-
-def nxtodict(h5file, include_attributes=True, **kw):
- """Read a HDF5 file and return a nested dictionary with the complete file
- structure and all data.
-
- As opposed to h5todict, all keys will be strings and no h5py objects are
- present in the tree.
-
- The named parameters are passed to h5todict.
- """
- nxtreedict = h5todict(h5file, include_attributes=include_attributes, **kw)
- return h5_to_nexus_dict(nxtreedict)
-
-
-def dicttojson(ddict, jsonfile, indent=None, mode="w"):
- """Serialize ``ddict`` as a JSON formatted stream to ``jsonfile``.
-
- :param ddict: Dictionary (or any object compatible with ``json.dump``).
- :param jsonfile: JSON file name or file-like object.
- If a file name is provided, the function opens the file in the
- specified mode and closes it again.
- :param indent: If indent is a non-negative integer, then JSON array
- elements and object members will be pretty-printed with that indent
- level. An indent level of ``0`` will only insert newlines.
- ``None`` (the default) selects the most compact representation.
- :param mode: File opening mode (``w``, ``a``, ``w+``…)
- """
- if not hasattr(jsonfile, "write"):
- jsonf = open(jsonfile, mode)
- else:
- jsonf = jsonfile
-
- json.dump(ddict, jsonf, indent=indent)
-
- if not hasattr(jsonfile, "write"):
- jsonf.close()
-
-
-def dicttoini(ddict, inifile, mode="w"):
- """Output dict as configuration file (similar to Microsoft Windows INI).
-
- :param dict: Dictionary of configuration parameters
- :param inifile: INI file name or file-like object.
- If a file name is provided, the function opens the file in the
- specified mode and closes it again.
- :param mode: File opening mode (``w``, ``a``, ``w+``…)
- """
- if not hasattr(inifile, "write"):
- inif = open(inifile, mode)
- else:
- inif = inifile
-
- ConfigDict(initdict=ddict).write(inif)
-
- if not hasattr(inifile, "write"):
- inif.close()
-
-
-def dump(ddict, ffile, mode="w", fmat=None):
- """Dump dictionary to a file
-
- :param ddict: Dictionary with string keys
- :param ffile: File name or file-like object with a ``write`` method
- :param str fmat: Output format: ``"json"``, ``"hdf5"`` or ``"ini"``.
- When None (the default), it uses the filename extension as the format.
- Dumping to a HDF5 file requires `h5py <http://www.h5py.org/>`_ to be
- installed.
- :param str mode: File opening mode (``w``, ``a``, ``w+``…)
- Default is *"w"*, write mode, overwrite if exists.
- :raises IOError: if file format is not supported
- """
- if fmat is None:
- # If file-like object get its name, else use ffile as filename
- filename = getattr(ffile, 'name', ffile)
- fmat = os.path.splitext(filename)[1][1:] # Strip extension leading '.'
- fmat = fmat.lower()
-
- if fmat == "json":
- dicttojson(ddict, ffile, indent=2, mode=mode)
- elif fmat in ["hdf5", "h5"]:
- dicttoh5(ddict, ffile, mode=mode)
- elif fmat in ["ini", "cfg"]:
- dicttoini(ddict, ffile, mode=mode)
- else:
- raise IOError("Unknown format " + fmat)
-
-
-def load(ffile, fmat=None):
- """Load dictionary from a file
-
- When loading from a JSON or INI file, an OrderedDict is returned to
- preserve the values' insertion order.
-
- :param ffile: File name or file-like object with a ``read`` method
- :param fmat: Input format: ``json``, ``hdf5`` or ``ini``.
- When None (the default), it uses the filename extension as the format.
- Loading from a HDF5 file requires `h5py <http://www.h5py.org/>`_ to be
- installed.
- :return: Dictionary (ordered dictionary for JSON and INI)
- :raises IOError: if file format is not supported
- """
- must_be_closed = False
- if not hasattr(ffile, "read"):
- f = open(ffile, "r")
- fname = ffile
- must_be_closed = True
- else:
- f = ffile
- fname = ffile.name
-
- try:
- if fmat is None: # Use file extension as format
- fmat = os.path.splitext(fname)[1][1:] # Strip extension leading '.'
- fmat = fmat.lower()
-
- if fmat == "json":
- return json.load(f, object_pairs_hook=OrderedDict)
- if fmat in ["hdf5", "h5"]:
- return h5todict(fname)
- elif fmat in ["ini", "cfg"]:
- return ConfigDict(filelist=[fname])
- else:
- raise IOError("Unknown format " + fmat)
- finally:
- if must_be_closed:
- f.close()
diff --git a/silx/io/fabioh5.py b/silx/io/fabioh5.py
deleted file mode 100755
index 2fd719d..0000000
--- a/silx/io/fabioh5.py
+++ /dev/null
@@ -1,1051 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""This module provides functions to read fabio images as an HDF5 file.
-
- >>> import silx.io.fabioh5
- >>> f = silx.io.fabioh5.File("foobar.edf")
-
-.. note:: This module has a dependency on the `h5py <http://www.h5py.org/>`_
- and `fabio <https://github.com/silx-kit/fabio>`_ libraries,
- which are not mandatory dependencies for `silx`.
-
-"""
-
-import collections
-import datetime
-import logging
-import numbers
-import os
-
-import fabio.file_series
-import numpy
-import six
-
-from . import commonh5
-from silx import version as silx_version
-import silx.utils.number
-import h5py
-
-
-_logger = logging.getLogger(__name__)
-
-
-_fabio_extensions = set([])
-
-
-def supported_extensions():
- """Returns all extensions supported by fabio.
-
- :returns: A set containing extensions like "*.edf".
- :rtype: Set[str]
- """
- global _fabio_extensions
- if len(_fabio_extensions) > 0:
- return _fabio_extensions
-
- formats = fabio.fabioformats.get_classes(reader=True)
- all_extensions = set([])
-
- for reader in formats:
- if not hasattr(reader, "DEFAULT_EXTENSIONS"):
- continue
-
- ext = reader.DEFAULT_EXTENSIONS
- ext = ["*.%s" % e for e in ext]
- all_extensions.update(ext)
-
- _fabio_extensions = set(all_extensions)
- return _fabio_extensions
-
-
-class _FileSeries(fabio.file_series.file_series):
- """
- .. note:: Overwrite a function to fix an issue in fabio.
- """
- def jump(self, num):
- """
- Goto a position in sequence
- """
- assert num < len(self) and num >= 0, "num out of range"
- self._current = num
- return self[self._current]
-
-
-class FrameData(commonh5.LazyLoadableDataset):
- """Expose a cube of image from a Fabio file using `FabioReader` as
- cache."""
-
- def __init__(self, name, fabio_reader, parent=None):
- if fabio_reader.is_spectrum():
- attrs = {"interpretation": "spectrum"}
- else:
- attrs = {"interpretation": "image"}
- commonh5.LazyLoadableDataset.__init__(self, name, parent, attrs=attrs)
- self.__fabio_reader = fabio_reader
- self._shape = None
- self._dtype = None
-
- def _create_data(self):
- return self.__fabio_reader.get_data()
-
- def _update_cache(self):
- if isinstance(self.__fabio_reader.fabio_file(),
- fabio.file_series.file_series):
- # Reading all the files is taking too much time
- # Reach the information from the only first frame
- first_image = self.__fabio_reader.fabio_file().first_image()
- self._dtype = first_image.data.dtype
- shape0 = self.__fabio_reader.frame_count()
- shape1, shape2 = first_image.data.shape
- self._shape = shape0, shape1, shape2
- else:
- self._dtype = super(commonh5.LazyLoadableDataset, self).dtype
- self._shape = super(commonh5.LazyLoadableDataset, self).shape
-
- @property
- def dtype(self):
- if self._dtype is None:
- self._update_cache()
- return self._dtype
-
- @property
- def shape(self):
- if self._shape is None:
- self._update_cache()
- return self._shape
-
- def __iter__(self):
- for frame in self.__fabio_reader.iter_frames():
- yield frame.data
-
- def __getitem__(self, item):
- # optimization for fetching a single frame if data not already loaded
- if not self._is_initialized:
- if isinstance(item, six.integer_types) and \
- isinstance(self.__fabio_reader.fabio_file(),
- fabio.file_series.file_series):
- if item < 0:
- # negative indexing
- item += len(self)
- return self.__fabio_reader.fabio_file().jump_image(item).data
- return super(FrameData, self).__getitem__(item)
-
-
-class RawHeaderData(commonh5.LazyLoadableDataset):
- """Lazy loadable raw header"""
-
- def __init__(self, name, fabio_reader, parent=None):
- commonh5.LazyLoadableDataset.__init__(self, name, parent)
- self.__fabio_reader = fabio_reader
-
- def _create_data(self):
- """Initialize hold data by merging all headers of each frames.
- """
- headers = []
- types = set([])
- for fabio_frame in self.__fabio_reader.iter_frames():
- header = fabio_frame.header
-
- data = []
- for key, value in header.items():
- data.append("%s: %s" % (str(key), str(value)))
-
- data = "\n".join(data)
- try:
- line = data.encode("ascii")
- types.add(numpy.string_)
- except UnicodeEncodeError:
- try:
- line = data.encode("utf-8")
- types.add(numpy.unicode_)
- except UnicodeEncodeError:
- # Fallback in void
- line = numpy.void(data)
- types.add(numpy.void)
-
- headers.append(line)
-
- if numpy.void in types:
- dtype = numpy.void
- elif numpy.unicode_ in types:
- dtype = numpy.unicode_
- else:
- dtype = numpy.string_
-
- if dtype == numpy.unicode_:
- # h5py only support vlen unicode
- dtype = h5py.special_dtype(vlen=six.text_type)
-
- return numpy.array(headers, dtype=dtype)
-
-
-class MetadataGroup(commonh5.LazyLoadableGroup):
- """Abstract class for groups containing a reference to a fabio image.
- """
-
- def __init__(self, name, metadata_reader, kind, parent=None, attrs=None):
- commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
- self.__metadata_reader = metadata_reader
- self.__kind = kind
-
- def _create_child(self):
- keys = self.__metadata_reader.get_keys(self.__kind)
- for name in keys:
- data = self.__metadata_reader.get_value(self.__kind, name)
- dataset = commonh5.Dataset(name, data)
- self.add_node(dataset)
-
- @property
- def _metadata_reader(self):
- return self.__metadata_reader
-
-
-class DetectorGroup(commonh5.LazyLoadableGroup):
- """Define the detector group (sub group of instrument) using Fabio data.
- """
-
- def __init__(self, name, fabio_reader, parent=None, attrs=None):
- if attrs is None:
- attrs = {"NX_class": "NXdetector"}
- commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
- self.__fabio_reader = fabio_reader
-
- def _create_child(self):
- data = FrameData("data", self.__fabio_reader)
- self.add_node(data)
-
- # TODO we should add here Nexus informations we can extract from the
- # metadata
-
- others = MetadataGroup("others", self.__fabio_reader, kind=FabioReader.DEFAULT)
- self.add_node(others)
-
-
-class ImageGroup(commonh5.LazyLoadableGroup):
- """Define the image group (sub group of measurement) using Fabio data.
- """
-
- def __init__(self, name, fabio_reader, parent=None, attrs=None):
- commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
- self.__fabio_reader = fabio_reader
-
- def _create_child(self):
- basepath = self.parent.parent.name
- data = commonh5.SoftLink("data", path=basepath + "/instrument/detector_0/data")
- self.add_node(data)
- detector = commonh5.SoftLink("info", path=basepath + "/instrument/detector_0")
- self.add_node(detector)
-
-
-class NxDataPreviewGroup(commonh5.LazyLoadableGroup):
- """Define the NxData group which is used as the default NXdata to show the
- content of the file.
- """
-
- def __init__(self, name, fabio_reader, parent=None):
- if fabio_reader.is_spectrum():
- interpretation = "spectrum"
- else:
- interpretation = "image"
- attrs = {
- "NX_class": "NXdata",
- "interpretation": interpretation,
- "signal": "data",
- }
- commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
- self.__fabio_reader = fabio_reader
-
- def _create_child(self):
- basepath = self.parent.name
- data = commonh5.SoftLink("data", path=basepath + "/instrument/detector_0/data")
- self.add_node(data)
-
-
-class SampleGroup(commonh5.LazyLoadableGroup):
- """Define the image group (sub group of measurement) using Fabio data.
- """
-
- def __init__(self, name, fabio_reader, parent=None):
- attrs = {"NXclass": "NXsample"}
- commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
- self.__fabio_reader = fabio_reader
-
- def _create_child(self):
- if self.__fabio_reader.has_ub_matrix():
- scalar = {"interpretation": "scalar"}
- data = self.__fabio_reader.get_unit_cell_abc()
- data = commonh5.Dataset("unit_cell_abc", data, attrs=scalar)
- self.add_node(data)
- unit_cell_data = numpy.zeros((1, 6), numpy.float32)
- unit_cell_data[0, :3] = data
- data = self.__fabio_reader.get_unit_cell_alphabetagamma()
- data = commonh5.Dataset("unit_cell_alphabetagamma", data, attrs=scalar)
- self.add_node(data)
- unit_cell_data[0, 3:] = data
- data = commonh5.Dataset("unit_cell", unit_cell_data, attrs=scalar)
- self.add_node(data)
- data = self.__fabio_reader.get_ub_matrix()
- data = commonh5.Dataset("ub_matrix", data, attrs=scalar)
- self.add_node(data)
-
-
-class MeasurementGroup(commonh5.LazyLoadableGroup):
- """Define the measurement group for fabio file.
- """
-
- def __init__(self, name, fabio_reader, parent=None, attrs=None):
- commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
- self.__fabio_reader = fabio_reader
-
- def _create_child(self):
- keys = self.__fabio_reader.get_keys(FabioReader.COUNTER)
-
- # create image measurement but take care that no other metadata use
- # this name
- for i in range(1000):
- name = "image_%i" % i
- if name not in keys:
- data = ImageGroup(name, self.__fabio_reader)
- self.add_node(data)
- break
- else:
- raise Exception("image_i for 0..1000 already used")
-
- # add all counters
- for name in keys:
- data = self.__fabio_reader.get_value(FabioReader.COUNTER, name)
- dataset = commonh5.Dataset(name, data)
- self.add_node(dataset)
-
-
-class FabioReader(object):
- """Class which read and cache data and metadata from a fabio image."""
-
- DEFAULT = 0
- COUNTER = 1
- POSITIONER = 2
-
- def __init__(self, file_name=None, fabio_image=None, file_series=None):
- """
- Constructor
-
- :param str file_name: File name of the image file to read
- :param fabio.fabioimage.FabioImage fabio_image: An already openned
- :class:`fabio.fabioimage.FabioImage` instance.
- :param Union[list[str],fabio.file_series.file_series] file_series: An
- list of file name or a :class:`fabio.file_series.file_series`
- instance
- """
- self.__at_least_32bits = False
- self.__signed_type = False
-
- self.__load(file_name, fabio_image, file_series)
- self.__counters = {}
- self.__positioners = {}
- self.__measurements = {}
- self.__key_filters = set([])
- self.__data = None
- self.__frame_count = self.frame_count()
- self._read()
-
- def __load(self, file_name=None, fabio_image=None, file_series=None):
- if file_name is not None and fabio_image:
- raise TypeError("Parameters file_name and fabio_image are mutually exclusive.")
- if file_name is not None and fabio_image:
- raise TypeError("Parameters fabio_image and file_series are mutually exclusive.")
-
- self.__must_be_closed = False
-
- if file_name is not None:
- self.__fabio_file = fabio.open(file_name)
- self.__must_be_closed = True
- elif fabio_image is not None:
- if isinstance(fabio_image, fabio.fabioimage.FabioImage):
- self.__fabio_file = fabio_image
- else:
- raise TypeError("FabioImage expected but %s found.", fabio_image.__class__)
- elif file_series is not None:
- if isinstance(file_series, list):
- self.__fabio_file = _FileSeries(file_series)
- elif isinstance(file_series, fabio.file_series.file_series):
- self.__fabio_file = file_series
- else:
- raise TypeError("file_series or list expected but %s found.", file_series.__class__)
-
- def close(self):
- """Close the object, and free up associated resources.
-
- The associated FabioImage is closed only if the object was created from
- a filename by this class itself.
-
- After calling this method, attempts to use the object (and children)
- may fail.
- """
- if self.__must_be_closed:
- # Make sure the API of fabio provide it a 'close' method
- # TODO the test can be removed if fabio version >= 0.8
- if hasattr(self.__fabio_file, "close"):
- self.__fabio_file.close()
- self.__fabio_file = None
-
- def fabio_file(self):
- return self.__fabio_file
-
- def frame_count(self):
- """Returns the number of frames available."""
- if isinstance(self.__fabio_file, fabio.file_series.file_series):
- return len(self.__fabio_file)
- elif isinstance(self.__fabio_file, fabio.fabioimage.FabioImage):
- return self.__fabio_file.nframes
- else:
- raise TypeError("Unsupported type %s", self.__fabio_file.__class__)
-
- def iter_frames(self):
- """Iter all the available frames.
-
- A frame provides at least `data` and `header` attributes.
- """
- if isinstance(self.__fabio_file, fabio.file_series.file_series):
- for file_number in range(len(self.__fabio_file)):
- with self.__fabio_file.jump_image(file_number) as fabio_image:
- # return the first frame only
- assert(fabio_image.nframes == 1)
- yield fabio_image
- elif isinstance(self.__fabio_file, fabio.fabioimage.FabioImage):
- for frame_count in range(self.__fabio_file.nframes):
- if self.__fabio_file.nframes == 1:
- yield self.__fabio_file
- else:
- yield self.__fabio_file.getframe(frame_count)
- else:
- raise TypeError("Unsupported type %s", self.__fabio_file.__class__)
-
- def _create_data(self):
- """Initialize hold data by merging all frames into a single cube.
-
- Choose the cube size which fit the best the data. If some images are
- smaller than expected, the empty space is set to 0.
-
- The computation is cached into the class, and only done ones.
- """
- images = []
- for fabio_frame in self.iter_frames():
- images.append(fabio_frame.data)
-
- # returns the data without extra dim in case of single frame
- if len(images) == 1:
- return images[0]
-
- # get the max size
- max_dim = max([i.ndim for i in images])
- max_shape = [0] * max_dim
- for image in images:
- for dim in range(image.ndim):
- if image.shape[dim] > max_shape[dim]:
- max_shape[dim] = image.shape[dim]
- max_shape = tuple(max_shape)
-
- # fix smallest images
- for index, image in enumerate(images):
- if image.shape == max_shape:
- continue
- location = [slice(0, i) for i in image.shape]
- while len(location) < max_dim:
- location.append(0)
- normalized_image = numpy.zeros(max_shape, dtype=image.dtype)
- normalized_image[tuple(location)] = image
- images[index] = normalized_image
-
- # create a cube
- return numpy.array(images)
-
- def __get_dict(self, kind):
- """Returns a dictionary from according to an expected kind"""
- if kind == self.DEFAULT:
- return self.__measurements
- elif kind == self.COUNTER:
- return self.__counters
- elif kind == self.POSITIONER:
- return self.__positioners
- else:
- raise Exception("Unexpected kind %s", kind)
-
- def get_data(self):
- """Returns a cube from all available data from frames
-
- :rtype: numpy.ndarray
- """
- if self.__data is None:
- self.__data = self._create_data()
- return self.__data
-
- def get_keys(self, kind):
- """Get all available keys according to a kind of metadata.
-
- :rtype: list
- """
- return self.__get_dict(kind).keys()
-
- def get_value(self, kind, name):
- """Get a metadata value according to the kind and the name.
-
- :rtype: numpy.ndarray
- """
- value = self.__get_dict(kind)[name]
- if not isinstance(value, numpy.ndarray):
- if kind in [self.COUNTER, self.POSITIONER]:
- # Force normalization for counters and positioners
- old = self._set_vector_normalization(at_least_32bits=True, signed_type=True)
- else:
- old = None
- value = self._convert_metadata_vector(value)
- self.__get_dict(kind)[name] = value
- if old is not None:
- self._set_vector_normalization(*old)
- return value
-
- def _set_counter_value(self, frame_id, name, value):
- """Set a counter metadata according to the frame id"""
- if name not in self.__counters:
- self.__counters[name] = [None] * self.__frame_count
- self.__counters[name][frame_id] = value
-
- def _set_positioner_value(self, frame_id, name, value):
- """Set a positioner metadata according to the frame id"""
- if name not in self.__positioners:
- self.__positioners[name] = [None] * self.__frame_count
- self.__positioners[name][frame_id] = value
-
- def _set_measurement_value(self, frame_id, name, value):
- """Set a measurement metadata according to the frame id"""
- if name not in self.__measurements:
- self.__measurements[name] = [None] * self.__frame_count
- self.__measurements[name][frame_id] = value
-
- def _enable_key_filters(self, fabio_file):
- self.__key_filters.clear()
- if hasattr(fabio_file, "RESERVED_HEADER_KEYS"):
- # Provided in fabio 0.5
- for key in fabio_file.RESERVED_HEADER_KEYS:
- self.__key_filters.add(key.lower())
-
- def _read(self):
- """Read all metadata from the fabio file and store it into this
- object."""
-
- file_series = isinstance(self.__fabio_file, fabio.file_series.file_series)
- if not file_series:
- self._enable_key_filters(self.__fabio_file)
-
- for frame_id, fabio_frame in enumerate(self.iter_frames()):
- if file_series:
- self._enable_key_filters(fabio_frame)
- self._read_frame(frame_id, fabio_frame.header)
-
- def _is_filtered_key(self, key):
- """
- If this function returns True, the :meth:`_read_key` while not be
- called with this `key`while reading the metatdata frame.
-
- :param str key: A key of the metadata
- :rtype: bool
- """
- return key.lower() in self.__key_filters
-
- def _read_frame(self, frame_id, header):
- """Read all metadata from a frame and store it into this
- object."""
- for key, value in header.items():
- if self._is_filtered_key(key):
- continue
- self._read_key(frame_id, key, value)
-
- def _read_key(self, frame_id, name, value):
- """Read a key from the metadata and cache it into this object."""
- self._set_measurement_value(frame_id, name, value)
-
- def _set_vector_normalization(self, at_least_32bits, signed_type):
- previous = self.__at_least_32bits, self.__signed_type
- self.__at_least_32bits = at_least_32bits
- self.__signed_type = signed_type
- return previous
-
- def _normalize_vector_type(self, dtype):
- """Normalize the """
- if self.__at_least_32bits:
- if numpy.issubdtype(dtype, numpy.signedinteger):
- dtype = numpy.result_type(dtype, numpy.uint32)
- if numpy.issubdtype(dtype, numpy.unsignedinteger):
- dtype = numpy.result_type(dtype, numpy.uint32)
- elif numpy.issubdtype(dtype, numpy.floating):
- dtype = numpy.result_type(dtype, numpy.float32)
- elif numpy.issubdtype(dtype, numpy.complexfloating):
- dtype = numpy.result_type(dtype, numpy.complex64)
- if self.__signed_type:
- if numpy.issubdtype(dtype, numpy.unsignedinteger):
- signed = numpy.dtype("%s%i" % ('i', dtype.itemsize))
- dtype = numpy.result_type(dtype, signed)
- return dtype
-
- def _convert_metadata_vector(self, values):
- """Convert a list of numpy data into a numpy array with the better
- fitting type."""
- converted = []
- types = set([])
- has_none = False
- is_array = False
- array = []
-
- for v in values:
- if v is None:
- converted.append(None)
- has_none = True
- array.append(None)
- else:
- c = self._convert_value(v)
- if c.shape != tuple():
- array.append(v.split(" "))
- is_array = True
- else:
- array.append(v)
- converted.append(c)
- types.add(c.dtype)
-
- if has_none and len(types) == 0:
- # That's a list of none values
- return numpy.array([0] * len(values), numpy.int8)
-
- result_type = numpy.result_type(*types)
-
- if issubclass(result_type.type, numpy.string_):
- # use the raw data to create the array
- result = values
- elif issubclass(result_type.type, numpy.unicode_):
- # use the raw data to create the array
- result = values
- else:
- result = converted
-
- result_type = self._normalize_vector_type(result_type)
-
- if has_none:
- # Fix missing data according to the array type
- if result_type.kind == "S":
- none_value = b""
- elif result_type.kind == "U":
- none_value = u""
- elif result_type.kind == "f":
- none_value = numpy.float64("NaN")
- elif result_type.kind == "i":
- none_value = numpy.int64(0)
- elif result_type.kind == "u":
- none_value = numpy.int64(0)
- elif result_type.kind == "b":
- none_value = numpy.bool_(False)
- else:
- none_value = None
-
- for index, r in enumerate(result):
- if r is not None:
- continue
- result[index] = none_value
- values[index] = none_value
- array[index] = none_value
-
- if result_type.kind in "uifd" and len(types) > 1 and len(values) > 1:
- # Catch numerical precision
- if is_array and len(array) > 1:
- return numpy.array(array, dtype=result_type)
- else:
- return numpy.array(values, dtype=result_type)
- return numpy.array(result, dtype=result_type)
-
- def _convert_value(self, value):
- """Convert a string into a numpy object (scalar or array).
-
- The value is most of the time a string, but it can be python object
- in case if TIFF decoder for example.
- """
- if isinstance(value, list):
- # convert to a numpy array
- return numpy.array(value)
- if isinstance(value, dict):
- # convert to a numpy associative array
- key_dtype = numpy.min_scalar_type(list(value.keys()))
- value_dtype = numpy.min_scalar_type(list(value.values()))
- associative_type = [('key', key_dtype), ('value', value_dtype)]
- assert key_dtype.kind != "O" and value_dtype.kind != "O"
- return numpy.array(list(value.items()), dtype=associative_type)
- if isinstance(value, numbers.Number):
- dtype = numpy.min_scalar_type(value)
- assert dtype.kind != "O"
- return dtype.type(value)
-
- if isinstance(value, six.binary_type):
- try:
- value = value.decode('utf-8')
- except UnicodeDecodeError:
- return numpy.void(value)
-
- if " " in value:
- result = self._convert_list(value)
- else:
- result = self._convert_scalar_value(value)
- return result
-
- def _convert_scalar_value(self, value):
- """Convert a string into a numpy int or float.
-
- If it is not possible it returns a numpy string.
- """
- try:
- numpy_type = silx.utils.number.min_numerical_convertible_type(value)
- converted = numpy_type(value)
- except ValueError:
- converted = numpy.string_(value)
- return converted
-
- def _convert_list(self, value):
- """Convert a string into a typed numpy array.
-
- If it is not possible it returns a numpy string.
- """
- try:
- numpy_values = []
- values = value.split(" ")
- types = set([])
- for string_value in values:
- v = self._convert_scalar_value(string_value)
- numpy_values.append(v)
- types.add(v.dtype.type)
-
- result_type = numpy.result_type(*types)
-
- if issubclass(result_type.type, (numpy.string_, six.binary_type)):
- # use the raw data to create the result
- return numpy.string_(value)
- elif issubclass(result_type.type, (numpy.unicode_, six.text_type)):
- # use the raw data to create the result
- return numpy.unicode_(value)
- else:
- if len(types) == 1:
- return numpy.array(numpy_values, dtype=result_type)
- else:
- return numpy.array(values, dtype=result_type)
- except ValueError:
- return numpy.string_(value)
-
- def has_sample_information(self):
- """Returns true if there is information about the sample in the
- file
-
- :rtype: bool
- """
- return self.has_ub_matrix()
-
- def has_ub_matrix(self):
- """Returns true if a UB matrix is available.
-
- :rtype: bool
- """
- return False
-
- def is_spectrum(self):
- """Returns true if the data should be interpreted as
- MCA data.
-
- :rtype: bool
- """
- return False
-
-
-class EdfFabioReader(FabioReader):
- """Class which read and cache data and metadata from a fabio image.
-
- It is mostly the same as FabioReader, but counter_mne and
- motor_mne are parsed using a special way.
- """
-
- def __init__(self, file_name=None, fabio_image=None, file_series=None):
- FabioReader.__init__(self, file_name, fabio_image, file_series)
- self.__unit_cell_abc = None
- self.__unit_cell_alphabetagamma = None
- self.__ub_matrix = None
-
- def _read_frame(self, frame_id, header):
- """Overwrite the method to check and parse special keys: counter and
- motors keys."""
- self.__catch_keys = set([])
- if "motor_pos" in header and "motor_mne" in header:
- self.__catch_keys.add("motor_pos")
- self.__catch_keys.add("motor_mne")
- self._read_mnemonic_key(frame_id, "motor", header)
- if "counter_pos" in header and "counter_mne" in header:
- self.__catch_keys.add("counter_pos")
- self.__catch_keys.add("counter_mne")
- self._read_mnemonic_key(frame_id, "counter", header)
- FabioReader._read_frame(self, frame_id, header)
-
- def _is_filtered_key(self, key):
- if key in self.__catch_keys:
- return True
- return FabioReader._is_filtered_key(self, key)
-
- def _get_mnemonic_key(self, base_key, header):
- mnemonic_values_key = base_key + "_mne"
- mnemonic_values = header.get(mnemonic_values_key, "")
- mnemonic_values = mnemonic_values.split()
- pos_values_key = base_key + "_pos"
- pos_values = header.get(pos_values_key, "")
- pos_values = pos_values.split()
-
- result = collections.OrderedDict()
- nbitems = max(len(mnemonic_values), len(pos_values))
- for i in range(nbitems):
- if i < len(mnemonic_values):
- mnemonic = mnemonic_values[i]
- else:
- # skip the element
- continue
-
- if i < len(pos_values):
- pos = pos_values[i]
- else:
- pos = None
-
- result[mnemonic] = pos
- return result
-
- def _read_mnemonic_key(self, frame_id, base_key, header):
- """Parse a mnemonic key"""
- is_counter = base_key == "counter"
- is_positioner = base_key == "motor"
- data = self._get_mnemonic_key(base_key, header)
-
- for mnemonic, pos in data.items():
- if is_counter:
- self._set_counter_value(frame_id, mnemonic, pos)
- elif is_positioner:
- self._set_positioner_value(frame_id, mnemonic, pos)
- else:
- raise Exception("State unexpected (base_key: %s)" % base_key)
-
- def _get_first_header(self):
- """
- ..note:: This function can be cached
- """
- fabio_file = self.fabio_file()
- if isinstance(fabio_file, fabio.file_series.file_series):
- return fabio_file.jump_image(0).header
- return fabio_file.header
-
- def has_ub_matrix(self):
- """Returns true if a UB matrix is available.
-
- :rtype: bool
- """
- header = self._get_first_header()
- expected_keys = set(["UB_mne", "UB_pos", "sample_mne", "sample_pos"])
- return expected_keys.issubset(header)
-
- def parse_ub_matrix(self):
- header = self._get_first_header()
- ub_data = self._get_mnemonic_key("UB", header)
- s_data = self._get_mnemonic_key("sample", header)
- if len(ub_data) > 9:
- _logger.warning("UB_mne and UB_pos contains more than expected keys.")
- if len(s_data) > 6:
- _logger.warning("sample_mne and sample_pos contains more than expected keys.")
-
- data = numpy.array([s_data["U0"], s_data["U1"], s_data["U2"]], dtype=float)
- unit_cell_abc = data
-
- data = numpy.array([s_data["U3"], s_data["U4"], s_data["U5"]], dtype=float)
- unit_cell_alphabetagamma = data
-
- ub_matrix = numpy.array([[
- [ub_data["UB0"], ub_data["UB1"], ub_data["UB2"]],
- [ub_data["UB3"], ub_data["UB4"], ub_data["UB5"]],
- [ub_data["UB6"], ub_data["UB7"], ub_data["UB8"]]]], dtype=float)
-
- self.__unit_cell_abc = unit_cell_abc
- self.__unit_cell_alphabetagamma = unit_cell_alphabetagamma
- self.__ub_matrix = ub_matrix
-
- def get_unit_cell_abc(self):
- """Get a numpy array data as defined for the dataset unit_cell_abc
- from the NXsample dataset.
-
- :rtype: numpy.ndarray
- """
- if self.__unit_cell_abc is None:
- self.parse_ub_matrix()
- return self.__unit_cell_abc
-
- def get_unit_cell_alphabetagamma(self):
- """Get a numpy array data as defined for the dataset
- unit_cell_alphabetagamma from the NXsample dataset.
-
- :rtype: numpy.ndarray
- """
- if self.__unit_cell_alphabetagamma is None:
- self.parse_ub_matrix()
- return self.__unit_cell_alphabetagamma
-
- def get_ub_matrix(self):
- """Get a numpy array data as defined for the dataset ub_matrix
- from the NXsample dataset.
-
- :rtype: numpy.ndarray
- """
- if self.__ub_matrix is None:
- self.parse_ub_matrix()
- return self.__ub_matrix
-
- def is_spectrum(self):
- """Returns true if the data should be interpreted as
- MCA data.
- EDF files or file series, with two or more header names starting with
- "MCA", should be interpreted as MCA data.
-
- :rtype: bool
- """
- count = 0
- for key in self._get_first_header():
- if key.lower().startswith("mca"):
- count += 1
- if count >= 2:
- return True
- return False
-
-
-class File(commonh5.File):
- """Class which handle a fabio image as a mimick of a h5py.File.
- """
-
- def __init__(self, file_name=None, fabio_image=None, file_series=None):
- """
- Constructor
-
- :param str file_name: File name of the image file to read
- :param fabio.fabioimage.FabioImage fabio_image: An already openned
- :class:`fabio.fabioimage.FabioImage` instance.
- :param Union[list[str],fabio.file_series.file_series] file_series: An
- list of file name or a :class:`fabio.file_series.file_series`
- instance
- """
- self.__fabio_reader = self.create_fabio_reader(file_name, fabio_image, file_series)
- if fabio_image is not None:
- file_name = fabio_image.filename
- scan = self.create_scan_group(self.__fabio_reader)
-
- attrs = {"NX_class": "NXroot",
- "file_time": datetime.datetime.now().isoformat(),
- "creator": "silx %s" % silx_version,
- "default": scan.basename}
- if file_name is not None:
- attrs["file_name"] = file_name
- commonh5.File.__init__(self, name=file_name, attrs=attrs)
- self.add_node(scan)
-
- def create_scan_group(self, fabio_reader):
- """Factory to create the scan group.
-
- :param FabioImage fabio_image: A Fabio image
- :param FabioReader fabio_reader: A reader for the Fabio image
- :rtype: commonh5.Group
- """
- nxdata = NxDataPreviewGroup("image", fabio_reader)
- scan_attrs = {
- "NX_class": "NXentry",
- "default": nxdata.basename,
- }
- scan = commonh5.Group("scan_0", attrs=scan_attrs)
- instrument = commonh5.Group("instrument", attrs={"NX_class": "NXinstrument"})
- measurement = MeasurementGroup("measurement", fabio_reader, attrs={"NX_class": "NXcollection"})
- file_ = commonh5.Group("file", attrs={"NX_class": "NXcollection"})
- positioners = MetadataGroup("positioners", fabio_reader, FabioReader.POSITIONER, attrs={"NX_class": "NXpositioner"})
- raw_header = RawHeaderData("scan_header", fabio_reader, self)
- detector = DetectorGroup("detector_0", fabio_reader)
-
- scan.add_node(instrument)
- instrument.add_node(positioners)
- instrument.add_node(file_)
- instrument.add_node(detector)
- file_.add_node(raw_header)
- scan.add_node(measurement)
- scan.add_node(nxdata)
-
- if fabio_reader.has_sample_information():
- sample = SampleGroup("sample", fabio_reader)
- scan.add_node(sample)
-
- return scan
-
- def create_fabio_reader(self, file_name, fabio_image, file_series):
- """Factory to create fabio reader.
-
- :rtype: FabioReader"""
- use_edf_reader = False
- first_file_name = None
- first_image = None
-
- if isinstance(file_series, list):
- first_file_name = file_series[0]
- elif isinstance(file_series, fabio.file_series.file_series):
- first_image = file_series.first_image()
- elif fabio_image is not None:
- first_image = fabio_image
- else:
- first_file_name = file_name
-
- if first_file_name is not None:
- _, ext = os.path.splitext(first_file_name)
- ext = ext[1:]
- edfimage = fabio.edfimage.EdfImage
- if hasattr(edfimage, "DEFAULT_EXTENTIONS"):
- # Typo on fabio 0.5
- edf_extensions = edfimage.DEFAULT_EXTENTIONS
- else:
- edf_extensions = edfimage.DEFAULT_EXTENSIONS
- use_edf_reader = ext in edf_extensions
- elif first_image is not None:
- use_edf_reader = isinstance(first_image, fabio.edfimage.EdfImage)
- else:
- assert(False)
-
- if use_edf_reader:
- reader = EdfFabioReader(file_name, fabio_image, file_series)
- else:
- reader = FabioReader(file_name, fabio_image, file_series)
- return reader
-
- def close(self):
- """Close the object, and free up associated resources.
-
- After calling this method, attempts to use the object (and children)
- may fail.
- """
- self.__fabio_reader.close()
- self.__fabio_reader = None
diff --git a/silx/io/h5py_utils.py b/silx/io/h5py_utils.py
deleted file mode 100644
index cbdb44a..0000000
--- a/silx/io/h5py_utils.py
+++ /dev/null
@@ -1,317 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-This module provides utility methods on top of h5py, mainly to handle
-parallel writing and reading.
-"""
-
-__authors__ = ["W. de Nolf"]
-__license__ = "MIT"
-__date__ = "27/01/2020"
-
-
-import os
-import traceback
-import h5py
-
-from .._version import calc_hexversion
-from ..utils import retry as retry_mod
-
-H5PY_HEX_VERSION = calc_hexversion(*h5py.version.version_tuple[:3])
-HDF5_HEX_VERSION = calc_hexversion(*h5py.version.hdf5_version_tuple[:3])
-
-HDF5_SWMR_VERSION = calc_hexversion(*h5py.get_config().swmr_min_hdf5_version[:3])
-HDF5_TRACK_ORDER_VERSION = calc_hexversion(2, 9, 0)
-
-HAS_SWMR = HDF5_HEX_VERSION >= HDF5_SWMR_VERSION
-HAS_TRACK_ORDER = H5PY_HEX_VERSION >= HDF5_TRACK_ORDER_VERSION
-
-
-def _is_h5py_exception(e):
- for frame in traceback.walk_tb(e.__traceback__):
- if frame[0].f_locals.get("__package__", None) == "h5py":
- return True
- return False
-
-
-def _retry_h5py_error(e):
- """
- :param BaseException e:
- :returns bool:
- """
- if _is_h5py_exception(e):
- if isinstance(e, (OSError, RuntimeError)):
- return True
- elif isinstance(e, KeyError):
- # For example this needs to be retried:
- # KeyError: 'Unable to open object (bad object header version number)'
- return "Unable to open object" in str(e)
- elif isinstance(e, retry_mod.RetryError):
- return True
- return False
-
-
-def retry(**kw):
- """Decorator for a method that needs to be executed until it not longer
- fails on HDF5 IO. Mainly used for reading an HDF5 file that is being
- written.
-
- :param \**kw: see `silx.utils.retry`
- """
- kw.setdefault("retry_on_error", _retry_h5py_error)
- return retry_mod.retry(**kw)
-
-
-def retry_contextmanager(**kw):
- """Decorator to make a context manager from a method that needs to be
- entered until it not longer fails on HDF5 IO. Mainly used for reading
- an HDF5 file that is being written.
-
- :param \**kw: see `silx.utils.retry_contextmanager`
- """
- kw.setdefault("retry_on_error", _retry_h5py_error)
- return retry_mod.retry_contextmanager(**kw)
-
-
-def retry_in_subprocess(**kw):
- """Same as `retry` but it also retries segmentation faults.
-
- On Window you cannot use this decorator with the "@" syntax:
-
- .. code-block:: python
-
- def _method(*args, **kw):
- ...
-
- method = retry_in_subprocess()(_method)
-
- :param \**kw: see `silx.utils.retry_in_subprocess`
- """
- kw.setdefault("retry_on_error", _retry_h5py_error)
- return retry_mod.retry_in_subprocess(**kw)
-
-
-def group_has_end_time(h5item):
- """Returns True when the HDF5 item is a Group with an "end_time"
- dataset. A reader can use this as an indication that the Group
- has been fully written (at least if the writer supports this).
-
- :param Union[h5py.Group,h5py.Dataset] h5item:
- :returns bool:
- """
- if isinstance(h5item, h5py.Group):
- return "end_time" in h5item
- else:
- return False
-
-
-@retry_contextmanager()
-def open_item(filename, name, retry_invalid=False, validate=None):
- """Yield an HDF5 dataset or group (retry until it can be instantiated).
-
- :param str filename:
- :param bool retry_invalid: retry when item is missing or not valid
- :param callable or None validate:
- :yields Dataset, Group or None:
- """
- with File(filename) as h5file:
- try:
- item = h5file[name]
- except KeyError as e:
- if "doesn't exist" in str(e):
- if retry_invalid:
- raise retry_mod.RetryError
- else:
- item = None
- else:
- raise
- if callable(validate) and item is not None:
- if not validate(item):
- if retry_invalid:
- raise retry_mod.RetryError
- else:
- item = None
- yield item
-
-
-def _top_level_names(filename, include_only=group_has_end_time):
- """Return all valid top-level HDF5 names.
-
- :param str filename:
- :param callable or None include_only:
- :returns list(str):
- """
- with File(filename) as h5file:
- try:
- if callable(include_only):
- return [name for name in h5file["/"] if include_only(h5file[name])]
- else:
- return list(h5file["/"])
- except KeyError:
- raise retry_mod.RetryError
-
-
-top_level_names = retry()(_top_level_names)
-safe_top_level_names = retry_in_subprocess()(_top_level_names)
-
-
-class File(h5py.File):
- """Takes care of HDF5 file locking and SWMR mode without the need
- to handle those explicitely.
-
- When using this class, you cannot open different files simultatiously
- with different modes because the locking flag is an environment variable.
- """
-
- _HDF5_FILE_LOCKING = None
- _NOPEN = 0
- _SWMR_LIBVER = "latest"
-
- def __init__(
- self,
- filename,
- mode=None,
- enable_file_locking=None,
- swmr=None,
- libver=None,
- **kwargs
- ):
- """The arguments `enable_file_locking` and `swmr` should not be
- specified explicitly for normal use cases.
-
- :param str filename:
- :param str or None mode: read-only by default
- :param bool or None enable_file_locking: by default it is disabled for `mode='r'`
- and `swmr=False` and enabled for all
- other modes.
- :param bool or None swmr: try both modes when `mode='r'` and `swmr=None`
- :param **kwargs: see `h5py.File.__init__`
- """
- if mode is None:
- mode = "r"
- elif mode not in ("r", "w", "w-", "x", "a", "r+"):
- raise ValueError("invalid mode {}".format(mode))
- if not HAS_SWMR:
- swmr = False
-
- if enable_file_locking is None:
- enable_file_locking = bool(mode != "r" or swmr)
- if self._NOPEN:
- self._check_locking_env(enable_file_locking)
- else:
- self._set_locking_env(enable_file_locking)
-
- if swmr and libver is None:
- libver = self._SWMR_LIBVER
-
- if HAS_TRACK_ORDER:
- kwargs.setdefault("track_order", True)
- try:
- super().__init__(filename, mode=mode, swmr=swmr, libver=libver, **kwargs)
- except OSError as e:
- # wlock wSWMR rlock rSWMR OSError: Unable to open file (...)
- # 1 TRUE FALSE FALSE FALSE -
- # 2 TRUE FALSE FALSE TRUE -
- # 3 TRUE FALSE TRUE FALSE unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'
- # 4 TRUE FALSE TRUE TRUE unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'
- # 5 TRUE TRUE FALSE FALSE file is already open for write (may use <h5clear file> to clear file consistency flags)
- # 6 TRUE TRUE FALSE TRUE -
- # 7 TRUE TRUE TRUE FALSE file is already open for write (may use <h5clear file> to clear file consistency flags)
- # 8 TRUE TRUE TRUE TRUE -
- if (
- mode == "r"
- and swmr is None
- and "file is already open for write" in str(e)
- ):
- # Try reading in SWMR mode (situation 5 and 7)
- swmr = True
- if libver is None:
- libver = self._SWMR_LIBVER
- super().__init__(
- filename, mode=mode, swmr=swmr, libver=libver, **kwargs
- )
- else:
- raise
- else:
- self._add_nopen(1)
- try:
- if mode != "r" and swmr:
- # Try setting writer in SWMR mode
- self.swmr_mode = True
- except Exception:
- self.close()
- raise
-
- @classmethod
- def _add_nopen(cls, v):
- cls._NOPEN = max(cls._NOPEN + v, 0)
-
- def close(self):
- super().close()
- self._add_nopen(-1)
- if not self._NOPEN:
- self._restore_locking_env()
-
- def _set_locking_env(self, enable):
- self._backup_locking_env()
- if enable:
- os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"
- elif enable is None:
- try:
- del os.environ["HDF5_USE_FILE_LOCKING"]
- except KeyError:
- pass
- else:
- os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
-
- def _get_locking_env(self):
- v = os.environ.get("HDF5_USE_FILE_LOCKING")
- if v == "TRUE":
- return True
- elif v is None:
- return None
- else:
- return False
-
- def _check_locking_env(self, enable):
- if enable != self._get_locking_env():
- if enable:
- raise RuntimeError(
- "Close all HDF5 files before enabling HDF5 file locking"
- )
- else:
- raise RuntimeError(
- "Close all HDF5 files before disabling HDF5 file locking"
- )
-
- def _backup_locking_env(self):
- v = os.environ.get("HDF5_USE_FILE_LOCKING")
- if v is None:
- self._HDF5_FILE_LOCKING = None
- else:
- self._HDF5_FILE_LOCKING = v == "TRUE"
-
- def _restore_locking_env(self):
- self._set_locking_env(self._HDF5_FILE_LOCKING)
- self._HDF5_FILE_LOCKING = None
diff --git a/silx/io/nxdata/_utils.py b/silx/io/nxdata/_utils.py
deleted file mode 100644
index 8b3474a..0000000
--- a/silx/io/nxdata/_utils.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Utility functions used by NXdata validation and parsing."""
-
-import copy
-import logging
-
-import numpy
-import six
-
-from silx.io import is_dataset
-from silx.utils.deprecation import deprecated
-
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/04/2018"
-
-
-nxdata_logger = logging.getLogger("silx.io.nxdata")
-
-
-INTERPDIM = {"scalar": 0,
- "spectrum": 1,
- "image": 2,
- "rgba-image": 3, # "hsla-image": 3, "cmyk-image": 3, # TODO
- "vertex": 1} # 3D scatter: 1D signal + 3 axes (x, y, z) of same legth
-"""Number of signal dimensions associated to each possible @interpretation
-attribute.
-"""
-
-
-@deprecated(since_version="0.8.0", replacement="get_attr_as_unicode")
-def get_attr_as_string(*args, **kwargs):
- return get_attr_as_unicode(*args, **kwargs)
-
-
-def get_attr_as_unicode(item, attr_name, default=None):
- """Return item.attrs[attr_name] as unicode or as a
- list of unicode.
-
- Numpy arrays of strings or bytes returned by h5py are converted to
- lists of unicode.
-
- :param item: Group or dataset
- :param attr_name: Attribute name
- :param default: Value to be returned if attribute is not found.
- :return: item.attrs[attr_name]
- """
- attr = item.attrs.get(attr_name, default)
-
- if isinstance(attr, six.binary_type):
- # byte-string
- return attr.decode("utf-8")
- elif isinstance(attr, numpy.ndarray) and not attr.shape:
- if isinstance(attr[()], six.binary_type):
- # byte string as ndarray scalar
- return attr[()].decode("utf-8")
- else:
- # other scalar, possibly unicode
- return attr[()]
- elif isinstance(attr, numpy.ndarray) and len(attr.shape):
- if hasattr(attr[0], "decode"):
- # array of byte-strings
- return [element.decode("utf-8") for element in attr]
- else:
- # other array, most likely unicode objects
- return [element for element in attr]
- else:
- return copy.deepcopy(attr)
-
-
-def get_uncertainties_names(group, signal_name):
- # Test consistency of @uncertainties
- uncertainties_names = get_attr_as_unicode(group, "uncertainties")
- if uncertainties_names is None:
- uncertainties_names = get_attr_as_unicode(group[signal_name], "uncertainties")
- if isinstance(uncertainties_names, six.text_type):
- uncertainties_names = [uncertainties_names]
- return uncertainties_names
-
-
-def get_signal_name(group):
- """Return the name of the (main) signal in a NXdata group.
- Return None if this info is missing (invalid NXdata).
-
- """
- signal_name = get_attr_as_unicode(group, "signal", default=None)
- if signal_name is None:
- nxdata_logger.info("NXdata group %s does not define a signal attr. "
- "Testing legacy specification.", group.name)
- for key in group:
- if "signal" in group[key].attrs:
- signal_name = key
- signal_attr = group[key].attrs["signal"]
- if signal_attr in [1, b"1", u"1"]:
- # This is the main (default) signal
- break
- return signal_name
-
-
-def get_auxiliary_signals_names(group):
- """Return list of auxiliary signals names"""
- auxiliary_signals_names = get_attr_as_unicode(group, "auxiliary_signals",
- default=[])
- if isinstance(auxiliary_signals_names, (six.text_type, six.binary_type)):
- auxiliary_signals_names = [auxiliary_signals_names]
- return auxiliary_signals_names
-
-
-def validate_auxiliary_signals(group, signal_name, auxiliary_signals_names):
- """Check data dimensionality and size. Return False if invalid."""
- issues = []
- for asn in auxiliary_signals_names:
- if asn not in group or not is_dataset(group[asn]):
- issues.append(
- "Cannot find auxiliary signal dataset '%s'" % asn)
- elif group[signal_name].shape != group[asn].shape:
- issues.append("Auxiliary signal dataset '%s' does not" % asn +
- " have the same shape as the main signal.")
- return issues
-
-
-def validate_number_of_axes(group, signal_name, num_axes):
- issues = []
- ndims = len(group[signal_name].shape)
- if 1 < ndims < num_axes:
- # ndim = 1 with several axes could be a scatter
- issues.append(
- "More @axes defined than there are " +
- "signal dimensions: " +
- "%d axes, %d dimensions." % (num_axes, ndims))
-
- # case of less axes than dimensions: number of axes must match
- # dimensionality defined by @interpretation
- elif ndims > num_axes:
- interpretation = get_attr_as_unicode(group[signal_name], "interpretation")
- if interpretation is None:
- interpretation = get_attr_as_unicode(group, "interpretation")
- if interpretation is None:
- issues.append("No @interpretation and not enough" +
- " @axes defined.")
-
- elif interpretation not in INTERPDIM:
- issues.append("Unrecognized @interpretation=" + interpretation +
- " for data with wrong number of defined @axes.")
- elif interpretation == "rgba-image":
- if ndims != 3 or group[signal_name].shape[-1] not in [3, 4]:
- issues.append(
- "Inconsistent RGBA Image. Expected 3 dimensions with " +
- "last one of length 3 or 4. Got ndim=%d " % ndims +
- "with last dimension of length %d." % group[signal_name].shape[-1])
- if num_axes != 2:
- issues.append(
- "Inconsistent number of axes for RGBA Image. Expected "
- "3, but got %d." % ndims)
-
- elif num_axes != INTERPDIM[interpretation]:
- issues.append(
- "%d-D signal with @interpretation=%s " % (ndims, interpretation) +
- "must define %d or %d axes." % (ndims, INTERPDIM[interpretation]))
- return issues
diff --git a/silx/io/nxdata/parse.py b/silx/io/nxdata/parse.py
deleted file mode 100644
index b1c1bba..0000000
--- a/silx/io/nxdata/parse.py
+++ /dev/null
@@ -1,997 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides a collection of functions to work with h5py-like
-groups following the NeXus *NXdata* specification.
-
-See http://download.nexusformat.org/sphinx/classes/base_classes/NXdata.html
-
-The main class is :class:`NXdata`.
-You can also fetch the default NXdata in a NXroot or a NXentry with function
-:func:`get_default`.
-
-
-Other public functions:
-
- - :func:`is_valid_nxdata`
- - :func:`is_NXroot_with_default_NXdata`
- - :func:`is_NXentry_with_default_NXdata`
- - :func:`is_group_with_default_NXdata`
-
-"""
-
-import json
-import numpy
-import six
-
-from silx.io.utils import is_group, is_file, is_dataset, h5py_read_dataset
-
-from ._utils import get_attr_as_unicode, INTERPDIM, nxdata_logger, \
- get_uncertainties_names, get_signal_name, \
- get_auxiliary_signals_names, validate_auxiliary_signals, validate_number_of_axes
-
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/03/2020"
-
-
-class InvalidNXdataError(Exception):
- pass
-
-
-class _SilxStyle(object):
- """NXdata@SILX_style parser.
-
- :param NXdata nxdata:
- NXdata description for which to extract silx_style information.
- """
-
- def __init__(self, nxdata):
- naxes = len(nxdata.axes)
- self._axes_scale_types = [None] * naxes
- self._signal_scale_type = None
-
- stylestr = get_attr_as_unicode(nxdata.group, "SILX_style")
- if stylestr is None:
- return
-
- try:
- style = json.loads(stylestr)
- except json.JSONDecodeError:
- nxdata_logger.error(
- "Ignoring SILX_style, cannot parse: %s", stylestr)
- return
-
- if not isinstance(style, dict):
- nxdata_logger.error(
- "Ignoring SILX_style, cannot parse: %s", stylestr)
-
- if 'axes_scale_types' in style:
- axes_scale_types = style['axes_scale_types']
-
- if isinstance(axes_scale_types, str):
- # Convert single argument to list
- axes_scale_types = [axes_scale_types]
-
- if not isinstance(axes_scale_types, list):
- nxdata_logger.error(
- "Ignoring SILX_style:axes_scale_types, not a list")
- else:
- for scale_type in axes_scale_types:
- if scale_type not in ('linear', 'log'):
- nxdata_logger.error(
- "Ignoring SILX_style:axes_scale_types, invalid value: %s", str(scale_type))
- break
- else: # All values are valid
- if len(axes_scale_types) > naxes:
- nxdata_logger.error(
- "Clipping SILX_style:axes_scale_types, too many values")
- axes_scale_types = axes_scale_types[:naxes]
- elif len(axes_scale_types) < naxes:
- # Extend axes_scale_types with None to match number of axes
- axes_scale_types = [None] * (naxes - len(axes_scale_types)) + axes_scale_types
- self._axes_scale_types = tuple(axes_scale_types)
-
- if 'signal_scale_type' in style:
- scale_type = style['signal_scale_type']
- if scale_type not in ('linear', 'log'):
- nxdata_logger.error(
- "Ignoring SILX_style:signal_scale_type, invalid value: %s", str(scale_type))
- else:
- self._signal_scale_type = scale_type
-
- axes_scale_types = property(
- lambda self: self._axes_scale_types,
- doc="Tuple of NXdata axes scale types (None, 'linear' or 'log'). List[str]")
-
- signal_scale_type = property(
- lambda self: self._signal_scale_type,
- doc="NXdata signal scale type (None, 'linear' or 'log'). str")
-
-
-class NXdata(object):
- """NXdata parser.
-
- .. note::
-
- Before attempting to access any attribute or property,
- you should check that :attr:`is_valid` is *True*.
-
- :param group: h5py-like group following the NeXus *NXdata* specification.
- :param boolean validate: Set this parameter to *False* to skip the initial
- validation. This option is provided for optimisation purposes, for cases
- where :meth:`silx.io.nxdata.is_valid_nxdata` has already been called
- prior to instantiating this :class:`NXdata`.
- """
- def __init__(self, group, validate=True):
- super(NXdata, self).__init__()
- self._plot_style = None
-
- self.group = group
- """h5py-like group object with @NX_class=NXdata.
- """
-
- self.issues = []
- """List of error messages for malformed NXdata."""
-
- if validate:
- self._validate()
- self.is_valid = not self.issues
- """Validity status for this NXdata.
- If False, all properties and attributes will be None.
- """
-
- self._is_scatter = None
- self._axes = None
-
- self.signal = None
- """Main signal dataset in this NXdata group.
- In case more than one signal is present in this group,
- the other ones can be found in :attr:`auxiliary_signals`.
- """
-
- self.signal_name = None
- """Signal long name, as specified in the @long_name attribute of the
- signal dataset. If not specified, the dataset name is used."""
-
- self.signal_ndim = None
- self.signal_is_0d = None
- self.signal_is_1d = None
- self.signal_is_2d = None
- self.signal_is_3d = None
-
- self.axes_names = None
- """List of axes names in a NXdata group.
-
- This attribute is similar to :attr:`axes_dataset_names` except that
- if an axis dataset has a "@long_name" attribute, it will be used
- instead of the dataset name.
- """
-
- if not self.is_valid:
- nxdata_logger.debug("%s", self.issues)
- else:
- self.signal = self.group[self.signal_dataset_name]
- self.signal_name = get_attr_as_unicode(self.signal, "long_name")
-
- if self.signal_name is None:
- self.signal_name = self.signal_dataset_name
-
- # ndim will be available in very recent h5py versions only
- self.signal_ndim = getattr(self.signal, "ndim",
- len(self.signal.shape))
-
- self.signal_is_0d = self.signal_ndim == 0
- self.signal_is_1d = self.signal_ndim == 1
- self.signal_is_2d = self.signal_ndim == 2
- self.signal_is_3d = self.signal_ndim == 3
-
- self.axes_names = []
- # check if axis dataset defines @long_name
- for _, dsname in enumerate(self.axes_dataset_names):
- if dsname is not None and "long_name" in self.group[dsname].attrs:
- self.axes_names.append(get_attr_as_unicode(self.group[dsname], "long_name"))
- else:
- self.axes_names.append(dsname)
-
- # excludes scatters
- self.signal_is_1d = self.signal_is_1d and len(self.axes) <= 1 # excludes n-D scatters
-
- self._plot_style = _SilxStyle(self)
-
- def _validate(self):
- """Fill :attr:`issues` with error messages for each error found."""
- if not is_group(self.group):
- raise TypeError("group must be a h5py-like group")
- if get_attr_as_unicode(self.group, "NX_class") != "NXdata":
- self.issues.append("Group has no attribute @NX_class='NXdata'")
- return
-
- signal_name = get_signal_name(self.group)
- if signal_name is None:
- self.issues.append("No @signal attribute on the NXdata group, "
- "and no dataset with a @signal=1 attr found")
- # very difficult to do more consistency tests without signal
- return
-
- elif signal_name not in self.group or not is_dataset(self.group[signal_name]):
- self.issues.append("Cannot find signal dataset '%s'" % signal_name)
- return
-
- auxiliary_signals_names = get_auxiliary_signals_names(self.group)
- self.issues += validate_auxiliary_signals(self.group,
- signal_name,
- auxiliary_signals_names)
-
- if "axes" in self.group.attrs:
- axes_names = get_attr_as_unicode(self.group, "axes")
- if isinstance(axes_names, (six.text_type, six.binary_type)):
- axes_names = [axes_names]
-
- self.issues += validate_number_of_axes(self.group, signal_name,
- num_axes=len(axes_names))
-
- # Test consistency of @uncertainties
- uncertainties_names = get_uncertainties_names(self.group, signal_name)
- if uncertainties_names is not None:
- if len(uncertainties_names) != len(axes_names):
- if len(uncertainties_names) < len(axes_names):
- # ignore the field to avoid index error in the axes loop
- uncertainties_names = None
- self.issues.append("@uncertainties does not define the same " +
- "number of fields than @axes. Field ignored")
- else:
- self.issues.append("@uncertainties does not define the same " +
- "number of fields than @axes")
-
- # Test individual axes
- is_scatter = True # true if all axes have the same size as the signal
- signal_size = 1
- for dim in self.group[signal_name].shape:
- signal_size *= dim
- polynomial_axes_names = []
- for i, axis_name in enumerate(axes_names):
-
- if axis_name == ".":
- continue
- if axis_name not in self.group or not is_dataset(self.group[axis_name]):
- self.issues.append("Could not find axis dataset '%s'" % axis_name)
- continue
-
- axis_size = 1
- for dim in self.group[axis_name].shape:
- axis_size *= dim
-
- if len(self.group[axis_name].shape) != 1:
- # I don't know how to interpret n-D axes
- self.issues.append("Axis %s is not 1D" % axis_name)
- continue
- else:
- # for a 1-d axis,
- fg_idx = self.group[axis_name].attrs.get("first_good", 0)
- lg_idx = self.group[axis_name].attrs.get("last_good", len(self.group[axis_name]) - 1)
- axis_len = lg_idx + 1 - fg_idx
-
- if axis_len != signal_size:
- if axis_len not in self.group[signal_name].shape + (1, 2):
- self.issues.append(
- "Axis %s number of elements does not " % axis_name +
- "correspond to the length of any signal dimension,"
- " it does not appear to be a constant or a linear calibration," +
- " and this does not seem to be a scatter plot.")
- continue
- elif axis_len in (1, 2):
- polynomial_axes_names.append(axis_name)
- is_scatter = False
- else:
- if not is_scatter:
- self.issues.append(
- "Axis %s number of elements is equal " % axis_name +
- "to the length of the signal, but this does not seem" +
- " to be a scatter (other axes have different sizes)")
- continue
-
- # Test individual uncertainties
- errors_name = axis_name + "_errors"
- if errors_name not in self.group and uncertainties_names is not None:
- errors_name = uncertainties_names[i]
- if errors_name in self.group and axis_name not in polynomial_axes_names:
- if self.group[errors_name].shape != self.group[axis_name].shape:
- self.issues.append(
- "Errors '%s' does not have the same " % errors_name +
- "dimensions as axis '%s'." % axis_name)
-
- # test dimensions of errors associated with signal
-
- signal_errors = signal_name + "_errors"
- if "errors" in self.group and is_dataset(self.group["errors"]):
- errors = "errors"
- elif signal_errors in self.group and is_dataset(self.group[signal_errors]):
- errors = signal_errors
- else:
- errors = None
- if errors:
- if self.group[errors].shape != self.group[signal_name].shape:
- # In principle just the same size should be enough but
- # NeXus documentation imposes to have the same shape
- self.issues.append(
- "Dataset containing standard deviations must " +
- "have the same dimensions as the signal.")
-
- @property
- def signal_dataset_name(self):
- """Name of the main signal dataset."""
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
- signal_dataset_name = get_attr_as_unicode(self.group, "signal")
- if signal_dataset_name is None:
- # find a dataset with @signal == 1
- for dsname in self.group:
- signal_attr = self.group[dsname].attrs.get("signal")
- if signal_attr in [1, b"1", u"1"]:
- # This is the main (default) signal
- signal_dataset_name = dsname
- break
- assert signal_dataset_name is not None
- return signal_dataset_name
-
- @property
- def auxiliary_signals_dataset_names(self):
- """Sorted list of names of the auxiliary signals datasets.
-
- These are the names provided by the *@auxiliary_signals* attribute
- on the NXdata group.
-
- In case the NXdata group does not specify a *@signal* attribute
- but has a dataset with an attribute *@signal=1*,
- we look for datasets with attributes *@signal=2, @signal=3...*
- (deprecated NXdata specification)."""
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
- signal_dataset_name = get_attr_as_unicode(self.group, "signal")
- if signal_dataset_name is not None:
- auxiliary_signals_names = get_attr_as_unicode(self.group, "auxiliary_signals")
- if auxiliary_signals_names is not None:
- if not isinstance(auxiliary_signals_names,
- (tuple, list, numpy.ndarray)):
- # tolerate a single string, but coerce into a list
- return [auxiliary_signals_names]
- return list(auxiliary_signals_names)
- return []
-
- # try old spec, @signal=1 (2, 3...) on dataset
- numbered_names = []
- for dsname in self.group:
- if dsname == self.signal_dataset_name:
- # main signal, not auxiliary
- continue
- ds = self.group[dsname]
- signal_attr = ds.attrs.get("signal")
- if signal_attr is not None and not is_dataset(ds):
- nxdata_logger.warning("Item %s with @signal=%s is not a dataset (%s)",
- dsname, signal_attr, type(ds))
- continue
- if signal_attr is not None:
- try:
- signal_number = int(signal_attr)
- except (ValueError, TypeError):
- nxdata_logger.warning("Could not parse attr @signal=%s on "
- "dataset %s as an int",
- signal_attr, dsname)
- continue
- numbered_names.append((signal_number, dsname))
- return [a[1] for a in sorted(numbered_names)]
-
- @property
- def auxiliary_signals_names(self):
- """List of names of the auxiliary signals.
-
- Similar to :attr:`auxiliary_signals_dataset_names`, but the @long_name
- is used when this attribute is present, instead of the dataset name.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- signal_names = []
- for asdn in self.auxiliary_signals_dataset_names:
- if "long_name" in self.group[asdn].attrs:
- signal_names.append(self.group[asdn].attrs["long_name"])
- else:
- signal_names.append(asdn)
- return signal_names
-
- @property
- def auxiliary_signals(self):
- """List of all auxiliary signal datasets."""
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- return [self.group[dsname] for dsname in self.auxiliary_signals_dataset_names]
-
- @property
- def interpretation(self):
- """*@interpretation* attribute associated with the *signal*
- dataset of the NXdata group. ``None`` if no interpretation
- attribute is present.
-
- The *interpretation* attribute provides information about the last
- dimensions of the signal. The allowed values are:
-
- - *"scalar"*: 0-D data to be plotted
- - *"spectrum"*: 1-D data to be plotted
- - *"image"*: 2-D data to be plotted
- - *"vertex"*: 3-D data to be plotted
-
- For example, a 3-D signal with interpretation *"spectrum"* should be
- considered to be a 2-D array of 1-D data. A 3-D signal with
- interpretation *"image"* should be interpreted as a 1-D array (a list)
- of 2-D images. An n-D array with interpretation *"image"* should be
- interpreted as an (n-2)-D array of images.
-
- A warning message is logged if the returned interpretation is not one
- of the allowed values, but no error is raised and the unknown
- interpretation is returned anyway.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- allowed_interpretations = [None, "scaler", "scalar", "spectrum", "image",
- "rgba-image", # "hsla-image", "cmyk-image"
- "vertex"]
-
- interpretation = get_attr_as_unicode(self.signal, "interpretation")
- if interpretation is None:
- interpretation = get_attr_as_unicode(self.group, "interpretation")
-
- if interpretation not in allowed_interpretations:
- nxdata_logger.warning("Interpretation %s is not valid." % interpretation +
- " Valid values: " + ", ".join(str(s) for s in allowed_interpretations))
- return interpretation
-
- @property
- def axes(self):
- """List of the axes datasets.
-
- The list typically has as many elements as there are dimensions in the
- signal dataset, the exception being scatter plots which use a 1D
- signal and multiple 1D axes of the same size.
-
- If an axis dataset applies to several dimensions of the signal, it
- will be repeated in the list.
-
- If a dimension of the signal has no dimension scale, `None` is
- inserted in its position in the list.
-
- .. note::
-
- The *@axes* attribute should define as many entries as there
- are dimensions in the signal, to avoid any ambiguity.
- If this is not the case, this implementation relies on the existence
- of an *@interpretation* (*spectrum* or *image*) attribute in the
- *signal* dataset.
-
- .. note::
-
- If an axis dataset defines attributes @first_good or @last_good,
- the output will be a numpy array resulting from slicing that
- axis (*axis[first_good:last_good + 1]*).
-
- :rtype: List[Dataset or 1D array or None]
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- if self._axes is not None:
- # use cache
- return self._axes
- axes = []
- for axis_name in self.axes_dataset_names:
- if axis_name is None:
- axes.append(None)
- else:
- axes.append(self.group[axis_name])
-
- # keep only good range of axis data
- for i, axis in enumerate(axes):
- if axis is None:
- continue
- if "first_good" not in axis.attrs and "last_good" not in axis.attrs:
- continue
- fg_idx = axis.attrs.get("first_good", 0)
- lg_idx = axis.attrs.get("last_good", len(axis) - 1)
- axes[i] = axis[fg_idx:lg_idx + 1]
-
- self._axes = axes
- return self._axes
-
- @property
- def axes_dataset_names(self):
- """List of axes dataset names.
-
- If an axis dataset applies to several dimensions of the signal, its
- name will be repeated in the list.
-
- If a dimension of the signal has no dimension scale (i.e. there is a
- "." in that position in the *@axes* array), `None` is inserted in the
- output list in its position.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- numbered_names = [] # used in case of @axis=0 (old spec)
- axes_dataset_names = get_attr_as_unicode(self.group, "axes")
- if axes_dataset_names is None:
- # try @axes on signal dataset (older NXdata specification)
- axes_dataset_names = get_attr_as_unicode(self.signal, "axes")
- if axes_dataset_names is not None:
- # we expect a comma separated string
- if hasattr(axes_dataset_names, "split"):
- axes_dataset_names = axes_dataset_names.split(":")
- else:
- # try @axis on the individual datasets (oldest NXdata specification)
- for dsname in self.group:
- if not is_dataset(self.group[dsname]):
- continue
- axis_attr = self.group[dsname].attrs.get("axis")
- if axis_attr is not None:
- try:
- axis_num = int(axis_attr)
- except (ValueError, TypeError):
- nxdata_logger.warning("Could not interpret attr @axis as"
- "int on dataset %s", dsname)
- continue
- numbered_names.append((axis_num, dsname))
-
- ndims = len(self.signal.shape)
- if axes_dataset_names is None:
- if numbered_names:
- axes_dataset_names = []
- numbers = [a[0] for a in numbered_names]
- names = [a[1] for a in numbered_names]
- for i in range(ndims):
- if i in numbers:
- axes_dataset_names.append(names[numbers.index(i)])
- else:
- axes_dataset_names.append(None)
- return axes_dataset_names
- else:
- return [None] * ndims
-
- if isinstance(axes_dataset_names, (six.text_type, six.binary_type)):
- axes_dataset_names = [axes_dataset_names]
-
- for i, axis_name in enumerate(axes_dataset_names):
- if hasattr(axis_name, "decode"):
- axis_name = axis_name.decode()
- if axis_name == ".":
- axes_dataset_names[i] = None
-
- if len(axes_dataset_names) != ndims:
- if self.is_scatter and ndims == 1:
- # case of a 1D signal with arbitrary number of axes
- return list(axes_dataset_names)
- if self.interpretation != "rgba-image":
- # @axes may only define 1 or 2 axes if @interpretation=spectrum/image.
- # Use the existing names for the last few dims, and prepend with Nones.
- assert len(axes_dataset_names) == INTERPDIM[self.interpretation]
- all_dimensions_names = [None] * (ndims - INTERPDIM[self.interpretation])
- for axis_name in axes_dataset_names:
- all_dimensions_names.append(axis_name)
- else:
- # 2 axes applying to the first two dimensions.
- # The 3rd signal dimension is expected to contain 3(4) RGB(A) values.
- assert len(axes_dataset_names) == 2
- all_dimensions_names = [axn for axn in axes_dataset_names]
- all_dimensions_names.append(None)
- return all_dimensions_names
-
- return list(axes_dataset_names)
-
- @property
- def title(self):
- """Plot title. If not found, returns an empty string.
-
- This attribute does not appear in the NXdata specification, but it is
- implemented in *nexpy* as a dataset named "title" inside the NXdata
- group. This dataset is expected to contain text.
-
- Because the *nexpy* approach could cause a conflict if the signal
- dataset or an axis dataset happened to be called "title", we also
- support providing the title as an attribute of the NXdata group.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- title = self.group.get("title")
- data_dataset_names = [self.signal_name] + self.axes_dataset_names
- if (title is not None and is_dataset(title) and
- "title" not in data_dataset_names):
- return str(h5py_read_dataset(title))
-
- title = self.group.attrs.get("title")
- if title is None:
- return ""
- return str(title)
-
- def get_axis_errors(self, axis_name):
- """Return errors (uncertainties) associated with an axis.
-
- If the axis has attributes @first_good or @last_good, the output
- is trimmed accordingly (a numpy array will be returned rather than a
- dataset).
-
- :param str axis_name: Name of axis dataset. This dataset **must exist**.
- :return: Dataset with axis errors, or None
- :raise KeyError: if this group does not contain a dataset named axis_name
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- # ensure axis_name is decoded, before comparing it with decoded attributes
- if hasattr(axis_name, "decode"):
- axis_name = axis_name.decode("utf-8")
- if axis_name not in self.group:
- # tolerate axis_name given as @long_name
- for item in self.group:
- long_name = get_attr_as_unicode(self.group[item], "long_name")
- if long_name is not None and long_name == axis_name:
- axis_name = item
- break
-
- if axis_name not in self.group:
- raise KeyError("group does not contain a dataset named '%s'" % axis_name)
-
- len_axis = len(self.group[axis_name])
-
- fg_idx = self.group[axis_name].attrs.get("first_good", 0)
- lg_idx = self.group[axis_name].attrs.get("last_good", len_axis - 1)
-
- # case of axisname_errors dataset present
- errors_name = axis_name + "_errors"
- if errors_name in self.group and is_dataset(self.group[errors_name]):
- if fg_idx != 0 or lg_idx != (len_axis - 1):
- return self.group[errors_name][fg_idx:lg_idx + 1]
- else:
- return self.group[errors_name]
- # case of uncertainties dataset name provided in @uncertainties
- uncertainties_names = get_attr_as_unicode(self.group, "uncertainties")
- if uncertainties_names is None:
- uncertainties_names = get_attr_as_unicode(self.signal, "uncertainties")
- if isinstance(uncertainties_names, six.text_type):
- uncertainties_names = [uncertainties_names]
- if uncertainties_names is not None:
- # take the uncertainty with the same index as the axis in @axes
- axes_ds_names = get_attr_as_unicode(self.group, "axes")
- if axes_ds_names is None:
- axes_ds_names = get_attr_as_unicode(self.signal, "axes")
- if isinstance(axes_ds_names, six.text_type):
- axes_ds_names = [axes_ds_names]
- elif isinstance(axes_ds_names, numpy.ndarray):
- # transform numpy.ndarray into list
- axes_ds_names = list(axes_ds_names)
- assert isinstance(axes_ds_names, list)
- if hasattr(axes_ds_names[0], "decode"):
- axes_ds_names = [ax_name.decode("utf-8") for ax_name in axes_ds_names]
- if axis_name not in axes_ds_names:
- raise KeyError("group attr @axes does not mention a dataset " +
- "named '%s'" % axis_name)
- errors = self.group[uncertainties_names[list(axes_ds_names).index(axis_name)]]
- if fg_idx == 0 and lg_idx == (len_axis - 1):
- return errors # dataset
- else:
- return errors[fg_idx:lg_idx + 1] # numpy array
- return None
-
- @property
- def errors(self):
- """Return errors (uncertainties) associated with the signal values.
-
- :return: Dataset with errors, or None
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- # case of signal
- signal_errors = self.signal_dataset_name + "_errors"
- if "errors" in self.group and is_dataset(self.group["errors"]):
- errors = "errors"
- elif signal_errors in self.group and is_dataset(self.group[signal_errors]):
- errors = signal_errors
- else:
- return None
- return self.group[errors]
-
- @property
- def plot_style(self):
- """Information extracted from the optional SILX_style attribute
-
- :raises: InvalidNXdataError
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- return self._plot_style
-
- @property
- def is_scatter(self):
- """True if the signal is 1D and all the axes have the
- same size as the signal."""
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- if self._is_scatter is not None:
- return self._is_scatter
- if not self.signal_is_1d:
- self._is_scatter = False
- else:
- self._is_scatter = True
- sigsize = 1
- for dim in self.signal.shape:
- sigsize *= dim
- for axis in self.axes:
- if axis is None:
- continue
- axis_size = 1
- for dim in axis.shape:
- axis_size *= dim
- self._is_scatter = self._is_scatter and (axis_size == sigsize)
- return self._is_scatter
-
- @property
- def is_x_y_value_scatter(self):
- """True if this is a scatter with a signal and two axes."""
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- return self.is_scatter and len(self.axes) == 2
-
- # we currently have no widget capable of plotting 4D data
- @property
- def is_unsupported_scatter(self):
- """True if this is a scatter with a signal and more than 2 axes."""
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- return self.is_scatter and len(self.axes) > 2
-
- @property
- def is_curve(self):
- """This property is True if the signal is 1D or :attr:`interpretation` is
- *"spectrum"*, and there is at most one axis with a consistent length.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- if self.signal_is_0d or self.interpretation not in [None, "spectrum"]:
- return False
- # the axis, if any, must be of the same length as the last dimension
- # of the signal, or of length 2 (a + b *x scale)
- if self.axes[-1] is not None and len(self.axes[-1]) not in [
- self.signal.shape[-1], 2]:
- return False
- if self.interpretation is None:
- # We no longer test whether x values are monotonic
- # (in the past, in that case, we used to consider it a scatter)
- return self.signal_is_1d
- # everything looks good
- return True
-
- @property
- def is_image(self):
- """True if the signal is 2D, or 3D with last dimension of length 3 or 4
- and interpretation *rgba-image*, or >2D with interpretation *image*.
- The axes (if any) length must also be consistent with the signal shape.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- if self.interpretation in ["scalar", "spectrum", "scaler"]:
- return False
- if self.signal_is_0d or self.signal_is_1d:
- return False
- if not self.signal_is_2d and \
- self.interpretation not in ["image", "rgba-image"]:
- return False
- if self.signal_is_3d and self.interpretation == "rgba-image":
- if self.signal.shape[-1] not in [3, 4]:
- return False
- img_axes = self.axes[0:2]
- img_shape = self.signal.shape[0:2]
- else:
- img_axes = self.axes[-2:]
- img_shape = self.signal.shape[-2:]
- for i, axis in enumerate(img_axes):
- if axis is not None and len(axis) not in [img_shape[i], 2]:
- return False
-
- return True
-
- @property
- def is_stack(self):
- """True in the signal is at least 3D and interpretation is not
- "scalar", "spectrum", "image" or "rgba-image".
- The axes length must also be consistent with the last 3 dimensions
- of the signal.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- if self.signal_ndim < 3 or self.interpretation in [
- "scalar", "scaler", "spectrum", "image", "rgba-image"]:
- return False
- stack_shape = self.signal.shape[-3:]
- for i, axis in enumerate(self.axes[-3:]):
- if axis is not None and len(axis) not in [stack_shape[i], 2]:
- return False
- return True
-
- @property
- def is_volume(self):
- """True in the signal is exactly 3D and interpretation
- "scalar", or nothing.
-
- The axes length must also be consistent with the 3 dimensions
- of the signal.
- """
- if not self.is_valid:
- raise InvalidNXdataError("Unable to parse invalid NXdata")
-
- if self.signal_ndim != 3:
- return False
- if self.interpretation not in [None, "scalar", "scaler"]:
- # 'scaler' and 'scalar' for a three dimensional array indicate a scalar field in 3D
- return False
- volume_shape = self.signal.shape[-3:]
- for i, axis in enumerate(self.axes[-3:]):
- if axis is not None and len(axis) not in [volume_shape[i], 2]:
- return False
- return True
-
-
-def is_valid_nxdata(group): # noqa
- """Check if a h5py group is a **valid** NX_data group.
-
- :param group: h5py-like group
- :return: True if this NXdata group is valid.
- :raise TypeError: if group is not a h5py group, a spech5 group,
- or a fabioh5 group
- """
- nxd = NXdata(group)
- return nxd.is_valid
-
-
-def is_group_with_default_NXdata(group, validate=True):
- """Return True if group defines a valid default
- NXdata.
-
- .. note::
-
- See https://github.com/silx-kit/silx/issues/2215
-
- :param group: h5py-like object.
- :param bool validate: Set this to skip the NXdata validation, and only
- check the existence of the group.
- Parameter provided for optimisation purposes, to avoid double
- validation if the validation is already performed separately."""
- default_nxdata_name = group.attrs.get("default")
- if default_nxdata_name is None or default_nxdata_name not in group:
- return False
-
- default_nxdata_group = group.get(default_nxdata_name)
-
- if not is_group(default_nxdata_group):
- return False
-
- if not validate:
- return True
- else:
- return is_valid_nxdata(default_nxdata_group)
-
-
-def is_NXentry_with_default_NXdata(group, validate=True):
- """Return True if group is a valid NXentry defining a valid default
- NXdata.
-
- :param group: h5py-like object.
- :param bool validate: Set this to skip the NXdata validation, and only
- check the existence of the group.
- Parameter provided for optimisation purposes, to avoid double
- validation if the validation is already performed separately."""
- if not is_group(group):
- return False
-
- if get_attr_as_unicode(group, "NX_class") != "NXentry":
- return False
-
- return is_group_with_default_NXdata(group, validate)
-
-
-def is_NXroot_with_default_NXdata(group, validate=True):
- """Return True if group is a valid NXroot defining a default NXentry
- defining a valid default NXdata.
-
- .. note::
-
- A NXroot group cannot directly define a default NXdata. If a
- *@default* argument is present, it must point to a NXentry group.
- This NXentry must define a valid NXdata for this function to return
- True.
-
- :param group: h5py-like object.
- :param bool validate: Set this to False if you are sure that the target group
- is valid NXdata (i.e. :func:`silx.io.nxdata.is_valid_nxdata(target_group)`
- returns True). Parameter provided for optimisation purposes.
- """
- if not is_group(group):
- return False
-
- # A NXroot is supposed to be at the root of a data file, and @NX_class
- # is therefore optional. We accept groups that are not located at the root
- # if they have @NX_class=NXroot (use case: several nexus files archived
- # in a single HDF5 file)
- if get_attr_as_unicode(group, "NX_class") != "NXroot" and not is_file(group):
- return False
-
- default_nxentry_name = group.attrs.get("default")
- if default_nxentry_name is None or default_nxentry_name not in group:
- return False
-
- default_nxentry_group = group.get(default_nxentry_name)
- return is_NXentry_with_default_NXdata(default_nxentry_group,
- validate=validate)
-
-
-def get_default(group, validate=True):
- """Return a :class:`NXdata` object corresponding to the default NXdata group
- in the group specified as parameter.
-
- This function can find the NXdata if the group is already a NXdata, or
- if it is a NXentry defining a default NXdata, or if it is a NXroot
- defining such a default valid NXentry.
-
- Return None if no valid NXdata could be found.
-
- :param group: h5py-like group following the Nexus specification
- (NXdata, NXentry or NXroot).
- :param bool validate: Set this to False if you are sure that group
- is valid NXdata (i.e. :func:`silx.io.nxdata.is_valid_nxdata(group)`
- returns True). Parameter provided for optimisation purposes.
- :return: :class:`NXdata` object or None
- :raise TypeError: if group is not a h5py-like group
- """
- if not is_group(group):
- raise TypeError("Provided parameter is not a h5py-like group")
-
- if is_NXroot_with_default_NXdata(group, validate=validate):
- default_entry = group[group.attrs["default"]]
- default_data = default_entry[default_entry.attrs["default"]]
- elif is_group_with_default_NXdata(group, validate=validate):
- default_data = group[group.attrs["default"]]
- elif not validate or is_valid_nxdata(group):
- default_data = group
- else:
- return None
-
- return NXdata(default_data, validate=False)
diff --git a/silx/io/nxdata/write.py b/silx/io/nxdata/write.py
deleted file mode 100644
index e9ac3ac..0000000
--- a/silx/io/nxdata/write.py
+++ /dev/null
@@ -1,203 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-import os
-import logging
-
-import h5py
-import numpy
-import six
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/04/2018"
-
-
-_logger = logging.getLogger(__name__)
-
-
-def _str_to_utf8(text):
- return numpy.array(text, dtype=h5py.special_dtype(vlen=six.text_type))
-
-
-def save_NXdata(filename, signal, axes=None,
- signal_name="data", axes_names=None,
- signal_long_name=None, axes_long_names=None,
- signal_errors=None, axes_errors=None,
- title=None, interpretation=None,
- nxentry_name="entry", nxdata_name=None):
- """Write data to an NXdata group.
-
- .. note::
-
- No consistency checks are made regarding the dimensionality of the
- signal and number of axes. The user is responsible for providing
- meaningful data, that can be interpreted by visualization software.
-
- :param str filename: Path to output file. If the file does not
- exists, it is created.
- :param numpy.ndarray signal: Signal array.
- :param List[numpy.ndarray] axes: List of axes arrays.
- :param str signal_name: Name of signal dataset, in output file
- :param List[str] axes_names: List of dataset names for axes, in
- output file
- :param str signal_long_name: *@long_name* attribute for signal, or None.
- :param axes_long_names: None, or list of long names
- for axes
- :type axes_long_names: List[str, None]
- :param numpy.ndarray signal_errors: Array of errors associated with the
- signal
- :param axes_errors: List of arrays of errors
- associated with each axis
- :type axes_errors: List[numpy.ndarray, None]
- :param str title: Graph title (saved as a "title" dataset) or None.
- :param str interpretation: *@interpretation* attribute ("spectrum",
- "image", "rgba-image" or None). This is only needed in cases of
- ambiguous dimensionality, e.g. a 3D array which represents a RGBA
- image rather than a stack.
- :param str nxentry_name: Name of group in which the NXdata group
- is created. By default, "/entry" is used.
-
- .. note::
-
- The Nexus format specification requires for NXdata groups
- be part of a NXentry group.
- The specified group should have attribute *@NX_class=NXentry*, in
- order for the created file to be nexus compliant.
- :param str nxdata_name: Name of NXdata group. If omitted (None), the
- function creates a new group using the first available name ("data0",
- or "data1"...).
- Overwriting an existing group (or dataset) is not supported, you must
- delete it yourself prior to calling this function if this is what you
- want.
- :return: True if save was successful, else False.
- """
- if h5py is None:
- raise ImportError("h5py could not be imported, but is required by "
- "save_NXdata function")
-
- if axes_names is not None:
- assert axes is not None, "Axes names defined, but missing axes arrays"
- assert len(axes) == len(axes_names), \
- "Mismatch between number of axes and axes_names"
-
- if axes is not None and axes_names is None:
- axes_names = []
- for i, axis in enumerate(axes):
- axes_names.append("dim%d" % i if axis is not None else ".")
- if axes is None:
- axes = []
-
- # Open file in
- if os.path.exists(filename):
- errmsg = "Cannot write/append to existing path %s"
- if not os.path.isfile(filename):
- errmsg += " (not a file)"
- _logger.error(errmsg, filename)
- return False
- if not os.access(filename, os.W_OK):
- errmsg += " (no permission to write)"
- _logger.error(errmsg, filename)
- return False
- mode = "r+"
- else:
- mode = "w-"
-
- with h5py.File(filename, mode=mode) as h5f:
- # get or create entry
- if nxentry_name is not None:
- entry = h5f.require_group(nxentry_name)
- if "default" not in h5f.attrs:
- # set this entry as default
- h5f.attrs["default"] = _str_to_utf8(nxentry_name)
- if "NX_class" not in entry.attrs:
- entry.attrs["NX_class"] = u"NXentry"
- else:
- # write NXdata into the root of the file (invalid nexus!)
- entry = h5f
-
- # Create NXdata group
- if nxdata_name is not None:
- if nxdata_name in entry:
- _logger.error("Cannot assign an NXdata group to an existing"
- " group or dataset")
- return False
- else:
- # no name specified, take one that is available
- nxdata_name = "data0"
- i = 1
- while nxdata_name in entry:
- _logger.info("%s item already exists in NXentry group," +
- " trying %s", nxdata_name, "data%d" % i)
- nxdata_name = "data%d" % i
- i += 1
-
- data_group = entry.create_group(nxdata_name)
- data_group.attrs["NX_class"] = u"NXdata"
- data_group.attrs["signal"] = _str_to_utf8(signal_name)
- if axes:
- data_group.attrs["axes"] = _str_to_utf8(axes_names)
- if title:
- # not in NXdata spec, but implemented by nexpy
- data_group["title"] = title
- # better way imho
- data_group.attrs["title"] = _str_to_utf8(title)
-
- signal_dataset = data_group.create_dataset(signal_name,
- data=signal)
- if signal_long_name:
- signal_dataset.attrs["long_name"] = _str_to_utf8(signal_long_name)
- if interpretation:
- signal_dataset.attrs["interpretation"] = _str_to_utf8(interpretation)
-
- for i, axis_array in enumerate(axes):
- if axis_array is None:
- assert axes_names[i] in [".", None], \
- "Axis name defined for dim %d but no axis array" % i
- continue
- axis_dataset = data_group.create_dataset(axes_names[i],
- data=axis_array)
- if axes_long_names is not None:
- axis_dataset.attrs["long_name"] = _str_to_utf8(axes_long_names[i])
-
- if signal_errors is not None:
- data_group.create_dataset("errors",
- data=signal_errors)
-
- if axes_errors is not None:
- assert isinstance(axes_errors, (list, tuple)), \
- "axes_errors must be a list or a tuple of ndarray or None"
- assert len(axes_errors) == len(axes_names), \
- "Mismatch between number of axes_errors and axes_names"
- for i, axis_errors in enumerate(axes_errors):
- if axis_errors is not None:
- dsname = axes_names[i] + "_errors"
- data_group.create_dataset(dsname,
- data=axis_errors)
- if "default" not in entry.attrs:
- # set this NXdata as default
- entry.attrs["default"] = nxdata_name
-
- return True
diff --git a/silx/io/specfile.pyx b/silx/io/specfile.pyx
deleted file mode 100644
index 4e76c2c..0000000
--- a/silx/io/specfile.pyx
+++ /dev/null
@@ -1,1268 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-This module is a cython binding to wrap the C SpecFile library, to access
-SpecFile data within a python program.
-
-Documentation for the original C library SpecFile can be found on the ESRF
-website:
-`The manual for the SpecFile Library <http://www.esrf.eu/files/live/sites/www/files/Instrumentation/software/beamline-control/BLISS/documentation/SpecFileManual.pdf>`_
-
-Examples
-========
-
-Start by importing :class:`SpecFile` and instantiate it:
-
-.. code-block:: python
-
- from silx.io.specfile import SpecFile
-
- sf = SpecFile("test.dat")
-
-A :class:`SpecFile` instance can be accessed like a dictionary to obtain a
-:class:`Scan` instance.
-
-If the key is a string representing two values
-separated by a dot (e.g. ``"1.2"``), they will be treated as the scan number
-(``#S`` header line) and the scan order::
-
- # get second occurrence of scan "#S 1"
- myscan = sf["1.2"]
-
- # access scan data as a numpy array
- nlines, ncolumns = myscan.data.shape
-
-If the key is an integer, it will be treated as a 0-based index::
-
- first_scan = sf[0]
- second_scan = sf[1]
-
-It is also possible to browse through all scans using :class:`SpecFile` as
-an iterator::
-
- for scan in sf:
- print(scan.scan_header_dict['S'])
-
-MCA spectra can be selectively loaded using an instance of :class:`MCA`
-provided by :class:`Scan`::
-
- # Only one MCA spectrum is loaded in memory
- second_mca = first_scan.mca[1]
-
- # Iterating trough all MCA spectra in a scan:
- for mca_data in first_scan.mca:
- print(sum(mca_data))
-
-Classes
-=======
-
-- :class:`SpecFile`
-- :class:`Scan`
-- :class:`MCA`
-
-Exceptions
-==========
-
-- :class:`SfError`
-- :class:`SfErrMemoryAlloc`
-- :class:`SfErrFileOpen`
-- :class:`SfErrFileClose`
-- :class:`SfErrFileRead`
-- :class:`SfErrFileWrite`
-- :class:`SfErrLineNotFound`
-- :class:`SfErrScanNotFound`
-- :class:`SfErrHeaderNotFound`
-- :class:`SfErrLabelNotFound`
-- :class:`SfErrMotorNotFound`
-- :class:`SfErrPositionNotFound`
-- :class:`SfErrLineEmpty`
-- :class:`SfErrUserNotFound`
-- :class:`SfErrColNotFound`
-- :class:`SfErrMcaNotFound`
-
-"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "11/08/2017"
-
-import os.path
-import logging
-import numpy
-import re
-import sys
-
-_logger = logging.getLogger(__name__)
-
-cimport cython
-from libc.stdlib cimport free
-
-cimport silx.io.specfile_wrapper as specfile_wrapper
-
-
-SF_ERR_NO_ERRORS = 0
-SF_ERR_FILE_OPEN = 2
-SF_ERR_SCAN_NOT_FOUND = 7
-
-
-# custom errors
-class SfError(Exception):
- """Base exception inherited by all exceptions raised when a
- C function from the legacy SpecFile library returns an error
- code.
- """
- pass
-
-class SfErrMemoryAlloc(SfError, MemoryError): pass
-class SfErrFileOpen(SfError, IOError): pass
-class SfErrFileClose(SfError, IOError): pass
-class SfErrFileRead(SfError, IOError): pass
-class SfErrFileWrite(SfError, IOError): pass
-class SfErrLineNotFound(SfError, KeyError): pass
-class SfErrScanNotFound(SfError, IndexError): pass
-class SfErrHeaderNotFound(SfError, KeyError): pass
-class SfErrLabelNotFound(SfError, KeyError): pass
-class SfErrMotorNotFound(SfError, KeyError): pass
-class SfErrPositionNotFound(SfError, KeyError): pass
-class SfErrLineEmpty(SfError, IOError): pass
-class SfErrUserNotFound(SfError, KeyError): pass
-class SfErrColNotFound(SfError, KeyError): pass
-class SfErrMcaNotFound(SfError, IndexError): pass
-
-
-ERRORS = {
- 1: SfErrMemoryAlloc,
- 2: SfErrFileOpen,
- 3: SfErrFileClose,
- 4: SfErrFileRead,
- 5: SfErrFileWrite,
- 6: SfErrLineNotFound,
- 7: SfErrScanNotFound,
- 8: SfErrHeaderNotFound,
- 9: SfErrLabelNotFound,
- 10: SfErrMotorNotFound,
- 11: SfErrPositionNotFound,
- 12: SfErrLineEmpty,
- 13: SfErrUserNotFound,
- 14: SfErrColNotFound,
- 15: SfErrMcaNotFound,
-}
-
-
-class SfNoMcaError(SfError):
- """Custom exception raised when ``SfNoMca()`` returns ``-1``
- """
- pass
-
-
-class MCA(object):
- """
-
- :param scan: Parent Scan instance
- :type scan: :class:`Scan`
-
- :var calibration: MCA calibration :math:`(a, b, c)` (as in
- :math:`a + b x + c x²`) from ``#@CALIB`` scan header.
- :type calibration: list of 3 floats, default ``[0., 1., 0.]``
- :var channels: MCA channels list from ``#@CHANN`` scan header.
- In the absence of a ``#@CHANN`` header, this attribute is a list
- ``[0, …, N-1]`` where ``N`` is the length of the first spectrum.
- In the absence of MCA spectra, this attribute defaults to ``None``.
- :type channels: list of int
-
- This class provides access to Multi-Channel Analysis data. A :class:`MCA`
- instance can be indexed to access 1D numpy arrays representing single
- MCA spectra.
-
- To create a :class:`MCA` instance, you must provide a parent :class:`Scan`
- instance, which in turn will provide a reference to the original
- :class:`SpecFile` instance::
-
- sf = SpecFile("/path/to/specfile.dat")
- scan2 = Scan(sf, scan_index=2)
- mcas_in_scan2 = MCA(scan2)
- for i in len(mcas_in_scan2):
- mca_data = mcas_in_scan2[i]
- ... # do some something with mca_data (1D numpy array)
-
- A more pythonic way to do the same work, without having to explicitly
- instantiate ``scan`` and ``mcas_in_scan``, would be::
-
- sf = SpecFile("specfilename.dat")
- # scan2 from previous example can be referred to as sf[2]
- # mcas_in_scan2 from previous example can be referred to as scan2.mca
- for mca_data in sf[2].mca:
- ... # do some something with mca_data (1D numpy array)
-
- """
- def __init__(self, scan):
- self._scan = scan
-
- # Header dict
- self._header = scan.mca_header_dict
-
- self.calibration = []
- """List of lists of calibration values,
- one list of 3 floats per MCA device or a single list applying to
- all devices """
- self._parse_calibration()
-
- self.channels = []
- """List of lists of channels,
- one list of integers per MCA device or a single list applying to
- all devices"""
- self._parse_channels()
-
- def _parse_channels(self):
- """Fill :attr:`channels`"""
- # Channels list
- if "CHANN" in self._header:
- chann_lines = self._header["CHANN"].split("\n")
- all_chann_values = [chann_line.split() for chann_line in chann_lines]
- for one_line_chann_values in all_chann_values:
- length, start, stop, increment = map(int, one_line_chann_values)
- self.channels.append(list(range(start, stop + 1, increment)))
- elif len(self):
- # in the absence of #@CHANN, use shape of first MCA
- length = self[0].shape[0]
- start, stop, increment = (0, length - 1, 1)
- self.channels.append(list(range(start, stop + 1, increment)))
-
- def _parse_calibration(self):
- """Fill :attr:`calibration`"""
- # Channels list
- if "CALIB" in self._header:
- calib_lines = self._header["CALIB"].split("\n")
- all_calib_values = [calib_line.split() for calib_line in calib_lines]
- for one_line_calib_values in all_calib_values:
- self.calibration.append(list(map(float, one_line_calib_values)))
- else:
- # in the absence of #@calib, use default
- self.calibration.append([0., 1., 0.])
-
- def __len__(self):
- """
-
- :return: Number of mca in Scan
- :rtype: int
- """
- return self._scan._specfile.number_of_mca(self._scan.index)
-
- def __getitem__(self, key):
- """Return a single MCA data line
-
- :param key: 0-based index of MCA within Scan
- :type key: int
-
- :return: Single MCA
- :rtype: 1D numpy array
- """
- if not len(self):
- raise IndexError("No MCA spectrum found in this scan")
-
- if isinstance(key, (int, long)):
- mca_index = key
- # allow negative index, like lists
- if mca_index < 0:
- mca_index = len(self) + mca_index
- else:
- raise TypeError("MCA index should be an integer (%s provided)" %
- (type(key)))
-
- if not 0 <= mca_index < len(self):
- msg = "MCA index must be in range 0-%d" % (len(self) - 1)
- raise IndexError(msg)
-
- return self._scan._specfile.get_mca(self._scan.index,
- mca_index)
-
- def __iter__(self):
- """Return the next MCA data line each time this method is called.
-
- :return: Single MCA
- :rtype: 1D numpy array
- """
- for mca_index in range(len(self)):
- yield self._scan._specfile.get_mca(self._scan.index, mca_index)
-
-
-def _add_or_concatenate(dictionary, key, value):
- """If key doesn't exist in dictionary, create a new ``key: value`` pair.
- Else append/concatenate the new value to the existing one
- """
- try:
- if key not in dictionary:
- dictionary[key] = value
- else:
- dictionary[key] += "\n" + value
- except TypeError:
- raise TypeError("Parameter value must be a string.")
-
-
-class Scan(object):
- """
-
- :param specfile: Parent SpecFile from which this scan is extracted.
- :type specfile: :class:`SpecFile`
- :param scan_index: Unique index defining the scan in the SpecFile
- :type scan_index: int
-
- Interface to access a SpecFile scan
-
- A scan is a block of descriptive header lines followed by a 2D data array.
-
- Following three ways of accessing a scan are equivalent::
-
- sf = SpecFile("/path/to/specfile.dat")
-
- # Explicit class instantiation
- scan2 = Scan(sf, scan_index=2)
-
- # 0-based index on a SpecFile object
- scan2 = sf[2]
-
- # Using a "n.m" key (scan number starting with 1, scan order)
- scan2 = sf["3.1"]
- """
- def __init__(self, specfile, scan_index):
- self._specfile = specfile
-
- self._index = scan_index
- self._number = specfile.number(scan_index)
- self._order = specfile.order(scan_index)
-
- self._scan_header_lines = self._specfile.scan_header(self._index)
- self._file_header_lines = self._specfile.file_header(self._index)
-
- if self._file_header_lines == self._scan_header_lines:
- self._file_header_lines = []
- self._header = self._file_header_lines + self._scan_header_lines
-
- self._scan_header_dict = {}
- self._mca_header_dict = {}
- for line in self._scan_header_lines:
- match = re.search(r"#(\w+) *(.*)", line)
- match_mca = re.search(r"#@(\w+) *(.*)", line)
- if match:
- hkey = match.group(1).lstrip("#").strip()
- hvalue = match.group(2).strip()
- _add_or_concatenate(self._scan_header_dict, hkey, hvalue)
- elif match_mca:
- hkey = match_mca.group(1).lstrip("#").strip()
- hvalue = match_mca.group(2).strip()
- _add_or_concatenate(self._mca_header_dict, hkey, hvalue)
- else:
- # this shouldn't happen
- _logger.warning("Unable to parse scan header line " + line)
-
- self._labels = []
- if self.record_exists_in_hdr('L'):
- try:
- self._labels = self._specfile.labels(self._index)
- except SfErrLineNotFound:
- # SpecFile.labels raises an IndexError when encountering
- # a Scan with no data, even if the header exists.
- L_header = re.sub(r" {2,}", " ", # max. 2 spaces
- self._scan_header_dict["L"])
- self._labels = L_header.split(" ")
-
- self._file_header_dict = {}
- for line in self._file_header_lines:
- match = re.search(r"#(\w+) *(.*)", line)
- if match:
- # header type
- hkey = match.group(1).lstrip("#").strip()
- hvalue = match.group(2).strip()
- _add_or_concatenate(self._file_header_dict, hkey, hvalue)
- else:
- _logger.warning("Unable to parse file header line " + line)
-
- self._motor_names = self._specfile.motor_names(self._index)
- self._motor_positions = self._specfile.motor_positions(self._index)
-
- self._data = None
- self._mca = None
-
- @cython.embedsignature(False)
- @property
- def index(self):
- """Unique scan index 0 - len(specfile)-1
-
- This attribute is implemented as a read-only property as changing
- its value may cause nasty side-effects (such as loading data from a
- different scan without updating the header accordingly."""
- return self._index
-
- @cython.embedsignature(False)
- @property
- def number(self):
- """First value on #S line (as int)"""
- return self._number
-
- @cython.embedsignature(False)
- @property
- def order(self):
- """Order can be > 1 if the same number is repeated in a specfile"""
- return self._order
-
- @cython.embedsignature(False)
- @property
- def header(self):
- """List of raw header lines (as a list of strings).
-
- This includes the file header, the scan header and possibly a MCA
- header.
- """
- return self._header
-
- @cython.embedsignature(False)
- @property
- def scan_header(self):
- """List of raw scan header lines (as a list of strings).
- """
- return self._scan_header_lines
-
- @cython.embedsignature(False)
- @property
- def file_header(self):
- """List of raw file header lines (as a list of strings).
- """
- return self._file_header_lines
-
- @cython.embedsignature(False)
- @property
- def scan_header_dict(self):
- """
- Dictionary of scan header strings, keys without the leading``#``
- (e.g. ``scan_header_dict["S"]``).
- Note: this does not include MCA header lines starting with ``#@``.
- """
- return self._scan_header_dict
-
- @cython.embedsignature(False)
- @property
- def mca_header_dict(self):
- """
- Dictionary of MCA header strings, keys without the leading ``#@``
- (e.g. ``mca_header_dict["CALIB"]``).
- """
- return self._mca_header_dict
-
- @cython.embedsignature(False)
- @property
- def file_header_dict(self):
- """
- Dictionary of file header strings, keys without the leading ``#``
- (e.g. ``file_header_dict["F"]``).
- """
- return self._file_header_dict
-
- @cython.embedsignature(False)
- @property
- def labels(self):
- """
- List of data column headers from ``#L`` scan header
- """
- return self._labels
-
- @cython.embedsignature(False)
- @property
- def data(self):
- """Scan data as a 2D numpy.ndarray with the usual attributes
- (e.g. data.shape).
-
- The first index is the detector, the second index is the sample index.
- """
- if self._data is None:
- self._data = numpy.transpose(self._specfile.data(self._index))
-
- return self._data
-
- @cython.embedsignature(False)
- @property
- def mca(self):
- """MCA data in this scan.
-
- Each multichannel analysis is a 1D numpy array. Metadata about
- MCA data is to be found in :py:attr:`mca_header`.
-
- :rtype: :class:`MCA`
- """
- if self._mca is None:
- self._mca = MCA(self)
- return self._mca
-
- @cython.embedsignature(False)
- @property
- def motor_names(self):
- """List of motor names from the ``#O`` file header line.
- """
- return self._motor_names
-
- @cython.embedsignature(False)
- @property
- def motor_positions(self):
- """List of motor positions as floats from the ``#P`` scan header line.
- """
- return self._motor_positions
-
- def record_exists_in_hdr(self, record):
- """Check whether a scan header line exists.
-
- This should be used before attempting to retrieve header information
- using a C function that may crash with a *segmentation fault* if the
- header isn't defined in the SpecFile.
-
- :param record: single upper case letter corresponding to the
- header you want to test (e.g. ``L`` for labels)
- :type record: str
-
- :return: True or False
- :rtype: boolean
- """
- for line in self._header:
- if line.startswith("#" + record):
- return True
- return False
-
- def data_line(self, line_index):
- """Returns data for a given line of this scan.
-
- .. note::
-
- A data line returned by this method, corresponds to a data line
- in the original specfile (a series of data points, one per
- detector). In the :attr:`data` array, this line index corresponds
- to the index in the second dimension (~ column) of the array.
-
- :param line_index: Index of data line to retrieve (starting with 0)
- :type line_index: int
-
- :return: Line data as a 1D array of doubles
- :rtype: numpy.ndarray
- """
- # attribute data corresponds to a transposed version of the original
- # specfile data (where detectors correspond to columns)
- return self.data[:, line_index]
-
- def data_column_by_name(self, label):
- """Returns a data column
-
- :param label: Label of data column to retrieve, as defined on the
- ``#L`` line of the scan header.
- :type label: str
-
- :return: Line data as a 1D array of doubles
- :rtype: numpy.ndarray
- """
- try:
- ret = self._specfile.data_column_by_name(self._index, label)
- except SfErrLineNotFound:
- # Could be a "#C Scan aborted after 0 points"
- _logger.warning("Cannot get data column %s in scan %d.%d",
- label, self.number, self.order)
- ret = numpy.empty((0, ), numpy.double)
- return ret
-
- def motor_position_by_name(self, name):
- """Returns the position for a given motor
-
- :param name: Name of motor, as defined on the ``#O`` line of the
- file header.
- :type name: str
-
- :return: Motor position
- :rtype: float
- """
- return self._specfile.motor_position_by_name(self._index, name)
-
-
-def _string_to_char_star(string_):
- """Convert a string to ASCII encoded bytes when using python3"""
- if sys.version_info[0] >= 3 and not isinstance(string_, bytes):
- return bytes(string_, "ascii")
- return string_
-
-
-def is_specfile(filename):
- """Test if a file is a SPEC file, by checking if one of the first two
- lines starts with *#F* (SPEC file header) or *#S* (scan header).
-
- :param str filename: File path
- :return: *True* if file is a SPEC file, *False* if it is not a SPEC file
- :rtype: bool
- """
- if not os.path.isfile(filename):
- return False
- # test for presence of #S or #F in first 10 lines
- with open(filename, "rb") as f:
- chunk = f.read(2500)
- for i, line in enumerate(chunk.split(b"\n")):
- if line.startswith(b"#S ") or line.startswith(b"#F "):
- return True
- if i >= 10:
- break
- return False
-
-
-cdef class SpecFile(object):
- """
-
- :param filename: Path of the SpecFile to read
-
- This class wraps the main data and header access functions of the C
- SpecFile library.
- """
-
- cdef:
- specfile_wrapper.SpecFileHandle *handle
- str filename
-
- def __cinit__(self, filename):
- cdef int error = 0
- self.handle = NULL
-
- if is_specfile(filename):
- filename = _string_to_char_star(filename)
- self.handle = specfile_wrapper.SfOpen(filename, &error)
- if error:
- self._handle_error(error)
- else:
- # handle_error takes care of raising the correct error,
- # this causes the destructor to be called
- self._handle_error(SF_ERR_FILE_OPEN)
-
- def __init__(self, filename):
- if not isinstance(filename, str):
- # decode bytes to str in python 3, str to unicode in python 2
- self.filename = filename.decode()
- else:
- self.filename = filename
-
- def __dealloc__(self):
- """Destructor: Calls SfClose(self.handle)"""
- self.close()
-
- def close(self):
- """Close the file descriptor"""
- # handle is NULL if SfOpen failed
- if self.handle:
- if specfile_wrapper.SfClose(self.handle):
- _logger.warning("Error while closing SpecFile")
- self.handle = NULL
-
- def __len__(self):
- """Return the number of scans in the SpecFile
- """
- return specfile_wrapper.SfScanNo(self.handle)
-
- def __iter__(self):
- """Return the next :class:`Scan` in a SpecFile each time this method
- is called.
-
- This usually happens when the python built-in function ``next()`` is
- called with a :class:`SpecFile` instance as a parameter, or when a
- :class:`SpecFile` instance is used as an iterator (e.g. in a ``for``
- loop).
- """
- for scan_index in range(len(self)):
- yield Scan(self, scan_index)
-
- def __getitem__(self, key):
- """Return a :class:`Scan` object.
-
- This special method is called when a :class:`SpecFile` instance is
- accessed as a dictionary (e.g. ``sf[key]``).
-
- :param key: 0-based scan index or ``"n.m"`` key, where ``n`` is the scan
- number defined on the ``#S`` header line and ``m`` is the order
- :type key: int or str
-
- :return: Scan defined by its 0-based index or its ``"n.m"`` key
- :rtype: :class:`Scan`
- """
- msg = "The scan identification key can be an integer representing "
- msg += "the unique scan index or a string 'N.M' with N being the scan"
- msg += " number and M the order (eg '2.3')."
-
- if isinstance(key, int):
- scan_index = key
- # allow negative index, like lists
- if scan_index < 0:
- scan_index = len(self) + scan_index
- else:
- try:
- (number, order) = map(int, key.split("."))
- scan_index = self.index(number, order)
- except (ValueError, SfErrScanNotFound, KeyError):
- # int() can raise a value error
- raise KeyError(msg + "\nValid keys: '" +
- "', '".join(self.keys()) + "'")
- except AttributeError:
- # e.g. "AttrErr: 'float' object has no attribute 'split'"
- raise TypeError(msg)
-
- if not 0 <= scan_index < len(self):
- msg = "Scan index must be in range 0-%d" % (len(self) - 1)
- raise IndexError(msg)
-
- return Scan(self, scan_index)
-
- def keys(self):
- """Returns list of scan keys (eg ``['1.1', '2.1',...]``).
-
- :return: list of scan keys
- :rtype: list of strings
- """
- ret_list = []
- list_of_numbers = self._list()
- count = {}
-
- for number in list_of_numbers:
- if number not in count:
- count[number] = 1
- else:
- count[number] += 1
- ret_list.append(u'%d.%d' % (number, count[number]))
-
- return ret_list
-
- def __contains__(self, key):
- """Return ``True`` if ``key`` is a valid scan key.
- Valid keys can be a string such as ``"1.1"`` or a 0-based scan index.
- """
- return key in (self.keys() + list(range(len(self))))
-
- def _get_error_string(self, error_code):
- """Returns the error message corresponding to the error code.
-
- :param code: Error code
- :type code: int
- :return: Human readable error message
- :rtype: str
- """
- return (<bytes> specfile_wrapper.SfError(error_code)).decode()
-
- def _handle_error(self, error_code):
- """Inspect error code, raise adequate error type if necessary.
-
- :param code: Error code
- :type code: int
- """
- error_message = self._get_error_string(error_code)
- if error_code in ERRORS:
- raise ERRORS[error_code](error_message)
-
- def index(self, scan_number, scan_order=1):
- """Returns scan index from scan number and order.
-
- :param scan_number: Scan number (possibly non-unique).
- :type scan_number: int
- :param scan_order: Scan order.
- :type scan_order: int default 1
-
- :return: Unique scan index
- :rtype: int
-
-
- Scan indices are increasing from ``0`` to ``len(self)-1`` in the
- order in which they appear in the file.
- Scan numbers are defined by users and are not necessarily unique.
- The scan order for a given scan number increments each time the scan
- number appears in a given file.
- """
- idx = specfile_wrapper.SfIndex(self.handle, scan_number, scan_order)
- if idx == -1:
- self._handle_error(SF_ERR_SCAN_NOT_FOUND)
- return idx - 1
-
- def number(self, scan_index):
- """Returns scan number from scan index.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: User defined scan number.
- :rtype: int
- """
- idx = specfile_wrapper.SfNumber(self.handle, scan_index + 1)
- if idx == -1:
- self._handle_error(SF_ERR_SCAN_NOT_FOUND)
- return idx
-
- def order(self, scan_index):
- """Returns scan order from scan index.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: Scan order (sequential number incrementing each time a
- non-unique occurrence of a scan number is encountered).
- :rtype: int
- """
- ordr = specfile_wrapper.SfOrder(self.handle, scan_index + 1)
- if ordr == -1:
- self._handle_error(SF_ERR_SCAN_NOT_FOUND)
- return ordr
-
- def _list(self):
- """see documentation of :meth:`list`
- """
- cdef:
- long *scan_numbers
- int error = SF_ERR_NO_ERRORS
-
- scan_numbers = specfile_wrapper.SfList(self.handle, &error)
- self._handle_error(error)
-
- ret_list = []
- for i in range(len(self)):
- ret_list.append(scan_numbers[i])
-
- free(scan_numbers)
- return ret_list
-
- def list(self):
- """Returns list (1D numpy array) of scan numbers in SpecFile.
-
- :return: list of scan numbers (from `` #S`` lines) in the same order
- as in the original SpecFile (e.g ``[1, 1, 2, 3, …]``).
- :rtype: numpy array
- """
- # this method is overloaded in specfilewrapper to output a string
- # representation of the list
- return self._list()
-
- def data(self, scan_index):
- """Returns data for the specified scan index.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: Complete scan data as a 2D array of doubles
- :rtype: numpy.ndarray
- """
- cdef:
- double** mydata
- long* data_info
- int i, j
- int error = SF_ERR_NO_ERRORS
- long nlines, ncolumns, regular
- double[:, :] ret_array
-
- sfdata_error = specfile_wrapper.SfData(self.handle,
- scan_index + 1,
- &mydata,
- &data_info,
- &error)
- if sfdata_error == -1 and not error:
- # this has happened in some situations with empty scans (#1759)
- _logger.warning("SfData returned -1 without an error."
- " Assuming aborted scan.")
-
- self._handle_error(error)
-
- if <long>data_info != 0:
- nlines = data_info[0]
- ncolumns = data_info[1]
- regular = data_info[2]
- else:
- nlines = 0
- ncolumns = 0
- regular = 0
-
- ret_array = numpy.empty((nlines, ncolumns), dtype=numpy.double)
-
- for i in range(nlines):
- for j in range(ncolumns):
- ret_array[i, j] = mydata[i][j]
-
- specfile_wrapper.freeArrNZ(<void ***>&mydata, nlines)
- free(data_info)
- return numpy.asarray(ret_array)
-
- def data_column_by_name(self, scan_index, label):
- """Returns data column for the specified scan index and column label.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
- :param label: Label of data column, as defined in the ``#L`` line
- of the scan header.
- :type label: str
-
- :return: Data column as a 1D array of doubles
- :rtype: numpy.ndarray
- """
- cdef:
- double* data_column
- long i, nlines
- int error = SF_ERR_NO_ERRORS
- double[:] ret_array
-
- label = _string_to_char_star(label)
-
- nlines = specfile_wrapper.SfDataColByName(self.handle,
- scan_index + 1,
- label,
- &data_column,
- &error)
- self._handle_error(error)
-
- if nlines == -1:
- # this can happen on empty scans in some situations (see #1759)
- _logger.warning("SfDataColByName returned -1 without an error."
- " Assuming aborted scan.")
- nlines = 0
-
- ret_array = numpy.empty((nlines,), dtype=numpy.double)
-
- for i in range(nlines):
- ret_array[i] = data_column[i]
-
- free(data_column)
- return numpy.asarray(ret_array)
-
- def scan_header(self, scan_index):
- """Return list of scan header lines.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: List of raw scan header lines
- :rtype: list of str
- """
- cdef:
- char** lines
- int error = SF_ERR_NO_ERRORS
-
- nlines = specfile_wrapper.SfHeader(self.handle,
- scan_index + 1,
- "", # no pattern matching
- &lines,
- &error)
-
- self._handle_error(error)
-
- lines_list = []
- for i in range(nlines):
- line = <bytes>lines[i].decode()
- lines_list.append(line)
-
- specfile_wrapper.freeArrNZ(<void***>&lines, nlines)
- return lines_list
-
- def file_header(self, scan_index=0):
- """Return list of file header lines.
-
- A file header contains all lines between a ``#F`` header line and
- a ``#S`` header line (start of scan). We need to specify a scan
- number because there can be more than one file header in a given file.
- A file header applies to all subsequent scans, until a new file
- header is defined.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: List of raw file header lines
- :rtype: list of str
- """
- cdef:
- char** lines
- int error = SF_ERR_NO_ERRORS
-
- nlines = specfile_wrapper.SfFileHeader(self.handle,
- scan_index + 1,
- "", # no pattern matching
- &lines,
- &error)
- self._handle_error(error)
-
- lines_list = []
- for i in range(nlines):
- line = <bytes>lines[i].decode()
- lines_list.append(line)
-
- specfile_wrapper.freeArrNZ(<void***>&lines, nlines)
- return lines_list
-
- def columns(self, scan_index):
- """Return number of columns in a scan from the ``#N`` header line
- (without ``#N`` and scan number)
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: Number of columns in scan from ``#N`` line
- :rtype: int
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
-
- ncolumns = specfile_wrapper.SfNoColumns(self.handle,
- scan_index + 1,
- &error)
- self._handle_error(error)
-
- return ncolumns
-
- def command(self, scan_index):
- """Return ``#S`` line (without ``#S`` and scan number)
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: S line
- :rtype: str
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
-
- s_record = <bytes> specfile_wrapper.SfCommand(self.handle,
- scan_index + 1,
- &error)
- self._handle_error(error)
-
- return s_record.decode()
-
- def date(self, scan_index=0):
- """Return date from ``#D`` line
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: Date from ``#D`` line
- :rtype: str
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
-
- d_line = <bytes> specfile_wrapper.SfDate(self.handle,
- scan_index + 1,
- &error)
- self._handle_error(error)
-
- return d_line.decode()
-
- def labels(self, scan_index):
- """Return all labels from ``#L`` line
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: All labels from ``#L`` line
- :rtype: list of strings
- """
- cdef:
- char** all_labels
- int error = SF_ERR_NO_ERRORS
-
- nlabels = specfile_wrapper.SfAllLabels(self.handle,
- scan_index + 1,
- &all_labels,
- &error)
- self._handle_error(error)
-
- labels_list = []
- for i in range(nlabels):
- labels_list.append(<bytes>all_labels[i].decode())
-
- specfile_wrapper.freeArrNZ(<void***>&all_labels, nlabels)
- return labels_list
-
- def motor_names(self, scan_index=0):
- """Return all motor names from ``#O`` lines
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.If not specified, defaults to 0 (meaning the
- function returns motors names associated with the first scan).
- This parameter makes a difference only if there are more than
- on file header in the file, in which case the file header applies
- to all following scans until a new file header appears.
- :type scan_index: int
-
- :return: All motor names
- :rtype: list of strings
- """
- cdef:
- char** all_motors
- int error = SF_ERR_NO_ERRORS
-
- nmotors = specfile_wrapper.SfAllMotors(self.handle,
- scan_index + 1,
- &all_motors,
- &error)
- self._handle_error(error)
-
- motors_list = []
- for i in range(nmotors):
- motors_list.append(<bytes>all_motors[i].decode())
-
- specfile_wrapper.freeArrNZ(<void***>&all_motors, nmotors)
- return motors_list
-
- def motor_positions(self, scan_index):
- """Return all motor positions
-
- :param scan_index: Unique scan index between ``0``
- and ``len(self)-1``.
- :type scan_index: int
-
- :return: All motor positions
- :rtype: list of double
- """
- cdef:
- double* motor_positions
- int error = SF_ERR_NO_ERRORS
-
- nmotors = specfile_wrapper.SfAllMotorPos(self.handle,
- scan_index + 1,
- &motor_positions,
- &error)
- self._handle_error(error)
-
- motor_positions_list = []
- for i in range(nmotors):
- motor_positions_list.append(motor_positions[i])
-
- free(motor_positions)
- return motor_positions_list
-
- def motor_position_by_name(self, scan_index, name):
- """Return motor position
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: Specified motor position
- :rtype: double
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
-
- name = _string_to_char_star(name)
-
- motor_position = specfile_wrapper.SfMotorPosByName(self.handle,
- scan_index + 1,
- name,
- &error)
- self._handle_error(error)
-
- return motor_position
-
- def number_of_mca(self, scan_index):
- """Return number of mca spectra in a scan.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: Number of mca spectra.
- :rtype: int
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
-
- num_mca = specfile_wrapper.SfNoMca(self.handle,
- scan_index + 1,
- &error)
- # error code updating isn't implemented in SfNoMCA
- if num_mca == -1:
- raise SfNoMcaError("Failed to retrieve number of MCA " +
- "(SfNoMca returned -1)")
- return num_mca
-
- def mca_calibration(self, scan_index):
- """Return MCA calibration in the form :math:`a + b x + c x²`
-
- Raise a KeyError if there is no ``@CALIB`` line in the scan header.
-
- :param scan_index: Unique scan index between ``0`` and
- ``len(self)-1``.
- :type scan_index: int
-
- :return: MCA calibration as a list of 3 values :math:`(a, b, c)`
- :rtype: list of floats
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
- double* mca_calib
-
- mca_calib_error = specfile_wrapper.SfMcaCalib(self.handle,
- scan_index + 1,
- &mca_calib,
- &error)
-
- # error code updating isn't implemented in SfMcaCalib
- if mca_calib_error:
- raise KeyError("MCA calibration line (@CALIB) not found")
-
- mca_calib_list = []
- for i in range(3):
- mca_calib_list.append(mca_calib[i])
-
- free(mca_calib)
- return mca_calib_list
-
- def get_mca(self, scan_index, mca_index):
- """Return one MCA spectrum
-
- :param scan_index: Unique scan index between ``0`` and ``len(self)-1``.
- :type scan_index: int
- :param mca_index: Index of MCA in the scan
- :type mca_index: int
-
- :return: MCA spectrum
- :rtype: 1D numpy array
- """
- cdef:
- int error = SF_ERR_NO_ERRORS
- double* mca_data
- long len_mca
- double[:] ret_array
-
- len_mca = specfile_wrapper.SfGetMca(self.handle,
- scan_index + 1,
- mca_index + 1,
- &mca_data,
- &error)
- self._handle_error(error)
-
- ret_array = numpy.empty((len_mca,), dtype=numpy.double)
-
- for i in range(len_mca):
- ret_array[i] = mca_data[i]
-
- free(mca_data)
- return numpy.asarray(ret_array)
diff --git a/silx/io/spech5.py b/silx/io/spech5.py
deleted file mode 100644
index 1eaec7c..0000000
--- a/silx/io/spech5.py
+++ /dev/null
@@ -1,883 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""This module provides a h5py-like API to access SpecFile data.
-
-API description
-+++++++++++++++
-
-Specfile data structure exposed by this API:
-
-::
-
- /
- 1.1/
- title = "…"
- start_time = "…"
- instrument/
- specfile/
- file_header = "…"
- scan_header = "…"
- positioners/
- motor_name = value
- …
- mca_0/
- data = …
- calibration = …
- channels = …
- preset_time = …
- elapsed_time = …
- live_time = …
-
- mca_1/
- …
- …
- measurement/
- colname0 = …
- colname1 = …
- …
- mca_0/
- data -> /1.1/instrument/mca_0/data
- info -> /1.1/instrument/mca_0/
- …
- sample/
- ub_matrix = …
- unit_cell = …
- unit_cell_abc = …
- unit_cell_alphabetagamma = …
- 2.1/
- …
-
-``file_header`` and ``scan_header`` are the raw headers as they
-appear in the original file, as a string of lines separated by newline (``\\n``) characters.
-
-The title is the content of the ``#S`` scan header line without the leading
-``#S`` and without the scan number (e.g ``"ascan ss1vo -4.55687 -0.556875 40 0.2"``).
-
-The start time is converted to ISO8601 format (``"2016-02-23T22:49:05Z"``),
-if the original date format is standard.
-
-Numeric datasets are stored in *float32* format, except for scalar integers
-which are stored as *int64*.
-
-Motor positions (e.g. ``/1.1/instrument/positioners/motor_name``) can be
-1D numpy arrays if they are measured as scan data, or else scalars as defined
-on ``#P`` scan header lines. A simple test is done to check if the motor name
-is also a data column header defined in the ``#L`` scan header line.
-
-Scan data (e.g. ``/1.1/measurement/colname0``) is accessed by column,
-the dataset name ``colname0`` being the column label as defined in the ``#L``
-scan header line.
-
-If a ``/`` character is present in a column label or in a motor name in the
-original SPEC file, it will be substituted with a ``%`` character in the
-corresponding dataset name.
-
-MCA data is exposed as a 2D numpy array containing all spectra for a given
-analyser. The number of analysers is calculated as the number of MCA spectra
-per scan data line. Demultiplexing is then performed to assign the correct
-spectra to a given analyser.
-
-MCA calibration is an array of 3 scalars, from the ``#@CALIB`` header line.
-It is identical for all MCA analysers, as there can be only one
-``#@CALIB`` line per scan.
-
-MCA channels is an array containing all channel numbers. This information is
-computed from the ``#@CHANN`` scan header line (if present), or computed from
-the shape of the first spectrum in a scan (``[0, … len(first_spectrum] - 1]``).
-
-Accessing data
-++++++++++++++
-
-Data and groups are accessed in :mod:`h5py` fashion::
-
- from silx.io.spech5 import SpecH5
-
- # Open a SpecFile
- sfh5 = SpecH5("test.dat")
-
- # using SpecH5 as a regular group to access scans
- scan1group = sfh5["1.1"]
- instrument_group = scan1group["instrument"]
-
- # alternative: full path access
- measurement_group = sfh5["/1.1/measurement"]
-
- # accessing a scan data column by name as a 1D numpy array
- data_array = measurement_group["Pslit HGap"]
-
- # accessing all mca-spectra for one MCA device
- mca_0_spectra = measurement_group["mca_0/data"]
-
-:class:`SpecH5` files and groups provide a :meth:`keys` method::
-
- >>> sfh5.keys()
- ['96.1', '97.1', '98.1']
- >>> sfh5['96.1'].keys()
- ['title', 'start_time', 'instrument', 'measurement']
-
-They can also be treated as iterators:
-
-.. code-block:: python
-
- from silx.io import is_dataset
-
- for scan_group in SpecH5("test.dat"):
- dataset_names = [item.name in scan_group["measurement"] if
- is_dataset(item)]
- print("Found data columns in scan " + scan_group.name)
- print(", ".join(dataset_names))
-
-You can test for existence of data or groups::
-
- >>> "/1.1/measurement/Pslit HGap" in sfh5
- True
- >>> "positioners" in sfh5["/2.1/instrument"]
- True
- >>> "spam" in sfh5["1.1"]
- False
-
-.. note::
-
- Text used to be stored with a dtype ``numpy.string_`` in silx versions
- prior to *0.7.0*. The type ``numpy.string_`` is a byte-string format.
- The consequence of this is that you had to decode strings before using
- them in **Python 3**::
-
- >>> from silx.io.spech5 import SpecH5
- >>> sfh5 = SpecH5("31oct98.dat")
- >>> sfh5["/68.1/title"]
- b'68 ascan tx3 -28.5 -24.5 20 0.5'
- >>> sfh5["/68.1/title"].decode()
- '68 ascan tx3 -28.5 -24.5 20 0.5'
-
- From silx version *0.7.0* onwards, text is now stored as unicode. This
- corresponds to the default text type in python 3, and to the *unicode*
- type in Python 2.
-
- To be on the safe side, you can test for the presence of a *decode*
- attribute, to ensure that you always work with unicode text::
-
- >>> title = sfh5["/68.1/title"]
- >>> if hasattr(title, "decode"):
- ... title = title.decode()
-
-"""
-
-import datetime
-import logging
-import re
-import io
-
-import h5py
-import numpy
-import six
-
-from silx import version as silx_version
-from .specfile import SpecFile, SfErrColNotFound
-from . import commonh5
-
-__authors__ = ["P. Knobel", "D. Naudet"]
-__license__ = "MIT"
-__date__ = "17/07/2018"
-
-logger1 = logging.getLogger(__name__)
-
-
-text_dtype = h5py.special_dtype(vlen=six.text_type)
-
-
-def to_h5py_utf8(str_list):
- """Convert a string or a list of strings to a numpy array of
- unicode strings that can be written to HDF5 as utf-8.
-
- This ensures that the type will be consistent between python 2 and
- python 3, if attributes or datasets are saved to an HDF5 file.
- """
- return numpy.array(str_list, dtype=text_dtype)
-
-
-def _get_number_of_mca_analysers(scan):
- """
- :param SpecFile sf: :class:`SpecFile` instance
- """
- number_of_mca_spectra = len(scan.mca)
- # Scan.data is transposed
- number_of_data_lines = scan.data.shape[1]
-
- if not number_of_data_lines == 0:
- # Number of MCA spectra must be a multiple of number of data lines
- assert number_of_mca_spectra % number_of_data_lines == 0
- return number_of_mca_spectra // number_of_data_lines
- elif number_of_mca_spectra:
- # Case of a scan without data lines, only MCA.
- # Our only option is to assume that the number of analysers
- # is the number of #@CHANN lines
- return len(scan.mca.channels)
- else:
- return 0
-
-
-def _motor_in_scan(sf, scan_key, motor_name):
- """
- :param sf: :class:`SpecFile` instance
- :param scan_key: Scan identification key (e.g. ``1.1``)
- :param motor_name: Name of motor as defined in file header lines
- :return: ``True`` if motor exists in scan, else ``False``
- :raise: ``KeyError`` if scan_key not found in SpecFile
- """
- if scan_key not in sf:
- raise KeyError("Scan key %s " % scan_key +
- "does not exist in SpecFile %s" % sf.filename)
- ret = motor_name in sf[scan_key].motor_names
- if not ret and "%" in motor_name:
- motor_name = motor_name.replace("%", "/")
- ret = motor_name in sf[scan_key].motor_names
- return ret
-
-
-def _column_label_in_scan(sf, scan_key, column_label):
- """
- :param sf: :class:`SpecFile` instance
- :param scan_key: Scan identification key (e.g. ``1.1``)
- :param column_label: Column label as defined in scan header
- :return: ``True`` if data column label exists in scan, else ``False``
- :raise: ``KeyError`` if scan_key not found in SpecFile
- """
- if scan_key not in sf:
- raise KeyError("Scan key %s " % scan_key +
- "does not exist in SpecFile %s" % sf.filename)
- ret = column_label in sf[scan_key].labels
- if not ret and "%" in column_label:
- column_label = column_label.replace("%", "/")
- ret = column_label in sf[scan_key].labels
- return ret
-
-
-def _parse_UB_matrix(header_line):
- """Parse G3 header line and return UB matrix
-
- :param str header_line: G3 header line
- :return: UB matrix
- """
- return numpy.array(list(map(float, header_line.split()))).reshape((1, 3, 3))
-
-
-def _ub_matrix_in_scan(scan):
- """Return True if scan header has a G3 line and all values are not 0.
-
- :param scan: specfile.Scan instance
- :return: True or False
- """
- if "G3" not in scan.scan_header_dict:
- return False
- return numpy.any(_parse_UB_matrix(scan.scan_header_dict["G3"]))
-
-
-def _parse_unit_cell(header_line):
- return numpy.array(list(map(float, header_line.split()))[0:6]).reshape((1, 6))
-
-
-def _unit_cell_in_scan(scan):
- """Return True if scan header has a G1 line and all values are not 0.
-
- :param scan: specfile.Scan instance
- :return: True or False
- """
- if "G1" not in scan.scan_header_dict:
- return False
- return numpy.any(_parse_unit_cell(scan.scan_header_dict["G1"]))
-
-
-def _parse_ctime(ctime_lines, analyser_index=0):
- """
- :param ctime_lines: e.g ``@CTIME %f %f %f``, first word ``@CTIME`` optional
- When multiple CTIME lines are present in a scan header, this argument
- is a concatenation of them separated by a ``\\n`` character.
- :param analyser_index: MCA device/analyser index, when multiple devices
- are in a scan.
- :return: (preset_time, live_time, elapsed_time)
- """
- ctime_lines = ctime_lines.lstrip("@CTIME ")
- ctimes_lines_list = ctime_lines.split("\n")
- if len(ctimes_lines_list) == 1:
- # single @CTIME line for all devices
- ctime_line = ctimes_lines_list[0]
- else:
- ctime_line = ctimes_lines_list[analyser_index]
- if not len(ctime_line.split()) == 3:
- raise ValueError("Incorrect format for @CTIME header line " +
- '(expected "@CTIME %f %f %f").')
- return list(map(float, ctime_line.split()))
-
-
-def spec_date_to_iso8601(date, zone=None):
- """Convert SpecFile date to Iso8601.
-
- :param date: Date (see supported formats below)
- :type date: str
- :param zone: Time zone as it appears in a ISO8601 date
-
- Supported formats:
-
- * ``DDD MMM dd hh:mm:ss YYYY``
- * ``DDD YYYY/MM/dd hh:mm:ss YYYY``
-
- where `DDD` is the abbreviated weekday, `MMM` is the month abbreviated
- name, `MM` is the month number (zero padded), `dd` is the weekday number
- (zero padded) `YYYY` is the year, `hh` the hour (zero padded), `mm` the
- minute (zero padded) and `ss` the second (zero padded).
- All names are expected to be in english.
-
- Examples::
-
- >>> spec_date_to_iso8601("Thu Feb 11 09:54:35 2016")
- '2016-02-11T09:54:35'
-
- >>> spec_date_to_iso8601("Sat 2015/03/14 03:53:50")
- '2015-03-14T03:53:50'
- """
- months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul',
- 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
- days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
-
- days_rx = '(?P<day>' + '|'.join(days) + ')'
- months_rx = '(?P<month>' + '|'.join(months) + ')'
- year_rx = r'(?P<year>\d{4})'
- day_nb_rx = r'(?P<day_nb>[0-3 ]\d)'
- month_nb_rx = r'(?P<month_nb>[0-1]\d)'
- hh_rx = r'(?P<hh>[0-2]\d)'
- mm_rx = r'(?P<mm>[0-5]\d)'
- ss_rx = r'(?P<ss>[0-5]\d)'
- tz_rx = r'(?P<tz>[+-]\d\d:\d\d){0,1}'
-
- # date formats must have either month_nb (1..12) or month (Jan, Feb, ...)
- re_tpls = ['{days} {months} {day_nb} {hh}:{mm}:{ss}{tz} {year}',
- '{days} {year}/{month_nb}/{day_nb} {hh}:{mm}:{ss}{tz}']
-
- grp_d = None
-
- for rx in re_tpls:
- full_rx = rx.format(days=days_rx,
- months=months_rx,
- year=year_rx,
- day_nb=day_nb_rx,
- month_nb=month_nb_rx,
- hh=hh_rx,
- mm=mm_rx,
- ss=ss_rx,
- tz=tz_rx)
- m = re.match(full_rx, date)
-
- if m:
- grp_d = m.groupdict()
- break
-
- if not grp_d:
- raise ValueError('Date format not recognized : {0}'.format(date))
-
- year = grp_d['year']
-
- month = grp_d.get('month_nb')
-
- if not month:
- month = '{0:02d}'.format(months.index(grp_d.get('month')) + 1)
-
- day = grp_d['day_nb']
-
- tz = grp_d['tz']
- if not tz:
- tz = zone
-
- time = '{0}:{1}:{2}'.format(grp_d['hh'],
- grp_d['mm'],
- grp_d['ss'])
-
- full_date = '{0}-{1}-{2}T{3}{4}'.format(year,
- month,
- day,
- time,
- tz if tz else '')
- return full_date
-
-
-def _demultiplex_mca(scan, analyser_index):
- """Return MCA data for a single analyser.
-
- Each MCA spectrum is a 1D array. For each analyser, there is one
- spectrum recorded per scan data line. When there are more than a single
- MCA analyser in a scan, the data will be multiplexed. For instance if
- there are 3 analysers, the consecutive spectra for the first analyser must
- be accessed as ``mca[0], mca[3], mca[6]…``.
-
- :param scan: :class:`Scan` instance containing the MCA data
- :param analyser_index: 0-based index referencing the analyser
- :type analyser_index: int
- :return: 2D numpy array containing all spectra for one analyser
- """
- number_of_analysers = _get_number_of_mca_analysers(scan)
- number_of_spectra = len(scan.mca)
- number_of_spectra_per_analyser = number_of_spectra // number_of_analysers
- len_spectrum = len(scan.mca[analyser_index])
-
- mca_array = numpy.empty((number_of_spectra_per_analyser, len_spectrum))
-
- for i in range(number_of_spectra_per_analyser):
- mca_array[i, :] = scan.mca[analyser_index + i * number_of_analysers]
-
- return mca_array
-
-
-# Node classes
-class SpecH5Dataset(object):
- """This convenience class is to be inherited by all datasets, for
- compatibility purpose with code that tests for
- ``isinstance(obj, SpecH5Dataset)``.
-
- This legacy behavior is deprecated. The correct way to test
- if an object is a dataset is to use :meth:`silx.io.utils.is_dataset`.
-
- Datasets must also inherit :class:`SpecH5NodeDataset` or
- :class:`SpecH5LazyNodeDataset` which actually implement all the
- API."""
- pass
-
-
-class SpecH5NodeDataset(commonh5.Dataset, SpecH5Dataset):
- """This class inherits :class:`commonh5.Dataset`, to which it adds
- little extra functionality. The main additional functionality is the
- proxy behavior that allows to mimic the numpy array stored in this
- class.
- """
- def __init__(self, name, data, parent=None, attrs=None):
- # get proper value types, to inherit from numpy
- # attributes (dtype, shape, size)
- if isinstance(data, six.string_types):
- # use unicode (utf-8 when saved to HDF5 output)
- value = to_h5py_utf8(data)
- elif isinstance(data, float):
- # use 32 bits for float scalars
- value = numpy.float32(data)
- elif isinstance(data, int):
- value = numpy.int_(data)
- else:
- # Enforce numpy array
- array = numpy.array(data)
- data_kind = array.dtype.kind
-
- if data_kind in ["S", "U"]:
- value = numpy.asarray(array,
- dtype=text_dtype)
- elif data_kind in ["f"]:
- value = numpy.asarray(array, dtype=numpy.float32)
- else:
- value = array
- commonh5.Dataset.__init__(self, name, value, parent, attrs)
-
- def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
- if hasattr(self[()], item):
- return getattr(self[()], item)
-
- raise AttributeError("SpecH5Dataset has no attribute %s" % item)
-
-
-class SpecH5LazyNodeDataset(commonh5.LazyLoadableDataset, SpecH5Dataset):
- """This class inherits :class:`commonh5.LazyLoadableDataset`,
- to which it adds a proxy behavior that allows to mimic the numpy
- array stored in this class.
-
- The class has to be inherited and the :meth:`_create_data` method has to be
- implemented to return the numpy data exposed by the dataset. This factory
- method is only called once, when the data is needed.
- """
- def __getattr__(self, item):
- """Proxy to underlying numpy array methods.
- """
- if hasattr(self[()], item):
- return getattr(self[()], item)
-
- raise AttributeError("SpecH5Dataset has no attribute %s" % item)
-
- def _create_data(self):
- """
- Factory to create the data exposed by the dataset when it is needed.
-
- It has to be implemented for the class to work.
-
- :rtype: numpy.ndarray
- """
- raise NotImplementedError()
-
-
-class SpecH5Group(object):
- """This convenience class is to be inherited by all groups, for
- compatibility purposes with code that tests for
- ``isinstance(obj, SpecH5Group)``.
-
- This legacy behavior is deprecated. The correct way to test
- if an object is a group is to use :meth:`silx.io.utils.is_group`.
-
- Groups must also inherit :class:`silx.io.commonh5.Group`, which
- actually implements all the methods and attributes."""
- pass
-
-
-class SpecH5(commonh5.File, SpecH5Group):
- """This class opens a SPEC file and exposes it as a *h5py.File*.
-
- It inherits :class:`silx.io.commonh5.Group` (via :class:`commonh5.File`),
- which implements most of its API.
- """
-
- def __init__(self, filename):
- """
- :param filename: Path to SpecFile in filesystem
- :type filename: str
- """
- if isinstance(filename, io.IOBase):
- # see https://github.com/silx-kit/silx/issues/858
- filename = filename.name
-
- self._sf = SpecFile(filename)
-
- attrs = {"NX_class": to_h5py_utf8("NXroot"),
- "file_time": to_h5py_utf8(
- datetime.datetime.now().isoformat()),
- "file_name": to_h5py_utf8(filename),
- "creator": to_h5py_utf8("silx spech5 %s" % silx_version)}
- commonh5.File.__init__(self, filename, attrs=attrs)
-
- for scan_key in self._sf.keys():
- scan = self._sf[scan_key]
- scan_group = ScanGroup(scan_key, parent=self, scan=scan)
- self.add_node(scan_group)
-
- def close(self):
- self._sf.close()
- self._sf = None
-
-
-class ScanGroup(commonh5.Group, SpecH5Group):
- def __init__(self, scan_key, parent, scan):
- """
-
- :param parent: parent Group
- :param str scan_key: Scan key (e.g. "1.1")
- :param scan: specfile.Scan object
- """
- commonh5.Group.__init__(self, scan_key, parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXentry")})
-
- # take title in #S after stripping away scan number and spaces
- s_hdr_line = scan.scan_header_dict["S"]
- title = s_hdr_line.lstrip("0123456789").lstrip()
- self.add_node(SpecH5NodeDataset(name="title",
- data=to_h5py_utf8(title),
- parent=self))
-
- if "D" in scan.scan_header_dict:
- try:
- start_time_str = spec_date_to_iso8601(scan.scan_header_dict["D"])
- except (IndexError, ValueError):
- logger1.warning("Could not parse date format in scan %s header." +
- " Using original date not converted to ISO-8601",
- scan_key)
- start_time_str = scan.scan_header_dict["D"]
- elif "D" in scan.file_header_dict:
- logger1.warning("No #D line in scan %s header. " +
- "Using file header for start_time.",
- scan_key)
- try:
- start_time_str = spec_date_to_iso8601(scan.file_header_dict["D"])
- except (IndexError, ValueError):
- logger1.warning("Could not parse date format in scan %s header. " +
- "Using original date not converted to ISO-8601",
- scan_key)
- start_time_str = scan.file_header_dict["D"]
- else:
- logger1.warning("No #D line in %s header. Setting date to empty string.",
- scan_key)
- start_time_str = ""
- self.add_node(SpecH5NodeDataset(name="start_time",
- data=to_h5py_utf8(start_time_str),
- parent=self))
-
- self.add_node(InstrumentGroup(parent=self, scan=scan))
- self.add_node(MeasurementGroup(parent=self, scan=scan))
- if _unit_cell_in_scan(scan) or _ub_matrix_in_scan(scan):
- self.add_node(SampleGroup(parent=self, scan=scan))
-
-
-class InstrumentGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, scan):
- """
-
- :param parent: parent Group
- :param scan: specfile.Scan object
- """
- commonh5.Group.__init__(self, name="instrument", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXinstrument")})
-
- self.add_node(InstrumentSpecfileGroup(parent=self, scan=scan))
- self.add_node(PositionersGroup(parent=self, scan=scan))
-
- num_analysers = _get_number_of_mca_analysers(scan)
- for anal_idx in range(num_analysers):
- self.add_node(InstrumentMcaGroup(parent=self,
- analyser_index=anal_idx,
- scan=scan))
-
-
-class InstrumentSpecfileGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, scan):
- commonh5.Group.__init__(self, name="specfile", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
- self.add_node(SpecH5NodeDataset(
- name="file_header",
- data=to_h5py_utf8(scan.file_header),
- parent=self,
- attrs={}))
- self.add_node(SpecH5NodeDataset(
- name="scan_header",
- data=to_h5py_utf8(scan.scan_header),
- parent=self,
- attrs={}))
-
-
-class PositionersGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, scan):
- commonh5.Group.__init__(self, name="positioners", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection")})
-
- dataset_info = [] # Store list of positioner's (name, value)
- is_error = False # True if error encountered
-
- for motor_name in scan.motor_names:
- safe_motor_name = motor_name.replace("/", "%")
- if motor_name in scan.labels and scan.data.shape[0] > 0:
- # return a data column if one has the same label as the motor
- motor_value = scan.data_column_by_name(motor_name)
- else:
- # Take value from #P scan header.
- # (may return float("inf") if #P line is missing from scan hdr)
- try:
- motor_value = scan.motor_position_by_name(motor_name)
- except SfErrColNotFound:
- is_error = True
- motor_value = float('inf')
- dataset_info.append((safe_motor_name, motor_value))
-
- if is_error: # Filter-out scalar values
- logger1.warning("Mismatching number of elements in #P and #O: Ignoring")
- dataset_info = [
- (name, value) for name, value in dataset_info
- if not isinstance(value, float)]
-
- for name, value in dataset_info:
- self.add_node(SpecH5NodeDataset(
- name=name,
- data=value,
- parent=self))
-
-
-class InstrumentMcaGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, analyser_index, scan):
- name = "mca_%d" % analyser_index
- commonh5.Group.__init__(self, name=name, parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXdetector")})
-
- mcaDataDataset = McaDataDataset(parent=self,
- analyser_index=analyser_index,
- scan=scan)
- self.add_node(mcaDataDataset)
- spectrum_length = mcaDataDataset.shape[-1]
- mcaDataDataset = None
-
- if len(scan.mca.channels) == 1:
- # single @CALIB line applying to multiple devices
- calibration_dataset = scan.mca.calibration[0]
- channels_dataset = scan.mca.channels[0]
- else:
- calibration_dataset = scan.mca.calibration[analyser_index]
- channels_dataset = scan.mca.channels[analyser_index]
-
- channels_length = len(channels_dataset)
- if (channels_length > 1) and (spectrum_length > 0):
- logger1.info("Spectrum and channels length mismatch")
- # this should always be the case
- if channels_length > spectrum_length:
- channels_dataset = channels_dataset[:spectrum_length]
- elif channels_length < spectrum_length:
- # only trust first channel and increment
- channel0 = channels_dataset[0]
- increment = channels_dataset[1] - channels_dataset[0]
- channels_dataset = numpy.linspace(channel0,
- channel0 + increment * spectrum_length,
- spectrum_length, endpoint=False)
-
- self.add_node(SpecH5NodeDataset(name="calibration",
- data=calibration_dataset,
- parent=self))
- self.add_node(SpecH5NodeDataset(name="channels",
- data=channels_dataset,
- parent=self))
-
- if "CTIME" in scan.mca_header_dict:
- ctime_line = scan.mca_header_dict['CTIME']
- preset_time, live_time, elapsed_time = _parse_ctime(ctime_line, analyser_index)
- self.add_node(SpecH5NodeDataset(name="preset_time",
- data=preset_time,
- parent=self))
- self.add_node(SpecH5NodeDataset(name="live_time",
- data=live_time,
- parent=self))
- self.add_node(SpecH5NodeDataset(name="elapsed_time",
- data=elapsed_time,
- parent=self))
-
-
-class McaDataDataset(SpecH5LazyNodeDataset):
- """Lazy loadable dataset for MCA data"""
- def __init__(self, parent, analyser_index, scan):
- commonh5.LazyLoadableDataset.__init__(
- self, name="data", parent=parent,
- attrs={"interpretation": to_h5py_utf8("spectrum"),})
- self._scan = scan
- self._analyser_index = analyser_index
- self._shape = None
- self._num_analysers = _get_number_of_mca_analysers(self._scan)
-
- def _create_data(self):
- return _demultiplex_mca(self._scan, self._analyser_index)
-
- @property
- def shape(self):
- if self._shape is None:
- num_spectra_in_file = len(self._scan.mca)
- num_spectra_per_analyser = num_spectra_in_file // self._num_analysers
- len_spectrum = len(self._scan.mca[self._analyser_index])
- self._shape = num_spectra_per_analyser, len_spectrum
- return self._shape
-
- @property
- def size(self):
- return numpy.prod(self.shape, dtype=numpy.intp)
-
- @property
- def dtype(self):
- # we initialize the data with numpy.empty() without specifying a dtype
- # in _demultiplex_mca()
- return numpy.empty((1, )).dtype
-
- def __len__(self):
- return self.shape[0]
-
- def __getitem__(self, item):
- # optimization for fetching a single spectrum if data not already loaded
- if not self._is_initialized:
- if isinstance(item, six.integer_types):
- if item < 0:
- # negative indexing
- item += len(self)
- return self._scan.mca[self._analyser_index +
- item * self._num_analysers]
- # accessing a slice or element of a single spectrum [i, j:k]
- try:
- spectrum_idx, channel_idx_or_slice = item
- assert isinstance(spectrum_idx, six.integer_types)
- except (ValueError, TypeError, AssertionError):
- pass
- else:
- if spectrum_idx < 0:
- item += len(self)
- idx = self._analyser_index + spectrum_idx * self._num_analysers
- return self._scan.mca[idx][channel_idx_or_slice]
-
- return super(McaDataDataset, self).__getitem__(item)
-
-
-class MeasurementGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, scan):
- """
-
- :param parent: parent Group
- :param scan: specfile.Scan object
- """
- commonh5.Group.__init__(self, name="measurement", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXcollection"),})
- for label in scan.labels:
- safe_label = label.replace("/", "%")
- self.add_node(SpecH5NodeDataset(name=safe_label,
- data=scan.data_column_by_name(label),
- parent=self))
-
- num_analysers = _get_number_of_mca_analysers(scan)
- for anal_idx in range(num_analysers):
- self.add_node(MeasurementMcaGroup(parent=self, analyser_index=anal_idx))
-
-
-class MeasurementMcaGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, analyser_index):
- basename = "mca_%d" % analyser_index
- commonh5.Group.__init__(self, name=basename, parent=parent,
- attrs={})
-
- target_name = self.name.replace("measurement", "instrument")
- self.add_node(commonh5.SoftLink(name="data",
- path=target_name + "/data",
- parent=self))
- self.add_node(commonh5.SoftLink(name="info",
- path=target_name,
- parent=self))
-
-
-class SampleGroup(commonh5.Group, SpecH5Group):
- def __init__(self, parent, scan):
- """
-
- :param parent: parent Group
- :param scan: specfile.Scan object
- """
- commonh5.Group.__init__(self, name="sample", parent=parent,
- attrs={"NX_class": to_h5py_utf8("NXsample"),})
-
- if _unit_cell_in_scan(scan):
- self.add_node(SpecH5NodeDataset(name="unit_cell",
- data=_parse_unit_cell(scan.scan_header_dict["G1"]),
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
- self.add_node(SpecH5NodeDataset(name="unit_cell_abc",
- data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 0:3],
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
- self.add_node(SpecH5NodeDataset(name="unit_cell_alphabetagamma",
- data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 3:6],
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
- if _ub_matrix_in_scan(scan):
- self.add_node(SpecH5NodeDataset(name="ub_matrix",
- data=_parse_UB_matrix(scan.scan_header_dict["G3"]),
- parent=self,
- attrs={"interpretation": to_h5py_utf8("scalar")}))
diff --git a/silx/io/test/__init__.py b/silx/io/test/__init__.py
deleted file mode 100644
index 68b6e9b..0000000
--- a/silx/io/test/__init__.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "08/12/2017"
-
-import unittest
-
-from .test_specfile import suite as test_specfile_suite
-from .test_specfilewrapper import suite as test_specfilewrapper_suite
-from .test_dictdump import suite as test_dictdump_suite
-from .test_spech5 import suite as test_spech5_suite
-from .test_spectoh5 import suite as test_spectoh5_suite
-from .test_octaveh5 import suite as test_octaveh5_suite
-from .test_fabioh5 import suite as test_fabioh5_suite
-from .test_utils import suite as test_utils_suite
-from .test_nxdata import suite as test_nxdata_suite
-from .test_commonh5 import suite as test_commonh5_suite
-from .test_rawh5 import suite as test_rawh5_suite
-from .test_url import suite as test_url_suite
-from .test_h5py_utils import suite as test_h5py_utils_suite
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_dictdump_suite())
- test_suite.addTest(test_specfile_suite())
- test_suite.addTest(test_specfilewrapper_suite())
- test_suite.addTest(test_spech5_suite())
- test_suite.addTest(test_spectoh5_suite())
- test_suite.addTest(test_octaveh5_suite())
- test_suite.addTest(test_utils_suite())
- test_suite.addTest(test_fabioh5_suite())
- test_suite.addTest(test_nxdata_suite())
- test_suite.addTest(test_commonh5_suite())
- test_suite.addTest(test_rawh5_suite())
- test_suite.addTest(test_url_suite())
- test_suite.addTest(test_h5py_utils_suite())
- return test_suite
diff --git a/silx/io/test/test_commonh5.py b/silx/io/test/test_commonh5.py
deleted file mode 100644
index 168ef34..0000000
--- a/silx/io/test/test_commonh5.py
+++ /dev/null
@@ -1,295 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for commonh5 wrapper"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "21/09/2017"
-
-import logging
-import numpy
-import unittest
-import tempfile
-import shutil
-
-_logger = logging.getLogger(__name__)
-
-import silx.io
-import silx.io.utils
-import h5py
-
-try:
- from .. import commonh5
-except ImportError:
- commonh5 = None
-
-
-class TestCommonFeatures(unittest.TestCase):
- """Test common features supported by h5py and our implementation."""
-
- @classmethod
- def createFile(cls):
- return None
-
- @classmethod
- def setUpClass(cls):
- # Set to None cause create_resource can raise an excpetion
- cls.h5 = None
- cls.h5 = cls.create_resource()
- if cls.h5 is None:
- raise unittest.SkipTest("File not created")
-
- @classmethod
- def create_resource(cls):
- """Must be implemented"""
- return None
-
- @classmethod
- def tearDownClass(cls):
- cls.h5 = None
-
- def test_file(self):
- node = self.h5
- self.assertTrue(silx.io.is_file(node))
- self.assertTrue(silx.io.is_group(node))
- self.assertFalse(silx.io.is_dataset(node))
- self.assertEqual(len(node.attrs), 0)
-
- def test_group(self):
- node = self.h5["group"]
- self.assertFalse(silx.io.is_file(node))
- self.assertTrue(silx.io.is_group(node))
- self.assertFalse(silx.io.is_dataset(node))
- self.assertEqual(len(node.attrs), 0)
- class_ = self.h5.get("group", getclass=True)
- classlink = self.h5.get("group", getlink=True, getclass=True)
- self.assertEqual(class_, h5py.Group)
- self.assertEqual(classlink, h5py.HardLink)
-
- def test_dataset(self):
- node = self.h5["group/dataset"]
- self.assertFalse(silx.io.is_file(node))
- self.assertFalse(silx.io.is_group(node))
- self.assertTrue(silx.io.is_dataset(node))
- self.assertEqual(len(node.attrs), 0)
- class_ = self.h5.get("group/dataset", getclass=True)
- classlink = self.h5.get("group/dataset", getlink=True, getclass=True)
- self.assertEqual(class_, h5py.Dataset)
- self.assertEqual(classlink, h5py.HardLink)
-
- def test_soft_link(self):
- node = self.h5["link/soft_link"]
- self.assertEqual(node.name, "/link/soft_link")
- class_ = self.h5.get("link/soft_link", getclass=True)
- link = self.h5.get("link/soft_link", getlink=True)
- classlink = self.h5.get("link/soft_link", getlink=True, getclass=True)
- self.assertEqual(class_, h5py.Dataset)
- self.assertTrue(isinstance(link, (h5py.SoftLink, commonh5.SoftLink)))
- self.assertTrue(silx.io.utils.is_softlink(link))
- self.assertEqual(classlink, h5py.SoftLink)
-
- def test_external_link(self):
- node = self.h5["link/external_link"]
- self.assertEqual(node.name, "/target/dataset")
- class_ = self.h5.get("link/external_link", getclass=True)
- classlink = self.h5.get("link/external_link", getlink=True, getclass=True)
- self.assertEqual(class_, h5py.Dataset)
- self.assertEqual(classlink, h5py.ExternalLink)
-
- def test_external_link_to_link(self):
- node = self.h5["link/external_link_to_link"]
- self.assertEqual(node.name, "/target/link")
- class_ = self.h5.get("link/external_link_to_link", getclass=True)
- classlink = self.h5.get("link/external_link_to_link", getlink=True, getclass=True)
- self.assertEqual(class_, h5py.Dataset)
- self.assertEqual(classlink, h5py.ExternalLink)
-
- def test_create_groups(self):
- c = self.h5.create_group(self.id() + "/a/b/c")
- d = c.create_group("/" + self.id() + "/a/b/d")
-
- self.assertRaises(ValueError, self.h5.create_group, self.id() + "/a/b/d")
- self.assertEqual(c.name, "/" + self.id() + "/a/b/c")
- self.assertEqual(d.name, "/" + self.id() + "/a/b/d")
-
- def test_setitem_python_object_dataset(self):
- group = self.h5.create_group(self.id())
- group["a"] = 10
- self.assertEqual(group["a"].dtype.kind, "i")
-
- def test_setitem_numpy_dataset(self):
- group = self.h5.create_group(self.id())
- group["a"] = numpy.array([10, 20, 30])
- self.assertEqual(group["a"].dtype.kind, "i")
- self.assertEqual(group["a"].shape, (3,))
-
- def test_setitem_link(self):
- group = self.h5.create_group(self.id())
- group["a"] = 10
- group["b"] = group["a"]
- self.assertEqual(group["b"].dtype.kind, "i")
-
- def test_setitem_dataset_is_sub_group(self):
- self.h5[self.id() + "/a"] = 10
-
-
-class TestCommonFeatures_h5py(TestCommonFeatures):
- """Check if h5py is compliant with what we expect."""
-
- @classmethod
- def create_resource(cls):
- cls.tmp_dir = tempfile.mkdtemp()
-
- externalh5 = h5py.File(cls.tmp_dir + "/external.h5", mode="w")
- externalh5["target/dataset"] = 50
- externalh5["target/link"] = h5py.SoftLink("/target/dataset")
- externalh5.close()
-
- h5 = h5py.File(cls.tmp_dir + "/base.h5", mode="w")
- h5["group/dataset"] = 50
- h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
- h5["link/external_link"] = h5py.ExternalLink("external.h5", "/target/dataset")
- h5["link/external_link_to_link"] = h5py.ExternalLink("external.h5", "/target/link")
-
- return h5
-
- @classmethod
- def tearDownClass(cls):
- super(TestCommonFeatures_h5py, cls).tearDownClass()
- if hasattr(cls, "tmp_dir") and cls.tmp_dir is not None:
- shutil.rmtree(cls.tmp_dir)
-
-
-class TestCommonFeatures_commonH5(TestCommonFeatures):
- """Check if commonh5 is compliant with h5py."""
-
- @classmethod
- def create_resource(cls):
- h5 = commonh5.File("base.h5", "w")
- h5.create_group("group").create_dataset("dataset", data=numpy.int32(50))
-
- link = h5.create_group("link")
- link.add_node(commonh5.SoftLink("soft_link", "/group/dataset"))
-
- return h5
-
- def test_external_link(self):
- # not applicable
- pass
-
- def test_external_link_to_link(self):
- # not applicable
- pass
-
-
-class TestSpecificCommonH5(unittest.TestCase):
- """Test specific features from commonh5.
-
- Test of shared features should be done by TestCommonFeatures."""
-
- def setUp(self):
- if commonh5 is None:
- self.skipTest("silx.io.commonh5 is needed")
-
- def test_node_attrs(self):
- node = commonh5.Node("Foo", attrs={"a": 1})
- self.assertEqual(node.attrs["a"], 1)
- node.attrs["b"] = 8
- self.assertEqual(node.attrs["b"], 8)
- node.attrs["b"] = 2
- self.assertEqual(node.attrs["b"], 2)
-
- def test_node_readonly_attrs(self):
- f = commonh5.File(name="Foo", mode="r")
- node = commonh5.Node("Foo", attrs={"a": 1})
- node.attrs["b"] = 8
- f.add_node(node)
- self.assertEqual(node.attrs["b"], 8)
- try:
- node.attrs["b"] = 1
- self.fail()
- except RuntimeError:
- pass
-
- def test_create_dataset(self):
- f = commonh5.File(name="Foo", mode="w")
- node = f.create_dataset("foo", data=numpy.array([1]))
- self.assertIs(node.parent, f)
- self.assertIs(f["foo"], node)
-
- def test_create_group(self):
- f = commonh5.File(name="Foo", mode="w")
- node = f.create_group("foo")
- self.assertIs(node.parent, f)
- self.assertIs(f["foo"], node)
-
- def test_readonly_create_dataset(self):
- f = commonh5.File(name="Foo", mode="r")
- try:
- f.create_dataset("foo", data=numpy.array([1]))
- self.fail()
- except RuntimeError:
- pass
-
- def test_readonly_create_group(self):
- f = commonh5.File(name="Foo", mode="r")
- try:
- f.create_group("foo")
- self.fail()
- except RuntimeError:
- pass
-
- def test_create_unicode_dataset(self):
- f = commonh5.File(name="Foo", mode="w")
- try:
- f.create_dataset("foo", data=numpy.array(u"aaaa"))
- self.fail()
- except TypeError:
- pass
-
- def test_setitem_dataset(self):
- self.h5 = commonh5.File(name="Foo", mode="w")
- group = self.h5.create_group(self.id())
- group["a"] = commonh5.Dataset(None, data=numpy.array(10))
- self.assertEqual(group["a"].dtype.kind, "i")
-
- def test_setitem_explicit_link(self):
- self.h5 = commonh5.File(name="Foo", mode="w")
- group = self.h5.create_group(self.id())
- group["a"] = 10
- group["b"] = commonh5.SoftLink(None, path="/" + self.id() + "/a")
- self.assertEqual(group["b"].dtype.kind, "i")
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestCommonFeatures_h5py))
- test_suite.addTest(loadTests(TestCommonFeatures_commonH5))
- test_suite.addTest(loadTests(TestSpecificCommonH5))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_dictdump.py b/silx/io/test/test_dictdump.py
deleted file mode 100644
index 93c9183..0000000
--- a/silx/io/test/test_dictdump.py
+++ /dev/null
@@ -1,1025 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for dicttoh5 module"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-from collections import OrderedDict
-import numpy
-import os
-import tempfile
-import unittest
-import h5py
-from copy import deepcopy
-
-from collections import defaultdict
-
-from silx.utils.testutils import TestLogging
-
-from ..configdict import ConfigDict
-from .. import dictdump
-from ..dictdump import dicttoh5, dicttojson, dump
-from ..dictdump import h5todict, load
-from ..dictdump import logger as dictdump_logger
-from ..utils import is_link
-from ..utils import h5py_read_dataset
-
-
-def tree():
- """Tree data structure as a recursive nested dictionary"""
- return defaultdict(tree)
-
-
-inhabitants = 160215
-
-city_attrs = tree()
-city_attrs["Europe"]["France"]["Grenoble"]["area"] = "18.44 km2"
-city_attrs["Europe"]["France"]["Grenoble"]["inhabitants"] = inhabitants
-city_attrs["Europe"]["France"]["Grenoble"]["coordinates"] = [45.1830, 5.7196]
-city_attrs["Europe"]["France"]["Tourcoing"]["area"]
-
-ext_attrs = tree()
-ext_attrs["ext_group"]["dataset"] = 10
-ext_filename = "ext.h5"
-
-link_attrs = tree()
-link_attrs["links"]["group"]["dataset"] = 10
-link_attrs["links"]["group"]["relative_softlink"] = h5py.SoftLink("dataset")
-link_attrs["links"]["relative_softlink"] = h5py.SoftLink("group/dataset")
-link_attrs["links"]["absolute_softlink"] = h5py.SoftLink("/links/group/dataset")
-link_attrs["links"]["external_link"] = h5py.ExternalLink(ext_filename, "/ext_group/dataset")
-
-
-class DictTestCase(unittest.TestCase):
-
- def assertRecursiveEqual(self, expected, actual, nodes=tuple()):
- err_msg = "\n\n Tree nodes: {}".format(nodes)
- if isinstance(expected, dict):
- self.assertTrue(isinstance(actual, dict), msg=err_msg)
- self.assertEqual(
- set(expected.keys()),
- set(actual.keys()),
- msg=err_msg
- )
- for k in actual:
- self.assertRecursiveEqual(
- expected[k],
- actual[k],
- nodes=nodes + (k,),
- )
- return
- if isinstance(actual, numpy.ndarray):
- actual = actual.tolist()
- if isinstance(expected, numpy.ndarray):
- expected = expected.tolist()
-
- self.assertEqual(expected, actual, msg=err_msg)
-
-
-class H5DictTestCase(DictTestCase):
-
- def _dictRoundTripNormalize(self, treedict):
- """Convert the dictionary as expected from a round-trip
- treedict -> dicttoh5 -> h5todict -> newtreedict
- """
- for key, value in list(treedict.items()):
- if isinstance(value, dict):
- self._dictRoundTripNormalize(value)
-
- # Expand treedict[("group", "attr_name")]
- # to treedict["group"]["attr_name"]
- for key, value in list(treedict.items()):
- if not isinstance(key, tuple):
- continue
- # Put the attribute inside the group
- grpname, attr = key
- if not grpname:
- continue
- group = treedict.setdefault(grpname, dict())
- if isinstance(group, dict):
- del treedict[key]
- group[("", attr)] = value
-
- def dictRoundTripNormalize(self, treedict):
- treedict2 = deepcopy(treedict)
- self._dictRoundTripNormalize(treedict2)
- return treedict2
-
-
-class TestDictToH5(H5DictTestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
- self.h5_ext_fname = os.path.join(self.tempdir, ext_filename)
-
- def tearDown(self):
- if os.path.exists(self.h5_fname):
- os.unlink(self.h5_fname)
- if os.path.exists(self.h5_ext_fname):
- os.unlink(self.h5_ext_fname)
- os.rmdir(self.tempdir)
-
- def testH5CityAttrs(self):
- filters = {'shuffle': True,
- 'fletcher32': True}
- dicttoh5(city_attrs, self.h5_fname, h5path='/city attributes',
- mode="w", create_dataset_args=filters)
-
- h5f = h5py.File(self.h5_fname, mode='r')
-
- self.assertIn("Tourcoing/area", h5f["/city attributes/Europe/France"])
- ds = h5f["/city attributes/Europe/France/Grenoble/inhabitants"]
- self.assertEqual(ds[...], 160215)
-
- # filters only apply to datasets that are not scalars (shape != () )
- ds = h5f["/city attributes/Europe/France/Grenoble/coordinates"]
- #self.assertEqual(ds.compression, "gzip")
- self.assertTrue(ds.fletcher32)
- self.assertTrue(ds.shuffle)
-
- h5f.close()
-
- ddict = load(self.h5_fname, fmat="hdf5")
- self.assertAlmostEqual(
- min(ddict["city attributes"]["Europe"]["France"]["Grenoble"]["coordinates"]),
- 5.7196)
-
- def testH5OverwriteDeprecatedApi(self):
- dd = ConfigDict({'t': True})
-
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a')
- dd = ConfigDict({'t': False})
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
- overwrite_data=False)
-
- res = h5todict(self.h5_fname)
- assert(res['t'] == True)
-
- dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
- overwrite_data=True)
-
- res = h5todict(self.h5_fname)
- assert(res['t'] == False)
-
- def testAttributes(self):
- """Any kind of attribute can be described"""
- ddict = {
- "group": {"datatset": "hmmm", ("", "group_attr"): 10},
- "dataset": "aaaaaaaaaaaaaaa",
- ("", "root_attr"): 11,
- ("dataset", "dataset_attr"): 12,
- ("group", "group_attr2"): 13,
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['group_attr'], 10)
- self.assertEqual(h5file.attrs['root_attr'], 11)
- self.assertEqual(h5file["dataset"].attrs['dataset_attr'], 12)
- self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
-
- def testPathAttributes(self):
- """A group is requested at a path"""
- ddict = {
- ("", "NX_class"): 'NXcollection',
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- # This should not warn
- with TestLogging(dictdump_logger, warning=0):
- dictdump.dicttoh5(ddict, h5file, h5path="foo/bar")
-
- def testKeyOrder(self):
- ddict1 = {
- "d": "plow",
- ("d", "a"): "ox",
- }
- ddict2 = {
- ("d", "a"): "ox",
- "d": "plow",
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(ddict1, h5file, h5path="g1")
- dictdump.dicttoh5(ddict2, h5file, h5path="g2")
- self.assertEqual(h5file["g1/d"].attrs['a'], "ox")
- self.assertEqual(h5file["g2/d"].attrs['a'], "ox")
-
- def testAttributeValues(self):
- """Any NX data types can be used"""
- ddict = {
- ("", "bool"): True,
- ("", "int"): 11,
- ("", "float"): 1.1,
- ("", "str"): "a",
- ("", "boollist"): [True, False, True],
- ("", "intlist"): [11, 22, 33],
- ("", "floatlist"): [1.1, 2.2, 3.3],
- ("", "strlist"): ["a", "bb", "ccc"],
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(ddict, h5file)
- for k, expected in ddict.items():
- result = h5file.attrs[k[1]]
- if isinstance(expected, list):
- if isinstance(expected[0], str):
- numpy.testing.assert_array_equal(result, expected)
- else:
- numpy.testing.assert_array_almost_equal(result, expected)
- else:
- self.assertEqual(result, expected)
-
- def testAttributeAlreadyExists(self):
- """A duplicated attribute warns if overwriting is not enabled"""
- ddict = {
- "group": {"dataset": "hmmm", ("", "attr"): 10},
- ("group", "attr"): 10,
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['attr'], 10)
-
- def testFlatDict(self):
- """Description of a tree with a single level of keys"""
- ddict = {
- "group/group/dataset": 10,
- ("group/group/dataset", "attr"): 11,
- ("group/group", "attr"): 12,
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(ddict, h5file)
- self.assertEqual(h5file["group/group/dataset"][()], 10)
- self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
- self.assertEqual(h5file["group/group"].attrs['attr'], 12)
-
- def testLinks(self):
- with h5py.File(self.h5_ext_fname, "w") as h5file:
- dictdump.dicttoh5(ext_attrs, h5file)
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(link_attrs, h5file)
- with h5py.File(self.h5_fname, "r") as h5file:
- self.assertEqual(h5file["links/group/dataset"][()], 10)
- self.assertEqual(h5file["links/group/relative_softlink"][()], 10)
- self.assertEqual(h5file["links/relative_softlink"][()], 10)
- self.assertEqual(h5file["links/absolute_softlink"][()], 10)
- self.assertEqual(h5file["links/external_link"][()], 10)
-
- def testDumpNumpyArray(self):
- ddict = {
- 'darks': {
- '0': numpy.array([[0, 0, 0], [0, 0, 0]], dtype=numpy.uint16)
- }
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttoh5(ddict, h5file)
- with h5py.File(self.h5_fname, "r") as h5file:
- numpy.testing.assert_array_equal(h5py_read_dataset(h5file["darks"]["0"]),
- ddict['darks']['0'])
-
- def testOverwrite(self):
- # Tree structure that will be tested
- group1 = {
- ("", "attr2"): "original2",
- "dset1": 0,
- "dset2": [0, 1],
- ("dset1", "attr1"): "original1",
- ("dset1", "attr2"): "original2",
- ("dset2", "attr1"): "original1",
- ("dset2", "attr2"): "original2",
- }
- group2 = {
- "subgroup1": group1.copy(),
- "subgroup2": group1.copy(),
- ("subgroup1", "attr1"): "original1",
- ("subgroup2", "attr1"): "original1"
- }
- group2.update(group1)
- # initial HDF5 tree
- otreedict = {
- ('', 'attr1'): "original1",
- ('', 'attr2'): "original2",
- 'group1': group1,
- 'group2': group2,
- ('group1', 'attr1'): "original1",
- ('group2', 'attr1'): "original1"
- }
- wtreedict = None # dumped dictionary
- etreedict = None # expected HDF5 tree after dump
-
- def reset_file():
- dicttoh5(
- otreedict,
- h5file=self.h5_fname,
- mode="w",
- )
-
- def append_file(update_mode):
- dicttoh5(
- wtreedict,
- h5file=self.h5_fname,
- mode="a",
- update_mode=update_mode
- )
-
- def assert_file():
- rtreedict = h5todict(
- self.h5_fname,
- include_attributes=True,
- asarray=False
- )
- netreedict = self.dictRoundTripNormalize(etreedict)
- try:
- self.assertRecursiveEqual(netreedict, rtreedict)
- except AssertionError:
- from pprint import pprint
- print("\nDUMP:")
- pprint(wtreedict)
- print("\nEXPECTED:")
- pprint(netreedict)
- print("\nHDF5:")
- pprint(rtreedict)
- raise
-
- def assert_append(update_mode):
- append_file(update_mode)
- assert_file()
-
- # Test wrong arguments
- with self.assertRaises(ValueError):
- dicttoh5(
- otreedict,
- h5file=self.h5_fname,
- mode="w",
- update_mode="wrong-value"
- )
-
- # No writing
- reset_file()
- etreedict = deepcopy(otreedict)
- assert_file()
-
- # Write identical dictionary
- wtreedict = deepcopy(otreedict)
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add", "modify", "replace"]:
- assert_append(update_mode)
-
- # Write empty dictionary
- wtreedict = dict()
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add", "modify", "replace"]:
- assert_append(update_mode)
-
- # Modified dataset
- wtreedict = dict()
- wtreedict["group2"] = dict()
- wtreedict["group2"]["subgroup2"] = dict()
- wtreedict["group2"]["subgroup2"]["dset1"] = {"dset3": [10, 20]}
- wtreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add"]:
- assert_append(update_mode)
-
- etreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
- assert_append("modify")
-
- etreedict["group2"] = dict()
- del etreedict[("group2", "attr1")]
- etreedict["group2"]["subgroup2"] = dict()
- etreedict["group2"]["subgroup2"]["dset1"] = {"dset3": [10, 20]}
- etreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
- assert_append("replace")
-
- # Modified group
- wtreedict = dict()
- wtreedict["group2"] = dict()
- wtreedict["group2"]["subgroup2"] = [0, 1]
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add", "modify"]:
- assert_append(update_mode)
-
- etreedict["group2"] = dict()
- del etreedict[("group2", "attr1")]
- etreedict["group2"]["subgroup2"] = [0, 1]
- assert_append("replace")
-
- # Modified attribute
- wtreedict = dict()
- wtreedict["group2"] = dict()
- wtreedict["group2"]["subgroup2"] = dict()
- wtreedict["group2"]["subgroup2"][("dset1", "attr1")] = "modified"
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add"]:
- assert_append(update_mode)
-
- etreedict["group2"]["subgroup2"][("dset1", "attr1")] = "modified"
- assert_append("modify")
-
- etreedict["group2"] = dict()
- del etreedict[("group2", "attr1")]
- etreedict["group2"]["subgroup2"] = dict()
- etreedict["group2"]["subgroup2"]["dset1"] = dict()
- etreedict["group2"]["subgroup2"]["dset1"][("", "attr1")] = "modified"
- assert_append("replace")
-
- # Delete group
- wtreedict = dict()
- wtreedict["group2"] = dict()
- wtreedict["group2"]["subgroup2"] = None
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add"]:
- assert_append(update_mode)
-
- del etreedict["group2"]["subgroup2"]
- del etreedict["group2"][("subgroup2", "attr1")]
- assert_append("modify")
-
- etreedict["group2"] = dict()
- del etreedict[("group2", "attr1")]
- assert_append("replace")
-
- # Delete dataset
- wtreedict = dict()
- wtreedict["group2"] = dict()
- wtreedict["group2"]["subgroup2"] = dict()
- wtreedict["group2"]["subgroup2"]["dset2"] = None
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add"]:
- assert_append(update_mode)
-
- del etreedict["group2"]["subgroup2"]["dset2"]
- del etreedict["group2"]["subgroup2"][("dset2", "attr1")]
- del etreedict["group2"]["subgroup2"][("dset2", "attr2")]
- assert_append("modify")
-
- etreedict["group2"] = dict()
- del etreedict[("group2", "attr1")]
- etreedict["group2"]["subgroup2"] = dict()
- assert_append("replace")
-
- # Delete attribute
- wtreedict = dict()
- wtreedict["group2"] = dict()
- wtreedict["group2"]["subgroup2"] = dict()
- wtreedict["group2"]["subgroup2"][("dset2", "attr1")] = None
-
- reset_file()
- etreedict = deepcopy(otreedict)
- for update_mode in [None, "add"]:
- assert_append(update_mode)
-
- del etreedict["group2"]["subgroup2"][("dset2", "attr1")]
- assert_append("modify")
-
- etreedict["group2"] = dict()
- del etreedict[("group2", "attr1")]
- etreedict["group2"]["subgroup2"] = dict()
- etreedict["group2"]["subgroup2"]["dset2"] = dict()
- assert_append("replace")
-
-
-class TestH5ToDict(H5DictTestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
- self.h5_ext_fname = os.path.join(self.tempdir, ext_filename)
- dicttoh5(city_attrs, self.h5_fname)
- dicttoh5(link_attrs, self.h5_fname, mode="a")
- dicttoh5(ext_attrs, self.h5_ext_fname)
-
- def tearDown(self):
- if os.path.exists(self.h5_fname):
- os.unlink(self.h5_fname)
- if os.path.exists(self.h5_ext_fname):
- os.unlink(self.h5_ext_fname)
- os.rmdir(self.tempdir)
-
- def testExcludeNames(self):
- ddict = h5todict(self.h5_fname, path="/Europe/France",
- exclude_names=["ourcoing", "inhab", "toto"])
- self.assertNotIn("Tourcoing", ddict)
- self.assertIn("Grenoble", ddict)
-
- self.assertNotIn("inhabitants", ddict["Grenoble"])
- self.assertIn("coordinates", ddict["Grenoble"])
- self.assertIn("area", ddict["Grenoble"])
-
- def testAsArrayTrue(self):
- """Test with asarray=True, the default"""
- ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble")
- self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants)))
-
- def testAsArrayFalse(self):
- """Test with asarray=False"""
- ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble", asarray=False)
- self.assertEqual(ddict["inhabitants"], inhabitants)
-
- def testDereferenceLinks(self):
- ddict = h5todict(self.h5_fname, path="links", dereference_links=True)
- self.assertTrue(ddict["absolute_softlink"], 10)
- self.assertTrue(ddict["relative_softlink"], 10)
- self.assertTrue(ddict["external_link"], 10)
- self.assertTrue(ddict["group"]["relative_softlink"], 10)
-
- def testPreserveLinks(self):
- ddict = h5todict(self.h5_fname, path="links", dereference_links=False)
- self.assertTrue(is_link(ddict["absolute_softlink"]))
- self.assertTrue(is_link(ddict["relative_softlink"]))
- self.assertTrue(is_link(ddict["external_link"]))
- self.assertTrue(is_link(ddict["group"]["relative_softlink"]))
-
- def testStrings(self):
- ddict = {"dset_bytes": b"bytes",
- "dset_utf8": "utf8",
- "dset_2bytes": [b"bytes", b"bytes"],
- "dset_2utf8": ["utf8", "utf8"],
- ("", "attr_bytes"): b"bytes",
- ("", "attr_utf8"): "utf8",
- ("", "attr_2bytes"): [b"bytes", b"bytes"],
- ("", "attr_2utf8"): ["utf8", "utf8"]}
- dicttoh5(ddict, self.h5_fname, mode="w")
- adict = h5todict(self.h5_fname, include_attributes=True, asarray=False)
- self.assertEqual(ddict["dset_bytes"], adict["dset_bytes"])
- self.assertEqual(ddict["dset_utf8"], adict["dset_utf8"])
- self.assertEqual(ddict[("", "attr_bytes")], adict[("", "attr_bytes")])
- self.assertEqual(ddict[("", "attr_utf8")], adict[("", "attr_utf8")])
- numpy.testing.assert_array_equal(ddict["dset_2bytes"], adict["dset_2bytes"])
- numpy.testing.assert_array_equal(ddict["dset_2utf8"], adict["dset_2utf8"])
- numpy.testing.assert_array_equal(ddict[("", "attr_2bytes")], adict[("", "attr_2bytes")])
- numpy.testing.assert_array_equal(ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")])
-
-
-class TestDictToNx(H5DictTestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "nx.h5")
- self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5")
-
- def tearDown(self):
- if os.path.exists(self.h5_fname):
- os.unlink(self.h5_fname)
- if os.path.exists(self.h5_ext_fname):
- os.unlink(self.h5_ext_fname)
- os.rmdir(self.tempdir)
-
- def testAttributes(self):
- """Any kind of attribute can be described"""
- ddict = {
- "group": {"dataset": 100, "@group_attr1": 10},
- "dataset": 200,
- "@root_attr": 11,
- "dataset@dataset_attr": "12",
- "group@group_attr2": 13,
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttonx(ddict, h5file)
- self.assertEqual(h5file["group"].attrs['group_attr1'], 10)
- self.assertEqual(h5file.attrs['root_attr'], 11)
- self.assertEqual(h5file["dataset"].attrs['dataset_attr'], "12")
- self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
-
- def testKeyOrder(self):
- ddict1 = {
- "d": "plow",
- "d@a": "ox",
- }
- ddict2 = {
- "d@a": "ox",
- "d": "plow",
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttonx(ddict1, h5file, h5path="g1")
- dictdump.dicttonx(ddict2, h5file, h5path="g2")
- self.assertEqual(h5file["g1/d"].attrs['a'], "ox")
- self.assertEqual(h5file["g2/d"].attrs['a'], "ox")
-
- def testAttributeValues(self):
- """Any NX data types can be used"""
- ddict = {
- "@bool": True,
- "@int": 11,
- "@float": 1.1,
- "@str": "a",
- "@boollist": [True, False, True],
- "@intlist": [11, 22, 33],
- "@floatlist": [1.1, 2.2, 3.3],
- "@strlist": ["a", "bb", "ccc"],
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttonx(ddict, h5file)
- for k, expected in ddict.items():
- result = h5file.attrs[k[1:]]
- if isinstance(expected, list):
- if isinstance(expected[0], str):
- numpy.testing.assert_array_equal(result, expected)
- else:
- numpy.testing.assert_array_almost_equal(result, expected)
- else:
- self.assertEqual(result, expected)
-
- def testFlatDict(self):
- """Description of a tree with a single level of keys"""
- ddict = {
- "group/group/dataset": 10,
- "group/group/dataset@attr": 11,
- "group/group@attr": 12,
- }
- with h5py.File(self.h5_fname, "w") as h5file:
- dictdump.dicttonx(ddict, h5file)
- self.assertEqual(h5file["group/group/dataset"][()], 10)
- self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
- self.assertEqual(h5file["group/group"].attrs['attr'], 12)
-
- def testLinks(self):
- ddict = {"ext_group": {"dataset": 10}}
- dictdump.dicttonx(ddict, self.h5_ext_fname)
- ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
- ">relative_softlink": "group/dataset",
- ">absolute_softlink": "/links/group/dataset",
- ">external_link": "nx_ext.h5::/ext_group/dataset"}}
- dictdump.dicttonx(ddict, self.h5_fname)
- with h5py.File(self.h5_fname, "r") as h5file:
- self.assertEqual(h5file["links/group/dataset"][()], 10)
- self.assertEqual(h5file["links/group/relative_softlink"][()], 10)
- self.assertEqual(h5file["links/relative_softlink"][()], 10)
- self.assertEqual(h5file["links/absolute_softlink"][()], 10)
- self.assertEqual(h5file["links/external_link"][()], 10)
-
- def testUpLinks(self):
- ddict = {"data": {"group": {"dataset": 10, ">relative_softlink": "dataset"}},
- "links": {"group": {"subgroup": {">relative_softlink": "../../../data/group/dataset"}}}}
- dictdump.dicttonx(ddict, self.h5_fname)
- with h5py.File(self.h5_fname, "r") as h5file:
- self.assertEqual(h5file["/links/group/subgroup/relative_softlink"][()], 10)
-
- def testOverwrite(self):
- entry_name = "entry"
- wtreedict = {
- "group1": {"a": 1, "b": 2},
- "group2@attr3": "attr3",
- "group2@attr4": "attr4",
- "group2": {
- "@attr1": "attr1",
- "@attr2": "attr2",
- "c": 3,
- "d": 4,
- "dataset4": 8,
- "dataset4@units": "keV",
- },
- "group3": {"subgroup": {"e": 9, "f": 10}},
- "dataset1": 5,
- "dataset2": 6,
- "dataset3": 7,
- "dataset3@units": "mm",
- }
- esubtree = {
- "@NX_class": "NXentry",
- "group1": {"@NX_class": "NXcollection", "a": 1, "b": 2},
- "group2": {
- "@NX_class": "NXcollection",
- "@attr1": "attr1",
- "@attr2": "attr2",
- "@attr3": "attr3",
- "@attr4": "attr4",
- "c": 3,
- "d": 4,
- "dataset4": 8,
- "dataset4@units": "keV",
- },
- "group3": {
- "@NX_class": "NXcollection",
- "subgroup": {"@NX_class": "NXcollection", "e": 9, "f": 10},
- },
- "dataset1": 5,
- "dataset2": 6,
- "dataset3": 7,
- "dataset3@units": "mm",
- }
- etreedict = {entry_name: esubtree}
-
- def append_file(update_mode, add_nx_class):
- dictdump.dicttonx(
- wtreedict,
- h5file=self.h5_fname,
- mode="a",
- h5path=entry_name,
- update_mode=update_mode,
- add_nx_class=add_nx_class
- )
-
- def assert_file():
- rtreedict = dictdump.nxtodict(
- self.h5_fname,
- include_attributes=True,
- asarray=False,
- )
- netreedict = self.dictRoundTripNormalize(etreedict)
- try:
- self.assertRecursiveEqual(netreedict, rtreedict)
- except AssertionError:
- from pprint import pprint
- print("\nDUMP:")
- pprint(wtreedict)
- print("\nEXPECTED:")
- pprint(netreedict)
- print("\nHDF5:")
- pprint(rtreedict)
- raise
-
- def assert_append(update_mode, add_nx_class=None):
- append_file(update_mode, add_nx_class=add_nx_class)
- assert_file()
-
- # First to an empty file
- assert_append(None)
-
- # Add non-existing attributes/datasets/groups
- wtreedict["group1"].pop("a")
- wtreedict["group2"].pop("@attr1")
- wtreedict["group2"]["@attr2"] = "attr3" # only for update
- wtreedict["group2"]["@type"] = "test"
- wtreedict["group2"]["dataset4"] = 9 # only for update
- del wtreedict["group2"]["dataset4@units"]
- wtreedict["group3"] = {}
- esubtree["group2"]["@type"] = "test"
- assert_append("add")
-
- # Add update existing attributes and datasets
- esubtree["group2"]["@attr2"] = "attr3"
- esubtree["group2"]["dataset4"] = 9
- assert_append("modify")
-
- # Do not add missing NX_class by default when updating
- wtreedict["group2"]["@NX_class"] = "NXprocess"
- esubtree["group2"]["@NX_class"] = "NXprocess"
- assert_append("modify")
- del wtreedict["group2"]["@NX_class"]
- assert_append("modify")
-
- # Overwrite existing groups/datasets/attributes
- esubtree["group1"].pop("a")
- esubtree["group2"].pop("@attr1")
- esubtree["group2"]["@NX_class"] = "NXcollection"
- esubtree["group2"]["dataset4"] = 9
- del esubtree["group2"]["dataset4@units"]
- esubtree["group3"] = {"@NX_class": "NXcollection"}
- assert_append("replace", add_nx_class=True)
-
-
-class TestNxToDict(H5DictTestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "nx.h5")
- self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5")
-
- def tearDown(self):
- if os.path.exists(self.h5_fname):
- os.unlink(self.h5_fname)
- if os.path.exists(self.h5_ext_fname):
- os.unlink(self.h5_ext_fname)
- os.rmdir(self.tempdir)
-
- def testAttributes(self):
- """Any kind of attribute can be described"""
- ddict = {
- "group": {"dataset": 100, "@group_attr1": 10},
- "dataset": 200,
- "@root_attr": 11,
- "dataset@dataset_attr": "12",
- "group@group_attr2": 13,
- }
- dictdump.dicttonx(ddict, self.h5_fname)
- ddict = dictdump.nxtodict(self.h5_fname, include_attributes=True)
- self.assertEqual(ddict["group"]["@group_attr1"], 10)
- self.assertEqual(ddict["@root_attr"], 11)
- self.assertEqual(ddict["dataset@dataset_attr"], "12")
- self.assertEqual(ddict["group"]["@group_attr2"], 13)
-
- def testDereferenceLinks(self):
- """Write links and dereference on read"""
- ddict = {"ext_group": {"dataset": 10}}
- dictdump.dicttonx(ddict, self.h5_ext_fname)
- ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
- ">relative_softlink": "group/dataset",
- ">absolute_softlink": "/links/group/dataset",
- ">external_link": "nx_ext.h5::/ext_group/dataset"}}
- dictdump.dicttonx(ddict, self.h5_fname)
-
- ddict = dictdump.h5todict(self.h5_fname, dereference_links=True)
- self.assertTrue(ddict["links"]["absolute_softlink"], 10)
- self.assertTrue(ddict["links"]["relative_softlink"], 10)
- self.assertTrue(ddict["links"]["external_link"], 10)
- self.assertTrue(ddict["links"]["group"]["relative_softlink"], 10)
-
- def testPreserveLinks(self):
- """Write/read links"""
- ddict = {"ext_group": {"dataset": 10}}
- dictdump.dicttonx(ddict, self.h5_ext_fname)
- ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
- ">relative_softlink": "group/dataset",
- ">absolute_softlink": "/links/group/dataset",
- ">external_link": "nx_ext.h5::/ext_group/dataset"}}
- dictdump.dicttonx(ddict, self.h5_fname)
-
- ddict = dictdump.nxtodict(self.h5_fname, dereference_links=False)
- self.assertTrue(ddict["links"][">absolute_softlink"], "dataset")
- self.assertTrue(ddict["links"][">relative_softlink"], "group/dataset")
- self.assertTrue(ddict["links"][">external_link"], "/links/group/dataset")
- self.assertTrue(ddict["links"]["group"][">relative_softlink"], "nx_ext.h5::/ext_group/datase")
-
- def testNotExistingPath(self):
- """Test converting not existing path"""
- with h5py.File(self.h5_fname, 'a') as f:
- f['data'] = 1
-
- ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='ignore')
- self.assertFalse(ddict)
-
- with TestLogging(dictdump_logger, error=1):
- ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='log')
- self.assertFalse(ddict)
-
- with self.assertRaises(KeyError):
- h5todict(self.h5_fname, path="/I/am/not/a/path", errors='raise')
-
- def testBrokenLinks(self):
- """Test with broken links"""
- with h5py.File(self.h5_fname, 'a') as f:
- f["/Mars/BrokenSoftLink"] = h5py.SoftLink("/Idontexists")
- f["/Mars/BrokenExternalLink"] = h5py.ExternalLink("notexistingfile.h5", "/Idontexists")
-
- ddict = h5todict(self.h5_fname, path="/Mars", errors='ignore')
- self.assertFalse(ddict)
-
- with TestLogging(dictdump_logger, error=2):
- ddict = h5todict(self.h5_fname, path="/Mars", errors='log')
- self.assertFalse(ddict)
-
- with self.assertRaises(KeyError):
- h5todict(self.h5_fname, path="/Mars", errors='raise')
-
-
-class TestDictToJson(DictTestCase):
- def setUp(self):
- self.dir_path = tempfile.mkdtemp()
- self.json_fname = os.path.join(self.dir_path, "cityattrs.json")
-
- def tearDown(self):
- os.unlink(self.json_fname)
- os.rmdir(self.dir_path)
-
- def testJsonCityAttrs(self):
- self.json_fname = os.path.join(self.dir_path, "cityattrs.json")
- dicttojson(city_attrs, self.json_fname, indent=3)
-
- with open(self.json_fname, "r") as f:
- json_content = f.read()
- self.assertIn('"inhabitants": 160215', json_content)
-
-
-class TestDictToIni(DictTestCase):
- def setUp(self):
- self.dir_path = tempfile.mkdtemp()
- self.ini_fname = os.path.join(self.dir_path, "test.ini")
-
- def tearDown(self):
- os.unlink(self.ini_fname)
- os.rmdir(self.dir_path)
-
- def testConfigDictIO(self):
- """Ensure values and types of data is preserved when dictionary is
- written to file and read back."""
- testdict = {
- 'simple_types': {
- 'float': 1.0,
- 'int': 1,
- 'percent string': '5 % is too much',
- 'backslash string': 'i can use \\',
- 'empty_string': '',
- 'nonestring': 'None',
- 'nonetype': None,
- 'interpstring': 'interpolation: %(percent string)s',
- },
- 'containers': {
- 'list': [-1, 'string', 3.0, False, None],
- 'array': numpy.array([1.0, 2.0, 3.0]),
- 'dict': {
- 'key1': 'Hello World',
- 'key2': 2.0,
- }
- }
- }
-
- dump(testdict, self.ini_fname)
-
- #read the data back
- readdict = load(self.ini_fname)
-
- testdictkeys = list(testdict.keys())
- readkeys = list(readdict.keys())
-
- self.assertTrue(len(readkeys) == len(testdictkeys),
- "Number of read keys not equal")
-
- self.assertEqual(readdict['simple_types']["interpstring"],
- "interpolation: 5 % is too much")
-
- testdict['simple_types']["interpstring"] = "interpolation: 5 % is too much"
-
- for key in testdict["simple_types"]:
- original = testdict['simple_types'][key]
- read = readdict['simple_types'][key]
- self.assertEqual(read, original,
- "Read <%s> instead of <%s>" % (read, original))
-
- for key in testdict["containers"]:
- original = testdict["containers"][key]
- read = readdict["containers"][key]
- if key == 'array':
- self.assertEqual(read.all(), original.all(),
- "Read <%s> instead of <%s>" % (read, original))
- else:
- self.assertEqual(read, original,
- "Read <%s> instead of <%s>" % (read, original))
-
- def testConfigDictOrder(self):
- """Ensure order is preserved when dictionary is
- written to file and read back."""
- test_dict = {'banana': 3, 'apple': 4, 'pear': 1, 'orange': 2}
- # sort by key
- test_ordered_dict1 = OrderedDict(sorted(test_dict.items(),
- key=lambda t: t[0]))
- # sort by value
- test_ordered_dict2 = OrderedDict(sorted(test_dict.items(),
- key=lambda t: t[1]))
- # add the two ordered dict as sections of a third ordered dict
- test_ordered_dict3 = OrderedDict()
- test_ordered_dict3["section1"] = test_ordered_dict1
- test_ordered_dict3["section2"] = test_ordered_dict2
-
- # write to ini and read back as a ConfigDict (inherits OrderedDict)
- dump(test_ordered_dict3,
- self.ini_fname, fmat="ini")
- read_instance = ConfigDict()
- read_instance.read(self.ini_fname)
-
- # loop through original and read-back dictionaries,
- # test identical order for key/value pairs
- for orig_key, section in zip(test_ordered_dict3.keys(),
- read_instance.keys()):
- self.assertEqual(orig_key, section)
- for orig_key2, read_key in zip(test_ordered_dict3[section].keys(),
- read_instance[section].keys()):
- self.assertEqual(orig_key2, read_key)
- self.assertEqual(test_ordered_dict3[section][orig_key2],
- read_instance[section][read_key])
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestDictToIni))
- test_suite.addTest(loadTests(TestDictToH5))
- test_suite.addTest(loadTests(TestDictToNx))
- test_suite.addTest(loadTests(TestDictToJson))
- test_suite.addTest(loadTests(TestH5ToDict))
- test_suite.addTest(loadTests(TestNxToDict))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_fabioh5.py b/silx/io/test/test_fabioh5.py
deleted file mode 100755
index f2c85b1..0000000
--- a/silx/io/test/test_fabioh5.py
+++ /dev/null
@@ -1,629 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for fabioh5 wrapper"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "02/07/2018"
-
-import os
-import logging
-import numpy
-import unittest
-import tempfile
-import shutil
-
-_logger = logging.getLogger(__name__)
-
-import fabio
-import h5py
-
-from .. import commonh5
-from .. import fabioh5
-
-
-class TestFabioH5(unittest.TestCase):
-
- def setUp(self):
-
- header = {
- "integer": "-100",
- "float": "1.0",
- "string": "hi!",
- "list_integer": "100 50 0",
- "list_float": "1.0 2.0 3.5",
- "string_looks_like_list": "2000 hi!",
- }
- data = numpy.array([[10, 11], [12, 13], [14, 15]], dtype=numpy.int64)
- self.fabio_image = fabio.numpyimage.NumpyImage(data, header)
- self.h5_image = fabioh5.File(fabio_image=self.fabio_image)
-
- def test_main_groups(self):
- self.assertEqual(self.h5_image.h5py_class, h5py.File)
- self.assertEqual(self.h5_image["/"].h5py_class, h5py.File)
- self.assertEqual(self.h5_image["/scan_0"].h5py_class, h5py.Group)
- self.assertEqual(self.h5_image["/scan_0/instrument"].h5py_class, h5py.Group)
- self.assertEqual(self.h5_image["/scan_0/measurement"].h5py_class, h5py.Group)
-
- def test_wrong_path_syntax(self):
- # result tested with a default h5py file
- self.assertRaises(ValueError, lambda: self.h5_image[""])
-
- def test_wrong_root_name(self):
- # result tested with a default h5py file
- self.assertRaises(KeyError, lambda: self.h5_image["/foo"])
-
- def test_wrong_root_path(self):
- # result tested with a default h5py file
- self.assertRaises(KeyError, lambda: self.h5_image["/foo/foo"])
-
- def test_wrong_name(self):
- # result tested with a default h5py file
- self.assertRaises(KeyError, lambda: self.h5_image["foo"])
-
- def test_wrong_path(self):
- # result tested with a default h5py file
- self.assertRaises(KeyError, lambda: self.h5_image["foo/foo"])
-
- def test_single_frame(self):
- data = numpy.arange(2 * 3)
- data.shape = 2, 3
- fabio_image = fabio.edfimage.edfimage(data=data)
- h5_image = fabioh5.File(fabio_image=fabio_image)
-
- dataset = h5_image["/scan_0/instrument/detector_0/data"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertTrue(isinstance(dataset[()], numpy.ndarray))
- self.assertEqual(dataset.dtype.kind, "i")
- self.assertEqual(dataset.shape, (2, 3))
- self.assertEqual(dataset[...][0, 0], 0)
- self.assertEqual(dataset.attrs["interpretation"], "image")
-
- def test_multi_frames(self):
- data = numpy.arange(2 * 3)
- data.shape = 2, 3
- fabio_image = fabio.edfimage.edfimage(data=data)
- fabio_image.append_frame(data=data)
- h5_image = fabioh5.File(fabio_image=fabio_image)
-
- dataset = h5_image["/scan_0/instrument/detector_0/data"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertTrue(isinstance(dataset[()], numpy.ndarray))
- self.assertEqual(dataset.dtype.kind, "i")
- self.assertEqual(dataset.shape, (2, 2, 3))
- self.assertEqual(dataset[...][0, 0, 0], 0)
- self.assertEqual(dataset.attrs["interpretation"], "image")
-
- def test_heterogeneous_frames(self):
- """Frames containing 2 images with different sizes and a cube"""
- data1 = numpy.arange(2 * 3)
- data1.shape = 2, 3
- data2 = numpy.arange(2 * 5)
- data2.shape = 2, 5
- data3 = numpy.arange(2 * 5 * 1)
- data3.shape = 2, 5, 1
- fabio_image = fabio.edfimage.edfimage(data=data1)
- fabio_image.append_frame(data=data2)
- fabio_image.append_frame(data=data3)
- h5_image = fabioh5.File(fabio_image=fabio_image)
-
- dataset = h5_image["/scan_0/instrument/detector_0/data"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertTrue(isinstance(dataset[()], numpy.ndarray))
- self.assertEqual(dataset.dtype.kind, "i")
- self.assertEqual(dataset.shape, (3, 2, 5, 1))
- self.assertEqual(dataset[...][0, 0, 0], 0)
- self.assertEqual(dataset.attrs["interpretation"], "image")
-
- def test_single_3d_frame(self):
- """Image source contains a cube"""
- data = numpy.arange(2 * 3 * 4)
- data.shape = 2, 3, 4
- # Do not provide the data to the constructor to avoid slicing of the
- # data. In this way the result stay a cube, and not a multi-frame
- fabio_image = fabio.edfimage.edfimage()
- fabio_image.data = data
- h5_image = fabioh5.File(fabio_image=fabio_image)
-
- dataset = h5_image["/scan_0/instrument/detector_0/data"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertTrue(isinstance(dataset[()], numpy.ndarray))
- self.assertEqual(dataset.dtype.kind, "i")
- self.assertEqual(dataset.shape, (2, 3, 4))
- self.assertEqual(dataset[...][0, 0, 0], 0)
- self.assertEqual(dataset.attrs["interpretation"], "image")
-
- def test_metadata_int(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/integer"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertEqual(dataset[()], -100)
- self.assertEqual(dataset.dtype.kind, "i")
- self.assertEqual(dataset.shape, (1,))
-
- def test_metadata_float(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/float"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertEqual(dataset[()], 1.0)
- self.assertEqual(dataset.dtype.kind, "f")
- self.assertEqual(dataset.shape, (1,))
-
- def test_metadata_string(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/string"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertEqual(dataset[()], numpy.string_("hi!"))
- self.assertEqual(dataset.dtype.type, numpy.string_)
- self.assertEqual(dataset.shape, (1,))
-
- def test_metadata_list_integer(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/list_integer"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertEqual(dataset.dtype.kind, "u")
- self.assertEqual(dataset.shape, (1, 3))
- self.assertEqual(dataset[0, 0], 100)
- self.assertEqual(dataset[0, 1], 50)
-
- def test_metadata_list_float(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/list_float"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertEqual(dataset.dtype.kind, "f")
- self.assertEqual(dataset.shape, (1, 3))
- self.assertEqual(dataset[0, 0], 1.0)
- self.assertEqual(dataset[0, 1], 2.0)
-
- def test_metadata_list_looks_like_list(self):
- dataset = self.h5_image["/scan_0/instrument/detector_0/others/string_looks_like_list"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertEqual(dataset[()], numpy.string_("2000 hi!"))
- self.assertEqual(dataset.dtype.type, numpy.string_)
- self.assertEqual(dataset.shape, (1,))
-
- def test_float_32(self):
- float_list = [u'1.2', u'1.3', u'1.4']
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = None
- for float_item in float_list:
- header = {"float_item": float_item}
- if fabio_image is None:
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- else:
- fabio_image.append_frame(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
- data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
- # There is no equality between items
- self.assertEqual(len(data), len(set(data)))
- # At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
- self.assertLessEqual(data.dtype.itemsize, 32 / 8)
-
- def test_float_64(self):
- float_list = [
- u'1469117129.082226',
- u'1469117136.684986', u'1469117144.312749', u'1469117151.892507',
- u'1469117159.474265', u'1469117167.100027', u'1469117174.815799',
- u'1469117182.437561', u'1469117190.094326', u'1469117197.721089']
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = None
- for float_item in float_list:
- header = {"time_of_day": float_item}
- if fabio_image is None:
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- else:
- fabio_image.append_frame(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
- data = h5_image["/scan_0/instrument/detector_0/others/time_of_day"]
- # There is no equality between items
- self.assertEqual(len(data), len(set(data)))
- # At least a float64
- self.assertIn(data.dtype.kind, ['d', 'f'])
- self.assertGreaterEqual(data.dtype.itemsize, 64 / 8)
-
- def test_mixed_float_size__scalar(self):
- # We expect to have a precision of 32 bits
- float_list = [u'1.2', u'1.3001']
- expected_float_result = [1.2, 1.3001]
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = None
- for float_item in float_list:
- header = {"float_item": float_item}
- if fabio_image is None:
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- else:
- fabio_image.append_frame(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
- data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
- # At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
- self.assertLessEqual(data.dtype.itemsize, 32 / 8)
- for computed, expected in zip(data, expected_float_result):
- numpy.testing.assert_almost_equal(computed, expected, 5)
-
- def test_mixed_float_size__list(self):
- # We expect to have a precision of 32 bits
- float_list = [u'1.2 1.3001']
- expected_float_result = numpy.array([[1.2, 1.3001]])
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = None
- for float_item in float_list:
- header = {"float_item": float_item}
- if fabio_image is None:
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- else:
- fabio_image.append_frame(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
- data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
- # At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
- self.assertLessEqual(data.dtype.itemsize, 32 / 8)
- for computed, expected in zip(data, expected_float_result):
- numpy.testing.assert_almost_equal(computed, expected, 5)
-
- def test_mixed_float_size__list_of_list(self):
- # We expect to have a precision of 32 bits
- float_list = [u'1.2 1.3001', u'1.3001 1.3001']
- expected_float_result = numpy.array([[1.2, 1.3001], [1.3001, 1.3001]])
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = None
- for float_item in float_list:
- header = {"float_item": float_item}
- if fabio_image is None:
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- else:
- fabio_image.append_frame(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
- data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
- # At worst a float32
- self.assertIn(data.dtype.kind, ['d', 'f'])
- self.assertLessEqual(data.dtype.itemsize, 32 / 8)
- for computed, expected in zip(data, expected_float_result):
- numpy.testing.assert_almost_equal(computed, expected, 5)
-
- def test_ub_matrix(self):
- """Data from mediapix.edf"""
- header = {}
- header["UB_mne"] = 'UB0 UB1 UB2 UB3 UB4 UB5 UB6 UB7 UB8'
- header["UB_pos"] = '1.99593e-16 2.73682e-16 -1.54 -1.08894 1.08894 1.6083e-16 1.08894 1.08894 9.28619e-17'
- header["sample_mne"] = 'U0 U1 U2 U3 U4 U5'
- header["sample_pos"] = '4.08 4.08 4.08 90 90 90'
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
- sample = h5_image["/scan_0/sample"]
- self.assertIsNotNone(sample)
- self.assertEqual(sample.attrs["NXclass"], "NXsample")
-
- d = sample['unit_cell_abc']
- expected = numpy.array([4.08, 4.08, 4.08])
- self.assertIsNotNone(d)
- self.assertEqual(d.shape, (3, ))
- self.assertIn(d.dtype.kind, ['d', 'f'])
- numpy.testing.assert_array_almost_equal(d[...], expected)
-
- d = sample['unit_cell_alphabetagamma']
- expected = numpy.array([90.0, 90.0, 90.0])
- self.assertIsNotNone(d)
- self.assertEqual(d.shape, (3, ))
- self.assertIn(d.dtype.kind, ['d', 'f'])
- numpy.testing.assert_array_almost_equal(d[...], expected)
-
- d = sample['ub_matrix']
- expected = numpy.array([[[1.99593e-16, 2.73682e-16, -1.54],
- [-1.08894, 1.08894, 1.6083e-16],
- [1.08894, 1.08894, 9.28619e-17]]])
- self.assertIsNotNone(d)
- self.assertEqual(d.shape, (1, 3, 3))
- self.assertIn(d.dtype.kind, ['d', 'f'])
- numpy.testing.assert_array_almost_equal(d[...], expected)
-
- def test_interpretation_mca_edf(self):
- """EDF files with two or more headers starting with "MCA"
- must have @interpretation = "spectrum" an the data."""
- header = {
- "Title": "zapimage samy -4.975 -5.095 80 500 samz -4.091 -4.171 70 0",
- "MCA a": -23.812,
- "MCA b": 2.7107,
- "MCA c": 8.1164e-06}
-
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
- h5_image = fabioh5.File(fabio_image=fabio_image)
-
- data_dataset = h5_image["/scan_0/measurement/image_0/data"]
- self.assertEqual(data_dataset.attrs["interpretation"], "spectrum")
-
- data_dataset = h5_image["/scan_0/instrument/detector_0/data"]
- self.assertEqual(data_dataset.attrs["interpretation"], "spectrum")
-
- data_dataset = h5_image["/scan_0/measurement/image_0/info/data"]
- self.assertEqual(data_dataset.attrs["interpretation"], "spectrum")
-
- def test_get_api(self):
- result = self.h5_image.get("scan_0", getclass=True, getlink=True)
- self.assertIs(result, h5py.HardLink)
- result = self.h5_image.get("scan_0", getclass=False, getlink=True)
- self.assertIsInstance(result, h5py.HardLink)
- result = self.h5_image.get("scan_0", getclass=True, getlink=False)
- self.assertIs(result, h5py.Group)
- result = self.h5_image.get("scan_0", getclass=False, getlink=False)
- self.assertIsInstance(result, commonh5.Group)
-
- def test_detector_link(self):
- detector1 = self.h5_image["/scan_0/instrument/detector_0"]
- detector2 = self.h5_image["/scan_0/measurement/image_0/info"]
- self.assertIsNot(detector1, detector2)
- self.assertEqual(list(detector1.items()), list(detector2.items()))
- self.assertEqual(self.h5_image.get(detector2.name, getlink=True).path, detector1.name)
-
- def test_detector_data_link(self):
- data1 = self.h5_image["/scan_0/instrument/detector_0/data"]
- data2 = self.h5_image["/scan_0/measurement/image_0/data"]
- self.assertIsNot(data1, data2)
- self.assertIs(data1._get_data(), data2._get_data())
- self.assertEqual(self.h5_image.get(data2.name, getlink=True).path, data1.name)
-
- def test_dirty_header(self):
- """Test that it does not fail"""
- try:
- header = {}
- header["foo"] = b'abc'
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = fabio.edfimage.edfimage(data=data, header=header)
- header = {}
- header["foo"] = b'a\x90bc\xFE'
- fabio_image.append_frame(data=data, header=header)
- except Exception as e:
- _logger.error(e.args[0])
- _logger.debug("Backtrace", exc_info=True)
- self.skipTest("fabio do not allow to create the resource")
-
- h5_image = fabioh5.File(fabio_image=fabio_image)
- scan_header_path = "/scan_0/instrument/file/scan_header"
- self.assertIn(scan_header_path, h5_image)
- data = h5_image[scan_header_path]
- self.assertIsInstance(data[...], numpy.ndarray)
-
- def test_unicode_header(self):
- """Test that it does not fail"""
- try:
- header = {}
- header["foo"] = b'abc'
- data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
- fabio_image = fabio.edfimage.edfimage(data=data, header=header)
- header = {}
- header["foo"] = u'abc\u2764'
- fabio_image.append_frame(data=data, header=header)
- except Exception as e:
- _logger.error(e.args[0])
- _logger.debug("Backtrace", exc_info=True)
- self.skipTest("fabio do not allow to create the resource")
-
- h5_image = fabioh5.File(fabio_image=fabio_image)
- scan_header_path = "/scan_0/instrument/file/scan_header"
- self.assertIn(scan_header_path, h5_image)
- data = h5_image[scan_header_path]
- self.assertIsInstance(data[...], numpy.ndarray)
-
-
-class TestFabioH5MultiFrames(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
-
- names = ["A", "B", "C", "D"]
- values = [["32000", "-10", "5.0", "1"],
- ["-32000", "-10", "5.0", "1"]]
-
- fabio_file = None
-
- for i in range(10):
- header = {
- "image_id": "%d" % i,
- "integer": "-100",
- "float": "1.0",
- "string": "hi!",
- "list_integer": "100 50 0",
- "list_float": "1.0 2.0 3.5",
- "string_looks_like_list": "2000 hi!",
- "motor_mne": " ".join(names),
- "motor_pos": " ".join(values[i % len(values)]),
- "counter_mne": " ".join(names),
- "counter_pos": " ".join(values[i % len(values)])
- }
- for iname, name in enumerate(names):
- header[name] = values[i % len(values)][iname]
-
- data = numpy.array([[i, 11], [12, 13], [14, 15]], dtype=numpy.int64)
- if fabio_file is None:
- fabio_file = fabio.edfimage.EdfImage(data=data, header=header)
- else:
- fabio_file.append_frame(data=data, header=header)
-
- cls.fabio_file = fabio_file
- cls.fabioh5 = fabioh5.File(fabio_image=fabio_file)
-
- def test_others(self):
- others = self.fabioh5["/scan_0/instrument/detector_0/others"]
- dataset = others["A"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 1)
- self.assertEqual(dataset.dtype.kind, "i")
- dataset = others["B"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 1)
- self.assertEqual(dataset.dtype.kind, "i")
- dataset = others["C"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 1)
- self.assertEqual(dataset.dtype.kind, "f")
- dataset = others["D"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 1)
- self.assertEqual(dataset.dtype.kind, "u")
-
- def test_positioners(self):
- counters = self.fabioh5["/scan_0/instrument/positioners"]
- # At least 32 bits, no unsigned values
- dataset = counters["A"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "i")
- dataset = counters["B"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "i")
- dataset = counters["C"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "f")
- dataset = counters["D"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "i")
-
- def test_counters(self):
- counters = self.fabioh5["/scan_0/measurement"]
- # At least 32 bits, no unsigned values
- dataset = counters["A"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "i")
- dataset = counters["B"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "i")
- dataset = counters["C"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "f")
- dataset = counters["D"]
- self.assertGreaterEqual(dataset.dtype.itemsize, 4)
- self.assertEqual(dataset.dtype.kind, "i")
-
-
-class TestFabioH5WithEdf(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
-
- cls.tmp_directory = tempfile.mkdtemp()
-
- cls.edf_filename = os.path.join(cls.tmp_directory, "test.edf")
-
- header = {
- "integer": "-100",
- "float": "1.0",
- "string": "hi!",
- "list_integer": "100 50 0",
- "list_float": "1.0 2.0 3.5",
- "string_looks_like_list": "2000 hi!",
- }
- data = numpy.array([[10, 11], [12, 13], [14, 15]], dtype=numpy.int64)
- fabio_image = fabio.edfimage.edfimage(data, header)
- fabio_image.write(cls.edf_filename)
-
- cls.fabio_image = fabio.open(cls.edf_filename)
- cls.h5_image = fabioh5.File(fabio_image=cls.fabio_image)
-
- @classmethod
- def tearDownClass(cls):
- cls.fabio_image = None
- cls.h5_image = None
- shutil.rmtree(cls.tmp_directory)
-
- def test_reserved_format_metadata(self):
- if fabio.hexversion < 327920: # 0.5.0 final
- self.skipTest("fabio >= 0.5.0 final is needed")
-
- # The EDF contains reserved keys in the header
- self.assertIn("HeaderID", self.fabio_image.header)
- # We do not expose them in FabioH5
- self.assertNotIn("/scan_0/instrument/detector_0/others/HeaderID", self.h5_image)
-
-
-class _TestableFrameData(fabioh5.FrameData):
- """Allow to test if the full data is reached."""
- def _create_data(self):
- raise RuntimeError("Not supposed to be called")
-
-
-class TestFabioH5WithFileSeries(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
-
- cls.tmp_directory = tempfile.mkdtemp()
-
- cls.edf_filenames = []
-
- for i in range(10):
- filename = os.path.join(cls.tmp_directory, "test_%04d.edf" % i)
- cls.edf_filenames.append(filename)
-
- header = {
- "image_id": "%d" % i,
- "integer": "-100",
- "float": "1.0",
- "string": "hi!",
- "list_integer": "100 50 0",
- "list_float": "1.0 2.0 3.5",
- "string_looks_like_list": "2000 hi!",
- }
- data = numpy.array([[i, 11], [12, 13], [14, 15]], dtype=numpy.int64)
- fabio_image = fabio.edfimage.edfimage(data, header)
- fabio_image.write(filename)
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.tmp_directory)
-
- def _testH5Image(self, h5_image):
- # test data
- dataset = h5_image["/scan_0/instrument/detector_0/data"]
- self.assertEqual(dataset.h5py_class, h5py.Dataset)
- self.assertTrue(isinstance(dataset[()], numpy.ndarray))
- self.assertEqual(dataset.dtype.kind, "i")
- self.assertEqual(dataset.shape, (10, 3, 2))
- self.assertEqual(list(dataset[:, 0, 0]), list(range(10)))
- self.assertEqual(dataset.attrs["interpretation"], "image")
- # test metatdata
- dataset = h5_image["/scan_0/instrument/detector_0/others/image_id"]
- self.assertEqual(list(dataset[...]), list(range(10)))
-
- def testFileList(self):
- h5_image = fabioh5.File(file_series=self.edf_filenames)
- self._testH5Image(h5_image)
-
- def testFileSeries(self):
- file_series = fabioh5._FileSeries(self.edf_filenames)
- h5_image = fabioh5.File(file_series=file_series)
- self._testH5Image(h5_image)
-
- def testFrameDataCache(self):
- file_series = fabioh5._FileSeries(self.edf_filenames)
- reader = fabioh5.FabioReader(file_series=file_series)
- frameData = _TestableFrameData("foo", reader)
- self.assertEqual(frameData.dtype.kind, "i")
- self.assertEqual(frameData.shape, (10, 3, 2))
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestFabioH5))
- test_suite.addTest(loadTests(TestFabioH5MultiFrames))
- test_suite.addTest(loadTests(TestFabioH5WithEdf))
- test_suite.addTest(loadTests(TestFabioH5WithFileSeries))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_h5py_utils.py b/silx/io/test/test_h5py_utils.py
deleted file mode 100644
index 2e2e3dd..0000000
--- a/silx/io/test/test_h5py_utils.py
+++ /dev/null
@@ -1,397 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for h5py utilities"""
-
-__authors__ = ["W. de Nolf"]
-__license__ = "MIT"
-__date__ = "27/01/2020"
-
-
-import unittest
-import os
-import sys
-import time
-import shutil
-import tempfile
-import threading
-import multiprocessing
-from contextlib import contextmanager
-
-from .. import h5py_utils
-from ...utils.retry import RetryError, RetryTimeoutError
-
-IS_WINDOWS = sys.platform == "win32"
-
-
-def _subprocess_context_main(queue, contextmgr, *args, **kw):
- try:
- with contextmgr(*args, **kw):
- queue.put(None)
- threading.Event().wait()
- except Exception:
- queue.put(None)
- raise
-
-
-@contextmanager
-def _subprocess_context(contextmgr, *args, **kw):
- timeout = kw.pop("timeout", 10)
- queue = multiprocessing.Queue(maxsize=1)
- p = multiprocessing.Process(
- target=_subprocess_context_main, args=(queue, contextmgr) + args, kwargs=kw
- )
- p.start()
- try:
- queue.get(timeout=timeout)
- yield
- finally:
- try:
- p.kill()
- except AttributeError:
- p.terminate()
- p.join(timeout)
-
-
-@contextmanager
-def _open_context(filename, **kw):
- with h5py_utils.File(filename, **kw) as f:
- if kw.get("mode") == "w":
- f["check"] = True
- f.flush()
- yield f
-
-
-def _cause_segfault():
- import ctypes
-
- i = ctypes.c_char(b"a")
- j = ctypes.pointer(i)
- c = 0
- while True:
- j[c] = b"a"
- c += 1
-
-
-def _top_level_names_test(txtfilename, *args, **kw):
- sys.stderr = open(os.devnull, "w")
-
- with open(txtfilename, mode="r") as f:
- failcounter = int(f.readline().strip())
-
- ncausefailure = kw.pop("ncausefailure")
- faildelay = kw.pop("faildelay")
- if failcounter < ncausefailure:
- time.sleep(faildelay)
- failcounter += 1
- with open(txtfilename, mode="w") as f:
- f.write(str(failcounter))
- if failcounter % 2:
- raise RetryError
- else:
- _cause_segfault()
- return h5py_utils._top_level_names(*args, **kw)
-
-
-top_level_names_test = h5py_utils.retry_in_subprocess()(_top_level_names_test)
-
-
-def subtests(test):
- def wrapper(self):
- for _ in self._subtests():
- with self.subTest(**self._subtest_options):
- test(self)
-
- return wrapper
-
-
-class TestH5pyUtils(unittest.TestCase):
- def setUp(self):
- self.test_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- shutil.rmtree(self.test_dir)
-
- def _subtests(self):
- self._subtest_options = {"mode": "w"}
- self.filename_generator = self._filenames()
- yield
- self._subtest_options = {"mode": "w", "libver": "latest"}
- self.filename_generator = self._filenames()
- yield
-
- @property
- def _liber_allows_concurrent_access(self):
- return self._subtest_options.get("libver") in [None, "earliest", "v18"]
-
- def _filenames(self):
- i = 1
- while True:
- filename = os.path.join(self.test_dir, "file{}.h5".format(i))
- with self._open_context(filename):
- pass
- yield filename
- i += 1
-
- def _new_filename(self):
- return next(self.filename_generator)
-
- @contextmanager
- def _open_context(self, filename, **kwargs):
- kw = self._subtest_options
- kw.update(kwargs)
- with _open_context(filename, **kw) as f:
-
- yield f
-
- @contextmanager
- def _open_context_subprocess(self, filename, **kwargs):
- kw = self._subtest_options
- kw.update(kwargs)
- with _subprocess_context(_open_context, filename, **kw):
- yield
-
- def _assert_hdf5_data(self, f):
- self.assertTrue(f["check"][()])
-
- def _validate_hdf5_data(self, filename, swmr=False):
- with self._open_context(filename, mode="r") as f:
- self.assertEqual(f.swmr_mode, swmr)
- self._assert_hdf5_data(f)
-
- @subtests
- def test_modes_single_process(self):
- orig = os.environ.get("HDF5_USE_FILE_LOCKING")
- filename1 = self._new_filename()
- self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
- filename2 = self._new_filename()
- self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
- with self._open_context(filename1, mode="r"):
- with self._open_context(filename2, mode="r"):
- pass
- for mode in ["w", "a"]:
- with self.assertRaises(RuntimeError):
- with self._open_context(filename2, mode=mode):
- pass
- self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
- with self._open_context(filename1, mode="a"):
- for mode in ["w", "a"]:
- with self._open_context(filename2, mode=mode):
- pass
- with self.assertRaises(RuntimeError):
- with self._open_context(filename2, mode="r"):
- pass
- self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
-
- @subtests
- def test_modes_multi_process(self):
- if not self._liber_allows_concurrent_access:
- # A concurrent reader with HDF5_USE_FILE_LOCKING=FALSE is
- # no longer works with HDF5 >=1.10 (you get an exception
- # when trying to open the file)
- return
- filename = self._new_filename()
-
- # File open by truncating writer
- with self._open_context_subprocess(filename, mode="w"):
- with self._open_context(filename, mode="r") as f:
- self._assert_hdf5_data(f)
- if IS_WINDOWS:
- with self._open_context(filename, mode="a") as f:
- self._assert_hdf5_data(f)
- else:
- with self.assertRaises(OSError):
- with self._open_context(filename, mode="a") as f:
- pass
- self._validate_hdf5_data(filename)
-
- # File open by appending writer
- with self._open_context_subprocess(filename, mode="a"):
- with self._open_context(filename, mode="r") as f:
- self._assert_hdf5_data(f)
- if IS_WINDOWS:
- with self._open_context(filename, mode="a") as f:
- self._assert_hdf5_data(f)
- else:
- with self.assertRaises(OSError):
- with self._open_context(filename, mode="a") as f:
- pass
- self._validate_hdf5_data(filename)
-
- # File open by reader
- with self._open_context_subprocess(filename, mode="r"):
- with self._open_context(filename, mode="r") as f:
- self._assert_hdf5_data(f)
- with self._open_context(filename, mode="a") as f:
- pass
- self._validate_hdf5_data(filename)
-
- # File open by locking reader
- with _subprocess_context(
- _open_context, filename, mode="r", enable_file_locking=True
- ):
- with self._open_context(filename, mode="r") as f:
- self._assert_hdf5_data(f)
- if IS_WINDOWS:
- with self._open_context(filename, mode="a") as f:
- self._assert_hdf5_data(f)
- else:
- with self.assertRaises(OSError):
- with self._open_context(filename, mode="a") as f:
- pass
- self._validate_hdf5_data(filename)
-
- @subtests
- @unittest.skipIf(not h5py_utils.HAS_SWMR, "SWMR not supported")
- def test_modes_multi_process_swmr(self):
- filename = self._new_filename()
-
- with self._open_context(filename, mode="w", libver="latest") as f:
- pass
-
- # File open by SWMR writer
- with self._open_context_subprocess(filename, mode="a", swmr=True):
- with self._open_context(filename, mode="r") as f:
- assert f.swmr_mode
- self._assert_hdf5_data(f)
- with self.assertRaises(OSError):
- with self._open_context(filename, mode="a") as f:
- pass
- self._validate_hdf5_data(filename, swmr=True)
-
- @subtests
- def test_retry_defaults(self):
- filename = self._new_filename()
-
- names = h5py_utils.top_level_names(filename)
- self.assertEqual(names, [])
-
- names = h5py_utils.safe_top_level_names(filename)
- self.assertEqual(names, [])
-
- names = h5py_utils.top_level_names(filename, include_only=None)
- self.assertEqual(names, ["check"])
-
- names = h5py_utils.safe_top_level_names(filename, include_only=None)
- self.assertEqual(names, ["check"])
-
- with h5py_utils.open_item(filename, "/check", validate=lambda x: False) as item:
- self.assertEqual(item, None)
-
- with h5py_utils.open_item(filename, "/check", validate=None) as item:
- self.assertTrue(item[()])
-
- with self.assertRaises(RetryTimeoutError):
- with h5py_utils.open_item(
- filename,
- "/check",
- retry_timeout=0.1,
- retry_invalid=True,
- validate=lambda x: False,
- ) as item:
- pass
-
- ncall = 0
-
- def validate(item):
- nonlocal ncall
- if ncall >= 1:
- return True
- else:
- ncall += 1
- raise RetryError
-
- with h5py_utils.open_item(
- filename, "/check", validate=validate, retry_timeout=1, retry_invalid=True
- ) as item:
- self.assertTrue(item[()])
-
- @subtests
- def test_retry_custom(self):
- filename = self._new_filename()
- ncausefailure = 3
- faildelay = 0.1
- sufficient_timeout = ncausefailure * (faildelay + 10)
- insufficient_timeout = ncausefailure * faildelay * 0.5
-
- @h5py_utils.retry_contextmanager()
- def open_item(filename, name):
- nonlocal failcounter
- if failcounter < ncausefailure:
- time.sleep(faildelay)
- failcounter += 1
- raise RetryError
- with h5py_utils.File(filename) as h5file:
- yield h5file[name]
-
- failcounter = 0
- kw = {"retry_timeout": sufficient_timeout}
- with open_item(filename, "/check", **kw) as item:
- self.assertTrue(item[()])
-
- failcounter = 0
- kw = {"retry_timeout": insufficient_timeout}
- with self.assertRaises(RetryTimeoutError):
- with open_item(filename, "/check", **kw) as item:
- pass
-
- @subtests
- def test_retry_in_subprocess(self):
- filename = self._new_filename()
- txtfilename = os.path.join(self.test_dir, "failcounter.txt")
- ncausefailure = 3
- faildelay = 0.1
- sufficient_timeout = ncausefailure * (faildelay + 10)
- insufficient_timeout = ncausefailure * faildelay * 0.5
-
- kw = {
- "retry_timeout": sufficient_timeout,
- "include_only": None,
- "ncausefailure": ncausefailure,
- "faildelay": faildelay,
- }
- with open(txtfilename, mode="w") as f:
- f.write("0")
- names = top_level_names_test(txtfilename, filename, **kw)
- self.assertEqual(names, ["check"])
-
- kw = {
- "retry_timeout": insufficient_timeout,
- "include_only": None,
- "ncausefailure": ncausefailure,
- "faildelay": faildelay,
- }
- with open(txtfilename, mode="w") as f:
- f.write("0")
- with self.assertRaises(RetryTimeoutError):
- top_level_names_test(txtfilename, filename, **kw)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestH5pyUtils))
- return test_suite
-
-
-if __name__ == "__main__":
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_nxdata.py b/silx/io/test/test_nxdata.py
deleted file mode 100644
index 80cc193..0000000
--- a/silx/io/test/test_nxdata.py
+++ /dev/null
@@ -1,579 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for NXdata parsing"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "24/03/2020"
-
-
-import tempfile
-import unittest
-import h5py
-import numpy
-import six
-
-from .. import nxdata
-
-
-text_dtype = h5py.special_dtype(vlen=six.text_type)
-
-
-class TestNXdata(unittest.TestCase):
- def setUp(self):
- tmp = tempfile.NamedTemporaryFile(prefix="nxdata_examples_", suffix=".h5", delete=True)
- tmp.file.close()
- self.h5fname = tmp.name
- self.h5f = h5py.File(tmp.name, "w")
-
- # SCALARS
- g0d = self.h5f.create_group("scalars")
-
- g0d0 = g0d.create_group("0D_scalar")
- g0d0.attrs["NX_class"] = "NXdata"
- g0d0.attrs["signal"] = "scalar"
- g0d0.create_dataset("scalar", data=10)
- g0d0.create_dataset("scalar_errors", data=0.1)
-
- g0d1 = g0d.create_group("2D_scalars")
- g0d1.attrs["NX_class"] = "NXdata"
- g0d1.attrs["signal"] = "scalars"
- ds = g0d1.create_dataset("scalars", data=numpy.arange(3 * 10).reshape((3, 10)))
- ds.attrs["interpretation"] = "scalar"
-
- g0d1 = g0d.create_group("4D_scalars")
- g0d1.attrs["NX_class"] = "NXdata"
- g0d1.attrs["signal"] = "scalars"
- ds = g0d1.create_dataset("scalars", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10)))
- ds.attrs["interpretation"] = "scalar"
-
- # SPECTRA
- g1d = self.h5f.create_group("spectra")
-
- g1d0 = g1d.create_group("1D_spectrum")
- g1d0.attrs["NX_class"] = "NXdata"
- g1d0.attrs["signal"] = "count"
- g1d0.attrs["auxiliary_signals"] = numpy.array(["count2", "count3"],
- dtype=text_dtype)
- g1d0.attrs["axes"] = "energy_calib"
- g1d0.attrs["uncertainties"] = numpy.array(["energy_errors", ],
- dtype=text_dtype)
- g1d0.create_dataset("count", data=numpy.arange(10))
- g1d0.create_dataset("count2", data=0.5 * numpy.arange(10))
- d = g1d0.create_dataset("count3", data=0.4 * numpy.arange(10))
- d.attrs["long_name"] = "3rd counter"
- g1d0.create_dataset("title", data="Title as dataset (like nexpy)")
- g1d0.create_dataset("energy_calib", data=(10, 5)) # 10 * idx + 5
- g1d0.create_dataset("energy_errors", data=3.14 * numpy.random.rand(10))
-
- g1d1 = g1d.create_group("2D_spectra")
- g1d1.attrs["NX_class"] = "NXdata"
- g1d1.attrs["signal"] = "counts"
- ds = g1d1.create_dataset("counts", data=numpy.arange(3 * 10).reshape((3, 10)))
- ds.attrs["interpretation"] = "spectrum"
-
- g1d2 = g1d.create_group("4D_spectra")
- g1d2.attrs["NX_class"] = "NXdata"
- g1d2.attrs["signal"] = "counts"
- g1d2.attrs["axes"] = numpy.array(["energy", ], dtype=text_dtype)
- ds = g1d2.create_dataset("counts", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10)))
- ds.attrs["interpretation"] = "spectrum"
- ds = g1d2.create_dataset("errors", data=4.5 * numpy.random.rand(2, 2, 3, 10))
- ds = g1d2.create_dataset("energy", data=5 + 10 * numpy.arange(15),
- shuffle=True, compression="gzip")
- ds.attrs["long_name"] = "Calibrated energy"
- ds.attrs["first_good"] = 3
- ds.attrs["last_good"] = 12
- g1d2.create_dataset("energy_errors", data=10 * numpy.random.rand(15))
-
- # IMAGES
- g2d = self.h5f.create_group("images")
-
- g2d0 = g2d.create_group("2D_regular_image")
- g2d0.attrs["NX_class"] = "NXdata"
- g2d0.attrs["signal"] = "image"
- g2d0.attrs["auxiliary_signals"] = "image2"
- g2d0.attrs["axes"] = numpy.array(["rows_calib", "columns_coordinates"],
- dtype=text_dtype)
- g2d0.create_dataset("image", data=numpy.arange(4 * 6).reshape((4, 6)))
- g2d0.create_dataset("image2", data=numpy.arange(4 * 6).reshape((4, 6)))
- ds = g2d0.create_dataset("rows_calib", data=(10, 5))
- ds.attrs["long_name"] = "Calibrated Y"
- g2d0.create_dataset("columns_coordinates", data=0.5 + 0.02 * numpy.arange(6))
-
- g2d1 = g2d.create_group("2D_irregular_data")
- g2d1.attrs["NX_class"] = "NXdata"
- g2d1.attrs["signal"] = "data"
- g2d1.attrs["title"] = "Title as group attr"
- g2d1.attrs["axes"] = numpy.array(["rows_coordinates", "columns_coordinates"],
- dtype=text_dtype)
- g2d1.create_dataset("data", data=numpy.arange(64 * 128).reshape((64, 128)))
- g2d1.create_dataset("rows_coordinates", data=numpy.arange(64) + numpy.random.rand(64))
- g2d1.create_dataset("columns_coordinates", data=numpy.arange(128) + 2.5 * numpy.random.rand(128))
-
- g2d2 = g2d.create_group("3D_images")
- g2d2.attrs["NX_class"] = "NXdata"
- g2d2.attrs["signal"] = "images"
- ds = g2d2.create_dataset("images", data=numpy.arange(2 * 4 * 6).reshape((2, 4, 6)))
- ds.attrs["interpretation"] = "image"
-
- g2d3 = g2d.create_group("5D_images")
- g2d3.attrs["NX_class"] = "NXdata"
- g2d3.attrs["signal"] = "images"
- g2d3.attrs["axes"] = numpy.array(["rows_coordinates", "columns_coordinates"],
- dtype=text_dtype)
- ds = g2d3.create_dataset("images", data=numpy.arange(2 * 2 * 2 * 4 * 6).reshape((2, 2, 2, 4, 6)))
- ds.attrs["interpretation"] = "image"
- g2d3.create_dataset("rows_coordinates", data=5 + 10 * numpy.arange(4))
- g2d3.create_dataset("columns_coordinates", data=0.5 + 0.02 * numpy.arange(6))
-
- g2d4 = g2d.create_group("RGBA_image")
- g2d4.attrs["NX_class"] = "NXdata"
- g2d4.attrs["signal"] = "image"
- g2d4.attrs["axes"] = numpy.array(["rows_calib", "columns_coordinates"],
- dtype=text_dtype)
- rgba_image = numpy.linspace(0, 1, num=7*8*3).reshape((7, 8, 3))
- rgba_image[:, :, 1] = 1 - rgba_image[:, :, 1] # invert G channel to add some color
- ds = g2d4.create_dataset("image", data=rgba_image)
- ds.attrs["interpretation"] = "rgba-image"
- ds = g2d4.create_dataset("rows_calib", data=(10, 5))
- ds.attrs["long_name"] = "Calibrated Y"
- g2d4.create_dataset("columns_coordinates", data=0.5+0.02*numpy.arange(8))
-
- # SCATTER
- g = self.h5f.create_group("scatters")
-
- gd0 = g.create_group("x_y_scatter")
- gd0.attrs["NX_class"] = "NXdata"
- gd0.attrs["signal"] = "y"
- gd0.attrs["axes"] = numpy.array(["x", ], dtype=text_dtype)
- gd0.create_dataset("y", data=numpy.random.rand(128) - 0.5)
- gd0.create_dataset("x", data=2 * numpy.random.rand(128))
- gd0.create_dataset("x_errors", data=0.05 * numpy.random.rand(128))
- gd0.create_dataset("errors", data=0.05 * numpy.random.rand(128))
-
- gd1 = g.create_group("x_y_value_scatter")
- gd1.attrs["NX_class"] = "NXdata"
- gd1.attrs["signal"] = "values"
- gd1.attrs["axes"] = numpy.array(["x", "y"], dtype=text_dtype)
- gd1.create_dataset("values", data=3.14 * numpy.random.rand(128))
- gd1.create_dataset("y", data=numpy.random.rand(128))
- gd1.create_dataset("y_errors", data=0.02 * numpy.random.rand(128))
- gd1.create_dataset("x", data=numpy.random.rand(128))
- gd1.create_dataset("x_errors", data=0.02 * numpy.random.rand(128))
-
- def tearDown(self):
- self.h5f.close()
-
- def testValidity(self):
- for group in self.h5f:
- for subgroup in self.h5f[group]:
- self.assertTrue(
- nxdata.is_valid_nxdata(self.h5f[group][subgroup]),
- "%s/%s not found to be a valid NXdata group" % (group, subgroup))
-
- def testScalars(self):
- nxd = nxdata.NXdata(self.h5f["scalars/0D_scalar"])
- self.assertTrue(nxd.signal_is_0d)
- self.assertEqual(nxd.signal[()], 10)
- self.assertEqual(nxd.axes_names, [])
- self.assertEqual(nxd.axes_dataset_names, [])
- self.assertEqual(nxd.axes, [])
- self.assertIsNotNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertIsNone(nxd.interpretation)
-
- nxd = nxdata.NXdata(self.h5f["scalars/2D_scalars"])
- self.assertTrue(nxd.signal_is_2d)
- self.assertEqual(nxd.signal[1, 2], 12)
- self.assertEqual(nxd.axes_names, [None, None])
- self.assertEqual(nxd.axes_dataset_names, [None, None])
- self.assertEqual(nxd.axes, [None, None])
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertEqual(nxd.interpretation, "scalar")
-
- nxd = nxdata.NXdata(self.h5f["scalars/4D_scalars"])
- self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
- nxd.signal_is_2d or nxd.signal_is_3d)
- self.assertEqual(nxd.signal[1, 0, 1, 4], 74)
- self.assertEqual(nxd.axes_names, [None, None, None, None])
- self.assertEqual(nxd.axes_dataset_names, [None, None, None, None])
- self.assertEqual(nxd.axes, [None, None, None, None])
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertEqual(nxd.interpretation, "scalar")
-
- def testSpectra(self):
- nxd = nxdata.NXdata(self.h5f["spectra/1D_spectrum"])
- self.assertTrue(nxd.signal_is_1d)
- self.assertTrue(nxd.is_curve)
- self.assertTrue(numpy.array_equal(numpy.array(nxd.signal),
- numpy.arange(10)))
- self.assertEqual(nxd.axes_names, ["energy_calib"])
- self.assertEqual(nxd.axes_dataset_names, ["energy_calib"])
- self.assertEqual(nxd.axes[0][0], 10)
- self.assertEqual(nxd.axes[0][1], 5)
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertIsNone(nxd.interpretation)
- self.assertEqual(nxd.title, "Title as dataset (like nexpy)")
-
- self.assertEqual(nxd.auxiliary_signals_dataset_names,
- ["count2", "count3"])
- self.assertEqual(nxd.auxiliary_signals_names,
- ["count2", "3rd counter"])
- self.assertAlmostEqual(nxd.auxiliary_signals[1][2],
- 0.8) # numpy.arange(10) * 0.4
-
- nxd = nxdata.NXdata(self.h5f["spectra/2D_spectra"])
- self.assertTrue(nxd.signal_is_2d)
- self.assertTrue(nxd.is_curve)
- self.assertEqual(nxd.axes_names, [None, None])
- self.assertEqual(nxd.axes_dataset_names, [None, None])
- self.assertEqual(nxd.axes, [None, None])
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertEqual(nxd.interpretation, "spectrum")
-
- nxd = nxdata.NXdata(self.h5f["spectra/4D_spectra"])
- self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
- nxd.signal_is_2d or nxd.signal_is_3d)
- self.assertTrue(nxd.is_curve)
- self.assertEqual(nxd.axes_names,
- [None, None, None, "Calibrated energy"])
- self.assertEqual(nxd.axes_dataset_names,
- [None, None, None, "energy"])
- self.assertEqual(nxd.axes[:3], [None, None, None])
- self.assertEqual(nxd.axes[3].shape, (10, )) # dataset shape (15, ) sliced [3:12]
- self.assertIsNotNone(nxd.errors)
- self.assertEqual(nxd.errors.shape, (2, 2, 3, 10))
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertEqual(nxd.interpretation, "spectrum")
- self.assertEqual(nxd.get_axis_errors("energy").shape,
- (10,))
- # test getting axis errors by long_name
- self.assertTrue(numpy.array_equal(nxd.get_axis_errors("Calibrated energy"),
- nxd.get_axis_errors("energy")))
- self.assertTrue(numpy.array_equal(nxd.get_axis_errors(b"Calibrated energy"),
- nxd.get_axis_errors("energy")))
-
- def testImages(self):
- nxd = nxdata.NXdata(self.h5f["images/2D_regular_image"])
- self.assertTrue(nxd.signal_is_2d)
- self.assertTrue(nxd.is_image)
- self.assertEqual(nxd.axes_names, ["Calibrated Y", "columns_coordinates"])
- self.assertEqual(list(nxd.axes_dataset_names),
- ["rows_calib", "columns_coordinates"])
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertIsNone(nxd.interpretation)
- self.assertEqual(len(nxd.auxiliary_signals), 1)
- self.assertEqual(nxd.auxiliary_signals_names, ["image2"])
-
- nxd = nxdata.NXdata(self.h5f["images/2D_irregular_data"])
- self.assertTrue(nxd.signal_is_2d)
- self.assertTrue(nxd.is_image)
-
- self.assertEqual(nxd.axes_dataset_names, nxd.axes_names)
- self.assertEqual(list(nxd.axes_dataset_names),
- ["rows_coordinates", "columns_coordinates"])
- self.assertEqual(len(nxd.axes), 2)
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertIsNone(nxd.interpretation)
- self.assertEqual(nxd.title, "Title as group attr")
-
- nxd = nxdata.NXdata(self.h5f["images/5D_images"])
- self.assertTrue(nxd.is_image)
- self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
- nxd.signal_is_2d or nxd.signal_is_3d)
- self.assertEqual(nxd.axes_names,
- [None, None, None, 'rows_coordinates', 'columns_coordinates'])
- self.assertEqual(nxd.axes_dataset_names,
- [None, None, None, 'rows_coordinates', 'columns_coordinates'])
- self.assertIsNone(nxd.errors)
- self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
- self.assertEqual(nxd.interpretation, "image")
-
- nxd = nxdata.NXdata(self.h5f["images/RGBA_image"])
- self.assertTrue(nxd.is_image)
- self.assertEqual(nxd.interpretation, "rgba-image")
- self.assertTrue(nxd.signal_is_3d)
- self.assertEqual(nxd.axes_names, ["Calibrated Y",
- "columns_coordinates",
- None])
- self.assertEqual(list(nxd.axes_dataset_names),
- ["rows_calib", "columns_coordinates", None])
-
- def testScatters(self):
- nxd = nxdata.NXdata(self.h5f["scatters/x_y_scatter"])
- self.assertTrue(nxd.signal_is_1d)
- self.assertEqual(nxd.axes_names, ["x"])
- self.assertEqual(nxd.axes_dataset_names,
- ["x"])
- self.assertIsNotNone(nxd.errors)
- self.assertEqual(nxd.get_axis_errors("x").shape,
- (128, ))
- self.assertTrue(nxd.is_scatter)
- self.assertFalse(nxd.is_x_y_value_scatter)
- self.assertIsNone(nxd.interpretation)
-
- nxd = nxdata.NXdata(self.h5f["scatters/x_y_value_scatter"])
- self.assertFalse(nxd.signal_is_1d)
- self.assertTrue(nxd.axes_dataset_names,
- nxd.axes_names)
- self.assertEqual(nxd.axes_dataset_names,
- ["x", "y"])
- self.assertEqual(nxd.get_axis_errors("x").shape,
- (128, ))
- self.assertEqual(nxd.get_axis_errors("y").shape,
- (128, ))
- self.assertEqual(len(nxd.axes), 2)
- self.assertIsNone(nxd.errors)
- self.assertTrue(nxd.is_scatter)
- self.assertTrue(nxd.is_x_y_value_scatter)
- self.assertIsNone(nxd.interpretation)
-
-
-class TestLegacyNXdata(unittest.TestCase):
- def setUp(self):
- tmp = tempfile.NamedTemporaryFile(prefix="nxdata_legacy_examples_",
- suffix=".h5", delete=True)
- tmp.file.close()
- self.h5fname = tmp.name
- self.h5f = h5py.File(tmp.name, "w")
-
- def tearDown(self):
- self.h5f.close()
-
- def testSignalAttrOnDataset(self):
- g = self.h5f.create_group("2D")
- g.attrs["NX_class"] = "NXdata"
-
- ds0 = g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- ds0.attrs["signal"] = 1
- ds0.attrs["long_name"] = "My first image"
-
- ds1 = g.create_dataset("image1",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- ds1.attrs["signal"] = "2"
- ds1.attrs["long_name"] = "My 2nd image"
-
- ds2 = g.create_dataset("image2",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- ds2.attrs["signal"] = 3
-
- nxd = nxdata.NXdata(self.h5f["2D"])
-
- self.assertEqual(nxd.signal_dataset_name, "image0")
- self.assertEqual(nxd.signal_name, "My first image")
- self.assertEqual(nxd.signal.shape,
- (4, 6))
-
- self.assertEqual(len(nxd.auxiliary_signals), 2)
- self.assertEqual(nxd.auxiliary_signals[1].shape,
- (4, 6))
-
- self.assertEqual(nxd.auxiliary_signals_dataset_names,
- ["image1", "image2"])
- self.assertEqual(nxd.auxiliary_signals_names,
- ["My 2nd image", "image2"])
-
- def testAxesOnSignalDataset(self):
- g = self.h5f.create_group("2D")
- g.attrs["NX_class"] = "NXdata"
-
- ds0 = g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- ds0.attrs["signal"] = 1
- ds0.attrs["axes"] = "yaxis:xaxis"
-
- ds1 = g.create_dataset("yaxis",
- data=numpy.arange(4))
- ds2 = g.create_dataset("xaxis",
- data=numpy.arange(6))
-
- nxd = nxdata.NXdata(self.h5f["2D"])
-
- self.assertEqual(nxd.axes_dataset_names,
- ["yaxis", "xaxis"])
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- numpy.arange(4)))
- self.assertTrue(numpy.array_equal(nxd.axes[1],
- numpy.arange(6)))
-
- def testAxesOnAxesDatasets(self):
- g = self.h5f.create_group("2D")
- g.attrs["NX_class"] = "NXdata"
-
- ds0 = g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- ds0.attrs["signal"] = 1
- ds1 = g.create_dataset("yaxis",
- data=numpy.arange(4))
- ds1.attrs["axis"] = 0
- ds2 = g.create_dataset("xaxis",
- data=numpy.arange(6))
- ds2.attrs["axis"] = "1"
-
- nxd = nxdata.NXdata(self.h5f["2D"])
- self.assertEqual(nxd.axes_dataset_names,
- ["yaxis", "xaxis"])
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- numpy.arange(4)))
- self.assertTrue(numpy.array_equal(nxd.axes[1],
- numpy.arange(6)))
-
- def testAsciiUndefinedAxesAttrs(self):
- """Some files may not be using utf8 for str attrs"""
- g = self.h5f.create_group("bytes_attrs")
- g.attrs["NX_class"] = b"NXdata"
- g.attrs["signal"] = b"image0"
- g.attrs["axes"] = b"yaxis", b"."
-
- g.create_dataset("image0",
- data=numpy.arange(4 * 6).reshape((4, 6)))
- g.create_dataset("yaxis",
- data=numpy.arange(4))
-
- nxd = nxdata.NXdata(self.h5f["bytes_attrs"])
- self.assertEqual(nxd.axes_dataset_names,
- ["yaxis", None])
-
-
-class TestSaveNXdata(unittest.TestCase):
- def setUp(self):
- tmp = tempfile.NamedTemporaryFile(prefix="nxdata",
- suffix=".h5", delete=True)
- tmp.file.close()
- self.h5fname = tmp.name
-
- def testSimpleSave(self):
- sig = numpy.array([0, 1, 2])
- a0 = numpy.array([2, 3, 4])
- a1 = numpy.array([3, 4, 5])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig,
- axes=[a0, a1],
- signal_name="sig",
- axes_names=["a0", "a1"],
- nxentry_name="a",
- nxdata_name="mydata")
-
- h5f = h5py.File(self.h5fname, "r")
- self.assertTrue(nxdata.is_valid_nxdata(h5f["a/mydata"]))
-
- nxd = nxdata.NXdata(h5f["/a/mydata"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- a0))
-
- h5f.close()
-
- def testSimplestSave(self):
- sig = numpy.array([0, 1, 2])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig)
-
- h5f = h5py.File(self.h5fname, "r")
-
- self.assertTrue(nxdata.is_valid_nxdata(h5f["/entry/data0"]))
-
- nxd = nxdata.NXdata(h5f["/entry/data0"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- h5f.close()
-
- def testSaveDefaultAxesNames(self):
- sig = numpy.array([0, 1, 2])
- a0 = numpy.array([2, 3, 4])
- a1 = numpy.array([3, 4, 5])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig,
- axes=[a0, a1],
- signal_name="sig",
- axes_names=None,
- axes_long_names=["a", "b"],
- nxentry_name="a",
- nxdata_name="mydata")
-
- h5f = h5py.File(self.h5fname, "r")
- self.assertTrue(nxdata.is_valid_nxdata(h5f["a/mydata"]))
-
- nxd = nxdata.NXdata(h5f["/a/mydata"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- a0))
- self.assertEqual(nxd.axes_dataset_names,
- [u"dim0", u"dim1"])
- self.assertEqual(nxd.axes_names,
- [u"a", u"b"])
-
- h5f.close()
-
- def testSaveToExistingEntry(self):
- h5f = h5py.File(self.h5fname, "w")
- g = h5f.create_group("myentry")
- g.attrs["NX_class"] = "NXentry"
- h5f.close()
-
- sig = numpy.array([0, 1, 2])
- a0 = numpy.array([2, 3, 4])
- a1 = numpy.array([3, 4, 5])
- nxdata.save_NXdata(filename=self.h5fname,
- signal=sig,
- axes=[a0, a1],
- signal_name="sig",
- axes_names=["a0", "a1"],
- nxentry_name="myentry",
- nxdata_name="toto")
-
- h5f = h5py.File(self.h5fname, "r")
- self.assertTrue(nxdata.is_valid_nxdata(h5f["myentry/toto"]))
-
- nxd = nxdata.NXdata(h5f["myentry/toto"])
- self.assertTrue(numpy.array_equal(nxd.signal,
- sig))
- self.assertTrue(numpy.array_equal(nxd.axes[0],
- a0))
- h5f.close()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestNXdata))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestLegacyNXdata))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSaveNXdata))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_octaveh5.py b/silx/io/test/test_octaveh5.py
deleted file mode 100644
index 2e65820..0000000
--- a/silx/io/test/test_octaveh5.py
+++ /dev/null
@@ -1,165 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Tests for the octaveh5 module
-"""
-
-__authors__ = ["C. Nemoz", "H. Payno"]
-__license__ = "MIT"
-__date__ = "12/07/2016"
-
-import unittest
-import os
-import tempfile
-
-try:
- from ..octaveh5 import Octaveh5
-except ImportError:
- Octaveh5 = None
-
-
-@unittest.skipIf(Octaveh5 is None, "Could not import h5py")
-class TestOctaveH5(unittest.TestCase):
- @staticmethod
- def _get_struct_FT():
- return {
- 'NO_CHECK': 0.0, 'SHOWSLICE': 1.0, 'DOTOMO': 1.0, 'DATABASE': 0.0, 'ANGLE_OFFSET': 0.0,
- 'VOLSELECTION_REMEMBER': 0.0, 'NUM_PART': 4.0, 'VOLOUTFILE': 0.0, 'RINGSCORRECTION': 0.0,
- 'DO_TEST_SLICE': 1.0, 'ZEROOFFMASK': 1.0, 'VERSION': 'fastomo3 version 2.0',
- 'CORRECT_SPIKES_THRESHOLD': 0.040000000000000001, 'SHOWPROJ': 0.0, 'HALF_ACQ': 0.0,
- 'ANGLE_OFFSET_VALUE': 0.0, 'FIXEDSLICE': 'middle', 'VOLSELECT': 'total' }
- @staticmethod
- def _get_struct_PYHSTEXE():
- return {
- 'EXE': 'PyHST2_2015d', 'VERBOSE': 0.0, 'OFFV': 'PyHST2_2015d', 'TOMO': 0.0,
- 'VERBOSE_FILE': 'pyhst_out.txt', 'DIR': '/usr/bin/', 'OFFN': 'pyhst2'}
-
- @staticmethod
- def _get_struct_FTAXIS():
- return {
- 'POSITION_VALUE': 12345.0, 'COR_ERROR': 0.0, 'FILESDURINGSCAN': 0.0, 'PLOTFIGURE': 1.0,
- 'DIM1': 0.0, 'OVERSAMPLING': 5.0, 'TO_THE_CENTER': 1.0, 'POSITION': 'fixed',
- 'COR_POSITION': 0.0, 'HA': 0.0 }
-
- @staticmethod
- def _get_struct_PAGANIN():
- return {
- 'MKEEP_MASK': 0.0, 'UNSHARP_SIGMA': 0.80000000000000004, 'DILATE': 2.0, 'UNSHARP_COEFF': 3.0,
- 'MEDIANR': 4.0, 'DB': 500.0, 'MKEEP_ABS': 0.0, 'MODE': 0.0, 'THRESHOLD': 0.5,
- 'MKEEP_BONE': 0.0, 'DB2': 100.0, 'MKEEP_CORR': 0.0, 'MKEEP_SOFT': 0.0 }
-
- @staticmethod
- def _get_struct_BEAMGEO():
- return {'DIST': 55.0, 'SY': 0.0, 'SX': 0.0, 'TYPE': 'p'}
-
-
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.test_3_6_fname = os.path.join(self.tempdir, "silx_tmp_t00_octaveTest_3_6.h5")
- self.test_3_8_fname = os.path.join(self.tempdir, "silx_tmp_t00_octaveTest_3_8.h5")
-
- def tearDown(self):
- if os.path.isfile(self.test_3_6_fname):
- os.unlink(self.test_3_6_fname)
- if os.path.isfile(self.test_3_8_fname):
- os.unlink(self.test_3_8_fname)
-
- def testWritedIsReaded(self):
- """
- Simple test to write and reaf the structure compatible with the octave h5 using structure.
- This test is for # test for octave version > 3.8
- """
- writer = Octaveh5()
-
- writer.open(self.test_3_8_fname, 'a')
- # step 1 writing the file
- writer.write('FT', self._get_struct_FT())
- writer.write('PYHSTEXE', self._get_struct_PYHSTEXE())
- writer.write('FTAXIS', self._get_struct_FTAXIS())
- writer.write('PAGANIN', self._get_struct_PAGANIN())
- writer.write('BEAMGEO', self._get_struct_BEAMGEO())
- writer.close()
-
- # step 2 reading the file
- reader = Octaveh5().open(self.test_3_8_fname)
- # 2.1 check FT
- data_readed = reader.get('FT')
- self.assertEqual(data_readed, self._get_struct_FT() )
- # 2.2 check PYHSTEXE
- data_readed = reader.get('PYHSTEXE')
- self.assertEqual(data_readed, self._get_struct_PYHSTEXE() )
- # 2.3 check FTAXIS
- data_readed = reader.get('FTAXIS')
- self.assertEqual(data_readed, self._get_struct_FTAXIS() )
- # 2.4 check PAGANIN
- data_readed = reader.get('PAGANIN')
- self.assertEqual(data_readed, self._get_struct_PAGANIN() )
- # 2.5 check BEAMGEO
- data_readed = reader.get('BEAMGEO')
- self.assertEqual(data_readed, self._get_struct_BEAMGEO() )
- reader.close()
-
- def testWritedIsReadedOldOctaveVersion(self):
- """The same test as testWritedIsReaded but for octave version < 3.8
- """
- # test for octave version < 3.8
- writer = Octaveh5(3.6)
-
- writer.open(self.test_3_6_fname, 'a')
-
- # step 1 writing the file
- writer.write('FT', self._get_struct_FT())
- writer.write('PYHSTEXE', self._get_struct_PYHSTEXE())
- writer.write('FTAXIS', self._get_struct_FTAXIS())
- writer.write('PAGANIN', self._get_struct_PAGANIN())
- writer.write('BEAMGEO', self._get_struct_BEAMGEO())
- writer.close()
-
- # step 2 reading the file
- reader = Octaveh5(3.6).open(self.test_3_6_fname)
- # 2.1 check FT
- data_readed = reader.get('FT')
- self.assertEqual(data_readed, self._get_struct_FT() )
- # 2.2 check PYHSTEXE
- data_readed = reader.get('PYHSTEXE')
- self.assertEqual(data_readed, self._get_struct_PYHSTEXE() )
- # 2.3 check FTAXIS
- data_readed = reader.get('FTAXIS')
- self.assertEqual(data_readed, self._get_struct_FTAXIS() )
- # 2.4 check PAGANIN
- data_readed = reader.get('PAGANIN')
- self.assertEqual(data_readed, self._get_struct_PAGANIN() )
- # 2.5 check BEAMGEO
- data_readed = reader.get('BEAMGEO')
- self.assertEqual(data_readed, self._get_struct_BEAMGEO() )
- reader.close()
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestOctaveH5))
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_rawh5.py b/silx/io/test/test_rawh5.py
deleted file mode 100644
index 0f7205c..0000000
--- a/silx/io/test/test_rawh5.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for silx.gui.hdf5 module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "21/09/2017"
-
-
-import unittest
-import tempfile
-import numpy
-import shutil
-from ..import rawh5
-
-
-class TestNumpyFile(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- cls.tmpDirectory = tempfile.mkdtemp()
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.tmpDirectory)
-
- def testNumpyFile(self):
- filename = "%s/%s.npy" % (self.tmpDirectory, self.id())
- c = numpy.random.rand(5, 5)
- numpy.save(filename, c)
- h5 = rawh5.NumpyFile(filename)
- self.assertIn("data", h5)
- self.assertEqual(h5["data"].dtype.kind, "f")
-
- def testNumpyZFile(self):
- filename = "%s/%s.npz" % (self.tmpDirectory, self.id())
- a = numpy.array(u"aaaaa")
- b = numpy.array([1, 2, 3, 4])
- c = numpy.random.rand(5, 5)
- d = numpy.array(b"aaaaa")
- e = numpy.array(u"i \u2661 my mother")
- numpy.savez(filename, a, b=b, c=c, d=d, e=e)
- h5 = rawh5.NumpyFile(filename)
- self.assertIn("arr_0", h5)
- self.assertIn("b", h5)
- self.assertIn("c", h5)
- self.assertIn("d", h5)
- self.assertIn("e", h5)
- self.assertEqual(h5["arr_0"].dtype.kind, "U")
- self.assertEqual(h5["b"].dtype.kind, "i")
- self.assertEqual(h5["c"].dtype.kind, "f")
- self.assertEqual(h5["d"].dtype.kind, "S")
- self.assertEqual(h5["e"].dtype.kind, "U")
-
- def testNumpyZFileContainingDirectories(self):
- filename = "%s/%s.npz" % (self.tmpDirectory, self.id())
- data = {}
- data['a/b/c'] = numpy.arange(10)
- data['a/b/e'] = numpy.arange(10)
- numpy.savez(filename, **data)
- h5 = rawh5.NumpyFile(filename)
- self.assertIn("a/b/c", h5)
- self.assertIn("a/b/e", h5)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestNumpyFile))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/io/test/test_specfile.py b/silx/io/test/test_specfile.py
deleted file mode 100644
index 79d5544..0000000
--- a/silx/io/test/test_specfile.py
+++ /dev/null
@@ -1,433 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for specfile wrapper"""
-
-__authors__ = ["P. Knobel", "V.A. Sole"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import locale
-import logging
-import numpy
-import os
-import sys
-import tempfile
-import unittest
-
-from silx.utils import testutils
-
-from ..specfile import SpecFile, Scan
-from .. import specfile
-
-
-logger1 = logging.getLogger(__name__)
-
-sftext = """#F /tmp/sf.dat
-#E 1455180875
-#D Thu Feb 11 09:54:35 2016
-#C imaging User = opid17
-#U00 user comment first line
-#U01 This is a dummy file to test SpecFile parsing
-#U02
-#U03 last line
-
-#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
-#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
-#o0 pshg mrtu mrtd
-#o2 ss1vo ss1ho ss1vg
-
-#J0 Seconds IA ion.mono Current
-#J1 xbpmc2 idgap1 Inorm
-
-#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
-#D Thu Feb 11 09:55:20 2016
-#T 0.2 (Seconds)
-#G0 0
-#G1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
-#G3 0 0 0 0 0 0 0 0 0
-#G4 0
-#Q
-#P0 180.005 -0.66875 0.87125
-#P1 14.74255 16.197579 12.238283
-#UMI0 Current AutoM Shutter
-#UMI1 192.51 OFF FE open
-#UMI2 Refill in 39883 sec, Fill Mode: uniform multibunch / Message: Feb 11 08:00 Delivery:Next Refill at 21:00;
-#N 4
-#L first column second column 3rd_col
--1.23 5.89 8
-8.478100E+01 5 1.56
-3.14 2.73 -3.14
-1.2 2.3 3.4
-
-#S 25 ascan c3th 1.33245 1.52245 40 0.15
-#D Thu Feb 11 10:00:31 2016
-#P0 80.005 -1.66875 1.87125
-#P1 4.74255 6.197579 2.238283
-#N 5
-#L column0 column1 col2 col3
-0.0 0.1 0.2 0.3
-1.0 1.1 1.2 1.3
-2.0 2.1 2.2 2.3
-3.0 3.1 3.2 3.3
-
-#S 26 yyyyyy
-#D Thu Feb 11 09:55:20 2016
-#P0 80.005 -1.66875 1.87125
-#P1 4.74255 6.197579 2.238283
-#N 4
-#L first column second column 3rd_col
-#C Sat Oct 31 15:51:47 1998. Scan aborted after 0 points.
-
-#F /tmp/sf.dat
-#E 1455180876
-#D Thu Feb 11 09:54:36 2016
-
-#S 1 aaaaaa
-#U first duplicate line
-#U second duplicate line
-#@MCADEV 1
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#N 3
-#L uno duo
-1 2
-@A 0 1 2
-3 4
-@A 3.1 4 5
-5 6
-@A 6 7.7 8
-"""
-
-
-loc = locale.getlocale(locale.LC_NUMERIC)
-try:
- locale.setlocale(locale.LC_NUMERIC, 'de_DE.utf8')
-except locale.Error:
- try_DE = False
-else:
- try_DE = True
- locale.setlocale(locale.LC_NUMERIC, loc)
-
-
-class TestSpecFile(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- fd, cls.fname1 = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
- os.close(fd)
-
- fd2, cls.fname2 = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd2, sftext[370:923])
- else:
- os.write(fd2, bytes(sftext[370:923], 'ascii'))
- os.close(fd2)
-
- fd3, cls.fname3 = tempfile.mkstemp(text=False)
- txt = sftext[371:923]
- if sys.version_info < (3, ):
- os.write(fd3, txt)
- else:
- os.write(fd3, bytes(txt, 'ascii'))
- os.close(fd3)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname1)
- os.unlink(cls.fname2)
- os.unlink(cls.fname3)
-
- def setUp(self):
- self.sf = SpecFile(self.fname1)
- self.scan1 = self.sf[0]
- self.scan1_2 = self.sf["1.2"]
- self.scan25 = self.sf["25.1"]
- self.empty_scan = self.sf["26.1"]
-
- self.sf_no_fhdr = SpecFile(self.fname2)
- self.scan1_no_fhdr = self.sf_no_fhdr[0]
-
- self.sf_no_fhdr_crash = SpecFile(self.fname3)
- self.scan1_no_fhdr_crash = self.sf_no_fhdr_crash[0]
-
- def tearDown(self):
- self.sf.close()
- self.sf_no_fhdr.close()
- self.sf_no_fhdr_crash.close()
-
- def test_open(self):
- self.assertIsInstance(self.sf, SpecFile)
- with self.assertRaises(specfile.SfErrFileOpen):
- SpecFile("doesnt_exist.dat")
-
- # test filename types unicode and bytes
- if sys.version_info[0] < 3:
- try:
- SpecFile(self.fname1)
- except TypeError:
- self.fail("failed to handle filename as python2 str")
- try:
- SpecFile(unicode(self.fname1))
- except TypeError:
- self.fail("failed to handle filename as python2 unicode")
- else:
- try:
- SpecFile(self.fname1)
- except TypeError:
- self.fail("failed to handle filename as python3 str")
- try:
- SpecFile(bytes(self.fname1, 'utf-8'))
- except TypeError:
- self.fail("failed to handle filename as python3 bytes")
-
- def test_number_of_scans(self):
- self.assertEqual(4, len(self.sf))
-
- def test_list_of_scan_indices(self):
- self.assertEqual(self.sf.list(),
- [1, 25, 26, 1])
- self.assertEqual(self.sf.keys(),
- ["1.1", "25.1", "26.1", "1.2"])
-
- def test_index_number_order(self):
- self.assertEqual(self.sf.index(1, 2), 3) # sf["1.2"]==sf[3]
- self.assertEqual(self.sf.number(1), 25) # sf[1]==sf["25"]
- self.assertEqual(self.sf.order(3), 2) # sf[3]==sf["1.2"]
- with self.assertRaises(specfile.SfErrScanNotFound):
- self.sf.index(3, 2)
- with self.assertRaises(specfile.SfErrScanNotFound):
- self.sf.index(99)
-
- def assertRaisesRegex(self, *args, **kwargs):
- # Python 2 compatibility
- if sys.version_info.major >= 3:
- return super(TestSpecFile, self).assertRaisesRegex(*args, **kwargs)
- else:
- return self.assertRaisesRegexp(*args, **kwargs)
-
- def test_getitem(self):
- self.assertIsInstance(self.sf[2], Scan)
- self.assertIsInstance(self.sf["1.2"], Scan)
- # int out of range
- with self.assertRaisesRegex(IndexError, 'Scan index must be in ran'):
- self.sf[107]
- # float indexing not allowed
- with self.assertRaisesRegex(TypeError, 'The scan identification k'):
- self.sf[1.2]
- # non existant scan with "N.M" indexing
- with self.assertRaises(KeyError):
- self.sf["3.2"]
-
- def test_specfile_iterator(self):
- i = 0
- for scan in self.sf:
- if i == 1:
- self.assertEqual(scan.motor_positions,
- self.sf[1].motor_positions)
- i += 1
- # number of returned scans
- self.assertEqual(i, len(self.sf))
-
- def test_scan_index(self):
- self.assertEqual(self.scan1.index, 0)
- self.assertEqual(self.scan1_2.index, 3)
- self.assertEqual(self.scan25.index, 1)
-
- def test_scan_headers(self):
- self.assertEqual(self.scan25.scan_header_dict['S'],
- "25 ascan c3th 1.33245 1.52245 40 0.15")
- self.assertEqual(self.scan1.header[17], '#G0 0')
- self.assertEqual(len(self.scan1.header), 29)
- # parsing headers with long keys
- self.assertEqual(self.scan1.scan_header_dict['UMI0'],
- 'Current AutoM Shutter')
- # parsing empty headers
- self.assertEqual(self.scan1.scan_header_dict['Q'], '')
- # duplicate headers: concatenated (with newline)
- self.assertEqual(self.scan1_2.scan_header_dict["U"],
- "first duplicate line\nsecond duplicate line")
-
- def test_file_headers(self):
- self.assertEqual(self.scan1.header[1],
- '#E 1455180875')
- self.assertEqual(self.scan1.file_header_dict['F'],
- '/tmp/sf.dat')
-
- def test_multiple_file_headers(self):
- """Scan 1.2 is after the second file header, with a different
- Epoch"""
- self.assertEqual(self.scan1_2.header[1],
- '#E 1455180876')
-
- def test_scan_labels(self):
- self.assertEqual(self.scan1.labels,
- ['first column', 'second column', '3rd_col'])
-
- def test_data(self):
- # data_line() and data_col() take 1-based indices as arg
- self.assertAlmostEqual(self.scan1.data_line(1)[2],
- 1.56)
- # tests for data transposition between original file and .data attr
- self.assertAlmostEqual(self.scan1.data[2, 0],
- 8)
- self.assertEqual(self.scan1.data.shape, (3, 4))
- self.assertAlmostEqual(numpy.sum(self.scan1.data), 113.631)
-
- def test_data_column_by_name(self):
- self.assertAlmostEqual(self.scan25.data_column_by_name("col2")[1],
- 1.2)
- # Scan.data is transposed after readinq, so column is the first index
- self.assertAlmostEqual(numpy.sum(self.scan25.data_column_by_name("col2")),
- numpy.sum(self.scan25.data[2, :]))
- with self.assertRaises(specfile.SfErrColNotFound):
- self.scan25.data_column_by_name("ygfxgfyxg")
-
- def test_motors(self):
- self.assertEqual(len(self.scan1.motor_names), 6)
- self.assertEqual(len(self.scan1.motor_positions), 6)
- self.assertAlmostEqual(sum(self.scan1.motor_positions),
- 223.385912)
- self.assertEqual(self.scan1.motor_names[1], 'MRTSlit UP')
- self.assertAlmostEqual(
- self.scan25.motor_position_by_name('MRTSlit UP'),
- -1.66875)
-
- def test_absence_of_file_header(self):
- """We expect Scan.file_header to be an empty list in the absence
- of a file header.
- """
- self.assertEqual(len(self.scan1_no_fhdr.motor_names), 0)
- # motor positions can still be read in the scan header
- # even in the absence of motor names
- self.assertAlmostEqual(sum(self.scan1_no_fhdr.motor_positions),
- 223.385912)
- self.assertEqual(len(self.scan1_no_fhdr.header), 15)
- self.assertEqual(len(self.scan1_no_fhdr.scan_header), 15)
- self.assertEqual(len(self.scan1_no_fhdr.file_header), 0)
-
- def test_crash_absence_of_file_header(self):
- """Test no crash in absence of file header and no leading newline
- character
- """
- self.assertEqual(len(self.scan1_no_fhdr_crash.motor_names), 0)
- # motor positions can still be read in the scan header
- # even in the absence of motor names
- self.assertAlmostEqual(sum(self.scan1_no_fhdr_crash.motor_positions),
- 223.385912)
- self.assertEqual(len(self.scan1_no_fhdr_crash.scan_header), 15)
- self.assertEqual(len(self.scan1_no_fhdr_crash.file_header), 0)
-
- def test_mca(self):
- self.assertEqual(len(self.scan1.mca), 0)
- self.assertEqual(len(self.scan1_2.mca), 3)
- self.assertEqual(self.scan1_2.mca[1][2], 5)
- self.assertEqual(sum(self.scan1_2.mca[2]), 21.7)
-
- # Negative indexing
- self.assertEqual(sum(self.scan1_2.mca[len(self.scan1_2.mca) - 1]),
- sum(self.scan1_2.mca[-1]))
-
- # Test iterator
- line_count, total_sum = (0, 0)
- for mca_line in self.scan1_2.mca:
- line_count += 1
- total_sum += sum(mca_line)
- self.assertEqual(line_count, 3)
- self.assertAlmostEqual(total_sum, 36.8)
-
- def test_mca_header(self):
- self.assertEqual(self.scan1.mca_header_dict, {})
- self.assertEqual(len(self.scan1_2.mca_header_dict), 4)
- self.assertEqual(self.scan1_2.mca_header_dict["CALIB"], "1 2 3")
- self.assertEqual(self.scan1_2.mca.calibration,
- [[1., 2., 3.]])
- # default calib in the absence of #@CALIB
- self.assertEqual(self.scan25.mca.calibration,
- [[0., 1., 0.]])
- self.assertEqual(self.scan1_2.mca.channels,
- [[0, 1, 2]])
- # absence of #@CHANN and spectra
- self.assertEqual(self.scan25.mca.channels,
- [])
-
- @testutils.test_logging(specfile._logger.name, warning=1)
- def test_empty_scan(self):
- """Test reading a scan with no data points"""
- self.assertEqual(len(self.empty_scan.labels),
- 3)
- col1 = self.empty_scan.data_column_by_name("second column")
- self.assertEqual(col1.shape, (0, ))
-
-
-class TestSFLocale(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- fd, cls.fname = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
- os.close(fd)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname)
- locale.setlocale(locale.LC_NUMERIC, loc) # restore saved locale
-
- def crunch_data(self):
- self.sf3 = SpecFile(self.fname)
- self.assertAlmostEqual(self.sf3[0].data_line(1)[2],
- 1.56)
- self.sf3.close()
-
- @unittest.skipIf(not try_DE, "de_DE.utf8 locale not installed")
- def test_locale_de_DE(self):
- locale.setlocale(locale.LC_NUMERIC, 'de_DE.utf8')
- self.crunch_data()
-
- def test_locale_user(self):
- locale.setlocale(locale.LC_NUMERIC, '') # use user's preferred locale
- self.crunch_data()
-
- def test_locale_C(self):
- locale.setlocale(locale.LC_NUMERIC, 'C') # use default (C) locale
- self.crunch_data()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecFile))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSFLocale))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_specfilewrapper.py b/silx/io/test/test_specfilewrapper.py
deleted file mode 100644
index 2f463fa..0000000
--- a/silx/io/test/test_specfilewrapper.py
+++ /dev/null
@@ -1,206 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for old specfile wrapper"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "15/05/2017"
-
-import locale
-import logging
-import numpy
-import os
-import sys
-import tempfile
-import unittest
-
-logger1 = logging.getLogger(__name__)
-
-from ..specfilewrapper import Specfile
-
-sftext = """#F /tmp/sf.dat
-#E 1455180875
-#D Thu Feb 11 09:54:35 2016
-#C imaging User = opid17
-#U00 user comment first line
-#U01 This is a dummy file to test SpecFile parsing
-#U02
-#U03 last line
-
-#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
-#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
-#o0 pshg mrtu mrtd
-#o2 ss1vo ss1ho ss1vg
-
-#J0 Seconds IA ion.mono Current
-#J1 xbpmc2 idgap1 Inorm
-
-#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
-#D Thu Feb 11 09:55:20 2016
-#T 0.2 (Seconds)
-#G0 0
-#G1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
-#G3 0 0 0 0 0 0 0 0 0
-#G4 0
-#Q
-#P0 180.005 -0.66875 0.87125
-#P1 14.74255 16.197579 12.238283
-#UMI0 Current AutoM Shutter
-#UMI1 192.51 OFF FE open
-#UMI2 Refill in 39883 sec, Fill Mode: uniform multibunch / Message: Feb 11 08:00 Delivery:Next Refill at 21:00;
-#N 4
-#L first column second column 3rd_col
--1.23 5.89 8
-8.478100E+01 5 1.56
-3.14 2.73 -3.14
-1.2 2.3 3.4
-
-#S 25 ascan c3th 1.33245 1.52245 40 0.15
-#D Thu Feb 11 10:00:31 2016
-#P0 80.005 -1.66875 1.87125
-#P1 4.74255 6.197579 2.238283
-#N 5
-#L column0 column1 col2 col3
-0.0 0.1 0.2 0.3
-1.0 1.1 1.2 1.3
-2.0 2.1 2.2 2.3
-3.0 3.1 3.2 3.3
-
-#F /tmp/sf.dat
-#E 1455180876
-#D Thu Feb 11 09:54:36 2016
-
-#S 1 aaaaaa
-#U first duplicate line
-#U second duplicate line
-#@MCADEV 1
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#N 3
-#L uno duo
-1 2
-@A 0 1 2
-3 4
-@A 3.1 4 5
-5 6
-@A 6 7.7 8
-"""
-
-
-class TestSpecfilewrapper(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- fd, cls.fname1 = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
- os.close(fd)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname1)
-
- def setUp(self):
- self.sf = Specfile(self.fname1)
- self.scan1 = self.sf[0]
- self.scan1_2 = self.sf.select("1.2")
- self.scan25 = self.sf.select("25.1")
-
- def tearDown(self):
- self.sf.close()
-
- def test_number_of_scans(self):
- self.assertEqual(3, len(self.sf))
-
- def test_list_of_scan_indices(self):
- self.assertEqual(self.sf.list(),
- '1,25,1')
- self.assertEqual(self.sf.keys(),
- ["1.1", "25.1", "1.2"])
-
- def test_scan_headers(self):
- self.assertEqual(self.scan25.header('S'),
- ["#S 25 ascan c3th 1.33245 1.52245 40 0.15"])
- self.assertEqual(self.scan1.header("G0"), ['#G0 0'])
- # parsing headers with long keys
- # parsing empty headers
- self.assertEqual(self.scan1.header('Q'), ['#Q '])
-
- def test_file_headers(self):
- self.assertEqual(self.scan1.header("E"),
- ['#E 1455180875'])
- self.assertEqual(self.sf.title(),
- "imaging")
- self.assertEqual(self.sf.epoch(),
- 1455180875)
- self.assertEqual(self.sf.allmotors(),
- ["Pslit HGap", "MRTSlit UP", "MRTSlit DOWN",
- "Sslit1 VOff", "Sslit1 HOff", "Sslit1 VGap"])
-
- def test_scan_labels(self):
- self.assertEqual(self.scan1.alllabels(),
- ['first column', 'second column', '3rd_col'])
-
- def test_data(self):
- self.assertAlmostEqual(self.scan1.dataline(3)[2],
- -3.14)
- self.assertAlmostEqual(self.scan1.datacol(1)[2],
- 3.14)
- # tests for data transposition between original file and .data attr
- self.assertAlmostEqual(self.scan1.data()[2, 0],
- 8)
- self.assertEqual(self.scan1.data().shape, (3, 4))
- self.assertAlmostEqual(numpy.sum(self.scan1.data()), 113.631)
-
- def test_date(self):
- self.assertEqual(self.scan1.date(),
- "Thu Feb 11 09:55:20 2016")
-
- def test_motors(self):
- self.assertEqual(len(self.sf.allmotors()), 6)
- self.assertEqual(len(self.scan1.allmotorpos()), 6)
- self.assertAlmostEqual(sum(self.scan1.allmotorpos()),
- 223.385912)
- self.assertEqual(self.sf.allmotors()[1], 'MRTSlit UP')
-
- def test_mca(self):
- self.assertEqual(self.scan1_2.mca(2)[2], 5)
- self.assertEqual(sum(self.scan1_2.mca(3)), 21.7)
-
- def test_mca_header(self):
- self.assertEqual(self.scan1_2.header("CALIB"),
- ["#@CALIB 1 2 3"])
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecfilewrapper))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_spech5.py b/silx/io/test/test_spech5.py
deleted file mode 100644
index 0263c3c..0000000
--- a/silx/io/test/test_spech5.py
+++ /dev/null
@@ -1,881 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for spech5"""
-from numpy import array_equal
-import os
-import io
-import sys
-import tempfile
-import unittest
-import datetime
-from functools import partial
-
-from silx.utils import testutils
-
-from .. import spech5
-from ..spech5 import (SpecH5, SpecH5Dataset, spec_date_to_iso8601)
-from .. import specfile
-
-import h5py
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "12/02/2018"
-
-sftext = """#F /tmp/sf.dat
-#E 1455180875
-#D Thu Feb 11 09:54:35 2016
-#C imaging User = opid17
-#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
-#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
-#o0 pshg mrtu mrtd
-#o2 ss1vo ss1ho ss1vg
-
-#J0 Seconds IA ion.mono Current
-#J1 xbpmc2 idgap1 Inorm
-
-#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
-#D Thu Feb 11 09:55:20 2016
-#T 0.2 (Seconds)
-#P0 180.005 -0.66875 0.87125
-#P1 14.74255 16.197579 12.238283
-#N 4
-#L MRTSlit UP second column 3rd_col
--1.23 5.89 8
-8.478100E+01 5 1.56
-3.14 2.73 -3.14
-1.2 2.3 3.4
-
-#S 25 ascan c3th 1.33245 1.52245 40 0.15
-#D Sat 2015/03/14 03:53:50
-#P0 80.005 -1.66875 1.87125
-#P1 4.74255 6.197579 2.238283
-#N 5
-#L column0 column1 col2 col3
-0.0 0.1 0.2 0.3
-1.0 1.1 1.2 1.3
-2.0 2.1 2.2 2.3
-3.0 3.1 3.2 3.3
-
-#S 1 aaaaaa
-#D Thu Feb 11 10:00:32 2016
-#@MCADEV 1
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#@CTIME 123.4 234.5 345.6
-#N 3
-#L uno duo
-1 2
-@A 0 1 2
-@A 10 9 8
-@A 1 1 1.1
-3 4
-@A 3.1 4 5
-@A 7 6 5
-@A 1 1 1
-5 6
-@A 6 7.7 8
-@A 4 3 2
-@A 1 1 1
-
-#S 1000 bbbbb
-#G1 3.25 3.25 5.207 90 90 120 2.232368448 2.232368448 1.206680489 90 90 60 1 1 2 -1 2 2 26.132 7.41 -88.96 1.11 1.000012861 15.19 26.06 67.355 -88.96 1.11 1.000012861 15.11 0.723353 0.723353
-#G3 0.0106337923671 0.027529133 1.206191273 -1.43467075 0.7633438883 0.02401568018 -1.709143587 -2.097621783 0.02456954971
-#L a b
-1 2
-
-#S 1001 ccccc
-#G1 0. 0. 0. 0 0 0 2.232368448 2.232368448 1.206680489 90 90 60 1 1 2 -1 2 2 26.132 7.41 -88.96 1.11 1.000012861 15.19 26.06 67.355 -88.96 1.11 1.000012861 15.11 0.723353 0.723353
-#G3 0. 0. 0. 0. 0.0 0. 0. 0. 0.
-#L a b
-1 2
-
-"""
-
-
-class TestSpecDate(unittest.TestCase):
- """
- Test of the spec_date_to_iso8601 function.
- """
- # TODO : time zone tests
- # TODO : error cases
-
- @classmethod
- def setUpClass(cls):
- import locale
- # FYI : not threadsafe
- cls.locale_saved = locale.setlocale(locale.LC_TIME)
- locale.setlocale(locale.LC_TIME, 'C')
-
- @classmethod
- def tearDownClass(cls):
- import locale
- # FYI : not threadsafe
- locale.setlocale(locale.LC_TIME, cls.locale_saved)
-
- def setUp(self):
- # covering all week days
- self.n_days = range(1, 10)
- # covering all months
- self.n_months = range(1, 13)
-
- self.n_years = [1999, 2016, 2020]
- self.n_seconds = [0, 5, 26, 59]
- self.n_minutes = [0, 9, 42, 59]
- self.n_hours = [0, 2, 17, 23]
-
- self.formats = ['%a %b %d %H:%M:%S %Y', '%a %Y/%m/%d %H:%M:%S']
-
- self.check_date_formats = partial(self.__check_date_formats,
- year=self.n_years[0],
- month=self.n_months[0],
- day=self.n_days[0],
- hour=self.n_hours[0],
- minute=self.n_minutes[0],
- second=self.n_seconds[0],
- msg=None)
-
- def __check_date_formats(self,
- year,
- month,
- day,
- hour,
- minute,
- second,
- msg=None):
- dt = datetime.datetime(year, month, day, hour, minute, second)
- expected_date = dt.isoformat()
-
- for i_fmt, fmt in enumerate(self.formats):
- spec_date = dt.strftime(fmt)
- iso_date = spec_date_to_iso8601(spec_date)
- self.assertEqual(iso_date,
- expected_date,
- msg='Testing {0}. format={1}. '
- 'Expected "{2}", got "{3} ({4})" (dt={5}).'
- ''.format(msg,
- i_fmt,
- expected_date,
- iso_date,
- spec_date,
- dt))
-
- def testYearsNominal(self):
- for year in self.n_years:
- self.check_date_formats(year=year, msg='year')
-
- def testMonthsNominal(self):
- for month in self.n_months:
- self.check_date_formats(month=month, msg='month')
-
- def testDaysNominal(self):
- for day in self.n_days:
- self.check_date_formats(day=day, msg='day')
-
- def testHoursNominal(self):
- for hour in self.n_hours:
- self.check_date_formats(hour=hour, msg='hour')
-
- def testMinutesNominal(self):
- for minute in self.n_minutes:
- self.check_date_formats(minute=minute, msg='minute')
-
- def testSecondsNominal(self):
- for second in self.n_seconds:
- self.check_date_formats(second=second, msg='second')
-
-
-class TestSpecH5(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- fd, cls.fname = tempfile.mkstemp()
- if sys.version_info < (3, ):
- os.write(fd, sftext)
- else:
- os.write(fd, bytes(sftext, 'ascii'))
- os.close(fd)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname)
-
- def setUp(self):
- self.sfh5 = SpecH5(self.fname)
-
- def tearDown(self):
- self.sfh5.close()
-
- def testContainsFile(self):
- self.assertIn("/1.2/measurement", self.sfh5)
- self.assertIn("/25.1", self.sfh5)
- self.assertIn("25.1", self.sfh5)
- self.assertNotIn("25.2", self.sfh5)
- # measurement is a child of a scan, full path would be required to
- # access from root level
- self.assertNotIn("measurement", self.sfh5)
- # Groups may or may not have a trailing /
- self.assertIn("/1.2/measurement/mca_1/", self.sfh5)
- self.assertIn("/1.2/measurement/mca_1", self.sfh5)
- # Datasets can't have a trailing /
- self.assertNotIn("/1.2/measurement/mca_0/info/calibration/ ", self.sfh5)
- # No mca_8
- self.assertNotIn("/1.2/measurement/mca_8/info/calibration", self.sfh5)
- # Link
- self.assertIn("/1.2/measurement/mca_0/info/calibration", self.sfh5)
-
- def testContainsGroup(self):
- self.assertIn("measurement", self.sfh5["/1.2/"])
- self.assertIn("measurement", self.sfh5["/1.2"])
- self.assertIn("25.1", self.sfh5["/"])
- self.assertNotIn("25.2", self.sfh5["/"])
- self.assertIn("instrument/positioners/Sslit1 HOff", self.sfh5["/1.1"])
- # illegal trailing "/" after dataset name
- self.assertNotIn("instrument/positioners/Sslit1 HOff/",
- self.sfh5["/1.1"])
- # full path to element in group (OK)
- self.assertIn("/1.1/instrument/positioners/Sslit1 HOff",
- self.sfh5["/1.1/instrument"])
-
- def testDataColumn(self):
- self.assertAlmostEqual(sum(self.sfh5["/1.2/measurement/duo"]),
- 12.0)
- self.assertAlmostEqual(
- sum(self.sfh5["1.1"]["measurement"]["MRTSlit UP"]),
- 87.891, places=4)
-
- def testDate(self):
- # start time is in Iso8601 format
- self.assertEqual(self.sfh5["/1.1/start_time"],
- u"2016-02-11T09:55:20")
- self.assertEqual(self.sfh5["25.1/start_time"],
- u"2015-03-14T03:53:50")
-
- def assertRaisesRegex(self, *args, **kwargs):
- # Python 2 compatibility
- if sys.version_info.major >= 3:
- return super(TestSpecH5, self).assertRaisesRegex(*args, **kwargs)
- else:
- return self.assertRaisesRegexp(*args, **kwargs)
-
- def testDatasetInstanceAttr(self):
- """The SpecH5Dataset objects must implement some dummy attributes
- to improve compatibility with widgets dealing with h5py datasets."""
- self.assertIsNone(self.sfh5["/1.1/start_time"].compression)
- self.assertIsNone(self.sfh5["1.1"]["measurement"]["MRTSlit UP"].chunks)
-
- # error message must be explicit
- with self.assertRaisesRegex(
- AttributeError,
- "SpecH5Dataset has no attribute tOTo"):
- dummy = self.sfh5["/1.1/start_time"].tOTo
-
- def testGet(self):
- """Test :meth:`SpecH5Group.get`"""
- # default value of param *default* is None
- self.assertIsNone(self.sfh5.get("toto"))
- self.assertEqual(self.sfh5["25.1"].get("toto", default=-3),
- -3)
-
- self.assertEqual(self.sfh5.get("/1.1/start_time", default=-3),
- u"2016-02-11T09:55:20")
-
- def testGetClass(self):
- """Test :meth:`SpecH5Group.get`"""
- self.assertIs(self.sfh5["1.1"].get("start_time", getclass=True),
- h5py.Dataset)
- self.assertIs(self.sfh5["1.1"].get("instrument", getclass=True),
- h5py.Group)
-
- # spech5 does not define external link, so there is no way
- # a group can *get* a SpecH5 class
-
- def testGetApi(self):
- result = self.sfh5.get("1.1", getclass=True, getlink=True)
- self.assertIs(result, h5py.HardLink)
- result = self.sfh5.get("1.1", getclass=False, getlink=True)
- self.assertIsInstance(result, h5py.HardLink)
- result = self.sfh5.get("1.1", getclass=True, getlink=False)
- self.assertIs(result, h5py.Group)
- result = self.sfh5.get("1.1", getclass=False, getlink=False)
- self.assertIsInstance(result, spech5.SpecH5Group)
-
- def testGetItemGroup(self):
- group = self.sfh5["25.1"]["instrument"]
- self.assertEqual(list(group["positioners"].keys()),
- ["Pslit HGap", "MRTSlit UP", "MRTSlit DOWN",
- "Sslit1 VOff", "Sslit1 HOff", "Sslit1 VGap"])
- with self.assertRaises(KeyError):
- group["Holy Grail"]
-
- def testGetitemSpecH5(self):
- self.assertEqual(self.sfh5["/1.2/instrument/positioners"],
- self.sfh5["1.2"]["instrument"]["positioners"])
-
- def testH5pyClass(self):
- """Test :attr:`h5py_class` returns the corresponding h5py class
- (h5py.File, h5py.Group, h5py.Dataset)"""
- a_file = self.sfh5
- self.assertIs(a_file.h5py_class,
- h5py.File)
-
- a_group = self.sfh5["/1.2/measurement"]
- self.assertIs(a_group.h5py_class,
- h5py.Group)
-
- a_dataset = self.sfh5["/1.1/instrument/positioners/Sslit1 HOff"]
- self.assertIs(a_dataset.h5py_class,
- h5py.Dataset)
-
- def testHeader(self):
- file_header = self.sfh5["/1.2/instrument/specfile/file_header"]
- scan_header = self.sfh5["/1.2/instrument/specfile/scan_header"]
-
- # File header has 10 lines
- self.assertEqual(len(file_header), 10)
- # 1.2 has 9 scan & mca header lines
- self.assertEqual(len(scan_header), 9)
-
- # line 4 of file header
- self.assertEqual(
- file_header[3],
- u"#C imaging User = opid17")
- # line 4 of scan header
- scan_header = self.sfh5["25.1/instrument/specfile/scan_header"]
-
- self.assertEqual(
- scan_header[3],
- u"#P1 4.74255 6.197579 2.238283")
-
- def testLinks(self):
- self.assertTrue(
- array_equal(self.sfh5["/1.2/measurement/mca_0/data"],
- self.sfh5["/1.2/instrument/mca_0/data"])
- )
- self.assertTrue(
- array_equal(self.sfh5["/1.2/measurement/mca_0/info/data"],
- self.sfh5["/1.2/instrument/mca_0/data"])
- )
- self.assertTrue(
- array_equal(self.sfh5["/1.2/measurement/mca_0/info/channels"],
- self.sfh5["/1.2/instrument/mca_0/channels"])
- )
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/"].keys(),
- self.sfh5["/1.2/instrument/mca_0/"].keys())
-
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/preset_time"],
- self.sfh5["/1.2/instrument/mca_0/preset_time"])
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/live_time"],
- self.sfh5["/1.2/instrument/mca_0/live_time"])
- self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/elapsed_time"],
- self.sfh5["/1.2/instrument/mca_0/elapsed_time"])
-
- def testListScanIndices(self):
- self.assertEqual(list(self.sfh5.keys()),
- ["1.1", "25.1", "1.2", "1000.1", "1001.1"])
- self.assertEqual(self.sfh5["1.2"].attrs,
- {"NX_class": "NXentry", })
-
- def testMcaAbsent(self):
- def access_absent_mca():
- """This must raise a KeyError, because scan 1.1 has no MCA"""
- return self.sfh5["/1.1/measurement/mca_0/"]
- self.assertRaises(KeyError, access_absent_mca)
-
- def testMcaCalib(self):
- mca0_calib = self.sfh5["/1.2/measurement/mca_0/info/calibration"]
- mca1_calib = self.sfh5["/1.2/measurement/mca_1/info/calibration"]
- self.assertEqual(mca0_calib.tolist(),
- [1, 2, 3])
- # calibration is unique in this scan and applies to all analysers
- self.assertEqual(mca0_calib.tolist(),
- mca1_calib.tolist())
-
- def testMcaChannels(self):
- mca0_chann = self.sfh5["/1.2/measurement/mca_0/info/channels"]
- mca1_chann = self.sfh5["/1.2/measurement/mca_1/info/channels"]
- self.assertEqual(mca0_chann.tolist(),
- [0, 1, 2])
- self.assertEqual(mca0_chann.tolist(),
- mca1_chann.tolist())
-
- def testMcaCtime(self):
- """Tests for #@CTIME mca header"""
- datasets = ["preset_time", "live_time", "elapsed_time"]
- for ds in datasets:
- self.assertNotIn("/1.1/instrument/mca_0/" + ds, self.sfh5)
- self.assertIn("/1.2/instrument/mca_0/" + ds, self.sfh5)
-
- mca0_preset_time = self.sfh5["/1.2/instrument/mca_0/preset_time"]
- mca1_preset_time = self.sfh5["/1.2/instrument/mca_1/preset_time"]
- self.assertLess(mca0_preset_time - 123.4,
- 10**-5)
- # ctime is unique in a this scan and applies to all analysers
- self.assertEqual(mca0_preset_time,
- mca1_preset_time)
-
- mca0_live_time = self.sfh5["/1.2/instrument/mca_0/live_time"]
- mca1_live_time = self.sfh5["/1.2/instrument/mca_1/live_time"]
- self.assertLess(mca0_live_time - 234.5,
- 10**-5)
- self.assertEqual(mca0_live_time,
- mca1_live_time)
-
- mca0_elapsed_time = self.sfh5["/1.2/instrument/mca_0/elapsed_time"]
- mca1_elapsed_time = self.sfh5["/1.2/instrument/mca_1/elapsed_time"]
- self.assertLess(mca0_elapsed_time - 345.6,
- 10**-5)
- self.assertEqual(mca0_elapsed_time,
- mca1_elapsed_time)
-
- def testMcaData(self):
- # sum 1st MCA in scan 1.2 over rows
- mca_0_data = self.sfh5["/1.2/measurement/mca_0/data"]
- for summed_row, expected in zip(mca_0_data.sum(axis=1).tolist(),
- [3.0, 12.1, 21.7]):
- self.assertAlmostEqual(summed_row, expected, places=4)
-
- # sum 3rd MCA in scan 1.2 along both axis
- mca_2_data = self.sfh5["1.2"]["measurement"]["mca_2"]["data"]
- self.assertAlmostEqual(sum(sum(mca_2_data)), 9.1, places=5)
- # attrs
- self.assertEqual(mca_0_data.attrs, {"interpretation": "spectrum"})
-
- def testMotorPosition(self):
- positioners_group = self.sfh5["/1.1/instrument/positioners"]
- # MRTSlit DOWN position is defined in #P0 san header line
- self.assertAlmostEqual(float(positioners_group["MRTSlit DOWN"]),
- 0.87125)
- # MRTSlit UP position is defined in first data column
- for a, b in zip(positioners_group["MRTSlit UP"].tolist(),
- [-1.23, 8.478100E+01, 3.14, 1.2]):
- self.assertAlmostEqual(float(a), b, places=4)
-
- def testNumberMcaAnalysers(self):
- """Scan 1.2 has 2 data columns + 3 mca spectra per data line."""
- self.assertEqual(len(self.sfh5["1.2"]["measurement"]), 5)
-
- def testTitle(self):
- self.assertEqual(self.sfh5["/25.1/title"],
- u"ascan c3th 1.33245 1.52245 40 0.15")
-
- def testValues(self):
- group = self.sfh5["/25.1"]
- self.assertTrue(hasattr(group, "values"))
- self.assertTrue(callable(group.values))
- self.assertIn(self.sfh5["/25.1/title"],
- self.sfh5["/25.1"].values())
-
- # visit and visititems ignore links
- def testVisit(self):
- name_list = []
- self.sfh5.visit(name_list.append)
- self.assertIn('1.2/instrument/positioners/Pslit HGap', name_list)
- self.assertIn("1.2/instrument/specfile/scan_header", name_list)
- self.assertEqual(len(name_list), 117)
-
- # test also visit of a subgroup, with various group name formats
- name_list_leading_and_trailing_slash = []
- self.sfh5['/1.2/instrument/'].visit(name_list_leading_and_trailing_slash.append)
- name_list_leading_slash = []
- self.sfh5['/1.2/instrument'].visit(name_list_leading_slash.append)
- name_list_trailing_slash = []
- self.sfh5['1.2/instrument/'].visit(name_list_trailing_slash.append)
- name_list_no_slash = []
- self.sfh5['1.2/instrument'].visit(name_list_no_slash.append)
-
- # no differences expected in the output names
- self.assertEqual(name_list_leading_and_trailing_slash,
- name_list_leading_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_trailing_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_no_slash)
- self.assertIn("positioners/Pslit HGap", name_list_no_slash)
- self.assertIn("positioners", name_list_no_slash)
-
- def testVisitItems(self):
- dataset_name_list = []
-
- def func_generator(l):
- """return a function appending names to list l"""
- def func(name, obj):
- if isinstance(obj, SpecH5Dataset):
- l.append(name)
- return func
-
- self.sfh5.visititems(func_generator(dataset_name_list))
- self.assertIn('1.2/instrument/positioners/Pslit HGap', dataset_name_list)
- self.assertEqual(len(dataset_name_list), 85)
-
- # test also visit of a subgroup, with various group name formats
- name_list_leading_and_trailing_slash = []
- self.sfh5['/1.2/instrument/'].visititems(func_generator(name_list_leading_and_trailing_slash))
- name_list_leading_slash = []
- self.sfh5['/1.2/instrument'].visititems(func_generator(name_list_leading_slash))
- name_list_trailing_slash = []
- self.sfh5['1.2/instrument/'].visititems(func_generator(name_list_trailing_slash))
- name_list_no_slash = []
- self.sfh5['1.2/instrument'].visititems(func_generator(name_list_no_slash))
-
- # no differences expected in the output names
- self.assertEqual(name_list_leading_and_trailing_slash,
- name_list_leading_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_trailing_slash)
- self.assertEqual(name_list_leading_slash,
- name_list_no_slash)
- self.assertIn("positioners/Pslit HGap", name_list_no_slash)
-
- def testNotSpecH5(self):
- fd, fname = tempfile.mkstemp()
- os.write(fd, b"Not a spec file!")
- os.close(fd)
- self.assertRaises(specfile.SfErrFileOpen, SpecH5, fname)
- self.assertRaises(IOError, SpecH5, fname)
- os.unlink(fname)
-
- def testSample(self):
- self.assertNotIn("sample", self.sfh5["/1.1"])
- self.assertIn("sample", self.sfh5["/1000.1"])
- self.assertIn("ub_matrix", self.sfh5["/1000.1/sample"])
- self.assertIn("unit_cell", self.sfh5["/1000.1/sample"])
- self.assertIn("unit_cell_abc", self.sfh5["/1000.1/sample"])
- self.assertIn("unit_cell_alphabetagamma", self.sfh5["/1000.1/sample"])
-
- # All 0 values
- self.assertNotIn("sample", self.sfh5["/1001.1"])
- with self.assertRaises(KeyError):
- self.sfh5["/1001.1/sample/unit_cell"]
-
- @testutils.test_logging(spech5.logger1.name, warning=2)
- def testOpenFileDescriptor(self):
- """Open a SpecH5 file from a file descriptor"""
- with io.open(self.sfh5.filename) as f:
- sfh5 = SpecH5(f)
- self.assertIsNotNone(sfh5)
- name_list = []
- # check if the object is working
- self.sfh5.visit(name_list.append)
- sfh5.close()
-
-
-sftext_multi_mca_headers = """
-#S 1 aaaaaa
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#@CTIME 123.4 234.5 345.6
-#@MCA %16C
-#@CHANN 3 1 3 1
-#@CALIB 5.5 6.6 7.7
-#@CTIME 10 11 12
-#N 3
-#L uno duo
-1 2
-@A 0 1 2
-@A 10 9 8
-3 4
-@A 3.1 4 5
-@A 7 6 5
-5 6
-@A 6 7.7 8
-@A 4 3 2
-
-"""
-
-
-class TestSpecH5MultiMca(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- fd, cls.fname = tempfile.mkstemp(text=False)
- if sys.version_info < (3, ):
- os.write(fd, sftext_multi_mca_headers)
- else:
- os.write(fd, bytes(sftext_multi_mca_headers, 'ascii'))
- os.close(fd)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname)
-
- def setUp(self):
- self.sfh5 = SpecH5(self.fname)
-
- def tearDown(self):
- self.sfh5.close()
-
- def testMcaCalib(self):
- mca0_calib = self.sfh5["/1.1/measurement/mca_0/info/calibration"]
- mca1_calib = self.sfh5["/1.1/measurement/mca_1/info/calibration"]
- self.assertEqual(mca0_calib.tolist(),
- [1, 2, 3])
- self.assertAlmostEqual(sum(mca1_calib.tolist()),
- sum([5.5, 6.6, 7.7]),
- places=5)
-
- def testMcaChannels(self):
- mca0_chann = self.sfh5["/1.1/measurement/mca_0/info/channels"]
- mca1_chann = self.sfh5["/1.1/measurement/mca_1/info/channels"]
- self.assertEqual(mca0_chann.tolist(),
- [0., 1., 2.])
- # @CHANN is unique in this scan and applies to all analysers
- self.assertEqual(mca1_chann.tolist(),
- [1., 2., 3.])
-
- def testMcaCtime(self):
- """Tests for #@CTIME mca header"""
- mca0_preset_time = self.sfh5["/1.1/instrument/mca_0/preset_time"]
- mca1_preset_time = self.sfh5["/1.1/instrument/mca_1/preset_time"]
- self.assertLess(mca0_preset_time - 123.4,
- 10**-5)
- self.assertLess(mca1_preset_time - 10,
- 10**-5)
-
- mca0_live_time = self.sfh5["/1.1/instrument/mca_0/live_time"]
- mca1_live_time = self.sfh5["/1.1/instrument/mca_1/live_time"]
- self.assertLess(mca0_live_time - 234.5,
- 10**-5)
- self.assertLess(mca1_live_time - 11,
- 10**-5)
-
- mca0_elapsed_time = self.sfh5["/1.1/instrument/mca_0/elapsed_time"]
- mca1_elapsed_time = self.sfh5["/1.1/instrument/mca_1/elapsed_time"]
- self.assertLess(mca0_elapsed_time - 345.6,
- 10**-5)
- self.assertLess(mca1_elapsed_time - 12,
- 10**-5)
-
-
-sftext_no_cols = r"""#F C:/DATA\test.mca
-#D Thu Jul 7 08:40:19 2016
-
-#S 1 31oct98.dat 22.1 If4
-#D Thu Jul 7 08:40:19 2016
-#C no data cols, one mca analyser, single spectrum
-#@MCA %16C
-#@CHANN 151 0 150 1
-#@CALIB 0 2 0
-@A 789 784 788 814 847 862 880 904 925 955 987 1015 1031 1070 1111 1139 \
-1203 1236 1290 1392 1492 1558 1688 1813 1977 2119 2346 2699 3121 3542 4102 4970 \
-6071 7611 10426 16188 28266 40348 50539 55555 56162 54162 47102 35718 24588 17034 12994 11444 \
-11808 13461 15687 18885 23827 31578 41999 49556 58084 59415 59456 55698 44525 28219 17680 12881 \
-9518 7415 6155 5246 4646 3978 3612 3299 3020 2761 2670 2472 2500 2310 2286 2106 \
-1989 1890 1782 1655 1421 1293 1135 990 879 757 672 618 532 488 445 424 \
-414 373 351 325 307 284 270 247 228 213 199 187 183 176 164 156 \
-153 140 142 130 118 118 103 101 97 86 90 86 87 81 75 82 \
-80 76 77 75 76 77 62 69 74 60 65 68 65 58 63 64 \
-63 59 60 56 57 60 55
-
-#S 2 31oct98.dat 22.1 If4
-#D Thu Jul 7 08:40:19 2016
-#C no data cols, one mca analyser, multiple spectra
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#@CTIME 123.4 234.5 345.6
-@A 0 1 2
-@A 10 9 8
-@A 1 1 1.1
-@A 3.1 4 5
-@A 7 6 5
-@A 1 1 1
-@A 6 7.7 8
-@A 4 3 2
-@A 1 1 1
-
-#S 3 31oct98.dat 22.1 If4
-#D Thu Jul 7 08:40:19 2016
-#C no data cols, 3 mca analysers, multiple spectra
-#@MCADEV 1
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#@CTIME 123.4 234.5 345.6
-#@MCADEV 2
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#@CTIME 123.4 234.5 345.6
-#@MCADEV 3
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#@CTIME 123.4 234.5 345.6
-@A 0 1 2
-@A 10 9 8
-@A 1 1 1.1
-@A 3.1 4 5
-@A 7 6 5
-@A 1 1 1
-@A 6 7.7 8
-@A 4 3 2
-@A 1 1 1
-"""
-
-
-class TestSpecH5NoDataCols(unittest.TestCase):
- """Test reading SPEC files with only MCA data"""
- @classmethod
- def setUpClass(cls):
- fd, cls.fname = tempfile.mkstemp()
- if sys.version_info < (3, ):
- os.write(fd, sftext_no_cols)
- else:
- os.write(fd, bytes(sftext_no_cols, 'ascii'))
- os.close(fd)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname)
-
- def setUp(self):
- self.sfh5 = SpecH5(self.fname)
-
- def tearDown(self):
- self.sfh5.close()
-
- def testScan1(self):
- # 1.1: single analyser, single spectrum, 151 channels
- self.assertIn("mca_0",
- self.sfh5["1.1/instrument/"])
- self.assertEqual(self.sfh5["1.1/instrument/mca_0/data"].shape,
- (1, 151))
- self.assertNotIn("mca_1",
- self.sfh5["1.1/instrument/"])
-
- def testScan2(self):
- # 2.1: single analyser, 9 spectra, 3 channels
- self.assertIn("mca_0",
- self.sfh5["2.1/instrument/"])
- self.assertEqual(self.sfh5["2.1/instrument/mca_0/data"].shape,
- (9, 3))
- self.assertNotIn("mca_1",
- self.sfh5["2.1/instrument/"])
-
- def testScan3(self):
- # 3.1: 3 analysers, 3 spectra/analyser, 3 channels
- for i in range(3):
- self.assertIn("mca_%d" % i,
- self.sfh5["3.1/instrument/"])
- self.assertEqual(
- self.sfh5["3.1/instrument/mca_%d/data" % i].shape,
- (3, 3))
-
- self.assertNotIn("mca_3",
- self.sfh5["3.1/instrument/"])
-
-
-sf_text_slash = r"""#F /data/id09/archive/logspecfiles/laue/2016/scan_231_laue_16-11-29.dat
-#D Sat Dec 10 22:20:59 2016
-#O0 Pslit/HGap MRTSlit%UP
-
-#S 1 laue_16-11-29.log 231.1 PD3/A
-#D Sat Dec 10 22:20:59 2016
-#P0 180.005 -0.66875
-#N 2
-#L GONY/mm PD3%A
--2.015 5.250424e-05
--2.01 5.30798e-05
--2.005 5.281903e-05
--2 5.220436e-05
-"""
-
-
-class TestSpecH5SlashInLabels(unittest.TestCase):
- """Test reading SPEC files with labels containing a / character
-
- The / character must be substituted with a %
- """
- @classmethod
- def setUpClass(cls):
- fd, cls.fname = tempfile.mkstemp()
- if sys.version_info < (3, ):
- os.write(fd, sf_text_slash)
- else:
- os.write(fd, bytes(sf_text_slash, 'ascii'))
- os.close(fd)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.fname)
-
- def setUp(self):
- self.sfh5 = SpecH5(self.fname)
-
- def tearDown(self):
- self.sfh5.close()
-
- def testLabels(self):
- """Ensure `/` is substituted with `%` and
- ensure legitimate `%` in names are still working"""
- self.assertEqual(list(self.sfh5["1.1/measurement/"].keys()),
- ["GONY%mm", "PD3%A"])
-
- # substituted "%"
- self.assertIn("GONY%mm",
- self.sfh5["1.1/measurement/"])
- self.assertNotIn("GONY/mm",
- self.sfh5["1.1/measurement/"])
- self.assertAlmostEqual(self.sfh5["1.1/measurement/GONY%mm"][0],
- -2.015, places=4)
- # legitimate "%"
- self.assertIn("PD3%A",
- self.sfh5["1.1/measurement/"])
-
- def testMotors(self):
- """Ensure `/` is substituted with `%` and
- ensure legitimate `%` in names are still working"""
- self.assertEqual(list(self.sfh5["1.1/instrument/positioners"].keys()),
- ["Pslit%HGap", "MRTSlit%UP"])
- # substituted "%"
- self.assertIn("Pslit%HGap",
- self.sfh5["1.1/instrument/positioners"])
- self.assertNotIn("Pslit/HGap",
- self.sfh5["1.1/instrument/positioners"])
- self.assertAlmostEqual(
- self.sfh5["1.1/instrument/positioners/Pslit%HGap"],
- 180.005, places=4)
- # legitimate "%"
- self.assertIn("MRTSlit%UP",
- self.sfh5["1.1/instrument/positioners"])
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecH5))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecDate))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecH5MultiMca))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecH5NoDataCols))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestSpecH5SlashInLabels))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_spectoh5.py b/silx/io/test/test_spectoh5.py
deleted file mode 100644
index 903a62c..0000000
--- a/silx/io/test/test_spectoh5.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for SpecFile to HDF5 converter"""
-
-from numpy import array_equal
-import os
-import sys
-import tempfile
-import unittest
-
-import h5py
-
-from ..spech5 import SpecH5, SpecH5Group
-from ..convert import convert, write_to_h5
-from ..utils import h5py_read_dataset
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "12/02/2018"
-
-
-sfdata = b"""#F /tmp/sf.dat
-#E 1455180875
-#D Thu Feb 11 09:54:35 2016
-#C imaging User = opid17
-#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
-#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
-#o0 pshg mrtu mrtd
-#o2 ss1vo ss1ho ss1vg
-
-#J0 Seconds IA ion.mono Current
-#J1 xbpmc2 idgap1 Inorm
-
-#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
-#D Thu Feb 11 09:55:20 2016
-#T 0.2 (Seconds)
-#P0 180.005 -0.66875 0.87125
-#P1 14.74255 16.197579 12.238283
-#N 4
-#L MRTSlit UP second column 3rd_col
--1.23 5.89 8
-8.478100E+01 5 1.56
-3.14 2.73 -3.14
-1.2 2.3 3.4
-
-#S 1 aaaaaa
-#D Thu Feb 11 10:00:32 2016
-#@MCADEV 1
-#@MCA %16C
-#@CHANN 3 0 2 1
-#@CALIB 1 2 3
-#N 3
-#L uno duo
-1 2
-@A 0 1 2
-@A 10 9 8
-3 4
-@A 3.1 4 5
-@A 7 6 5
-5 6
-@A 6 7.7 8
-@A 4 3 2
-"""
-
-
-class TestConvertSpecHDF5(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- fd, cls.spec_fname = tempfile.mkstemp(prefix="TestConvertSpecHDF5")
- os.write(fd, sfdata)
- os.close(fd)
-
- fd, cls.h5_fname = tempfile.mkstemp(prefix="TestConvertSpecHDF5")
- # Close and delete (we just need the name)
- os.close(fd)
- os.unlink(cls.h5_fname)
-
- @classmethod
- def tearDownClass(cls):
- os.unlink(cls.spec_fname)
-
- def setUp(self):
- convert(self.spec_fname, self.h5_fname)
-
- self.sfh5 = SpecH5(self.spec_fname)
- self.h5f = h5py.File(self.h5_fname, "a")
-
- def tearDown(self):
- self.h5f.close()
- self.sfh5.close()
- os.unlink(self.h5_fname)
-
- def testAppendToHDF5(self):
- write_to_h5(self.sfh5, self.h5f, h5path="/foo/bar/spam")
- self.assertTrue(
- array_equal(self.h5f["/1.2/measurement/mca_1/data"],
- self.h5f["/foo/bar/spam/1.2/measurement/mca_1/data"])
- )
-
- def testWriteSpecH5Group(self):
- """Test passing a SpecH5Group as parameter, instead of a Spec filename
- or a SpecH5."""
- g = self.sfh5["1.1/instrument"]
- self.assertIsInstance(g, SpecH5Group) # let's be paranoid
- write_to_h5(g, self.h5f, h5path="my instruments")
-
- self.assertAlmostEqual(self.h5f["my instruments/positioners/Sslit1 HOff"][tuple()],
- 16.197579, places=4)
-
- def testTitle(self):
- """Test the value of a dataset"""
- title12 = h5py_read_dataset(self.h5f["/1.2/title"])
- self.assertEqual(title12,
- u"aaaaaa")
-
- def testAttrs(self):
- # Test root group (file) attributes
- self.assertEqual(self.h5f.attrs["NX_class"],
- u"NXroot")
- # Test dataset attributes
- ds = self.h5f["/1.2/instrument/mca_1/data"]
- self.assertTrue("interpretation" in ds.attrs)
- self.assertEqual(list(ds.attrs.values()),
- [u"spectrum"])
- # Test group attributes
- grp = self.h5f["1.1"]
- self.assertEqual(grp.attrs["NX_class"],
- u"NXentry")
- self.assertEqual(len(list(grp.attrs.keys())),
- 1)
-
- def testHdf5HasSameMembers(self):
- spec_member_list = []
-
- def append_spec_members(name):
- spec_member_list.append(name)
- self.sfh5.visit(append_spec_members)
-
- hdf5_member_list = []
-
- def append_hdf5_members(name):
- hdf5_member_list.append(name)
- self.h5f.visit(append_hdf5_members)
-
- # 1. For some reason, h5py visit method doesn't include the leading
- # "/" character when it passes the member name to the function,
- # even though an explicit the .name attribute of a member will
- # have a leading "/"
- spec_member_list = [m.lstrip("/") for m in spec_member_list]
-
- self.assertEqual(set(hdf5_member_list),
- set(spec_member_list))
-
- def testLinks(self):
- self.assertTrue(
- array_equal(self.sfh5["/1.2/measurement/mca_0/data"],
- self.h5f["/1.2/measurement/mca_0/data"])
- )
- self.assertTrue(
- array_equal(self.h5f["/1.2/instrument/mca_1/channels"],
- self.h5f["/1.2/measurement/mca_1/info/channels"])
- )
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestConvertSpecHDF5))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/test/test_url.py b/silx/io/test/test_url.py
deleted file mode 100644
index 114f6a7..0000000
--- a/silx/io/test/test_url.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for url module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "29/01/2018"
-
-
-import unittest
-from ..url import DataUrl
-
-
-class TestDataUrl(unittest.TestCase):
-
- def assertUrl(self, url, expected):
- self.assertEqual(url.is_valid(), expected[0])
- self.assertEqual(url.is_absolute(), expected[1])
- self.assertEqual(url.scheme(), expected[2])
- self.assertEqual(url.file_path(), expected[3])
- self.assertEqual(url.data_path(), expected[4])
- self.assertEqual(url.data_slice(), expected[5])
-
- def test_fabio_absolute(self):
- url = DataUrl("fabio:///data/image.edf?slice=2")
- expected = [True, True, "fabio", "/data/image.edf", None, (2, )]
- self.assertUrl(url, expected)
-
- def test_fabio_absolute_windows(self):
- url = DataUrl("fabio:///C:/data/image.edf?slice=2")
- expected = [True, True, "fabio", "C:/data/image.edf", None, (2, )]
- self.assertUrl(url, expected)
-
- def test_silx_absolute(self):
- url = DataUrl("silx:///data/image.h5?path=/data/dataset&slice=1,5")
- expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
- self.assertUrl(url, expected)
-
- def test_commandline_shell_separator(self):
- url = DataUrl("silx:///data/image.h5::path=/data/dataset&slice=1,5")
- expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
- self.assertUrl(url, expected)
-
- def test_silx_absolute2(self):
- url = DataUrl("silx:///data/image.edf?/scan_0/detector/data")
- expected = [True, True, "silx", "/data/image.edf", "/scan_0/detector/data", None]
- self.assertUrl(url, expected)
-
- def test_silx_absolute_windows(self):
- url = DataUrl("silx:///C:/data/image.h5?/scan_0/detector/data")
- expected = [True, True, "silx", "C:/data/image.h5", "/scan_0/detector/data", None]
- self.assertUrl(url, expected)
-
- def test_silx_relative(self):
- url = DataUrl("silx:./image.h5")
- expected = [True, False, "silx", "./image.h5", None, None]
- self.assertUrl(url, expected)
-
- def test_fabio_relative(self):
- url = DataUrl("fabio:./image.edf")
- expected = [True, False, "fabio", "./image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_silx_relative2(self):
- url = DataUrl("silx:image.h5")
- expected = [True, False, "silx", "image.h5", None, None]
- self.assertUrl(url, expected)
-
- def test_fabio_relative2(self):
- url = DataUrl("fabio:image.edf")
- expected = [True, False, "fabio", "image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_relative(self):
- url = DataUrl("image.edf")
- expected = [True, False, None, "image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_relative2(self):
- url = DataUrl("./foo/bar/image.edf")
- expected = [True, False, None, "./foo/bar/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_relative3(self):
- url = DataUrl("foo/bar/image.edf")
- expected = [True, False, None, "foo/bar/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_absolute(self):
- url = DataUrl("/data/image.edf")
- expected = [True, True, None, "/data/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_file_absolute_windows(self):
- url = DataUrl("C:/data/image.edf")
- expected = [True, True, None, "C:/data/image.edf", None, None]
- self.assertUrl(url, expected)
-
- def test_absolute_with_path(self):
- url = DataUrl("/foo/foobar.h5?/foo/bar")
- expected = [True, True, None, "/foo/foobar.h5", "/foo/bar", None]
- self.assertUrl(url, expected)
-
- def test_windows_file_data_slice(self):
- url = DataUrl("C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [True, True, None, "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_scheme_file_data_slice(self):
- url = DataUrl("silx:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [True, True, "silx", "/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_scheme_windows_file_data_slice(self):
- url = DataUrl("silx:C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [True, True, "silx", "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_empty(self):
- url = DataUrl("")
- expected = [False, False, None, "", None, None]
- self.assertUrl(url, expected)
-
- def test_unknown_scheme(self):
- url = DataUrl("foo:/foo/foobar.h5?path=/foo/bar&slice=5,1")
- expected = [False, True, "foo", "/foo/foobar.h5", "/foo/bar", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_slice(self):
- url = DataUrl("/a.h5?path=/b&slice=5,1")
- expected = [True, True, None, "/a.h5", "/b", (5, 1)]
- self.assertUrl(url, expected)
-
- def test_slice2(self):
- url = DataUrl("/a.h5?path=/b&slice=2:5")
- expected = [True, True, None, "/a.h5", "/b", (slice(2, 5),)]
- self.assertUrl(url, expected)
-
- def test_slice3(self):
- url = DataUrl("/a.h5?path=/b&slice=::2")
- expected = [True, True, None, "/a.h5", "/b", (slice(None, None, 2),)]
- self.assertUrl(url, expected)
-
- def test_slice_ellipsis(self):
- url = DataUrl("/a.h5?path=/b&slice=...")
- expected = [True, True, None, "/a.h5", "/b", (Ellipsis, )]
- self.assertUrl(url, expected)
-
- def test_slice_slicing(self):
- url = DataUrl("/a.h5?path=/b&slice=:")
- expected = [True, True, None, "/a.h5", "/b", (slice(None), )]
- self.assertUrl(url, expected)
-
- def test_slice_missing_element(self):
- url = DataUrl("/a.h5?path=/b&slice=5,,1")
- expected = [False, True, None, "/a.h5", "/b", None]
- self.assertUrl(url, expected)
-
- def test_slice_no_elements(self):
- url = DataUrl("/a.h5?path=/b&slice=")
- expected = [False, True, None, "/a.h5", "/b", None]
- self.assertUrl(url, expected)
-
- def test_create_relative_url(self):
- url = DataUrl(scheme="silx", file_path="./foo.h5", data_path="/", data_slice=(5, 1))
- self.assertFalse(url.is_absolute())
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_create_absolute_url(self):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1))
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_create_absolute_windows_url(self):
- url = DataUrl(scheme="silx", file_path="C:/foo.h5", data_path="/", data_slice=(5, 1))
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_create_slice_url(self):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1, Ellipsis, slice(None)))
- url2 = DataUrl(url.path())
- self.assertEqual(url, url2)
-
- def test_wrong_url(self):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=(5, 1))
- self.assertFalse(url.is_valid())
-
- def test_path_creation(self):
- """make sure the construction of path succeed and that we can
- recreate a DataUrl from a path"""
- for data_slice in (1, (1,)):
- with self.subTest(data_slice=data_slice):
- url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=data_slice)
- path = url.path()
- DataUrl(path=path)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestDataUrl))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/io/test/test_utils.py b/silx/io/test/test_utils.py
deleted file mode 100644
index 13ab532..0000000
--- a/silx/io/test/test_utils.py
+++ /dev/null
@@ -1,888 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for utils module"""
-
-import io
-import numpy
-import os
-import re
-import shutil
-import tempfile
-import unittest
-import sys
-
-from .. import utils
-from ..._version import calc_hexversion
-import silx.io.url
-
-import h5py
-from ..utils import h5ls
-
-import fabio
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "03/12/2020"
-
-expected_spec1 = r"""#F .*
-#D .*
-
-#S 1 Ordinate1
-#D .*
-#N 2
-#L Abscissa Ordinate1
-1 4\.00
-2 5\.00
-3 6\.00
-"""
-
-expected_spec2 = expected_spec1 + r"""
-#S 2 Ordinate2
-#D .*
-#N 2
-#L Abscissa Ordinate2
-1 7\.00
-2 8\.00
-3 9\.00
-"""
-
-expected_spec2reg = r"""#F .*
-#D .*
-
-#S 1 Ordinate1
-#D .*
-#N 3
-#L Abscissa Ordinate1 Ordinate2
-1 4\.00 7\.00
-2 5\.00 8\.00
-3 6\.00 9\.00
-"""
-
-expected_spec2irr = expected_spec1 + r"""
-#S 2 Ordinate2
-#D .*
-#N 2
-#L Abscissa Ordinate2
-1 7\.00
-2 8\.00
-"""
-
-expected_csv = r"""Abscissa;Ordinate1;Ordinate2
-1;4\.00;7\.00e\+00
-2;5\.00;8\.00e\+00
-3;6\.00;9\.00e\+00
-"""
-
-expected_csv2 = r"""x;y0;y1
-1;4\.00;7\.00e\+00
-2;5\.00;8\.00e\+00
-3;6\.00;9\.00e\+00
-"""
-
-
-class TestSave(unittest.TestCase):
- """Test saving curves as SpecFile:
- """
-
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.spec_fname = os.path.join(self.tempdir, "savespec.dat")
- self.csv_fname = os.path.join(self.tempdir, "savecsv.csv")
- self.npy_fname = os.path.join(self.tempdir, "savenpy.npy")
-
- self.x = [1, 2, 3]
- self.xlab = "Abscissa"
- self.y = [[4, 5, 6], [7, 8, 9]]
- self.y_irr = [[4, 5, 6], [7, 8]]
- self.ylabs = ["Ordinate1", "Ordinate2"]
-
- def tearDown(self):
- if os.path.isfile(self.spec_fname):
- os.unlink(self.spec_fname)
- if os.path.isfile(self.csv_fname):
- os.unlink(self.csv_fname)
- if os.path.isfile(self.npy_fname):
- os.unlink(self.npy_fname)
- shutil.rmtree(self.tempdir)
-
- def test_save_csv(self):
- utils.save1D(self.csv_fname, self.x, self.y,
- xlabel=self.xlab, ylabels=self.ylabs,
- filetype="csv", fmt=["%d", "%.2f", "%.2e"],
- csvdelim=";", autoheader=True)
-
- csvf = open(self.csv_fname)
- actual_csv = csvf.read()
- csvf.close()
-
- self.assertRegex(actual_csv, expected_csv)
-
- def test_save_npy(self):
- """npy file is saved with numpy.save after building a numpy array
- and converting it to a named record array"""
- npyf = open(self.npy_fname, "wb")
- utils.save1D(npyf, self.x, self.y,
- xlabel=self.xlab, ylabels=self.ylabs)
- npyf.close()
-
- npy_recarray = numpy.load(self.npy_fname)
-
- self.assertEqual(npy_recarray.shape, (3,))
- self.assertTrue(numpy.array_equal(npy_recarray['Ordinate1'],
- numpy.array((4, 5, 6))))
-
- def test_savespec_filename(self):
- """Save SpecFile using savespec()"""
- utils.savespec(self.spec_fname, self.x, self.y[0], xlabel=self.xlab,
- ylabel=self.ylabs[0], fmt=["%d", "%.2f"],
- close_file=True, scan_number=1)
-
- specf = open(self.spec_fname)
- actual_spec = specf.read()
- specf.close()
- self.assertRegex(actual_spec, expected_spec1)
-
- def test_savespec_file_handle(self):
- """Save SpecFile using savespec(), passing a file handle"""
- # first savespec: open, write file header, save y[0] as scan 1,
- # return file handle
- specf = utils.savespec(self.spec_fname, self.x, self.y[0],
- xlabel=self.xlab, ylabel=self.ylabs[0],
- fmt=["%d", "%.2f"], close_file=False)
-
- # second savespec: save y[1] as scan 2, close file
- utils.savespec(specf, self.x, self.y[1], xlabel=self.xlab,
- ylabel=self.ylabs[1], fmt=["%d", "%.2f"],
- write_file_header=False, close_file=True,
- scan_number=2)
-
- specf = open(self.spec_fname)
- actual_spec = specf.read()
- specf.close()
- self.assertRegex(actual_spec, expected_spec2)
-
- def test_save_spec_reg(self):
- """Save SpecFile using save() on a regular pattern"""
- utils.save1D(self.spec_fname, self.x, self.y, xlabel=self.xlab,
- ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
-
- specf = open(self.spec_fname)
- actual_spec = specf.read()
- specf.close()
-
- self.assertRegex(actual_spec, expected_spec2reg)
-
- def test_save_spec_irr(self):
- """Save SpecFile using save() on an irregular pattern"""
- # invalid test case ?!
- return
- utils.save1D(self.spec_fname, self.x, self.y_irr, xlabel=self.xlab,
- ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
-
- specf = open(self.spec_fname)
- actual_spec = specf.read()
- specf.close()
- self.assertRegex(actual_spec, expected_spec2irr)
-
- def test_save_csv_no_labels(self):
- """Save csv using save(), with autoheader=True but
- xlabel=None and ylabels=None
- This is a non-regression test for bug #223"""
- self.tempdir = tempfile.mkdtemp()
- self.spec_fname = os.path.join(self.tempdir, "savespec.dat")
- self.csv_fname = os.path.join(self.tempdir, "savecsv.csv")
- self.npy_fname = os.path.join(self.tempdir, "savenpy.npy")
-
- self.x = [1, 2, 3]
- self.xlab = "Abscissa"
- self.y = [[4, 5, 6], [7, 8, 9]]
- self.ylabs = ["Ordinate1", "Ordinate2"]
- utils.save1D(self.csv_fname, self.x, self.y,
- autoheader=True, fmt=["%d", "%.2f", "%.2e"])
-
- csvf = open(self.csv_fname)
- actual_csv = csvf.read()
- csvf.close()
- self.assertRegex(actual_csv, expected_csv2)
-
-
-def assert_match_any_string_in_list(test, pattern, list_of_strings):
- for string_ in list_of_strings:
- if re.match(pattern, string_):
- return True
- return False
-
-
-class TestH5Ls(unittest.TestCase):
- """Test displaying the following HDF5 file structure:
-
- +foo
- +bar
- <HDF5 dataset "spam": shape (2, 2), type "<i8">
- <HDF5 dataset "tmp": shape (3,), type "<i8">
- <HDF5 dataset "data": shape (1,), type "<f8">
-
- """
-
- def assertMatchAnyStringInList(self, pattern, list_of_strings):
- for string_ in list_of_strings:
- if re.match(pattern, string_):
- return None
- raise AssertionError("regex pattern %s does not match any" % pattern +
- " string in list " + str(list_of_strings))
-
- def testHdf5(self):
- fd, self.h5_fname = tempfile.mkstemp(text=False)
- # Close and delete (we just want the name)
- os.close(fd)
- os.unlink(self.h5_fname)
- self.h5f = h5py.File(self.h5_fname, "w")
- self.h5f["/foo/bar/tmp"] = [1, 2, 3]
- self.h5f["/foo/bar/spam"] = [[1, 2], [3, 4]]
- self.h5f["/foo/data"] = [3.14]
- self.h5f.close()
-
- rep = h5ls(self.h5_fname)
- lines = rep.split("\n")
-
- self.assertIn("+foo", lines)
- self.assertIn("\t+bar", lines)
-
- match = r'\t\t<HDF5 dataset "tmp": shape \(3,\), type "<i[48]">'
- self.assertMatchAnyStringInList(match, lines)
- match = r'\t\t<HDF5 dataset "spam": shape \(2, 2\), type "<i[48]">'
- self.assertMatchAnyStringInList(match, lines)
- match = r'\t<HDF5 dataset "data": shape \(1,\), type "<f[48]">'
- self.assertMatchAnyStringInList(match, lines)
-
- os.unlink(self.h5_fname)
-
- # Following test case disabled d/t errors on AppVeyor:
- # os.unlink(spec_fname)
- # PermissionError: [WinError 32] The process cannot access the file because
- # it is being used by another process: 'C:\\...\\savespec.dat'
-
- # def testSpec(self):
- # tempdir = tempfile.mkdtemp()
- # spec_fname = os.path.join(tempdir, "savespec.dat")
- #
- # x = [1, 2, 3]
- # xlab = "Abscissa"
- # y = [[4, 5, 6], [7, 8, 9]]
- # ylabs = ["Ordinate1", "Ordinate2"]
- # utils.save1D(spec_fname, x, y, xlabel=xlab,
- # ylabels=ylabs, filetype="spec",
- # fmt=["%d", "%.2f"])
- #
- # rep = h5ls(spec_fname)
- # lines = rep.split("\n")
- # self.assertIn("+1.1", lines)
- # self.assertIn("\t+instrument", lines)
- #
- # self.assertMatchAnyStringInList(
- # r'\t\t\t<SPEC dataset "file_header": shape \(\), type "|S60">',
- # lines)
- # self.assertMatchAnyStringInList(
- # r'\t\t<SPEC dataset "Ordinate1": shape \(3L?,\), type "<f4">',
- # lines)
- #
- # os.unlink(spec_fname)
- # shutil.rmtree(tempdir)
-
-
-class TestOpen(unittest.TestCase):
- """Test `silx.io.utils.open` function."""
-
- @classmethod
- def setUpClass(cls):
- cls.tmp_directory = tempfile.mkdtemp()
- cls.createResources(cls.tmp_directory)
-
- @classmethod
- def createResources(cls, directory):
-
- cls.h5_filename = os.path.join(directory, "test.h5")
- h5 = h5py.File(cls.h5_filename, mode="w")
- h5["group/group/dataset"] = 50
- h5.close()
-
- cls.spec_filename = os.path.join(directory, "test.dat")
- utils.savespec(cls.spec_filename, [1], [1.1], xlabel="x", ylabel="y",
- fmt=["%d", "%.2f"], close_file=True, scan_number=1)
-
- cls.edf_filename = os.path.join(directory, "test.edf")
- header = fabio.fabioimage.OrderedDict()
- header["integer"] = "10"
- data = numpy.array([[10, 50], [50, 10]])
- fabiofile = fabio.edfimage.EdfImage(data, header)
- fabiofile.write(cls.edf_filename)
-
- cls.txt_filename = os.path.join(directory, "test.txt")
- f = io.open(cls.txt_filename, "w+t")
- f.write(u"Kikoo")
- f.close()
-
- cls.missing_filename = os.path.join(directory, "test.missing")
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.tmp_directory)
-
- def testH5(self):
- f = utils.open(self.h5_filename)
- self.assertIsNotNone(f)
- self.assertIsInstance(f, h5py.File)
- f.close()
-
- def testH5With(self):
- with utils.open(self.h5_filename) as f:
- self.assertIsNotNone(f)
- self.assertIsInstance(f, h5py.File)
-
- def testH5_withPath(self):
- f = utils.open(self.h5_filename + "::/group/group/dataset")
- self.assertIsNotNone(f)
- self.assertEqual(f.h5py_class, h5py.Dataset)
- self.assertEqual(f[()], 50)
- f.close()
-
- def testH5With_withPath(self):
- with utils.open(self.h5_filename + "::/group/group") as f:
- self.assertIsNotNone(f)
- self.assertEqual(f.h5py_class, h5py.Group)
- self.assertIn("dataset", f)
-
- def testSpec(self):
- f = utils.open(self.spec_filename)
- self.assertIsNotNone(f)
- self.assertEqual(f.h5py_class, h5py.File)
- f.close()
-
- def testSpecWith(self):
- with utils.open(self.spec_filename) as f:
- self.assertIsNotNone(f)
- self.assertEqual(f.h5py_class, h5py.File)
-
- def testEdf(self):
- f = utils.open(self.edf_filename)
- self.assertIsNotNone(f)
- self.assertEqual(f.h5py_class, h5py.File)
- f.close()
-
- def testEdfWith(self):
- with utils.open(self.edf_filename) as f:
- self.assertIsNotNone(f)
- self.assertEqual(f.h5py_class, h5py.File)
-
- def testUnsupported(self):
- self.assertRaises(IOError, utils.open, self.txt_filename)
-
- def testNotExists(self):
- # load it
- self.assertRaises(IOError, utils.open, self.missing_filename)
-
- def test_silx_scheme(self):
- url = silx.io.url.DataUrl(scheme="silx", file_path=self.h5_filename, data_path="/")
- with utils.open(url.path()) as f:
- self.assertIsNotNone(f)
- self.assertTrue(silx.io.utils.is_file(f))
-
- def test_fabio_scheme(self):
- url = silx.io.url.DataUrl(scheme="fabio", file_path=self.edf_filename)
- self.assertRaises(IOError, utils.open, url.path())
-
- def test_bad_url(self):
- url = silx.io.url.DataUrl(scheme="sil", file_path=self.h5_filename)
- self.assertRaises(IOError, utils.open, url.path())
-
- def test_sliced_url(self):
- url = silx.io.url.DataUrl(file_path=self.h5_filename, data_slice=(5,))
- self.assertRaises(IOError, utils.open, url.path())
-
-
-class TestNodes(unittest.TestCase):
- """Test `silx.io.utils.is_` functions."""
-
- def test_real_h5py_objects(self):
- name = tempfile.mktemp(suffix=".h5")
- try:
- with h5py.File(name, "w") as h5file:
- h5group = h5file.create_group("arrays")
- h5dataset = h5group.create_dataset("scalar", data=10)
-
- self.assertTrue(utils.is_file(h5file))
- self.assertTrue(utils.is_group(h5file))
- self.assertFalse(utils.is_dataset(h5file))
-
- self.assertFalse(utils.is_file(h5group))
- self.assertTrue(utils.is_group(h5group))
- self.assertFalse(utils.is_dataset(h5group))
-
- self.assertFalse(utils.is_file(h5dataset))
- self.assertFalse(utils.is_group(h5dataset))
- self.assertTrue(utils.is_dataset(h5dataset))
- finally:
- os.unlink(name)
-
- def test_h5py_like_file(self):
-
- class Foo(object):
-
- def __init__(self):
- self.h5_class = utils.H5Type.FILE
-
- obj = Foo()
- self.assertTrue(utils.is_file(obj))
- self.assertTrue(utils.is_group(obj))
- self.assertFalse(utils.is_dataset(obj))
-
- def test_h5py_like_group(self):
-
- class Foo(object):
-
- def __init__(self):
- self.h5_class = utils.H5Type.GROUP
-
- obj = Foo()
- self.assertFalse(utils.is_file(obj))
- self.assertTrue(utils.is_group(obj))
- self.assertFalse(utils.is_dataset(obj))
-
- def test_h5py_like_dataset(self):
-
- class Foo(object):
-
- def __init__(self):
- self.h5_class = utils.H5Type.DATASET
-
- obj = Foo()
- self.assertFalse(utils.is_file(obj))
- self.assertFalse(utils.is_group(obj))
- self.assertTrue(utils.is_dataset(obj))
-
- def test_bad(self):
-
- class Foo(object):
-
- def __init__(self):
- pass
-
- obj = Foo()
- self.assertFalse(utils.is_file(obj))
- self.assertFalse(utils.is_group(obj))
- self.assertFalse(utils.is_dataset(obj))
-
- def test_bad_api(self):
-
- class Foo(object):
-
- def __init__(self):
- self.h5_class = int
-
- obj = Foo()
- self.assertFalse(utils.is_file(obj))
- self.assertFalse(utils.is_group(obj))
- self.assertFalse(utils.is_dataset(obj))
-
-
-class TestGetData(unittest.TestCase):
- """Test `silx.io.utils.get_data` function."""
-
- @classmethod
- def setUpClass(cls):
- cls.tmp_directory = tempfile.mkdtemp()
- cls.createResources(cls.tmp_directory)
-
- @classmethod
- def createResources(cls, directory):
-
- cls.h5_filename = os.path.join(directory, "test.h5")
- h5 = h5py.File(cls.h5_filename, mode="w")
- h5["group/group/scalar"] = 50
- h5["group/group/array"] = [1, 2, 3, 4, 5]
- h5["group/group/array2d"] = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
- h5.close()
-
- cls.spec_filename = os.path.join(directory, "test.dat")
- utils.savespec(cls.spec_filename, [1], [1.1], xlabel="x", ylabel="y",
- fmt=["%d", "%.2f"], close_file=True, scan_number=1)
-
- cls.edf_filename = os.path.join(directory, "test.edf")
- cls.edf_multiframe_filename = os.path.join(directory, "test_multi.edf")
- header = fabio.fabioimage.OrderedDict()
- header["integer"] = "10"
- data = numpy.array([[10, 50], [50, 10]])
- fabiofile = fabio.edfimage.EdfImage(data, header)
- fabiofile.write(cls.edf_filename)
- fabiofile.append_frame(data=data, header=header)
- fabiofile.write(cls.edf_multiframe_filename)
-
- cls.txt_filename = os.path.join(directory, "test.txt")
- f = io.open(cls.txt_filename, "w+t")
- f.write(u"Kikoo")
- f.close()
-
- cls.missing_filename = os.path.join(directory, "test.missing")
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.tmp_directory)
-
- def test_hdf5_scalar(self):
- url = "silx:%s?/group/group/scalar" % self.h5_filename
- data = utils.get_data(url=url)
- self.assertEqual(data, 50)
-
- def test_hdf5_array(self):
- url = "silx:%s?/group/group/array" % self.h5_filename
- data = utils.get_data(url=url)
- self.assertEqual(data.shape, (5,))
- self.assertEqual(data[0], 1)
-
- def test_hdf5_array_slice(self):
- url = "silx:%s?path=/group/group/array2d&slice=1" % self.h5_filename
- data = utils.get_data(url=url)
- self.assertEqual(data.shape, (5,))
- self.assertEqual(data[0], 6)
-
- def test_hdf5_array_slice_out_of_range(self):
- url = "silx:%s?path=/group/group/array2d&slice=5" % self.h5_filename
- # ValueError: h5py 2.x
- # IndexError: h5py 3.x
- self.assertRaises((ValueError, IndexError), utils.get_data, url)
-
- def test_edf_using_silx(self):
- url = "silx:%s?/scan_0/instrument/detector_0/data" % self.edf_filename
- data = utils.get_data(url=url)
- self.assertEqual(data.shape, (2, 2))
- self.assertEqual(data[0, 0], 10)
-
- def test_fabio_frame(self):
- url = "fabio:%s?slice=1" % self.edf_multiframe_filename
- data = utils.get_data(url=url)
- self.assertEqual(data.shape, (2, 2))
- self.assertEqual(data[0, 0], 10)
-
- def test_fabio_singleframe(self):
- url = "fabio:%s?slice=0" % self.edf_filename
- data = utils.get_data(url=url)
- self.assertEqual(data.shape, (2, 2))
- self.assertEqual(data[0, 0], 10)
-
- def test_fabio_too_much_frames(self):
- url = "fabio:%s?slice=..." % self.edf_multiframe_filename
- self.assertRaises(ValueError, utils.get_data, url)
-
- def test_fabio_no_frame(self):
- url = "fabio:%s" % self.edf_filename
- data = utils.get_data(url=url)
- self.assertEqual(data.shape, (2, 2))
- self.assertEqual(data[0, 0], 10)
-
- def test_unsupported_scheme(self):
- url = "foo:/foo/bar"
- self.assertRaises(ValueError, utils.get_data, url)
-
- def test_no_scheme(self):
- url = "%s?path=/group/group/array2d&slice=5" % self.h5_filename
- self.assertRaises((ValueError, IOError), utils.get_data, url)
-
- def test_file_not_exists(self):
- url = "silx:/foo/bar"
- self.assertRaises(IOError, utils.get_data, url)
-
-
-def _h5_py_version_older_than(version):
- v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
- r_majeur, r_mineur, r_micro = [int(i) for i in version.split('.')]
- return calc_hexversion(v_majeur, v_mineur, v_micro) >= calc_hexversion(r_majeur, r_mineur, r_micro)
-
-
-@unittest.skipUnless(_h5_py_version_older_than('2.9.0'), 'h5py version < 2.9.0')
-class TestRawFileToH5(unittest.TestCase):
- """Test conversion of .vol file to .h5 external dataset"""
-
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self._vol_file = os.path.join(self.tempdir, 'test_vol.vol')
- self._file_info = os.path.join(self.tempdir, 'test_vol.info.vol')
- self._dataset_shape = 100, 20, 5
- data = numpy.random.random(self._dataset_shape[0] *
- self._dataset_shape[1] *
- self._dataset_shape[2]).astype(dtype=numpy.float32).reshape(self._dataset_shape)
- numpy.save(file=self._vol_file, arr=data)
- # those are storing into .noz file
- assert os.path.exists(self._vol_file + '.npy')
- os.rename(self._vol_file + '.npy', self._vol_file)
- self.h5_file = os.path.join(self.tempdir, 'test_h5.h5')
- self.external_dataset_path = '/root/my_external_dataset'
- self._data_url = silx.io.url.DataUrl(file_path=self.h5_file,
- data_path=self.external_dataset_path)
- with open(self._file_info, 'w') as _fi:
- _fi.write('NUM_X = %s\n' % self._dataset_shape[2])
- _fi.write('NUM_Y = %s\n' % self._dataset_shape[1])
- _fi.write('NUM_Z = %s\n' % self._dataset_shape[0])
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def check_dataset(self, h5_file, data_path, shape):
- """Make sure the external dataset is valid"""
- with h5py.File(h5_file, 'r') as _file:
- return data_path in _file and _file[data_path].shape == shape
-
- def test_h5_file_not_existing(self):
- """Test that can create a file with external dataset from scratch"""
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
- os.remove(self.h5_file)
- utils.vol_to_h5_external_dataset(vol_file=self._vol_file,
- output_url=self._data_url,
- info_file=self._file_info)
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
-
- def test_h5_file_existing(self):
- """Test that can add the external dataset from an existing file"""
- with h5py.File(self.h5_file, 'w') as _file:
- _file['/root/dataset1'] = numpy.zeros((100, 100))
- _file['/root/group/dataset2'] = numpy.ones((100, 100))
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
-
- def test_vol_file_not_existing(self):
- """Make sure error is raised if .vol file does not exists"""
- os.remove(self._vol_file)
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
-
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
-
- def test_conflicts(self):
- """Test several conflict cases"""
- # test if path already exists
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- dtype=numpy.float32)
- with self.assertRaises(ValueError):
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- overwrite=False,
- dtype=numpy.float32)
-
- utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
- output_url=self._data_url,
- shape=(100, 20, 5),
- overwrite=True,
- dtype=numpy.float32)
-
- self.assertTrue(self.check_dataset(h5_file=self.h5_file,
- data_path=self.external_dataset_path,
- shape=self._dataset_shape))
-
-
-class TestH5Strings(unittest.TestCase):
- """Test HDF5 str and bytes writing and reading"""
-
- @classmethod
- def setUpClass(cls):
- cls.tempdir = tempfile.mkdtemp()
- cls.vlenstr = h5py.special_dtype(vlen=str)
- cls.vlenbytes = h5py.special_dtype(vlen=bytes)
- try:
- cls.unicode = unicode
- except NameError:
- cls.unicode = str
-
- @classmethod
- def tearDownClass(cls):
- shutil.rmtree(cls.tempdir)
-
- def setUp(self):
- self.file = h5py.File(os.path.join(self.tempdir, 'file.h5'), mode="w")
-
- def tearDown(self):
- self.file.close()
-
- @classmethod
- def _make_array(cls, value, n):
- if isinstance(value, bytes):
- dtype = cls.vlenbytes
- elif isinstance(value, cls.unicode):
- dtype = cls.vlenstr
- else:
- return numpy.array([value] * n)
- return numpy.array([value] * n, dtype=dtype)
-
- @classmethod
- def _get_charset(cls, value):
- if isinstance(value, bytes):
- return h5py.h5t.CSET_ASCII
- elif isinstance(value, cls.unicode):
- return h5py.h5t.CSET_UTF8
- else:
- return None
-
- def _check_dataset(self, value, result=None):
- # Write+read scalar
- if result:
- decode_ascii = True
- else:
- decode_ascii = False
- result = value
- charset = self._get_charset(value)
- self.file["data"] = value
- data = utils.h5py_read_dataset(self.file["data"], decode_ascii=decode_ascii)
- assert type(data) == type(result), data
- assert data == result, data
- if charset:
- assert self.file["data"].id.get_type().get_cset() == charset
-
- # Write+read variable length
- self.file["vlen_data"] = self._make_array(value, 2)
- data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii, index=0)
- assert type(data) == type(result), data
- assert data == result, data
- data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii)
- numpy.testing.assert_array_equal(data, [result] * 2)
- if charset:
- assert self.file["vlen_data"].id.get_type().get_cset() == charset
-
- def _check_attribute(self, value, result=None):
- if result:
- decode_ascii = True
- else:
- decode_ascii = False
- result = value
- self.file.attrs["data"] = value
- data = utils.h5py_read_attribute(self.file.attrs, "data", decode_ascii=decode_ascii)
- assert type(data) == type(result), data
- assert data == result, data
-
- self.file.attrs["vlen_data"] = self._make_array(value, 2)
- data = utils.h5py_read_attribute(self.file.attrs, "vlen_data", decode_ascii=decode_ascii)
- assert type(data[0]) == type(result), data[0]
- assert data[0] == result, data[0]
- numpy.testing.assert_array_equal(data, [result] * 2)
-
- data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)["vlen_data"]
- assert type(data[0]) == type(result), data[0]
- assert data[0] == result, data[0]
- numpy.testing.assert_array_equal(data, [result] * 2)
-
- def test_dataset_ascii_bytes(self):
- self._check_dataset(b"abc")
-
- def test_attribute_ascii_bytes(self):
- self._check_attribute(b"abc")
-
- def test_dataset_ascii_bytes_decode(self):
- self._check_dataset(b"abc", result="abc")
-
- def test_attribute_ascii_bytes_decode(self):
- self._check_attribute(b"abc", result="abc")
-
- def test_dataset_ascii_str(self):
- self._check_dataset("abc")
-
- def test_attribute_ascii_str(self):
- self._check_attribute("abc")
-
- def test_dataset_utf8_str(self):
- self._check_dataset("\u0101bc")
-
- def test_attribute_utf8_str(self):
- self._check_attribute("\u0101bc")
-
- def test_dataset_utf8_bytes(self):
- # 0xC481 is the byte representation of U+0101
- self._check_dataset(b"\xc4\x81bc")
-
- def test_attribute_utf8_bytes(self):
- # 0xC481 is the byte representation of U+0101
- self._check_attribute(b"\xc4\x81bc")
-
- def test_dataset_utf8_bytes_decode(self):
- # 0xC481 is the byte representation of U+0101
- self._check_dataset(b"\xc4\x81bc", result="\u0101bc")
-
- def test_attribute_utf8_bytes_decode(self):
- # 0xC481 is the byte representation of U+0101
- self._check_attribute(b"\xc4\x81bc", result="\u0101bc")
-
- def test_dataset_latin1_bytes(self):
- # extended ascii character 0xE4
- self._check_dataset(b"\xe423")
-
- def test_attribute_latin1_bytes(self):
- # extended ascii character 0xE4
- self._check_attribute(b"\xe423")
-
- def test_dataset_latin1_bytes_decode(self):
- # U+DCE4: surrogate for extended ascii character 0xE4
- self._check_dataset(b"\xe423", result="\udce423")
-
- def test_attribute_latin1_bytes_decode(self):
- # U+DCE4: surrogate for extended ascii character 0xE4
- self._check_attribute(b"\xe423", result="\udce423")
-
- def test_dataset_no_string(self):
- self._check_dataset(numpy.int64(10))
-
- def test_attribute_no_string(self):
- self._check_attribute(numpy.int64(10))
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestSave))
- test_suite.addTest(loadTests(TestH5Ls))
- test_suite.addTest(loadTests(TestOpen))
- test_suite.addTest(loadTests(TestNodes))
- test_suite.addTest(loadTests(TestGetData))
- test_suite.addTest(loadTests(TestRawFileToH5))
- test_suite.addTest(loadTests(TestH5Strings))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/io/url.py b/silx/io/url.py
deleted file mode 100644
index 66b75f0..0000000
--- a/silx/io/url.py
+++ /dev/null
@@ -1,390 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""URL module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "29/01/2018"
-
-import logging
-import six
-from collections.abc import Iterable
-
-parse = six.moves.urllib.parse
-
-
-_logger = logging.getLogger(__name__)
-
-
-class DataUrl(object):
- """Non-mutable object to parse a string representing a resource data
- locator.
-
- It supports:
-
- - path to file and path inside file to the data
- - data slicing
- - fabio or silx access to the data
- - absolute and relative file access
-
- >>> # fabio access using absolute path
- >>> DataUrl("fabio:///data/image.edf?slice=2")
- >>> DataUrl("fabio:///C:/data/image.edf?slice=2")
-
- >>> # silx access using absolute path
- >>> DataUrl("silx:///data/image.h5?path=/data/dataset&slice=1,5")
- >>> DataUrl("silx:///data/image.edf?path=/scan_0/detector/data")
- >>> DataUrl("silx:///C:/data/image.edf?path=/scan_0/detector/data")
-
- >>> # `path=` can be omited if there is no other query keys
- >>> DataUrl("silx:///data/image.h5?/data/dataset")
- >>> # is the same as
- >>> DataUrl("silx:///data/image.h5?path=/data/dataset")
-
- >>> # `::` can be used instead of `?` which can be useful with shell in
- >>> # command lines
- >>> DataUrl("silx:///data/image.h5::/data/dataset")
- >>> # is the same as
- >>> DataUrl("silx:///data/image.h5?/data/dataset")
-
- >>> # Relative path access
- >>> DataUrl("silx:./image.h5")
- >>> DataUrl("fabio:./image.edf")
- >>> DataUrl("silx:image.h5")
- >>> DataUrl("fabio:image.edf")
-
- >>> # Is also support parsing of file access for convenience
- >>> DataUrl("./foo/bar/image.edf")
- >>> DataUrl("C:/data/")
-
- :param str path: Path representing a link to a data. If specified, other
- arguments are not used.
- :param str file_path: Link to the file containing the the data.
- None if there is no data selection.
- :param str data_path: Data selection applyed to the data file selected.
- None if there is no data selection.
- :param Tuple[int,slice,Ellipse] data_slice: Slicing applyed of the selected
- data. None if no slicing applyed.
- :param Union[str,None] scheme: Scheme of the URL. "silx", "fabio"
- is supported. Other strings can be provided, but :meth:`is_valid` will
- be false.
- """
- def __init__(self, path=None, file_path=None, data_path=None, data_slice=None, scheme=None):
- self.__is_valid = False
- if path is not None:
- assert(file_path is None)
- assert(data_path is None)
- assert(data_slice is None)
- assert(scheme is None)
- self.__parse_from_path(path)
- else:
- self.__file_path = file_path
- self.__data_path = data_path
- self.__data_slice = data_slice
- self.__scheme = scheme
- self.__path = None
- self.__check_validity()
-
- def __eq__(self, other):
- if not isinstance(other, DataUrl):
- return False
- if self.is_valid() != other.is_valid():
- return False
- if self.is_valid():
- if self.__scheme != other.scheme():
- return False
- if self.__file_path != other.file_path():
- return False
- if self.__data_path != other.data_path():
- return False
- if self.__data_slice != other.data_slice():
- return False
- return True
- else:
- return self.__path == other.path()
-
- def __ne__(self, other):
- return not (self == other)
-
- def __repr__(self):
- return str(self)
-
- def __str__(self):
- if self.is_valid() or self.__path is None:
- def quote_string(string):
- if isinstance(string, six.string_types):
- return "'%s'" % string
- else:
- return string
-
- template = "DataUrl(valid=%s, scheme=%s, file_path=%s, data_path=%s, data_slice=%s)"
- return template % (self.__is_valid,
- quote_string(self.__scheme),
- quote_string(self.__file_path),
- quote_string(self.__data_path),
- self.__data_slice)
- else:
- template = "DataUrl(valid=%s, string=%s)"
- return template % (self.__is_valid, self.__path)
-
- def __check_validity(self):
- """Check the validity of the attributes."""
- if self.__file_path in [None, ""]:
- self.__is_valid = False
- return
-
- if self.__scheme is None:
- self.__is_valid = True
- elif self.__scheme == "fabio":
- self.__is_valid = self.__data_path is None
- elif self.__scheme == "silx":
- # If there is a slice you must have a data path
- # But you can have a data path without slice
- slice_implies_data = (self.__data_path is None and self.__data_slice is None) or self.__data_path is not None
- self.__is_valid = slice_implies_data
- else:
- self.__is_valid = False
-
- @staticmethod
- def _parse_slice(slice_string):
- """Parse a slicing sequence and return an associated tuple.
-
- It supports a sequence of `...`, `:`, and integers separated by a coma.
-
- :rtype: tuple
- """
- def str_to_slice(string):
- if string == "...":
- return Ellipsis
- elif ':' in string:
- if string == ":":
- return slice(None)
- else:
- def get_value(my_str):
- if my_str in ('', None):
- return None
- else:
- return int(my_str)
- sss = string.split(':')
- start = get_value(sss[0])
- stop = get_value(sss[1] if len(sss) > 1 else None)
- step = get_value(sss[2] if len(sss) > 2 else None)
- return slice(start, stop, step)
- else:
- return int(string)
-
- if slice_string == "":
- raise ValueError("An empty slice is not valid")
-
- tokens = slice_string.split(",")
- data_slice = []
- for t in tokens:
- try:
- data_slice.append(str_to_slice(t))
- except ValueError:
- raise ValueError("'%s' is not a valid slicing" % t)
- return tuple(data_slice)
-
- def __parse_from_path(self, path):
- """Parse the path and initialize attributes.
-
- :param str path: Path representing the URL.
- """
- self.__path = path
- # only replace if ? not here already. Otherwise can mess sith
- # data_slice if == ::2 for example
- if '?' not in path:
- path = path.replace("::", "?", 1)
- url = parse.urlparse(path)
-
- is_valid = True
-
- if len(url.scheme) <= 2:
- # Windows driver
- scheme = None
- pos = self.__path.index(url.path)
- file_path = self.__path[0:pos] + url.path
- else:
- scheme = url.scheme if url.scheme != "" else None
- file_path = url.path
-
- # Check absolute windows path
- if len(file_path) > 2 and file_path[0] == '/':
- if file_path[1] == ":" or file_path[2] == ":":
- file_path = file_path[1:]
-
- self.__scheme = scheme
- self.__file_path = file_path
-
- query = parse.parse_qsl(url.query, keep_blank_values=True)
- if len(query) == 1 and query[0][1] == "":
- # there is no query keys
- data_path = query[0][0]
- data_slice = None
- else:
- merged_query = {}
- for name, value in query:
- if name in query:
- merged_query[name].append(value)
- else:
- merged_query[name] = [value]
-
- def pop_single_value(merged_query, name):
- if name in merged_query:
- values = merged_query.pop(name)
- if len(values) > 1:
- _logger.warning("More than one query key named '%s'. The last one is used.", name)
- value = values[-1]
- else:
- value = None
- return value
-
- data_path = pop_single_value(merged_query, "path")
- data_slice = pop_single_value(merged_query, "slice")
- if data_slice is not None:
- try:
- data_slice = self._parse_slice(data_slice)
- except ValueError:
- is_valid = False
- data_slice = None
-
- for key in merged_query.keys():
- _logger.warning("Query key %s unsupported. Key skipped.", key)
-
- self.__data_path = data_path
- self.__data_slice = data_slice
-
- if is_valid:
- self.__check_validity()
- else:
- self.__is_valid = False
-
- def is_valid(self):
- """Returns true if the URL is valid. Else attributes can be None.
-
- :rtype: bool
- """
- return self.__is_valid
-
- def path(self):
- """Returns the string representing the URL.
-
- :rtype: str
- """
- if self.__path is not None:
- return self.__path
-
- def slice_to_string(data_slice):
- if data_slice == Ellipsis:
- return "..."
- elif data_slice == slice(None):
- return ":"
- elif isinstance(data_slice, int):
- return str(data_slice)
- else:
- raise TypeError("Unexpected slicing type. Found %s" % type(data_slice))
-
- if self.__data_path is not None and self.__data_slice is None:
- query = self.__data_path
- else:
- queries = []
- if self.__data_path is not None:
- queries.append("path=" + self.__data_path)
- if self.__data_slice is not None:
- if isinstance(self.__data_slice, Iterable):
- data_slice = ",".join([slice_to_string(s) for s in self.__data_slice])
- else:
- data_slice = slice_to_string(self.__data_slice)
- queries.append("slice=" + data_slice)
- query = "&".join(queries)
-
- path = ""
- if self.__file_path is not None:
- path += self.__file_path
-
- if query != "":
- path = path + "?" + query
-
- if self.__scheme is not None:
- if self.is_absolute():
- if path.startswith("/"):
- path = self.__scheme + "://" + path
- else:
- path = self.__scheme + ":///" + path
- else:
- path = self.__scheme + ":" + path
-
- return path
-
- def is_absolute(self):
- """Returns true if the file path is an absolute path.
-
- :rtype: bool
- """
- file_path = self.file_path()
- if file_path is None:
- return False
- if len(file_path) > 0:
- if file_path[0] == "/":
- return True
- if len(file_path) > 2:
- # Windows
- if file_path[1] == ":" or file_path[2] == ":":
- return True
- elif len(file_path) > 1:
- # Windows
- if file_path[1] == ":":
- return True
- return False
-
- def file_path(self):
- """Returns the path to the file containing the data.
-
- :rtype: str
- """
- return self.__file_path
-
- def data_path(self):
- """Returns the path inside the file to the data.
-
- :rtype: str
- """
- return self.__data_path
-
- def data_slice(self):
- """Returns the slicing applied to the data.
-
- It is a tuple containing numbers, slice or ellipses.
-
- :rtype: Tuple[int, slice, Ellipse]
- """
- return self.__data_slice
-
- def scheme(self):
- """Returns the scheme. It can be None if no scheme is specified.
-
- :rtype: Union[str, None]
- """
- return self.__scheme
diff --git a/silx/io/utils.py b/silx/io/utils.py
deleted file mode 100644
index 12e9a7e..0000000
--- a/silx/io/utils.py
+++ /dev/null
@@ -1,1142 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-""" I/O utility functions"""
-
-__authors__ = ["P. Knobel", "V. Valls"]
-__license__ = "MIT"
-__date__ = "03/12/2020"
-
-import enum
-import os.path
-import sys
-import time
-import logging
-import collections
-
-import numpy
-import six
-
-from silx.utils.proxy import Proxy
-import silx.io.url
-from .._version import calc_hexversion
-
-import h5py
-import h5py.h5t
-import h5py.h5a
-
-try:
- import h5pyd
-except ImportError as e:
- h5pyd = None
-
-logger = logging.getLogger(__name__)
-
-NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"]
-"""List of possible extensions for HDF5 file formats."""
-
-
-class H5Type(enum.Enum):
- """Identify a set of HDF5 concepts"""
- DATASET = 1
- GROUP = 2
- FILE = 3
- SOFT_LINK = 4
- EXTERNAL_LINK = 5
- HARD_LINK = 6
-
-
-_CLASSES_TYPE = None
-"""Store mapping between classes and types"""
-
-string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
-
-builtin_open = open
-
-
-def supported_extensions(flat_formats=True):
- """Returns the list file extensions supported by `silx.open`.
-
- The result filter out formats when the expected module is not available.
-
- :param bool flat_formats: If true, also include flat formats like npy or
- edf (while the expected module is available)
- :returns: A dictionary indexed by file description and containing a set of
- extensions (an extension is a string like "\\*.ext").
- :rtype: Dict[str, Set[str]]
- """
- formats = collections.OrderedDict()
- formats["HDF5 files"] = set(["*.h5", "*.hdf", "*.hdf5"])
- formats["NeXus files"] = set(["*.nx", "*.nxs", "*.h5", "*.hdf", "*.hdf5"])
- formats["NeXus layout from spec files"] = set(["*.dat", "*.spec", "*.mca"])
- if flat_formats:
- try:
- from silx.io import fabioh5
- except ImportError:
- fabioh5 = None
- if fabioh5 is not None:
- formats["NeXus layout from fabio files"] = set(fabioh5.supported_extensions())
-
- extensions = ["*.npz"]
- if flat_formats:
- extensions.append("*.npy")
-
- formats["Numpy binary files"] = set(extensions)
- formats["Coherent X-Ray Imaging files"] = set(["*.cxi"])
- return formats
-
-
-def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
- fmt="%.7g", csvdelim=";", newline="\n", header="",
- footer="", comments="#", autoheader=False):
- """Saves any number of curves to various formats: `Specfile`, `CSV`,
- `txt` or `npy`. All curves must have the same number of points and share
- the same ``x`` values.
-
- :param fname: Output file path, or file handle open in write mode.
- If ``fname`` is a path, file is opened in ``w`` mode. Existing file
- with a same name will be overwritten.
- :param x: 1D-Array (or list) of abscissa values.
- :param y: 2D-array (or list of lists) of ordinates values. First index
- is the curve index, second index is the sample index. The length
- of the second dimension (number of samples) must be equal to
- ``len(x)``. ``y`` can be a 1D-array in case there is only one curve
- to be saved.
- :param filetype: Filetype: ``"spec", "csv", "txt", "ndarray"``.
- If ``None``, filetype is detected from file name extension
- (``.dat, .csv, .txt, .npy``).
- :param xlabel: Abscissa label
- :param ylabels: List of `y` labels
- :param fmt: Format string for data. You can specify a short format
- string that defines a single format for both ``x`` and ``y`` values,
- or a list of two different format strings (e.g. ``["%d", "%.7g"]``).
- Default is ``"%.7g"``.
- This parameter does not apply to the `npy` format.
- :param csvdelim: String or character separating columns in `txt` and
- `CSV` formats. The user is responsible for ensuring that this
- delimiter is not used in data labels when writing a `CSV` file.
- :param newline: String or character separating lines/records in `txt`
- format (default is line break character ``\\n``).
- :param header: String that will be written at the beginning of the file in
- `txt` format.
- :param footer: String that will be written at the end of the file in `txt`
- format.
- :param comments: String that will be prepended to the ``header`` and
- ``footer`` strings, to mark them as comments. Default: ``#``.
- :param autoheader: In `CSV` or `txt`, ``True`` causes the first header
- line to be written as a standard CSV header line with column labels
- separated by the specified CSV delimiter.
-
- When saving to Specfile format, each curve is saved as a separate scan
- with two data columns (``x`` and ``y``).
-
- `CSV` and `txt` formats are similar, except that the `txt` format allows
- user defined header and footer text blocks, whereas the `CSV` format has
- only a single header line with columns labels separated by field
- delimiters and no footer. The `txt` format also allows defining a record
- separator different from a line break.
-
- The `npy` format is written with ``numpy.save`` and can be read back with
- ``numpy.load``. If ``xlabel`` and ``ylabels`` are undefined, data is saved
- as a regular 2D ``numpy.ndarray`` (contatenation of ``x`` and ``y``). If
- both ``xlabel`` and ``ylabels`` are defined, the data is saved as a
- ``numpy.recarray`` after being transposed and having labels assigned to
- columns.
- """
-
- available_formats = ["spec", "csv", "txt", "ndarray"]
-
- if filetype is None:
- exttypes = {".dat": "spec",
- ".csv": "csv",
- ".txt": "txt",
- ".npy": "ndarray"}
- outfname = (fname if not hasattr(fname, "name") else
- fname.name)
- fileext = os.path.splitext(outfname)[1]
- if fileext in exttypes:
- filetype = exttypes[fileext]
- else:
- raise IOError("File type unspecified and could not be " +
- "inferred from file extension (not in " +
- "txt, dat, csv, npy)")
- else:
- filetype = filetype.lower()
-
- if filetype not in available_formats:
- raise IOError("File type %s is not supported" % (filetype))
-
- # default column headers
- if xlabel is None:
- xlabel = "x"
- if ylabels is None:
- if numpy.array(y).ndim > 1:
- ylabels = ["y%d" % i for i in range(len(y))]
- else:
- ylabels = ["y"]
- elif isinstance(ylabels, (list, tuple)):
- # if ylabels is provided as a list, every element must
- # be a string
- ylabels = [ylabel if isinstance(ylabel, string_types) else "y%d" % i
- for ylabel in ylabels]
-
- if filetype.lower() == "spec":
- # Check if we have regular data:
- ref = len(x)
- regular = True
- for one_y in y:
- regular &= len(one_y) == ref
- if regular:
- if isinstance(fmt, (list, tuple)) and len(fmt) < (len(ylabels) + 1):
- fmt = fmt + [fmt[-1] * (1 + len(ylabels) - len(fmt))]
- specf = savespec(fname, x, y, xlabel, ylabels, fmt=fmt,
- scan_number=1, mode="w", write_file_header=True,
- close_file=False)
- else:
- y_array = numpy.asarray(y)
- # make sure y_array is a 2D array even for a single curve
- if y_array.ndim == 1:
- y_array.shape = 1, -1
- elif y_array.ndim not in [1, 2]:
- raise IndexError("y must be a 1D or 2D array")
-
- # First curve
- specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt,
- scan_number=1, mode="w", write_file_header=True,
- close_file=False)
- # Other curves
- for i in range(1, y_array.shape[0]):
- specf = savespec(specf, x, y_array[i], xlabel, ylabels[i],
- fmt=fmt, scan_number=i + 1, mode="w",
- write_file_header=False, close_file=False)
-
- # close file if we created it
- if not hasattr(fname, "write"):
- specf.close()
-
- else:
- autoheader_line = xlabel + csvdelim + csvdelim.join(ylabels)
- if xlabel is not None and ylabels is not None and filetype == "csv":
- # csv format: optional single header line with labels, no footer
- if autoheader:
- header = autoheader_line + newline
- else:
- header = ""
- comments = ""
- footer = ""
- newline = "\n"
- elif filetype == "txt" and autoheader:
- # Comments string is added at the beginning of header string in
- # savetxt(). We add another one after the first header line and
- # before the rest of the header.
- if header:
- header = autoheader_line + newline + comments + header
- else:
- header = autoheader_line + newline
-
- # Concatenate x and y in a single 2D array
- X = numpy.vstack((x, y))
-
- if filetype.lower() in ["csv", "txt"]:
- X = X.transpose()
- savetxt(fname, X, fmt=fmt, delimiter=csvdelim,
- newline=newline, header=header, footer=footer,
- comments=comments)
-
- elif filetype.lower() == "ndarray":
- if xlabel is not None and ylabels is not None:
- labels = [xlabel] + ylabels
-
- # .transpose is needed here because recarray labels
- # apply to columns
- X = numpy.core.records.fromrecords(X.transpose(),
- names=labels)
- numpy.save(fname, X)
-
-
-# Replace with numpy.savetxt when dropping support of numpy < 1.7.0
-def savetxt(fname, X, fmt="%.7g", delimiter=";", newline="\n",
- header="", footer="", comments="#"):
- """``numpy.savetxt`` backport of header and footer arguments from
- numpy=1.7.0.
-
- See ``numpy.savetxt`` help:
- http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.savetxt.html
- """
- if not hasattr(fname, "name"):
- ffile = builtin_open(fname, 'wb')
- else:
- ffile = fname
-
- if header:
- if sys.version_info[0] >= 3:
- header = header.encode("utf-8")
- ffile.write(header)
-
- numpy.savetxt(ffile, X, fmt, delimiter, newline)
-
- if footer:
- footer = (comments + footer.replace(newline, newline + comments) +
- newline)
- if sys.version_info[0] >= 3:
- footer = footer.encode("utf-8")
- ffile.write(footer)
-
- if not hasattr(fname, "name"):
- ffile.close()
-
-
-def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
- scan_number=1, mode="w", write_file_header=True,
- close_file=False):
- """Saves one curve to a SpecFile.
-
- The curve is saved as a scan with two data columns. To save multiple
- curves to a single SpecFile, call this function for each curve by
- providing the same file handle each time.
-
- :param specfile: Output SpecFile name, or file handle open in write
- or append mode. If a file name is provided, a new file is open in
- write mode (existing file with the same name will be lost)
- :param x: 1D-Array (or list) of abscissa values
- :param y: 1D-array (or list), or list of them of ordinates values.
- All dataset must have the same length as x
- :param xlabel: Abscissa label (default ``"X"``)
- :param ylabel: Ordinate label, may be a list of labels when multiple curves
- are to be saved together.
- :param fmt: Format string for data. You can specify a short format
- string that defines a single format for both ``x`` and ``y`` values,
- or a list of two different format strings (e.g. ``["%d", "%.7g"]``).
- Default is ``"%.7g"``.
- :param scan_number: Scan number (default 1).
- :param mode: Mode for opening file: ``w`` (default), ``a``, ``r+``,
- ``w+``, ``a+``. This parameter is only relevant if ``specfile`` is a
- path.
- :param write_file_header: If ``True``, write a file header before writing
- the scan (``#F`` and ``#D`` line).
- :param close_file: If ``True``, close the file after saving curve.
- :return: ``None`` if ``close_file`` is ``True``, else return the file
- handle.
- """
- # Make sure we use binary mode for write
- # (issue with windows: write() replaces \n with os.linesep in text mode)
- if "b" not in mode:
- first_letter = mode[0]
- assert first_letter in "rwa"
- mode = mode.replace(first_letter, first_letter + "b")
-
- x_array = numpy.asarray(x)
- y_array = numpy.asarray(y)
- if y_array.ndim > 2:
- raise IndexError("Y columns must have be packed as 1D")
-
- if y_array.shape[-1] != x_array.shape[0]:
- raise IndexError("X and Y columns must have the same length")
-
- if y_array.ndim == 2:
- assert isinstance(ylabel, (list, tuple))
- assert y_array.shape[0] == len(ylabel)
- labels = (xlabel, *ylabel)
- else:
- labels = (xlabel, ylabel)
- data = numpy.vstack((x_array, y_array))
- ncol = data.shape[0]
- assert len(labels) == ncol
-
- print(xlabel, ylabel, fmt, ncol, x_array, y_array)
- if isinstance(fmt, string_types) and fmt.count("%") == 1:
- full_fmt_string = " ".join([fmt] * ncol)
- elif isinstance(fmt, (list, tuple)) and len(fmt) == ncol:
- full_fmt_string = " ".join(fmt)
- else:
- raise ValueError("`fmt` must be a single format string or a list of " +
- "format strings with as many format as ncolumns")
-
- if not hasattr(specfile, "write"):
- f = builtin_open(specfile, mode)
- else:
- f = specfile
-
- current_date = "#D %s" % (time.ctime(time.time()))
- if write_file_header:
- lines = [ "#F %s" % f.name, current_date, ""]
- else:
- lines = [""]
-
- lines += [ "#S %d %s" % (scan_number, labels[1]),
- current_date,
- "#N %d" % ncol,
- "#L " + " ".join(labels)]
-
- for i in data.T:
- lines.append(full_fmt_string % tuple(i))
- lines.append("")
- output = "\n".join(lines)
- f.write(output.encode())
-
- if close_file:
- f.close()
- return None
- return f
-
-
-def h5ls(h5group, lvl=0):
- """Return a simple string representation of a HDF5 tree structure.
-
- :param h5group: Any :class:`h5py.Group` or :class:`h5py.File` instance,
- or a HDF5 file name
- :param lvl: Number of tabulations added to the group. ``lvl`` is
- incremented as we recursively process sub-groups.
- :return: String representation of an HDF5 tree structure
-
-
- Group names and dataset representation are printed preceded by a number of
- tabulations corresponding to their depth in the tree structure.
- Datasets are represented as :class:`h5py.Dataset` objects.
-
- Example::
-
- >>> print(h5ls("Downloads/sample.h5"))
- +fields
- +fieldB
- <HDF5 dataset "z": shape (256, 256), type "<f4">
- +fieldE
- <HDF5 dataset "x": shape (256, 256), type "<f4">
- <HDF5 dataset "y": shape (256, 256), type "<f4">
-
- .. note:: This function requires `h5py <http://www.h5py.org/>`_ to be
- installed.
- """
- h5repr = ''
- if is_group(h5group):
- h5f = h5group
- elif isinstance(h5group, string_types):
- h5f = open(h5group) # silx.io.open
- else:
- raise TypeError("h5group must be a hdf5-like group object or a file name.")
-
- for key in h5f.keys():
- # group
- if hasattr(h5f[key], 'keys'):
- h5repr += '\t' * lvl + '+' + key
- h5repr += '\n'
- h5repr += h5ls(h5f[key], lvl + 1)
- # dataset
- else:
- h5repr += '\t' * lvl
- h5repr += str(h5f[key])
- h5repr += '\n'
-
- if isinstance(h5group, string_types):
- h5f.close()
-
- return h5repr
-
-
-def _open_local_file(filename):
- """
- Load a file as an `h5py.File`-like object.
-
- Format supported:
- - h5 files, if `h5py` module is installed
- - SPEC files exposed as a NeXus layout
- - raster files exposed as a NeXus layout (if `fabio` is installed)
- - Numpy files ('npy' and 'npz' files)
-
- The file is opened in read-only mode.
-
- :param str filename: A filename
- :raises: IOError if the file can't be loaded as an h5py.File like object
- :rtype: h5py.File
- """
- if not os.path.isfile(filename):
- raise IOError("Filename '%s' must be a file path" % filename)
-
- debugging_info = []
- try:
- _, extension = os.path.splitext(filename)
-
- if extension in [".npz", ".npy"]:
- try:
- from . import rawh5
- return rawh5.NumpyFile(filename)
- except (IOError, ValueError) as e:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as a numpy file." % filename))
-
- if h5py.is_hdf5(filename):
- try:
- return h5py.File(filename, "r")
- except OSError:
- return h5py.File(filename, "r", libver='latest', swmr=True)
-
- try:
- from . import fabioh5
- return fabioh5.File(filename)
- except ImportError:
- debugging_info.append((sys.exc_info(), "fabioh5 can't be loaded."))
- except Exception:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as fabio file." % filename))
-
- try:
- from . import spech5
- return spech5.SpecH5(filename)
- except ImportError:
- debugging_info.append((sys.exc_info(),
- "spech5 can't be loaded."))
- except IOError:
- debugging_info.append((sys.exc_info(),
- "File '%s' can't be read as spec file." % filename))
- finally:
- for exc_info, message in debugging_info:
- logger.debug(message, exc_info=exc_info)
-
- raise IOError("File '%s' can't be read as HDF5" % filename)
-
-
-class _MainNode(Proxy):
- """A main node is a sub node of the HDF5 tree which is responsible of the
- closure of the file.
-
- It is a proxy to the sub node, plus support context manager and `close`
- method usually provided by `h5py.File`.
-
- :param h5_node: Target to the proxy.
- :param h5_file: Main file. This object became the owner of this file.
- """
-
- def __init__(self, h5_node, h5_file):
- super(_MainNode, self).__init__(h5_node)
- self.__file = h5_file
- self.__class = get_h5_class(h5_node)
-
- @property
- def h5_class(self):
- """Returns the HDF5 class which is mimicked by this class.
-
- :rtype: H5Type
- """
- return self.__class
-
- @property
- def h5py_class(self):
- """Returns the h5py classes which is mimicked by this class. It can be
- one of `h5py.File, h5py.Group` or `h5py.Dataset`.
-
- :rtype: h5py class
- """
- return h5type_to_h5py_class(self.__class)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
- def close(self):
- """Close the file"""
- self.__file.close()
- self.__file = None
-
-
-def open(filename): # pylint:disable=redefined-builtin
- """
- Open a file as an `h5py`-like object.
-
- Format supported:
- - h5 files, if `h5py` module is installed
- - SPEC files exposed as a NeXus layout
- - raster files exposed as a NeXus layout (if `fabio` is installed)
- - Numpy files ('npy' and 'npz' files)
-
- The filename can be trailled an HDF5 path using the separator `::`. In this
- case the object returned is a proxy to the target node, implementing the
- `close` function and supporting `with` context.
-
- The file is opened in read-only mode.
-
- :param str filename: A filename which can containt an HDF5 path by using
- `::` separator.
- :raises: IOError if the file can't be loaded or path can't be found
- :rtype: h5py-like node
- """
- url = silx.io.url.DataUrl(filename)
-
- if url.scheme() in [None, "file", "silx"]:
- # That's a local file
- if not url.is_valid():
- raise IOError("URL '%s' is not valid" % filename)
- h5_file = _open_local_file(url.file_path())
- elif url.scheme() in ["fabio"]:
- raise IOError("URL '%s' containing fabio scheme is not supported" % filename)
- else:
- # That's maybe an URL supported by h5pyd
- uri = six.moves.urllib.parse.urlparse(filename)
- if h5pyd is None:
- raise IOError("URL '%s' unsupported. Try to install h5pyd." % filename)
- path = uri.path
- endpoint = "%s://%s" % (uri.scheme, uri.netloc)
- if path.startswith("/"):
- path = path[1:]
- return h5pyd.File(path, 'r', endpoint=endpoint)
-
- if url.data_slice():
- raise IOError("URL '%s' containing slicing is not supported" % filename)
-
- if url.data_path() in [None, "/", ""]:
- # The full file is requested
- return h5_file
- else:
- # Only a children is requested
- if url.data_path() not in h5_file:
- msg = "File '%s' does not contain path '%s'." % (filename, url.data_path())
- raise IOError(msg)
- node = h5_file[url.data_path()]
- proxy = _MainNode(node, h5_file)
- return proxy
-
-
-def _get_classes_type():
- """Returns a mapping between Python classes and HDF5 concepts.
-
- This function allow an lazy initialization to avoid recurssive import
- of modules.
- """
- global _CLASSES_TYPE
- from . import commonh5
-
- if _CLASSES_TYPE is not None:
- return _CLASSES_TYPE
-
- _CLASSES_TYPE = collections.OrderedDict()
-
- _CLASSES_TYPE[commonh5.Dataset] = H5Type.DATASET
- _CLASSES_TYPE[commonh5.File] = H5Type.FILE
- _CLASSES_TYPE[commonh5.Group] = H5Type.GROUP
- _CLASSES_TYPE[commonh5.SoftLink] = H5Type.SOFT_LINK
-
- _CLASSES_TYPE[h5py.Dataset] = H5Type.DATASET
- _CLASSES_TYPE[h5py.File] = H5Type.FILE
- _CLASSES_TYPE[h5py.Group] = H5Type.GROUP
- _CLASSES_TYPE[h5py.SoftLink] = H5Type.SOFT_LINK
- _CLASSES_TYPE[h5py.HardLink] = H5Type.HARD_LINK
- _CLASSES_TYPE[h5py.ExternalLink] = H5Type.EXTERNAL_LINK
-
- if h5pyd is not None:
- _CLASSES_TYPE[h5pyd.Dataset] = H5Type.DATASET
- _CLASSES_TYPE[h5pyd.File] = H5Type.FILE
- _CLASSES_TYPE[h5pyd.Group] = H5Type.GROUP
- _CLASSES_TYPE[h5pyd.SoftLink] = H5Type.SOFT_LINK
- _CLASSES_TYPE[h5pyd.HardLink] = H5Type.HARD_LINK
- _CLASSES_TYPE[h5pyd.ExternalLink] = H5Type.EXTERNAL_LINK
-
- return _CLASSES_TYPE
-
-
-def get_h5_class(obj=None, class_=None):
- """
- Returns the HDF5 type relative to the object or to the class.
-
- :param obj: Instance of an object
- :param class_: A class
- :rtype: H5Type
- """
- if class_ is None:
- class_ = obj.__class__
-
- classes = _get_classes_type()
- t = classes.get(class_, None)
- if t is not None:
- return t
-
- if obj is not None:
- if hasattr(obj, "h5_class"):
- return obj.h5_class
-
- for referencedClass_, type_ in classes.items():
- if issubclass(class_, referencedClass_):
- classes[class_] = type_
- return type_
-
- classes[class_] = None
- return None
-
-
-def h5type_to_h5py_class(type_):
- """
- Returns an h5py class from an H5Type. None if nothing found.
-
- :param H5Type type_:
- :rtype: H5py class
- """
- if type_ == H5Type.FILE:
- return h5py.File
- if type_ == H5Type.GROUP:
- return h5py.Group
- if type_ == H5Type.DATASET:
- return h5py.Dataset
- if type_ == H5Type.SOFT_LINK:
- return h5py.SoftLink
- if type_ == H5Type.HARD_LINK:
- return h5py.HardLink
- if type_ == H5Type.EXTERNAL_LINK:
- return h5py.ExternalLink
- return None
-
-
-def get_h5py_class(obj):
- """Returns the h5py class from an object.
-
- If it is an h5py object or an h5py-like object, an h5py class is returned.
- If the object is not an h5py-like object, None is returned.
-
- :param obj: An object
- :return: An h5py object
- """
- if hasattr(obj, "h5py_class"):
- return obj.h5py_class
- type_ = get_h5_class(obj)
- return h5type_to_h5py_class(type_)
-
-
-def is_file(obj):
- """
- True is the object is an h5py.File-like object.
-
- :param obj: An object
- """
- t = get_h5_class(obj)
- return t == H5Type.FILE
-
-
-def is_group(obj):
- """
- True if the object is a h5py.Group-like object. A file is a group.
-
- :param obj: An object
- """
- t = get_h5_class(obj)
- return t in [H5Type.GROUP, H5Type.FILE]
-
-
-def is_dataset(obj):
- """
- True if the object is a h5py.Dataset-like object.
-
- :param obj: An object
- """
- t = get_h5_class(obj)
- return t == H5Type.DATASET
-
-
-def is_softlink(obj):
- """
- True if the object is a h5py.SoftLink-like object.
-
- :param obj: An object
- """
- t = get_h5_class(obj)
- return t == H5Type.SOFT_LINK
-
-
-def is_externallink(obj):
- """
- True if the object is a h5py.ExternalLink-like object.
-
- :param obj: An object
- """
- t = get_h5_class(obj)
- return t == H5Type.EXTERNAL_LINK
-
-
-def is_link(obj):
- """
- True if the object is a h5py link-like object.
-
- :param obj: An object
- """
- t = get_h5_class(obj)
- return t in {H5Type.SOFT_LINK, H5Type.EXTERNAL_LINK}
-
-
-def get_data(url):
- """Returns a numpy data from an URL.
-
- Examples:
-
- >>> # 1st frame from an EDF using silx.io.open
- >>> data = silx.io.get_data("silx:/users/foo/image.edf::/scan_0/instrument/detector_0/data[0]")
-
- >>> # 1st frame from an EDF using fabio
- >>> data = silx.io.get_data("fabio:/users/foo/image.edf::[0]")
-
- Yet 2 schemes are supported by the function.
-
- - If `silx` scheme is used, the file is opened using
- :meth:`silx.io.open`
- and the data is reach using usually NeXus paths.
- - If `fabio` scheme is used, the file is opened using :meth:`fabio.open`
- from the FabIO library.
- No data path have to be specified, but each frames can be accessed
- using the data slicing.
- This shortcut of :meth:`silx.io.open` allow to have a faster access to
- the data.
-
- .. seealso:: :class:`silx.io.url.DataUrl`
-
- :param Union[str,silx.io.url.DataUrl]: A data URL
- :rtype: Union[numpy.ndarray, numpy.generic]
- :raises ImportError: If the mandatory library to read the file is not
- available.
- :raises ValueError: If the URL is not valid or do not match the data
- :raises IOError: If the file is not found or in case of internal error of
- :meth:`fabio.open` or :meth:`silx.io.open`. In this last case more
- informations are displayed in debug mode.
- """
- if not isinstance(url, silx.io.url.DataUrl):
- url = silx.io.url.DataUrl(url)
-
- if not url.is_valid():
- raise ValueError("URL '%s' is not valid" % url.path())
-
- if not os.path.exists(url.file_path()):
- raise IOError("File '%s' not found" % url.file_path())
-
- if url.scheme() == "silx":
- data_path = url.data_path()
- data_slice = url.data_slice()
-
- with open(url.file_path()) as h5:
- if data_path not in h5:
- raise ValueError("Data path from URL '%s' not found" % url.path())
- data = h5[data_path]
-
- if not silx.io.is_dataset(data):
- raise ValueError("Data path from URL '%s' is not a dataset" % url.path())
-
- if data_slice is not None:
- data = h5py_read_dataset(data, index=data_slice)
- else:
- # works for scalar and array
- data = h5py_read_dataset(data)
-
- elif url.scheme() == "fabio":
- import fabio
- data_slice = url.data_slice()
- if data_slice is None:
- data_slice = (0,)
- if data_slice is None or len(data_slice) != 1:
- raise ValueError("Fabio slice expect a single frame, but %s found" % data_slice)
- index = data_slice[0]
- if not isinstance(index, int):
- raise ValueError("Fabio slice expect a single integer, but %s found" % data_slice)
-
- try:
- fabio_file = fabio.open(url.file_path())
- except Exception:
- logger.debug("Error while opening %s with fabio", url.file_path(), exc_info=True)
- raise IOError("Error while opening %s with fabio (use debug for more information)" % url.path())
-
- if fabio_file.nframes == 1:
- if index != 0:
- raise ValueError("Only a single frame available. Slice %s out of range" % index)
- data = fabio_file.data
- else:
- data = fabio_file.getframe(index).data
-
- # There is no explicit close
- fabio_file = None
-
- else:
- raise ValueError("Scheme '%s' not supported" % url.scheme())
-
- return data
-
-
-def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype,
- overwrite=False):
- """
- Create a HDF5 dataset at `output_url` pointing to the given vol_file.
-
- Either `shape` or `info_file` must be provided.
-
- :param str bin_file: Path to the .vol file
- :param DataUrl output_url: HDF5 URL where to save the external dataset
- :param tuple shape: Shape of the volume
- :param numpy.dtype dtype: Data type of the volume elements (default: float32)
- :param bool overwrite: True to allow overwriting (default: False).
- """
- assert isinstance(output_url, silx.io.url.DataUrl)
- assert isinstance(shape, (tuple, list))
- v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
- if calc_hexversion(v_majeur, v_mineur, v_micro)< calc_hexversion(2,9,0):
- raise Exception('h5py >= 2.9 should be installed to access the '
- 'external feature.')
-
- with h5py.File(output_url.file_path(), mode="a") as _h5_file:
- if output_url.data_path() in _h5_file:
- if overwrite is False:
- raise ValueError('data_path already exists')
- else:
- logger.warning('will overwrite path %s' % output_url.data_path())
- del _h5_file[output_url.data_path()]
- external = [(bin_file, 0, h5py.h5f.UNLIMITED)]
- _h5_file.create_dataset(output_url.data_path(),
- shape,
- dtype=dtype,
- external=external)
-
-
-def vol_to_h5_external_dataset(vol_file, output_url, info_file=None,
- vol_dtype=numpy.float32, overwrite=False):
- """
- Create a HDF5 dataset at `output_url` pointing to the given vol_file.
-
- If the vol_file.info containing the shape is not on the same folder as the
- vol-file then you should specify her location.
-
- :param str vol_file: Path to the .vol file
- :param DataUrl output_url: HDF5 URL where to save the external dataset
- :param Union[str,None] info_file:
- .vol.info file name written by pyhst and containing the shape information
- :param numpy.dtype vol_dtype: Data type of the volume elements (default: float32)
- :param bool overwrite: True to allow overwriting (default: False).
- :raises ValueError: If fails to read shape from the .vol.info file
- """
- _info_file = info_file
- if _info_file is None:
- _info_file = vol_file + '.info'
- if not os.path.exists(_info_file):
- logger.error('info_file not given and %s does not exists, please'
- 'specify .vol.info file' % _info_file)
- return
-
- def info_file_to_dict():
- ddict = {}
- with builtin_open(info_file, "r") as _file:
- lines = _file.readlines()
- for line in lines:
- if not '=' in line:
- continue
- l = line.rstrip().replace(' ', '')
- l = l.split('#')[0]
- key, value = l.split('=')
- ddict[key.lower()] = value
- return ddict
-
- ddict = info_file_to_dict()
- if 'num_x' not in ddict or 'num_y' not in ddict or 'num_z' not in ddict:
- raise ValueError(
- 'Unable to retrieve volume shape from %s' % info_file)
-
- dimX = int(ddict['num_x'])
- dimY = int(ddict['num_y'])
- dimZ = int(ddict['num_z'])
- shape = (dimZ, dimY, dimX)
-
- return rawfile_to_h5_external_dataset(bin_file=vol_file,
- output_url=output_url,
- shape=shape,
- dtype=vol_dtype,
- overwrite=overwrite)
-
-
-def h5py_decode_value(value, encoding="utf-8", errors="surrogateescape"):
- """Keep bytes when value cannot be decoded
-
- :param value: bytes or array of bytes
- :param encoding str:
- :param errors str:
- """
- try:
- if numpy.isscalar(value):
- return value.decode(encoding, errors=errors)
- str_item = [b.decode(encoding, errors=errors) for b in value.flat]
- return numpy.array(str_item, dtype=object).reshape(value.shape)
- except UnicodeDecodeError:
- return value
-
-
-def h5py_encode_value(value, encoding="utf-8", errors="surrogateescape"):
- """Keep string when value cannot be encoding
-
- :param value: string or array of strings
- :param encoding str:
- :param errors str:
- """
- try:
- if numpy.isscalar(value):
- return value.encode(encoding, errors=errors)
- bytes_item = [s.encode(encoding, errors=errors) for s in value.flat]
- return numpy.array(bytes_item, dtype=object).reshape(value.shape)
- except UnicodeEncodeError:
- return value
-
-
-class H5pyDatasetReadWrapper:
- """Wrapper to handle H5T_STRING decoding on-the-fly when reading
- a dataset. Uniform behaviour for h5py 2.x and h5py 3.x
-
- h5py abuses H5T_STRING with ASCII character set
- to store `bytes`: dset[()] = b"..."
- Therefore an H5T_STRING with ASCII encoding is not decoded by default.
- """
-
- H5PY_AUTODECODE_NONASCII = int(h5py.version.version.split(".")[0]) < 3
-
- def __init__(self, dset, decode_ascii=False):
- """
- :param h5py.Dataset dset:
- :param bool decode_ascii:
- """
- try:
- string_info = h5py.h5t.check_string_dtype(dset.dtype)
- except AttributeError:
- # h5py < 2.10
- try:
- idx = dset.id.get_type().get_cset()
- except AttributeError:
- # Not an H5T_STRING
- encoding = None
- else:
- encoding = ["ascii", "utf-8"][idx]
- else:
- # h5py >= 2.10
- try:
- encoding = string_info.encoding
- except AttributeError:
- # Not an H5T_STRING
- encoding = None
- if encoding == "ascii" and not decode_ascii:
- encoding = None
- if encoding != "ascii" and self.H5PY_AUTODECODE_NONASCII:
- # Decoding is already done by the h5py library
- encoding = None
- if encoding == "ascii":
- # ASCII can be decoded as UTF-8
- encoding = "utf-8"
- self._encoding = encoding
- self._dset = dset
-
- def __getitem__(self, args):
- value = self._dset[args]
- if self._encoding:
- return h5py_decode_value(value, encoding=self._encoding)
- else:
- return value
-
-
-class H5pyAttributesReadWrapper:
- """Wrapper to handle H5T_STRING decoding on-the-fly when reading
- an attribute. Uniform behaviour for h5py 2.x and h5py 3.x
-
- h5py abuses H5T_STRING with ASCII character set
- to store `bytes`: dset[()] = b"..."
- Therefore an H5T_STRING with ASCII encoding is not decoded by default.
- """
-
- H5PY_AUTODECODE = int(h5py.version.version.split(".")[0]) >= 3
-
- def __init__(self, attrs, decode_ascii=False):
- """
- :param h5py.Dataset dset:
- :param bool decode_ascii:
- """
- self._attrs = attrs
- self._decode_ascii = decode_ascii
-
- def __getitem__(self, args):
- value = self._attrs[args]
-
- # Get the string encoding (if a string)
- try:
- dtype = self._attrs.get_id(args).dtype
- except AttributeError:
- # h5py < 2.10
- attr_id = h5py.h5a.open(self._attrs._id, self._attrs._e(args))
- try:
- idx = attr_id.get_type().get_cset()
- except AttributeError:
- # Not an H5T_STRING
- return value
- else:
- encoding = ["ascii", "utf-8"][idx]
- else:
- # h5py >= 2.10
- try:
- encoding = h5py.h5t.check_string_dtype(dtype).encoding
- except AttributeError:
- # Not an H5T_STRING
- return value
-
- if self.H5PY_AUTODECODE:
- if encoding == "ascii" and not self._decode_ascii:
- # Undo decoding by the h5py library
- return h5py_encode_value(value, encoding="utf-8")
- else:
- if encoding == "ascii" and self._decode_ascii:
- # Decode ASCII as UTF-8 for consistency
- return h5py_decode_value(value, encoding="utf-8")
-
- # Decoding is already done by the h5py library
- return value
-
- def items(self):
- for k in self._attrs.keys():
- yield k, self[k]
-
-
-def h5py_read_dataset(dset, index=tuple(), decode_ascii=False):
- """Read data from dataset object. UTF-8 strings will be
- decoded while ASCII strings will only be decoded when
- `decode_ascii=True`.
-
- :param h5py.Dataset dset:
- :param index: slicing (all by default)
- :param bool decode_ascii:
- """
- return H5pyDatasetReadWrapper(dset, decode_ascii=decode_ascii)[index]
-
-
-def h5py_read_attribute(attrs, name, decode_ascii=False):
- """Read data from attributes. UTF-8 strings will be
- decoded while ASCII strings will only be decoded when
- `decode_ascii=True`.
-
- :param h5py.AttributeManager attrs:
- :param str name: attribute name
- :param bool decode_ascii:
- """
- return H5pyAttributesReadWrapper(attrs, decode_ascii=decode_ascii)[name]
-
-
-def h5py_read_attributes(attrs, decode_ascii=False):
- """Read data from attributes. UTF-8 strings will be
- decoded while ASCII strings will only be decoded when
- `decode_ascii=True`.
-
- :param h5py.AttributeManager attrs:
- :param bool decode_ascii:
- """
- return dict(H5pyAttributesReadWrapper(attrs, decode_ascii=decode_ascii).items())
diff --git a/silx/math/colormap.pyx b/silx/math/colormap.pyx
deleted file mode 100644
index 2cefe04..0000000
--- a/silx/math/colormap.pyx
+++ /dev/null
@@ -1,559 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""This module provides :func:`cmap` which applies a colormap to a dataset.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "16/05/2018"
-
-
-import os
-cimport cython
-from cython.parallel import prange
-cimport numpy as cnumpy
-from libc.math cimport frexp, sinh, sqrt
-from .math_compatibility cimport asinh, isnan, isfinite, lrint, INFINITY, NAN
-
-import logging
-import numbers
-
-import numpy
-
-__all__ = ['cmap']
-
-_logger = logging.getLogger(__name__)
-
-
-cdef int DEFAULT_NUM_THREADS
-if hasattr(os, 'sched_getaffinity'):
- DEFAULT_NUM_THREADS = min(4, len(os.sched_getaffinity(0)))
-elif os.cpu_count() is not None:
- DEFAULT_NUM_THREADS = min(4, os.cpu_count())
-else: # Fallback
- DEFAULT_NUM_THREADS = 1
-# Number of threads to use for the computation (initialized to up to 4)
-
-cdef int USE_OPENMP_THRESHOLD = 1000
-"""OpenMP is not used for arrays with less elements than this threshold"""
-
-# Supported data types
-ctypedef fused data_types:
- cnumpy.uint8_t
- cnumpy.int8_t
- cnumpy.uint16_t
- cnumpy.int16_t
- cnumpy.uint32_t
- cnumpy.int32_t
- cnumpy.uint64_t
- cnumpy.int64_t
- float
- double
- long double
-
-
-# Data types using a LUT to apply the colormap
-ctypedef fused lut_types:
- cnumpy.uint8_t
- cnumpy.int8_t
- cnumpy.uint16_t
- cnumpy.int16_t
-
-
-# Data types using default colormap implementation
-ctypedef fused default_types:
- cnumpy.uint32_t
- cnumpy.int32_t
- cnumpy.uint64_t
- cnumpy.int64_t
- float
- double
- long double
-
-
-# Supported colors/output types
-ctypedef fused image_types:
- cnumpy.uint8_t
- float
-
-
-# Normalization
-
-ctypedef double (*NormalizationFunction)(double) nogil
-
-
-cdef class Normalization:
- """Base class for colormap normalization"""
-
- def apply(self, data, double vmin, double vmax):
- """Apply normalization.
-
- :param Union[float,numpy.ndarray] data:
- :param float vmin: Lower bound of the range
- :param float vmax: Upper bound of the range
- :rtype: Union[float,numpy.ndarray]
- """
- cdef int length
- cdef double[:] result
-
- if isinstance(data, numbers.Real):
- return self.apply_double(<double> data, vmin, vmax)
- else:
- data = numpy.array(data, copy=False)
- length = <int> data.size
- result = numpy.empty(length, dtype=numpy.float64)
- data1d = numpy.ravel(data)
- for index in range(length):
- result[index] = self.apply_double(
- <double> data1d[index], vmin, vmax)
- return numpy.array(result).reshape(data.shape)
-
- def revert(self, data, double vmin, double vmax):
- """Revert normalization.
-
- :param Union[float,numpy.ndarray] data:
- :param float vmin: Lower bound of the range
- :param float vmax: Upper bound of the range
- :rtype: Union[float,numpy.ndarray]
- """
- cdef int length
- cdef double[:] result
-
- if isinstance(data, numbers.Real):
- return self.revert_double(<double> data, vmin, vmax)
- else:
- data = numpy.array(data, copy=False)
- length = <int> data.size
- result = numpy.empty(length, dtype=numpy.float64)
- data1d = numpy.ravel(data)
- for index in range(length):
- result[index] = self.revert_double(
- <double> data1d[index], vmin, vmax)
- return numpy.array(result).reshape(data.shape)
-
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
- """Apply normalization to a floating point value
-
- Override in subclass
-
- :param float value:
- :param float vmin: Lower bound of the range
- :param float vmax: Upper bound of the range
- """
- return value
-
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- """Apply inverse of normalization to a floating point value
-
- Override in subclass
-
- :param float value:
- :param float vmin: Lower bound of the range
- :param float vmax: Upper bound of the range
- """
- return value
-
-
-cdef class LinearNormalization(Normalization):
- """Linear normalization"""
-
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
- return value
-
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- return value
-
-
-cdef class LogarithmicNormalization(Normalization):
- """Logarithmic normalization using a fast log approximation"""
- cdef:
- readonly int lutsize
- readonly double[::1] lut # LUT used for fast log approximation
-
- def __cinit__(self, int lutsize=4096):
- # Initialize log approximation LUT
- self.lutsize = lutsize
- self.lut = numpy.log2(
- numpy.linspace(0.5, 1., lutsize + 1,
- endpoint=True).astype(numpy.float64))
- # index_lut can overflow of 1
- self.lut[lutsize] = self.lut[lutsize - 1]
-
- def __dealloc__(self):
- self.lut = None
-
- @cython.wraparound(False)
- @cython.boundscheck(False)
- @cython.nonecheck(False)
- @cython.cdivision(True)
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
- """Return log10(value) fast approximation based on LUT"""
- cdef double result = NAN # if value < 0.0 or value == NAN
- cdef int exponent, index_lut
- cdef double mantissa # in [0.5, 1) unless value == 0 NaN or +/-inf
-
- if value <= 0.0 or not isfinite(value):
- if value == 0.0:
- result = - INFINITY
- elif value > 0.0: # i.e., value = +INFINITY
- result = value # i.e. +INFINITY
- else:
- mantissa = frexp(value, &exponent)
- index_lut = lrint(self.lutsize * 2 * (mantissa - 0.5))
- # 1/log2(10) = 0.30102999566398114
- result = 0.30102999566398114 * (<double> exponent +
- self.lut[index_lut])
- return result
-
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- return 10**value
-
-
-cdef class ArcsinhNormalization(Normalization):
- """Inverse hyperbolic sine normalization"""
-
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
- return asinh(value)
-
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- return sinh(value)
-
-
-cdef class SqrtNormalization(Normalization):
- """Square root normalization"""
-
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
- return sqrt(value)
-
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- return value**2
-
-
-cdef class PowerNormalization(Normalization):
- """Gamma correction:
-
- Linear normalization to [0, 1] followed by power normalization.
-
- :param gamma: Gamma correction factor
- """
-
- cdef:
- readonly double gamma
-
- def __cinit__(self, double gamma):
- self.gamma = gamma
-
- def __init__(self, gamma):
- # Needed for multiple inheritance to work
- pass
-
- cdef double apply_double(self, double value, double vmin, double vmax) nogil:
- if vmin == vmax:
- return 0.
- elif value <= vmin:
- return 0.
- elif value >= vmax:
- return 1.
- else:
- return ((value - vmin) / (vmax - vmin))**self.gamma
-
- cdef double revert_double(self, double value, double vmin, double vmax) nogil:
- if value <= 0.:
- return vmin
- elif value >= 1.:
- return vmax
- else:
- return vmin + (vmax - vmin) * value**(1.0/self.gamma)
-
-
-# Colormap
-
-@cython.wraparound(False)
-@cython.boundscheck(False)
-@cython.nonecheck(False)
-@cython.cdivision(True)
-cdef image_types[:, ::1] compute_cmap(
- default_types[:] data,
- image_types[:, ::1] colors,
- Normalization normalization,
- double vmin,
- double vmax,
- image_types[::1] nan_color):
- """Apply colormap to data.
-
- :param data: Input data
- :param colors: Colors look-up-table
- :param vmin: Lower bound of the colormap range
- :param vmax: Upper bound of the colormap range
- :param nan_color: Color to use for NaN value
- :param normalization: Normalization to apply
- :return: Data converted to colors
- """
- cdef image_types[:, ::1] output
- cdef double scale, value, normalized_vmin, normalized_vmax
- cdef int length, nb_channels, nb_colors
- cdef int channel, index, lut_index, num_threads
-
- nb_colors = <int> colors.shape[0]
- nb_channels = <int> colors.shape[1]
- length = <int> data.size
-
- output = numpy.empty((length, nb_channels),
- dtype=numpy.array(colors, copy=False).dtype)
-
- normalized_vmin = normalization.apply_double(vmin, vmin, vmax)
- normalized_vmax = normalization.apply_double(vmax, vmin, vmax)
-
- if not isfinite(normalized_vmin) or not isfinite(normalized_vmax):
- raise ValueError('Colormap range is not valid')
-
- if normalized_vmin == normalized_vmax:
- scale = 0.
- else:
- scale = nb_colors / (normalized_vmax - normalized_vmin)
-
- if length < USE_OPENMP_THRESHOLD:
- num_threads = 1
- else:
- num_threads = min(
- DEFAULT_NUM_THREADS,
- int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)))
-
- with nogil:
- for index in prange(length, num_threads=num_threads):
- value = normalization.apply_double(
- <double> data[index], vmin, vmax)
-
- # Handle NaN
- if isnan(value):
- for channel in range(nb_channels):
- output[index, channel] = nan_color[channel]
- continue
-
- if value <= normalized_vmin:
- lut_index = 0
- elif value >= normalized_vmax:
- lut_index = nb_colors - 1
- else:
- lut_index = <int>((value - normalized_vmin) * scale)
- # Index can overflow of 1
- if lut_index >= nb_colors:
- lut_index = nb_colors - 1
-
- for channel in range(nb_channels):
- output[index, channel] = colors[lut_index, channel]
-
- return output
-
-@cython.wraparound(False)
-@cython.boundscheck(False)
-@cython.nonecheck(False)
-@cython.cdivision(True)
-cdef image_types[:, ::1] compute_cmap_with_lut(
- lut_types[:] data,
- image_types[:, ::1] colors,
- Normalization normalization,
- double vmin,
- double vmax,
- image_types[::1] nan_color):
- """Convert data to colors using look-up table to speed the process.
-
- Only supports data of types: uint8, uint16, int8, int16.
-
- :param data: Input data
- :param colors: Colors look-up-table
- :param vmin: Lower bound of the colormap range
- :param vmax: Upper bound of the colormap range
- :param nan_color: Color to use for NaN values
- :param normalization: Normalization to apply
- :return: The generated image
- """
- cdef image_types[:, ::1] output
- cdef double[:] values
- cdef image_types[:, ::1] lut
- cdef int type_min, type_max
- cdef int nb_channels, length
- cdef int channel, index, lut_index, num_threads
-
- length = <int> data.size
- nb_channels = <int> colors.shape[1]
-
- if lut_types is cnumpy.int8_t:
- type_min = -128
- type_max = 127
- elif lut_types is cnumpy.uint8_t:
- type_min = 0
- type_max = 255
- elif lut_types is cnumpy.int16_t:
- type_min = -32768
- type_max = 32767
- else: # uint16_t
- type_min = 0
- type_max = 65535
-
- colors_dtype = numpy.array(colors).dtype
-
- values = numpy.arange(type_min, type_max + 1, dtype=numpy.float64)
- lut = compute_cmap(
- values, colors, normalization, vmin, vmax, nan_color)
-
- output = numpy.empty((length, nb_channels), dtype=colors_dtype)
-
- if length < USE_OPENMP_THRESHOLD:
- num_threads = 1
- else:
- num_threads = min(
- DEFAULT_NUM_THREADS,
- int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)))
-
- with nogil:
- # Apply LUT
- for index in prange(length, num_threads=num_threads):
- lut_index = data[index] - type_min
- for channel in range(nb_channels):
- output[index, channel] = lut[lut_index, channel]
-
- return output
-
-
-# Normalizations without parameters
-_BASIC_NORMALIZATIONS = {
- 'linear': LinearNormalization(),
- 'log': LogarithmicNormalization(),
- 'arcsinh': ArcsinhNormalization(),
- 'sqrt': SqrtNormalization(),
- }
-
-
-@cython.wraparound(False)
-@cython.boundscheck(False)
-@cython.nonecheck(False)
-@cython.cdivision(True)
-def _cmap(data_types[:] data,
- image_types[:, ::1] colors,
- Normalization normalization,
- double vmin,
- double vmax,
- image_types[::1] nan_color):
- """Implementation of colormap.
-
- Use :func:`cmap`.
-
- :param data: Input data
- :param colors: Colors look-up-table
- :param normalization: Normalization object to apply
- :param vmin: Lower bound of the colormap range
- :param vmax: Upper bound of the colormap range
- :param nan_color: Color to use for NaN value.
- :return: The generated image
- """
- cdef image_types[:, ::1] output
-
- # Proxy for calling the right implementation depending on data type
- if data_types in lut_types: # Use LUT implementation
- output = compute_cmap_with_lut(
- data, colors, normalization, vmin, vmax, nan_color)
-
- elif data_types in default_types: # Use default implementation
- output = compute_cmap(
- data, colors, normalization, vmin, vmax, nan_color)
-
- else:
- raise ValueError('Unsupported data type')
-
- return numpy.array(output, copy=False)
-
-
-def cmap(data,
- colors,
- double vmin,
- double vmax,
- normalization='linear',
- nan_color=None):
- """Convert data to colors with provided colors look-up table.
-
- :param numpy.ndarray data: The input data
- :param numpy.ndarray colors: Color look-up table as a 2D array.
- It MUST be of type uint8 or float32
- :param vmin: Data value to map to the beginning of colormap.
- :param vmax: Data value to map to the end of the colormap.
- :param Union[str,Normalization] normalization:
- Either a :class:`Normalization` instance or a str in:
-
- - 'linear' (default)
- - 'log'
- - 'arcsinh'
- - 'sqrt'
- - 'gamma'
-
- :param nan_color: Color to use for NaN value.
- Default: A color with all channels set to 0
- :return: Array of colors. The shape of the
- returned array is that of data array + the last dimension of colors.
- The dtype of the returned array is that of the colors array.
- :rtype: numpy.ndarray
- """
- cdef int nb_channels
- cdef Normalization norm
-
- # Make data a numpy array of native endian type (no need for contiguity)
- data = numpy.array(data, copy=False)
- native_endian_dtype = data.dtype.newbyteorder('N')
- if native_endian_dtype.kind == 'f' and native_endian_dtype.itemsize == 2:
- native_endian_dtype = "=f4" # Use native float32 instead of float16
- data = numpy.array(data, copy=False, dtype=native_endian_dtype)
-
- # Make colors a contiguous array of native endian type
- colors = numpy.array(colors, copy=False)
- nb_channels = colors.shape[colors.ndim - 1]
- colors = numpy.ascontiguousarray(colors,
- dtype=colors.dtype.newbyteorder('N'))
-
- # Make normalization a Normalization object
- if isinstance(normalization, str):
- norm = _BASIC_NORMALIZATIONS.get(normalization, None)
- if norm is None:
- raise ValueError('Unsupported normalization %s' % normalization)
- else:
- norm = normalization
-
- # Check nan_color
- if nan_color is None:
- nan_color = numpy.zeros((nb_channels,), dtype=colors.dtype)
- else:
- nan_color = numpy.ascontiguousarray(
- nan_color, dtype=colors.dtype).reshape(-1)
- assert nan_color.shape == (nb_channels,)
-
- image = _cmap(
- data.reshape(-1),
- colors.reshape(-1, nb_channels),
- norm,
- vmin,
- vmax,
- nan_color)
- image.shape = data.shape + (nb_channels,)
-
- return image
diff --git a/silx/math/fft/test/__init__.py b/silx/math/fft/test/__init__.py
deleted file mode 100644
index 83f8926..0000000
--- a/silx/math/fft/test/__init__.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-from .test_fft import suite
diff --git a/silx/math/fft/test/test_fft.py b/silx/math/fft/test/test_fft.py
deleted file mode 100644
index 9ef2fd2..0000000
--- a/silx/math/fft/test/test_fft.py
+++ /dev/null
@@ -1,270 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of the FFT module"""
-
-import numpy as np
-import unittest
-import logging
-try:
- from scipy.misc import ascent
- __have_scipy = True
-except ImportError:
- __have_scipy = False
-from silx.utils.testutils import ParametricTestCase
-from silx.math.fft.fft import FFT
-from silx.math.fft.clfft import __have_clfft__
-from silx.math.fft.cufft import __have_cufft__
-from silx.math.fft.fftw import __have_fftw__
-
-from silx.test.utils import test_options
-
-logger = logging.getLogger(__name__)
-
-
-class TransformInfos(object):
- def __init__(self):
- self.dimensions = [
- "1D",
- "batched_1D",
- "2D",
- "batched_2D",
- "3D",
- ]
- self.modes = {
- "R2C": np.float32,
- "R2C_double": np.float64,
- "C2C": np.complex64,
- "C2C_double": np.complex128,
- }
- self.sizes = {
- "1D": [(128,), (127,)],
- "2D": [(128, 128), (128, 127), (127, 128), (127, 127)],
- "3D": [(64, 64, 64), (64, 64, 63), (64, 63, 64), (63, 64, 64),
- (64, 63, 63), (63, 64, 63), (63, 63, 64), (63, 63, 63)]
- }
- self.axes = {
- "1D": None,
- "batched_1D": (-1,),
- "2D": None,
- "batched_2D": (-2, -1),
- "3D": None,
- }
- self.sizes["batched_1D"] = self.sizes["2D"]
- self.sizes["batched_2D"] = self.sizes["3D"]
-
-
-class TestData(object):
- def __init__(self):
- self.data = ascent().astype("float32")
- self.data1d = self.data[:, 0] # non-contiguous data
- self.data3d = np.tile(self.data[:64, :64], (64, 1, 1))
- self.data_refs = {
- 1: self.data1d,
- 2: self.data,
- 3: self.data3d,
- }
-
-
-@unittest.skipUnless(__have_scipy, "scipy is missing")
-class TestFFT(ParametricTestCase):
- """Test cuda/opencl/fftw backends of FFT"""
-
- def setUp(self):
- self.tol = {
- np.dtype("float32"): 1e-3,
- np.dtype("float64"): 1e-9,
- np.dtype("complex64"): 1e-3,
- np.dtype("complex128"): 1e-9,
- }
- self.transform_infos = TransformInfos()
- self.test_data = TestData()
-
- @staticmethod
- def calc_mae(arr1, arr2):
- """
- Compute the Max Absolute Error between two arrays
- """
- return np.max(np.abs(arr1 - arr2))
-
- @unittest.skipIf(not __have_cufft__,
- "cuda back-end requires pycuda and scikit-cuda")
- def test_cuda(self):
- import pycuda.autoinit
-
- # Error is higher when using cuda. fast_math mode ?
- self.tol[np.dtype("float32")] *= 2
-
- self.__run_tests(backend="cuda")
-
- @unittest.skipIf(not __have_clfft__,
- "opencl back-end requires pyopencl and gpyfft")
- def test_opencl(self):
- from silx.opencl.common import ocl
- if ocl is not None:
- self.__run_tests(backend="opencl", ctx=ocl.create_context())
-
- @unittest.skipIf(not __have_fftw__,
- "fftw back-end requires pyfftw")
- def test_fftw(self):
- self.__run_tests(backend="fftw")
-
- def __run_tests(self, backend, **extra_args):
- """Run all tests with the given backend
-
- :param str backend:
- :param dict extra_args: Additional arguments to provide to FFT
- """
- for trdim in self.transform_infos.dimensions:
- for mode in self.transform_infos.modes:
- for size in self.transform_infos.sizes[trdim]:
- with self.subTest(trdim=trdim, mode=mode, size=size):
- self.__test(backend, trdim, mode, size, **extra_args)
-
- def __test(self, backend, trdim, mode, size, **extra_args):
- """Compare given backend with numpy for given conditions"""
- logger.debug("backend: %s, trdim: %s, mode: %s, size: %s",
- backend, trdim, mode, str(size))
- if size == "3D" and test_options.TEST_LOW_MEM:
- self.skipTest("low mem")
-
- ndim = len(size)
- input_data = self.test_data.data_refs[ndim].astype(
- self.transform_infos.modes[mode])
- tol = self.tol[np.dtype(input_data.dtype)]
- if trdim == "3D":
- tol *= 10 # Error is relatively high in high dimensions
- # It seems that cuda has problems with C2D batched 1D
- if trdim == "batched_1D" and backend == "cuda" and mode == "C2C":
- tol *= 10
-
- # Python < 3.5 does not want to mix **extra_args with existing kwargs
- fft_args = {
- "template": input_data,
- "axes": self.transform_infos.axes[trdim],
- "backend": backend,
- }
- fft_args.update(extra_args)
- F = FFT(
- **fft_args
- )
- F_np = FFT(
- template=input_data,
- axes=self.transform_infos.axes[trdim],
- backend="numpy"
- )
-
- # Forward FFT
- res = F.fft(input_data)
- res_np = F_np.fft(input_data)
- mae = self.calc_mae(res, res_np)
- all_close = np.allclose(res, res_np, atol=tol, rtol=tol),
- self.assertTrue(
- all_close,
- "FFT %s:%s, MAE(%s, numpy) = %f (tol = %.2e)" % (mode, trdim, backend, mae, tol)
- )
-
- # Inverse FFT
- res2 = F.ifft(res)
- mae = self.calc_mae(res2, input_data)
- self.assertTrue(
- mae < tol,
- "IFFT %s:%s, MAE(%s, numpy) = %f" % (mode, trdim, backend, mae)
- )
-
-
-@unittest.skipUnless(__have_scipy, "scipy is missing")
-class TestNumpyFFT(ParametricTestCase):
- """
- Test the Numpy backend individually.
- """
-
- def setUp(self):
- transforms = {
- "1D": {
- True: (np.fft.rfft, np.fft.irfft),
- False: (np.fft.fft, np.fft.ifft),
- },
- "2D": {
- True: (np.fft.rfft2, np.fft.irfft2),
- False: (np.fft.fft2, np.fft.ifft2),
- },
- "3D": {
- True: (np.fft.rfftn, np.fft.irfftn),
- False: (np.fft.fftn, np.fft.ifftn),
- },
- }
- transforms["batched_1D"] = transforms["1D"]
- transforms["batched_2D"] = transforms["2D"]
- self.transforms = transforms
- self.transform_infos = TransformInfos()
- self.test_data = TestData()
-
- def test(self):
- """Test the numpy backend against native fft.
-
- Results should be exactly the same.
- """
- for trdim in self.transform_infos.dimensions:
- for mode in self.transform_infos.modes:
- for size in self.transform_infos.sizes[trdim]:
- with self.subTest(trdim=trdim, mode=mode, size=size):
- self.__test(trdim, mode, size)
-
- def __test(self, trdim, mode, size):
- logger.debug("trdim: %s, mode: %s, size: %s", trdim, mode, str(size))
- ndim = len(size)
- input_data = self.test_data.data_refs[ndim].astype(
- self.transform_infos.modes[mode])
- np_fft, np_ifft = self.transforms[trdim][np.isrealobj(input_data)]
-
- F = FFT(
- template=input_data,
- axes=self.transform_infos.axes[trdim],
- backend="numpy"
- )
- # Test FFT
- res = F.fft(input_data)
- ref = np_fft(input_data)
- self.assertTrue(np.allclose(res, ref))
-
- # Test IFFT
- res2 = F.ifft(res)
- ref2 = np_ifft(ref)
- self.assertTrue(np.allclose(res2, ref2))
-
-
-def suite():
- suite = unittest.TestSuite()
- for cls in (TestNumpyFFT, TestFFT):
- suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(cls))
- return suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
-
-
diff --git a/silx/math/fit/fitmanager.py b/silx/math/fit/fitmanager.py
deleted file mode 100644
index b60e073..0000000
--- a/silx/math/fit/fitmanager.py
+++ /dev/null
@@ -1,1087 +0,0 @@
-# coding: utf-8
-# /*#########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ##########################################################################*/
-"""
-This module provides a tool to perform advanced fitting. The actual fit relies
-on :func:`silx.math.fit.leastsq`.
-
-This module deals with:
-
- - handling of the model functions (using a set of default functions or
- loading custom user functions)
- - handling of estimation function, that are used to determine the number
- of parameters to be fitted for functions with unknown number of
- parameters (such as the sum of a variable number of gaussian curves),
- and find reasonable initial parameters for input to the iterative
- fitting algorithm
- - handling of custom derivative functions that can be passed as a
- parameter to :func:`silx.math.fit.leastsq`
- - providing different background models
-
-"""
-from collections import OrderedDict
-import logging
-import numpy
-from numpy.linalg.linalg import LinAlgError
-import os
-import sys
-
-from .filters import strip, smooth1d
-from .leastsq import leastsq
-from .fittheory import FitTheory
-from . import bgtheories
-
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "16/01/2017"
-
-_logger = logging.getLogger(__name__)
-
-
-class FitManager(object):
- """
- Fit functions manager
-
- :param x: Abscissa data. If ``None``, :attr:`xdata` is set to
- ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
- :type x: Sequence or numpy array or None
- :param y: The dependant data ``y = f(x)``. ``y`` must have the same
- shape as ``x`` if ``x`` is not ``None``.
- :type y: Sequence or numpy array or None
- :param sigmay: The uncertainties in the ``ydata`` array. These can be
- used as weights in the least-squares problem, if ``weight_flag``
- is ``True``.
- If ``None``, the uncertainties are assumed to be 1, unless
- ``weight_flag`` is ``True``, in which case the square-root
- of ``y`` is used.
- :type sigmay: Sequence or numpy array or None
- :param weight_flag: If this parameter is ``True`` and ``sigmay``
- uncertainties are not specified, the square root of ``y`` is used
- as weights in the least-squares problem. If ``False``, the
- uncertainties are set to 1.
- :type weight_flag: boolean
- """
- def __init__(self, x=None, y=None, sigmay=None, weight_flag=False):
- """
- """
- self.fitconfig = {
- 'WeightFlag': weight_flag,
- 'fitbkg': 'No Background',
- 'fittheory': None,
- # Next few parameters are defined for compatibility with legacy theories
- # which take the background as argument for their estimation function
- 'StripWidth': 2,
- 'StripIterations': 5000,
- 'StripThresholdFactor': 1.0,
- 'SmoothingFlag': False
- }
- """Dictionary of fit configuration parameters.
- These parameters can be modified using the :meth:`configure` method.
-
- Keys are:
-
- - 'fitbkg': name of the function used for fitting a low frequency
- background signal
- - 'FwhmPoints': default full width at half maximum value for the
- peaks'.
- - 'Sensitivity': Sensitivity parameter for the peak detection
- algorithm (:func:`silx.math.fit.peak_search`)
- """
-
- self.theories = OrderedDict()
- """Dictionary of fit theories, defining functions to be fitted
- to individual peaks.
-
- Keys are descriptive theory names (e.g "Gaussians" or "Step up").
- Values are :class:`silx.math.fit.fittheory.FitTheory` objects with
- the following attributes:
-
- - *"function"* is the fit function for an individual peak
- - *"parameters"* is a sequence of parameter names
- - *"estimate"* is the parameter estimation function
- - *"configure"* is the function returning the configuration dict
- for the theory in the format described in the :attr:` fitconfig`
- documentation
- - *"derivative"* (optional) is a custom derivative function, whose
- signature is described in the documentation of
- :func:`silx.math.fit.leastsq.leastsq`
- (``model_deriv(xdata, parameters, index)``).
- - *"description"* is a description string
- """
-
- self.selectedtheory = None
- """Name of currently selected theory. This name matches a key in
- :attr:`theories`."""
-
- self.bgtheories = OrderedDict()
- """Dictionary of background theories.
-
- See :attr:`theories` for documentation on theories.
- """
-
- # Load default theories (constant, linear, strip)
- self.loadbgtheories(bgtheories)
-
- self.selectedbg = 'No Background'
- """Name of currently selected background theory. This name must be
- an existing key in :attr:`bgtheories`."""
-
- self.fit_results = []
- """This list stores detailed information about all fit parameters.
- It is initialized in :meth:`estimate` and completed with final fit
- values in :meth:`runfit`.
-
- Each fit parameter is stored as a dictionary with following fields:
-
- - 'name': Parameter name.
- - 'estimation': Estimated value.
- - 'group': Group number. Group 0 corresponds to the background
- function parameters. Group ``n`` (for ``n>0``) corresponds to
- the fit function parameters for the n-th peak.
- - 'code': Constraint code
-
- - 0 - FREE
- - 1 - POSITIVE
- - 2 - QUOTED
- - 3 - FIXED
- - 4 - FACTOR
- - 5 - DELTA
- - 6 - SUM
-
- - 'cons1':
-
- - Ignored if 'code' is FREE, POSITIVE or FIXED.
- - Min value of the parameter if code is QUOTED
- - Index of fitted parameter to which 'cons2' is related
- if code is FACTOR, DELTA or SUM.
-
- - 'cons2':
-
- - Ignored if 'code' is FREE, POSITIVE or FIXED.
- - Max value of the parameter if QUOTED
- - Factor to apply to related parameter with index 'cons1' if
- 'code' is FACTOR
- - Difference with parameter with index 'cons1' if
- 'code' is DELTA
- - Sum obtained when adding parameter with index 'cons1' if
- 'code' is SUM
-
- - 'fitresult': Fitted value.
- - 'sigma': Standard deviation for the parameter estimate
- - 'xmin': Lower limit of the ``x`` data range on which the fit
- was performed
- - 'xmax': Upeer limit of the ``x`` data range on which the fit
- was performed
- """
-
- self.parameter_names = []
- """This list stores all fit parameter names: background function
- parameters and fit function parameters for every peak. It is filled
- in :meth:`estimate`.
-
- It is the responsibility of the estimate function defined in
- :attr:`theories` to determine how many parameters are needed,
- based on how many peaks are detected and how many parameters are needed
- to fit an individual peak.
- """
-
- self.setdata(x, y, sigmay)
-
- ##################
- # Public methods #
- ##################
- def addbackground(self, bgname, bgtheory):
- """Add a new background theory to dictionary :attr:`bgtheories`.
-
- :param bgname: String with the name describing the function
- :param bgtheory: :class:`FitTheory` object
- :type bgtheory: :class:`silx.math.fit.fittheory.FitTheory`
- """
- self.bgtheories[bgname] = bgtheory
-
- def addtheory(self, name, theory=None,
- function=None, parameters=None,
- estimate=None, configure=None, derivative=None,
- description=None, pymca_legacy=False):
- """Add a new theory to dictionary :attr:`theories`.
-
- You can pass a name and a :class:`FitTheory` object as arguments, or
- alternatively provide all arguments necessary to instantiate a new
- :class:`FitTheory` object.
-
- See :meth:`loadtheories` for more information on estimation functions,
- configuration functions and custom derivative functions.
-
- :param name: String with the name describing the function
- :param theory: :class:`FitTheory` object, defining a fit function and
- associated information (estimation function, description…).
- If this parameter is provided, all other parameters, except for
- ``name``, are ignored.
- :type theory: :class:`silx.math.fit.fittheory.FitTheory`
- :param callable function: Mandatory argument if ``theory`` is not provided.
- See documentation for :attr:`silx.math.fit.fittheory.FitTheory.function`.
- :param List[str] parameters: Mandatory argument if ``theory`` is not provided.
- See documentation for :attr:`silx.math.fit.fittheory.FitTheory.parameters`.
- :param callable estimate: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.estimate`
- :param callable configure: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.configure`
- :param callable derivative: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.derivative`
- :param str description: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.description`
- :param config_widget: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.config_widget`
- :param bool pymca_legacy: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.pymca_legacy`
- """
- if theory is not None:
- self.theories[name] = theory
-
- elif function is not None and parameters is not None:
- self.theories[name] = FitTheory(
- description=description,
- function=function,
- parameters=parameters,
- estimate=estimate,
- configure=configure,
- derivative=derivative,
- pymca_legacy=pymca_legacy
- )
-
- else:
- raise TypeError("You must supply a FitTheory object or define " +
- "a fit function and its parameters.")
-
- def addbgtheory(self, name, theory=None,
- function=None, parameters=None,
- estimate=None, configure=None,
- derivative=None, description=None):
- """Add a new theory to dictionary :attr:`bgtheories`.
-
- You can pass a name and a :class:`FitTheory` object as arguments, or
- alternatively provide all arguments necessary to instantiate a new
- :class:`FitTheory` object.
-
- :param name: String with the name describing the function
- :param theory: :class:`FitTheory` object, defining a fit function and
- associated information (estimation function, description…).
- If this parameter is provided, all other parameters, except for
- ``name``, are ignored.
- :type theory: :class:`silx.math.fit.fittheory.FitTheory`
- :param function function: Mandatory argument if ``theory`` is not provided.
- See documentation for :attr:`silx.math.fit.fittheory.FitTheory.function`.
- :param list[str] parameters: Mandatory argument if ``theory`` is not provided.
- See documentation for :attr:`silx.math.fit.fittheory.FitTheory.parameters`.
- :param function estimate: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.estimate`
- :param function configure: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.configure`
- :param function derivative: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.derivative`
- :param str description: See documentation for
- :attr:`silx.math.fit.fittheory.FitTheory.description`
- """
- if theory is not None:
- self.bgtheories[name] = theory
-
- elif function is not None and parameters is not None:
- self.bgtheories[name] = FitTheory(
- description=description,
- function=function,
- parameters=parameters,
- estimate=estimate,
- configure=configure,
- derivative=derivative,
- is_background=True
- )
-
- else:
- raise TypeError("You must supply a FitTheory object or define " +
- "a background function and its parameters.")
-
- def configure(self, **kw):
- """Configure the current theory by filling or updating the
- :attr:`fitconfig` dictionary.
- Call the custom configuration function, if any. This allows the user
- to modify the behavior of the custom fit function or the custom
- estimate function.
-
- This methods accepts only named parameters. All ``**kw`` parameters
- are expected to be fields of :attr:`fitconfig` to be updated, unless
- they have a special meaning for the custom configuration function
- of the currently selected theory..
-
- This method returns the modified config dictionary returned by the
- custom configuration function.
- """
- # inspect **kw to find known keys, update them in self.fitconfig
- for key in self.fitconfig:
- if key in kw:
- self.fitconfig[key] = kw[key]
-
- # initialize dict with existing config dict
- result = {}
- result.update(self.fitconfig)
-
- if "WeightFlag" in kw:
- if kw["WeightFlag"]:
- self.enableweight()
- else:
- self.disableweight()
-
- if self.selectedtheory is None:
- return result
-
- # Apply custom configuration function
- custom_config_fun = self.theories[self.selectedtheory].configure
- if custom_config_fun is not None:
- result.update(custom_config_fun(**kw))
-
- custom_bg_config_fun = self.bgtheories[self.selectedbg].configure
- if custom_bg_config_fun is not None:
- result.update(custom_bg_config_fun(**kw))
-
- # Update self.fitconfig with custom config
- for key in self.fitconfig:
- if key in result:
- self.fitconfig[key] = result[key]
-
- result.update(self.fitconfig)
- return result
-
- def estimate(self, callback=None):
- """
- Fill :attr:`fit_results` with an estimation of the fit parameters.
-
- At first, the background parameters are estimated, if a background
- model has been specified.
- Then, a custom estimation function related to the model function is
- called.
-
- This process determines the number of needed fit parameters and
- provides an initial estimation for them, to serve as an input for the
- actual iterative fitting performed in :meth:`runfit`.
-
- :param callback: Optional callback function, conforming to the
- signature ``callback(data)`` with ``data`` being a dictionary.
- This callback function is called before and after the estimation
- process, and is given a dictionary containing the values of
- :attr:`state` (``'Estimate in progress'`` or ``'Ready to Fit'``)
- and :attr:`chisq`.
- This is used for instance in :mod:`silx.gui.fit.FitWidget` to
- update a widget displaying a status message.
- :return: Estimated parameters
- """
- self.state = 'Estimate in progress'
- self.chisq = None
-
- if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
-
- CONS = {0: 'FREE',
- 1: 'POSITIVE',
- 2: 'QUOTED',
- 3: 'FIXED',
- 4: 'FACTOR',
- 5: 'DELTA',
- 6: 'SUM',
- 7: 'IGNORE'}
-
- # Filter-out not finite data
- xwork = self.xdata[self._finite_mask]
- ywork = self.ydata[self._finite_mask]
-
- # estimate the background
- bg_params, bg_constraints = self.estimate_bkg(xwork, ywork)
-
- # estimate the function
- try:
- fun_params, fun_constraints = self.estimate_fun(xwork, ywork)
- except LinAlgError:
- self.state = 'Estimate failed'
- if callback is not None:
- callback(data={'status': self.state})
- raise
-
- # build the names
- self.parameter_names = []
-
- for bg_param_name in self.bgtheories[self.selectedbg].parameters:
- self.parameter_names.append(bg_param_name)
-
- fun_param_names = self.theories[self.selectedtheory].parameters
- param_index, peak_index = 0, 0
- while param_index < len(fun_params):
- peak_index += 1
- for fun_param_name in fun_param_names:
- self.parameter_names.append(fun_param_name + "%d" % peak_index)
- param_index += 1
-
- self.fit_results = []
- nb_fun_params_per_group = len(fun_param_names)
- group_number = 0
- xmin = min(xwork)
- xmax = max(xwork)
- nb_bg_params = len(bg_params)
- for (pindex, pname) in enumerate(self.parameter_names):
- # First come background parameters
- if pindex < nb_bg_params:
- estimation_value = bg_params[pindex]
- constraint_code = CONS[int(bg_constraints[pindex][0])]
- cons1 = bg_constraints[pindex][1]
- cons2 = bg_constraints[pindex][2]
- # then come peak function parameters
- else:
- fun_param_index = pindex - nb_bg_params
-
- # increment group_number for each new fitted peak
- if (fun_param_index % nb_fun_params_per_group) == 0:
- group_number += 1
-
- estimation_value = fun_params[fun_param_index]
- constraint_code = CONS[int(fun_constraints[fun_param_index][0])]
- # cons1 is the index of another fit parameter. In the global
- # fit_results, we must adjust the index to account for the bg
- # params added to the start of the list.
- cons1 = fun_constraints[fun_param_index][1]
- if constraint_code in ["FACTOR", "DELTA", "SUM"]:
- cons1 += nb_bg_params
- cons2 = fun_constraints[fun_param_index][2]
-
- self.fit_results.append({'name': pname,
- 'estimation': estimation_value,
- 'group': group_number,
- 'code': constraint_code,
- 'cons1': cons1,
- 'cons2': cons2,
- 'fitresult': 0.0,
- 'sigma': 0.0,
- 'xmin': xmin,
- 'xmax': xmax})
-
- self.state = 'Ready to Fit'
- self.chisq = None
- self.niter = 0
-
- if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
- return numpy.append(bg_params, fun_params)
-
- def fit(self):
- """Convenience method to call :meth:`estimate` followed by :meth:`runfit`.
-
- :return: Output of :meth:`runfit`"""
- self.estimate()
- return self.runfit()
-
- def gendata(self, x=None, paramlist=None, estimated=False):
- """Return a data array using the currently selected fit function
- and the fitted parameters.
-
- :param x: Independent variable where the function is calculated.
- If ``None``, use :attr:`xdata`.
- :param paramlist: List of dictionaries, each dictionary item being a
- fit parameter. The dictionary's format is documented in
- :attr:`fit_results`.
- If ``None`` (default), use parameters from :attr:`fit_results`.
- :param estimated: If *True*, use estimated parameters.
- :return: :meth:`fitfunction` calculated for parameters whose code is
- not set to ``"IGNORE"``.
-
- This calculates :meth:`fitfunction` on `x` data using fit parameters
- from a list of parameter dictionaries, if field ``code`` is not set
- to ``"IGNORE"``.
- """
- x = self.xdata if x is None else numpy.array(x, copy=False)
-
- if paramlist is None:
- paramlist = self.fit_results
- active_params = []
- for param in paramlist:
- if param['code'] not in ['IGNORE', 7]:
- if not estimated:
- active_params.append(param['fitresult'])
- else:
- active_params.append(param['estimation'])
-
- # Mask x with not finite (support nD x)
- finite_mask = numpy.all(numpy.isfinite(x), axis=tuple(range(1, x.ndim)))
-
- if numpy.all(finite_mask): # All values are finite: fast path
- return self.fitfunction(numpy.array(x, copy=True), *active_params)
-
- else: # Only run fitfunction on finite data and complete result with NaNs
- # Create result with same number as elements as x, filling holes with NaNs
- result = numpy.full((x.shape[0],), numpy.nan, dtype=numpy.float64)
- result[finite_mask] = self.fitfunction(
- numpy.array(x[finite_mask], copy=True), *active_params)
- return result
-
- def get_estimation(self):
- """Return the list of fit parameter names."""
- if self.state not in ["Ready to fit", "Fit in progress", "Ready"]:
- _logger.warning("get_estimation() called before estimate() completed")
- return [param["estimation"] for param in self.fit_results]
-
- def get_names(self):
- """Return the list of fit parameter estimations."""
- if self.state not in ["Ready to fit", "Fit in progress", "Ready"]:
- msg = "get_names() called before estimate() completed, "
- msg += "names are not populated at this stage"
- _logger.warning(msg)
- return [param["name"] for param in self.fit_results]
-
- def get_fitted_parameters(self):
- """Return the list of fitted parameters."""
- if self.state not in ["Ready"]:
- msg = "get_fitted_parameters() called before runfit() completed, "
- msg += "results are not available a this stage"
- _logger.warning(msg)
- return [param["fitresult"] for param in self.fit_results]
-
- def loadtheories(self, theories):
- """Import user defined fit functions defined in an external Python
- source file, and save them in :attr:`theories`.
-
- An example of such a file can be found in the sources of
- :mod:`silx.math.fit.fittheories`. It must contain a
- dictionary named ``THEORY`` with the following structure::
-
- THEORY = {
- 'theory_name_1':
- FitTheory(description='Description of theory 1',
- function=fitfunction1,
- parameters=('param name 1', 'param name 2', …),
- estimate=estimation_function1,
- configure=configuration_function1,
- derivative=derivative_function1),
- 'theory_name_2':
- FitTheory(…),
- }
-
- See documentation of :mod:`silx.math.fit.fittheories` and
- :mod:`silx.math.fit.fittheory` for more
- information on designing your fit functions file.
-
- This method can also load user defined functions in the legacy
- format used in *PyMca*.
-
- :param theories: Name of python source file, or module containing the
- definition of fit functions.
- :raise: ImportError if theories cannot be imported
- """
- from types import ModuleType
- if isinstance(theories, ModuleType):
- theories_module = theories
- else:
- # if theories is not a module, it must be a string
- string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
- if not isinstance(theories, string_types):
- raise ImportError("theory must be a python module, a module" +
- "name or a python filename")
- # if theories is a filename
- if os.path.isfile(theories):
- sys.path.append(os.path.dirname(theories))
- f = os.path.basename(os.path.splitext(theories)[0])
- theories_module = __import__(f)
- # if theories is a module name
- else:
- theories_module = __import__(theories)
-
- if hasattr(theories_module, "INIT"):
- theories.INIT()
-
- if not hasattr(theories_module, "THEORY"):
- msg = "File %s does not contain a THEORY dictionary" % theories
- raise ImportError(msg)
-
- elif isinstance(theories_module.THEORY, dict):
- # silx format for theory definition
- for theory_name, fittheory in list(theories_module.THEORY.items()):
- self.addtheory(theory_name, fittheory)
- else:
- self._load_legacy_theories(theories_module)
-
- def loadbgtheories(self, theories):
- """Import user defined background functions defined in an external Python
- module (source file), and save them in :attr:`theories`.
-
- An example of such a file can be found in the sources of
- :mod:`silx.math.fit.fittheories`. It must contain a
- dictionary named ``THEORY`` with the following structure::
-
- THEORY = {
- 'theory_name_1':
- FitTheory(description='Description of theory 1',
- function=fitfunction1,
- parameters=('param name 1', 'param name 2', …),
- estimate=estimation_function1,
- configure=configuration_function1,
- 'theory_name_2':
- FitTheory(…),
- }
-
- See documentation of :mod:`silx.math.fit.bgtheories` and
- :mod:`silx.math.fit.fittheory` for more
- information on designing your background functions file.
-
- :param theories: Module or name of python source file containing the
- definition of background functions.
- :raise: ImportError if theories cannot be imported
- """
- from types import ModuleType
- if isinstance(theories, ModuleType):
- theories_module = theories
- else:
- # if theories is not a module, it must be a string
- string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
- if not isinstance(theories, string_types):
- raise ImportError("theory must be a python module, a module" +
- "name or a python filename")
- # if theories is a filename
- if os.path.isfile(theories):
- sys.path.append(os.path.dirname(theories))
- f = os.path.basename(os.path.splitext(theories)[0])
- theories_module = __import__(f)
- # if theories is a module name
- else:
- theories_module = __import__(theories)
-
- if hasattr(theories_module, "INIT"):
- theories.INIT()
-
- if not hasattr(theories_module, "THEORY"):
- msg = "File %s does not contain a THEORY dictionary" % theories
- raise ImportError(msg)
-
- elif isinstance(theories_module.THEORY, dict):
- # silx format for theory definition
- for theory_name, fittheory in list(theories_module.THEORY.items()):
- self.addbgtheory(theory_name, fittheory)
-
- def setbackground(self, theory):
- """Choose a background type from within :attr:`bgtheories`.
-
- This updates :attr:`selectedbg`.
-
- :param theory: The name of the background to be used.
- :raise: KeyError if ``theory`` is not a key of :attr:`bgtheories``.
- """
- if theory in self.bgtheories:
- self.selectedbg = theory
- else:
- msg = "No theory with name %s in bgtheories.\n" % theory
- msg += "Available theories: %s\n" % self.bgtheories.keys()
- raise KeyError(msg)
-
- # run configure to apply our fitconfig to the selected theory
- # through its custom config function
- self.configure(**self.fitconfig)
-
- def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
- """Set data attributes:
-
- - ``xdata0``, ``ydata0`` and ``sigmay0`` store the initial data
- and uncertainties. These attributes are not modified after
- initialization.
- - ``xdata``, ``ydata`` and ``sigmay`` store the data after
- removing values where ``xdata < xmin`` or ``xdata > xmax``.
- These attributes may be modified at a latter stage by filters.
-
- :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to
- ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
- :type x: Sequence or numpy array or None
- :param y: The dependant data ``y = f(x)``. ``y`` must have the same
- shape as ``x`` if ``x`` is not ``None``.
- :type y: Sequence or numpy array or None
- :param sigmay: The uncertainties in the ``ydata`` array. These are
- used as weights in the least-squares problem.
- If ``None``, the uncertainties are assumed to be 1.
- :type sigmay: Sequence or numpy array or None
- :param xmin: Lower value of x values to use for fitting
- :param xmax: Upper value of x values to use for fitting
- """
- if y is None:
- self.xdata0 = numpy.array([], numpy.float64)
- self.ydata0 = numpy.array([], numpy.float64)
- # self.sigmay0 = numpy.array([], numpy.float64)
- self.xdata = numpy.array([], numpy.float64)
- self.ydata = numpy.array([], numpy.float64)
- # self.sigmay = numpy.array([], numpy.float64)
-
- else:
- self.ydata0 = numpy.array(y)
- self.ydata = numpy.array(y)
- if x is None:
- self.xdata0 = numpy.arange(len(self.ydata0))
- self.xdata = numpy.arange(len(self.ydata0))
- else:
- self.xdata0 = numpy.array(x)
- self.xdata = numpy.array(x)
-
- # default weight
- if sigmay is None:
- self.sigmay0 = None
- self.sigmay = numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
- else:
- self.sigmay0 = numpy.array(sigmay)
- self.sigmay = numpy.array(sigmay) if self.fitconfig["WeightFlag"] else None
-
- # take the data between limits, using boolean array indexing
- if (xmin is not None or xmax is not None) and len(self.xdata):
- xmin = xmin if xmin is not None else min(self.xdata)
- xmax = xmax if xmax is not None else max(self.xdata)
- bool_array = (self.xdata >= xmin) & (self.xdata <= xmax)
- self.xdata = self.xdata[bool_array]
- self.ydata = self.ydata[bool_array]
- self.sigmay = self.sigmay[bool_array] if sigmay is not None else None
-
- self._finite_mask = numpy.logical_and(
- numpy.all(numpy.isfinite(self.xdata), axis=tuple(range(1, self.xdata.ndim))),
- numpy.isfinite(self.ydata))
-
- def enableweight(self):
- """This method can be called to set :attr:`sigmay`. If :attr:`sigmay0` was filled with
- actual uncertainties in :meth:`setdata`, use these values.
- Else, use ``sqrt(self.ydata)``.
- """
- if self.sigmay0 is None:
- self.sigmay = numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
- else:
- self.sigmay = self.sigmay0
-
- def disableweight(self):
- """This method can be called to set :attr:`sigmay` equal to ``None``.
- As a result, :func:`leastsq` will consider that the weights in the
- least square problem are 1 for all samples."""
- self.sigmay = None
-
- def settheory(self, theory):
- """Pick a theory from :attr:`theories`.
-
- :param theory: Name of the theory to be used.
- :raise: KeyError if ``theory`` is not a key of :attr:`theories`.
- """
- if theory is None:
- self.selectedtheory = None
- elif theory in self.theories:
- self.selectedtheory = theory
- else:
- msg = "No theory with name %s in theories.\n" % theory
- msg += "Available theories: %s\n" % self.theories.keys()
- raise KeyError(msg)
-
- # run configure to apply our fitconfig to the selected theory
- # through its custom config function
- self.configure(**self.fitconfig)
-
- def runfit(self, callback=None):
- """Run the actual fitting and fill :attr:`fit_results` with fit results.
-
- Before running this method, :attr:`fit_results` must already be
- populated with a list of all parameters and their estimated values.
- For this, run :meth:`estimate` beforehand.
-
- :param callback: Optional callback function, conforming to the
- signature ``callback(data)`` with ``data`` being a dictionary.
- This callback function is called before and after the estimation
- process, and is given a dictionary containing the values of
- :attr:`state` (``'Fit in progress'`` or ``'Ready'``)
- and :attr:`chisq`.
- This is used for instance in :mod:`silx.gui.fit.FitWidget` to
- update a widget displaying a status message.
- :return: Tuple ``(fitted parameters, uncertainties, infodict)``.
- *infodict* is the dictionary returned by
- :func:`silx.math.fit.leastsq` when called with option
- ``full_output=True``. Uncertainties is a sequence of uncertainty
- values associated with each fitted parameter.
- """
- # self.dataupdate()
-
- self.state = 'Fit in progress'
- self.chisq = None
-
- if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
-
- param_val = []
- param_constraints = []
- # Initial values are set to the ones computed in estimate()
- for param in self.fit_results:
- param_val.append(param['estimation'])
- param_constraints.append([param['code'], param['cons1'], param['cons2']])
-
- # Filter-out not finite data
- ywork = self.ydata[self._finite_mask]
- xwork = self.xdata[self._finite_mask]
-
- try:
- params, covariance_matrix, infodict = leastsq(
- self.fitfunction, # bg + actual model function
- xwork, ywork, param_val,
- sigma=self.sigmay,
- constraints=param_constraints,
- model_deriv=self.theories[self.selectedtheory].derivative,
- full_output=True, left_derivative=True)
- except LinAlgError:
- self.state = 'Fit failed'
- callback(data={'status': self.state})
- raise
-
- sigmas = infodict['uncertainties']
-
- for i, param in enumerate(self.fit_results):
- if param['code'] != 'IGNORE':
- param['fitresult'] = params[i]
- param['sigma'] = sigmas[i]
-
- self.chisq = infodict["reduced_chisq"]
- self.niter = infodict["niter"]
- self.state = 'Ready'
-
- if callback is not None:
- callback(data={'chisq': self.chisq,
- 'status': self.state})
-
- return params, sigmas, infodict
-
- ###################
- # Private methods #
- ###################
- def fitfunction(self, x, *pars):
- """Function to be fitted.
-
- This is the sum of the selected background function plus
- the selected fit model function.
-
- :param x: Independent variable where the function is calculated.
- :param pars: Sequence of all fit parameters. The first few parameters
- are background parameters, then come the peak function parameters.
- :return: Output of the fit function with ``x`` as input and ``pars``
- as fit parameters.
- """
- result = numpy.zeros(numpy.shape(x), numpy.float64)
-
- if self.selectedbg is not None:
- bg_pars_list = self.bgtheories[self.selectedbg].parameters
- nb_bg_pars = len(bg_pars_list)
-
- bgfun = self.bgtheories[self.selectedbg].function
- result += bgfun(x, self.ydata, *pars[0:nb_bg_pars])
- else:
- nb_bg_pars = 0
-
- selectedfun = self.theories[self.selectedtheory].function
- result += selectedfun(x, *pars[nb_bg_pars:])
-
- return result
-
- def estimate_bkg(self, x, y):
- """Estimate background parameters using the function defined in
- the current fit configuration.
-
- To change the selected background model, attribute :attr:`selectdbg`
- must be changed using method :meth:`setbackground`.
-
- The actual background function to be used is
- referenced in :attr:`bgtheories`
-
- :param x: Sequence of x data
- :param y: sequence of y data
- :return: Tuple of two sequences and one data array
- ``(estimated_param, constraints, bg_data)``:
-
- - ``estimated_param`` is a list of estimated values for each
- background parameter.
- - ``constraints`` is a 2D sequence of dimension ``(n_parameters, 3)``
-
- - ``constraints[i][0]``: Constraint code.
- See explanation about codes in :attr:`fit_results`
-
- - ``constraints[i][1]``
- See explanation about 'cons1' in :attr:`fit_results`
- documentation.
-
- - ``constraints[i][2]``
- See explanation about 'cons2' in :attr:`fit_results`
- documentation.
- """
- background_estimate_function = self.bgtheories[self.selectedbg].estimate
- if background_estimate_function is not None:
- return background_estimate_function(x, y)
- else:
- return [], []
-
- def estimate_fun(self, x, y):
- """Estimate fit parameters using the function defined in
- the current fit configuration.
-
- :param x: Sequence of x data
- :param y: sequence of y data
- :param bg: Background signal, to be subtracted from ``y`` before fitting.
- :return: Tuple of two sequences ``(estimated_param, constraints)``:
-
- - ``estimated_param`` is a list of estimated values for each
- background parameter.
- - ``constraints`` is a 2D sequence of dimension (n_parameters, 3)
-
- - ``constraints[i][0]``: Constraint code.
- See explanation about codes in :attr:`fit_results`
-
- - ``constraints[i][1]``
- See explanation about 'cons1' in :attr:`fit_results`
- documentation.
-
- - ``constraints[i][2]``
- See explanation about 'cons2' in :attr:`fit_results`
- documentation.
- :raise: ``TypeError`` if estimation function is not callable
-
- """
- estimatefunction = self.theories[self.selectedtheory].estimate
- if hasattr(estimatefunction, '__call__'):
- if not self.theories[self.selectedtheory].pymca_legacy:
- return estimatefunction(x, y)
- else:
- # legacy pymca estimate functions have a different signature
- if self.fitconfig["fitbkg"] == "No Background":
- bg = numpy.zeros_like(y)
- else:
- if self.fitconfig["SmoothingFlag"]:
- y = smooth1d(y)
- bg = strip(y,
- w=self.fitconfig["StripWidth"],
- niterations=self.fitconfig["StripIterations"],
- factor=self.fitconfig["StripThresholdFactor"])
- # fitconfig can be filled by user defined config function
- xscaling = self.fitconfig.get('Xscaling', 1.0)
- yscaling = self.fitconfig.get('Yscaling', 1.0)
- return estimatefunction(x, y, bg, xscaling, yscaling)
- else:
- raise TypeError("Estimation function in attribute " +
- "theories[%s]" % self.selectedtheory +
- " must be callable.")
-
- def _load_legacy_theories(self, theories_module):
- """Load theories from a custom module in the old PyMca format.
-
- See PyMca5.PyMcaMath.fitting.SpecfitFunctions for an example.
- """
- mandatory_attributes = ["THEORY", "PARAMETERS",
- "FUNCTION", "ESTIMATE"]
- err_msg = "Custom fit function file must define: "
- err_msg += ", ".join(mandatory_attributes)
- for attr in mandatory_attributes:
- if not hasattr(theories_module, attr):
- raise ImportError(err_msg)
-
- derivative = theories_module.DERIVATIVE if hasattr(theories_module, "DERIVATIVE") else None
- configure = theories_module.CONFIGURE if hasattr(theories_module, "CONFIGURE") else None
- estimate = theories_module.ESTIMATE if hasattr(theories_module, "ESTIMATE") else None
- if isinstance(theories_module.THEORY, (list, tuple)):
- # multiple fit functions
- for i in range(len(theories_module.THEORY)):
- deriv = derivative[i] if derivative is not None else None
- config = configure[i] if configure is not None else None
- estim = estimate[i] if estimate is not None else None
- self.addtheory(theories_module.THEORY[i],
- FitTheory(
- theories_module.FUNCTION[i],
- theories_module.PARAMETERS[i],
- estim,
- config,
- deriv,
- pymca_legacy=True))
- else:
- # single fit function
- self.addtheory(theories_module.THEORY,
- FitTheory(
- theories_module.FUNCTION,
- theories_module.PARAMETERS,
- estimate,
- configure,
- derivative,
- pymca_legacy=True))
-
-
-def test():
- from .functions import sum_gauss
- from . import fittheories
- from . import bgtheories
-
- # Create synthetic data with a sum of gaussian functions
- x = numpy.arange(1000).astype(numpy.float64)
-
- p = [1000, 100., 250,
- 255, 690., 45,
- 1500, 800.5, 95]
- y = 0.5 * x + 13 + sum_gauss(x, *p)
-
- # Fitting
- fit = FitManager()
- # more sensitivity necessary to resolve
- # overlapping peaks at x=690 and x=800.5
- fit.setdata(x=x, y=y)
- fit.loadtheories(fittheories)
- fit.settheory('Gaussians')
- fit.loadbgtheories(bgtheories)
- fit.setbackground('Linear')
- fit.estimate()
- fit.runfit()
-
- print("Searched parameters = ", p)
- print("Obtained parameters : ")
- dummy_list = []
- for param in fit.fit_results:
- print(param['name'], ' = ', param['fitresult'])
- dummy_list.append(param['fitresult'])
- print("chisq = ", fit.chisq)
-
- # Plot
- constant, slope = dummy_list[:2]
- p1 = dummy_list[2:]
- print(p1)
- y2 = slope * x + constant + sum_gauss(x, *p1)
-
- try:
- from silx.gui import qt
- from silx.gui.plot.PlotWindow import PlotWindow
- app = qt.QApplication([])
- pw = PlotWindow(control=True)
- pw.addCurve(x, y, "Original")
- pw.addCurve(x, y2, "Fit result")
- pw.legendsDockWidget.show()
- pw.show()
- app.exec_()
- except ImportError:
- _logger.warning("Could not import qt to display fit result as curve")
-
-
-if __name__ == "__main__":
- test()
diff --git a/silx/math/fit/fittheories.py b/silx/math/fit/fittheories.py
deleted file mode 100644
index 6b19a38..0000000
--- a/silx/math/fit/fittheories.py
+++ /dev/null
@@ -1,1374 +0,0 @@
-# coding: utf-8
-#/*##########################################################################
-#
-# Copyright (c) 2004-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-########################################################################### */
-"""This modules provides a set of fit functions and associated
-estimation functions in a format that can be imported into a
-:class:`silx.math.fit.FitManager` instance.
-
-These functions are well suited for fitting multiple gaussian shaped peaks
-typically found in spectroscopy data. The estimation functions are designed
-to detect how many peaks are present in the data, and provide an initial
-estimate for their height, their center location and their full-width
-at half maximum (fwhm).
-
-The limitation of these estimation algorithms is that only gaussians having a
-similar fwhm can be detected by the peak search algorithm.
-This *search fwhm* can be defined by the user, if
-he knows the characteristics of his data, or can be automatically estimated
-based on the fwhm of the largest peak in the data.
-
-The source code of this module can serve as template for defining your own
-fit functions.
-
-The functions to be imported by :meth:`FitManager.loadtheories` are defined by
-a dictionary :const:`THEORY`: with the following structure::
-
- from silx.math.fit.fittheory import FitTheory
-
- THEORY = {
- 'theory_name_1': FitTheory(
- description='Description of theory 1',
- function=fitfunction1,
- parameters=('param name 1', 'param name 2', …),
- estimate=estimation_function1,
- configure=configuration_function1,
- derivative=derivative_function1),
-
- 'theory_name_2': FitTheory(…),
- }
-
-.. note::
-
- Consider using an OrderedDict instead of a regular dictionary, when
- defining your own theory dictionary, if the order matters to you.
- This will likely be the case if you intend to load a selection of
- functions in a GUI such as :class:`silx.gui.fit.FitManager`.
-
-Theory names can be customized (e.g. ``gauss, lorentz, splitgauss``…).
-
-The mandatory parameters for :class:`FitTheory` are ``function`` and
-``parameters``.
-
-You can also define an ``INIT`` function that will be executed by
-:meth:`FitManager.loadtheories`.
-
-See the documentation of :class:`silx.math.fit.fittheory.FitTheory`
-for more information.
-
-Module members:
----------------
-"""
-import numpy
-from collections import OrderedDict
-import logging
-
-from silx.math.fit import functions
-from silx.math.fit.peaks import peak_search, guess_fwhm
-from silx.math.fit.filters import strip, savitsky_golay
-from silx.math.fit.leastsq import leastsq
-from silx.math.fit.fittheory import FitTheory
-
-_logger = logging.getLogger(__name__)
-
-__authors__ = ["V.A. Sole", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "15/05/2017"
-
-
-DEFAULT_CONFIG = {
- 'NoConstraintsFlag': False,
- 'PositiveFwhmFlag': True,
- 'PositiveHeightAreaFlag': True,
- 'SameFwhmFlag': False,
- 'QuotedPositionFlag': False, # peak not outside data range
- 'QuotedEtaFlag': False, # force 0 < eta < 1
- # Peak detection
- 'AutoScaling': False,
- 'Yscaling': 1.0,
- 'FwhmPoints': 8,
- 'AutoFwhm': True,
- 'Sensitivity': 2.5,
- 'ForcePeakPresence': True,
- # Hypermet
- 'HypermetTails': 15,
- 'QuotedFwhmFlag': 0,
- 'MaxFwhm2InputRatio': 1.5,
- 'MinFwhm2InputRatio': 0.4,
- # short tail parameters
- 'MinGaussArea4ShortTail': 50000.,
- 'InitialShortTailAreaRatio': 0.050,
- 'MaxShortTailAreaRatio': 0.100,
- 'MinShortTailAreaRatio': 0.0010,
- 'InitialShortTailSlopeRatio': 0.70,
- 'MaxShortTailSlopeRatio': 2.00,
- 'MinShortTailSlopeRatio': 0.50,
- # long tail parameters
- 'MinGaussArea4LongTail': 1000.0,
- 'InitialLongTailAreaRatio': 0.050,
- 'MaxLongTailAreaRatio': 0.300,
- 'MinLongTailAreaRatio': 0.010,
- 'InitialLongTailSlopeRatio': 20.0,
- 'MaxLongTailSlopeRatio': 50.0,
- 'MinLongTailSlopeRatio': 5.0,
- # step tail
- 'MinGaussHeight4StepTail': 5000.,
- 'InitialStepTailHeightRatio': 0.002,
- 'MaxStepTailHeightRatio': 0.0100,
- 'MinStepTailHeightRatio': 0.0001,
- # Hypermet constraints
- # position in range [estimated position +- estimated fwhm/2]
- 'HypermetQuotedPositionFlag': True,
- 'DeltaPositionFwhmUnits': 0.5,
- 'SameSlopeRatioFlag': 1,
- 'SameAreaRatioFlag': 1,
- # Strip bg removal
- 'StripBackgroundFlag': True,
- 'SmoothingFlag': True,
- 'SmoothingWidth': 5,
- 'StripWidth': 2,
- 'StripIterations': 5000,
- 'StripThresholdFactor': 1.0}
-"""This dictionary defines default configuration parameters that have effects
-on fit functions and estimation functions, mainly on fit constraints.
-This dictionary is accessible as attribute :attr:`FitTheories.config`,
-which can be modified by configuration functions defined in
-:const:`CONFIGURE`.
-"""
-
-CFREE = 0
-CPOSITIVE = 1
-CQUOTED = 2
-CFIXED = 3
-CFACTOR = 4
-CDELTA = 5
-CSUM = 6
-CIGNORED = 7
-
-
-class FitTheories(object):
- """Class wrapping functions from :class:`silx.math.fit.functions`
- and providing estimate functions for all of these fit functions."""
- def __init__(self, config=None):
- if config is None:
- self.config = DEFAULT_CONFIG
- else:
- self.config = config
-
- def ahypermet(self, x, *pars):
- """
- Wrapping of :func:`silx.math.fit.functions.sum_ahypermet` without
- the tail flags in the function signature.
-
- Depending on the value of `self.config['HypermetTails']`, one can
- activate or deactivate the various terms of the hypermet function.
-
- `self.config['HypermetTails']` must be an integer between 0 and 15.
- It is a set of 4 binary flags, one for activating each one of the
- hypermet terms: *gaussian function, short tail, long tail, step*.
-
- For example, 15 can be expressed as ``1111`` in base 2, so a flag of
- 15 means all terms are active.
- """
- g_term = self.config['HypermetTails'] & 1
- st_term = (self.config['HypermetTails'] >> 1) & 1
- lt_term = (self.config['HypermetTails'] >> 2) & 1
- step_term = (self.config['HypermetTails'] >> 3) & 1
- return functions.sum_ahypermet(x, *pars,
- gaussian_term=g_term, st_term=st_term,
- lt_term=lt_term, step_term=step_term)
-
- def poly(self, x, *pars):
- """Order n polynomial.
- The order of the polynomial is defined by the number of
- coefficients (``*pars``).
-
- """
- p = numpy.poly1d(pars)
- return p(x)
-
- @staticmethod
- def estimate_poly(x, y, n=2):
- """Estimate polynomial coefficients for a degree n polynomial.
-
- """
- pcoeffs = numpy.polyfit(x, y, n)
- constraints = numpy.zeros((n + 1, 3), numpy.float64)
- return pcoeffs, constraints
-
- def estimate_quadratic(self, x, y):
- """Estimate quadratic coefficients
-
- """
- return self.estimate_poly(x, y, n=2)
-
- def estimate_cubic(self, x, y):
- """Estimate coefficients for a degree 3 polynomial
-
- """
- return self.estimate_poly(x, y, n=3)
-
- def estimate_quartic(self, x, y):
- """Estimate coefficients for a degree 4 polynomial
-
- """
- return self.estimate_poly(x, y, n=4)
-
- def estimate_quintic(self, x, y):
- """Estimate coefficients for a degree 5 polynomial
-
- """
- return self.estimate_poly(x, y, n=5)
-
- def strip_bg(self, y):
- """Return the strip background of y, using parameters from
- :attr:`config` dictionary (*StripBackgroundFlag, StripWidth,
- StripIterations, StripThresholdFactor*)"""
- remove_strip_bg = self.config.get('StripBackgroundFlag', False)
- if remove_strip_bg:
- if self.config['SmoothingFlag']:
- y = savitsky_golay(y, self.config['SmoothingWidth'])
- strip_width = self.config['StripWidth']
- strip_niterations = self.config['StripIterations']
- strip_thr_factor = self.config['StripThresholdFactor']
- return strip(y, w=strip_width,
- niterations=strip_niterations,
- factor=strip_thr_factor)
- else:
- return numpy.zeros_like(y)
-
- def guess_yscaling(self, y):
- """Estimate scaling for y prior to peak search.
- A smoothing filter is applied to y to estimate the noise level
- (chi-squared)
-
- :param y: Data array
- :return: Scaling factor
- """
- # ensure y is an array
- yy = numpy.array(y, copy=False)
-
- # smooth
- convolution_kernel = numpy.ones(shape=(3,)) / 3.
- ysmooth = numpy.convolve(y, convolution_kernel, mode="same")
-
- # remove zeros
- idx_array = numpy.fabs(y) > 0.0
- yy = yy[idx_array]
- ysmooth = ysmooth[idx_array]
-
- # compute scaling factor
- chisq = numpy.mean((yy - ysmooth)**2 / numpy.fabs(yy))
- if chisq > 0:
- return 1. / chisq
- else:
- return 1.0
-
- def peak_search(self, y, fwhm, sensitivity):
- """Search for peaks in y array, after padding the array and
- multiplying its value by a scaling factor.
-
- :param y: 1-D data array
- :param int fwhm: Typical full width at half maximum for peaks,
- in number of points. This parameter is used for to discriminate between
- true peaks and background fluctuations.
- :param float sensitivity: Sensitivity parameter. This is a threshold factor
- for peak detection. Only peaks larger than the standard deviation
- of the noise multiplied by this sensitivity parameter are detected.
- :return: List of peak indices
- """
- # add padding
- ysearch = numpy.ones((len(y) + 2 * fwhm,), numpy.float64)
- ysearch[0:fwhm] = y[0]
- ysearch[-1:-fwhm - 1:-1] = y[len(y)-1]
- ysearch[fwhm:fwhm + len(y)] = y[:]
-
- scaling = self.guess_yscaling(y) if self.config["AutoScaling"] else self.config["Yscaling"]
-
- if len(ysearch) > 1.5 * fwhm:
- peaks = peak_search(scaling * ysearch,
- fwhm=fwhm, sensitivity=sensitivity)
- return [peak_index - fwhm for peak_index in peaks
- if 0 <= peak_index - fwhm < len(y)]
- else:
- return []
-
- def estimate_height_position_fwhm(self, x, y):
- """Estimation of *Height, Position, FWHM* of peaks, for gaussian-like
- curves.
-
- This functions finds how many parameters are needed, based on the
- number of peaks detected. Then it estimates the fit parameters
- with a few iterations of fitting gaussian functions.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Height, Position, FWHM*.
- Fit constraints depend on :attr:`config`.
- """
- fittedpar = []
-
- bg = self.strip_bg(y)
-
- if self.config['AutoFwhm']:
- search_fwhm = guess_fwhm(y)
- else:
- search_fwhm = int(float(self.config['FwhmPoints']))
- search_sens = float(self.config['Sensitivity'])
-
- if search_fwhm < 3:
- _logger.warning("Setting peak fwhm to 3 (lower limit)")
- search_fwhm = 3
- self.config['FwhmPoints'] = 3
-
- if search_sens < 1:
- _logger.warning("Setting peak search sensitivity to 1. " +
- "(lower limit to filter out noise peaks)")
- search_sens = 1
- self.config['Sensitivity'] = 1
-
- npoints = len(y)
-
- # Find indices of peaks in data array
- peaks = self.peak_search(y,
- fwhm=search_fwhm,
- sensitivity=search_sens)
-
- if not len(peaks):
- forcepeak = int(float(self.config.get('ForcePeakPresence', 0)))
- if forcepeak:
- delta = y - bg
- # get index of global maximum
- # (first one if several samples are equal to this value)
- peaks = [numpy.nonzero(delta == delta.max())[0][0]]
-
- # Find index of largest peak in peaks array
- index_largest_peak = 0
- if len(peaks) > 0:
- # estimate fwhm as 5 * sampling interval
- sig = 5 * abs(x[npoints - 1] - x[0]) / npoints
- peakpos = x[int(peaks[0])]
- if abs(peakpos) < 1.0e-16:
- peakpos = 0.0
- param = numpy.array(
- [y[int(peaks[0])] - bg[int(peaks[0])], peakpos, sig])
- height_largest_peak = param[0]
- peak_index = 1
- for i in peaks[1:]:
- param2 = numpy.array(
- [y[int(i)] - bg[int(i)], x[int(i)], sig])
- param = numpy.concatenate((param, param2))
- if param2[0] > height_largest_peak:
- height_largest_peak = param2[0]
- index_largest_peak = peak_index
- peak_index += 1
-
- # Subtract background
- xw = x
- yw = y - bg
-
- cons = numpy.zeros((len(param), 3), numpy.float64)
-
- # peak height must be positive
- cons[0:len(param):3, 0] = CPOSITIVE
- # force peaks to stay around their position
- cons[1:len(param):3, 0] = CQUOTED
-
- # set possible peak range to estimated peak +- guessed fwhm
- if len(xw) > search_fwhm:
- fwhmx = numpy.fabs(xw[int(search_fwhm)] - xw[0])
- cons[1:len(param):3, 1] = param[1:len(param):3] - 0.5 * fwhmx
- cons[1:len(param):3, 2] = param[1:len(param):3] + 0.5 * fwhmx
- else:
- shape = [max(1, int(x)) for x in (param[1:len(param):3])]
- cons[1:len(param):3, 1] = min(xw) * numpy.ones(
- shape,
- numpy.float64)
- cons[1:len(param):3, 2] = max(xw) * numpy.ones(
- shape,
- numpy.float64)
-
- # ensure fwhm is positive
- cons[2:len(param):3, 0] = CPOSITIVE
-
- # run a quick iterative fit (4 iterations) to improve
- # estimations
- fittedpar, _, _ = leastsq(functions.sum_gauss, xw, yw, param,
- max_iter=4, constraints=cons.tolist(),
- full_output=True)
-
- # set final constraints based on config parameters
- cons = numpy.zeros((len(fittedpar), 3), numpy.float64)
- peak_index = 0
- for i in range(len(peaks)):
- # Setup height area constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveHeightAreaFlag']:
- cons[peak_index, 0] = CPOSITIVE
- cons[peak_index, 1] = 0
- cons[peak_index, 2] = 0
- peak_index += 1
-
- # Setup position constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['QuotedPositionFlag']:
- cons[peak_index, 0] = CQUOTED
- cons[peak_index, 1] = min(x)
- cons[peak_index, 2] = max(x)
- peak_index += 1
-
- # Setup positive FWHM constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveFwhmFlag']:
- cons[peak_index, 0] = CPOSITIVE
- cons[peak_index, 1] = 0
- cons[peak_index, 2] = 0
- if self.config['SameFwhmFlag']:
- if i != index_largest_peak:
- cons[peak_index, 0] = CFACTOR
- cons[peak_index, 1] = 3 * index_largest_peak + 2
- cons[peak_index, 2] = 1.0
- peak_index += 1
-
- return fittedpar, cons
-
- def estimate_agauss(self, x, y):
- """Estimation of *Area, Position, FWHM* of peaks, for gaussian-like
- curves.
-
- This functions uses :meth:`estimate_height_position_fwhm`, then
- converts the height parameters to area under the curve with the
- formula ``area = sqrt(2*pi) * height * fwhm / (2 * sqrt(2 * log(2))``
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Area, Position, FWHM*.
- Fit constraints depend on :attr:`config`.
- """
- fittedpar, cons = self.estimate_height_position_fwhm(x, y)
- # get the number of found peaks
- npeaks = len(fittedpar) // 3
- for i in range(npeaks):
- height = fittedpar[3 * i]
- fwhm = fittedpar[3 * i + 2]
- # Replace height with area in fittedpar
- fittedpar[3 * i] = numpy.sqrt(2 * numpy.pi) * height * fwhm / (
- 2.0 * numpy.sqrt(2 * numpy.log(2)))
- return fittedpar, cons
-
- def estimate_alorentz(self, x, y):
- """Estimation of *Area, Position, FWHM* of peaks, for Lorentzian
- curves.
-
- This functions uses :meth:`estimate_height_position_fwhm`, then
- converts the height parameters to area under the curve with the
- formula ``area = height * fwhm * 0.5 * pi``
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Area, Position, FWHM*.
- Fit constraints depend on :attr:`config`.
- """
- fittedpar, cons = self.estimate_height_position_fwhm(x, y)
- # get the number of found peaks
- npeaks = len(fittedpar) // 3
- for i in range(npeaks):
- height = fittedpar[3 * i]
- fwhm = fittedpar[3 * i + 2]
- # Replace height with area in fittedpar
- fittedpar[3 * i] = (height * fwhm * 0.5 * numpy.pi)
- return fittedpar, cons
-
- def estimate_splitgauss(self, x, y):
- """Estimation of *Height, Position, FWHM1, FWHM2* of peaks, for
- asymmetric gaussian-like curves.
-
- This functions uses :meth:`estimate_height_position_fwhm`, then
- adds a second (identical) estimation of FWHM to the fit parameters
- for each peak, and the corresponding constraint.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Height, Position, FWHM1, FWHM2*.
- Fit constraints depend on :attr:`config`.
- """
- fittedpar, cons = self.estimate_height_position_fwhm(x, y)
- # get the number of found peaks
- npeaks = len(fittedpar) // 3
- estimated_parameters = []
- estimated_constraints = numpy.zeros((4 * npeaks, 3), numpy.float64)
- for i in range(npeaks):
- for j in range(3):
- estimated_parameters.append(fittedpar[3 * i + j])
- # fwhm2 estimate = fwhm1
- estimated_parameters.append(fittedpar[3 * i + 2])
- # height
- estimated_constraints[4 * i, 0] = cons[3 * i, 0]
- estimated_constraints[4 * i, 1] = cons[3 * i, 1]
- estimated_constraints[4 * i, 2] = cons[3 * i, 2]
- # position
- estimated_constraints[4 * i + 1, 0] = cons[3 * i + 1, 0]
- estimated_constraints[4 * i + 1, 1] = cons[3 * i + 1, 1]
- estimated_constraints[4 * i + 1, 2] = cons[3 * i + 1, 2]
- # fwhm1
- estimated_constraints[4 * i + 2, 0] = cons[3 * i + 2, 0]
- estimated_constraints[4 * i + 2, 1] = cons[3 * i + 2, 1]
- estimated_constraints[4 * i + 2, 2] = cons[3 * i + 2, 2]
- # fwhm2
- estimated_constraints[4 * i + 3, 0] = cons[3 * i + 2, 0]
- estimated_constraints[4 * i + 3, 1] = cons[3 * i + 2, 1]
- estimated_constraints[4 * i + 3, 2] = cons[3 * i + 2, 2]
- if cons[3 * i + 2, 0] == CFACTOR:
- # convert indices of related parameters
- # (this happens if SameFwhmFlag == True)
- estimated_constraints[4 * i + 2, 1] = \
- int(cons[3 * i + 2, 1] / 3) * 4 + 2
- estimated_constraints[4 * i + 3, 1] = \
- int(cons[3 * i + 2, 1] / 3) * 4 + 3
- return estimated_parameters, estimated_constraints
-
- def estimate_pvoigt(self, x, y):
- """Estimation of *Height, Position, FWHM, eta* of peaks, for
- pseudo-Voigt curves.
-
- Pseudo-Voigt are a sum of a gaussian curve *G(x)* and a lorentzian
- curve *L(x)* with the same height, center, fwhm parameters:
- ``y(x) = eta * G(x) + (1-eta) * L(x)``
-
- This functions uses :meth:`estimate_height_position_fwhm`, then
- adds a constant estimation of *eta* (0.5) to the fit parameters
- for each peak, and the corresponding constraint.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Height, Position, FWHM, eta*.
- Constraint for the eta parameter can be set to QUOTED (0.--1.)
- by setting :attr:`config`['QuotedEtaFlag'] to ``True``.
- If this is not the case, the constraint code is set to FREE.
- """
- fittedpar, cons = self.estimate_height_position_fwhm(x, y)
- npeaks = len(fittedpar) // 3
- newpar = []
- newcons = numpy.zeros((4 * npeaks, 3), numpy.float64)
- # find out related parameters proper index
- if not self.config['NoConstraintsFlag']:
- if self.config['SameFwhmFlag']:
- j = 0
- # get the index of the free FWHM
- for i in range(npeaks):
- if cons[3 * i + 2, 0] != 4:
- j = i
- for i in range(npeaks):
- if i != j:
- cons[3 * i + 2, 1] = 4 * j + 2
- for i in range(npeaks):
- newpar.append(fittedpar[3 * i])
- newpar.append(fittedpar[3 * i + 1])
- newpar.append(fittedpar[3 * i + 2])
- newpar.append(0.5)
- # height
- newcons[4 * i, 0] = cons[3 * i, 0]
- newcons[4 * i, 1] = cons[3 * i, 1]
- newcons[4 * i, 2] = cons[3 * i, 2]
- # position
- newcons[4 * i + 1, 0] = cons[3 * i + 1, 0]
- newcons[4 * i + 1, 1] = cons[3 * i + 1, 1]
- newcons[4 * i + 1, 2] = cons[3 * i + 1, 2]
- # fwhm
- newcons[4 * i + 2, 0] = cons[3 * i + 2, 0]
- newcons[4 * i + 2, 1] = cons[3 * i + 2, 1]
- newcons[4 * i + 2, 2] = cons[3 * i + 2, 2]
- # Eta constrains
- newcons[4 * i + 3, 0] = CFREE
- newcons[4 * i + 3, 1] = 0
- newcons[4 * i + 3, 2] = 0
- if self.config['QuotedEtaFlag']:
- newcons[4 * i + 3, 0] = CQUOTED
- newcons[4 * i + 3, 1] = 0.0
- newcons[4 * i + 3, 2] = 1.0
- return newpar, newcons
-
- def estimate_splitpvoigt(self, x, y):
- """Estimation of *Height, Position, FWHM1, FWHM2, eta* of peaks, for
- asymmetric pseudo-Voigt curves.
-
- This functions uses :meth:`estimate_height_position_fwhm`, then
- adds an identical FWHM2 parameter and a constant estimation of
- *eta* (0.5) to the fit parameters for each peak, and the corresponding
- constraints.
-
- Constraint for the eta parameter can be set to QUOTED (0.--1.)
- by setting :attr:`config`['QuotedEtaFlag'] to ``True``.
- If this is not the case, the constraint code is set to FREE.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Height, Position, FWHM1, FWHM2, eta*.
- """
- fittedpar, cons = self.estimate_height_position_fwhm(x, y)
- npeaks = len(fittedpar) // 3
- newpar = []
- newcons = numpy.zeros((5 * npeaks, 3), numpy.float64)
- # find out related parameters proper index
- if not self.config['NoConstraintsFlag']:
- if self.config['SameFwhmFlag']:
- j = 0
- # get the index of the free FWHM
- for i in range(npeaks):
- if cons[3 * i + 2, 0] != 4:
- j = i
- for i in range(npeaks):
- if i != j:
- cons[3 * i + 2, 1] = 4 * j + 2
- for i in range(npeaks):
- # height
- newpar.append(fittedpar[3 * i])
- # position
- newpar.append(fittedpar[3 * i + 1])
- # fwhm1
- newpar.append(fittedpar[3 * i + 2])
- # fwhm2 estimate equal to fwhm1
- newpar.append(fittedpar[3 * i + 2])
- # eta
- newpar.append(0.5)
- # constraint codes
- # ----------------
- # height
- newcons[5 * i, 0] = cons[3 * i, 0]
- # position
- newcons[5 * i + 1, 0] = cons[3 * i + 1, 0]
- # fwhm1
- newcons[5 * i + 2, 0] = cons[3 * i + 2, 0]
- # fwhm2
- newcons[5 * i + 3, 0] = cons[3 * i + 2, 0]
- # cons 1
- # ------
- newcons[5 * i, 1] = cons[3 * i, 1]
- newcons[5 * i + 1, 1] = cons[3 * i + 1, 1]
- newcons[5 * i + 2, 1] = cons[3 * i + 2, 1]
- newcons[5 * i + 3, 1] = cons[3 * i + 2, 1]
- # cons 2
- # ------
- newcons[5 * i, 2] = cons[3 * i, 2]
- newcons[5 * i + 1, 2] = cons[3 * i + 1, 2]
- newcons[5 * i + 2, 2] = cons[3 * i + 2, 2]
- newcons[5 * i + 3, 2] = cons[3 * i + 2, 2]
-
- if cons[3 * i + 2, 0] == CFACTOR:
- # fwhm2 connstraint depends on fwhm1
- newcons[5 * i + 3, 1] = newcons[5 * i + 2, 1] + 1
- # eta constraints
- newcons[5 * i + 4, 0] = CFREE
- newcons[5 * i + 4, 1] = 0
- newcons[5 * i + 4, 2] = 0
- if self.config['QuotedEtaFlag']:
- newcons[5 * i + 4, 0] = CQUOTED
- newcons[5 * i + 4, 1] = 0.0
- newcons[5 * i + 4, 2] = 1.0
- return newpar, newcons
-
- def estimate_apvoigt(self, x, y):
- """Estimation of *Area, Position, FWHM1, eta* of peaks, for
- pseudo-Voigt curves.
-
- This functions uses :meth:`estimate_pvoigt`, then converts the height
- parameter to area.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *Area, Position, FWHM, eta*.
- """
- fittedpar, cons = self.estimate_pvoigt(x, y)
- npeaks = len(fittedpar) // 4
- # Assume 50% of the area is determined by the gaussian and 50% by
- # the Lorentzian.
- for i in range(npeaks):
- height = fittedpar[4 * i]
- fwhm = fittedpar[4 * i + 2]
- fittedpar[4 * i] = 0.5 * (height * fwhm * 0.5 * numpy.pi) +\
- 0.5 * (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
- ) * numpy.sqrt(2 * numpy.pi)
- return fittedpar, cons
-
- def estimate_ahypermet(self, x, y):
- """Estimation of *area, position, fwhm, st_area_r, st_slope_r,
- lt_area_r, lt_slope_r, step_height_r* of peaks, for hypermet curves.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each peak are:
- *area, position, fwhm, st_area_r, st_slope_r,
- lt_area_r, lt_slope_r, step_height_r* .
- """
- yscaling = self.config.get('Yscaling', 1.0)
- if yscaling == 0:
- yscaling = 1.0
- fittedpar, cons = self.estimate_height_position_fwhm(x, y)
- npeaks = len(fittedpar) // 3
- newpar = []
- newcons = numpy.zeros((8 * npeaks, 3), numpy.float64)
- main_peak = 0
- # find out related parameters proper index
- if not self.config['NoConstraintsFlag']:
- if self.config['SameFwhmFlag']:
- j = 0
- # get the index of the free FWHM
- for i in range(npeaks):
- if cons[3 * i + 2, 0] != 4:
- j = i
- for i in range(npeaks):
- if i != j:
- cons[3 * i + 2, 1] = 8 * j + 2
- main_peak = j
- for i in range(npeaks):
- if fittedpar[3 * i] > fittedpar[3 * main_peak]:
- main_peak = i
-
- for i in range(npeaks):
- height = fittedpar[3 * i]
- position = fittedpar[3 * i + 1]
- fwhm = fittedpar[3 * i + 2]
- area = (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
- ) * numpy.sqrt(2 * numpy.pi)
- # the gaussian parameters
- newpar.append(area)
- newpar.append(position)
- newpar.append(fwhm)
- # print "area, pos , fwhm = ",area,position,fwhm
- # Avoid zero derivatives because of not calculating contribution
- g_term = 1
- st_term = 1
- lt_term = 1
- step_term = 1
- if self.config['HypermetTails'] != 0:
- g_term = self.config['HypermetTails'] & 1
- st_term = (self.config['HypermetTails'] >> 1) & 1
- lt_term = (self.config['HypermetTails'] >> 2) & 1
- step_term = (self.config['HypermetTails'] >> 3) & 1
- if g_term == 0:
- # fix the gaussian parameters
- newcons[8 * i, 0] = CFIXED
- newcons[8 * i + 1, 0] = CFIXED
- newcons[8 * i + 2, 0] = CFIXED
- # the short tail parameters
- if ((area * yscaling) <
- self.config['MinGaussArea4ShortTail']) | \
- (st_term == 0):
- newpar.append(0.0)
- newpar.append(0.0)
- newcons[8 * i + 3, 0] = CFIXED
- newcons[8 * i + 3, 1] = 0.0
- newcons[8 * i + 3, 2] = 0.0
- newcons[8 * i + 4, 0] = CFIXED
- newcons[8 * i + 4, 1] = 0.0
- newcons[8 * i + 4, 2] = 0.0
- else:
- newpar.append(self.config['InitialShortTailAreaRatio'])
- newpar.append(self.config['InitialShortTailSlopeRatio'])
- newcons[8 * i + 3, 0] = CQUOTED
- newcons[8 * i + 3, 1] = self.config['MinShortTailAreaRatio']
- newcons[8 * i + 3, 2] = self.config['MaxShortTailAreaRatio']
- newcons[8 * i + 4, 0] = CQUOTED
- newcons[8 * i + 4, 1] = self.config['MinShortTailSlopeRatio']
- newcons[8 * i + 4, 2] = self.config['MaxShortTailSlopeRatio']
- # the long tail parameters
- if ((area * yscaling) <
- self.config['MinGaussArea4LongTail']) | \
- (lt_term == 0):
- newpar.append(0.0)
- newpar.append(0.0)
- newcons[8 * i + 5, 0] = CFIXED
- newcons[8 * i + 5, 1] = 0.0
- newcons[8 * i + 5, 2] = 0.0
- newcons[8 * i + 6, 0] = CFIXED
- newcons[8 * i + 6, 1] = 0.0
- newcons[8 * i + 6, 2] = 0.0
- else:
- newpar.append(self.config['InitialLongTailAreaRatio'])
- newpar.append(self.config['InitialLongTailSlopeRatio'])
- newcons[8 * i + 5, 0] = CQUOTED
- newcons[8 * i + 5, 1] = self.config['MinLongTailAreaRatio']
- newcons[8 * i + 5, 2] = self.config['MaxLongTailAreaRatio']
- newcons[8 * i + 6, 0] = CQUOTED
- newcons[8 * i + 6, 1] = self.config['MinLongTailSlopeRatio']
- newcons[8 * i + 6, 2] = self.config['MaxLongTailSlopeRatio']
- # the step parameters
- if ((height * yscaling) <
- self.config['MinGaussHeight4StepTail']) | \
- (step_term == 0):
- newpar.append(0.0)
- newcons[8 * i + 7, 0] = CFIXED
- newcons[8 * i + 7, 1] = 0.0
- newcons[8 * i + 7, 2] = 0.0
- else:
- newpar.append(self.config['InitialStepTailHeightRatio'])
- newcons[8 * i + 7, 0] = CQUOTED
- newcons[8 * i + 7, 1] = self.config['MinStepTailHeightRatio']
- newcons[8 * i + 7, 2] = self.config['MaxStepTailHeightRatio']
- # if self.config['NoConstraintsFlag'] == 1:
- # newcons=numpy.zeros((8*npeaks, 3),numpy.float64)
- if npeaks > 0:
- if g_term:
- if self.config['PositiveHeightAreaFlag']:
- for i in range(npeaks):
- newcons[8 * i, 0] = CPOSITIVE
- if self.config['PositiveFwhmFlag']:
- for i in range(npeaks):
- newcons[8 * i + 2, 0] = CPOSITIVE
- if self.config['SameFwhmFlag']:
- for i in range(npeaks):
- if i != main_peak:
- newcons[8 * i + 2, 0] = CFACTOR
- newcons[8 * i + 2, 1] = 8 * main_peak + 2
- newcons[8 * i + 2, 2] = 1.0
- if self.config['HypermetQuotedPositionFlag']:
- for i in range(npeaks):
- delta = self.config['DeltaPositionFwhmUnits'] * fwhm
- newcons[8 * i + 1, 0] = CQUOTED
- newcons[8 * i + 1, 1] = newpar[8 * i + 1] - delta
- newcons[8 * i + 1, 2] = newpar[8 * i + 1] + delta
- if self.config['SameSlopeRatioFlag']:
- for i in range(npeaks):
- if i != main_peak:
- newcons[8 * i + 4, 0] = CFACTOR
- newcons[8 * i + 4, 1] = 8 * main_peak + 4
- newcons[8 * i + 4, 2] = 1.0
- newcons[8 * i + 6, 0] = CFACTOR
- newcons[8 * i + 6, 1] = 8 * main_peak + 6
- newcons[8 * i + 6, 2] = 1.0
- if self.config['SameAreaRatioFlag']:
- for i in range(npeaks):
- if i != main_peak:
- newcons[8 * i + 3, 0] = CFACTOR
- newcons[8 * i + 3, 1] = 8 * main_peak + 3
- newcons[8 * i + 3, 2] = 1.0
- newcons[8 * i + 5, 0] = CFACTOR
- newcons[8 * i + 5, 1] = 8 * main_peak + 5
- newcons[8 * i + 5, 2] = 1.0
- return newpar, newcons
-
- def estimate_stepdown(self, x, y):
- """Estimation of parameters for stepdown curves.
-
- The functions estimates gaussian parameters for the derivative of
- the data, takes the largest gaussian peak and uses its estimated
- parameters to define the center of the step and its fwhm. The
- estimated amplitude returned is simply ``max(y) - min(y)``.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit newconstraints.
- Parameters to be estimated for each stepdown are:
- *height, centroid, fwhm* .
- """
- crappyfilter = [-0.25, -0.75, 0.0, 0.75, 0.25]
- cutoff = len(crappyfilter) // 2
- y_deriv = numpy.convolve(y,
- crappyfilter,
- mode="valid")
-
- # make the derivative's peak have the same amplitude as the step
- if max(y_deriv) > 0:
- y_deriv = y_deriv * max(y) / max(y_deriv)
-
- fittedpar, newcons = self.estimate_height_position_fwhm(
- x[cutoff:-cutoff], y_deriv)
-
- data_amplitude = max(y) - min(y)
-
- # use parameters from largest gaussian found
- if len(fittedpar):
- npeaks = len(fittedpar) // 3
- largest_index = 0
- largest = [data_amplitude,
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
- for i in range(npeaks):
- if fittedpar[3 * i] > largest[0]:
- largest_index = i
- largest = [data_amplitude,
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
- else:
- # no peak was found
- largest = [data_amplitude, # height
- x[len(x)//2], # center: middle of x range
- self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value
-
- # Setup constrains
- newcons = numpy.zeros((3, 3), numpy.float64)
- if not self.config['NoConstraintsFlag']:
- # Setup height constrains
- if self.config['PositiveHeightAreaFlag']:
- newcons[0, 0] = CPOSITIVE
- newcons[0, 1] = 0
- newcons[0, 2] = 0
-
- # Setup position constrains
- if self.config['QuotedPositionFlag']:
- newcons[1, 0] = CQUOTED
- newcons[1, 1] = min(x)
- newcons[1, 2] = max(x)
-
- # Setup positive FWHM constrains
- if self.config['PositiveFwhmFlag']:
- newcons[2, 0] = CPOSITIVE
- newcons[2, 1] = 0
- newcons[2, 2] = 0
-
- return largest, newcons
-
- def estimate_slit(self, x, y):
- """Estimation of parameters for slit curves.
-
- The functions estimates stepup and stepdown parameters for the largest
- steps, and uses them for calculating the center (middle between stepup
- and stepdown), the height (maximum amplitude in data), the fwhm
- (distance between the up- and down-step centers) and the beamfwhm
- (average of FWHM for up- and down-step).
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each slit are:
- *height, position, fwhm, beamfwhm* .
- """
- largestup, cons = self.estimate_stepup(x, y)
- largestdown, cons = self.estimate_stepdown(x, y)
- fwhm = numpy.fabs(largestdown[1] - largestup[1])
- beamfwhm = 0.5 * (largestup[2] + largestdown[1])
- beamfwhm = min(beamfwhm, fwhm / 10.0)
- beamfwhm = max(beamfwhm, (max(x) - min(x)) * 3.0 / len(x))
-
- y_minus_bg = y - self.strip_bg(y)
- height = max(y_minus_bg)
-
- i1 = numpy.nonzero(y_minus_bg >= 0.5 * height)[0]
- xx = numpy.take(x, i1)
- position = (xx[0] + xx[-1]) / 2.0
- fwhm = xx[-1] - xx[0]
- largest = [height, position, fwhm, beamfwhm]
- cons = numpy.zeros((4, 3), numpy.float64)
- # Setup constrains
- if not self.config['NoConstraintsFlag']:
- # Setup height constrains
- if self.config['PositiveHeightAreaFlag']:
- cons[0, 0] = CPOSITIVE
- cons[0, 1] = 0
- cons[0, 2] = 0
-
- # Setup position constrains
- if self.config['QuotedPositionFlag']:
- cons[1, 0] = CQUOTED
- cons[1, 1] = min(x)
- cons[1, 2] = max(x)
-
- # Setup positive FWHM constrains
- if self.config['PositiveFwhmFlag']:
- cons[2, 0] = CPOSITIVE
- cons[2, 1] = 0
- cons[2, 2] = 0
-
- # Setup positive FWHM constrains
- if self.config['PositiveFwhmFlag']:
- cons[3, 0] = CPOSITIVE
- cons[3, 1] = 0
- cons[3, 2] = 0
- return largest, cons
-
- def estimate_stepup(self, x, y):
- """Estimation of parameters for a single step up curve.
-
- The functions estimates gaussian parameters for the derivative of
- the data, takes the largest gaussian peak and uses its estimated
- parameters to define the center of the step and its fwhm. The
- estimated amplitude returned is simply ``max(y) - min(y)``.
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- Parameters to be estimated for each stepup are:
- *height, centroid, fwhm* .
- """
- crappyfilter = [0.25, 0.75, 0.0, -0.75, -0.25]
- cutoff = len(crappyfilter) // 2
- y_deriv = numpy.convolve(y, crappyfilter, mode="valid")
- if max(y_deriv) > 0:
- y_deriv = y_deriv * max(y) / max(y_deriv)
-
- fittedpar, cons = self.estimate_height_position_fwhm(
- x[cutoff:-cutoff], y_deriv)
-
- # for height, use the data amplitude after removing the background
- data_amplitude = max(y) - min(y)
-
- # find params of the largest gaussian found
- if len(fittedpar):
- npeaks = len(fittedpar) // 3
- largest_index = 0
- largest = [data_amplitude,
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
- for i in range(npeaks):
- if fittedpar[3 * i] > largest[0]:
- largest_index = i
- largest = [fittedpar[3 * largest_index],
- fittedpar[3 * largest_index + 1],
- fittedpar[3 * largest_index + 2]]
- else:
- # no peak was found
- largest = [data_amplitude, # height
- x[len(x)//2], # center: middle of x range
- self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value
-
- newcons = numpy.zeros((3, 3), numpy.float64)
- # Setup constrains
- if not self.config['NoConstraintsFlag']:
- # Setup height constraints
- if self.config['PositiveHeightAreaFlag']:
- newcons[0, 0] = CPOSITIVE
- newcons[0, 1] = 0
- newcons[0, 2] = 0
-
- # Setup position constraints
- if self.config['QuotedPositionFlag']:
- newcons[1, 0] = CQUOTED
- newcons[1, 1] = min(x)
- newcons[1, 2] = max(x)
-
- # Setup positive FWHM constraints
- if self.config['PositiveFwhmFlag']:
- newcons[2, 0] = CPOSITIVE
- newcons[2, 1] = 0
- newcons[2, 2] = 0
-
- return largest, newcons
-
- def estimate_periodic_gauss(self, x, y):
- """Estimation of parameters for periodic gaussian curves:
- *number of peaks, distance between peaks, height, position of the
- first peak, fwhm*
-
- The functions detects all peaks, then computes the parameters the
- following way:
-
- - *distance*: average of distances between detected peaks
- - *height*: average height of detected peaks
- - *fwhm*: fwhm of the highest peak (in number of samples) if
- field ``'AutoFwhm'`` in :attr:`config` is ``True``, else take
- the default value (field ``'FwhmPoints'`` in :attr:`config`)
-
- :param x: Array of abscissa values
- :param y: Array of ordinate values (``y = f(x)``)
- :return: Tuple of estimated fit parameters and fit constraints.
- """
- yscaling = self.config.get('Yscaling', 1.0)
- if yscaling == 0:
- yscaling = 1.0
-
- bg = self.strip_bg(y)
-
- if self.config['AutoFwhm']:
- search_fwhm = guess_fwhm(y)
- else:
- search_fwhm = int(float(self.config['FwhmPoints']))
- search_sens = float(self.config['Sensitivity'])
-
- if search_fwhm < 3:
- search_fwhm = 3
-
- if search_sens < 1:
- search_sens = 1
-
- if len(y) > 1.5 * search_fwhm:
- peaks = peak_search(yscaling * y, fwhm=search_fwhm,
- sensitivity=search_sens)
- else:
- peaks = []
- npeaks = len(peaks)
- if not npeaks:
- fittedpar = []
- cons = numpy.zeros((len(fittedpar), 3), numpy.float64)
- return fittedpar, cons
-
- fittedpar = [0.0, 0.0, 0.0, 0.0, 0.0]
-
- # The number of peaks
- fittedpar[0] = npeaks
-
- # The separation between peaks in x units
- delta = 0.0
- height = 0.0
- for i in range(npeaks):
- height += y[int(peaks[i])] - bg[int(peaks[i])]
- if i != npeaks - 1:
- delta += (x[int(peaks[i + 1])] - x[int(peaks[i])])
-
- # delta between peaks
- if npeaks > 1:
- fittedpar[1] = delta / (npeaks - 1)
-
- # starting height
- fittedpar[2] = height / npeaks
-
- # position of the first peak
- fittedpar[3] = x[int(peaks[0])]
-
- # Estimate the fwhm
- fittedpar[4] = search_fwhm
-
- # setup constraints
- cons = numpy.zeros((5, 3), numpy.float64)
- cons[0, 0] = CFIXED # the number of gaussians
- if npeaks == 1:
- cons[1, 0] = CFIXED # the delta between peaks
- else:
- cons[1, 0] = CFREE
- j = 2
- # Setup height area constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveHeightAreaFlag']:
- # POSITIVE = 1
- cons[j, 0] = CPOSITIVE
- cons[j, 1] = 0
- cons[j, 2] = 0
- j += 1
-
- # Setup position constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['QuotedPositionFlag']:
- # QUOTED = 2
- cons[j, 0] = CQUOTED
- cons[j, 1] = min(x)
- cons[j, 2] = max(x)
- j += 1
-
- # Setup positive FWHM constrains
- if not self.config['NoConstraintsFlag']:
- if self.config['PositiveFwhmFlag']:
- # POSITIVE=1
- cons[j, 0] = CPOSITIVE
- cons[j, 1] = 0
- cons[j, 2] = 0
- j += 1
- return fittedpar, cons
-
- def configure(self, **kw):
- """Add new / unknown keyword arguments to :attr:`config`,
- update entries in :attr:`config` if the parameter name is a existing
- key.
-
- :param kw: Dictionary of keyword arguments.
- :return: Configuration dictionary :attr:`config`
- """
- if not kw.keys():
- return self.config
- for key in kw.keys():
- notdone = 1
- # take care of lower / upper case problems ...
- for config_key in self.config.keys():
- if config_key.lower() == key.lower():
- self.config[config_key] = kw[key]
- notdone = 0
- if notdone:
- self.config[key] = kw[key]
- return self.config
-
-fitfuns = FitTheories()
-
-THEORY = OrderedDict((
- ('Gaussians',
- FitTheory(description='Gaussian functions',
- function=functions.sum_gauss,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_height_position_fwhm,
- configure=fitfuns.configure)),
- ('Lorentz',
- FitTheory(description='Lorentzian functions',
- function=functions.sum_lorentz,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_height_position_fwhm,
- configure=fitfuns.configure)),
- ('Area Gaussians',
- FitTheory(description='Gaussian functions (area)',
- function=functions.sum_agauss,
- parameters=('Area', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_agauss,
- configure=fitfuns.configure)),
- ('Area Lorentz',
- FitTheory(description='Lorentzian functions (area)',
- function=functions.sum_alorentz,
- parameters=('Area', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_alorentz,
- configure=fitfuns.configure)),
- ('Pseudo-Voigt Line',
- FitTheory(description='Pseudo-Voigt functions',
- function=functions.sum_pvoigt,
- parameters=('Height', 'Position', 'FWHM', 'Eta'),
- estimate=fitfuns.estimate_pvoigt,
- configure=fitfuns.configure)),
- ('Area Pseudo-Voigt',
- FitTheory(description='Pseudo-Voigt functions (area)',
- function=functions.sum_apvoigt,
- parameters=('Area', 'Position', 'FWHM', 'Eta'),
- estimate=fitfuns.estimate_apvoigt,
- configure=fitfuns.configure)),
- ('Split Gaussian',
- FitTheory(description='Asymmetric gaussian functions',
- function=functions.sum_splitgauss,
- parameters=('Height', 'Position', 'LowFWHM',
- 'HighFWHM'),
- estimate=fitfuns.estimate_splitgauss,
- configure=fitfuns.configure)),
- ('Split Lorentz',
- FitTheory(description='Asymmetric lorentzian functions',
- function=functions.sum_splitlorentz,
- parameters=('Height', 'Position', 'LowFWHM', 'HighFWHM'),
- estimate=fitfuns.estimate_splitgauss,
- configure=fitfuns.configure)),
- ('Split Pseudo-Voigt',
- FitTheory(description='Asymmetric pseudo-Voigt functions',
- function=functions.sum_splitpvoigt,
- parameters=('Height', 'Position', 'LowFWHM',
- 'HighFWHM', 'Eta'),
- estimate=fitfuns.estimate_splitpvoigt,
- configure=fitfuns.configure)),
- ('Step Down',
- FitTheory(description='Step down function',
- function=functions.sum_stepdown,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_stepdown,
- configure=fitfuns.configure)),
- ('Step Up',
- FitTheory(description='Step up function',
- function=functions.sum_stepup,
- parameters=('Height', 'Position', 'FWHM'),
- estimate=fitfuns.estimate_stepup,
- configure=fitfuns.configure)),
- ('Slit',
- FitTheory(description='Slit function',
- function=functions.sum_slit,
- parameters=('Height', 'Position', 'FWHM', 'BeamFWHM'),
- estimate=fitfuns.estimate_slit,
- configure=fitfuns.configure)),
- ('Atan',
- FitTheory(description='Arctan step up function',
- function=functions.atan_stepup,
- parameters=('Height', 'Position', 'Width'),
- estimate=fitfuns.estimate_stepup,
- configure=fitfuns.configure)),
- ('Hypermet',
- FitTheory(description='Hypermet functions',
- function=fitfuns.ahypermet, # customized version of functions.sum_ahypermet
- parameters=('G_Area', 'Position', 'FWHM', 'ST_Area',
- 'ST_Slope', 'LT_Area', 'LT_Slope', 'Step_H'),
- estimate=fitfuns.estimate_ahypermet,
- configure=fitfuns.configure)),
- # ('Periodic Gaussians',
- # FitTheory(description='Periodic gaussian functions',
- # function=functions.periodic_gauss,
- # parameters=('N', 'Delta', 'Height', 'Position', 'FWHM'),
- # estimate=fitfuns.estimate_periodic_gauss,
- # configure=fitfuns.configure))
- ('Degree 2 Polynomial',
- FitTheory(description='Degree 2 polynomial'
- '\ny = a*x^2 + b*x +c',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c'],
- estimate=fitfuns.estimate_quadratic)),
- ('Degree 3 Polynomial',
- FitTheory(description='Degree 3 polynomial'
- '\ny = a*x^3 + b*x^2 + c*x + d',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c', 'd'],
- estimate=fitfuns.estimate_cubic)),
- ('Degree 4 Polynomial',
- FitTheory(description='Degree 4 polynomial'
- '\ny = a*x^4 + b*x^3 + c*x^2 + d*x + e',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c', 'd', 'e'],
- estimate=fitfuns.estimate_quartic)),
- ('Degree 5 Polynomial',
- FitTheory(description='Degree 5 polynomial'
- '\ny = a*x^5 + b*x^4 + c*x^3 + d*x^2 + e*x + f',
- function=fitfuns.poly,
- parameters=['a', 'b', 'c', 'd', 'e', 'f'],
- estimate=fitfuns.estimate_quintic)),
-))
-"""Dictionary of fit theories: fit functions and their associated estimation
-function, parameters list, configuration function and description.
-"""
-
-
-def test(a):
- from silx.math.fit import fitmanager
- x = numpy.arange(1000).astype(numpy.float64)
- p = [1500, 100., 50.0,
- 1500, 700., 50.0]
- y_synthetic = functions.sum_gauss(x, *p) + 1
-
- fit = fitmanager.FitManager(x, y_synthetic)
- fit.addtheory('Gaussians', functions.sum_gauss, ['Height', 'Position', 'FWHM'],
- a.estimate_height_position_fwhm)
- fit.settheory('Gaussians')
- fit.setbackground('Linear')
-
- fit.estimate()
- fit.runfit()
-
- y_fit = fit.gendata()
-
- print("Fit parameter names: %s" % str(fit.get_names()))
- print("Theoretical parameters: %s" % str(numpy.append([1, 0], p)))
- print("Fitted parameters: %s" % str(fit.get_fitted_parameters()))
-
- try:
- from silx.gui import qt
- from silx.gui.plot import plot1D
- app = qt.QApplication([])
-
- # Offset of 1 to see the difference in log scale
- plot1D(x, (y_synthetic + 1, y_fit), "Input data + 1, Fit")
-
- app.exec_()
- except ImportError:
- _logger.warning("Unable to load qt binding, can't plot results.")
-
-
-if __name__ == "__main__":
- test(fitfuns)
diff --git a/silx/math/fit/test/__init__.py b/silx/math/fit/test/__init__.py
deleted file mode 100644
index d3d8ce8..0000000
--- a/silx/math/fit/test/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "22/06/2016"
-
-import unittest
-
-from .test_fit import suite as test_curve_fit
-from .test_functions import suite as test_fitfuns
-from .test_filters import suite as test_fitfilters
-from .test_peaks import suite as test_peaks
-from .test_fitmanager import suite as test_fitmanager
-from .test_bgtheories import suite as test_bgtheories
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_curve_fit())
- test_suite.addTest(test_fitfuns())
- test_suite.addTest(test_fitfilters())
- test_suite.addTest(test_peaks())
- test_suite.addTest(test_fitmanager())
- test_suite.addTest(test_bgtheories())
- return test_suite
diff --git a/silx/math/fit/test/test_bgtheories.py b/silx/math/fit/test/test_bgtheories.py
deleted file mode 100644
index e9fea37..0000000
--- a/silx/math/fit/test/test_bgtheories.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-import copy
-import unittest
-import numpy
-import random
-
-from silx.math.fit import bgtheories
-from silx.math.fit.functions import sum_gauss
-
-
-class TestBgTheories(unittest.TestCase):
- """
- """
- def setUp(self):
- self.x = numpy.arange(100)
- self.y = 10 + 0.05 * self.x + sum_gauss(self.x, 10., 45., 15.)
- # add a very narrow high amplitude peak to test strip and snip
- self.y += sum_gauss(self.x, 100., 75., 2.)
- self.narrow_peak_index = list(self.x).index(75)
- random.seed()
-
- def tearDown(self):
- pass
-
- def testTheoriesAttrs(self):
- for theory_name in bgtheories.THEORY:
- self.assertIsInstance(theory_name, str)
- self.assertTrue(hasattr(bgtheories.THEORY[theory_name],
- "function"))
- self.assertTrue(hasattr(bgtheories.THEORY[theory_name].function,
- "__call__"))
- # Ensure legacy functions are not renamed accidentally
- self.assertTrue(
- {"No Background", "Constant", "Linear", "Strip", "Snip"}.issubset(
- set(bgtheories.THEORY)))
-
- def testNoBg(self):
- nobgfun = bgtheories.THEORY["No Background"].function
- self.assertTrue(numpy.array_equal(nobgfun(self.x, self.y),
- numpy.zeros_like(self.x)))
- # default estimate
- self.assertEqual(bgtheories.THEORY["No Background"].estimate(self.x, self.y),
- ([], []))
-
- def testConstant(self):
- consfun = bgtheories.THEORY["Constant"].function
- c = random.random() * 100
- self.assertTrue(numpy.array_equal(consfun(self.x, self.y, c),
- c * numpy.ones_like(self.x)))
- # default estimate
- esti_par, cons = bgtheories.THEORY["Constant"].estimate(self.x, self.y)
- self.assertEqual(cons,
- [[0, 0, 0]])
- self.assertAlmostEqual(esti_par,
- min(self.y))
-
- def testLinear(self):
- linfun = bgtheories.THEORY["Linear"].function
- a = random.random() * 100
- b = random.random() * 100
- self.assertTrue(numpy.array_equal(linfun(self.x, self.y, a, b),
- a + b * self.x))
- # default estimate
- esti_par, cons = bgtheories.THEORY["Linear"].estimate(self.x, self.y)
-
- self.assertEqual(cons,
- [[0, 0, 0], [0, 0, 0]])
- self.assertAlmostEqual(esti_par[0], 10, places=3)
- self.assertAlmostEqual(esti_par[1], 0.05, places=3)
-
- def testStrip(self):
- stripfun = bgtheories.THEORY["Strip"].function
- anchors = sorted(random.sample(list(self.x), 4))
- anchors_indices = [list(self.x).index(a) for a in anchors]
-
- # we really want to strip away the narrow peak
- anchors_indices_copy = copy.deepcopy(anchors_indices)
- for idx in anchors_indices_copy:
- if abs(idx - self.narrow_peak_index) < 5:
- anchors_indices.remove(idx)
- anchors.remove(self.x[idx])
-
- width = 2
- niter = 1000
- bgtheories.THEORY["Strip"].configure(AnchorsList=anchors, AnchorsFlag=True)
-
- bg = stripfun(self.x, self.y, width, niter)
-
- # assert peak amplitude has been decreased
- self.assertLess(bg[self.narrow_peak_index],
- self.y[self.narrow_peak_index])
-
- # default estimate
- for i in anchors_indices:
- self.assertEqual(bg[i], self.y[i])
-
- # estimated parameters are equal to the default ones in the config dict
- bgtheories.THEORY["Strip"].configure(StripWidth=7, StripIterations=8)
- esti_par, cons = bgtheories.THEORY["Strip"].estimate(self.x, self.y)
- self.assertTrue(numpy.array_equal(cons, [[3, 0, 0], [3, 0, 0]]))
- self.assertEqual(esti_par, [7, 8])
-
- def testSnip(self):
- snipfun = bgtheories.THEORY["Snip"].function
- anchors = sorted(random.sample(list(self.x), 4))
- anchors_indices = [list(self.x).index(a) for a in anchors]
-
- # we want to strip away the narrow peak, so remove nearby anchors
- anchors_indices_copy = copy.deepcopy(anchors_indices)
- for idx in anchors_indices_copy:
- if abs(idx - self.narrow_peak_index) < 5:
- anchors_indices.remove(idx)
- anchors.remove(self.x[idx])
-
- width = 16
- bgtheories.THEORY["Snip"].configure(AnchorsList=anchors, AnchorsFlag=True)
- bg = snipfun(self.x, self.y, width)
-
- # assert peak amplitude has been decreased
- self.assertLess(bg[self.narrow_peak_index],
- self.y[self.narrow_peak_index],
- "Snip didn't decrease the peak amplitude.")
-
- # anchored data must remain fixed
- for i in anchors_indices:
- self.assertEqual(bg[i], self.y[i])
-
- # estimated parameters are equal to the default ones in the config dict
- bgtheories.THEORY["Snip"].configure(SnipWidth=7)
- esti_par, cons = bgtheories.THEORY["Snip"].estimate(self.x, self.y)
- self.assertTrue(numpy.array_equal(cons, [[3, 0, 0]]))
- self.assertEqual(esti_par, [7])
-
-
-test_cases = (TestBgTheories,)
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/fit/test/test_filters.py b/silx/math/fit/test/test_filters.py
deleted file mode 100644
index 078b998..0000000
--- a/silx/math/fit/test/test_filters.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-import numpy
-import unittest
-from silx.math.fit import filters
-from silx.math.fit import functions
-from silx.test.utils import add_relative_noise
-
-
-class TestSmooth(unittest.TestCase):
- """
- Unit tests of smoothing functions.
-
- Test that the difference between a synthetic curve with 5% added random
- noise and the result of smoothing that signal is less than 5%. We compare
- the sum of all samples in each curve.
- """
- def setUp(self):
- x = numpy.arange(5000)
- # (height1, center1, fwhm1, beamfwhm...)
- slit_params = (50, 500, 200, 100,
- 50, 600, 80, 30,
- 20, 2000, 150, 150,
- 50, 2250, 110, 100,
- 40, 3000, 50, 10,
- 23, 4980, 250, 20)
-
- self.y1 = functions.sum_slit(x, *slit_params)
- # 5% noise
- self.y1 = add_relative_noise(self.y1, 5.)
-
- # (height1, center1, fwhm1...)
- step_params = (50, 500, 200,
- 50, 600, 80,
- 20, 2000, 150,
- 50, 2250, 110,
- 40, 3000, 50,
- 23, 4980, 250,)
-
- self.y2 = functions.sum_stepup(x, *step_params)
- # 5% noise
- self.y2 = add_relative_noise(self.y2, 5.)
-
- self.y3 = functions.sum_stepdown(x, *step_params)
- # 5% noise
- self.y3 = add_relative_noise(self.y3, 5.)
-
- def tearDown(self):
- pass
-
- def testSavitskyGolay(self):
- npts = 25
- for y in [self.y1, self.y2, self.y3]:
- smoothed_y = filters.savitsky_golay(y, npoints=npts)
-
- # we added +-5% of random noise. The difference must be much lower
- # than 5%.
- diff = abs(sum(smoothed_y) - sum(y)) / sum(y)
- self.assertLess(diff, 0.05,
- "Difference between data with 5%% noise and " +
- "smoothed data is > 5%% (%f %%)" % (diff * 100))
-
- # Try various smoothing levels
- npts += 25
-
- def testSmooth1d(self):
- """Test the 1D smoothing against the formula
- ys[i] = (y[i-1] + 2 * y[i] + y[i+1]) / 4 (for 1 < i < n-1)"""
- smoothed_y = filters.smooth1d(self.y1)
-
- for i in range(1, len(self.y1) - 1):
- self.assertAlmostEqual(4 * smoothed_y[i],
- self.y1[i-1] + 2 * self.y1[i] + self.y1[i+1])
-
- def testSmooth2d(self):
- """Test that a 2D smoothing is the same as two successive and
- orthogonal 1D smoothings"""
- x = numpy.arange(10000)
-
- noise = 2 * numpy.random.random(10000) - 1
- noise *= 0.05
- y = x * (1 + noise)
-
- y.shape = (100, 100)
-
- smoothed_y = filters.smooth2d(y)
-
- intermediate_smooth = numpy.zeros_like(y)
- expected_smooth = numpy.zeros_like(y)
- # smooth along first dimension
- for i in range(0, y.shape[0]):
- intermediate_smooth[i, :] = filters.smooth1d(y[i, :])
-
- # smooth along second dimension
- for j in range(0, y.shape[1]):
- expected_smooth[:, j] = filters.smooth1d(intermediate_smooth[:, j])
-
- for i in range(0, y.shape[0]):
- for j in range(0, y.shape[1]):
- self.assertAlmostEqual(smoothed_y[i, j],
- expected_smooth[i, j])
-
-
-test_cases = (TestSmooth,)
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/fit/test/test_fit.py b/silx/math/fit/test/test_fit.py
deleted file mode 100644
index 3fdf394..0000000
--- a/silx/math/fit/test/test_fit.py
+++ /dev/null
@@ -1,387 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Nominal tests of the leastsq function.
-"""
-
-import unittest
-
-import numpy
-import sys
-
-from silx.utils import testutils
-from silx.math.fit.leastsq import _logger as fitlogger
-
-
-class Test_leastsq(unittest.TestCase):
- """
- Unit tests of the leastsq function.
- """
-
- ndims = None
-
- def setUp(self):
- try:
- from silx.math.fit import leastsq
- self.instance = leastsq
- except ImportError:
- self.instance = None
-
- def myexp(x):
- # put a (bad) filter to avoid over/underflows
- # with no python looping
- return numpy.exp(x*numpy.less(abs(x), 250)) - \
- 1.0 * numpy.greater_equal(abs(x), 250)
-
- self.my_exp = myexp
-
- def gauss(x, *params):
- params = numpy.array(params, copy=False, dtype=numpy.float64)
- result = params[0] + params[1] * x
- for i in range(2, len(params), 3):
- p = params[i:(i+3)]
- dummy = 2.3548200450309493*(x - p[1])/p[2]
- result += p[0] * self.my_exp(-0.5 * dummy * dummy)
- return result
-
- self.gauss = gauss
-
- def gauss_derivative(x, params, idx):
- if idx == 0:
- return numpy.ones(len(x), numpy.float64)
- if idx == 1:
- return x
- gaussian_peak = (idx - 2) // 3
- gaussian_parameter = (idx - 2) % 3
- actual_idx = 2 + 3 * gaussian_peak
- p = params[actual_idx:(actual_idx+3)]
- if gaussian_parameter == 0:
- return self.gauss(x, *[0, 0, 1.0, p[1], p[2]])
- if gaussian_parameter == 1:
- tmp = self.gauss(x, *[0, 0, p[0], p[1], p[2]])
- tmp *= 2.3548200450309493*(x - p[1])/p[2]
- return tmp * 2.3548200450309493/p[2]
- if gaussian_parameter == 2:
- tmp = self.gauss(x, *[0, 0, p[0], p[1], p[2]])
- tmp *= 2.3548200450309493*(x - p[1])/p[2]
- return tmp * 2.3548200450309493*(x - p[1])/(p[2]*p[2])
-
- self.gauss_derivative = gauss_derivative
-
- def tearDown(self):
- self.instance = None
- self.gauss = None
- self.gauss_derivative = None
- self.my_exp = None
- self.model_function = None
- self.model_derivative = None
-
- def testImport(self):
- self.assertTrue(self.instance is not None,
- "Cannot import leastsq from silx.math.fit")
-
- def testUnconstrainedFitNoWeight(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.)
- y = self.gauss(x, *parameters_actual)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
- model_function = self.gauss
-
- fittedpar, cov = self.instance(model_function, x, y, parameters_estimate)
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- def testUnconstrainedFitWeight(self):
- parameters_actual = [10.5,2,1000.0,20.,15]
- x = numpy.arange(10000.)
- y = self.gauss(x, *parameters_actual)
- sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
- model_function = self.gauss
-
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma)
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- def testDerivativeFunction(self):
- parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300]
- x = numpy.arange(10000.)
- y = self.gauss(x, *parameters_actual)
- delta = numpy.sqrt(numpy.finfo(numpy.float64).eps)
- for i in range(len(parameters_actual)):
- p = parameters_actual * 1
- if p[i] == 0:
- delta_par = delta
- else:
- delta_par = p[i] * delta
- if i > 2:
- p[0] = 0.0
- p[1] = 0.0
- p[i] += delta_par
- yPlus = self.gauss(x, *p)
- p[i] = parameters_actual[i] - delta_par
- yMinus = self.gauss(x, *p)
- numerical_derivative = (yPlus - yMinus) / (2 * delta_par)
- #numerical_derivative = (self.gauss(x, *p) - y) / delta_par
- p[i] = parameters_actual[i]
- derivative = self.gauss_derivative(x, p, i)
- diff = numerical_derivative - derivative
- test_condition = numpy.allclose(numerical_derivative,
- derivative, atol=5.0e-6)
- if not test_condition:
- msg = "Error calculating derivative of parameter %d." % i
- msg += "\n diff min = %g diff max = %g" % (diff.min(), diff.max())
- self.assertTrue(test_condition, msg)
-
- def testConstrainedFit(self):
- CFREE = 0
- CPOSITIVE = 1
- CQUOTED = 2
- CFIXED = 3
- CFACTOR = 4
- CDELTA = 5
- CSUM = 6
- parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300]
- x = numpy.arange(10000.)
- y = self.gauss(x, *parameters_actual)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10, 400, 850, 200]
- model_function = self.gauss
- model_deriv = self.gauss_derivative
- constraints_all_free = [[0, 0, 0]] * len(parameters_actual)
- constraints_all_positive = [[1, 0, 0]] * len(parameters_actual)
- constraints_delta_position = [[0, 0, 0]] * len(parameters_actual)
- constraints_delta_position[6] = [CDELTA, 3, 880]
- constraints_sum_position = constraints_all_positive * 1
- constraints_sum_position[6] = [CSUM, 3, 920]
- constraints_factor = constraints_delta_position * 1
- constraints_factor[2] = [CFACTOR, 5, 2]
- constraints_list = [None,
- constraints_all_free,
- constraints_all_positive,
- constraints_delta_position,
- constraints_sum_position]
-
- # for better code coverage, the warning recommending to set full_output
- # to True when using constraints should be shown at least once
- full_output = True
- for index, constraints in enumerate(constraints_list):
- if index == 2:
- full_output = None
- elif index == 3:
- full_output = 0
- for model_deriv in [None, self.gauss_derivative]:
- for sigma in [None, numpy.sqrt(y)]:
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- constraints=constraints,
- model_deriv=model_deriv,
- full_output=full_output)[:2]
- full_output = True
-
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- def testUnconstrainedFitAnalyticalDerivative(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.)
- y = self.gauss(x, *parameters_actual)
- sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
- model_function = self.gauss
- model_deriv = self.gauss_derivative
-
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- model_deriv=model_deriv)
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- @testutils.test_logging(fitlogger.name, warning=2)
- def testBadlyShapedData(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.).reshape(1000, 10)
- y = self.gauss(x, *parameters_actual)
- sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
- model_function = self.gauss
-
- for check_finite in [True, False]:
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=check_finite)
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- @testutils.test_logging(fitlogger.name, warning=3)
- def testDataWithNaN(self):
- parameters_actual = [10.5, 2, 1000.0, 20., 15]
- x = numpy.arange(10000.).reshape(1000, 10)
- y = self.gauss(x, *parameters_actual)
- sigma = numpy.sqrt(y)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
- model_function = self.gauss
- x[500] = numpy.inf
- # check default behavior
- try:
- self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma)
- except ValueError:
- info = "%s" % sys.exc_info()[1]
- self.assertTrue("array must not contain inf" in info)
-
- # check requested behavior
- try:
- self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=True)
- except ValueError:
- info = "%s" % sys.exc_info()[1]
- self.assertTrue("array must not contain inf" in info)
-
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=False)
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- # testing now with ydata containing NaN
- x = numpy.arange(10000.).reshape(1000, 10)
- y[500] = numpy.nan
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=False)
-
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- # testing now with sigma containing NaN
- sigma[300] = numpy.nan
- fittedpar, cov = self.instance(model_function, x, y,
- parameters_estimate,
- sigma=sigma,
- check_finite=False)
- test_condition = numpy.allclose(parameters_actual, fittedpar)
- if not test_condition:
- msg = "Unsuccessfull fit\n"
- for i in range(len(fittedpar)):
- msg += "Expected %g obtained %g\n" % (parameters_actual[i],
- fittedpar[i])
- self.assertTrue(test_condition, msg)
-
- def testUncertainties(self):
- """Test for validity of uncertainties in returned full-output
- dictionary. This is a non-regression test for pull request #197"""
- parameters_actual = [10.5, 2, 1000.0, 20., 15, 2001.0, 30.1, 16]
- x = numpy.arange(10000.)
- y = self.gauss(x, *parameters_actual)
- parameters_estimate = [0.0, 1.0, 900.0, 25., 10., 1500., 20., 2.0]
-
- # test that uncertainties are not 0.
- fittedpar, cov, infodict = self.instance(self.gauss, x, y, parameters_estimate,
- full_output=True)
- uncertainties = infodict["uncertainties"]
- self.assertEqual(len(uncertainties), len(parameters_actual))
- self.assertEqual(len(uncertainties), len(fittedpar))
- for uncertainty in uncertainties:
- self.assertNotAlmostEqual(uncertainty, 0.)
-
- # set constraint FIXED for half the parameters.
- # This should cause leastsq to return 100% uncertainty.
- parameters_estimate = [10.6, 2.1, 1000.1, 20.1, 15.1, 2001.1, 30.2, 16.1]
- CFIXED = 3
- CFREE = 0
- constraints = []
- for i in range(len(parameters_estimate)):
- if i % 2:
- constraints.append([CFIXED, 0, 0])
- else:
- constraints.append([CFREE, 0, 0])
- fittedpar, cov, infodict = self.instance(self.gauss, x, y, parameters_estimate,
- constraints=constraints,
- full_output=True)
- uncertainties = infodict["uncertainties"]
- for i in range(len(parameters_estimate)):
- if i % 2:
- # test that all FIXED parameters have 100% uncertainty
- self.assertAlmostEqual(uncertainties[i],
- parameters_estimate[i])
-
-
-test_cases = (Test_leastsq,)
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/fit/test/test_fitmanager.py b/silx/math/fit/test/test_fitmanager.py
deleted file mode 100644
index acac242..0000000
--- a/silx/math/fit/test/test_fitmanager.py
+++ /dev/null
@@ -1,513 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Tests for fitmanager module
-"""
-
-import unittest
-import numpy
-import os.path
-
-from silx.math.fit import fitmanager
-from silx.math.fit import fittheories
-from silx.math.fit import bgtheories
-from silx.math.fit.fittheory import FitTheory
-from silx.math.fit.functions import sum_gauss, sum_stepdown, sum_stepup
-
-from silx.utils.testutils import ParametricTestCase
-from silx.test.utils import temp_dir
-
-custom_function_definition = """
-import copy
-from silx.math.fit.fittheory import FitTheory
-
-CONFIG = {'d': 1.}
-
-def myfun(x, a, b, c):
- "Model function"
- return (a * x**2 + b * x + c) / CONFIG['d']
-
-def myesti(x, y):
- "Initial parameters for iterative fit (a, b, c) = (1, 1, 1)"
- return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
-
-def myconfig(d=1., **kw):
- "This function can modify CONFIG"
- CONFIG["d"] = d
- return CONFIG
-
-def myderiv(x, parameters, index):
- "Custom derivative (does not work, causes singular matrix)"
- pars_plus = copy.copy(parameters)
- pars_plus[index] *= 1.0001
-
- pars_minus = parameters
- pars_minus[index] *= copy.copy(0.9999)
-
- delta_fun = myfun(x, *pars_plus) - myfun(x, *pars_minus)
- delta_par = parameters[index] * 0.0001 * 2
-
- return delta_fun / delta_par
-
-THEORY = {
- 'my fit theory':
- FitTheory(function=myfun,
- parameters=('A', 'B', 'C'),
- estimate=myesti,
- configure=myconfig,
- derivative=myderiv)
-}
-
-"""
-
-old_custom_function_definition = """
-CONFIG = {'d': 1.0}
-
-def myfun(x, a, b, c):
- "Model function"
- return (a * x**2 + b * x + c) / CONFIG['d']
-
-def myesti(x, y, bg, xscalinq, yscaling):
- "Initial parameters for iterative fit (a, b, c) = (1, 1, 1)"
- return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
-
-def myconfig(**kw):
- "Update or complete CONFIG dictionary"
- for key in kw:
- CONFIG[key] = kw[key]
- return CONFIG
-
-THEORY = ['my fit theory']
-PARAMETERS = [('A', 'B', 'C')]
-FUNCTION = [myfun]
-ESTIMATE = [myesti]
-CONFIGURE = [myconfig]
-
-"""
-
-
-def _order_of_magnitude(x):
- return numpy.log10(x).round()
-
-
-class TestFitmanager(ParametricTestCase):
- """
- Unit tests of multi-peak functions.
- """
- def setUp(self):
- pass
-
- def tearDown(self):
- pass
-
- def testFitManager(self):
- """Test fit manager on synthetic data using a gaussian function
- and a linear background"""
- # Create synthetic data with a sum of gaussian functions
- x = numpy.arange(1000).astype(numpy.float64)
-
- p = [1000, 100., 250,
- 255, 650., 45,
- 1500, 800.5, 95]
- linear_bg = 2.65 * x + 13
- y = linear_bg + sum_gauss(x, *p)
-
- y_with_nans = numpy.array(y)
- y_with_nans[::10] = numpy.nan
-
- x_with_nans = numpy.array(x)
- x_with_nans[5::15] = numpy.nan
-
- tests = {
- 'all finite': (x, y),
- 'y with NaNs': (x, y_with_nans),
- 'x with NaNs': (x_with_nans, y),
- }
-
- for name, (xdata, ydata) in tests.items():
- with self.subTest(name=name):
- # Fitting
- fit = fitmanager.FitManager()
- fit.setdata(x=xdata, y=ydata)
- fit.loadtheories(fittheories)
- # Use one of the default fit functions
- fit.settheory('Gaussians')
- fit.setbackground('Linear')
- fit.estimate()
- fit.runfit()
-
- # fit.fit_results[]
-
- # first 2 parameters are related to the linear background
- self.assertEqual(fit.fit_results[0]["name"], "Constant")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"], 13)
- self.assertEqual(fit.fit_results[1]["name"], "Slope")
- self.assertAlmostEqual(fit.fit_results[1]["fitresult"], 2.65)
-
- for i, param in enumerate(fit.fit_results[2:]):
- param_number = i // 3 + 1
- if i % 3 == 0:
- self.assertEqual(param["name"],
- "Height%d" % param_number)
- elif i % 3 == 1:
- self.assertEqual(param["name"],
- "Position%d" % param_number)
- elif i % 3 == 2:
- self.assertEqual(param["name"],
- "FWHM%d" % param_number)
-
- self.assertAlmostEqual(param["fitresult"],
- p[i])
- self.assertAlmostEqual(_order_of_magnitude(param["estimation"]),
- _order_of_magnitude(p[i]))
-
- def testLoadCustomFitFunction(self):
- """Test FitManager using a custom fit function defined in an external
- file and imported with FitManager.loadtheories"""
- # Create synthetic data with a sum of gaussian functions
- x = numpy.arange(100).astype(numpy.float64)
-
- # a, b, c are the fit parameters
- # d is a known scaling parameter that is set using configure()
- a, b, c, d = 1.5, 2.5, 3.5, 4.5
- y = (a * x**2 + b * x + c) / d
-
- # Fitting
- fit = fitmanager.FitManager()
- fit.setdata(x=x, y=y)
-
- # Create a temporary function definition file, and import it
- with temp_dir() as tmpDir:
- tmpfile = os.path.join(tmpDir, 'customfun.py')
- # custom_function_definition
- fd = open(tmpfile, "w")
- fd.write(custom_function_definition)
- fd.close()
- fit.loadtheories(tmpfile)
- tmpfile_pyc = os.path.join(tmpDir, 'customfun.pyc')
- if os.path.exists(tmpfile_pyc):
- os.unlink(tmpfile_pyc)
- os.unlink(tmpfile)
-
- fit.settheory('my fit theory')
- # Test configure
- fit.configure(d=4.5)
- fit.estimate()
- fit.runfit()
-
- self.assertEqual(fit.fit_results[0]["name"],
- "A1")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
- 1.5)
- self.assertEqual(fit.fit_results[1]["name"],
- "B1")
- self.assertAlmostEqual(fit.fit_results[1]["fitresult"],
- 2.5)
- self.assertEqual(fit.fit_results[2]["name"],
- "C1")
- self.assertAlmostEqual(fit.fit_results[2]["fitresult"],
- 3.5)
-
- def testLoadOldCustomFitFunction(self):
- """Test FitManager using a custom fit function defined in an external
- file and imported with FitManager.loadtheories (legacy PyMca format)"""
- # Create synthetic data with a sum of gaussian functions
- x = numpy.arange(100).astype(numpy.float64)
-
- # a, b, c are the fit parameters
- # d is a known scaling parameter that is set using configure()
- a, b, c, d = 1.5, 2.5, 3.5, 4.5
- y = (a * x**2 + b * x + c) / d
-
- # Fitting
- fit = fitmanager.FitManager()
- fit.setdata(x=x, y=y)
-
- # Create a temporary function definition file, and import it
- with temp_dir() as tmpDir:
- tmpfile = os.path.join(tmpDir, 'oldcustomfun.py')
- # custom_function_definition
- fd = open(tmpfile, "w")
- fd.write(old_custom_function_definition)
- fd.close()
- fit.loadtheories(tmpfile)
- tmpfile_pyc = os.path.join(tmpDir, 'oldcustomfun.pyc')
- if os.path.exists(tmpfile_pyc):
- os.unlink(tmpfile_pyc)
- os.unlink(tmpfile)
-
- fit.settheory('my fit theory')
- fit.configure(d=4.5)
- fit.estimate()
- fit.runfit()
-
- self.assertEqual(fit.fit_results[0]["name"],
- "A1")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
- 1.5)
- self.assertEqual(fit.fit_results[1]["name"],
- "B1")
- self.assertAlmostEqual(fit.fit_results[1]["fitresult"],
- 2.5)
- self.assertEqual(fit.fit_results[2]["name"],
- "C1")
- self.assertAlmostEqual(fit.fit_results[2]["fitresult"],
- 3.5)
-
- def testAddTheory(self, estimate=True):
- """Test FitManager using a custom fit function imported with
- FitManager.addtheory"""
- # Create synthetic data with a sum of gaussian functions
- x = numpy.arange(100).astype(numpy.float64)
-
- # a, b, c are the fit parameters
- # d is a known scaling parameter that is set using configure()
- a, b, c, d = -3.14, 1234.5, 10000, 4.5
- y = (a * x**2 + b * x + c) / d
-
- # Fitting
- fit = fitmanager.FitManager()
- fit.setdata(x=x, y=y)
-
- # Define and add the fit theory
- CONFIG = {'d': 1.}
-
- def myfun(x_, a_, b_, c_):
- """"Model function"""
- return (a_ * x_**2 + b_ * x_ + c_) / CONFIG['d']
-
- def myesti(x_, y_):
- """"Initial parameters for iterative fit:
- (a, b, c) = (1, 1, 1)
- Constraints all set to 0 (FREE)"""
- return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
-
- def myconfig(d_=1., **kw):
- """This function can modify CONFIG"""
- CONFIG["d"] = d_
- return CONFIG
-
- def myderiv(x_, parameters, index):
- """Custom derivative"""
- pars_plus = numpy.array(parameters, copy=True)
- pars_plus[index] *= 1.001
-
- pars_minus = numpy.array(parameters, copy=True)
- pars_minus[index] *= 0.999
-
- delta_fun = myfun(x_, *pars_plus) - myfun(x_, *pars_minus)
- delta_par = parameters[index] * 0.001 * 2
-
- return delta_fun / delta_par
-
- fit.addtheory("polynomial",
- FitTheory(function=myfun,
- parameters=["A", "B", "C"],
- estimate=myesti if estimate else None,
- configure=myconfig,
- derivative=myderiv))
-
- fit.settheory('polynomial')
- fit.configure(d_=4.5)
- fit.estimate()
- params1, sigmas, infodict = fit.runfit()
-
- self.assertEqual(fit.fit_results[0]["name"],
- "A1")
- self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
- -3.14)
- self.assertEqual(fit.fit_results[1]["name"],
- "B1")
- # params1[1] is the same as fit.fit_results[1]["fitresult"]
- self.assertAlmostEqual(params1[1],
- 1234.5)
- self.assertEqual(fit.fit_results[2]["name"],
- "C1")
- self.assertAlmostEqual(params1[2],
- 10000)
-
- # change configuration scaling factor and check that the fit returns
- # different values
- fit.configure(d_=5.)
- fit.estimate()
- params2, sigmas, infodict = fit.runfit()
- for p1, p2 in zip(params1, params2):
- self.assertFalse(numpy.array_equal(p1, p2),
- "Fit parameters are equal even though the " +
- "configuration has been changed")
-
- def testNoEstimate(self):
- """Ensure that the in the absence of the estimation function,
- the default estimation function :meth:`FitTheory.default_estimate`
- is used."""
- self.testAddTheory(estimate=False)
-
- def testStep(self):
- """Test fit manager on a step function with a more complex estimate
- function than the gaussian (convolution filter)"""
- for theory_name, theory_fun in (('Step Down', sum_stepdown),
- ('Step Up', sum_stepup)):
- # Create synthetic data with a sum of gaussian functions
- x = numpy.arange(1000).astype(numpy.float64)
-
- # ('Height', 'Position', 'FWHM')
- p = [1000, 439, 250]
-
- constantbg = 13
- y = theory_fun(x, *p) + constantbg
-
- # Fitting
- fit = fitmanager.FitManager()
- fit.setdata(x=x, y=y)
- fit.loadtheories(fittheories)
- fit.settheory(theory_name)
- fit.setbackground('Constant')
-
- fit.estimate()
-
- params, sigmas, infodict = fit.runfit()
-
- # first parameter is the constant background
- self.assertAlmostEqual(params[0], 13, places=5)
- for i, param in enumerate(params[1:]):
- self.assertAlmostEqual(param, p[i], places=5)
- self.assertAlmostEqual(_order_of_magnitude(fit.fit_results[i+1]["estimation"]),
- _order_of_magnitude(p[i]))
-
-
-def quadratic(x, a, b, c):
- return a * x**2 + b * x + c
-
-
-def cubic(x, a, b, c, d):
- return a * x**3 + b * x**2 + c * x + d
-
-
-class TestPolynomials(unittest.TestCase):
- """Test polynomial fit theories and fit background"""
- def setUp(self):
- self.x = numpy.arange(100).astype(numpy.float64)
-
- def testQuadraticBg(self):
- gaussian_params = [100, 45, 8]
- poly_params = [0.05, -2, 3]
- p = numpy.poly1d(poly_params)
-
- y = p(self.x) + sum_gauss(self.x, *gaussian_params)
-
- fm = fitmanager.FitManager(self.x, y)
- fm.loadbgtheories(bgtheories)
- fm.loadtheories(fittheories)
- fm.settheory("Gaussians")
- fm.setbackground("Degree 2 Polynomial")
- esti_params = fm.estimate()
- fit_params = fm.runfit()[0]
-
- for p, pfit in zip(poly_params + gaussian_params, fit_params):
- self.assertAlmostEqual(p,
- pfit)
-
- def testCubicBg(self):
- gaussian_params = [1000, 45, 8]
- poly_params = [0.0005, -0.05, 3, -4]
- p = numpy.poly1d(poly_params)
-
- y = p(self.x) + sum_gauss(self.x, *gaussian_params)
-
- fm = fitmanager.FitManager(self.x, y)
- fm.loadtheories(fittheories)
- fm.settheory("Gaussians")
- fm.setbackground("Degree 3 Polynomial")
- esti_params = fm.estimate()
- fit_params = fm.runfit()[0]
-
- for p, pfit in zip(poly_params + gaussian_params, fit_params):
- self.assertAlmostEqual(p,
- pfit)
-
- def testQuarticcBg(self):
- gaussian_params = [10000, 69, 25]
- poly_params = [5e-10, 0.0005, 0.005, 2, 4]
- p = numpy.poly1d(poly_params)
-
- y = p(self.x) + sum_gauss(self.x, *gaussian_params)
-
- fm = fitmanager.FitManager(self.x, y)
- fm.loadtheories(fittheories)
- fm.settheory("Gaussians")
- fm.setbackground("Degree 4 Polynomial")
- esti_params = fm.estimate()
- fit_params = fm.runfit()[0]
-
- for p, pfit in zip(poly_params + gaussian_params, fit_params):
- self.assertAlmostEqual(p,
- pfit,
- places=5)
-
- def _testPoly(self, poly_params, theory, places=5):
- p = numpy.poly1d(poly_params)
-
- y = p(self.x)
-
- fm = fitmanager.FitManager(self.x, y)
- fm.loadbgtheories(bgtheories)
- fm.loadtheories(fittheories)
- fm.settheory(theory)
- esti_params = fm.estimate()
- fit_params = fm.runfit()[0]
-
- for p, pfit in zip(poly_params, fit_params):
- self.assertAlmostEqual(p, pfit, places=places)
-
- def testQuadratic(self):
- self._testPoly([0.05, -2, 3],
- "Degree 2 Polynomial")
-
- def testCubic(self):
- self._testPoly([0.0005, -0.05, 3, -4],
- "Degree 3 Polynomial")
-
- def testQuartic(self):
- self._testPoly([1, -2, 3, -4, -5],
- "Degree 4 Polynomial")
-
- def testQuintic(self):
- self._testPoly([1, -2, 3, -4, -5, 6],
- "Degree 5 Polynomial",
- places=4)
-
-
-test_cases = (TestFitmanager, TestPolynomials)
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/fit/test/test_functions.py b/silx/math/fit/test/test_functions.py
deleted file mode 100644
index ce7dbd6..0000000
--- a/silx/math/fit/test/test_functions.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Tests for functions module
-"""
-
-import unittest
-import numpy
-import math
-
-from silx.math.fit import functions
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "21/07/2016"
-
-class Test_functions(unittest.TestCase):
- """
- Unit tests of multi-peak functions.
- """
- def setUp(self):
- self.x = numpy.arange(11)
-
- # height, center, sigma1, sigma2
- (h, c, s1, s2) = (7., 5., 3., 2.1)
- self.g_params = {
- "height": h,
- "center": c,
- #"sigma": s,
- "fwhm1": 2 * math.sqrt(2 * math.log(2)) * s1,
- "fwhm2": 2 * math.sqrt(2 * math.log(2)) * s2,
- "area1": h * s1 * math.sqrt(2 * math.pi)
- }
- # result of `7 * scipy.signal.gaussian(11, 3)`
- self.scipy_gaussian = numpy.array(
- [1.74546546, 2.87778603, 4.24571462, 5.60516182, 6.62171628,
- 7., 6.62171628, 5.60516182, 4.24571462, 2.87778603,
- 1.74546546]
- )
-
- # result of:
- # numpy.concatenate((7 * scipy.signal.gaussian(11, 3)[0:5],
- # 7 * scipy.signal.gaussian(11, 2.1)[5:11]))
- self.scipy_asym_gaussian = numpy.array(
- [1.74546546, 2.87778603, 4.24571462, 5.60516182, 6.62171628,
- 7., 6.24968751, 4.44773692, 2.52313452, 1.14093853, 0.41124877]
- )
-
- def tearDown(self):
- pass
-
- def testGauss(self):
- """Compare sum_gauss with scipy.signals.gaussian"""
- y = functions.sum_gauss(self.x,
- self.g_params["height"],
- self.g_params["center"],
- self.g_params["fwhm1"])
-
- for i in range(11):
- self.assertAlmostEqual(y[i], self.scipy_gaussian[i])
-
- def testAGauss(self):
- """Compare sum_agauss with scipy.signals.gaussian"""
- y = functions.sum_agauss(self.x,
- self.g_params["area1"],
- self.g_params["center"],
- self.g_params["fwhm1"])
- for i in range(11):
- self.assertAlmostEqual(y[i], self.scipy_gaussian[i])
-
- def testFastAGauss(self):
- """Compare sum_fastagauss with scipy.signals.gaussian
- Limit precision to 3 decimal places."""
- y = functions.sum_fastagauss(self.x,
- self.g_params["area1"],
- self.g_params["center"],
- self.g_params["fwhm1"])
- for i in range(11):
- self.assertAlmostEqual(y[i], self.scipy_gaussian[i], 3)
-
-
- def testSplitGauss(self):
- """Compare sum_splitgauss with scipy.signals.gaussian"""
- y = functions.sum_splitgauss(self.x,
- self.g_params["height"],
- self.g_params["center"],
- self.g_params["fwhm1"],
- self.g_params["fwhm2"])
- for i in range(11):
- self.assertAlmostEqual(y[i], self.scipy_asym_gaussian[i])
-
- def testErf(self):
- """Compare erf with math.erf"""
- # scalars
- self.assertAlmostEqual(functions.erf(0.14), math.erf(0.14), places=5)
- self.assertAlmostEqual(functions.erf(0), math.erf(0), places=5)
- self.assertAlmostEqual(functions.erf(-0.74), math.erf(-0.74), places=5)
-
- # lists
- x = [-5, -2, -1.5, -0.6, 0, 0.1, 2, 3]
- erfx = functions.erf(x)
- for i in range(len(x)):
- self.assertAlmostEqual(erfx[i],
- math.erf(x[i]),
- places=5)
-
- # ndarray
- x = numpy.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
- erfx = functions.erf(x)
- for i in range(x.shape[0]):
- for j in range(x.shape[1]):
- self.assertAlmostEqual(erfx[i, j],
- math.erf(x[i, j]),
- places=5)
-
- def testErfc(self):
- """Compare erf with math.erf"""
- # scalars
- self.assertAlmostEqual(functions.erfc(0.14), math.erfc(0.14), places=5)
- self.assertAlmostEqual(functions.erfc(0), math.erfc(0), places=5)
- self.assertAlmostEqual(functions.erfc(-0.74), math.erfc(-0.74), places=5)
-
- # lists
- x = [-5, -2, -1.5, -0.6, 0, 0.1, 2, 3]
- erfcx = functions.erfc(x)
- for i in range(len(x)):
- self.assertAlmostEqual(erfcx[i], math.erfc(x[i]), places=5)
-
- # ndarray
- x = numpy.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
- erfcx = functions.erfc(x)
- for i in range(x.shape[0]):
- for j in range(x.shape[1]):
- self.assertAlmostEqual(erfcx[i, j], math.erfc(x[i, j]), places=5)
-
- def testAtanStepUp(self):
- """Compare atan_stepup with math.atan
-
- atan_stepup(x, a, b, c) = a * (0.5 + (arctan((x - b) / c) / pi))"""
- x0 = numpy.arange(100) / 6.33
- y0 = functions.atan_stepup(x0, 11.1, 22.2, 3.33)
-
- for x, y in zip(x0, y0):
- self.assertAlmostEqual(
- 11.1 * (0.5 + math.atan((x - 22.2) / 3.33) / math.pi),
- y
- )
-
- def testStepUp(self):
- """sanity check for step up:
-
- - derivative must be largest around the step center
- - max value must be close to height parameter
-
- """
- x0 = numpy.arange(1000)
- center = 444
- height = 1234
- fwhm = 210
- y0 = functions.sum_stepup(x0, height, center, fwhm)
-
- self.assertLess(max(y0), height)
- self.assertAlmostEqual(max(y0), height, places=1)
- self.assertAlmostEqual(min(y0), 0, places=1)
-
- deriv0 = _numerical_derivative(functions.sum_stepup, x0, [height, center, fwhm])
-
- # Test center position within +- 1 sample of max derivative
- index_max_deriv = numpy.argmax(deriv0)
- self.assertLess(abs(index_max_deriv - center),
- 1)
-
- def testStepDown(self):
- """sanity check for step down:
-
- - absolute value of derivative must be largest around the step center
- - max value must be close to height parameter
-
- """
- x0 = numpy.arange(1000)
- center = 444
- height = 1234
- fwhm = 210
- y0 = functions.sum_stepdown(x0, height, center, fwhm)
-
- self.assertLess(max(y0), height)
- self.assertAlmostEqual(max(y0), height, places=1)
- self.assertAlmostEqual(min(y0), 0, places=1)
-
- deriv0 = _numerical_derivative(functions.sum_stepdown, x0, [height, center, fwhm])
-
- # Test center position within +- 1 sample of max derivative
- index_min_deriv = numpy.argmax(-deriv0)
- self.assertLess(abs(index_min_deriv - center),
- 1)
-
- def testSlit(self):
- """sanity check for slit:
-
- - absolute value of derivative must be largest around the step center
- - max value must be close to height parameter
-
- """
- x0 = numpy.arange(1000)
- center = 444
- height = 1234
- fwhm = 210
- beamfwhm = 30
- y0 = functions.sum_slit(x0, height, center, fwhm, beamfwhm)
-
- self.assertAlmostEqual(max(y0), height, places=1)
- self.assertAlmostEqual(min(y0), 0, places=1)
-
- deriv0 = _numerical_derivative(functions.sum_slit, x0, [height, center, fwhm, beamfwhm])
-
- # Test step up center position (center - fwhm/2) within +- 1 sample of max derivative
- index_max_deriv = numpy.argmax(deriv0)
- self.assertLess(abs(index_max_deriv - (center - fwhm/2)),
- 1)
- # Test step down center position (center + fwhm/2) within +- 1 sample of min derivative
- index_min_deriv = numpy.argmin(deriv0)
- self.assertLess(abs(index_min_deriv - (center + fwhm/2)),
- 1)
-
-
-def _numerical_derivative(f, x, params=[], delta_factor=0.0001):
- """Compute the numerical derivative of ``f`` for all values of ``x``.
-
- :param f: function
- :param x: Array of evenly spaced abscissa values
- :param params: list of additional parameters
- :return: Array of derivative values
- """
- deltax = (x[1] - x[0]) * delta_factor
- y_plus = f(x + deltax, *params)
- y_minus = f(x - deltax, *params)
-
- return (y_plus - y_minus) / (2 * deltax)
-
-test_cases = (Test_functions,)
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/fit/test/test_peaks.py b/silx/math/fit/test/test_peaks.py
deleted file mode 100644
index 17eb75d..0000000
--- a/silx/math/fit/test/test_peaks.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Tests for peaks module
-"""
-
-import unittest
-import numpy
-import math
-
-from silx.math.fit import functions
-from silx.math.fit import peaks
-
-class Test_peak_search(unittest.TestCase):
- """
- Unit tests of peak_search on various types of multi-peak functions.
- """
- def setUp(self):
- self.x = numpy.arange(5000)
- # (height1, center1, fwhm1, ...)
- self.h_c_fwhm = (50, 500, 100,
- 50, 600, 80,
- 20, 2000, 100,
- 50, 2250, 110,
- 40, 3000, 99,
- 23, 4980, 80)
- # (height1, center1, fwhm1, eta1 ...)
- self.h_c_fwhm_eta = (50, 500, 100, 0.4,
- 50, 600, 80, 0.5,
- 20, 2000, 100, 0.6,
- 50, 2250, 110, 0.7,
- 40, 3000, 99, 0.8,
- 23, 4980, 80, 0.3,)
- # (height1, center1, fwhm11, fwhm21, ...)
- self.h_c_fwhm_fwhm = (50, 500, 100, 85,
- 50, 600, 80, 110,
- 20, 2000, 100, 100,
- 50, 2250, 110, 99,
- 40, 3000, 99, 110,
- 23, 4980, 80, 80,)
- # (height1, center1, fwhm11, fwhm21, eta1 ...)
- self.h_c_fwhm_fwhm_eta = (50, 500, 100, 85, 0.4,
- 50, 600, 80, 110, 0.5,
- 20, 2000, 100, 100, 0.6,
- 50, 2250, 110, 99, 0.7,
- 40, 3000, 99, 110, 0.8,
- 23, 4980, 80, 80, 0.3,)
- # (area1, center1, fwhm1, ...)
- self.a_c_fwhm = (2550, 500, 100,
- 2000, 600, 80,
- 500, 2000, 100,
- 4000, 2250, 110,
- 2300, 3000, 99,
- 3333, 4980, 80)
- # (area1, center1, fwhm1, eta1 ...)
- self.a_c_fwhm_eta = (500, 500, 100, 0.4,
- 500, 600, 80, 0.5,
- 200, 2000, 100, 0.6,
- 500, 2250, 110, 0.7,
- 400, 3000, 99, 0.8,
- 230, 4980, 80, 0.3,)
- # (area, position, fwhm, st_area_r, st_slope_r, lt_area_r, lt_slope_r, step_height_r)
- self.hypermet_params = (1000, 500, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 1000, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 2000, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 2350, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 3000, 200, 0.2, 100, 0.3, 100, 0.05,
- 1000, 4900, 200, 0.2, 100, 0.3, 100, 0.05,)
-
-
- def tearDown(self):
- pass
-
- def get_peaks(self, function, params):
- """
-
- :param function: Multi-peak function
- :param params: Parameter for this function
- :return: list of (peak, relevance) tuples
- """
- y = function(self.x, *params)
- return peaks.peak_search(y=y, fwhm=100, relevance_info=True)
-
- def testPeakSearch_various_functions(self):
- """Run peak search on a variety of synthetic functions, and
- check that result falls within +-25 samples of the actual peak
- (reasonable delta considering a fwhm of ~100 samples) and effects
- of overlapping peaks)."""
- f_p = ((functions.sum_gauss, self.h_c_fwhm ),
- (functions.sum_lorentz, self.h_c_fwhm),
- (functions.sum_pvoigt, self.h_c_fwhm_eta),
- (functions.sum_splitgauss, self.h_c_fwhm_fwhm),
- (functions.sum_splitlorentz, self.h_c_fwhm_fwhm),
- (functions.sum_splitpvoigt, self.h_c_fwhm_fwhm_eta),
- (functions.sum_agauss, self.a_c_fwhm),
- (functions.sum_fastagauss, self.a_c_fwhm),
- (functions.sum_alorentz, self.a_c_fwhm),
- (functions.sum_apvoigt, self.a_c_fwhm_eta),
- (functions.sum_ahypermet, self.hypermet_params),
- (functions.sum_fastahypermet, self.hypermet_params),)
-
- for function, params in f_p:
- peaks = self.get_peaks(function, params)
-
- self.assertEqual(len(peaks), 6,
- "Wrong number of peaks detected")
-
- for i in range(6):
- theoretical_peak_index = params[i*(len(params)//6) + 1]
- found_peak_index = peaks[i][0]
- self.assertLess(abs(found_peak_index - theoretical_peak_index), 25)
-
-
-test_cases = (Test_peak_search,)
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/medianfilter/test/__init__.py b/silx/math/medianfilter/test/__init__.py
deleted file mode 100644
index 92a6524..0000000
--- a/silx/math/medianfilter/test/__init__.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "22/06/2016"
-
-import unittest
-
-from . import test_medianfilter
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_medianfilter.suite())
- return test_suite
diff --git a/silx/math/medianfilter/test/benchmark.py b/silx/math/medianfilter/test/benchmark.py
deleted file mode 100644
index cbb16b3..0000000
--- a/silx/math/medianfilter/test/benchmark.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2017-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests of the median filter"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "02/05/2017"
-
-from silx.gui import qt
-from silx.math.medianfilter import medfilt2d as medfilt2d_silx
-import numpy
-import numpy.random
-from timeit import Timer
-from silx.gui.plot import Plot1D
-import logging
-
-try:
- import scipy
-except:
- scipy = None
-else:
- import scipy.ndimage
-
-try:
- import PyMca5.PyMca as pymca
-except:
- pymca = None
-else:
- from PyMca5.PyMca.median import medfilt2d as medfilt2d_pymca
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-class BenchmarkMedianFilter(object):
- """Simple benchmark of the median fiter silx vs scipy"""
-
- NB_ITER = 3
-
- def __init__(self, imageWidth, kernels):
- self.img = numpy.random.rand(imageWidth, imageWidth)
- self.kernels = kernels
-
- self.run()
-
- def run(self):
- self.execTime = {}
- for kernel in self.kernels:
- self.execTime[kernel] = self.bench(kernel)
-
- def bench(self, width):
- def execSilx():
- medfilt2d_silx(self.img, width)
-
- def execScipy():
- scipy.ndimage.median_filter(input=self.img,
- size=width,
- mode='nearest')
-
- def execPymca():
- medfilt2d_pymca(self.img, width)
-
- execTime = {}
-
- t = Timer(execSilx)
- execTime["silx"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
- logger.info(
- 'exec time silx (kernel size = %s) is %s' % (width, execTime["silx"]))
-
- if scipy is not None:
- t = Timer(execScipy)
- execTime["scipy"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
- logger.info(
- 'exec time scipy (kernel size = %s) is %s' % (width, execTime["scipy"]))
- if pymca is not None:
- t = Timer(execPymca)
- execTime["pymca"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
- logger.info(
- 'exec time pymca (kernel size = %s) is %s' % (width, execTime["pymca"]))
-
- return execTime
-
- def getExecTimeFor(self, id):
- res = []
- for k in self.kernels:
- res.append(self.execTime[k][id])
- return res
-
-
-app = qt.QApplication([])
-kernels = [3, 5, 7, 11, 15]
-benchmark = BenchmarkMedianFilter(imageWidth=1000, kernels=kernels)
-plot = Plot1D()
-plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("silx"), legend='silx')
-if scipy is not None:
- plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("scipy"), legend='scipy')
-if pymca is not None:
- plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("pymca"), legend='pymca')
-plot.show()
-app.exec_()
-del app
diff --git a/silx/math/medianfilter/test/test_medianfilter.py b/silx/math/medianfilter/test/test_medianfilter.py
deleted file mode 100644
index 3a45b3d..0000000
--- a/silx/math/medianfilter/test/test_medianfilter.py
+++ /dev/null
@@ -1,740 +0,0 @@
-# coding: utf-8
-# ##########################################################################
-# Copyright (C) 2017-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################
-"""Tests of the median filter"""
-
-__authors__ = ["H. Payno"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-import unittest
-import numpy
-from silx.math.medianfilter import medfilt2d, medfilt1d
-from silx.math.medianfilter.medianfilter import reflect, mirror
-from silx.math.medianfilter.medianfilter import MODES as silx_mf_modes
-from silx.utils.testutils import ParametricTestCase
-try:
- import scipy
- import scipy.misc
-except:
- scipy = None
-else:
- import scipy.ndimage
-
-import logging
-_logger = logging.getLogger(__name__)
-
-RANDOM_FLOAT_MAT = numpy.array([
- [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
- [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
- [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
- [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
-
-RANDOM_INT_MAT = numpy.array([
- [0, 5, 2, 6, 1],
- [2, 3, 1, 7, 1],
- [9, 8, 6, 7, 8],
- [5, 6, 8, 2, 4]])
-
-
-class TestMedianFilterNearest(ParametricTestCase):
- """Unit tests for the median filter in nearest mode"""
-
- def testFilter3_100(self):
- """Test median filter on a 10x10 matrix with a 3x3 kernel."""
- dataIn = numpy.arange(100, dtype=numpy.int32)
- dataIn = dataIn.reshape((10, 10))
-
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=False,
- mode='nearest')
- self.assertTrue(dataOut[0, 0] == 1)
- self.assertTrue(dataOut[9, 0] == 90)
- self.assertTrue(dataOut[9, 9] == 98)
-
- self.assertTrue(dataOut[0, 9] == 9)
- self.assertTrue(dataOut[0, 4] == 5)
- self.assertTrue(dataOut[9, 4] == 93)
- self.assertTrue(dataOut[4, 4] == 44)
-
- def testFilter3_9(self):
- "Test median filter on a 3x3 matrix with a 3x3 kernel."
- dataIn = numpy.array([0, -1, 1,
- 12, 6, -2,
- 100, 4, 12],
- dtype=numpy.int16)
- dataIn = dataIn.reshape((3, 3))
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=False,
- mode='nearest')
- self.assertTrue(dataOut.shape == dataIn.shape)
- self.assertTrue(dataOut[1, 1] == 4)
- self.assertTrue(dataOut[0, 0] == 0)
- self.assertTrue(dataOut[0, 1] == 0)
- self.assertTrue(dataOut[1, 0] == 6)
-
- def testFilterWidthOne(self):
- """Make sure a filter of one by one give the same result as the input
- """
- dataIn = numpy.arange(100, dtype=numpy.int32)
- dataIn = dataIn.reshape((10, 10))
-
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(1, 1),
- conditional=False,
- mode='nearest')
-
- self.assertTrue(numpy.array_equal(dataIn, dataOut))
-
- def testFilter3_1d(self):
- """Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=3, conditional=False,
- mode='nearest'),
- [0, 2, 5, 2, 1])
- )
-
- def testFilter3Conditionnal(self):
- """Test that the conditional filter apply correctly in a 10x10 matrix
- with a 3x3 kernel
- """
- dataIn = numpy.arange(100, dtype=numpy.int32)
- dataIn = dataIn.reshape((10, 10))
-
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=True,
- mode='nearest')
- self.assertTrue(dataOut[0, 0] == 1)
- self.assertTrue(dataOut[0, 1] == 1)
- self.assertTrue(numpy.array_equal(dataOut[1:8, 1:8], dataIn[1:8, 1:8]))
- self.assertTrue(dataOut[9, 9] == 98)
-
- def testFilter3_1D(self):
- """Simple test of a 3x3 median filter on a 1D array"""
- dataIn = numpy.arange(100, dtype=numpy.int32)
-
- dataOut = medfilt2d(image=dataIn,
- kernel_size=(5),
- conditional=False,
- mode='nearest')
-
- self.assertTrue(dataOut[0] == 0)
- self.assertTrue(dataOut[9] == 9)
- self.assertTrue(dataOut[99] == 99)
-
- def testNaNs(self):
- """Test median filter on image with NaNs in nearest mode"""
- # Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
- nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='nearest')
- self.assertEqual(output[0, 0], 10)
- self.assertEqual(output[0, 1], 2)
- self.assertEqual(output[1, 0], 11)
- self.assertEqual(output[1, 1], 12)
-
- # Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
- some_nans[0, 1] = numpy.nan
- some_nans[1, 1] = numpy.nan
- some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='nearest')
- self.assertEqual(output[0, 0], 0)
- self.assertEqual(output[0, 1], 2)
- self.assertEqual(output[1, 0], 20)
- self.assertEqual(output[1, 1], 20)
-
-
-class TestMedianFilterReflect(ParametricTestCase):
- """Unit test for the median filter in reflect mode"""
-
- def testArange9(self):
- """Test from a 3x3 window to RANDOM_FLOAT_MAT"""
- img = numpy.arange(9, dtype=numpy.int32)
- img = img.reshape(3, 3)
- kernel = (3, 3)
- res = medfilt2d(image=img,
- kernel_size=kernel,
- conditional=False,
- mode='reflect')
- self.assertTrue(
- numpy.array_equal(res.ravel(), [1, 2, 2, 3, 4, 5, 6, 6, 7]))
-
- def testRandom10(self):
- """Test a (5, 3) window to a RANDOM_FLOAT_MAT"""
- kernel = (5, 3)
-
- thRes = numpy.array([
- [0.23067427, 0.56049024, 0.56049024, 0.4440632, 0.42161216],
- [0.23067427, 0.62717157, 0.56049024, 0.56049024, 0.46372299],
- [0.62717157, 0.62717157, 0.56049024, 0.56049024, 0.4440632],
- [0.76532598, 0.68382382, 0.56049024, 0.56049024, 0.42161216],
- [0.81025249, 0.68382382, 0.56049024, 0.68382382, 0.46372299]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='reflect')
-
- self.assertTrue(numpy.array_equal(thRes, res))
-
- def testApplyReflect1D(self):
- """Test the reflect function used for the median filter in reflect mode
- """
- # test for inside values
- self.assertTrue(reflect(2, 3) == 2)
- # test for boundaries values
- self.assertTrue(reflect(3, 3) == 2)
- self.assertTrue(reflect(4, 3) == 1)
- self.assertTrue(reflect(5, 3) == 0)
- self.assertTrue(reflect(6, 3) == 0)
- self.assertTrue(reflect(7, 3) == 1)
- self.assertTrue(reflect(-1, 3) == 0)
- self.assertTrue(reflect(-2, 3) == 1)
- self.assertTrue(reflect(-3, 3) == 2)
- self.assertTrue(reflect(-4, 3) == 2)
- self.assertTrue(reflect(-5, 3) == 1)
- self.assertTrue(reflect(-6, 3) == 0)
- self.assertTrue(reflect(-7, 3) == 0)
-
- def testRandom10Conditionnal(self):
- """Test the median filter in reflect mode and with the conditionnal
- option"""
- kernel = (3, 1)
-
- thRes = numpy.array([
- [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
- [0.23067427, 0.62717157, 0.56049024, 0.44406320, 0.42161216],
- [0.76532598, 0.20303021, 0.56049024, 0.46372299, 0.42161216],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.33623165],
- [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='reflect')
- self.assertTrue(numpy.array_equal(thRes, res))
-
- def testNaNs(self):
- """Test median filter on image with NaNs in reflect mode"""
- # Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
- nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='reflect')
- self.assertEqual(output[0, 0], 10)
- self.assertEqual(output[0, 1], 2)
- self.assertEqual(output[1, 0], 11)
- self.assertEqual(output[1, 1], 12)
-
- # Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
- some_nans[0, 1] = numpy.nan
- some_nans[1, 1] = numpy.nan
- some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='reflect')
- self.assertEqual(output[0, 0], 0)
- self.assertEqual(output[0, 1], 2)
- self.assertEqual(output[1, 0], 20)
- self.assertEqual(output[1, 1], 20)
-
- def testFilter3_1d(self):
- """Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
- mode='reflect'),
- [2, 2, 2, 2, 2])
- )
-
-
-class TestMedianFilterMirror(ParametricTestCase):
- """Unit test for the median filter in mirror mode
- """
-
- def testApplyMirror1D(self):
- """Test the reflect function used for the median filter in mirror mode
- """
- # test for inside values
- self.assertTrue(mirror(2, 3) == 2)
- # test for boundaries values
- self.assertTrue(mirror(4, 4) == 2)
- self.assertTrue(mirror(5, 4) == 1)
- self.assertTrue(mirror(6, 4) == 0)
- self.assertTrue(mirror(7, 4) == 1)
- self.assertTrue(mirror(8, 4) == 2)
- self.assertTrue(mirror(-1, 4) == 1)
- self.assertTrue(mirror(-2, 4) == 2)
- self.assertTrue(mirror(-3, 4) == 3)
- self.assertTrue(mirror(-4, 4) == 2)
- self.assertTrue(mirror(-5, 4) == 1)
- self.assertTrue(mirror(-6, 4) == 0)
-
- def testRandom10(self):
- """Test a (5, 3) window to a random array"""
- kernel = (3, 5)
-
- thRes = numpy.array([
- [0.05272484, 0.40555336, 0.42161216, 0.42161216, 0.42161216],
- [0.56049024, 0.56049024, 0.4440632, 0.4440632, 0.4440632],
- [0.56049024, 0.46372299, 0.46372299, 0.46372299, 0.46372299],
- [0.68382382, 0.56049024, 0.56049024, 0.46372299, 0.56049024],
- [0.68382382, 0.46372299, 0.68382382, 0.46372299, 0.68382382]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='mirror')
-
- self.assertTrue(numpy.array_equal(thRes, res))
-
- def testRandom10Conditionnal(self):
- """Test the median filter in reflect mode and with the conditionnal
- option"""
- kernel = (1, 3)
-
- thRes = numpy.array([
- [0.62717157, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
- [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.65166994],
- [0.74219128, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
- [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
- [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.84220106]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='mirror')
-
- self.assertTrue(numpy.array_equal(thRes, res))
-
- def testNaNs(self):
- """Test median filter on image with NaNs in mirror mode"""
- # Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
- nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='mirror')
- self.assertEqual(output[0, 0], 11)
- self.assertEqual(output[0, 1], 11)
- self.assertEqual(output[1, 0], 11)
- self.assertEqual(output[1, 1], 12)
-
- # Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
- some_nans[0, 1] = numpy.nan
- some_nans[1, 1] = numpy.nan
- some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='mirror')
- self.assertEqual(output[0, 0], 0)
- self.assertEqual(output[0, 1], 12)
- self.assertEqual(output[1, 0], 21)
- self.assertEqual(output[1, 1], 20)
-
- def testFilter3_1d(self):
- """Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
- mode='mirror'),
- [2, 5, 2, 5, 2])
- )
-
-class TestMedianFilterShrink(ParametricTestCase):
- """Unit test for the median filter in mirror mode
- """
-
- def testRandom_3x3(self):
- """Test the median filter in shrink mode and with the conditionnal
- option"""
- kernel = (3, 3)
-
- thRes = numpy.array([
- [0.62717157, 0.62717157, 0.62717157, 0.65166994, 0.65166994],
- [0.62717157, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
- [0.74219128, 0.56049024, 0.46372299, 0.46372299, 0.46372299],
- [0.74219128, 0.68382382, 0.56049024, 0.56049024, 0.46372299],
- [0.81025249, 0.81025249, 0.68382382, 0.81281709, 0.81281709]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='shrink')
-
- self.assertTrue(numpy.array_equal(thRes, res))
-
- def testBounds(self):
- """Test the median filter in shrink mode with 3 different kernels
- which should return the same result due to the large values of kernels
- used.
- """
- kernel1 = (1, 9)
- kernel2 = (1, 11)
- kernel3 = (1, 21)
-
- thRes = numpy.array([[2, 2, 2, 2, 2],
- [2, 2, 2, 2, 2],
- [8, 8, 8, 8, 8],
- [5, 5, 5, 5, 5]])
-
- resK1 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel1,
- conditional=False,
- mode='shrink')
-
- resK2 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel2,
- conditional=False,
- mode='shrink')
-
- resK3 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel3,
- conditional=False,
- mode='shrink')
-
- self.assertTrue(numpy.array_equal(resK1, thRes))
- self.assertTrue(numpy.array_equal(resK2, resK1))
- self.assertTrue(numpy.array_equal(resK3, resK1))
-
- def testRandom_3x3Conditionnal(self):
- """Test the median filter in reflect mode and with the conditionnal
- option"""
- kernel = (3, 3)
-
- thRes = numpy.array([
- [0.05564293, 0.62717157, 0.62717157, 0.40555336, 0.65166994],
- [0.62717157, 0.56049024, 0.05272484, 0.65166994, 0.42161216],
- [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.46372299],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
- [0.81025249, 0.81025249, 0.81651256, 0.81281709, 0.81281709]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='shrink')
-
- self.assertTrue(numpy.array_equal(res, thRes))
-
- def testRandomInt(self):
- """Test 3x3 kernel on RANDOM_INT_MAT
- """
- kernel = (3, 3)
-
- thRes = numpy.array([[3, 2, 5, 2, 6],
- [5, 3, 6, 6, 7],
- [6, 6, 6, 6, 7],
- [8, 8, 7, 7, 7]])
-
- resK1 = medfilt2d(image=RANDOM_INT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='shrink')
-
- self.assertTrue(numpy.array_equal(resK1, thRes))
-
- def testNaNs(self):
- """Test median filter on image with NaNs in shrink mode"""
- # Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
- nan_corner[0, 0] = numpy.nan
- output = medfilt2d(
- nan_corner, kernel_size=3, conditional=False, mode='shrink')
- self.assertEqual(output[0, 0], 10)
- self.assertEqual(output[0, 1], 10)
- self.assertEqual(output[1, 0], 11)
- self.assertEqual(output[1, 1], 12)
-
- # Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
- some_nans[0, 1] = numpy.nan
- some_nans[1, 1] = numpy.nan
- some_nans[1, 0] = numpy.nan
- output = medfilt2d(
- some_nans, kernel_size=3, conditional=False, mode='shrink')
- self.assertEqual(output[0, 0], 0)
- self.assertEqual(output[0, 1], 2)
- self.assertEqual(output[1, 0], 20)
- self.assertEqual(output[1, 1], 20)
-
- def testFilter3_1d(self):
- """Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=3, conditional=False,
- mode='shrink'),
- [5, 2, 5, 2, 6])
- )
-
-class TestMedianFilterConstant(ParametricTestCase):
- """Unit test for the median filter in constant mode
- """
-
- def testRandom10(self):
- """Test a (5, 3) window to a random array"""
- kernel = (3, 5)
-
- thRes = numpy.array([
- [0., 0.02839148, 0.05564293, 0.02839148, 0.],
- [0.05272484, 0.40555336, 0.4440632, 0.42161216, 0.28773158],
- [0.05272484, 0.44406320, 0.46372299, 0.42161216, 0.28773158],
- [0.20303021, 0.46372299, 0.56049024, 0.44406320, 0.33623165],
- [0., 0.07813661, 0.33623165, 0.07813661, 0.]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode='constant')
-
- self.assertTrue(numpy.array_equal(thRes, res))
-
- RANDOM_FLOAT_MAT = numpy.array([
- [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
- [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
- [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
- [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
- [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
-
- def testRandom10Conditionnal(self):
- """Test the median filter in reflect mode and with the conditionnal
- option"""
- kernel = (1, 3)
-
- print(RANDOM_FLOAT_MAT)
-
- thRes = numpy.array([
- [0.05564293, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
- [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.42161216],
- [0.23067427, 0.56049024, 0.56049024, 0.44406320, 0.28773158],
- [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
- [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.33623165]])
-
- res = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=True,
- mode='constant')
-
- self.assertTrue(numpy.array_equal(thRes, res))
-
- def testNaNs(self):
- """Test median filter on image with NaNs in constant mode"""
- # Data with a NaN in first corner
- nan_corner = numpy.arange(100.).reshape(10, 10)
- nan_corner[0, 0] = numpy.nan
- output = medfilt2d(nan_corner,
- kernel_size=3,
- conditional=False,
- mode='constant',
- cval=0)
- self.assertEqual(output[0, 0], 0)
- self.assertEqual(output[0, 1], 2)
- self.assertEqual(output[1, 0], 10)
- self.assertEqual(output[1, 1], 12)
-
- # Data with some NaNs
- some_nans = numpy.arange(100.).reshape(10, 10)
- some_nans[0, 1] = numpy.nan
- some_nans[1, 1] = numpy.nan
- some_nans[1, 0] = numpy.nan
- output = medfilt2d(some_nans,
- kernel_size=3,
- conditional=False,
- mode='constant',
- cval=0)
- self.assertEqual(output[0, 0], 0)
- self.assertEqual(output[0, 1], 0)
- self.assertEqual(output[1, 0], 0)
- self.assertEqual(output[1, 1], 20)
-
- def testFilter3_1d(self):
- """Test binding and result of the 1d filter"""
- self.assertTrue(numpy.array_equal(
- medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
- mode='constant'),
- [0, 2, 2, 2, 1])
- )
-
-class TestGeneralExecution(ParametricTestCase):
- """Some general test on median filter application"""
-
- def testTypes(self):
- """Test that all needed types have their implementation of the median
- filter
- """
- for mode in silx_mf_modes:
- for testType in [numpy.float32, numpy.float64, numpy.int16,
- numpy.uint16, numpy.int32, numpy.int64,
- numpy.uint64]:
- with self.subTest(mode=mode, type=testType):
- data = (numpy.random.rand(10, 10) * 65000).astype(testType)
- out = medfilt2d(image=data,
- kernel_size=(3, 3),
- conditional=False,
- mode=mode)
- self.assertTrue(out.dtype.type is testType)
-
- def testInputDataIsNotModify(self):
- """Make sure input data is not modify by the median filter"""
- dataIn = numpy.arange(100, dtype=numpy.int32)
- dataIn = dataIn.reshape((10, 10))
- dataInCopy = dataIn.copy()
-
- for mode in silx_mf_modes:
- with self.subTest(mode=mode):
- medfilt2d(image=dataIn,
- kernel_size=(3, 3),
- conditional=False,
- mode=mode)
- self.assertTrue(numpy.array_equal(dataIn, dataInCopy))
-
- def testAllNaNs(self):
- """Test median filter on image all NaNs"""
- all_nans = numpy.empty((10, 10), dtype=numpy.float32)
- all_nans[:] = numpy.nan
-
- for mode in silx_mf_modes:
- for conditional in (True, False):
- with self.subTest(mode=mode, conditional=conditional):
- output = medfilt2d(
- all_nans,
- kernel_size=3,
- conditional=conditional,
- mode=mode,
- cval=numpy.nan)
- self.assertTrue(numpy.all(numpy.isnan(output)))
-
- def testConditionalWithNaNs(self):
- """Test that NaNs are propagated through conditional median filter"""
- for mode in silx_mf_modes:
- with self.subTest(mode=mode):
- image = numpy.ones((10, 10), dtype=numpy.float32)
- nan_mask = numpy.zeros_like(image, dtype=bool)
- nan_mask[0, 0] = True
- nan_mask[4, :] = True
- nan_mask[6, 4] = True
- image[nan_mask] = numpy.nan
- output = medfilt2d(
- image,
- kernel_size=3,
- conditional=True,
- mode=mode)
- out_isnan = numpy.isnan(output)
- self.assertTrue(numpy.all(out_isnan[nan_mask]))
- self.assertFalse(
- numpy.any(out_isnan[numpy.logical_not(nan_mask)]))
-
-
-def _getScipyAndSilxCommonModes():
- """return the mode which are comparable between silx and scipy"""
- modes = silx_mf_modes.copy()
- del modes['shrink']
- return modes
-
-
-@unittest.skipUnless(scipy is not None, "scipy not available")
-class TestVsScipy(ParametricTestCase):
- """Compare scipy.ndimage.median_filter vs silx.math.medianfilter
- on comparable
- """
- def testWithArange(self):
- """Test vs scipy with different kernels on arange matrix"""
- data = numpy.arange(10000, dtype=numpy.int32)
- data = data.reshape(100, 100)
-
- kernels = [(3, 7), (7, 5), (1, 1), (3, 3)]
- modesToTest = _getScipyAndSilxCommonModes()
- for kernel in kernels:
- for mode in modesToTest:
- with self.subTest(kernel=kernel, mode=mode):
- resScipy = scipy.ndimage.median_filter(input=data,
- size=kernel,
- mode=mode)
- resSilx = medfilt2d(image=data,
- kernel_size=kernel,
- conditional=False,
- mode=mode)
-
- self.assertTrue(numpy.array_equal(resScipy, resSilx))
-
- def testRandomMatrice(self):
- """Test vs scipy with different kernels on RANDOM_FLOAT_MAT"""
- kernels = [(3, 7), (7, 5), (1, 1), (3, 3)]
- modesToTest = _getScipyAndSilxCommonModes()
- for kernel in kernels:
- for mode in modesToTest:
- with self.subTest(kernel=kernel, mode=mode):
- resScipy = scipy.ndimage.median_filter(input=RANDOM_FLOAT_MAT,
- size=kernel,
- mode=mode)
-
- resSilx = medfilt2d(image=RANDOM_FLOAT_MAT,
- kernel_size=kernel,
- conditional=False,
- mode=mode)
-
- self.assertTrue(numpy.array_equal(resScipy, resSilx))
-
- def testAscentOrLena(self):
- """Test vs scipy with """
- if hasattr(scipy.misc, 'ascent'):
- img = scipy.misc.ascent()
- else:
- img = scipy.misc.lena()
-
- kernels = [(3, 1), (3, 5), (5, 9), (9, 3)]
- modesToTest = _getScipyAndSilxCommonModes()
-
- for kernel in kernels:
- for mode in modesToTest:
- with self.subTest(kernel=kernel, mode=mode):
- resScipy = scipy.ndimage.median_filter(input=img,
- size=kernel,
- mode=mode)
-
- resSilx = medfilt2d(image=img,
- kernel_size=kernel,
- conditional=False,
- mode=mode)
-
- self.assertTrue(numpy.array_equal(resScipy, resSilx))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for test in [TestGeneralExecution,
- TestVsScipy,
- TestMedianFilterNearest,
- TestMedianFilterReflect,
- TestMedianFilterMirror,
- TestMedianFilterShrink,
- TestMedianFilterConstant]:
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(test))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/math/setup.py b/silx/math/setup.py
deleted file mode 100644
index 8cc15e6..0000000
--- a/silx/math/setup.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-__authors__ = ["D. Naudet"]
-__license__ = "MIT"
-__date__ = "27/03/2017"
-
-import os.path
-
-import numpy
-
-from numpy.distutils.misc_util import Configuration
-
-
-def configuration(parent_package='', top_path=None):
- config = Configuration('math', parent_package, top_path)
- config.add_subpackage('test')
- config.add_subpackage('fit')
- config.add_subpackage('medianfilter')
- config.add_subpackage('fft')
-
- # =====================================
- # histogramnd
- # =====================================
- histo_src = [os.path.join('histogramnd', 'src', 'histogramnd_c.c'),
- 'chistogramnd.pyx']
- histo_inc = [os.path.join('histogramnd', 'include'),
- numpy.get_include()]
-
- config.add_extension('chistogramnd',
- sources=histo_src,
- include_dirs=histo_inc,
- language='c')
-
- # =====================================
- # histogramnd_lut
- # =====================================
- config.add_extension('chistogramnd_lut',
- sources=['chistogramnd_lut.pyx'],
- include_dirs=histo_inc,
- language='c')
- # =====================================
- # marching cubes
- # =====================================
- mc_src = [os.path.join('marchingcubes', 'mc_lut.cpp'),
- 'marchingcubes.pyx']
- config.add_extension('marchingcubes',
- sources=mc_src,
- include_dirs=['marchingcubes', numpy.get_include()],
- language='c++')
-
- # min/max
- config.add_extension('combo',
- sources=['combo.pyx'],
- include_dirs=['include'],
- language='c')
-
- config.add_extension('colormap',
- sources=["colormap.pyx"],
- language='c',
- include_dirs=['include', numpy.get_include()],
- extra_link_args=['-fopenmp'],
- extra_compile_args=['-fopenmp'])
-
- config.add_extension('interpolate',
- sources=["interpolate.pyx"],
- language='c',
- include_dirs=['include', numpy.get_include()],
- extra_link_args=['-fopenmp'],
- extra_compile_args=['-fopenmp'])
-
- return config
-
-
-if __name__ == "__main__":
- from numpy.distutils.core import setup
-
- setup(configuration=configuration)
diff --git a/silx/math/test/__init__.py b/silx/math/test/__init__.py
deleted file mode 100644
index e9f29f3..0000000
--- a/silx/math/test/__init__.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-__authors__ = ["D. Naudet"]
-__license__ = "MIT"
-__date__ = "04/07/2016"
-
-import unittest
-
-from .test_histogramnd_error import suite as test_histo_error
-from .test_histogramnd_nominal import suite as test_histo_nominal
-from .test_histogramnd_vs_np import suite as test_histo_vs_np
-from .test_HistogramndLut_nominal import suite as test_histolut_nominal
-from ..fit.test import suite as test_fit_suite
-from .test_marchingcubes import suite as test_marchingcubes_suite
-from ..medianfilter.test import suite as test_medianfilter_suite
-from .test_combo import suite as test_combo_suite
-from .test_calibration import suite as test_calibration_suite
-from .test_colormap import suite as test_colormap_suite
-from .test_interpolate import suite as test_interpolate_suite
-from ..fft.test import suite as test_fft_suite
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_histo_nominal())
- test_suite.addTest(test_histo_error())
- test_suite.addTest(test_histo_vs_np())
- test_suite.addTest(test_fit_suite())
- test_suite.addTest(test_histolut_nominal())
- test_suite.addTest(test_marchingcubes_suite())
- test_suite.addTest(test_medianfilter_suite())
- test_suite.addTest(test_combo_suite())
- test_suite.addTest(test_calibration_suite())
- test_suite.addTest(test_colormap_suite())
- test_suite.addTest(test_interpolate_suite())
- test_suite.addTest(test_fft_suite())
- return test_suite
diff --git a/silx/math/test/benchmark_combo.py b/silx/math/test/benchmark_combo.py
deleted file mode 100644
index e179f76..0000000
--- a/silx/math/test/benchmark_combo.py
+++ /dev/null
@@ -1,203 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Benchmarks of the combo module"""
-
-from __future__ import division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import logging
-import os.path
-import time
-import unittest
-
-import numpy
-
-from silx.test.utils import temp_dir
-from silx.utils.testutils import ParametricTestCase
-
-from silx.math import combo
-
-_logger = logging.getLogger(__name__)
-_logger.setLevel(logging.DEBUG)
-
-
-class BenchmarkMinMax(ParametricTestCase):
- """Benchmark of min max combo"""
-
- DTYPES = ('float32', 'float64',
- 'int8', 'int16', 'int32', 'int64',
- 'uint8', 'uint16', 'uint32', 'uint64')
-
- ARANGE = 'ascent', 'descent', 'random'
-
- EXPONENT = 3, 4, 5, 6, 7
-
- def test_benchmark_min_max(self):
- """Benchmark min_max without min positive.
-
- Compares with:
-
- - numpy.nanmin, numpy.nanmax and
- - numpy.argmin, numpy.argmax
-
- It runs bench for different types, different data size and 3
- data sets: increasing , decreasing and random data.
- """
- durations = {'min/max': [], 'argmin/max': [], 'combo': []}
-
- _logger.info('Benchmark against argmin/argmax and nanmin/nanmax')
-
- for dtype in self.DTYPES:
- for arange in self.ARANGE:
- for exponent in self.EXPONENT:
- size = 10**exponent
- with self.subTest(dtype=dtype, size=size, arange=arange):
- if arange == 'ascent':
- data = numpy.arange(0, size, 1, dtype=dtype)
- elif arange == 'descent':
- data = numpy.arange(size, 0, -1, dtype=dtype)
- else:
- if dtype in ('float32', 'float64'):
- data = numpy.random.random(size)
- else:
- data = numpy.random.randint(10**6, size=size)
- data = numpy.array(data, dtype=dtype)
-
- start = time.time()
- ref_min = numpy.nanmin(data)
- ref_max = numpy.nanmax(data)
- durations['min/max'].append(time.time() - start)
-
- start = time.time()
- ref_argmin = numpy.argmin(data)
- ref_argmax = numpy.argmax(data)
- durations['argmin/max'].append(time.time() - start)
-
- start = time.time()
- result = combo.min_max(data, min_positive=False)
- durations['combo'].append(time.time() - start)
-
- _logger.info(
- '%s-%s-10**%d\tx%.2f argmin/max x%.2f min/max',
- dtype, arange, exponent,
- durations['argmin/max'][-1] / durations['combo'][-1],
- durations['min/max'][-1] / durations['combo'][-1])
-
- self.assertEqual(result.minimum, ref_min)
- self.assertEqual(result.maximum, ref_max)
- self.assertEqual(result.argmin, ref_argmin)
- self.assertEqual(result.argmax, ref_argmax)
-
- self.show_results('min/max', durations, 'combo')
-
- def test_benchmark_min_pos(self):
- """Benchmark min_max wit min positive.
-
- Compares with:
-
- - numpy.nanmin(data[data > 0]); numpy.nanmin(pos); numpy.nanmax(pos)
-
- It runs bench for different types, different data size and 3
- data sets: increasing , decreasing and random data.
- """
- durations = {'min/max': [], 'combo': []}
-
- _logger.info('Benchmark against min, max, positive min')
-
- for dtype in self.DTYPES:
- for arange in self.ARANGE:
- for exponent in self.EXPONENT:
- size = 10**exponent
- with self.subTest(dtype=dtype, size=size, arange=arange):
- if arange == 'ascent':
- data = numpy.arange(0, size, 1, dtype=dtype)
- elif arange == 'descent':
- data = numpy.arange(size, 0, -1, dtype=dtype)
- else:
- if dtype in ('float32', 'float64'):
- data = numpy.random.random(size)
- else:
- data = numpy.random.randint(10**6, size=size)
- data = numpy.array(data, dtype=dtype)
-
- start = time.time()
- ref_min_positive = numpy.nanmin(data[data > 0])
- ref_min = numpy.nanmin(data)
- ref_max = numpy.nanmax(data)
- durations['min/max'].append(time.time() - start)
-
- start = time.time()
- result = combo.min_max(data, min_positive=True)
- durations['combo'].append(time.time() - start)
-
- _logger.info(
- '%s-%s-10**%d\tx%.2f min/minpos/max',
- dtype, arange, exponent,
- durations['min/max'][-1] / durations['combo'][-1])
-
- self.assertEqual(result.min_positive, ref_min_positive)
- self.assertEqual(result.minimum, ref_min)
- self.assertEqual(result.maximum, ref_max)
-
- self.show_results('min/max/min positive', durations, 'combo')
-
- def show_results(self, title, durations, ref_key):
- try:
- from matplotlib import pyplot
- except ImportError:
- _logger.warning('matplotlib not available')
- return
-
- pyplot.title(title)
- pyplot.xlabel('-'.join(self.DTYPES))
- pyplot.ylabel('duration (sec)')
- for label, values in durations.items():
- pyplot.semilogy(values, label=label)
- pyplot.legend()
- pyplot.show()
-
- pyplot.title(title)
- pyplot.xlabel('-'.join(self.DTYPES))
- pyplot.ylabel('Duration ratio')
- ref = numpy.array(durations[ref_key])
- for label, values in durations.items():
- values = numpy.array(values)
- pyplot.plot(values/ref, label=label + ' / ' + ref_key)
- pyplot.legend()
- pyplot.show()
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- unittest.defaultTestLoader.loadTestsFromTestCase(BenchmarkMinMax))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_HistogramndLut_nominal.py b/silx/math/test/test_HistogramndLut_nominal.py
deleted file mode 100644
index 08ca682..0000000
--- a/silx/math/test/test_HistogramndLut_nominal.py
+++ /dev/null
@@ -1,587 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Nominal tests of the HistogramndLut function.
-"""
-
-import unittest
-
-import numpy as np
-
-from silx.math import HistogramndLut
-
-
-def _get_bin_edges(histo_range, n_bins, n_dims):
- edges = []
- for i_dim in range(n_dims):
- edges.append(histo_range[i_dim, 0] +
- np.arange(n_bins[i_dim] + 1) *
- (histo_range[i_dim, 1] - histo_range[i_dim, 0]) /
- n_bins[i_dim])
- return tuple(edges)
-
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-class _TestHistogramndLut_nominal(unittest.TestCase):
- """
- Unit tests of the HistogramndLut class.
- """
-
- ndims = None
-
- def setUp(self):
- ndims = self.ndims
- self.tested_dim = ndims-1
-
- if ndims is None:
- raise ValueError('ndims class member not set.')
-
- sample = np.array([5.5, -3.3,
- 0., -0.5,
- 3.3, 8.8,
- -7.7, 6.0,
- -4.0])
-
- weights = np.array([500.5, -300.3,
- 0.01, -0.5,
- 300.3, 800.8,
- -700.7, 600.6,
- -400.4])
-
- n_elems = len(sample)
-
- if ndims == 1:
- shape = (n_elems,)
- else:
- shape = (n_elems, ndims)
-
- self.sample = np.zeros(shape=shape, dtype=sample.dtype)
- if ndims == 1:
- self.sample = sample
- else:
- self.sample[..., ndims-1] = sample
-
- self.weights = weights
-
- # the tests are performed along one dimension,
- # all the other bins indices along the other dimensions
- # are expected to be 2
- # (e.g : when testing a 2D sample : [0, x] will go into
- # bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
- # for the first dimension)
- self.other_axes_index = 2
- self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
- self.histo_range[ndims-1] = [-4., 6.]
-
- self.n_bins = np.array([4]*ndims)
- self.n_bins[ndims-1] = 5
-
- if ndims == 1:
- def fill_histo(h, v, dim, op=None):
- if op:
- h[:] = op(h[:], v)
- else:
- h[:] = v
- self.fill_histo = fill_histo
- else:
- def fill_histo(h, v, dim, op=None):
- idx = [self.other_axes_index]*len(h.shape)
- idx[dim] = slice(0, None)
- idx = tuple(idx)
- if op:
- h[idx] = op(h[idx], v)
- else:
- h[idx] = v
- self.fill_histo = fill_histo
-
- def test_nominal_bin_edges(self):
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- bin_edges = instance.bins_edges
-
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
-
- for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
-
- def test_nominal_histo_range(self):
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- histo_range = instance.histo_range
-
- self.assertTrue(np.array_equal(histo_range, self.histo_range))
-
- def test_nominal_last_bin_closed(self):
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- last_bin_closed = instance.last_bin_closed
-
- self.assertEqual(last_bin_closed, False)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- last_bin_closed=True)
-
- last_bin_closed = instance.last_bin_closed
-
- self.assertEqual(last_bin_closed, True)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- last_bin_closed=False)
-
- last_bin_closed = instance.last_bin_closed
-
- self.assertEqual(last_bin_closed, False)
-
- def test_nominal_n_bins_array(self):
-
- test_n_bins = np.arange(self.ndims) + 10
- instance = HistogramndLut(self.sample,
- self.histo_range,
- test_n_bins)
-
- n_bins = instance.n_bins
-
- self.assertTrue(np.array_equal(test_n_bins, n_bins))
-
- def test_nominal_n_bins_scalar(self):
-
- test_n_bins = 10
- expected_n_bins = np.array([test_n_bins] * self.ndims)
- instance = HistogramndLut(self.sample,
- self.histo_range,
- test_n_bins)
-
- n_bins = instance.n_bins
-
- self.assertTrue(np.array_equal(expected_n_bins, n_bins))
-
- def test_nominal_histo_ref(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- instance.accumulate(self.weights)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
- histo_ref = instance.histo(copy=False)
- w_histo_ref = instance.weighted_histo(copy=False)
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
- self.assertTrue(np.array_equal(histo_ref, expected_h))
- self.assertTrue(np.array_equal(w_histo_ref, expected_c))
-
- histo_ref[0, ...] = histo_ref[0, ...] + 10
- w_histo_ref[0, ...] = w_histo_ref[0, ...] + 20
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
- self.assertFalse(np.array_equal(histo_ref, expected_h))
- self.assertFalse(np.array_equal(w_histo_ref, expected_c))
-
- histo_2 = instance.histo()
- w_histo_2 = instance.weighted_histo()
-
- self.assertFalse(np.array_equal(histo_2, expected_h))
- self.assertFalse(np.array_equal(w_histo_2, expected_c))
- self.assertTrue(np.array_equal(histo_2, histo_ref))
- self.assertTrue(np.array_equal(w_histo_2, w_histo_ref))
-
- def test_nominal_accumulate_once(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- instance.accumulate(self.weights)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
- self.assertTrue(np.array_equal(instance.histo(), expected_h))
- self.assertTrue(np.array_equal(instance.weighted_histo(),
- expected_c))
-
- def test_nominal_accumulate_twice(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- # calling accumulate twice
- expected_h *= 2
- expected_c *= 2
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- instance.accumulate(self.weights)
-
- instance.accumulate(self.weights)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
- self.assertTrue(np.array_equal(instance.histo(), expected_h))
- self.assertTrue(np.array_equal(instance.weighted_histo(),
- expected_c))
-
- def test_nominal_apply_lut_once(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- histo, w_histo = instance.apply_lut(self.weights)
-
- self.assertEqual(w_histo.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
- self.assertEqual(instance.histo(), None)
- self.assertEqual(instance.weighted_histo(), None)
-
- def test_nominal_apply_lut_twice(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- # calling apply_lut twice
- expected_h *= 2
- expected_c *= 2
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- histo, w_histo = instance.apply_lut(self.weights)
- histo_2, w_histo_2 = instance.apply_lut(self.weights,
- histo=histo,
- weighted_histo=w_histo)
-
- self.assertEqual(id(histo), id(histo_2))
- self.assertEqual(id(w_histo), id(w_histo_2))
- self.assertEqual(w_histo.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
- self.assertEqual(instance.histo(), None)
- self.assertEqual(instance.weighted_histo(), None)
-
- def test_nominal_accumulate_last_bin_closed(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 2])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- last_bin_closed=True)
-
- instance.accumulate(self.weights)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
-
- def test_nominal_accumulate_weight_min_max(self):
- """
- """
- weight_min = -299.9
- weight_max = 499.9
-
- expected_h_tpl = np.array([0, 1, 1, 1, 0])
- expected_c_tpl = np.array([0., -0.5, 0.01, 300.3, 0.])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- instance.accumulate(self.weights,
- weight_min=weight_min,
- weight_max=weight_max)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
-
- def test_nominal_accumulate_forced_int32(self):
- """
- double weights, int32 weighted_histogram
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700, 0, 0, 300, 500])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- dtype=np.int32)
-
- instance.accumulate(self.weights)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.int32)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
-
- def test_nominal_accumulate_forced_float32(self):
- """
- int32 weights, float32 weighted_histogram
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700., 0., 0., 300., 500.])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins,
- dtype=np.float32)
-
- instance.accumulate(self.weights.astype(np.int32))
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.float32)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
-
- def test_nominal_accumulate_int32(self):
- """
- int32 weights
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700, 0, 0, 300, 500])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.int32)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- instance.accumulate(self.weights.astype(np.int32))
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- self.assertEqual(w_histo.dtype, np.int32)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
-
- def test_nominal_accumulate_int32_double(self):
- """
- int32 weights
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700, 0, 0, 300, 500])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.int32)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- instance = HistogramndLut(self.sample,
- self.histo_range,
- self.n_bins)
-
- instance.accumulate(self.weights.astype(np.int32))
- instance.accumulate(self.weights)
-
- histo = instance.histo()
- w_histo = instance.weighted_histo()
-
- expected_h *= 2
- expected_c *= 2
-
- self.assertEqual(w_histo.dtype, np.int32)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(w_histo, expected_c))
-
- def testNoneNativeTypes(self):
- type = self.sample.dtype.newbyteorder("B")
- sampleB = self.sample.astype(type)
-
- type = self.sample.dtype.newbyteorder("L")
- sampleL = self.sample.astype(type)
-
- histo_inst = HistogramndLut(sampleB,
- self.histo_range,
- self.n_bins)
-
- histo_inst = HistogramndLut(sampleL,
- self.histo_range,
- self.n_bins)
-
-
-class TestHistogramndLut_nominal_1d(_TestHistogramndLut_nominal):
- ndims = 1
-
-
-class TestHistogramndLut_nominal_2d(_TestHistogramndLut_nominal):
- ndims = 2
-
-
-class TestHistogramndLut_nominal_3d(_TestHistogramndLut_nominal):
- ndims = 3
-
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-test_cases = (TestHistogramndLut_nominal_1d,
- TestHistogramndLut_nominal_2d,
- TestHistogramndLut_nominal_3d,)
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_calibration.py b/silx/math/test/test_calibration.py
deleted file mode 100644
index 5a0c20e..0000000
--- a/silx/math/test/test_calibration.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests of the calibration module"""
-
-from __future__ import division
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "14/05/2018"
-
-
-import unittest
-
-import numpy
-
-from silx.math.calibration import NoCalibration, LinearCalibration, \
- ArrayCalibration, FunctionCalibration
-
-
-X = numpy.array([3.14, 2.73, 1337])
-
-
-class TestNoCalibration(unittest.TestCase):
- def setUp(self):
- self.calib = NoCalibration()
-
- def testIsAffine(self):
- self.assertTrue(self.calib.is_affine())
-
- def testSlope(self):
- self.assertEqual(self.calib.get_slope(), 1.)
-
- def testYIntercept(self):
- self.assertEqual(self.calib(0.),
- 0.)
-
- def testCall(self):
- self.assertTrue(numpy.array_equal(self.calib(X), X))
-
-
-class TestLinearCalibration(unittest.TestCase):
- def setUp(self):
- self.y_intercept = 1.5
- self.slope = 2.5
- self.calib = LinearCalibration(y_intercept=self.y_intercept,
- slope=self.slope)
-
- def testIsAffine(self):
- self.assertTrue(self.calib.is_affine())
-
- def testSlope(self):
- self.assertEqual(self.calib.get_slope(), self.slope)
-
- def testYIntercept(self):
- self.assertEqual(self.calib(0.),
- self.y_intercept)
-
- def testCall(self):
- self.assertTrue(numpy.array_equal(self.calib(X),
- self.y_intercept + self.slope * X))
-
-
-class TestArrayCalibration(unittest.TestCase):
- def setUp(self):
- self.arr = numpy.array([45.2, 25.3, 666., -8.])
- self.calib = ArrayCalibration(self.arr)
- self.affine_calib = ArrayCalibration([0.1, 0.2, 0.3])
-
- def testIsAffine(self):
- self.assertFalse(self.calib.is_affine())
- self.assertTrue(self.affine_calib.is_affine())
-
- def testSlope(self):
- with self.assertRaises(AttributeError):
- self.calib.get_slope()
- self.assertEqual(self.affine_calib.get_slope(),
- 0.1)
-
- def testYIntercept(self):
- self.assertEqual(self.calib(0),
- self.arr[0])
-
- def testCall(self):
- with self.assertRaises(ValueError):
- # X is an array with a different shape
- self.calib(X)
-
- with self.assertRaises(ValueError):
- # floats are not valid indices
- self.calib(3.14)
-
- self.assertTrue(
- numpy.array_equal(self.calib([1, 2, 3, 4]),
- self.arr))
-
- for idx, value in enumerate(self.arr):
- self.assertEqual(self.calib(idx), value)
-
-
-class TestFunctionCalibration(unittest.TestCase):
- def setUp(self):
- self.non_affine_fun = numpy.sin
- self.non_affine_calib = FunctionCalibration(self.non_affine_fun)
-
- self.affine_fun = lambda x: 52. * x + 0.01
- self.affine_calib = FunctionCalibration(self.affine_fun,
- is_affine=True)
-
- def testIsAffine(self):
- self.assertFalse(self.non_affine_calib.is_affine())
- self.assertTrue(self.affine_calib.is_affine())
-
- def testSlope(self):
- with self.assertRaises(AttributeError):
- self.non_affine_calib.get_slope()
- self.assertAlmostEqual(self.affine_calib.get_slope(),
- 52.)
-
- def testCall(self):
- for x in X:
- self.assertAlmostEqual(self.non_affine_calib(x),
- self.non_affine_fun(x))
- self.assertAlmostEqual(self.affine_calib(x),
- self.affine_fun(x))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestNoCalibration))
- test_suite.addTest(loadTests(TestArrayCalibration))
- test_suite.addTest(loadTests(TestLinearCalibration))
- test_suite.addTest(loadTests(TestFunctionCalibration))
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_colormap.py b/silx/math/test/test_colormap.py
deleted file mode 100644
index 4e80710..0000000
--- a/silx/math/test/test_colormap.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Test for colormap mapping implementation"""
-
-from __future__ import division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "16/05/2018"
-
-
-import logging
-import sys
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.math import colormap
-
-
-_logger = logging.getLogger(__name__)
-
-
-class TestNormalization(ParametricTestCase):
- """Test silx.math.colormap.Normalization sub classes"""
-
- def _testCodec(self, normalization, rtol=1e-5):
- """Test apply/revert for normalizations"""
- test_data = (numpy.arange(1, 10, dtype=numpy.int32),
- numpy.linspace(1., 100., 1000, dtype=numpy.float32),
- numpy.linspace(-1., 1., 100, dtype=numpy.float32),
- 1.,
- 1)
-
- for index in range(len(test_data)):
- with self.subTest(normalization=normalization, data_index=index):
- data = test_data[index]
- normalized = normalization.apply(data, 1., 100.)
- result = normalization.revert(normalized, 1., 100.)
-
- self.assertTrue(numpy.array_equal(
- numpy.isnan(normalized), numpy.isnan(result)))
-
- if isinstance(data, numpy.ndarray):
- notNaN = numpy.logical_not(numpy.isnan(result))
- data = data[notNaN]
- result = result[notNaN]
- self.assertTrue(numpy.allclose(data, result, rtol=rtol))
-
- def testLinearNormalization(self):
- """Test for LinearNormalization"""
- normalization = colormap.LinearNormalization()
- self._testCodec(normalization)
-
- def testLogarithmicNormalization(self):
- """Test for LogarithmicNormalization"""
- normalization = colormap.LogarithmicNormalization()
- # relative tolerance is higher because of the log approximation
- self._testCodec(normalization, rtol=1e-3)
-
- # Specific extra tests
- self.assertTrue(numpy.isnan(normalization.apply(-1., 1., 100.)))
- self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 1., 100.)))
- self.assertEqual(normalization.apply(numpy.inf, 1., 100.), numpy.inf)
- self.assertEqual(normalization.apply(0, 1., 100.), - numpy.inf)
-
- def testArcsinhNormalization(self):
- """Test for ArcsinhNormalization"""
- self._testCodec(colormap.ArcsinhNormalization())
-
- def testSqrtNormalization(self):
- """Test for SqrtNormalization"""
- normalization = colormap.SqrtNormalization()
- self._testCodec(normalization)
-
- # Specific extra tests
- self.assertTrue(numpy.isnan(normalization.apply(-1., 0., 100.)))
- self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 0., 100.)))
- self.assertEqual(normalization.apply(numpy.inf, 0., 100.), numpy.inf)
- self.assertEqual(normalization.apply(0, 0., 100.), 0.)
-
-
-class TestColormap(ParametricTestCase):
- """Test silx.math.colormap.cmap"""
-
- NORMALIZATIONS = (
- 'linear',
- 'log',
- 'arcsinh',
- 'sqrt',
- colormap.LinearNormalization(),
- colormap.LogarithmicNormalization(),
- colormap.PowerNormalization(2.),
- colormap.PowerNormalization(0.5))
-
- @staticmethod
- def ref_colormap(data, colors, vmin, vmax, normalization, nan_color):
- """Reference implementation of colormap
-
- :param numpy.ndarray data: Data to convert
- :param numpy.ndarray colors: Color look-up-table
- :param float vmin: Lower bound of the colormap range
- :param float vmax: Upper bound of the colormap range
- :param str normalization: Normalization to use
- :param Union[numpy.ndarray, None] nan_color: Color to use for NaN
- """
- norm_functions = {'linear': lambda v: v,
- 'log': numpy.log10,
- 'arcsinh': numpy.arcsinh,
- 'sqrt': numpy.sqrt}
-
- if isinstance(normalization, str):
- norm_function = norm_functions[normalization]
- else:
- def norm_function(value):
- return normalization.apply(value, vmin, vmax)
-
- with numpy.errstate(divide='ignore', invalid='ignore'):
- # Ignore divide by zero and invalid value encountered in log10, sqrt
- norm_data, vmin, vmax = map(norm_function, (data, vmin, vmax))
-
- if normalization == 'arcsinh' and sys.platform == 'win32':
- # There is a difference of behavior of numpy.arcsinh
- # between Windows and other OS for results of infinite values
- # This makes Windows behaves as Linux and MacOS
- norm_data[data == numpy.inf] = numpy.inf
- norm_data[data == -numpy.inf] = -numpy.inf
-
- nb_colors = len(colors)
- scale = nb_colors / (vmax - vmin)
-
- # Substraction must be done in float to avoid overflow with uint
- indices = numpy.clip(scale * (norm_data - float(vmin)),
- 0, nb_colors - 1)
- indices[numpy.isnan(indices)] = nb_colors # Use an extra index for NaN
- indices = indices.astype('uint')
-
- # Add NaN color to array
- if nan_color is None:
- nan_color = (0,) * colors.shape[-1]
- colors = numpy.append(colors, numpy.atleast_2d(nan_color), axis=0)
-
- return colors[indices]
-
- def _test(self, data, colors, vmin, vmax, normalization, nan_color):
- """Run test of colormap against alternative implementation
-
- :param numpy.ndarray data: Data to convert
- :param numpy.ndarray colors: Color look-up-table
- :param float vmin: Lower bound of the colormap range
- :param float vmax: Upper bound of the colormap range
- :param str normalization: Normalization to use
- :param Union[numpy.ndarray, None] nan_color: Color to use for NaN
- """
- image = colormap.cmap(
- data, colors, vmin, vmax, normalization, nan_color)
-
- ref_image = self.ref_colormap(
- data, colors, vmin, vmax, normalization, nan_color)
-
- self.assertTrue(numpy.allclose(ref_image, image))
- self.assertEqual(image.dtype, colors.dtype)
- self.assertEqual(image.shape, data.shape + (colors.shape[-1],))
-
- def test(self):
- """Test all dtypes with finite data
-
- Test all supported types and endianness
- """
- colors = numpy.zeros((256, 4), dtype=numpy.uint8)
- colors[:, 0] = numpy.arange(len(colors))
- colors[:, 3] = 255
-
- # Generates (u)int and floats types
- dtypes = [e + k + i for e in '<>' for k in 'uif' for i in '1248'
- if k != 'f' or i != '1']
- dtypes.append(numpy.dtype(numpy.longdouble).name) # Add long double
-
- for normalization in self.NORMALIZATIONS:
- for dtype in dtypes:
- with self.subTest(dtype=dtype, normalization=normalization):
- _logger.info('normalization: %s, dtype: %s',
- normalization, dtype)
- data = numpy.arange(-5, 15, dtype=dtype).reshape(4, 5)
-
- self._test(data, colors, 1, 10, normalization, None)
-
- def test_not_finite(self):
- """Test float data with not finite values"""
- colors = numpy.zeros((256, 4), dtype=numpy.uint8)
- colors[:, 0] = numpy.arange(len(colors))
- colors[:, 3] = 255
-
- test_data = { # message: data
- 'no finite values': (float('inf'), float('-inf'), float('nan')),
- 'only NaN': (float('nan'), float('nan'), float('nan')),
- 'mix finite/not finite': (float('inf'), float('-inf'), 1., float('nan')),
- }
-
- for normalization in self.NORMALIZATIONS:
- for msg, data in test_data.items():
- with self.subTest(msg, normalization=normalization):
- _logger.info('normalization: %s, %s', normalization, msg)
- data = numpy.array(data, dtype=numpy.float64)
- self._test(data, colors, 1, 10, normalization, (0, 0, 0, 0))
-
- def test_errors(self):
- """Test raising exception for bad vmin, vmax, normalization parameters
- """
- colors = numpy.zeros((256, 4), dtype=numpy.uint8)
- colors[:, 0] = numpy.arange(len(colors))
- colors[:, 3] = 255
-
- data = numpy.arange(10, dtype=numpy.float64)
-
- test_params = [ # (vmin, vmax, normalization)
- (-1., 2., 'log'),
- (0., 1., 'log'),
- (1., 0., 'log'),
- (-1., 1., 'sqrt'),
- (1., -1., 'sqrt'),
- ]
-
- for vmin, vmax, normalization in test_params:
- with self.subTest(
- vmin=vmin, vmax=vmax, normalization=normalization):
- _logger.info('normalization: %s, range: [%f, %f]',
- normalization, vmin, vmax)
- with self.assertRaises(ValueError):
- self._test(data, colors, vmin, vmax, normalization, None)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestColormap))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestNormalization))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/math/test/test_combo.py b/silx/math/test/test_combo.py
deleted file mode 100644
index 1335763..0000000
--- a/silx/math/test/test_combo.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests of the combo module"""
-
-from __future__ import division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-
-from silx.math.combo import min_max
-
-
-class TestMinMax(ParametricTestCase):
- """Tests of min max combo"""
-
- FLOATING_DTYPES = 'float32', 'float64'
- if hasattr(numpy, "float128"):
- FLOATING_DTYPES += ('float128',)
- SIGNED_INT_DTYPES = 'int8', 'int16', 'int32', 'int64'
- UNSIGNED_INT_DTYPES = 'uint8', 'uint16', 'uint32', 'uint64'
- DTYPES = FLOATING_DTYPES + SIGNED_INT_DTYPES + UNSIGNED_INT_DTYPES
-
- def _numpy_min_max(self, data, min_positive=False, finite=False):
- """Reference numpy implementation of min_max
-
- :param numpy.ndarray data: Data set to use for test
- :param bool min_positive: True to test with positive min
- :param bool finite: True to only test finite values
- """
- data = numpy.array(data, copy=False)
- if data.size == 0:
- raise ValueError('Zero-sized array')
-
- minimum = None
- argmin = None
- maximum = None
- argmax = None
- min_pos = None
- argmin_pos = None
-
- if finite:
- filtered_data = data[numpy.isfinite(data)]
- else:
- filtered_data = data
-
- if filtered_data.size > 0:
- if numpy.all(numpy.isnan(filtered_data)):
- minimum = numpy.nan
- argmin = 0
- maximum = numpy.nan
- argmax = 0
- else:
- minimum = numpy.nanmin(filtered_data)
- # nanargmin equivalent
- argmin = numpy.where(data == minimum)[0][0]
- maximum = numpy.nanmax(filtered_data)
- # nanargmax equivalent
- argmax = numpy.where(data == maximum)[0][0]
-
- if min_positive:
- with numpy.errstate(invalid='ignore'):
- # Ignore invalid value encountered in greater
- pos_data = filtered_data[filtered_data > 0]
- if pos_data.size > 0:
- min_pos = numpy.min(pos_data)
- argmin_pos = numpy.where(data == min_pos)[0][0]
-
- return minimum, min_pos, maximum, argmin, argmin_pos, argmax
-
- def _test_min_max(self, data, min_positive, finite=False):
- """Compare min_max with numpy for the given dataset
-
- :param numpy.ndarray data: Data set to use for test
- :param bool min_positive: True to test with positive min
- :param bool finite: True to only test finite values
- """
- minimum, min_pos, maximum, argmin, argmin_pos, argmax = \
- self._numpy_min_max(data, min_positive, finite)
-
- result = min_max(data, min_positive, finite)
-
- self.assertSimilar(minimum, result.minimum)
- self.assertSimilar(min_pos, result.min_positive)
- self.assertSimilar(maximum, result.maximum)
- self.assertSimilar(argmin, result.argmin)
- self.assertSimilar(argmin_pos, result.argmin_positive)
- self.assertSimilar(argmax, result.argmax)
-
- def assertSimilar(self, a, b):
- """Assert that a and b are both None or NaN or that a == b."""
- self.assertTrue((a is None and b is None) or
- (numpy.isnan(a) and numpy.isnan(b)) or
- a == b)
-
- def test_different_datasets(self):
- """Test min_max with different numpy.arange datasets."""
- size = 1000
-
- for dtype in self.DTYPES:
-
- tests = {
- '0 to N': (0, 1),
- 'N-1 to 0': (size - 1, -1)}
- if dtype not in self.UNSIGNED_INT_DTYPES:
- tests['N/2 to -N/2'] = size // 2, -1
- tests['0 to -N'] = 0, -1
-
- for name, (start, step) in tests.items():
- for min_positive in (True, False):
- with self.subTest(dtype=dtype,
- min_positive=min_positive,
- data=name):
- data = numpy.arange(
- start, start + step * size, step, dtype=dtype)
-
- self._test_min_max(data, min_positive)
-
- def test_nodata(self):
- """Test min_max with None and empty array"""
- for dtype in self.DTYPES:
- with self.subTest(dtype=dtype):
- with self.assertRaises(TypeError):
- min_max(None)
-
- data = numpy.array((), dtype=dtype)
- with self.assertRaises(ValueError):
- min_max(data)
-
- NAN_TEST_DATA = [
- (float('nan'), float('nan')), # All NaNs
- (float('nan'), 1.0), # NaN first and positive
- (float('nan'), -1.0), # NaN first and negative
- (1.0, 2.0, float('nan')), # NaN last and positive
- (-1.0, -2.0, float('nan')), # NaN last and negative
- (1.0, float('nan'), -1.0), # Some NaN
- ]
-
- def test_nandata(self):
- """Test min_max with NaN in data"""
- for dtype in self.FLOATING_DTYPES:
- for data in self.NAN_TEST_DATA:
- with self.subTest(dtype=dtype, data=data):
- data = numpy.array(data, dtype=dtype)
- self._test_min_max(data, min_positive=True)
-
- INF_TEST_DATA = [
- [float('inf')] * 3, # All +inf
- [float('-inf')] * 3, # All -inf
- (float('inf'), float('-inf')), # + and - inf
- (float('inf'), float('-inf'), float('nan')), # +/-inf, nan last
- (float('nan'), float('-inf'), float('inf')), # +/-inf, nan first
- (float('inf'), float('nan'), float('-inf')), # +/-inf, nan center
- ]
-
- def test_infdata(self):
- """Test min_max with inf."""
- for dtype in self.FLOATING_DTYPES:
- for data in self.INF_TEST_DATA:
- with self.subTest(dtype=dtype, data=data):
- data = numpy.array(data, dtype=dtype)
- self._test_min_max(data, min_positive=True)
-
- def test_finite(self):
- """Test min_max with finite=True"""
- tests = [
- (-1., 2., 0.), # Basic test
- (float('nan'), float('inf'), float('-inf')), # NaN + Inf
- (float('nan'), float('inf'), -2, float('-inf')), # NaN + Inf + 1 value
- (float('inf'), -3, -2), # values + inf
- ]
- tests += self.INF_TEST_DATA
- tests += self.NAN_TEST_DATA
-
- for dtype in self.FLOATING_DTYPES:
- for data in tests:
- with self.subTest(dtype=dtype, data=data):
- data = numpy.array(data, dtype=dtype)
- self._test_min_max(data, min_positive=True, finite=True)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestMinMax))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_histogramnd_error.py b/silx/math/test/test_histogramnd_error.py
deleted file mode 100644
index a8c66a0..0000000
--- a/silx/math/test/test_histogramnd_error.py
+++ /dev/null
@@ -1,535 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-
-__authors__ = ["D. Naudet"]
-__license__ = "MIT"
-__date__ = "01/02/2016"
-
-"""
-Tests of the histogramnd function, error cases.
-"""
-import sys
-import platform
-import unittest
-
-import numpy as np
-
-from silx.math.chistogramnd import chistogramnd as histogramnd
-from silx.math import Histogramnd
-
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-class _Test_chistogramnd_errors(unittest.TestCase):
- """
- Unit tests of the chistogramnd error cases.
- """
- def setUp(self):
- raise NotImplementedError('')
-
- def test_weights_shape(self):
- """
- """
-
- for err_w_shape in self.err_weights_shapes:
- test_msg = ('Testing invalid weights shape : {0}'
- ''.format(err_w_shape))
-
- err_weights = np.random.randint(0,
- high=10,
- size=err_w_shape)
- err_weights = err_weights.astype(np.double)
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=err_weights)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str,
- '<weights> must be an array whose length '
- 'is equal to the number of samples.')
-
- def test_histo_range_shape(self):
- """
- """
- n_dims = 1 if len(self.s_shape) == 1 else self.s_shape[1]
- expected_txt_tpl = ('<histo_range> error : expected {n_dims} sets '
- 'of lower and upper bin edges, '
- 'got the following instead : {histo_range}. '
- '(provided <sample> contains '
- '{n_dims}D values)')
-
- for err_histo_range in self.err_histo_range_shapes:
- test_msg = ('Testing invalid histo_range shape : {0}'
- ''.format(err_histo_range))
-
- expected_txt = expected_txt_tpl.format(histo_range=err_histo_range,
- n_dims=n_dims)
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- err_histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_nbins_shape(self):
- """
- """
-
- expected_txt = ('n_bins must be either a scalar (same number '
- 'of bins for all dimensions) or '
- 'an array (number of bins for each '
- 'dimension).')
-
- for err_n_bins in self.err_n_bins_shapes:
- test_msg = ('Testing invalid n_bins shape : {0}'
- ''.format(err_n_bins))
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- err_n_bins,
- weights=self.weights)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_nbins_values(self):
- """
- """
- expected_txt = ('<n_bins> : only positive values allowed.')
-
- for err_n_bins in self.err_n_bins_values:
- test_msg = ('Testing invalid n_bins value : {0}'
- ''.format(err_n_bins))
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- err_n_bins,
- weights=self.weights)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_histo_shape(self):
- """
- """
- for err_h_shape in self.err_histo_shapes:
-
- # windows & python 2.7 : numpy shapes are long values
- if platform.system() == 'Windows':
- version = (sys.version_info.major, sys.version_info.minor)
- if version <= (2, 7):
- err_h_shape = tuple([long(val) for val in err_h_shape])
-
- test_msg = ('Testing invalid histo shape : {0}'
- ''.format(err_h_shape))
-
- expected_txt = ('Provided <histo> array doesn\'t have '
- 'a shape compatible with <n_bins> '
- ': should be {0} instead of {1}.'
- ''.format(self.h_shape, err_h_shape))
-
- histo = np.zeros(shape=err_h_shape, dtype=np.uint32)
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- histo=histo)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_histo_dtype(self):
- """
- """
- for err_h_dtype in self.err_histo_dtypes:
- test_msg = ('Testing invalid histo dtype : {0}'
- ''.format(err_h_dtype))
-
- histo = np.zeros(shape=self.h_shape, dtype=err_h_dtype)
-
- expected_txt = ('Provided <histo> array doesn\'t have '
- 'the expected type '
- ': should be {0} instead of {1}.'
- ''.format(np.uint32, histo.dtype))
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- histo=histo)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_weighted_histo_shape(self):
- """
- """
- # using the same values as histo
- for err_h_shape in self.err_histo_shapes:
-
- # windows & python 2.7 : numpy shapes are long values
- if platform.system() == 'Windows':
- version = (sys.version_info.major, sys.version_info.minor)
- if version <= (2, 7):
- err_h_shape = tuple([long(val) for val in err_h_shape])
-
- test_msg = ('Testing invalid weighted_histo shape : {0}'
- ''.format(err_h_shape))
-
- expected_txt = ('Provided <weighted_histo> array doesn\'t have '
- 'a shape compatible with <n_bins> '
- ': should be {0} instead of {1}.'
- ''.format(self.h_shape, err_h_shape))
-
- cumul = np.zeros(shape=err_h_shape, dtype=np.double)
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- weighted_histo=cumul)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_cumul_dtype(self):
- """
- """
- # using the same values as histo
- for err_h_dtype in self.err_histo_dtypes:
- test_msg = ('Testing invalid weighted_histo dtype : {0}'
- ''.format(err_h_dtype))
-
- cumul = np.zeros(shape=self.h_shape, dtype=err_h_dtype)
-
- expected_txt = ('Provided <weighted_histo> array doesn\'t have '
- 'the expected type '
- ': should be {0} or {1} instead of {2}.'
- ''.format(np.float64, np.float32, cumul.dtype))
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- weighted_histo=cumul)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_wh_histo_dtype(self):
- """
- """
- # using the same values as histo
- for err_h_dtype in self.err_histo_dtypes:
- test_msg = ('Testing invalid wh_dtype dtype : {0}'
- ''.format(err_h_dtype))
-
- expected_txt = ('<wh_dtype> type not supported : {0}.'
- ''.format(err_h_dtype))
-
- ex_str = None
- try:
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- wh_dtype=err_h_dtype)[0:2]
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_unmanaged_dtypes(self):
- """
- """
- for err_unmanaged_dtype in self.err_unmanaged_dtypes:
- test_msg = ('Testing unmanaged dtypes : {0}'
- ''.format(err_unmanaged_dtype))
-
- sample = self.sample.astype(err_unmanaged_dtype[0])
- weights = self.weights.astype(err_unmanaged_dtype[1])
-
- expected_txt = ('Case not supported - sample:{0} '
- 'and weights:{1}.'
- ''.format(sample.dtype,
- weights.dtype))
-
- ex_str = None
- try:
- histogramnd(sample,
- self.histo_range,
- self.n_bins,
- weights=weights)
- except TypeError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str, msg=test_msg)
- self.assertEqual(ex_str, expected_txt, msg=test_msg)
-
- def test_uncontiguous_histo(self):
- """
- """
- # non contiguous array
- shape = np.array(self.n_bins, ndmin=1)
- shape[0] *= 2
- histo_tmp = np.zeros(shape)
- histo = histo_tmp[::2, ...]
-
- expected_txt = ('<histo> must be a C_CONTIGUOUS numpy array.')
-
- ex_str = None
- try:
- histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- histo=histo)
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str)
- self.assertEqual(ex_str, expected_txt)
-
- def test_uncontiguous_weighted_histo(self):
- """
- """
- # non contiguous array
- shape = np.array(self.n_bins, ndmin=1)
- shape[0] *= 2
- cumul_tmp = np.zeros(shape)
- cumul = cumul_tmp[::2, ...]
-
- expected_txt = ('<weighted_histo> must be a C_CONTIGUOUS numpy array.')
-
- ex_str = None
- try:
- histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- weighted_histo=cumul)
- except ValueError as ex:
- ex_str = str(ex)
-
- self.assertIsNotNone(ex_str)
- self.assertEqual(ex_str, expected_txt)
-
-
-class Test_chistogramnd_1D_errors(_Test_chistogramnd_errors):
- """
- Unit tests of the 1D histogramnd error cases.
- """
-
- def setUp(self):
- # nominal values
- self.n_elements = 1000
- self.s_shape = (self.n_elements,)
- self.w_shape = (self.n_elements,)
-
- self.histo_range = [0., 100.]
- self.n_bins = 10
-
- self.h_shape = (self.n_bins,)
-
- self.sample = np.random.randint(0,
- high=10,
- size=self.s_shape)
- self.sample = self.sample.astype(np.double)
-
- self.weights = np.random.randint(0,
- high=10,
- size=self.w_shape)
- self.weights = self.weights.astype(np.double)
-
- self.err_weights_shapes = ((self.n_elements+1,),
- (self.n_elements-1,),
- (self.n_elements-1, 3))
- self.err_histo_range_shapes = ([0.],
- [0., 1., 2.],
- [[0.], [1.]])
- self.err_n_bins_shapes = ([10, 2],
- [[10], [2]])
- self.err_n_bins_values = (0,
- [-10],
- None)
- self.err_histo_shapes = ((self.n_bins+1,),
- (self.n_bins-1,),
- (self.n_bins, self.n_bins))
- # these are used for testing the histo parameter as well
- # as the weighted_histo parameter.
- self.err_histo_dtypes = (np.uint16,
- np.float16)
-
- self.err_unmanaged_dtypes = ((np.double, np.uint16),
- (np.uint16, np.double),
- (np.uint16, np.uint16))
-
-class Test_chistogramnd_ND_range(unittest.TestCase):
- """
-
- """
-
- def test_invalid_histo_range(self):
- data = np.random.random((60, 60))
- nbins = 10
-
- with self.assertRaises(ValueError):
- histo_range = data.min(), np.inf
-
- Histogramnd(sample=data.ravel(),
- histo_range=histo_range,
- n_bins=nbins)
-
- histo_range = data.min(), np.nan
-
- Histogramnd(sample=data.ravel(),
- histo_range=histo_range,
- n_bins=nbins)
-
-
-class Test_chistogramnd_ND_errors(_Test_chistogramnd_errors):
- """
- Unit tests of the 3D histogramnd error cases.
- """
-
- def setUp(self):
- # nominal values
- self.n_elements = 1000
- self.s_shape = (self.n_elements, 3)
- self.w_shape = (self.n_elements,)
-
- self.histo_range = [[0., 100.], [0., 100.], [0., 100.]]
- self.n_bins = (10, 20, 30)
-
- self.h_shape = self.n_bins
-
- self.sample = np.random.randint(0,
- high=10,
- size=self.s_shape)
- self.sample = self.sample.astype(np.double)
-
- self.weights = np.random.randint(0,
- high=10,
- size=self.w_shape)
- self.weights = self.weights.astype(np.double)
-
- self.err_weights_shapes = ((self.n_elements+1,),
- (self.n_elements-1,),
- (self.n_elements-1, 3))
- self.err_histo_range_shapes = ([0.],
- [0., 1.],
- [[0., 10.], [0., 10.]],
- [0., 10., 0, 10., 0, 10.])
- self.err_n_bins_shapes = ([10, 2],
- [[10], [20], [30]])
- self.err_n_bins_values = (0,
- [-10],
- [10, 20, -4],
- None,
- [10, None, 30])
- self.err_histo_shapes = ((self.n_bins[0]+1,
- self.n_bins[1],
- self.n_bins[2]),
- (self.n_bins[0],
- self.n_bins[1],
- self.n_bins[2]-1),
- (self.n_bins[0],
- self.n_bins[1]),
- (self.n_bins[1],
- self.n_bins[0],
- self.n_bins[2]),
- (self.n_bins[0],
- self.n_bins[1],
- self.n_bins[2],
- 10)
- )
- # these are used for testing the histo parameter as well
- # as the weighted_histo parameter.
- self.err_histo_dtypes = (np.uint16,
- np.float16)
-
- self.err_unmanaged_dtypes = ((np.double, np.uint16),
- (np.uint16, np.double),
- (np.uint16, np.uint16))
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-test_cases = (Test_chistogramnd_1D_errors,
- Test_chistogramnd_ND_errors,
- Test_chistogramnd_ND_range)
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_histogramnd_nominal.py b/silx/math/test/test_histogramnd_nominal.py
deleted file mode 100644
index 92febdd..0000000
--- a/silx/math/test/test_histogramnd_nominal.py
+++ /dev/null
@@ -1,949 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Nominal tests of the histogramnd function.
-"""
-
-import unittest
-
-import numpy as np
-
-from silx.math.chistogramnd import chistogramnd as histogramnd
-from silx.math import Histogramnd
-
-
-def _get_bin_edges(histo_range, n_bins, n_dims):
- edges = []
- for i_dim in range(n_dims):
- edges.append(histo_range[i_dim, 0] +
- np.arange(n_bins[i_dim] + 1) *
- (histo_range[i_dim, 1] - histo_range[i_dim, 0]) /
- n_bins[i_dim])
- return tuple(edges)
-
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-class _Test_chistogramnd_nominal(unittest.TestCase):
- """
- Unit tests of the histogramnd function.
- """
-
- ndims = None
-
- def setUp(self):
- ndims = self.ndims
- self.tested_dim = ndims-1
-
- if ndims is None:
- raise ValueError('ndims class member not set.')
-
- sample = np.array([5.5, -3.3,
- 0., -0.5,
- 3.3, 8.8,
- -7.7, 6.0,
- -4.0])
-
- weights = np.array([500.5, -300.3,
- 0.01, -0.5,
- 300.3, 800.8,
- -700.7, 600.6,
- -400.4])
-
- n_elems = len(sample)
-
- if ndims == 1:
- shape = (n_elems,)
- else:
- shape = (n_elems, ndims)
-
- self.sample = np.zeros(shape=shape, dtype=sample.dtype)
- if ndims == 1:
- self.sample = sample
- else:
- self.sample[..., ndims-1] = sample
-
- self.weights = weights
-
- # the tests are performed along one dimension,
- # all the other bins indices along the other dimensions
- # are expected to be 2
- # (e.g : when testing a 2D sample : [0, x] will go into
- # bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
- # for the first dimension)
- self.other_axes_index = 2
- self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
- self.histo_range[ndims-1] = [-4., 6.]
-
- self.n_bins = np.array([4]*ndims)
- self.n_bins[ndims-1] = 5
-
- if ndims == 1:
- def fill_histo(h, v, dim, op=None):
- if op:
- h[:] = op(h[:], v)
- else:
- h[:] = v
- self.fill_histo = fill_histo
- else:
- def fill_histo(h, v, dim, op=None):
- idx = [self.other_axes_index]*len(h.shape)
- idx[dim] = slice(0, None)
- idx = tuple(idx)
- if op:
- h[idx] = op(h[idx], v)
- else:
- h[idx] = v
- self.fill_histo = fill_histo
-
- def test_nominal(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul, bin_edges = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
-
- def test_nominal_wh_dtype(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul, bin_edges = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- wh_dtype=np.float32)
-
- self.assertEqual(cumul.dtype, np.float32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.allclose(cumul, expected_c))
-
- def test_nominal_uncontiguous_sample(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- shape = list(self.sample.shape)
- shape[0] *= 2
- sample = np.zeros(shape, dtype=self.sample.dtype)
- uncontig_sample = sample[::2, ...]
- uncontig_sample[:] = self.sample
-
- self.assertFalse(uncontig_sample.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
-
- histo, cumul, bin_edges = histogramnd(uncontig_sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_nominal_uncontiguous_weights(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- shape = list(self.weights.shape)
- shape[0] *= 2
- weights = np.zeros(shape, dtype=self.weights.dtype)
- uncontig_weights = weights[::2, ...]
- uncontig_weights[:] = self.weights
-
- self.assertFalse(uncontig_weights.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
-
- histo, cumul, bin_edges = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=uncontig_weights)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_nominal_wo_weights(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(cumul is None)
-
- def test_nominal_wo_weights_w_cumul(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
-
- # creating an array of ones just to make sure that
- # it is not cleared by histogramnd
- cumul_in = np.ones(self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None,
- weighted_histo=cumul_in)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(cumul is None)
- self.assertTrue(np.array_equal(cumul_in,
- np.ones(shape=self.n_bins,
- dtype=np.double)))
-
- def test_nominal_wo_weights_w_histo(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
-
- # creating an array of ones just to make sure that
- # it is not cleared by histogramnd
- histo_in = np.ones(self.n_bins, dtype=np.uint32)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None,
- histo=histo_in)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h + 1))
- self.assertTrue(cumul is None)
- self.assertEqual(id(histo), id(histo_in))
-
- def test_nominal_last_bin_closed(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 2])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_int32_weights_double_weights_range(self):
- """
- """
- weight_min = -299.9 # ===> will be cast to -299
- weight_max = 499.9 # ===> will be cast to 499
-
- expected_h_tpl = np.array([0, 1, 1, 1, 0])
- expected_c_tpl = np.array([0., 0., 0., 300., 0.])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights.astype(np.int32),
- weight_min=weight_min,
- weight_max=weight_max)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_reuse_histo(self):
- """
- """
-
- expected_h_tpl = np.array([2, 3, 2, 2, 2])
- expected_c_tpl = np.array([0.0, -7007, -5.0, 0.1, 3003.0])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
-
- sample_2 = self.sample[:]
- if len(sample_2.shape) == 1:
- idx = (slice(0, None),)
- else:
- idx = slice(0, None), self.tested_dim
-
- sample_2[idx] += 2
-
- histo_2, cumul = histogramnd(sample_2, # <==== !!
- self.histo_range,
- self.n_bins,
- weights=10 * self.weights, # <==== !!
- histo=histo)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
- self.assertEqual(id(histo), id(histo_2))
-
- def test_reuse_cumul(self):
- """
- """
-
- expected_h_tpl = np.array([0, 2, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
-
- sample_2 = self.sample[:]
- if len(sample_2.shape) == 1:
- idx = (slice(0, None),)
- else:
- idx = slice(0, None), self.tested_dim
-
- sample_2[idx] += 2
-
- histo, cumul_2 = histogramnd(sample_2, # <==== !!
- self.histo_range,
- self.n_bins,
- weights=10 * self.weights, # <==== !!
- weighted_histo=cumul)[0:2]
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
- self.assertEqual(id(cumul), id(cumul_2))
-
- def test_reuse_cumul_float(self):
- """
- """
-
- expected_h_tpl = np.array([0, 2, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5],
- dtype=np.float32)
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)[0:2]
-
- # converting the cumul array to float
- cumul = cumul.astype(np.float32)
-
- sample_2 = self.sample[:]
- if len(sample_2.shape) == 1:
- idx = (slice(0, None),)
- else:
- idx = slice(0, None), self.tested_dim
-
- sample_2[idx] += 2
-
- histo, cumul_2 = histogramnd(sample_2, # <==== !!
- self.histo_range,
- self.n_bins,
- weights=10 * self.weights, # <==== !!
- weighted_histo=cumul)[0:2]
-
- self.assertEqual(cumul.dtype, np.float32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertEqual(id(cumul), id(cumul_2))
- self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
-
-class _Test_Histogramnd_nominal(unittest.TestCase):
- """
- Unit tests of the Histogramnd class.
- """
-
- ndims = None
-
- def setUp(self):
- ndims = self.ndims
- self.tested_dim = ndims-1
-
- if ndims is None:
- raise ValueError('ndims class member not set.')
-
- sample = np.array([5.5, -3.3,
- 0., -0.5,
- 3.3, 8.8,
- -7.7, 6.0,
- -4.0])
-
- weights = np.array([500.5, -300.3,
- 0.01, -0.5,
- 300.3, 800.8,
- -700.7, 600.6,
- -400.4])
-
- n_elems = len(sample)
-
- if ndims == 1:
- shape = (n_elems,)
- else:
- shape = (n_elems, ndims)
-
- self.sample = np.zeros(shape=shape, dtype=sample.dtype)
- if ndims == 1:
- self.sample = sample
- else:
- self.sample[..., ndims-1] = sample
-
- self.weights = weights
-
- # the tests are performed along one dimension,
- # all the other bins indices along the other dimensions
- # are expected to be 2
- # (e.g : when testing a 2D sample : [0, x] will go into
- # bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
- # for the first dimension)
- self.other_axes_index = 2
- self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
- self.histo_range[ndims-1] = [-4., 6.]
-
- self.n_bins = np.array([4]*ndims)
- self.n_bins[ndims-1] = 5
-
- if ndims == 1:
- def fill_histo(h, v, dim, op=None):
- if op:
- h[:] = op(h[:], v)
- else:
- h[:] = v
- self.fill_histo = fill_histo
- else:
- def fill_histo(h, v, dim, op=None):
- idx = [self.other_axes_index]*len(h.shape)
- idx[dim] = slice(0, None)
- if op:
- h[idx] = op(h[idx], v)
- else:
- h[idx] = v
- self.fill_histo = fill_histo
-
- def test_nominal(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- histo, cumul, bin_edges = histo
-
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
-
- def test_nominal_wh_dtype(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul, bin_edges = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- wh_dtype=np.float32)
-
- self.assertEqual(cumul.dtype, np.float32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.allclose(cumul, expected_c))
-
- def test_nominal_uncontiguous_sample(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- shape = list(self.sample.shape)
- shape[0] *= 2
- sample = np.zeros(shape, dtype=self.sample.dtype)
- uncontig_sample = sample[::2, ...]
- uncontig_sample[:] = self.sample
-
- self.assertFalse(uncontig_sample.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
-
- histo, cumul, bin_edges = Histogramnd(uncontig_sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_nominal_uncontiguous_weights(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- shape = list(self.weights.shape)
- shape[0] *= 2
- weights = np.zeros(shape, dtype=self.weights.dtype)
- uncontig_weights = weights[::2, ...]
- uncontig_weights[:] = self.weights
-
- self.assertFalse(uncontig_weights.flags['C_CONTIGUOUS'],
- msg='Making sure the array is not contiguous.')
-
- histo, cumul, bin_edges = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=uncontig_weights)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_nominal_wo_weights(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
-
- histo, cumul = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(cumul is None)
-
- def test_nominal_last_bin_closed(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 2])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_int32_weights_double_weights_range(self):
- """
- """
- weight_min = -299.9 # ===> will be cast to -299
- weight_max = 499.9 # ===> will be cast to 499
-
- expected_h_tpl = np.array([0, 1, 1, 1, 0])
- expected_c_tpl = np.array([0., 0., 0., 300., 0.])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo, cumul = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights.astype(np.int32),
- weight_min=weight_min,
- weight_max=weight_max)[0:2]
-
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def test_nominal_no_sample(self):
- """
- """
-
- histo_inst = Histogramnd(None,
- self.histo_range,
- self.n_bins)
-
- histo, weighted_histo, edges = histo_inst
-
- self.assertIsNone(histo)
- self.assertIsNone(weighted_histo)
- self.assertIsNone(edges)
- self.assertIsNone(histo_inst.histo)
- self.assertIsNone(histo_inst.weighted_histo)
- self.assertIsNone(histo_inst.edges)
-
- def test_empty_init_accumulate(self):
- """
- """
- expected_h_tpl = np.array([2, 1, 1, 1, 1])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo_inst = Histogramnd(None,
- self.histo_range,
- self.n_bins)
-
- histo_inst.accumulate(self.sample,
- weights=self.weights)
-
- histo = histo_inst.histo
- cumul = histo_inst.weighted_histo
- bin_edges = histo_inst.edges
-
- expected_edges = _get_bin_edges(self.histo_range,
- self.n_bins,
- self.ndims)
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertEqual(histo.dtype, np.uint32)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- for i_edges, edges in enumerate(expected_edges):
- self.assertTrue(np.array_equal(bin_edges[i_edges],
- expected_edges[i_edges]),
- msg='Testing bin_edges for dim {0}'
- ''.format(i_edges+1))
-
- def test_accumulate(self):
- """
- """
-
- expected_h_tpl = np.array([2, 3, 2, 2, 2])
- expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo_inst = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- sample_2 = self.sample[:]
- if len(sample_2.shape) == 1:
- idx = (slice(0, None),)
- else:
- idx = slice(0, None), self.tested_dim
-
- sample_2[idx] += 2
-
- histo_inst.accumulate(sample_2, # <==== !!
- weights=10 * self.weights) # <==== !!
-
- histo = histo_inst.histo
- cumul = histo_inst.weighted_histo
- bin_edges = histo_inst.edges
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
-
- def test_accumulate_no_weights(self):
- """
- """
-
- expected_h_tpl = np.array([2, 3, 2, 2, 2])
- expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo_inst = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- sample_2 = self.sample[:]
- if len(sample_2.shape) == 1:
- idx = (slice(0, None),)
- else:
- idx = slice(0, None), self.tested_dim
-
- sample_2[idx] += 2
-
- histo_inst.accumulate(sample_2) # <==== !!
-
- histo = histo_inst.histo
- cumul = histo_inst.weighted_histo
- bin_edges = histo_inst.edges
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
-
- def test_accumulate_no_weights_at_init(self):
- """
- """
-
- expected_h_tpl = np.array([2, 3, 2, 2, 2])
- expected_c_tpl = np.array([0.0, -700.7, -0.5, 0.01, 300.3])
-
- expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
- expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
-
- self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
- self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
-
- histo_inst = Histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=None) # <==== !!
-
- cumul = histo_inst.weighted_histo
- self.assertIsNone(cumul)
-
- sample_2 = self.sample[:]
- if len(sample_2.shape) == 1:
- idx = (slice(0, None),)
- else:
- idx = slice(0, None), self.tested_dim
-
- sample_2[idx] += 2
-
- histo_inst.accumulate(sample_2,
- weights=self.weights) # <==== !!
-
- histo = histo_inst.histo
- cumul = histo_inst.weighted_histo
- bin_edges = histo_inst.edges
-
- self.assertEqual(cumul.dtype, np.float64)
- self.assertTrue(np.array_equal(histo, expected_h))
- self.assertTrue(np.array_equal(cumul, expected_c))
-
- def testNoneNativeTypes(self):
- type = self.sample.dtype.newbyteorder("B")
- sampleB = self.sample.astype(type)
-
- type = self.sample.dtype.newbyteorder("L")
- sampleL = self.sample.astype(type)
-
- histo_inst = Histogramnd(sampleB,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
- histo_inst = Histogramnd(sampleL,
- self.histo_range,
- self.n_bins,
- weights=self.weights)
-
-
-class Test_chistogram_nominal_1d(_Test_chistogramnd_nominal):
- ndims = 1
-
-
-class Test_chistogram_nominal_2d(_Test_chistogramnd_nominal):
- ndims = 2
-
-
-class Test_chistogram_nominal_3d(_Test_chistogramnd_nominal):
- ndims = 3
-
-
-class Test_Histogramnd_nominal_1d(_Test_Histogramnd_nominal):
- ndims = 1
-
-
-class Test_Histogramnd_nominal_2d(_Test_Histogramnd_nominal):
- ndims = 2
-
-
-class Test_Histogramnd_nominal_3d(_Test_Histogramnd_nominal):
- ndims = 3
-
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-test_cases = (Test_chistogram_nominal_1d,
- Test_chistogram_nominal_2d,
- Test_chistogram_nominal_3d,
- Test_Histogramnd_nominal_1d,
- # Test_Histogramnd_nominal_2d,
- # Test_Histogramnd_nominal_3d
- )
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_histogramnd_vs_np.py b/silx/math/test/test_histogramnd_vs_np.py
deleted file mode 100644
index 36d71f9..0000000
--- a/silx/math/test/test_histogramnd_vs_np.py
+++ /dev/null
@@ -1,848 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""
-Tests for the histogramnd function.
-Results are compared to numpy's histogramdd.
-"""
-
-import unittest
-import operator
-
-import numpy as np
-
-from silx.math.chistogramnd import chistogramnd as histogramnd
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-_RTOL_DICT = {np.float64: 10**-13,
- np.float32: 10**-5}
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-def _add_values_to_array_if_missing(array, values, n_values):
- max_in_col = np.any(array[:, ...] == values, axis=0)
-
- if len(array.shape) == 1:
- if not max_in_col:
- rnd_idx = np.random.randint(0,
- high=len(array)-1,
- size=(n_values,))
- array[rnd_idx] = values
- else:
- for i in range(len(max_in_col)):
- if not max_in_col[i]:
- rnd_idx = np.random.randint(0,
- high=len(array)-1,
- size=(n_values,))
- array[rnd_idx, i] = values[i]
-
-
-def _get_values_index(array, values, op=operator.lt):
- idx = op(array[:, ...], values)
- if array.ndim > 1:
- idx = np.all(idx, axis=1)
- return np.where(idx)[0]
-
-
-def _get_in_range_indices(array,
- minvalues,
- maxvalues,
- minop=operator.ge,
- maxop=operator.lt):
- idx = np.logical_and(minop(array, minvalues),
- maxop(array, maxvalues))
- if array.ndim > 1:
- idx = np.all(idx, axis=1)
- return np.where(idx)[0]
-
-
-class _TestHistogramnd(unittest.TestCase):
-
- """
- Unit tests of the histogramnd function.
- """
- sample_rng = None
- weights_rng = None
- n_dims = None
-
- filter_min = None
- filter_max = None
-
- histo_range = None
- n_bins = None
-
- dtype_sample = None
- dtype_weights = None
-
- def generate_data(self):
-
- self.longMessage = True
-
- int_min = 0
- int_max = 100000
- n_elements = 10**5
-
- if self.n_dims == 1:
- shape = (n_elements,)
- else:
- shape = (n_elements, self.n_dims,)
-
- self.rng_state = np.random.get_state()
-
- self.state_msg = ('Current RNG state :\n'
- '{0}'.format(self.rng_state))
-
- sample = np.random.randint(int_min,
- high=int_max,
- size=shape)
-
- sample = sample.astype(self.dtype_sample)
- sample = (self.sample_rng[0] +
- (sample-int_min) *
- (self.sample_rng[1]-self.sample_rng[0]) /
- (int_max-int_min)).astype(self.dtype_sample)
-
- weights = np.random.randint(int_min,
- high=int_max,
- size=(n_elements,))
- weights = weights.astype(self.dtype_weights)
- weights = (self.weights_rng[0] +
- (weights-int_min) *
- (self.weights_rng[1]-self.weights_rng[0]) /
- (int_max-int_min)).astype(self.dtype_weights)
-
- # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
- # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
- # the bins range are cast to the same type as the sample
- # in order to get the same results as numpy
- # (which doesnt cast the range)
- self.histo_range = np.array(self.histo_range).astype(self.dtype_sample)
-
- # adding some values that are equal to the max
- # in order to test the opened/closed last bin
- bins_max = [b[1] for b in self.histo_range]
- _add_values_to_array_if_missing(sample,
- bins_max,
- 100)
-
- # adding some values that are equal to the min weight value
- # in order to test the filters
- _add_values_to_array_if_missing(weights,
- self.weights_rng[0],
- 100)
-
- # adding some values that are equal to the max weight value
- # in order to test the filters
- _add_values_to_array_if_missing(weights,
- self.weights_rng[1],
- 100)
-
- return sample, weights
-
- def setUp(self):
- self.sample, self.weights = self.generate_data()
- self.rtol = _RTOL_DICT.get(self.dtype_weights, None)
-
- def array_compare(self, ar_a, ar_b):
- if self.rtol is None:
- return np.array_equal(ar_a, ar_b)
- return np.allclose(ar_a, ar_b, self.rtol)
-
- def test_bin_ranges(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
-
- result_np = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
-
- for i_edges, edges in enumerate(result_c[2]):
- # allclose for now until I can try with the latest version (TBD)
- # of numpy
- self.assertTrue(np.allclose(edges,
- result_np[1][i_edges]),
- msg='{0}. Testing bin_edges for dim {1}.'
- ''.format(self.state_msg, i_edges+1))
-
- def test_last_bin_closed(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
-
- result_np = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
- # comparing weights
- weights_cmp = np.array_equal(result_c[1],
- result_np_w[0])
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(weights_cmp, msg=self.state_msg)
-
- bins_min = [rng[0] for rng in self.histo_range]
- bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample,
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
-
- self.assertEqual(result_c[0].sum(), inrange_idx.shape[0],
- msg=self.state_msg)
-
- # we have to sum the weights using the same precision as the
- # histogramnd function
- weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
-
- def test_last_bin_open(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=False)
-
- bins_max = [rng[1] for rng in self.histo_range]
- filtered_idx = _get_values_index(self.sample, bins_max)
-
- result_np = np.histogramdd(self.sample[filtered_idx],
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w = np.histogramdd(self.sample[filtered_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[filtered_idx])
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c[0], result_np[0])
- # comparing weights
- weights_cmp = np.array_equal(result_c[1],
- result_np_w[0])
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(weights_cmp, msg=self.state_msg)
-
- bins_min = [rng[0] for rng in self.histo_range]
- bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample,
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.lt)
-
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
- # we have to sum the weights using the same precision as the
- # histogramnd function
- weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
-
- def test_filter_min(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weight_min=self.filter_min)
-
- # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
- filter_min = self.dtype_weights(self.filter_min)
-
- weight_idx = _get_values_index(self.weights,
- filter_min, # <------ !!!
- operator.ge)
-
- result_np = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[weight_idx])
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
- # comparing weights
- weights_cmp = np.array_equal(result_c[1], result_np_w[0])
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(weights_cmp, msg=self.state_msg)
-
- bins_min = [rng[0] for rng in self.histo_range]
- bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample[weight_idx],
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
-
- inrange_idx = weight_idx[inrange_idx]
-
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
-
- # we have to sum the weights using the same precision as the
- # histogramnd function
- weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
-
- def test_filter_max(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weight_max=self.filter_max)
-
- # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
- filter_max = self.dtype_weights(self.filter_max)
-
- weight_idx = _get_values_index(self.weights,
- filter_max, # <------ !!!
- operator.le)
-
- result_np = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[weight_idx])
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
- # comparing weights
- weights_cmp = np.array_equal(result_c[1], result_np_w[0])
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(weights_cmp, msg=self.state_msg)
-
- bins_min = [rng[0] for rng in self.histo_range]
- bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample[weight_idx],
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
-
- inrange_idx = weight_idx[inrange_idx]
-
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
-
- # we have to sum the weights using the same precision as the
- # histogramnd function
- weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
-
- def test_filter_minmax(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weight_min=self.filter_min,
- weight_max=self.filter_max)
-
- # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
- filter_min = self.dtype_weights(self.filter_min)
- filter_max = self.dtype_weights(self.filter_max)
-
- weight_idx = _get_in_range_indices(self.weights,
- filter_min, # <------ !!!
- filter_max, # <------ !!!
- minop=operator.ge,
- maxop=operator.le)
-
- result_np = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w = np.histogramdd(self.sample[weight_idx],
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights[weight_idx])
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c[0],
- result_np[0])
- # comparing weights
- weights_cmp = np.array_equal(result_c[1], result_np_w[0])
-
- self.assertTrue(hits_cmp)
- self.assertTrue(weights_cmp)
-
- bins_min = [rng[0] for rng in self.histo_range]
- bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample[weight_idx],
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
-
- inrange_idx = weight_idx[inrange_idx]
-
- self.assertEqual(result_c[0].sum(), len(inrange_idx),
- msg=self.state_msg)
-
- # we have to sum the weights using the same precision as the
- # histogramnd function
- weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
- self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
- msg=self.state_msg)
-
- def test_reuse_histo(self):
- """
-
- """
- result_c_1 = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
-
- result_np_1 = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
-
- np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
-
- sample_2, weights_2 = self.generate_data()
-
- result_c_2 = histogramnd(sample_2,
- self.histo_range,
- self.n_bins,
- weights=weights_2,
- last_bin_closed=True,
- histo=result_c_1[0])
-
- result_np_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range,
- weights=weights_2)
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c_2[0],
- result_np_1[0] +
- result_np_2[0])
- # comparing weights
- weights_cmp = np.array_equal(result_c_2[1],
- result_np_w_2[0])
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(weights_cmp, msg=self.state_msg)
-
- def test_reuse_cumul(self):
- """
-
- """
- result_c = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True)
-
- np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
-
- sample_2, weights_2 = self.generate_data()
-
- result_c_2 = histogramnd(sample_2,
- self.histo_range,
- self.n_bins,
- weights=weights_2,
- last_bin_closed=True,
- weighted_histo=result_c[1])
-
- result_np_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w_2 = np.histogramdd(sample_2,
- bins=self.n_bins,
- range=self.histo_range,
- weights=weights_2)
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c_2[0],
- result_np_2[0])
- # comparing weights
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertTrue(self.array_compare(result_c_2[1],
- result_np_w[0] + result_np_w_2[0]),
- msg=self.state_msg)
-
- def test_reuse_cumul_float(self):
- """
-
- """
- n_bins = np.array(self.n_bins, ndmin=1)
- if len(self.sample.shape) == 2:
- if len(n_bins) == self.sample.shape[1]:
- shp = tuple([x for x in n_bins])
- else:
- shp = (self.n_bins,) * self.sample.shape[1]
- cumul = np.zeros(shp, dtype=np.float32)
- else:
- shp = (self.n_bins,)
- cumul = np.zeros(shp, dtype=np.float32)
-
- result_c_1 = histogramnd(self.sample,
- self.histo_range,
- self.n_bins,
- weights=self.weights,
- last_bin_closed=True,
- weighted_histo=cumul)
-
- result_np_1 = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range)
-
- result_np_w_1 = np.histogramdd(self.sample,
- bins=self.n_bins,
- range=self.histo_range,
- weights=self.weights)
-
- # comparing "hits"
- hits_cmp = np.array_equal(result_c_1[0],
- result_np_1[0])
-
- self.assertTrue(hits_cmp, msg=self.state_msg)
- self.assertEqual(result_c_1[1].dtype, np.float32, msg=self.state_msg)
-
- bins_min = [rng[0] for rng in self.histo_range]
- bins_max = [rng[1] for rng in self.histo_range]
- inrange_idx = _get_in_range_indices(self.sample,
- bins_min,
- bins_max,
- minop=operator.ge,
- maxop=operator.le)
- weights_sum = \
- self.weights[inrange_idx].astype(np.float32).sum(dtype=np.float64)
- self.assertTrue(np.allclose(result_c_1[1].sum(dtype=np.float64),
- weights_sum), msg=self.state_msg)
- self.assertTrue(np.allclose(result_c_1[1].sum(dtype=np.float64),
- result_np_w_1[0].sum(dtype=np.float64)),
- msg=self.state_msg)
-
-
-class _TestHistogramnd_1d(_TestHistogramnd):
-
- """
- Unit tests of the 1D histogramnd function.
- """
-
- sample_rng = [-55., 100.]
- weights_rng = [-70., 150.]
- n_dims = 1
- filter_min = -15.6
- filter_max = 85.7
-
- histo_range = [[-30.2, 90.3]]
- n_bins = 30
-
- dtype = None
-
-
-class _TestHistogramnd_2d(_TestHistogramnd):
-
- """
- Unit tests of the 1D histogramnd function.
- """
-
- sample_rng = [-50.2, 100.99]
- weights_rng = [70., 150.]
- n_dims = 2
- filter_min = 81.7
- filter_max = 135.3
-
- histo_range = [[10., 90.], [20., 70.]]
- n_bins = 30
-
- dtype = None
-
-
-class _TestHistogramnd_3d(_TestHistogramnd):
-
- """
- Unit tests of the 1D histogramnd function.
- """
-
- sample_rng = [10.2, 200.9]
- weights_rng = [0., 100.]
- n_dims = 3
- filter_min = 31.5
- filter_max = 83.7
-
- histo_range = [[30.8, 150.2], [20.1, 90.9], [10.1, 195.]]
- n_bins = 30
-
- dtype = None
-
-
-# ################################################################
-# ################################################################
-# ################################################################
-# ################################################################
-
-
-class TestHistogramnd_1d_double_double(_TestHistogramnd_1d):
- dtype_sample = np.double
- dtype_weights = np.double
-
-
-class TestHistogramnd_1d_double_float(_TestHistogramnd_1d):
- dtype_sample = np.double
- dtype_weights = np.float32
-
-
-class TestHistogramnd_1d_double_int32(_TestHistogramnd_1d):
- dtype_sample = np.double
- dtype_weights = np.int32
-
-
-class TestHistogramnd_1d_float_double(_TestHistogramnd_1d):
- dtype_sample = np.float32
- dtype_weights = np.double
-
-
-class TestHistogramnd_1d_float_float(_TestHistogramnd_1d):
- dtype_sample = np.float32
- dtype_weights = np.float32
-
-
-class TestHistogramnd_1d_float_int32(_TestHistogramnd_1d):
- dtype_sample = np.float32
- dtype_weights = np.int32
-
-
-class TestHistogramnd_1d_int32_double(_TestHistogramnd_1d):
- dtype_sample = np.int32
- dtype_weights = np.double
-
-
-class TestHistogramnd_1d_int32_float(_TestHistogramnd_1d):
- dtype_sample = np.int32
- dtype_weights = np.float32
-
-
-class TestHistogramnd_1d_int32_int32(_TestHistogramnd_1d):
- dtype_sample = np.int32
- dtype_weights = np.int32
-
-
-class TestHistogramnd_2d_double_double(_TestHistogramnd_2d):
- dtype_sample = np.double
- dtype_weights = np.double
-
-
-class TestHistogramnd_2d_double_float(_TestHistogramnd_2d):
- dtype_sample = np.double
- dtype_weights = np.float32
-
-
-class TestHistogramnd_2d_double_int32(_TestHistogramnd_2d):
- dtype_sample = np.double
- dtype_weights = np.int32
-
-
-class TestHistogramnd_2d_float_double(_TestHistogramnd_2d):
- dtype_sample = np.float32
- dtype_weights = np.double
-
-
-class TestHistogramnd_2d_float_float(_TestHistogramnd_2d):
- dtype_sample = np.float32
- dtype_weights = np.float32
-
-
-class TestHistogramnd_2d_float_int32(_TestHistogramnd_2d):
- dtype_sample = np.float32
- dtype_weights = np.int32
-
-
-class TestHistogramnd_2d_int32_double(_TestHistogramnd_2d):
- dtype_sample = np.int32
- dtype_weights = np.double
-
-
-class TestHistogramnd_2d_int32_float(_TestHistogramnd_2d):
- dtype_sample = np.int32
- dtype_weights = np.float32
-
-
-class TestHistogramnd_2d_int32_int32(_TestHistogramnd_2d):
- dtype_sample = np.int32
- dtype_weights = np.int32
-
-
-class TestHistogramnd_3d_double_double(_TestHistogramnd_3d):
- dtype_sample = np.double
- dtype_weights = np.double
-
-
-class TestHistogramnd_3d_double_float(_TestHistogramnd_3d):
- dtype_sample = np.double
- dtype_weights = np.float32
-
-
-class TestHistogramnd_3d_double_int32(_TestHistogramnd_3d):
- dtype_sample = np.double
- dtype_weights = np.int32
-
-
-class TestHistogramnd_3d_float_double(_TestHistogramnd_3d):
- dtype_sample = np.float32
- dtype_weights = np.double
-
-
-class TestHistogramnd_3d_float_float(_TestHistogramnd_3d):
- dtype_sample = np.float32
- dtype_weights = np.float32
-
-
-class TestHistogramnd_3d_float_int32(_TestHistogramnd_3d):
- dtype_sample = np.float32
- dtype_weights = np.int32
-
-
-class TestHistogramnd_3d_int32_double(_TestHistogramnd_3d):
- dtype_sample = np.int32
- dtype_weights = np.double
-
-
-class TestHistogramnd_3d_int32_float(_TestHistogramnd_3d):
- dtype_sample = np.int32
- dtype_weights = np.float32
-
-
-class TestHistogramnd_3d_int32_int32(_TestHistogramnd_3d):
- dtype_sample = np.int32
- dtype_weights = np.int32
-
-
-# ==============================================================
-# ==============================================================
-# ==============================================================
-
-
-test_cases = (TestHistogramnd_1d_double_double,
- TestHistogramnd_1d_double_float,
- TestHistogramnd_1d_double_int32,
- TestHistogramnd_1d_float_double,
- TestHistogramnd_1d_float_float,
- TestHistogramnd_1d_float_int32,
- TestHistogramnd_1d_int32_double,
- TestHistogramnd_1d_int32_float,
- TestHistogramnd_1d_int32_int32,
- TestHistogramnd_2d_double_double,
- TestHistogramnd_2d_double_float,
- TestHistogramnd_2d_double_int32,
- TestHistogramnd_2d_float_double,
- TestHistogramnd_2d_float_float,
- TestHistogramnd_2d_float_int32,
- TestHistogramnd_2d_int32_double,
- TestHistogramnd_2d_int32_float,
- TestHistogramnd_2d_int32_int32,
- TestHistogramnd_3d_double_double,
- TestHistogramnd_3d_double_float,
- TestHistogramnd_3d_double_int32,
- TestHistogramnd_3d_float_double,
- TestHistogramnd_3d_float_float,
- TestHistogramnd_3d_float_int32,
- TestHistogramnd_3d_int32_double,
- TestHistogramnd_3d_int32_float,
- TestHistogramnd_3d_int32_int32,)
-
-
-def suite():
- loader = unittest.defaultTestLoader
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- tests = loader.loadTestsFromTestCase(test_class)
- test_suite.addTests(tests)
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/math/test/test_interpolate.py b/silx/math/test/test_interpolate.py
deleted file mode 100644
index 9989302..0000000
--- a/silx/math/test/test_interpolate.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Test for interpolate module"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "11/07/2019"
-
-
-import unittest
-
-import numpy
-try:
- from scipy.interpolate import interpn
-except ImportError:
- interpn = None
-
-from silx.utils.testutils import ParametricTestCase
-from silx.math import interpolate
-
-
-@unittest.skipUnless(interpn is not None, "scipy missing")
-class TestInterp3d(ParametricTestCase):
- """Test silx.math.interpolate.interp3d"""
-
- @staticmethod
- def ref_interp3d(data, points):
- """Reference implementation of interp3d based on scipy
-
- :param numpy.ndarray data: 3D floating dataset
- :param numpy.ndarray points: Array of points of shape (N, 3)
- """
- return interpn(
- [numpy.arange(dim, dtype=data.dtype) for dim in data.shape],
- data,
- points,
- method='linear')
-
- def test_random_data(self):
- """Test interp3d with random data"""
- size = 32
- npoints = 10
-
- ref_data = numpy.random.random((size, size, size))
- ref_points = numpy.random.random(npoints*3).reshape(npoints, 3) * (size -1)
-
- for dtype in (numpy.float32, numpy.float64):
- data = ref_data.astype(dtype)
- points = ref_points.astype(dtype)
- ref_result = self.ref_interp3d(data, points)
-
- for method in (u'linear', u'linear_omp'):
- with self.subTest(method=method):
- result = interpolate.interp3d(data, points, method=method)
- self.assertTrue(numpy.allclose(ref_result, result))
-
- def test_notfinite_data(self):
- """Test interp3d with NaN and inf"""
- data = numpy.ones((3, 3, 3), dtype=numpy.float64)
- data[0, 0, 0] = numpy.nan
- data[2, 2, 2] = numpy.inf
- points = numpy.array([(0.5, 0.5, 0.5),
- (1.5, 1.5, 1.5)])
-
- for method in (u'linear', u'linear_omp'):
- with self.subTest(method=method):
- result = interpolate.interp3d(
- data, points, method=method)
- self.assertTrue(numpy.isnan(result[0]))
- self.assertTrue(result[1] == numpy.inf)
-
- def test_points_outside(self):
- """Test interp3d with points outside the volume"""
- data = numpy.ones((4, 4, 4), dtype=numpy.float64)
- points = numpy.array([(-0.1, -0.1, -0.1),
- (3.1, 3.1, 3.1),
- (-0.1, 1., 1.),
- (1., 1., 3.1)])
-
- for method in (u'linear', u'linear_omp'):
- for fill_value in (numpy.nan, 0., -1.):
- with self.subTest(method=method):
- result = interpolate.interp3d(
- data, points, method=method, fill_value=fill_value)
- if numpy.isnan(fill_value):
- self.assertTrue(numpy.all(numpy.isnan(result)))
- else:
- self.assertTrue(numpy.all(numpy.equal(result, fill_value)))
-
- def test_integer_points(self):
- """Test interp3d with integer points coord"""
- data = numpy.arange(4**3, dtype=numpy.float64).reshape(4, 4, 4)
- points = numpy.array([(0., 0., 0.),
- (0., 0., 1.),
- (2., 3., 0.),
- (3., 3., 3.)])
-
- ref_result = data[tuple(points.T.astype(numpy.int32))]
-
- for method in (u'linear', u'linear_omp'):
- with self.subTest(method=method):
- result = interpolate.interp3d(data, points, method=method)
- self.assertTrue(numpy.allclose(ref_result, result))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestInterp3d))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/math/test/test_marchingcubes.py b/silx/math/test/test_marchingcubes.py
deleted file mode 100644
index 41f7e30..0000000
--- a/silx/math/test/test_marchingcubes.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests of the marchingcubes module"""
-
-from __future__ import division
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-import unittest
-
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-
-from silx.math import marchingcubes
-
-
-class TestMarchingCubes(ParametricTestCase):
- """Tests of marching cubes"""
-
- def assertAllClose(self, array1, array2, msg=None,
- rtol=1e-05, atol=1e-08):
- """Assert that the 2 numpy.ndarrays are almost equal.
-
- :param str msg: Message to provide when assert fails
- :param float rtol: Relative tolerance, see :func:`numpy.allclose`
- :param float atol: Absolute tolerance, see :func:`numpy.allclose`
- """
- if not numpy.allclose(array1, array2, rtol, atol):
- raise self.failureException(msg)
-
- def test_cube(self):
- """Unit tests with a single cube"""
-
- # No isosurface
- cube_zero = numpy.zeros((2, 2, 2), dtype=numpy.float32)
-
- result = marchingcubes.MarchingCubes(cube_zero, 1.)
- self.assertEqual(result.shape, cube_zero.shape)
- self.assertEqual(result.isolevel, 1.)
- self.assertEqual(result.invert_normals, True)
-
- vertices, normals, indices = result
- self.assertEqual(len(vertices), 0)
- self.assertEqual(len(normals), 0)
- self.assertEqual(len(indices), 0)
-
- # Cube array dimensions: shape = (dim 0, dim 1, dim2)
- #
- # dim 0 (Z)
- # ^
- # |
- # 4 +------+ 5
- # /| /|
- # / | / |
- # 6 +------+ 7|
- # | | | |
- # |0 +---|--+ 1 -> dim 2 (X)
- # | / | /
- # |/ |/
- # 2 +------+ 3
- # /
- # dim 1 (Y)
-
- # isosurface perpendicular to dim 0 (Z)
- cube = numpy.array(
- (((0., 0.), (0., 0.)),
- ((1., 1.), (1., 1.))), dtype=numpy.float32)
- level = 0.5
- vertices, normals, indices = marchingcubes.MarchingCubes(
- cube, level, invert_normals=False)
- self.assertAllClose(vertices[:, 0], level)
- self.assertAllClose(normals, (1., 0., 0.))
- self.assertEqual(len(indices), 2)
-
- # isosurface perpendicular to dim 1 (Y)
- cube = numpy.array(
- (((0., 0.), (1., 1.)),
- ((0., 0.), (1., 1.))), dtype=numpy.float32)
- level = 0.2
- vertices, normals, indices = marchingcubes.MarchingCubes(cube, level)
- self.assertAllClose(vertices[:, 1], level)
- self.assertAllClose(normals, (0., -1., 0.))
- self.assertEqual(len(indices), 2)
-
- # isosurface perpendicular to dim 2 (X)
- cube = numpy.array(
- (((0., 1.), (0., 1.)),
- ((0., 1.), (0., 1.))), dtype=numpy.float32)
- level = 0.9
- vertices, normals, indices = marchingcubes.MarchingCubes(
- cube, level, invert_normals=False)
- self.assertAllClose(vertices[:, 2], level)
- self.assertAllClose(normals, (0., 0., 1.))
- self.assertEqual(len(indices), 2)
-
- # isosurface normal in dim1, dim 0 (Y, Z) plane
- cube = numpy.array(
- (((0., 0.), (0., 0.)),
- ((0., 0.), (1., 1.))), dtype=numpy.float32)
- level = 0.5
- vertices, normals, indices = marchingcubes.MarchingCubes(cube, level)
- self.assertAllClose(normals[:, 2], 0.)
- self.assertEqual(len(indices), 2)
-
- def test_sampling(self):
- """Test different sampling, comparing to reference without sampling"""
- isolevel = 0.5
- size = 9
- chessboard = numpy.zeros((size, size, size), dtype=numpy.float32)
- chessboard.reshape(-1)[::2] = 1 # OK as long as dimensions are odd
-
- ref_result = marchingcubes.MarchingCubes(chessboard, isolevel)
-
- samplings = [
- (2, 1, 1),
- (1, 2, 1),
- (1, 1, 2),
- (2, 2, 2),
- (3, 3, 3),
- (1, 3, 1),
- (1, 1, 3),
- ]
-
- for sampling in samplings:
- with self.subTest(sampling=sampling):
- sampling = numpy.array(sampling)
-
- data = 1e6 * numpy.ones(
- sampling * size, dtype=numpy.float32)
- # Copy ref chessboard in data according to sampling
- data[::sampling[0], ::sampling[1], ::sampling[2]] = chessboard
-
- result = marchingcubes.MarchingCubes(data, isolevel,
- sampling=sampling)
- # Compare vertices normalized with shape
- self.assertAllClose(
- ref_result.get_vertices() / ref_result.shape,
- result.get_vertices() / result.shape,
- atol=0., rtol=0.)
-
- # Compare normals
- # This comparison only works for normals aligned with axes
- # otherwise non uniform sampling would make different normals
- self.assertAllClose(ref_result.get_normals(),
- result.get_normals(),
- atol=0., rtol=0.)
-
- self.assertAllClose(ref_result.get_indices(),
- result.get_indices(),
- atol=0., rtol=0.)
-
-
-test_cases = (TestMarchingCubes,)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- for test_class in test_cases:
- test_suite.addTests(
- unittest.defaultTestLoader.loadTestsFromTestCase(test_class))
- return test_suite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/codec/test/__init__.py b/silx/opencl/codec/test/__init__.py
deleted file mode 100644
index ec76dd3..0000000
--- a/silx/opencl/codec/test/__init__.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2013-2017 European Synchrotron Radiation Facility, Grenoble, France
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["J. Kieffer"]
-__license__ = "MIT"
-__date__ = "13/10/2017"
-
-import unittest
-from . import test_byte_offset
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(test_byte_offset.suite())
-
- return testSuite
diff --git a/silx/opencl/codec/test/test_byte_offset.py b/silx/opencl/codec/test/test_byte_offset.py
deleted file mode 100644
index d1482ce..0000000
--- a/silx/opencl/codec/test/test_byte_offset.py
+++ /dev/null
@@ -1,315 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Project: Byte-offset decompression in OpenCL
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2013-2020 European Synchrotron Radiation Facility,
-# Grenoble, France
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following
-# conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
-# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
-# OTHER DEALINGS IN THE SOFTWARE.
-
-"""
-Test suite for byte-offset decompression
-"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Jérôme Kieffer"]
-__contact__ = "jerome.kieffer@esrf.eu"
-__license__ = "MIT"
-__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "02/03/2021"
-
-import sys
-import time
-import logging
-import numpy
-from silx.opencl.common import ocl, pyopencl
-from silx.opencl.codec import byte_offset
-import fabio
-import unittest
-logger = logging.getLogger(__name__)
-
-
-@unittest.skipUnless(ocl and pyopencl,
- "PyOpenCl is missing")
-class TestByteOffset(unittest.TestCase):
-
- @staticmethod
- def _create_test_data(shape, nexcept, lam=200):
- """Create test (image, compressed stream) pair.
-
- :param shape: Shape of test image
- :param int nexcept: Number of exceptions in the image
- :param lam: Expectation of interval argument for numpy.random.poisson
- :return: (reference image array, compressed stream)
- """
- size = numpy.prod(shape)
- ref = numpy.random.poisson(lam, numpy.prod(shape))
- exception_loc = numpy.random.randint(0, size, size=nexcept)
- exception_value = numpy.random.randint(0, 1000000, size=nexcept)
- ref[exception_loc] = exception_value
- ref.shape = shape
-
- raw = fabio.compression.compByteOffset(ref)
- return ref, raw
-
- def test_decompress(self):
- """
- tests the byte offset decompression on GPU
- """
- ref, raw = self._create_test_data(shape=(91, 97), nexcept=229)
- # ref, raw = self._create_test_data(shape=(7, 9), nexcept=0)
-
- size = numpy.prod(ref.shape)
-
- try:
- bo = byte_offset.ByteOffset(raw_size=len(raw), dec_size=size, profile=True)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- if sys.platform == "darwin":
- raise unittest.SkipTest("Byte-offset decompression is known to be buggy on MacOS-CPU")
- else:
- raise err
- print(bo.block_size)
-
- t0 = time.time()
- res_cy = fabio.compression.decByteOffset(raw)
- t1 = time.time()
- res_cl = bo.decode(raw)
- t2 = time.time()
- delta_cy = abs(ref.ravel() - res_cy).max()
- delta_cl = abs(ref.ravel() - res_cl.get()).max()
-
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
- bo.log_profile()
- # print(ref)
- # print(res_cl.get())
- self.assertEqual(delta_cy, 0, "Checks fabio works")
- self.assertEqual(delta_cl, 0, "Checks opencl works")
-
- def test_many_decompress(self, ntest=10):
- """
- tests the byte offset decompression on GPU, many images to ensure there
- is not leaking in memory
- """
- shape = (991, 997)
- size = numpy.prod(shape)
- ref, raw = self._create_test_data(shape=shape, nexcept=0, lam=100)
-
- try:
- bo = byte_offset.ByteOffset(len(raw), size, profile=True)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- if sys.platform == "darwin":
- raise unittest.SkipTest("Byte-offset decompression is known to be buggy on MacOS-CPU")
- else:
- raise err
- t0 = time.time()
- res_cy = fabio.compression.decByteOffset(raw)
- t1 = time.time()
- res_cl = bo(raw)
- t2 = time.time()
- delta_cy = abs(ref.ravel() - res_cy).max()
- delta_cl = abs(ref.ravel() - res_cl.get()).max()
- self.assertEqual(delta_cy, 0, "Checks fabio works")
- self.assertEqual(delta_cl, 0, "Checks opencl works")
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
-
- for i in range(ntest):
- ref, raw = self._create_test_data(shape=shape, nexcept=2729, lam=200)
-
- t0 = time.time()
- res_cy = fabio.compression.decByteOffset(raw)
- t1 = time.time()
- res_cl = bo(raw)
- t2 = time.time()
- delta_cy = abs(ref.ravel() - res_cy).max()
- delta_cl = abs(ref.ravel() - res_cl.get()).max()
- self.assertEqual(delta_cy, 0, "Checks fabio works #%i" % i)
- self.assertEqual(delta_cl, 0, "Checks opencl works #%i" % i)
-
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
- bo.log_profile(stats=True)
-
- def test_encode(self):
- """Test byte offset compression"""
- ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
-
- try:
- bo = byte_offset.ByteOffset(len(raw), ref.size, profile=True)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- raise err
-
- t0 = time.time()
- compressed_array = bo.encode(ref)
- t1 = time.time()
-
- compressed_stream = compressed_array.get().tobytes()
- self.assertEqual(raw, compressed_stream)
-
- logger.debug("Global execution time: OpenCL: %.3fms.",
- 1000.0 * (t1 - t0))
- bo.log_profile()
-
- def test_encode_to_array(self):
- """Test byte offset compression while providing an out array"""
-
- ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
-
- try:
- bo = byte_offset.ByteOffset(profile=True)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- raise err
- # Test with out buffer too small
- out = pyopencl.array.empty(bo.queue, (10,), numpy.int8)
- with self.assertRaises(ValueError):
- bo.encode(ref, out)
-
- # Test with out buffer too big
- out = pyopencl.array.empty(bo.queue, (len(raw) + 10,), numpy.int8)
-
- compressed_array = bo.encode(ref, out)
-
- # Get size from returned array
- compressed_size = compressed_array.size
- self.assertEqual(compressed_size, len(raw))
-
- # Get data from out array, read it from bo object queue
- out_bo_queue = out.with_queue(bo.queue)
- compressed_stream = out_bo_queue.get().tobytes()[:compressed_size]
- self.assertEqual(raw, compressed_stream)
-
- def test_encode_to_bytes(self):
- """Test byte offset compression to bytes"""
- ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
-
- try:
- bo = byte_offset.ByteOffset(profile=True)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- raise err
-
- t0 = time.time()
- res_fabio = fabio.compression.compByteOffset(ref)
- t1 = time.time()
- compressed_stream = bo.encode_to_bytes(ref)
- t2 = time.time()
-
- self.assertEqual(raw, compressed_stream)
-
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
- bo.log_profile()
-
- def test_encode_to_bytes_from_array(self):
- """Test byte offset compression to bytes from a pyopencl array.
- """
- ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
-
- try:
- bo = byte_offset.ByteOffset(profile=True)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- raise err
-
- d_ref = pyopencl.array.to_device(
- bo.queue, ref.astype(numpy.int32).ravel())
-
- t0 = time.time()
- res_fabio = fabio.compression.compByteOffset(ref)
- t1 = time.time()
- compressed_stream = bo.encode_to_bytes(d_ref)
- t2 = time.time()
-
- self.assertEqual(raw, compressed_stream)
-
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
- bo.log_profile()
-
- def test_many_encode(self, ntest=10):
- """Test byte offset compression with many image"""
- shape = (991, 997)
- ref, raw = self._create_test_data(shape=shape, nexcept=0, lam=100)
-
- try:
- bo = byte_offset.ByteOffset(profile=False)
- except (RuntimeError, pyopencl.RuntimeError) as err:
- logger.warning(err)
- raise err
-
- bo_durations = []
-
- t0 = time.time()
- res_fabio = fabio.compression.compByteOffset(ref)
- t1 = time.time()
- compressed_stream = bo.encode_to_bytes(ref)
- t2 = time.time()
- bo_durations.append(1000.0 * (t2 - t1))
-
- self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
-
- for i in range(ntest):
- ref, raw = self._create_test_data(shape=shape, nexcept=2729, lam=200)
-
- t0 = time.time()
- res_fabio = fabio.compression.compByteOffset(ref)
- t1 = time.time()
- compressed_stream = bo.encode_to_bytes(ref)
- t2 = time.time()
- bo_durations.append(1000.0 * (t2 - t1))
-
- self.assertEqual(raw, compressed_stream)
- logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
- 1000.0 * (t1 - t0),
- 1000.0 * (t2 - t1))
-
- logger.debug("OpenCL execution time: Mean: %fms, Min: %fms, Max: %fms",
- numpy.mean(bo_durations),
- numpy.min(bo_durations),
- numpy.max(bo_durations))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(TestByteOffset("test_decompress"))
- test_suite.addTest(TestByteOffset("test_many_decompress"))
- test_suite.addTest(TestByteOffset("test_encode"))
- test_suite.addTest(TestByteOffset("test_encode_to_array"))
- test_suite.addTest(TestByteOffset("test_encode_to_bytes"))
- test_suite.addTest(TestByteOffset("test_encode_to_bytes_from_array"))
- test_suite.addTest(TestByteOffset("test_many_encode"))
- return test_suite
diff --git a/silx/opencl/common.py b/silx/opencl/common.py
deleted file mode 100644
index b66b7b7..0000000
--- a/silx/opencl/common.py
+++ /dev/null
@@ -1,689 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Project: S I L X project
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2021 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
-#
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following
-# conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
-# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
-# OTHER DEALINGS IN THE SOFTWARE.
-#
-
-__author__ = "Jerome Kieffer"
-__contact__ = "Jerome.Kieffer@ESRF.eu"
-__license__ = "MIT"
-__copyright__ = "2012-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "30/11/2020"
-__status__ = "stable"
-__all__ = ["ocl", "pyopencl", "mf", "release_cl_buffers", "allocate_cl_buffers",
- "measure_workgroup_size", "kernel_workgroup_size"]
-
-import os
-import logging
-
-import numpy
-
-from .utils import get_opencl_code
-
-logger = logging.getLogger(__name__)
-
-if os.environ.get("SILX_OPENCL") in ["0", "False"]:
- logger.info("Use of OpenCL has been disabled from environment variable: SILX_OPENCL=0")
- pyopencl = None
-else:
- try:
- import pyopencl
- except ImportError:
- logger.warning("Unable to import pyOpenCl. Please install it from: https://pypi.org/project/pyopencl")
- pyopencl = None
- else:
- try:
- pyopencl.get_platforms()
- except pyopencl.LogicError:
- logger.warning("The module pyOpenCL has been imported but can't be used here")
- pyopencl = None
- else:
- import pyopencl.array as array
- mf = pyopencl.mem_flags
-
-if pyopencl is None:
-
- # Define default mem flags
- class mf(object):
- WRITE_ONLY = 1
- READ_ONLY = 1
- READ_WRITE = 1
-
-FLOP_PER_CORE = {"GPU": 64, # GPU, Fermi at least perform 64 flops per cycle/multicore, G80 were at 24 or 48 ...
- "CPU": 4, # CPU, at least intel's have 4 operation per cycle
- "ACC": 8} # ACC: the Xeon-phi (MIC) appears to be able to process 8 Flops per hyperthreaded-core
-
-# Sources : https://en.wikipedia.org/wiki/CUDA
-NVIDIA_FLOP_PER_CORE = {(1, 0): 24, # Guessed !
- (1, 1): 24, # Measured on G98 [Quadro NVS 295]
- (1, 2): 24, # Guessed !
- (1, 3): 24, # measured on a GT285 (GT200)
- (2, 0): 64, # Measured on a 580 (GF110)
- (2, 1): 96, # Measured on Quadro2000 GF106GL
- (3, 0): 384, # Guessed!
- (3, 5): 384, # Measured on K20
- (3, 7): 384, # K80: Guessed!
- (5, 0): 256, # Maxwell 4 warps/SM 2 flops/ CU
- (5, 2): 256, # Titan-X
- (5, 3): 256, # TX1
- (6, 0): 128, # GP100
- (6, 1): 128, # GP104
- (6, 2): 128, # ?
- (7, 0): 128, # Volta # measured on Telsa V100
- (7, 1): 128, # Volta ?
- }
-
-AMD_FLOP_PER_CORE = 160 # Measured on a M7820 10 core, 700MHz 1120GFlops
-
-
-class Device(object):
- """
- Simple class that contains the structure of an OpenCL device
- """
-
- def __init__(self, name="None", dtype=None, version=None, driver_version=None,
- extensions="", memory=None, available=None,
- cores=None, frequency=None, flop_core=None, idx=0, workgroup=1):
- """
- Simple container with some important data for the OpenCL device description.
-
- :param name: name of the device
- :param dtype: device type: CPU/GPU/ACC...
- :param version: driver version
- :param driver_version:
- :param extensions: List of opencl extensions
- :param memory: maximum memory available on the device
- :param available: is the device deactivated or not
- :param cores: number of SM/cores
- :param frequency: frequency of the device
- :param flop_core: Flopating Point operation per core per cycle
- :param idx: index of the device within the platform
- :param workgroup: max workgroup size
- """
- self.name = name.strip()
- self.type = dtype
- self.version = version
- self.driver_version = driver_version
- self.extensions = extensions.split()
- self.memory = memory
- self.available = available
- self.cores = cores
- self.frequency = frequency
- self.id = idx
- self.max_work_group_size = workgroup
- if not flop_core:
- flop_core = FLOP_PER_CORE.get(dtype, 1)
- if cores and frequency:
- self.flops = cores * frequency * flop_core
- else:
- self.flops = flop_core
-
- def __repr__(self):
- return "%s" % self.name
-
- def pretty_print(self):
- """
- Complete device description
-
- :return: string
- """
- lst = ["Name\t\t:\t%s" % self.name,
- "Type\t\t:\t%s" % self.type,
- "Memory\t\t:\t%.3f MB" % (self.memory / 2.0 ** 20),
- "Cores\t\t:\t%s CU" % self.cores,
- "Frequency\t:\t%s MHz" % self.frequency,
- "Speed\t\t:\t%.3f GFLOPS" % (self.flops / 1000.),
- "Version\t\t:\t%s" % self.version,
- "Available\t:\t%s" % self.available]
- return os.linesep.join(lst)
-
- def set_unavailable(self):
- """Use this method to flag a faulty device
- """
- self.available = False
-
-
-class Platform(object):
- """
- Simple class that contains the structure of an OpenCL platform
- """
-
- def __init__(self, name="None", vendor="None", version=None, extensions=None, idx=0):
- """
- Class containing all descriptions of a platform and all devices description within that platform.
-
- :param name: platform name
- :param vendor: name of the brand/vendor
- :param version:
- :param extensions: list of the extension provided by the platform to all of its devices
- :param idx: index of the platform
- """
- self.name = name.strip()
- self.vendor = vendor.strip()
- self.version = version
- self.extensions = extensions.split()
- self.devices = []
- self.id = idx
-
- def __repr__(self):
- return "%s" % self.name
-
- def add_device(self, device):
- """
- Add new device to the platform
-
- :param device: Device instance
- """
- self.devices.append(device)
-
- def get_device(self, key):
- """
- Return a device according to key
-
- :param key: identifier for a device, either it's id (int) or it's name
- :type key: int or str
- """
- out = None
- try:
- devid = int(key)
- except ValueError:
- for a_dev in self.devices:
- if a_dev.name == key:
- out = a_dev
- else:
- if len(self.devices) > devid > 0:
- out = self.devices[devid]
- return out
-
-
-def _measure_workgroup_size(device_or_context, fast=False):
- """Mesure the maximal work group size of the given device
-
- DEPRECATED since not perfectly correct !
-
- :param device_or_context: instance of pyopencl.Device or pyopencl.Context
- or 2-tuple (platformid,deviceid)
- :param fast: ask the kernel the valid value, don't probe it
- :return: maximum size for the workgroup
- """
- if isinstance(device_or_context, pyopencl.Device):
- try:
- ctx = pyopencl.Context(devices=[device_or_context])
- except pyopencl._cl.LogicError as error:
- platform = device_or_context.platform
- platformid = pyopencl.get_platforms().index(platform)
- deviceid = platform.get_devices().index(device_or_context)
- ocl.platforms[platformid].devices[deviceid].set_unavailable()
- raise RuntimeError("Unable to create context on %s/%s: %s" % (platform, device_or_context, error))
- else:
- device = device_or_context
- elif isinstance(device_or_context, pyopencl.Context):
- ctx = device_or_context
- device = device_or_context.devices[0]
- elif isinstance(device_or_context, (tuple, list)) and len(device_or_context) == 2:
- ctx = ocl.create_context(platformid=device_or_context[0],
- deviceid=device_or_context[1])
- device = ctx.devices[0]
- else:
- raise RuntimeError("""given parameter device_or_context is not an
- instanciation of a device or a context""")
- shape = device.max_work_group_size
- # get the context
-
- assert ctx is not None
- queue = pyopencl.CommandQueue(ctx)
-
- max_valid_wg = 1
- data = numpy.random.random(shape).astype(numpy.float32)
- d_data = pyopencl.array.to_device(queue, data)
- d_data_1 = pyopencl.array.empty_like(d_data)
- d_data_1.fill(numpy.float32(1.0))
-
- program = pyopencl.Program(ctx, get_opencl_code("addition")).build()
- if fast:
- max_valid_wg = program.addition.get_work_group_info(pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, device)
- else:
- maxi = int(round(numpy.log2(shape)))
- for i in range(maxi + 1):
- d_res = pyopencl.array.empty_like(d_data)
- wg = 1 << i
- try:
- evt = program.addition(
- queue, (shape,), (wg,),
- d_data.data, d_data_1.data, d_res.data, numpy.int32(shape))
- evt.wait()
- except Exception as error:
- logger.info("%s on device %s for WG=%s/%s", error, device.name, wg, shape)
- program = queue = d_res = d_data_1 = d_data = None
- break
- else:
- res = d_res.get()
- good = numpy.allclose(res, data + 1)
- if good:
- if wg > max_valid_wg:
- max_valid_wg = wg
- else:
- logger.warning("ArithmeticError on %s for WG=%s/%s", wg, device.name, shape)
-
- return max_valid_wg
-
-
-def _is_nvidia_gpu(vendor, devtype):
- return (vendor == "NVIDIA Corporation") and (devtype == "GPU")
-
-
-class OpenCL(object):
- """
- Simple class that wraps the structure ocl_tools_extended.h
-
- This is a static class.
- ocl should be the only instance and shared among all python modules.
- """
-
- platforms = []
- nb_devices = 0
- context_cache = {} # key: 2-tuple of int, value: context
- if pyopencl:
- platform = device = pypl = devtype = extensions = pydev = None
- for idx, platform in enumerate(pyopencl.get_platforms()):
- pypl = Platform(platform.name, platform.vendor, platform.version, platform.extensions, idx)
- for idd, device in enumerate(platform.get_devices()):
- ####################################################
- # Nvidia does not report int64 atomics (we are using) ...
- # this is a hack around as any nvidia GPU with double-precision supports int64 atomics
- ####################################################
- extensions = device.extensions
- if (pypl.vendor == "NVIDIA Corporation") and ('cl_khr_fp64' in extensions):
- extensions += ' cl_khr_int64_base_atomics cl_khr_int64_extended_atomics'
- try:
- devtype = pyopencl.device_type.to_string(device.type).upper()
- except ValueError:
- # pocl does not describe itself as a CPU !
- devtype = "CPU"
- if len(devtype) > 3:
- if "GPU" in devtype:
- devtype = "GPU"
- elif "ACC" in devtype:
- devtype = "ACC"
- elif "CPU" in devtype:
- devtype = "CPU"
- else:
- devtype = devtype[:3]
- if _is_nvidia_gpu(device.vendor, devtype) and ("compute_capability_major_nv" in dir(device)):
- try:
- comput_cap = device.compute_capability_major_nv, device.compute_capability_minor_nv
- except pyopencl.LogicError:
- flop_core = FLOP_PER_CORE["GPU"]
- else:
- flop_core = NVIDIA_FLOP_PER_CORE.get(comput_cap, FLOP_PER_CORE["GPU"])
- elif (pypl.vendor == "Advanced Micro Devices, Inc.") and (devtype == "GPU"):
- flop_core = AMD_FLOP_PER_CORE
- elif devtype == "CPU":
- flop_core = FLOP_PER_CORE.get(devtype, 1)
- else:
- flop_core = 1
- workgroup = device.max_work_group_size
- if (devtype == "CPU") and (pypl.vendor == "Apple"):
- logger.info("For Apple's OpenCL on CPU: Measuring actual valid max_work_goup_size.")
- workgroup = _measure_workgroup_size(device, fast=True)
- if (devtype == "GPU") and os.environ.get("GPU") == "False":
- # Environment variable to disable GPU devices
- continue
- pydev = Device(device.name, devtype, device.version, device.driver_version, extensions,
- device.global_mem_size, bool(device.available), device.max_compute_units,
- device.max_clock_frequency, flop_core, idd, workgroup)
- pypl.add_device(pydev)
- nb_devices += 1
- platforms.append(pypl)
- del platform, device, pypl, devtype, extensions, pydev
-
- def __repr__(self):
- out = ["OpenCL devices:"]
- for platformid, platform in enumerate(self.platforms):
- deviceids = ["(%s,%s) %s" % (platformid, deviceid, dev.name)
- for deviceid, dev in enumerate(platform.devices)]
- out.append("[%s] %s: " % (platformid, platform.name) + ", ".join(deviceids))
- return os.linesep.join(out)
-
- def get_platform(self, key):
- """
- Return a platform according
-
- :param key: identifier for a platform, either an Id (int) or it's name
- :type key: int or str
- """
- out = None
- try:
- platid = int(key)
- except ValueError:
- for a_plat in self.platforms:
- if a_plat.name == key:
- out = a_plat
- else:
- if len(self.platforms) > platid > 0:
- out = self.platforms[platid]
- return out
-
- def select_device(self, dtype="ALL", memory=None, extensions=None, best=True, **kwargs):
- """
- Select a device based on few parameters (at the end, keep the one with most memory)
-
- :param dtype: "gpu" or "cpu" or "all" ....
- :param memory: minimum amount of memory (int)
- :param extensions: list of extensions to be present
- :param best: shall we look for the
- :returns: A tuple of plateform ID and device ID, else None if nothing
- found
- """
- if extensions is None:
- extensions = []
- if "type" in kwargs:
- dtype = kwargs["type"].upper()
- else:
- dtype = dtype.upper()
- if len(dtype) > 3:
- dtype = dtype[:3]
- best_found = None
- for platformid, platform in enumerate(self.platforms):
- for deviceid, device in enumerate(platform.devices):
- if not device.available:
- continue
- if (dtype in ["ALL", "DEF"]) or (device.type == dtype):
- if (memory is None) or (memory <= device.memory):
- found = True
- for ext in extensions:
- if ext not in device.extensions:
- found = False
- if found:
- if not best:
- return platformid, deviceid
- else:
- if not best_found:
- best_found = platformid, deviceid, device.flops
- elif best_found[2] < device.flops:
- best_found = platformid, deviceid, device.flops
- if best_found:
- return best_found[0], best_found[1]
-
- # Nothing found
- return None
-
- def create_context(self, devicetype="ALL", useFp64=False, platformid=None,
- deviceid=None, cached=True, memory=None, extensions=None):
- """
- Choose a device and initiate a context.
-
- Devicetypes can be GPU,gpu,CPU,cpu,DEF,ACC,ALL.
- Suggested are GPU,CPU.
- For each setting to work there must be such an OpenCL device and properly installed.
- E.g.: If Nvidia driver is installed, GPU will succeed but CPU will fail.
- The AMD SDK kit is required for CPU via OpenCL.
- :param devicetype: string in ["cpu","gpu", "all", "acc"]
- :param useFp64: boolean specifying if double precision will be used: deprecated use extensions=["cl_khr_fp64"]
- :param platformid: integer
- :param deviceid: integer
- :param cached: True if we want to cache the context
- :param memory: minimum amount of memory of the device
- :param extensions: list of extensions to be present
- :return: OpenCL context on the selected device
- """
- if extensions is None:
- extensions = []
- if useFp64:
- logger.warning("Deprecation: please select your device using the extension name!, i.e. extensions=['cl_khr_fp64']")
- extensions.append('cl_khr_fp64')
-
- if (platformid is not None) and (deviceid is not None):
- platformid = int(platformid)
- deviceid = int(deviceid)
- elif "PYOPENCL_CTX" in os.environ:
- pyopencl_ctx = [int(i) if i.isdigit() else 0 for i in os.environ["PYOPENCL_CTX"].split(":")]
- pyopencl_ctx += [0] * (2 - len(pyopencl_ctx)) # pad with 0
- platformid, deviceid = pyopencl_ctx
- else:
- ids = ocl.select_device(type=devicetype, extensions=extensions)
- if ids:
- platformid, deviceid = ids
- ctx = None
- if (platformid is not None) and (deviceid is not None):
- if (platformid, deviceid) in self.context_cache:
- ctx = self.context_cache[(platformid, deviceid)]
- else:
- try:
- ctx = pyopencl.Context(devices=[pyopencl.get_platforms()[platformid].get_devices()[deviceid]])
- except pyopencl._cl.LogicError as error:
- self.platforms[platformid].devices[deviceid].set_unavailable()
- logger.warning("Unable to create context on %s/%s: %s", platformid, deviceid, error)
- ctx = None
- else:
- if cached:
- self.context_cache[(platformid, deviceid)] = ctx
- if ctx is None:
- logger.warning("Last chance to get an OpenCL device ... probably not the one requested")
- ctx = pyopencl.create_some_context(interactive=False)
- return ctx
-
- def device_from_context(self, context):
- """
- Retrieves the Device from the context
-
- :param context: OpenCL context
- :return: instance of Device
- """
- odevice = context.devices[0]
- oplat = odevice.platform
- device_id = oplat.get_devices().index(odevice)
- platform_id = pyopencl.get_platforms().index(oplat)
- return self.platforms[platform_id].devices[device_id]
-
-
-if pyopencl:
- ocl = OpenCL()
- if ocl.nb_devices == 0:
- ocl = None
-else:
- ocl = None
-
-
-def release_cl_buffers(cl_buffers):
- """
- :param cl_buffers: the buffer you want to release
- :type cl_buffers: dict(str, pyopencl.Buffer)
-
- This method release the memory of the buffers store in the dict
- """
- for key, buffer_ in cl_buffers.items():
- if buffer_ is not None:
- if isinstance(buffer_, pyopencl.array.Array):
- try:
- buffer_.data.release()
- except pyopencl.LogicError:
- logger.error("Error while freeing buffer %s", key)
- else:
- try:
- buffer_.release()
- except pyopencl.LogicError:
- logger.error("Error while freeing buffer %s", key)
- cl_buffers[key] = None
- return cl_buffers
-
-
-def allocate_cl_buffers(buffers, device=None, context=None):
- """
- :param buffers: the buffers info use to create the pyopencl.Buffer
- :type buffers: list(std, flag, numpy.dtype, int)
- :param device: one of the context device
- :param context: opencl contextdevice
- :return: a dict containing the instanciated pyopencl.Buffer
- :rtype: dict(str, pyopencl.Buffer)
-
- This method instanciate the pyopencl.Buffer from the buffers
- description.
- """
- mem = {}
- if device is None:
- device = ocl.device_from_context(context)
-
- # check if enough memory is available on the device
- ualloc = 0
- for _, _, dtype, size in buffers:
- ualloc += numpy.dtype(dtype).itemsize * size
- memory = device.memory
- logger.info("%.3fMB are needed on device which has %.3fMB",
- ualloc / 1.0e6, memory / 1.0e6)
- if ualloc >= memory:
- memError = "Fatal error in allocate_buffers."
- memError += "Not enough device memory for buffers"
- memError += "(%lu requested, %lu available)" % (ualloc, memory)
- raise MemoryError(memError) # noqa
-
- # do the allocation
- try:
- for name, flag, dtype, size in buffers:
- mem[name] = pyopencl.Buffer(context, flag,
- numpy.dtype(dtype).itemsize * size)
- except pyopencl.MemoryError as error:
- release_cl_buffers(mem)
- raise MemoryError(error)
-
- return mem
-
-
-def allocate_texture(ctx, shape, hostbuf=None, support_1D=False):
- """
- Allocate an OpenCL image ("texture").
-
- :param ctx: OpenCL context
- :param shape: Shape of the image. Note that pyopencl and OpenCL < 1.2
- do not support 1D images, so 1D images are handled as 2D with one row
- :param support_1D: force the image to be 1D if the shape has only one dim
- """
- if len(shape) == 1 and not(support_1D):
- shape = (1,) + shape
- return pyopencl.Image(
- ctx,
- pyopencl.mem_flags.READ_ONLY | pyopencl.mem_flags.USE_HOST_PTR,
- pyopencl.ImageFormat(
- pyopencl.channel_order.INTENSITY,
- pyopencl.channel_type.FLOAT
- ),
- hostbuf=numpy.zeros(shape[::-1], dtype=numpy.float32)
- )
-
-
-def check_textures_availability(ctx):
- """
- Check whether textures are supported on the current OpenCL context.
-
- :param ctx: OpenCL context
- """
- try:
- dummy_texture = allocate_texture(ctx, (16, 16))
- # Need to further access some attributes (pocl)
- dummy_height = dummy_texture.height
- textures_available = True
- del dummy_texture, dummy_height
- except (pyopencl.RuntimeError, pyopencl.LogicError):
- textures_available = False
- # Nvidia Fermi GPUs (compute capability 2.X) do not support opencl read_imagef
- # There is no way to detect this until a kernel is compiled
- try:
- cc = ctx.devices[0].compute_capability_major_nv
- textures_available &= (cc >= 3)
- except (pyopencl.LogicError, AttributeError): # probably not a Nvidia GPU
- pass
- #
- return textures_available
-
-
-def measure_workgroup_size(device):
- """Measure the actual size of the workgroup
-
- :param device: device or context or 2-tuple with indexes
- :return: the actual measured workgroup size
-
- if device is "all", returns a dict with all devices with their ids as keys.
- """
- if (ocl is None) or (device is None):
- return None
-
- if isinstance(device, tuple) and (len(device) == 2):
- # this is probably a tuple (platformid, deviceid)
- device = ocl.create_context(platformid=device[0], deviceid=device[1])
-
- if device == "all":
- res = {}
- for pid, platform in enumerate(ocl.platforms):
- for did, _devices in enumerate(platform.devices):
- tup = (pid, did)
- res[tup] = measure_workgroup_size(tup)
- else:
- res = _measure_workgroup_size(device)
- return res
-
-
-def query_kernel_info(program, kernel, what="WORK_GROUP_SIZE"):
- """Extract the compile time information from a kernel
-
- :param program: OpenCL program
- :param kernel: kernel or name of the kernel
- :param what: what is the query about ?
- :return: int or 3-int for the workgroup size.
-
- Possible information available are:
- * 'COMPILE_WORK_GROUP_SIZE': Returns the work-group size specified inside the kernel (__attribute__((reqd_work_gr oup_size(X, Y, Z))))
- * 'GLOBAL_WORK_SIZE': maximum global size that can be used to execute a kernel #OCL2.1!
- * 'LOCAL_MEM_SIZE': amount of local memory in bytes being used by the kernel
- * 'PREFERRED_WORK_GROUP_SIZE_MULTIPLE': preferred multiple of workgroup size for launch. This is a performance hint.
- * 'PRIVATE_MEM_SIZE' Returns the minimum amount of private memory, in bytes, used by each workitem in the kernel
- * 'WORK_GROUP_SIZE': maximum work-group size that can be used to execute a kernel on a specific device given by device
-
- Further information on:
- https://www.khronos.org/registry/OpenCL/sdk/1.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html
-
- """
- assert isinstance(program, pyopencl.Program)
- if not isinstance(kernel, pyopencl.Kernel):
- kernel_name = kernel
- assert kernel in (k.function_name for k in program.all_kernels()), "the kernel exists"
- kernel = program.__getattr__(kernel_name)
-
- device = program.devices[0]
- query_wg = getattr(pyopencl.kernel_work_group_info, what)
- return kernel.get_work_group_info(query_wg, device)
-
-
-def kernel_workgroup_size(program, kernel):
- """Extract the compile time maximum workgroup size
-
- :param program: OpenCL program
- :param kernel: kernel or name of the kernel
- :return: the maximum acceptable workgroup size for the given kernel
- """
- return query_kernel_info(program, kernel, what="WORK_GROUP_SIZE")
diff --git a/silx/opencl/test/__init__.py b/silx/opencl/test/__init__.py
deleted file mode 100644
index 928dbaf..0000000
--- a/silx/opencl/test/__init__.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Project: silx
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-__authors__ = ["J. Kieffer"]
-__license__ = "MIT"
-__date__ = "17/05/2021"
-
-import os
-import unittest
-from . import test_addition
-from . import test_medfilt
-from . import test_backprojection
-from . import test_projection
-from . import test_linalg
-from . import test_array_utils
-from ..codec import test as test_codec
-from . import test_image
-from . import test_kahan
-from . import test_doubleword
-from . import test_stats
-from . import test_convolution
-from . import test_sparse
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTests(test_addition.suite())
- test_suite.addTests(test_medfilt.suite())
- test_suite.addTests(test_backprojection.suite())
- test_suite.addTests(test_projection.suite())
- test_suite.addTests(test_linalg.suite())
- test_suite.addTests(test_array_utils.suite())
- test_suite.addTests(test_codec.suite())
- test_suite.addTests(test_image.suite())
- test_suite.addTests(test_kahan.suite())
- test_suite.addTests(test_doubleword.suite())
- test_suite.addTests(test_stats.suite())
- test_suite.addTests(test_convolution.suite())
- test_suite.addTests(test_sparse.suite())
- # Allow to remove sift from the project
- test_base_dir = os.path.dirname(__file__)
- sift_dir = os.path.join(test_base_dir, "..", "sift")
- if os.path.exists(sift_dir):
- from ..sift import test as test_sift
- test_suite.addTests(test_sift.suite())
-
- return test_suite
diff --git a/silx/opencl/test/test_addition.py b/silx/opencl/test/test_addition.py
deleted file mode 100644
index 19dfdf0..0000000
--- a/silx/opencl/test/test_addition.py
+++ /dev/null
@@ -1,154 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Project: Sift implementation in Python + OpenCL
-# https://github.com/silx-kit/silx
-#
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following
-# conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
-# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
-# OTHER DEALINGS IN THE SOFTWARE.
-
-"""
-Simple test of an addition
-"""
-
-__authors__ = ["Henri Payno, Jérôme Kieffer"]
-__contact__ = "jerome.kieffer@esrf.eu"
-__license__ = "MIT"
-__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "30/11/2020"
-
-import logging
-import numpy
-
-import unittest
-from ..common import ocl, _measure_workgroup_size, query_kernel_info
-if ocl:
- import pyopencl
- import pyopencl.array
-from ..utils import get_opencl_code
-logger = logging.getLogger(__name__)
-
-
-@unittest.skipUnless(ocl, "PyOpenCl is missing")
-class TestAddition(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- super(TestAddition, cls).setUpClass()
- if ocl:
- cls.ctx = ocl.create_context()
- if logger.getEffectiveLevel() <= logging.INFO:
- cls.PROFILE = True
- cls.queue = pyopencl.CommandQueue(
- cls.ctx,
- properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
- else:
- cls.PROFILE = False
- cls.queue = pyopencl.CommandQueue(cls.ctx)
- cls.max_valid_wg = 0
-
- @classmethod
- def tearDownClass(cls):
- super(TestAddition, cls).tearDownClass()
- print("Maximum valid workgroup size %s on device %s" % (cls.max_valid_wg, cls.ctx.devices[0]))
- cls.ctx = None
- cls.queue = None
-
- def setUp(self):
- if ocl is None:
- return
- self.shape = 4096
- self.data = numpy.random.random(self.shape).astype(numpy.float32)
- self.d_array_img = pyopencl.array.to_device(self.queue, self.data)
- self.d_array_5 = pyopencl.array.empty_like(self.d_array_img)
- self.d_array_5.fill(-5)
- self.program = pyopencl.Program(self.ctx, get_opencl_code("addition")).build()
-
- def tearDown(self):
- self.img = self.data = None
- self.d_array_img = self.d_array_5 = self.program = None
-
- @unittest.skipUnless(ocl, "pyopencl is missing")
- def test_add(self):
- """
- tests the addition kernel
- """
- maxi = int(round(numpy.log2(self.shape)))
- for i in range(maxi):
- d_array_result = pyopencl.array.empty_like(self.d_array_img)
- wg = 1 << i
- try:
- evt = self.program.addition(self.queue, (self.shape,), (wg,),
- self.d_array_img.data, self.d_array_5.data, d_array_result.data, numpy.int32(self.shape))
- evt.wait()
- except Exception as error:
- max_valid_wg = self.program.addition.get_work_group_info(pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, self.ctx.devices[0])
- msg = "Error %s on WG=%s: %s" % (error, wg, max_valid_wg)
- self.assertLess(max_valid_wg, wg, msg)
- break
- else:
- res = d_array_result.get()
- good = numpy.allclose(res, self.data - 5)
- if good and wg > self.max_valid_wg:
- self.__class__.max_valid_wg = wg
- self.assertTrue(good, "calculation is correct for WG=%s" % wg)
-
- @unittest.skipUnless(ocl, "pyopencl is missing")
- def test_measurement(self):
- """
- tests that all devices are working properly ... lengthy and error prone
- """
- for platform in ocl.platforms:
- for did, device in enumerate(platform.devices):
- meas = _measure_workgroup_size((platform.id, device.id))
- self.assertEqual(meas, device.max_work_group_size,
- "Workgroup size for %s/%s: %s == %s" % (platform, device, meas, device.max_work_group_size))
-
- @unittest.skipUnless(ocl, "pyopencl is missing")
- def test_query(self):
- """
- tests that all devices are working properly ... lengthy and error prone
- """
- for what in ("COMPILE_WORK_GROUP_SIZE",
- "LOCAL_MEM_SIZE",
- "PREFERRED_WORK_GROUP_SIZE_MULTIPLE",
- "PRIVATE_MEM_SIZE",
- "WORK_GROUP_SIZE"):
- logger.info("%s: %s", what, query_kernel_info(program=self.program, kernel="addition", what=what))
-
- # Not all ICD work properly ....
- #self.assertEqual(3, len(query_kernel_info(program=self.program, kernel="addition", what="COMPILE_WORK_GROUP_SIZE")), "3D kernel")
-
- min_wg = query_kernel_info(program=self.program, kernel="addition", what="PREFERRED_WORK_GROUP_SIZE_MULTIPLE")
- max_wg = query_kernel_info(program=self.program, kernel="addition", what="WORK_GROUP_SIZE")
- self.assertEqual(max_wg % min_wg, 0, msg="max_wg is a multiple of min_wg")
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestAddition("test_add"))
- # testSuite.addTest(TestAddition("test_measurement"))
- testSuite.addTest(TestAddition("test_query"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_array_utils.py b/silx/opencl/test/test_array_utils.py
deleted file mode 100644
index 833d828..0000000
--- a/silx/opencl/test/test_array_utils.py
+++ /dev/null
@@ -1,161 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of the OpenCL array_utils"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Pierre paleo"]
-__license__ = "MIT"
-__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "14/06/2017"
-
-
-import time
-import logging
-import numpy as np
-import unittest
-try:
- import mako
-except ImportError:
- mako = None
-from ..common import ocl
-if ocl:
- import pyopencl as cl
- import pyopencl.array as parray
- from .. import linalg
-from ..utils import get_opencl_code
-from silx.test.utils import utilstest
-
-logger = logging.getLogger(__name__)
-try:
- from scipy.ndimage.filters import laplace
- _has_scipy = True
-except ImportError:
- _has_scipy = False
-
-
-
-@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
-class TestCpy2d(unittest.TestCase):
-
- def setUp(self):
- if ocl is None:
- return
- self.ctx = ocl.create_context()
- if logger.getEffectiveLevel() <= logging.INFO:
- self.PROFILE = True
- self.queue = cl.CommandQueue(
- self.ctx,
- properties=cl.command_queue_properties.PROFILING_ENABLE)
- else:
- self.PROFILE = False
- self.queue = cl.CommandQueue(self.ctx)
- self.allocate_arrays()
- self.program = cl.Program(self.ctx, get_opencl_code("array_utils")).build()
-
- def allocate_arrays(self):
- """
- Allocate various types of arrays for the tests
- """
- self.prng_state = np.random.get_state()
- # Generate arrays of random shape
- self.shape1 = np.random.randint(20, high=512, size=(2,))
- self.shape2 = np.random.randint(20, high=512, size=(2,))
- self.array1 = np.random.rand(*self.shape1).astype(np.float32)
- self.array2 = np.random.rand(*self.shape2).astype(np.float32)
- self.d_array1 = parray.to_device(self.queue, self.array1)
- self.d_array2 = parray.to_device(self.queue, self.array2)
- # Generate random offsets
- offset1_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - 10)
- offset1_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - 10)
- offset2_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - 10)
- offset2_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - 10)
- self.offset1 = (offset1_y, offset1_x)
- self.offset2 = (offset2_y, offset2_x)
- # Compute the size of the rectangle to transfer
- size_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - max(offset1_y, offset2_y) + 1)
- size_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - max(offset1_x, offset2_x) + 1)
- self.transfer_shape = (size_y, size_x)
-
- def tearDown(self):
- self.array1 = None
- self.array2 = None
- self.d_array1.data.release()
- self.d_array2.data.release()
- self.d_array1 = None
- self.d_array2 = None
- self.ctx = None
- self.queue = None
-
- def compare(self, result, reference):
- errmax = np.max(np.abs(result - reference))
- logger.info("Max error = %e" % (errmax))
- self.assertTrue(errmax == 0, str("Max error is too high"))#. PRNG state was %s" % str(self.prng_state)))
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_cpy2d(self):
- """
- Test rectangular transfer of self.d_array1 to self.d_array2
- """
- # Reference
- o1 = self.offset1
- o2 = self.offset2
- T = self.transfer_shape
- logger.info("""Testing D->D rectangular copy with (N1_y, N1_x) = %s,
- (N2_y, N2_x) = %s:
- array2[%d:%d, %d:%d] = array1[%d:%d, %d:%d]""" %
- (
- str(self.shape1), str(self.shape2),
- o2[0], o2[0] + T[0],
- o2[1], o2[1] + T[1],
- o1[0], o1[0] + T[0],
- o1[1], o1[1] + T[1]
- )
- )
- self.array2[o2[0]:o2[0] + T[0], o2[1]:o2[1] + T[1]] = self.array1[o1[0]:o1[0] + T[0], o1[1]:o1[1] + T[1]]
- kernel_args = (
- self.d_array2.data,
- self.d_array1.data,
- np.int32(self.shape2[1]),
- np.int32(self.shape1[1]),
- np.int32(self.offset2[::-1]),
- np.int32(self.offset1[::-1]),
- np.int32(self.transfer_shape[::-1])
- )
- wg = None
- ndrange = self.transfer_shape[::-1]
- self.program.cpy2d(self.queue, ndrange, wg, *kernel_args)
- res = self.d_array2.get()
- self.compare(res, self.array2)
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestCpy2d("test_cpy2d"))
- return testSuite
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_backprojection.py b/silx/opencl/test/test_backprojection.py
deleted file mode 100644
index 9dfdd3a..0000000
--- a/silx/opencl/test/test_backprojection.py
+++ /dev/null
@@ -1,231 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of the filtered backprojection module"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Pierre paleo"]
-__license__ = "MIT"
-__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "19/01/2018"
-
-
-import time
-import logging
-import numpy as np
-import unittest
-from math import pi
-try:
- import mako
-except ImportError:
- mako = None
-from ..common import ocl
-if ocl:
- from .. import backprojection
- from ...image.tomography import compute_fourier_filter
-from silx.test.utils import utilstest
-
-logger = logging.getLogger(__name__)
-
-
-def generate_coords(img_shp, center=None):
- """
- Return two 2D arrays containing the indexes of an image.
- The zero is at the center of the image.
- """
- l_r, l_c = float(img_shp[0]), float(img_shp[1])
- R, C = np.mgrid[:l_r, :l_c]
- if center is None:
- center0, center1 = l_r / 2., l_c / 2.
- else:
- center0, center1 = center
- R = R + 0.5 - center0
- C = C + 0.5 - center1
- return R, C
-
-
-def clip_circle(img, center=None, radius=None):
- """
- Puts zeros outside the inscribed circle of the image support.
- """
- R, C = generate_coords(img.shape, center)
- M = R * R + C * C
- res = np.zeros_like(img)
- if radius is None:
- radius = img.shape[0] / 2. - 1
- mask = M < radius * radius
- res[mask] = img[mask]
- return res
-
-
-@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
-class TestFBP(unittest.TestCase):
-
- def setUp(self):
- if ocl is None:
- return
- self.getfiles()
- self.fbp = backprojection.Backprojection(self.sino.shape, profile=True)
- if self.fbp.compiletime_workgroup_size < 16 * 16:
- self.skipTest("Current implementation of OpenCL backprojection is "
- "not supported on this platform yet")
- # Astra does not use the same backprojector implementation.
- # Therefore, we cannot expect results to be the "same" (up to float32
- # numerical error)
- self.tol = 5e-2
- if not(self.fbp._use_textures) or self.fbp.device.type == "CPU":
- # Precision is less when using CPU
- # (either CPU textures or "manual" linear interpolation)
- self.tol *= 2
-
- def tearDown(self):
- self.sino = None
- # self.fbp.log_profile()
- self.fbp = None
-
- def getfiles(self):
- # load sinogram of 512x512 MRI phantom
- self.sino = np.load(utilstest.getfile("sino500.npz"))["data"]
- # load reconstruction made with ASTRA FBP (with filter designed in spatial domain)
- self.reference_rec = np.load(utilstest.getfile("rec_astra_500.npz"))["data"]
-
- def measure(self):
- "Common measurement of timings"
- t1 = time.time()
- try:
- result = self.fbp.filtered_backprojection(self.sino)
- except RuntimeError as msg:
- logger.error(msg)
- return
- t2 = time.time()
- return t2 - t1, result
-
- def compare(self, res):
- """
- Compare a result with the reference reconstruction.
- Only the valid reconstruction zone (inscribed circle) is taken into
- account
- """
- res_clipped = clip_circle(res)
- ref_clipped = clip_circle(self.reference_rec)
- delta = abs(res_clipped - ref_clipped)
- bad = delta > 1
- logger.debug("Absolute difference: %s with %s outlier pixels out of %s"
- "", delta.max(), bad.sum(), np.prod(bad.shape))
- return delta.max()
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_fbp(self):
- """
- tests FBP
- """
- # Test single reconstruction
- # --------------------------
- t, res = self.measure()
- if t is None:
- logger.info("test_fp: skipped")
- else:
- logger.info("test_backproj: time = %.3fs" % t)
- err = self.compare(res)
- msg = str("Max error = %e" % err)
- logger.info(msg)
- self.assertTrue(err < self.tol, "Max error is too high")
-
- # Test multiple reconstructions
- # -----------------------------
- res0 = np.copy(res)
- for i in range(10):
- res = self.fbp.filtered_backprojection(self.sino)
- errmax = np.max(np.abs(res - res0))
- self.assertTrue(errmax < 1.e-6, "Max error is too high")
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_fbp_filters(self):
- """
- Test the different available filters of silx FBP.
- """
- avail_filters = [
- "ramlak", "shepp-logan", "cosine", "hamming",
- "hann"
- ]
- # Create a Dirac delta function at a single angle view.
- # As the filters are radially invarant:
- # - backprojection yields an image where each line is a Dirac.
- # - FBP yields an image where each line is the spatial filter
- # One can simply filter "dirac" without backprojecting it, but this
- # test will also ensure that backprojection behaves well.
- dirac = np.zeros_like(self.sino)
- na, dw = dirac.shape
- dirac[0, dw//2] = na / pi * 2
-
- for filter_name in avail_filters:
- B = backprojection.Backprojection(dirac.shape, filter_name=filter_name)
- r = B(dirac)
- # Check that radial invariance is kept
- std0 = np.max(np.abs(np.std(r, axis=0)))
- self.assertTrue(
- std0 < 5.e-6,
- "Something wrong with FBP(filter=%s)" % filter_name
- )
- # Check that the filter is retrieved
- r_f = np.fft.fft(np.fft.fftshift(r[0])).real / 2. # filter factor
- ref_filter_f = compute_fourier_filter(dw, filter_name)
- errmax = np.max(np.abs(r_f - ref_filter_f))
- logger.info("FBP filter %s: max error=%e" % (filter_name, errmax))
- self.assertTrue(
- errmax < 1.e-3,
- "Something wrong with FBP(filter=%s)" % filter_name
- )
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_fbp_oddsize(self):
- # Generate a 513-sinogram.
- # The padded width will be nextpow(513*2).
- # silx [0.10, 0.10.1] will give 1029, which makes R2C transform fail.
- sino = np.pad(self.sino, ((0, 0), (1, 0)), mode='edge')
- B = backprojection.Backprojection(sino.shape, axis_position=self.fbp.axis_pos+1)
- res = B(sino)
- # Compare with self.reference_rec. Tolerance is high as backprojector
- # is not fully shift-invariant.
- errmax = np.max(np.abs(clip_circle(res[1:, 1:] - self.reference_rec)))
- self.assertLess(
- errmax, 1.e-1,
- "Something wrong with FBP on odd-sized sinogram"
- )
-
-
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestFBP("test_fbp"))
- testSuite.addTest(TestFBP("test_fbp_filters"))
- testSuite.addTest(TestFBP("test_fbp_oddsize"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_convolution.py b/silx/opencl/test/test_convolution.py
deleted file mode 100644
index 7bceb0d..0000000
--- a/silx/opencl/test/test_convolution.py
+++ /dev/null
@@ -1,265 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-"""
-Test of the Convolution class.
-"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Pierre Paleo"]
-__contact__ = "pierre.paleo@esrf.fr"
-__license__ = "MIT"
-__copyright__ = "2019 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "01/08/2019"
-
-import logging
-from itertools import product
-import numpy as np
-from silx.utils.testutils import parameterize
-from silx.image.utils import gaussian_kernel
-
-try:
- from scipy.ndimage import convolve, convolve1d
- from scipy.misc import ascent
-
- scipy_convolve = convolve
- scipy_convolve1d = convolve1d
-except ImportError:
- scipy_convolve = None
-import unittest
-from ..common import ocl, check_textures_availability
-
-if ocl:
- import pyopencl as cl
- import pyopencl.array as parray
- from silx.opencl.convolution import Convolution
-logger = logging.getLogger(__name__)
-
-
-@unittest.skipUnless(ocl and scipy_convolve, "PyOpenCl/scipy is missing")
-class TestConvolution(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- super(TestConvolution, cls).setUpClass()
- cls.image = np.ascontiguousarray(ascent()[:, :511], dtype="f")
- cls.data1d = cls.image[0]
- cls.data2d = cls.image
- cls.data3d = np.tile(cls.image[224:-224, 224:-224], (62, 1, 1))
- cls.kernel1d = gaussian_kernel(1.0)
- cls.kernel2d = np.outer(cls.kernel1d, cls.kernel1d)
- cls.kernel3d = np.multiply.outer(cls.kernel2d, cls.kernel1d)
- cls.ctx = ocl.create_context()
- cls.tol = {
- "1D": 1e-4,
- "2D": 1e-3,
- "3D": 1e-3,
- }
-
- @classmethod
- def tearDownClass(cls):
- cls.data1d = cls.data2d = cls.data3d = cls.image = None
- cls.kernel1d = cls.kernel2d = cls.kernel3d = None
-
- @staticmethod
- def compare(arr1, arr2):
- return np.max(np.abs(arr1 - arr2))
-
- @staticmethod
- def print_err(conv):
- errmsg = str(
- """
- Something wrong with %s
- mode=%s, texture=%s
- """
- % (conv.use_case_desc, conv.mode, conv.use_textures)
- )
- return errmsg
-
- def __init__(self, methodName="runTest", param=None):
- unittest.TestCase.__init__(self, methodName)
- self.param = param
- self.mode = param["boundary_handling"]
- logger.debug(
- """
- Testing convolution with boundary_handling=%s,
- use_textures=%s, input_device=%s, output_device=%s
- """
- % (
- self.mode,
- param["use_textures"],
- param["input_on_device"],
- param["output_on_device"],
- )
- )
-
- def instantiate_convol(self, shape, kernel, axes=None):
- if self.mode == "constant":
- if not (self.param["use_textures"]) or (
- self.param["use_textures"]
- and not (check_textures_availability(self.ctx))
- ):
- self.skipTest("mode=constant not implemented without textures")
- C = Convolution(
- shape,
- kernel,
- mode=self.mode,
- ctx=self.ctx,
- axes=axes,
- extra_options={"dont_use_textures": not (self.param["use_textures"])},
- )
- return C
-
- def get_data_and_kernel(self, test_name):
- dims = {
- "test_1D": (1, 1),
- "test_separable_2D": (2, 1),
- "test_separable_3D": (3, 1),
- "test_nonseparable_2D": (2, 2),
- "test_nonseparable_3D": (3, 3),
- }
- dim_data = {1: self.data1d, 2: self.data2d, 3: self.data3d}
- dim_kernel = {
- 1: self.kernel1d,
- 2: self.kernel2d,
- 3: self.kernel3d,
- }
- dd, kd = dims[test_name]
- return dim_data[dd], dim_kernel[kd]
-
- def get_reference_function(self, test_name):
- ref_func = {
- "test_1D": lambda x, y: scipy_convolve1d(x, y, mode=self.mode),
- "test_separable_2D": lambda x, y: scipy_convolve1d(
- scipy_convolve1d(x, y, mode=self.mode, axis=1),
- y,
- mode=self.mode,
- axis=0,
- ),
- "test_separable_3D": lambda x, y: scipy_convolve1d(
- scipy_convolve1d(
- scipy_convolve1d(x, y, mode=self.mode, axis=2),
- y,
- mode=self.mode,
- axis=1,
- ),
- y,
- mode=self.mode,
- axis=0,
- ),
- "test_nonseparable_2D": lambda x, y: scipy_convolve(x, y, mode=self.mode),
- "test_nonseparable_3D": lambda x, y: scipy_convolve(x, y, mode=self.mode),
- }
- return ref_func[test_name]
-
- def template_test(self, test_name):
- data, kernel = self.get_data_and_kernel(test_name)
- conv = self.instantiate_convol(data.shape, kernel)
- if self.param["input_on_device"]:
- data_ref = parray.to_device(conv.queue, data)
- else:
- data_ref = data
- if self.param["output_on_device"]:
- d_res = parray.empty_like(conv.data_out)
- d_res.fill(0)
- res = d_res
- else:
- res = None
- res = conv(data_ref, output=res)
- if self.param["output_on_device"]:
- res = res.get()
- ref_func = self.get_reference_function(test_name)
- ref = ref_func(data, kernel)
- metric = self.compare(res, ref)
- logger.info("%s: max error = %.2e" % (test_name, metric))
- tol = self.tol[str("%dD" % kernel.ndim)]
- self.assertLess(metric, tol, self.print_err(conv))
-
- def test_1D(self):
- self.template_test("test_1D")
-
- def test_separable_2D(self):
- self.template_test("test_separable_2D")
-
- def test_separable_3D(self):
- self.template_test("test_separable_3D")
-
- def test_nonseparable_2D(self):
- self.template_test("test_nonseparable_2D")
-
- def test_nonseparable_3D(self):
- self.template_test("test_nonseparable_3D")
-
- def test_batched_2D(self):
- """
- Test batched (nonseparable) 2D convolution on 3D data.
- In this test: batch along "z" (axis 0)
- """
- data = self.data3d
- kernel = self.kernel2d
- conv = self.instantiate_convol(data.shape, kernel, axes=(0,))
- res = conv(data) # 3D
- ref = scipy_convolve(data[0], kernel, mode=self.mode) # 2D
-
- std = np.std(res, axis=0)
- std_max = np.max(np.abs(std))
- self.assertLess(std_max, self.tol["2D"], self.print_err(conv))
- metric = self.compare(res[0], ref)
- logger.info("test_nonseparable_3D: max error = %.2e" % metric)
- self.assertLess(metric, self.tol["2D"], self.print_err(conv))
-
-
-def test_convolution():
- boundary_handling_ = ["reflect", "nearest", "wrap", "constant"]
- use_textures_ = [True, False]
- input_on_device_ = [True, False]
- output_on_device_ = [True, False]
- testSuite = unittest.TestSuite()
-
- param_vals = list(
- product(boundary_handling_, use_textures_, input_on_device_, output_on_device_)
- )
- for boundary_handling, use_textures, input_dev, output_dev in param_vals:
- testcase = parameterize(
- TestConvolution,
- param={
- "boundary_handling": boundary_handling,
- "input_on_device": input_dev,
- "output_on_device": output_dev,
- "use_textures": use_textures,
- },
- )
- testSuite.addTest(testcase)
- return testSuite
-
-
-def suite():
- testSuite = test_convolution()
- return testSuite
-
-
-if __name__ == "__main__":
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_doubleword.py b/silx/opencl/test/test_doubleword.py
deleted file mode 100644
index ca947e0..0000000
--- a/silx/opencl/test/test_doubleword.py
+++ /dev/null
@@ -1,258 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-#
-# Project: The silx project
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2021-2021 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-"test suite for OpenCL code"
-
-__author__ = "Jérôme Kieffer"
-__contact__ = "Jerome.Kieffer@ESRF.eu"
-__license__ = "MIT"
-__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "31/05/2021"
-
-import unittest
-import numpy
-import logging
-import platform
-
-logger = logging.getLogger(__name__)
-try:
- import pyopencl
-except ImportError as error:
- logger.warning("OpenCL module (pyopencl) is not present, skip tests. %s.", error)
- pyopencl = None
-
-from .. import ocl
-if ocl is not None:
- from ..utils import read_cl_file
- from .. import pyopencl
- import pyopencl.array
- from pyopencl.elementwise import ElementwiseKernel
-from ...test.utils import test_options
-
-EPS32 = numpy.finfo("float32").eps
-EPS64 = numpy.finfo("float64").eps
-
-
-class TestDoubleWord(unittest.TestCase):
- """
- Test the kernels for compensated math in OpenCL
- """
-
- @classmethod
- def setUpClass(cls):
- if not test_options.WITH_OPENCL_TEST:
- raise unittest.SkipTest("User request to skip OpenCL tests")
- if pyopencl is None or ocl is None:
- raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
-
- cls.ctx = ocl.create_context(devicetype="GPU")
- cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
-
- # this is running 32 bits OpenCL woth POCL
- if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
- cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
- cls.args = "-DX87_VOLATILE=volatile"
- else:
- cls.args = ""
- size = 1024
- cls.a = 1.0 + numpy.random.random(size)
- cls.b = 1.0 + numpy.random.random(size)
- cls.ah = cls.a.astype(numpy.float32)
- cls.bh = cls.b.astype(numpy.float32)
- cls.al = (cls.a - cls.ah).astype(numpy.float32)
- cls.bl = (cls.b - cls.bh).astype(numpy.float32)
- cls.doubleword = read_cl_file("doubleword.cl")
-
- @classmethod
- def tearDownClass(cls):
- cls.queue = None
- cls.ctx = None
- cls.a = cls.al = cls.ah = None
- cls.b = cls.bl = cls.bh = None
- cls.doubleword = None
-
- def test_fast_sum2(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *a, float *b, float *res_h, float *res_l",
- "float2 tmp = fast_fp_plus_fp(a[i], b[i]); res_h[i] = tmp.s0; res_l[i] = tmp.s1",
- preamble=self.doubleword)
- a_g = pyopencl.array.to_device(self.queue, self.ah)
- b_g = pyopencl.array.to_device(self.queue, self.bl)
- res_l = pyopencl.array.empty_like(a_g)
- res_h = pyopencl.array.empty_like(a_g)
- test_kernel(a_g, b_g, res_h, res_l)
- self.assertEqual(abs(self.ah + self.bl - res_h.get()).max(), 0, "Major matches")
- self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bl - res_h.get()).max(), 0, "Exact mismatches")
- self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bl - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
-
- def test_sum2(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *a, float *b, float *res_h, float *res_l",
- "float2 tmp = fp_plus_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- a_g = pyopencl.array.to_device(self.queue, self.ah)
- b_g = pyopencl.array.to_device(self.queue, self.bh)
- res_l = pyopencl.array.empty_like(a_g)
- res_h = pyopencl.array.empty_like(a_g)
- test_kernel(a_g, b_g, res_h, res_l)
- self.assertEqual(abs(self.ah + self.bh - res_h.get()).max(), 0, "Major matches")
- self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bh - res_h.get()).max(), 0, "Exact mismatches")
- self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bh - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
-
- def test_prod2(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *a, float *b, float *res_h, float *res_l",
- "float2 tmp = fp_times_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- a_g = pyopencl.array.to_device(self.queue, self.ah)
- b_g = pyopencl.array.to_device(self.queue, self.bh)
- res_l = pyopencl.array.empty_like(a_g)
- res_h = pyopencl.array.empty_like(a_g)
- test_kernel(a_g, b_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertEqual(abs(self.ah * self.bh - res_m).max(), 0, "Major matches")
- self.assertGreater(abs(self.ah.astype(numpy.float64) * self.bh - res_m).max(), 0, "Exact mismatches")
- self.assertEqual(abs(self.ah.astype(numpy.float64) * self.bh - res).max(), 0, "Exact matches")
-
- def test_dw_plus_fp(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *b, float *res_h, float *res_l",
- "float2 tmp = dw_plus_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- ah_g = pyopencl.array.to_device(self.queue, self.ah)
- al_g = pyopencl.array.to_device(self.queue, self.al)
- b_g = pyopencl.array.to_device(self.queue, self.bh)
- res_l = pyopencl.array.empty_like(b_g)
- res_h = pyopencl.array.empty_like(b_g)
- test_kernel(ah_g, al_g, b_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertLess(abs(self.a + self.bh - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a + self.bh - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.ah.astype(numpy.float64) + self.al + self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
-
- def test_dw_plus_dw(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
- "float2 tmp = dw_plus_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- ah_g = pyopencl.array.to_device(self.queue, self.ah)
- al_g = pyopencl.array.to_device(self.queue, self.al)
- bh_g = pyopencl.array.to_device(self.queue, self.bh)
- bl_g = pyopencl.array.to_device(self.queue, self.bl)
- res_l = pyopencl.array.empty_like(bh_g)
- res_h = pyopencl.array.empty_like(bh_g)
- test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertLess(abs(self.a + self.b - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a + self.b - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a + self.b - res).max(), 3 * EPS32 ** 2, "Exact matches")
-
- def test_dw_times_fp(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *b, float *res_h, float *res_l",
- "float2 tmp = dw_times_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- ah_g = pyopencl.array.to_device(self.queue, self.ah)
- al_g = pyopencl.array.to_device(self.queue, self.al)
- b_g = pyopencl.array.to_device(self.queue, self.bh)
- res_l = pyopencl.array.empty_like(b_g)
- res_h = pyopencl.array.empty_like(b_g)
- test_kernel(ah_g, al_g, b_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertLess(abs(self.a * self.bh - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a * self.bh - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a * self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
-
- def test_dw_times_dw(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
- "float2 tmp = dw_times_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- ah_g = pyopencl.array.to_device(self.queue, self.ah)
- al_g = pyopencl.array.to_device(self.queue, self.al)
- bh_g = pyopencl.array.to_device(self.queue, self.bh)
- bl_g = pyopencl.array.to_device(self.queue, self.bl)
- res_l = pyopencl.array.empty_like(bh_g)
- res_h = pyopencl.array.empty_like(bh_g)
- test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertLess(abs(self.a * self.b - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a * self.b - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a * self.b - res).max(), 5 * EPS32 ** 2, "Exact matches")
-
- def test_dw_div_fp(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *b, float *res_h, float *res_l",
- "float2 tmp = dw_div_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- ah_g = pyopencl.array.to_device(self.queue, self.ah)
- al_g = pyopencl.array.to_device(self.queue, self.al)
- b_g = pyopencl.array.to_device(self.queue, self.bh)
- res_l = pyopencl.array.empty_like(b_g)
- res_h = pyopencl.array.empty_like(b_g)
- test_kernel(ah_g, al_g, b_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertLess(abs(self.a / self.bh - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a / self.bh - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a / self.bh - res).max(), 3 * EPS32 ** 2, "Exact matches")
-
- def test_dw_div_dw(self):
- test_kernel = ElementwiseKernel(self.ctx,
- "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
- "float2 tmp = dw_div_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
- preamble=self.doubleword)
- ah_g = pyopencl.array.to_device(self.queue, self.ah)
- al_g = pyopencl.array.to_device(self.queue, self.al)
- bh_g = pyopencl.array.to_device(self.queue, self.bh)
- bl_g = pyopencl.array.to_device(self.queue, self.bl)
- res_l = pyopencl.array.empty_like(bh_g)
- res_h = pyopencl.array.empty_like(bh_g)
- test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
- res_m = res_h.get()
- res = res_h.get().astype(numpy.float64) + res_l.get()
- self.assertLess(abs(self.a / self.b - res_m).max(), EPS32, "Major matches")
- self.assertGreater(abs(self.a / self.b - res_m).max(), EPS64, "Exact mismatches")
- self.assertLess(abs(self.a / self.b - res).max(), 6 * EPS32 ** 2, "Exact matches")
-
-
-def suite():
- testsuite = unittest.TestSuite()
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- testsuite.addTest(loader(TestDoubleWord))
- return testsuite
-
-
-if __name__ == '__main__':
- runner = unittest.TextTestRunner()
- runner.run(suite())
diff --git a/silx/opencl/test/test_image.py b/silx/opencl/test/test_image.py
deleted file mode 100644
index d73a854..0000000
--- a/silx/opencl/test/test_image.py
+++ /dev/null
@@ -1,137 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Project: image manipulation in OpenCL
-# https://github.com/silx-kit/silx
-#
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following
-# conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
-# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
-# OTHER DEALINGS IN THE SOFTWARE.
-
-"""
-Simple test of image manipulation
-"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Jérôme Kieffer"]
-__contact__ = "jerome.kieffer@esrf.eu"
-__license__ = "MIT"
-__copyright__ = "2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "13/02/2018"
-
-import logging
-import numpy
-
-import unittest
-from ..common import ocl, _measure_workgroup_size
-if ocl:
- import pyopencl
- import pyopencl.array
-from ...test.utils import utilstest
-from ..image import ImageProcessing
-logger = logging.getLogger(__name__)
-try:
- from PIL import Image
-except ImportError:
- Image = None
-
-
-@unittest.skipUnless(ocl and Image, "PyOpenCl/Image is missing")
-class TestImage(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- super(TestImage, cls).setUpClass()
- if ocl:
- cls.ctx = ocl.create_context()
- cls.lena = utilstest.getfile("lena.png")
- cls.data = numpy.asarray(Image.open(cls.lena))
- cls.ip = ImageProcessing(ctx=cls.ctx, template=cls.data, profile=True)
-
- @classmethod
- def tearDownClass(cls):
- super(TestImage, cls).tearDownClass()
- cls.ctx = None
- cls.lena = None
- cls.data = None
- if logger.level <= logging.INFO:
- logger.warning("\n".join(cls.ip.log_profile()))
- cls.ip = None
-
- def setUp(self):
- if ocl is None:
- return
- self.data = numpy.asarray(Image.open(self.lena))
-
- def tearDown(self):
- self.img = self.data = None
-
- @unittest.skipUnless(ocl, "pyopencl is missing")
- def test_cast(self):
- """
- tests the cast kernel
- """
- res = self.ip.to_float(self.data)
- self.assertEqual(res.shape, self.data.shape, "shape")
- self.assertEqual(res.dtype, numpy.float32, "dtype")
- self.assertEqual(abs(res - self.data).max(), 0, "content")
-
- @unittest.skipUnless(ocl, "pyopencl is missing")
- def test_normalize(self):
- """
- tests that all devices are working properly ...
- """
- tmp = pyopencl.array.empty(self.ip.ctx, self.data.shape, "float32")
- res = self.ip.to_float(self.data, out=tmp)
- res2 = self.ip.normalize(tmp, -100, 100, copy=False)
- norm = (self.data.astype(numpy.float32) - self.data.min()) / (self.data.max() - self.data.min())
- ref2 = 200 * norm - 100
- self.assertLess(abs(res2 - ref2).max(), 3e-5, "content")
-
- @unittest.skipUnless(ocl, "pyopencl is missing")
- def test_histogram(self):
- """
- Test on a greyscaled image ... of Lena :)
- """
- lena_bw = (0.2126 * self.data[:, :, 0] +
- 0.7152 * self.data[:, :, 1] +
- 0.0722 * self.data[:, :, 2]).astype("int32")
- ref = numpy.histogram(lena_bw, 255)
- ip = ImageProcessing(ctx=self.ctx, template=lena_bw, profile=True)
- res = ip.histogram(lena_bw, 255)
- ip.log_profile()
- delta = (ref[0] - res[0])
- deltap = (ref[1] - res[1])
- self.assertEqual(delta.sum(), 0, "errors are self-compensated")
- self.assertLessEqual(abs(delta).max(), 1, "errors are small")
- self.assertLessEqual(abs(deltap).max(), 3e-5, "errors on position are small: %s" % (abs(deltap).max()))
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestImage("test_cast"))
- testSuite.addTest(TestImage("test_normalize"))
- testSuite.addTest(TestImage("test_histogram"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_kahan.py b/silx/opencl/test/test_kahan.py
deleted file mode 100644
index 6ea599b..0000000
--- a/silx/opencl/test/test_kahan.py
+++ /dev/null
@@ -1,269 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-#
-# Project: OpenCL numerical library
-# https://github.com/silx-kit/silx
-#
-# Copyright (C) 2015-2021 European Synchrotron Radiation Facility, Grenoble, France
-#
-# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-
-"test suite for OpenCL code"
-
-__author__ = "Jérôme Kieffer"
-__contact__ = "Jerome.Kieffer@ESRF.eu"
-__license__ = "MIT"
-__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "17/05/2021"
-
-
-import unittest
-import numpy
-import logging
-import platform
-
-logger = logging.getLogger(__name__)
-try:
- import pyopencl
-except ImportError as error:
- logger.warning("OpenCL module (pyopencl) is not present, skip tests. %s.", error)
- pyopencl = None
-
-from .. import ocl
-if ocl is not None:
- from ..utils import read_cl_file
- from .. import pyopencl
- import pyopencl.array
-from ...test.utils import test_options
-
-
-class TestKahan(unittest.TestCase):
- """
- Test the kernels for compensated math in OpenCL
- """
-
- @classmethod
- def setUpClass(cls):
- if not test_options.WITH_OPENCL_TEST:
- raise unittest.SkipTest("User request to skip OpenCL tests")
- if pyopencl is None or ocl is None:
- raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
-
- cls.ctx = ocl.create_context(devicetype="GPU")
- cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
-
- # this is running 32 bits OpenCL woth POCL
- if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
- cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
- cls.args = "-DX87_VOLATILE=volatile"
- else:
- cls.args = ""
-
- @classmethod
- def tearDownClass(cls):
- cls.queue = None
- cls.ctx = None
-
- @staticmethod
- def dummy_sum(ary, dtype=None):
- "perform the actual sum in a dummy way "
- if dtype is None:
- dtype = ary.dtype.type
- sum_ = dtype(0)
- for i in ary:
- sum_ += i
- return sum_
-
- def test_kahan(self):
- # simple test
- N = 26
- data = (1 << (N - 1 - numpy.arange(N))).astype(numpy.float32)
-
- ref64 = numpy.sum(data, dtype=numpy.float64)
- ref32 = self.dummy_sum(data)
- if (ref64 == ref32):
- logger.warning("Kahan: invalid tests as float32 provides the same result as float64")
- # Dummy kernel to evaluate
- src = """
- kernel void summation(global float* data,
- int size,
- global float* result)
- {
- float2 acc = (float2)(0.0f, 0.0f);
- for (int i=0; i<size; i++)
- {
- acc = kahan_sum(acc, data[i]);
- }
- result[0] = acc.s0;
- result[1] = acc.s1;
- }
- """
- prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(self.args)
- ones_d = pyopencl.array.to_device(self.queue, data)
- res_d = pyopencl.array.empty(self.queue, 2, numpy.float32)
- res_d.fill(0)
- evt = prg.summation(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
- evt.wait()
- res = res_d.get().sum(dtype=numpy.float64)
- self.assertEqual(ref64, res, "test_kahan")
-
- def test_dot16(self):
- # simple test
- N = 16
- data = (1 << (N - 1 - numpy.arange(N))).astype(numpy.float32)
-
- ref64 = numpy.dot(data.astype(numpy.float64), data.astype(numpy.float64))
- ref32 = numpy.dot(data, data)
- if (ref64 == ref32):
- logger.warning("dot16: invalid tests as float32 provides the same result as float64")
- # Dummy kernel to evaluate
- src = """
- kernel void test_dot16(global float* data,
- int size,
- global float* result)
- {
- float2 acc = (float2)(0.0f, 0.0f);
- float16 data16 = (float16) (data[0],data[1],data[2],data[3],data[4],
- data[5],data[6],data[7],data[8],data[9],
- data[10],data[11],data[12],data[13],data[14],data[15]);
- acc = comp_dot16(data16, data16);
- result[0] = acc.s0;
- result[1] = acc.s1;
- }
-
- kernel void test_dot8(global float* data,
- int size,
- global float* result)
- {
- float2 acc = (float2)(0.0f, 0.0f);
- float8 data0 = (float8) (data[0],data[2],data[4],data[6],data[8],data[10],data[12],data[14]);
- float8 data1 = (float8) (data[1],data[3],data[5],data[7],data[9],data[11],data[13],data[15]);
- acc = comp_dot8(data0, data1);
- result[0] = acc.s0;
- result[1] = acc.s1;
- }
-
- kernel void test_dot4(global float* data,
- int size,
- global float* result)
- {
- float2 acc = (float2)(0.0f, 0.0f);
- float4 data0 = (float4) (data[0],data[4],data[8],data[12]);
- float4 data1 = (float4) (data[3],data[7],data[11],data[15]);
- acc = comp_dot4(data0, data1);
- result[0] = acc.s0;
- result[1] = acc.s1;
- }
-
- kernel void test_dot3(global float* data,
- int size,
- global float* result)
- {
- float2 acc = (float2)(0.0f, 0.0f);
- float3 data0 = (float3) (data[0],data[4],data[12]);
- float3 data1 = (float3) (data[3],data[11],data[15]);
- acc = comp_dot3(data0, data1);
- result[0] = acc.s0;
- result[1] = acc.s1;
- }
-
- kernel void test_dot2(global float* data,
- int size,
- global float* result)
- {
- float2 acc = (float2)(0.0f, 0.0f);
- float2 data0 = (float2) (data[0],data[14]);
- float2 data1 = (float2) (data[1],data[15]);
- acc = comp_dot2(data0, data1);
- result[0] = acc.s0;
- result[1] = acc.s1;
- }
-
- """
-
- prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(self.args)
- ones_d = pyopencl.array.to_device(self.queue, data)
- res_d = pyopencl.array.empty(self.queue, 2, numpy.float32)
- res_d.fill(0)
- evt = prg.test_dot16(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
- evt.wait()
- res = res_d.get().sum(dtype="float64")
- self.assertEqual(ref64, res, "test_dot16")
-
- res_d.fill(0)
- data0 = data[0::2]
- data1 = data[1::2]
- ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
- ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot8: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot8(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
- evt.wait()
- res = res_d.get().sum(dtype="float64")
- self.assertEqual(ref64, res, "test_dot8")
-
- res_d.fill(0)
- data0 = data[0::4]
- data1 = data[3::4]
- ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
- ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot4: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot4(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
- evt.wait()
- res = res_d.get().sum(dtype="float64")
- self.assertEqual(ref64, res, "test_dot4")
-
- res_d.fill(0)
- data0 = numpy.array([data[0], data[4], data[12]])
- data1 = numpy.array([data[3], data[11], data[15]])
- ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
- ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot3: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot3(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
- evt.wait()
- res = res_d.get().sum(dtype="float64")
- self.assertEqual(ref64, res, "test_dot3")
-
- res_d.fill(0)
- data0 = numpy.array([data[0], data[14]])
- data1 = numpy.array([data[1], data[15]])
- ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
- ref32 = numpy.dot(data0, data1)
- if (ref64 == ref32):
- logger.warning("dot2: invalid tests as float32 provides the same result as float64")
- evt = prg.test_dot2(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
- evt.wait()
- res = res_d.get().sum(dtype="float64")
- self.assertEqual(ref64, res, "test_dot2")
-
-
-def suite():
- testsuite = unittest.TestSuite()
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- testsuite.addTest(loader(TestKahan))
- return testsuite
-
-
-if __name__ == '__main__':
- runner = unittest.TextTestRunner()
- runner.run(suite())
diff --git a/silx/opencl/test/test_linalg.py b/silx/opencl/test/test_linalg.py
deleted file mode 100644
index 0b6c730..0000000
--- a/silx/opencl/test/test_linalg.py
+++ /dev/null
@@ -1,216 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of the linalg module"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Pierre paleo"]
-__license__ = "MIT"
-__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "01/08/2019"
-
-
-import time
-import logging
-import numpy as np
-import unittest
-try:
- import mako
-except ImportError:
- mako = None
-from ..common import ocl
-if ocl:
- import pyopencl as cl
- import pyopencl.array as parray
- from .. import linalg
-from silx.test.utils import utilstest
-
-logger = logging.getLogger(__name__)
-try:
- from scipy.ndimage.filters import laplace
- _has_scipy = True
-except ImportError:
- _has_scipy = False
-
-
-# TODO move this function in math or image ?
-def gradient(img):
- '''
- Compute the gradient of an image as a numpy array
- Code from https://github.com/emmanuelle/tomo-tv/
- '''
- shape = [img.ndim, ] + list(img.shape)
- gradient = np.zeros(shape, dtype=img.dtype)
- slice_all = [0, slice(None, -1),]
- for d in range(img.ndim):
- gradient[tuple(slice_all)] = np.diff(img, axis=d)
- slice_all[0] = d + 1
- slice_all.insert(1, slice(None))
- return gradient
-
-
-# TODO move this function in math or image ?
-def divergence(grad):
- '''
- Compute the divergence of a gradient
- Code from https://github.com/emmanuelle/tomo-tv/
- '''
- res = np.zeros(grad.shape[1:])
- for d in range(grad.shape[0]):
- this_grad = np.rollaxis(grad[d], d)
- this_res = np.rollaxis(res, d)
- this_res[:-1] += this_grad[:-1]
- this_res[1:-1] -= this_grad[:-2]
- this_res[-1] -= this_grad[-2]
- return res
-
-
-@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
-class TestLinAlg(unittest.TestCase):
-
- def setUp(self):
- if ocl is None:
- return
- self.getfiles()
- self.la = linalg.LinAlg(self.image.shape)
- self.allocate_arrays()
-
- def allocate_arrays(self):
- """
- Allocate various types of arrays for the tests
- """
- # numpy images
- self.grad = np.zeros(self.image.shape, dtype=np.complex64)
- self.grad2 = np.zeros((2,) + self.image.shape, dtype=np.float32)
- self.grad_ref = gradient(self.image)
- self.div_ref = divergence(self.grad_ref)
- self.image2 = np.zeros_like(self.image)
- # Device images
- self.gradient_parray = parray.empty(self.la.queue, self.image.shape, np.complex64)
- self.gradient_parray.fill(0)
- # we should be using cl.Buffer(self.la.ctx, cl.mem_flags.READ_WRITE, size=self.image.nbytes*2),
- # but platforms not suporting openCL 1.2 have a problem with enqueue_fill_buffer,
- # so we use the parray "fill" utility
- self.gradient_buffer = self.gradient_parray.data
- # Do the same for image
- self.image_parray = parray.to_device(self.la.queue, self.image)
- self.image_buffer = self.image_parray.data
- # Refs
- tmp = np.zeros(self.image.shape, dtype=np.complex64)
- tmp.real = np.copy(self.grad_ref[0])
- tmp.imag = np.copy(self.grad_ref[1])
- self.grad_ref_parray = parray.to_device(self.la.queue, tmp)
- self.grad_ref_buffer = self.grad_ref_parray.data
-
- def tearDown(self):
- self.image = None
- self.image2 = None
- self.grad = None
- self.grad2 = None
- self.grad_ref = None
- self.div_ref = None
- self.gradient_parray.data.release()
- self.gradient_parray = None
- self.gradient_buffer = None
- self.image_parray.data.release()
- self.image_parray = None
- self.image_buffer = None
- self.grad_ref_parray.data.release()
- self.grad_ref_parray = None
- self.grad_ref_buffer = None
-
- def getfiles(self):
- # load 512x512 MRI phantom - TODO include Lena or ascent once a .npz is available
- self.image = np.load(utilstest.getfile("Brain512.npz"))["data"]
-
- def compare(self, result, reference, abstol, name):
- errmax = np.max(np.abs(result - reference))
- logger.info("%s: Max error = %e" % (name, errmax))
- self.assertTrue(errmax < abstol, str("%s: Max error is too high" % name))
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_gradient(self):
- arrays = {
- "numpy.ndarray": self.image,
- "buffer": self.image_buffer,
- "parray": self.image_parray
- }
- for desc, image in arrays.items():
- # Test with dst on host (numpy.ndarray)
- res = self.la.gradient(image, return_to_host=True)
- self.compare(res, self.grad_ref, 1e-6, str("gradient[src=%s, dst=numpy.ndarray]" % desc))
- # Test with dst on device (pyopencl.Buffer)
- self.la.gradient(image, dst=self.gradient_buffer)
- cl.enqueue_copy(self.la.queue, self.grad, self.gradient_buffer)
- self.grad2[0] = self.grad.real
- self.grad2[1] = self.grad.imag
- self.compare(self.grad2, self.grad_ref, 1e-6, str("gradient[src=%s, dst=buffer]" % desc))
- # Test with dst on device (pyopencl.Array)
- self.la.gradient(image, dst=self.gradient_parray)
- self.grad = self.gradient_parray.get()
- self.grad2[0] = self.grad.real
- self.grad2[1] = self.grad.imag
- self.compare(self.grad2, self.grad_ref, 1e-6, str("gradient[src=%s, dst=parray]" % desc))
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_divergence(self):
- arrays = {
- "numpy.ndarray": self.grad_ref,
- "buffer": self.grad_ref_buffer,
- "parray": self.grad_ref_parray
- }
- for desc, grad in arrays.items():
- # Test with dst on host (numpy.ndarray)
- res = self.la.divergence(grad, return_to_host=True)
- self.compare(res, self.div_ref, 1e-6, str("divergence[src=%s, dst=numpy.ndarray]" % desc))
- # Test with dst on device (pyopencl.Buffer)
- self.la.divergence(grad, dst=self.image_buffer)
- cl.enqueue_copy(self.la.queue, self.image2, self.image_buffer)
- self.compare(self.image2, self.div_ref, 1e-6, str("divergence[src=%s, dst=buffer]" % desc))
- # Test with dst on device (pyopencl.Array)
- self.la.divergence(grad, dst=self.image_parray)
- self.image2 = self.image_parray.get()
- self.compare(self.image2, self.div_ref, 1e-6, str("divergence[src=%s, dst=parray]" % desc))
-
- @unittest.skipUnless(ocl and mako and _has_scipy, "pyopencl and/or scipy is missing")
- def test_laplacian(self):
- laplacian_ref = laplace(self.image)
- # Laplacian = div(grad)
- self.la.gradient(self.image)
- laplacian_ocl = self.la.divergence(self.la.d_gradient, return_to_host=True)
- self.compare(laplacian_ocl, laplacian_ref, 1e-6, "laplacian")
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestLinAlg("test_gradient"))
- testSuite.addTest(TestLinAlg("test_divergence"))
- testSuite.addTest(TestLinAlg("test_laplacian"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_medfilt.py b/silx/opencl/test/test_medfilt.py
deleted file mode 100644
index 976b199..0000000
--- a/silx/opencl/test/test_medfilt.py
+++ /dev/null
@@ -1,175 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Project: Median filter of images + OpenCL
-# https://github.com/silx-kit/silx
-#
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following
-# conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
-# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
-# OTHER DEALINGS IN THE SOFTWARE.
-
-"""
-Simple test of the median filter
-"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Jérôme Kieffer"]
-__contact__ = "jerome.kieffer@esrf.eu"
-__license__ = "MIT"
-__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "05/07/2018"
-
-
-import sys
-import time
-import logging
-import numpy
-import unittest
-from collections import namedtuple
-try:
- import mako
-except ImportError:
- mako = None
-from ..common import ocl
-if ocl:
- import pyopencl
- import pyopencl.array
- from .. import medfilt
-
-logger = logging.getLogger(__name__)
-
-Result = namedtuple("Result", ["size", "error", "sp_time", "oc_time"])
-
-try:
- from scipy.misc import ascent
-except:
- def ascent():
- """Dummy image from random data"""
- return numpy.random.random((512, 512))
-try:
- from scipy.ndimage import filters
- median_filter = filters.median_filter
- HAS_SCIPY = True
-except:
- HAS_SCIPY = False
- from silx.math import medfilt2d as median_filter
-
-@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
-class TestMedianFilter(unittest.TestCase):
-
- def setUp(self):
- if ocl is None:
- return
- self.data = ascent().astype(numpy.float32)
- self.medianfilter = medfilt.MedianFilter2D(self.data.shape, devicetype="gpu")
-
- def tearDown(self):
- self.data = None
- self.medianfilter = None
-
- def measure(self, size):
- "Common measurement of accuracy and timings"
- t0 = time.time()
- if HAS_SCIPY:
- ref = median_filter(self.data, size, mode="nearest")
- else:
- ref = median_filter(self.data, size)
- t1 = time.time()
- try:
- got = self.medianfilter.medfilt2d(self.data, size)
- except RuntimeError as msg:
- logger.error(msg)
- return
- t2 = time.time()
- delta = abs(got - ref).max()
- return Result(size, delta, t1 - t0, t2 - t1)
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_medfilt(self):
- """
- tests the median filter kernel
- """
- r = self.measure(size=11)
- if r is None:
- logger.info("test_medfilt: size: %s: skipped")
- else:
- logger.info("test_medfilt: size: %s error %s, t_ref: %.3fs, t_ocl: %.3fs" % r)
- self.assertEqual(r.error, 0, 'Results are correct')
-
- def benchmark(self, limit=36):
- "Run some benchmarking"
- try:
- import PyQt5
- from ...gui.matplotlib import pylab
- from ...gui.utils import update_fig
- except:
- pylab = None
-
- def update_fig(*ag, **kwarg):
- pass
-
- fig = pylab.figure()
- fig.suptitle("Median filter of an image 512x512")
- sp = fig.add_subplot(1, 1, 1)
- sp.set_title(self.medianfilter.ctx.devices[0].name)
- sp.set_xlabel("Window width & height")
- sp.set_ylabel("Execution time (s)")
- sp.set_xlim(2, limit + 1)
- sp.set_ylim(0, 4)
- data_size = []
- data_scipy = []
- data_opencl = []
- plot_sp = sp.plot(data_size, data_scipy, "-or", label="scipy")[0]
- plot_opencl = sp.plot(data_size, data_opencl, "-ob", label="opencl")[0]
- sp.legend(loc=2)
- fig.show()
- update_fig(fig)
- for s in range(3, limit, 2):
- r = self.measure(s)
- print(r)
- if r.error == 0:
- data_size.append(s)
- data_scipy.append(r.sp_time)
- data_opencl.append(r.oc_time)
- plot_sp.set_data(data_size, data_scipy)
- plot_opencl.set_data(data_size, data_opencl)
- update_fig(fig)
- fig.show()
- if sys.version_info[0] < 3:
- raw_input()
- else:
- input()
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestMedianFilter("test_medfilt"))
- return testSuite
-
-
-def benchmark():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestMedianFilter("benchmark"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_projection.py b/silx/opencl/test/test_projection.py
deleted file mode 100644
index 7631128..0000000
--- a/silx/opencl/test/test_projection.py
+++ /dev/null
@@ -1,131 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of the forward projection module"""
-
-from __future__ import division, print_function
-
-__authors__ = ["Pierre paleo"]
-__license__ = "MIT"
-__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "19/01/2018"
-
-
-import time
-import logging
-import numpy as np
-import unittest
-try:
- import mako
-except ImportError:
- mako = None
-from ..common import ocl
-if ocl:
- from .. import projection
-from silx.test.utils import utilstest
-
-logger = logging.getLogger(__name__)
-
-
-@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
-class TestProj(unittest.TestCase):
-
- def setUp(self):
- if ocl is None:
- return
- # ~ if sys.platform.startswith('darwin'):
- # ~ self.skipTest("Projection is not implemented on CPU for OS X yet")
- self.getfiles()
- n_angles = self.sino.shape[0]
- self.proj = projection.Projection(self.phantom.shape, n_angles)
- if self.proj.compiletime_workgroup_size < 16 * 16:
- self.skipTest("Current implementation of OpenCL projection is not supported on this platform yet")
-
- def tearDown(self):
- self.phantom = None
- self.sino = None
- self.proj = None
-
- def getfiles(self):
- # load 512x512 MRI phantom
- self.phantom = np.load(utilstest.getfile("Brain512.npz"))["data"]
- # load sinogram computed with PyHST
- self.sino = np.load(utilstest.getfile("sino500_pyhst.npz"))["data"]
-
- def measure(self):
- "Common measurement of timings"
- t1 = time.time()
- try:
- result = self.proj.projection(self.phantom)
- except RuntimeError as msg:
- logger.error(msg)
- return
- t2 = time.time()
- return t2 - t1, result
-
- def compare(self, res):
- """
- Compare a result with the reference reconstruction.
- Only the valid reconstruction zone (inscribed circle) is taken into account
- """
- # Compare with the original phantom.
- # TODO: compare a standard projection
- ref = self.sino
- return np.max(np.abs(res - ref))
-
- @unittest.skipUnless(ocl and mako, "pyopencl is missing")
- def test_proj(self):
- """
- tests Projection
- """
- # Test single reconstruction
- # --------------------------
- t, res = self.measure()
- if t is None:
- logger.info("test_proj: skipped")
- else:
- logger.info("test_proj: time = %.3fs" % t)
- err = self.compare(res)
- msg = str("Max error = %e" % err)
- logger.info(msg)
- # Interpolation differs at some lines, giving relative error of 10/50000
- self.assertTrue(err < 20., "Max error is too high")
- # Test multiple reconstructions
- # -----------------------------
- res0 = np.copy(res)
- for i in range(10):
- res = self.proj.projection(self.phantom)
- errmax = np.max(np.abs(res - res0))
- self.assertTrue(errmax < 1.e-6, "Max error is too high")
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestProj("test_proj"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/opencl/test/test_sparse.py b/silx/opencl/test/test_sparse.py
deleted file mode 100644
index 76a6a0a..0000000
--- a/silx/opencl/test/test_sparse.py
+++ /dev/null
@@ -1,203 +0,0 @@
-#!/usr/bin/env python
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test of the sparse module"""
-
-import numpy as np
-import unittest
-import logging
-from itertools import product
-from ..common import ocl
-if ocl:
- import pyopencl.array as parray
- from silx.opencl.sparse import CSR
-try:
- import scipy.sparse as sp
-except ImportError:
- sp = None
-logger = logging.getLogger(__name__)
-
-
-
-def generate_sparse_random_data(
- shape=(1000,),
- data_min=0, data_max=100,
- density=0.1,
- use_only_integers=True,
- dtype="f"):
- """
- Generate random sparse data where.
-
- Parameters
- ------------
- shape: tuple
- Output data shape.
- data_min: int or float
- Minimum value of data
- data_max: int or float
- Maximum value of data
- density: float
- Density of non-zero elements in the output data.
- Low value of density mean low number of non-zero elements.
- use_only_integers: bool
- If set to True, the output data items will be primarily integers,
- possibly casted to float if dtype is a floating-point type.
- This can be used for ease of debugging.
- dtype: str or numpy.dtype
- Output data type
- """
- mask = np.random.binomial(1, density, size=shape)
- if use_only_integers:
- d = np.random.randint(data_min, high=data_max, size=shape)
- else:
- d = data_min + (data_max - data_min) * np.random.rand(*shape)
- return (d * mask).astype(dtype)
-
-
-
-@unittest.skipUnless(ocl and sp, "PyOpenCl/scipy is missing")
-class TestCSR(unittest.TestCase):
- """Test CSR format"""
-
- def setUp(self):
- # Test possible configurations
- input_on_device = [False, True]
- output_on_device = [False, True]
- dtypes = [np.float32, np.int32, np.uint16]
- self._test_configs = list(product(input_on_device, output_on_device, dtypes))
-
-
- def compute_ref_sparsification(self, array):
- ref_sparse = sp.csr_matrix(array)
- return ref_sparse
-
-
- def test_sparsification(self):
- for input_on_device, output_on_device, dtype in self._test_configs:
- self._test_sparsification(input_on_device, output_on_device, dtype)
-
-
- def _test_sparsification(self, input_on_device, output_on_device, dtype):
- current_config = "input on device: %s, output on device: %s, dtype: %s" % (
- str(input_on_device), str(output_on_device), str(dtype)
- )
- logger.debug("CSR: %s" % current_config)
- # Generate data and reference CSR
- array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
- ref_sparse = self.compute_ref_sparsification(array)
- # Sparsify on device
- csr = CSR(array.shape, dtype=dtype)
- if input_on_device:
- # The array has to be flattened
- arr = parray.to_device(csr.queue, array.ravel())
- else:
- arr = array
- if output_on_device:
- d_data = parray.empty_like(csr.data)
- d_indices = parray.empty_like(csr.indices)
- d_indptr = parray.empty_like(csr.indptr)
- d_data.fill(0)
- d_indices.fill(0)
- d_indptr.fill(0)
- output = (d_data, d_indices, d_indptr)
- else:
- output = None
- data, indices, indptr = csr.sparsify(arr, output=output)
- if output_on_device:
- data = data.get()
- indices = indices.get()
- indptr = indptr.get()
- # Compare
- nnz = ref_sparse.nnz
- self.assertTrue(
- np.allclose(data[:nnz], ref_sparse.data),
- "something wrong with sparsified data (%s)"
- % current_config
- )
- self.assertTrue(
- np.allclose(indices[:nnz], ref_sparse.indices),
- "something wrong with sparsified indices (%s)"
- % current_config
- )
- self.assertTrue(
- np.allclose(indptr, ref_sparse.indptr),
- "something wrong with sparsified indices pointers (indptr) (%s)"
- % current_config
- )
-
-
- def test_desparsification(self):
- for input_on_device, output_on_device, dtype in self._test_configs:
- self._test_desparsification(input_on_device, output_on_device, dtype)
-
-
- def _test_desparsification(self, input_on_device, output_on_device, dtype):
- current_config = "input on device: %s, output on device: %s, dtype: %s" % (
- str(input_on_device), str(output_on_device), str(dtype)
- )
- logger.debug("CSR: %s" % current_config)
- # Generate data and reference CSR
- array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
- ref_sparse = self.compute_ref_sparsification(array)
- # De-sparsify on device
- csr = CSR(array.shape, dtype=dtype, max_nnz=ref_sparse.nnz)
- if input_on_device:
- data = parray.to_device(csr.queue, ref_sparse.data)
- indices = parray.to_device(csr.queue, ref_sparse.indices)
- indptr = parray.to_device(csr.queue, ref_sparse.indptr)
- else:
- data = ref_sparse.data
- indices = ref_sparse.indices
- indptr = ref_sparse.indptr
- if output_on_device:
- d_arr = parray.empty_like(csr.array)
- d_arr.fill(0)
- output = d_arr
- else:
- output = None
- arr = csr.densify(data, indices, indptr, output=output)
- if output_on_device:
- arr = arr.get()
- # Compare
- self.assertTrue(
- np.allclose(arr.reshape(array.shape), array),
- "something wrong with densified data (%s)"
- % current_config
- )
-
-
-
-def suite():
- suite = unittest.TestSuite()
- suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestCSR)
- )
- return suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
-
-
diff --git a/silx/opencl/test/test_stats.py b/silx/opencl/test/test_stats.py
deleted file mode 100644
index 8baf05e..0000000
--- a/silx/opencl/test/test_stats.py
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Project: Sift implementation in Python + OpenCL
-# https://github.com/silx-kit/silx
-#
-# Permission is hereby granted, free of charge, to any person
-# obtaining a copy of this software and associated documentation
-# files (the "Software"), to deal in the Software without
-# restriction, including without limitation the rights to use,
-# copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the
-# Software is furnished to do so, subject to the following
-# conditions:
-#
-# The above copyright notice and this permission notice shall be
-# included in all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
-# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
-# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
-# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
-# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
-# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
-# OTHER DEALINGS IN THE SOFTWARE.
-
-"""
-Simple test of an addition
-"""
-__authors__ = ["Henri Payno, Jérôme Kieffer"]
-__contact__ = "jerome.kieffer@esrf.eu"
-__license__ = "MIT"
-__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "19/05/2021"
-
-import logging
-import time
-import numpy
-
-import unittest
-from ..common import ocl
-if ocl:
- import pyopencl
- import pyopencl.array
- from ..statistics import StatResults, Statistics
-from ..utils import get_opencl_code
-logger = logging.getLogger(__name__)
-
-
-@unittest.skipUnless(ocl, "PyOpenCl is missing")
-class TestStatistics(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- cls.size = 1 << 20 # 1 million elements
- cls.data = numpy.random.randint(0, 65000, cls.size).astype("uint16")
- fdata = cls.data.astype("float64")
- t0 = time.perf_counter()
- std = fdata.std()
- cls.ref = StatResults(fdata.min(), fdata.max(), float(fdata.size),
- fdata.sum(), fdata.mean(), std ** 2,
- std)
- t1 = time.perf_counter()
- cls.ref_time = t1 - t0
-
- @classmethod
- def tearDownClass(cls):
- cls.size = cls.ref = cls.data = cls.ref_time = None
-
- @classmethod
- def validate(cls, res):
- return (
- (res.min == cls.ref.min) and
- (res.max == cls.ref.max) and
- (res.cnt == cls.ref.cnt) and
- abs(res.mean - cls.ref.mean) < 0.01 and
- abs(res.std - cls.ref.std) < 0.1)
-
- def test_measurement(self):
- """
- tests that all devices are working properly ...
- """
- logger.info("Reference results: %s", self.ref)
- for pid, platform in enumerate(ocl.platforms):
- for did, device in enumerate(platform.devices):
- try:
- s = Statistics(template=self.data, platformid=pid, deviceid=did)
- except Exception as err:
- failed_init = True
- res = StatResults(0, 0, 0, 0, 0, 0, 0)
- print(err)
- else:
- failed_init = False
- for comp in ("single", "double", "comp"):
- t0 = time.perf_counter()
- res = s(self.data, comp=comp)
- t1 = time.perf_counter()
- logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0))
-
- if failed_init or not self.validate(res):
- logger.error("failed_init %s; Computation modes %s", failed_init, comp)
- logger.error("Failed on platform %s device %s", platform, device)
- logger.error("Reference results: %s", self.ref)
- logger.error("Faulty results: %s", res)
- self.assertTrue(False, f"Stat calculation failed on {platform},{device} in mode {comp}")
-
-
-def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(TestStatistics("test_measurement"))
- return testSuite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/resources/gui/icons/compare-align-auto.svg b/silx/resources/gui/icons/compare-align-auto.svg
deleted file mode 100644
index de82c30..0000000
--- a/silx/resources/gui/icons/compare-align-auto.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/></cc:Work></rdf:RDF></metadata><radialGradient id="a" cx="22.443" cy="21.502" r="0" gradientUnits="userSpaceOnUse"><stop id="stop21" stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop id="stop23" stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop id="stop25" stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path id="path34" d="m22.443 21.502" fill="url(#a)"/><path id="path36" d="m10.992 6.764s4.839-0.584 5.992 4.366" fill="none" stroke="#FFF" stroke-miterlimit="10" stroke-width="1.2"/><g id="g4597" transform="matrix(.89618 0 0 .89618 33.643 30.672)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#fab058" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/><text id="text4553" x="-33.067287" y="-5.5593224" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551" x="-33.067287" y="-5.5593224" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">A</tspan></text>
-</g><g id="g4602" transform="matrix(.50611 .17057 -.17057 .50611 -5.8136 18.919)"><rect id="rect2-6-4" x="33.767" y="-32.267" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2" stroke-width="2.6213"/><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
-</g></svg>
diff --git a/silx/resources/gui/icons/compare-align-center.svg b/silx/resources/gui/icons/compare-align-center.svg
deleted file mode 100644
index 1888820..0000000
--- a/silx/resources/gui/icons/compare-align-center.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><radialGradient id="a" cx="22.443" cy="21.502" r="0" gradientTransform="translate(-6.443 -5.502)" gradientUnits="userSpaceOnUse"><stop id="stop21" stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop id="stop23" stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop id="stop25" stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path id="path34" d="m16 16" fill="url(#a)"/><g id="g4597" transform="matrix(.89618 0 0 .89618 35.067 29.248)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#fab058" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/>
-</g><g id="g4602" transform="matrix(.70181 0 0 .70181 -16.83 29.513)"><rect id="rect2-6-4" x="33.767" y="-32.267" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.9948"/><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
-</g></svg>
diff --git a/silx/resources/gui/icons/compare-align-origin.svg b/silx/resources/gui/icons/compare-align-origin.svg
deleted file mode 100644
index efccf50..0000000
--- a/silx/resources/gui/icons/compare-align-origin.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><radialGradient id="a" cx="22.443" cy="21.502" r="0" gradientTransform="translate(1.4237 -1.4237)" gradientUnits="userSpaceOnUse"><stop id="stop21" stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop id="stop23" stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop id="stop25" stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path id="path34" d="m23.867 20.078" fill="url(#a)"/><g id="g4597" transform="matrix(.89618 0 0 .89618 35.067 29.248)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#fab058" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/>
-</g><g id="g4602" transform="matrix(.70181 0 0 .70181 -19.285 27.058)"><rect id="rect2-6-4" x="33.767" y="-32.267" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.9948"/><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
-</g></svg>
diff --git a/silx/resources/gui/icons/compare-align-stretch.svg b/silx/resources/gui/icons/compare-align-stretch.svg
deleted file mode 100644
index 4c4b653..0000000
--- a/silx/resources/gui/icons/compare-align-stretch.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><radialGradient id="a" cx="22.443" cy="21.502" r="0" gradientTransform="translate(-6.443 -5.502)" gradientUnits="userSpaceOnUse"><stop id="stop21" stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop id="stop23" stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop id="stop25" stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path id="path34" d="m16 16" fill="url(#a)"/><g id="g4597" transform="matrix(.89618 0 0 .89618 35.067 29.248)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/>
-</g><g id="g4602" transform="matrix(.70866 0 0 .70866 -17.151 29.645)"><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
-</g><path id="rect969" d="m6.3051 6.1695h5.4237l-5.4237 5.4237z" color="#000000" fill="#f0f"/><path id="rect969-5" d="m26.034 6.1695h-5.4237l5.4237 5.4237z" color="#000000" fill="#f0f"/><path id="rect969-5-3" d="m26.034 25.763h-5.4237l5.4237-5.4237z" color="#000000" fill="#f0f"/><path id="rect969-5-3-5" d="m6.3051 25.763h5.4237l-5.4237-5.4237z" color="#000000" fill="#f0f"/></svg>
diff --git a/silx/resources/gui/icons/math-peak-search.svg b/silx/resources/gui/icons/math-peak-search.svg
deleted file mode 100644
index 2c19792..0000000
--- a/silx/resources/gui/icons/math-peak-search.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><defs><filter id="a" color-interpolation-filters="sRGB"><feGaussianBlur stdDeviation="1.2128746"/></filter></defs><path d="m4.356 26.781c0.66-0.935 1.841-0.809 2.729-1.399 0.703-0.467 0.856-1.623 0.992-2.349 0.218-1.165-0.362-4.839 1.218-5.27 1.004-0.274 1.677-0.422 2.422-1.176 1.721-1.742 1.883-4.988 2.669-7.182 0.504-1.407 1.142-1.524 1.711-0.079 0.35 0.886 0.697 1.771 1.017 2.668 0.689 1.934 1.256 3.931 1.737 5.926 0.45 1.865 0.957 3.707 1.576 5.523 0.279 0.821 0.38 1.479 1.177 1.893 1.154 0.598 1.675-0.925 1.896-1.673 0.278-0.937 0.439-1.908 0.69-2.854 0.455-1.711 0.864 0.714 1.019 1.371 0.442 1.884 0.466 3.932 1.071 5.769 0.181 0.549 1.05 0.314 0.867-0.238-0.398-1.209-0.782-9.396-2.967-8.609-1.242 0.448-1.363 3.699-1.672 4.738-0.364 1.226-1.034-0.032-1.215-0.635-0.366-1.225-0.775-2.429-1.108-3.664-0.629-2.33-1.193-4.659-1.927-6.96-0.276-0.867-1.45-6-3.046-5.583-2.015 0.528-2.388 4.501-2.846 6.112-0.615 2.163-1.571 3.309-3.726 3.896-0.864 0.236-1.143 0.979-1.28 1.771-0.3 1.735 0.738 5.357-1.488 6.215-1.107 0.426-1.578 0.317-2.295 1.332-0.334 0.478 0.447 0.927 0.779 0.457z"/><g transform="translate(1.6271 .13559)" filter="url(#a)"><path d="m2.1425 16.187c-0.417 0.236-1.12 0.115-1.557-0.271-0.442-0.39-0.455-0.906-0.039-1.147l7.33-4.184c0.422-0.242 1.121-0.119 1.56 0.27 0.44 0.392 0.457 0.901 0.035 1.146l-7.329 4.186z" stroke="#00a651" stroke-miterlimit="10" stroke-width=".1"/><path d="m14.176 2.8136c-1.8408-0.22181-3.7106 0.0891-5.25 0.96875-1.5391 0.88172-2.4552 2.2584-2.5625 3.75-0.10727 1.4916 0.57148 3.0357 1.9375 4.25 2.7388 2.4255 7.203 2.9807 10.281 1.2188 1.5391-0.87925 2.4546-2.2587 2.5625-3.75 0.10787-1.4913-0.5729-3.0355-1.9375-4.25-1.3708-1.2142-3.1904-1.9657-5.0312-2.1875zm-0.15625 1.5625c1.5617 0.18769 3.0903 0.77817 4.1875 1.75 1.0904 0.97048 1.5071 2.0373 1.4375 3-0.06963 0.96271-0.62261 1.8827-1.8125 2.5625-2.3797 1.3621-6.3401 0.90923-8.5312-1.0312-1.092-0.9707-1.5067-2.0686-1.4375-3.0312 0.06923-0.96267 0.62157-1.849 1.8125-2.5312 1.1906-0.68035 2.7821-0.90644 4.3437-0.71875z" color="#000000" style="block-progression:tb;text-indent:0;text-transform:none"/><path d="m30.572 31.718c0.247 0.361 0.019 0.865-0.506 1.109-0.531 0.246-1.174 0.141-1.42-0.221l-4.346-6.416c-0.255-0.369-0.025-0.869 0.502-1.111 0.533-0.244 1.163-0.146 1.422 0.227l4.348 6.412z" stroke="#00a651" stroke-miterlimit="10" stroke-width=".1"/><path d="m21.551 15.595c-0.87491 0.08975-1.7393 0.30814-2.5625 0.6875-1.6444 0.76154-2.8268 2.0268-3.3438 3.4688-0.51696 1.4419-0.34202 3.0547 0.59375 4.4375v0.03125c1.8808 2.7617 5.9597 3.6148 9.25 2.0938 1.6461-0.76046 2.8267-2.0267 3.3438-3.4688 0.5171-1.442 0.34525-3.0565-0.59375-4.4375-1.4049-2.0747-4.0628-3.0818-6.6875-2.8125zm0.15625 1.5c2.128-0.19847 4.2576 0.6445 5.2812 2.1562 0.684 1.006 0.7729 2.0713 0.40625 3.0938s-1.2054 1.9812-2.5312 2.5938c-2.6497 1.2249-6.0018 0.45384-7.375-1.5625-0.68223-1.0082-0.80429-2.1019-0.4375-3.125s1.2379-1.9803 2.5625-2.5938c0.66327-0.30564 1.3844-0.49634 2.0938-0.5625z" color="#000000" style="block-progression:tb;text-indent:0;text-transform:none"/></g><g stroke="#00a651" stroke-miterlimit="10"><path d="m3.222 15.385c-0.417 0.236-1.12 0.115-1.557-0.271-0.442-0.39-0.455-0.906-0.039-1.147l7.33-4.184c0.422-0.242 1.121-0.119 1.56 0.27 0.44 0.392 0.457 0.901 0.035 1.146l-7.329 4.186z" fill="#00a651" stroke-width=".1"/><path d="m19.291 11.538c-2.729 1.562-6.936 1.054-9.401-1.129-2.458-2.185-2.241-5.219 0.489-6.783 2.73-1.56 6.936-1.054 9.404 1.132 2.455 2.185 2.237 5.221-0.492 6.78z" fill="none" stroke-width="1.5"/></g><g stroke="#00a651" stroke-miterlimit="10"><path d="m31.651 30.916c0.247 0.361 0.019 0.865-0.506 1.109-0.531 0.246-1.174 0.141-1.42-0.221l-4.346-6.416c-0.255-0.369-0.025-0.869 0.502-1.111 0.533-0.244 1.163-0.146 1.422 0.227l4.348 6.412z" fill="#00a651" stroke-width=".1"/><path d="m28.693 18.014c1.623 2.387 0.53 5.436-2.442 6.809-2.97 1.373-6.686 0.547-8.313-1.842-1.618-2.391-0.526-5.438 2.443-6.813 2.973-1.37 6.693-0.545 8.312 1.846z" fill="none" stroke-width="1.5"/></g></svg>
diff --git a/silx/resources/gui/icons/remove.svg b/silx/resources/gui/icons/remove.svg
deleted file mode 100644
index 4ac0f67..0000000
--- a/silx/resources/gui/icons/remove.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><radialGradient id="c" cx="22.443" cy="21.502" r="0" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path d="m22.443 21.502" fill="url(#c)"/><linearGradient id="d" x1="22.414" x2="22.473" y1="21.502" y2="21.502" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m22.443 24.002c0.038 0 0.038-5 0-5s-0.038 5 0 5z" fill="url(#d)"/><path d="m8.293 10.899c5.462 5.68 10.925 11.36 16.387 17.04-0.814-0.847 0.851-4.115 0-5-5.462-5.68-10.924-11.36-16.387-17.04 0.814 0.847-0.851 4.116 0 5z" fill="#ed1c24"/><path d="m24.452 5.675c-5.434 5.658-10.869 11.317-16.304 16.975-0.851 0.886 0.814 4.152 0 5 5.435-5.658 10.869-11.317 16.304-16.976 0.851-0.884-0.814-4.152 0-4.999z" fill="#ed1c24"/></svg>
diff --git a/silx/resources/gui/icons/zoom-back.svg b/silx/resources/gui/icons/zoom-back.svg
deleted file mode 100644
index cf47b8f..0000000
--- a/silx/resources/gui/icons/zoom-back.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="d" x1="20.887" x2="23.374" y1="21.759" y2="18.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><radialGradient id="c" cx="13.206" cy="8.4126" r="9.1344" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="b" x1="4.605" x2="18.267" y1="12.302" y2="12.302" gradientUnits="userSpaceOnUse"><stop stop-color="#FFF" offset="0"/><stop offset="1"/></linearGradient><radialGradient id="a" cx="22.443" cy="21.502" r="0" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><polygon points="28.606 22.356 26.103 25.571 15.723 17.758 18.174 14.502" fill="url(#d)" stroke="#808285" stroke-miterlimit="10" stroke-width=".2"/><circle cx="11.483" cy="12.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><path d="m17.967 12.302c0 3.594-3.039 6.507-6.518 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.518 2.913 6.518 6.507z" fill="url(#c)" stroke="url(#b)" stroke-miterlimit="10" stroke-width=".6"/><path d="m22.443 21.502" fill="url(#a)"/><path d="m10.992 6.764s4.839-0.584 5.992 4.366" fill="none" stroke="#FFF" stroke-miterlimit="10" stroke-width="1.2"/><g transform="matrix(-1 0 0 1 23.132 -18.833)" fill="#F00" stroke="#F00" stroke-miterlimit="10"><path d="m4.7543 24.006c-10.964-0.107-10.073 10.653-10.266 10.974 0 0 0.193-7.139 10.267-6.713v-4.261z" stroke-width=".1"/><path d="m4.7543 22.329c0-0.17 0.122-0.243 0.271-0.16l6.169 3.403c0.149 0.083 0.157 0.23 0.018 0.328l-6.204 4.348c-0.14 0.098-0.254 0.038-0.254-0.132v-7.787z"/></g></svg>
diff --git a/silx/resources/gui/icons/zoom-in.svg b/silx/resources/gui/icons/zoom-in.svg
deleted file mode 100644
index f062a7d..0000000
--- a/silx/resources/gui/icons/zoom-in.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="19.887" x2="22.374" y1="23.759" y2="20.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="27.606 24.356 25.103 27.571 14.723 19.758 17.174 16.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".1"/><circle cx="10.483" cy="14.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="12.253" cy="10.413" r="9.1342" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="3.4521" x2="17.514" y1="14.302" y2="14.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m17.014 14.301c0 3.594-3.038 6.507-6.517 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.517 2.914 6.517 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10"/><radialGradient id="h" cx="21.443" cy="23.502" r="0" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path d="m21.443 23.502" fill="url(#h)"/><path d="m9.177 9.151s4.405-1.127 6.307 3.42" fill="none" stroke="#fff" stroke-miterlimit="10"/><g fill="#00a651" stroke="#00a651" stroke-miterlimit="10"><rect x="24.483" y="7.225" width="1.239" height="8.379"/><rect x="20.913" y="10.796" width="8.38" height="1.237"/></g></svg>
diff --git a/silx/resources/gui/icons/zoom-original.svg b/silx/resources/gui/icons/zoom-original.svg
deleted file mode 100644
index f20556b..0000000
--- a/silx/resources/gui/icons/zoom-original.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="20.888" x2="23.375" y1="23.759" y2="20.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="28.606 24.356 26.103 27.571 15.723 19.758 18.174 16.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".1"/><circle cx="11.483" cy="14.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="13.253" cy="10.413" r="9.1342" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="4.4521" x2="18.514" y1="14.302" y2="14.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m18.014 14.301c0 3.594-3.038 6.507-6.517 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.517 2.914 6.517 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10"/><radialGradient id="h" cx="22.443" cy="23.502" r="0" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path d="m22.443 23.502" fill="url(#h)"/><path d="m10.177 9.151s4.405-1.127 6.307 3.42" fill="none" stroke="#fff" stroke-miterlimit="10"/><g fill="#ed1c24" stroke="#ed1c24" stroke-miterlimit="10" stroke-width="2.5"><line x1="7.257" x2="25.712" y1="24.906" y2="6.518"/><line x1="7.392" x2="25.575" y1="6.371" y2="25.053"/></g></svg>
diff --git a/silx/resources/gui/icons/zoom-out.svg b/silx/resources/gui/icons/zoom-out.svg
deleted file mode 100644
index fe3b8dd..0000000
--- a/silx/resources/gui/icons/zoom-out.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="19.887" x2="22.374" y1="22.759" y2="19.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="27.606 23.356 25.103 26.571 14.723 18.758 17.174 15.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".1"/><circle cx="10.483" cy="13.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="12.253" cy="9.4126" r="9.1342" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="3.4521" x2="17.514" y1="13.302" y2="13.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m17.014 13.301c0 3.594-3.038 6.507-6.517 6.507s-6.544-2.914-6.544-6.507 3.065-6.507 6.544-6.507 6.517 2.914 6.517 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10"/><radialGradient id="h" cx="21.443" cy="22.502" r="0" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path d="m21.443 22.502" fill="url(#h)"/><path d="m9.177 8.151s4.405-1.127 6.307 3.42" fill="none" stroke="#fff" stroke-miterlimit="10"/><rect x="20.304" y="7.802" width="7.377" height=".988" fill="#ed1c24" stroke="#ed1c24" stroke-miterlimit="10" stroke-width="2"/></svg>
diff --git a/silx/resources/gui/icons/zoom.svg b/silx/resources/gui/icons/zoom.svg
deleted file mode 100644
index 448f3b9..0000000
--- a/silx/resources/gui/icons/zoom.svg
+++ /dev/null
@@ -1,2 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="20.887" x2="23.374" y1="21.759" y2="18.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="28.606 22.356 26.103 25.571 15.723 17.758 18.174 14.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".2"/><circle cx="11.483" cy="12.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="13.206" cy="8.4126" r="9.1344" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="4.605" x2="18.267" y1="12.302" y2="12.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m17.967 12.302c0 3.594-3.039 6.507-6.518 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.518 2.913 6.518 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10" stroke-width=".6"/><radialGradient id="h" cx="22.443" cy="21.502" r="0" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#517180" stop-opacity=".5317" offset=".6832"/><stop stop-color="#414042" stop-opacity=".5" offset="1"/></radialGradient><path d="m22.443 21.502" fill="url(#h)"/><path d="m10.992 6.764s4.839-0.584 5.992 4.366" fill="none" stroke="#fff" stroke-miterlimit="10" stroke-width="1.2"/></svg>
diff --git a/silx/setup.py b/silx/setup.py
deleted file mode 100644
index 2575ebf..0000000
--- a/silx/setup.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/07/2018"
-
-from numpy.distutils.misc_util import Configuration
-
-
-def configuration(parent_package='', top_path=None):
- config = Configuration('silx', parent_package, top_path)
- config.add_subpackage('gui')
- config.add_subpackage('io')
- config.add_subpackage('math')
- config.add_subpackage('image')
- config.add_subpackage('opencl')
- config.add_subpackage('resources')
- config.add_subpackage('sx')
- config.add_subpackage('test')
- config.add_subpackage('third_party')
- config.add_subpackage('utils')
- config.add_subpackage('app')
- config.add_subpackage("examples", "../examples")
-
- return config
-
-
-if __name__ == "__main__":
- from numpy.distutils.core import setup
-
- setup(configuration=configuration)
diff --git a/silx/sx/_plot.py b/silx/sx/_plot.py
deleted file mode 100644
index 5746492..0000000
--- a/silx/sx/_plot.py
+++ /dev/null
@@ -1,623 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This module adds convenient functions to use plot widgets from the console.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "06/11/2018"
-
-
-import collections
-try:
- from collections import abc
-except ImportError: # Python2 support
- import collections as abc
-import logging
-import weakref
-
-import numpy
-import six
-
-from ..utils.weakref import WeakList
-from ..gui import qt
-from ..gui.plot import Plot1D, Plot2D, ScatterView
-from ..gui.plot import items
-from ..gui import colors
-from ..gui.plot.tools import roi
-from ..gui.plot.items import roi as roi_items
-from ..gui.plot.tools.toolbars import InteractiveModeToolBar
-
-_logger = logging.getLogger(__name__)
-
-_plots = WeakList()
-"""List of widgets created through plot and imshow"""
-
-
-def plot(*args, **kwargs):
- """
- Plot curves in a :class:`~silx.gui.plot.PlotWindow.Plot1D` widget.
-
- How to use:
-
- >>> from silx import sx
- >>> import numpy
-
- Plot a single curve given some values:
-
- >>> values = numpy.random.random(100)
- >>> plot_1curve = sx.plot(values, title='Random data')
-
- Plot a single curve given the x and y values:
-
- >>> angles = numpy.linspace(0, numpy.pi, 100)
- >>> sin_a = numpy.sin(angles)
- >>> plot_sinus = sx.plot(angles, sin_a, xlabel='angle (radian)', ylabel='sin(a)')
-
- Plot many curves by giving a 2D array, provided xn, yn arrays:
-
- >>> plot_curves = sx.plot(x0, y0, x1, y1, x2, y2, ...)
-
- Plot curve with style giving a style string:
-
- >>> plot_styled = sx.plot(x0, y0, 'ro-', x1, y1, 'b.')
-
- Supported symbols:
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
-
- Supported types of line:
-
- - ' ' no line
- - '-' solid line
- - '--' dashed line
- - '-.' dash-dot line
- - ':' dotted line
-
- If provided, the names arguments color, linestyle, linewidth and marker
- override any style provided to a curve.
-
- This function supports a subset of `matplotlib.pyplot.plot
- <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot>`_
- arguments.
-
- :param str color: Color to use for all curves (default: None)
- :param str linestyle: Type of line to use for all curves (default: None)
- :param float linewidth: With of all the curves (default: 1)
- :param str marker: Symbol to use for all the curves (default: None)
- :param str title: The title of the Plot widget (default: None)
- :param str xlabel: The label of the X axis (default: None)
- :param str ylabel: The label of the Y axis (default: None)
- :return: The widget plotting the curve(s)
- :rtype: silx.gui.plot.Plot1D
- """
- plt = Plot1D()
- if 'title' in kwargs:
- plt.setGraphTitle(kwargs['title'])
- if 'xlabel' in kwargs:
- plt.getXAxis().setLabel(kwargs['xlabel'])
- if 'ylabel' in kwargs:
- plt.getYAxis().setLabel(kwargs['ylabel'])
-
- color = kwargs.get('color')
- linestyle = kwargs.get('linestyle')
- linewidth = kwargs.get('linewidth')
- marker = kwargs.get('marker')
-
- # Parse args and store curves as (x, y, style string)
- args = list(args)
- curves = []
- while args:
- first_arg = args.pop(0) # Process an arg
-
- if len(args) == 0:
- # Last curve defined as (y,)
- curves.append((numpy.arange(len(first_arg)), first_arg, None))
- else:
- second_arg = args.pop(0)
- if isinstance(second_arg, six.string_types):
- # curve defined as (y, style)
- y = first_arg
- style = second_arg
- curves.append((numpy.arange(len(y)), y, style))
- else: # second_arg must be an array-like
- x = first_arg
- y = second_arg
- if len(args) >= 1 and isinstance(args[0], six.string_types):
- # Curve defined as (x, y, style)
- style = args.pop(0)
- curves.append((x, y, style))
- else:
- # Curve defined as (x, y)
- curves.append((x, y, None))
-
- for index, curve in enumerate(curves):
- x, y, style = curve
-
- # Default style
- curve_symbol, curve_linestyle, curve_color = None, None, None
-
- # Parse style
- if style:
- # Handle color first
- possible_colors = [c for c in colors.COLORDICT if style.startswith(c)]
- if possible_colors: # Take the longest string matching a color name
- curve_color = possible_colors[0]
- for c in possible_colors[1:]:
- if len(c) > len(curve_color):
- curve_color = c
- style = style[len(curve_color):]
-
- if style:
- # Run twice to handle inversion symbol/linestyle
- for _i in range(2):
- # Handle linestyle
- for line in (' ', '--', '-', '-.', ':'):
- if style.endswith(line):
- curve_linestyle = line
- style = style[:-len(line)]
- break
-
- # Handle symbol
- for curve_marker in ('o', '.', ',', '+', 'x', 'd', 's'):
- if style.endswith(curve_marker):
- curve_symbol = style[-1]
- style = style[:-1]
- break
-
- # As in matplotlib, marker, linestyle and color override other style
- plt.addCurve(x, y,
- legend=('curve_%d' % index),
- symbol=marker or curve_symbol,
- linestyle=linestyle or curve_linestyle,
- linewidth=linewidth,
- color=color or curve_color)
-
- plt.show()
- _plots.insert(0, plt)
- return plt
-
-
-def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
- vmin=None, vmax=None,
- aspect=False,
- origin='upper', scale=(1., 1.),
- title='', xlabel='X', ylabel='Y'):
- """
- Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget.
-
- How to use:
-
- >>> from silx import sx
- >>> import numpy
-
- >>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024)
- >>> plt = sx.imshow(data, title='Random data')
-
- By default, the image origin is displayed in the upper left
- corner of the plot. To invert the Y axis, and place the image origin
- in the lower left corner of the plot, use the *origin* parameter:
-
- >>> plt = sx.imshow(data, origin='lower')
-
- This function supports a subset of `matplotlib.pyplot.imshow
- <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_
- arguments.
-
- :param data: data to plot as an image
- :type data: numpy.ndarray-like with 2 dimensions
- :param str cmap: The name of the colormap to use for the plot. It also
- supports a numpy array containing a RGB LUT, or a `colors.Colormap`
- instance.
- :param str norm: The normalization of the colormap:
- 'linear' (default) or 'log'
- :param float vmin: The value to use for the min of the colormap
- :param float vmax: The value to use for the max of the colormap
- :param bool aspect: True to keep aspect ratio (Default: False)
- :param origin: Either image origin as the Y axis orientation:
- 'upper' (default) or 'lower'
- or the coordinates (ox, oy) of the image origin in the plot.
- :type origin: str or 2-tuple of floats
- :param scale: (sx, sy) The scale of the image in the plot
- (i.e., the size of the image's pixel in plot coordinates)
- :type scale: 2-tuple of floats
- :param str title: The title of the Plot widget
- :param str xlabel: The label of the X axis
- :param str ylabel: The label of the Y axis
- :return: The widget plotting the image
- :rtype: silx.gui.plot.Plot2D
- """
- plt = Plot2D()
- plt.setGraphTitle(title)
- plt.getXAxis().setLabel(xlabel)
- plt.getYAxis().setLabel(ylabel)
-
- # Update default colormap with input parameters
- colormap = plt.getDefaultColormap()
- if isinstance(cmap, colors.Colormap):
- colormap = cmap
- plt.setDefaultColormap(colormap)
- elif isinstance(cmap, numpy.ndarray):
- colormap.setColors(cmap)
- elif cmap is not None:
- colormap.setName(cmap)
- assert norm in colors.Colormap.NORMALIZATIONS
- colormap.setNormalization(norm)
- colormap.setVMin(vmin)
- colormap.setVMax(vmax)
-
- # Handle aspect
- if aspect in (None, False, 'auto', 'normal'):
- plt.setKeepDataAspectRatio(False)
- elif aspect in (True, 'equal') or aspect == 1:
- plt.setKeepDataAspectRatio(True)
- else:
- _logger.warning(
- 'imshow: Unhandled aspect argument: %s', str(aspect))
-
- # Handle matplotlib-like origin
- if origin in ('upper', 'lower'):
- plt.setYAxisInverted(origin == 'upper')
- origin = 0., 0. # Set origin to the definition of silx
-
- if data is not None:
- data = numpy.array(data, copy=True)
-
- assert data.ndim in (2, 3) # data or RGB(A)
- if data.ndim == 3:
- assert data.shape[-1] in (3, 4) # RGB(A) image
-
- plt.addImage(data, origin=origin, scale=scale)
-
- plt.show()
- _plots.insert(0, plt)
- return plt
-
-
-def scatter(x=None, y=None, value=None, size=None,
- marker=None,
- cmap=None, norm=colors.Colormap.LINEAR,
- vmin=None, vmax=None):
- """
- Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
-
- How to use:
-
- >>> from silx import sx
- >>> import numpy
-
- >>> x = numpy.random.random(100)
- >>> y = numpy.random.random(100)
- >>> values = numpy.random.random(100)
- >>> plt = sx.scatter(x, y, values, cmap='viridis')
-
- Supported symbols:
-
- - 'o' circle
- - '.' point
- - ',' pixel
- - '+' cross
- - 'x' x-cross
- - 'd' diamond
- - 's' square
-
- This function supports a subset of `matplotlib.pyplot.scatter
- <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_
- arguments.
-
- :param numpy.ndarray x: 1D array-like of x coordinates
- :param numpy.ndarray y: 1D array-like of y coordinates
- :param numpy.ndarray value: 1D array-like of data values
- :param float size: Size^2 of the markers
- :param str marker: Symbol used to represent the points
- :param str cmap: The name of the colormap to use for the plot
- :param str norm: The normalization of the colormap:
- 'linear' (default) or 'log'
- :param float vmin: The value to use for the min of the colormap
- :param float vmax: The value to use for the max of the colormap
- :return: The widget plotting the scatter plot
- :rtype: silx.gui.plot.ScatterView.ScatterView
- """
- plt = ScatterView()
-
- # Update default colormap with input parameters
- colormap = plt.getPlotWidget().getDefaultColormap()
- if cmap is not None:
- colormap.setName(cmap)
- assert norm in colors.Colormap.NORMALIZATIONS
- colormap.setNormalization(norm)
- colormap.setVMin(vmin)
- colormap.setVMax(vmax)
- plt.getPlotWidget().setDefaultColormap(colormap)
-
- if x is not None and y is not None: # Add a scatter plot
- x = numpy.array(x, copy=True).reshape(-1)
- y = numpy.array(y, copy=True).reshape(-1)
- assert len(x) == len(y)
-
- if value is None:
- value = numpy.ones(len(x), dtype=numpy.float32)
-
- elif isinstance(value, abc.Iterable):
- value = numpy.array(value, copy=True).reshape(-1)
- assert len(x) == len(value)
-
- else:
- value = numpy.ones(len(x), dtype=numpy.float64) * value
-
- plt.setData(x, y, value)
- item = plt.getScatterItem()
- if marker is not None:
- item.setSymbol(marker)
- if size is not None:
- item.setSymbolSize(numpy.sqrt(size))
-
- plt.resetZoom()
-
- plt.show()
- _plots.insert(0, plt.getPlotWidget())
- return plt
-
-
-class _GInputResult(tuple):
- """Object storing :func:`ginput` result
-
- :param position: Selected point coordinates in the plot (x, y)
- :param Item item: Plot item under the selected position
- :param indices: Selected indices in the data of the item.
- For a curve it is a list of indices, for an image it is (row, column)
- :param data: Value of data at selected indices.
- For a curve it is an array of values, for an image it is a single value
- """
-
- def __new__(cls, position, item, indices, data):
- return super(_GInputResult, cls).__new__(cls, position)
-
- def __init__(self, position, item, indices, data):
- self._itemRef = weakref.ref(item) if item is not None else None
- self._indices = numpy.array(indices, copy=True)
- if isinstance(data, abc.Iterable):
- self._data = numpy.array(data, copy=True)
- else:
- self._data = data
-
- def getItem(self):
- """Returns the item at the selected position if any.
-
- :return: plot item under the selected postion.
- It is None if there was no item at that position or if
- it is no more in the plot.
- :rtype: silx.gui.plot.items.Item"""
- return None if self._itemRef is None else self._itemRef()
-
- def getIndices(self):
- """Returns indices in data array at the select position
-
- :return: 1D array of indices for curve and (row, column) for images
- :rtype: numpy.ndarray
- """
- return numpy.array(self._indices, copy=True)
-
- def getData(self):
- """Returns data value at the selected position.
-
- For curves, an array of (x, y) values close to the point is returned.
- For images, either a single value or a RGB(A) array is returned.
-
- :return: 2D array of (x, y) data values for curves (Nx2),
- a single value for data images and RGB(A) array for images.
- """
- if isinstance(self._data, numpy.ndarray):
- return numpy.array(self._data, copy=True)
- else:
- return self._data
-
-
-class _GInputHandler(roi.InteractiveRegionOfInterestManager):
- """Implements :func:`ginput`
-
- :param PlotWidget plot:
- :param int n: Max number of points to request
- :param float timeout: Timeout in seconds
- """
-
- def __init__(self, plot, n, timeout):
- super(_GInputHandler, self).__init__(plot)
-
- self._timeout = timeout
- self.__selections = collections.OrderedDict()
-
- window = plot.window() # Retrieve window containing PlotWidget
- statusBar = window.statusBar()
- self.sigMessageChanged.connect(statusBar.showMessage)
- self.setMaxRois(n)
- self.setValidationMode(self.ValidationMode.AUTO_ENTER)
- self.sigRoiAdded.connect(self.__added)
- self.sigRoiAboutToBeRemoved.connect(self.__removed)
-
- def exec_(self):
- """Request user inputs
-
- :return: List of selection points information
- """
- plot = self.parent()
- if plot is None:
- return
-
- window = plot.window() # Retrieve window containing PlotWidget
-
- # Add ROI point interactive mode action
- for toolbar in window.findChildren(qt.QToolBar):
- if isinstance(toolbar, InteractiveModeToolBar):
- break
- else: # Add a toolbar
- toolbar = qt.QToolBar()
- window.addToolBar(toolbar)
- toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI))
-
- super(_GInputHandler, self).exec_(roiClass=roi_items.PointROI, timeout=self._timeout)
-
- if isinstance(toolbar, InteractiveModeToolBar):
- toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI))
- else:
- toolbar.setParent(None)
-
- return tuple(self.__selections.values())
-
- def __updateSelection(self, roi):
- """Perform picking and update selection list
-
- :param RegionOfInterest roi:
- """
- plot = self.parent()
- if plot is None:
- return # No plot, abort
-
- if not isinstance(roi, roi_items.PointROI):
- # Only handle points
- raise RuntimeError("Unexpected item")
-
- x, y = roi.getPosition()
- xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False)
-
- # Pick item at selected position
- pickingResult = plot._pickTopMost(
- xPixel, yPixel,
- lambda item: isinstance(item, (items.ImageBase, items.Curve)))
-
- if pickingResult is None:
- result = _GInputResult((x, y),
- item=None,
- indices=numpy.array((), dtype=int),
- data=None)
- else:
- item = pickingResult.getItem()
- indices = pickingResult.getIndices(copy=True)
-
- if isinstance(item, items.Curve):
- xData = item.getXData(copy=False)[indices]
- yData = item.getYData(copy=False)[indices]
- result = _GInputResult((x, y),
- item=item,
- indices=indices,
- data=numpy.array((xData, yData)).T)
-
- elif isinstance(item, items.ImageBase):
- row, column = indices[0][0], indices[1][0]
- data = item.getData(copy=False)[row, column]
- result = _GInputResult((x, y),
- item=item,
- indices=(row, column),
- data=data)
-
- self.__selections[roi] = result
-
- def __added(self, roi):
- """Handle new ROI added
-
- :param RegionOfInterest roi:
- """
- if isinstance(roi, roi_items.PointROI):
- # Only handle points
- roi.setName('%d' % len(self.__selections))
- self.__updateSelection(roi)
- roi.sigRegionChanged.connect(self.__regionChanged)
-
- def __removed(self, roi):
- """Handle ROI removed"""
- if self.__selections.pop(roi, None) is not None:
- roi.sigRegionChanged.disconnect(self.__regionChanged)
-
- def __regionChanged(self):
- """Handle update of a ROI"""
- roi = self.sender()
- self.__updateSelection(roi)
-
-
-def ginput(n=1, timeout=30, plot=None):
- """Get input points on a plot.
-
- If no plot is provided, it uses a plot widget created with
- either :func:`silx.sx.plot` or :func:`silx.sx.imshow`.
-
- How to use:
-
- >>> from silx import sx
-
- >>> sx.imshow(image) # Plot the image
- >>> sx.ginput(1) # Request selection on the image plot
- ((0.598, 1.234))
-
- How to get more information about the selected positions:
-
- >>> positions = sx.ginput(1)
-
- >>> positions[0].getData() # Returns value(s) at selected position
-
- >>> positions[0].getIndices() # Returns data indices at selected position
-
- >>> positions[0].getItem() # Returns plot item at selected position
-
- :param int n: Number of points the user need to select
- :param float timeout: Timeout in seconds before ginput returns
- event if selection is not completed
- :param silx.gui.plot.PlotWidget.PlotWidget plot: An optional PlotWidget
- from which to get input
- :return: List of clicked points coordinates (x, y) in plot
- :raise ValueError: If provided plot is not a PlotWidget
- """
- if plot is None:
- # Select most recent visible plot widget
- for widget in _plots:
- if widget.isVisible():
- plot = widget
- break
- else: # If no plot widget is visible, take the most recent one
- try:
- plot = _plots[0]
- except IndexError:
- pass
- else:
- plot.show()
-
- if plot is None:
- _logger.warning('No plot available to perform ginput, create one')
- plot = Plot1D()
- plot.show()
- _plots.insert(0, plot)
-
- plot.raise_() # So window becomes the top level one
-
- _logger.info('Performing ginput with plot widget %s', str(plot))
- handler = _GInputHandler(plot, n, timeout)
- points = handler.exec_()
-
- return points
diff --git a/silx/test/__init__.py b/silx/test/__init__.py
deleted file mode 100644
index 2063ab5..0000000
--- a/silx/test/__init__.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""This package provides access to the full silx test suite.
-
-It is possible to disable tests depending on Qt by setting
-`silx.test.utils.test_options.WITH_QT_TEST = False`
-It will skip all tests from :mod:`silx.test.gui`.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "21/12/2018"
-
-
-import logging
-import unittest
-
-from silx.test.utils import test_options
-
-
-logger = logging.getLogger(__name__)
-
-
-def suite():
- # In case Qt tests are not run, do not load sx as it loads Qt
- # instead add a skipped test class to the suite
- if not test_options.WITH_QT_TEST:
- # Explicitly disabled tests
- msg = "silx.sx tests disabled %s" % test_options.WITH_QT_TEST_REASON
- logger.warning(msg)
-
- class SkipSXTest(unittest.TestCase):
- def runTest(self):
- self.skipTest(test_options.WITH_QT_TEST_REASON)
-
- def test_sx_suite():
- suite = unittest.TestSuite()
- suite.addTest(SkipSXTest())
- return suite
- else:
- from .test_sx import suite as test_sx_suite
-
- from . import test_version
- from . import test_resources
- from ..io import test as test_io
- from ..math import test as test_math
- from ..image import test as test_image
- from ..gui import test as test_gui
- from ..utils import test as test_utils
- from ..opencl import test as test_ocl
- from ..app import test as test_app
-
- test_suite = unittest.TestSuite()
- # test sx first cause qui tests load ipython module
- test_suite.addTest(test_sx_suite())
- test_suite.addTest(test_gui.suite())
- # then test no-gui tests
- test_suite.addTest(test_utils.suite())
- test_suite.addTest(test_version.suite())
- test_suite.addTest(test_resources.suite())
- test_suite.addTest(test_io.suite())
- test_suite.addTest(test_math.suite())
- test_suite.addTest(test_image.suite())
- test_suite.addTest(test_ocl.suite())
- test_suite.addTest(test_app.suite())
-
- return test_suite
-
-
-def run_tests(*args, **kwargs):
- """Run test complete test_suite
-
- Provided arguments are passed to :class:`unittest.TextTestRunner`.
- """
- test_options.configure()
- runner = unittest.TextTestRunner(*args, **kwargs)
- if not runner.run(suite()).wasSuccessful():
- print("Test suite failed")
- return 1
- else:
- print("Test suite succeeded")
- return 0
diff --git a/silx/test/test_resources.py b/silx/test/test_resources.py
deleted file mode 100644
index 7f5f432..0000000
--- a/silx/test/test_resources.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for resource files management."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "08/03/2019"
-
-
-import os
-import unittest
-import shutil
-import tempfile
-
-import silx.resources
-
-
-class TestResources(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- super(TestResources, cls).setUpClass()
-
- cls.tmpDirectory = tempfile.mkdtemp(prefix="resource_")
- os.mkdir(os.path.join(cls.tmpDirectory, "gui"))
- destination_dir = os.path.join(cls.tmpDirectory, "gui", "icons")
- os.mkdir(destination_dir)
- source = silx.resources.resource_filename("gui/icons/zoom-in.png")
- destination = os.path.join(destination_dir, "foo.png")
- shutil.copy(source, destination)
- source = silx.resources.resource_filename("gui/icons/zoom-out.svg")
- destination = os.path.join(destination_dir, "close.png")
- shutil.copy(source, destination)
-
- @classmethod
- def tearDownClass(cls):
- super(TestResources, cls).tearDownClass()
- shutil.rmtree(cls.tmpDirectory)
-
- def setUp(self):
- # Store the original configuration
- self._oldResources = dict(silx.resources._RESOURCE_DIRECTORIES)
- unittest.TestCase.setUp(self)
-
- def tearDown(self):
- unittest.TestCase.tearDown(self)
- # Restiture the original configuration
- silx.resources._RESOURCE_DIRECTORIES = self._oldResources
-
- def test_resource_dir(self):
- """Get a resource directory"""
- icons_dirname = silx.resources.resource_filename('gui/icons/')
- self.assertTrue(os.path.isdir(icons_dirname))
-
- def test_resource_file(self):
- """Get a resource file name"""
- filename = silx.resources.resource_filename('gui/icons/colormap.png')
- self.assertTrue(os.path.isfile(filename))
-
- def test_resource_nonexistent(self):
- """Get a non existent resource"""
- filename = silx.resources.resource_filename('non_existent_file.txt')
- self.assertFalse(os.path.exists(filename))
-
- def test_isdir(self):
- self.assertTrue(silx.resources.is_dir('gui/icons'))
-
- def test_not_isdir(self):
- self.assertFalse(silx.resources.is_dir('gui/icons/colormap.png'))
-
- def test_list_dir(self):
- result = silx.resources.list_dir('gui/icons')
- self.assertTrue(len(result) > 10)
-
- # With prefixed resources
-
- def test_resource_dir_with_prefix(self):
- """Get a resource directory"""
- icons_dirname = silx.resources.resource_filename('silx:gui/icons/')
- self.assertTrue(os.path.isdir(icons_dirname))
-
- def test_resource_file_with_prefix(self):
- """Get a resource file name"""
- filename = silx.resources.resource_filename('silx:gui/icons/colormap.png')
- self.assertTrue(os.path.isfile(filename))
-
- def test_resource_nonexistent_with_prefix(self):
- """Get a non existent resource"""
- filename = silx.resources.resource_filename('silx:non_existent_file.txt')
- self.assertFalse(os.path.exists(filename))
-
- def test_isdir_with_prefix(self):
- self.assertTrue(silx.resources.is_dir('silx:gui/icons'))
-
- def test_not_isdir_with_prefix(self):
- self.assertFalse(silx.resources.is_dir('silx:gui/icons/colormap.png'))
-
- def test_list_dir_with_prefix(self):
- result = silx.resources.list_dir('silx:gui/icons')
- self.assertTrue(len(result) > 10)
-
- # Test new repository
-
- def test_repository_not_exists(self):
- """The resource from 'test' is available"""
- self.assertRaises(ValueError, silx.resources.resource_filename, 'test:foo.png')
-
- def test_adding_test_directory(self):
- """The resource from 'test' is available"""
- silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
- path = silx.resources.resource_filename('test:gui/icons/foo.png')
- self.assertTrue(os.path.exists(path))
-
- def test_adding_test_directory_no_override(self):
- """The resource from 'silx' is still available"""
- silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
- filename1 = silx.resources.resource_filename('gui/icons/close.png')
- filename2 = silx.resources.resource_filename('silx:gui/icons/close.png')
- filename3 = silx.resources.resource_filename('test:gui/icons/close.png')
- self.assertTrue(os.path.isfile(filename1))
- self.assertTrue(os.path.isfile(filename2))
- self.assertTrue(os.path.isfile(filename3))
- self.assertEqual(filename1, filename2)
- self.assertNotEqual(filename1, filename3)
-
- def test_adding_test_directory_non_existing(self):
- """A resource while not exists in test is not available anyway it exists
- in silx"""
- silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
- resource_name = "gui/icons/colormap.png"
- path = silx.resources.resource_filename('test:' + resource_name)
- path2 = silx.resources.resource_filename('silx:' + resource_name)
- self.assertFalse(os.path.exists(path))
- self.assertTrue(os.path.exists(path2))
-
-
-class TestResourcesWithoutPkgResources(TestResources):
-
- @classmethod
- def setUpClass(cls):
- super(TestResourcesWithoutPkgResources, cls).setUpClass()
- cls._old = silx.resources.pkg_resources
- silx.resources.pkg_resources = None
-
- @classmethod
- def tearDownClass(cls):
- silx.resources.pkg_resources = cls._old
- del cls._old
- super(TestResourcesWithoutPkgResources, cls).tearDownClass()
-
-
-class TestResourcesWithCustomDirectory(TestResources):
-
- @classmethod
- def setUpClass(cls):
- super(TestResourcesWithCustomDirectory, cls).setUpClass()
- cls._old = silx.resources._RESOURCES_DIR
- base = os.path.dirname(silx.resources.__file__)
- silx.resources._RESOURCES_DIR = base
-
- @classmethod
- def tearDownClass(cls):
- silx.resources._RESOURCES_DIR = cls._old
- del cls._old
- super(TestResourcesWithCustomDirectory, cls).tearDownClass()
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestResources))
- test_suite.addTest(loadTests(TestResourcesWithoutPkgResources))
- test_suite.addTest(loadTests(TestResourcesWithCustomDirectory))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/test/test_sx.py b/silx/test/test_sx.py
deleted file mode 100644
index a32cc06..0000000
--- a/silx/test/test_sx.py
+++ /dev/null
@@ -1,292 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "06/11/2018"
-
-
-import logging
-import unittest
-import numpy
-
-from silx.utils.testutils import ParametricTestCase
-from silx.test.utils import test_options
-
-from silx.gui import qt
-# load TestCaseQt before sx
-from silx.gui.utils.testutils import TestCaseQt
-from silx.gui.colors import rgba
-from silx.gui.colors import Colormap
-
-
-_logger = logging.getLogger(__name__)
-
-
-class SXTest(TestCaseQt, ParametricTestCase):
- """Test the sx module"""
-
- def _expose_and_close(self, plot):
- self.qWaitForWindowExposed(plot)
- self.qapp.processEvents()
- plot.setAttribute(qt.Qt.WA_DeleteOnClose)
- plot.close()
-
- def test_plot(self):
- """Test plot function"""
- from silx import sx # Lazy loading to avoid it to create QApplication
-
- y = numpy.random.random(100)
- x = numpy.arange(len(y)) * 0.5
-
- # Nothing
- plt = sx.plot()
- self._expose_and_close(plt)
-
- # y
- plt = sx.plot(y, title='y')
- self._expose_and_close(plt)
-
- # y, style
- plt = sx.plot(y, 'blued ', title='y, "blued "')
- self._expose_and_close(plt)
-
- # x, y
- plt = sx.plot(x, y, title='x, y')
- self._expose_and_close(plt)
-
- # x, y, style
- plt = sx.plot(x, y, 'ro-', xlabel='x', title='x, y, "ro-"')
- self._expose_and_close(plt)
-
- # x, y, style, y
- plt = sx.plot(x, y, 'ro-', y ** 2, xlabel='x', ylabel='y',
- title='x, y, "ro-", y ** 2')
- self._expose_and_close(plt)
-
- # x, y, style, y, style
- plt = sx.plot(x, y, 'ro-', y ** 2, 'b--',
- title='x, y, "ro-", y ** 2, "b--"')
- self._expose_and_close(plt)
-
- # x, y, style, x, y, style
- plt = sx.plot(x, y, 'ro-', x, y ** 2, 'b--',
- title='x, y, "ro-", x, y ** 2, "b--"')
- self._expose_and_close(plt)
-
- # x, y, x, y
- plt = sx.plot(x, y, x, y ** 2, title='x, y, x, y ** 2')
- self._expose_and_close(plt)
-
- def test_imshow(self):
- """Test imshow function"""
- from silx import sx # Lazy loading to avoid it to create QApplication
-
- img = numpy.arange(100.).reshape(10, 10) + 1
-
- # Nothing
- plt = sx.imshow()
- self._expose_and_close(plt)
-
- # image
- plt = sx.imshow(img)
- self._expose_and_close(plt)
-
- # image, named cmap
- plt = sx.imshow(img, cmap='jet', title='jet cmap')
- self._expose_and_close(plt)
-
- # image, custom colormap
- plt = sx.imshow(img, cmap=Colormap(), title='custom colormap')
- self._expose_and_close(plt)
-
- # image, log cmap
- plt = sx.imshow(img, norm='log', title='log cmap')
- self._expose_and_close(plt)
-
- # image, fixed range
- plt = sx.imshow(img, vmin=10, vmax=20,
- title='[10,20] cmap')
- self._expose_and_close(plt)
-
- # image, keep ratio
- plt = sx.imshow(img, aspect=True,
- title='keep ratio')
- self._expose_and_close(plt)
-
- # image, change origin and scale
- plt = sx.imshow(img, origin=(10, 10), scale=(2, 2),
- title='origin=(10, 10), scale=(2, 2)')
- self._expose_and_close(plt)
-
- # image, origin='lower'
- plt = sx.imshow(img, origin='upper', title='origin="lower"')
- self._expose_and_close(plt)
-
- def test_scatter(self):
- """Test scatter function"""
- from silx import sx # Lazy loading to avoid it to create QApplication
-
- x = numpy.arange(100)
- y = numpy.arange(100)
- values = numpy.arange(100)
-
- # simple scatter
- plt = sx.scatter(x, y, values)
- self._expose_and_close(plt)
-
- # No value
- plt = sx.scatter(x, y, values)
- self._expose_and_close(plt)
-
- # single value
- plt = sx.scatter(x, y, 10.)
- self._expose_and_close(plt)
-
- # set size
- plt = sx.scatter(x, y, values, size=20)
- self._expose_and_close(plt)
-
- # set colormap
- plt = sx.scatter(x, y, values, cmap='jet')
- self._expose_and_close(plt)
-
- # set colormap range
- plt = sx.scatter(x, y, values, vmin=2, vmax=50)
- self._expose_and_close(plt)
-
- # set colormap normalisation
- plt = sx.scatter(x, y, values, norm='log')
- self._expose_and_close(plt)
-
- def test_ginput(self):
- """Test ginput function
-
- This does NOT perform interactive tests
- """
- from silx import sx # Lazy loading to avoid it to create QApplication
-
- for create_plot in (sx.plot, sx.imshow, sx.scatter):
- with self.subTest(create_plot.__name__):
- plt = create_plot()
- self.qWaitForWindowExposed(plt)
- self.qapp.processEvents()
-
- result = sx.ginput(1, timeout=0.1)
- self.assertEqual(len(result), 0)
-
- plt.setAttribute(qt.Qt.WA_DeleteOnClose)
- plt.close()
-
- @unittest.skipUnless(test_options.WITH_GL_TEST,
- test_options.WITH_GL_TEST_REASON)
- def test_contour3d(self):
- """Test contour3d function"""
- from silx import sx # Lazy loading to avoid it to create QApplication
-
- coords = numpy.linspace(-10, 10, 64)
- z = coords.reshape(-1, 1, 1)
- y = coords.reshape(1, -1, 1)
- x = coords.reshape(1, 1, -1)
- data = numpy.sin(x * y * z) / (x * y * z)
-
- # Just data
- window = sx.contour3d(data)
-
- isosurfaces = window.getIsosurfaces()
- self.assertEqual(len(isosurfaces), 1)
-
- if not window.getPlot3DWidget().isValid():
- self.skipTest("OpenGL context is not valid")
-
- # N contours + color
- colors = ['red', 'green', 'blue']
- window = sx.contour3d(data, copy=False, contours=len(colors),
- color=colors)
-
- isosurfaces = window.getIsosurfaces()
- self.assertEqual(len(isosurfaces), len(colors))
- for iso, color in zip(isosurfaces, colors):
- self.assertEqual(rgba(iso.getColor()), rgba(color))
-
- # by isolevel, single color
- contours = 0.2, 0.5
- window = sx.contour3d(data, copy=False, contours=contours,
- color='yellow')
-
- isosurfaces = window.getIsosurfaces()
- self.assertEqual(len(isosurfaces), len(contours))
- for iso, level in zip(isosurfaces, contours):
- self.assertEqual(iso.getLevel(), level)
- self.assertEqual(rgba(iso.getColor()),
- rgba('yellow'))
-
- # Single isolevel, colormap
- window = sx.contour3d(data, copy=False, contours=0.5,
- colormap='gray', vmin=0.6, opacity=0.4)
-
- isosurfaces = window.getIsosurfaces()
- self.assertEqual(len(isosurfaces), 1)
- self.assertEqual(isosurfaces[0].getLevel(), 0.5)
- self.assertEqual(rgba(isosurfaces[0].getColor()),
- (0., 0., 0., 0.4))
-
- @unittest.skipUnless(test_options.WITH_GL_TEST,
- test_options.WITH_GL_TEST_REASON)
- def test_points3d(self):
- """Test points3d function"""
- from silx import sx # Lazy loading to avoid it to create QApplication
-
- x = numpy.random.random(1024)
- y = numpy.random.random(1024)
- z = numpy.random.random(1024)
- values = numpy.random.random(1024)
-
- # 3D positions, no value
- window = sx.points3d(x, y, z)
-
- if not window.getSceneWidget().isValid():
- self.skipTest("OpenGL context is not valid")
-
- # 3D positions, values
- window = sx.points3d(x, y, z, values, mode='2dsquare',
- colormap='magma', vmin=0.4, vmax=0.5)
-
- # 2D positions, no value
- window = sx.points3d(x, y)
-
- # 2D positions, values
- window = sx.points3d(x, y, values=values, mode=',',
- colormap='magma', vmin=0.4, vmax=0.5)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(SXTest))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/test/test_version.py b/silx/test/test_version.py
deleted file mode 100644
index bb91e4e..0000000
--- a/silx/test/test_version.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Basic test of top-level package import and existence of version info."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/02/2016"
-
-import unittest
-
-import silx
-
-
-class TestVersion(unittest.TestCase):
- def test_version(self):
- self.assertTrue(isinstance(silx.version, str))
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestVersion))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/test/utils.py b/silx/test/utils.py
deleted file mode 100644
index 77746c6..0000000
--- a/silx/test/utils.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Utilities for writing tests.
-
-- :func:`temp_dir` provides a with context to create/delete a temporary
- directory.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "03/01/2019"
-
-
-import sys
-import contextlib
-import os
-import numpy
-import shutil
-import tempfile
-from ..resources import ExternalResources
-
-
-utilstest = ExternalResources(project="silx",
- url_base="http://www.silx.org/pub/silx/",
- env_key="SILX_DATA",
- timeout=60)
-"This is the instance to be used. Singleton-like feature provided by module"
-
-
-class _TestOptions(object):
-
- def __init__(self):
- self.WITH_QT_TEST = True
- """Qt tests are included"""
-
- self.WITH_QT_TEST_REASON = ""
- """Reason for Qt tests are disabled if any"""
-
- self.WITH_OPENCL_TEST = True
- """OpenCL tests are included"""
-
- self.WITH_GL_TEST = True
- """OpenGL tests are included"""
-
- self.WITH_GL_TEST_REASON = ""
- """Reason for OpenGL tests are disabled if any"""
-
- self.TEST_LOW_MEM = False
- """Skip tests using too much memory"""
-
- def configure(self, parsed_options=None):
- """Configure the TestOptions class from the command line arguments and the
- environment variables
- """
- if parsed_options is not None and not parsed_options.gui:
- self.WITH_QT_TEST = False
- self.WITH_QT_TEST_REASON = "Skipped by command line"
- elif os.environ.get('WITH_QT_TEST', 'True') == 'False':
- self.WITH_QT_TEST = False
- self.WITH_QT_TEST_REASON = "Skipped by WITH_QT_TEST env var"
- elif sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
- self.WITH_QT_TEST = False
- self.WITH_QT_TEST_REASON = "DISPLAY env variable not set"
-
- if (parsed_options is not None and not parsed_options.opencl) or os.environ.get('SILX_OPENCL', 'True') == 'False':
- self.WITH_OPENCL_TEST = False
- # That's an easy way to skip OpenCL tests
- # It disable the use of OpenCL on the full silx project
- os.environ['SILX_OPENCL'] = "False"
-
- if parsed_options is not None and not parsed_options.opengl:
- self.WITH_GL_TEST = False
- self.WITH_GL_TEST_REASON = "Skipped by command line"
- elif os.environ.get('WITH_GL_TEST', 'True') == 'False':
- self.WITH_GL_TEST = False
- self.WITH_GL_TEST_REASON = "Skipped by WITH_GL_TEST env var"
- elif sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
- self.WITH_GL_TEST = False
- self.WITH_GL_TEST_REASON = "DISPLAY env variable not set"
- else:
- try:
- import OpenGL
- except ImportError:
- self.WITH_GL_TEST = False
- self.WITH_GL_TEST_REASON = "OpenGL package not available"
-
- if (parsed_options is not None and parsed_options.low_mem) or os.environ.get('SILX_TEST_LOW_MEM', 'True') == 'False':
- self.TEST_LOW_MEM = True
-
- if self.WITH_QT_TEST:
- try:
- from silx.gui import qt
- except ImportError:
- pass
- else:
- if sys.platform == "win32" and qt.qVersion() == "5.9.2":
- self.SKIP_TEST_FOR_ISSUE_936 = True
-
- def add_parser_argument(self, parser):
- """Add extrat arguments to the test argument parser
-
- :param ArgumentParser parser: An argument parser
- """
-
- parser.add_argument("-x", "--no-gui", dest="gui", default=True,
- action="store_false",
- help="Disable the test of the graphical use interface")
- parser.add_argument("-g", "--no-opengl", dest="opengl", default=True,
- action="store_false",
- help="Disable tests using OpenGL")
- parser.add_argument("-o", "--no-opencl", dest="opencl", default=True,
- action="store_false",
- help="Disable the test of the OpenCL part")
- parser.add_argument("-l", "--low-mem", dest="low_mem", default=False,
- action="store_true",
- help="Disable test with large memory consumption (>100Mbyte")
-
-
-test_options = _TestOptions()
-"""Singleton providing configuration information for all the tests"""
-
-
-# Temporary directory context #################################################
-
-@contextlib.contextmanager
-def temp_dir():
- """with context providing a temporary directory.
-
- >>> import os.path
- >>> with temp_dir() as tmp:
- ... print(os.path.isdir(tmp)) # Use tmp directory
- """
- tmp_dir = tempfile.mkdtemp()
- try:
- yield tmp_dir
- finally:
- shutil.rmtree(tmp_dir)
-
-
-# Synthetic data and random noise #############################################
-def add_gaussian_noise(y, stdev=1., mean=0.):
- """Add random gaussian noise to synthetic data.
-
- :param ndarray y: Array of synthetic data
- :param float mean: Mean of the gaussian distribution of noise.
- :param float stdev: Standard deviation of the gaussian distribution of
- noise.
- :return: Array of data with noise added
- """
- noise = numpy.random.normal(mean, stdev, size=y.size)
- noise.shape = y.shape
- return y + noise
-
-
-def add_poisson_noise(y):
- """Add random noise from a poisson distribution to synthetic data.
-
- :param ndarray y: Array of synthetic data
- :return: Array of data with noise added
- """
- yn = numpy.random.poisson(y)
- yn.shape = y.shape
- return yn
-
-
-def add_relative_noise(y, max_noise=5.):
- """Add relative random noise to synthetic data. The maximum noise level
- is given in percents.
-
- An array of noise in the interval [-max_noise, max_noise] (continuous
- uniform distribution) is generated, and applied to the data the
- following way:
-
- :math:`yn = y * (1. + noise / 100.)`
-
- :param ndarray y: Array of synthetic data
- :param float max_noise: Maximum percentage of noise
- :return: Array of data with noise added
- """
- noise = max_noise * (2 * numpy.random.random(size=y.size) - 1)
- noise.shape = y.shape
- return y * (1. + noise / 100.)
diff --git a/silx/third_party/setup.py b/silx/third_party/setup.py
deleted file mode 100644
index dd3d302..0000000
--- a/silx/third_party/setup.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# coding: ascii
-#
-# JK: Numpy.distutils which imports this does not handle utf-8 in version<1.12
-#
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-__authors__ = ["Valentin Valls"]
-__license__ = "MIT"
-__date__ = "23/04/2018"
-
-import os
-from numpy.distutils.misc_util import Configuration
-
-
-def configuration(parent_package='', top_path=None):
- config = Configuration('third_party', parent_package, top_path)
- # includes _local only if it is available
- local_path = os.path.join(top_path, "silx", "third_party", "_local")
- if os.path.exists(local_path):
- config.add_subpackage('_local')
- config.add_subpackage('_local.scipy_spatial')
- return config
-
-
-if __name__ == "__main__":
- from numpy.distutils.core import setup
- setup(configuration=configuration)
diff --git a/silx/utils/ExternalResources.py b/silx/utils/ExternalResources.py
deleted file mode 100644
index e21381c..0000000
--- a/silx/utils/ExternalResources.py
+++ /dev/null
@@ -1,320 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Helper to access to external resources.
-"""
-
-__authors__ = ["Thomas Vincent", "J. Kieffer"]
-__license__ = "MIT"
-__date__ = "08/03/2019"
-
-
-import os
-import threading
-import json
-import logging
-import tempfile
-import unittest
-import six
-
-logger = logging.getLogger(__name__)
-
-
-class ExternalResources(object):
- """Utility class which allows to download test-data from www.silx.org
- and manage the temporary data during the tests.
-
- """
-
- def __init__(self, project,
- url_base,
- env_key=None,
- timeout=60):
- """Constructor of the class
-
- :param str project: name of the project, like "silx"
- :param str url_base: base URL for the data, like "http://www.silx.org/pub"
- :param str env_key: name of the environment variable which contains the
- test_data directory, like "SILX_DATA".
- If None (default), then the name of the
- environment variable is built from the project argument:
- "<PROJECT>_DATA".
- The environment variable is optional: in case it is not set,
- a directory in the temporary folder is used.
- :param timeout: time in seconds before it breaks
- """
- self.project = project
- self._initialized = False
- self.sem = threading.Semaphore()
-
- self.env_key = env_key or (self.project.upper() + "_TESTDATA")
- self.url_base = url_base
- self.all_data = set()
- self.timeout = timeout
- self._data_home = None
-
- @property
- def data_home(self):
- """Returns the data_home path and make sure it exists in the file
- system."""
- if self._data_home is not None:
- return self._data_home
-
- data_home = os.environ.get(self.env_key)
- if data_home is None:
- try:
- import getpass
- name = getpass.getuser()
- except Exception:
- if "getlogin" in dir(os):
- name = os.getlogin()
- elif "USER" in os.environ:
- name = os.environ["USER"]
- elif "USERNAME" in os.environ:
- name = os.environ["USERNAME"]
- else:
- name = "uid" + str(os.getuid())
-
- basename = "%s_testdata_%s" % (self.project, name)
- data_home = os.path.join(tempfile.gettempdir(), basename)
- if not os.path.exists(data_home):
- os.makedirs(data_home)
- self._data_home = data_home
- return data_home
-
- def _initialize_data(self):
- """Initialize for downloading test data"""
- if not self._initialized:
- with self.sem:
- if not self._initialized:
- self.testdata = os.path.join(self.data_home, "all_testdata.json")
- if os.path.exists(self.testdata):
- with open(self.testdata) as f:
- self.all_data = set(json.load(f))
- self._initialized = True
-
- def clean_up(self):
- pass
-
- def getfile(self, filename):
- """Downloads the requested file from web-server available
- at https://www.silx.org/pub/silx/
-
- :param: relative name of the image.
- :return: full path of the locally saved file.
- """
- logger.debug("ExternalResources.getfile('%s')", filename)
-
- if not self._initialized:
- self._initialize_data()
-
- fullfilename = os.path.abspath(os.path.join(self.data_home, filename))
-
- if not os.path.isfile(fullfilename):
- logger.debug("Trying to download image %s, timeout set to %ss",
- filename, self.timeout)
- dictProxies = {}
- if "http_proxy" in os.environ:
- dictProxies['http'] = os.environ["http_proxy"]
- dictProxies['https'] = os.environ["http_proxy"]
- if "https_proxy" in os.environ:
- dictProxies['https'] = os.environ["https_proxy"]
- if dictProxies:
- proxy_handler = six.moves.urllib.request.ProxyHandler(dictProxies)
- opener = six.moves.urllib.request.build_opener(proxy_handler).open
- else:
- opener = six.moves.urllib.request.urlopen
-
- logger.debug("wget %s/%s", self.url_base, filename)
- try:
- data = opener("%s/%s" % (self.url_base, filename),
- data=None, timeout=self.timeout).read()
- logger.info("Image %s successfully downloaded.", filename)
- except six.moves.urllib.error.URLError:
- raise unittest.SkipTest("network unreachable.")
-
- if not os.path.isdir(os.path.dirname(fullfilename)):
- # Create sub-directory if needed
- os.makedirs(os.path.dirname(fullfilename))
-
- try:
- with open(fullfilename, "wb") as outfile:
- outfile.write(data)
- except IOError:
- raise IOError("unable to write downloaded \
- data to disk at %s" % self.data_home)
-
- if not os.path.isfile(fullfilename):
- raise RuntimeError(
- """Could not automatically download test images %s!
- If you are behind a firewall, please set both environment variable http_proxy and https_proxy.
- This even works under windows !
- Otherwise please try to download the images manually from
- %s/%s""" % (filename, self.url_base, filename))
-
- if filename not in self.all_data:
- self.all_data.add(filename)
- image_list = list(self.all_data)
- image_list.sort()
- try:
- with open(self.testdata, "w") as fp:
- json.dump(image_list, fp, indent=4)
- except IOError:
- logger.debug("Unable to save JSON list")
-
- return fullfilename
-
- def getdir(self, dirname):
- """Downloads the requested tarball from the server
- https://www.silx.org/pub/silx/
- and unzips it into the data directory
-
- :param: relative name of the image.
- :return: list of files with their full path.
- """
- lodn = dirname.lower()
- if (lodn.endswith("tar") or lodn.endswith("tgz") or
- lodn.endswith("tbz2") or lodn.endswith("tar.gz") or
- lodn.endswith("tar.bz2")):
- import tarfile
- engine = tarfile.TarFile.open
- elif lodn.endswith("zip"):
- import zipfile
- engine = zipfile.ZipFile
- else:
- raise RuntimeError("Unsupported archive format. Only tar and zip "
- "are currently supported")
- full_path = self.getfile(dirname)
- with engine(full_path, mode="r") as fd:
- output = os.path.join(self.data_home, dirname + "__content")
- fd.extractall(output)
- if lodn.endswith("zip"):
- result = [os.path.join(output, i) for i in fd.namelist()]
- else:
- result = [os.path.join(output, i) for i in fd.getnames()]
- return result
-
- def get_file_and_repack(self, filename):
- """
- Download the requested file, decompress and repack it to bz2 and gz.
-
- :param str filename: name of the image.
- :rtype: str
- :return: full path of the locally saved file
- """
- if not self._initialized:
- self._initialize_data()
- if filename not in self.all_data:
- self.all_data.add(filename)
- image_list = list(self.all_data)
- image_list.sort()
- try:
- with open(self.testdata, "w") as fp:
- json.dump(image_list, fp, indent=4)
- except IOError:
- logger.debug("Unable to save JSON list")
- baseimage = os.path.basename(filename)
- logger.info("UtilsTest.getimage('%s')" % baseimage)
-
- if not os.path.exists(self.data_home):
- os.makedirs(self.data_home)
- fullimagename = os.path.abspath(os.path.join(self.data_home, baseimage))
-
- if baseimage.endswith(".bz2"):
- bzip2name = baseimage
- basename = baseimage[:-4]
- gzipname = basename + ".gz"
- elif baseimage.endswith(".gz"):
- gzipname = baseimage
- basename = baseimage[:-3]
- bzip2name = basename + ".bz2"
- else:
- basename = baseimage
- gzipname = baseimage + "gz2"
- bzip2name = basename + ".bz2"
-
- fullimagename_gz = os.path.abspath(os.path.join(self.data_home, gzipname))
- fullimagename_raw = os.path.abspath(os.path.join(self.data_home, basename))
- fullimagename_bz2 = os.path.abspath(os.path.join(self.data_home, bzip2name))
-
- # The files are recreated from the bz2 file
- if not os.path.isfile(fullimagename_bz2):
- self.getfile(bzip2name)
- if not os.path.isfile(fullimagename_bz2):
- raise RuntimeError(
- """Could not automatically download test images %s!
- If you are behind a firewall, please set the environment variable http_proxy.
- Otherwise please try to download the images manually from
- %s""" % (self.url_base, filename))
-
- try:
- import bz2
- except ImportError:
- raise RuntimeError("bz2 library is needed to decompress data")
- try:
- import gzip
- except ImportError:
- gzip = None
-
- raw_file_exists = os.path.isfile(fullimagename_raw)
- gz_file_exists = os.path.isfile(fullimagename_gz)
- if not raw_file_exists or not gz_file_exists:
- with open(fullimagename_bz2, "rb") as f:
- data = f.read()
- decompressed = bz2.decompress(data)
-
- if not raw_file_exists:
- try:
- with open(fullimagename_raw, "wb") as fullimage:
- fullimage.write(decompressed)
- except IOError:
- raise IOError("unable to write decompressed \
- data to disk at %s" % self.data_home)
-
- if not gz_file_exists:
- if gzip is None:
- raise RuntimeError("gzip library is expected to recompress data")
- try:
- gzip.open(fullimagename_gz, "wb").write(decompressed)
- except IOError:
- raise IOError("unable to write gzipped \
- data to disk at %s" % self.data_home)
-
- return fullimagename
-
- def download_all(self, imgs=None):
- """Download all data needed for the test/benchmarks
-
- :param imgs: list of files to download, by default all
- :return: list of path with all files
- """
- if not self._initialized:
- self._initialize_data()
- if not imgs:
- imgs = self.all_data
- res = []
- for fn in imgs:
- logger.info("Downloading from silx.org: %s", fn)
- res.append(self.getfile(fn))
- return res
diff --git a/silx/utils/_have_openmp.pxd b/silx/utils/_have_openmp.pxd
deleted file mode 100644
index 40a2857..0000000
--- a/silx/utils/_have_openmp.pxd
+++ /dev/null
@@ -1,49 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-"""
-Store in a Cython module if it was compiled with OpenMP
-
-You have to patch the setup module like that:
-
-.. code-block:: python
-
- silx_include = os.path.join(top_path, "silx", "utils", "include")
- config.add_extension('my_extension',
- include_dirs=[silx_include],
- ...)
-
-Then you can include it like that in your Cython module:
-
-.. code-block:: python
-
- include "../../utils/_have_openmp.pxi"
-
-"""
-
-
-cdef extern from "silx_store_openmp.h":
- int COMPILED_WITH_OPENMP
-_COMPILED_WITH_OPENMP = COMPILED_WITH_OPENMP
diff --git a/silx/utils/array_like.py b/silx/utils/array_like.py
deleted file mode 100644
index 1a2e72e..0000000
--- a/silx/utils/array_like.py
+++ /dev/null
@@ -1,596 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Functions and classes for array-like objects, implementing common numpy
-array features for datasets or nested sequences, while trying to avoid copying
-data.
-
-Classes:
-
- - :class:`DatasetView`: Similar to a numpy view, to access
- a h5py dataset as if it was transposed, without casting it into a
- numpy array (this lets h5py handle reading the data from the
- file into memory, as needed).
- - :class:`ListOfImages`: Similar to a numpy view, to access
- a list of 2D numpy arrays as if it was a 3D array (possibly transposed),
- without casting it into a numpy array.
-
-Functions:
-
- - :func:`is_array`
- - :func:`is_list_of_arrays`
- - :func:`is_nested_sequence`
- - :func:`get_shape`
- - :func:`get_dtype`
- - :func:`get_concatenated_dtype`
-
-"""
-
-from __future__ import absolute_import, print_function, division
-
-import sys
-
-import numpy
-import six
-import numbers
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "26/04/2017"
-
-
-def is_array(obj):
- """Return True if object implements necessary attributes to be
- considered similar to a numpy array.
-
- Attributes needed are "shape", "dtype", "__getitem__"
- and "__array__".
-
- :param obj: Array-like object (numpy array, h5py dataset...)
- :return: boolean
- """
- # add more required attribute if necessary
- for attr in ("shape", "dtype", "__array__", "__getitem__"):
- if not hasattr(obj, attr):
- return False
- return True
-
-
-def is_list_of_arrays(obj):
- """Return True if object is a sequence of numpy arrays,
- e.g. a list of images as 2D arrays.
-
- :param obj: list of arrays
- :return: boolean"""
- # object must not be a numpy array
- if is_array(obj):
- return False
-
- # object must have a __len__ method
- if not hasattr(obj, "__len__"):
- return False
-
- # all elements in sequence must be arrays
- for arr in obj:
- if not is_array(arr):
- return False
-
- return True
-
-
-def is_nested_sequence(obj):
- """Return True if object is a nested sequence.
-
- A simple 1D sequence is considered to be a nested sequence.
-
- Numpy arrays and h5py datasets are not considered to be nested sequences.
-
- To test if an object is a nested sequence in a more general sense,
- including arrays and datasets, use::
-
- is_nested_sequence(obj) or is_array(obj)
-
- :param obj: nested sequence (numpy array, h5py dataset...)
- :return: boolean"""
- # object must not be a numpy array
- if is_array(obj):
- return False
-
- if not hasattr(obj, "__len__"):
- return False
-
- # obj must not be a list of (lists of) numpy arrays
- subsequence = obj
- while hasattr(subsequence, "__len__"):
- if is_array(subsequence):
- return False
- # strings cause infinite loops
- if isinstance(subsequence, six.string_types + (six.binary_type, )):
- return True
- subsequence = subsequence[0]
-
- # object has __len__ and is not an array
- return True
-
-
-def get_shape(array_like):
- """Return shape of an array like object.
-
- In case the object is a nested sequence but not an array or dataset
- (list of lists, tuples...), the size of each dimension is assumed to be
- uniform, and is deduced from the length of the first sequence.
-
- :param array_like: Array like object: numpy array, hdf5 dataset,
- multi-dimensional sequence
- :return: Shape of array, as a tuple of integers
- """
- if hasattr(array_like, "shape"):
- return array_like.shape
-
- shape = []
- subsequence = array_like
- while hasattr(subsequence, "__len__"):
- shape.append(len(subsequence))
- # strings cause infinite loops
- if isinstance(subsequence, six.string_types + (six.binary_type, )):
- break
- subsequence = subsequence[0]
-
- return tuple(shape)
-
-
-def get_dtype(array_like):
- """Return dtype of an array like object.
-
- In the case of a nested sequence, the type of the first value
- is inspected.
-
- :param array_like: Array like object: numpy array, hdf5 dataset,
- multi-dimensional nested sequence
- :return: numpy dtype of object
- """
- if hasattr(array_like, "dtype"):
- return array_like.dtype
-
- subsequence = array_like
- while hasattr(subsequence, "__len__"):
- # strings cause infinite loops
- if isinstance(subsequence, six.string_types + (six.binary_type, )):
- break
- subsequence = subsequence[0]
-
- return numpy.dtype(type(subsequence))
-
-
-def get_concatenated_dtype(arrays):
- """Return dtype of array resulting of concatenation
- of a list of arrays (without actually concatenating
- them).
-
- :param arrays: list of numpy arrays
- :return: resulting dtype after concatenating arrays
- """
- dtypes = {a.dtype for a in arrays}
- dummy = []
- for dt in dtypes:
- dummy.append(numpy.zeros((1, 1), dtype=dt))
- return numpy.array(dummy).dtype
-
-
-class ListOfImages(object):
- """This class provides a way to access values and slices in a stack of
- images stored as a list of 2D numpy arrays, without creating a 3D numpy
- array first.
-
- A transposition can be specified, as a 3-tuple of dimensions in the wanted
- order. For example, to transpose from ``xyz`` ``(0, 1, 2)`` into ``yzx``,
- the transposition tuple is ``(1, 2, 0)``
-
- All the 2D arrays in the list must have the same shape.
-
- The global dtype of the stack of images is the one that would be obtained
- by casting the list of 2D arrays into a 3D numpy array.
-
- :param images: list of 2D numpy arrays, or :class:`ListOfImages` object
- :param transposition: Tuple of dimension numbers in the wanted order
- """
- def __init__(self, images, transposition=None):
- """
-
- """
- super(ListOfImages, self).__init__()
-
- # if images is a ListOfImages instance, get the underlying data
- # as a list of 2D arrays
- if isinstance(images, ListOfImages):
- images = images.images
-
- # test stack of images is as expected
- assert is_list_of_arrays(images), \
- "Image stack must be a list of arrays"
- image0_shape = images[0].shape
- for image in images:
- assert image.ndim == 2, \
- "Images must be 2D numpy arrays"
- assert image.shape == image0_shape, \
- "All images must have the same shape"
-
- self.images = images
- """List of images"""
-
- self.shape = (len(images), ) + image0_shape
- """Tuple of array dimensions"""
- self.dtype = get_concatenated_dtype(images)
- """Data-type of the global array"""
- self.ndim = 3
- """Number of array dimensions"""
-
- self.size = len(images) * image0_shape[0] * image0_shape[1]
- """Number of elements in the array."""
-
- self.transposition = list(range(self.ndim))
- """List of dimension indices, in an order depending on the
- specified transposition. By default this is simply
- [0, ..., self.ndim], but it can be changed by specifying a different
- ``transposition`` parameter at initialization.
-
- Use :meth:`transpose`, to create a new :class:`ListOfImages`
- with a different :attr:`transposition`.
- """
-
- if transposition is not None:
- assert len(transposition) == self.ndim
- assert set(transposition) == set(list(range(self.ndim))), \
- "Transposition must be a sequence containing all dimensions"
- self.transposition = transposition
- self.__sort_shape()
-
- def __sort_shape(self):
- """Sort shape in the order defined in :attr:`transposition`
- """
- new_shape = tuple(self.shape[dim] for dim in self.transposition)
- self.shape = new_shape
-
- def __sort_indices(self, indices):
- """Return array indices sorted in the order needed
- to access data in the original non-transposed images.
-
- :param indices: Tuple of ndim indices, in the order needed
- to access the transposed view
- :return: Sorted tuple of indices, to access original data
- """
- assert len(indices) == self.ndim
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(self.transposition, indices)))
- return sorted_indices
-
- def __array__(self, dtype=None):
- """Cast the images into a numpy array, and return it.
-
- If a transposition has been done on this images, return
- a transposed view of a numpy array."""
- return numpy.transpose(numpy.array(self.images, dtype=dtype),
- self.transposition)
-
- def __len__(self):
- return self.shape[0]
-
- def transpose(self, transposition=None):
- """Return a re-ordered (dimensions permutated)
- :class:`ListOfImages`.
-
- The returned object refers to
- the same images but with a different :attr:`transposition`.
-
- :param List[int] transposition: List/tuple of dimension numbers in the
- wanted order.
- If ``None`` (default), reverse the dimensions.
- :return: new :class:`ListOfImages` object
- """
- # by default, reverse the dimensions
- if transposition is None:
- transposition = list(reversed(self.transposition))
-
- # If this ListOfImages is already transposed, sort new transposition
- # relative to old transposition
- elif list(self.transposition) != list(range(self.ndim)):
- transposition = [self.transposition[i] for i in transposition]
-
- return ListOfImages(self.images,
- transposition)
-
- @property
- def T(self):
- """
- Same as self.transpose()
-
- :return: DatasetView with dimensions reversed."""
- return self.transpose()
-
- def __getitem__(self, item):
- """Handle a subset of numpy indexing with regards to the dimension
- order as specified in :attr:`transposition`
-
- Following features are **not supported**:
-
- - fancy indexing using numpy arrays
- - using ellipsis objects
-
- :param item: Index
- :return: value or slice as a numpy array
- """
- # 1-D slicing -> n-D slicing (n=1)
- if not hasattr(item, "__len__"):
- # first dimension index is given
- item = [item]
- # following dimensions are indexed with : (all elements)
- item += [slice(None) for _i in range(self.ndim - 1)]
-
- # n-dimensional slicing
- if len(item) != self.ndim:
- raise IndexError(
- "N-dim slicing requires a tuple of N indices/slices. " +
- "Needed dimensions: %d" % self.ndim)
-
- # get list of indices sorted in the original images order
- sorted_indices = self.__sort_indices(item)
- list_idx, array_idx = sorted_indices[0], sorted_indices[1:]
-
- images_selection = self.images[list_idx]
-
- # now we must transpose the output data
- output_dimensions = []
- frozen_dimensions = []
- for i, idx in enumerate(item):
- # slices and sequences
- if not isinstance(idx, numbers.Integral):
- output_dimensions.append(self.transposition[i])
- # regular integer index
- else:
- # whenever a dimension is fixed (indexed by an integer)
- # the number of output dimension is reduced
- frozen_dimensions.append(self.transposition[i])
-
- # decrement output dimensions that are above frozen dimensions
- for frozen_dim in reversed(sorted(frozen_dimensions)):
- for i, out_dim in enumerate(output_dimensions):
- if out_dim > frozen_dim:
- output_dimensions[i] -= 1
-
- assert (len(output_dimensions) + len(frozen_dimensions)) == self.ndim
- assert set(output_dimensions) == set(range(len(output_dimensions)))
-
- # single list elements selected
- if isinstance(images_selection, numpy.ndarray):
- return numpy.transpose(images_selection[array_idx],
- axes=output_dimensions)
- # muliple list elements selected
- else:
- # apply selection first
- output_stack = []
- for img in images_selection:
- output_stack.append(img[array_idx])
- # then cast into a numpy array, and transpose
- return numpy.transpose(numpy.array(output_stack),
- axes=output_dimensions)
-
- def min(self):
- """
- :return: Global minimum value
- """
- min_value = self.images[0].min()
- if len(self.images) > 1:
- for img in self.images[1:]:
- min_value = min(min_value, img.min())
- return min_value
-
- def max(self):
- """
- :return: Global maximum value
- """
- max_value = self.images[0].max()
- if len(self.images) > 1:
- for img in self.images[1:]:
- max_value = max(max_value, img.max())
- return max_value
-
-
-class DatasetView(object):
- """This class provides a way to transpose a dataset without
- casting it into a numpy array. This way, the dataset in a file need not
- necessarily be integrally read into memory to view it in a different
- transposition.
-
- .. note::
- The performances depend a lot on the way the dataset was written
- to file. Depending on the chunking strategy, reading a complete 2D slice
- in an unfavorable direction may still require the entire dataset to
- be read from disk.
-
- :param dataset: h5py dataset
- :param transposition: List of dimensions sorted in the order of
- transposition (relative to the original h5py dataset)
- """
- def __init__(self, dataset, transposition=None):
- """
-
- """
- super(DatasetView, self).__init__()
- self.dataset = dataset
- """original dataset"""
-
- self.shape = dataset.shape
- """Tuple of array dimensions"""
- self.dtype = dataset.dtype
- """Data-type of the array’s element"""
- self.ndim = len(dataset.shape)
- """Number of array dimensions"""
-
- size = 0
- if self.ndim:
- size = 1
- for dimsize in self.shape:
- size *= dimsize
- self.size = size
- """Number of elements in the array."""
-
- self.transposition = list(range(self.ndim))
- """List of dimension indices, in an order depending on the
- specified transposition. By default this is simply
- [0, ..., self.ndim], but it can be changed by specifying a different
- `transposition` parameter at initialization.
-
- Use :meth:`transpose`, to create a new :class:`DatasetView`
- with a different :attr:`transposition`.
- """
-
- if transposition is not None:
- assert len(transposition) == self.ndim
- assert set(transposition) == set(list(range(self.ndim))), \
- "Transposition must be a list containing all dimensions"
- self.transposition = transposition
- self.__sort_shape()
-
- def __sort_shape(self):
- """Sort shape in the order defined in :attr:`transposition`
- """
- new_shape = tuple(self.shape[dim] for dim in self.transposition)
- self.shape = new_shape
-
- def __sort_indices(self, indices):
- """Return array indices sorted in the order needed
- to access data in the original non-transposed dataset.
-
- :param indices: Tuple of ndim indices, in the order needed
- to access the view
- :return: Sorted tuple of indices, to access original data
- """
- assert len(indices) == self.ndim
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(self.transposition, indices)))
- return sorted_indices
-
- def __getitem__(self, item):
- """Handle fancy indexing with regards to the dimension order as
- specified in :attr:`transposition`
-
- The supported fancy-indexing syntax is explained at
- http://docs.h5py.org/en/latest/high/dataset.html#fancy-indexing.
-
- Additional restrictions exist if the data has been transposed:
-
- - numpy boolean array indexing is not supported
- - ellipsis objects are not supported
-
- :param item: Index, possibly fancy index (must be supported by h5py)
- :return: Sliced numpy array or numpy scalar
- """
- # no transposition, let the original dataset handle indexing
- if self.transposition == list(range(self.ndim)):
- return self.dataset[item]
-
- # 1-D slicing: create a list of indices to switch to n-D slicing
- if not hasattr(item, "__len__"):
- # first dimension index (list index) is given
- item = [item]
- # following dimensions are indexed with slices representing all elements
- item += [slice(None) for _i in range(self.ndim - 1)]
-
- # n-dimensional slicing
- if len(item) != self.ndim:
- raise IndexError(
- "N-dim slicing requires a tuple of N indices/slices. " +
- "Needed dimensions: %d" % self.ndim)
-
- # get list of indices sorted in the original dataset order
- sorted_indices = self.__sort_indices(item)
-
- output_data_not_transposed = self.dataset[sorted_indices]
-
- # now we must transpose the output data
- output_dimensions = []
- frozen_dimensions = []
- for i, idx in enumerate(item):
- # slices and sequences
- if not isinstance(idx, int):
- output_dimensions.append(self.transposition[i])
- # regular integer index
- else:
- # whenever a dimension is fixed (indexed by an integer)
- # the number of output dimension is reduced
- frozen_dimensions.append(self.transposition[i])
-
- # decrement output dimensions that are above frozen dimensions
- for frozen_dim in reversed(sorted(frozen_dimensions)):
- for i, out_dim in enumerate(output_dimensions):
- if out_dim > frozen_dim:
- output_dimensions[i] -= 1
-
- assert (len(output_dimensions) + len(frozen_dimensions)) == self.ndim
- assert set(output_dimensions) == set(range(len(output_dimensions)))
-
- return numpy.transpose(output_data_not_transposed,
- axes=output_dimensions)
-
- def __array__(self, dtype=None):
- """Cast the dataset into a numpy array, and return it.
-
- If a transposition has been done on this dataset, return
- a transposed view of a numpy array."""
- return numpy.transpose(numpy.array(self.dataset, dtype=dtype),
- self.transposition)
-
- def __len__(self):
- return self.shape[0]
-
- def transpose(self, transposition=None):
- """Return a re-ordered (dimensions permutated)
- :class:`DatasetView`.
-
- The returned object refers to
- the same dataset but with a different :attr:`transposition`.
-
- :param List[int] transposition: List of dimension numbers in the wanted order.
- If ``None`` (default), reverse the dimensions.
- :return: Transposed DatasetView
- """
- # by default, reverse the dimensions
- if transposition is None:
- transposition = list(reversed(self.transposition))
-
- # If this DatasetView is already transposed, sort new transposition
- # relative to old transposition
- elif list(self.transposition) != list(range(self.ndim)):
- transposition = [self.transposition[i] for i in transposition]
-
- return DatasetView(self.dataset,
- transposition)
-
- @property
- def T(self):
- """
- Same as self.transpose()
-
- :return: DatasetView with dimensions reversed."""
- return self.transpose()
diff --git a/silx/utils/debug.py b/silx/utils/debug.py
deleted file mode 100644
index 5459448..0000000
--- a/silx/utils/debug.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-
-
-import inspect
-import types
-import logging
-
-import six
-
-
-debug_logger = logging.getLogger("silx.DEBUG")
-
-_indent = 0
-
-
-def log_method(func, class_name=None):
- """Decorator to inject a warning log before an after any function/method.
-
- .. code-block:: python
-
- @log_method
- def foo():
- return None
-
- :param callable func: The function to patch
- :param str class_name: In case a method, provide the class name
- """
- def wrapper(*args, **kwargs):
- global _indent
-
- indent = " " * _indent
- func_name = func.func_name if six.PY2 else func.__name__
- if class_name is not None:
- name = "%s.%s" % (class_name, func_name)
- else:
- name = "%s" % (func_name)
-
- debug_logger.warning("%s%s" % (indent, name))
- _indent += 1
- result = func(*args, **kwargs)
- _indent -= 1
- debug_logger.warning("%sreturn (%s)" % (indent, name))
- return result
- return wrapper
-
-
-def log_all_methods(base_class):
- """Decorator to inject a warning log before an after any method provided by
- a class.
-
- .. code-block:: python
-
- @log_all_methods
- class Foo(object):
-
- def a(self):
- return None
-
- def b(self):
- return self.a()
-
- Here is the output when calling the `b` method.
-
- .. code-block::
-
- WARNING:silx.DEBUG:_Foobar.b
- WARNING:silx.DEBUG: _Foobar.a
- WARNING:silx.DEBUG: return (_Foobar.a)
- WARNING:silx.DEBUG:return (_Foobar.b)
-
- :param class base_class: The class to patch
- """
- methodTypes = (types.MethodType, types.FunctionType, types.BuiltinFunctionType, types.BuiltinMethodType)
- for name, func in inspect.getmembers(base_class):
- if isinstance(func, methodTypes):
- if func.__name__ not in ["__subclasshook__", "__new__"]:
- # patching __new__ in Python2 break the object, then we skip it
- setattr(base_class, name, log_method(func, base_class.__name__))
-
- return base_class
diff --git a/silx/utils/html.py b/silx/utils/html.py
deleted file mode 100644
index aab25f2..0000000
--- a/silx/utils/html.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Utils function relative to HTML
-"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "19/09/2016"
-
-
-def escape(string, quote=True):
- """Returns a string where HTML metacharacters are properly escaped.
-
- Compatibility layer to avoid incompatibilities between Python versions,
- Qt versions and Qt bindings.
-
- >>> import silx.utils.html
- >>> silx.utils.html.escape("<html>")
- >>> "&lt;html&gt;"
-
- .. note:: Since Python 3.3 you can use the `html` module. For previous
- version, it is provided by `sgi` module.
- .. note:: Qt4 provides it with `Qt.escape` while Qt5 provide it with
- `QString.toHtmlEscaped`. But `QString` is not exposed by `PyQt` or
- `PySide`.
-
- :param str string: Human readable string.
- :param bool quote: Escape quote and double quotes (default behaviour).
- :returns: Valid HTML syntax to display the input string.
- :rtype: str
- """
- string = string.replace("&", "&amp;") # must be done first
- string = string.replace("<", "&lt;")
- string = string.replace(">", "&gt;")
- if quote:
- string = string.replace("'", "&apos;")
- string = string.replace("\"", "&quot;")
- return string
diff --git a/silx/utils/proxy.py b/silx/utils/proxy.py
deleted file mode 100644
index 8711799..0000000
--- a/silx/utils/proxy.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Module containing proxy objects"""
-
-from __future__ import absolute_import, print_function, division
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "02/10/2017"
-
-
-import functools
-import six
-
-
-class Proxy(object):
- """Create a proxy of an object.
-
- Provides default methods and property using :meth:`__getattr__` and special
- method by redefining them one by one.
- Special methods are defined as properties, as a result if the `obj` method
- is not defined, the property code fail and the special method will not be
- visible.
- """
-
- __slots__ = ["__obj", "__weakref__"]
-
- def __init__(self, obj):
- object.__setattr__(self, "_Proxy__obj", obj)
-
- __class__ = property(lambda self: self.__obj.__class__)
-
- def __getattr__(self, name):
- return getattr(self.__obj, name)
-
- __setattr__ = property(lambda self: self.__obj.__setattr__)
- __delattr__ = property(lambda self: self.__obj.__delattr__)
-
- # binary comparator methods
-
- __lt__ = property(lambda self: self.__obj.__lt__)
- __le__ = property(lambda self: self.__obj.__le__)
- __eq__ = property(lambda self: self.__obj.__eq__)
- __ne__ = property(lambda self: self.__obj.__ne__)
- __gt__ = property(lambda self: self.__obj.__gt__)
- __ge__ = property(lambda self: self.__obj.__ge__)
-
- if six.PY2:
- __cmp__ = property(lambda self: self.__obj.__cmp__)
-
- # binary numeric methods
-
- __add__ = property(lambda self: self.__obj.__add__)
- __radd__ = property(lambda self: self.__obj.__radd__)
- __iadd__ = property(lambda self: self.__obj.__iadd__)
- __sub__ = property(lambda self: self.__obj.__sub__)
- __rsub__ = property(lambda self: self.__obj.__rsub__)
- __isub__ = property(lambda self: self.__obj.__isub__)
- __mul__ = property(lambda self: self.__obj.__mul__)
- __rmul__ = property(lambda self: self.__obj.__rmul__)
- __imul__ = property(lambda self: self.__obj.__imul__)
-
- if six.PY2:
- # Only part of Python 2
- # Python 3 uses __truediv__ and __floordiv__
- __div__ = property(lambda self: self.__obj.__div__)
- __rdiv__ = property(lambda self: self.__obj.__rdiv__)
- __idiv__ = property(lambda self: self.__obj.__idiv__)
-
- __truediv__ = property(lambda self: self.__obj.__truediv__)
- __rtruediv__ = property(lambda self: self.__obj.__rtruediv__)
- __itruediv__ = property(lambda self: self.__obj.__itruediv__)
- __floordiv__ = property(lambda self: self.__obj.__floordiv__)
- __rfloordiv__ = property(lambda self: self.__obj.__rfloordiv__)
- __ifloordiv__ = property(lambda self: self.__obj.__ifloordiv__)
- __mod__ = property(lambda self: self.__obj.__mod__)
- __rmod__ = property(lambda self: self.__obj.__rmod__)
- __imod__ = property(lambda self: self.__obj.__imod__)
- __divmod__ = property(lambda self: self.__obj.__divmod__)
- __rdivmod__ = property(lambda self: self.__obj.__rdivmod__)
- __pow__ = property(lambda self: self.__obj.__pow__)
- __rpow__ = property(lambda self: self.__obj.__rpow__)
- __ipow__ = property(lambda self: self.__obj.__ipow__)
- __lshift__ = property(lambda self: self.__obj.__lshift__)
- __rlshift__ = property(lambda self: self.__obj.__rlshift__)
- __ilshift__ = property(lambda self: self.__obj.__ilshift__)
- __rshift__ = property(lambda self: self.__obj.__rshift__)
- __rrshift__ = property(lambda self: self.__obj.__rrshift__)
- __irshift__ = property(lambda self: self.__obj.__irshift__)
-
- # binary logical methods
-
- __and__ = property(lambda self: self.__obj.__and__)
- __rand__ = property(lambda self: self.__obj.__rand__)
- __iand__ = property(lambda self: self.__obj.__iand__)
- __xor__ = property(lambda self: self.__obj.__xor__)
- __rxor__ = property(lambda self: self.__obj.__rxor__)
- __ixor__ = property(lambda self: self.__obj.__ixor__)
- __or__ = property(lambda self: self.__obj.__or__)
- __ror__ = property(lambda self: self.__obj.__ror__)
- __ior__ = property(lambda self: self.__obj.__ior__)
-
- # unary methods
-
- __neg__ = property(lambda self: self.__obj.__neg__)
- __pos__ = property(lambda self: self.__obj.__pos__)
- __abs__ = property(lambda self: self.__obj.__abs__)
- __invert__ = property(lambda self: self.__obj.__invert__)
- if six.PY3:
- __floor__ = property(lambda self: self.__obj.__floor__)
- __ceil__ = property(lambda self: self.__obj.__ceil__)
- __round__ = property(lambda self: self.__obj.__round__)
-
- # cast
-
- __repr__ = property(lambda self: self.__obj.__repr__)
- __str__ = property(lambda self: self.__obj.__str__)
- __complex__ = property(lambda self: self.__obj.__complex__)
- __int__ = property(lambda self: self.__obj.__int__)
- __float__ = property(lambda self: self.__obj.__float__)
- __hash__ = property(lambda self: self.__obj.__hash__)
- if six.PY2:
- __long__ = property(lambda self: self.__obj.__long__)
- __oct__ = property(lambda self: self.__obj.__oct__)
- __hex__ = property(lambda self: self.__obj.__hex__)
- __unicode__ = property(lambda self: self.__obj.__unicode__)
- __nonzero__ = property(lambda self: lambda: bool(self.__obj))
- if six.PY3:
- __bytes__ = property(lambda self: self.__obj.__bytes__)
- __bool__ = property(lambda self: lambda: bool(self.__obj))
- __format__ = property(lambda self: self.__obj.__format__)
-
- # container
-
- __len__ = property(lambda self: self.__obj.__len__)
- if six.PY3:
- __length_hint__ = property(lambda self: self.__obj.__length_hint__)
- __getitem__ = property(lambda self: self.__obj.__getitem__)
- __missing__ = property(lambda self: self.__obj.__missing__)
- __setitem__ = property(lambda self: self.__obj.__setitem__)
- __delitem__ = property(lambda self: self.__obj.__delitem__)
- __iter__ = property(lambda self: self.__obj.__iter__)
- __reversed__ = property(lambda self: self.__obj.__reversed__)
- __contains__ = property(lambda self: self.__obj.__contains__)
-
- if six.PY2:
- __getslice__ = property(lambda self: self.__obj.__getslice__)
- __setslice__ = property(lambda self: self.__obj.__setslice__)
- __delslice__ = property(lambda self: self.__obj.__delslice__)
-
- # pickle
-
- __reduce__ = property(lambda self: self.__obj.__reduce__)
- __reduce_ex__ = property(lambda self: self.__obj.__reduce_ex__)
-
- # async
-
- if six.PY3:
- __await__ = property(lambda self: self.__obj.__await__)
- __aiter__ = property(lambda self: self.__obj.__aiter__)
- __anext__ = property(lambda self: self.__obj.__anext__)
- __aenter__ = property(lambda self: self.__obj.__aenter__)
- __aexit__ = property(lambda self: self.__obj.__aexit__)
-
- # other
-
- __index__ = property(lambda self: self.__obj.__index__)
- if six.PY2:
- __coerce__ = property(lambda self: self.__obj.__coerce__)
-
- if six.PY3:
- __next__ = property(lambda self: self.__obj.__next__)
-
- __enter__ = property(lambda self: self.__obj.__enter__)
- __exit__ = property(lambda self: self.__obj.__exit__)
-
- __concat__ = property(lambda self: self.__obj.__concat__)
- __iconcat__ = property(lambda self: self.__obj.__iconcat__)
-
- if six.PY2:
- __repeat__ = property(lambda self: self.__obj.__repeat__)
- __irepeat__ = property(lambda self: self.__obj.__irepeat__)
-
- __call__ = property(lambda self: self.__obj.__call__)
-
-
-def _docstring(dest, origin):
- """Implementation of docstring decorator.
-
- It patches dest.__doc__.
- """
- if not isinstance(dest, type) and isinstance(origin, type):
- # func is not a class, but origin is, get the method with the same name
- try:
- origin = getattr(origin, dest.__name__)
- except AttributeError:
- raise ValueError(
- "origin class has no %s method" % dest.__name__)
-
- dest.__doc__ = origin.__doc__
- return dest
-
-
-def docstring(origin):
- """Decorator to initialize the docstring from another source.
-
- This is useful to duplicate a docstring for inheritance and composition.
-
- If origin is a method or a function, it copies its docstring.
- If origin is a class, the docstring is copied from the method
- of that class which has the same name as the method/function
- being decorated.
-
- :param origin:
- The method, function or class from which to get the docstring
- :raises ValueError:
- If the origin class has not method n case the
- """
- return functools.partial(_docstring, origin=origin)
diff --git a/silx/utils/test/__init__.py b/silx/utils/test/__init__.py
deleted file mode 100755
index b35feee..0000000
--- a/silx/utils/test/__init__.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-__authors__ = ["T. Vincent", "P. Knobel"]
-__license__ = "MIT"
-__date__ = "08/03/2019"
-
-
-import unittest
-from . import test_weakref
-from . import test_html
-from . import test_array_like
-from . import test_launcher
-from . import test_deprecation
-from . import test_proxy
-from . import test_debug
-from . import test_number
-from . import test_external_resources
-from . import test_enum
-from . import test_testutils
-from . import test_retry
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(test_weakref.suite())
- test_suite.addTest(test_html.suite())
- test_suite.addTest(test_array_like.suite())
- test_suite.addTest(test_launcher.suite())
- test_suite.addTest(test_deprecation.suite())
- test_suite.addTest(test_proxy.suite())
- test_suite.addTest(test_debug.suite())
- test_suite.addTest(test_number.suite())
- test_suite.addTest(test_external_resources.suite())
- test_suite.addTest(test_enum.suite())
- test_suite.addTest(test_testutils.suite())
- test_suite.addTest(test_retry.suite())
- return test_suite
diff --git a/silx/utils/test/test_array_like.py b/silx/utils/test/test_array_like.py
deleted file mode 100644
index fe92db5..0000000
--- a/silx/utils/test/test_array_like.py
+++ /dev/null
@@ -1,445 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for array_like module"""
-
-__authors__ = ["P. Knobel"]
-__license__ = "MIT"
-__date__ = "09/01/2017"
-
-import h5py
-import numpy
-import os
-import tempfile
-import unittest
-
-from ..array_like import DatasetView, ListOfImages
-from ..array_like import get_dtype, get_concatenated_dtype, get_shape,\
- is_array, is_nested_sequence, is_list_of_arrays
-
-
-class TestTransposedDatasetView(unittest.TestCase):
-
- def setUp(self):
- # dataset attributes
- self.ndim = 3
- self.original_shape = (5, 10, 20)
- self.size = 1
- for dim in self.original_shape:
- self.size *= dim
-
- self.volume = numpy.arange(self.size).reshape(self.original_shape)
-
- self.tempdir = tempfile.mkdtemp()
- self.h5_fname = os.path.join(self.tempdir, "tempfile.h5")
- with h5py.File(self.h5_fname, "w") as f:
- f["volume"] = self.volume
-
- self.h5f = h5py.File(self.h5_fname, "r")
-
- self.all_permutations = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0),
- (2, 0, 1), (2, 1, 0)]
-
- def tearDown(self):
- self.h5f.close()
- os.unlink(self.h5_fname)
- os.rmdir(self.tempdir)
-
- def _testSize(self, obj):
- """These assertions apply to all following test cases"""
- self.assertEqual(obj.ndim, self.ndim)
- self.assertEqual(obj.size, self.size)
- size_from_shape = 1
- for dim in obj.shape:
- size_from_shape *= dim
- self.assertEqual(size_from_shape, self.size)
-
- for dim in self.original_shape:
- self.assertIn(dim, obj.shape)
-
- def testNoTransposition(self):
- """no transposition (transposition = (0, 1, 2))"""
- a = DatasetView(self.h5f["volume"])
-
- self.assertEqual(a.shape, self.original_shape)
- self._testSize(a)
-
- # reversing the dimensions twice results in no change
- rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
-
- for i in range(a.shape[0]):
- for j in range(a.shape[1]):
- for k in range(a.shape[2]):
- self.assertEqual(self.h5f["volume"][i, j, k],
- a[i, j, k])
-
- def _testTransposition(self, transposition):
- """test transposed dataset
-
- :param tuple transposition: List of dimensions (0... n-1) sorted
- in the desired order
- """
- a = DatasetView(self.h5f["volume"],
- transposition=transposition)
- self._testSize(a)
-
- # sort shape of transposed object, to hopefully find the original shape
- sorted_shape = tuple(dim_size for (_, dim_size) in
- sorted(zip(transposition, a.shape)))
- self.assertEqual(sorted_shape, self.original_shape)
-
- a_as_array = numpy.array(self.h5f["volume"]).transpose(transposition)
-
- # test the __array__ method
- self.assertTrue(numpy.array_equal(
- numpy.array(a),
- a_as_array))
-
- # test slicing
- for selection in [(2, slice(None), slice(None)),
- (slice(None), 1, slice(0, 8)),
- (slice(0, 3), slice(None), 3),
- (1, 3, slice(None)),
- (slice(None), 2, 1),
- (4, slice(1, 9, 2), 2)]:
- self.assertIsInstance(a[selection], numpy.ndarray)
- self.assertTrue(numpy.array_equal(
- a[selection],
- a_as_array[selection]))
-
- # test the DatasetView.__getitem__ for single values
- # (step adjusted to test at least 3 indices in each dimension)
- for i in range(0, a.shape[0], a.shape[0] // 3):
- for j in range(0, a.shape[1], a.shape[1] // 3):
- for k in range(0, a.shape[2], a.shape[2] // 3):
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(transposition, [i, j, k])))
- viewed_value = a[i, j, k]
- corresponding_original_value = self.h5f["volume"][sorted_indices]
- self.assertEqual(viewed_value,
- corresponding_original_value)
-
- # reversing the dimensions twice results in no change
- rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
-
- # test .T property
- self.assertTrue(numpy.array_equal(
- a.T,
- a.transpose(rtrans)))
-
- def testTransposition012(self):
- """transposition = (0, 1, 2)
- (should be the same as testNoTransposition)"""
- self._testTransposition((0, 1, 2))
-
- def testTransposition021(self):
- """transposition = (0, 2, 1)"""
- self._testTransposition((0, 2, 1))
-
- def testTransposition102(self):
- """transposition = (1, 0, 2)"""
- self._testTransposition((1, 0, 2))
-
- def testTransposition120(self):
- """transposition = (1, 2, 0)"""
- self._testTransposition((1, 2, 0))
-
- def testTransposition201(self):
- """transposition = (2, 0, 1)"""
- self._testTransposition((2, 0, 1))
-
- def testTransposition210(self):
- """transposition = (2, 1, 0)"""
- self._testTransposition((2, 1, 0))
-
- def testAllDoubleTranspositions(self):
- for trans1 in self.all_permutations:
- for trans2 in self.all_permutations:
- self._testDoubleTransposition(trans1, trans2)
-
- def _testDoubleTransposition(self, transposition1, transposition2):
- a = DatasetView(self.h5f["volume"],
- transposition=transposition1).transpose(transposition2)
-
- b = self.volume.transpose(transposition1).transpose(transposition2)
-
- self.assertTrue(numpy.array_equal(a, b),
- "failed with double transposition %s %s" % (transposition1, transposition2))
-
- def test1DIndex(self):
- a = DatasetView(self.h5f["volume"])
- self.assertTrue(numpy.array_equal(self.volume[1],
- a[1]))
-
- b = DatasetView(self.h5f["volume"], transposition=(1, 0, 2))
- self.assertTrue(numpy.array_equal(self.volume[:, 1, :],
- b[1]))
-
-
-class TestTransposedListOfImages(unittest.TestCase):
- def setUp(self):
- # images attributes
- self.ndim = 3
- self.original_shape = (5, 10, 20)
- self.size = 1
- for dim in self.original_shape:
- self.size *= dim
-
- volume = numpy.arange(self.size).reshape(self.original_shape)
-
- self.images = []
- for i in range(self.original_shape[0]):
- self.images.append(
- volume[i])
-
- self.images_as_3D_array = numpy.array(self.images)
-
- self.all_permutations = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0),
- (2, 0, 1), (2, 1, 0)]
-
- def tearDown(self):
- pass
-
- def _testSize(self, obj):
- """These assertions apply to all following test cases"""
- self.assertEqual(obj.ndim, self.ndim)
- self.assertEqual(obj.size, self.size)
- size_from_shape = 1
- for dim in obj.shape:
- size_from_shape *= dim
- self.assertEqual(size_from_shape, self.size)
-
- for dim in self.original_shape:
- self.assertIn(dim, obj.shape)
-
- def testNoTransposition(self):
- """no transposition (transposition = (0, 1, 2))"""
- a = ListOfImages(self.images)
-
- self.assertEqual(a.shape, self.original_shape)
- self._testSize(a)
-
- for i in range(a.shape[0]):
- for j in range(a.shape[1]):
- for k in range(a.shape[2]):
- self.assertEqual(self.images[i][j, k],
- a[i, j, k])
-
- # reversing the dimensions twice results in no change
- rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
-
- # test .T property
- self.assertTrue(numpy.array_equal(
- a.T,
- a.transpose(rtrans)))
-
- def _testTransposition(self, transposition):
- """test transposed dataset
-
- :param tuple transposition: List of dimensions (0... n-1) sorted
- in the desired order
- """
- a = ListOfImages(self.images,
- transposition=transposition)
- self._testSize(a)
-
- # sort shape of transposed object, to hopefully find the original shape
- sorted_shape = tuple(dim_size for (_, dim_size) in
- sorted(zip(transposition, a.shape)))
- self.assertEqual(sorted_shape, self.original_shape)
-
- a_as_array = numpy.array(self.images).transpose(transposition)
-
- # test the DatasetView.__array__ method
- self.assertTrue(numpy.array_equal(
- numpy.array(a),
- a_as_array))
-
- # test slicing
- for selection in [(2, slice(None), slice(None)),
- (slice(None), 1, slice(0, 8)),
- (slice(0, 3), slice(None), 3),
- (1, 3, slice(None)),
- (slice(None), 2, 1),
- (4, slice(1, 9, 2), 2)]:
- self.assertIsInstance(a[selection], numpy.ndarray)
- self.assertTrue(numpy.array_equal(
- a[selection],
- a_as_array[selection]))
-
- # test the DatasetView.__getitem__ for single values
- # (step adjusted to test at least 3 indices in each dimension)
- for i in range(0, a.shape[0], a.shape[0] // 3):
- for j in range(0, a.shape[1], a.shape[1] // 3):
- for k in range(0, a.shape[2], a.shape[2] // 3):
- viewed_value = a[i, j, k]
- sorted_indices = tuple(idx for (_, idx) in
- sorted(zip(transposition, [i, j, k])))
- corresponding_original_value = self.images[sorted_indices[0]][sorted_indices[1:]]
- self.assertEqual(viewed_value,
- corresponding_original_value)
-
- # reversing the dimensions twice results in no change
- rtrans = list(reversed(range(self.ndim)))
- self.assertTrue(numpy.array_equal(
- a,
- a.transpose(rtrans).transpose(rtrans)))
-
- # test .T property
- self.assertTrue(numpy.array_equal(
- a.T,
- a.transpose(rtrans)))
-
- def _testDoubleTransposition(self, transposition1, transposition2):
- a = ListOfImages(self.images,
- transposition=transposition1).transpose(transposition2)
-
- b = self.images_as_3D_array.transpose(transposition1).transpose(transposition2)
-
- self.assertTrue(numpy.array_equal(a, b),
- "failed with double transposition %s %s" % (transposition1, transposition2))
-
- def testTransposition012(self):
- """transposition = (0, 1, 2)
- (should be the same as testNoTransposition)"""
- self._testTransposition((0, 1, 2))
-
- def testTransposition021(self):
- """transposition = (0, 2, 1)"""
- self._testTransposition((0, 2, 1))
-
- def testTransposition102(self):
- """transposition = (1, 0, 2)"""
- self._testTransposition((1, 0, 2))
-
- def testTransposition120(self):
- """transposition = (1, 2, 0)"""
- self._testTransposition((1, 2, 0))
-
- def testTransposition201(self):
- """transposition = (2, 0, 1)"""
- self._testTransposition((2, 0, 1))
-
- def testTransposition210(self):
- """transposition = (2, 1, 0)"""
- self._testTransposition((2, 1, 0))
-
- def testAllDoubleTranspositions(self):
- for trans1 in self.all_permutations:
- for trans2 in self.all_permutations:
- self._testDoubleTransposition(trans1, trans2)
-
- def test1DIndex(self):
- a = ListOfImages(self.images)
- self.assertTrue(numpy.array_equal(self.images[1],
- a[1]))
-
- b = ListOfImages(self.images, transposition=(1, 0, 2))
- self.assertTrue(numpy.array_equal(self.images_as_3D_array[:, 1, :],
- b[1]))
-
-
-class TestFunctions(unittest.TestCase):
- """Test functions to guess the dtype and shape of an array_like
- object"""
- def testListOfLists(self):
- l = [[0, 1, 2], [2, 3, 4]]
- self.assertEqual(get_dtype(l),
- numpy.dtype(int))
- self.assertEqual(get_shape(l),
- (2, 3))
- self.assertTrue(is_nested_sequence(l))
- self.assertFalse(is_array(l))
- self.assertFalse(is_list_of_arrays(l))
-
- l = [[0., 1.], [2., 3.]]
- self.assertEqual(get_dtype(l),
- numpy.dtype(float))
- self.assertEqual(get_shape(l),
- (2, 2))
- self.assertTrue(is_nested_sequence(l))
- self.assertFalse(is_array(l))
- self.assertFalse(is_list_of_arrays(l))
-
- # concatenated dtype of int and float
- l = [numpy.array([[0, 1, 2], [2, 3, 4]]),
- numpy.array([[0., 1., 2.], [2., 3., 4.]])]
-
- self.assertEqual(get_concatenated_dtype(l),
- numpy.array(l).dtype)
- self.assertEqual(get_shape(l),
- (2, 2, 3))
- self.assertFalse(is_nested_sequence(l))
- self.assertFalse(is_array(l))
- self.assertTrue(is_list_of_arrays(l))
-
- def testNumpyArray(self):
- a = numpy.array([[0, 1], [2, 3]])
- self.assertEqual(get_dtype(a),
- a.dtype)
- self.assertFalse(is_nested_sequence(a))
- self.assertTrue(is_array(a))
- self.assertFalse(is_list_of_arrays(a))
-
- def testH5pyDataset(self):
- a = numpy.array([[0, 1], [2, 3]])
-
- tempdir = tempfile.mkdtemp()
- h5_fname = os.path.join(tempdir, "tempfile.h5")
- with h5py.File(h5_fname, "w") as h5f:
- h5f["dataset"] = a
- d = h5f["dataset"]
-
- self.assertEqual(get_dtype(d),
- numpy.dtype(int))
- self.assertFalse(is_nested_sequence(d))
- self.assertTrue(is_array(d))
- self.assertFalse(is_list_of_arrays(d))
-
- os.unlink(h5_fname)
- os.rmdir(tempdir)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestTransposedDatasetView))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestTransposedListOfImages))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestFunctions))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_debug.py b/silx/utils/test/test_debug.py
deleted file mode 100644
index da08960..0000000
--- a/silx/utils/test/test_debug.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for debug module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "27/02/2018"
-
-
-import unittest
-from silx.utils import debug
-from silx.utils import testutils
-
-
-@debug.log_all_methods
-class _Foobar(object):
-
- def a(self):
- return None
-
- def b(self):
- return self.a()
-
- def random_args(self, *args, **kwargs):
- return args, kwargs
-
- def named_args(self, a, b):
- return a + 1, b + 1
-
-
-class TestDebug(unittest.TestCase):
- """Tests for debug module."""
-
- def logB(self):
- """
- Can be used to check the log output using:
- `./run_tests.py silx.utils.test.test_debug.TestDebug.logB -v`
- """
- print()
- test = _Foobar()
- test.b()
-
- @testutils.test_logging(debug.debug_logger.name, warning=2)
- def testMethod(self):
- test = _Foobar()
- test.a()
-
- @testutils.test_logging(debug.debug_logger.name, warning=4)
- def testInterleavedMethod(self):
- test = _Foobar()
- test.b()
-
- @testutils.test_logging(debug.debug_logger.name, warning=2)
- def testNamedArgument(self):
- # Arguments arre still provided to the patched method
- test = _Foobar()
- result = test.named_args(10, 11)
- self.assertEqual(result, (11, 12))
-
- @testutils.test_logging(debug.debug_logger.name, warning=2)
- def testRandomArguments(self):
- # Arguments arre still provided to the patched method
- test = _Foobar()
- result = test.random_args("foo", 50, a=10, b=100)
- self.assertEqual(result[0], ("foo", 50))
- self.assertEqual(result[1], {"a": 10, "b": 100})
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestDebug))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_deprecation.py b/silx/utils/test/test_deprecation.py
deleted file mode 100644
index 0aa06a0..0000000
--- a/silx/utils/test/test_deprecation.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for html module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import unittest
-from .. import deprecation
-from silx.utils import testutils
-
-
-class TestDeprecation(unittest.TestCase):
- """Tests for deprecation module."""
-
- @deprecation.deprecated
- def deprecatedWithoutParam(self):
- pass
-
- @deprecation.deprecated(reason="r", replacement="r", since_version="v")
- def deprecatedWithParams(self):
- pass
-
- @deprecation.deprecated(reason="r", replacement="r", since_version="v", only_once=True)
- def deprecatedOnlyOnce(self):
- pass
-
- @deprecation.deprecated(reason="r", replacement="r", since_version="v", only_once=False)
- def deprecatedEveryTime(self):
- pass
-
- @testutils.test_logging(deprecation.depreclog.name, warning=1)
- def testAnnotationWithoutParam(self):
- self.deprecatedWithoutParam()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=1)
- def testAnnotationWithParams(self):
- self.deprecatedWithParams()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=3)
- def testLoggedEveryTime(self):
- """Logged everytime cause it is 3 different locations"""
- self.deprecatedOnlyOnce()
- self.deprecatedOnlyOnce()
- self.deprecatedOnlyOnce()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=1)
- def testLoggedSingleTime(self):
- def log():
- self.deprecatedOnlyOnce()
- log()
- log()
- log()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=3)
- def testLoggedEveryTime2(self):
- self.deprecatedEveryTime()
- self.deprecatedEveryTime()
- self.deprecatedEveryTime()
-
- @testutils.test_logging(deprecation.depreclog.name, warning=1)
- def testWarning(self):
- deprecation.deprecated_warning(type_="t", name="n")
-
- def testBacktrace(self):
- testLogging = testutils.TestLogging(deprecation.depreclog.name)
- with testLogging:
- self.deprecatedEveryTime()
- message = testLogging.records[0].getMessage()
- filename = __file__.replace(".pyc", ".py")
- self.assertTrue(filename in message)
- self.assertTrue("testBacktrace" in message)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestDeprecation))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_enum.py b/silx/utils/test/test_enum.py
deleted file mode 100644
index a72da46..0000000
--- a/silx/utils/test/test_enum.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests of Enum class with extra class methods"""
-
-from __future__ import absolute_import
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "29/04/2019"
-
-
-import sys
-import unittest
-
-import enum
-from silx.utils.enum import Enum
-
-
-class TestEnum(unittest.TestCase):
- """Tests for enum module."""
-
- def test(self):
- """Test with Enum"""
- class Success(Enum):
- A = 1
- B = 'B'
- self._check_enum_content(Success)
-
- @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
- def test(self):
- """Test Enum with member redefinition"""
- with self.assertRaises(TypeError):
- class Failure(Enum):
- A = 1
- A = 'B'
-
- def test_unique(self):
- """Test with enum.unique"""
- with self.assertRaises(ValueError):
- @enum.unique
- class Failure(Enum):
- A = 1
- B = 1
-
- @enum.unique
- class Success(Enum):
- A = 1
- B = 'B'
- self._check_enum_content(Success)
-
- def _check_enum_content(self, enum_):
- """Check that the content of an enum is: <A: 1, B: 2>.
-
- :param Enum enum_:
- """
- self.assertEqual(enum_.members(), (enum_.A, enum_.B))
- self.assertEqual(enum_.names(), ('A', 'B'))
- self.assertEqual(enum_.values(), (1, 'B'))
-
- self.assertEqual(enum_.from_value(1), enum_.A)
- self.assertEqual(enum_.from_value('B'), enum_.B)
- with self.assertRaises(ValueError):
- enum_.from_value(3)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite.addTest(loadTests(TestEnum))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_external_resources.py b/silx/utils/test/test_external_resources.py
deleted file mode 100644
index 8576029..0000000
--- a/silx/utils/test/test_external_resources.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Test for resource files management."""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "08/03/2019"
-
-
-import os
-import unittest
-import shutil
-import socket
-import six
-
-from silx.utils.ExternalResources import ExternalResources
-
-
-def isSilxWebsiteAvailable():
- try:
- six.moves.urllib.request.urlopen('http://www.silx.org', timeout=1)
- return True
- except six.moves.urllib.error.URLError:
- return False
- except socket.timeout:
- # This exception is still received in Python 2.7
- return False
-
-
-class TestExternalResources(unittest.TestCase):
- """This is a test for the ExternalResources"""
-
- @classmethod
- def setUpClass(cls):
- if not isSilxWebsiteAvailable():
- raise unittest.SkipTest("Network or silx website not available")
-
- def setUp(self):
- self.resources = ExternalResources("toto", "http://www.silx.org/pub/silx/")
-
- def tearDown(self):
- if self.resources.data_home:
- shutil.rmtree(self.resources.data_home)
- self.resources = None
-
- def test_download(self):
- "test the download from silx.org"
- f = self.resources.getfile("lena.png")
- self.assertTrue(os.path.exists(f))
- di = self.resources.getdir("source.tar.gz")
- for fi in di:
- self.assertTrue(os.path.exists(fi))
-
- def test_download_all(self):
- "test the download of all files from silx.org"
- filename = self.resources.getfile("lena.png")
- directory = "source.tar.gz"
- filelist = self.resources.getdir(directory)
- # download file and remove it to create a json mapping file
- os.remove(filename)
- directory_path = os.path.commonprefix(filelist)
- # Make sure we will rmtree a dangerous path like "/"
- self.assertIn(self.resources.data_home, directory_path)
- shutil.rmtree(directory_path)
- filelist = self.resources.download_all()
- self.assertGreater(len(filelist), 1, "At least 2 items were downloaded")
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestExternalResources))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_html.py b/silx/utils/test/test_html.py
deleted file mode 100644
index 4af8560..0000000
--- a/silx/utils/test/test_html.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for html module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "19/09/2016"
-
-
-import unittest
-from .. import html
-
-
-class TestHtml(unittest.TestCase):
- """Tests for html module."""
-
- def testLtGt(self):
- result = html.escape("<html>'\"")
- self.assertEqual("&lt;html&gt;&apos;&quot;", result)
-
- def testLtAmpGt(self):
- # '&' have to be escaped first
- result = html.escape("<&>")
- self.assertEqual("&lt;&amp;&gt;", result)
-
- def testNoQuotes(self):
- result = html.escape("\"m&m's\"", quote=False)
- self.assertEqual("\"m&amp;m's\"", result)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestHtml))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_launcher.py b/silx/utils/test/test_launcher.py
deleted file mode 100644
index c64ac9a..0000000
--- a/silx/utils/test/test_launcher.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for html module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "17/01/2018"
-
-
-import sys
-import unittest
-from silx.utils.testutils import ParametricTestCase
-from .. import launcher
-
-
-class CallbackMock():
-
- def __init__(self, result=None):
- self._execute_count = 0
- self._execute_argv = None
- self._result = result
-
- def execute(self, argv):
- self._execute_count = self._execute_count + 1
- self._execute_argv = argv
- return self._result
-
- def __call__(self, argv):
- return self.execute(argv)
-
-
-class TestLauncherCommand(unittest.TestCase):
- """Tests for launcher class."""
-
- def testEnv(self):
- command = launcher.LauncherCommand("foo")
- old = sys.argv
- params = ["foo", "bar"]
- with command.get_env(params):
- self.assertEqual(params, sys.argv)
- self.assertEqual(sys.argv, old)
-
- def testEnvWhileException(self):
- command = launcher.LauncherCommand("foo")
- old = sys.argv
- params = ["foo", "bar"]
- try:
- with command.get_env(params):
- raise RuntimeError()
- except RuntimeError:
- pass
- self.assertEqual(sys.argv, old)
-
- def testExecute(self):
- params = ["foo", "bar"]
- callback = CallbackMock(result=42)
- command = launcher.LauncherCommand("foo", function=callback)
- status = command.execute(params)
- self.assertEqual(callback._execute_count, 1)
- self.assertEqual(callback._execute_argv, params)
- self.assertEqual(status, 42)
-
-
-class TestModuleCommand(ParametricTestCase):
-
- def setUp(self):
- module_name = "silx.utils.test.test_launcher_command"
- command = launcher.LauncherCommand("foo", module_name=module_name)
- self.command = command
-
- def testHelp(self):
- status = self.command.execute(["--help"])
- self.assertEqual(status, 0)
-
- def testException(self):
- try:
- self.command.execute(["exception"])
- self.fail()
- except RuntimeError:
- pass
-
- def testCall(self):
- status = self.command.execute([])
- self.assertEqual(status, 0)
-
- def testError(self):
- status = self.command.execute(["error"])
- self.assertEqual(status, -1)
-
-
-class TestLauncher(ParametricTestCase):
- """Tests for launcher class."""
-
- def testCallCommand(self):
- callback = CallbackMock(result=42)
- runner = launcher.Launcher(prog="prog")
- command = launcher.LauncherCommand("foo", function=callback)
- runner.add_command(command=command)
- status = runner.execute(["prog", "foo", "param1", "param2"])
- self.assertEqual(status, 42)
- self.assertEqual(callback._execute_argv, ["prog foo", "param1", "param2"])
- self.assertEqual(callback._execute_count, 1)
-
- def testAddCommand(self):
- runner = launcher.Launcher(prog="prog")
- module_name = "silx.utils.test.test_launcher_command"
- runner.add_command("foo", module_name=module_name)
- status = runner.execute(["prog", "foo"])
- self.assertEqual(status, 0)
-
- def testCallHelpOnCommand(self):
- callback = CallbackMock(result=42)
- runner = launcher.Launcher(prog="prog")
- command = launcher.LauncherCommand("foo", function=callback)
- runner.add_command(command=command)
- status = runner.execute(["prog", "--help", "foo"])
- self.assertEqual(status, 42)
- self.assertEqual(callback._execute_argv, ["prog foo", "--help"])
- self.assertEqual(callback._execute_count, 1)
-
- def testCallHelpOnCommand2(self):
- callback = CallbackMock(result=42)
- runner = launcher.Launcher(prog="prog")
- command = launcher.LauncherCommand("foo", function=callback)
- runner.add_command(command=command)
- status = runner.execute(["prog", "help", "foo"])
- self.assertEqual(status, 42)
- self.assertEqual(callback._execute_argv, ["prog foo", "--help"])
- self.assertEqual(callback._execute_count, 1)
-
- def testCallHelpOnUnknownCommand(self):
- callback = CallbackMock(result=42)
- runner = launcher.Launcher(prog="prog")
- command = launcher.LauncherCommand("foo", function=callback)
- runner.add_command(command=command)
- status = runner.execute(["prog", "help", "foo2"])
- self.assertEqual(status, -1)
-
- def testNotAvailableCommand(self):
- callback = CallbackMock(result=42)
- runner = launcher.Launcher(prog="prog")
- command = launcher.LauncherCommand("foo", function=callback)
- runner.add_command(command=command)
- status = runner.execute(["prog", "foo2", "param1", "param2"])
- self.assertEqual(status, -1)
- self.assertEqual(callback._execute_count, 0)
-
- def testCallHelp(self):
- callback = CallbackMock(result=42)
- runner = launcher.Launcher(prog="prog")
- command = launcher.LauncherCommand("foo", function=callback)
- runner.add_command(command=command)
- status = runner.execute(["prog", "help"])
- self.assertEqual(status, 0)
- self.assertEqual(callback._execute_count, 0)
-
- def testCommonCommands(self):
- runner = launcher.Launcher()
- tests = [
- ["prog"],
- ["prog", "--help"],
- ["prog", "--version"],
- ["prog", "help", "--help"],
- ["prog", "help", "help"],
- ]
- for arguments in tests:
- with self.subTest(args=tests):
- status = runner.execute(arguments)
- self.assertEqual(status, 0)
-
-
-def suite():
- loader = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loader(TestLauncherCommand))
- test_suite.addTest(loader(TestLauncher))
- test_suite.addTest(loader(TestModuleCommand))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_number.py b/silx/utils/test/test_number.py
deleted file mode 100644
index 4ac9636..0000000
--- a/silx/utils/test/test_number.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for silx.uitls.number module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "05/06/2018"
-
-import logging
-import numpy
-import unittest
-import pkg_resources
-from silx.utils import number
-from silx.utils import testutils
-
-_logger = logging.getLogger(__name__)
-
-
-class TestConversionTypes(testutils.ParametricTestCase):
-
- def testEmptyFail(self):
- self.assertRaises(ValueError, number.min_numerical_convertible_type, "")
-
- def testStringFail(self):
- self.assertRaises(ValueError, number.min_numerical_convertible_type, "a")
-
- def testInteger(self):
- dtype = number.min_numerical_convertible_type("1456")
- self.assertTrue(numpy.issubdtype(dtype, numpy.unsignedinteger))
-
- def testTrailledInteger(self):
- dtype = number.min_numerical_convertible_type(" \t\n\r1456\t\n\r")
- self.assertTrue(numpy.issubdtype(dtype, numpy.unsignedinteger))
-
- def testPositiveInteger(self):
- dtype = number.min_numerical_convertible_type("+1456")
- self.assertTrue(numpy.issubdtype(dtype, numpy.unsignedinteger))
-
- def testNegativeInteger(self):
- dtype = number.min_numerical_convertible_type("-1456")
- self.assertTrue(numpy.issubdtype(dtype, numpy.signedinteger))
-
- def testIntegerExponential(self):
- dtype = number.min_numerical_convertible_type("14e10")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testIntegerPositiveExponential(self):
- dtype = number.min_numerical_convertible_type("14e+10")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testIntegerNegativeExponential(self):
- dtype = number.min_numerical_convertible_type("14e-10")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testNumberDecimal(self):
- dtype = number.min_numerical_convertible_type("14.5")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testPositiveNumberDecimal(self):
- dtype = number.min_numerical_convertible_type("+14.5")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testNegativeNumberDecimal(self):
- dtype = number.min_numerical_convertible_type("-14.5")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testDecimal(self):
- dtype = number.min_numerical_convertible_type(".50")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testPositiveDecimal(self):
- dtype = number.min_numerical_convertible_type("+.5")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testNegativeDecimal(self):
- dtype = number.min_numerical_convertible_type("-.5")
- self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
-
- def testMantissa16(self):
- dtype = number.min_numerical_convertible_type("1.50")
- self.assertEqual(dtype, numpy.float16)
-
- def testFloat32(self):
- dtype = number.min_numerical_convertible_type("-23.172")
- self.assertEqual(dtype, numpy.float32)
-
- def testMantissa32(self):
- dtype = number.min_numerical_convertible_type("1400.50")
- self.assertEqual(dtype, numpy.float32)
-
- def testMantissa64(self):
- dtype = number.min_numerical_convertible_type("10000.000010")
- self.assertEqual(dtype, numpy.float64)
-
- def testMantissa80(self):
- self.skipIfFloat80NotSupported()
- dtype = number.min_numerical_convertible_type("1000000000.00001013")
-
- if pkg_resources.parse_version(numpy.version.version) <= pkg_resources.parse_version("1.10.4"):
- # numpy 1.8.2 -> Debian 8
- # Checking a float128 precision with numpy 1.8.2 using abs(diff) is not working.
- # It looks like the difference is done using float64 (diff == 0.0)
- expected = (numpy.longdouble, numpy.float64)
- else:
- expected = (numpy.longdouble, )
- self.assertIn(dtype, expected)
-
- def testExponent32(self):
- dtype = number.min_numerical_convertible_type("14.0e30")
- self.assertEqual(dtype, numpy.float32)
-
- def testExponent64(self):
- dtype = number.min_numerical_convertible_type("14.0e300")
- self.assertEqual(dtype, numpy.float64)
-
- def testExponent80(self):
- self.skipIfFloat80NotSupported()
- dtype = number.min_numerical_convertible_type("14.0e3000")
- self.assertEqual(dtype, numpy.longdouble)
-
- def testFloat32ToString(self):
- value = str(numpy.float32(numpy.pi))
- dtype = number.min_numerical_convertible_type(value)
- self.assertIn(dtype, (numpy.float32, numpy.float64))
-
- def skipIfFloat80NotSupported(self):
- if number.is_longdouble_64bits():
- self.skipTest("float-80bits not supported")
-
- def testLosePrecisionUsingFloat80(self):
- self.skipIfFloat80NotSupported()
- if pkg_resources.parse_version(numpy.version.version) <= pkg_resources.parse_version("1.10.4"):
- self.skipTest("numpy > 1.10.4 expected")
- # value does not fit even in a 128 bits mantissa
- value = "1.0340282366920938463463374607431768211456"
- func = testutils.test_logging(number._logger.name, warning=1)
- func = func(number.min_numerical_convertible_type)
- dtype = func(value)
- self.assertIn(dtype, (numpy.longdouble, ))
-
- def testMillisecondEpochTime(self):
- datetimes = ['1465803236.495412',
- '1465803236.999362',
- '1465803237.504311',
- '1465803238.009261',
- '1465803238.512211',
- '1465803239.016160',
- '1465803239.520110',
- '1465803240.026059',
- '1465803240.529009']
- for datetime in datetimes:
- with self.subTest(datetime=datetime):
- dtype = number.min_numerical_convertible_type(datetime)
- self.assertEqual(dtype, numpy.float64)
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestConversionTypes))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest="suite")
diff --git a/silx/utils/test/test_proxy.py b/silx/utils/test/test_proxy.py
deleted file mode 100644
index 72b4d21..0000000
--- a/silx/utils/test/test_proxy.py
+++ /dev/null
@@ -1,344 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for weakref module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "02/10/2017"
-
-
-import unittest
-import pickle
-import numpy
-from silx.utils.proxy import Proxy, docstring
-
-
-class Thing(object):
-
- def __init__(self, value):
- self.value = value
-
- def __getitem__(self, selection):
- return selection + 1
-
- def method(self, value):
- return value + 2
-
-
-class InheritedProxy(Proxy):
- """Inheriting the proxy allow to specialisze methods"""
-
- def __init__(self, obj, value):
- Proxy.__init__(self, obj)
- self.value = value + 2
-
- def __getitem__(self, selection):
- return selection + 3
-
- def method(self, value):
- return value + 4
-
-
-class TestProxy(unittest.TestCase):
- """Test that the proxy behave as expected"""
-
- def text_init(self):
- obj = Thing(10)
- p = Proxy(obj)
- self.assertTrue(isinstance(p, Thing))
- self.assertTrue(isinstance(p, Proxy))
-
- # methods and properties
-
- def test_has_special_method(self):
- obj = Thing(10)
- p = Proxy(obj)
- self.assertTrue(hasattr(p, "__getitem__"))
-
- def test_missing_special_method(self):
- obj = Thing(10)
- p = Proxy(obj)
- self.assertFalse(hasattr(p, "__and__"))
-
- def test_method(self):
- obj = Thing(10)
- p = Proxy(obj)
- self.assertEqual(p.method(10), obj.method(10))
-
- def test_property(self):
- obj = Thing(10)
- p = Proxy(obj)
- self.assertEqual(p.value, obj.value)
-
- # special functions
-
- def test_getitem(self):
- obj = Thing(10)
- p = Proxy(obj)
- self.assertEqual(p[10], obj[10])
-
- def test_setitem(self):
- obj = numpy.array([10, 20, 30])
- p = Proxy(obj)
- p[0] = 20
- self.assertEqual(obj[0], 20)
-
- def test_slice(self):
- obj = numpy.arange(20)
- p = Proxy(obj)
- expected = obj[0:10:2]
- result = p[0:10:2]
- self.assertEqual(list(result), list(expected))
-
- # binary comparator methods
-
- def test_lt(self):
- obj = numpy.array([20])
- p = Proxy(obj)
- expected = obj < obj
- result = p < p
- self.assertEqual(result, expected)
-
- # binary numeric methods
-
- def test_add(self):
- obj = numpy.array([20])
- proxy = Proxy(obj)
- expected = obj + obj
- result = proxy + proxy
- self.assertEqual(result, expected)
-
- def test_iadd(self):
- expected = numpy.array([20])
- expected += 10
- obj = numpy.array([20])
- result = Proxy(obj)
- result += 10
- self.assertEqual(result, expected)
-
- def test_radd(self):
- obj = numpy.array([20])
- p = Proxy(obj)
- expected = 10 + obj
- result = 10 + p
- self.assertEqual(result, expected)
-
- # binary logical methods
-
- def test_and(self):
- obj = numpy.array([20])
- p = Proxy(obj)
- expected = obj & obj
- result = p & p
- self.assertEqual(result, expected)
-
- def test_iand(self):
- expected = numpy.array([20])
- expected &= 10
- obj = numpy.array([20])
- result = Proxy(obj)
- result &= 10
- self.assertEqual(result, expected)
-
- def test_rand(self):
- obj = numpy.array([20])
- p = Proxy(obj)
- expected = 10 & obj
- result = 10 & p
- self.assertEqual(result, expected)
-
- # unary methods
-
- def test_neg(self):
- obj = numpy.array([20])
- p = Proxy(obj)
- expected = -obj
- result = -p
- self.assertEqual(result, expected)
-
- def test_round(self):
- obj = 20.5
- p = Proxy(obj)
- expected = round(obj)
- result = round(p)
- self.assertEqual(result, expected)
-
- # cast
-
- def test_bool(self):
- obj = True
- p = Proxy(obj)
- if p:
- pass
- else:
- self.fail()
-
- def test_str(self):
- obj = Thing(10)
- p = Proxy(obj)
- expected = str(obj)
- result = str(p)
- self.assertEqual(result, expected)
-
- def test_repr(self):
- obj = Thing(10)
- p = Proxy(obj)
- expected = repr(obj)
- result = repr(p)
- self.assertEqual(result, expected)
-
- def test_text_bool(self):
- obj = ""
- p = Proxy(obj)
- if p:
- self.fail()
- else:
- pass
-
- def test_text_str(self):
- obj = "a"
- p = Proxy(obj)
- expected = str(obj)
- result = str(p)
- self.assertEqual(result, expected)
-
- def test_text_repr(self):
- obj = "a"
- p = Proxy(obj)
- expected = repr(obj)
- result = repr(p)
- self.assertEqual(result, expected)
-
- def test_hash(self):
- obj = [0, 1, 2]
- p = Proxy(obj)
- with self.assertRaises(TypeError):
- hash(p)
- obj = (0, 1, 2)
- p = Proxy(obj)
- hash(p)
-
-
-class TestInheritedProxy(unittest.TestCase):
- """Test that inheriting the Proxy class behave as expected"""
-
- # methods and properties
-
- def test_method(self):
- obj = Thing(10)
- p = InheritedProxy(obj, 11)
- self.assertEqual(p.method(10), 11 + 3)
-
- def test_property(self):
- obj = Thing(10)
- p = InheritedProxy(obj, 11)
- self.assertEqual(p.value, 11 + 2)
-
- # special functions
-
- def test_getitem(self):
- obj = Thing(10)
- p = InheritedProxy(obj, 11)
- self.assertEqual(p[12], 12 + 3)
-
-
-class TestPickle(unittest.TestCase):
-
- def test_dumps(self):
- obj = Thing(10)
- p = Proxy(obj)
- expected = pickle.dumps(obj)
- result = pickle.dumps(p)
- self.assertEqual(result, expected)
-
- def test_loads(self):
- obj = Thing(10)
- p = Proxy(obj)
- obj2 = pickle.loads(pickle.dumps(p))
- self.assertTrue(isinstance(obj2, Thing))
- self.assertFalse(isinstance(obj2, Proxy))
- self.assertEqual(obj.value, obj2.value)
-
-
-class TestDocstring(unittest.TestCase):
- """Test docstring decorator"""
-
- class Base(object):
- def method(self):
- """Docstring"""
- pass
-
- def test_inheritance(self):
- class Derived(TestDocstring.Base):
- @docstring(TestDocstring.Base)
- def method(self):
- pass
-
- self.assertEqual(Derived.method.__doc__,
- TestDocstring.Base.method.__doc__)
-
- def test_composition(self):
- class Composed(object):
- def __init__(self):
- self._base = TestDocstring.Base()
-
- @docstring(TestDocstring.Base)
- def method(self):
- return self._base.method()
-
- @docstring(TestDocstring.Base.method)
- def renamed(self):
- return self._base.method()
-
- self.assertEqual(Composed.method.__doc__,
- TestDocstring.Base.method.__doc__)
-
- self.assertEqual(Composed.renamed.__doc__,
- TestDocstring.Base.method.__doc__)
-
- def test_function(self):
- def f():
- """Docstring"""
- pass
-
- @docstring(f)
- def g():
- pass
-
- self.assertEqual(f.__doc__, g.__doc__)
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestProxy))
- test_suite.addTest(loadTests(TestPickle))
- test_suite.addTest(loadTests(TestInheritedProxy))
- test_suite.addTest(loadTests(TestDocstring))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_retry.py b/silx/utils/test/test_retry.py
deleted file mode 100644
index d223f44..0000000
--- a/silx/utils/test/test_retry.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ############################################################################*/
-"""Tests for retry utilities"""
-
-__authors__ = ["W. de Nolf"]
-__license__ = "MIT"
-__date__ = "05/02/2020"
-
-
-import unittest
-import os
-import sys
-import time
-import tempfile
-
-from .. import retry
-
-
-def _cause_segfault():
- import ctypes
-
- i = ctypes.c_char(b"a")
- j = ctypes.pointer(i)
- c = 0
- while True:
- j[c] = b"a"
- c += 1
-
-
-def _submain(filename, kwcheck=None, ncausefailure=0, faildelay=0):
- assert filename
- assert kwcheck
- sys.stderr = open(os.devnull, "w")
-
- with open(filename, mode="r") as f:
- failcounter = int(f.readline().strip())
-
- if failcounter < ncausefailure:
- time.sleep(faildelay)
- failcounter += 1
- with open(filename, mode="w") as f:
- f.write(str(failcounter))
- if failcounter % 2:
- raise retry.RetryError
- else:
- _cause_segfault()
- return True
-
-
-_wsubmain = retry.retry_in_subprocess()(_submain)
-
-
-class TestRetry(unittest.TestCase):
- def setUp(self):
- self.test_dir = tempfile.mkdtemp()
- self.ctr_file = os.path.join(self.test_dir, "failcounter.txt")
-
- def tearDown(self):
- if os.path.exists(self.ctr_file):
- os.unlink(self.ctr_file)
- os.rmdir(self.test_dir)
-
- def test_retry(self):
- ncausefailure = 3
- faildelay = 0.1
- sufficient_timeout = ncausefailure * (faildelay + 10)
- insufficient_timeout = ncausefailure * faildelay * 0.5
-
- @retry.retry()
- def method(check, kwcheck=None):
- assert check
- assert kwcheck
- nonlocal failcounter
- if failcounter < ncausefailure:
- time.sleep(faildelay)
- failcounter += 1
- raise retry.RetryError
- return True
-
- failcounter = 0
- kw = {
- "kwcheck": True,
- "retry_timeout": sufficient_timeout,
- }
- self.assertTrue(method(True, **kw))
-
- failcounter = 0
- kw = {
- "kwcheck": True,
- "retry_timeout": insufficient_timeout,
- }
- with self.assertRaises(retry.RetryTimeoutError):
- method(True, **kw)
-
- def test_retry_contextmanager(self):
- ncausefailure = 3
- faildelay = 0.1
- sufficient_timeout = ncausefailure * (faildelay + 10)
- insufficient_timeout = ncausefailure * faildelay * 0.5
-
- @retry.retry_contextmanager()
- def context(check, kwcheck=None):
- assert check
- assert kwcheck
- nonlocal failcounter
- if failcounter < ncausefailure:
- time.sleep(faildelay)
- failcounter += 1
- raise retry.RetryError
- yield True
-
- failcounter = 0
- kw = {"kwcheck": True, "retry_timeout": sufficient_timeout}
- with context(True, **kw) as result:
- self.assertTrue(result)
-
- failcounter = 0
- kw = {"kwcheck": True, "retry_timeout": insufficient_timeout}
- with self.assertRaises(retry.RetryTimeoutError):
- with context(True, **kw) as result:
- pass
-
- def test_retry_in_subprocess(self):
- ncausefailure = 3
- faildelay = 0.1
- sufficient_timeout = ncausefailure * (faildelay + 10)
- insufficient_timeout = ncausefailure * faildelay * 0.5
-
- kw = {
- "ncausefailure": ncausefailure,
- "faildelay": faildelay,
- "kwcheck": True,
- "retry_timeout": sufficient_timeout,
- }
- with open(self.ctr_file, mode="w") as f:
- f.write("0")
- self.assertTrue(_wsubmain(self.ctr_file, **kw))
-
- kw = {
- "ncausefailure": ncausefailure,
- "faildelay": faildelay,
- "kwcheck": True,
- "retry_timeout": insufficient_timeout,
- }
- with open(self.ctr_file, mode="w") as f:
- f.write("0")
- with self.assertRaises(retry.RetryTimeoutError):
- _wsubmain(self.ctr_file, **kw)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestRetry))
- return test_suite
-
-
-if __name__ == "__main__":
- unittest.main(defaultTest="suite")
diff --git a/silx/utils/test/test_testutils.py b/silx/utils/test/test_testutils.py
deleted file mode 100755
index c72a3d8..0000000
--- a/silx/utils/test/test_testutils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for testutils module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "18/11/2019"
-
-
-import unittest
-import logging
-from .. import testutils
-
-
-class TestTestLogging(unittest.TestCase):
- """Tests for TestLogging."""
-
- def testRight(self):
- logger = logging.getLogger(__name__ + "testRight")
- listener = testutils.TestLogging(logger, error=1)
- with listener:
- logger.error("expected")
- logger.info("ignored")
-
- def testCustomLevel(self):
- logger = logging.getLogger(__name__ + "testCustomLevel")
- listener = testutils.TestLogging(logger, error=1)
- with listener:
- logger.error("expected")
- logger.log(666, "custom level have to be ignored")
-
- def testWrong(self):
- logger = logging.getLogger(__name__ + "testWrong")
- listener = testutils.TestLogging(logger, error=1)
- with self.assertRaises(RuntimeError):
- with listener:
- logger.error("expected")
- logger.error("not expected")
-
- def testManyErrors(self):
- logger = logging.getLogger(__name__ + "testManyErrors")
- listener = testutils.TestLogging(logger, error=1, warning=2)
- with self.assertRaises(RuntimeError):
- with listener:
- pass
-
- def testCanBeChecked(self):
- logger = logging.getLogger(__name__ + "testCanBreak")
- listener = testutils.TestLogging(logger, error=1, warning=2)
- with self.assertRaises(RuntimeError):
- with listener:
- logger.error("aaa")
- logger.warning("aaa")
- self.assertFalse(listener.can_be_checked())
- logger.error("aaa")
- # Here we know that it's already wrong without a big cost
- self.assertTrue(listener.can_be_checked())
-
- def testWithAs(self):
- logger = logging.getLogger(__name__ + "testCanBreak")
- with testutils.TestLogging(logger) as listener:
- logger.error("aaa")
- self.assertIsNotNone(listener)
-
- def testErrorMessage(self):
- logger = logging.getLogger(__name__ + "testCanBreak")
- listener = testutils.TestLogging(logger, error=1, warning=2)
- with self.assertRaisesRegex(RuntimeError, "aaabbb"):
- with listener:
- logger.error("aaa")
- logger.warning("aaabbb")
- logger.error("aaa")
-
-
-def suite():
- loadTests = unittest.defaultTestLoader.loadTestsFromTestCase
- test_suite = unittest.TestSuite()
- test_suite.addTest(loadTests(TestTestLogging))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/test/test_weakref.py b/silx/utils/test/test_weakref.py
deleted file mode 100644
index 001193d..0000000
--- a/silx/utils/test/test_weakref.py
+++ /dev/null
@@ -1,330 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Tests for weakref module"""
-
-__authors__ = ["V. Valls"]
-__license__ = "MIT"
-__date__ = "15/09/2016"
-
-
-import unittest
-from .. import weakref
-
-
-class Dummy(object):
- """Dummy class to use it as geanie pig"""
- def inc(self, a):
- return a + 1
-
- def __lt__(self, other):
- return True
-
-
-def dummy_inc(a):
- """Dummy function to use it as geanie pig"""
- return a + 1
-
-
-class TestWeakMethod(unittest.TestCase):
- """Tests for weakref.WeakMethod"""
-
- def testMethod(self):
- dummy = Dummy()
- callable_ = weakref.WeakMethod(dummy.inc)
- self.assertEqual(callable_()(10), 11)
-
- def testMethodWithDeadObject(self):
- dummy = Dummy()
- callable_ = weakref.WeakMethod(dummy.inc)
- dummy = None
- self.assertIsNone(callable_())
-
- def testMethodWithDeadFunction(self):
- dummy = Dummy()
- dummy.inc2 = lambda self, a: a + 1
- callable_ = weakref.WeakMethod(dummy.inc2)
- dummy.inc2 = None
- self.assertIsNone(callable_())
-
- def testFunction(self):
- callable_ = weakref.WeakMethod(dummy_inc)
- self.assertEqual(callable_()(10), 11)
-
- def testDeadFunction(self):
- def inc(a):
- return a + 1
- callable_ = weakref.WeakMethod(inc)
- inc = None
- self.assertIsNone(callable_())
-
- def testLambda(self):
- store = lambda a: a + 1 # noqa: E731
- callable_ = weakref.WeakMethod(store)
- self.assertEqual(callable_()(10), 11)
-
- def testDeadLambda(self):
- callable_ = weakref.WeakMethod(lambda a: a + 1)
- self.assertIsNone(callable_())
-
- def testCallbackOnDeadObject(self):
- self.__count = 0
-
- def callback(ref):
- self.__count += 1
- self.assertIs(callable_, ref)
- dummy = Dummy()
- callable_ = weakref.WeakMethod(dummy.inc, callback)
- dummy = None
- self.assertEqual(self.__count, 1)
-
- def testCallbackOnDeadMethod(self):
- self.__count = 0
-
- def callback(ref):
- self.__count += 1
- self.assertIs(callable_, ref)
- dummy = Dummy()
- dummy.inc2 = lambda self, a: a + 1
- callable_ = weakref.WeakMethod(dummy.inc2, callback)
- dummy.inc2 = None
- self.assertEqual(self.__count, 1)
-
- def testCallbackOnDeadFunction(self):
- self.__count = 0
-
- def callback(ref):
- self.__count += 1
- self.assertIs(callable_, ref)
- store = lambda a: a + 1 # noqa: E731
- callable_ = weakref.WeakMethod(store, callback)
- store = None
- self.assertEqual(self.__count, 1)
-
- def testEquals(self):
- dummy = Dummy()
- callable1 = weakref.WeakMethod(dummy.inc)
- callable2 = weakref.WeakMethod(dummy.inc)
- self.assertEqual(callable1, callable2)
-
- def testInSet(self):
- callable_set = set([])
- dummy = Dummy()
- callable_set.add(weakref.WeakMethod(dummy.inc))
- callable_ = weakref.WeakMethod(dummy.inc)
- self.assertIn(callable_, callable_set)
-
- def testInDict(self):
- callable_dict = {}
- dummy = Dummy()
- callable_dict[weakref.WeakMethod(dummy.inc)] = 10
- callable_ = weakref.WeakMethod(dummy.inc)
- self.assertEqual(callable_dict.get(callable_), 10)
-
-
-class TestWeakMethodProxy(unittest.TestCase):
-
- def testMethod(self):
- dummy = Dummy()
- callable_ = weakref.WeakMethodProxy(dummy.inc)
- self.assertEqual(callable_(10), 11)
-
- def testMethodWithDeadObject(self):
- dummy = Dummy()
- method = weakref.WeakMethodProxy(dummy.inc)
- dummy = None
- self.assertRaises(ReferenceError, method, 9)
-
-
-class TestWeakList(unittest.TestCase):
- """Tests for weakref.WeakList"""
-
- def setUp(self):
- self.list = weakref.WeakList()
- self.object1 = Dummy()
- self.object2 = Dummy()
- self.list.append(self.object1)
- self.list.append(self.object2)
-
- def testAppend(self):
- obj = Dummy()
- self.list.append(obj)
- self.assertEqual(len(self.list), 3)
- obj = None
- self.assertEqual(len(self.list), 2)
-
- def testRemove(self):
- self.list.remove(self.object1)
- self.assertEqual(len(self.list), 1)
-
- def testPop(self):
- obj = self.list.pop(0)
- self.assertIs(obj, self.object1)
- self.assertEqual(len(self.list), 1)
-
- def testGetItem(self):
- self.assertIs(self.object1, self.list[0])
-
- def testGetItemSlice(self):
- objects = self.list[:]
- self.assertEqual(len(objects), 2)
- self.assertIs(self.object1, objects[0])
- self.assertIs(self.object2, objects[1])
-
- def testIter(self):
- obj_list = list(self.list)
- self.assertEqual(len(obj_list), 2)
- self.assertIs(self.object1, obj_list[0])
-
- def testLen(self):
- self.assertEqual(len(self.list), 2)
-
- def testSetItem(self):
- obj = Dummy()
- self.list[0] = obj
- self.assertIsNot(self.object1, self.list[0])
- obj = None
- self.assertEqual(len(self.list), 1)
-
- def testSetItemSlice(self):
- obj = Dummy()
- self.list[:] = [obj, obj]
- self.assertEqual(len(self.list), 2)
- self.assertIs(obj, self.list[0])
- self.assertIs(obj, self.list[1])
- obj = None
- self.assertEqual(len(self.list), 0)
-
- def testDelItem(self):
- del self.list[0]
- self.assertEqual(len(self.list), 1)
- self.assertIs(self.object2, self.list[0])
-
- def testDelItemSlice(self):
- del self.list[:]
- self.assertEqual(len(self.list), 0)
-
- def testContains(self):
- self.assertIn(self.object1, self.list)
-
- def testAdd(self):
- others = [Dummy()]
- l = self.list + others
- self.assertIs(l[0], self.object1)
- self.assertEqual(len(l), 3)
- others = None
- self.assertEqual(len(l), 2)
-
- def testExtend(self):
- others = [Dummy()]
- self.list.extend(others)
- self.assertIs(self.list[0], self.object1)
- self.assertEqual(len(self.list), 3)
- others = None
- self.assertEqual(len(self.list), 2)
-
- def testIadd(self):
- others = [Dummy()]
- self.list += others
- self.assertIs(self.list[0], self.object1)
- self.assertEqual(len(self.list), 3)
- others = None
- self.assertEqual(len(self.list), 2)
-
- def testMul(self):
- l = self.list * 2
- self.assertIs(l[0], self.object1)
- self.assertEqual(len(l), 4)
- self.object1 = None
- self.assertEqual(len(l), 2)
- self.assertIs(l[0], self.object2)
- self.assertIs(l[1], self.object2)
-
- def testImul(self):
- self.list *= 2
- self.assertIs(self.list[0], self.object1)
- self.assertEqual(len(self.list), 4)
- self.object1 = None
- self.assertEqual(len(self.list), 2)
- self.assertIs(self.list[0], self.object2)
- self.assertIs(self.list[1], self.object2)
-
- def testCount(self):
- self.list.append(self.object2)
- self.assertEqual(self.list.count(self.object1), 1)
- self.assertEqual(self.list.count(self.object2), 2)
-
- def testIndex(self):
- self.assertEqual(self.list.index(self.object1), 0)
- self.assertEqual(self.list.index(self.object2), 1)
-
- def testInsert(self):
- obj = Dummy()
- self.list.insert(1, obj)
- self.assertEqual(len(self.list), 3)
- self.assertIs(self.list[1], obj)
- obj = None
- self.assertEqual(len(self.list), 2)
-
- def testReverse(self):
- self.list.reverse()
- self.assertEqual(len(self.list), 2)
- self.assertIs(self.list[0], self.object2)
- self.assertIs(self.list[1], self.object1)
-
- def testReverted(self):
- new_list = reversed(self.list)
- self.assertEqual(len(new_list), 2)
- self.assertIs(self.list[1], self.object2)
- self.assertIs(self.list[0], self.object1)
- self.assertIs(new_list[0], self.object2)
- self.assertIs(new_list[1], self.object1)
- self.object1 = None
- self.assertEqual(len(new_list), 1)
-
- def testStr(self):
- self.assertNotEqual(self.list.__str__(), "[]")
-
- def testRepr(self):
- self.assertNotEqual(self.list.__repr__(), "[]")
-
- def testSort(self):
- # only a coverage
- self.list.sort()
- self.assertEqual(len(self.list), 2)
-
-
-def suite():
- test_suite = unittest.TestSuite()
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestWeakMethod))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestWeakMethodProxy))
- test_suite.addTest(
- unittest.defaultTestLoader.loadTestsFromTestCase(TestWeakList))
- return test_suite
-
-
-if __name__ == '__main__':
- unittest.main(defaultTest='suite')
diff --git a/silx/utils/testutils.py b/silx/utils/testutils.py
deleted file mode 100755
index 434beee..0000000
--- a/silx/utils/testutils.py
+++ /dev/null
@@ -1,333 +0,0 @@
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Utilities for writing tests.
-
-- :class:`ParametricTestCase` provides a :meth:`TestCase.subTest` replacement
- for Python < 3.4
-- :class:`TestLogging` with context or the :func:`test_logging` decorator
- enables testing the number of logging messages of different levels.
-"""
-
-__authors__ = ["T. Vincent"]
-__license__ = "MIT"
-__date__ = "26/01/2018"
-
-
-import contextlib
-import functools
-import logging
-import sys
-import unittest
-
-_logger = logging.getLogger(__name__)
-
-
-if sys.hexversion >= 0x030400F0: # Python >= 3.4
- class ParametricTestCase(unittest.TestCase):
- pass
-else:
- class ParametricTestCase(unittest.TestCase):
- """TestCase with subTest support for Python < 3.4.
-
- Add subTest method to support parametric tests.
- API is the same, but behavior differs:
- If a subTest fails, the following ones are not run.
- """
-
- _subtest_msg = None # Class attribute to provide a default value
-
- @contextlib.contextmanager
- def subTest(self, msg=None, **params):
- """Use as unittest.TestCase.subTest method in Python >= 3.4."""
- # Format arguments as: '[msg] (key=value, ...)'
- param_str = ', '.join(['%s=%s' % (k, v) for k, v in params.items()])
- self._subtest_msg = '[%s] (%s)' % (msg or '', param_str)
- yield
- self._subtest_msg = None
-
- def shortDescription(self):
- short_desc = super(ParametricTestCase, self).shortDescription()
- if self._subtest_msg is not None:
- # Append subTest message to shortDescription
- short_desc = ' '.join(
- [msg for msg in (short_desc, self._subtest_msg) if msg])
-
- return short_desc if short_desc else None
-
-
-def parameterize(test_case_class, *args, **kwargs):
- """Create a suite containing all tests taken from the given
- subclass, passing them the parameters.
-
- .. code-block:: python
-
- class TestParameterizedCase(unittest.TestCase):
- def __init__(self, methodName='runTest', foo=None):
- unittest.TestCase.__init__(self, methodName)
- self.foo = foo
-
- def suite():
- testSuite = unittest.TestSuite()
- testSuite.addTest(parameterize(TestParameterizedCase, foo=10))
- testSuite.addTest(parameterize(TestParameterizedCase, foo=50))
- return testSuite
- """
- test_loader = unittest.TestLoader()
- test_names = test_loader.getTestCaseNames(test_case_class)
- suite = unittest.TestSuite()
- for name in test_names:
- suite.addTest(test_case_class(name, *args, **kwargs))
- return suite
-
-
-class LoggingRuntimeError(RuntimeError):
- """Raised when the `TestLogging` fails"""
-
- def __init__(self, msg, records):
- super(LoggingRuntimeError, self).__init__(msg)
- self.records = records
-
- def __str__(self):
- return super(LoggingRuntimeError, self).__str__() + " -> " + str(self.records)
-
-
-class TestLogging(logging.Handler):
- """Context checking the number of logging messages from a specified Logger.
-
- It disables propagation of logging message while running.
-
- This is meant to be used as a with statement, for example:
-
- >>> with TestLogging(logger, error=2, warning=0):
- >>> pass # Run tests here expecting 2 ERROR and no WARNING from logger
- ...
-
- :param logger: Name or instance of the logger to test.
- (Default: root logger)
- :type logger: str or :class:`logging.Logger`
- :param int critical: Expected number of CRITICAL messages.
- Default: Do not check.
- :param int error: Expected number of ERROR messages.
- Default: Do not check.
- :param int warning: Expected number of WARNING messages.
- Default: Do not check.
- :param int info: Expected number of INFO messages.
- Default: Do not check.
- :param int debug: Expected number of DEBUG messages.
- Default: Do not check.
- :param int notset: Expected number of NOTSET messages.
- Default: Do not check.
- :raises RuntimeError: If the message counts are the expected ones.
- """
-
- def __init__(self, logger=None, critical=None, error=None,
- warning=None, info=None, debug=None, notset=None):
- if logger is None:
- logger = logging.getLogger()
- elif not isinstance(logger, logging.Logger):
- logger = logging.getLogger(logger)
- self.logger = logger
-
- self.records = []
-
- self.expected_count_by_level = {
- logging.CRITICAL: critical,
- logging.ERROR: error,
- logging.WARNING: warning,
- logging.INFO: info,
- logging.DEBUG: debug,
- logging.NOTSET: notset
- }
-
- self._expected_count = sum([v for k, v in self.expected_count_by_level.items() if v is not None])
- """Amount of any logging expected"""
-
- super(TestLogging, self).__init__()
-
- def __enter__(self):
- """Context (i.e., with) support"""
- self.records = [] # Reset recorded LogRecords
- self.logger.addHandler(self)
- self.logger.propagate = False
- # ensure no log message is ignored
- self.entry_level = self.logger.level * 1
- self.logger.setLevel(logging.DEBUG)
- self.entry_disabled = self.logger.disabled
- self.logger.disabled = False
- return self
-
- def can_be_checked(self):
- """Returns True if this listener have received enough messages to
- be valid, and then checked.
-
- This can be useful for asynchronous wait of messages. It allows process
- an early break, instead of waiting much time in an active loop.
- """
- return len(self.records) >= self._expected_count
-
- def get_count_by_level(self):
- """Returns the current message count by level.
- """
- count = {
- logging.CRITICAL: 0,
- logging.ERROR: 0,
- logging.WARNING: 0,
- logging.INFO: 0,
- logging.DEBUG: 0,
- logging.NOTSET: 0
- }
- for record in self.records:
- level = record.levelno
- if level in count:
- count[level] = count[level] + 1
- return count
-
- def __exit__(self, exc_type, exc_value, traceback):
- """Context (i.e., with) support"""
- self.logger.removeHandler(self)
- self.logger.propagate = True
- self.logger.setLevel(self.entry_level)
- self.logger.disabled = self.entry_disabled
-
- count_by_level = self.get_count_by_level()
-
- # Remove keys which does not matter
- ignored = [r for r, v in self.expected_count_by_level.items() if v is None]
- expected_count_by_level = dict(self.expected_count_by_level)
- for i in ignored:
- del count_by_level[i]
- del expected_count_by_level[i]
-
- if count_by_level != expected_count_by_level:
- # Re-send record logs through logger as they where masked
- # to help debug
- message = ""
- for level in count_by_level.keys():
- if message != "":
- message += ", "
- count = count_by_level[level]
- expected_count = expected_count_by_level[level]
- message += "%d %s (got %d)" % (expected_count, logging.getLevelName(level), count)
-
- raise LoggingRuntimeError(
- 'Expected %s' % message, records=list(self.records))
-
- def emit(self, record):
- """Override :meth:`logging.Handler.emit`"""
- self.records.append(record)
-
-
-def test_logging(logger=None, critical=None, error=None,
- warning=None, info=None, debug=None, notset=None):
- """Decorator checking number of logging messages.
-
- Propagation of logging messages is disabled by this decorator.
-
- In case the expected number of logging messages is not found, it raises
- a RuntimeError.
-
- >>> class Test(unittest.TestCase):
- ... @test_logging('module_logger_name', error=2, warning=0)
- ... def test(self):
- ... pass # Test expecting 2 ERROR and 0 WARNING messages
-
- :param logger: Name or instance of the logger to test.
- (Default: root logger)
- :type logger: str or :class:`logging.Logger`
- :param int critical: Expected number of CRITICAL messages.
- Default: Do not check.
- :param int error: Expected number of ERROR messages.
- Default: Do not check.
- :param int warning: Expected number of WARNING messages.
- Default: Do not check.
- :param int info: Expected number of INFO messages.
- Default: Do not check.
- :param int debug: Expected number of DEBUG messages.
- Default: Do not check.
- :param int notset: Expected number of NOTSET messages.
- Default: Do not check.
- """
- def decorator(func):
- test_context = TestLogging(logger, critical, error,
- warning, info, debug, notset)
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- with test_context:
- result = func(*args, **kwargs)
- return result
- return wrapper
- return decorator
-
-
-# Simulate missing library context
-class EnsureImportError(object):
- """This context manager allows to simulate the unavailability
- of a library, even if it is actually available. It ensures that
- an ImportError is raised if the code inside the context tries to
- import the module.
-
- It can be used to test that a correct fallback library is used,
- or that the expected error code is returned.
-
- Trivial example::
-
- from silx.utils.testutils import EnsureImportError
-
- with EnsureImportError("h5py"):
- try:
- import h5py
- except ImportError:
- print("Good")
-
- .. note::
-
- This context manager does not remove the library from the namespace,
- if it is already imported. It only ensures that any attempt to import
- it again will cause an ImportError to be raised.
- """
- def __init__(self, name):
- """
-
- :param str name: Name of module to be hidden (e.g. "h5py")
- """
- self.module_name = name
-
- def __enter__(self):
- """Simulate failed import by setting sys.modules[name]=None"""
- if self.module_name not in sys.modules:
- self._delete_on_exit = True
- self._backup = None
- else:
- self._delete_on_exit = False
- self._backup = sys.modules[self.module_name]
- sys.modules[self.module_name] = None
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """Restore previous state"""
- if self._delete_on_exit:
- del sys.modules[self.module_name]
- else:
- sys.modules[self.module_name] = self._backup
diff --git a/src/silx/__init__.py b/src/silx/__init__.py
new file mode 100644
index 0000000..0ad0357
--- /dev/null
+++ b/src/silx/__init__.py
@@ -0,0 +1,58 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""The silx package contains the following main sub-packages:
+
+- silx.gui: Qt widgets for data visualization and data file browsing
+- silx.image: Some processing functions for 2D images
+- silx.io: Reading and writing data files (HDF5/NeXus, SPEC, ...)
+- silx.math: Some processing functions for 1D, 2D, 3D, nD arrays
+- silx.opencl: OpenCL-based data processing
+- silx.sx: High-level silx functions suited for (I)Python console.
+- silx.utils: Miscellaneous convenient functions
+
+See silx documentation: http://www.silx.org/doc/silx/latest/
+"""
+
+from __future__ import absolute_import, print_function, division
+
+__authors__ = ["Jérôme Kieffer"]
+__license__ = "MIT"
+__date__ = "26/04/2018"
+
+import os as _os
+import logging as _logging
+from ._config import Config as _Config
+
+config = _Config()
+"""Global configuration shared with the whole library"""
+
+# Attach a do nothing logging handler for silx
+_logging.getLogger(__name__).addHandler(_logging.NullHandler())
+
+
+project = _os.path.basename(_os.path.dirname(_os.path.abspath(__file__)))
+
+from ._version import __date__ as date # noqa
+from ._version import version, version_info, hexversion, strictversion # noqa
diff --git a/silx/__main__.py b/src/silx/__main__.py
index f832a09..f832a09 100644
--- a/silx/__main__.py
+++ b/src/silx/__main__.py
diff --git a/silx/_config.py b/src/silx/_config.py
index fb0e409..fb0e409 100644
--- a/silx/_config.py
+++ b/src/silx/_config.py
diff --git a/src/silx/_version.py b/src/silx/_version.py
new file mode 100644
index 0000000..feb2639
--- /dev/null
+++ b/src/silx/_version.py
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Unique place where the version number is defined.
+
+provides:
+* version = "1.2.3" or "1.2.3-beta4"
+* version_info = named tuple (1,2,3,"beta",4)
+* hexversion: 0x010203B4
+* strictversion = "1.2.3b4
+* debianversion = "1.2.3~beta4"
+* calc_hexversion: the function to transform a version_tuple into an integer
+
+This is called hexversion since it only really looks meaningful when viewed as the
+result of passing it to the built-in hex() function.
+The version_info value may be used for a more human-friendly encoding of the same information.
+
+The hexversion is a 32-bit number with the following layout:
+Bits (big endian order) Meaning
+1-8 PY_MAJOR_VERSION (the 2 in 2.1.0a3)
+9-16 PY_MINOR_VERSION (the 1 in 2.1.0a3)
+17-24 PY_MICRO_VERSION (the 0 in 2.1.0a3)
+25-28 PY_RELEASE_LEVEL (0xA for alpha, 0xB for beta, 0xC for release candidate and 0xF for final)
+29-32 PY_RELEASE_SERIAL (the 3 in 2.1.0a3, zero for final releases)
+
+Thus 2.1.0a3 is hexversion 0x020100a3.
+
+"""
+
+from __future__ import absolute_import, print_function, division
+__authors__ = ["Jérôme Kieffer"]
+__license__ = "MIT"
+__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "30/09/2020"
+__status__ = "production"
+__docformat__ = 'restructuredtext'
+__all__ = ["date", "version_info", "strictversion", "hexversion", "debianversion",
+ "calc_hexversion"]
+
+RELEASE_LEVEL_VALUE = {"dev": 0,
+ "alpha": 10,
+ "beta": 11,
+ "candidate": 12,
+ "final": 15}
+
+PRERELEASE_NORMALIZED_NAME = {"dev": "a",
+ "alpha": "a",
+ "beta": "b",
+ "candidate": "rc"}
+
+MAJOR = 1
+MINOR = 0
+MICRO = 0
+RELEV = "final" # <16
+SERIAL = 0 # <16
+
+date = __date__
+
+from collections import namedtuple
+_version_info = namedtuple("version_info", ["major", "minor", "micro", "releaselevel", "serial"])
+
+version_info = _version_info(MAJOR, MINOR, MICRO, RELEV, SERIAL)
+
+strictversion = version = debianversion = "%d.%d.%d" % version_info[:3]
+if version_info.releaselevel != "final":
+ _prerelease = PRERELEASE_NORMALIZED_NAME[version_info[3]]
+ version += "-%s%s" % (_prerelease, version_info[-1])
+ debianversion += "~adev%i" % version_info[-1] if RELEV == "dev" else "~%s%i" % (_prerelease, version_info[-1])
+ strictversion += _prerelease + str(version_info[-1])
+
+
+def calc_hexversion(major=0, minor=0, micro=0, releaselevel="dev", serial=0):
+ """Calculate the hexadecimal version number from the tuple version_info:
+
+ :param major: integer
+ :param minor: integer
+ :param micro: integer
+ :param relev: integer or string
+ :param serial: integer
+ :return: integer always increasing with revision numbers
+ """
+ try:
+ releaselevel = int(releaselevel)
+ except ValueError:
+ releaselevel = RELEASE_LEVEL_VALUE.get(releaselevel, 0)
+
+ hex_version = int(serial)
+ hex_version |= releaselevel * 1 << 4
+ hex_version |= int(micro) * 1 << 8
+ hex_version |= int(minor) * 1 << 16
+ hex_version |= int(major) * 1 << 24
+ return hex_version
+
+
+hexversion = calc_hexversion(*version_info)
+
+if __name__ == "__main__":
+ print(version)
diff --git a/silx/app/__init__.py b/src/silx/app/__init__.py
index 3af680c..3af680c 100644
--- a/silx/app/__init__.py
+++ b/src/silx/app/__init__.py
diff --git a/src/silx/app/convert.py b/src/silx/app/convert.py
new file mode 100644
index 0000000..43baf7e
--- /dev/null
+++ b/src/silx/app/convert.py
@@ -0,0 +1,548 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Convert silx supported data files into HDF5 files"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/02/2019"
+
+import ast
+import os
+import argparse
+from glob import glob
+import logging
+import re
+import time
+import numpy
+
+import silx.io
+from silx.io.specfile import is_specfile
+from silx.io.fioh5 import is_fiofile
+from silx.io import fabioh5
+
+_logger = logging.getLogger(__name__)
+"""Module logger"""
+
+
+def c_format_string_to_re(pattern_string):
+ """
+
+ :param pattern_string: C style format string with integer patterns
+ (e.g. "%d", "%04d").
+ Not supported: fixed length padded with whitespaces (e.g "%4d", "%-4d")
+ :return: Equivalent regular expression (e.g. "\\d+", "\\d{4}")
+ """
+ # escape dots and backslashes
+ pattern_string = pattern_string.replace("\\", "\\\\")
+ pattern_string = pattern_string.replace(".", r"\.")
+
+ # %d
+ pattern_string = pattern_string.replace("%d", r"([-+]?\d+)")
+
+ # %0nd
+ for sub_pattern in re.findall(r"%0\d+d", pattern_string):
+ n = int(re.search(r"%0(\d+)d", sub_pattern).group(1))
+ if n == 1:
+ re_sub_pattern = r"([+-]?\d)"
+ else:
+ re_sub_pattern = r"([\d+-]\d{%d})" % (n - 1)
+ pattern_string = pattern_string.replace(sub_pattern, re_sub_pattern, 1)
+
+ return pattern_string
+
+
+def drop_indices_before_begin(filenames, regex, begin):
+ """
+
+ :param List[str] filenames: list of filenames
+ :param str regex: Regexp used to find indices in a filename
+ :param str begin: Comma separated list of begin indices
+ :return: List of filenames with only indices >= begin
+ """
+ begin_indices = list(map(int, begin.split(",")))
+ output_filenames = []
+ for fname in filenames:
+ m = re.match(regex, fname)
+ file_indices = list(map(int, m.groups()))
+ if len(file_indices) != len(begin_indices):
+ raise IOError("Number of indices found in filename "
+ "does not match number of parsed end indices.")
+ good_indices = True
+ for i, fidx in enumerate(file_indices):
+ if fidx < begin_indices[i]:
+ good_indices = False
+ if good_indices:
+ output_filenames.append(fname)
+ return output_filenames
+
+
+def drop_indices_after_end(filenames, regex, end):
+ """
+
+ :param List[str] filenames: list of filenames
+ :param str regex: Regexp used to find indices in a filename
+ :param str end: Comma separated list of end indices
+ :return: List of filenames with only indices <= end
+ """
+ end_indices = list(map(int, end.split(",")))
+ output_filenames = []
+ for fname in filenames:
+ m = re.match(regex, fname)
+ file_indices = list(map(int, m.groups()))
+ if len(file_indices) != len(end_indices):
+ raise IOError("Number of indices found in filename "
+ "does not match number of parsed end indices.")
+ good_indices = True
+ for i, fidx in enumerate(file_indices):
+ if fidx > end_indices[i]:
+ good_indices = False
+ if good_indices:
+ output_filenames.append(fname)
+ return output_filenames
+
+
+def are_files_missing_in_series(filenames, regex):
+ """Return True if any file is missing in a list of filenames
+ that are supposed to follow a pattern.
+
+ :param List[str] filenames: list of filenames
+ :param str regex: Regexp used to find indices in a filename
+ :return: boolean
+ :raises AssertionError: if a filename does not match the regexp
+ """
+ previous_indices = None
+ for fname in filenames:
+ m = re.match(regex, fname)
+ assert m is not None, \
+ "regex %s does not match filename %s" % (fname, regex)
+ new_indices = list(map(int, m.groups()))
+ if previous_indices is not None:
+ for old_idx, new_idx in zip(previous_indices, new_indices):
+ if (new_idx - old_idx) > 1:
+ _logger.error("Index increment > 1 in file series: "
+ "previous idx %d, next idx %d",
+ old_idx, new_idx)
+ return True
+ previous_indices = new_indices
+ return False
+
+
+def are_all_specfile(filenames):
+ """Return True if all files in a list are SPEC files.
+ :param List[str] filenames: list of filenames
+ """
+ for fname in filenames:
+ if not is_specfile(fname):
+ return False
+ return True
+
+
+def contains_specfile(filenames):
+ """Return True if any file in a list are SPEC files.
+ :param List[str] filenames: list of filenames
+ """
+ for fname in filenames:
+ if is_specfile(fname):
+ return True
+ return False
+
+
+def contains_fiofile(filenames):
+ """Return True if any file in a list are FIO files.
+ :param List[str] filenames: list of filenames
+ """
+ for fname in filenames:
+ if is_fiofile(fname):
+ return True
+ return False
+
+
+def are_all_fiofile(filenames):
+ """Return True if all files in a list are FIO files.
+ :param List[str] filenames: list of filenames
+ """
+ for fname in filenames:
+ if not is_fiofile(fname):
+ return False
+ return True
+
+
+def main(argv):
+ """
+ Main function to launch the converter as an application
+
+ :param argv: Command line arguments
+ :returns: exit status
+ """
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ 'input_files',
+ nargs="*",
+ help='Input files (EDF, TIFF, FIO, SPEC...). When specifying '
+ 'multiple files, you cannot specify both fabio images '
+ 'and SPEC (or FIO) files. Multiple SPEC or FIO files will '
+ 'simply be concatenated, with one entry per scan. '
+ 'Multiple image files will be merged into a single '
+ 'entry with a stack of images.')
+ # input_files and --filepattern are mutually exclusive
+ parser.add_argument(
+ '--file-pattern',
+ help='File name pattern for loading a series of indexed image files '
+ '(toto_%%04d.edf). This argument is incompatible with argument '
+ 'input_files. If an output URI with a HDF5 path is provided, '
+ 'only the content of the NXdetector group will be copied there. '
+ 'If no HDF5 path, or just "/", is given, a complete NXdata '
+ 'structure will be created.')
+ parser.add_argument(
+ '-o', '--output-uri',
+ default=time.strftime("%Y%m%d-%H%M%S") + '.h5',
+ help='Output file name (HDF5). An URI can be provided to write'
+ ' the data into a specific group in the output file: '
+ '/path/to/file::/path/to/group. '
+ 'If not provided, the filename defaults to a timestamp:'
+ ' YYYYmmdd-HHMMSS.h5')
+ parser.add_argument(
+ '-m', '--mode',
+ default="w-",
+ help='Write mode: "r+" (read/write, file must exist), '
+ '"w" (write, existing file is lost), '
+ '"w-" (write, fail if file exists) or '
+ '"a" (read/write if exists, create otherwise)')
+ parser.add_argument(
+ '--begin',
+ help='First file index, or first file indices to be considered. '
+ 'This argument only makes sense when used together with '
+ '--file-pattern. Provide as many start indices as there '
+ 'are indices in the file pattern, separated by commas. '
+ 'Examples: "--filepattern toto_%%d.edf --begin 100", '
+ ' "--filepattern toto_%%d_%%04d_%%02d.edf --begin 100,2000,5".')
+ parser.add_argument(
+ '--end',
+ help='Last file index, or last file indices to be considered. '
+ 'The same rules as with argument --begin apply. '
+ 'Example: "--filepattern toto_%%d_%%d.edf --end 199,1999"')
+ parser.add_argument(
+ '--add-root-group',
+ action="store_true",
+ help='This option causes each input file to be written to a '
+ 'specific root group with the same name as the file. When '
+ 'merging multiple input files, this can help preventing conflicts'
+ ' when datasets have the same name (see --overwrite-data). '
+ 'This option is ignored when using --file-pattern.')
+ parser.add_argument(
+ '--overwrite-data',
+ action="store_true",
+ help='If the output path exists and an input dataset has the same'
+ ' name as an existing output dataset, overwrite the output '
+ 'dataset (in modes "r+" or "a").')
+ parser.add_argument(
+ '--min-size',
+ type=int,
+ default=500,
+ help='Minimum number of elements required to be in a dataset to '
+ 'apply compression or chunking (default 500).')
+ parser.add_argument(
+ '--chunks',
+ nargs="?",
+ const="auto",
+ help='Chunk shape. Provide an argument that evaluates as a python '
+ 'tuple (e.g. "(1024, 768)"). If this option is provided without '
+ 'specifying an argument, the h5py library will guess a chunk for '
+ 'you. Note that if you specify an explicit chunking shape, it '
+ 'will be applied identically to all datasets with a large enough '
+ 'size (see --min-size). ')
+ parser.add_argument(
+ '--compression',
+ nargs="?",
+ const="gzip",
+ help='Compression filter. By default, the datasets in the output '
+ 'file are not compressed. If this option is specified without '
+ 'argument, the GZIP compression is used. Additional compression '
+ 'filters may be available, depending on your HDF5 installation.')
+
+ def check_gzip_compression_opts(value):
+ ivalue = int(value)
+ if ivalue < 0 or ivalue > 9:
+ raise argparse.ArgumentTypeError(
+ "--compression-opts must be an int from 0 to 9")
+ return ivalue
+
+ parser.add_argument(
+ '--compression-opts',
+ type=check_gzip_compression_opts,
+ help='Compression options. For "gzip", this may be an integer from '
+ '0 to 9, with a default of 4. This is only supported for GZIP.')
+ parser.add_argument(
+ '--shuffle',
+ action="store_true",
+ help='Enables the byte shuffle filter. This may improve the compression '
+ 'ratio for block oriented compressors like GZIP or LZF.')
+ parser.add_argument(
+ '--fletcher32',
+ action="store_true",
+ help='Adds a checksum to each chunk to detect data corruption.')
+ parser.add_argument(
+ '--debug',
+ action="store_true",
+ default=False,
+ help='Set logging system in debug mode')
+
+ options = parser.parse_args(argv[1:])
+
+ if options.debug:
+ logging.root.setLevel(logging.DEBUG)
+
+ # Import after parsing --debug
+ try:
+ # it should be loaded before h5py
+ import hdf5plugin # noqa
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ hdf5plugin = None
+
+ import h5py
+
+ try:
+ from silx.io.convert import write_to_h5
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ write_to_h5 = None
+
+ if hdf5plugin is None:
+ message = "Module 'hdf5plugin' is not installed. It supports additional hdf5"\
+ + " compressions. You can install it using \"pip install hdf5plugin\"."
+ _logger.debug(message)
+
+ # Process input arguments (mutually exclusive arguments)
+ if bool(options.input_files) == bool(options.file_pattern is not None):
+ if not options.input_files:
+ message = "You must specify either input files (at least one), "
+ message += "or a file pattern."
+ else:
+ message = "You cannot specify input files and a file pattern"
+ message += " at the same time."
+ _logger.error(message)
+ return -1
+ elif options.input_files:
+ # some shells (windows) don't interpret wildcard characters (*, ?, [])
+ old_input_list = list(options.input_files)
+ options.input_files = []
+ for fname in old_input_list:
+ globbed_files = glob(fname)
+ if not globbed_files:
+ # no files found, keep the name as it is, to raise an error later
+ options.input_files += [fname]
+ else:
+ # glob does not sort files, but the bash shell does
+ options.input_files += sorted(globbed_files)
+ else:
+ # File series
+ dirname = os.path.dirname(options.file_pattern)
+ file_pattern_re = c_format_string_to_re(options.file_pattern) + "$"
+ files_in_dir = glob(os.path.join(dirname, "*"))
+ _logger.debug("""
+ Processing file_pattern
+ dirname: %s
+ file_pattern_re: %s
+ files_in_dir: %s
+ """, dirname, file_pattern_re, files_in_dir)
+
+ options.input_files = sorted(list(filter(lambda name: re.match(file_pattern_re, name),
+ files_in_dir)))
+ _logger.debug("options.input_files: %s", options.input_files)
+
+ if options.begin is not None:
+ options.input_files = drop_indices_before_begin(options.input_files,
+ file_pattern_re,
+ options.begin)
+ _logger.debug("options.input_files after applying --begin: %s",
+ options.input_files)
+
+ if options.end is not None:
+ options.input_files = drop_indices_after_end(options.input_files,
+ file_pattern_re,
+ options.end)
+ _logger.debug("options.input_files after applying --end: %s",
+ options.input_files)
+
+ if are_files_missing_in_series(options.input_files,
+ file_pattern_re):
+ _logger.error("File missing in the file series. Aborting.")
+ return -1
+
+ if not options.input_files:
+ _logger.error("No file matching --file-pattern found.")
+ return -1
+
+ # Test that the output path is writeable
+ if "::" in options.output_uri:
+ output_name, hdf5_path = options.output_uri.split("::")
+ else:
+ output_name, hdf5_path = options.output_uri, "/"
+
+ if os.path.isfile(output_name):
+ if options.mode == "w-":
+ _logger.error("Output file %s exists and mode is 'w-' (default)."
+ " Aborting. To append data to an existing file, "
+ "use 'a' or 'r+'.",
+ output_name)
+ return -1
+ elif not os.access(output_name, os.W_OK):
+ _logger.error("Output file %s exists and is not writeable.",
+ output_name)
+ return -1
+ elif options.mode == "w":
+ _logger.info("Output file %s exists and mode is 'w'. "
+ "Overwriting existing file.", output_name)
+ elif options.mode in ["a", "r+"]:
+ _logger.info("Appending data to existing file %s.",
+ output_name)
+ else:
+ if options.mode == "r+":
+ _logger.error("Output file %s does not exist and mode is 'r+'"
+ " (append, file must exist). Aborting.",
+ output_name)
+ return -1
+ else:
+ _logger.info("Creating new output file %s.",
+ output_name)
+
+ # Test that all input files exist and are readable
+ bad_input = False
+ for fname in options.input_files:
+ if not os.access(fname, os.R_OK):
+ _logger.error("Cannot read input file %s.",
+ fname)
+ bad_input = True
+ if bad_input:
+ _logger.error("Aborting.")
+ return -1
+
+ # create_dataset special args
+ create_dataset_args = {}
+ if options.chunks is not None:
+ if options.chunks.lower() in ["auto", "true"]:
+ create_dataset_args["chunks"] = True
+ else:
+ try:
+ chunks = ast.literal_eval(options.chunks)
+ except (ValueError, SyntaxError):
+ _logger.error("Invalid --chunks argument %s", options.chunks)
+ return -1
+ if not isinstance(chunks, (tuple, list)):
+ _logger.error("--chunks argument str does not evaluate to a tuple")
+ return -1
+ else:
+ nitems = numpy.prod(chunks)
+ nbytes = nitems * 8
+ if nbytes > 10**6:
+ _logger.warning("Requested chunk size might be larger than"
+ " the default 1MB chunk cache, for float64"
+ " data. This can dramatically affect I/O "
+ "performances.")
+ create_dataset_args["chunks"] = chunks
+
+ if options.compression is not None:
+ try:
+ compression = int(options.compression)
+ except ValueError:
+ compression = options.compression
+ create_dataset_args["compression"] = compression
+
+ if options.compression_opts is not None:
+ create_dataset_args["compression_opts"] = options.compression_opts
+
+ if options.shuffle:
+ create_dataset_args["shuffle"] = True
+
+ if options.fletcher32:
+ create_dataset_args["fletcher32"] = True
+
+ if (len(options.input_files) > 1 and
+ not contains_specfile(options.input_files) and
+ not contains_fiofile(options.input_files) and
+ not options.add_root_group) or options.file_pattern is not None:
+ # File series -> stack of images
+ input_group = fabioh5.File(file_series=options.input_files)
+ if hdf5_path != "/":
+ # we want to append only data and headers to an existing file
+ input_group = input_group["/scan_0/instrument/detector_0"]
+ with h5py.File(output_name, mode=options.mode) as h5f:
+ write_to_h5(input_group, h5f,
+ h5path=hdf5_path,
+ overwrite_data=options.overwrite_data,
+ create_dataset_args=create_dataset_args,
+ min_size=options.min_size)
+
+ elif len(options.input_files) == 1 or \
+ are_all_specfile(options.input_files) or\
+ are_all_fiofile(options.input_files) or\
+ options.add_root_group:
+ # single file, or spec files
+ h5paths_and_groups = []
+ for input_name in options.input_files:
+ hdf5_path_for_file = hdf5_path
+ if options.add_root_group:
+ hdf5_path_for_file = hdf5_path.rstrip("/") + "/" + os.path.basename(input_name)
+ try:
+ h5paths_and_groups.append((hdf5_path_for_file,
+ silx.io.open(input_name)))
+ except IOError:
+ _logger.error("Cannot read file %s. If this is a file format "
+ "supported by the fabio library, you can try to"
+ " install fabio (`pip install fabio`)."
+ " Aborting conversion.",
+ input_name)
+ return -1
+
+ with h5py.File(output_name, mode=options.mode) as h5f:
+ for hdf5_path_for_file, input_group in h5paths_and_groups:
+ write_to_h5(input_group, h5f,
+ h5path=hdf5_path_for_file,
+ overwrite_data=options.overwrite_data,
+ create_dataset_args=create_dataset_args,
+ min_size=options.min_size)
+
+ else:
+ # multiple file, SPEC and fabio images mixed
+ _logger.error("Multiple files with incompatible formats specified. "
+ "You can provide multiple SPEC files or multiple image "
+ "files, but not both.")
+ return -1
+
+ with h5py.File(output_name, mode="r+") as h5f:
+ # append "silx convert" to the creator attribute, for NeXus files
+ previous_creator = h5f.attrs.get("creator", u"")
+ creator = "silx convert (v%s)" % silx.version
+ # only if it not already there
+ if creator not in previous_creator:
+ if not previous_creator:
+ new_creator = creator
+ else:
+ new_creator = previous_creator + "; " + creator
+ h5f.attrs["creator"] = numpy.array(
+ new_creator,
+ dtype=h5py.special_dtype(vlen=str))
+
+ return 0
diff --git a/silx/app/setup.py b/src/silx/app/setup.py
index 85c3662..85c3662 100644
--- a/silx/app/setup.py
+++ b/src/silx/app/setup.py
diff --git a/src/silx/app/test/__init__.py b/src/silx/app/test/__init__.py
new file mode 100644
index 0000000..7790ee5
--- /dev/null
+++ b/src/silx/app/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/app/test/test_convert.py b/src/silx/app/test/test_convert.py
new file mode 100644
index 0000000..2148db5
--- /dev/null
+++ b/src/silx/app/test/test_convert.py
@@ -0,0 +1,156 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module testing silx.app.convert"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import os
+import sys
+import tempfile
+import unittest
+import io
+import gc
+import h5py
+
+import silx
+from .. import convert
+from silx.utils import testutils
+from silx.io.utils import h5py_read_dataset
+
+
+# content of a spec file
+sftext = """#F /tmp/sf.dat
+#E 1455180875
+#D Thu Feb 11 09:54:35 2016
+#C imaging User = opid17
+#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
+#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
+#o0 pshg mrtu mrtd
+#o2 ss1vo ss1ho ss1vg
+
+#J0 Seconds IA ion.mono Current
+#J1 xbpmc2 idgap1 Inorm
+
+#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
+#D Thu Feb 11 09:55:20 2016
+#T 0.2 (Seconds)
+#P0 180.005 -0.66875 0.87125
+#P1 14.74255 16.197579 12.238283
+#N 4
+#L MRTSlit UP second column 3rd_col
+-1.23 5.89 8
+8.478100E+01 5 1.56
+3.14 2.73 -3.14
+1.2 2.3 3.4
+
+#S 1 aaaaaa
+#D Thu Feb 11 10:00:32 2016
+#@MCADEV 1
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#N 3
+#L uno duo
+1 2
+@A 0 1 2
+@A 10 9 8
+3 4
+@A 3.1 4 5
+@A 7 6 5
+5 6
+@A 6 7.7 8
+@A 4 3 2
+"""
+
+
+class TestConvertCommand(unittest.TestCase):
+ """Test command line parsing"""
+
+ def testHelp(self):
+ # option -h must cause a `raise SystemExit` or a `return 0`
+ try:
+ result = convert.main(["convert", "--help"])
+ except SystemExit as e:
+ result = e.args[0]
+ self.assertEqual(result, 0)
+
+ def testWrongOption(self):
+ # presence of a wrong option must cause a SystemExit or a return
+ # with a non-zero status
+ try:
+ result = convert.main(["convert", "--foo"])
+ except SystemExit as e:
+ result = e.args[0]
+ self.assertNotEqual(result, 0)
+
+ @testutils.validate_logging(convert._logger.name, error=3)
+ # one error log per missing file + one "Aborted" error log
+ def testWrongFiles(self):
+ result = convert.main(["convert", "foo.spec", "bar.edf"])
+ self.assertNotEqual(result, 0)
+
+ def testFile(self):
+ # create a writable temp directory
+ tempdir = tempfile.mkdtemp()
+
+ # write a temporary SPEC file
+ specname = os.path.join(tempdir, "input.dat")
+ with io.open(specname, "wb") as fd:
+ if sys.version_info < (3, ):
+ fd.write(sftext)
+ else:
+ fd.write(bytes(sftext, 'ascii'))
+
+ # convert it
+ h5name = os.path.join(tempdir, "output.h5")
+ assert not os.path.isfile(h5name)
+ command_list = ["convert", "-m", "w",
+ specname, "-o", h5name]
+ result = convert.main(command_list)
+
+ self.assertEqual(result, 0)
+ self.assertTrue(os.path.isfile(h5name))
+
+ with h5py.File(h5name, "r") as h5f:
+ title12 = h5py_read_dataset(h5f["/1.2/title"])
+ if sys.version_info < (3, ):
+ title12 = title12.encode("utf-8")
+ self.assertEqual(title12,
+ "aaaaaa")
+
+ creator = h5f.attrs.get("creator")
+ self.assertIsNotNone(creator, "No creator attribute in NXroot group")
+ if sys.version_info < (3, ):
+ creator = creator.encode("utf-8")
+ self.assertIn("silx convert (v%s)" % silx.version, creator)
+
+ # delete input file
+ gc.collect() # necessary to free spec file on Windows
+ os.unlink(specname)
+ os.unlink(h5name)
+ os.rmdir(tempdir)
diff --git a/src/silx/app/test_.py b/src/silx/app/test_.py
new file mode 100644
index 0000000..2b6bdf8
--- /dev/null
+++ b/src/silx/app/test_.py
@@ -0,0 +1,45 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Launch unittests of the library"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/01/2018"
+
+
+def main(argv):
+ """
+ Main function to launch the unittests as an application
+
+ :param argv: Command line arguments
+ :returns: exit status
+ """
+ import silx.test
+ import pytest
+
+ if silx.test.run_tests(args=argv[1:]) == pytest.ExitCode.OK:
+ exit_status = 0
+ else:
+ exit_status = 1
+ return exit_status
diff --git a/src/silx/app/view/About.py b/src/silx/app/view/About.py
new file mode 100644
index 0000000..85f1450
--- /dev/null
+++ b/src/silx/app/view/About.py
@@ -0,0 +1,258 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""About box for Silx viewer"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/07/2018"
+
+import os
+import sys
+
+from silx.gui import qt
+from silx.gui import icons
+
+_LICENSE_TEMPLATE = """<p align="center">
+<b>Copyright (C) {year} European Synchrotron Radiation Facility</b>
+</p>
+
+<p align="justify">
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+</p>
+
+<p align="justify">
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+</p>
+
+<p align="justify">
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+</p>
+"""
+
+
+class About(qt.QDialog):
+ """
+ Util dialog to display an common about box for all the silx GUIs.
+ """
+
+ def __init__(self, parent=None):
+ """
+ :param files_: List of HDF5 or Spec files (pathes or
+ :class:`silx.io.spech5.SpecH5` or :class:`h5py.File`
+ instances)
+ """
+ super(About, self).__init__(parent)
+ self.__createLayout()
+ self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ self.setModal(True)
+ self.setApplicationName(None)
+
+ def __createLayout(self):
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(24, 15, 24, 20)
+ layout.setSpacing(8)
+
+ self.__label = qt.QLabel(self)
+ self.__label.setWordWrap(True)
+ flags = self.__label.textInteractionFlags()
+ flags = flags | qt.Qt.TextSelectableByKeyboard
+ flags = flags | qt.Qt.TextSelectableByMouse
+ self.__label.setTextInteractionFlags(flags)
+ self.__label.setOpenExternalLinks(True)
+ self.__label.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Preferred)
+
+ licenseButton = qt.QPushButton(self)
+ licenseButton.setText("License...")
+ licenseButton.clicked.connect(self.__displayLicense)
+ licenseButton.setAutoDefault(False)
+
+ self.__options = qt.QDialogButtonBox()
+ self.__options.addButton(licenseButton, qt.QDialogButtonBox.ActionRole)
+ okButton = self.__options.addButton(qt.QDialogButtonBox.Ok)
+ okButton.setDefault(True)
+ okButton.clicked.connect(self.accept)
+
+ layout.addWidget(self.__label)
+ layout.addWidget(self.__options)
+ layout.setStretch(0, 100)
+ layout.setStretch(1, 0)
+
+ def getHtmlLicense(self):
+ """Returns the text license in HTML format.
+
+ :rtype: str
+ """
+ from silx._version import __date__ as date
+ year = date.split("/")[2]
+ info = dict(
+ year=year
+ )
+ textLicense = _LICENSE_TEMPLATE.format(**info)
+ return textLicense
+
+ def __displayLicense(self):
+ """Displays the license used by silx."""
+ text = self.getHtmlLicense()
+ licenseDialog = qt.QMessageBox(self)
+ licenseDialog.setWindowTitle("License")
+ licenseDialog.setText(text)
+ licenseDialog.exec()
+
+ def setApplicationName(self, name):
+ self.__applicationName = name
+ if name is None:
+ self.setWindowTitle("About")
+ else:
+ self.setWindowTitle("About %s" % name)
+ self.__updateText()
+
+ @staticmethod
+ def __formatOptionalLibraries(name, isAvailable):
+ """Utils to format availability of features"""
+ if isAvailable:
+ template = '<b>%s</b> is <font color="green">loaded</font>'
+ else:
+ template = '<b>%s</b> is <font color="red">not loaded</font>'
+ return template % name
+
+ @staticmethod
+ def __formatOptionalFilters(name, isAvailable):
+ """Utils to format availability of features"""
+ if isAvailable:
+ template = '<b>%s</b> is <font color="green">available</font>'
+ else:
+ template = '<b>%s</b> is <font color="red">not available</font>'
+ return template % name
+
+ def __updateText(self):
+ """Update the content of the dialog according to the settings."""
+ import silx._version
+ import h5py.version
+
+ message = """<table>
+ <tr><td width="50%" align="center" valign="middle">
+ <img src="{silx_image_path}" width="100" />
+ </td><td width="50%" align="center" valign="middle">
+ <b>{application_name}</b>
+ <br />
+ <br />{silx_version}
+ <br />
+ <br /><a href="{project_url}">Upstream project on GitHub</a>
+ </td></tr>
+ </table>
+ <dl>
+ <dt><b>Silx version</b></dt><dd>{silx_version}</dd>
+ <dt><b>HDF5 version</b></dt><dd>{hdf5_version}</dd>
+ <dt><b>h5py version</b></dt><dd>{h5py_version}</dd>
+ <dt><b>Qt version</b></dt><dd>{qt_version}</dd>
+ <dt><b>Qt binding</b></dt><dd>{qt_binding}</dd>
+ <dt><b>Python version</b></dt><dd>{python_version}</dd>
+ <dt><b>Optional libraries</b></dt><dd>{optional_lib}</dd>
+ </dl>
+ <p>
+ Copyright (C) <a href="{esrf_url}">European Synchrotron Radiation Facility</a>
+ </p>
+ """
+
+ optionals = []
+ if h5py.version.hdf5_version_tuple >= (1, 10, 2):
+ # Previous versions only return True if the filter was first used
+ # to decode a dataset
+ import h5py.h5z
+ FILTER_LZ4 = 32004
+ FILTER_BITSHUFFLE = 32008
+ filters = [
+ ("HDF5 LZ4 filter", FILTER_LZ4),
+ ("HDF5 Bitshuffle filter", FILTER_BITSHUFFLE),
+ ]
+ for name, filterId in filters:
+ isAvailable = h5py.h5z.filter_avail(filterId)
+ optionals.append(self.__formatOptionalFilters(name, isAvailable))
+ else:
+ optionals.append(self.__formatOptionalLibraries("hdf5plugin", "hdf5plugin" in sys.modules))
+
+ # Access to the logo in SVG or PNG
+ logo = icons.getQFile("silx:" + os.path.join("gui", "logo", "silx"))
+
+ info = dict(
+ application_name=self.__applicationName,
+ esrf_url="http://www.esrf.eu",
+ project_url="https://github.com/silx-kit/silx",
+ silx_version=silx._version.version,
+ h5py_version=h5py.version.version,
+ hdf5_version=h5py.version.hdf5_version,
+ qt_binding=qt.BINDING,
+ qt_version=qt.qVersion(),
+ python_version=sys.version.replace("\n", "<br />"),
+ optional_lib="<br />".join(optionals),
+ silx_image_path=logo.fileName()
+ )
+
+ self.__label.setText(message.format(**info))
+ self.__updateSize()
+
+ def __updateSize(self):
+ """Force the size to a QMessageBox like size."""
+ if qt.BINDING in ("PySide2", "PyQt5"):
+ screenSize = qt.QApplication.desktop().availableGeometry(qt.QCursor.pos()).size()
+ else: # Qt6
+ screenSize = qt.QApplication.instance().primaryScreen().availableGeometry().size()
+ hardLimit = min(screenSize.width() - 480, 1000)
+ if screenSize.width() <= 1024:
+ hardLimit = screenSize.width()
+ softLimit = min(screenSize.width() / 2, 420)
+
+ layoutMinimumSize = self.layout().totalMinimumSize()
+ width = layoutMinimumSize.width()
+ if width > softLimit:
+ width = softLimit
+ if width > hardLimit:
+ width = hardLimit
+
+ height = layoutMinimumSize.height()
+ self.setFixedSize(width, height)
+
+ @staticmethod
+ def about(parent, applicationName):
+ """Displays a silx about box with title and text text.
+
+ :param qt.QWidget parent: The parent widget
+ :param str title: The title of the dialog
+ :param str applicationName: The content of the dialog
+ """
+ dialog = About(parent)
+ dialog.setApplicationName(applicationName)
+ dialog.exec()
diff --git a/src/silx/app/view/ApplicationContext.py b/src/silx/app/view/ApplicationContext.py
new file mode 100644
index 0000000..324f3b8
--- /dev/null
+++ b/src/silx/app/view/ApplicationContext.py
@@ -0,0 +1,195 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Browse a data file with a GUI"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/05/2018"
+
+import weakref
+import logging
+
+import silx
+from silx.gui.data.DataViews import DataViewHooks
+from silx.gui.colors import Colormap
+from silx.gui.dialog.ColormapDialog import ColormapDialog
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ApplicationContext(DataViewHooks):
+ """
+ Store the conmtext of the application
+
+ It overwrites the DataViewHooks to custom the use of the DataViewer for
+ the silx view application.
+
+ - Create a single colormap shared with all the views
+ - Create a single colormap dialog shared with all the views
+ """
+
+ def __init__(self, parent, settings=None):
+ self.__parent = weakref.ref(parent)
+ self.__defaultColormap = None
+ self.__defaultColormapDialog = None
+ self.__settings = settings
+ self.__recentFiles = []
+
+ def getSettings(self):
+ """Returns actual application settings.
+
+ :rtype: qt.QSettings
+ """
+ return self.__settings
+
+ def restoreLibrarySettings(self):
+ """Restore the library settings, which must be done early"""
+ settings = self.__settings
+ if settings is None:
+ return
+ settings.beginGroup("library")
+ plotBackend = settings.value("plot.backend", "")
+ plotImageYAxisOrientation = settings.value("plot-image.y-axis-orientation", "")
+ settings.endGroup()
+
+ # Use matplotlib backend by default
+ silx.config.DEFAULT_PLOT_BACKEND = \
+ "opengl" if plotBackend == "opengl" else "matplotlib"
+ if plotImageYAxisOrientation != "":
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = plotImageYAxisOrientation
+
+ def restoreSettings(self):
+ """Restore the settings of all the application"""
+ settings = self.__settings
+ if settings is None:
+ return
+ parent = self.__parent()
+ parent.restoreSettings(settings)
+
+ settings.beginGroup("colormap")
+ byteArray = settings.value("default", None)
+ if byteArray is not None:
+ try:
+ colormap = Colormap()
+ colormap.restoreState(byteArray)
+ self.__defaultColormap = colormap
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ settings.endGroup()
+
+ self.__recentFiles = []
+ settings.beginGroup("recent-files")
+ for index in range(1, 10 + 1):
+ if not settings.contains("path%d" % index):
+ break
+ filePath = settings.value("path%d" % index)
+ self.__recentFiles.append(filePath)
+ settings.endGroup()
+
+ def saveSettings(self):
+ """Save the settings of all the application"""
+ settings = self.__settings
+ if settings is None:
+ return
+ parent = self.__parent()
+ parent.saveSettings(settings)
+
+ if self.__defaultColormap is not None:
+ settings.beginGroup("colormap")
+ settings.setValue("default", self.__defaultColormap.saveState())
+ settings.endGroup()
+
+ settings.beginGroup("library")
+ settings.setValue("plot.backend", silx.config.DEFAULT_PLOT_BACKEND)
+ settings.setValue("plot-image.y-axis-orientation", silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION)
+ settings.endGroup()
+
+ settings.beginGroup("recent-files")
+ for index in range(0, 11):
+ key = "path%d" % (index + 1)
+ if index < len(self.__recentFiles):
+ filePath = self.__recentFiles[index]
+ settings.setValue(key, filePath)
+ else:
+ settings.remove(key)
+ settings.endGroup()
+
+ def getRecentFiles(self):
+ """Returns the list of recently opened files.
+
+ The list is limited to the last 10 entries. The newest file path is
+ in first.
+
+ :rtype: List[str]
+ """
+ return self.__recentFiles
+
+ def pushRecentFile(self, filePath):
+ """Push a new recent file to the list.
+
+ If the file is duplicated in the list, all duplications are removed
+ before inserting the new filePath.
+
+ If the list becan bigger than 10 items, oldest paths are removed.
+
+ :param filePath: File path to push
+ """
+ # Remove old occurencies
+ self.__recentFiles[:] = (f for f in self.__recentFiles if f != filePath)
+ self.__recentFiles.insert(0, filePath)
+ while len(self.__recentFiles) > 10:
+ self.__recentFiles.pop()
+
+ def clearRencentFiles(self):
+ """Clear the history of the rencent files.
+ """
+ self.__recentFiles[:] = []
+
+ def getColormap(self, view):
+ """Returns a default colormap.
+
+ Override from DataViewHooks
+
+ :rtype: Colormap
+ """
+ if self.__defaultColormap is None:
+ self.__defaultColormap = Colormap(name="viridis")
+ return self.__defaultColormap
+
+ def getColormapDialog(self, view):
+ """Returns a shared color dialog as default for all the views.
+
+ Override from DataViewHooks
+
+ :rtype: ColorDialog
+ """
+ if self.__defaultColormapDialog is None:
+ parent = self.__parent()
+ if parent is None:
+ return None
+ dialog = ColormapDialog(parent=parent)
+ dialog.setModal(False)
+ self.__defaultColormapDialog = dialog
+ return self.__defaultColormapDialog
diff --git a/src/silx/app/view/CustomNxdataWidget.py b/src/silx/app/view/CustomNxdataWidget.py
new file mode 100644
index 0000000..8c6cd39
--- /dev/null
+++ b/src/silx/app/view/CustomNxdataWidget.py
@@ -0,0 +1,1002 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+"""Widget to custom NXdata groups"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "15/06/2018"
+
+import logging
+import numpy
+import weakref
+
+from silx.gui import qt
+from silx.io import commonh5
+import silx.io.nxdata
+from silx.gui.hdf5._utils import Hdf5DatasetMimeData
+from silx.gui.data.TextFormatter import TextFormatter
+from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
+from silx.gui import icons
+
+
+_logger = logging.getLogger(__name__)
+_formatter = TextFormatter()
+_hdf5Formatter = Hdf5Formatter(textFormatter=_formatter)
+
+
+class _RowItems(qt.QStandardItem):
+ """Define the list of items used for a specific row."""
+
+ def type(self):
+ return qt.QStandardItem.UserType + 1
+
+ def getRowItems(self):
+ """Returns the list of items used for a specific row.
+
+ The first item should be this class.
+
+ :rtype: List[qt.QStandardItem]
+ """
+ raise NotImplementedError()
+
+
+class _DatasetItemRow(_RowItems):
+ """Define a row which can contain a dataset."""
+
+ def __init__(self, label="", dataset=None):
+ """Constructor"""
+ super(_DatasetItemRow, self).__init__(label)
+ self.setEditable(False)
+ self.setDropEnabled(False)
+ self.setDragEnabled(False)
+
+ self.__name = qt.QStandardItem()
+ self.__name.setEditable(False)
+ self.__name.setDropEnabled(True)
+
+ self.__type = qt.QStandardItem()
+ self.__type.setEditable(False)
+ self.__type.setDropEnabled(False)
+ self.__type.setDragEnabled(False)
+
+ self.__shape = qt.QStandardItem()
+ self.__shape.setEditable(False)
+ self.__shape.setDropEnabled(False)
+ self.__shape.setDragEnabled(False)
+
+ self.setDataset(dataset)
+
+ def getDefaultFormatter(self):
+ """Get the formatter used to display dataset informations.
+
+ :rtype: Hdf5Formatter
+ """
+ return _hdf5Formatter
+
+ def setDataset(self, dataset):
+ """Set the dataset stored in this item.
+
+ :param Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset] dataset:
+ The dataset to store.
+ """
+ self.__dataset = dataset
+ if self.__dataset is not None:
+ name = self.__dataset.name
+
+ if silx.io.is_dataset(dataset):
+ type_ = self.getDefaultFormatter().humanReadableType(dataset)
+ shape = self.getDefaultFormatter().humanReadableShape(dataset)
+
+ if dataset.shape is None:
+ icon_name = "item-none"
+ elif len(dataset.shape) < 4:
+ icon_name = "item-%ddim" % len(dataset.shape)
+ else:
+ icon_name = "item-ndim"
+ icon = icons.getQIcon(icon_name)
+ else:
+ type_ = ""
+ shape = ""
+ icon = qt.QIcon()
+ else:
+ name = ""
+ type_ = ""
+ shape = ""
+ icon = qt.QIcon()
+
+ self.__icon = icon
+ self.__name.setText(name)
+ self.__name.setDragEnabled(self.__dataset is not None)
+ self.__name.setIcon(self.__icon)
+ self.__type.setText(type_)
+ self.__shape.setText(shape)
+
+ parent = self.parent()
+ if parent is not None:
+ self.parent()._datasetUpdated()
+
+ def getDataset(self):
+ """Returns the dataset stored within the item."""
+ return self.__dataset
+
+ def getRowItems(self):
+ """Returns the list of items used for a specific row.
+
+ The first item should be this class.
+
+ :rtype: List[qt.QStandardItem]
+ """
+ return [self, self.__name, self.__type, self.__shape]
+
+
+class _DatasetAxisItemRow(_DatasetItemRow):
+ """Define a row describing an axis."""
+
+ def __init__(self):
+ """Constructor"""
+ super(_DatasetAxisItemRow, self).__init__()
+
+ def setAxisId(self, axisId):
+ """Set the id of the axis (the first axis is 0)
+
+ :param int axisId: Identifier of this axis.
+ """
+ self.__axisId = axisId
+ label = "Axis %d" % (axisId + 1)
+ self.setText(label)
+
+ def getAxisId(self):
+ """Returns the identifier of this axis.
+
+ :rtype: int
+ """
+ return self.__axisId
+
+
+class _NxDataItem(qt.QStandardItem):
+ """
+ Define a custom NXdata.
+ """
+
+ def __init__(self):
+ """Constructor"""
+ qt.QStandardItem.__init__(self)
+ self.__error = None
+ self.__title = None
+ self.__axes = []
+ self.__virtual = None
+
+ item = _DatasetItemRow("Signal", None)
+ self.appendRow(item.getRowItems())
+ self.__signal = item
+
+ self.setEditable(False)
+ self.setDragEnabled(False)
+ self.setDropEnabled(False)
+ self.__setError(None)
+
+ def getRowItems(self):
+ """Returns the list of items used for a specific row.
+
+ The first item should be this class.
+
+ :rtype: List[qt.QStandardItem]
+ """
+ row = [self]
+ for _ in range(3):
+ item = qt.QStandardItem("")
+ item.setEditable(False)
+ item.setDragEnabled(False)
+ item.setDropEnabled(False)
+ row.append(item)
+ return row
+
+ def _datasetUpdated(self):
+ """Called when the NXdata contained of the item have changed.
+
+ It invalidates the NXdata stored and send an event `sigNxdataUpdated`.
+ """
+ self.__virtual = None
+ self.__setError(None)
+ model = self.model()
+ if model is not None:
+ model.sigNxdataUpdated.emit(self.index())
+
+ def createVirtualGroup(self):
+ """Returns a new virtual Group using a NeXus NXdata structure to store
+ data
+
+ :rtype: silx.io.commonh5.Group
+ """
+ name = ""
+ if self.__title is not None:
+ name = self.__title
+ virtual = commonh5.Group(name)
+ virtual.attrs["NX_class"] = "NXdata"
+
+ if self.__title is not None:
+ virtual.attrs["title"] = self.__title
+
+ if self.__signal is not None:
+ signal = self.__signal.getDataset()
+ if signal is not None:
+ # Could be done using a link instead of a copy
+ node = commonh5.DatasetProxy("signal", target=signal)
+ virtual.attrs["signal"] = "signal"
+ virtual.add_node(node)
+
+ axesAttr = []
+ for i, axis in enumerate(self.__axes):
+ if axis is None:
+ name = "."
+ else:
+ axis = axis.getDataset()
+ if axis is None:
+ name = "."
+ else:
+ name = "axis%d" % i
+ node = commonh5.DatasetProxy(name, target=axis)
+ virtual.add_node(node)
+ axesAttr.append(name)
+
+ if axesAttr != []:
+ virtual.attrs["axes"] = numpy.array(axesAttr)
+
+ validator = silx.io.nxdata.NXdata(virtual)
+ if not validator.is_valid:
+ message = "<html>"
+ message += "This NXdata is not consistant"
+ message += "<ul>"
+ for issue in validator.issues:
+ message += "<li>%s</li>" % issue
+ message += "</ul>"
+ message += "</html>"
+ self.__setError(message)
+ else:
+ self.__setError(None)
+ return virtual
+
+ def isValid(self):
+ """Returns true if the stored NXdata is valid
+
+ :rtype: bool
+ """
+ return self.__error is None
+
+ def getVirtualGroup(self):
+ """Returns a cached virtual Group using a NeXus NXdata structure to
+ store data.
+
+ If the stored NXdata was invalidated, :meth:`createVirtualGroup` is
+ internally called to update the cache.
+
+ :rtype: silx.io.commonh5.Group
+ """
+ if self.__virtual is None:
+ self.__virtual = self.createVirtualGroup()
+ return self.__virtual
+
+ def getTitle(self):
+ """Returns the title of the NXdata
+
+ :rtype: str
+ """
+ return self.text()
+
+ def setTitle(self, title):
+ """Set the title of the NXdata
+
+ :param str title: The title of this NXdata
+ """
+ self.setText(title)
+
+ def __setError(self, error):
+ """Set the error message in case of the current state of the stored
+ NXdata is not valid.
+
+ :param str error: Message to display
+ """
+ self.__error = error
+ style = qt.QApplication.style()
+ if error is None:
+ message = ""
+ icon = style.standardIcon(qt.QStyle.SP_DirLinkIcon)
+ else:
+ message = error
+ icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
+ self.setIcon(icon)
+ self.setToolTip(message)
+
+ def getError(self):
+ """Returns the error message in case the NXdata is not valid.
+
+ :rtype: str"""
+ return self.__error
+
+ def setSignalDataset(self, dataset):
+ """Set the dataset to use as signal with this NXdata.
+
+ :param Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset] dataset:
+ The dataset to use as signal.
+ """
+
+ self.__signal.setDataset(dataset)
+ self._datasetUpdated()
+
+ def getSignalDataset(self):
+ """Returns the dataset used as signal.
+
+ :rtype: Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset]
+ """
+ return self.__signal.getDataset()
+
+ def setAxesDatasets(self, datasets):
+ """Set all the available dataset used as axes.
+
+ Axes will be created or removed from the GUI in order to provide the
+ same amount of requested axes.
+
+ A `None` element is an axes with no dataset.
+
+ :param List[Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset,None]] datasets:
+ List of dataset to use as axes.
+ """
+ for i, dataset in enumerate(datasets):
+ if i < len(self.__axes):
+ mustAppend = False
+ item = self.__axes[i]
+ else:
+ mustAppend = True
+ item = _DatasetAxisItemRow()
+ item.setAxisId(i)
+ item.setDataset(dataset)
+ if mustAppend:
+ self.__axes.append(item)
+ self.appendRow(item.getRowItems())
+
+ # Clean up extra axis
+ for i in range(len(datasets), len(self.__axes)):
+ item = self.__axes.pop(len(datasets))
+ self.removeRow(item.row())
+
+ self._datasetUpdated()
+
+ def getAxesDatasets(self):
+ """Returns available axes as dataset.
+
+ A `None` element is an axes with no dataset.
+
+ :rtype: List[Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset,None]]
+ """
+ datasets = []
+ for axis in self.__axes:
+ datasets.append(axis.getDataset())
+ return datasets
+
+
+class _Model(qt.QStandardItemModel):
+ """Model storing a list of custom NXdata items.
+
+ Supports drag and drop of datasets.
+ """
+
+ sigNxdataUpdated = qt.Signal(qt.QModelIndex)
+ """Emitted when stored NXdata was edited"""
+
+ def __init__(self, parent=None):
+ """Constructor"""
+ qt.QStandardItemModel.__init__(self, parent)
+ root = self.invisibleRootItem()
+ root.setDropEnabled(True)
+ root.setDragEnabled(False)
+
+ def supportedDropActions(self):
+ """Inherited method to redefine supported drop actions."""
+ return qt.Qt.CopyAction | qt.Qt.MoveAction
+
+ def mimeTypes(self):
+ """Inherited method to redefine draggable mime types."""
+ return [Hdf5DatasetMimeData.MIME_TYPE]
+
+ def mimeData(self, indexes):
+ """
+ Returns an object that contains serialized items of data corresponding
+ to the list of indexes specified.
+
+ :param List[qt.QModelIndex] indexes: List of indexes
+ :rtype: qt.QMimeData
+ """
+ if len(indexes) > 1:
+ return None
+ if len(indexes) == 0:
+ return None
+
+ qindex = indexes[0]
+ qindex = self.index(qindex.row(), 0, parent=qindex.parent())
+ item = self.itemFromIndex(qindex)
+ if isinstance(item, _DatasetItemRow):
+ dataset = item.getDataset()
+ if dataset is None:
+ return None
+ else:
+ mimeData = Hdf5DatasetMimeData(dataset=item.getDataset())
+ else:
+ mimeData = None
+ return mimeData
+
+ def dropMimeData(self, mimedata, action, row, column, parentIndex):
+ """Inherited method to handle a drop operation to this model."""
+ if action == qt.Qt.IgnoreAction:
+ return True
+
+ if mimedata.hasFormat(Hdf5DatasetMimeData.MIME_TYPE):
+ if row != -1 or column != -1:
+ # It is not a drop on a specific item
+ return False
+ item = self.itemFromIndex(parentIndex)
+ if item is None or item is self.invisibleRootItem():
+ # Drop at the end
+ dataset = mimedata.dataset()
+ if silx.io.is_dataset(dataset):
+ self.createFromSignal(dataset)
+ elif silx.io.is_group(dataset):
+ nxdata = dataset
+ try:
+ self.createFromNxdata(nxdata)
+ except ValueError:
+ _logger.error("Error while dropping a group as an NXdata")
+ _logger.debug("Backtrace", exc_info=True)
+ return False
+ else:
+ _logger.error("Dropping a wrong object")
+ return False
+ else:
+ item = item.parent().child(item.row(), 0)
+ if not isinstance(item, _DatasetItemRow):
+ # Dropped at a bad place
+ return False
+ dataset = mimedata.dataset()
+ if silx.io.is_dataset(dataset):
+ item.setDataset(dataset)
+ else:
+ _logger.error("Dropping a wrong object")
+ return False
+ return True
+
+ return False
+
+ def __getNxdataByTitle(self, title):
+ """Returns an NXdata item by its title, else None.
+
+ :rtype: Union[_NxDataItem,None]
+ """
+ for row in range(self.rowCount()):
+ qindex = self.index(row, 0)
+ item = self.itemFromIndex(qindex)
+ if item.getTitle() == title:
+ return item
+ return None
+
+ def findFreeNxdataTitle(self):
+ """Returns an NXdata title which is not yet used.
+
+ :rtype: str
+ """
+ for i in range(self.rowCount() + 1):
+ name = "NXData #%d" % (i + 1)
+ group = self.__getNxdataByTitle(name)
+ if group is None:
+ break
+ return name
+
+ def createNewNxdata(self, name=None):
+ """Create a new NXdata item.
+
+ :param Union[str,None] name: A title for the new NXdata
+ """
+ item = _NxDataItem()
+ if name is None:
+ name = self.findFreeNxdataTitle()
+ item.setTitle(name)
+ self.appendRow(item.getRowItems())
+
+ def createFromSignal(self, dataset):
+ """Create a new NXdata item from a signal dataset.
+
+ This signal will also define an amount of axes according to its number
+ of dimensions.
+
+ :param Union[numpy.ndarray,h5py.Dataset,silx.io.commonh5.Dataset] dataset:
+ A dataset uses as signal.
+ """
+
+ item = _NxDataItem()
+ name = self.findFreeNxdataTitle()
+ item.setTitle(name)
+ item.setSignalDataset(dataset)
+ item.setAxesDatasets([None] * len(dataset.shape))
+ self.appendRow(item.getRowItems())
+
+ def createFromNxdata(self, nxdata):
+ """Create a new custom NXdata item from an existing NXdata group.
+
+ If the NXdata is not valid, nothing is created, and an exception is
+ returned.
+
+ :param Union[h5py.Group,silx.io.commonh5.Group] nxdata: An h5py group
+ following the NXData specification.
+ :raise ValueError:If `nxdata` is not valid.
+ """
+ validator = silx.io.nxdata.NXdata(nxdata)
+ if validator.is_valid:
+ item = _NxDataItem()
+ title = validator.title
+ if title in [None or ""]:
+ title = self.findFreeNxdataTitle()
+ item.setTitle(title)
+ item.setSignalDataset(validator.signal)
+ item.setAxesDatasets(validator.axes)
+ self.appendRow(item.getRowItems())
+ else:
+ raise ValueError("Not a valid NXdata")
+
+ def removeNxdataItem(self, item):
+ """Remove an NXdata item from this model.
+
+ :param _NxDataItem item: An item
+ """
+ if isinstance(item, _NxDataItem):
+ parent = item.parent()
+ assert(parent is None)
+ model = item.model()
+ model.removeRow(item.row())
+ else:
+ _logger.error("Unexpected item")
+
+ def appendAxisToNxdataItem(self, item):
+ """Append a new axes to this item (or the NXdata item own by this item).
+
+ :param Union[_NxDataItem,qt.QStandardItem] item: An item
+ """
+ if item is not None and not isinstance(item, _NxDataItem):
+ item = item.parent()
+ nxdataItem = item
+ if isinstance(item, _NxDataItem):
+ datasets = nxdataItem.getAxesDatasets()
+ datasets.append(None)
+ nxdataItem.setAxesDatasets(datasets)
+ else:
+ _logger.error("Unexpected item")
+
+ def removeAxisItem(self, item):
+ """Remove an axis item from this model.
+
+ :param _DatasetAxisItemRow item: An axis item
+ """
+ if isinstance(item, _DatasetAxisItemRow):
+ axisId = item.getAxisId()
+ nxdataItem = item.parent()
+ datasets = nxdataItem.getAxesDatasets()
+ del datasets[axisId]
+ nxdataItem.setAxesDatasets(datasets)
+ else:
+ _logger.error("Unexpected item")
+
+
+class CustomNxDataToolBar(qt.QToolBar):
+ """A specialised toolbar to manage custom NXdata model and items."""
+
+ def __init__(self, parent=None):
+ """Constructor"""
+ super(CustomNxDataToolBar, self).__init__(parent=parent)
+ self.__nxdataWidget = None
+ self.__initContent()
+ # Initialize action state
+ self.__currentSelectionChanged(qt.QModelIndex(), qt.QModelIndex())
+
+ def __initContent(self):
+ """Create all expected actions and set the content of this toolbar."""
+ action = qt.QAction("Create a new custom NXdata", self)
+ action.setIcon(icons.getQIcon("nxdata-create"))
+ action.triggered.connect(self.__createNewNxdata)
+ self.addAction(action)
+ self.__addNxDataAction = action
+
+ action = qt.QAction("Remove the selected NXdata", self)
+ action.setIcon(icons.getQIcon("nxdata-remove"))
+ action.triggered.connect(self.__removeSelectedNxdata)
+ self.addAction(action)
+ self.__removeNxDataAction = action
+
+ self.addSeparator()
+
+ action = qt.QAction("Create a new axis to the selected NXdata", self)
+ action.setIcon(icons.getQIcon("nxdata-axis-add"))
+ action.triggered.connect(self.__appendNewAxisToSelectedNxdata)
+ self.addAction(action)
+ self.__addNxDataAxisAction = action
+
+ action = qt.QAction("Remove the selected NXdata axis", self)
+ action.setIcon(icons.getQIcon("nxdata-axis-remove"))
+ action.triggered.connect(self.__removeSelectedAxis)
+ self.addAction(action)
+ self.__removeNxDataAxisAction = action
+
+ def __getSelectedItem(self):
+ """Get the selected item from the linked CustomNxdataWidget.
+
+ :rtype: qt.QStandardItem
+ """
+ selectionModel = self.__nxdataWidget.selectionModel()
+ index = selectionModel.currentIndex()
+ if not index.isValid():
+ return
+ model = self.__nxdataWidget.model()
+ index = model.index(index.row(), 0, index.parent())
+ item = model.itemFromIndex(index)
+ return item
+
+ def __createNewNxdata(self):
+ """Create a new NXdata item to the linked CustomNxdataWidget."""
+ if self.__nxdataWidget is None:
+ return
+ model = self.__nxdataWidget.model()
+ model.createNewNxdata()
+
+ def __removeSelectedNxdata(self):
+ """Remove the NXdata item currently selected in the linked
+ CustomNxdataWidget."""
+ if self.__nxdataWidget is None:
+ return
+ model = self.__nxdataWidget.model()
+ item = self.__getSelectedItem()
+ model.removeNxdataItem(item)
+
+ def __appendNewAxisToSelectedNxdata(self):
+ """Append a new axis to the NXdata item currently selected in the
+ linked CustomNxdataWidget."""
+ if self.__nxdataWidget is None:
+ return
+ model = self.__nxdataWidget.model()
+ item = self.__getSelectedItem()
+ model.appendAxisToNxdataItem(item)
+
+ def __removeSelectedAxis(self):
+ """Remove the axis item currently selected in the linked
+ CustomNxdataWidget."""
+ if self.__nxdataWidget is None:
+ return
+ model = self.__nxdataWidget.model()
+ item = self.__getSelectedItem()
+ model.removeAxisItem(item)
+
+ def setCustomNxDataWidget(self, widget):
+ """Set the linked CustomNxdataWidget to this toolbar."""
+ assert(isinstance(widget, CustomNxdataWidget))
+ if self.__nxdataWidget is not None:
+ selectionModel = self.__nxdataWidget.selectionModel()
+ selectionModel.currentChanged.disconnect(self.__currentSelectionChanged)
+ self.__nxdataWidget = widget
+ if self.__nxdataWidget is not None:
+ selectionModel = self.__nxdataWidget.selectionModel()
+ selectionModel.currentChanged.connect(self.__currentSelectionChanged)
+
+ def __currentSelectionChanged(self, current, previous):
+ """Update the actions according to the linked CustomNxdataWidget
+ item selection"""
+ if not current.isValid():
+ item = None
+ else:
+ model = self.__nxdataWidget.model()
+ index = model.index(current.row(), 0, current.parent())
+ item = model.itemFromIndex(index)
+ self.__removeNxDataAction.setEnabled(isinstance(item, _NxDataItem))
+ self.__removeNxDataAxisAction.setEnabled(isinstance(item, _DatasetAxisItemRow))
+ self.__addNxDataAxisAction.setEnabled(isinstance(item, _NxDataItem) or isinstance(item, _DatasetItemRow))
+
+
+class _HashDropZones(qt.QStyledItemDelegate):
+ """Delegate item displaying a drop zone when the item do not contains
+ dataset."""
+
+ def __init__(self, parent=None):
+ """Constructor"""
+ super(_HashDropZones, self).__init__(parent)
+ pen = qt.QPen()
+ pen.setColor(qt.QColor("#D0D0D0"))
+ pen.setStyle(qt.Qt.DotLine)
+ pen.setWidth(2)
+ self.__dropPen = pen
+
+ def paint(self, painter, option, index):
+ """
+ Paint the item
+
+ :param qt.QPainter painter: A painter
+ :param qt.QStyleOptionViewItem option: Options of the item to paint
+ :param qt.QModelIndex index: Index of the item to paint
+ """
+ displayDropZone = False
+ if index.isValid():
+ model = index.model()
+ rowIndex = model.index(index.row(), 0, index.parent())
+ rowItem = model.itemFromIndex(rowIndex)
+ if isinstance(rowItem, _DatasetItemRow):
+ displayDropZone = rowItem.getDataset() is None
+
+ if displayDropZone:
+ painter.save()
+
+ # Draw background if selected
+ if option.state & qt.QStyle.State_Selected:
+ colorGroup = qt.QPalette.Inactive
+ if option.state & qt.QStyle.State_Active:
+ colorGroup = qt.QPalette.Active
+ if not option.state & qt.QStyle.State_Enabled:
+ colorGroup = qt.QPalette.Disabled
+ brush = option.palette.brush(colorGroup, qt.QPalette.Highlight)
+ painter.fillRect(option.rect, brush)
+
+ painter.setPen(self.__dropPen)
+ painter.drawRect(option.rect.adjusted(3, 3, -3, -3))
+ painter.restore()
+ else:
+ qt.QStyledItemDelegate.paint(self, painter, option, index)
+
+
+class CustomNxdataWidget(qt.QTreeView):
+ """Widget providing a table displaying and allowing to custom virtual
+ NXdata."""
+
+ sigNxdataItemUpdated = qt.Signal(qt.QStandardItem)
+ """Emitted when the NXdata from an NXdata item was edited"""
+
+ sigNxdataItemRemoved = qt.Signal(qt.QStandardItem)
+ """Emitted when an NXdata item was removed"""
+
+ def __init__(self, parent=None):
+ """Constructor"""
+ qt.QTreeView.__init__(self, parent=None)
+ self.__model = _Model(self)
+ self.__model.setColumnCount(4)
+ self.__model.setHorizontalHeaderLabels(["Name", "Dataset", "Type", "Shape"])
+ self.setModel(self.__model)
+
+ self.setItemDelegateForColumn(1, _HashDropZones(self))
+
+ self.__model.sigNxdataUpdated.connect(self.__nxdataUpdate)
+ self.__model.rowsAboutToBeRemoved.connect(self.__rowsAboutToBeRemoved)
+ self.__model.rowsAboutToBeInserted.connect(self.__rowsAboutToBeInserted)
+
+ header = self.header()
+ header.setSectionResizeMode(0, qt.QHeaderView.ResizeToContents)
+ header.setSectionResizeMode(1, qt.QHeaderView.Stretch)
+ header.setSectionResizeMode(2, qt.QHeaderView.ResizeToContents)
+ header.setSectionResizeMode(3, qt.QHeaderView.ResizeToContents)
+
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self.setDropIndicatorShown(True)
+ self.setDragDropOverwriteMode(True)
+ self.setDragEnabled(True)
+ self.viewport().setAcceptDrops(True)
+
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested[qt.QPoint].connect(self.__executeContextMenu)
+
+ def __rowsAboutToBeInserted(self, parentIndex, start, end):
+ # FIXME: workaround for https://github.com/silx-kit/silx/issues/1919
+ # Uses of ResizeToContents looks to break nice update of cells with Qt5
+ # This patch make the view blinking
+ self.repaint()
+
+ def __rowsAboutToBeRemoved(self, parentIndex, start, end):
+ """Called when an item was removed from the model."""
+ items = []
+ model = self.model()
+ for index in range(start, end):
+ qindex = model.index(index, 0, parent=parentIndex)
+ item = self.__model.itemFromIndex(qindex)
+ if isinstance(item, _NxDataItem):
+ items.append(item)
+ for item in items:
+ self.sigNxdataItemRemoved.emit(item)
+
+ # FIXME: workaround for https://github.com/silx-kit/silx/issues/1919
+ # Uses of ResizeToContents looks to break nice update of cells with Qt5
+ # This patch make the view blinking
+ self.repaint()
+
+ def __nxdataUpdate(self, index):
+ """Called when a virtual NXdata was updated from the model."""
+ model = self.model()
+ item = model.itemFromIndex(index)
+ self.sigNxdataItemUpdated.emit(item)
+
+ def createDefaultContextMenu(self, index):
+ """Create a default context menu at this position.
+
+ :param qt.QModelIndex index: Index of the item
+ """
+ index = self.__model.index(index.row(), 0, parent=index.parent())
+ item = self.__model.itemFromIndex(index)
+
+ menu = qt.QMenu()
+
+ weakself = weakref.proxy(self)
+
+ if isinstance(item, _NxDataItem):
+ action = qt.QAction("Add a new axis", menu)
+ action.triggered.connect(lambda: weakself.model().appendAxisToNxdataItem(item))
+ action.setIcon(icons.getQIcon("nxdata-axis-add"))
+ action.setIconVisibleInMenu(True)
+ menu.addAction(action)
+ menu.addSeparator()
+ action = qt.QAction("Remove this NXdata", menu)
+ action.triggered.connect(lambda: weakself.model().removeNxdataItem(item))
+ action.setIcon(icons.getQIcon("remove"))
+ action.setIconVisibleInMenu(True)
+ menu.addAction(action)
+ else:
+ if isinstance(item, _DatasetItemRow):
+ if item.getDataset() is not None:
+ action = qt.QAction("Remove this dataset", menu)
+ action.triggered.connect(lambda: item.setDataset(None))
+ menu.addAction(action)
+
+ if isinstance(item, _DatasetAxisItemRow):
+ menu.addSeparator()
+ action = qt.QAction("Remove this axis", menu)
+ action.triggered.connect(lambda: weakself.model().removeAxisItem(item))
+ action.setIcon(icons.getQIcon("remove"))
+ action.setIconVisibleInMenu(True)
+ menu.addAction(action)
+
+ return menu
+
+ def __executeContextMenu(self, point):
+ """Execute the context menu at this position."""
+ index = self.indexAt(point)
+ menu = self.createDefaultContextMenu(index)
+ if menu is None or menu.isEmpty():
+ return
+ menu.exec(qt.QCursor.pos())
+
+ def removeDatasetsFrom(self, root):
+ """
+ Remove all datasets provided by this root
+
+ :param root: The root file of datasets to remove
+ """
+ for row in range(self.__model.rowCount()):
+ qindex = self.__model.index(row, 0)
+ item = self.model().itemFromIndex(qindex)
+
+ edited = False
+ datasets = item.getAxesDatasets()
+ for i, dataset in enumerate(datasets):
+ if dataset is not None:
+ # That's an approximation, IS can't be used as h5py generates
+ # To objects for each requests to a node
+ if dataset.file.filename == root.file.filename:
+ datasets[i] = None
+ edited = True
+ if edited:
+ item.setAxesDatasets(datasets)
+
+ dataset = item.getSignalDataset()
+ if dataset is not None:
+ # That's an approximation, IS can't be used as h5py generates
+ # To objects for each requests to a node
+ if dataset.file.filename == root.file.filename:
+ item.setSignalDataset(None)
+
+ def replaceDatasetsFrom(self, removedRoot, loadedRoot):
+ """
+ Replace any dataset from any NXdata items using the same dataset name
+ from another root.
+
+ Usually used when a file was synchronized.
+
+ :param removedRoot: The h5py root file which is replaced
+ (which have to be removed)
+ :param loadedRoot: The new h5py root file which have to be used
+ instread.
+ """
+ for row in range(self.__model.rowCount()):
+ qindex = self.__model.index(row, 0)
+ item = self.model().itemFromIndex(qindex)
+
+ edited = False
+ datasets = item.getAxesDatasets()
+ for i, dataset in enumerate(datasets):
+ newDataset = self.__replaceDatasetRoot(dataset, removedRoot, loadedRoot)
+ if dataset is not newDataset:
+ datasets[i] = newDataset
+ edited = True
+ if edited:
+ item.setAxesDatasets(datasets)
+
+ dataset = item.getSignalDataset()
+ newDataset = self.__replaceDatasetRoot(dataset, removedRoot, loadedRoot)
+ if dataset is not newDataset:
+ item.setSignalDataset(newDataset)
+
+ def __replaceDatasetRoot(self, dataset, fromRoot, toRoot):
+ """
+ Replace the dataset by the same dataset name from another root.
+ """
+ if dataset is None:
+ return None
+
+ if dataset.file is None:
+ # Not from the expected root
+ return dataset
+
+ # That's an approximation, IS can't be used as h5py generates
+ # To objects for each requests to a node
+ if dataset.file.filename == fromRoot.file.filename:
+ # Try to find the same dataset name
+ try:
+ return toRoot[dataset.name]
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ return None
+ else:
+ # Not from the expected root
+ return dataset
+
+ def selectedItems(self):
+ """Returns the list of selected items containing NXdata
+
+ :rtype: List[qt.QStandardItem]
+ """
+ result = []
+ for qindex in self.selectedIndexes():
+ if qindex.column() != 0:
+ continue
+ if not qindex.isValid():
+ continue
+ item = self.__model.itemFromIndex(qindex)
+ if not isinstance(item, _NxDataItem):
+ continue
+ result.append(item)
+ return result
+
+ def selectedNxdata(self):
+ """Returns the list of selected NXdata
+
+ :rtype: List[silx.io.commonh5.Group]
+ """
+ result = []
+ for qindex in self.selectedIndexes():
+ if qindex.column() != 0:
+ continue
+ if not qindex.isValid():
+ continue
+ item = self.__model.itemFromIndex(qindex)
+ if not isinstance(item, _NxDataItem):
+ continue
+ result.append(item.getVirtualGroup())
+ return result
diff --git a/silx/app/view/DataPanel.py b/src/silx/app/view/DataPanel.py
index 5d87381..5d87381 100644
--- a/silx/app/view/DataPanel.py
+++ b/src/silx/app/view/DataPanel.py
diff --git a/src/silx/app/view/Viewer.py b/src/silx/app/view/Viewer.py
new file mode 100644
index 0000000..7e5e4c9
--- /dev/null
+++ b/src/silx/app/view/Viewer.py
@@ -0,0 +1,962 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Browse a data file with a GUI"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "15/01/2019"
+
+
+import os
+import collections
+import logging
+import functools
+
+import silx.io.nxdata
+from silx.gui import qt
+from silx.gui import icons
+import silx.gui.hdf5
+from .ApplicationContext import ApplicationContext
+from .CustomNxdataWidget import CustomNxdataWidget
+from .CustomNxdataWidget import CustomNxDataToolBar
+from . import utils
+from silx.gui.utils import projecturl
+from .DataPanel import DataPanel
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Viewer(qt.QMainWindow):
+ """
+ This window allows to browse a data file like images or HDF5 and it's
+ content.
+ """
+
+ def __init__(self, parent=None, settings=None):
+ """
+ Constructor
+ """
+
+ qt.QMainWindow.__init__(self, parent)
+ self.setWindowTitle("Silx viewer")
+
+ silxIcon = icons.getQIcon("silx")
+ self.setWindowIcon(silxIcon)
+
+ self.__context = self.createApplicationContext(settings)
+ self.__context.restoreLibrarySettings()
+
+ self.__dialogState = None
+ self.__customNxDataItem = None
+ self.__treeview = silx.gui.hdf5.Hdf5TreeView(self)
+ self.__treeview.setExpandsOnDoubleClick(False)
+ """Silx HDF5 TreeView"""
+
+ rightPanel = qt.QSplitter(self)
+ rightPanel.setOrientation(qt.Qt.Vertical)
+ self.__splitter2 = rightPanel
+
+ self.__displayIt = None
+ self.__treeWindow = self.__createTreeWindow(self.__treeview)
+
+ # Custom the model to be able to manage the life cycle of the files
+ treeModel = silx.gui.hdf5.Hdf5TreeModel(self.__treeview, ownFiles=False)
+ treeModel.sigH5pyObjectLoaded.connect(self.__h5FileLoaded)
+ treeModel.sigH5pyObjectRemoved.connect(self.__h5FileRemoved)
+ treeModel.sigH5pyObjectSynchronized.connect(self.__h5FileSynchonized)
+ treeModel.setDatasetDragEnabled(True)
+ self.__treeModelSorted = silx.gui.hdf5.NexusSortFilterProxyModel(self.__treeview)
+ self.__treeModelSorted.setSourceModel(treeModel)
+ self.__treeModelSorted.sort(0, qt.Qt.AscendingOrder)
+ self.__treeModelSorted.setSortCaseSensitivity(qt.Qt.CaseInsensitive)
+
+ self.__treeview.setModel(self.__treeModelSorted)
+ rightPanel.addWidget(self.__treeWindow)
+
+ self.__customNxdata = CustomNxdataWidget(self)
+ self.__customNxdata.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ # optimise the rendering
+ self.__customNxdata.setUniformRowHeights(True)
+ self.__customNxdata.setIconSize(qt.QSize(16, 16))
+ self.__customNxdata.setExpandsOnDoubleClick(False)
+
+ self.__customNxdataWindow = self.__createCustomNxdataWindow(self.__customNxdata)
+ self.__customNxdataWindow.setVisible(False)
+ rightPanel.addWidget(self.__customNxdataWindow)
+
+ rightPanel.setStretchFactor(1, 1)
+ rightPanel.setCollapsible(0, False)
+ rightPanel.setCollapsible(1, False)
+
+ self.__dataPanel = DataPanel(self, self.__context)
+
+ spliter = qt.QSplitter(self)
+ spliter.addWidget(rightPanel)
+ spliter.addWidget(self.__dataPanel)
+ spliter.setStretchFactor(1, 1)
+ spliter.setCollapsible(0, False)
+ spliter.setCollapsible(1, False)
+ self.__splitter = spliter
+
+ main_panel = qt.QWidget(self)
+ layout = qt.QVBoxLayout()
+ layout.addWidget(spliter)
+ layout.setStretchFactor(spliter, 1)
+ main_panel.setLayout(layout)
+
+ self.setCentralWidget(main_panel)
+
+ self.__treeview.activated.connect(self.displaySelectedData)
+ self.__customNxdata.activated.connect(self.displaySelectedCustomData)
+ self.__customNxdata.sigNxdataItemRemoved.connect(self.__customNxdataRemoved)
+ self.__customNxdata.sigNxdataItemUpdated.connect(self.__customNxdataUpdated)
+ self.__treeview.addContextMenuCallback(self.customContextMenu)
+
+ treeModel = self.__treeview.findHdf5TreeModel()
+ columns = list(treeModel.COLUMN_IDS)
+ columns.remove(treeModel.VALUE_COLUMN)
+ columns.remove(treeModel.NODE_COLUMN)
+ columns.remove(treeModel.DESCRIPTION_COLUMN)
+ columns.insert(1, treeModel.DESCRIPTION_COLUMN)
+ self.__treeview.header().setSections(columns)
+
+ self._iconUpward = icons.getQIcon('plot-yup')
+ self._iconDownward = icons.getQIcon('plot-ydown')
+
+ self.createActions()
+ self.createMenus()
+ self.__context.restoreSettings()
+
+ def createApplicationContext(self, settings):
+ return ApplicationContext(self, settings)
+
+ def __createTreeWindow(self, treeView):
+ toolbar = qt.QToolBar(self)
+ toolbar.setIconSize(qt.QSize(16, 16))
+ toolbar.setStyleSheet("QToolBar { border: 0px }")
+
+ action = qt.QAction(toolbar)
+ action.setIcon(icons.getQIcon("view-refresh"))
+ action.setText("Refresh")
+ action.setToolTip("Refresh all selected items")
+ action.triggered.connect(self.__refreshSelected)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_F5))
+ toolbar.addAction(action)
+ treeView.addAction(action)
+ self.__refreshAction = action
+
+ # Another shortcut for refresh
+ action = qt.QAction(toolbar)
+ action.setShortcut(qt.QKeySequence(qt.Qt.ControlModifier + qt.Qt.Key_R))
+ treeView.addAction(action)
+ action.triggered.connect(self.__refreshSelected)
+
+ action = qt.QAction(toolbar)
+ # action.setIcon(icons.getQIcon("view-refresh"))
+ action.setText("Close")
+ action.setToolTip("Close selected item")
+ action.triggered.connect(self.__removeSelected)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_Delete))
+ treeView.addAction(action)
+ self.__closeAction = action
+
+ toolbar.addSeparator()
+
+ action = qt.QAction(toolbar)
+ action.setIcon(icons.getQIcon("tree-expand-all"))
+ action.setText("Expand all")
+ action.setToolTip("Expand all selected items")
+ action.triggered.connect(self.__expandAllSelected)
+ action.setShortcut(qt.QKeySequence(qt.Qt.ControlModifier + qt.Qt.Key_Plus))
+ toolbar.addAction(action)
+ treeView.addAction(action)
+ self.__expandAllAction = action
+
+ action = qt.QAction(toolbar)
+ action.setIcon(icons.getQIcon("tree-collapse-all"))
+ action.setText("Collapse all")
+ action.setToolTip("Collapse all selected items")
+ action.triggered.connect(self.__collapseAllSelected)
+ action.setShortcut(qt.QKeySequence(qt.Qt.ControlModifier + qt.Qt.Key_Minus))
+ toolbar.addAction(action)
+ treeView.addAction(action)
+ self.__collapseAllAction = action
+
+ action = qt.QAction("&Sort file content", toolbar)
+ action.setIcon(icons.getQIcon("tree-sort"))
+ action.setToolTip("Toggle sorting of file content")
+ action.setCheckable(True)
+ action.setChecked(True)
+ action.triggered.connect(self.setContentSorted)
+ toolbar.addAction(action)
+ treeView.addAction(action)
+ self._sortContentAction = action
+
+ widget = qt.QWidget(self)
+ layout = qt.QVBoxLayout(widget)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(toolbar)
+ layout.addWidget(treeView)
+ return widget
+
+ def __removeSelected(self):
+ """Close selected items"""
+ qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
+
+ selection = self.__treeview.selectionModel()
+ indexes = selection.selectedIndexes()
+ selectedItems = []
+ model = self.__treeview.model()
+ h5files = set([])
+ while len(indexes) > 0:
+ index = indexes.pop(0)
+ if index.column() != 0:
+ continue
+ h5 = model.data(index, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ rootIndex = index
+ # Reach the root of the tree
+ while rootIndex.parent().isValid():
+ rootIndex = rootIndex.parent()
+ rootRow = rootIndex.row()
+ relativePath = self.__getRelativePath(model, rootIndex, index)
+ selectedItems.append((rootRow, relativePath))
+ h5files.add(h5.file)
+
+ if len(h5files) != 0:
+ model = self.__treeview.findHdf5TreeModel()
+ for h5 in h5files:
+ row = model.h5pyObjectRow(h5)
+ model.removeH5pyObject(h5)
+
+ qt.QApplication.restoreOverrideCursor()
+
+ def __refreshSelected(self):
+ """Refresh all selected items
+ """
+ qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
+
+ selection = self.__treeview.selectionModel()
+ indexes = selection.selectedIndexes()
+ selectedItems = []
+ model = self.__treeview.model()
+ h5files = set([])
+ while len(indexes) > 0:
+ index = indexes.pop(0)
+ if index.column() != 0:
+ continue
+ h5 = model.data(index, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ rootIndex = index
+ # Reach the root of the tree
+ while rootIndex.parent().isValid():
+ rootIndex = rootIndex.parent()
+ rootRow = rootIndex.row()
+ relativePath = self.__getRelativePath(model, rootIndex, index)
+ selectedItems.append((rootRow, relativePath))
+ h5files.add(h5.file)
+
+ if len(h5files) == 0:
+ qt.QApplication.restoreOverrideCursor()
+ return
+
+ model = self.__treeview.findHdf5TreeModel()
+ for h5 in h5files:
+ self.__synchronizeH5pyObject(h5)
+
+ model = self.__treeview.model()
+ itemSelection = qt.QItemSelection()
+ for rootRow, relativePath in selectedItems:
+ rootIndex = model.index(rootRow, 0, qt.QModelIndex())
+ index = self.__indexFromPath(model, rootIndex, relativePath)
+ if index is None:
+ continue
+ indexEnd = model.index(index.row(), model.columnCount() - 1, index.parent())
+ itemSelection.select(index, indexEnd)
+ selection.select(itemSelection, qt.QItemSelectionModel.ClearAndSelect)
+
+ qt.QApplication.restoreOverrideCursor()
+
+ def __synchronizeH5pyObject(self, h5):
+ model = self.__treeview.findHdf5TreeModel()
+ # This is buggy right now while h5py do not allow to close a file
+ # while references are still used.
+ # FIXME: The architecture have to be reworked to support this feature.
+ # model.synchronizeH5pyObject(h5)
+
+ filename = h5.filename
+ row = model.h5pyObjectRow(h5)
+ index = self.__treeview.model().index(row, 0, qt.QModelIndex())
+ paths = self.__getPathFromExpandedNodes(self.__treeview, index)
+ model.removeH5pyObject(h5)
+ model.insertFile(filename, row)
+ index = self.__treeview.model().index(row, 0, qt.QModelIndex())
+ self.__expandNodesFromPaths(self.__treeview, index, paths)
+
+ def __getRelativePath(self, model, rootIndex, index):
+ """Returns a relative path from an index to his rootIndex.
+
+ If the path is empty the index is also the rootIndex.
+ """
+ path = ""
+ while index.isValid():
+ if index == rootIndex:
+ return path
+ name = model.data(index)
+ if path == "":
+ path = name
+ else:
+ path = name + "/" + path
+ index = index.parent()
+
+ # index is not a children of rootIndex
+ raise ValueError("index is not a children of the rootIndex")
+
+ def __getPathFromExpandedNodes(self, view, rootIndex):
+ """Return relative path from the root index of the extended nodes"""
+ model = view.model()
+ rootPath = None
+ paths = []
+ indexes = [rootIndex]
+ while len(indexes):
+ index = indexes.pop(0)
+ if not view.isExpanded(index):
+ continue
+
+ node = model.data(index, role=silx.gui.hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
+ path = node._getCanonicalName()
+ if rootPath is None:
+ rootPath = path
+ path = path[len(rootPath):]
+ paths.append(path)
+
+ for child in range(model.rowCount(index)):
+ childIndex = model.index(child, 0, index)
+ indexes.append(childIndex)
+ return paths
+
+ def __indexFromPath(self, model, rootIndex, path):
+ elements = path.split("/")
+ if elements[0] == "":
+ elements.pop(0)
+ index = rootIndex
+ while len(elements) != 0:
+ element = elements.pop(0)
+ found = False
+ for child in range(model.rowCount(index)):
+ childIndex = model.index(child, 0, index)
+ name = model.data(childIndex)
+ if element == name:
+ index = childIndex
+ found = True
+ break
+ if not found:
+ return None
+ return index
+
+ def __expandNodesFromPaths(self, view, rootIndex, paths):
+ model = view.model()
+ for path in paths:
+ index = self.__indexFromPath(model, rootIndex, path)
+ if index is not None:
+ view.setExpanded(index, True)
+
+ def __expandAllSelected(self):
+ """Expand all selected items of the tree.
+
+ The depth is fixed to avoid infinite loop with recurssive links.
+ """
+ qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)
+
+ selection = self.__treeview.selectionModel()
+ indexes = selection.selectedIndexes()
+ model = self.__treeview.model()
+ while len(indexes) > 0:
+ index = indexes.pop(0)
+ if isinstance(index, tuple):
+ index, depth = index
+ else:
+ depth = 0
+ if index.column() != 0:
+ continue
+
+ if depth > 10:
+ # Avoid infinite loop with recursive links
+ break
+
+ if model.hasChildren(index):
+ self.__treeview.setExpanded(index, True)
+ for row in range(model.rowCount(index)):
+ childIndex = model.index(row, 0, index)
+ indexes.append((childIndex, depth + 1))
+ qt.QApplication.restoreOverrideCursor()
+
+ def __collapseAllSelected(self):
+ """Collapse all selected items of the tree.
+
+ The depth is fixed to avoid infinite loop with recurssive links.
+ """
+ selection = self.__treeview.selectionModel()
+ indexes = selection.selectedIndexes()
+ model = self.__treeview.model()
+ while len(indexes) > 0:
+ index = indexes.pop(0)
+ if isinstance(index, tuple):
+ index, depth = index
+ else:
+ depth = 0
+ if index.column() != 0:
+ continue
+
+ if depth > 10:
+ # Avoid infinite loop with recursive links
+ break
+
+ if model.hasChildren(index):
+ self.__treeview.setExpanded(index, False)
+ for row in range(model.rowCount(index)):
+ childIndex = model.index(row, 0, index)
+ indexes.append((childIndex, depth + 1))
+
+ def __createCustomNxdataWindow(self, customNxdataWidget):
+ toolbar = CustomNxDataToolBar(self)
+ toolbar.setCustomNxDataWidget(customNxdataWidget)
+ toolbar.setIconSize(qt.QSize(16, 16))
+ toolbar.setStyleSheet("QToolBar { border: 0px }")
+
+ widget = qt.QWidget(self)
+ layout = qt.QVBoxLayout(widget)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(toolbar)
+ layout.addWidget(customNxdataWidget)
+ return widget
+
+ def __h5FileLoaded(self, loadedH5):
+ self.__context.pushRecentFile(loadedH5.file.filename)
+ if loadedH5.file.filename == self.__displayIt:
+ self.__displayIt = None
+ self.displayData(loadedH5)
+
+ def __h5FileRemoved(self, removedH5):
+ self.__dataPanel.removeDatasetsFrom(removedH5)
+ self.__customNxdata.removeDatasetsFrom(removedH5)
+ removedH5.close()
+
+ def __h5FileSynchonized(self, removedH5, loadedH5):
+ self.__dataPanel.replaceDatasetsFrom(removedH5, loadedH5)
+ self.__customNxdata.replaceDatasetsFrom(removedH5, loadedH5)
+ removedH5.close()
+
+ def closeEvent(self, event):
+ self.__context.saveSettings()
+
+ # Clean up as much as possible Python objects
+ self.displayData(None)
+ customModel = self.__customNxdata.model()
+ customModel.clear()
+ hdf5Model = self.__treeview.findHdf5TreeModel()
+ hdf5Model.clear()
+
+ def saveSettings(self, settings):
+ """Save the window settings to this settings object
+
+ :param qt.QSettings settings: Initialized settings
+ """
+ isFullScreen = bool(self.windowState() & qt.Qt.WindowFullScreen)
+ if isFullScreen:
+ # show in normal to catch the normal geometry
+ self.showNormal()
+
+ settings.beginGroup("mainwindow")
+ settings.setValue("size", self.size())
+ settings.setValue("pos", self.pos())
+ settings.setValue("full-screen", isFullScreen)
+ settings.endGroup()
+
+ settings.beginGroup("mainlayout")
+ settings.setValue("spliter", self.__splitter.sizes())
+ settings.setValue("spliter2", self.__splitter2.sizes())
+ isVisible = self.__customNxdataWindow.isVisible()
+ settings.setValue("custom-nxdata-window-visible", isVisible)
+ settings.endGroup()
+
+ settings.beginGroup("content")
+ isSorted = self._sortContentAction.isChecked()
+ settings.setValue("is-sorted", isSorted)
+ settings.endGroup()
+
+ if isFullScreen:
+ self.showFullScreen()
+
+ def restoreSettings(self, settings):
+ """Restore the window settings using this settings object
+
+ :param qt.QSettings settings: Initialized settings
+ """
+ settings.beginGroup("mainwindow")
+ size = settings.value("size", qt.QSize(640, 480))
+ pos = settings.value("pos", qt.QPoint())
+ isFullScreen = settings.value("full-screen", False)
+ try:
+ if not isinstance(isFullScreen, bool):
+ isFullScreen = utils.stringToBool(isFullScreen)
+ except ValueError:
+ isFullScreen = False
+ settings.endGroup()
+
+ settings.beginGroup("mainlayout")
+ try:
+ data = settings.value("spliter")
+ data = [int(d) for d in data]
+ self.__splitter.setSizes(data)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ try:
+ data = settings.value("spliter2")
+ data = [int(d) for d in data]
+ self.__splitter2.setSizes(data)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ isVisible = settings.value("custom-nxdata-window-visible", False)
+ try:
+ if not isinstance(isVisible, bool):
+ isVisible = utils.stringToBool(isVisible)
+ except ValueError:
+ isVisible = False
+ self.__customNxdataWindow.setVisible(isVisible)
+ self._displayCustomNxdataWindow.setChecked(isVisible)
+
+ settings.endGroup()
+
+ settings.beginGroup("content")
+ isSorted = settings.value("is-sorted", True)
+ try:
+ if not isinstance(isSorted, bool):
+ isSorted = utils.stringToBool(isSorted)
+ except ValueError:
+ isSorted = True
+ self.setContentSorted(isSorted)
+ settings.endGroup()
+
+ if not pos.isNull():
+ self.move(pos)
+ if not size.isNull():
+ self.resize(size)
+ if isFullScreen:
+ self.showFullScreen()
+
+ def createActions(self):
+ action = qt.QAction("E&xit", self)
+ action.setShortcuts(qt.QKeySequence.Quit)
+ action.setStatusTip("Exit the application")
+ action.triggered.connect(self.close)
+ self._exitAction = action
+
+ action = qt.QAction("&Open...", self)
+ action.setStatusTip("Open a file")
+ action.triggered.connect(self.open)
+ self._openAction = action
+
+ menu = qt.QMenu("Open Recent", self)
+ menu.setStatusTip("Open a recently opened file")
+ self._openRecentMenu = menu
+
+ action = qt.QAction("Close All", self)
+ action.setStatusTip("Close all opened files")
+ action.triggered.connect(self.closeAll)
+ self._closeAllAction = action
+
+ action = qt.QAction("&About", self)
+ action.setStatusTip("Show the application's About box")
+ action.triggered.connect(self.about)
+ self._aboutAction = action
+
+ action = qt.QAction("&Documentation", self)
+ action.setStatusTip("Show the Silx library's documentation")
+ action.triggered.connect(self.showDocumentation)
+ self._documentationAction = action
+
+ # Plot backend
+
+ self._plotBackendMenu = qt.QMenu("Plot rendering backend", self)
+ self._plotBackendMenu.setStatusTip("Select plot rendering backend")
+
+ group = qt.QActionGroup(self)
+ group.setExclusive(True)
+
+ action = qt.QAction("matplotlib", self)
+ action.setStatusTip("Plot will be rendered using matplotlib")
+ action.setCheckable(True)
+ action.triggered.connect(self.__forceMatplotlibBackend)
+ group.addAction(action)
+ self._plotBackendMenu.addAction(action)
+ self._usePlotWithMatplotlib = action
+
+ action = qt.QAction("OpenGL", self)
+ action.setStatusTip("Plot will be rendered using OpenGL")
+ action.setCheckable(True)
+ action.triggered.connect(self.__forceOpenglBackend)
+ group.addAction(action)
+ self._plotBackendMenu.addAction(action)
+ self._usePlotWithOpengl = action
+
+ # Plot image orientation
+
+ self._plotImageOrientationMenu = qt.QMenu(
+ "Default plot image y-axis orientation", self)
+ self._plotImageOrientationMenu.setStatusTip(
+ "Select the default y-axis orientation used by plot displaying images")
+
+ group = qt.QActionGroup(self)
+ group.setExclusive(True)
+
+ action = qt.QAction("Downward, origin on top", self)
+ action.setIcon(self._iconDownward)
+ action.setStatusTip("Plot images will use a downward Y-axis orientation")
+ action.setCheckable(True)
+ action.triggered.connect(self.__forcePlotImageDownward)
+ group.addAction(action)
+ self._plotImageOrientationMenu.addAction(action)
+ self._useYAxisOrientationDownward = action
+
+ action = qt.QAction("Upward, origin on bottom", self)
+ action.setIcon(self._iconUpward)
+ action.setStatusTip("Plot images will use a upward Y-axis orientation")
+ action.setCheckable(True)
+ action.triggered.connect(self.__forcePlotImageUpward)
+ group.addAction(action)
+ self._plotImageOrientationMenu.addAction(action)
+ self._useYAxisOrientationUpward = action
+
+ # Windows
+
+ action = qt.QAction("Show custom NXdata selector", self)
+ action.setStatusTip("Show a widget which allow to create plot by selecting data and axes")
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_F6))
+ action.toggled.connect(self.__toggleCustomNxdataWindow)
+ self._displayCustomNxdataWindow = action
+
+ def __toggleCustomNxdataWindow(self):
+ isVisible = self._displayCustomNxdataWindow.isChecked()
+ self.__customNxdataWindow.setVisible(isVisible)
+
+ def __updateFileMenu(self):
+ files = self.__context.getRecentFiles()
+ self._openRecentMenu.clear()
+ self._openRecentMenu.setEnabled(len(files) != 0)
+ if len(files) != 0:
+ for filePath in files:
+ baseName = os.path.basename(filePath)
+ action = qt.QAction(baseName, self)
+ action.setToolTip(filePath)
+ action.triggered.connect(functools.partial(self.__openRecentFile, filePath))
+ self._openRecentMenu.addAction(action)
+ self._openRecentMenu.addSeparator()
+ baseName = os.path.basename(filePath)
+ action = qt.QAction("Clear history", self)
+ action.setToolTip("Clear the history of the recent files")
+ action.triggered.connect(self.__clearRecentFile)
+ self._openRecentMenu.addAction(action)
+
+ def __clearRecentFile(self):
+ self.__context.clearRencentFiles()
+
+ def __openRecentFile(self, filePath):
+ self.appendFile(filePath)
+
+ def __updateOptionMenu(self):
+ """Update the state of the checked options as it is based on global
+ environment values."""
+
+ # plot backend
+
+ title = self._plotBackendMenu.title().split(": ", 1)[0]
+ self._plotBackendMenu.setTitle("%s: %s" % (title, silx.config.DEFAULT_PLOT_BACKEND))
+
+ action = self._usePlotWithMatplotlib
+ action.setChecked(silx.config.DEFAULT_PLOT_BACKEND in ["matplotlib", "mpl"])
+ title = action.text().split(" (", 1)[0]
+ if not action.isChecked():
+ title += " (applied after application restart)"
+ action.setText(title)
+
+ action = self._usePlotWithOpengl
+ action.setChecked(silx.config.DEFAULT_PLOT_BACKEND in ["opengl", "gl"])
+ title = action.text().split(" (", 1)[0]
+ if not action.isChecked():
+ title += " (applied after application restart)"
+ action.setText(title)
+
+ # plot orientation
+
+ menu = self._plotImageOrientationMenu
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward":
+ menu.setIcon(self._iconDownward)
+ else:
+ menu.setIcon(self._iconUpward)
+
+ action = self._useYAxisOrientationDownward
+ action.setChecked(silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == "downward")
+ title = action.text().split(" (", 1)[0]
+ if not action.isChecked():
+ title += " (applied after application restart)"
+ action.setText(title)
+
+ action = self._useYAxisOrientationUpward
+ action.setChecked(silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION != "downward")
+ title = action.text().split(" (", 1)[0]
+ if not action.isChecked():
+ title += " (applied after application restart)"
+ action.setText(title)
+
+ def createMenus(self):
+ fileMenu = self.menuBar().addMenu("&File")
+ fileMenu.addAction(self._openAction)
+ fileMenu.addMenu(self._openRecentMenu)
+ fileMenu.addAction(self._closeAllAction)
+ fileMenu.addSeparator()
+ fileMenu.addAction(self._exitAction)
+ fileMenu.aboutToShow.connect(self.__updateFileMenu)
+
+ optionMenu = self.menuBar().addMenu("&Options")
+ optionMenu.addMenu(self._plotImageOrientationMenu)
+ optionMenu.addMenu(self._plotBackendMenu)
+ optionMenu.aboutToShow.connect(self.__updateOptionMenu)
+
+ viewMenu = self.menuBar().addMenu("&Views")
+ viewMenu.addAction(self._displayCustomNxdataWindow)
+
+ helpMenu = self.menuBar().addMenu("&Help")
+ helpMenu.addAction(self._aboutAction)
+ helpMenu.addAction(self._documentationAction)
+
+ def open(self):
+ dialog = self.createFileDialog()
+ if self.__dialogState is None:
+ currentDirectory = os.getcwd()
+ dialog.setDirectory(currentDirectory)
+ else:
+ dialog.restoreState(self.__dialogState)
+
+ result = dialog.exec()
+ if not result:
+ return
+
+ self.__dialogState = dialog.saveState()
+
+ filenames = dialog.selectedFiles()
+ for filename in filenames:
+ self.appendFile(filename)
+
+ def closeAll(self):
+ """Close all currently opened files"""
+ model = self.__treeview.findHdf5TreeModel()
+ model.clear()
+
+ def createFileDialog(self):
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Open")
+ dialog.setModal(True)
+
+ # NOTE: hdf5plugin have to be loaded before
+ extensions = collections.OrderedDict()
+ for description, ext in silx.io.supported_extensions().items():
+ extensions[description] = " ".join(sorted(list(ext)))
+
+ # Add extensions supported by fabio
+ extensions["NeXus layout from EDF files"] = "*.edf"
+ extensions["NeXus layout from TIFF image files"] = "*.tif *.tiff"
+ extensions["NeXus layout from CBF files"] = "*.cbf"
+ extensions["NeXus layout from MarCCD image files"] = "*.mccd"
+
+ all_supported_extensions = set()
+ for name, exts in extensions.items():
+ exts = exts.split(" ")
+ all_supported_extensions.update(exts)
+ all_supported_extensions = sorted(list(all_supported_extensions))
+
+ filters = []
+ filters.append("All supported files (%s)" % " ".join(all_supported_extensions))
+ for name, extension in extensions.items():
+ filters.append("%s (%s)" % (name, extension))
+ filters.append("All files (*)")
+
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFiles)
+ return dialog
+
+ def about(self):
+ from .About import About
+ About.about(self, "Silx viewer")
+
+ def showDocumentation(self):
+ subpath = "index.html"
+ url = projecturl.getDocumentationUrl(subpath)
+ qt.QDesktopServices.openUrl(qt.QUrl(url))
+
+ def setContentSorted(self, sort):
+ """Set whether file content should be sorted or not.
+
+ :param bool sort:
+ """
+ sort = bool(sort)
+ if sort != self.isContentSorted():
+
+ # save expanded nodes
+ pathss = []
+ root = qt.QModelIndex()
+ model = self.__treeview.model()
+ for i in range(model.rowCount(root)):
+ index = model.index(i, 0, root)
+ paths = self.__getPathFromExpandedNodes(self.__treeview, index)
+ pathss.append(paths)
+
+ self.__treeview.setModel(
+ self.__treeModelSorted if sort else self.__treeModelSorted.sourceModel())
+ self._sortContentAction.setChecked(self.isContentSorted())
+
+ # restore expanded nodes
+ model = self.__treeview.model()
+ for i in range(model.rowCount(root)):
+ index = model.index(i, 0, root)
+ paths = pathss.pop(0)
+ self.__expandNodesFromPaths(self.__treeview, index, paths)
+
+ def isContentSorted(self):
+ """Returns whether the file content is sorted or not.
+
+ :rtype: bool
+ """
+ return self.__treeview.model() is self.__treeModelSorted
+
+ def __forcePlotImageDownward(self):
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = "downward"
+
+ def __forcePlotImageUpward(self):
+ silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION = "upward"
+
+ def __forceMatplotlibBackend(self):
+ silx.config.DEFAULT_PLOT_BACKEND = "matplotlib"
+
+ def __forceOpenglBackend(self):
+ silx.config.DEFAULT_PLOT_BACKEND = "opengl"
+
+ def appendFile(self, filename):
+ if self.__displayIt is None:
+ # Store the file to display it (loading could be async)
+ self.__displayIt = filename
+ self.__treeview.findHdf5TreeModel().appendFile(filename)
+
+ def displaySelectedData(self):
+ """Called to update the dataviewer with the selected data.
+ """
+ selected = list(self.__treeview.selectedH5Nodes(ignoreBrokenLinks=False))
+ if len(selected) == 1:
+ # Update the viewer for a single selection
+ data = selected[0]
+ self.__dataPanel.setData(data)
+ else:
+ _logger.debug("Too many data selected")
+
+ def displayData(self, data):
+ """Called to update the dataviewer with a secific data.
+ """
+ self.__dataPanel.setData(data)
+
+ def displaySelectedCustomData(self):
+ selected = list(self.__customNxdata.selectedItems())
+ if len(selected) == 1:
+ # Update the viewer for a single selection
+ item = selected[0]
+ self.__dataPanel.setCustomDataItem(item)
+ else:
+ _logger.debug("Too many items selected")
+
+ def __customNxdataRemoved(self, item):
+ if self.__dataPanel.getCustomNxdataItem() is item:
+ self.__dataPanel.setCustomDataItem(None)
+
+ def __customNxdataUpdated(self, item):
+ if self.__dataPanel.getCustomNxdataItem() is item:
+ self.__dataPanel.setCustomDataItem(item)
+
+ def __makeSureCustomNxDataWindowIsVisible(self):
+ if not self.__customNxdataWindow.isVisible():
+ self.__customNxdataWindow.setVisible(True)
+ self._displayCustomNxdataWindow.setChecked(True)
+
+ def useAsNewCustomSignal(self, h5dataset):
+ self.__makeSureCustomNxDataWindowIsVisible()
+ model = self.__customNxdata.model()
+ model.createFromSignal(h5dataset)
+
+ def useAsNewCustomNxdata(self, h5nxdata):
+ self.__makeSureCustomNxDataWindowIsVisible()
+ model = self.__customNxdata.model()
+ model.createFromNxdata(h5nxdata)
+
+ def customContextMenu(self, event):
+ """Called to populate the context menu
+
+ :param silx.gui.hdf5.Hdf5ContextMenuEvent event: Event
+ containing expected information to populate the context menu
+ """
+ selectedObjects = event.source().selectedH5Nodes(ignoreBrokenLinks=False)
+ menu = event.menu()
+
+ if not menu.isEmpty():
+ menu.addSeparator()
+
+ for obj in selectedObjects:
+ h5 = obj.h5py_object
+
+ name = obj.name
+ if name.startswith("/"):
+ name = name[1:]
+ if name == "":
+ name = "the root"
+
+ action = qt.QAction("Show %s" % name, event.source())
+ action.triggered.connect(lambda: self.displayData(h5))
+ menu.addAction(action)
+
+ if silx.io.is_dataset(h5):
+ action = qt.QAction("Use as a new custom signal", event.source())
+ action.triggered.connect(lambda: self.useAsNewCustomSignal(h5))
+ menu.addAction(action)
+
+ if silx.io.is_group(h5) and silx.io.nxdata.is_valid_nxdata(h5):
+ action = qt.QAction("Use as a new custom NXdata", event.source())
+ action.triggered.connect(lambda: self.useAsNewCustomNxdata(h5))
+ menu.addAction(action)
+
+ if silx.io.is_file(h5):
+ action = qt.QAction("Close %s" % obj.local_filename, event.source())
+ action.triggered.connect(lambda: self.__treeview.findHdf5TreeModel().removeH5pyObject(h5))
+ menu.addAction(action)
+ action = qt.QAction("Synchronize %s" % obj.local_filename, event.source())
+ action.triggered.connect(lambda: self.__synchronizeH5pyObject(h5))
+ menu.addAction(action)
diff --git a/silx/app/view/__init__.py b/src/silx/app/view/__init__.py
index 229c44e..229c44e 100644
--- a/silx/app/view/__init__.py
+++ b/src/silx/app/view/__init__.py
diff --git a/src/silx/app/view/main.py b/src/silx/app/view/main.py
new file mode 100644
index 0000000..dbc6a2b
--- /dev/null
+++ b/src/silx/app/view/main.py
@@ -0,0 +1,186 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Module containing launcher of the `silx view` application"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2019"
+
+import argparse
+import logging
+import os
+import signal
+import sys
+
+
+_logger = logging.getLogger(__name__)
+"""Module logger"""
+
+
+def createParser():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ 'files',
+ nargs=argparse.ZERO_OR_MORE,
+ help='Data file to show (h5 file, edf files, spec files)')
+ parser.add_argument(
+ '--debug',
+ dest="debug",
+ action="store_true",
+ default=False,
+ help='Set logging system in debug mode')
+ parser.add_argument(
+ '--use-opengl-plot',
+ dest="use_opengl_plot",
+ action="store_true",
+ default=False,
+ help='Use OpenGL for plots (instead of matplotlib)')
+ parser.add_argument(
+ '-f', '--fresh',
+ dest="fresh_preferences",
+ action="store_true",
+ default=False,
+ help='Start the application using new fresh user preferences')
+ parser.add_argument(
+ '--hdf5-file-locking',
+ dest="hdf5_file_locking",
+ action="store_true",
+ default=False,
+ help='Start the application with HDF5 file locking enabled (it is disabled by default)')
+ return parser
+
+
+def createWindow(parent, settings):
+ from .Viewer import Viewer
+ window = Viewer(parent=None, settings=settings)
+ return window
+
+
+def mainQt(options):
+ """Part of the main depending on Qt"""
+ if options.debug:
+ logging.root.setLevel(logging.DEBUG)
+
+ #
+ # Import most of the things here to be sure to use the right logging level
+ #
+
+ # Use max opened files hard limit as soft limit
+ try:
+ import resource
+ except ImportError:
+ _logger.debug("No resource module available")
+ else:
+ if hasattr(resource, 'RLIMIT_NOFILE'):
+ try:
+ hard_nofile = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard_nofile, hard_nofile))
+ except (ValueError, OSError):
+ _logger.warning("Failed to retrieve and set the max opened files limit")
+ else:
+ _logger.debug("Set max opened files to %d", hard_nofile)
+
+ # This needs to be done prior to load HDF5
+ hdf5_file_locking = 'TRUE' if options.hdf5_file_locking else 'FALSE'
+ _logger.info('Set HDF5_USE_FILE_LOCKING=%s', hdf5_file_locking)
+ os.environ['HDF5_USE_FILE_LOCKING'] = hdf5_file_locking
+
+ try:
+ # it should be loaded before h5py
+ import hdf5plugin # noqa
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+
+ import h5py
+
+ import silx
+ import silx.utils.files
+ from silx.gui import qt
+ # Make sure matplotlib is configured
+ # Needed for Debian 8: compatibility between Qt4/Qt5 and old matplotlib
+ import silx.gui.utils.matplotlib # noqa
+
+ app = qt.QApplication([])
+ qt.QLocale.setDefault(qt.QLocale.c())
+
+ def sigintHandler(*args):
+ """Handler for the SIGINT signal."""
+ qt.QApplication.quit()
+
+ signal.signal(signal.SIGINT, sigintHandler)
+ sys.excepthook = qt.exceptionHandler
+
+ timer = qt.QTimer()
+ timer.start(500)
+ # Application have to wake up Python interpreter, else SIGINT is not
+ # catched
+ timer.timeout.connect(lambda: None)
+
+ settings = qt.QSettings(qt.QSettings.IniFormat,
+ qt.QSettings.UserScope,
+ "silx",
+ "silx-view",
+ None)
+ if options.fresh_preferences:
+ settings.clear()
+
+ window = createWindow(parent=None, settings=settings)
+ window.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+
+ if options.use_opengl_plot:
+ # It have to be done after the settings (after the Viewer creation)
+ silx.config.DEFAULT_PLOT_BACKEND = "opengl"
+
+ # NOTE: under Windows, cmd does not convert `*.tif` into existing files
+ options.files = silx.utils.files.expand_filenames(options.files)
+
+ for filename in options.files:
+ # TODO: Would be nice to add a process widget and a cancel button
+ try:
+ window.appendFile(filename)
+ except IOError as e:
+ _logger.error(e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+
+ window.show()
+ result = app.exec()
+ # remove ending warnings relative to QTimer
+ app.deleteLater()
+ return result
+
+
+def main(argv):
+ """
+ Main function to launch the viewer as an application
+
+ :param argv: Command line arguments
+ :returns: exit status
+ """
+ parser = createParser()
+ options = parser.parse_args(argv[1:])
+ mainQt(options)
+
+
+if __name__ == '__main__':
+ main(sys.argv)
diff --git a/silx/app/view/setup.py b/src/silx/app/view/setup.py
index fa076cb..fa076cb 100644
--- a/silx/app/view/setup.py
+++ b/src/silx/app/view/setup.py
diff --git a/src/silx/app/view/test/__init__.py b/src/silx/app/view/test/__init__.py
new file mode 100644
index 0000000..7790ee5
--- /dev/null
+++ b/src/silx/app/view/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/app/view/test/test_launcher.py b/src/silx/app/view/test/test_launcher.py
new file mode 100644
index 0000000..4f7aaa5
--- /dev/null
+++ b/src/silx/app/view/test/test_launcher.py
@@ -0,0 +1,140 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module testing silx.app.view"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/06/2018"
+
+
+import os
+import shutil
+import sys
+import tempfile
+import unittest
+import logging
+import subprocess
+import pytest
+
+from .. import main
+from silx import __main__ as silx_main
+
+_logger = logging.getLogger(__name__)
+
+
+@pytest.mark.usefixtures("qapp")
+class TestLauncher(unittest.TestCase):
+ """Test command line parsing"""
+
+ def testHelp(self):
+ # option -h must cause a raise SystemExit or a return 0
+ try:
+ parser = main.createParser()
+ parser.parse_args(["view", "--help"])
+ result = 0
+ except SystemExit as e:
+ result = e.args[0]
+ self.assertEqual(result, 0)
+
+ def testWrongOption(self):
+ try:
+ parser = main.createParser()
+ parser.parse_args(["view", "--foo"])
+ self.fail()
+ except SystemExit as e:
+ result = e.args[0]
+ self.assertNotEqual(result, 0)
+
+ def testWrongFile(self):
+ try:
+ parser = main.createParser()
+ result = parser.parse_args(["view", "__file.not.found__"])
+ result = 0
+ except SystemExit as e:
+ result = e.args[0]
+ self.assertEqual(result, 0)
+
+ def executeAsScript(self, filename, *args):
+ """Execute a command line.
+
+ Log output as debug in case of bad return code.
+ """
+ env = self.createTestEnv()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Copy file to temporary dir to avoid import from current dir.
+ script = os.path.join(tmpdir, 'launcher.py')
+ shutil.copyfile(filename, script)
+ command_line = [sys.executable, script] + list(args)
+
+ _logger.info("Execute: %s", " ".join(command_line))
+ p = subprocess.Popen(command_line,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ env=env)
+ out, err = p.communicate()
+ _logger.info("Return code: %d", p.returncode)
+ try:
+ out = out.decode('utf-8')
+ except UnicodeError:
+ pass
+ try:
+ err = err.decode('utf-8')
+ except UnicodeError:
+ pass
+
+ if p.returncode != 0:
+ _logger.info("stdout:")
+ _logger.info("%s", out)
+ _logger.info("stderr:")
+ _logger.info("%s", err)
+ else:
+ _logger.debug("stdout:")
+ _logger.debug("%s", out)
+ _logger.debug("stderr:")
+ _logger.debug("%s", err)
+ self.assertEqual(p.returncode, 0)
+
+ def createTestEnv(self):
+ """
+ Returns an associated environment with a working project.
+ """
+ env = dict((str(k), str(v)) for k, v in os.environ.items())
+ env["PYTHONPATH"] = os.pathsep.join(sys.path)
+ return env
+
+ def testExecuteViewHelp(self):
+ """Test if the main module is well connected.
+
+ Uses subprocess to avoid to parasite the current environment.
+ """
+ self.executeAsScript(main.__file__, "--help")
+
+ def testExecuteSilxViewHelp(self):
+ """Test if the main module is well connected.
+
+ Uses subprocess to avoid to parasite the current environment.
+ """
+ self.executeAsScript(silx_main.__file__, "view", "--help")
diff --git a/src/silx/app/view/test/test_view.py b/src/silx/app/view/test/test_view.py
new file mode 100644
index 0000000..e236e42
--- /dev/null
+++ b/src/silx/app/view/test/test_view.py
@@ -0,0 +1,388 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module testing silx.app.view"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/06/2018"
+
+
+import weakref
+import numpy
+import h5py
+import pytest
+
+from silx.gui import qt
+from silx.app.view.Viewer import Viewer
+from silx.app.view.About import About
+from silx.app.view.DataPanel import DataPanel
+from silx.app.view.CustomNxdataWidget import CustomNxdataWidget
+from silx.gui.hdf5._utils import Hdf5DatasetMimeData
+from silx.gui.utils.testutils import TestCaseQt
+from silx.io import commonh5
+
+
+@pytest.fixture(scope="module")
+def data_h5(tmpdir_factory):
+ filename = tmpdir_factory.mktemp("data").join("data.h5")
+ filename = str(filename)
+ f = h5py.File(filename, "w")
+ g = f.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ g.create_dataset("integers", data=numpy.array([10, 20, 30]))
+ f.close()
+ return filename
+
+
+@pytest.fixture(scope="module")
+def data2_h5(tmpdir_factory):
+ filename = tmpdir_factory.mktemp("data").join("data2.h5")
+ filename = str(filename)
+ f = h5py.File(filename, "w")
+ g = f.create_group("arrays")
+ g.create_dataset("scalar", data=20)
+ g.create_dataset("integers", data=numpy.array([10, 20, 30]))
+ f.close()
+ return filename
+
+
+@pytest.fixture(scope="class")
+def data_class_attr(request, data_h5, data2_h5):
+ """Provides test_options as class attribute
+
+ Used as transition from TestCase to pytest
+ """
+ request.cls.data_h5 = data_h5
+ request.cls.data2_h5 = data2_h5
+
+
+@pytest.mark.usefixtures("qapp")
+class TestViewer(TestCaseQt):
+ """Test for Viewer class"""
+
+ def testConstruct(self):
+ widget = Viewer()
+ self.qWaitForWindowExposed(widget)
+
+ def testDestroy(self):
+ widget = Viewer()
+ ref = weakref.ref(widget)
+ widget = None
+ self.qWaitForDestroy(ref)
+
+
+@pytest.mark.usefixtures("qapp")
+class TestAbout(TestCaseQt):
+ """Test for About box class"""
+
+ def testConstruct(self):
+ widget = About()
+ self.qWaitForWindowExposed(widget)
+
+ def testLicense(self):
+ widget = About()
+ widget.getHtmlLicense()
+ self.qWaitForWindowExposed(widget)
+
+ def testDestroy(self):
+ widget = About()
+ ref = weakref.ref(widget)
+ widget = None
+ self.qWaitForDestroy(ref)
+
+
+@pytest.mark.usefixtures("qapp")
+@pytest.mark.usefixtures("data_class_attr")
+class TestDataPanel(TestCaseQt):
+
+ def testConstruct(self):
+ widget = DataPanel()
+ self.qWaitForWindowExposed(widget)
+
+ def testDestroy(self):
+ widget = DataPanel()
+ ref = weakref.ref(widget)
+ widget = None
+ self.qWaitForDestroy(ref)
+
+ def testHeaderLabelPaintEvent(self):
+ widget = DataPanel()
+ data = numpy.array([1, 2, 3, 4, 5])
+ widget.setData(data)
+ # Expected to execute HeaderLabel.paintEvent
+ widget.setVisible(True)
+ self.qWaitForWindowExposed(widget)
+
+ def testData(self):
+ widget = DataPanel()
+ data = numpy.array([1, 2, 3, 4, 5])
+ widget.setData(data)
+ self.assertIs(widget.getData(), data)
+ self.assertIs(widget.getCustomNxdataItem(), None)
+
+ def testDataNone(self):
+ widget = DataPanel()
+ widget.setData(None)
+ self.assertIs(widget.getData(), None)
+ self.assertIs(widget.getCustomNxdataItem(), None)
+
+ def testCustomDataItem(self):
+ class CustomDataItemMock(object):
+ def getVirtualGroup(self):
+ return None
+
+ def text(self):
+ return ""
+
+ data = CustomDataItemMock()
+ widget = DataPanel()
+ widget.setCustomDataItem(data)
+ self.assertIs(widget.getData(), None)
+ self.assertIs(widget.getCustomNxdataItem(), data)
+
+ def testCustomDataItemNone(self):
+ data = None
+ widget = DataPanel()
+ widget.setCustomDataItem(data)
+ self.assertIs(widget.getData(), None)
+ self.assertIs(widget.getCustomNxdataItem(), data)
+
+ def testRemoveDatasetsFrom(self):
+ f = h5py.File(self.data_h5, mode='r')
+ try:
+ widget = DataPanel()
+ widget.setData(f["arrays/scalar"])
+ widget.removeDatasetsFrom(f)
+ self.assertIs(widget.getData(), None)
+ finally:
+ widget.setData(None)
+ f.close()
+
+ def testReplaceDatasetsFrom(self):
+ f = h5py.File(self.data_h5, mode='r')
+ f2 = h5py.File(self.data2_h5, mode='r')
+ try:
+ widget = DataPanel()
+ widget.setData(f["arrays/scalar"])
+ self.assertEqual(widget.getData()[()], 10)
+ widget.replaceDatasetsFrom(f, f2)
+ self.assertEqual(widget.getData()[()], 20)
+ finally:
+ widget.setData(None)
+ f.close()
+ f2.close()
+
+
+@pytest.mark.usefixtures("qapp")
+@pytest.mark.usefixtures("data_class_attr")
+class TestCustomNxdataWidget(TestCaseQt):
+
+ def testConstruct(self):
+ widget = CustomNxdataWidget()
+ self.qWaitForWindowExposed(widget)
+
+ def testDestroy(self):
+ widget = CustomNxdataWidget()
+ ref = weakref.ref(widget)
+ widget = None
+ self.qWaitForDestroy(ref)
+
+ def testCreateNxdata(self):
+ widget = CustomNxdataWidget()
+ model = widget.model()
+ model.createNewNxdata()
+ model.createNewNxdata("Foo")
+ widget.setVisible(True)
+ self.qWaitForWindowExposed(widget)
+
+ def testCreateNxdataFromDataset(self):
+ widget = CustomNxdataWidget()
+ model = widget.model()
+ signal = commonh5.Dataset("foo", data=numpy.array([[[5]]]))
+ model.createFromSignal(signal)
+ widget.setVisible(True)
+ self.qWaitForWindowExposed(widget)
+
+ def testCreateNxdataFromNxdata(self):
+ widget = CustomNxdataWidget()
+ model = widget.model()
+ data = numpy.array([[[5]]])
+ nxdata = commonh5.Group("foo")
+ nxdata.attrs["NX_class"] = "NXdata"
+ nxdata.attrs["signal"] = "signal"
+ nxdata.create_dataset("signal", data=data)
+ model.createFromNxdata(nxdata)
+ widget.setVisible(True)
+ self.qWaitForWindowExposed(widget)
+
+ def testCreateBadNxdata(self):
+ widget = CustomNxdataWidget()
+ model = widget.model()
+ signal = commonh5.Dataset("foo", data=numpy.array([[[5]]]))
+ model.createFromSignal(signal)
+ axis = commonh5.Dataset("foo", data=numpy.array([[[5]]]))
+ nxdataIndex = model.index(0, 0)
+ item = model.itemFromIndex(nxdataIndex)
+ item.setAxesDatasets([axis])
+ nxdata = item.getVirtualGroup()
+ self.assertIsNotNone(nxdata)
+ self.assertFalse(item.isValid())
+
+ def testRemoveDatasetsFrom(self):
+ f = h5py.File(self.data_h5, mode='r')
+ try:
+ widget = CustomNxdataWidget()
+ model = widget.model()
+ dataset = f["arrays/integers"]
+ model.createFromSignal(dataset)
+ widget.removeDatasetsFrom(f)
+ finally:
+ model.clear()
+ f.close()
+
+ def testReplaceDatasetsFrom(self):
+ f = h5py.File(self.data_h5, mode='r')
+ f2 = h5py.File(self.data2_h5, mode='r')
+ try:
+ widget = CustomNxdataWidget()
+ model = widget.model()
+ dataset = f["arrays/integers"]
+ model.createFromSignal(dataset)
+ widget.replaceDatasetsFrom(f, f2)
+ finally:
+ model.clear()
+ f.close()
+ f2.close()
+
+
+@pytest.mark.usefixtures("qapp")
+class TestCustomNxdataWidgetInteraction(TestCaseQt):
+ """Test CustomNxdataWidget with user interaction"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+
+ self.widget = CustomNxdataWidget()
+ self.model = self.widget.model()
+ data = numpy.array([[[5]]])
+ dataset = commonh5.Dataset("foo", data=data)
+ self.model.createFromSignal(dataset)
+ self.selectionModel = self.widget.selectionModel()
+
+ def tearDown(self):
+ self.selectionModel = None
+ self.model.clear()
+ self.model = None
+ self.widget = None
+ TestCaseQt.tearDown(self)
+
+ def testSelectedNxdata(self):
+ index = self.model.index(0, 0)
+ self.selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
+ nxdata = self.widget.selectedNxdata()
+ self.assertEqual(len(nxdata), 1)
+ self.assertIsNot(nxdata[0], None)
+
+ def testSelectedItems(self):
+ index = self.model.index(0, 0)
+ self.selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
+ items = self.widget.selectedItems()
+ self.assertEqual(len(items), 1)
+ self.assertIsNot(items[0], None)
+ self.assertIsInstance(items[0], qt.QStandardItem)
+
+ def testRowsAboutToBeRemoved(self):
+ self.model.removeRow(0)
+ self.qWaitForWindowExposed(self.widget)
+
+ def testPaintItems(self):
+ self.widget.expandAll()
+ self.widget.setVisible(True)
+ self.qWaitForWindowExposed(self.widget)
+
+ def testCreateDefaultContextMenu(self):
+ nxDataIndex = self.model.index(0, 0)
+ menu = self.widget.createDefaultContextMenu(nxDataIndex)
+ self.assertIsNot(menu, None)
+ self.assertIsInstance(menu, qt.QMenu)
+
+ signalIndex = self.model.index(0, 0, nxDataIndex)
+ menu = self.widget.createDefaultContextMenu(signalIndex)
+ self.assertIsNot(menu, None)
+ self.assertIsInstance(menu, qt.QMenu)
+
+ axesIndex = self.model.index(1, 0, nxDataIndex)
+ menu = self.widget.createDefaultContextMenu(axesIndex)
+ self.assertIsNot(menu, None)
+ self.assertIsInstance(menu, qt.QMenu)
+
+ def testDropNewDataset(self):
+ dataset = commonh5.Dataset("foo", numpy.array([1, 2, 3, 4]))
+ mimedata = Hdf5DatasetMimeData(dataset=dataset)
+ self.model.dropMimeData(mimedata, qt.Qt.CopyAction, -1, -1, qt.QModelIndex())
+ self.assertEqual(self.model.rowCount(qt.QModelIndex()), 2)
+
+ def testDropNewNxdata(self):
+ data = numpy.array([[[5]]])
+ nxdata = commonh5.Group("foo")
+ nxdata.attrs["NX_class"] = "NXdata"
+ nxdata.attrs["signal"] = "signal"
+ nxdata.create_dataset("signal", data=data)
+ mimedata = Hdf5DatasetMimeData(dataset=nxdata)
+ self.model.dropMimeData(mimedata, qt.Qt.CopyAction, -1, -1, qt.QModelIndex())
+ self.assertEqual(self.model.rowCount(qt.QModelIndex()), 2)
+
+ def testDropAxisDataset(self):
+ dataset = commonh5.Dataset("foo", numpy.array([1, 2, 3, 4]))
+ mimedata = Hdf5DatasetMimeData(dataset=dataset)
+ nxDataIndex = self.model.index(0, 0)
+ axesIndex = self.model.index(1, 0, nxDataIndex)
+ self.model.dropMimeData(mimedata, qt.Qt.CopyAction, -1, -1, axesIndex)
+ self.assertEqual(self.model.rowCount(qt.QModelIndex()), 1)
+ item = self.model.itemFromIndex(axesIndex)
+ self.assertIsNot(item.getDataset(), None)
+
+ def testMimeData(self):
+ nxDataIndex = self.model.index(0, 0)
+ signalIndex = self.model.index(0, 0, nxDataIndex)
+ mimeData = self.model.mimeData([signalIndex])
+ self.assertIsNot(mimeData, None)
+ self.assertIsInstance(mimeData, qt.QMimeData)
+
+ def testRemoveNxdataItem(self):
+ nxdataIndex = self.model.index(0, 0)
+ item = self.model.itemFromIndex(nxdataIndex)
+ self.model.removeNxdataItem(item)
+
+ def testAppendAxisToNxdataItem(self):
+ nxdataIndex = self.model.index(0, 0)
+ item = self.model.itemFromIndex(nxdataIndex)
+ self.model.appendAxisToNxdataItem(item)
+
+ def testRemoveAxisItem(self):
+ nxdataIndex = self.model.index(0, 0)
+ axesIndex = self.model.index(1, 0, nxdataIndex)
+ item = self.model.itemFromIndex(axesIndex)
+ self.model.removeAxisItem(item)
diff --git a/silx/app/view/utils.py b/src/silx/app/view/utils.py
index 80167c8..80167c8 100644
--- a/silx/app/view/utils.py
+++ b/src/silx/app/view/utils.py
diff --git a/src/silx/conftest.py b/src/silx/conftest.py
new file mode 100644
index 0000000..53b3edc
--- /dev/null
+++ b/src/silx/conftest.py
@@ -0,0 +1,130 @@
+import pytest
+import logging
+import os
+
+
+logger = logging.getLogger(__name__)
+
+
+def _set_qt_binding(binding):
+ if binding is not None:
+ binding = binding.lower()
+ if binding == "pyqt5":
+ logger.info("Force using PyQt5")
+ import PyQt5.QtCore # noqa
+ elif binding == "pyside2":
+ logger.info("Force using PySide2")
+ import PySide2.QtCore # noqa
+ elif binding == "pyside6":
+ logger.info("Force using PySide6")
+ import PySide6.QtCore # noqa
+ else:
+ raise ValueError("Qt binding '%s' is unknown" % binding)
+
+
+def pytest_addoption(parser):
+ parser.addoption("--qt-binding", type=str, default=None, dest="qt_binding",
+ help="Force using a Qt binding: 'PyQt5', 'PySide2', 'PySide6'")
+ parser.addoption("--no-gui", dest="gui", default=True,
+ action="store_false",
+ help="Disable the test of the graphical use interface")
+ parser.addoption("--no-opengl", dest="opengl", default=True,
+ action="store_false",
+ help="Disable tests using OpenGL")
+ parser.addoption("--no-opencl", dest="opencl", default=True,
+ action="store_false",
+ help="Disable the test of the OpenCL part")
+ parser.addoption("--low-mem", dest="low_mem", default=False,
+ action="store_true",
+ help="Disable test with large memory consumption (>100Mbyte")
+
+
+def pytest_configure(config):
+ if not config.getoption('opencl', True):
+ os.environ['SILX_OPENCL'] = 'False' # Disable OpenCL support in silx
+
+ _set_qt_binding(config.option.qt_binding)
+
+
+@pytest.fixture(scope="session")
+def test_options(request):
+ from .test import utils
+ options = utils._TestOptions()
+ options.configure(request.config.option)
+ yield options
+
+
+@pytest.fixture(scope="class")
+def test_options_class_attr(request, test_options):
+ """Provides test_options as class attribute
+
+ Used as transition from TestCase to pytest
+ """
+ request.cls.test_options = test_options
+
+
+@pytest.fixture(scope="session")
+def use_opengl(test_options):
+ """Fixture to flag test using a OpenGL.
+
+ This can be skipped with `--no-opengl`.
+ """
+ if not test_options.WITH_GL_TEST:
+ pytest.skip(test_options.WITH_GL_TEST_REASON, allow_module_level=True)
+
+
+@pytest.fixture(scope="session")
+def use_opencl(test_options):
+ """Fixture to flag test using a OpenCL.
+
+ This can be skipped with `--no-opencl`.
+ """
+ if not test_options.WITH_OPENCL_TEST:
+ pytest.skip(test_options.WITH_OPENCL_TEST_REASON, allow_module_level=True)
+
+
+@pytest.fixture(scope="session")
+def use_large_memory(test_options):
+ """Fixture to flag test using a large memory consumption.
+
+ This can be skipped with `--low-mem`.
+ """
+ if test_options.TEST_LOW_MEM:
+ pytest.skip(test_options.TEST_LOW_MEM_REASON, allow_module_level=True)
+
+
+@pytest.fixture(scope="session")
+def use_gui(test_options):
+ """Fixture to flag test using GUI.
+
+ This can be skipped with `--no-gui`.
+ """
+ if not test_options.WITH_QT_TEST:
+ pytest.skip(test_options.WITH_QT_TEST_REASON, allow_module_level=True)
+
+
+@pytest.fixture(scope="session")
+def qapp(use_gui, xvfb, request):
+ _set_qt_binding(request.config.option.qt_binding)
+
+ from silx.gui import qt
+ app = qt.QApplication.instance()
+ if app is None:
+ app = qt.QApplication([])
+ try:
+ yield app
+ finally:
+ if app is not None:
+ app.closeAllWindows()
+
+
+@pytest.fixture
+def qapp_utils(qapp):
+ """Helper containing method to deal with QApplication and widget"""
+ from silx.gui.utils.testutils import TestCaseQt
+ utils = TestCaseQt()
+ utils.setUpClass()
+ utils.setUp()
+ yield utils
+ utils.tearDown()
+ utils.tearDownClass()
diff --git a/silx/gui/__init__.py b/src/silx/gui/__init__.py
index b796e20..b796e20 100644
--- a/silx/gui/__init__.py
+++ b/src/silx/gui/__init__.py
diff --git a/silx/gui/_glutils/Context.py b/src/silx/gui/_glutils/Context.py
index c62dbb9..c62dbb9 100644
--- a/silx/gui/_glutils/Context.py
+++ b/src/silx/gui/_glutils/Context.py
diff --git a/src/silx/gui/_glutils/FramebufferTexture.py b/src/silx/gui/_glutils/FramebufferTexture.py
new file mode 100644
index 0000000..d12a6e0
--- /dev/null
+++ b/src/silx/gui/_glutils/FramebufferTexture.py
@@ -0,0 +1,168 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Association of a texture and a framebuffer object for off-screen rendering.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/07/2016"
+
+
+import logging
+
+from . import gl
+from .Texture import Texture
+
+
+_logger = logging.getLogger(__name__)
+
+
+class FramebufferTexture(object):
+ """Framebuffer with a texture.
+
+ Aimed at off-screen rendering to texture.
+
+ :param internalFormat: OpenGL texture internal format
+ :param shape: Shape (height, width) of the framebuffer and texture
+ :type shape: 2-tuple of int
+ :param stencilFormat: Stencil renderbuffer format
+ :param depthFormat: Depth renderbuffer format
+ :param kwargs: Extra arguments for :class:`Texture` constructor
+ """
+
+ _PACKED_FORMAT = gl.GL_DEPTH24_STENCIL8, gl.GL_DEPTH_STENCIL
+
+ def __init__(self,
+ internalFormat,
+ shape,
+ stencilFormat=gl.GL_DEPTH24_STENCIL8,
+ depthFormat=gl.GL_DEPTH24_STENCIL8,
+ **kwargs):
+
+ self._texture = Texture(internalFormat, shape=shape, **kwargs)
+ self._texture.prepare()
+
+ self._previousFramebuffer = 0 # Used by with statement
+
+ self._name = gl.glGenFramebuffers(1)
+
+ with self: # Bind FBO
+ # Attachments
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER,
+ gl.GL_COLOR_ATTACHMENT0,
+ gl.GL_TEXTURE_2D,
+ self._texture.name,
+ 0)
+
+ height, width = self._texture.shape
+
+ if stencilFormat is not None:
+ self._stencilId = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._stencilId)
+ gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
+ stencilFormat,
+ width, height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
+ gl.GL_STENCIL_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._stencilId)
+ else:
+ self._stencilId = None
+
+ if depthFormat is not None:
+ if self._stencilId and depthFormat in self._PACKED_FORMAT:
+ self._depthId = self._stencilId
+ else:
+ self._depthId = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self._depthId)
+ gl.glRenderbufferStorage(gl.GL_RENDERBUFFER,
+ depthFormat,
+ width, height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER,
+ gl.GL_DEPTH_ATTACHMENT,
+ gl.GL_RENDERBUFFER,
+ self._depthId)
+ else:
+ self._depthId = None
+
+ status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER)
+ if status != gl.GL_FRAMEBUFFER_COMPLETE:
+ _logger.error(
+ "OpenGL framebuffer initialization not complete, display may fail (error %d)",
+ status)
+
+ @property
+ def shape(self):
+ """Shape of the framebuffer (height, width)"""
+ return self._texture.shape
+
+ @property
+ def texture(self):
+ """The texture this framebuffer is rendering to.
+
+ The life-cycle of the texture is managed by this object"""
+ return self._texture
+
+ @property
+ def name(self):
+ """OpenGL name of the framebuffer"""
+ if self._name is not None:
+ return self._name
+ else:
+ raise RuntimeError("No OpenGL framebuffer resource, \
+ discard has already been called")
+
+ def bind(self):
+ """Bind this framebuffer for rendering"""
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.name)
+
+ # with statement
+
+ def __enter__(self):
+ self._previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ self.bind()
+
+ def __exit__(self, exctype, excvalue, traceback):
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self._previousFramebuffer)
+ self._previousFramebuffer = None
+
+ def discard(self):
+ """Delete associated OpenGL resources including texture"""
+ if self._name is not None:
+ gl.glDeleteFramebuffers(self._name)
+ self._name = None
+
+ if self._stencilId is not None:
+ gl.glDeleteRenderbuffers(self._stencilId)
+ if self._stencilId == self._depthId:
+ self._depthId = None
+ self._stencilId = None
+ if self._depthId is not None:
+ gl.glDeleteRenderbuffers(self._depthId)
+ self._depthId = None
+
+ self._texture.discard() # Also discard the texture
+ else:
+ _logger.warning("Discard has already been called")
diff --git a/src/silx/gui/_glutils/OpenGLWidget.py b/src/silx/gui/_glutils/OpenGLWidget.py
new file mode 100644
index 0000000..2ca4649
--- /dev/null
+++ b/src/silx/gui/_glutils/OpenGLWidget.py
@@ -0,0 +1,422 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a compatibility layer for OpenGL widget.
+
+It provides a compatibility layer for Qt OpenGL widget used in silx
+across Qt<=5.3 QtOpenGL.QGLWidget and QOpenGLWidget.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/11/2019"
+
+
+import logging
+import sys
+
+from .. import qt
+from ..utils.glutils import isOpenGLAvailable
+from .._glutils import gl
+
+
+_logger = logging.getLogger(__name__)
+
+
+if not hasattr(qt, 'QOpenGLWidget') and not hasattr(qt, 'QGLWidget'):
+ OpenGLWidget = None
+
+else:
+ if hasattr(qt, 'QOpenGLWidget'): # PyQt>=5.4
+ _logger.info('Using QOpenGLWidget')
+ _BaseOpenGLWidget = qt.QOpenGLWidget
+
+ else:
+ _logger.info('Using QGLWidget')
+ _BaseOpenGLWidget = qt.QGLWidget
+
+ class _OpenGLWidget(_BaseOpenGLWidget):
+ """Wrapper over QOpenGLWidget and QGLWidget"""
+
+ sigOpenGLContextError = qt.Signal(str)
+ """Signal emitted when an OpenGL context error is detected at runtime.
+
+ It provides the error reason as a str.
+ """
+
+ def __init__(self, parent,
+ alphaBufferSize=0,
+ depthBufferSize=24,
+ stencilBufferSize=8,
+ version=(2, 0),
+ f=qt.Qt.WindowFlags()):
+ # True if using QGLWidget, False if using QOpenGLWidget
+ self.__legacy = not hasattr(qt, 'QOpenGLWidget')
+
+ self.__devicePixelRatio = 1.0
+ self.__requestedOpenGLVersion = int(version[0]), int(version[1])
+ self.__isValid = False
+
+ if self.__legacy: # QGLWidget
+ format_ = qt.QGLFormat()
+ format_.setAlphaBufferSize(alphaBufferSize)
+ format_.setAlpha(alphaBufferSize != 0)
+ format_.setDepthBufferSize(depthBufferSize)
+ format_.setDepth(depthBufferSize != 0)
+ format_.setStencilBufferSize(stencilBufferSize)
+ format_.setStencil(stencilBufferSize != 0)
+ format_.setVersion(*self.__requestedOpenGLVersion)
+ format_.setDoubleBuffer(True)
+
+ super(_OpenGLWidget, self).__init__(format_, parent, None, f)
+
+ else: # QOpenGLWidget
+ super(_OpenGLWidget, self).__init__(parent, f)
+
+ format_ = qt.QSurfaceFormat()
+ format_.setAlphaBufferSize(alphaBufferSize)
+ format_.setDepthBufferSize(depthBufferSize)
+ format_.setStencilBufferSize(stencilBufferSize)
+ format_.setVersion(*self.__requestedOpenGLVersion)
+ format_.setSwapBehavior(qt.QSurfaceFormat.DoubleBuffer)
+ self.setFormat(format_)
+
+ # Enable receiving mouse move events when no buttons are pressed
+ self.setMouseTracking(True)
+
+ def getDevicePixelRatio(self):
+ """Returns the ratio device-independent / device pixel size
+
+ It should be either 1.0 or 2.0.
+
+ :return: Scale factor between screen and Qt units
+ :rtype: float
+ """
+ return self.__devicePixelRatio
+
+ def getRequestedOpenGLVersion(self):
+ """Returns the requested OpenGL version.
+
+ :return: (major, minor)
+ :rtype: 2-tuple of int"""
+ return self.__requestedOpenGLVersion
+
+ def getOpenGLVersion(self):
+ """Returns the available OpenGL version.
+
+ :return: (major, minor)
+ :rtype: 2-tuple of int"""
+ if self.__legacy: # QGLWidget
+ supportedVersion = 0, 0
+
+ # Go through all OpenGL version flags checking support
+ flags = self.format().openGLVersionFlags()
+ for version in ((1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
+ (2, 0), (2, 1),
+ (3, 0), (3, 1), (3, 2), (3, 3),
+ (4, 0)):
+ versionFlag = getattr(qt.QGLFormat,
+ 'OpenGL_Version_%d_%d' % version)
+ if not versionFlag & flags:
+ break
+ supportedVersion = version
+ return supportedVersion
+
+ else: # QOpenGLWidget
+ return self.format().version()
+
+ # QOpenGLWidget methods
+
+ def isValid(self):
+ """Returns True if OpenGL is available.
+
+ This adds extra checks to Qt isValid method.
+
+ :rtype: bool
+ """
+ return self.__isValid and super(_OpenGLWidget, self).isValid()
+
+ def defaultFramebufferObject(self):
+ """Returns the framebuffer object handle.
+
+ See :meth:`QOpenGLWidget.defaultFramebufferObject`
+ """
+ if self.__legacy: # QGLWidget
+ return 0
+ else: # QOpenGLWidget
+ return super(_OpenGLWidget, self).defaultFramebufferObject()
+
+ # *GL overridden methods
+
+ def initializeGL(self):
+ parent = self.parent()
+ if parent is None:
+ _logger.error('_OpenGLWidget has no parent')
+ return
+
+ # Check OpenGL version
+ if self.getOpenGLVersion() >= self.getRequestedOpenGLVersion():
+ try:
+ gl.glGetError() # clear any previous error (if any)
+ version = gl.glGetString(gl.GL_VERSION)
+ except:
+ version = None
+
+ if version:
+ self.__isValid = True
+ else:
+ errMsg = 'OpenGL not available'
+ if sys.platform.startswith('linux'):
+ errMsg += ': If connected remotely, ' \
+ 'GLX forwarding might be disabled.'
+ _logger.error(errMsg)
+ self.sigOpenGLContextError.emit(errMsg)
+ self.__isValid = False
+
+ else:
+ errMsg = 'OpenGL %d.%d not available' % \
+ self.getRequestedOpenGLVersion()
+ _logger.error('OpenGL widget disabled: %s', errMsg)
+ self.sigOpenGLContextError.emit(errMsg)
+ self.__isValid = False
+
+ if self.isValid():
+ parent.initializeGL()
+
+ def paintGL(self):
+ parent = self.parent()
+ if parent is None:
+ _logger.error('_OpenGLWidget has no parent')
+ return
+
+ devicePixelRatio = self.window().windowHandle().devicePixelRatio()
+
+ if devicePixelRatio != self.getDevicePixelRatio():
+ # Update devicePixelRatio and call resizeOpenGL
+ # as resizeGL is not always called.
+ self.__devicePixelRatio = devicePixelRatio
+ self.makeCurrent()
+ parent.resizeGL(self.width(), self.height())
+
+ if self.isValid():
+ parent.paintGL()
+
+ def resizeGL(self, width, height):
+ parent = self.parent()
+ if parent is None:
+ _logger.error('_OpenGLWidget has no parent')
+ return
+
+ if self.isValid():
+ # Call parent resizeGL with device-independent pixel unit
+ # This works over both QGLWidget and QOpenGLWidget
+ parent.resizeGL(self.width(), self.height())
+
+
+class OpenGLWidget(qt.QWidget):
+ """OpenGL widget wrapper over QGLWidget and QOpenGLWidget
+
+ This wrapper API implements a subset of QOpenGLWidget API.
+ The constructor takes a different set of arguments.
+ Methods returning object like :meth:`context` returns either
+ QGL* or QOpenGL* objects.
+
+ :param parent: Parent widget see :class:`QWidget`
+ :param int alphaBufferSize:
+ Size in bits of the alpha channel (default: 0).
+ Set to 0 to disable alpha channel.
+ :param int depthBufferSize:
+ Size in bits of the depth buffer (default: 24).
+ Set to 0 to disable depth buffer.
+ :param int stencilBufferSize:
+ Size in bits of the stencil buffer (default: 8).
+ Set to 0 to disable stencil buffer
+ :param version: Requested OpenGL version (default: (2, 0)).
+ :type version: 2-tuple of int
+ :param f: see :class:`QWidget`
+ """
+
+ def __init__(self, parent=None,
+ alphaBufferSize=0,
+ depthBufferSize=24,
+ stencilBufferSize=8,
+ version=(2, 0),
+ f=qt.Qt.WindowFlags()):
+ super(OpenGLWidget, self).__init__(parent, f)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+
+ self.__context = None
+
+ _check = isOpenGLAvailable(version=version, runtimeCheck=False)
+ if _OpenGLWidget is None or not _check:
+ _logger.error('OpenGL-based widget disabled: %s', _check.error)
+ self.__openGLWidget = None
+ label = self._createErrorQLabel(_check.error)
+ self.layout().addWidget(label)
+
+ else:
+ self.__openGLWidget = _OpenGLWidget(
+ parent=self,
+ alphaBufferSize=alphaBufferSize,
+ depthBufferSize=depthBufferSize,
+ stencilBufferSize=stencilBufferSize,
+ version=version,
+ f=f)
+ # Async connection need, otherwise issue when hiding OpenGL
+ # widget while doing the rendering..
+ self.__openGLWidget.sigOpenGLContextError.connect(
+ self._handleOpenGLInitError, qt.Qt.QueuedConnection)
+ self.layout().addWidget(self.__openGLWidget)
+
+ @staticmethod
+ def _createErrorQLabel(error):
+ """Create QLabel displaying error message in place of OpenGL widget
+
+ :param str error: The error message to display"""
+ label = qt.QLabel()
+ label.setText('OpenGL-based widget disabled:\n%s' % error)
+ label.setAlignment(qt.Qt.AlignCenter)
+ label.setWordWrap(True)
+ return label
+
+ def _handleOpenGLInitError(self, error):
+ """Handle runtime errors in OpenGL widget"""
+ if self.__openGLWidget is not None:
+ self.__openGLWidget.setVisible(False)
+ self.__openGLWidget.setParent(None)
+ self.__openGLWidget = None
+
+ label = self._createErrorQLabel(error)
+ self.layout().addWidget(label)
+
+ # Additional API
+
+ def getDevicePixelRatio(self):
+ """Returns the ratio device-independent / device pixel size
+
+ It should be either 1.0 or 2.0.
+
+ :return: Scale factor between screen and Qt units
+ :rtype: float
+ """
+ if self.__openGLWidget is None:
+ return 1.
+ else:
+ return self.__openGLWidget.getDevicePixelRatio()
+
+ def getDotsPerInch(self):
+ """Returns current screen resolution as device pixels per inch.
+
+ :rtype: float
+ """
+ screen = self.window().windowHandle().screen()
+ if screen is not None:
+ # TODO check if this is correct on different OS/screen
+ # OK on macOS10.12/qt5.13.2
+ dpi = screen.physicalDotsPerInch() * self.getDevicePixelRatio()
+ else: # Fallback
+ dpi = 96. * self.getDevicePixelRatio()
+ return dpi
+
+ def getOpenGLVersion(self):
+ """Returns the available OpenGL version.
+
+ :return: (major, minor)
+ :rtype: 2-tuple of int"""
+ if self.__openGLWidget is None:
+ return 0, 0
+ else:
+ return self.__openGLWidget.getOpenGLVersion()
+
+ # QOpenGLWidget API
+
+ def isValid(self):
+ """Returns True if OpenGL with the requested version is available.
+
+ :rtype: bool
+ """
+ if self.__openGLWidget is None:
+ return False
+ else:
+ return self.__openGLWidget.isValid()
+
+ def context(self):
+ """Return Qt OpenGL context object or None.
+
+ See :meth:`QOpenGLWidget.context` and :meth:`QGLWidget.context`
+ """
+ if self.__openGLWidget is None:
+ return None
+ else:
+ # Keep a reference on QOpenGLContext to make
+ # else PyQt5 keeps creating a new one.
+ self.__context = self.__openGLWidget.context()
+ return self.__context
+
+ def defaultFramebufferObject(self):
+ """Returns the framebuffer object handle.
+
+ See :meth:`QOpenGLWidget.defaultFramebufferObject`
+ """
+ if self.__openGLWidget is None:
+ return 0
+ else:
+ return self.__openGLWidget.defaultFramebufferObject()
+
+ def makeCurrent(self):
+ """Make the underlying OpenGL widget's context current.
+
+ See :meth:`QOpenGLWidget.makeCurrent`
+ """
+ if self.__openGLWidget is not None:
+ self.__openGLWidget.makeCurrent()
+
+ def update(self):
+ """Async update of the OpenGL widget.
+
+ See :meth:`QOpenGLWidget.update`
+ """
+ if self.__openGLWidget is not None:
+ self.__openGLWidget.update()
+
+ # QOpenGLWidget API to override
+
+ def initializeGL(self):
+ """Override to implement OpenGL initialization."""
+ pass
+
+ def paintGL(self):
+ """Override to implement OpenGL rendering."""
+ pass
+
+ def resizeGL(self, width, height):
+ """Override to implement resize of OpenGL framebuffer.
+
+ :param int width: Width in device-independent pixels
+ :param int height: Height in device-independent pixels
+ """
+ pass
diff --git a/silx/gui/_glutils/Program.py b/src/silx/gui/_glutils/Program.py
index 87eec5f..87eec5f 100644
--- a/silx/gui/_glutils/Program.py
+++ b/src/silx/gui/_glutils/Program.py
diff --git a/silx/gui/_glutils/Texture.py b/src/silx/gui/_glutils/Texture.py
index c72135a..c72135a 100644
--- a/silx/gui/_glutils/Texture.py
+++ b/src/silx/gui/_glutils/Texture.py
diff --git a/silx/gui/_glutils/VertexBuffer.py b/src/silx/gui/_glutils/VertexBuffer.py
index b74b748..b74b748 100644
--- a/silx/gui/_glutils/VertexBuffer.py
+++ b/src/silx/gui/_glutils/VertexBuffer.py
diff --git a/silx/gui/_glutils/__init__.py b/src/silx/gui/_glutils/__init__.py
index e88affd..e88affd 100644
--- a/silx/gui/_glutils/__init__.py
+++ b/src/silx/gui/_glutils/__init__.py
diff --git a/src/silx/gui/_glutils/font.py b/src/silx/gui/_glutils/font.py
new file mode 100644
index 0000000..3ea474d
--- /dev/null
+++ b/src/silx/gui/_glutils/font.py
@@ -0,0 +1,156 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Text rasterisation feature leveraging Qt font and text layout support."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+import logging
+import numpy
+
+from ..utils.image import convertQImageToArray
+from .. import qt
+
+_logger = logging.getLogger(__name__)
+
+
+def getDefaultFontFamily():
+ """Returns the default font family of the application"""
+ return qt.QApplication.instance().font().family()
+
+
+# Font weights
+ULTRA_LIGHT = 0
+"""Lightest characters: Minimum font weight"""
+
+LIGHT = 25
+"""Light characters"""
+
+NORMAL = 50
+"""Normal characters"""
+
+SEMI_BOLD = 63
+"""Between normal and bold characters"""
+
+BOLD = 74
+"""Thicker characters"""
+
+BLACK = 87
+"""Really thick characters"""
+
+ULTRA_BLACK = 99
+"""Thickest characters: Maximum font weight"""
+
+
+def rasterText(text, font,
+ size=-1,
+ weight=-1,
+ italic=False,
+ devicePixelRatio=1.0):
+ """Raster text using Qt.
+
+ It supports multiple lines.
+
+ :param str text: The text to raster
+ :param font: Font name or QFont to use
+ :type font: str or :class:`QFont`
+ :param int size:
+ Font size in points
+ Used only if font is given as name.
+ :param int weight:
+ Font weight in [0, 99], see QFont.Weight.
+ Used only if font is given as name.
+ :param bool italic:
+ True for italic font (default: False).
+ Used only if font is given as name.
+ :param float devicePixelRatio:
+ The current ratio between device and device-independent pixel
+ (default: 1.0)
+ :return: Corresponding image in gray scale and baseline offset from top
+ :rtype: (HxW numpy.ndarray of uint8, int)
+ """
+ if not text:
+ _logger.info("Trying to raster empty text, replaced by white space")
+ text = ' ' # Replace empty text by white space to produce an image
+
+ if not isinstance(font, qt.QFont):
+ font = qt.QFont(font, size, weight, italic)
+
+ # get text size
+ image = qt.QImage(1, 1, qt.QImage.Format_RGB888)
+ painter = qt.QPainter()
+ painter.begin(image)
+ painter.setPen(qt.Qt.white)
+ painter.setFont(font)
+ bounds = painter.boundingRect(
+ qt.QRect(0, 0, 4096, 4096), qt.Qt.TextExpandTabs, text)
+ painter.end()
+
+ metrics = qt.QFontMetrics(font)
+
+ # This does not provide the correct text bbox on macOS
+ # size = metrics.size(qt.Qt.TextExpandTabs, text)
+ # bounds = metrics.boundingRect(
+ # qt.QRect(0, 0, size.width(), size.height()),
+ # qt.Qt.TextExpandTabs,
+ # text)
+
+ # Add extra border and handle devicePixelRatio
+ width = bounds.width() * devicePixelRatio + 2
+ # align line size to 32 bits to ease conversion to numpy array
+ width = 4 * ((width + 3) // 4)
+ image = qt.QImage(int(width),
+ int(bounds.height() * devicePixelRatio + 2),
+ qt.QImage.Format_RGB888)
+ image.setDevicePixelRatio(devicePixelRatio)
+
+ # TODO if Qt5 use Format_Grayscale8 instead
+ image.fill(0)
+
+ # Raster text
+ painter = qt.QPainter()
+ painter.begin(image)
+ painter.setPen(qt.Qt.white)
+ painter.setFont(font)
+ painter.drawText(bounds, qt.Qt.TextExpandTabs, text)
+ painter.end()
+
+ array = convertQImageToArray(image)
+
+ # RGB to R
+ array = numpy.ascontiguousarray(array[:, :, 0])
+
+ # Remove leading and trailing empty columns but one on each side
+ column_cumsum = numpy.cumsum(numpy.sum(array, axis=0))
+ array = array[:, column_cumsum.argmin():column_cumsum.argmax() + 2]
+
+ # Remove leading and trailing empty rows but one on each side
+ row_cumsum = numpy.cumsum(numpy.sum(array, axis=1))
+ min_row = row_cumsum.argmin()
+ array = array[min_row:row_cumsum.argmax() + 2, :]
+
+ return array, metrics.ascent() - min_row
diff --git a/silx/gui/_glutils/gl.py b/src/silx/gui/_glutils/gl.py
index 608d9ce..608d9ce 100644
--- a/silx/gui/_glutils/gl.py
+++ b/src/silx/gui/_glutils/gl.py
diff --git a/src/silx/gui/_glutils/utils.py b/src/silx/gui/_glutils/utils.py
new file mode 100644
index 0000000..5886599
--- /dev/null
+++ b/src/silx/gui/_glutils/utils.py
@@ -0,0 +1,123 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides conversion functions between OpenGL and numpy types.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+import numpy
+
+from OpenGL.constants import BYTE_SIZES as _BYTE_SIZES
+from OpenGL.constants import ARRAY_TO_GL_TYPE_MAPPING as _ARRAY_TO_GL_TYPE_MAPPING
+
+
+def sizeofGLType(type_):
+ """Returns the size in bytes of an element of type `type_`"""
+ return _BYTE_SIZES[type_]
+
+
+def isSupportedGLType(type_):
+ """Test if a numpy type or dtype can be converted to a GL type."""
+ return numpy.dtype(type_).char in _ARRAY_TO_GL_TYPE_MAPPING
+
+
+def numpyToGLType(type_):
+ """Returns the GL type corresponding the provided numpy type or dtype."""
+ return _ARRAY_TO_GL_TYPE_MAPPING[numpy.dtype(type_).char]
+
+
+def segmentTrianglesIntersection(segment, triangles):
+ """Check for segment/triangles intersection.
+
+ This is based on signed tetrahedron volume comparison.
+
+ See A. Kensler, A., Shirley, P.
+ Optimizing Ray-Triangle Intersection via Automated Search.
+ Symposium on Interactive Ray Tracing, vol. 0, p33-38 (2006)
+
+ :param numpy.ndarray segment:
+ Segment end points as a 2x3 array of coordinates
+ :param numpy.ndarray triangles:
+ Nx3x3 array of triangles
+ :return: (triangle indices, segment parameter, barycentric coord)
+ Indices of intersected triangles, "depth" along the segment
+ of the intersection point and barycentric coordinates of intersection
+ point in the triangle.
+ :rtype: List[numpy.ndarray]
+ """
+ # TODO triangles from vertices + indices
+ # TODO early rejection? e.g., check segment bbox vs triangle bbox
+ segment = numpy.asarray(segment)
+ assert segment.ndim == 2
+ assert segment.shape == (2, 3)
+
+ triangles = numpy.asarray(triangles)
+ assert triangles.ndim == 3
+ assert triangles.shape[1] == 3
+
+ # Test line/triangles intersection
+ d = segment[1] - segment[0]
+ t0s0 = segment[0] - triangles[:, 0, :]
+ edge01 = triangles[:, 1, :] - triangles[:, 0, :]
+ edge02 = triangles[:, 2, :] - triangles[:, 0, :]
+
+ dCrossEdge02 = numpy.cross(d, edge02)
+ t0s0CrossEdge01 = numpy.cross(t0s0, edge01)
+ volume = numpy.sum(dCrossEdge02 * edge01, axis=1)
+ del edge01
+ subVolumes = numpy.empty((len(triangles), 3), dtype=triangles.dtype)
+ subVolumes[:, 1] = numpy.sum(dCrossEdge02 * t0s0, axis=1)
+ del dCrossEdge02
+ subVolumes[:, 2] = numpy.sum(t0s0CrossEdge01 * d, axis=1)
+ subVolumes[:, 0] = volume - subVolumes[:, 1] - subVolumes[:, 2]
+ intersect = numpy.logical_or(
+ numpy.all(subVolumes >= 0., axis=1), # All positive
+ numpy.all(subVolumes <= 0., axis=1)) # All negative
+ intersect = numpy.where(intersect)[0] # Indices of intersected triangles
+
+ # Get barycentric coordinates
+ with numpy.errstate(invalid="ignore"):
+ barycentric = subVolumes[intersect] / volume[intersect].reshape(-1, 1)
+ del subVolumes
+
+ # Test segment/triangles intersection
+ volAlpha = numpy.sum(t0s0CrossEdge01[intersect] * edge02[intersect], axis=1)
+ with numpy.errstate(invalid="ignore"):
+ t = volAlpha / volume[intersect] # segment parameter of intersected triangles
+ del t0s0CrossEdge01
+ del edge02
+ del volAlpha
+ del volume
+
+ inSegmentMask = numpy.logical_and(t >= 0., t <= 1.)
+ intersect = intersect[inSegmentMask]
+ t = t[inSegmentMask]
+ barycentric = barycentric[inSegmentMask]
+
+ # Sort intersecting triangles by t
+ indices = numpy.argsort(t)
+ return intersect[indices], t[indices], barycentric[indices]
diff --git a/src/silx/gui/colors.py b/src/silx/gui/colors.py
new file mode 100755
index 0000000..12046cf
--- /dev/null
+++ b/src/silx/gui/colors.py
@@ -0,0 +1,1036 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides API to manage colors.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent", "H.Payno"]
+__license__ = "MIT"
+__date__ = "29/01/2019"
+
+import numpy
+import logging
+
+from silx.gui import qt
+from silx.gui.utils import blockSignals
+from silx.math import colormap as _colormap
+from silx.utils.exceptions import NotEditableError
+from silx.utils import deprecation
+
+
+_logger = logging.getLogger(__name__)
+
+try:
+ import silx.gui.utils.matplotlib # noqa Initalize matplotlib
+ from matplotlib import cm as _matplotlib_cm
+ from matplotlib.pyplot import colormaps as _matplotlib_colormaps
+except ImportError:
+ _logger.info("matplotlib not available, only embedded colormaps available")
+ _matplotlib_cm = None
+ _matplotlib_colormaps = None
+
+
+_COLORDICT = {}
+"""Dictionary of common colors."""
+
+_COLORDICT['b'] = _COLORDICT['blue'] = '#0000ff'
+_COLORDICT['r'] = _COLORDICT['red'] = '#ff0000'
+_COLORDICT['g'] = _COLORDICT['green'] = '#00ff00'
+_COLORDICT['k'] = _COLORDICT['black'] = '#000000'
+_COLORDICT['w'] = _COLORDICT['white'] = '#ffffff'
+_COLORDICT['pink'] = '#ff66ff'
+_COLORDICT['brown'] = '#a52a2a'
+_COLORDICT['orange'] = '#ff9900'
+_COLORDICT['violet'] = '#6600ff'
+_COLORDICT['gray'] = _COLORDICT['grey'] = '#a0a0a4'
+# _COLORDICT['darkGray'] = _COLORDICT['darkGrey'] = '#808080'
+# _COLORDICT['lightGray'] = _COLORDICT['lightGrey'] = '#c0c0c0'
+_COLORDICT['y'] = _COLORDICT['yellow'] = '#ffff00'
+_COLORDICT['m'] = _COLORDICT['magenta'] = '#ff00ff'
+_COLORDICT['c'] = _COLORDICT['cyan'] = '#00ffff'
+_COLORDICT['darkBlue'] = '#000080'
+_COLORDICT['darkRed'] = '#800000'
+_COLORDICT['darkGreen'] = '#008000'
+_COLORDICT['darkBrown'] = '#660000'
+_COLORDICT['darkCyan'] = '#008080'
+_COLORDICT['darkYellow'] = '#808000'
+_COLORDICT['darkMagenta'] = '#800080'
+_COLORDICT['transparent'] = '#00000000'
+
+
+# FIXME: It could be nice to expose a functional API instead of that attribute
+COLORDICT = _COLORDICT
+
+
+DEFAULT_MIN_LIN = 0
+"""Default min value if in linear normalization"""
+DEFAULT_MAX_LIN = 1
+"""Default max value if in linear normalization"""
+
+
+def rgba(color, colorDict=None):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a tuple (R, G, B, A)
+ of floats.
+
+ It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
+ QColor as color argument.
+
+ :param str color: The color to convert
+ :param dict colorDict: A dictionary of color name conversion to color code
+ :returns: RGBA colors as floats in [0., 1.]
+ :rtype: tuple
+ """
+ if colorDict is None:
+ colorDict = _COLORDICT
+
+ if hasattr(color, 'getRgb'): # QColor support
+ color = color.getRgb()
+
+ values = numpy.asarray(color).ravel()
+
+ if values.dtype.kind in 'iuf': # integer or float
+ # Color is an array
+ assert len(values) in (3, 4)
+
+ # Convert from integers in [0, 255] to float in [0, 1]
+ if values.dtype.kind in 'iu':
+ values = values / 255.
+
+ # Clip to [0, 1]
+ values[values < 0.] = 0.
+ values[values > 1.] = 1.
+
+ if len(values) == 3:
+ return values[0], values[1], values[2], 1.
+ else:
+ return tuple(values)
+
+ # We assume color is a string
+ if not color.startswith('#'):
+ color = colorDict[color]
+
+ assert len(color) in (7, 9) and color[0] == '#'
+ r = int(color[1:3], 16) / 255.
+ g = int(color[3:5], 16) / 255.
+ b = int(color[5:7], 16) / 255.
+ a = int(color[7:9], 16) / 255. if len(color) == 9 else 1.
+ return r, g, b, a
+
+
+def greyed(color, colorDict=None):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a grey color
+ (R, G, B, A).
+
+ It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
+ QColor as color argument.
+
+ :param str color: The color to convert
+ :param dict colorDict: A dictionary of color name conversion to color code
+ :returns: RGBA colors as floats in [0., 1.]
+ :rtype: tuple
+ """
+ r, g, b, a = rgba(color=color, colorDict=colorDict)
+ g = 0.21 * r + 0.72 * g + 0.07 * b
+ return g, g, g, a
+
+
+def asQColor(color):
+ """Convert color code '#RRGGBB' and '#RRGGBBAA' to a `qt.QColor`.
+
+ It also supports RGB(A) from uint8 in [0, 255], float in [0, 1], and
+ QColor as color argument.
+
+ :param str color: The color to convert
+ :rtype: qt.QColor
+ """
+ color = rgba(color)
+ return qt.QColor.fromRgbF(*color)
+
+
+def cursorColorForColormap(colormapName):
+ """Get a color suitable for overlay over a colormap.
+
+ :param str colormapName: The name of the colormap.
+ :return: Name of the color.
+ :rtype: str
+ """
+ return _colormap.get_colormap_cursor_color(colormapName)
+
+
+# Colormap loader
+
+def _registerColormapFromMatplotlib(name, cursor_color='black', preferred=False):
+ colormap = _matplotlib_cm.get_cmap(name)
+ lut = colormap(numpy.linspace(0, 1, colormap.N, endpoint=True))
+ colors = _colormap.array_to_rgba8888(lut)
+ registerLUT(name, colors, cursor_color, preferred)
+
+
+def _getColormap(name):
+ """Returns the color LUT corresponding to a colormap name
+ :param str name: Name of the colormap to load
+ :returns: Corresponding table of colors
+ :rtype: numpy.ndarray
+ :raise ValueError: If no colormap corresponds to name
+ """
+ name = str(name)
+ try:
+ return _colormap.get_colormap_lut(name)
+ except ValueError:
+ # Colormap is not available, try to load it from matplotlib
+ _registerColormapFromMatplotlib(name, 'black', False)
+ return _colormap.get_colormap_lut(name)
+
+
+class Colormap(qt.QObject):
+ """Description of a colormap
+
+ If no `name` nor `colors` are provided, a default gray LUT is used.
+
+ :param str name: Name of the colormap
+ :param tuple colors: optional, custom colormap.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ If 'name' is None, then this array is used as the colormap.
+ :param str normalization: Normalization: 'linear' (default) or 'log'
+ :param vmin: Lower bound of the colormap or None for autoscale (default)
+ :type vmin: Union[None, float]
+ :param vmax: Upper bounds of the colormap or None for autoscale (default)
+ :type vmax: Union[None, float]
+ """
+
+ LINEAR = 'linear'
+ """constant for linear normalization"""
+
+ LOGARITHM = 'log'
+ """constant for logarithmic normalization"""
+
+ SQRT = 'sqrt'
+ """constant for square root normalization"""
+
+ GAMMA = 'gamma'
+ """Constant for gamma correction normalization"""
+
+ ARCSINH = 'arcsinh'
+ """constant for inverse hyperbolic sine normalization"""
+
+ _BASIC_NORMALIZATIONS = {
+ LINEAR: _colormap.LinearNormalization(),
+ LOGARITHM: _colormap.LogarithmicNormalization(),
+ SQRT: _colormap.SqrtNormalization(),
+ ARCSINH: _colormap.ArcsinhNormalization(),
+ }
+ """Normalizations without parameters"""
+
+ NORMALIZATIONS = LINEAR, LOGARITHM, SQRT, GAMMA, ARCSINH
+ """Tuple of managed normalizations"""
+
+ MINMAX = 'minmax'
+ """constant for autoscale using min/max data range"""
+
+ STDDEV3 = 'stddev3'
+ """constant for autoscale using mean +/- 3*std(data)
+ with a clamp on min/max of the data"""
+
+ AUTOSCALE_MODES = (MINMAX, STDDEV3)
+ """Tuple of managed auto scale algorithms"""
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the colormap has changed."""
+
+ _DEFAULT_NAN_COLOR = 255, 255, 255, 0
+
+ def __init__(self, name=None, colors=None, normalization=LINEAR, vmin=None, vmax=None, autoscaleMode=MINMAX):
+ qt.QObject.__init__(self)
+ self._editable = True
+ self.__gamma = 2.0
+ # Default NaN color: fully transparent white
+ self.__nanColor = numpy.array(self._DEFAULT_NAN_COLOR, dtype=numpy.uint8)
+
+ assert normalization in Colormap.NORMALIZATIONS
+ assert autoscaleMode in Colormap.AUTOSCALE_MODES
+
+ if normalization is Colormap.LOGARITHM:
+ if (vmin is not None and vmin < 0) or (vmax is not None and vmax < 0):
+ m = "Unsuported vmin (%s) and/or vmax (%s) given for a log scale."
+ m += ' Autoscale will be performed.'
+ m = m % (vmin, vmax)
+ _logger.warning(m)
+ vmin = None
+ vmax = None
+
+ self._name = None
+ self._colors = None
+
+ if colors is not None and name is not None:
+ deprecation.deprecated_warning("Argument",
+ name="silx.gui.plot.Colors",
+ reason="name and colors can't be used at the same time",
+ since_version="0.10.0",
+ skip_backtrace_count=1)
+
+ colors = None
+
+ if name is not None:
+ self.setName(name) # And resets colormap LUT
+ elif colors is not None:
+ self.setColormapLUT(colors)
+ else:
+ # Default colormap is grey
+ self.setName("gray")
+
+ self._normalization = str(normalization)
+ self._autoscaleMode = str(autoscaleMode)
+ self._vmin = float(vmin) if vmin is not None else None
+ self._vmax = float(vmax) if vmax is not None else None
+ self.__warnBadVmin = True
+ self.__warnBadVmax = True
+
+ def setFromColormap(self, other):
+ """Set this colormap using information from the `other` colormap.
+
+ :param ~silx.gui.colors.Colormap other: Colormap to use as reference.
+ """
+ if not self.isEditable():
+ raise NotEditableError('Colormap is not editable')
+ if self == other:
+ return
+ with blockSignals(self):
+ name = other.getName()
+ if name is not None:
+ self.setName(name)
+ else:
+ self.setColormapLUT(other.getColormapLUT())
+ self.setNaNColor(other.getNaNColor())
+ self.setNormalization(other.getNormalization())
+ self.setGammaNormalizationParameter(
+ other.getGammaNormalizationParameter())
+ self.setAutoscaleMode(other.getAutoscaleMode())
+ self.setVRange(*other.getVRange())
+ self.setEditable(other.isEditable())
+ self.sigChanged.emit()
+
+ def getNColors(self, nbColors=None):
+ """Returns N colors computed by sampling the colormap regularly.
+
+ :param nbColors:
+ The number of colors in the returned array or None for the default value.
+ The default value is the size of the colormap LUT.
+ :type nbColors: int or None
+ :return: 2D array of uint8 of shape (nbColors, 4)
+ :rtype: numpy.ndarray
+ """
+ # Handle default value for nbColors
+ if nbColors is None:
+ return numpy.array(self._colors, copy=True)
+ else:
+ nbColors = int(nbColors)
+ colormap = self.copy()
+ colormap.setNormalization(Colormap.LINEAR)
+ colormap.setVRange(vmin=0, vmax=nbColors - 1)
+ colors = colormap.applyToData(
+ numpy.arange(nbColors, dtype=numpy.int32))
+ return colors
+
+ def getName(self):
+ """Return the name of the colormap
+ :rtype: str
+ """
+ return self._name
+
+ def setName(self, name):
+ """Set the name of the colormap to use.
+
+ :param str name: The name of the colormap.
+ At least the following names are supported: 'gray',
+ 'reversed gray', 'temperature', 'red', 'green', 'blue', 'jet',
+ 'viridis', 'magma', 'inferno', 'plasma'.
+ """
+ name = str(name)
+ if self._name == name:
+ return
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if name not in self.getSupportedColormaps():
+ raise ValueError("Colormap name '%s' is not supported" % name)
+ self._name = name
+ self._colors = _getColormap(self._name)
+ self.sigChanged.emit()
+
+ def getColormapLUT(self, copy=True):
+ """Return the list of colors for the colormap or None if not set.
+
+ This returns None if the colormap was set with :meth:`setName`.
+ Use :meth:`getNColors` to get the colormap LUT for any colormap.
+
+ :param bool copy: If true a copy of the numpy array is provided
+ :return: the list of colors for the colormap or None if not set
+ :rtype: numpy.ndarray or None
+ """
+ if self._name is None:
+ return numpy.array(self._colors, copy=copy)
+ else:
+ return None
+
+ def setColormapLUT(self, colors):
+ """Set the colors of the colormap.
+
+ :param numpy.ndarray colors: the colors of the LUT.
+ If float, it is converted from [0, 1] to uint8 range.
+ Otherwise it is casted to uint8.
+
+ .. warning: this will set the value of name to None
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ assert colors is not None
+
+ colors = numpy.array(colors, copy=False)
+ if colors.shape == ():
+ raise TypeError("An array is expected for 'colors' argument. '%s' was found." % type(colors))
+ assert len(colors) != 0
+ assert colors.ndim >= 2
+ colors.shape = -1, colors.shape[-1]
+ self._colors = _colormap.array_to_rgba8888(colors)
+ self._name = None
+ self.sigChanged.emit()
+
+ def getNaNColor(self):
+ """Returns the color to use for Not-A-Number floating point value.
+
+ :rtype: QColor
+ """
+ return qt.QColor(*self.__nanColor)
+
+ def setNaNColor(self, color):
+ """Set the color to use for Not-A-Number floating point value.
+
+ :param color: RGB(A) color to use for NaN values
+ :type color: QColor, str, tuple of uint8 or float in [0., 1.]
+ """
+ color = (numpy.array(rgba(color)) * 255).astype(numpy.uint8)
+ if not numpy.array_equal(self.__nanColor, color):
+ self.__nanColor = color
+ self.sigChanged.emit()
+
+ def getNormalization(self):
+ """Return the normalization of the colormap.
+
+ See :meth:`setNormalization` for returned values.
+
+ :return: the normalization of the colormap
+ :rtype: str
+ """
+ return self._normalization
+
+ def setNormalization(self, norm):
+ """Set the colormap normalization.
+
+ Accepted normalizations: 'log', 'linear', 'sqrt'
+
+ :param str norm: the norm to set
+ """
+ assert norm in self.NORMALIZATIONS
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ norm = str(norm)
+ if norm != self._normalization:
+ self._normalization = norm
+ self.__warnBadVmin = True
+ self.__warnBadVmax = True
+ self.sigChanged.emit()
+
+ def setGammaNormalizationParameter(self, gamma: float) -> None:
+ """Set the gamma correction parameter.
+
+ Only used for gamma correction normalization.
+
+ :param float gamma:
+ :raise ValueError: If gamma is not valid
+ """
+ if gamma < 0. or not numpy.isfinite(gamma):
+ raise ValueError("Gamma value not supported")
+ if gamma != self.__gamma:
+ self.__gamma = gamma
+ self.sigChanged.emit()
+
+ def getGammaNormalizationParameter(self) -> float:
+ """Returns the gamma correction parameter value.
+
+ :rtype: float
+ """
+ return self.__gamma
+
+ def getAutoscaleMode(self):
+ """Return the autoscale mode of the colormap ('minmax' or 'stddev3')
+
+ :rtype: str
+ """
+ return self._autoscaleMode
+
+ def setAutoscaleMode(self, mode):
+ """Set the autoscale mode: either 'minmax' or 'stddev3'
+
+ :param str mode: the mode to set
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ assert mode in self.AUTOSCALE_MODES
+ if mode != self._autoscaleMode:
+ self._autoscaleMode = mode
+ self.sigChanged.emit()
+
+ def isAutoscale(self):
+ """Return True if both min and max are in autoscale mode"""
+ return self._vmin is None and self._vmax is None
+
+ def getVMin(self):
+ """Return the lower bound of the colormap
+
+ :return: the lower bound of the colormap
+ :rtype: float or None
+ """
+ return self._vmin
+
+ def setVMin(self, vmin):
+ """Set the minimal value of the colormap
+
+ :param float vmin: Lower bound of the colormap or None for autoscale
+ (default)
+ value)
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if vmin is not None:
+ if self._vmax is not None and vmin > self._vmax:
+ err = "Can't set vmin because vmin >= vmax. " \
+ "vmin = %s, vmax = %s" % (vmin, self._vmax)
+ raise ValueError(err)
+
+ if vmin != self._vmin:
+ self._vmin = vmin
+ self.__warnBadVmin = True
+ self.sigChanged.emit()
+
+ def getVMax(self):
+ """Return the upper bounds of the colormap or None
+
+ :return: the upper bounds of the colormap or None
+ :rtype: float or None
+ """
+ return self._vmax
+
+ def setVMax(self, vmax):
+ """Set the maximal value of the colormap
+
+ :param float vmax: Upper bounds of the colormap or None for autoscale
+ (default)
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if vmax is not None:
+ if self._vmin is not None and vmax < self._vmin:
+ err = "Can't set vmax because vmax <= vmin. " \
+ "vmin = %s, vmax = %s" % (self._vmin, vmax)
+ raise ValueError(err)
+
+ if vmax != self._vmax:
+ self._vmax = vmax
+ self.__warnBadVmax = True
+ self.sigChanged.emit()
+
+ def isEditable(self):
+ """ Return if the colormap is editable or not
+
+ :return: editable state of the colormap
+ :rtype: bool
+ """
+ return self._editable
+
+ def setEditable(self, editable):
+ """
+ Set the editable state of the colormap
+
+ :param bool editable: is the colormap editable
+ """
+ assert type(editable) is bool
+ self._editable = editable
+ self.sigChanged.emit()
+
+ def _getNormalizer(self):
+ """Returns normalizer object"""
+ normalization = self.getNormalization()
+ if normalization == self.GAMMA:
+ return _colormap.GammaNormalization(self.getGammaNormalizationParameter())
+ else:
+ return self._BASIC_NORMALIZATIONS[normalization]
+
+ def _computeAutoscaleRange(self, data):
+ """Compute the data range which will be used in autoscale mode.
+
+ :param numpy.ndarray data: The data for which to compute the range
+ :return: (vmin, vmax) range
+ """
+ return self._getNormalizer().autoscale(
+ data, mode=self.getAutoscaleMode())
+
+ def getColormapRange(self, data=None):
+ """Return (vmin, vmax) the range of the colormap for the given data or item.
+
+ :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixIn] data:
+ The data or item to use for autoscale bounds.
+ :return: (vmin, vmax) corresponding to the colormap applied to data if provided.
+ :rtype: tuple
+ """
+ vmin = self._vmin
+ vmax = self._vmax
+ assert vmin is None or vmax is None or vmin <= vmax # TODO handle this in setters
+
+ normalizer = self._getNormalizer()
+
+ # Handle invalid bounds as autoscale
+ if vmin is not None and not normalizer.is_valid(vmin):
+ if self.__warnBadVmin:
+ self.__warnBadVmin = False
+ _logger.info(
+ 'Invalid vmin, switching to autoscale for lower bound')
+ vmin = None
+ if vmax is not None and not normalizer.is_valid(vmax):
+ if self.__warnBadVmax:
+ self.__warnBadVmax = False
+ _logger.info(
+ 'Invalid vmax, switching to autoscale for upper bound')
+ vmax = None
+
+ if vmin is None or vmax is None: # Handle autoscale
+ from .plot.items.core import ColormapMixIn # avoid cyclic import
+ if isinstance(data, ColormapMixIn):
+ min_, max_ = data._getColormapAutoscaleRange(self)
+ # Make sure min_, max_ are not None
+ min_ = normalizer.DEFAULT_RANGE[0] if min_ is None else min_
+ max_ = normalizer.DEFAULT_RANGE[1] if max_ is None else max_
+ else:
+ min_, max_ = normalizer.autoscale(
+ data, mode=self.getAutoscaleMode())
+
+ if vmin is None: # Set vmin respecting provided vmax
+ vmin = min_ if vmax is None else min(min_, vmax)
+
+ if vmax is None:
+ vmax = max(max_, vmin) # Handle max_ <= 0 for log scale
+
+ return vmin, vmax
+
+ def getVRange(self):
+ """Get the bounds of the colormap
+
+ :rtype: Tuple(Union[float,None],Union[float,None])
+ :returns: A tuple of 2 values for min and max. Or None instead of float
+ for autoscale
+ """
+ return self.getVMin(), self.getVMax()
+
+ def setVRange(self, vmin, vmax):
+ """Set the bounds of the colormap
+
+ :param vmin: Lower bound of the colormap or None for autoscale
+ (default)
+ :param vmax: Upper bounds of the colormap or None for autoscale
+ (default)
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ if vmin is not None and vmax is not None:
+ if vmin > vmax:
+ err = "Can't set vmin and vmax because vmin >= vmax " \
+ "vmin = %s, vmax = %s" % (vmin, vmax)
+ raise ValueError(err)
+
+ if self._vmin == vmin and self._vmax == vmax:
+ return
+
+ if vmin != self._vmin:
+ self.__warnBadVmin = True
+ self._vmin = vmin
+ if vmax != self._vmax:
+ self.__warnBadVmax = True
+ self._vmax = vmax
+ self.sigChanged.emit()
+
+ def __getitem__(self, item):
+ if item == 'autoscale':
+ return self.isAutoscale()
+ elif item == 'name':
+ return self.getName()
+ elif item == 'normalization':
+ return self.getNormalization()
+ elif item == 'vmin':
+ return self.getVMin()
+ elif item == 'vmax':
+ return self.getVMax()
+ elif item == 'colors':
+ return self.getColormapLUT()
+ elif item == 'autoscaleMode':
+ return self.getAutoscaleMode()
+ else:
+ raise KeyError(item)
+
+ def _toDict(self):
+ """Return the equivalent colormap as a dictionary
+ (old colormap representation)
+
+ :return: the representation of the Colormap as a dictionary
+ :rtype: dict
+ """
+ return {
+ 'name': self._name,
+ 'colors': self.getColormapLUT(),
+ 'vmin': self._vmin,
+ 'vmax': self._vmax,
+ 'autoscale': self.isAutoscale(),
+ 'normalization': self.getNormalization(),
+ 'autoscaleMode': self.getAutoscaleMode(),
+ }
+
+ def _setFromDict(self, dic):
+ """Set values to the colormap from a dictionary
+
+ :param dict dic: the colormap as a dictionary
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ name = dic['name'] if 'name' in dic else None
+ colors = dic['colors'] if 'colors' in dic else None
+ if name is not None and colors is not None:
+ if isinstance(colors, int):
+ # Filter out argument which was supported but never used
+ _logger.info("Unused 'colors' from colormap dictionary filterer.")
+ colors = None
+ vmin = dic['vmin'] if 'vmin' in dic else None
+ vmax = dic['vmax'] if 'vmax' in dic else None
+ if 'normalization' in dic:
+ normalization = dic['normalization']
+ else:
+ warn = 'Normalization not given in the dictionary, '
+ warn += 'set by default to ' + Colormap.LINEAR
+ _logger.warning(warn)
+ normalization = Colormap.LINEAR
+
+ if name is None and colors is None:
+ err = 'The colormap should have a name defined or a tuple of colors'
+ raise ValueError(err)
+ if normalization not in Colormap.NORMALIZATIONS:
+ err = 'Given normalization is not recognized (%s)' % normalization
+ raise ValueError(err)
+
+ autoscaleMode = dic.get('autoscaleMode', Colormap.MINMAX)
+ if autoscaleMode not in Colormap.AUTOSCALE_MODES:
+ err = 'Given autoscale mode is not recognized (%s)' % autoscaleMode
+ raise ValueError(err)
+
+ # If autoscale, then set boundaries to None
+ if dic.get('autoscale', False):
+ vmin, vmax = None, None
+
+ if name is not None:
+ self.setName(name)
+ else:
+ self.setColormapLUT(colors)
+ self._vmin = vmin
+ self._vmax = vmax
+ self._autoscale = True if (vmin is None and vmax is None) else False
+ self._normalization = normalization
+ self._autoscaleMode = autoscaleMode
+
+ self.__warnBadVmin = True
+ self.__warnBadVmax = True
+ self.sigChanged.emit()
+
+ @staticmethod
+ def _fromDict(dic):
+ colormap = Colormap()
+ colormap._setFromDict(dic)
+ return colormap
+
+ def copy(self):
+ """Return a copy of the Colormap.
+
+ :rtype: silx.gui.colors.Colormap
+ """
+ colormap = Colormap(name=self._name,
+ colors=self.getColormapLUT(),
+ vmin=self._vmin,
+ vmax=self._vmax,
+ normalization=self.getNormalization(),
+ autoscaleMode=self.getAutoscaleMode())
+ colormap.setNaNColor(self.getNaNColor())
+ colormap.setGammaNormalizationParameter(
+ self.getGammaNormalizationParameter())
+ colormap.setEditable(self.isEditable())
+ return colormap
+
+ def applyToData(self, data, reference=None):
+ """Apply the colormap to the data
+
+ :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn] data:
+ The data to convert or the item for which to apply the colormap.
+ :param Union[numpy.ndarray,~silx.gui.plot.item.ColormapMixIn,None] reference:
+ The data or item to use as reference to compute autoscale
+ """
+ if reference is None:
+ reference = data
+ vmin, vmax = self.getColormapRange(reference)
+
+ if hasattr(data, "getColormappedData"): # Use item's data
+ data = data.getColormappedData(copy=False)
+
+ return _colormap.cmap(
+ data,
+ self._colors,
+ vmin,
+ vmax,
+ self._getNormalizer(),
+ self.__nanColor)
+
+ @staticmethod
+ def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
+ 'viridis', 'magma', 'inferno', 'plasma')
+
+ :rtype: tuple
+ """
+ registered_colormaps = _colormap.get_registered_colormaps()
+ colormaps = set(registered_colormaps)
+ if _matplotlib_colormaps is not None:
+ colormaps.update(_matplotlib_colormaps())
+
+ # Put registered_colormaps first
+ colormaps = tuple(cmap for cmap in sorted(colormaps)
+ if cmap not in registered_colormaps)
+ return registered_colormaps + colormaps
+
+ def __str__(self):
+ return str(self._toDict())
+
+ def __eq__(self, other):
+ """Compare colormap values and not pointers"""
+ if other is None:
+ return False
+ if not isinstance(other, Colormap):
+ return False
+ if self.getNormalization() != other.getNormalization():
+ return False
+ if self.getNormalization() == self.GAMMA:
+ delta = self.getGammaNormalizationParameter() - other.getGammaNormalizationParameter()
+ if abs(delta) > 0.001:
+ return False
+ return (self.getName() == other.getName() and
+ self.getAutoscaleMode() == other.getAutoscaleMode() and
+ self.getVMin() == other.getVMin() and
+ self.getVMax() == other.getVMax() and
+ numpy.array_equal(self.getColormapLUT(), other.getColormapLUT())
+ )
+
+ _SERIAL_VERSION = 3
+
+ def restoreState(self, byteArray):
+ """
+ Read the colormap state from a QByteArray.
+
+ :param qt.QByteArray byteArray: Stream containing the state
+ :return: True if the restoration sussseed
+ :rtype: bool
+ """
+ if self.isEditable() is False:
+ raise NotEditableError('Colormap is not editable')
+ stream = qt.QDataStream(byteArray, qt.QIODevice.ReadOnly)
+
+ className = stream.readQString()
+ if className != self.__class__.__name__:
+ _logger.warning("Classname mismatch. Found %s." % className)
+ return False
+
+ version = stream.readUInt32()
+ if version not in numpy.arange(1, self._SERIAL_VERSION+1):
+ _logger.warning("Serial version mismatch. Found %d." % version)
+ return False
+
+ name = stream.readQString()
+ isNull = stream.readBool()
+ if not isNull:
+ vmin = stream.readQVariant()
+ else:
+ vmin = None
+ isNull = stream.readBool()
+ if not isNull:
+ vmax = stream.readQVariant()
+ else:
+ vmax = None
+
+ normalization = stream.readQString()
+ if normalization == Colormap.GAMMA:
+ gamma = stream.readFloat()
+ else:
+ gamma = None
+
+ if version == 1:
+ autoscaleMode = Colormap.MINMAX
+ else:
+ autoscaleMode = stream.readQString()
+
+ if version <= 2:
+ nanColor = self._DEFAULT_NAN_COLOR
+ else:
+ nanColor = stream.readInt32(), stream.readInt32(), stream.readInt32(), stream.readInt32()
+
+ # emit change event only once
+ old = self.blockSignals(True)
+ try:
+ self.setName(name)
+ self.setNormalization(normalization)
+ self.setAutoscaleMode(autoscaleMode)
+ self.setVRange(vmin, vmax)
+ if gamma is not None:
+ self.setGammaNormalizationParameter(gamma)
+ self.setNaNColor(nanColor)
+ finally:
+ self.blockSignals(old)
+ self.sigChanged.emit()
+ return True
+
+ def saveState(self):
+ """
+ Save state of the colomap into a QDataStream.
+
+ :rtype: qt.QByteArray
+ """
+ data = qt.QByteArray()
+ stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
+
+ stream.writeQString(self.__class__.__name__)
+ stream.writeUInt32(self._SERIAL_VERSION)
+ stream.writeQString(self.getName())
+ stream.writeBool(self.getVMin() is None)
+ if self.getVMin() is not None:
+ stream.writeQVariant(self.getVMin())
+ stream.writeBool(self.getVMax() is None)
+ if self.getVMax() is not None:
+ stream.writeQVariant(self.getVMax())
+ stream.writeQString(self.getNormalization())
+ if self.getNormalization() == Colormap.GAMMA:
+ stream.writeFloat(self.getGammaNormalizationParameter())
+ stream.writeQString(self.getAutoscaleMode())
+ nanColor = self.getNaNColor()
+ stream.writeInt32(nanColor.red())
+ stream.writeInt32(nanColor.green())
+ stream.writeInt32(nanColor.blue())
+ stream.writeInt32(nanColor.alpha())
+
+ return data
+
+
+_PREFERRED_COLORMAPS = None
+"""
+Tuple of preferred colormap names accessed with :meth:`preferredColormaps`.
+"""
+
+_DEFAULT_PREFERRED_COLORMAPS = (
+ 'gray', 'reversed gray', 'red', 'green', 'blue',
+ 'viridis', 'cividis', 'magma', 'inferno', 'plasma',
+ 'temperature',
+ 'jet', 'hsv'
+)
+
+
+def preferredColormaps():
+ """Returns the name of the preferred colormaps.
+
+ This list is used by widgets allowing to change the colormap
+ like the :class:`ColormapDialog` as a subset of colormap choices.
+
+ :rtype: tuple of str
+ """
+ global _PREFERRED_COLORMAPS
+ if _PREFERRED_COLORMAPS is None:
+ # Initialize preferred colormaps
+ setPreferredColormaps(_DEFAULT_PREFERRED_COLORMAPS)
+ return tuple(_PREFERRED_COLORMAPS)
+
+
+def setPreferredColormaps(colormaps):
+ """Set the list of preferred colormap names.
+
+ Warning: If a colormap name is not available
+ it will be removed from the list.
+
+ :param colormaps: Not empty list of colormap names
+ :type colormaps: iterable of str
+ :raise ValueError: if the list of available preferred colormaps is empty.
+ """
+ supportedColormaps = Colormap.getSupportedColormaps()
+ colormaps = [cmap for cmap in colormaps if cmap in supportedColormaps]
+ if len(colormaps) == 0:
+ raise ValueError("Cannot set preferred colormaps to an empty list")
+
+ global _PREFERRED_COLORMAPS
+ _PREFERRED_COLORMAPS = colormaps
+
+
+def registerLUT(name, colors, cursor_color='black', preferred=True):
+ """Register a custom LUT to be used with `Colormap` objects.
+
+ It can override existing LUT names.
+
+ :param str name: Name of the LUT as defined to configure colormaps
+ :param numpy.ndarray colors: The custom LUT to register.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ :param bool preferred: If true, this LUT will be displayed as part of the
+ preferred colormaps in dialogs.
+ :param str cursor_color: Color used to display overlay over images using
+ colormap with this LUT.
+ """
+ _colormap.register_colormap(name, colors, cursor_color)
+
+ if preferred:
+ # Invalidate the preferred cache
+ global _PREFERRED_COLORMAPS
+ if _PREFERRED_COLORMAPS is not None:
+ if name not in _PREFERRED_COLORMAPS:
+ _PREFERRED_COLORMAPS.append(name)
+ else:
+ # The cache is not yet loaded, it's fine
+ pass
+
+
+# Load some colormaps from matplotlib by default
+if _matplotlib_cm is not None:
+ _registerColormapFromMatplotlib('jet', cursor_color='pink', preferred=True)
+ _registerColormapFromMatplotlib('hsv', cursor_color='black', preferred=True)
diff --git a/src/silx/gui/conftest.py b/src/silx/gui/conftest.py
new file mode 100644
index 0000000..74b5c19
--- /dev/null
+++ b/src/silx/gui/conftest.py
@@ -0,0 +1,5 @@
+import pytest
+
+@pytest.fixture(autouse=True)
+def auto_qapp(qapp):
+ pass
diff --git a/src/silx/gui/console.py b/src/silx/gui/console.py
new file mode 100644
index 0000000..953b6a1
--- /dev/null
+++ b/src/silx/gui/console.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides an IPython console widget.
+
+You can push variables - any python object - to the
+console's interactive namespace. This provides users with an advanced way
+of interacting with your program. For instance, if your program has a
+:class:`PlotWidget` or a :class:`PlotWindow`, you can push a reference to
+these widgets to allow your users to add curves, save data to files… by using
+the widgets' methods from the console.
+
+.. note::
+
+ This module has a dependency on
+ `qtconsole <https://pypi.org/project/qtconsole/>`_.
+ An ``ImportError`` will be raised if it is
+ imported while the dependencies are not satisfied.
+
+Basic usage example::
+
+ from silx.gui import qt
+ from silx.gui.console import IPythonWidget
+
+ app = qt.QApplication([])
+
+ hello_button = qt.QPushButton("Hello World!", None)
+ hello_button.show()
+
+ console = IPythonWidget()
+ console.show()
+ console.pushVariables({"the_button": hello_button})
+
+ app.exec()
+
+This program will display a console widget and a push button in two separate
+windows. You will be able to interact with the button from the console,
+for example change its text::
+
+ >>> the_button.setText("Spam spam")
+
+An IPython interactive console is a powerful tool that enables you to work
+with data and plot it.
+See `this tutorial <https://plot.ly/python/ipython-notebook-tutorial/>`_
+for more information on some of the rich features of IPython.
+"""
+__authors__ = ["Tim Rae", "V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/05/2016"
+
+import logging
+
+from . import qt
+
+_logger = logging.getLogger(__name__)
+
+
+# This widget cannot be used inside an interactive IPython shell.
+# It would raise MultipleInstanceError("Multiple incompatible subclass
+# instances of InProcessInteractiveShell are being created").
+try:
+ __IPYTHON__
+except NameError:
+ pass # Not in IPython
+else:
+ msg = "Module " + __name__ + " cannot be used within an IPython shell"
+ raise ImportError(msg)
+
+try:
+ from qtconsole.rich_jupyter_widget import RichJupyterWidget as \
+ _RichJupyterWidget
+except ImportError:
+ try:
+ from qtconsole.rich_ipython_widget import RichJupyterWidget as \
+ _RichJupyterWidget
+ except ImportError:
+ from qtconsole.rich_ipython_widget import RichIPythonWidget as \
+ _RichJupyterWidget
+
+from qtconsole.inprocess import QtInProcessKernelManager
+
+try:
+ from ipykernel import version_info as _ipykernel_version_info
+except ImportError:
+ _ipykernel_version_info = None
+
+
+class IPythonWidget(_RichJupyterWidget):
+ """Live IPython console widget.
+
+ .. image:: img/IPythonWidget.png
+
+ :param custom_banner: Custom welcome message to be printed at the top of
+ the console.
+ """
+
+ def __init__(self, parent=None, custom_banner=None, *args, **kwargs):
+ if parent is not None:
+ kwargs["parent"] = parent
+ super(IPythonWidget, self).__init__(*args, **kwargs)
+ if custom_banner is not None:
+ self.banner = custom_banner
+ self.setWindowTitle(self.banner)
+ self.kernel_manager = kernel_manager = QtInProcessKernelManager()
+ kernel_manager.start_kernel()
+
+ # Monkey-patch to workaround issue:
+ # https://github.com/ipython/ipykernel/issues/370
+ if (_ipykernel_version_info is not None and
+ _ipykernel_version_info[0] > 4 and
+ _ipykernel_version_info[:3] <= (5, 1, 0)):
+ def _abort_queues(*args, **kwargs):
+ pass
+ kernel_manager.kernel._abort_queues = _abort_queues
+
+ self.kernel_client = kernel_client = self._kernel_manager.client()
+ kernel_client.start_channels()
+
+ def stop():
+ kernel_client.stop_channels()
+ kernel_manager.shutdown_kernel()
+ self.exit_requested.connect(stop)
+
+ def sizeHint(self):
+ """Return a reasonable default size for usage in :class:`PlotWindow`"""
+ return qt.QSize(500, 300)
+
+ def pushVariables(self, variable_dict):
+ """ Given a dictionary containing name / value pairs, push those
+ variables to the IPython console widget.
+
+ :param variable_dict: Dictionary of variables to be pushed to the
+ console's interactive namespace (```{variable_name: object, …}```)
+ """
+ self.kernel_manager.kernel.shell.push(variable_dict)
+
+
+class IPythonDockWidget(qt.QDockWidget):
+ """Dock Widget including a :class:`IPythonWidget` inside
+ a vertical layout.
+
+ .. image:: img/IPythonDockWidget.png
+
+ :param available_vars: Dictionary of variables to be pushed to the
+ console's interactive namespace: ``{"variable_name": object, …}``
+ :param custom_banner: Custom welcome message to be printed at the top of
+ the console
+ :param title: Dock widget title
+ :param parent: Parent :class:`qt.QMainWindow` containing this
+ :class:`qt.QDockWidget`
+ """
+ def __init__(self, parent=None, available_vars=None, custom_banner=None,
+ title="Console"):
+ super(IPythonDockWidget, self).__init__(title, parent)
+
+ self.ipyconsole = IPythonWidget(custom_banner=custom_banner)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.ipyconsole)
+
+ if available_vars is not None:
+ self.ipyconsole.pushVariables(available_vars)
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
+
+
+def main():
+ """Run a Qt app with an IPython console"""
+ app = qt.QApplication([])
+ widget = IPythonDockWidget()
+ widget.show()
+ app.exec()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/src/silx/gui/data/ArrayTableModel.py b/src/silx/gui/data/ArrayTableModel.py
new file mode 100644
index 0000000..23b0bb2
--- /dev/null
+++ b/src/silx/gui/data/ArrayTableModel.py
@@ -0,0 +1,650 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module defines a data model for displaying and editing arrays of any
+number of dimensions in a table view.
+"""
+from __future__ import division
+import numpy
+import logging
+from silx.gui import qt
+from silx.gui.data.TextFormatter import TextFormatter
+
+__authors__ = ["V.A. Sole"]
+__license__ = "MIT"
+__date__ = "27/09/2017"
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _is_array(data):
+ """Return True if object implements all necessary attributes to be used
+ as a numpy array.
+
+ :param object data: Array-like object (numpy array, h5py dataset...)
+ :return: boolean
+ """
+ # add more required attribute if necessary
+ for attr in ("shape", "dtype"):
+ if not hasattr(data, attr):
+ return False
+ return True
+
+
+class ArrayTableModel(qt.QAbstractTableModel):
+ """This data model provides access to 2D slices in a N-dimensional
+ array.
+
+ A slice for a 3-D array is characterized by a perspective (the number of
+ the axis orthogonal to the slice) and an index at which the slice
+ intersects the orthogonal axis.
+
+ In the n-D case, only slices parallel to the last two axes are handled. A
+ slice is therefore characterized by a list of indices locating the
+ slice on all the :math:`n - 2` orthogonal axes.
+
+ :param parent: Parent QObject
+ :param data: Numpy array, or object implementing a similar interface
+ (e.g. h5py dataset)
+ :param str fmt: Format string for representing numerical values.
+ Default is ``"%g"``.
+ :param sequence[int] perspective: See documentation
+ of :meth:`setPerspective`.
+ """
+
+ MAX_NUMBER_OF_SECTIONS = 10e6
+ """Maximum number of displayed rows and columns"""
+
+ def __init__(self, parent=None, data=None, perspective=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self._array = None
+ """n-dimensional numpy array"""
+
+ self._bgcolors = None
+ """(n+1)-dimensional numpy array containing RGB(A) color data
+ for the background color
+ """
+
+ self._fgcolors = None
+ """(n+1)-dimensional numpy array containing RGB(A) color data
+ for the foreground color
+ """
+
+ self._formatter = None
+ """Formatter for text representation of data"""
+
+ formatter = TextFormatter(self)
+ formatter.setUseQuoteForText(False)
+ self.setFormatter(formatter)
+
+ self._index = None
+ """This attribute stores the slice index, as a list of indices
+ where the frame intersects orthogonal axis."""
+
+ self._perspective = None
+ """Sequence of dimensions orthogonal to the frame to be viewed.
+ For an array with ``n`` dimensions, this is a sequence of ``n-2``
+ integers. the first dimension is numbered ``0``.
+ By default, the data frames use the last two dimensions as their axes
+ and therefore the perspective is a sequence of the first ``n-2``
+ dimensions.
+ For example, for a 5-D array, the default perspective is ``(0, 1, 2)``
+ and the default frames axes are ``(3, 4)``."""
+
+ # set _data and _perspective
+ self.setArrayData(data, perspective=perspective)
+
+ def _getRowDim(self):
+ """The row axis is the first axis parallel to the frames
+ (lowest dimension number)
+
+ Return None for 0-D (scalar) or 1-D arrays
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 2:
+ # scalar or 1D array: no row index
+ return None
+ # take all dimensions and remove the orthogonal ones
+ frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
+ # sanity check
+ assert len(frame_axes) == 2
+ return min(frame_axes)
+
+ def _getColumnDim(self):
+ """The column axis is the second (highest dimension) axis parallel
+ to the frames
+
+ Return None for 0-D (scalar)
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 1:
+ # scalar: no column index
+ return None
+ frame_axes = set(range(0, n_dimensions)) - set(self._perspective)
+ # sanity check
+ assert (len(frame_axes) == 2) if n_dimensions > 1 else (len(frame_axes) == 1)
+ return max(frame_axes)
+
+ def _getIndexTuple(self, table_row, table_col):
+ """Return the n-dimensional index of a value in the original array,
+ based on its row and column indices in the table view
+
+ :param table_row: Row index (0-based) of a table cell
+ :param table_col: Column index (0-based) of a table cell
+ :return: Tuple of indices of the element in the numpy array
+ """
+ row_dim = self._getRowDim()
+ col_dim = self._getColumnDim()
+
+ # get indices on all orthogonal axes
+ selection = list(self._index)
+ # insert indices on parallel axes
+ if row_dim is not None:
+ selection.insert(row_dim, table_row)
+ if col_dim is not None:
+ selection.insert(col_dim, table_col)
+ return tuple(selection)
+
+ # Methods to be implemented to subclass QAbstractTableModel
+ def rowCount(self, parent_idx=None):
+ """QAbstractTableModel method
+ Return number of rows to be displayed in table"""
+ row_dim = self._getRowDim()
+ if row_dim is None:
+ # 0-D and 1-D arrays
+ return 1
+ return min(self._array.shape[row_dim], self.MAX_NUMBER_OF_SECTIONS)
+
+ def columnCount(self, parent_idx=None):
+ """QAbstractTableModel method
+ Return number of columns to be displayed in table"""
+ col_dim = self._getColumnDim()
+ if col_dim is None:
+ # 0-D array
+ return 1
+ return min(self._array.shape[col_dim], self.MAX_NUMBER_OF_SECTIONS)
+
+ def __isClipped(self, orientation=qt.Qt.Vertical) -> bool:
+ """Returns whether or not array is clipped in a given orientation"""
+ if orientation == qt.Qt.Vertical:
+ dim = self._getRowDim()
+ else:
+ dim = self._getColumnDim()
+ return (dim is not None and
+ self._array.shape[dim] > self.MAX_NUMBER_OF_SECTIONS)
+
+ def __isClippedIndex(self, index) -> bool:
+ """Returns whether or not index's cell represents clipped data."""
+ if not index.isValid():
+ return False
+ if index.row() == self.MAX_NUMBER_OF_SECTIONS - 2:
+ return self.__isClipped(qt.Qt.Vertical)
+ if index.column() == self.MAX_NUMBER_OF_SECTIONS - 2:
+ return self.__isClipped(qt.Qt.Horizontal)
+ return False
+
+ def __clippedData(self, role=qt.Qt.DisplayRole):
+ """Return data for cells representing clipped data"""
+ if role == qt.Qt.DisplayRole:
+ return "..."
+ elif role == qt.Qt.ToolTipRole:
+ return "Dataset is too large: display is clipped"
+ else:
+ return None
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if index.isValid():
+ if self.__isClippedIndex(index): # Special displayed for clipped data
+ return self.__clippedData(role)
+
+ row, column = index.row(), index.column()
+
+ # When clipped, display last data of the array in last column of the table
+ if (self.__isClipped(qt.Qt.Vertical) and
+ row == self.MAX_NUMBER_OF_SECTIONS - 1):
+ row = self._array.shape[self._getRowDim()] - 1
+ if (self.__isClipped(qt.Qt.Horizontal) and
+ column == self.MAX_NUMBER_OF_SECTIONS - 1):
+ column = self._array.shape[self._getColumnDim()] - 1
+
+ selection = self._getIndexTuple(row, column)
+
+ if role == qt.Qt.DisplayRole:
+ return self._formatter.toString(self._array[selection], self._array.dtype)
+
+ if role == qt.Qt.BackgroundRole and self._bgcolors is not None:
+ r, g, b = self._bgcolors[selection][0:3]
+ if self._bgcolors.shape[-1] == 3:
+ return qt.QColor(r, g, b)
+ if self._bgcolors.shape[-1] == 4:
+ a = self._bgcolors[selection][3]
+ return qt.QColor(r, g, b, a)
+
+ if role == qt.Qt.ForegroundRole:
+ if self._fgcolors is not None:
+ r, g, b = self._fgcolors[selection][0:3]
+ if self._fgcolors.shape[-1] == 3:
+ return qt.QColor(r, g, b)
+ if self._fgcolors.shape[-1] == 4:
+ a = self._fgcolors[selection][3]
+ return qt.QColor(r, g, b, a)
+
+ # no fg color given, use black or white
+ # based on luminosity threshold
+ elif self._bgcolors is not None:
+ r, g, b = self._bgcolors[selection][0:3]
+ lum = 0.21 * r + 0.72 * g + 0.07 * b
+ if lum < 128:
+ return qt.QColor(qt.Qt.white)
+ else:
+ return qt.QColor(qt.Qt.black)
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method
+ Return the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if self.__isClipped(orientation): # Header is clipped
+ if section == self.MAX_NUMBER_OF_SECTIONS - 2:
+ # Represent clipped data
+ return self.__clippedData(role)
+
+ elif section == self.MAX_NUMBER_OF_SECTIONS - 1:
+ # Display last index from data not table
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ dim = self._getRowDim()
+ else:
+ dim = self._getColumnDim()
+ return str(self._array.shape[dim] - 1)
+ else:
+ return None
+
+ if role == qt.Qt.DisplayRole:
+ return "%d" % section
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not."""
+ if not self._editable or self.__isClippedIndex(index):
+ return qt.QAbstractTableModel.flags(self, index)
+ return qt.QAbstractTableModel.flags(self, index) | qt.Qt.ItemIsEditable
+
+ def setData(self, index, value, role=None):
+ """QAbstractTableModel method to handle editing data.
+ Cast the new value into the same format as the array before editing
+ the array value."""
+ if index.isValid() and role == qt.Qt.EditRole:
+ try:
+ # cast value to same type as array
+ v = numpy.array(value, dtype=self._array.dtype).item()
+ except ValueError:
+ return False
+
+ selection = self._getIndexTuple(index.row(),
+ index.column())
+ self._array[selection] = v
+ self.dataChanged.emit(index, index)
+ return True
+ else:
+ return False
+
+ # Public methods
+ def setArrayData(self, data, copy=True,
+ perspective=None, editable=False):
+ """Set the data array and the viewing perspective.
+
+ You can set ``copy=False`` if you need more performances, when dealing
+ with a large numpy array. In this case, a simple reference to the data
+ is used to access the data, rather than a copy of the array.
+
+ .. warning::
+
+ Any change to the data model will affect your original data
+ array, when using a reference rather than a copy..
+
+ :param data: n-dimensional numpy array, or any object that can be
+ converted to a numpy array using ``numpy.array(data)`` (e.g.
+ a nested sequence).
+ :param bool copy: If *True* (default), a copy of the array is stored
+ and the original array is not modified if the table is edited.
+ If *False*, then the behavior depends on the data type:
+ if possible (if the original array is a proper numpy array)
+ a reference to the original array is used.
+ :param perspective: See documentation of :meth:`setPerspective`.
+ If None, the default perspective is the list of the first ``n-2``
+ dimensions, to view frames parallel to the last two axes.
+ :param bool editable: Flag to enable editing data. Default *False*.
+ """
+ self.beginResetModel()
+
+ if data is None:
+ # empty array
+ self._array = numpy.array([])
+ elif copy:
+ # copy requested (default)
+ self._array = numpy.array(data, copy=True)
+ if hasattr(data, "dtype"):
+ # Avoid to lose the monkey-patched h5py dtype
+ self._array.dtype = data.dtype
+ elif not _is_array(data):
+ raise TypeError("data is not a proper array. Try setting" +
+ " copy=True to convert it into a numpy array" +
+ " (this will cause the data to be copied!)")
+ # # copy not requested, but necessary
+ # _logger.warning(
+ # "data is not an array-like object. " +
+ # "Data must be copied.")
+ # self._array = numpy.array(data, copy=True)
+ else:
+ # Copy explicitly disabled & data implements required attributes.
+ # We can use a reference.
+ self._array = data
+
+ # reset colors to None if new data shape is inconsistent
+ valid_color_shapes = (self._array.shape + (3,),
+ self._array.shape + (4,))
+ if self._bgcolors is not None:
+ if self._bgcolors.shape not in valid_color_shapes:
+ self._bgcolors = None
+ if self._fgcolors is not None:
+ if self._fgcolors.shape not in valid_color_shapes:
+ self._fgcolors = None
+
+ self.setEditable(editable)
+
+ self._index = [0 for _i in range((len(self._array.shape) - 2))]
+ self._perspective = tuple(perspective) if perspective is not None else\
+ tuple(range(0, len(self._array.shape) - 2))
+
+ self.endResetModel()
+
+ def setArrayColors(self, bgcolors=None, fgcolors=None):
+ """Set the colors for all table cells by passing an array
+ of RGB or RGBA values (integers between 0 and 255).
+
+ The shape of the colors array must be consistent with the data shape.
+
+ If the data array is n-dimensional, the colors array must be
+ (n+1)-dimensional, with the first n-dimensions identical to the data
+ array dimensions, and the last dimension length-3 (RGB) or
+ length-4 (RGBA).
+
+ :param bgcolors: RGB or RGBA colors array, defining the background color
+ for each cell in the table.
+ :param fgcolors: RGB or RGBA colors array, defining the foreground color
+ (text color) for each cell in the table.
+ """
+ # array must be RGB or RGBA
+ valid_shapes = (self._array.shape + (3,), self._array.shape + (4,))
+ errmsg = "Inconsistent shape for color array, should be %s or %s" % valid_shapes
+
+ if bgcolors is not None:
+ if not _is_array(bgcolors):
+ bgcolors = numpy.array(bgcolors)
+ assert bgcolors.shape in valid_shapes, errmsg
+
+ self._bgcolors = bgcolors
+
+ if fgcolors is not None:
+ if not _is_array(fgcolors):
+ fgcolors = numpy.array(fgcolors)
+ assert fgcolors.shape in valid_shapes, errmsg
+
+ self._fgcolors = fgcolors
+
+ def setEditable(self, editable):
+ """Set flags to make the data editable.
+
+ .. warning::
+
+ If the data is a reference to a h5py dataset open in read-only
+ mode, setting *editable=True* will fail and print a warning.
+
+ .. warning::
+
+ Making the data editable means that the underlying data structure
+ in this data model will be modified.
+ If the data is a reference to a public object (open with
+ ``copy=False``), this could have side effects. If it is a
+ reference to an HDF5 dataset, this means the file will be
+ modified.
+
+ :param bool editable: Flag to enable editing data.
+ :return: True if setting desired flag succeeded, False if it failed.
+ """
+ self._editable = editable
+ if hasattr(self._array, "file"):
+ if hasattr(self._array.file, "mode"):
+ if editable and self._array.file.mode == "r":
+ _logger.warning(
+ "Data is a HDF5 dataset open in read-only " +
+ "mode. Editing must be disabled.")
+ self._editable = False
+ return False
+ return True
+
+ def getData(self, copy=True):
+ """Return a copy of the data array, or a reference to it
+ if *copy=False* is passed as parameter.
+
+ In case the shape was modified, to convert 0-D or 1-D data
+ into 2-D data, the original shape is restored in the returned data.
+
+ :param bool copy: If *True* (default), return a copy of the data. If
+ *False*, return a reference.
+ :return: numpy array of data, or reference to original data object
+ if *copy=False*
+ """
+ data = self._array if not copy else numpy.array(self._array, copy=True)
+ return data
+
+ def setFrameIndex(self, index):
+ """Set the active slice index.
+
+ This method is only relevant to arrays with at least 3 dimensions.
+
+ :param index: Index of the active slice in the array.
+ In the general n-D case, this is a sequence of :math:`n - 2`
+ indices where the slice intersects the respective orthogonal axes.
+ :raise IndexError: If any index in the index sequence is out of bound
+ on its respective axis.
+ """
+ shape = self._array.shape
+ if len(shape) < 3:
+ # index is ignored
+ return
+
+ self.beginResetModel()
+
+ if len(shape) == 3:
+ len_ = shape[self._perspective[0]]
+ # accept integers as index in the case of 3-D arrays
+ if not hasattr(index, "__len__"):
+ self._index = [index]
+ else:
+ self._index = index
+ if not 0 <= self._index[0] < len_:
+ raise ValueError("Index must be a positive integer " +
+ "lower than %d" % len_)
+ else:
+ # general n-D case
+ for i_, idx in enumerate(index):
+ if not 0 <= idx < shape[self._perspective[i_]]:
+ raise IndexError("Invalid index %d " % idx +
+ "not in range 0-%d" % (shape[i_] - 1))
+ self._index = index
+
+ self.endResetModel()
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self._formatter:
+ return
+
+ self.beginResetModel()
+
+ if self._formatter is not None:
+ self._formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self._formatter = formatter
+ if self._formatter is not None:
+ self._formatter.formatChanged.connect(self.__formatChanged)
+
+ self.endResetModel()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self._formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.reset()
+
+ def setPerspective(self, perspective):
+ """Set the perspective by defining a sequence listing all axes
+ orthogonal to the frame or 2-D slice to be visualized.
+
+ Alternatively, you can use :meth:`setFrameAxes` for the complementary
+ approach of specifying the two axes parallel to the frame.
+
+ In the 1-D or 2-D case, this parameter is irrelevant.
+
+ In the 3-D case, if the unit vectors describing
+ your axes are :math:`\vec{x}, \vec{y}, \vec{z}`, a perspective of 0
+ means you slices are parallel to :math:`\vec{y}\vec{z}`, 1 means they
+ are parallel to :math:`\vec{x}\vec{z}` and 2 means they
+ are parallel to :math:`\vec{x}\vec{y}`.
+
+ In the n-D case, this parameter is a sequence of :math:`n-2` axes
+ numbers.
+ For instance if you want to display 2-D frames whose axes are the
+ second and third dimensions of a 5-D array, set the perspective to
+ ``(0, 3, 4)``.
+
+ :param perspective: Sequence of dimensions/axes orthogonal to the
+ frames.
+ :raise: IndexError if any value in perspective is higher than the
+ number of dimensions minus one (first dimension is 0), or
+ if the number of values is different from the number of dimensions
+ minus two.
+ """
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 3:
+ _logger.warning(
+ "perspective is not relevant for 1D and 2D arrays")
+ return
+
+ if not hasattr(perspective, "__len__"):
+ # we can tolerate an integer for 3-D array
+ if n_dimensions == 3:
+ perspective = [perspective]
+ else:
+ raise ValueError("perspective must be a sequence of integers")
+
+ # ensure unicity of dimensions in perspective
+ perspective = tuple(set(perspective))
+
+ if len(perspective) != n_dimensions - 2 or\
+ min(perspective) < 0 or max(perspective) >= n_dimensions:
+ raise IndexError(
+ "Invalid perspective " + str(perspective) +
+ " for %d-D array " % n_dimensions +
+ "with shape " + str(self._array.shape))
+
+ self.beginResetModel()
+
+ self._perspective = perspective
+
+ # reset index
+ self._index = [0 for _i in range(n_dimensions - 2)]
+
+ self.endResetModel()
+
+ def setFrameAxes(self, row_axis, col_axis):
+ """Set the perspective by specifying the two axes parallel to the frame
+ to be visualised.
+
+ The complementary approach of defining the orthogonal axes can be used
+ with :meth:`setPerspective`.
+
+ :param int row_axis: Index (0-based) of the first dimension used as a frame
+ axis
+ :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
+ axis
+ :raise: IndexError if axes are invalid
+ """
+ if row_axis > col_axis:
+ _logger.warning("The dimension of the row axis must be lower " +
+ "than the dimension of the column axis. Swapping.")
+ row_axis, col_axis = min(row_axis, col_axis), max(row_axis, col_axis)
+
+ n_dimensions = len(self._array.shape)
+ if n_dimensions < 3:
+ _logger.warning(
+ "Frame axes cannot be changed for 1D and 2D arrays")
+ return
+
+ perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
+
+ if len(perspective) != n_dimensions - 2 or\
+ min(perspective) < 0 or max(perspective) >= n_dimensions:
+ raise IndexError(
+ "Invalid perspective " + str(perspective) +
+ " for %d-D array " % n_dimensions +
+ "with shape " + str(self._array.shape))
+
+ self.beginResetModel()
+
+ self._perspective = perspective
+ # reset index
+ self._index = [0 for _i in range(n_dimensions - 2)]
+
+ self.endResetModel()
+
+
+if __name__ == "__main__":
+ app = qt.QApplication([])
+ w = qt.QTableView()
+ d = numpy.random.normal(0, 1, (5, 1000, 1000))
+ for i in range(5):
+ d[i, :, :] += i * 10
+ m = ArrayTableModel(data=d)
+ w.setModel(m)
+ m.setFrameIndex(3)
+ # m.setArrayData(numpy.ones((100,)))
+ w.show()
+ app.exec()
diff --git a/src/silx/gui/data/ArrayTableWidget.py b/src/silx/gui/data/ArrayTableWidget.py
new file mode 100644
index 0000000..baef5f4
--- /dev/null
+++ b/src/silx/gui/data/ArrayTableWidget.py
@@ -0,0 +1,492 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines a widget designed to display data arrays with any
+number of dimensions as 2D frames (images, slices) in a table view.
+The dimensions not displayed in the table can be browsed using improved
+sliders.
+
+The widget uses a TableView that relies on a custom abstract item
+model: :class:`silx.gui.data.ArrayTableModel`.
+"""
+from __future__ import division
+import sys
+
+from silx.gui import qt
+from silx.gui.widgets.TableWidget import TableView
+from .ArrayTableModel import ArrayTableModel
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/01/2017"
+
+
+class AxesSelector(qt.QWidget):
+ """Widget with two combo-boxes to select two dimensions among
+ all possible dimensions of an n-dimensional array.
+
+ The first combobox contains values from :math:`0` to :math:`n-2`.
+
+ The choices in the 2nd CB depend on the value selected in the first one.
+ If the value selected in the first CB is :math:`m`, the second one lets you
+ select values from :math:`m+1` to :math:`n-1`.
+
+ The two axes can be used to select the row axis and the column axis t
+ display a slice of the array data in a table view.
+ """
+ sigDimensionsChanged = qt.Signal(int, int)
+ """Signal emitted whenever one of the comboboxes is changed.
+ The signal carries the two selected dimensions."""
+
+ def __init__(self, parent=None, n=None):
+ qt.QWidget.__init__(self, parent)
+ self.layout = qt.QHBoxLayout(self)
+ self.layout.setContentsMargins(0, 2, 0, 2)
+ self.layout.setSpacing(10)
+
+ self.rowsCB = qt.QComboBox(self)
+ self.columnsCB = qt.QComboBox(self)
+
+ self.layout.addWidget(qt.QLabel("Rows dimension", self))
+ self.layout.addWidget(self.rowsCB)
+ self.layout.addWidget(qt.QLabel(" ", self))
+ self.layout.addWidget(qt.QLabel("Columns dimension", self))
+ self.layout.addWidget(self.columnsCB)
+ self.layout.addStretch(1)
+
+ self._slotsAreConnected = False
+ if n is not None:
+ self.setNDimensions(n)
+
+ def setNDimensions(self, n):
+ """Initialize combo-boxes depending on number of dimensions of array.
+ Initially, the rows dimension is the second-to-last one, and the
+ columns dimension is the last one.
+
+ Link the CBs together. MAke them emit a signal when their value is
+ changed.
+
+ :param int n: Number of dimensions of array
+ """
+ # remember the number of dimensions and the rows dimension
+ self.n = n
+ self._rowsDim = n - 2
+
+ # ensure slots are disconnected before (re)initializing widget
+ if self._slotsAreConnected:
+ self.rowsCB.currentIndexChanged.disconnect(self._rowDimChanged)
+ self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
+
+ self._clear()
+ self.rowsCB.addItems([str(i) for i in range(n - 1)])
+ self.rowsCB.setCurrentIndex(n - 2)
+ if n >= 1:
+ self.columnsCB.addItem(str(n - 1))
+ self.columnsCB.setCurrentIndex(0)
+
+ # reconnect slots
+ self.rowsCB.currentIndexChanged.connect(self._rowDimChanged)
+ self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
+ self._slotsAreConnected = True
+
+ # emit new dimensions
+ if n > 2:
+ self.sigDimensionsChanged.emit(n - 2, n - 1)
+
+ def setDimensions(self, row_dim, col_dim):
+ """Set the rows and columns dimensions.
+
+ The rows dimension must be lower than the columns dimension.
+
+ :param int row_dim: Rows dimension
+ :param int col_dim: Columns dimension
+ """
+ if row_dim >= col_dim:
+ raise IndexError("Row dimension must be lower than column dimension")
+ if not (0 <= row_dim < self.n - 1):
+ raise IndexError("Row dimension must be between 0 and %d" % (self.n - 2))
+ if not (row_dim < col_dim <= self.n - 1):
+ raise IndexError("Col dimension must be between %d and %d" % (row_dim + 1, self.n - 1))
+
+ # set the rows dimension; this triggers an update of columnsCB
+ self.rowsCB.setCurrentIndex(row_dim)
+ # columnsCB first item is "row_dim + 1". So index of "col_dim" is
+ # col_dim - (row_dim + 1)
+ self.columnsCB.setCurrentIndex(col_dim - row_dim - 1)
+
+ def getDimensions(self):
+ """Return a 2-tuple of the rows dimension and the columns dimension.
+
+ :return: 2-tuple of axes numbers (row_dimension, col_dimension)
+ """
+ return self._getRowDim(), self._getColDim()
+
+ def _clear(self):
+ """Empty the combo-boxes"""
+ self.rowsCB.clear()
+ self.columnsCB.clear()
+
+ def _getRowDim(self):
+ """Get rows dimension, selected in :attr:`rowsCB`
+ """
+ # rows combobox contains elements "0", ..."n-2",
+ # so the selected dim is always equal to the index
+ return self.rowsCB.currentIndex()
+
+ def _getColDim(self):
+ """Get columns dimension, selected in :attr:`columnsCB`"""
+ # columns combobox contains elements "row_dim+1", "row_dim+2", ..., "n-1"
+ # so the selected dim is equal to row_dim + 1 + index
+ return self._rowsDim + 1 + self.columnsCB.currentIndex()
+
+ def _rowDimChanged(self):
+ """Update columns combobox when the rows dimension is changed.
+
+ Emit :attr:`sigDimensionsChanged`"""
+ old_col_dim = self._getColDim()
+ new_row_dim = self._getRowDim()
+
+ # clear cols CB
+ self.columnsCB.currentIndexChanged.disconnect(self._colDimChanged)
+ self.columnsCB.clear()
+ # refill cols CB
+ for i in range(new_row_dim + 1, self.n):
+ self.columnsCB.addItem(str(i))
+
+ # keep previous col dimension if possible
+ new_col_cb_idx = old_col_dim - (new_row_dim + 1)
+ if new_col_cb_idx < 0:
+ # if row_dim is now greater than the previous col_dim,
+ # we select a new col_dim = row_dim + 1 (first element in cols CB)
+ new_col_cb_idx = 0
+ self.columnsCB.setCurrentIndex(new_col_cb_idx)
+
+ # reconnect slot
+ self.columnsCB.currentIndexChanged.connect(self._colDimChanged)
+
+ self._rowsDim = new_row_dim
+
+ self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
+
+ def _colDimChanged(self):
+ """Emit :attr:`sigDimensionsChanged`"""
+ self.sigDimensionsChanged.emit(self._getRowDim(), self._getColDim())
+
+
+def _get_shape(array_like):
+ """Return shape of an array like object.
+
+ In case the object is a nested sequence (list of lists, tuples...),
+ the size of each dimension is assumed to be uniform, and is deduced from
+ the length of the first sequence.
+
+ :param array_like: Array like object: numpy array, hdf5 dataset,
+ multi-dimensional sequence
+ :return: Shape of array, as a tuple of integers
+ """
+ if hasattr(array_like, "shape"):
+ return array_like.shape
+
+ shape = []
+ subsequence = array_like
+ while hasattr(subsequence, "__len__"):
+ shape.append(len(subsequence))
+ subsequence = subsequence[0]
+
+ return tuple(shape)
+
+
+class ArrayTableWidget(qt.QWidget):
+ """This widget is designed to display data of 2D frames (images, slices)
+ in a table view. The widget can load any n-dimensional array, and display
+ any 2-D frame/slice in the array.
+
+ The index of the dimensions orthogonal to the displayed frame can be set
+ interactively using a browser widget (sliders, buttons and text entries).
+
+ To set the data, use :meth:`setArrayData`.
+ To select the perspective, use :meth:`setPerspective` or
+ use :meth:`setFrameAxes`.
+ To select the frame, use :meth:`setFrameIndex`.
+
+ .. image:: img/ArrayTableWidget.png
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: parent QWidget
+ :param labels: list of labels for each dimension of the array
+ """
+ qt.QWidget.__init__(self, parent)
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+
+ self.browserContainer = qt.QWidget(self)
+ self.browserLayout = qt.QGridLayout(self.browserContainer)
+ self.browserLayout.setContentsMargins(0, 0, 0, 0)
+ self.browserLayout.setSpacing(0)
+
+ self._dimensionLabelsText = []
+ """List of text labels sorted in the increasing order of the dimension
+ they apply to."""
+ self._browserLabels = []
+ """List of QLabel widgets."""
+ self._browserWidgets = []
+ """List of HorizontalSliderWithBrowser widgets."""
+
+ self.axesSelector = AxesSelector(self)
+
+ self.view = TableView(self)
+
+ self.mainLayout.addWidget(self.browserContainer)
+ self.mainLayout.addWidget(self.axesSelector)
+ self.mainLayout.addWidget(self.view)
+
+ self.model = ArrayTableModel(self)
+ self.view.setModel(self.model)
+
+ def setArrayData(self, data, labels=None, copy=True, editable=False):
+ """Set the data array. Update frame browsers and labels.
+
+ :param data: Numpy array or similar object (e.g. nested sequence,
+ h5py dataset...)
+ :param labels: list of labels for each dimension of the array, or
+ boolean ``True`` to use default labels ("dimension 0",
+ "dimension 1", ...). `None` to disable labels (default).
+ :param bool copy: If *True*, store a copy of *data* in the model. If
+ *False*, store a reference to *data* if possible (only possible if
+ *data* is a proper numpy array or an object that implements the
+ same methods).
+ :param bool editable: Flag to enable editing data. Default is *False*
+ """
+ self._data_shape = _get_shape(data)
+
+ n_widgets = len(self._browserWidgets)
+ n_dimensions = len(self._data_shape)
+
+ # Reset text of labels
+ self._dimensionLabelsText = []
+ for i in range(n_dimensions):
+ if labels in [True, 1]:
+ label_text = "Dimension %d" % i
+ elif labels is None or i >= len(labels):
+ label_text = ""
+ else:
+ label_text = labels[i]
+ self._dimensionLabelsText.append(label_text)
+
+ # not enough widgets, create new ones (we need n_dim - 2)
+ for i in range(n_widgets, n_dimensions - 2):
+ browser = HorizontalSliderWithBrowser(self.browserContainer)
+ self.browserLayout.addWidget(browser, i, 1)
+ self._browserWidgets.append(browser)
+ browser.valueChanged.connect(self._browserSlot)
+ browser.setEnabled(False)
+ browser.hide()
+
+ label = qt.QLabel(self.browserContainer)
+ self._browserLabels.append(label)
+ self.browserLayout.addWidget(label, i, 0)
+ label.hide()
+
+ n_widgets = len(self._browserWidgets)
+ for i in range(n_widgets):
+ label = self._browserLabels[i]
+ browser = self._browserWidgets[i]
+
+ if (i + 2) < n_dimensions:
+ label.setText(self._dimensionLabelsText[i])
+ browser.setRange(0, self._data_shape[i] - 1)
+ browser.setEnabled(True)
+ browser.show()
+ if labels is not None:
+ label.show()
+ else:
+ label.hide()
+ else:
+ browser.setEnabled(False)
+ browser.hide()
+ label.hide()
+
+ # set model
+ self.model.setArrayData(data, copy=copy, editable=editable)
+ # some linux distributions need this call
+ self.view.setModel(self.model)
+ if editable:
+ self.view.enableCut()
+ self.view.enablePaste()
+
+ # initialize & connect axesSelector
+ self.axesSelector.setNDimensions(n_dimensions)
+ self.axesSelector.sigDimensionsChanged.connect(self.setFrameAxes)
+
+ def setArrayColors(self, bgcolors=None, fgcolors=None):
+ """Set the colors for all table cells by passing an array
+ of RGB or RGBA values (integers between 0 and 255).
+
+ The shape of the colors array must be consistent with the data shape.
+
+ If the data array is n-dimensional, the colors array must be
+ (n+1)-dimensional, with the first n-dimensions identical to the data
+ array dimensions, and the last dimension length-3 (RGB) or
+ length-4 (RGBA).
+
+ :param bgcolors: RGB or RGBA colors array, defining the background color
+ for each cell in the table.
+ :param fgcolors: RGB or RGBA colors array, defining the foreground color
+ (text color) for each cell in the table.
+ """
+ self.model.setArrayColors(bgcolors, fgcolors)
+
+ def displayAxesSelector(self, isVisible):
+ """Allow to display or hide the axes selector.
+
+ :param bool isVisible: True to display the axes selector.
+ """
+ self.axesSelector.setVisible(isVisible)
+
+ def setFrameIndex(self, index):
+ """Set the active slice/image index in the n-dimensional array.
+
+ A frame is a 2D array extracted from an array. This frame is
+ necessarily parallel to 2 axes, and orthogonal to all other axes.
+
+ The index of a frame is a sequence of indices along the orthogonal
+ axes, where the frame intersects the respective axis. The indices
+ are listed in the same order as the corresponding dimensions of the
+ data array.
+
+ For example, it the data array has 5 dimensions, and we are
+ considering frames whose parallel axes are the 2nd and 4th dimensions
+ of the array, the frame index will be a sequence of length 3
+ corresponding to the indices where the frame intersects the 1st, 3rd
+ and 5th axes.
+
+ :param index: Sequence of indices defining the active data slice in
+ a n-dimensional array. The sequence length is :math:`n-2`
+ :raise: IndexError if any index in the index sequence is out of bound
+ on its respective axis.
+ """
+ self.model.setFrameIndex(index)
+
+ def _resetBrowsers(self, perspective):
+ """Adjust limits for browsers based on the perspective and the
+ size of the corresponding dimensions. Reset the index to 0.
+ Update the dimension in the labels.
+
+ :param perspective: Sequence of axes/dimensions numbers (0-based)
+ defining the axes orthogonal to the frame.
+ """
+ # for 3D arrays we can accept an int rather than a 1-tuple
+ if not hasattr(perspective, "__len__"):
+ perspective = [perspective]
+
+ # perspective must be sorted
+ perspective = sorted(perspective)
+
+ n_dimensions = len(self._data_shape)
+ for i in range(n_dimensions - 2):
+ browser = self._browserWidgets[i]
+ label = self._browserLabels[i]
+ browser.setRange(0, self._data_shape[perspective[i]] - 1)
+ browser.setValue(0)
+ label.setText(self._dimensionLabelsText[perspective[i]])
+
+ def setPerspective(self, perspective):
+ """Set the *perspective* by specifying which axes are orthogonal
+ to the frame.
+
+ For the opposite approach (defining parallel axes), use
+ :meth:`setFrameAxes` instead.
+
+ :param perspective: Sequence of unique axes numbers (0-based) defining
+ the orthogonal axes. For a n-dimensional array, the sequence
+ length is :math:`n-2`. The order is of the sequence is not taken
+ into account (the dimensions are displayed in increasing order
+ in the widget).
+ """
+ self.model.setPerspective(perspective)
+ self._resetBrowsers(perspective)
+
+ def setFrameAxes(self, row_axis, col_axis):
+ """Set the *perspective* by specifying which axes are parallel
+ to the frame.
+
+ For the opposite approach (defining orthogonal axes), use
+ :meth:`setPerspective` instead.
+
+ :param int row_axis: Index (0-based) of the first dimension used as a frame
+ axis
+ :param int col_axis: Index (0-based) of the 2nd dimension used as a frame
+ axis
+ """
+ self.model.setFrameAxes(row_axis, col_axis)
+ n_dimensions = len(self._data_shape)
+ perspective = tuple(set(range(0, n_dimensions)) - {row_axis, col_axis})
+ self._resetBrowsers(perspective)
+
+ def _browserSlot(self, value):
+ index = []
+ for browser in self._browserWidgets:
+ if browser.isEnabled():
+ index.append(browser.value())
+ self.setFrameIndex(index)
+ self.view.reset()
+
+ def getData(self, copy=True):
+ """Return a copy of the data array, or a reference to it if
+ *copy=False* is passed as parameter.
+
+ :param bool copy: If *True* (default), return a copy of the data. If
+ *False*, return a reference.
+ :return: Numpy array of data, or reference to original data object
+ if *copy=False*
+ """
+ return self.model.getData(copy=copy)
+
+
+def main():
+ import numpy
+ a = qt.QApplication([])
+ d = numpy.random.normal(0, 1, (4, 5, 1000, 1000))
+ for j in range(4):
+ for i in range(5):
+ d[j, i, :, :] += i + 10 * j
+ w = ArrayTableWidget()
+ if "2" in sys.argv:
+ print("sending a single image")
+ w.setArrayData(d[0, 0])
+ elif "3" in sys.argv:
+ print("sending 5 images")
+ w.setArrayData(d[0])
+ else:
+ print("sending 4 * 5 images ")
+ w.setArrayData(d, labels=True)
+ w.show()
+ a.exec()
+
+if __name__ == "__main__":
+ main()
diff --git a/silx/gui/data/DataViewer.py b/src/silx/gui/data/DataViewer.py
index 2e51439..2e51439 100644
--- a/silx/gui/data/DataViewer.py
+++ b/src/silx/gui/data/DataViewer.py
diff --git a/silx/gui/data/DataViewerFrame.py b/src/silx/gui/data/DataViewerFrame.py
index 9bfb95b..9bfb95b 100644
--- a/silx/gui/data/DataViewerFrame.py
+++ b/src/silx/gui/data/DataViewerFrame.py
diff --git a/silx/gui/data/DataViewerSelector.py b/src/silx/gui/data/DataViewerSelector.py
index a1e9947..a1e9947 100644
--- a/silx/gui/data/DataViewerSelector.py
+++ b/src/silx/gui/data/DataViewerSelector.py
diff --git a/silx/gui/data/DataViews.py b/src/silx/gui/data/DataViews.py
index b18a813..b18a813 100644
--- a/silx/gui/data/DataViews.py
+++ b/src/silx/gui/data/DataViews.py
diff --git a/src/silx/gui/data/Hdf5TableView.py b/src/silx/gui/data/Hdf5TableView.py
new file mode 100644
index 0000000..9d65a84
--- /dev/null
+++ b/src/silx/gui/data/Hdf5TableView.py
@@ -0,0 +1,634 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define model and widget to display 1D slices from numpy
+array using compound data types or hdf5 databases.
+"""
+from __future__ import division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/02/2019"
+
+import collections
+import functools
+import os.path
+import logging
+import h5py
+import numpy
+
+from silx.gui import qt
+import silx.io
+from .TextFormatter import TextFormatter
+import silx.gui.hdf5
+from silx.gui.widgets import HierarchicalTableView
+from ..hdf5.Hdf5Formatter import Hdf5Formatter
+from ..hdf5._utils import htmlFromDict
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _CellData(object):
+ """Store a table item
+ """
+ def __init__(self, value=None, isHeader=False, span=None, tooltip=None):
+ """
+ Constructor
+
+ :param str value: Label of this property
+ :param bool isHeader: True if the cell is an header
+ :param tuple span: Tuple of row, column span
+ """
+ self.__value = value
+ self.__isHeader = isHeader
+ self.__span = span
+ self.__tooltip = tooltip
+
+ def isHeader(self):
+ """Returns true if the property is a sub-header title.
+
+ :rtype: bool
+ """
+ return self.__isHeader
+
+ def value(self):
+ """Returns the value of the item.
+ """
+ return self.__value
+
+ def span(self):
+ """Returns the span size of the cell.
+
+ :rtype: tuple
+ """
+ return self.__span
+
+ def tooltip(self):
+ """Returns the tooltip of the item.
+
+ :rtype: tuple
+ """
+ return self.__tooltip
+
+ def invalidateValue(self):
+ self.__value = None
+
+ def invalidateToolTip(self):
+ self.__tooltip = None
+
+ def data(self, role):
+ return None
+
+
+class _TableData(object):
+ """Modelize a table with header, row and column span.
+
+ It is mostly defined as a row based table.
+ """
+
+ def __init__(self, columnCount):
+ """Constructor.
+
+ :param int columnCount: Define the number of column of the table
+ """
+ self.__colCount = columnCount
+ self.__data = []
+
+ def rowCount(self):
+ """Returns the number of rows.
+
+ :rtype: int
+ """
+ return len(self.__data)
+
+ def columnCount(self):
+ """Returns the number of columns.
+
+ :rtype: int
+ """
+ return self.__colCount
+
+ def clear(self):
+ """Remove all the cells of the table"""
+ self.__data = []
+
+ def cellAt(self, row, column):
+ """Returns the cell at the row column location. Else None if there is
+ nothing.
+
+ :rtype: _CellData
+ """
+ if row < 0:
+ return None
+ if column < 0:
+ return None
+ if row >= len(self.__data):
+ return None
+ cells = self.__data[row]
+ if column >= len(cells):
+ return None
+ return cells[column]
+
+ def addHeaderRow(self, headerLabel):
+ """Append the table with header on the full row.
+
+ :param str headerLabel: label of the header.
+ """
+ item = _CellData(value=headerLabel, isHeader=True, span=(1, self.__colCount))
+ self.__data.append([item])
+
+ def addHeaderValueRow(self, headerLabel, value, tooltip=None):
+ """Append the table with a row using the first column as an header and
+ other cells as a single cell for the value.
+
+ :param str headerLabel: label of the header.
+ :param object value: value to store.
+ """
+ header = _CellData(value=headerLabel, isHeader=True)
+ value = _CellData(value=value, span=(1, self.__colCount), tooltip=tooltip)
+ self.__data.append([header, value])
+
+ def addRow(self, *args):
+ """Append the table with a row using arguments for each cells
+
+ :param list[object] args: List of cell values for the row
+ """
+ row = []
+ for value in args:
+ if not isinstance(value, _CellData):
+ value = _CellData(value=value)
+ row.append(value)
+ self.__data.append(row)
+
+
+class _CellFilterAvailableData(_CellData):
+ """Cell rendering for availability of a filter"""
+
+ _states = {
+ True: ("Available", qt.QColor(0x000000), None, None),
+ False: ("Not available", qt.QColor(0xFFFFFF), qt.QColor(0xFF0000),
+ "You have to install this filter on your system to be able to read this dataset"),
+ "na": ("n.a.", qt.QColor(0x000000), None,
+ "This version of h5py/hdf5 is not able to display the information"),
+ }
+
+ def __init__(self, filterId):
+ if h5py.version.hdf5_version_tuple >= (1, 10, 2):
+ # Previous versions only returns True if the filter was first used
+ # to decode a dataset
+ self.__availability = h5py.h5z.filter_avail(filterId)
+ else:
+ self.__availability = "na"
+ _CellData.__init__(self)
+
+ def value(self):
+ state = self._states[self.__availability]
+ return state[0]
+
+ def tooltip(self):
+ state = self._states[self.__availability]
+ return state[3]
+
+ def data(self, role=qt.Qt.DisplayRole):
+ state = self._states[self.__availability]
+ if role == qt.Qt.ForegroundRole:
+ return state[1]
+ elif role == qt.Qt.BackgroundRole:
+ return state[2]
+ else:
+ return None
+
+
+class Hdf5TableModel(HierarchicalTableView.HierarchicalTableModel):
+ """This data model provides access to HDF5 node content (File, Group,
+ Dataset). Main info, like name, file, attributes... are displayed
+ """
+
+ def __init__(self, parent=None, data=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Parent object
+ :param object data: An h5py-like object (file, group or dataset)
+ """
+ super(Hdf5TableModel, self).__init__(parent)
+
+ self.__obj = None
+ self.__data = _TableData(columnCount=5)
+ self.__formatter = None
+ self.__hdf5Formatter = Hdf5Formatter(self)
+ formatter = TextFormatter(self)
+ self.setFormatter(formatter)
+ self.setObject(data)
+
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ return self.__data.rowCount()
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ return self.__data.columnCount()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ cell = self.__data.cellAt(index.row(), index.column())
+ if cell is None:
+ return None
+
+ if role == self.SpanRole:
+ return cell.span()
+ elif role == self.IsHeaderRole:
+ return cell.isHeader()
+ elif role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
+ value = cell.value()
+ if callable(value):
+ try:
+ value = value(self.__obj)
+ except Exception:
+ cell.invalidateValue()
+ raise
+ return value
+ elif role == qt.Qt.ToolTipRole:
+ value = cell.tooltip()
+ if callable(value):
+ try:
+ value = value(self.__obj)
+ except Exception:
+ cell.invalidateToolTip()
+ raise
+ return value
+ else:
+ return cell.data(role)
+ return None
+
+ def isSupportedObject(self, h5pyObject):
+ """
+ Returns true if the provided object can be modelized using this model.
+ """
+ isSupported = False
+ isSupported = isSupported or silx.io.is_group(h5pyObject)
+ isSupported = isSupported or silx.io.is_dataset(h5pyObject)
+ isSupported = isSupported or isinstance(h5pyObject, silx.gui.hdf5.H5Node)
+ return isSupported
+
+ def setObject(self, h5pyObject):
+ """Set the h5py-like object exposed by the model
+
+ :param h5pyObject: A h5py-like object. It can be a `h5py.Dataset`,
+ a `h5py.File`, a `h5py.Group`. It also can be a,
+ `silx.gui.hdf5.H5Node` which is needed to display some local path
+ information.
+ """
+ self.beginResetModel()
+
+ if h5pyObject is None or self.isSupportedObject(h5pyObject):
+ self.__obj = h5pyObject
+ else:
+ _logger.warning("Object class %s unsupported. Object ignored.", type(h5pyObject))
+ self.__initProperties()
+
+ self.endResetModel()
+
+ def __formatHdf5Type(self, dataset):
+ """Format the HDF5 type"""
+ return self.__hdf5Formatter.humanReadableHdf5Type(dataset)
+
+ def __attributeTooltip(self, attribute):
+ attributeDict = collections.OrderedDict()
+ if hasattr(attribute, "shape"):
+ attributeDict["Shape"] = self.__hdf5Formatter.humanReadableShape(attribute)
+ attributeDict["Data type"] = self.__hdf5Formatter.humanReadableType(attribute, full=True)
+ html = htmlFromDict(attributeDict, title="HDF5 Attribute")
+ return html
+
+ def __formatDType(self, dataset):
+ """Format the numpy dtype"""
+ return self.__hdf5Formatter.humanReadableType(dataset, full=True)
+
+ def __formatShape(self, dataset):
+ """Format the shape"""
+ if dataset.shape is None or len(dataset.shape) <= 1:
+ return self.__hdf5Formatter.humanReadableShape(dataset)
+ size = dataset.size
+ shape = self.__hdf5Formatter.humanReadableShape(dataset)
+ return u"%s = %s" % (shape, size)
+
+ def __formatChunks(self, dataset):
+ """Format the shape"""
+ chunks = dataset.chunks
+ if chunks is None:
+ return ""
+ shape = " \u00D7 ".join([str(i) for i in chunks])
+ sizes = numpy.product(chunks)
+ text = "%s = %s" % (shape, sizes)
+ return text
+
+ def __initProperties(self):
+ """Initialize the list of available properties according to the defined
+ h5py-like object."""
+ self.__data.clear()
+ if self.__obj is None:
+ return
+
+ obj = self.__obj
+
+ hdf5obj = obj
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ hdf5obj = obj.h5py_object
+
+ if silx.io.is_file(hdf5obj):
+ objectType = "File"
+ elif silx.io.is_group(hdf5obj):
+ objectType = "Group"
+ elif silx.io.is_dataset(hdf5obj):
+ objectType = "Dataset"
+ else:
+ objectType = obj.__class__.__name__
+ self.__data.addHeaderRow(headerLabel="HDF5 %s" % objectType)
+
+ SEPARATOR = "::"
+
+ self.__data.addHeaderRow(headerLabel="Path info")
+ showPhysicalLocation = True
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ # helpful informations if the object come from an HDF5 tree
+ self.__data.addHeaderValueRow("Basename", lambda x: x.local_basename)
+ self.__data.addHeaderValueRow("Name", lambda x: x.local_name)
+ local = lambda x: x.local_filename + SEPARATOR + x.local_name
+ self.__data.addHeaderValueRow("Local", local)
+ else:
+ # it's a real H5py object
+ self.__data.addHeaderValueRow("Basename", lambda x: os.path.basename(x.name))
+ self.__data.addHeaderValueRow("Name", lambda x: x.name)
+ if obj.file is not None:
+ self.__data.addHeaderValueRow("File", lambda x: x.file.filename)
+ if hasattr(obj, "path"):
+ # That's a link
+ if hasattr(obj, "filename"):
+ # External link
+ link = lambda x: x.filename + SEPARATOR + x.path
+ else:
+ # Soft link
+ link = lambda x: x.path
+ self.__data.addHeaderValueRow("Link", link)
+ showPhysicalLocation = False
+
+ # External data (nothing to do with external links)
+ nExtSources = 0
+ firstExtSource = None
+ extType = None
+ if silx.io.is_dataset(hdf5obj):
+ if hasattr(hdf5obj, "is_virtual"):
+ if hdf5obj.is_virtual:
+ extSources = hdf5obj.virtual_sources()
+ if extSources:
+ firstExtSource = extSources[0].file_name + SEPARATOR + extSources[0].dset_name
+ extType = "Virtual"
+ nExtSources = len(extSources)
+ if hasattr(hdf5obj, "external"):
+ extSources = hdf5obj.external
+ if extSources:
+ firstExtSource = extSources[0][0]
+ extType = "Raw"
+ nExtSources = len(extSources)
+
+ if showPhysicalLocation:
+ def _physical_location(x):
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ return x.physical_filename + SEPARATOR + x.physical_name
+ elif silx.io.is_file(obj):
+ return x.filename + SEPARATOR + x.name
+ elif obj.file is not None:
+ return x.file.filename + SEPARATOR + x.name
+ else:
+ # Guess it is a virtual node
+ return "No physical location"
+
+ self.__data.addHeaderValueRow("Physical", _physical_location)
+
+ if extType:
+ def _first_source(x):
+ # Absolute path
+ if os.path.isabs(firstExtSource):
+ return firstExtSource
+
+ # Relative path with respect to the file directory
+ if isinstance(obj, silx.gui.hdf5.H5Node):
+ filename = x.physical_filename
+ elif silx.io.is_file(obj):
+ filename = x.filename
+ elif obj.file is not None:
+ filename = x.file.filename
+ else:
+ return firstExtSource
+
+ if firstExtSource[0] == ".":
+ firstExtSource.pop(0)
+ return os.path.join(os.path.dirname(filename), firstExtSource)
+
+ self.__data.addHeaderRow(headerLabel="External sources")
+ self.__data.addHeaderValueRow("Type", extType)
+ self.__data.addHeaderValueRow("Count", str(nExtSources))
+ self.__data.addHeaderValueRow("First", _first_source)
+
+ if hasattr(obj, "dtype"):
+
+ self.__data.addHeaderRow(headerLabel="Data info")
+
+ if hasattr(obj, "id") and hasattr(obj.id, "get_type"):
+ # display the HDF5 type
+ self.__data.addHeaderValueRow("HDF5 type", self.__formatHdf5Type)
+ self.__data.addHeaderValueRow("dtype", self.__formatDType)
+ if hasattr(obj, "shape"):
+ self.__data.addHeaderValueRow("shape", self.__formatShape)
+ if hasattr(obj, "chunks") and obj.chunks is not None:
+ self.__data.addHeaderValueRow("chunks", self.__formatChunks)
+
+ # relative to compression
+ # h5py expose compression, compression_opts but are not initialized
+ # for external plugins, then we use id
+ # h5py also expose fletcher32 and shuffle attributes, but it is also
+ # part of the filters
+ if hasattr(obj, "shape") and hasattr(obj, "id"):
+ if hasattr(obj.id, "get_create_plist"):
+ dcpl = obj.id.get_create_plist()
+ if dcpl.get_nfilters() > 0:
+ self.__data.addHeaderRow(headerLabel="Compression info")
+ pos = _CellData(value="Position", isHeader=True)
+ hdf5id = _CellData(value="HDF5 ID", isHeader=True)
+ name = _CellData(value="Name", isHeader=True)
+ options = _CellData(value="Options", isHeader=True)
+ availability = _CellData(value="", isHeader=True)
+ self.__data.addRow(pos, hdf5id, name, options, availability)
+ for index in range(dcpl.get_nfilters()):
+ filterId, name, options = self.__getFilterInfo(obj, index)
+ pos = _CellData(value=str(index))
+ hdf5id = _CellData(value=str(filterId))
+ name = _CellData(value=name)
+ options = _CellData(value=options)
+ availability = _CellFilterAvailableData(filterId=filterId)
+ self.__data.addRow(pos, hdf5id, name, options, availability)
+
+ if hasattr(obj, "attrs"):
+ if len(obj.attrs) > 0:
+ self.__data.addHeaderRow(headerLabel="Attributes")
+ for key in sorted(obj.attrs.keys()):
+ callback = lambda key, x: self.__formatter.toString(x.attrs[key])
+ callbackTooltip = lambda key, x: self.__attributeTooltip(x.attrs[key])
+ self.__data.addHeaderValueRow(headerLabel=key,
+ value=functools.partial(callback, key),
+ tooltip=functools.partial(callbackTooltip, key))
+
+ def __getFilterInfo(self, dataset, filterIndex):
+ """Get a tuple of readable info from dataset filters
+
+ :param h5py.Dataset dataset: A h5py dataset
+ :param int filterId:
+ """
+ try:
+ dcpl = dataset.id.get_create_plist()
+ info = dcpl.get_filter(filterIndex)
+ filterId, _flags, cdValues, name = info
+ name = self.__formatter.toString(name)
+ options = " ".join([self.__formatter.toString(i) for i in cdValues])
+ return (filterId, name, options)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ return (None, None, None)
+
+ def object(self):
+ """Returns the internal object modelized.
+
+ :rtype: An h5py-like object
+ """
+ return self.__obj
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self.__formatter:
+ return
+
+ self.__hdf5Formatter.setTextFormatter(formatter)
+
+ self.beginResetModel()
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self.__formatter = formatter
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ self.endResetModel()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.reset()
+
+
+class Hdf5TableItemDelegate(HierarchicalTableView.HierarchicalItemDelegate):
+ """Item delegate the :class:`Hdf5TableView` with read-only text editor"""
+
+ def createEditor(self, parent, option, index):
+ """See :meth:`QStyledItemDelegate.createEditor`"""
+ editor = super().createEditor(parent, option, index)
+ if isinstance(editor, qt.QLineEdit):
+ editor.setReadOnly(True)
+ editor.deselect()
+ editor.textChanged.connect(self.__textChanged, qt.Qt.QueuedConnection)
+ self.installEventFilter(editor)
+ return editor
+
+ def __textChanged(self, text):
+ sender = self.sender()
+ if sender is not None:
+ sender.deselect()
+
+ def eventFilter(self, watched, event):
+ eventType = event.type()
+ if eventType == qt.QEvent.FocusIn:
+ watched.selectAll()
+ qt.QTimer.singleShot(0, watched.selectAll)
+ elif eventType == qt.QEvent.FocusOut:
+ watched.deselect()
+ return super().eventFilter(watched, event)
+
+
+class Hdf5TableView(HierarchicalTableView.HierarchicalTableView):
+ """A widget to display metadata about a HDF5 node using a table."""
+
+ def __init__(self, parent=None):
+ super(Hdf5TableView, self).__init__(parent)
+ self.setModel(Hdf5TableModel(self))
+ self.setItemDelegate(Hdf5TableItemDelegate(self))
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ def isSupportedData(self, data):
+ """
+ Returns true if the provided object can be modelized using this model.
+ """
+ return self.model().isSupportedObject(data)
+
+ def setData(self, data):
+ """Set the h5py-like object exposed by the model
+
+ :param data: A h5py-like object. It can be a `h5py.Dataset`,
+ a `h5py.File`, a `h5py.Group`. It also can be a,
+ `silx.gui.hdf5.H5Node` which is needed to display some local path
+ information.
+ """
+ model = self.model()
+
+ model.setObject(data)
+ header = self.horizontalHeader()
+ header.setSectionResizeMode(0, qt.QHeaderView.Fixed)
+ header.setSectionResizeMode(1, qt.QHeaderView.ResizeToContents)
+ header.setSectionResizeMode(2, qt.QHeaderView.Stretch)
+ header.setSectionResizeMode(3, qt.QHeaderView.ResizeToContents)
+ header.setSectionResizeMode(4, qt.QHeaderView.ResizeToContents)
+ header.setStretchLastSection(False)
+
+ for row in range(model.rowCount()):
+ for column in range(model.columnCount()):
+ index = model.index(row, column)
+ if (index.isValid() and index.data(
+ HierarchicalTableView.HierarchicalTableModel.IsHeaderRole) is False):
+ self.openPersistentEditor(index)
diff --git a/src/silx/gui/data/HexaTableView.py b/src/silx/gui/data/HexaTableView.py
new file mode 100644
index 0000000..9e00a7b
--- /dev/null
+++ b/src/silx/gui/data/HexaTableView.py
@@ -0,0 +1,272 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module defines model and widget to display raw data using an
+hexadecimal viewer.
+"""
+from __future__ import division
+
+import collections
+
+import numpy
+
+from silx.gui import qt
+import silx.io.utils
+from silx.gui.widgets.TableWidget import CopySelectedCellsAction
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/05/2018"
+
+
+class _VoidConnector(object):
+ """Byte connector to a numpy.void data.
+
+ It uses a cache of 32 x 1KB and a direct read access API from HDF5.
+ """
+
+ def __init__(self, data):
+ self.__cache = collections.OrderedDict()
+ self.__len = data.itemsize
+ self.__data = data
+
+ def __getBuffer(self, bufferId):
+ if bufferId not in self.__cache:
+ pos = bufferId << 10
+ data = self.__data
+ if hasattr(data, "tobytes"):
+ data = data.tobytes()[pos:pos + 1024]
+ else:
+ # Old fashion
+ data = data.data[pos:pos + 1024]
+
+ self.__cache[bufferId] = data
+ if len(self.__cache) > 32:
+ self.__cache.popitem()
+ else:
+ data = self.__cache[bufferId]
+ return data
+
+ def __getitem__(self, pos):
+ """Returns the value of the byte at the given position.
+
+ :param uint pos: Position of the byte
+ :rtype: int
+ """
+ bufferId = pos >> 10
+ bufferPos = pos & 0b1111111111
+ data = self.__getBuffer(bufferId)
+ return data[bufferPos]
+
+ def __len__(self):
+ """
+ Returns the number of available bytes.
+
+ :rtype: uint
+ """
+ return self.__len
+
+
+class HexaTableModel(qt.QAbstractTableModel):
+ """This data model provides access to a numpy void data.
+
+ Bytes are displayed one by one as a hexadecimal viewer.
+
+ The 16th first columns display bytes as hexadecimal, the last column
+ displays the same data as ASCII.
+
+ :param qt.QObject parent: Parent object
+ :param data: A numpy array or a h5py dataset
+ """
+ def __init__(self, parent=None, data=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self.__data = None
+ self.__connector = None
+ self.setArrayData(data)
+
+ if hasattr(qt.QFontDatabase, "systemFont"): # Qt >= 5.2
+ self.__font = qt.QFontDatabase.systemFont(qt.QFontDatabase.FixedFont)
+ else: # Qt < 5.2
+ self.__font = qt.QFont("Monospace")
+ self.__font.setStyleHint(qt.QFont.TypeWriter)
+ self.__palette = qt.QPalette()
+
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ if self.__connector is None:
+ return 0
+ return ((len(self.__connector) - 1) >> 4) + 1
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ return 0x10 + 1
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ if self.__connector is None:
+ return None
+
+ row = index.row()
+ column = index.column()
+
+ if role == qt.Qt.DisplayRole:
+ if column == 0x10:
+ start = (row << 4)
+ text = ""
+ for i in range(0x10):
+ pos = start + i
+ if pos >= len(self.__connector):
+ break
+ value = self.__connector[pos]
+ if value > 0x20 and value < 0x7F:
+ text += chr(value)
+ else:
+ text += "."
+ return text
+ else:
+ pos = (row << 4) + column
+ if pos < len(self.__connector):
+ value = self.__connector[pos]
+ return "%02X" % value
+ else:
+ return ""
+ elif role == qt.Qt.FontRole:
+ return self.__font
+
+ elif role == qt.Qt.BackgroundRole:
+ pos = (row << 4) + column
+ if column != 0x10 and pos >= len(self.__connector):
+ return self.__palette.color(qt.QPalette.Disabled, qt.QPalette.Window)
+ else:
+ return None
+
+ return None
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if section == -1:
+ # PyQt4 send -1 when there is columns but no rows
+ return None
+
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ return "%02X" % (section << 4)
+ if orientation == qt.Qt.Horizontal:
+ if section == 0x10:
+ return "ASCII"
+ else:
+ return "%02X" % section
+ elif role == qt.Qt.FontRole:
+ return self.__font
+ elif role == qt.Qt.TextAlignmentRole:
+ if orientation == qt.Qt.Vertical:
+ return qt.Qt.AlignRight
+ if orientation == qt.Qt.Horizontal:
+ if section == 0x10:
+ return qt.Qt.AlignLeft
+ else:
+ return qt.Qt.AlignCenter
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not.
+ """
+ row = index.row()
+ column = index.column()
+ pos = (row << 4) + column
+ if column != 0x10 and pos >= len(self.__connector):
+ return qt.Qt.NoItemFlags
+ return qt.QAbstractTableModel.flags(self, index)
+
+ def setArrayData(self, data):
+ """Set the data array.
+
+ :param data: A numpy object or a dataset.
+ """
+ self.beginResetModel()
+
+ self.__connector = None
+ self.__data = data
+ if self.__data is not None:
+ if silx.io.utils.is_dataset(self.__data):
+ data = data[()]
+ elif isinstance(self.__data, numpy.ndarray):
+ data = data[()]
+ self.__connector = _VoidConnector(data)
+
+ self.endResetModel()
+
+ def arrayData(self):
+ """Returns the internal data.
+
+ :rtype: numpy.ndarray of h5py.Dataset
+ """
+ return self.__data
+
+
+class HexaTableView(qt.QTableView):
+ """TableView using HexaTableModel as default model.
+
+ It customs the column size to provide a better layout.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: parent QWidget
+ """
+ qt.QTableView.__init__(self, parent)
+
+ model = HexaTableModel(self)
+ self.setModel(model)
+ self._copyAction = CopySelectedCellsAction(self)
+ self.addAction(self._copyAction)
+
+ def copy(self):
+ self._copyAction.trigger()
+
+ def setArrayData(self, data):
+ """Set the data array.
+
+ :param data: A numpy object or a dataset.
+ """
+ self.model().setArrayData(data)
+ self.__fixHeader()
+
+ def __fixHeader(self):
+ """Update the view according to the state of the auto-resize"""
+ header = self.horizontalHeader()
+ header.setDefaultSectionSize(30)
+ header.setStretchLastSection(True)
+ for i in range(0x10):
+ header.setSectionResizeMode(i, qt.QHeaderView.Fixed)
+ header.setSectionResizeMode(0x10, qt.QHeaderView.Stretch)
diff --git a/src/silx/gui/data/NXdataWidgets.py b/src/silx/gui/data/NXdataWidgets.py
new file mode 100644
index 0000000..54ea287
--- /dev/null
+++ b/src/silx/gui/data/NXdataWidgets.py
@@ -0,0 +1,1086 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines widgets used by _NXdataView.
+"""
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "12/11/2018"
+
+import logging
+import numpy
+
+from silx.gui import qt
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+from silx.gui.plot import Plot1D, Plot2D, StackView, ScatterView, items
+from silx.gui.plot.ComplexImageView import ComplexImageView
+from silx.gui.colors import Colormap
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+from silx.math.calibration import ArrayCalibration, NoCalibration, LinearCalibration
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ArrayCurvePlot(qt.QWidget):
+ """
+ Widget for plotting a curve from a multi-dimensional signal array
+ and a 1D axis array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last dimension must have the same length as
+ the axis array.
+
+ The widget provides sliders to select indices on the first (n - 1)
+ dimensions of the signal array, and buttons to add/replace selected
+ curves to the plot.
+
+ This widget also handles simple 2D or 3D scatter plots (third dimension
+ displayed as colour of points).
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayCurvePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__signal_errors = None
+ self.__axis = None
+ self.__axis_name = None
+ self.__x_axis_errors = None
+ self.__values = None
+
+ self._plot = Plot1D(self)
+
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ self._plot.sigActiveCurveChanged.connect(self._setYLabelFromActiveLegend)
+
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: Plot1D
+ """
+ return self._plot
+
+ def setCurvesData(self, ys, x=None,
+ yerror=None, xerror=None,
+ ylabels=None, xlabel=None, title=None,
+ xscale=None, yscale=None):
+ """
+
+ :param List[ndarray] ys: List of arrays to be represented by the y (vertical) axis.
+ It can be multiple n-D array whose last dimension must
+ have the same length as x (and values must be None)
+ :param ndarray x: 1-D dataset used as the curve's x values. If provided,
+ its lengths must be equal to the length of the last dimension of
+ ``y`` (and equal to the length of ``value``, for a scatter plot).
+ :param ndarray yerror: Single array of errors for y (same shape), or None.
+ There can only be one array, and it applies to the first/main y
+ (no y errors for auxiliary_signals curves).
+ :param ndarray xerror: 1-D dataset of errors for x, or None
+ :param str ylabels: Labels for each curve's Y axis
+ :param str xlabel: Label for X axis
+ :param str title: Graph title
+ :param str xscale: Scale of X axis in (None, 'linear', 'log')
+ :param str yscale: Scale of Y axis in (None, 'linear', 'log')
+ """
+ self.__signals = ys
+ self.__signals_names = ylabels or (["Y"] * len(ys))
+ self.__signal_errors = yerror
+ self.__axis = x
+ self.__axis_name = xlabel
+ self.__x_axis_errors = xerror
+
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateCurve)
+ self.__selector_is_connected = False
+ self._selector.setData(ys[0])
+ self._selector.setAxisNames(["Y"])
+
+ if len(ys[0].shape) < 2:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._plot.setGraphTitle(title or "")
+ if xscale is not None:
+ self._plot.getXAxis().setScale(
+ 'log' if xscale == 'log' else 'linear')
+ if yscale is not None:
+ self._plot.getYAxis().setScale(
+ 'log' if yscale == 'log' else 'linear')
+ self._updateCurve()
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateCurve)
+ self.__selector_is_connected = True
+
+ def _updateCurve(self):
+ selection = self._selector.selection()
+ ys = [sig[selection] for sig in self.__signals]
+ y0 = ys[0]
+ len_y = len(y0)
+ x = self.__axis
+ if x is None:
+ x = numpy.arange(len_y)
+ elif numpy.isscalar(x) or len(x) == 1:
+ # constant axis
+ x = x * numpy.ones_like(y0)
+ elif len(x) == 2 and len_y != 2:
+ # linear calibration a + b * x
+ x = x[0] + x[1] * numpy.arange(len_y)
+
+ # Only remove curves that will no longer belong to the plot
+ # So remaining curves keep their settings
+ for item in self._plot.getItems():
+ if (isinstance(item, items.Curve) and
+ item.getName() not in self.__signals_names):
+ self._plot.remove(item)
+
+ for i in range(len(self.__signals)):
+ legend = self.__signals_names[i]
+
+ # errors only supported for primary signal in NXdata
+ y_errors = None
+ if i == 0 and self.__signal_errors is not None:
+ y_errors = self.__signal_errors[self._selector.selection()]
+ self._plot.addCurve(x, ys[i], legend=legend,
+ xerror=self.__x_axis_errors,
+ yerror=y_errors)
+ if i == 0:
+ self._plot.setActiveCurve(legend)
+
+ self._plot.resetZoom()
+ self._plot.getXAxis().setLabel(self.__axis_name)
+ self._plot.getYAxis().setLabel(self.__signals_names[0])
+
+ def _setYLabelFromActiveLegend(self, previous_legend, new_legend):
+ for ylabel in self.__signals_names:
+ if new_legend is not None and new_legend == ylabel:
+ self._plot.getYAxis().setLabel(ylabel)
+ break
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.clear()
+
+
+class XYVScatterPlot(qt.QWidget):
+ """
+ Widget for plotting one or more scatters
+ (with identical x, y coordinates).
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(XYVScatterPlot, self).__init__(parent)
+
+ self.__y_axis = None
+ """1D array"""
+ self.__y_axis_name = None
+ self.__values = None
+ """List of 1D arrays (for multiple scatters with identical
+ x, y coordinates)"""
+
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__x_axis_errors = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__y_axis_errors = None
+
+ self._plot = ScatterView(self)
+ self._plot.setColormap(Colormap(name="viridis",
+ vmin=None, vmax=None,
+ normalization=Colormap.LINEAR))
+
+ self._slider = HorizontalSliderWithBrowser(parent=self)
+ self._slider.setMinimum(0)
+ self._slider.setValue(0)
+ self._slider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._slider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0)
+ layout.addWidget(self._slider, 1, 0)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateScatter()
+
+ def getScatterView(self):
+ """Returns the :class:`ScatterView` used for the display
+
+ :rtype: ScatterView
+ """
+ return self._plot
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: PlotWidget
+ """
+ return self._plot.getPlotWidget()
+
+ def setScattersData(self, y, x, values,
+ yerror=None, xerror=None,
+ ylabel=None, xlabel=None,
+ title="", scatter_titles=None,
+ xscale=None, yscale=None):
+ """
+
+ :param ndarray y: 1D array for y (vertical) coordinates.
+ :param ndarray x: 1D array for x coordinates.
+ :param List[ndarray] values: List of 1D arrays of values.
+ This will be used to compute the color map and assign colors
+ to the points. There should be as many arrays in the list as
+ scatters to be represented.
+ :param ndarray yerror: 1D array of errors for y (same shape), or None.
+ :param ndarray xerror: 1D array of errors for x, or None
+ :param str ylabel: Label for Y axis
+ :param str xlabel: Label for X axis
+ :param str title: Main graph title
+ :param List[str] scatter_titles: Subtitles (one per scatter)
+ :param str xscale: Scale of X axis in (None, 'linear', 'log')
+ :param str yscale: Scale of Y axis in (None, 'linear', 'log')
+ """
+ self.__y_axis = y
+ self.__x_axis = x
+ self.__x_axis_name = xlabel or "X"
+ self.__y_axis_name = ylabel or "Y"
+ self.__x_axis_errors = xerror
+ self.__y_axis_errors = yerror
+ self.__values = values
+
+ self.__graph_title = title or ""
+ self.__scatter_titles = scatter_titles
+
+ self._slider.valueChanged[int].disconnect(self._sliderIdxChanged)
+ self._slider.setMaximum(len(values) - 1)
+ if len(values) > 1:
+ self._slider.show()
+ else:
+ self._slider.hide()
+ self._slider.setValue(0)
+ self._slider.valueChanged[int].connect(self._sliderIdxChanged)
+
+ if xscale is not None:
+ self._plot.getXAxis().setScale(
+ 'log' if xscale == 'log' else 'linear')
+ if yscale is not None:
+ self._plot.getYAxis().setScale(
+ 'log' if yscale == 'log' else 'linear')
+
+ self._updateScatter()
+
+ def _updateScatter(self):
+ x = self.__x_axis
+ y = self.__y_axis
+
+ idx = self._slider.value()
+
+ if self.__graph_title:
+ title = self.__graph_title # main NXdata @title
+ if len(self.__scatter_titles) > 1:
+ # Append dataset name only when there is many datasets
+ title += '\n' + self.__scatter_titles[idx]
+ else:
+ title = self.__scatter_titles[idx] # scatter dataset name
+
+ self._plot.setGraphTitle(title)
+ self._plot.setData(x, y, self.__values[idx],
+ xerror=self.__x_axis_errors,
+ yerror=self.__y_axis_errors)
+ self._plot.resetZoom()
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ self._plot.getPlotWidget().clear()
+
+
+class ArrayImagePlot(qt.QWidget):
+ """
+ Widget for plotting an image from a multi-dimensional signal array
+ and two 1D axes array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last two dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 2) dimensions of
+ the signal array, and the plot is updated to show the image corresponding
+ to the selection.
+
+ If one or both of the axes does not have regularly spaced values, the
+ the image is plotted as a coloured scatter plot.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayImagePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+
+ self._plot = Plot2D(self)
+ self._plot.setDefaultColormap(Colormap(name="viridis",
+ vmin=None, vmax=None,
+ normalization=Colormap.LINEAR))
+ self._plot.getIntensityHistogramAction().setVisible(True)
+ self._plot.setKeepDataAspectRatio(True)
+ maskToolWidget = self._plot.getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
+ # not closable
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self._selector.selectionChanged.connect(self._updateImage)
+
+ self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
+ self._auxSigSlider.setMinimum(0)
+ self._auxSigSlider.setValue(0)
+ self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._auxSigSlider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+ layout.addWidget(self._auxSigSlider)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateImage()
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: Plot2D
+ """
+ return self._plot
+
+ def setImageData(self, signals,
+ x_axis=None, y_axis=None,
+ signals_names=None,
+ xlabel=None, ylabel=None,
+ title=None, isRgba=False,
+ xscale=None, yscale=None):
+ """
+
+ :param signals: list of n-D datasets, whose last 2 dimensions are used as the
+ image's values, or list of 3D datasets interpreted as RGBA image.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param signals_names: Names for each image, used as subtitle and legend.
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param title: Graph title
+ :param isRgba: True if data is a 3D RGBA image
+ :param str xscale: Scale of X axis in (None, 'linear', 'log')
+ :param str yscale: Scale of Y axis in (None, 'linear', 'log')
+ """
+ self._selector.selectionChanged.disconnect(self._updateImage)
+ self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
+
+ self.__signals = signals
+ self.__signals_names = signals_names
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__title = title
+
+ self._selector.clear()
+ if not isRgba:
+ self._selector.setAxisNames(["Y", "X"])
+ img_ndim = 2
+ else:
+ self._selector.setAxisNames(["Y", "X", "RGB(A) channel"])
+ img_ndim = 3
+ self._selector.setData(signals[0])
+
+ if len(signals[0].shape) <= img_ndim:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._auxSigSlider.setMaximum(len(signals) - 1)
+ if len(signals) > 1:
+ self._auxSigSlider.show()
+ else:
+ self._auxSigSlider.hide()
+ self._auxSigSlider.setValue(0)
+
+ self._axis_scales = xscale, yscale
+ self._updateImage()
+ self._plot.resetZoom()
+
+ self._selector.selectionChanged.connect(self._updateImage)
+ self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
+
+ def _updateImage(self):
+ selection = self._selector.selection()
+ auxSigIdx = self._auxSigSlider.value()
+
+ legend = self.__signals_names[auxSigIdx]
+
+ images = [img[selection] for img in self.__signals]
+ image = images[auxSigIdx]
+
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+
+ if x_axis is None and y_axis is None:
+ xcalib = NoCalibration()
+ ycalib = NoCalibration()
+ else:
+ if x_axis is None:
+ # no calibration
+ x_axis = numpy.arange(image.shape[1])
+ elif numpy.isscalar(x_axis) or len(x_axis) == 1:
+ # constant axis
+ x_axis = x_axis * numpy.ones((image.shape[1], ))
+ elif len(x_axis) == 2:
+ # linear calibration
+ x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
+
+ if y_axis is None:
+ y_axis = numpy.arange(image.shape[0])
+ elif numpy.isscalar(y_axis) or len(y_axis) == 1:
+ y_axis = y_axis * numpy.ones((image.shape[0], ))
+ elif len(y_axis) == 2:
+ y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
+
+ xcalib = ArrayCalibration(x_axis)
+ ycalib = ArrayCalibration(y_axis)
+
+ self._plot.remove(kind=("scatter", "image",))
+ if xcalib.is_affine() and ycalib.is_affine():
+ # regular image
+ xorigin, xscale = xcalib(0), xcalib.get_slope()
+ yorigin, yscale = ycalib(0), ycalib.get_slope()
+ origin = (xorigin, yorigin)
+ scale = (xscale, yscale)
+
+ self._plot.getXAxis().setScale('linear')
+ self._plot.getYAxis().setScale('linear')
+ self._plot.addImage(image, legend=legend,
+ origin=origin, scale=scale,
+ replace=True, resetzoom=False)
+ else:
+ xaxisscale, yaxisscale = self._axis_scales
+
+ if xaxisscale is not None:
+ self._plot.getXAxis().setScale(
+ 'log' if xaxisscale == 'log' else 'linear')
+ if yaxisscale is not None:
+ self._plot.getYAxis().setScale(
+ 'log' if yaxisscale == 'log' else 'linear')
+
+ scatterx, scattery = numpy.meshgrid(x_axis, y_axis)
+ # fixme: i don't think this can handle "irregular" RGBA images
+ self._plot.addScatter(numpy.ravel(scatterx),
+ numpy.ravel(scattery),
+ numpy.ravel(image),
+ legend=legend)
+
+ if self.__title:
+ title = self.__title
+ if len(self.__signals_names) > 1:
+ # Append dataset name only when there is many datasets
+ title += '\n' + self.__signals_names[auxSigIdx]
+ else:
+ title = self.__signals_names[auxSigIdx]
+ self._plot.setGraphTitle(title)
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.clear()
+
+
+class ArrayComplexImagePlot(qt.QWidget):
+ """
+ Widget for plotting an image of complex from a multi-dimensional signal array
+ and two 1D axes array.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last two dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 2) dimensions of
+ the signal array, and the plot is updated to show the image corresponding
+ to the selection.
+
+ If one or both of the axes does not have regularly spaced values, the
+ the image is plotted as a coloured scatter plot.
+ """
+ def __init__(self, parent=None, colormap=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayComplexImagePlot, self).__init__(parent)
+
+ self.__signals = None
+ self.__signals_names = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+
+ self._plot = ComplexImageView(self)
+ if colormap is not None:
+ for mode in (ComplexImageView.ComplexMode.ABSOLUTE,
+ ComplexImageView.ComplexMode.SQUARE_AMPLITUDE,
+ ComplexImageView.ComplexMode.REAL,
+ ComplexImageView.ComplexMode.IMAGINARY):
+ self._plot.setColormap(colormap, mode)
+
+ self._plot.getPlot().getIntensityHistogramAction().setVisible(True)
+ self._plot.setKeepDataAspectRatio(True)
+ maskToolWidget = self._plot.getPlot().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
+ # not closable
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self._selector.selectionChanged.connect(self._updateImage)
+
+ self._auxSigSlider = HorizontalSliderWithBrowser(parent=self)
+ self._auxSigSlider.setMinimum(0)
+ self._auxSigSlider.setValue(0)
+ self._auxSigSlider.valueChanged[int].connect(self._sliderIdxChanged)
+ self._auxSigSlider.setToolTip("Select auxiliary signals")
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._plot)
+ layout.addWidget(self._selector)
+ layout.addWidget(self._auxSigSlider)
+
+ self.setLayout(layout)
+
+ def _sliderIdxChanged(self, value):
+ self._updateImage()
+
+ def getPlot(self):
+ """Returns the plot used for the display
+
+ :rtype: PlotWidget
+ """
+ return self._plot.getPlot()
+
+ def setImageData(self, signals,
+ x_axis=None, y_axis=None,
+ signals_names=None,
+ xlabel=None, ylabel=None,
+ title=None):
+ """
+
+ :param signals: list of n-D datasets, whose last 2 dimensions are used as the
+ image's values, or list of 3D datasets interpreted as RGBA image.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param signals_names: Names for each image, used as subtitle and legend.
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param title: Graph title
+ """
+ self._selector.selectionChanged.disconnect(self._updateImage)
+ self._auxSigSlider.valueChanged.disconnect(self._sliderIdxChanged)
+
+ self.__signals = signals
+ self.__signals_names = signals_names
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__title = title
+
+ self._selector.clear()
+ self._selector.setAxisNames(["Y", "X"])
+ self._selector.setData(signals[0])
+
+ if len(signals[0].shape) <= 2:
+ self._selector.hide()
+ else:
+ self._selector.show()
+
+ self._auxSigSlider.setMaximum(len(signals) - 1)
+ if len(signals) > 1:
+ self._auxSigSlider.show()
+ else:
+ self._auxSigSlider.hide()
+ self._auxSigSlider.setValue(0)
+
+ self._updateImage()
+ self._plot.getPlot().resetZoom()
+
+ self._selector.selectionChanged.connect(self._updateImage)
+ self._auxSigSlider.valueChanged.connect(self._sliderIdxChanged)
+
+ def _updateImage(self):
+ selection = self._selector.selection()
+ auxSigIdx = self._auxSigSlider.value()
+
+ images = [img[selection] for img in self.__signals]
+ image = images[auxSigIdx]
+
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+
+ if x_axis is None and y_axis is None:
+ xcalib = NoCalibration()
+ ycalib = NoCalibration()
+ else:
+ if x_axis is None:
+ # no calibration
+ x_axis = numpy.arange(image.shape[1])
+ elif numpy.isscalar(x_axis) or len(x_axis) == 1:
+ # constant axis
+ x_axis = x_axis * numpy.ones((image.shape[1], ))
+ elif len(x_axis) == 2:
+ # linear calibration
+ x_axis = x_axis[0] * numpy.arange(image.shape[1]) + x_axis[1]
+
+ if y_axis is None:
+ y_axis = numpy.arange(image.shape[0])
+ elif numpy.isscalar(y_axis) or len(y_axis) == 1:
+ y_axis = y_axis * numpy.ones((image.shape[0], ))
+ elif len(y_axis) == 2:
+ y_axis = y_axis[0] * numpy.arange(image.shape[0]) + y_axis[1]
+
+ xcalib = ArrayCalibration(x_axis)
+ ycalib = ArrayCalibration(y_axis)
+
+ self._plot.setData(image)
+ if xcalib.is_affine():
+ xorigin, xscale = xcalib(0), xcalib.get_slope()
+ else:
+ _logger.warning("Unsupported complex image X axis calibration")
+ xorigin, xscale = 0., 1.
+
+ if ycalib.is_affine():
+ yorigin, yscale = ycalib(0), ycalib.get_slope()
+ else:
+ _logger.warning("Unsupported complex image Y axis calibration")
+ yorigin, yscale = 0., 1.
+
+ self._plot.setOrigin((xorigin, yorigin))
+ self._plot.setScale((xscale, yscale))
+
+ if self.__title:
+ title = self.__title
+ if len(self.__signals_names) > 1:
+ # Append dataset name only when there is many datasets
+ title += '\n' + self.__signals_names[auxSigIdx]
+ else:
+ title = self.__signals_names[auxSigIdx]
+ self._plot.setGraphTitle(title)
+ self._plot.getXAxis().setLabel(self.__x_axis_name)
+ self._plot.getYAxis().setLabel(self.__y_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._plot.setData(None)
+
+
+class ArrayStackPlot(qt.QWidget):
+ """
+ Widget for plotting a n-D array (n >= 3) as a stack of images.
+ Three axis arrays can be provided to calibrate the axes.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last 3 dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 3) dimensions of
+ the signal array, and the plot is updated to load the stack corresponding
+ to the selection.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayStackPlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ # the Z, Y, X axes apply to the last three dimensions of the signal
+ # (in that order)
+ self.__z_axis = None
+ self.__z_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+
+ self._stack_view = StackView(self)
+ maskToolWidget = self._stack_view.getPlotWidget().getMaskToolsDockWidget().widget()
+ maskToolWidget.setItemMaskUpdated(True)
+
+ self._hline = qt.QFrame(self)
+ self._hline.setFrameStyle(qt.QFrame.HLine)
+ self._hline.setFrameShadow(qt.QFrame.Sunken)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._stack_view)
+ layout.addWidget(self._hline)
+ layout.addWidget(self._legend)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getStackView(self):
+ """Returns the plot used for the display
+
+ :rtype: StackView
+ """
+ return self._stack_view
+
+ def setStackData(self, signal,
+ x_axis=None, y_axis=None, z_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None, zlabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 3 dimensions are used as the
+ 3D stack values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param z_axis: 1-D dataset used as the image's z. If provided,
+ its lengths must be equal to the length of the 3rd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param zlabel: Label for Z axis
+ :param title: Graph title
+ """
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateStack)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__z_axis = z_axis
+ self.__z_axis_name = zlabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames(["Y", "X", "Z"])
+
+ self._stack_view.setGraphTitle(title or "")
+ # by default, the z axis is the image position (dimension not plotted)
+ self._stack_view.getPlotWidget().getXAxis().setLabel(self.__x_axis_name or "X")
+ self._stack_view.getPlotWidget().getYAxis().setLabel(self.__y_axis_name or "Y")
+
+ self._updateStack()
+
+ ndims = len(signal.shape)
+ self._stack_view.setFirstStackDimension(ndims - 3)
+
+ # the legend label shows the selection slice producing the volume
+ # (only interesting for ndim > 3)
+ if ndims > 3:
+ self._selector.setVisible(True)
+ self._legend.setVisible(True)
+ self._hline.setVisible(True)
+ else:
+ self._selector.setVisible(False)
+ self._legend.setVisible(False)
+ self._hline.setVisible(False)
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateStack)
+ self.__selector_is_connected = True
+
+ @staticmethod
+ def _get_origin_scale(axis):
+ """Assuming axis is a regularly spaced 1D array,
+ return a tuple (origin, scale) where:
+ - origin = axis[0]
+ - scale = (axis[n-1] - axis[0]) / (n -1)
+ :param axis: 1D numpy array
+ :return: Tuple (axis[0], (axis[-1] - axis[0]) / (len(axis) - 1))
+ """
+ return axis[0], (axis[-1] - axis[0]) / (len(axis) - 1)
+
+ def _updateStack(self):
+ """Update displayed stack according to the current axes selector
+ data."""
+ stk = self._selector.selectedData()
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+ z_axis = self.__z_axis
+
+ calibrations = []
+ for axis in [z_axis, y_axis, x_axis]:
+
+ if axis is None:
+ calibrations.append(NoCalibration())
+ elif len(axis) == 2:
+ calibrations.append(
+ LinearCalibration(y_intercept=axis[0],
+ slope=axis[1]))
+ else:
+ calibrations.append(ArrayCalibration(axis))
+
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ self._stack_view.setStack(stk, calibrations=calibrations)
+ self._stack_view.setLabels(
+ labels=[self.__z_axis_name,
+ self.__y_axis_name,
+ self.__x_axis_name])
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self._stack_view.clear()
+
+
+class ArrayVolumePlot(qt.QWidget):
+ """
+ Widget for plotting a n-D array (n >= 3) as a 3D scalar field.
+ Three axis arrays can be provided to calibrate the axes.
+
+ The signal array can have an arbitrary number of dimensions, the only
+ limitation being that the last 3 dimensions must have the same length as
+ the axes arrays.
+
+ Sliders are provided to select indices on the first (n - 3) dimensions of
+ the signal array, and the plot is updated to load the stack corresponding
+ to the selection.
+ """
+ def __init__(self, parent=None):
+ """
+
+ :param parent: Parent QWidget
+ """
+ super(ArrayVolumePlot, self).__init__(parent)
+
+ self.__signal = None
+ self.__signal_name = None
+ # the Z, Y, X axes apply to the last three dimensions of the signal
+ # (in that order)
+ self.__z_axis = None
+ self.__z_axis_name = None
+ self.__y_axis = None
+ self.__y_axis_name = None
+ self.__x_axis = None
+ self.__x_axis_name = None
+
+ from ._VolumeWindow import VolumeWindow
+
+ self._view = VolumeWindow(self)
+
+ self._hline = qt.QFrame(self)
+ self._hline.setFrameStyle(qt.QFrame.HLine)
+ self._hline.setFrameShadow(qt.QFrame.Sunken)
+ self._legend = qt.QLabel(self)
+ self._selector = NumpyAxesSelector(self)
+ self._selector.setNamedAxesSelectorVisibility(False)
+ self.__selector_is_connected = False
+
+ layout = qt.QVBoxLayout()
+ layout.addWidget(self._view)
+ layout.addWidget(self._hline)
+ layout.addWidget(self._legend)
+ layout.addWidget(self._selector)
+
+ self.setLayout(layout)
+
+ def getVolumeView(self):
+ """Returns the plot used for the display
+
+ :rtype: SceneWindow
+ """
+ return self._view
+
+ def setData(self, signal,
+ x_axis=None, y_axis=None, z_axis=None,
+ signal_name=None,
+ xlabel=None, ylabel=None, zlabel=None,
+ title=None):
+ """
+
+ :param signal: n-D dataset, whose last 3 dimensions are used as the
+ 3D stack values.
+ :param x_axis: 1-D dataset used as the image's x coordinates. If
+ provided, its lengths must be equal to the length of the last
+ dimension of ``signal``.
+ :param y_axis: 1-D dataset used as the image's y. If provided,
+ its lengths must be equal to the length of the 2nd to last
+ dimension of ``signal``.
+ :param z_axis: 1-D dataset used as the image's z. If provided,
+ its lengths must be equal to the length of the 3rd to last
+ dimension of ``signal``.
+ :param signal_name: Label used in the legend
+ :param xlabel: Label for X axis
+ :param ylabel: Label for Y axis
+ :param zlabel: Label for Z axis
+ :param title: Graph title
+ """
+ if self.__selector_is_connected:
+ self._selector.selectionChanged.disconnect(self._updateVolume)
+ self.__selector_is_connected = False
+
+ self.__signal = signal
+ self.__signal_name = signal_name or ""
+ self.__x_axis = x_axis
+ self.__x_axis_name = xlabel
+ self.__y_axis = y_axis
+ self.__y_axis_name = ylabel
+ self.__z_axis = z_axis
+ self.__z_axis_name = zlabel
+
+ self._selector.setData(signal)
+ self._selector.setAxisNames(["Y", "X", "Z"])
+
+ self._updateVolume()
+
+ # the legend label shows the selection slice producing the volume
+ # (only interesting for ndim > 3)
+ if signal.ndim > 3:
+ self._selector.setVisible(True)
+ self._legend.setVisible(True)
+ self._hline.setVisible(True)
+ else:
+ self._selector.setVisible(False)
+ self._legend.setVisible(False)
+ self._hline.setVisible(False)
+
+ if not self.__selector_is_connected:
+ self._selector.selectionChanged.connect(self._updateVolume)
+ self.__selector_is_connected = True
+
+ def _updateVolume(self):
+ """Update displayed stack according to the current axes selector
+ data."""
+ x_axis = self.__x_axis
+ y_axis = self.__y_axis
+ z_axis = self.__z_axis
+
+ offset = []
+ scale = []
+ for axis in [x_axis, y_axis, z_axis]:
+ if axis is None:
+ calibration = NoCalibration()
+ elif len(axis) == 2:
+ calibration = LinearCalibration(
+ y_intercept=axis[0], slope=axis[1])
+ else:
+ calibration = ArrayCalibration(axis)
+ if not calibration.is_affine():
+ _logger.warning("Axis has not linear values, ignored")
+ offset.append(0.)
+ scale.append(1.)
+ else:
+ offset.append(calibration(0))
+ scale.append(calibration.get_slope())
+
+ legend = self.__signal_name + "["
+ for sl in self._selector.selection():
+ if sl == slice(None):
+ legend += ":, "
+ else:
+ legend += str(sl) + ", "
+ legend = legend[:-2] + "]"
+ self._legend.setText("Displayed data: " + legend)
+
+ # Update SceneWidget
+ data = self._selector.selectedData()
+
+ volumeView = self.getVolumeView()
+ volumeView.setData(data, offset=offset, scale=scale)
+ volumeView.setAxesLabels(
+ self.__x_axis_name, self.__y_axis_name, self.__z_axis_name)
+
+ def clear(self):
+ old = self._selector.blockSignals(True)
+ self._selector.clear()
+ self._selector.blockSignals(old)
+ self.getVolumeView().clear()
diff --git a/silx/gui/data/NumpyAxesSelector.py b/src/silx/gui/data/NumpyAxesSelector.py
index e6da0d4..e6da0d4 100644
--- a/silx/gui/data/NumpyAxesSelector.py
+++ b/src/silx/gui/data/NumpyAxesSelector.py
diff --git a/src/silx/gui/data/RecordTableView.py b/src/silx/gui/data/RecordTableView.py
new file mode 100644
index 0000000..ea73c62
--- /dev/null
+++ b/src/silx/gui/data/RecordTableView.py
@@ -0,0 +1,439 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module define model and widget to display 1D slices from numpy
+array using compound data types or hdf5 databases.
+"""
+from __future__ import division
+
+import itertools
+import numpy
+from silx.gui import qt
+import silx.io
+from .TextFormatter import TextFormatter
+from silx.gui.widgets.TableWidget import CopySelectedCellsAction
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/08/2018"
+
+
+class _MultiLineItem(qt.QItemDelegate):
+ """Draw a multiline text without hiding anything.
+
+ The paint method display a cell without any wrap. And an editor is
+ available to scroll into the selected cell.
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent of the widget
+ """
+ qt.QItemDelegate.__init__(self, parent)
+ self.__textOptions = qt.QTextOption()
+ self.__textOptions.setFlags(qt.QTextOption.IncludeTrailingSpaces |
+ qt.QTextOption.ShowTabsAndSpaces)
+ self.__textOptions.setWrapMode(qt.QTextOption.NoWrap)
+ self.__textOptions.setAlignment(qt.Qt.AlignTop | qt.Qt.AlignLeft)
+
+ def paint(self, painter, option, index):
+ """
+ Write multiline text without using any wrap or any alignment according
+ to the cell size.
+
+ :param qt.QPainter painter: Painter context used to displayed the cell
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ painter.save()
+
+ # set colors
+ painter.setPen(qt.QPen(qt.Qt.NoPen))
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.highlight()
+ painter.setBrush(brush)
+ else:
+ brush = index.data(qt.Qt.BackgroundRole)
+ if brush is None:
+ # default background color for a cell
+ brush = qt.Qt.white
+ painter.setBrush(brush)
+ painter.drawRect(option.rect)
+
+ if index.isValid():
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.highlightedText()
+ else:
+ brush = index.data(qt.Qt.ForegroundRole)
+ if brush is None:
+ brush = option.palette.text()
+ painter.setPen(qt.QPen(brush.color()))
+ text = index.data(qt.Qt.DisplayRole)
+ painter.drawText(qt.QRectF(option.rect), text, self.__textOptions)
+
+ painter.restore()
+
+ def createEditor(self, parent, option, index):
+ """
+ Returns the widget used to edit the item specified by index for editing.
+
+ We use it not to edit the content but to show the content with a
+ convenient scroll bar.
+
+ :param qt.QWidget parent: Parent of the widget
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ if not index.isValid():
+ return super(_MultiLineItem, self).createEditor(parent, option, index)
+
+ editor = qt.QTextEdit(parent)
+ editor.setReadOnly(True)
+ return editor
+
+ def setEditorData(self, editor, index):
+ """
+ Read data from the model and feed the editor.
+
+ :param qt.QWidget editor: Editor widget
+ :param qt.QIndex index: Index of the data to display
+ """
+ text = index.model().data(index, qt.Qt.EditRole)
+ editor.setText(text)
+
+ def updateEditorGeometry(self, editor, option, index):
+ """
+ Update the geometry of the editor according to the changes of the view.
+
+ :param qt.QWidget editor: Editor widget
+ :param qt.QStyleOptionViewItem option: Control how the editor is shown
+ :param qt.QIndex index: Index of the data to display
+ """
+ editor.setGeometry(option.rect)
+
+
+class RecordTableModel(qt.QAbstractTableModel):
+ """This data model provides access to 1D slices from numpy array using
+ compound data types or hdf5 databases.
+
+ Each entries are displayed in a single row, and each columns contain a
+ specific field of the compound type.
+
+ It also allows to display 1D arrays of simple data types.
+ array.
+
+ :param qt.QObject parent: Parent object
+ :param numpy.ndarray data: A numpy array or a h5py dataset
+ """
+
+ MAX_NUMBER_OF_ROWS = 10e6
+ """Maximum number of display values of the dataset"""
+
+ def __init__(self, parent=None, data=None):
+ qt.QAbstractTableModel.__init__(self, parent)
+
+ self.__data = None
+ self.__is_array = False
+ self.__fields = None
+ self.__formatter = None
+ self.__editFormatter = None
+ self.setFormatter(TextFormatter(self))
+
+ # set _data
+ self.setArrayData(data)
+
+ # Methods to be implemented to subclass QAbstractTableModel
+ def rowCount(self, parent_idx=None):
+ """Returns number of rows to be displayed in table"""
+ if self.__data is None:
+ return 0
+ elif not self.__is_array:
+ return 1
+ else:
+ return min(len(self.__data), self.MAX_NUMBER_OF_ROWS)
+
+ def columnCount(self, parent_idx=None):
+ """Returns number of columns to be displayed in table"""
+ if self.__fields is None:
+ return 1
+ else:
+ return len(self.__fields)
+
+ def __clippedData(self, role=qt.Qt.DisplayRole):
+ """Return data for cells representing clipped data"""
+ if role == qt.Qt.DisplayRole:
+ return "..."
+ elif role == qt.Qt.ToolTipRole:
+ return "Dataset is too large: display is clipped"
+ else:
+ return None
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ """QAbstractTableModel method to access data values
+ in the format ready to be displayed"""
+ if not index.isValid():
+ return None
+
+ if self.__data is None:
+ return None
+
+ # Special display of one before last data for clipped table
+ if self.__isClipped() and index.row() == self.rowCount() - 2:
+ return self.__clippedData(role)
+
+ if self.__is_array:
+ row = index.row()
+ if row >= self.rowCount():
+ return None
+ elif self.__isClipped() and row == self.rowCount() - 1:
+ # Clipped array, display last value at the end
+ data = self.__data[-1]
+ else:
+ data = self.__data[row]
+ else:
+ if index.row() > 0:
+ return None
+ data = self.__data
+
+ if self.__fields is not None:
+ if index.column() >= len(self.__fields):
+ return None
+ key = self.__fields[index.column()][1]
+ data = data[key[0]]
+ if len(key) > 1:
+ data = data[key[1]]
+
+ # no dtype in case of 1D array of unicode objects (#2093)
+ dtype = getattr(data, "dtype", None)
+
+ if role == qt.Qt.DisplayRole:
+ return self.__formatter.toString(data, dtype=dtype)
+ elif role == qt.Qt.EditRole:
+ return self.__editFormatter.toString(data, dtype=dtype)
+ return None
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ """Returns the 0-based row or column index, for display in the
+ horizontal and vertical headers"""
+ if section == -1:
+ # PyQt4 send -1 when there is columns but no rows
+ return None
+
+ # Handle clipping of huge tables
+ if (self.__isClipped() and
+ orientation == qt.Qt.Vertical and
+ section == self.rowCount() - 2):
+ return self.__clippedData(role)
+
+ if role == qt.Qt.DisplayRole:
+ if orientation == qt.Qt.Vertical:
+ if not self.__is_array:
+ return "Scalar"
+ elif section == self.MAX_NUMBER_OF_ROWS - 1:
+ return str(len(self.__data) - 1)
+ else:
+ return str(section)
+ if orientation == qt.Qt.Horizontal:
+ if self.__fields is None:
+ if section == 0:
+ return "Data"
+ else:
+ return None
+ else:
+ if section < len(self.__fields):
+ return self.__fields[section][0]
+ else:
+ return None
+ return None
+
+ def flags(self, index):
+ """QAbstractTableModel method to inform the view whether data
+ is editable or not.
+ """
+ return qt.QAbstractTableModel.flags(self, index)
+
+ def __isClipped(self) -> bool:
+ """Returns whether the displayed array is clipped or not"""
+ return self.__data is not None and self.__is_array and len(self.__data) > self.MAX_NUMBER_OF_ROWS
+
+ def setArrayData(self, data):
+ """Set the data array and the viewing perspective.
+
+ You can set ``copy=False`` if you need more performances, when dealing
+ with a large numpy array. In this case, a simple reference to the data
+ is used to access the data, rather than a copy of the array.
+
+ .. warning::
+
+ Any change to the data model will affect your original data
+ array, when using a reference rather than a copy..
+
+ :param data: 1D numpy array, or any object that can be
+ converted to a numpy array using ``numpy.array(data)`` (e.g.
+ a nested sequence).
+ """
+ self.beginResetModel()
+
+ self.__data = data
+ if isinstance(data, numpy.ndarray):
+ self.__is_array = True
+ elif silx.io.is_dataset(data) and data.shape != tuple():
+ self.__is_array = True
+ else:
+ self.__is_array = False
+
+ self.__fields = []
+ if data is not None:
+ if data.dtype.fields is not None:
+ fields = sorted(data.dtype.fields.items(), key=lambda e: e[1][1])
+ for name, (dtype, _index) in fields:
+ if dtype.shape != tuple():
+ keys = itertools.product(*[range(x) for x in dtype.shape])
+ for key in keys:
+ label = "%s%s" % (name, list(key))
+ array_key = (name, key)
+ self.__fields.append((label, array_key))
+ else:
+ self.__fields.append((name, (name,)))
+ else:
+ self.__fields = None
+
+ self.endResetModel()
+
+ def arrayData(self):
+ """Returns the internal data.
+
+ :rtype: numpy.ndarray of h5py.Dataset
+ """
+ return self.__data
+
+ def setFormatter(self, formatter):
+ """Set the formatter object to be used to display data from the model
+
+ :param TextFormatter formatter: Formatter to use
+ """
+ if formatter is self.__formatter:
+ return
+
+ self.beginResetModel()
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+
+ self.__formatter = formatter
+ self.__editFormatter = TextFormatter(formatter)
+ self.__editFormatter.setUseQuoteForText(False)
+
+ if self.__formatter is not None:
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ self.endResetModel()
+
+ def getFormatter(self):
+ """Returns the text formatter used.
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def __formatChanged(self):
+ """Called when the format changed.
+ """
+ self.__editFormatter = TextFormatter(self, self.getFormatter())
+ self.__editFormatter.setUseQuoteForText(False)
+ self.reset()
+
+
+class _ShowEditorProxyModel(qt.QIdentityProxyModel):
+ """
+ Allow to custom the flag edit of the model
+ """
+
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QObject arent: parent object
+ """
+ super(_ShowEditorProxyModel, self).__init__(parent)
+ self.__forceEditable = False
+
+ def flags(self, index):
+ flag = qt.QIdentityProxyModel.flags(self, index)
+ if self.__forceEditable:
+ flag = flag | qt.Qt.ItemIsEditable
+ return flag
+
+ def forceCellEditor(self, show):
+ """
+ Enable the editable flag to allow to display cell editor.
+ """
+ if self.__forceEditable == show:
+ return
+ self.beginResetModel()
+ self.__forceEditable = show
+ self.endResetModel()
+
+
+class RecordTableView(qt.QTableView):
+ """TableView using DatabaseTableModel as default model.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param qt.QWidget parent: parent QWidget
+ """
+ qt.QTableView.__init__(self, parent)
+
+ model = _ShowEditorProxyModel(self)
+ self._model = RecordTableModel()
+ model.setSourceModel(self._model)
+ self.setModel(model)
+
+ self.__multilineView = _MultiLineItem(self)
+ self.setEditTriggers(qt.QAbstractItemView.AllEditTriggers)
+ self._copyAction = CopySelectedCellsAction(self)
+ self.addAction(self._copyAction)
+
+ def copy(self):
+ self._copyAction.trigger()
+
+ def setArrayData(self, data):
+ model = self.model()
+ sourceModel = model.sourceModel()
+ sourceModel.setArrayData(data)
+
+ if data is not None:
+ if issubclass(data.dtype.type, (numpy.string_, numpy.unicode_)):
+ # TODO it would be nice to also fix fields
+ # but using it only for string array is already very useful
+ self.setItemDelegateForColumn(0, self.__multilineView)
+ model.forceCellEditor(True)
+ else:
+ self.setItemDelegateForColumn(0, None)
+ model.forceCellEditor(False)
diff --git a/src/silx/gui/data/TextFormatter.py b/src/silx/gui/data/TextFormatter.py
new file mode 100644
index 0000000..b6baca4
--- /dev/null
+++ b/src/silx/gui/data/TextFormatter.py
@@ -0,0 +1,386 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a class sharred by widget from the
+data module to format data as text in the same way."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "24/07/2018"
+
+import logging
+import numbers
+
+import numpy
+
+from silx.gui import qt
+
+import h5py
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TextFormatter(qt.QObject):
+ """Formatter to convert data to string.
+
+ The method :meth:`toString` returns a formatted string from an input data
+ using parameters set to this object.
+
+ It support most python and numpy data, expecting dictionary. Unsupported
+ data are displayed using the string representation of the object (`str`).
+
+ It provides a set of parameters to custom the formatting of integer and
+ float values (:meth:`setIntegerFormat`, :meth:`setFloatFormat`).
+
+ It also allows to custom the use of quotes to display text data
+ (:meth:`setUseQuoteForText`), and custom unit used to display imaginary
+ numbers (:meth:`setImaginaryUnit`).
+
+ The object emit an event `formatChanged` every time a parametter is
+ changed.
+ """
+
+ formatChanged = qt.Signal()
+ """Emitted when properties of the formatter change."""
+
+ def __init__(self, parent=None, formatter=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Owner of the object
+ :param TextFormatter formatter: Instantiate this object from the
+ formatter
+ """
+ qt.QObject.__init__(self, parent)
+ if formatter is not None:
+ self.__integerFormat = formatter.integerFormat()
+ self.__floatFormat = formatter.floatFormat()
+ self.__useQuoteForText = formatter.useQuoteForText()
+ self.__imaginaryUnit = formatter.imaginaryUnit()
+ self.__enumFormat = formatter.enumFormat()
+ else:
+ self.__integerFormat = "%d"
+ self.__floatFormat = "%g"
+ self.__useQuoteForText = True
+ self.__imaginaryUnit = u"j"
+ self.__enumFormat = u"%(name)s(%(value)d)"
+
+ def integerFormat(self):
+ """Returns the format string controlling how the integer data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__integerFormat
+
+ def setIntegerFormat(self, value):
+ """Set format string controlling how the integer data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%d", "%i", "%08i").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__integerFormat == value:
+ return
+ self.__integerFormat = value
+ self.formatChanged.emit()
+
+ def floatFormat(self):
+ """Returns the format string controlling how the floating-point data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__floatFormat
+
+ def setFloatFormat(self, value):
+ """Set format string controlling how the floating-point data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%.3f", "%d", "%-10.2f",
+ "%10.3e").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__floatFormat == value:
+ return
+ self.__floatFormat = value
+ self.formatChanged.emit()
+
+ def useQuoteForText(self):
+ """Returns true if the string data are formatted using double quotes.
+
+ Else, no quotes are used.
+ """
+ return self.__integerFormat
+
+ def setUseQuoteForText(self, useQuote):
+ """Set the use of quotes to delimit string data.
+
+ :param bool useQuote: True to use quotes.
+ """
+ if self.__useQuoteForText == useQuote:
+ return
+ self.__useQuoteForText = useQuote
+ self.formatChanged.emit()
+
+ def imaginaryUnit(self):
+ """Returns the unit display for imaginary numbers.
+
+ :rtype: str
+ """
+ return self.__imaginaryUnit
+
+ def setImaginaryUnit(self, imaginaryUnit):
+ """Set the unit display for imaginary numbers.
+
+ :param str imaginaryUnit: Unit displayed after imaginary numbers
+ """
+ if self.__imaginaryUnit == imaginaryUnit:
+ return
+ self.__imaginaryUnit = imaginaryUnit
+ self.formatChanged.emit()
+
+ def setEnumFormat(self, value):
+ """Set format string controlling how the enum data are
+ formated by this object.
+
+ :param str value: Format string (e.g. "%(name)s(%(value)d)").
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+ """
+ if self.__enumFormat == value:
+ return
+ self.__enumFormat = value
+ self.formatChanged.emit()
+
+ def enumFormat(self):
+ """Returns the format string controlling how the enum data
+ are formated by this object.
+
+ This is the C-style format string used by python when formatting
+ strings with the modulus operator.
+
+ :rtype: str
+ """
+ return self.__enumFormat
+
+ def __formatText(self, text):
+ if self.__useQuoteForText:
+ text = "\"%s\"" % text.replace("\\", "\\\\").replace("\"", "\\\"")
+ return text
+
+ def __formatBinary(self, data):
+ if isinstance(data, numpy.void):
+ data = data.item()
+ if isinstance(data, numpy.ndarray):
+ # Before numpy 1.15.0 the item API was returning a numpy array
+ data = data.astype(numpy.uint8)
+ else:
+ # Now it is supposed to be a bytes type
+ pass
+ data = ["\\x%02X" % d for d in data]
+ if self.__useQuoteForText:
+ return "b\"%s\"" % "".join(data)
+ else:
+ return "".join(data)
+
+ def __formatSafeAscii(self, data):
+ data = [chr(d) if (d > 0x20 and d < 0x7F) else "\\x%02X" % d for d in data]
+ if self.__useQuoteForText:
+ data = [c if c != '"' else "\\" + c for c in data]
+ return "b\"%s\"" % "".join(data)
+ else:
+ return "".join(data)
+
+ def __formatCharString(self, data):
+ """Format text of char.
+
+ From the specifications we expect to have ASCII, but we also allow
+ CP1252 in some ceases as fallback.
+
+ If no encoding fits, it will display a readable ASCII chars, with
+ escaped chars (using the python syntax) for non decoded characters.
+
+ :param data: A binary string of char expected in ASCII
+ :rtype: str
+ """
+ try:
+ text = "%s" % data.decode("ascii")
+ return self.__formatText(text)
+ except UnicodeDecodeError:
+ # Here we can spam errors, this is definitly a badly
+ # generated file
+ _logger.error("Invalid ASCII string %s.", data)
+ if data == b"\xB0":
+ _logger.error("Fallback using cp1252 encoding")
+ return self.__formatText(u"\u00B0")
+ return self.__formatSafeAscii(data)
+
+ def __formatH5pyObject(self, data, dtype):
+ # That's an HDF5 object
+ ref = h5py.check_dtype(ref=dtype)
+ if ref is not None:
+ if bool(data):
+ return "REF"
+ else:
+ return "NULL_REF"
+ vlen = h5py.check_dtype(vlen=dtype)
+ if vlen is not None:
+ if vlen == str:
+ # HDF5 UTF8
+ # With h5py>=3 reading dataset returns bytes
+ if isinstance(data, (bytes, numpy.bytes_)):
+ try:
+ data = data.decode("utf-8")
+ except UnicodeDecodeError:
+ self.__formatSafeAscii(data)
+ return self.__formatText(data)
+ elif vlen == bytes:
+ # HDF5 ASCII
+ return self.__formatCharString(data)
+ elif isinstance(vlen, numpy.dtype):
+ return self.toString(data, vlen)
+ return None
+
+ def toString(self, data, dtype=None):
+ """Format a data into a string using formatter options
+
+ :param object data: Data to render
+ :param dtype: enforce a dtype (mostly used to remember the h5py dtype,
+ special h5py dtypes are not propagated from array to items)
+ :rtype: str
+ """
+ if isinstance(data, tuple):
+ text = [self.toString(d) for d in data]
+ return "(" + " ".join(text) + ")"
+ elif isinstance(data, list):
+ text = [self.toString(d) for d in data]
+ return "[" + " ".join(text) + "]"
+ elif isinstance(data, numpy.ndarray):
+ if dtype is None:
+ dtype = data.dtype
+ if data.shape == ():
+ # it is a scaler
+ return self.toString(data[()], dtype)
+ else:
+ text = [self.toString(d, dtype) for d in data]
+ return "[" + " ".join(text) + "]"
+ if dtype is not None and dtype.kind == 'O':
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
+ elif isinstance(data, numpy.void):
+ if dtype is None:
+ dtype = data.dtype
+ if dtype.fields is not None:
+ text = []
+ for index, field in enumerate(dtype.fields.items()):
+ text.append(field[0] + ":" + self.toString(data[index], field[1][0]))
+ return "(" + " ".join(text) + ")"
+ return self.__formatBinary(data)
+ elif isinstance(data, (numpy.unicode_, str)):
+ return self.__formatText(data)
+ elif isinstance(data, (numpy.string_, bytes)):
+ if dtype is None and hasattr(data, "dtype"):
+ dtype = data.dtype
+ if dtype is not None:
+ # Maybe a sub item from HDF5
+ if dtype.kind == 'S':
+ return self.__formatCharString(data)
+ elif dtype.kind == 'O':
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
+ try:
+ # Try ascii/utf-8
+ text = "%s" % data.decode("utf-8")
+ return self.__formatText(text)
+ except UnicodeDecodeError:
+ pass
+ return self.__formatBinary(data)
+ elif isinstance(data, str):
+ text = "%s" % data
+ return self.__formatText(text)
+ elif isinstance(data, (numpy.integer)):
+ if dtype is None:
+ dtype = data.dtype
+ enumType = h5py.check_dtype(enum=dtype)
+ if enumType is not None:
+ for key, value in enumType.items():
+ if value == data:
+ result = {}
+ result["name"] = key
+ result["value"] = data
+ return self.__enumFormat % result
+ return self.__integerFormat % data
+ elif isinstance(data, (numbers.Integral)):
+ return self.__integerFormat % data
+ elif isinstance(data, (numbers.Real, numpy.floating)):
+ # It have to be done before complex checking
+ return self.__floatFormat % data
+ elif isinstance(data, (numpy.complexfloating, numbers.Complex)):
+ text = ""
+ if data.real != 0:
+ text += self.__floatFormat % data.real
+ if data.real != 0 and data.imag != 0:
+ if data.imag < 0:
+ template = self.__floatFormat + " - " + self.__floatFormat + self.__imaginaryUnit
+ params = (data.real, -data.imag)
+ else:
+ template = self.__floatFormat + " + " + self.__floatFormat + self.__imaginaryUnit
+ params = (data.real, data.imag)
+ else:
+ if data.imag != 0:
+ template = self.__floatFormat + self.__imaginaryUnit
+ params = (data.imag)
+ else:
+ template = self.__floatFormat
+ params = (data.real)
+ return template % params
+ elif isinstance(data, h5py.h5r.Reference):
+ dtype = h5py.special_dtype(ref=h5py.Reference)
+ text = self.__formatH5pyObject(data, dtype)
+ return text
+ elif isinstance(data, h5py.h5r.RegionReference):
+ dtype = h5py.special_dtype(ref=h5py.RegionReference)
+ text = self.__formatH5pyObject(data, dtype)
+ return text
+ elif isinstance(data, numpy.object_) or dtype is not None:
+ if dtype is None:
+ dtype = data.dtype
+ text = self.__formatH5pyObject(data, dtype)
+ if text is not None:
+ return text
+ # That's a numpy object
+ return str(data)
+ return str(data)
diff --git a/silx/gui/data/_RecordPlot.py b/src/silx/gui/data/_RecordPlot.py
index 5be792f..5be792f 100644
--- a/silx/gui/data/_RecordPlot.py
+++ b/src/silx/gui/data/_RecordPlot.py
diff --git a/silx/gui/data/_VolumeWindow.py b/src/silx/gui/data/_VolumeWindow.py
index 03b6876..03b6876 100644
--- a/silx/gui/data/_VolumeWindow.py
+++ b/src/silx/gui/data/_VolumeWindow.py
diff --git a/silx/gui/data/__init__.py b/src/silx/gui/data/__init__.py
index 560062d..560062d 100644
--- a/silx/gui/data/__init__.py
+++ b/src/silx/gui/data/__init__.py
diff --git a/silx/gui/data/setup.py b/src/silx/gui/data/setup.py
index 23ccbdd..23ccbdd 100644
--- a/silx/gui/data/setup.py
+++ b/src/silx/gui/data/setup.py
diff --git a/src/silx/gui/data/test/__init__.py b/src/silx/gui/data/test/__init__.py
new file mode 100644
index 0000000..7790ee5
--- /dev/null
+++ b/src/silx/gui/data/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/data/test/test_arraywidget.py b/src/silx/gui/data/test/test_arraywidget.py
new file mode 100644
index 0000000..c84a34f
--- /dev/null
+++ b/src/silx/gui/data/test/test_arraywidget.py
@@ -0,0 +1,316 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import os
+import tempfile
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.data import ArrayTableWidget
+from silx.gui.data.ArrayTableModel import ArrayTableModel
+from silx.gui.utils.testutils import TestCaseQt
+
+import h5py
+
+
+class TestArrayWidget(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestArrayWidget, self).setUp()
+ self.aw = ArrayTableWidget.ArrayTableWidget()
+
+ def tearDown(self):
+ del self.aw
+ super(TestArrayWidget, self).tearDown()
+
+ def testShow(self):
+ """test for errors"""
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ def testSetData0D(self):
+ a = 1
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # scalar/0D data has no frame index
+ self.assertEqual(len(self.aw.model._index), 0)
+ # and no perspective
+ self.assertEqual(len(self.aw.model._perspective), 0)
+
+ def testSetData1D(self):
+ a = [1, 2]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # 1D data has no frame index
+ self.assertEqual(len(self.aw.model._index), 0)
+ # and no perspective
+ self.assertEqual(len(self.aw.model._perspective), 0)
+
+ def testSetData4D(self):
+ a = numpy.reshape(numpy.linspace(0.213, 1.234, 1250),
+ (5, 5, 5, 10))
+ self.aw.setArrayData(a)
+
+ # default perspective (0, 1)
+ self.assertEqual(list(self.aw.model._perspective),
+ [0, 1])
+ self.aw.setPerspective((1, 3))
+ self.assertEqual(list(self.aw.model._perspective),
+ [1, 3])
+
+ b = self.aw.getData(copy=True)
+ self.assertTrue(numpy.array_equal(a, b))
+
+ # 4D data has a 2-tuple as frame index
+ self.assertEqual(len(self.aw.model._index), 2)
+ # default index is (0, 0)
+ self.assertEqual(list(self.aw.model._index),
+ [0, 0])
+ self.aw.setFrameIndex((3, 1))
+
+ self.assertEqual(list(self.aw.model._index),
+ [3, 1])
+
+ def testColors(self):
+ a = numpy.arange(256, dtype=numpy.uint8)
+ self.aw.setArrayData(a)
+
+ bgcolor = numpy.empty(a.shape + (3,), dtype=numpy.uint8)
+ # Black & white palette
+ bgcolor[..., 0] = a
+ bgcolor[..., 1] = a
+ bgcolor[..., 2] = a
+
+ fgcolor = numpy.bitwise_xor(bgcolor, 255)
+
+ self.aw.setArrayColors(bgcolor, fgcolor)
+
+ # test colors are as expected in model
+ for i in range(256):
+ # all RGB channels for BG equal to data value
+ self.assertEqual(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.BackgroundRole),
+ qt.QColor(i, i, i),
+ "Unexpected background color"
+ )
+
+ # all RGB channels for FG equal to XOR(data value, 255)
+ self.assertEqual(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.ForegroundRole),
+ qt.QColor(i ^ 255, i ^ 255, i ^ 255),
+ "Unexpected text color"
+ )
+
+ # test colors are reset to None when a new data array is loaded
+ # with different shape
+ self.aw.setArrayData(numpy.arange(300))
+
+ for i in range(300):
+ # all RGB channels for BG equal to data value
+ self.assertIsNone(
+ self.aw.model.data(self.aw.model.index(0, i),
+ role=qt.Qt.BackgroundRole))
+
+ def testDefaultFlagNotEditable(self):
+ """editable should be False by default, in setArrayData"""
+ self.aw.setArrayData([[0]])
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testFlagEditable(self):
+ self.aw.setArrayData([[0]], editable=True)
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertTrue(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testFlagNotEditable(self):
+ self.aw.setArrayData([[0]], editable=False)
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ def testReferenceReturned(self):
+ """when setting the data with copy=False and
+ retrieving it with getData(copy=False), we should recover
+ the same original object.
+ """
+ # n-D (n >=2)
+ a0 = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
+ (10, 10, 10))
+ self.aw.setArrayData(a0, copy=False)
+ a1 = self.aw.getData(copy=False)
+
+ self.assertIs(a0, a1)
+
+ # 1D
+ b0 = numpy.linspace(0.213, 1.234, 1000)
+ self.aw.setArrayData(b0, copy=False)
+ b1 = self.aw.getData(copy=False)
+ self.assertIs(b0, b1)
+
+ def testClipping(self):
+ """Test clipping of large arrays"""
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ data = numpy.arange(ArrayTableModel.MAX_NUMBER_OF_SECTIONS + 10)
+
+ for shape in [(1, -1), (-1, 1)]:
+ with self.subTest(shape=shape):
+ self.aw.setArrayData(data.reshape(shape), editable=True)
+ self.qapp.processEvents()
+
+
+class TestH5pyArrayWidget(TestCaseQt):
+ """Basic test for ArrayTableWidget with a dataset.
+
+ Test flags, for dataset open in read-only or read-write modes"""
+ def setUp(self):
+ super(TestH5pyArrayWidget, self).setUp()
+ self.aw = ArrayTableWidget.ArrayTableWidget()
+ self.data = numpy.reshape(numpy.linspace(0.213, 1.234, 1000),
+ (10, 10, 10))
+ # create an h5py file with a dataset
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "array.h5")
+ h5f = h5py.File(self.h5_fname, mode='w')
+ h5f["my_array"] = self.data
+ h5f["my_scalar"] = 3.14
+ h5f["my_1D_array"] = numpy.array(numpy.arange(1000))
+ h5f.close()
+
+ def tearDown(self):
+ del self.aw
+ os.unlink(self.h5_fname)
+ os.rmdir(self.tempdir)
+ super(TestH5pyArrayWidget, self).tearDown()
+
+ def testShow(self):
+ self.aw.show()
+ self.qWaitForWindowExposed(self.aw)
+
+ def testReadOnly(self):
+ """Open H5 dataset in read-only mode, ensure the model is not editable."""
+ h5f = h5py.File(self.h5_fname, "r")
+ a = h5f["my_array"]
+ # ArrayTableModel relies on following condition
+ self.assertTrue(a.file.mode == "r")
+
+ self.aw.setArrayData(a, copy=False, editable=True)
+
+ self.assertIsInstance(a, h5py.Dataset) # simple sanity check
+ # internal representation must be a reference to original data (copy=False)
+ self.assertIsInstance(self.aw.model._array, h5py.Dataset)
+ self.assertTrue(self.aw.model._array.file.mode == "r")
+
+ b = self.aw.getData()
+ self.assertTrue(numpy.array_equal(self.data, b))
+
+ # model must have detected read-only dataset and disabled editing
+ self.assertFalse(self.aw.model._editable)
+ idx = self.aw.model.createIndex(0, 0)
+ self.assertFalse(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+
+ # force editing read-only datasets raises IOError
+ self.assertRaises(IOError, self.aw.model.setData,
+ idx, 123.4, role=qt.Qt.EditRole)
+ h5f.close()
+
+ def testReadWrite(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_array"]
+ self.assertTrue(a.file.mode == "r+")
+
+ self.aw.setArrayData(a, copy=False, editable=True)
+ b = self.aw.getData(copy=False)
+ self.assertTrue(numpy.array_equal(self.data, b))
+
+ idx = self.aw.model.createIndex(0, 0)
+ # model is editable
+ self.assertTrue(
+ self.aw.model.flags(idx) & qt.Qt.ItemIsEditable)
+ h5f.close()
+
+ def testSetData0D(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_scalar"]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ h5f.close()
+
+ def testSetData1D(self):
+ h5f = h5py.File(self.h5_fname, "r+")
+ a = h5f["my_1D_array"]
+ self.aw.setArrayData(a)
+ b = self.aw.getData(copy=True)
+
+ self.assertTrue(numpy.array_equal(a, b))
+
+ h5f.close()
+
+ def testReferenceReturned(self):
+ """when setting the data with copy=False and
+ retrieving it with getData(copy=False), we should recover
+ the same original object.
+
+ This only works for array with at least 2D. For 1D and 0D
+ arrays, a view is created at some point, which in the case
+ of an hdf5 dataset creates a copy."""
+ h5f = h5py.File(self.h5_fname, "r+")
+
+ # n-D
+ a0 = h5f["my_array"]
+ self.aw.setArrayData(a0, copy=False)
+ a1 = self.aw.getData(copy=False)
+ self.assertIs(a0, a1)
+
+ # 1D
+ b0 = h5f["my_1D_array"]
+ self.aw.setArrayData(b0, copy=False)
+ b1 = self.aw.getData(copy=False)
+ self.assertIs(b0, b1)
+
+ h5f.close()
diff --git a/src/silx/gui/data/test/test_dataviewer.py b/src/silx/gui/data/test/test_dataviewer.py
new file mode 100644
index 0000000..30b76ce
--- /dev/null
+++ b/src/silx/gui/data/test/test_dataviewer.py
@@ -0,0 +1,304 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "19/02/2019"
+
+import os
+import tempfile
+import pytest
+from contextlib import contextmanager
+
+import numpy
+from ..DataViewer import DataViewer
+from ..DataViews import DataView
+from .. import DataViews
+
+from silx.gui import qt
+
+from silx.gui.data.DataViewerFrame import DataViewerFrame
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+import h5py
+
+
+class _DataViewMock(DataView):
+ """Dummy view to display nothing"""
+
+ def __init__(self, parent):
+ DataView.__init__(self, parent)
+
+ def axesNames(self, data, info):
+ return []
+
+ def createWidget(self, parent):
+ return qt.QLabel(parent)
+
+ def getDataPriority(self, data, info):
+ return 0
+
+
+class _TestAbstractDataViewer(TestCaseQt):
+ __test__ = False # ignore abstract class
+
+ def create_widget(self):
+ # Avoid to raise an error when testing the full module
+ self.skipTest("Not implemented")
+
+ @contextmanager
+ def h5_temporary_file(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ h5file["data"] = data
+ yield h5file
+ # clean up
+ h5file.close()
+ os.unlink(tmp_name)
+
+ def test_text_data(self):
+ data_list = ["aaa", int, 8, self]
+ widget = self.create_widget()
+ for data in data_list:
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+
+ def test_plot_1d_data(self):
+ data = numpy.arange(3 ** 1)
+ data.shape = [3] * 1
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.PLOT1D_MODE, availableModes)
+
+ def test_image_data(self):
+ data = numpy.arange(3 ** 2)
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.IMAGE_MODE, availableModes)
+
+ def test_image_bool(self):
+ data = numpy.zeros((10, 10), dtype=bool)
+ data[::2, ::2] = True
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.IMAGE_MODE, availableModes)
+
+ def test_image_complex_data(self):
+ data = numpy.arange(3 ** 2, dtype=numpy.complex64)
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+ self.assertIn(DataViews.IMAGE_MODE, availableModes)
+
+ def test_plot_3d_data(self):
+ data = numpy.arange(3 ** 3)
+ data.shape = [3] * 3
+ widget = self.create_widget()
+ widget.setData(data)
+ availableModes = set([v.modeId() for v in widget.currentAvailableViews()])
+ try:
+ import silx.gui.plot3d # noqa
+ self.assertIn(DataViews.PLOT3D_MODE, availableModes)
+ except ImportError:
+ self.assertIn(DataViews.STACK_MODE, availableModes)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayMode())
+
+ def test_array_1d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 1))
+ data.shape = [3] * 1
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_array_2d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 2))
+ data.shape = [3] * 2
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_array_4d_data(self):
+ data = numpy.array(["aaa"] * (3 ** 4))
+ data.shape = [3] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_record_4d_data(self):
+ data = numpy.zeros(3 ** 4, dtype='3int8, float32, (2,3)float64')
+ data.shape = [3] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ self.assertEqual(DataViews.RAW_MODE, widget.displayedView().modeId())
+
+ def test_3d_h5_dataset(self):
+ with self.h5_temporary_file() as h5file:
+ dataset = h5file["data"]
+ widget = self.create_widget()
+ widget.setData(dataset)
+
+ def test_data_event(self):
+ listener = SignalListener()
+ widget = self.create_widget()
+ widget.dataChanged.connect(listener)
+ widget.setData(10)
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 2)
+
+ def test_display_mode_event(self):
+ listener = SignalListener()
+ widget = self.create_widget()
+ widget.displayedViewChanged.connect(listener)
+ widget.setData(10)
+ widget.setData(None)
+ modes = [v.modeId() for v in listener.arguments(argumentIndex=0)]
+ self.assertEqual(modes, [DataViews.RAW_MODE, DataViews.EMPTY_MODE])
+ listener.clear()
+
+ def test_change_display_mode(self):
+ data = numpy.arange(10 ** 4)
+ data.shape = [10] * 4
+ widget = self.create_widget()
+ widget.setData(data)
+ widget.setDisplayMode(DataViews.PLOT1D_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.PLOT1D_MODE)
+ widget.setDisplayMode(DataViews.IMAGE_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.IMAGE_MODE)
+ widget.setDisplayMode(DataViews.RAW_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.RAW_MODE)
+ widget.setDisplayMode(DataViews.EMPTY_MODE)
+ self.assertEqual(widget.displayedView().modeId(), DataViews.EMPTY_MODE)
+
+ def test_create_default_views(self):
+ widget = self.create_widget()
+ views = widget.createDefaultViews()
+ self.assertTrue(len(views) > 0)
+
+ def test_add_view(self):
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ widget.addView(view)
+ self.assertTrue(view in widget.availableViews())
+ self.assertTrue(view in widget.currentAvailableViews())
+
+ def test_remove_view(self):
+ widget = self.create_widget()
+ widget.setData("foobar")
+ view = widget.currentAvailableViews()[0]
+ widget.removeView(view)
+ self.assertTrue(view not in widget.availableViews())
+ self.assertTrue(view not in widget.currentAvailableViews())
+
+ def test_replace_view(self):
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ widget.replaceView(DataViews.RAW_MODE,
+ view)
+ self.assertIsNone(widget.getViewFromModeId(DataViews.RAW_MODE))
+ self.assertTrue(view in widget.availableViews())
+ self.assertTrue(view in widget.currentAvailableViews())
+
+ def test_replace_view_in_composite(self):
+ # replace a view that is a child of a composite view
+ widget = self.create_widget()
+ view = _DataViewMock(widget)
+ replaced = widget.replaceView(DataViews.NXDATA_INVALID_MODE,
+ view)
+ self.assertTrue(replaced)
+ nxdata_view = widget.getViewFromModeId(DataViews.NXDATA_MODE)
+ self.assertNotIn(DataViews.NXDATA_INVALID_MODE,
+ [v.modeId() for v in nxdata_view.getViews()])
+ self.assertTrue(view in nxdata_view.getViews())
+
+
+class TestDataViewer(_TestAbstractDataViewer):
+ __test__ = True # because _TestAbstractDataViewer is ignored
+ def create_widget(self):
+ return DataViewer()
+
+
+class TestDataViewerFrame(_TestAbstractDataViewer):
+ __test__ = True # because _TestAbstractDataViewer is ignored
+ def create_widget(self):
+ return DataViewerFrame()
+
+
+class TestDataView(TestCaseQt):
+
+ def createComplexData(self):
+ line = [1, 2j, 3 + 3j, 4]
+ image = [line, line, line, line]
+ cube = [image, image, image, image]
+ data = numpy.array(cube, dtype=numpy.complex64)
+ return data
+
+ def createDataViewWithData(self, dataViewClass, data):
+ viewer = dataViewClass(None)
+ widget = viewer.getWidget()
+ viewer.setData(data)
+ return widget
+
+ def testCurveWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot1dView
+ widget = self.createDataViewWithData(dataViewClass, data[0, 0])
+ self.qWaitForWindowExposed(widget)
+
+ def testImageWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot2dView
+ widget = self.createDataViewWithData(dataViewClass, data[0])
+ self.qWaitForWindowExposed(widget)
+
+ @pytest.mark.usefixtures("use_opengl")
+ def testCubeWithComplex(self):
+ try:
+ import silx.gui.plot3d # noqa
+ except ImportError:
+ self.skipTest("OpenGL not available")
+ data = self.createComplexData()
+ dataViewClass = DataViews._Plot3dView
+ widget = self.createDataViewWithData(dataViewClass, data)
+ self.qWaitForWindowExposed(widget)
+
+ def testImageStackWithComplex(self):
+ data = self.createComplexData()
+ dataViewClass = DataViews._StackView
+ widget = self.createDataViewWithData(dataViewClass, data)
+ self.qWaitForWindowExposed(widget)
diff --git a/src/silx/gui/data/test/test_numpyaxesselector.py b/src/silx/gui/data/test/test_numpyaxesselector.py
new file mode 100644
index 0000000..37b8d3e
--- /dev/null
+++ b/src/silx/gui/data/test/test_numpyaxesselector.py
@@ -0,0 +1,150 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/01/2018"
+
+import os
+import tempfile
+import unittest
+from contextlib import contextmanager
+
+import numpy
+
+from silx.gui.data.NumpyAxesSelector import NumpyAxesSelector
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+import h5py
+
+
+class TestNumpyAxesSelector(TestCaseQt):
+
+ def test_creation(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ widget = NumpyAxesSelector()
+ widget.setVisible(True)
+
+ def test_none(self):
+ data = numpy.arange(3 * 3 * 3)
+ widget = NumpyAxesSelector()
+ widget.setData(data)
+ widget.setData(None)
+ result = widget.selectedData()
+ self.assertIsNone(result)
+
+ def test_output_samedim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["x", "y", "z"])
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_moredim(self):
+ data = numpy.arange(3 * 3 * 3 * 3)
+ data.shape = 3, 3, 3, 3
+ expectedResult = data
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["x", "y", "z", "boum"])
+ widget.setData(data[0])
+ result = widget.selectedData()
+ self.assertIsNone(result)
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_lessdim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data[0]
+
+ widget = NumpyAxesSelector()
+ widget.setAxisNames(["y", "x"])
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_output_1dim(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ expectedResult = data[0, 0, 0]
+
+ widget = NumpyAxesSelector()
+ widget.setData(data)
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ @contextmanager
+ def h5_temporary_file(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ h5file["data"] = data
+ yield h5file
+ # clean up
+ h5file.close()
+ os.unlink(tmp_name)
+
+ def test_h5py_dataset(self):
+ with self.h5_temporary_file() as h5file:
+ dataset = h5file["data"]
+ expectedResult = dataset[0]
+
+ widget = NumpyAxesSelector()
+ widget.setData(dataset)
+ widget.setAxisNames(["y", "x"])
+ result = widget.selectedData()
+ self.assertTrue(numpy.array_equal(result, expectedResult))
+
+ def test_data_event(self):
+ data = numpy.arange(3 * 3 * 3)
+ widget = NumpyAxesSelector()
+ listener = SignalListener()
+ widget.dataChanged.connect(listener)
+ widget.setData(data)
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 2)
+
+ def test_selected_data_event(self):
+ data = numpy.arange(3 * 3 * 3)
+ data.shape = 3, 3, 3
+ widget = NumpyAxesSelector()
+ listener = SignalListener()
+ widget.selectionChanged.connect(listener)
+ widget.setData(data)
+ widget.setAxisNames(["x"])
+ widget.setData(None)
+ self.assertEqual(listener.callCount(), 3)
+ listener.clear()
diff --git a/src/silx/gui/data/test/test_textformatter.py b/src/silx/gui/data/test/test_textformatter.py
new file mode 100644
index 0000000..af41def
--- /dev/null
+++ b/src/silx/gui/data/test/test_textformatter.py
@@ -0,0 +1,199 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/12/2017"
+
+import unittest
+import shutil
+import tempfile
+
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.testutils import SignalListener
+from ..TextFormatter import TextFormatter
+from silx.io.utils import h5py_read_dataset
+
+import h5py
+
+
+class TestTextFormatter(TestCaseQt):
+
+ def test_copy(self):
+ formatter = TextFormatter()
+ copy = TextFormatter(formatter=formatter)
+ self.assertIsNot(formatter, copy)
+ copy.setFloatFormat("%.3f")
+ self.assertEqual(formatter.integerFormat(), copy.integerFormat())
+ self.assertNotEqual(formatter.floatFormat(), copy.floatFormat())
+ self.assertEqual(formatter.useQuoteForText(), copy.useQuoteForText())
+ self.assertEqual(formatter.imaginaryUnit(), copy.imaginaryUnit())
+
+ def test_event(self):
+ listener = SignalListener()
+ formatter = TextFormatter()
+ formatter.formatChanged.connect(listener)
+ formatter.setFloatFormat("%.3f")
+ formatter.setIntegerFormat("%03i")
+ formatter.setUseQuoteForText(False)
+ formatter.setImaginaryUnit("z")
+ self.assertEqual(listener.callCount(), 4)
+
+ def test_int(self):
+ formatter = TextFormatter()
+ formatter.setIntegerFormat("%05i")
+ result = formatter.toString(512)
+ self.assertEqual(result, "00512")
+
+ def test_float(self):
+ formatter = TextFormatter()
+ formatter.setFloatFormat("%.3f")
+ result = formatter.toString(1.3)
+ self.assertEqual(result, "1.300")
+
+ def test_complex(self):
+ formatter = TextFormatter()
+ formatter.setFloatFormat("%.1f")
+ formatter.setImaginaryUnit("i")
+ result = formatter.toString(1.0 + 5j)
+ result = result.replace(" ", "")
+ self.assertEqual(result, "1.0+5.0i")
+
+ def test_string(self):
+ formatter = TextFormatter()
+ formatter.setIntegerFormat("%.1f")
+ formatter.setImaginaryUnit("z")
+ result = formatter.toString("toto")
+ self.assertEqual(result, '"toto"')
+
+ def test_numpy_void(self):
+ formatter = TextFormatter()
+ result = formatter.toString(numpy.void(b"\xFF"))
+ self.assertEqual(result, 'b"\\xFF"')
+
+ def test_char_cp1252(self):
+ # degree character in cp1252
+ formatter = TextFormatter()
+ result = formatter.toString(numpy.bytes_(b"\xB0"))
+ self.assertEqual(result, u'"\u00B0"')
+
+
+class TestTextFormatterWithH5py(TestCaseQt):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestTextFormatterWithH5py, cls).setUpClass()
+
+ cls.tmpDirectory = tempfile.mkdtemp()
+ cls.h5File = h5py.File("%s/formatter.h5" % cls.tmpDirectory, mode="w")
+ cls.formatter = TextFormatter()
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestTextFormatterWithH5py, cls).tearDownClass()
+ cls.h5File.close()
+ cls.h5File = None
+ shutil.rmtree(cls.tmpDirectory)
+
+ def create_dataset(self, data, dtype=None):
+ testName = "%s" % self.id()
+ dataset = self.h5File.create_dataset(testName, data=data, dtype=dtype)
+ return dataset
+
+ def read_dataset(self, d):
+ return self.formatter.toString(d[()], dtype=d.dtype)
+
+ def testAscii(self):
+ d = self.create_dataset(data=b"abc")
+ result = self.read_dataset(d)
+ self.assertEqual(result, '"abc"')
+
+ def testUnicode(self):
+ d = self.create_dataset(data=u"i\u2661cookies")
+ result = self.read_dataset(d)
+ self.assertEqual(len(result), 11)
+ self.assertEqual(result, u'"i\u2661cookies"')
+
+ def testBadAscii(self):
+ d = self.create_dataset(data=b"\xF0\x9F\x92\x94")
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'b"\\xF0\\x9F\\x92\\x94"')
+
+ def testVoid(self):
+ d = self.create_dataset(data=numpy.void(b"abc\xF0"))
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'b"\\x61\\x62\\x63\\xF0"')
+
+ def testEnum(self):
+ dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
+ d = numpy.array(42, dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'BLUE(42)')
+
+ def testRef(self):
+ dtype = h5py.special_dtype(ref=h5py.Reference)
+ d = numpy.array(self.h5File.ref, dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, 'REF')
+
+ def testArrayAscii(self):
+ d = self.create_dataset(data=[b"abc"])
+ result = self.read_dataset(d)
+ self.assertEqual(result, '["abc"]')
+
+ def testArrayUnicode(self):
+ dtype = h5py.special_dtype(vlen=str)
+ d = numpy.array([u"i\u2661cookies"], dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(len(result), 13)
+ self.assertEqual(result, u'["i\u2661cookies"]')
+
+ def testArrayBadAscii(self):
+ d = self.create_dataset(data=[b"\xF0\x9F\x92\x94"])
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[b"\\xF0\\x9F\\x92\\x94"]')
+
+ def testArrayVoid(self):
+ d = self.create_dataset(data=numpy.void([b"abc\xF0"]))
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[b"\\x61\\x62\\x63\\xF0"]')
+
+ def testArrayEnum(self):
+ dtype = h5py.special_dtype(enum=('i', {"RED": 0, "GREEN": 1, "BLUE": 42}))
+ d = numpy.array([42, 1, 100], dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[BLUE(42) GREEN(1) 100]')
+
+ def testArrayRef(self):
+ dtype = h5py.special_dtype(ref=h5py.Reference)
+ d = numpy.array([self.h5File.ref, None], dtype=dtype)
+ d = self.create_dataset(data=d)
+ result = self.read_dataset(d)
+ self.assertEqual(result, '[REF NULL_REF]')
diff --git a/src/silx/gui/dialog/AbstractDataFileDialog.py b/src/silx/gui/dialog/AbstractDataFileDialog.py
new file mode 100644
index 0000000..5272f48
--- /dev/null
+++ b/src/silx/gui/dialog/AbstractDataFileDialog.py
@@ -0,0 +1,1731 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`AbstractDataFileDialog`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/03/2019"
+
+
+import sys
+import os
+import logging
+import functools
+from distutils.version import LooseVersion
+
+import numpy
+
+import silx.io.url
+from silx.gui import qt
+from silx.gui.hdf5.Hdf5TreeModel import Hdf5TreeModel
+from . import utils
+from .FileTypeComboBox import FileTypeComboBox
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+DEFAULT_SIDEBAR_URL = True
+"""Set it to false to disable initilializing of the sidebar urls with the
+default Qt list. This could allow to disable a behaviour known to segfault on
+some version of PyQt."""
+
+
+class _IconProvider(object):
+
+ FileDialogToParentDir = qt.QStyle.SP_CustomBase + 1
+
+ FileDialogToParentFile = qt.QStyle.SP_CustomBase + 2
+
+ def __init__(self):
+ self.__iconFileDialogToParentDir = None
+ self.__iconFileDialogToParentFile = None
+
+ def _createIconToParent(self, standardPixmap):
+ """
+
+ FIXME: It have to be tested for some OS (arrow icon do not have always
+ the same direction)
+ """
+ style = qt.QApplication.style()
+ baseIcon = style.standardIcon(qt.QStyle.SP_FileDialogToParent)
+ backgroundIcon = style.standardIcon(standardPixmap)
+ icon = qt.QIcon()
+
+ sizes = baseIcon.availableSizes()
+ sizes = sorted(sizes, key=lambda s: s.height())
+ sizes = filter(lambda s: s.height() < 100, sizes)
+ sizes = list(sizes)
+ if len(sizes) > 0:
+ baseSize = sizes[-1]
+ else:
+ baseSize = baseIcon.availableSizes()[0]
+ size = qt.QSize(baseSize.width(), baseSize.height() * 3 // 2)
+
+ modes = [qt.QIcon.Normal, qt.QIcon.Disabled]
+ for mode in modes:
+ pixmap = qt.QPixmap(size)
+ pixmap.fill(qt.Qt.transparent)
+ painter = qt.QPainter(pixmap)
+ painter.drawPixmap(0, 0, backgroundIcon.pixmap(baseSize, mode=mode))
+ painter.drawPixmap(0, size.height() // 3, baseIcon.pixmap(baseSize, mode=mode))
+ painter.end()
+ icon.addPixmap(pixmap, mode=mode)
+
+ return icon
+
+ def getFileDialogToParentDir(self):
+ if self.__iconFileDialogToParentDir is None:
+ self.__iconFileDialogToParentDir = self._createIconToParent(qt.QStyle.SP_DirIcon)
+ return self.__iconFileDialogToParentDir
+
+ def getFileDialogToParentFile(self):
+ if self.__iconFileDialogToParentFile is None:
+ self.__iconFileDialogToParentFile = self._createIconToParent(qt.QStyle.SP_FileIcon)
+ return self.__iconFileDialogToParentFile
+
+ def icon(self, kind):
+ if kind == self.FileDialogToParentDir:
+ return self.getFileDialogToParentDir()
+ elif kind == self.FileDialogToParentFile:
+ return self.getFileDialogToParentFile()
+ else:
+ style = qt.QApplication.style()
+ icon = style.standardIcon(kind)
+ return icon
+
+
+class _SideBar(qt.QListView):
+ """Sidebar containing shortcuts for common directories"""
+
+ def __init__(self, parent=None):
+ super(_SideBar, self).__init__(parent)
+ self.__iconProvider = qt.QFileIconProvider()
+ self.setUniformItemSizes(True)
+ model = qt.QStandardItemModel(self)
+ self.setModel(model)
+ self._initModel()
+ self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+
+ def iconProvider(self):
+ return self.__iconProvider
+
+ def _initModel(self):
+ urls = self._getDefaultUrls()
+ self.setUrls(urls)
+
+ def _getDefaultUrls(self):
+ """Returns the default shortcuts.
+
+ It uses the default QFileDialog shortcuts if it is possible, else
+ provides a link to the computer's root and the user's home.
+
+ :rtype: List[str]
+ """
+ urls = []
+ version = LooseVersion(qt.qVersion())
+ feed_sidebar = True
+
+ if not DEFAULT_SIDEBAR_URL:
+ _logger.debug("Skip default sidebar URLs (from setted variable)")
+ feed_sidebar = False
+ elif version < LooseVersion("5.11.2") and qt.BINDING == "PyQt5" and sys.platform in ["linux", "linux2"]:
+ # Avoid segfault on PyQt5 + gtk
+ _logger.debug("Skip default sidebar URLs (avoid PyQt5 segfault)")
+ feed_sidebar = False
+
+ if feed_sidebar:
+ # Get default shortcut
+ # There is no other way
+ d = qt.QFileDialog(self)
+ # Needed to be able to reach the sidebar urls
+ d.setOption(qt.QFileDialog.DontUseNativeDialog, True)
+ urls = d.sidebarUrls()
+ d.deleteLater()
+ d = None
+
+ if len(urls) == 0:
+ urls.append(qt.QUrl("file://"))
+ urls.append(qt.QUrl.fromLocalFile(qt.QDir.homePath()))
+
+ return urls
+
+ def setSelectedPath(self, path):
+ selected = None
+ model = self.model()
+ for i in range(model.rowCount()):
+ index = model.index(i, 0)
+ url = model.data(index, qt.Qt.UserRole)
+ urlPath = url.toLocalFile()
+ if path == urlPath:
+ selected = index
+
+ selectionModel = self.selectionModel()
+ if selected is not None:
+ selectionModel.setCurrentIndex(selected, qt.QItemSelectionModel.ClearAndSelect)
+ else:
+ selectionModel.clear()
+
+ def setUrls(self, urls):
+ model = self.model()
+ model.clear()
+
+ names = {}
+ names[qt.QDir.rootPath()] = "Computer"
+ names[qt.QDir.homePath()] = "Home"
+
+ style = qt.QApplication.style()
+ iconProvider = self.iconProvider()
+ for url in urls:
+ path = url.toLocalFile()
+ if path == "":
+ if sys.platform != "win32":
+ url = qt.QUrl(qt.QDir.rootPath())
+ name = "Computer"
+ icon = style.standardIcon(qt.QStyle.SP_ComputerIcon)
+ else:
+ fileInfo = qt.QFileInfo(path)
+ name = names.get(path, fileInfo.fileName())
+ icon = iconProvider.icon(fileInfo)
+
+ if icon.isNull():
+ icon = style.standardIcon(qt.QStyle.SP_MessageBoxCritical)
+
+ item = qt.QStandardItem()
+ item.setText(name)
+ item.setIcon(icon)
+ item.setData(url, role=qt.Qt.UserRole)
+ model.appendRow(item)
+
+ def urls(self):
+ result = []
+ model = self.model()
+ for i in range(model.rowCount()):
+ index = model.index(i, 0)
+ url = model.data(index, qt.Qt.UserRole)
+ result.append(url)
+ return result
+
+ def sizeHint(self):
+ index = self.model().index(0, 0)
+ return self.sizeHintForIndex(index) + qt.QSize(2 * self.frameWidth(), 2 * self.frameWidth())
+
+
+class _Browser(qt.QStackedWidget):
+
+ activated = qt.Signal(qt.QModelIndex)
+ selected = qt.Signal(qt.QModelIndex)
+ rootIndexChanged = qt.Signal(qt.QModelIndex)
+
+ def __init__(self, parent, listView, detailView):
+ qt.QStackedWidget.__init__(self, parent)
+ self.__listView = listView
+ self.__detailView = detailView
+ self.insertWidget(0, self.__listView)
+ self.insertWidget(1, self.__detailView)
+
+ self.__listView.activated.connect(self.__emitActivated)
+ self.__detailView.activated.connect(self.__emitActivated)
+
+ def __emitActivated(self, index):
+ self.activated.emit(index)
+
+ def __emitSelected(self, selected, deselected):
+ index = self.selectedIndex()
+ if index is not None:
+ self.selected.emit(index)
+
+ def selectedIndex(self):
+ if self.currentIndex() == 0:
+ selectionModel = self.__listView.selectionModel()
+ else:
+ selectionModel = self.__detailView.selectionModel()
+
+ if selectionModel is None:
+ return None
+
+ indexes = selectionModel.selectedIndexes()
+ # Filter non-main columns
+ indexes = [i for i in indexes if i.column() == 0]
+ if len(indexes) == 1:
+ index = indexes[0]
+ return index
+ return None
+
+ def model(self):
+ """Returns the current model."""
+ if self.currentIndex() == 0:
+ return self.__listView.model()
+ else:
+ return self.__detailView.model()
+
+ def selectIndex(self, index):
+ if self.currentIndex() == 0:
+ selectionModel = self.__listView.selectionModel()
+ else:
+ selectionModel = self.__detailView.selectionModel()
+ if selectionModel is None:
+ return
+ selectionModel.setCurrentIndex(index, qt.QItemSelectionModel.ClearAndSelect)
+
+ def viewMode(self):
+ """Returns the current view mode.
+
+ :rtype: qt.QFileDialog.ViewMode
+ """
+ if self.currentIndex() == 0:
+ return qt.QFileDialog.List
+ elif self.currentIndex() == 1:
+ return qt.QFileDialog.Detail
+ else:
+ assert(False)
+
+ def setViewMode(self, mode):
+ """Set the current view mode.
+
+ :param qt.QFileDialog.ViewMode mode: The new view mode
+ """
+ if mode == qt.QFileDialog.Detail:
+ self.showDetails()
+ elif mode == qt.QFileDialog.List:
+ self.showList()
+ else:
+ assert(False)
+
+ def showList(self):
+ self.__listView.show()
+ self.__detailView.hide()
+ self.setCurrentIndex(0)
+
+ def showDetails(self):
+ self.__listView.hide()
+ self.__detailView.show()
+ self.setCurrentIndex(1)
+ self.__detailView.updateGeometry()
+
+ def clear(self):
+ self.__listView.setRootIndex(qt.QModelIndex())
+ self.__detailView.setRootIndex(qt.QModelIndex())
+ selectionModel = self.__listView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ selectionModel = self.__detailView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ self.__listView.setModel(None)
+ self.__detailView.setModel(None)
+
+ def setRootIndex(self, index, model=None):
+ """Sets the root item to the item at the given index.
+ """
+ rootIndex = self.__listView.rootIndex()
+ newModel = model or index.model()
+ assert(newModel is not None)
+
+ if rootIndex is None or rootIndex.model() is not newModel:
+ # update the model
+ selectionModel = self.__listView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ selectionModel = self.__detailView.selectionModel()
+ if selectionModel is not None:
+ selectionModel.selectionChanged.disconnect()
+ selectionModel.clear()
+ pIndex = qt.QPersistentModelIndex(index)
+ self.__listView.setModel(newModel)
+ # changing the model of the tree view change the index mapping
+ # that is why we are using a persistance model index
+ self.__detailView.setModel(newModel)
+ index = newModel.index(pIndex.row(), pIndex.column(), pIndex.parent())
+ selectionModel = self.__listView.selectionModel()
+ selectionModel.selectionChanged.connect(self.__emitSelected)
+ selectionModel = self.__detailView.selectionModel()
+ selectionModel.selectionChanged.connect(self.__emitSelected)
+
+ self.__listView.setRootIndex(index)
+ self.__detailView.setRootIndex(index)
+ self.rootIndexChanged.emit(index)
+
+ def rootIndex(self):
+ """Returns the model index of the model's root item. The root item is
+ the parent item to the view's toplevel items. The root can be invalid.
+ """
+ return self.__listView.rootIndex()
+
+ __serialVersion = 1
+ """Store the current version of the serialized data"""
+
+ def visualRect(self, index):
+ """Returns the rectangle on the viewport occupied by the item at index.
+
+ :param qt.QModelIndex index: An index
+ :rtype: QRect
+ """
+ if self.currentIndex() == 0:
+ return self.__listView.visualRect(index)
+ else:
+ return self.__detailView.visualRect(index)
+
+ def viewport(self):
+ """Returns the viewport widget.
+
+ :param qt.QModelIndex index: An index
+ :rtype: QRect
+ """
+ if self.currentIndex() == 0:
+ return self.__listView.viewport()
+ else:
+ return self.__detailView.viewport()
+
+ def restoreState(self, state):
+ """Restores the dialogs's layout, history and current directory to the
+ state specified.
+
+ :param qt.QByeArray state: Stream containing the new state
+ :rtype: bool
+ """
+ stream = qt.QDataStream(state, qt.QIODevice.ReadOnly)
+
+ nameId = stream.readQString()
+ if nameId != "Browser":
+ _logger.warning("Stored state contains an invalid name id. Browser restoration cancelled.")
+ return False
+
+ version = stream.readInt32()
+ if version != self.__serialVersion:
+ _logger.warning("Stored state contains an invalid version. Browser restoration cancelled.")
+ return False
+
+ headerData = stream.readQVariant()
+ self.__detailView.header().restoreState(headerData)
+
+ viewMode = stream.readInt32()
+ self.setViewMode(viewMode)
+ return True
+
+ def saveState(self):
+ """Saves the state of the dialog's layout.
+
+ :rtype: qt.QByteArray
+ """
+ data = qt.QByteArray()
+ stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
+
+ nameId = u"Browser"
+ stream.writeQString(nameId)
+ stream.writeInt32(self.__serialVersion)
+ stream.writeQVariant(self.__detailView.header().saveState())
+ stream.writeInt32(self.viewMode())
+
+ return data
+
+
+class _FabioData(object):
+
+ def __init__(self, fabioFile):
+ self.__fabioFile = fabioFile
+
+ @property
+ def dtype(self):
+ # Let say it is a valid type
+ return numpy.dtype("float")
+
+ @property
+ def shape(self):
+ if self.__fabioFile.nframes == 0:
+ return None
+ if self.__fabioFile.nframes == 1:
+ return [slice(None), slice(None)]
+ return [self.__fabioFile.nframes, slice(None), slice(None)]
+
+ def __getitem__(self, selector):
+ if self.__fabioFile.nframes == 1 and selector == tuple():
+ return self.__fabioFile.data
+ if isinstance(selector, tuple) and len(selector) == 1:
+ selector = selector[0]
+
+ if isinstance(selector, int):
+ if 0 <= selector < self.__fabioFile.nframes:
+ if self.__fabioFile.nframes == 1:
+ return self.__fabioFile.data
+ else:
+ frame = self.__fabioFile.getframe(selector)
+ return frame.data
+ else:
+ raise ValueError("Invalid selector %s" % selector)
+ else:
+ raise TypeError("Unsupported selector type %s" % type(selector))
+
+
+class _PathEdit(qt.QLineEdit):
+ pass
+
+
+class _CatchResizeEvent(qt.QObject):
+
+ resized = qt.Signal(qt.QResizeEvent)
+
+ def __init__(self, parent, target):
+ super(_CatchResizeEvent, self).__init__(parent)
+ self.__target = target
+ self.__target_oldResizeEvent = self.__target.resizeEvent
+ self.__target.resizeEvent = self.__resizeEvent
+
+ def __resizeEvent(self, event):
+ result = self.__target_oldResizeEvent(event)
+ self.resized.emit(event)
+ return result
+
+
+class AbstractDataFileDialog(qt.QDialog):
+ """The `AbstractFileDialog` provides a generic GUI to create a custom dialog
+ allowing to access to file resources like HDF5 files or HDF5 datasets.
+
+ .. image:: img/abstractdatafiledialog.png
+
+ The dialog contains:
+
+ - Shortcuts: It provides few links to have a fast access of browsing
+ locations.
+ - Browser: It provides a display to browse throw the file system and inside
+ HDF5 files or fabio files. A file format selector is provided.
+ - URL: Display the URL available to reach the data using
+ :meth:`silx.io.get_data`, :meth:`silx.io.open`.
+ - Data selector: A widget to apply a sub selection of the browsed dataset.
+ This widget can be provided, else nothing will be used.
+ - Data preview: A widget to preview the selected data, which is the result
+ of the filter from the data selector.
+ This widget can be provided, else nothing will be used.
+ - Preview's toolbar: Provides tools used to custom data preview or data
+ selector.
+ This widget can be provided, else nothing will be used.
+ - Buttons to validate the dialog
+ """
+
+ _defaultIconProvider = None
+ """Lazy loaded default icon provider"""
+
+ def __init__(self, parent=None):
+ super(AbstractDataFileDialog, self).__init__(parent)
+ self._init()
+
+ def _init(self):
+ self.setWindowTitle("Open")
+
+ self.__openedFiles = []
+ """Store the list of files opened by the model itself."""
+ # FIXME: It should be managed one by one by Hdf5Item itself
+
+ self.__directory = None
+ self.__directoryLoadedFilter = None
+ self.__errorWhileLoadingFile = None
+ self.__selectedFile = None
+ self.__selectedData = None
+ self.__currentHistory = []
+ """Store history of URLs, last index one is the latest one"""
+ self.__currentHistoryLocation = -1
+ """Store the location in the history. Bigger is older"""
+
+ self.__processing = 0
+ """Number of asynchronous processing tasks"""
+ self.__h5 = None
+ self.__fabio = None
+
+ # On Qt5 a safe icon provider is still needed to avoid freeze
+ _logger.debug("Uses default QFileSystemModel with a SafeFileIconProvider")
+ self.__fileModel = qt.QFileSystemModel(self)
+ from .SafeFileIconProvider import SafeFileIconProvider
+ iconProvider = SafeFileIconProvider()
+ self.__fileModel.setIconProvider(iconProvider)
+
+ # The common file dialog filter only on Mac OS X
+ self.__fileModel.setNameFilterDisables(sys.platform == "darwin")
+ self.__fileModel.setReadOnly(True)
+ self.__fileModel.directoryLoaded.connect(self.__directoryLoaded)
+
+ self.__dataModel = Hdf5TreeModel(self)
+
+ self.__createWidgets()
+ self.__initLayout()
+ self.__showAsListView()
+
+ path = os.getcwd()
+ self.__fileModel_setRootPath(path)
+
+ self.__clearData()
+ self.__updatePath()
+
+ # Update the file model filter
+ self.__fileTypeCombo.setCurrentIndex(0)
+ self.__filterSelected(0)
+
+ # It is not possible to override the QObject destructor nor
+ # to access to the content of the Python object with the `destroyed`
+ # signal cause the Python method was already removed with the QWidget,
+ # while the QObject still exists.
+ # We use a static method plus explicit references to objects to
+ # release. The callback do not use any ref to self.
+ onDestroy = functools.partial(self._closeFileList, self.__openedFiles)
+ self.destroyed.connect(onDestroy)
+
+ @staticmethod
+ def _closeFileList(fileList):
+ """Static method to close explicit references to internal objects."""
+ _logger.debug("Clear AbstractDataFileDialog")
+ for obj in fileList:
+ _logger.debug("Close file %s", obj.filename)
+ obj.close()
+ fileList[:] = []
+
+ def done(self, result):
+ self._clear()
+ super(AbstractDataFileDialog, self).done(result)
+
+ def _clear(self):
+ """Explicit method to clear data stored in the dialog.
+ After this call it is not anymore possible to use the widget.
+
+ This method is triggered by the destruction of the object and the
+ QDialog :meth:`done`. Then it can be triggered more than once.
+ """
+ _logger.debug("Clear dialog")
+ self.__errorWhileLoadingFile = None
+ self.__clearData()
+ if self.__fileModel is not None:
+ # Cache the directory before cleaning the model
+ self.__directory = self.directory()
+ self.__browser.clear()
+ self.__closeFile()
+ self.__fileModel = None
+ self.__dataModel = None
+
+ def hasPendingEvents(self):
+ """Returns true if the dialog have asynchronous tasks working on the
+ background."""
+ return self.__processing > 0
+
+ # User interface
+
+ def __createWidgets(self):
+ self.__sidebar = self._createSideBar()
+ if self.__sidebar is not None:
+ sideBarModel = self.__sidebar.selectionModel()
+ sideBarModel.selectionChanged.connect(self.__shortcutSelected)
+ self.__sidebar.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ listView = qt.QListView(self)
+ listView.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ listView.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ listView.setResizeMode(qt.QListView.Adjust)
+ listView.setWrapping(True)
+ listView.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+ listView.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ utils.patchToConsumeReturnKey(listView)
+
+ treeView = qt.QTreeView(self)
+ treeView.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ treeView.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ treeView.setRootIsDecorated(False)
+ treeView.setItemsExpandable(False)
+ treeView.setSortingEnabled(True)
+ treeView.header().setSortIndicator(0, qt.Qt.AscendingOrder)
+ treeView.header().setStretchLastSection(False)
+ treeView.setTextElideMode(qt.Qt.ElideMiddle)
+ treeView.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+ treeView.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ treeView.setDragDropMode(qt.QAbstractItemView.InternalMove)
+ utils.patchToConsumeReturnKey(treeView)
+
+ self.__browser = _Browser(self, listView, treeView)
+ self.__browser.activated.connect(self.__browsedItemActivated)
+ self.__browser.selected.connect(self.__browsedItemSelected)
+ self.__browser.rootIndexChanged.connect(self.__rootIndexChanged)
+ self.__browser.setObjectName("browser")
+
+ self.__previewWidget = self._createPreviewWidget(self)
+
+ self.__fileTypeCombo = FileTypeComboBox(self)
+ self.__fileTypeCombo.setObjectName("fileTypeCombo")
+ self.__fileTypeCombo.setDuplicatesEnabled(False)
+ self.__fileTypeCombo.setSizeAdjustPolicy(qt.QComboBox.AdjustToMinimumContentsLengthWithIcon)
+ self.__fileTypeCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ self.__fileTypeCombo.activated[int].connect(self.__filterSelected)
+ self.__fileTypeCombo.setFabioUrlSupproted(self._isFabioFilesSupported())
+
+ self.__pathEdit = _PathEdit(self)
+ self.__pathEdit.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ self.__pathEdit.textChanged.connect(self.__textChanged)
+ self.__pathEdit.setObjectName("url")
+ utils.patchToConsumeReturnKey(self.__pathEdit)
+
+ self.__buttons = qt.QDialogButtonBox(self)
+ self.__buttons.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ types = qt.QDialogButtonBox.Open | qt.QDialogButtonBox.Cancel
+ self.__buttons.setStandardButtons(types)
+ self.__buttons.button(qt.QDialogButtonBox.Cancel).setObjectName("cancel")
+ self.__buttons.button(qt.QDialogButtonBox.Open).setObjectName("open")
+
+ self.__buttons.accepted.connect(self.accept)
+ self.__buttons.rejected.connect(self.reject)
+
+ self.__browseToolBar = self._createBrowseToolBar()
+ self.__backwardAction.setEnabled(False)
+ self.__forwardAction.setEnabled(False)
+ self.__fileDirectoryAction.setEnabled(False)
+ self.__parentFileDirectoryAction.setEnabled(False)
+
+ self.__selectorWidget = self._createSelectorWidget(self)
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.selectionChanged.connect(self.__selectorWidgetChanged)
+
+ self.__previewToolBar = self._createPreviewToolbar(self, self.__previewWidget, self.__selectorWidget)
+
+ self.__dataIcon = qt.QLabel(self)
+ self.__dataIcon.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed)
+ self.__dataIcon.setScaledContents(True)
+ self.__dataIcon.setMargin(2)
+ self.__dataIcon.setAlignment(qt.Qt.AlignCenter)
+
+ self.__dataInfo = qt.QLabel(self)
+ self.__dataInfo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+
+ def _createSideBar(self):
+ sidebar = _SideBar(self)
+ sidebar.setObjectName("sidebar")
+ return sidebar
+
+ def iconProvider(self):
+ iconProvider = self.__class__._defaultIconProvider
+ if iconProvider is None:
+ iconProvider = _IconProvider()
+ self.__class__._defaultIconProvider = iconProvider
+ return iconProvider
+
+ def _createBrowseToolBar(self):
+ toolbar = qt.QToolBar(self)
+ toolbar.setIconSize(qt.QSize(16, 16))
+ iconProvider = self.iconProvider()
+
+ backward = qt.QAction(toolbar)
+ backward.setText("Back")
+ backward.setObjectName("backwardAction")
+ backward.setIcon(iconProvider.icon(qt.QStyle.SP_ArrowBack))
+ backward.triggered.connect(self.__navigateBackward)
+ self.__backwardAction = backward
+
+ forward = qt.QAction(toolbar)
+ forward.setText("Forward")
+ forward.setObjectName("forwardAction")
+ forward.setIcon(iconProvider.icon(qt.QStyle.SP_ArrowForward))
+ forward.triggered.connect(self.__navigateForward)
+ self.__forwardAction = forward
+
+ parentDirectory = qt.QAction(toolbar)
+ parentDirectory.setText("Go to parent")
+ parentDirectory.setObjectName("toParentAction")
+ parentDirectory.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogToParent))
+ parentDirectory.triggered.connect(self.__navigateToParent)
+ self.__toParentAction = parentDirectory
+
+ fileDirectory = qt.QAction(toolbar)
+ fileDirectory.setText("Root of the file")
+ fileDirectory.setObjectName("toRootFileAction")
+ fileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentFile))
+ fileDirectory.triggered.connect(self.__navigateToParentFile)
+ self.__fileDirectoryAction = fileDirectory
+
+ parentFileDirectory = qt.QAction(toolbar)
+ parentFileDirectory.setText("Parent directory of the file")
+ parentFileDirectory.setObjectName("toDirectoryAction")
+ parentFileDirectory.setIcon(iconProvider.icon(iconProvider.FileDialogToParentDir))
+ parentFileDirectory.triggered.connect(self.__navigateToParentDir)
+ self.__parentFileDirectoryAction = parentFileDirectory
+
+ listView = qt.QAction(toolbar)
+ listView.setText("List view")
+ listView.setObjectName("listModeAction")
+ listView.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogListView))
+ listView.triggered.connect(self.__showAsListView)
+ listView.setCheckable(True)
+
+ detailView = qt.QAction(toolbar)
+ detailView.setText("Detail view")
+ detailView.setObjectName("detailModeAction")
+ detailView.setIcon(iconProvider.icon(qt.QStyle.SP_FileDialogDetailedView))
+ detailView.triggered.connect(self.__showAsDetailedView)
+ detailView.setCheckable(True)
+
+ self.__listViewAction = listView
+ self.__detailViewAction = detailView
+
+ toolbar.addAction(backward)
+ toolbar.addAction(forward)
+ toolbar.addSeparator()
+ toolbar.addAction(parentDirectory)
+ toolbar.addAction(fileDirectory)
+ toolbar.addAction(parentFileDirectory)
+ toolbar.addSeparator()
+ toolbar.addAction(listView)
+ toolbar.addAction(detailView)
+
+ toolbar.setStyleSheet("QToolBar { border: 0px }")
+
+ return toolbar
+
+ def __initLayout(self):
+ sideBarLayout = qt.QVBoxLayout()
+ sideBarLayout.setContentsMargins(0, 0, 0, 0)
+ dummyToolBar = qt.QWidget(self)
+ dummyToolBar.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ dummyCombo = qt.QWidget(self)
+ dummyCombo.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ sideBarLayout.addWidget(dummyToolBar)
+ if self.__sidebar is not None:
+ sideBarLayout.addWidget(self.__sidebar)
+ sideBarLayout.addWidget(dummyCombo)
+ sideBarWidget = qt.QWidget(self)
+ sideBarWidget.setLayout(sideBarLayout)
+
+ dummyCombo.setFixedHeight(self.__fileTypeCombo.height())
+ self.__resizeCombo = _CatchResizeEvent(self, self.__fileTypeCombo)
+ self.__resizeCombo.resized.connect(lambda e: dummyCombo.setFixedHeight(e.size().height()))
+
+ dummyToolBar.setFixedHeight(self.__browseToolBar.height())
+ self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
+ self.__resizeToolbar.resized.connect(lambda e: dummyToolBar.setFixedHeight(e.size().height()))
+
+ datasetSelection = qt.QWidget(self)
+ layoutLeft = qt.QVBoxLayout()
+ layoutLeft.setContentsMargins(0, 0, 0, 0)
+ layoutLeft.addWidget(self.__browseToolBar)
+ layoutLeft.addWidget(self.__browser)
+ layoutLeft.addWidget(self.__fileTypeCombo)
+ datasetSelection.setLayout(layoutLeft)
+ datasetSelection.setSizePolicy(qt.QSizePolicy.MinimumExpanding, qt.QSizePolicy.Expanding)
+
+ infoLayout = qt.QHBoxLayout()
+ infoLayout.setContentsMargins(0, 0, 0, 0)
+ infoLayout.addWidget(self.__dataIcon)
+ infoLayout.addWidget(self.__dataInfo)
+
+ dataFrame = qt.QFrame(self)
+ dataFrame.setFrameShape(qt.QFrame.StyledPanel)
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(self.__previewWidget)
+ layout.addLayout(infoLayout)
+ dataFrame.setLayout(layout)
+
+ dataSelection = qt.QWidget(self)
+ dataLayout = qt.QVBoxLayout()
+ dataLayout.setContentsMargins(0, 0, 0, 0)
+ if self.__previewToolBar is not None:
+ dataLayout.addWidget(self.__previewToolBar)
+ else:
+ # Add dummy space
+ dummyToolbar2 = qt.QWidget(self)
+ dummyToolbar2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ dummyToolbar2.setFixedHeight(self.__browseToolBar.height())
+ self.__resizeToolbar = _CatchResizeEvent(self, self.__browseToolBar)
+ self.__resizeToolbar.resized.connect(lambda e: dummyToolbar2.setFixedHeight(e.size().height()))
+ dataLayout.addWidget(dummyToolbar2)
+
+ dataLayout.addWidget(dataFrame)
+ if self.__selectorWidget is not None:
+ dataLayout.addWidget(self.__selectorWidget)
+ else:
+ # Add dummy space
+ dummyCombo2 = qt.QWidget(self)
+ dummyCombo2.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ dummyCombo2.setFixedHeight(self.__fileTypeCombo.height())
+ self.__resizeToolbar = _CatchResizeEvent(self, self.__fileTypeCombo)
+ self.__resizeToolbar.resized.connect(lambda e: dummyCombo2.setFixedHeight(e.size().height()))
+ dataLayout.addWidget(dummyCombo2)
+ dataSelection.setLayout(dataLayout)
+
+ self.__splitter = qt.QSplitter(self)
+ self.__splitter.setContentsMargins(0, 0, 0, 0)
+ self.__splitter.addWidget(sideBarWidget)
+ self.__splitter.addWidget(datasetSelection)
+ self.__splitter.addWidget(dataSelection)
+ self.__splitter.setStretchFactor(1, 10)
+
+ bottomLayout = qt.QHBoxLayout()
+ bottomLayout.setContentsMargins(0, 0, 0, 0)
+ bottomLayout.addWidget(self.__pathEdit)
+ bottomLayout.addWidget(self.__buttons)
+
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.__splitter)
+ layout.addLayout(bottomLayout)
+
+ self.setLayout(layout)
+ self.updateGeometry()
+
+ # Logic
+
+ def __navigateBackward(self):
+ """Navigate through the history one step backward."""
+ if len(self.__currentHistory) > 0 and self.__currentHistoryLocation > 0:
+ self.__currentHistoryLocation -= 1
+ url = self.__currentHistory[self.__currentHistoryLocation]
+ self.selectUrl(url)
+
+ def __navigateForward(self):
+ """Navigate through the history one step forward."""
+ if len(self.__currentHistory) > 0 and self.__currentHistoryLocation < len(self.__currentHistory) - 1:
+ self.__currentHistoryLocation += 1
+ url = self.__currentHistory[self.__currentHistoryLocation]
+ self.selectUrl(url)
+
+ def __navigateToParent(self):
+ index = self.__browser.rootIndex()
+ if index.model() is self.__fileModel:
+ # browse throw the file system
+ index = index.parent()
+ path = self.__fileModel.filePath(index)
+ self.__fileModel_setRootPath(path)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__updatePath()
+ elif index.model() is self.__dataModel:
+ index = index.parent()
+ if index.isValid():
+ # browse throw the hdf5
+ self.__browser.setRootIndex(index)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__updatePath()
+ else:
+ # go back to the file system
+ self.__navigateToParentDir()
+ else:
+ # Root of the file system (my computer)
+ pass
+
+ def __navigateToParentFile(self):
+ index = self.__browser.rootIndex()
+ if index.model() is self.__dataModel:
+ index = self.__dataModel.indexFromH5Object(self.__h5)
+ self.__browser.setRootIndex(index)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__updatePath()
+
+ def __navigateToParentDir(self):
+ index = self.__browser.rootIndex()
+ if index.model() is self.__dataModel:
+ path = os.path.dirname(self.__h5.file.filename)
+ index = self.__fileModel.index(path)
+ self.__browser.setRootIndex(index)
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__closeFile()
+ self.__updatePath()
+
+ def viewMode(self):
+ """Returns the current view mode.
+
+ :rtype: qt.QFileDialog.ViewMode
+ """
+ return self.__browser.viewMode()
+
+ def setViewMode(self, mode):
+ """Set the current view mode.
+
+ :param qt.QFileDialog.ViewMode mode: The new view mode
+ """
+ if mode == qt.QFileDialog.Detail:
+ self.__browser.showDetails()
+ self.__listViewAction.setChecked(False)
+ self.__detailViewAction.setChecked(True)
+ elif mode == qt.QFileDialog.List:
+ self.__browser.showList()
+ self.__listViewAction.setChecked(True)
+ self.__detailViewAction.setChecked(False)
+ else:
+ assert(False)
+
+ def __showAsListView(self):
+ self.setViewMode(qt.QFileDialog.List)
+
+ def __showAsDetailedView(self):
+ self.setViewMode(qt.QFileDialog.Detail)
+
+ def __shortcutSelected(self):
+ self.__browser.selectIndex(qt.QModelIndex())
+ self.__clearData()
+ self.__updatePath()
+ selectionModel = self.__sidebar.selectionModel()
+ indexes = selectionModel.selectedIndexes()
+ if len(indexes) == 1:
+ index = indexes[0]
+ url = self.__sidebar.model().data(index, role=qt.Qt.UserRole)
+ path = url.toLocalFile()
+ self.__fileModel_setRootPath(path)
+
+ def __browsedItemActivated(self, index):
+ if not index.isValid():
+ return
+ if index.model() is self.__fileModel:
+ path = self.__fileModel.filePath(index)
+ if self.__fileModel.isDir(index):
+ self.__fileModel_setRootPath(path)
+ if os.path.isfile(path):
+ self.__fileActivated(index)
+ elif index.model() is self.__dataModel:
+ obj = self.__dataModel.data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ if silx.io.is_group(obj):
+ self.__browser.setRootIndex(index)
+ else:
+ assert(False)
+
+ def __browsedItemSelected(self, index):
+ self.__dataSelected(index)
+ self.__updatePath()
+
+ def __fileModel_setRootPath(self, path):
+ """Set the root path of the fileModel with a filter on the
+ directoryLoaded event.
+
+ Without this filter an extra event is received (at least with PyQt4)
+ when we use for the first time the sidebar.
+
+ :param str path: Path to load
+ """
+ assert(path is not None)
+ if path != "" and not os.path.exists(path):
+ return
+ if self.hasPendingEvents():
+ # Make sure the asynchronous fileModel setRootPath is finished
+ qt.QApplication.instance().processEvents()
+
+ if self.__directoryLoadedFilter is not None:
+ if utils.samefile(self.__directoryLoadedFilter, path):
+ return
+ self.__directoryLoadedFilter = path
+ self.__processing += 1
+ if self.__fileModel is None:
+ return
+ index = self.__fileModel.setRootPath(path)
+ if not index.isValid():
+ # There is a problem with this path
+ # No asynchronous process will be waked up
+ self.__processing -= 1
+ self.__browser.setRootIndex(index, model=self.__fileModel)
+ self.__clearData()
+ self.__updatePath()
+
+ def __directoryLoaded(self, path):
+ if self.__directoryLoadedFilter is not None:
+ if not utils.samefile(self.__directoryLoadedFilter, path):
+ # Filter event which should not arrive in PyQt4
+ # The first click on the sidebar sent 2 events
+ self.__processing -= 1
+ return
+ if self.__fileModel is None:
+ return
+ index = self.__fileModel.index(path)
+ self.__browser.setRootIndex(index, model=self.__fileModel)
+ self.__updatePath()
+ self.__processing -= 1
+
+ def __closeFile(self):
+ self.__openedFiles[:] = []
+ self.__fileDirectoryAction.setEnabled(False)
+ self.__parentFileDirectoryAction.setEnabled(False)
+ if self.__h5 is not None:
+ self.__dataModel.removeH5pyObject(self.__h5)
+ self.__h5.close()
+ self.__h5 = None
+ if self.__fabio is not None:
+ if hasattr(self.__fabio, "close"):
+ self.__fabio.close()
+ self.__fabio = None
+
+ def __openFabioFile(self, filename):
+ self.__closeFile()
+ try:
+ self.__fabio = fabio.open(filename)
+ self.__openedFiles.append(self.__fabio)
+ self.__selectedFile = filename
+ except Exception as e:
+ _logger.error("Error while loading file %s: %s", filename, e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ self.__errorWhileLoadingFile = filename, e.args[0]
+ return False
+ else:
+ return True
+
+ def __openSilxFile(self, filename):
+ self.__closeFile()
+ try:
+ self.__h5 = silx.io.open(filename)
+ self.__openedFiles.append(self.__h5)
+ self.__selectedFile = filename
+ except IOError as e:
+ _logger.error("Error while loading file %s: %s", filename, e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ self.__errorWhileLoadingFile = filename, e.args[0]
+ return False
+ else:
+ self.__fileDirectoryAction.setEnabled(True)
+ self.__parentFileDirectoryAction.setEnabled(True)
+ self.__dataModel.insertH5pyObject(self.__h5)
+ return True
+
+ def __isSilxHavePriority(self, filename):
+ """Silx have priority when there is a specific decoder
+ """
+ _, ext = os.path.splitext(filename)
+ ext = "*%s" % ext
+ formats = silx.io.supported_extensions(flat_formats=False)
+ for extensions in formats.values():
+ if ext in extensions:
+ return True
+ return False
+
+ def __openFile(self, filename):
+ codec = self.__fileTypeCombo.currentCodec()
+ openners = []
+ if codec.is_autodetect():
+ if self.__isSilxHavePriority(filename):
+ openners.append(self.__openSilxFile)
+ if self._isFabioFilesSupported():
+ openners.append(self.__openFabioFile)
+ else:
+ if self._isFabioFilesSupported():
+ openners.append(self.__openFabioFile)
+ openners.append(self.__openSilxFile)
+ elif codec.is_silx_codec():
+ openners.append(self.__openSilxFile)
+ elif self._isFabioFilesSupported() and codec.is_fabio_codec():
+ # It is requested to use fabio, anyway fabio is here or not
+ openners.append(self.__openFabioFile)
+
+ for openner in openners:
+ ref = openner(filename)
+ if ref is not None:
+ return True
+ return False
+
+ def __fileActivated(self, index):
+ self.__selectedFile = None
+ path = self.__fileModel.filePath(index)
+ if os.path.isfile(path):
+ loaded = self.__openFile(path)
+ if loaded:
+ if self.__h5 is not None:
+ index = self.__dataModel.indexFromH5Object(self.__h5)
+ self.__browser.setRootIndex(index)
+ elif self.__fabio is not None:
+ data = _FabioData(self.__fabio)
+ self.__setData(data)
+ self.__updatePath()
+ else:
+ self.__clearData()
+
+ def __dataSelected(self, index):
+ selectedData = None
+ if index is not None:
+ if index.model() is self.__dataModel:
+ obj = self.__dataModel.data(index, self.__dataModel.H5PY_OBJECT_ROLE)
+ if self._isDataSupportable(obj):
+ selectedData = obj
+ elif index.model() is self.__fileModel:
+ self.__closeFile()
+ if self._isFabioFilesSupported():
+ path = self.__fileModel.filePath(index)
+ if os.path.isfile(path):
+ codec = self.__fileTypeCombo.currentCodec()
+ is_fabio_decoder = codec.is_fabio_codec()
+ is_fabio_have_priority = not codec.is_silx_codec() and not self.__isSilxHavePriority(path)
+ if is_fabio_decoder or is_fabio_have_priority:
+ # Then it's flat frame container
+ self.__openFabioFile(path)
+ if self.__fabio is not None:
+ selectedData = _FabioData(self.__fabio)
+ else:
+ assert(False)
+
+ self.__setData(selectedData)
+
+ def __filterSelected(self, index):
+ filters = self.__fileTypeCombo.itemExtensions(index)
+ self.__fileModel.setNameFilters(list(filters))
+
+ def __setData(self, data):
+ self.__data = data
+
+ if data is not None and self._isDataSupportable(data):
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.setData(data)
+ if not self.__selectorWidget.isUsed():
+ # Needed to fake the fact we have to reset the zoom in preview
+ self.__selectedData = None
+ self.__setSelectedData(data)
+ self.__selectorWidget.hide()
+ else:
+ self.__selectorWidget.setVisible(self.__selectorWidget.hasVisibleSelectors())
+ # Needed to fake the fact we have to reset the zoom in preview
+ self.__selectedData = None
+ self.__selectorWidget.selectionChanged.emit()
+ else:
+ # Needed to fake the fact we have to reset the zoom in preview
+ self.__selectedData = None
+ self.__setSelectedData(data)
+ else:
+ self.__clearData()
+ self.__updatePath()
+
+ def _isDataSupported(self, data):
+ """Check if the data can be returned by the dialog.
+
+ If true, this data can be returned by the dialog and the open button
+ while be enabled. If false the button will be disabled.
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ def _isDataSupportable(self, data):
+ """Check if the selected data can be supported at one point.
+
+ If true, the data selector will be checked and it will update the data
+ preview. Else the selecting is disabled.
+
+ :rtype: bool
+ """
+ raise NotImplementedError()
+
+ def __clearData(self):
+ """Clear the data part of the GUI"""
+ if self.__previewWidget is not None:
+ self.__previewWidget.setData(None)
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.setData(None)
+ self.__selectorWidget.hide()
+ self.__selectedData = None
+ self.__data = None
+ self.__updateDataInfo()
+ button = self.__buttons.button(qt.QDialogButtonBox.Open)
+ button.setEnabled(False)
+
+ def __selectorWidgetChanged(self):
+ data = self.__selectorWidget.getSelectedData(self.__data)
+ self.__setSelectedData(data)
+
+ def __setSelectedData(self, data):
+ """Set the data selected by the dialog.
+
+ If :meth:`_isDataSupported` returns false, this function will be
+ inhibited and no data will be selected.
+ """
+ if isinstance(data, _FabioData):
+ data = data[()]
+ if self.__previewWidget is not None:
+ fromDataSelector = self.__selectedData is not None
+ self.__previewWidget.setData(data, fromDataSelector=fromDataSelector)
+ if self._isDataSupported(data):
+ self.__selectedData = data
+ else:
+ self.__clearData()
+ return
+ self.__updateDataInfo()
+ self.__updatePath()
+ button = self.__buttons.button(qt.QDialogButtonBox.Open)
+ button.setEnabled(True)
+
+ def __updateDataInfo(self):
+ if self.__errorWhileLoadingFile is not None:
+ filename, message = self.__errorWhileLoadingFile
+ message = "<b>Error while loading file '%s'</b><hr/>%s" % (filename, message)
+ size = self.__dataInfo.height()
+ icon = self.style().standardIcon(qt.QStyle.SP_MessageBoxCritical)
+ pixmap = icon.pixmap(size, size)
+
+ self.__dataInfo.setText("Error while loading file")
+ self.__dataInfo.setToolTip(message)
+ self.__dataIcon.setToolTip(message)
+ self.__dataIcon.setVisible(True)
+ self.__dataIcon.setPixmap(pixmap)
+
+ self.__errorWhileLoadingFile = None
+ return
+
+ self.__dataIcon.setVisible(False)
+ self.__dataInfo.setToolTip("")
+ if self.__selectedData is None:
+ self.__dataInfo.setText("No data selected")
+ else:
+ text = self._displayedDataInfo(self.__data, self.__selectedData)
+ self.__dataInfo.setVisible(text is not None)
+ if text is not None:
+ self.__dataInfo.setText(text)
+
+ def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
+ """Returns the text displayed under the data preview.
+
+ This zone is used to display error in case or problem of data selection
+ or problems with IO.
+
+ :param numpy.ndarray dataAfterSelection: Data as it is after the
+ selection widget (basically the data from the preview widget)
+ :param numpy.ndarray dataAfterSelection: Data as it is before the
+ selection widget (basically the data from the browsing widget)
+ :rtype: bool
+ """
+ return None
+
+ def __createUrlFromIndex(self, index, useSelectorWidget=True):
+ if index.model() is self.__fileModel:
+ filename = self.__fileModel.filePath(index)
+ dataPath = None
+ elif index.model() is self.__dataModel:
+ obj = self.__dataModel.data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ filename = obj.file.filename
+ dataPath = obj.name
+ else:
+ # root of the computer
+ filename = ""
+ dataPath = None
+
+ if useSelectorWidget and self.__selectorWidget is not None and self.__selectorWidget.isUsed():
+ slicing = self.__selectorWidget.slicing()
+ if slicing == tuple():
+ slicing = None
+ else:
+ slicing = None
+
+ if self.__fabio is not None:
+ scheme = "fabio"
+ elif self.__h5 is not None:
+ scheme = "silx"
+ else:
+ if os.path.isfile(filename):
+ codec = self.__fileTypeCombo.currentCodec()
+ if codec.is_fabio_codec():
+ scheme = "fabio"
+ elif codec.is_silx_codec():
+ scheme = "silx"
+ else:
+ scheme = None
+ else:
+ scheme = None
+
+ url = silx.io.url.DataUrl(file_path=filename, data_path=dataPath, data_slice=slicing, scheme=scheme)
+ return url
+
+ def __updatePath(self):
+ index = self.__browser.selectedIndex()
+ if index is None:
+ index = self.__browser.rootIndex()
+ url = self.__createUrlFromIndex(index)
+ if url.path() != self.__pathEdit.text():
+ old = self.__pathEdit.blockSignals(True)
+ self.__pathEdit.setText(url.path())
+ self.__pathEdit.blockSignals(old)
+
+ def __rootIndexChanged(self, index):
+ url = self.__createUrlFromIndex(index, useSelectorWidget=False)
+
+ currentUrl = None
+ if 0 <= self.__currentHistoryLocation < len(self.__currentHistory):
+ currentUrl = self.__currentHistory[self.__currentHistoryLocation]
+
+ if currentUrl is None or currentUrl != url.path():
+ # clean up the forward history
+ self.__currentHistory = self.__currentHistory[0:self.__currentHistoryLocation + 1]
+ self.__currentHistory.append(url.path())
+ self.__currentHistoryLocation += 1
+
+ if index.model() != self.__dataModel:
+ if sys.platform == "win32":
+ # path == ""
+ isRoot = not index.isValid()
+ else:
+ # path in ["", "/"]
+ isRoot = not index.isValid() or not index.parent().isValid()
+ else:
+ isRoot = False
+
+ if index.isValid():
+ self.__dataSelected(index)
+ self.__toParentAction.setEnabled(not isRoot)
+ self.__updateActionHistory()
+ self.__updateSidebar()
+
+ def __updateSidebar(self):
+ """Called when the current directory location change"""
+ if self.__sidebar is None:
+ return
+ selectionModel = self.__sidebar.selectionModel()
+ selectionModel.selectionChanged.disconnect(self.__shortcutSelected)
+ index = self.__browser.rootIndex()
+ if index.model() == self.__fileModel:
+ path = self.__fileModel.filePath(index)
+ self.__sidebar.setSelectedPath(path)
+ elif index.model() is None:
+ path = ""
+ self.__sidebar.setSelectedPath(path)
+ else:
+ selectionModel.clear()
+ selectionModel.selectionChanged.connect(self.__shortcutSelected)
+
+ def __updateActionHistory(self):
+ self.__forwardAction.setEnabled(len(self.__currentHistory) - 1 > self.__currentHistoryLocation)
+ self.__backwardAction.setEnabled(self.__currentHistoryLocation > 0)
+
+ def __textChanged(self, text):
+ self.__pathChanged()
+
+ def _isFabioFilesSupported(self):
+ """Returns true fabio files can be loaded.
+ """
+ return True
+
+ def _isLoadableUrl(self, url):
+ """Returns true if the URL is loadable by this dialog.
+
+ :param DataUrl url: The requested URL
+ """
+ return True
+
+ def __pathChanged(self):
+ url = silx.io.url.DataUrl(path=self.__pathEdit.text())
+ if url.is_valid() or url.path() == "":
+ if url.path() in ["", "/"] or url.file_path() in ["", "/"]:
+ self.__fileModel_setRootPath(qt.QDir.rootPath())
+ elif os.path.exists(url.file_path()):
+ rootIndex = None
+ if os.path.isdir(url.file_path()):
+ self.__fileModel_setRootPath(url.file_path())
+ index = self.__fileModel.index(url.file_path())
+ elif os.path.isfile(url.file_path()):
+ if self._isLoadableUrl(url):
+ if url.scheme() == "silx":
+ loaded = self.__openSilxFile(url.file_path())
+ elif url.scheme() == "fabio" and self._isFabioFilesSupported():
+ loaded = self.__openFabioFile(url.file_path())
+ else:
+ loaded = self.__openFile(url.file_path())
+ else:
+ loaded = False
+ if loaded:
+ if self.__h5 is not None:
+ rootIndex = self.__dataModel.indexFromH5Object(self.__h5)
+ elif self.__fabio is not None:
+ index = self.__fileModel.index(url.file_path())
+ rootIndex = index
+ if rootIndex is None:
+ index = self.__fileModel.index(url.file_path())
+ index = index.parent()
+
+ if rootIndex is not None:
+ if rootIndex.model() == self.__dataModel:
+ if url.data_path() is not None:
+ dataPath = url.data_path()
+ if dataPath in self.__h5:
+ obj = self.__h5[dataPath]
+ else:
+ path = utils.findClosestSubPath(self.__h5, dataPath)
+ if path is None:
+ path = "/"
+ obj = self.__h5[path]
+
+ if silx.io.is_file(obj):
+ self.__browser.setRootIndex(rootIndex)
+ elif silx.io.is_group(obj):
+ index = self.__dataModel.indexFromH5Object(obj)
+ self.__browser.setRootIndex(index)
+ else:
+ index = self.__dataModel.indexFromH5Object(obj)
+ self.__browser.setRootIndex(index.parent())
+ self.__browser.selectIndex(index)
+ else:
+ self.__browser.setRootIndex(rootIndex)
+ self.__clearData()
+ elif rootIndex.model() == self.__fileModel:
+ # that's a fabio file
+ self.__browser.setRootIndex(rootIndex.parent())
+ self.__browser.selectIndex(rootIndex)
+ # data = _FabioData(self.__fabio)
+ # self.__setData(data)
+ else:
+ assert(False)
+ else:
+ self.__browser.setRootIndex(index, model=self.__fileModel)
+ self.__clearData()
+
+ if self.__selectorWidget is not None:
+ self.__selectorWidget.selectSlicing(url.data_slice())
+ else:
+ self.__errorWhileLoadingFile = (url.file_path(), "File not found")
+ self.__clearData()
+ else:
+ self.__errorWhileLoadingFile = (url.file_path(), "Path invalid")
+ self.__clearData()
+
+ def previewToolbar(self):
+ return self.__previewToolbar
+
+ def previewWidget(self):
+ return self.__previewWidget
+
+ def selectorWidget(self):
+ return self.__selectorWidget
+
+ def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
+ return None
+
+ def _createPreviewWidget(self, parent):
+ return None
+
+ def _createSelectorWidget(self, parent):
+ return None
+
+ # Selected file
+
+ def setDirectory(self, path):
+ """Sets the data dialog's current directory."""
+ self.__fileModel_setRootPath(path)
+
+ def selectedFile(self):
+ """Returns the file path containing the selected data.
+
+ :rtype: str
+ """
+ return self.__selectedFile
+
+ def selectFile(self, filename):
+ """Sets the data dialog's current file."""
+ self.__directoryLoadedFilter = ""
+ old = self.__pathEdit.blockSignals(True)
+ try:
+ self.__pathEdit.setText(filename)
+ finally:
+ self.__pathEdit.blockSignals(old)
+ self.__pathChanged()
+
+ # Selected data
+
+ def selectUrl(self, url):
+ """Sets the data dialog's current data url.
+
+ :param Union[str,DataUrl] url: URL identifying a data (it can be a
+ `DataUrl` object)
+ """
+ if isinstance(url, silx.io.url.DataUrl):
+ url = url.path()
+ self.__directoryLoadedFilter = ""
+ old = self.__pathEdit.blockSignals(True)
+ try:
+ self.__pathEdit.setText(url)
+ finally:
+ self.__pathEdit.blockSignals(old)
+ self.__pathChanged()
+
+ def selectedUrl(self):
+ """Returns the URL from the file system to the data.
+
+ If the dialog is not validated, the path can be an intermediat
+ selected path, or an invalid path.
+
+ :rtype: str
+ """
+ return self.__pathEdit.text()
+
+ def selectedDataUrl(self):
+ """Returns the URL as a :class:`DataUrl` from the file system to the
+ data.
+
+ If the dialog is not validated, the path can be an intermediat
+ selected path, or an invalid path.
+
+ :rtype: DataUrl
+ """
+ url = self.selectedUrl()
+ return silx.io.url.DataUrl(url)
+
+ def directory(self):
+ """Returns the path from the current browsed directory.
+
+ :rtype: str
+ """
+ if self.__directory is not None:
+ # At post execution, returns the cache
+ return self.__directory
+
+ index = self.__browser.rootIndex()
+ if index.model() is self.__fileModel:
+ path = self.__fileModel.filePath(index)
+ return path
+ elif index.model() is self.__dataModel:
+ path = os.path.dirname(self.__h5.file.filename)
+ return path
+ else:
+ return ""
+
+ def _selectedData(self):
+ """Returns the internal selected data
+
+ :rtype: numpy.ndarray
+ """
+ return self.__selectedData
+
+ # Filters
+
+ def selectedNameFilter(self):
+ """Returns the filter that the user selected in the file dialog."""
+ return self.__fileTypeCombo.currentText()
+
+ # History
+
+ def history(self):
+ """Returns the browsing history of the filedialog as a list of paths.
+
+ :rtype: List<str>
+ """
+ if len(self.__currentHistory) <= 1:
+ return []
+ history = self.__currentHistory[0:self.__currentHistoryLocation]
+ return list(history)
+
+ def setHistory(self, history):
+ self.__currentHistory = []
+ self.__currentHistory.extend(history)
+ self.__currentHistoryLocation = len(self.__currentHistory) - 1
+ self.__updateActionHistory()
+
+ # Colormap
+
+ def colormap(self):
+ if self.__previewWidget is None:
+ return None
+ return self.__previewWidget.colormap()
+
+ def setColormap(self, colormap):
+ if self.__previewWidget is None:
+ raise RuntimeError("No preview widget defined")
+ self.__previewWidget.setColormap(colormap)
+
+ # Sidebar
+
+ def setSidebarUrls(self, urls):
+ """Sets the urls that are located in the sidebar."""
+ if self.__sidebar is None:
+ return
+ self.__sidebar.setUrls(urls)
+
+ def sidebarUrls(self):
+ """Returns a list of urls that are currently in the sidebar."""
+ if self.__sidebar is None:
+ return []
+ return self.__sidebar.urls()
+
+ # State
+
+ __serialVersion = 1
+ """Store the current version of the serialized data"""
+
+ @classmethod
+ def qualifiedName(cls):
+ return "%s.%s" % (cls.__module__, cls.__name__)
+
+ def restoreState(self, state):
+ """Restores the dialogs's layout, history and current directory to the
+ state specified.
+
+ :param qt.QByteArray state: Stream containing the new state
+ :rtype: bool
+ """
+ stream = qt.QDataStream(state, qt.QIODevice.ReadOnly)
+
+ qualifiedName = stream.readQString()
+ if qualifiedName != self.qualifiedName():
+ _logger.warning("Stored state contains an invalid qualified name. %s restoration cancelled.", self.__class__.__name__)
+ return False
+
+ version = stream.readInt32()
+ if version != self.__serialVersion:
+ _logger.warning("Stored state contains an invalid version. %s restoration cancelled.", self.__class__.__name__)
+ return False
+
+ result = True
+
+ splitterData = stream.readQVariant()
+ sidebarUrls = stream.readQStringList()
+ history = stream.readQStringList()
+ workingDirectory = stream.readQString()
+ browserData = stream.readQVariant()
+ viewMode = stream.readInt32()
+ colormapData = stream.readQVariant()
+
+ result &= self.__splitter.restoreState(splitterData)
+ sidebarUrls = [qt.QUrl(s) for s in sidebarUrls]
+ self.setSidebarUrls(list(sidebarUrls))
+ history = [s for s in history]
+ self.setHistory(list(history))
+ if workingDirectory is not None:
+ self.setDirectory(workingDirectory)
+ result &= self.__browser.restoreState(browserData)
+ self.setViewMode(viewMode)
+ colormap = self.colormap()
+ if colormap is not None:
+ result &= self.colormap().restoreState(colormapData)
+
+ return result
+
+ def saveState(self):
+ """Saves the state of the dialog's layout, history and current
+ directory.
+
+ :rtype: qt.QByteArray
+ """
+ data = qt.QByteArray()
+ stream = qt.QDataStream(data, qt.QIODevice.WriteOnly)
+
+ s = self.qualifiedName()
+ stream.writeQString(u"%s" % s)
+ stream.writeInt32(self.__serialVersion)
+ stream.writeQVariant(self.__splitter.saveState())
+ strings = [u"%s" % s.toString() for s in self.sidebarUrls()]
+ stream.writeQStringList(strings)
+ strings = [u"%s" % s for s in self.history()]
+ stream.writeQStringList(strings)
+ stream.writeQString(u"%s" % self.directory())
+ stream.writeQVariant(self.__browser.saveState())
+ stream.writeInt32(self.viewMode())
+ colormap = self.colormap()
+ if colormap is not None:
+ stream.writeQVariant(self.colormap().saveState())
+ else:
+ stream.writeQVariant(None)
+
+ return data
diff --git a/src/silx/gui/dialog/ColormapDialog.py b/src/silx/gui/dialog/ColormapDialog.py
new file mode 100644
index 0000000..2506e2a
--- /dev/null
+++ b/src/silx/gui/dialog/ColormapDialog.py
@@ -0,0 +1,1775 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A QDialog widget to set-up the colormap.
+
+It uses a description of colormaps as dict compatible with :class:`Plot`.
+
+To run the following sample code, a QApplication must be initialized.
+
+Create the colormap dialog and set the colormap description and data range:
+
+>>> from silx.gui.dialog.ColormapDialog import ColormapDialog
+>>> from silx.gui.colors import Colormap
+
+>>> dialog = ColormapDialog()
+>>> colormap = Colormap(name='red', normalization='log',
+... vmin=1., vmax=2.)
+
+>>> dialog.setColormap(colormap)
+>>> colormap.setVRange(1., 100.) # This scale the width of the plot area
+>>> dialog.show()
+
+Get the colormap description (compatible with :class:`Plot`) from the dialog:
+
+>>> cmap = dialog.getColormap()
+>>> cmap.getName()
+'red'
+
+It is also possible to display an histogram of the image in the dialog.
+This updates the data range with the range of the bins.
+
+>>> import numpy
+>>> image = numpy.random.normal(size=512 * 512).reshape(512, -1)
+>>> hist, bin_edges = numpy.histogram(image, bins=10)
+>>> dialog.setHistogram(hist, bin_edges)
+
+The updates of the colormap description are also available through the signal:
+:attr:`ColormapDialog.sigColormapChanged`.
+""" # noqa
+
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import enum
+import logging
+
+import numpy
+
+from .. import qt
+from .. import utils
+from ..colors import Colormap, cursorColorForColormap
+from ..plot import PlotWidget
+from ..plot.items.axis import Axis
+from ..plot.items import BoundingRect
+from silx.gui.widgets.FloatEdit import FloatEdit
+import weakref
+from silx.math.combo import min_max
+from silx.gui.plot import items
+from silx.gui import icons
+from silx.gui.qt import inspect as qtinspect
+from silx.gui.widgets.ColormapNameComboBox import ColormapNameComboBox
+from silx.gui.widgets.WaitingPushButton import WaitingPushButton
+from silx.math.histogram import Histogramnd
+from silx.utils import deprecation
+from silx.gui.plot.items.roi import RectangleROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+
+_logger = logging.getLogger(__name__)
+
+_colormapIconPreview = {}
+
+
+class _DataRefHolder(items.Item, items.ColormapMixIn):
+ """Holder for a weakref of a numpy array.
+
+ It provides features from `ColormapMixIn`.
+ """
+
+ def __init__(self, dataRef):
+ items.Item.__init__(self)
+ items.ColormapMixIn.__init__(self)
+ self.__dataRef = dataRef
+ self._updated(items.ItemChangedType.DATA)
+
+ def getColormappedData(self, copy=True):
+ return self.__dataRef()
+
+
+class _BoundaryWidget(qt.QWidget):
+ """Widget to edit a boundary of the colormap (vmin or vmax)"""
+
+ sigAutoScaleChanged = qt.Signal(object)
+ """Signal emitted when the autoscale was changed
+
+ True is sent as an argument if autoscale is set to true.
+ """
+
+ sigValueChanged = qt.Signal(object)
+ """Signal emitted when value is changed
+
+ The new value is sent as an argument.
+ """
+
+ def __init__(self, parent=None, value=0.0):
+ qt.QWidget.__init__(self, parent=parent)
+ self.setLayout(qt.QHBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._numVal = FloatEdit(parent=self, value=value)
+ self.layout().addWidget(self._numVal)
+ self._autoCB = qt.QCheckBox('auto', parent=self)
+ self.layout().addWidget(self._autoCB)
+ self._autoCB.setChecked(False)
+ self._autoCB.setVisible(False)
+
+ self._autoCB.toggled.connect(self._autoToggled)
+ self._numVal.textEdited.connect(self.__textEdited)
+ self._numVal.editingFinished.connect(self.__editingFinished)
+ self.setFocusProxy(self._numVal)
+
+ self.__textWasEdited = False
+ """True if the text was edited, in order to send an event
+ at the end of the user interaction"""
+
+ self.__realValue = None
+ """Store the real value set by setValue, to avoid
+ rounding of the widget"""
+
+ def __textEdited(self):
+ self.__textWasEdited = True
+
+ def __editingFinished(self):
+ if self.__textWasEdited:
+ value = self._numVal.value()
+ self.__realValue = value
+ with utils.blockSignals(self._numVal):
+ # Fix the formatting
+ self._numVal.setValue(self.__realValue)
+ self.sigValueChanged.emit(value)
+ self.__textWasEdited = False
+
+ def isAutoChecked(self):
+ return self._autoCB.isChecked()
+
+ def getValue(self):
+ """Returns the stored range. If autoscale is
+ enabled, this returns None.
+ """
+ if self._autoCB.isChecked():
+ return None
+ if self.__realValue is not None:
+ return self.__realValue
+ return self._numVal.value()
+
+ def _autoToggled(self, enabled):
+ self._numVal.setEnabled(not enabled)
+ self._updateDisplayedText()
+ self.sigAutoScaleChanged.emit(enabled)
+
+ def _updateDisplayedText(self):
+ self.__textWasEdited = False
+ if self._autoCB.isChecked() and self.__realValue is not None:
+ with utils.blockSignals(self._numVal):
+ self._numVal.setValue(self.__realValue)
+
+ def setValue(self, value, isAuto=False):
+ """Set the value of the boundary.
+
+ :param float value: A finite value for the boundary
+ :param bool isAuto: If true, the finite value was automatically computed
+ from the data, else it is a fixed custom value.
+ """
+ assert value is not None
+ self._autoCB.setChecked(isAuto)
+ with utils.blockSignals(self._numVal):
+ if isAuto or self.__realValue != value:
+ if not self.__textWasEdited:
+ self._numVal.setValue(value)
+ self.__realValue = value
+ self._numVal.setEnabled(not isAuto)
+
+
+class _AutoscaleModeComboBox(qt.QComboBox):
+
+ DATA = {
+ Colormap.MINMAX: ("Min/max", "Use the data min/max"),
+ Colormap.STDDEV3: ("Mean ± 3 × stddev", "Use the data mean ± 3 × standard deviation"),
+ }
+
+ def __init__(self, parent: qt.QWidget):
+ super(_AutoscaleModeComboBox, self).__init__(parent=parent)
+ self.currentIndexChanged.connect(self.__updateTooltip)
+ self._init()
+
+ def _init(self):
+ for mode in Colormap.AUTOSCALE_MODES:
+ label, tooltip = self.DATA.get(mode, (mode, None))
+ self.addItem(label, mode)
+ if tooltip is not None:
+ self.setItemData(self.count() - 1, tooltip, qt.Qt.ToolTipRole)
+
+ def setCurrentIndex(self, index):
+ self.__updateTooltip(index)
+ super(_AutoscaleModeComboBox, self).setCurrentIndex(index)
+
+ def __updateTooltip(self, index):
+ if index > -1:
+ tooltip = self.itemData(index, qt.Qt.ToolTipRole)
+ else:
+ tooltip = ""
+ self.setToolTip(tooltip)
+
+ def currentMode(self):
+ index = self.currentIndex()
+ return self.itemData(index)
+
+ def setCurrentMode(self, mode):
+ for index in range(self.count()):
+ if mode == self.itemData(index):
+ self.setCurrentIndex(index)
+ return
+ if mode is None:
+ # If None was not a value
+ self.setCurrentIndex(-1)
+ return
+ self.addItem(mode, mode)
+ self.setCurrentIndex(self.count() - 1)
+
+
+class _AutoScaleButtons(qt.QWidget):
+
+ autoRangeChanged = qt.Signal(object)
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent=parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ self.setFocusPolicy(qt.Qt.NoFocus)
+
+ self._bothAuto = qt.QPushButton(self)
+ self._bothAuto.setText("Autoscale")
+ self._bothAuto.setToolTip("Enable/disable the autoscale for both min and max")
+ self._bothAuto.setCheckable(True)
+ self._bothAuto.toggled[bool].connect(self.__bothToggled)
+ self._bothAuto.setFocusPolicy(qt.Qt.TabFocus)
+
+ self._minAuto = qt.QCheckBox(self)
+ self._minAuto.setText("")
+ self._minAuto.setToolTip("Enable/disable the autoscale for min")
+ self._minAuto.toggled[bool].connect(self.__minToggled)
+ self._minAuto.setFocusPolicy(qt.Qt.TabFocus)
+
+ self._maxAuto = qt.QCheckBox(self)
+ self._maxAuto.setText("")
+ self._maxAuto.setToolTip("Enable/disable the autoscale for max")
+ self._maxAuto.toggled[bool].connect(self.__maxToggled)
+ self._maxAuto.setFocusPolicy(qt.Qt.TabFocus)
+
+ layout.addStretch(1)
+ layout.addWidget(self._minAuto)
+ layout.addSpacing(20)
+ layout.addWidget(self._bothAuto)
+ layout.addSpacing(20)
+ layout.addWidget(self._maxAuto)
+ layout.addStretch(1)
+
+ def __bothToggled(self, checked):
+ autoRange = checked, checked
+ self.setAutoRange(autoRange)
+ self.autoRangeChanged.emit(autoRange)
+
+ def __minToggled(self, checked):
+ autoRange = self.getAutoRange()
+ self.setAutoRange(autoRange)
+ self.autoRangeChanged.emit(autoRange)
+
+ def __maxToggled(self, checked):
+ autoRange = self.getAutoRange()
+ self.setAutoRange(autoRange)
+ self.autoRangeChanged.emit(autoRange)
+
+ def setAutoRangeFromColormap(self, colormap):
+ vRange = colormap.getVRange()
+ autoRange = vRange[0] is None, vRange[1] is None
+ self.setAutoRange(autoRange)
+
+ def setAutoRange(self, autoRange):
+ if autoRange[0] == autoRange[1]:
+ with utils.blockSignals(self._bothAuto):
+ self._bothAuto.setChecked(autoRange[0])
+ else:
+ with utils.blockSignals(self._bothAuto):
+ self._bothAuto.setChecked(False)
+ with utils.blockSignals(self._minAuto):
+ self._minAuto.setChecked(autoRange[0])
+ with utils.blockSignals(self._maxAuto):
+ self._maxAuto.setChecked(autoRange[1])
+
+ def getAutoRange(self):
+ return self._minAuto.isChecked(), self._maxAuto.isChecked()
+
+
+@enum.unique
+class _DataInPlotMode(enum.Enum):
+ """Enum for each mode of display of the data in the plot."""
+ RANGE = 'range'
+ HISTOGRAM = 'histogram'
+
+
+class _ColormapHistogram(qt.QWidget):
+ """Display the colormap and the data as a plot."""
+
+ sigRangeMoving = qt.Signal(object, object)
+ """Emitted when a mouse interaction moves the location
+ of the colormap range in the plot.
+
+ This signal contains 2 elements:
+
+ - vmin: A float value if this range was moved, else None
+ - vmax: A float value if this range was moved, else None
+ """
+
+ sigRangeMoved = qt.Signal(object, object)
+ """Emitted when a mouse interaction stop.
+
+ This signal contains 2 elements:
+
+ - vmin: A float value if this range was moved, else None
+ - vmax: A float value if this range was moved, else None
+ """
+
+ def __init__(self, parent):
+ qt.QWidget.__init__(self, parent=parent)
+ self._dataInPlotMode = _DataInPlotMode.RANGE
+ self._finiteRange = None, None
+ self._initPlot()
+
+ self._histogramData = {}
+ """Histogram displayed in the plot"""
+
+ self._dragging = False, False
+ """True, if the min or the max handle is dragging"""
+
+ self._dataRange = {}
+ """Histogram displayed in the plot"""
+
+ self._invalidated = False
+
+ def paintEvent(self, event):
+ if self._invalidated:
+ self._updateDataInPlot()
+ self._invalidated = False
+ self._updateMarkerPosition()
+ return super(_ColormapHistogram, self).paintEvent(event)
+
+ def getFiniteRange(self):
+ """Returns the colormap range as displayed in the plot."""
+ return self._finiteRange
+
+ def setFiniteRange(self, vRange):
+ """Set the colormap range to use in the plot.
+
+ Here there is no concept of auto. The values should
+ not be None, except if there is no range or marker
+ to display.
+ """
+ # Do not reset the limit for handle about to be dragged
+ if self._dragging[0]:
+ vRange = self._finiteRange[0], vRange[1]
+ if self._dragging[1]:
+ vRange = vRange[0], self._finiteRange[1]
+
+ if vRange == self._finiteRange:
+ return
+
+ self._finiteRange = vRange
+ self.update()
+
+ def getColormap(self):
+ return self.parent().getColormap()
+
+ def _getNormalizedHistogram(self):
+ """Return an histogram already normalized according to the colormap
+ normalization.
+
+ Returns a tuple edges, counts
+ """
+ norm = self._getNorm()
+ histogram = self._histogramData.get(norm, None)
+ if histogram is None:
+ histogram = self._computeNormalizedHistogram()
+ self._histogramData[norm] = histogram
+ return histogram
+
+ def _computeNormalizedHistogram(self):
+ colormap = self.getColormap()
+ if colormap is None:
+ norm = Colormap.LINEAR
+ else:
+ norm = colormap.getNormalization()
+
+ # Try to use the histogram defined in the dialog
+ histo = self.parent()._getHistogram()
+ if histo is not None:
+ counts, edges = histo
+ normalizer = Colormap(normalization=norm)._getNormalizer()
+ mask = normalizer.is_valid(edges[:-1]) # Check lower bin edges only
+ firstValid = numpy.argmax(mask) # edges increases monotonically
+ if firstValid == 0: # Mask is all False or all True
+ return (counts, edges) if mask[0] else (None, None)
+ else: # Clip to valid values
+ return counts[firstValid:], edges[firstValid:]
+
+ data = self.parent()._getArray()
+ if data is None:
+ return None, None
+ dataRange = self._getNormalizedDataRange()
+ if dataRange[0] is None or dataRange[1] is None:
+ return None, None
+ counts, edges = self.parent().computeHistogram(data, scale=norm, dataRange=dataRange)
+ return counts, edges
+
+ def _getNormalizedDataRange(self):
+ """Return a data range already normalized according to the colormap
+ normalization.
+
+ Returns a tuple with min and max
+ """
+ norm = self._getNorm()
+ dataRange = self._dataRange.get(norm, None)
+ if dataRange is None:
+ dataRange = self._computeNormalizedDataRange()
+ self._dataRange[norm] = dataRange
+ return dataRange
+
+ def _computeNormalizedDataRange(self):
+ colormap = self.getColormap()
+ if colormap is None:
+ norm = Colormap.LINEAR
+ else:
+ norm = colormap.getNormalization()
+
+ # Try to use the one defined in the dialog
+ dataRange = self.parent()._getDataRange()
+ if dataRange is not None:
+ if norm in (Colormap.LINEAR, Colormap.GAMMA, Colormap.ARCSINH):
+ return dataRange[0], dataRange[2]
+ elif norm == Colormap.LOGARITHM:
+ return dataRange[1], dataRange[2]
+ elif norm == Colormap.SQRT:
+ return dataRange[1], dataRange[2]
+ else:
+ _logger.error("Undefined %s normalization", norm)
+
+ # Try to use the histogram defined in the dialog
+ histo = self.parent()._getHistogram()
+ if histo is not None:
+ _histo, edges = histo
+ normalizer = Colormap(normalization=norm)._getNormalizer()
+ edges = edges[normalizer.is_valid(edges)]
+ if edges.size == 0:
+ return None, None
+ else:
+ dataRange = min_max(edges, finite=True)
+ return dataRange.minimum, dataRange.maximum
+
+ item = self.parent()._getItem()
+ if item is not None:
+ # Trick to reach data range using colormap cache
+ cm = Colormap()
+ cm.setVRange(None, None)
+ cm.setNormalization(norm)
+ dataRange = item._getColormapAutoscaleRange(cm)
+ return dataRange
+
+ # If there is no item, there is no data
+ return None, None
+
+ def _getDisplayableRange(self):
+ """Returns the selected min/max range to apply to the data,
+ according to the used scale.
+
+ One or both limits can be None in case it is not displayable in the
+ current axes scale.
+
+ :returns: Tuple{float, float}
+ """
+ scale = self._plot.getXAxis().getScale()
+
+ def isDisplayable(pos):
+ if pos is None:
+ return False
+ if scale == Axis.LOGARITHMIC:
+ return pos > 0.0
+ return True
+
+ posMin, posMax = self.getFiniteRange()
+ if not isDisplayable(posMin):
+ posMin = None
+ if not isDisplayable(posMax):
+ posMax = None
+
+ return posMin, posMax
+
+ def _initPlot(self):
+ """Init the plot to display the range and the values"""
+ self._plot = PlotWidget(self)
+ self._plot.setDataMargins(0.125, 0.125, 0.125, 0.125)
+ self._plot.getXAxis().setLabel("Data Values")
+ self._plot.getYAxis().setLabel("")
+ self._plot.setInteractiveMode('select', zoomOnWheel=False)
+ self._plot.setActiveCurveHandling(False)
+ self._plot.setMinimumSize(qt.QSize(250, 200))
+ self._plot.sigPlotSignal.connect(self._plotEventReceived)
+ palette = self.palette()
+ color = palette.color(qt.QPalette.Normal, qt.QPalette.Window)
+ self._plot.setBackgroundColor(color)
+ self._plot.setDataBackgroundColor("white")
+
+ lut = numpy.arange(256)
+ lut.shape = 1, -1
+ self._plot.addImage(lut, legend='lut')
+ self._lutItem = self._plot._getItem("image", "lut")
+ self._lutItem.setVisible(False)
+
+ self._plot.addScatter(x=[], y=[], value=[], legend='lut2')
+ self._lutItem2 = self._plot._getItem("scatter", "lut2")
+ self._lutItem2.setVisible(False)
+ self.__lutY = numpy.array([-0.05] * 256)
+ self.__lutV = numpy.arange(256)
+
+ self._bound = BoundingRect()
+ self._plot.addItem(self._bound)
+ self._bound.setVisible(True)
+
+ # Add plot for histogram
+ self._plotToolbar = qt.QToolBar(self)
+ self._plotToolbar.setFloatable(False)
+ self._plotToolbar.setMovable(False)
+ self._plotToolbar.setIconSize(qt.QSize(8, 8))
+ self._plotToolbar.setStyleSheet("QToolBar { border: 0px }")
+ self._plotToolbar.setOrientation(qt.Qt.Vertical)
+
+ group = qt.QActionGroup(self._plotToolbar)
+ group.setExclusive(True)
+
+ action = qt.QAction("Data range", self)
+ action.setToolTip("Display the data range within the colormap range. A fast data processing have to be done.")
+ action.setIcon(icons.getQIcon('colormap-range'))
+ action.setCheckable(True)
+ action.setData(_DataInPlotMode.RANGE)
+ action.setChecked(action.data() == self._dataInPlotMode)
+ self._plotToolbar.addAction(action)
+ group.addAction(action)
+ action = qt.QAction("Histogram", self)
+ action.setToolTip("Display the data histogram within the colormap range. A slow data processing have to be done. ")
+ action.setIcon(icons.getQIcon('colormap-histogram'))
+ action.setCheckable(True)
+ action.setData(_DataInPlotMode.HISTOGRAM)
+ action.setChecked(action.data() == self._dataInPlotMode)
+ self._plotToolbar.addAction(action)
+ group.addAction(action)
+ group.triggered.connect(self._displayDataInPlotModeChanged)
+
+ plotBoxLayout = qt.QHBoxLayout()
+ plotBoxLayout.setContentsMargins(0, 0, 0, 0)
+ plotBoxLayout.setSpacing(2)
+ plotBoxLayout.addWidget(self._plotToolbar)
+ plotBoxLayout.addWidget(self._plot)
+ plotBoxLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
+ self.setLayout(plotBoxLayout)
+
+ def _plotEventReceived(self, event):
+ """Handle events from the plot"""
+ kind = event['event']
+
+ if kind == 'markerMoving':
+ value = event['xdata']
+ if event['label'] == 'Min':
+ self._dragging = True, False
+ self._finiteRange = value, self._finiteRange[1]
+ self._last = value, None
+ self.sigRangeMoving.emit(*self._last)
+ elif event['label'] == 'Max':
+ self._dragging = False, True
+ self._finiteRange = self._finiteRange[0], value
+ self._last = None, value
+ self.sigRangeMoving.emit(*self._last)
+ self._updateLutItem(self._finiteRange)
+ elif kind == 'markerMoved':
+ self.sigRangeMoved.emit(*self._last)
+ self._plot.resetZoom()
+ self._dragging = False, False
+ else:
+ pass
+
+ def _updateMarkerPosition(self):
+ colormap = self.getColormap()
+ posMin, posMax = self._getDisplayableRange()
+
+ if colormap is None:
+ isDraggable = False
+ else:
+ isDraggable = colormap.isEditable()
+
+ with utils.blockSignals(self):
+ if posMin is not None and not self._dragging[0]:
+ self._plot.addXMarker(
+ posMin,
+ legend='Min',
+ text='Min',
+ draggable=isDraggable,
+ color="blue",
+ constraint=self._plotMinMarkerConstraint)
+ if posMax is not None and not self._dragging[1]:
+ self._plot.addXMarker(
+ posMax,
+ legend='Max',
+ text='Max',
+ draggable=isDraggable,
+ color="blue",
+ constraint=self._plotMaxMarkerConstraint)
+
+ self._updateLutItem((posMin, posMax))
+ self._plot.resetZoom()
+
+ def _updateLutItem(self, vRange):
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ if vRange is None:
+ posMin, posMax = self._getDisplayableRange()
+ else:
+ posMin, posMax = vRange
+ if posMin is None or posMax is None:
+ self._lutItem.setVisible(False)
+ pos = posMax if posMin is None else posMin
+ if pos is not None:
+ self._bound.setBounds((pos, pos, -0.1, 0))
+ else:
+ self._bound.setBounds((0, 0, -0.1, 0))
+ else:
+ norm = colormap.getNormalization()
+ normColormap = colormap.copy()
+ normColormap.setEditable(True)
+ normColormap.setVRange(0, 255)
+ normColormap.setNormalization(Colormap.LINEAR)
+ if norm == Colormap.LINEAR:
+ scale = (posMax - posMin) / 256
+ self._lutItem.setColormap(normColormap)
+ self._lutItem.setOrigin((posMin, -0.09))
+ self._lutItem.setScale((scale, 0.08))
+ self._lutItem.setVisible(True)
+ self._lutItem2.setVisible(False)
+ elif norm == Colormap.LOGARITHM:
+ self._lutItem2.setVisible(False)
+ self._lutItem2.setColormap(normColormap)
+ xx = numpy.geomspace(posMin, posMax, 256)
+ self._lutItem2.setData(x=xx,
+ y=self.__lutY,
+ value=self.__lutV,
+ copy=False)
+ self._lutItem2.setSymbol("|")
+ self._lutItem2.setVisible(True)
+ self._lutItem.setVisible(False)
+ else:
+ # Fallback: Display with linear axis and applied normalization
+ self._lutItem2.setVisible(False)
+ normColormap.setNormalization(norm)
+ self._lutItem2.setColormap(normColormap)
+ xx = numpy.linspace(posMin, posMax, 256, endpoint=True)
+ self._lutItem2.setData(
+ x=xx,
+ y=self.__lutY,
+ value=self.__lutV,
+ copy=False)
+ self._lutItem2.setSymbol("|")
+ self._lutItem2.setVisible(True)
+ self._lutItem.setVisible(False)
+
+ self._bound.setBounds((posMin, posMax, -0.1, 1))
+
+ def _plotMinMarkerConstraint(self, x, y):
+ """Constraint of the min marker"""
+ _vmin, vmax = self.getFiniteRange()
+ if vmax is None:
+ return x, y
+ return min(x, vmax), y
+
+ def _plotMaxMarkerConstraint(self, x, y):
+ """Constraint of the max marker"""
+ vmin, _vmax = self.getFiniteRange()
+ if vmin is None:
+ return x, y
+ return max(x, vmin), y
+
+ def _setDataInPlotMode(self, mode):
+ if self._dataInPlotMode == mode:
+ return
+ self._dataInPlotMode = mode
+ self._updateDataInPlot()
+
+ def _displayDataInPlotModeChanged(self, action):
+ mode = action.data()
+ self._setDataInPlotMode(mode)
+
+ def invalidateData(self):
+ self._histogramData = {}
+ self._dataRange = {}
+ self._invalidated = True
+ self.update()
+
+ def _updateDataInPlot(self):
+ mode = self._dataInPlotMode
+
+ norm = self._getNorm()
+ if norm == Colormap.LINEAR:
+ scale = Axis.LINEAR
+ elif norm == Colormap.LOGARITHM:
+ scale = Axis.LOGARITHMIC
+ else:
+ scale = Axis.LINEAR
+
+ axis = self._plot.getXAxis()
+ axis.setScale(scale)
+
+ if mode == _DataInPlotMode.RANGE:
+ dataRange = self._getNormalizedDataRange()
+ xmin, xmax = dataRange
+ if xmax is None or xmin is None:
+ self._plot.remove(legend='Data', kind='histogram')
+ else:
+ histogram = numpy.array([1])
+ bin_edges = numpy.array([xmin, xmax])
+ self._plot.addHistogram(histogram,
+ bin_edges,
+ legend="Data",
+ color='gray',
+ align='center',
+ fill=True,
+ z=1)
+
+ elif mode == _DataInPlotMode.HISTOGRAM:
+ histogram, bin_edges = self._getNormalizedHistogram()
+ if histogram is None or bin_edges is None:
+ self._plot.remove(legend='Data', kind='histogram')
+ else:
+ histogram = numpy.array(histogram, copy=True)
+ bin_edges = numpy.array(bin_edges, copy=True)
+ with numpy.errstate(invalid='ignore'):
+ norm_histogram = histogram / numpy.nanmax(histogram)
+ self._plot.addHistogram(norm_histogram,
+ bin_edges,
+ legend="Data",
+ color='gray',
+ align='center',
+ fill=True,
+ z=1)
+ else:
+ _logger.error("Mode unsupported")
+
+ def sizeHint(self):
+ return self.layout().minimumSize()
+
+ def updateLut(self):
+ self._updateLutItem(None)
+
+ def _getNorm(self):
+ colormap = self.getColormap()
+ if colormap is None:
+ return Axis.LINEAR
+ else:
+ norm = colormap.getNormalization()
+ return norm
+
+ def updateNormalization(self):
+ self._updateDataInPlot()
+ self.update()
+
+
+class ColormapDialog(qt.QDialog):
+ """A QDialog widget to set the colormap.
+
+ :param parent: See :class:`QDialog`
+ :param str title: The QDialog title
+ """
+
+ visibleChanged = qt.Signal(bool)
+ """This event is sent when the dialog visibility change"""
+
+ def __init__(self, parent=None, title="Colormap Dialog"):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle(title)
+
+ self.__aboutToDelete = False
+ self._colormap = None
+
+ self._data = None
+ """Weak ref to an external numpy array
+ """
+ self._itemHolder = None
+ """Hard ref to a private item (used as holder to the data)
+ This allow to reuse the item cache
+ """
+ self._item = None
+ """Weak ref to an external item"""
+
+ self._colormapChange = utils.LockReentrant()
+ """Used as a semaphore to avoid editing the colormap object when we are
+ only attempt to display it.
+ Used instead of n connect and disconnect of the sigChanged. The
+ disconnection to sigChanged was also limiting when this colormapdialog
+ is used in the colormapaction and associated to the activeImageChanged.
+ (because the activeImageChanged is send when the colormap changed and
+ the self.setcolormap is a callback)
+ """
+
+ self.__colormapInvalidated = False
+ self.__dataInvalidated = False
+
+ self._histogramData = None
+
+ self._dataRange = None
+ """If defined 3-tuple containing information from a data:
+ minimum, positive minimum, maximum"""
+
+ self._colormapStoredState = None
+
+ # Colormap row
+ self._comboBoxColormap = ColormapNameComboBox(parent=self)
+ self._comboBoxColormap.currentIndexChanged[int].connect(self._comboBoxColormapUpdated)
+
+ # Normalization row
+ self._comboBoxNormalization = qt.QComboBox(parent=self)
+ normalizations = [
+ ('Linear', Colormap.LINEAR),
+ ('Gamma correction', Colormap.GAMMA),
+ ('Arcsinh', Colormap.ARCSINH),
+ ('Logarithmic', Colormap.LOGARITHM),
+ ('Square root', Colormap.SQRT)]
+ for name, userData in normalizations:
+ try:
+ icon = icons.getQIcon("colormap-norm-%s" % userData)
+ except:
+ icon = qt.QIcon()
+ self._comboBoxNormalization.addItem(icon, name, userData)
+ self._comboBoxNormalization.currentIndexChanged[int].connect(
+ self._normalizationUpdated)
+
+ self._gammaSpinBox = qt.QDoubleSpinBox(parent=self)
+ self._gammaSpinBox.setEnabled(False)
+ self._gammaSpinBox.setRange(0., 1000.)
+ self._gammaSpinBox.setDecimals(4)
+ if hasattr(qt.QDoubleSpinBox, "setStepType"):
+ # Introduced in Qt 5.12
+ self._gammaSpinBox.setStepType(qt.QDoubleSpinBox.AdaptiveDecimalStepType)
+ else:
+ self._gammaSpinBox.setSingleStep(0.1)
+ self._gammaSpinBox.valueChanged.connect(self._gammaUpdated)
+ self._gammaSpinBox.setValue(2.)
+
+ autoScaleCombo = _AutoscaleModeComboBox(self)
+ autoScaleCombo.currentIndexChanged.connect(self._autoscaleModeUpdated)
+ self._autoScaleCombo = autoScaleCombo
+
+ # Min row
+ self._minValue = _BoundaryWidget(parent=self, value=1.0)
+ self._minValue.sigAutoScaleChanged.connect(self._minAutoscaleUpdated)
+ self._minValue.sigValueChanged.connect(self._minValueUpdated)
+
+ # Max row
+ self._maxValue = _BoundaryWidget(parent=self, value=10.0)
+ self._maxValue.sigAutoScaleChanged.connect(self._maxAutoscaleUpdated)
+ self._maxValue.sigValueChanged.connect(self._maxValueUpdated)
+
+ self._autoButtons = _AutoScaleButtons(self)
+ self._autoButtons.autoRangeChanged.connect(self._autoRangeButtonsUpdated)
+
+ rangeLayout = qt.QGridLayout()
+ miniFont = qt.QFont(self.font())
+ miniFont.setPixelSize(8)
+ labelMin = qt.QLabel("Min", self)
+ labelMin.setFont(miniFont)
+ labelMin.setAlignment(qt.Qt.AlignHCenter)
+ labelMax = qt.QLabel("Max", self)
+ labelMax.setAlignment(qt.Qt.AlignHCenter)
+ labelMax.setFont(miniFont)
+ rangeLayout.addWidget(labelMin, 0, 0)
+ rangeLayout.addWidget(labelMax, 0, 1)
+ rangeLayout.addWidget(self._minValue, 1, 0)
+ rangeLayout.addWidget(self._maxValue, 1, 1)
+ rangeLayout.addWidget(self._autoButtons, 2, 0, 1, -1, qt.Qt.AlignCenter)
+
+ self._histoWidget = _ColormapHistogram(self)
+ self._histoWidget.sigRangeMoving.connect(self._histogramRangeMoving)
+ self._histoWidget.sigRangeMoved.connect(self._histogramRangeMoved)
+
+ # Scale to buttons
+ self._visibleAreaButton = qt.QPushButton(self)
+ self._visibleAreaButton.setEnabled(False)
+ self._visibleAreaButton.setText("Visible Area")
+ self._visibleAreaButton.clicked.connect(
+ self._handleScaleToVisibleAreaClicked,
+ type=qt.Qt.QueuedConnection)
+
+ # Place-holder for selected area ROI manager
+ self._roiForColormapManager = None
+
+ self._selectedAreaButton = WaitingPushButton(self)
+ self._selectedAreaButton.setEnabled(False)
+ self._selectedAreaButton.setText("Selection")
+ self._selectedAreaButton.setIcon(icons.getQIcon("add-shape-rectangle"))
+ self._selectedAreaButton.setCheckable(True)
+ self._selectedAreaButton.setDisabledWhenWaiting(False)
+ self._selectedAreaButton.toggled.connect(
+ self._handleScaleToSelectionToggled,
+ type=qt.Qt.QueuedConnection)
+
+ # define modal buttons
+ types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
+ self._buttonsModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsModal.setStandardButtons(types)
+ self._buttonsModal.accepted.connect(self.accept)
+ self._buttonsModal.rejected.connect(self.reject)
+
+ # define non modal buttons
+ types = qt.QDialogButtonBox.Close | qt.QDialogButtonBox.Reset
+ self._buttonsNonModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsNonModal.setStandardButtons(types)
+ button = self._buttonsNonModal.button(qt.QDialogButtonBox.Close)
+ button.clicked.connect(self.accept)
+ button.setDefault(True)
+ button = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ button.clicked.connect(self.resetColormap)
+
+ self._buttonsModal.setFocus(qt.Qt.OtherFocusReason)
+ self._buttonsNonModal.setFocus(qt.Qt.OtherFocusReason)
+
+ # Set the colormap to default values
+ self.setColormap(Colormap(name='gray', normalization='linear',
+ vmin=None, vmax=None))
+
+ self.setModal(self.isModal())
+
+ formLayout = qt.QFormLayout(self)
+ formLayout.setContentsMargins(10, 10, 10, 10)
+ formLayout.addRow('Colormap:', self._comboBoxColormap)
+ formLayout.addRow('Normalization:', self._comboBoxNormalization)
+ formLayout.addRow('Gamma:', self._gammaSpinBox)
+ formLayout.addRow(self._histoWidget)
+ formLayout.addRow(rangeLayout)
+ label = qt.QLabel('Mode:', self)
+ self._autoscaleModeLabel = label
+ label.setToolTip("Mode for autoscale. Algorithm used to find range in auto scale.")
+ formLayout.addItem(qt.QSpacerItem(1, 1, qt.QSizePolicy.Fixed, qt.QSizePolicy.Fixed))
+ formLayout.addRow(label, autoScaleCombo)
+
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._visibleAreaButton)
+ layout.addWidget(self._selectedAreaButton)
+ self._scaleToAreaGroup = qt.QGroupBox('Scale to:', self)
+ self._scaleToAreaGroup.setLayout(layout)
+ self._scaleToAreaGroup.setVisible(False)
+ formLayout.addRow(self._scaleToAreaGroup)
+
+ formLayout.addRow(self._buttonsModal)
+ formLayout.addRow(self._buttonsNonModal)
+ formLayout.setSizeConstraint(qt.QLayout.SetMinimumSize)
+
+ self.setTabOrder(self._comboBoxColormap, self._comboBoxNormalization)
+ self.setTabOrder(self._comboBoxNormalization, self._gammaSpinBox)
+ self.setTabOrder(self._gammaSpinBox, self._minValue)
+ self.setTabOrder(self._minValue, self._maxValue)
+ self.setTabOrder(self._maxValue, self._autoButtons)
+ self.setTabOrder(self._autoButtons, self._autoScaleCombo)
+ self.setTabOrder(self._autoScaleCombo, self._visibleAreaButton)
+ self.setTabOrder(self._visibleAreaButton, self._selectedAreaButton)
+ self.setTabOrder(self._selectedAreaButton, self._buttonsModal)
+ self.setTabOrder(self._buttonsModal, self._buttonsNonModal)
+
+ self.setFixedSize(self.sizeHint())
+ self._applyColormap()
+
+ def _invalidateColormap(self):
+ if self.isVisible():
+ self._applyColormap()
+ else:
+ self.__colormapInvalidated = True
+
+ def _invalidateData(self):
+ if self.isVisible():
+ self._updateWidgetRange()
+ self._histoWidget.invalidateData()
+ else:
+ self.__dataInvalidated = True
+
+ def _validate(self):
+ if self.__colormapInvalidated:
+ self._applyColormap()
+ if self.__dataInvalidated:
+ self._histoWidget.invalidateData()
+ if self.__dataInvalidated or self.__colormapInvalidated:
+ self._updateWidgetRange()
+ self.__dataInvalidated = False
+ self.__colormapInvalidated = False
+
+ def showEvent(self, event):
+ self.visibleChanged.emit(True)
+ super(ColormapDialog, self).showEvent(event)
+ if self.isVisible():
+ self._validate()
+
+ def closeEvent(self, event):
+ if not self.isModal():
+ self.accept()
+ super(ColormapDialog, self).closeEvent(event)
+
+ def hideEvent(self, event):
+ self.visibleChanged.emit(False)
+ super(ColormapDialog, self).hideEvent(event)
+
+ def close(self):
+ self.accept()
+ qt.QDialog.close(self)
+
+ def setModal(self, modal):
+ assert type(modal) is bool
+ self._buttonsNonModal.setVisible(not modal)
+ self._buttonsModal.setVisible(modal)
+ qt.QDialog.setModal(self, modal)
+
+ def event(self, event):
+ if event.type() == qt.QEvent.DeferredDelete:
+ self.__aboutToDelete = True
+ return super(ColormapDialog, self).event(event)
+
+ def exec(self):
+ wasModal = self.isModal()
+ self.setModal(True)
+ result = super(ColormapDialog, self).exec()
+ if not self.__aboutToDelete:
+ self.setModal(wasModal)
+ return result
+
+ def exec_(self): # Qt5 compatibility wrapper
+ return self.exec()
+
+ def _getFiniteColormapRange(self):
+ """Return a colormap range where auto ranges are fixed
+ according to the available data.
+ """
+ colormap = self.getColormap()
+ if colormap is None:
+ return 1, 10
+
+ item = self._getItem()
+ if item is not None:
+ return colormap.getColormapRange(item)
+ # If there is not item, there is no data
+ return colormap.getColormapRange(None)
+
+ @staticmethod
+ def computeDataRange(data):
+ """Compute the data range as used by :meth:`setDataRange`.
+
+ :param data: The data to process
+ :rtype: List[Union[None,float]]
+ """
+ if data is None or len(data) == 0:
+ return None, None, None
+
+ dataRange = min_max(data, min_positive=True, finite=True)
+ if dataRange.minimum is None:
+ # Only non-finite data
+ dataRange = None
+
+ if dataRange is not None:
+ dataRange = dataRange.minimum, dataRange.min_positive, dataRange.maximum
+
+ if dataRange is None or len(dataRange) != 3:
+ qt.QMessageBox.warning(
+ None, "No Data",
+ "Image data does not contain any real value")
+ dataRange = 1., 1., 10.
+
+ return dataRange
+
+ @staticmethod
+ def computeHistogram(data, scale=Axis.LINEAR, dataRange=None):
+ """Compute the data histogram as used by :meth:`setHistogram`.
+
+ :param data: The data to process
+ :param dataRange: Optional range to compute the histogram, which is a
+ tuple of min, max
+ :rtype: Tuple(List(float),List(float)
+ """
+ # For compatibility
+ if scale == Axis.LOGARITHMIC:
+ scale = Colormap.LOGARITHM
+
+ if data is None:
+ return None, None
+
+ if len(data) == 0:
+ return None, None
+
+ if data.ndim == 3: # RGB(A) images
+ _logger.info('Converting current image from RGB(A) to grayscale\
+ in order to compute the intensity distribution')
+ data = (data[:,:, 0] * 0.299 +
+ data[:,:, 1] * 0.587 +
+ data[:,:, 2] * 0.114)
+
+ # bad hack: get 256 continuous bins in the case we have a B&W
+ normalizeData = True
+ if numpy.issubdtype(data.dtype, numpy.ubyte):
+ normalizeData = False
+ elif numpy.issubdtype(data.dtype, numpy.integer):
+ if dataRange is not None:
+ xmin, xmax = dataRange
+ if xmin is not None and xmax is not None:
+ normalizeData = (xmax - xmin) > 255
+
+ if normalizeData:
+ if scale == Colormap.LOGARITHM:
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ data = numpy.log10(data)
+
+ if dataRange is not None:
+ xmin, xmax = dataRange
+ if xmin is None:
+ return None, None
+ if normalizeData:
+ if scale == Colormap.LOGARITHM:
+ xmin, xmax = numpy.log10(xmin), numpy.log10(xmax)
+ else:
+ xmin, xmax = min_max(data, min_positive=False, finite=True)
+
+ if xmin is None:
+ return None, None
+
+ nbins = min(256, int(numpy.sqrt(data.size)))
+ data_range = xmin, xmax
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(data.dtype, numpy.integer):
+ if nbins > xmax - xmin:
+ nbins = int(xmax - xmin)
+
+ nbins = max(2, nbins)
+ data = data.ravel().astype(numpy.float32)
+
+ histogram = Histogramnd(data, n_bins=nbins, histo_range=data_range)
+ bins = histogram.edges[0]
+ if normalizeData:
+ if scale == Colormap.LOGARITHM:
+ bins = 10 ** bins
+ return histogram.histo, bins
+
+ def _getItem(self):
+ if self._itemHolder is not None:
+ return self._itemHolder
+ if self._item is None:
+ return None
+ return self._item()
+
+ def setItem(self, item):
+ """Store the plot item.
+
+ According to the state of the dialog, the item will be used to display
+ the data range or the histogram of the data using :meth:`setDataRange`
+ and :meth:`setHistogram`
+ """
+ # While event from items are not supported, we can't ignore dup items
+ # old = self._getItem()
+ # if old is item:
+ # return
+ self._data = None
+ self._itemHolder = None
+ try:
+ if item is None:
+ self._item = None
+ else:
+ if not isinstance(item, items.ColormapMixIn):
+ self._item = None
+ raise ValueError("Item %s is not supported" % item)
+ self._item = weakref.ref(item, self._itemAboutToFinalize)
+ finally:
+ self._syncScaleToButtonsEnabled()
+ self._dataRange = None
+ self._histogramData = None
+ self._invalidateData()
+
+ def _getData(self):
+ if self._data is None:
+ return None
+ return self._data()
+
+ def setData(self, data):
+ """Store the data
+
+ According to the state of the dialog, the data will be used to display
+ the data range or the histogram of the data using :meth:`setDataRange`
+ and :meth:`setHistogram`
+ """
+ oldData = self._getData()
+ if oldData is data:
+ return
+
+ self._item = None
+ self._syncScaleToButtonsEnabled()
+ if data is None:
+ self._data = None
+ self._itemHolder = None
+ else:
+ self._data = weakref.ref(data, self._dataAboutToFinalize)
+ self._itemHolder = _DataRefHolder(self._data)
+
+ self._dataRange = None
+ self._histogramData = None
+
+ self._invalidateData()
+
+ def _getArray(self):
+ data = self._getData()
+ if data is not None:
+ return data
+ item = self._getItem()
+ if item is not None:
+ return item.getColormappedData(copy=False)
+ return None
+
+ def _colormapAboutToFinalize(self, weakrefColormap):
+ """Callback when the data weakref is about to be finalized."""
+ if self._colormap is weakrefColormap and qtinspect.isValid(self):
+ self.setColormap(None)
+
+ def _dataAboutToFinalize(self, weakrefData):
+ """Callback when the data weakref is about to be finalized."""
+ if self._data is weakrefData and qtinspect.isValid(self):
+ self.setData(None)
+
+ def _itemAboutToFinalize(self, weakref):
+ """Callback when the data weakref is about to be finalized."""
+ if self._item is weakref and qtinspect.isValid(self):
+ self.setItem(None)
+
+ @deprecation.deprecated(reason="It is private data", since_version="0.13")
+ def getHistogram(self):
+ histo = self._getHistogram()
+ if histo is None:
+ return None
+ counts, bin_edges = histo
+ return numpy.array(counts, copy=True), numpy.array(bin_edges, copy=True)
+
+ def _getHistogram(self):
+ """Returns the histogram defined by the dialog as metadata
+ to describe the data in order to speed up the dialog.
+
+ :return: (hist, bin_edges)
+ :rtype: 2-tuple of numpy arrays"""
+ return self._histogramData
+
+ def setHistogram(self, hist=None, bin_edges=None):
+ """Set the histogram to display.
+
+ This update the data range with the bounds of the bins.
+
+ :param hist: array-like of counts or None to hide histogram
+ :param bin_edges: array-like of bins edges or None to hide histogram
+ """
+ if hist is None or bin_edges is None:
+ self._histogramData = None
+ else:
+ self._histogramData = numpy.array(hist), numpy.array(bin_edges)
+
+ self._invalidateData()
+
+ def getColormap(self):
+ """Return the colormap description.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ if self._colormap is None:
+ return None
+ return self._colormap()
+
+ def resetColormap(self):
+ """
+ Reset the colormap state before modification.
+
+ ..note :: the colormap reference state is the state when set or the
+ state when validated
+ """
+ colormap = self.getColormap()
+ if colormap is not None and self._colormapStoredState is not None:
+ if colormap != self._colormapStoredState:
+ with self._colormapChange:
+ colormap.setFromColormap(self._colormapStoredState)
+ self._applyColormap()
+
+ def _getDataRange(self):
+ """Returns the data range defined by the dialog as metadata
+ to describe the data in order to speed up the dialog.
+
+ :return: (minimum, positiveMin, maximum)
+ :rtype: 3-tuple of floats or None"""
+ return self._dataRange
+
+ def setDataRange(self, minimum=None, positiveMin=None, maximum=None):
+ """Set the range of data to use for the range of the histogram area.
+
+ :param float minimum: The minimum of the data
+ :param float positiveMin: The positive minimum of the data
+ :param float maximum: The maximum of the data
+ """
+ self._dataRange = minimum, positiveMin, maximum
+ self._invalidateData()
+
+ def _setColormapRange(self, xmin, xmax):
+ """Set a new range to the held colormap and update the
+ widget."""
+ colormap = self.getColormap()
+ if colormap is not None:
+ with self._colormapChange:
+ colormap.setVRange(xmin, xmax)
+ self._updateWidgetRange()
+
+ def setColormapRangeFromDataBounds(self, bounds):
+ """Set the range of the colormap from current item and rect.
+
+ If there is no ColormapMixIn item attached to the ColormapDialog,
+ nothing is done.
+
+ :param Union[List[float],None] bounds:
+ (xmin, xmax, ymin, ymax) Rectangular region in data space
+ """
+ if bounds is None:
+ return None # no-op
+
+ colormap = self.getColormap()
+ if colormap is None:
+ return # no-op
+
+ item = self._getItem()
+ if not isinstance(item, items.ColormapMixIn):
+ return None # no-op
+
+ data = item.getColormappedData(copy=False)
+
+ xmin, xmax, ymin, ymax = bounds
+
+ if isinstance(item, items.ImageBase):
+ ox, oy = item.getOrigin()
+ sx, sy = item.getScale()
+
+ ystart = max(0, int((ymin - oy) / sy))
+ ystop = max(0, int(numpy.ceil((ymax - oy) / sy)))
+ xstart = max(0, int((xmin - ox) / sx))
+ xstop = max(0, int(numpy.ceil((xmax - ox) / sx)))
+
+ subset = data[ystart:ystop, xstart:xstop]
+
+ elif isinstance(item, items.Scatter):
+ x = item.getXData(copy=False)
+ y = item.getYData(copy=False)
+ subset = data[
+ numpy.logical_and(
+ numpy.logical_and(xmin <= x, x <= xmax),
+ numpy.logical_and(ymin <= y, y <= ymax))]
+
+ if subset.size == 0:
+ return # no-op
+
+ vmin, vmax = colormap._computeAutoscaleRange(subset)
+ self._setColormapRange(vmin, vmax)
+
+ def _updateWidgetRange(self):
+ """Update the colormap range displayed into the widget."""
+ xmin, xmax = self._getFiniteColormapRange()
+ colormap = self.getColormap()
+ if colormap is not None:
+ vRange = colormap.getVRange()
+ autoMin, autoMax = (r is None for r in vRange)
+ else:
+ autoMin, autoMax = False, False
+
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(xmin, autoMin)
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(xmax, autoMax)
+ with utils.blockSignals(self._histoWidget):
+ self._histoWidget.setFiniteRange((xmin, xmax))
+ with utils.blockSignals(self._autoButtons):
+ self._autoButtons.setAutoRange((autoMin, autoMax))
+ self._autoscaleModeLabel.setEnabled(autoMin or autoMax)
+
+ def accept(self):
+ self.storeCurrentState()
+ qt.QDialog.accept(self)
+
+ def storeCurrentState(self):
+ """
+ save the current value sof the colormap if the user want to undo is
+ modifications
+ """
+ colormap = self.getColormap()
+ if colormap is not None:
+ self._colormapStoredState = colormap.copy()
+ else:
+ self._colormapStoredState = None
+
+ def reject(self):
+ self.resetColormap()
+ qt.QDialog.reject(self)
+
+ def setColormap(self, colormap):
+ """Set the colormap description
+
+ :param ~silx.gui.colors.Colormap colormap: the colormap to edit
+ """
+ assert colormap is None or isinstance(colormap, Colormap)
+ if self._colormapChange.locked():
+ return
+
+ oldColormap = self.getColormap()
+ if oldColormap is colormap:
+ return
+ if oldColormap is not None:
+ oldColormap.sigChanged.disconnect(self._applyColormap)
+
+ if colormap is not None:
+ colormap.sigChanged.connect(self._applyColormap)
+ colormap = weakref.ref(colormap, self._colormapAboutToFinalize)
+
+ self._colormap = colormap
+ self.storeCurrentState()
+ self._invalidateColormap()
+
+ def _updateResetButton(self):
+ resetButton = self._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ rStateEnabled = False
+ colormap = self.getColormap()
+ if colormap is not None and colormap.isEditable():
+ # can reset only in the case the colormap changed
+ rStateEnabled = colormap != self._colormapStoredState
+ resetButton.setEnabled(rStateEnabled)
+
+ def _applyColormap(self):
+ self._updateResetButton()
+ if self._colormapChange.locked():
+ return
+
+ self._syncScaleToButtonsEnabled()
+
+ colormap = self.getColormap()
+ if colormap is None:
+ self._comboBoxColormap.setEnabled(False)
+ self._comboBoxNormalization.setEnabled(False)
+ self._gammaSpinBox.setEnabled(False)
+ self._autoScaleCombo.setEnabled(False)
+ self._minValue.setEnabled(False)
+ self._maxValue.setEnabled(False)
+ self._autoButtons.setEnabled(False)
+ self._autoscaleModeLabel.setEnabled(False)
+ self._histoWidget.setVisible(False)
+ self._histoWidget.setFiniteRange((None, None))
+ else:
+ assert colormap.getNormalization() in Colormap.NORMALIZATIONS
+ with utils.blockSignals(self._comboBoxColormap):
+ self._comboBoxColormap.setCurrentLut(colormap)
+ self._comboBoxColormap.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._comboBoxNormalization):
+ index = self._comboBoxNormalization.findData(
+ colormap.getNormalization())
+ if index < 0:
+ _logger.error('Unsupported normalization: %s' %
+ colormap.getNormalization())
+ else:
+ self._comboBoxNormalization.setCurrentIndex(index)
+ self._comboBoxNormalization.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._gammaSpinBox):
+ self._gammaSpinBox.setValue(
+ colormap.getGammaNormalizationParameter())
+ self._gammaSpinBox.setEnabled(
+ colormap.getNormalization() == 'gamma' and
+ colormap.isEditable())
+ with utils.blockSignals(self._autoScaleCombo):
+ self._autoScaleCombo.setCurrentMode(colormap.getAutoscaleMode())
+ self._autoScaleCombo.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._autoButtons):
+ self._autoButtons.setEnabled(colormap.isEditable())
+ self._autoButtons.setAutoRangeFromColormap(colormap)
+
+ vmin, vmax = colormap.getVRange()
+ if vmin is None or vmax is None:
+ # Compute it only if needed
+ dataRange = self._getFiniteColormapRange()
+ else:
+ dataRange = vmin, vmax
+
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
+ self._minValue.setEnabled(colormap.isEditable())
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
+ self._maxValue.setEnabled(colormap.isEditable())
+ self._autoscaleModeLabel.setEnabled(vmin is None or vmax is None)
+
+ with utils.blockSignals(self._histoWidget):
+ self._histoWidget.setVisible(True)
+ self._histoWidget.setFiniteRange(dataRange)
+ self._histoWidget.updateNormalization()
+
+ def _comboBoxColormapUpdated(self):
+ """Callback executed when the combo box with the colormap LUT
+ is updated by user input.
+ """
+ colormap = self.getColormap()
+ if colormap is not None:
+ with self._colormapChange:
+ name = self._comboBoxColormap.getCurrentName()
+ if name is not None:
+ colormap.setName(name)
+ else:
+ lut = self._comboBoxColormap.getCurrentColors()
+ colormap.setColormapLUT(lut)
+ self._histoWidget.updateLut()
+
+ def _autoRangeButtonsUpdated(self, autoRange):
+ """Callback executed when the autoscale buttons widget
+ is updated by user input.
+ """
+ dataRange = self._getFiniteColormapRange()
+
+ # Final colormap range
+ vmin = (dataRange[0] if not autoRange[0] else None)
+ vmax = (dataRange[1] if not autoRange[1] else None)
+
+ with self._colormapChange:
+ colormap = self.getColormap()
+ colormap.setVRange(vmin, vmax)
+
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(vmin or dataRange[0], isAuto=vmin is None)
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(vmax or dataRange[1], isAuto=vmax is None)
+
+ self._updateWidgetRange()
+
+ def _normalizationUpdated(self, index):
+ """Callback executed when the normalization widget
+ is updated by user input.
+ """
+ colormap = self.getColormap()
+ if colormap is not None:
+ normalization = self._comboBoxNormalization.itemData(index)
+ self._gammaSpinBox.setEnabled(normalization == 'gamma')
+
+ with self._colormapChange:
+ colormap.setNormalization(normalization)
+ self._histoWidget.updateNormalization()
+
+ self._updateWidgetRange()
+
+ def _gammaUpdated(self, value):
+ """Callback used to update the gamma normalization parameter"""
+ colormap = self.getColormap()
+ if colormap is not None:
+ colormap.setGammaNormalizationParameter(value)
+
+ def _autoscaleModeUpdated(self):
+ """Callback executed when the autoscale mode widget
+ is updated by user input.
+ """
+ mode = self._autoScaleCombo.currentMode()
+
+ colormap = self.getColormap()
+ if colormap is not None:
+ with self._colormapChange:
+ colormap.setAutoscaleMode(mode)
+
+ self._updateWidgetRange()
+
+ def _minAutoscaleUpdated(self, autoEnabled):
+ """Callback executed when the min autoscale from
+ the lineedit is updated by user input"""
+ colormap = self.getColormap()
+ xmin, xmax = colormap.getVRange()
+ if autoEnabled:
+ xmin = None
+ else:
+ xmin, _xmax = self._getFiniteColormapRange()
+ self._setColormapRange(xmin, xmax)
+
+ def _maxAutoscaleUpdated(self, autoEnabled):
+ """Callback executed when the max autoscale from
+ the lineedit is updated by user input"""
+ colormap = self.getColormap()
+ xmin, xmax = colormap.getVRange()
+ if autoEnabled:
+ xmax = None
+ else:
+ _xmin, xmax = self._getFiniteColormapRange()
+ self._setColormapRange(xmin, xmax)
+
+ def _minValueUpdated(self, value):
+ """Callback executed when the lineedit min value is
+ updated by user input"""
+ xmin = value
+ xmax = self._maxValue.getValue()
+ if xmax is not None and xmin > xmax:
+ # FIXME: This should be done in the widget itself
+ xmin = xmax
+ with utils.blockSignals(self._minValue):
+ self._minValue.setValue(xmin)
+ self._setColormapRange(xmin, xmax)
+
+ def _maxValueUpdated(self, value):
+ """Callback executed when the lineedit max value is
+ updated by user input"""
+ xmin = self._minValue.getValue()
+ xmax = value
+ if xmin is not None and xmin > xmax:
+ # FIXME: This should be done in the widget itself
+ xmax = xmin
+ with utils.blockSignals(self._maxValue):
+ self._maxValue.setValue(xmax)
+ self._setColormapRange(xmin, xmax)
+
+ def _histogramRangeMoving(self, vmin, vmax):
+ """Callback executed when for colormap range displayed in
+ the histogram widget is moving.
+
+ :param vmin: Update of the minimum range, else None
+ :param vmax: Update of the maximum range, else None
+ """
+ colormap = self.getColormap()
+ if vmin is not None:
+ with self._colormapChange:
+ colormap.setVMin(vmin)
+ self._minValue.setValue(vmin)
+ if vmax is not None:
+ with self._colormapChange:
+ colormap.setVMax(vmax)
+ self._maxValue.setValue(vmax)
+
+ def _histogramRangeMoved(self, vmin, vmax):
+ """Callback executed when for colormap range displayed in
+ the histogram widget has finished to move
+ """
+ xmin = self._minValue.getValue()
+ xmax = self._maxValue.getValue()
+ if vmin is None:
+ vmin = xmin
+ if vmax is None:
+ vmax = xmax
+ self._setColormapRange(vmin, vmax)
+
+ def _syncScaleToButtonsEnabled(self):
+ """Set the state of scale to buttons according to current item and colormap"""
+ colormap = self.getColormap()
+ enabled = self._item is not None and colormap is not None and colormap.isEditable()
+ self._scaleToAreaGroup.setVisible(enabled)
+ self._visibleAreaButton.setEnabled(enabled)
+ if not enabled:
+ self._selectedAreaButton.setChecked(False)
+ self._selectedAreaButton.setEnabled(enabled)
+
+ def _handleScaleToVisibleAreaClicked(self):
+ """Set colormap range from current item's visible area"""
+ item = self._getItem()
+ if item is None:
+ return # no-op
+
+ bounds = item.getVisibleBounds()
+ if bounds is None:
+ return # no-op
+
+ self.setColormapRangeFromDataBounds(bounds)
+
+ def _handleScaleToSelectionToggled(self, checked=False):
+ """Handle toggle of scale to selected are button"""
+ # Reset any previous ROI manager
+ if self._roiForColormapManager is not None:
+ self._roiForColormapManager.clear()
+ self._roiForColormapManager.stop()
+ self._roiForColormapManager = None
+
+ if not checked: # Reset button status
+ self._selectedAreaButton.setWaiting(False)
+ self._selectedAreaButton.setText("Selection")
+ return
+
+ item = self._getItem()
+ if item is None:
+ self._selectedAreaButton.setChecked(False)
+ return # no-op
+
+ plotWidget = item.getPlot()
+ if plotWidget is None:
+ self._selectedAreaButton.setChecked(False)
+ return # no-op
+
+ self._selectedAreaButton.setWaiting(True)
+ self._selectedAreaButton.setText("Draw Area...")
+
+ self._roiForColormapManager = RegionOfInterestManager(parent=plotWidget)
+ cmap = self.getColormap()
+ self._roiForColormapManager.setColor(
+ 'black' if cmap is None else cursorColorForColormap(cmap.getName()))
+ self._roiForColormapManager.sigInteractiveModeFinished.connect(
+ self.__roiInteractiveModeFinished)
+ self._roiForColormapManager.sigInteractiveRoiFinalized.connect(self.__roiFinalized)
+ self._roiForColormapManager.start(RectangleROI)
+
+ def __roiInteractiveModeFinished(self):
+ self._selectedAreaButton.setChecked(False)
+
+ def __roiFinalized(self, roi):
+ self._selectedAreaButton.setChecked(False)
+ if roi is not None:
+ ox, oy = roi.getOrigin()
+ width, height = roi.getSize()
+ self.setColormapRangeFromDataBounds((ox, ox+width, oy, oy+height))
+
+ def keyPressEvent(self, event):
+ """Override key handling.
+
+ It disables leaving the dialog when editing a text field.
+
+ But several press of Return key can be use to validate and close the
+ dialog.
+ """
+ if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return):
+ # Bypass QDialog keyPressEvent
+ # To avoid leaving the dialog when pressing enter on a text field
+ if self._minValue.hasFocus():
+ nextFocus = self._maxValue
+ elif self._maxValue.hasFocus():
+ if self.isModal():
+ nextFocus = self._buttonsModal.button(qt.QDialogButtonBox.Apply)
+ else:
+ nextFocus = self._buttonsNonModal.button(qt.QDialogButtonBox.Close)
+ else:
+ nextFocus = None
+ if nextFocus is not None:
+ nextFocus.setFocus(qt.Qt.OtherFocusReason)
+ else:
+ super(ColormapDialog, self).keyPressEvent(event)
diff --git a/src/silx/gui/dialog/DataFileDialog.py b/src/silx/gui/dialog/DataFileDialog.py
new file mode 100644
index 0000000..0d0382d
--- /dev/null
+++ b/src/silx/gui/dialog/DataFileDialog.py
@@ -0,0 +1,340 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`DataFileDialog`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "14/02/2018"
+
+import enum
+import logging
+from silx.gui import qt
+from silx.gui.hdf5.Hdf5Formatter import Hdf5Formatter
+import silx.io
+from .AbstractDataFileDialog import AbstractDataFileDialog
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _DataPreview(qt.QWidget):
+ """Provide a preview of the selected image"""
+
+ def __init__(self, parent=None):
+ super(_DataPreview, self).__init__(parent)
+
+ self.__formatter = Hdf5Formatter(self)
+ self.__data = None
+ self.__info = qt.QTableView(self)
+ self.__model = qt.QStandardItemModel(self)
+ self.__info.setModel(self.__model)
+ self.__info.horizontalHeader().hide()
+ self.__info.horizontalHeader().setStretchLastSection(True)
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.__info)
+ self.setLayout(layout)
+
+ def colormap(self):
+ return None
+
+ def setColormap(self, colormap):
+ # Ignored
+ pass
+
+ def sizeHint(self):
+ return qt.QSize(200, 200)
+
+ def setData(self, data, fromDataSelector=False):
+ self.__info.setEnabled(data is not None)
+ if data is None:
+ self.__model.clear()
+ else:
+ self.__model.clear()
+
+ if silx.io.is_dataset(data):
+ kind = "Dataset"
+ elif silx.io.is_group(data):
+ kind = "Group"
+ elif silx.io.is_file(data):
+ kind = "File"
+ else:
+ kind = "Unknown"
+
+ headers = []
+
+ basename = data.name.split("/")[-1]
+ if basename == "":
+ basename = "/"
+ headers.append("Basename")
+ self.__model.appendRow([qt.QStandardItem(basename)])
+ headers.append("Kind")
+ self.__model.appendRow([qt.QStandardItem(kind)])
+ if hasattr(data, "dtype"):
+ headers.append("Type")
+ text = self.__formatter.humanReadableType(data)
+ self.__model.appendRow([qt.QStandardItem(text)])
+ if hasattr(data, "shape"):
+ headers.append("Shape")
+ text = self.__formatter.humanReadableShape(data)
+ self.__model.appendRow([qt.QStandardItem(text)])
+ if hasattr(data, "attrs") and "NX_class" in data.attrs:
+ headers.append("NX_class")
+ value = data.attrs["NX_class"]
+ formatter = self.__formatter.textFormatter()
+ old = formatter.useQuoteForText()
+ formatter.setUseQuoteForText(False)
+ text = self.__formatter.textFormatter().toString(value)
+ formatter.setUseQuoteForText(old)
+ self.__model.appendRow([qt.QStandardItem(text)])
+ self.__model.setVerticalHeaderLabels(headers)
+ self.__data = data
+
+ def __imageItem(self):
+ image = self.__plot.getImage("data")
+ return image
+
+ def data(self):
+ if self.__data is not None:
+ if hasattr(self.__data, "name"):
+ # in case of HDF5
+ if self.__data.name is None:
+ # The dataset was closed
+ self.__data = None
+ return self.__data
+
+ def clear(self):
+ self.__data = None
+ self.__info.setText("")
+
+
+class DataFileDialog(AbstractDataFileDialog):
+ """The `DataFileDialog` class provides a dialog that allow users to select
+ any datasets or groups from an HDF5-like file.
+
+ The `DataFileDialog` class enables a user to traverse the file system in
+ order to select an HDF5-like file. Then to traverse the file to select an
+ HDF5 node.
+
+ .. image:: img/datafiledialog.png
+
+ The selected data is any kind of group or dataset. It can be restricted
+ to only existing datasets or only existing groups using
+ :meth:`setFilterMode`. A callback can be defining using
+ :meth:`setFilterCallback` to filter even more data which can be returned.
+
+ Filtering data which can be returned by a `DataFileDialog` can be done like
+ that:
+
+ .. code-block:: python
+
+ # Force to return only a dataset
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingDataset)
+
+ .. code-block:: python
+
+ def customFilter(obj):
+ if "NX_class" in obj.attrs:
+ return obj.attrs["NX_class"] in [b"NXentry", u"NXentry"]
+ return False
+
+ # Force to return an NX entry
+ dialog = DataFileDialog()
+ # 1st, filter out everything which is not a group
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
+ # 2nd, check what NX_class is an NXentry
+ dialog.setFilterCallback(customFilter)
+
+ Executing a `DataFileDialog` can be done like that:
+
+ .. code-block:: python
+
+ dialog = DataFileDialog()
+ result = dialog.exec()
+ if result:
+ print("Selection:")
+ print(dialog.selectedFile())
+ print(dialog.selectedUrl())
+ else:
+ print("Nothing selected")
+
+ If the selection is a dataset you can access to the data using
+ :meth:`selectedData`.
+
+ If the selection is a group or if you want to read the selected object on
+ your own you can use the `silx.io` API.
+
+ .. code-block:: python
+
+ url = dialog.selectedUrl()
+ with silx.io.open(url) as data:
+ pass
+
+ Or by loading the file first
+
+ .. code-block:: python
+
+ url = dialog.selectedDataUrl()
+ with silx.io.open(url.file_path()) as h5:
+ data = h5[url.data_path()]
+
+ Or by using `h5py` library
+
+ .. code-block:: python
+
+ url = dialog.selectedDataUrl()
+ with h5py.File(url.file_path(), mode="r") as h5:
+ data = h5[url.data_path()]
+ """
+
+ class FilterMode(enum.Enum):
+ """This enum is used to indicate what the user may select in the
+ dialog; i.e. what the dialog will return if the user clicks OK."""
+
+ AnyNode = 0
+ """Any existing node from an HDF5-like file."""
+ ExistingDataset = 1
+ """An existing HDF5-like dataset."""
+ ExistingGroup = 2
+ """An existing HDF5-like group. A file root is a group."""
+
+ def __init__(self, parent=None):
+ AbstractDataFileDialog.__init__(self, parent=parent)
+ self.__filter = DataFileDialog.FilterMode.AnyNode
+ self.__filterCallback = None
+
+ def selectedData(self):
+ """Returns the selected data by using the :meth:`silx.io.get_data`
+ API with the selected URL provided by the dialog.
+
+ If the URL identify a group of a file it will raise an exception. For
+ group or file you have to use on your own the API :meth:`silx.io.open`.
+
+ :rtype: numpy.ndarray
+ :raise ValueError: If the URL do not link to a dataset
+ """
+ url = self.selectedUrl()
+ return silx.io.get_data(url)
+
+ def _createPreviewWidget(self, parent):
+ previewWidget = _DataPreview(parent)
+ previewWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ return previewWidget
+
+ def _createSelectorWidget(self, parent):
+ # There is no selector
+ return None
+
+ def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
+ # There is no toolbar
+ return None
+
+ def _isDataSupportable(self, data):
+ """Check if the selected data can be supported at one point.
+
+ If true, the data selector will be checked and it will update the data
+ preview. Else the selecting is disabled.
+
+ :rtype: bool
+ """
+ # Everything is supported
+ return True
+
+ def _isFabioFilesSupported(self):
+ # Everything is supported
+ return False
+
+ def _isDataSupported(self, data):
+ """Check if the data can be returned by the dialog.
+
+ If true, this data can be returned by the dialog and the open button
+ will be enabled. If false the button will be disabled.
+
+ :rtype: bool
+ """
+ if self.__filter == DataFileDialog.FilterMode.AnyNode:
+ accepted = True
+ elif self.__filter == DataFileDialog.FilterMode.ExistingDataset:
+ accepted = silx.io.is_dataset(data)
+ elif self.__filter == DataFileDialog.FilterMode.ExistingGroup:
+ accepted = silx.io.is_group(data)
+ else:
+ raise ValueError("Filter %s is not supported" % self.__filter)
+ if not accepted:
+ return False
+ if self.__filterCallback is not None:
+ try:
+ return self.__filterCallback(data)
+ except Exception:
+ _logger.error("Error while executing custom callback", exc_info=True)
+ return False
+ return True
+
+ def setFilterCallback(self, callback):
+ """Set the filter callback. This filter is applied only if the filter
+ mode (:meth:`filterMode`) first accepts the selected data.
+
+ It is not supposed to be set while the dialog is being used.
+
+ :param callable callback: Define a custom function returning a boolean
+ and taking as argument an h5-like node. If the function returns true
+ the dialog can return the associated URL.
+ """
+ self.__filterCallback = callback
+
+ def setFilterMode(self, mode):
+ """Set the filter mode.
+
+ It is not supposed to be set while the dialog is being used.
+
+ :param DataFileDialog.FilterMode mode: The new filter.
+ """
+ self.__filter = mode
+
+ def fileMode(self):
+ """Returns the filter mode.
+
+ :rtype: DataFileDialog.FilterMode
+ """
+ return self.__filter
+
+ def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
+ """Returns the text displayed under the data preview.
+
+ This zone is used to display error in case or problem of data selection
+ or problems with IO.
+
+ :param numpy.ndarray dataAfterSelection: Data as it is after the
+ selection widget (basically the data from the preview widget)
+ :param numpy.ndarray dataAfterSelection: Data as it is before the
+ selection widget (basically the data from the browsing widget)
+ :rtype: bool
+ """
+ return u""
diff --git a/src/silx/gui/dialog/DatasetDialog.py b/src/silx/gui/dialog/DatasetDialog.py
new file mode 100644
index 0000000..c5ee295
--- /dev/null
+++ b/src/silx/gui/dialog/DatasetDialog.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select a HDF5 dataset in a
+tree.
+
+.. autoclass:: DatasetDialog
+ :members: addFile, addGroup, getSelectedDataUrl, setMode
+
+"""
+from .GroupDialog import _Hdf5ItemSelectionDialog
+import silx.io
+from silx.io.url import DataUrl
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/09/2018"
+
+
+class DatasetDialog(_Hdf5ItemSelectionDialog):
+ """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to
+ provide a HDF5 dataset selection dialog.
+
+ The information identifying the selected node is provided as a
+ :class:`silx.io.url.DataUrl`.
+
+ Example:
+
+ .. code-block:: python
+
+ dialog = DatasetDialog()
+ dialog.addFile(filepath1)
+ dialog.addFile(filepath2)
+
+ if dialog.exec():
+ print("File path: %s" % dialog.getSelectedDataUrl().file_path())
+ print("HDF5 dataset path : %s " % dialog.getSelectedDataUrl().data_path())
+ else:
+ print("Operation cancelled :(")
+
+ """
+ def __init__(self, parent=None):
+ _Hdf5ItemSelectionDialog.__init__(self, parent)
+
+ # customization for groups
+ self.setWindowTitle("HDF5 dataset selection")
+
+ self._header.setSections([self._model.NAME_COLUMN,
+ self._model.NODE_COLUMN,
+ self._model.LINK_COLUMN,
+ self._model.TYPE_COLUMN,
+ self._model.SHAPE_COLUMN])
+ self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
+
+ def setMode(self, mode):
+ """Set dialog mode DatasetDialog.SaveMode or DatasetDialog.LoadMode
+
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ """
+ _Hdf5ItemSelectionDialog.setMode(self, mode)
+ if mode == DatasetDialog.SaveMode:
+ self._selectDatasetStatusText = "Select a dataset or type a new dataset name"
+ elif mode == DatasetDialog.LoadMode:
+ self._selectDatasetStatusText = "Select a dataset"
+
+ def _onActivation(self, idx):
+ # double-click or enter press: filter for datasets
+ nodes = list(self._tree.selectedH5Nodes())
+ node = nodes[0]
+ if silx.io.is_dataset(node.h5py_object):
+ self.accept()
+
+ def _updateUrl(self):
+ # overloaded to filter for datasets
+ nodes = list(self._tree.selectedH5Nodes())
+ newDatasetName = self._lineEditNewItem.text()
+ isDatasetSelected = False
+ if nodes:
+ node = nodes[0]
+ if silx.io.is_dataset(node.h5py_object):
+ data_path = node.local_name
+ isDatasetSelected = True
+ elif silx.io.is_group(node.h5py_object):
+ data_path = node.local_name
+ if newDatasetName.lstrip("/"):
+ if not data_path.endswith("/"):
+ data_path += "/"
+ data_path += newDatasetName.lstrip("/")
+ isDatasetSelected = True
+
+ if isDatasetSelected:
+ self._selectedUrl = DataUrl(file_path=node.local_filename,
+ data_path=data_path)
+ self._okButton.setEnabled(True)
+ self._labelSelection.setText(
+ self._selectedUrl.path())
+ else:
+ self._selectedUrl = None
+ self._okButton.setEnabled(False)
+ self._labelSelection.setText(self._selectDatasetStatusText)
diff --git a/silx/gui/dialog/FileTypeComboBox.py b/src/silx/gui/dialog/FileTypeComboBox.py
index 92529bc..92529bc 100644
--- a/silx/gui/dialog/FileTypeComboBox.py
+++ b/src/silx/gui/dialog/FileTypeComboBox.py
diff --git a/src/silx/gui/dialog/GroupDialog.py b/src/silx/gui/dialog/GroupDialog.py
new file mode 100644
index 0000000..e129a51
--- /dev/null
+++ b/src/silx/gui/dialog/GroupDialog.py
@@ -0,0 +1,230 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select a HDF5 group in a
+tree.
+
+.. autoclass:: GroupDialog
+ :members: addFile, addGroup, getSelectedDataUrl, setMode
+
+"""
+from silx.gui import qt
+from silx.gui.hdf5.Hdf5TreeView import Hdf5TreeView
+import silx.io
+from silx.io.url import DataUrl
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "22/03/2018"
+
+
+class _Hdf5ItemSelectionDialog(qt.QDialog):
+ SaveMode = 1
+ """Mode used to set the HDF5 item selection dialog to *save* mode.
+ This adds a text field to type in a new item name."""
+
+ LoadMode = 2
+ """Mode used to set the HDF5 item selection dialog to *load* mode.
+ Only existing items of the HDF5 file can be selected in this mode."""
+
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("HDF5 item selection")
+
+ self._tree = Hdf5TreeView(self)
+ self._tree.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self._tree.activated.connect(self._onActivation)
+ self._tree.selectionModel().selectionChanged.connect(
+ self._onSelectionChange)
+
+ self._model = self._tree.findHdf5TreeModel()
+
+ self._header = self._tree.header()
+
+ self._newItemWidget = qt.QWidget(self)
+ newItemLayout = qt.QVBoxLayout(self._newItemWidget)
+ self._labelNewItem = qt.QLabel(self._newItemWidget)
+ self._labelNewItem.setText("Create new item in selected group (optional):")
+ self._lineEditNewItem = qt.QLineEdit(self._newItemWidget)
+ self._lineEditNewItem.setToolTip(
+ "Specify the name of a new item "
+ "to be created in the selected group.")
+ self._lineEditNewItem.textChanged.connect(
+ self._onNewItemNameChange)
+ newItemLayout.addWidget(self._labelNewItem)
+ newItemLayout.addWidget(self._lineEditNewItem)
+
+ _labelSelectionTitle = qt.QLabel(self)
+ _labelSelectionTitle.setText("Current selection")
+ self._labelSelection = qt.QLabel(self)
+ self._labelSelection.setStyleSheet("color: gray")
+ self._labelSelection.setWordWrap(True)
+ self._labelSelection.setText("Select an item")
+
+ buttonBox = qt.QDialogButtonBox()
+ self._okButton = buttonBox.addButton(qt.QDialogButtonBox.Ok)
+ self._okButton.setEnabled(False)
+ buttonBox.addButton(qt.QDialogButtonBox.Cancel)
+
+ buttonBox.accepted.connect(self.accept)
+ buttonBox.rejected.connect(self.reject)
+
+ vlayout = qt.QVBoxLayout(self)
+ vlayout.addWidget(self._tree)
+ vlayout.addWidget(self._newItemWidget)
+ vlayout.addWidget(_labelSelectionTitle)
+ vlayout.addWidget(self._labelSelection)
+ vlayout.addWidget(buttonBox)
+ self.setLayout(vlayout)
+
+ self.setMinimumWidth(400)
+
+ self._selectedUrl = None
+
+ def _onSelectionChange(self, old, new):
+ self._updateUrl()
+
+ def _onNewItemNameChange(self, text):
+ self._updateUrl()
+
+ def _onActivation(self, idx):
+ # double-click or enter press
+ self.accept()
+
+ def setMode(self, mode):
+ """Set dialog mode DatasetDialog.SaveMode or DatasetDialog.LoadMode
+
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ """
+ if mode == self.LoadMode:
+ # hide "Create new item" field
+ self._lineEditNewItem.clear()
+ self._newItemWidget.hide()
+ elif mode == self.SaveMode:
+ self._newItemWidget.show()
+ else:
+ raise ValueError("Invalid DatasetDialog mode %s" % mode)
+
+ def addFile(self, path):
+ """Add a HDF5 file to the tree.
+ All groups it contains will be selectable in the dialog.
+
+ :param str path: File path
+ """
+ self._model.insertFile(path)
+
+ def addGroup(self, group):
+ """Add a HDF5 group to the tree. This group and all its subgroups
+ will be selectable in the dialog.
+
+ :param h5py.Group group: HDF5 group
+ """
+ self._model.insertH5pyObject(group)
+
+ def _updateUrl(self):
+ nodes = list(self._tree.selectedH5Nodes())
+ subgroupName = self._lineEditNewItem.text()
+ if nodes:
+ node = nodes[0]
+ data_path = node.local_name
+ if subgroupName.lstrip("/"):
+ if not data_path.endswith("/"):
+ data_path += "/"
+ data_path += subgroupName.lstrip("/")
+ self._selectedUrl = DataUrl(file_path=node.local_filename,
+ data_path=data_path)
+ self._okButton.setEnabled(True)
+ self._labelSelection.setText(
+ self._selectedUrl.path())
+
+ def getSelectedDataUrl(self):
+ """Return a :class:`DataUrl` with a file path and a data path.
+ Return None if the dialog was cancelled.
+
+ :return: :class:`silx.io.url.DataUrl` object pointing to the
+ selected HDF5 item.
+ """
+ return self._selectedUrl
+
+
+class GroupDialog(_Hdf5ItemSelectionDialog):
+ """This :class:`QDialog` uses a :class:`silx.gui.hdf5.Hdf5TreeView` to
+ provide a HDF5 group selection dialog.
+
+ The information identifying the selected node is provided as a
+ :class:`silx.io.url.DataUrl`.
+
+ Example:
+
+ .. code-block:: python
+
+ dialog = GroupDialog()
+ dialog.addFile(filepath1)
+ dialog.addFile(filepath2)
+
+ if dialog.exec():
+ print("File path: %s" % dialog.getSelectedDataUrl().file_path())
+ print("HDF5 group path : %s " % dialog.getSelectedDataUrl().data_path())
+ else:
+ print("Operation cancelled :(")
+
+ """
+ def __init__(self, parent=None):
+ _Hdf5ItemSelectionDialog.__init__(self, parent)
+
+ # customization for groups
+ self.setWindowTitle("HDF5 group selection")
+
+ self._header.setSections([self._model.NAME_COLUMN,
+ self._model.NODE_COLUMN,
+ self._model.LINK_COLUMN])
+
+ def _onActivation(self, idx):
+ # double-click or enter press: filter for groups
+ nodes = list(self._tree.selectedH5Nodes())
+ node = nodes[0]
+ if silx.io.is_group(node.h5py_object):
+ self.accept()
+
+ def _updateUrl(self):
+ # overloaded to filter for groups
+ nodes = list(self._tree.selectedH5Nodes())
+ subgroupName = self._lineEditNewItem.text()
+ if nodes:
+ node = nodes[0]
+ if silx.io.is_group(node.h5py_object):
+ data_path = node.local_name
+ if subgroupName.lstrip("/"):
+ if not data_path.endswith("/"):
+ data_path += "/"
+ data_path += subgroupName.lstrip("/")
+ self._selectedUrl = DataUrl(file_path=node.local_filename,
+ data_path=data_path)
+ self._okButton.setEnabled(True)
+ self._labelSelection.setText(
+ self._selectedUrl.path())
+ else:
+ self._selectedUrl = None
+ self._okButton.setEnabled(False)
+ self._labelSelection.setText("Select a group")
diff --git a/src/silx/gui/dialog/ImageFileDialog.py b/src/silx/gui/dialog/ImageFileDialog.py
new file mode 100644
index 0000000..83c6d95
--- /dev/null
+++ b/src/silx/gui/dialog/ImageFileDialog.py
@@ -0,0 +1,354 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`ImageFileDialog`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/03/2019"
+
+import logging
+from silx.gui.plot import actions
+from silx.gui import qt
+from silx.gui.plot.PlotWidget import PlotWidget
+from .AbstractDataFileDialog import AbstractDataFileDialog
+import silx.io
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _ImageSelection(qt.QWidget):
+ """Provide a widget allowing to select an image from an hypercube by
+ selecting a slice."""
+
+ selectionChanged = qt.Signal()
+ """Emitted when the selection change."""
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.__shape = None
+ self.__axis = []
+ layout = qt.QVBoxLayout()
+ self.setLayout(layout)
+
+ def hasVisibleSelectors(self):
+ return self.__visibleSliders > 0
+
+ def isUsed(self):
+ if self.__shape is None:
+ return False
+ return len(self.__shape) > 2
+
+ def getSelectedData(self, data):
+ slicing = self.slicing()
+ image = data[slicing]
+ return image
+
+ def setData(self, data):
+ if data is None:
+ self.__visibleSliders = 0
+ return
+
+ shape = data.shape
+ if self.__shape is not None:
+ # clean up
+ for widget in self.__axis:
+ self.layout().removeWidget(widget)
+ widget.deleteLater()
+ self.__axis = []
+
+ self.__shape = shape
+ self.__visibleSliders = 0
+
+ if shape is not None:
+ # create expected axes
+ for index in range(len(shape) - 2):
+ axis = qt.QSlider(self)
+ axis.setMinimum(0)
+ axis.setMaximum(shape[index] - 1)
+ axis.setOrientation(qt.Qt.Horizontal)
+ if shape[index] == 1:
+ axis.setVisible(False)
+ else:
+ self.__visibleSliders += 1
+
+ axis.valueChanged.connect(self.__axisValueChanged)
+ self.layout().addWidget(axis)
+ self.__axis.append(axis)
+
+ self.selectionChanged.emit()
+
+ def __axisValueChanged(self):
+ self.selectionChanged.emit()
+
+ def slicing(self):
+ slicing = []
+ for axes in self.__axis:
+ slicing.append(axes.value())
+ return tuple(slicing)
+
+ def setSlicing(self, slicing):
+ for i, value in enumerate(slicing):
+ if i > len(self.__axis):
+ break
+ self.__axis[i].setValue(value)
+
+ def selectSlicing(self, slicing):
+ """Select a slicing.
+
+ The provided value could be unconsistent and therefore is not supposed
+ to be retrivable with a getter.
+
+ :param Union[None,Tuple[int]] slicing:
+ """
+ if slicing is None:
+ # Create a default slicing
+ needed = self.__visibleSliders
+ slicing = (0,) * needed
+ if len(slicing) < self.__visibleSliders:
+ slicing = slicing + (0,) * (self.__visibleSliders - len(slicing))
+ self.setSlicing(slicing)
+
+
+class _ImagePreview(qt.QWidget):
+ """Provide a preview of the selected image"""
+
+ def __init__(self, parent=None):
+ super(_ImagePreview, self).__init__(parent)
+
+ self.__data = None
+ self.__plot = PlotWidget(self)
+ self.__plot.setAxesDisplayed(False)
+ self.__plot.setKeepDataAspectRatio(True)
+ layout = qt.QVBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self.__plot)
+ self.setLayout(layout)
+
+ def resizeEvent(self, event):
+ self.__updateConstraints()
+ return qt.QWidget.resizeEvent(self, event)
+
+ def sizeHint(self):
+ return qt.QSize(200, 200)
+
+ def plot(self):
+ return self.__plot
+
+ def setData(self, data, fromDataSelector=False):
+ if data is None:
+ self.clear()
+ return
+
+ resetzoom = not fromDataSelector
+ previousImage = self.data()
+ if previousImage is not None and data.shape != previousImage.shape:
+ resetzoom = True
+
+ self.__plot.addImage(legend="data", data=data, resetzoom=resetzoom)
+ self.__data = data
+ self.__updateConstraints()
+
+ def __updateConstraints(self):
+ """
+ Update the constraints depending on the size of the widget
+ """
+ image = self.data()
+ if image is None:
+ return
+ size = self.size()
+ if size.width() == 0 or size.height() == 0:
+ return
+
+ heightData, widthData = image.shape
+
+ widthContraint = heightData * size.width() / size.height()
+ if widthContraint > widthData:
+ heightContraint = heightData
+ else:
+ heightContraint = heightData * size.height() / size.width()
+ widthContraint = widthData
+
+ midWidth, midHeight = widthData * 0.5, heightData * 0.5
+ heightContraint, widthContraint = heightContraint * 0.5, widthContraint * 0.5
+
+ axis = self.__plot.getXAxis()
+ axis.setLimitsConstraints(midWidth - widthContraint, midWidth + widthContraint)
+ axis = self.__plot.getYAxis()
+ axis.setLimitsConstraints(midHeight - heightContraint, midHeight + heightContraint)
+
+ def __imageItem(self):
+ image = self.__plot.getImage("data")
+ return image
+
+ def data(self):
+ if self.__data is not None:
+ if hasattr(self.__data, "name"):
+ # in case of HDF5
+ if self.__data.name is None:
+ # The dataset was closed
+ self.__data = None
+ return self.__data
+
+ def colormap(self):
+ image = self.__imageItem()
+ if image is not None:
+ return image.getColormap()
+ return self.__plot.getDefaultColormap()
+
+ def setColormap(self, colormap):
+ self.__plot.setDefaultColormap(colormap)
+
+ def clear(self):
+ self.__data = None
+ image = self.__imageItem()
+ if image is not None:
+ self.__plot.removeImage(legend="data")
+
+
+class ImageFileDialog(AbstractDataFileDialog):
+ """The `ImageFileDialog` class provides a dialog that allow users to select
+ an image from a file.
+
+ The `ImageFileDialog` class enables a user to traverse the file system in
+ order to select one file. Then to traverse the file to select a frame or
+ a slice of a dataset.
+
+ .. image:: img/imagefiledialog_h5.png
+
+ It supports fast access to image files using `FabIO`. Which is not the case
+ of the default silx API. Image files still also can be available using the
+ NeXus layout, by editing the file type combo box.
+
+ .. image:: img/imagefiledialog_edf.png
+
+ The selected data is an numpy array with 2 dimension.
+
+ Using an `ImageFileDialog` can be done like that.
+
+ .. code-block:: python
+
+ dialog = ImageFileDialog()
+ result = dialog.exec()
+ if result:
+ print("Selection:")
+ print(dialog.selectedFile())
+ print(dialog.selectedUrl())
+ print(dialog.selectedImage())
+ else:
+ print("Nothing selected")
+ """
+
+ def selectedImage(self):
+ """Returns the selected image data as numpy
+
+ :rtype: numpy.ndarray
+ """
+ url = self.selectedUrl()
+ return silx.io.get_data(url)
+
+ def _createPreviewWidget(self, parent):
+ previewWidget = _ImagePreview(parent)
+ previewWidget.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ return previewWidget
+
+ def _createSelectorWidget(self, parent):
+ return _ImageSelection(parent)
+
+ def _createPreviewToolbar(self, parent, dataPreviewWidget, dataSelectorWidget):
+ plot = dataPreviewWidget.plot()
+ toolbar = qt.QToolBar(parent)
+ toolbar.setIconSize(qt.QSize(16, 16))
+ toolbar.setStyleSheet("QToolBar { border: 0px }")
+ toolbar.addAction(actions.mode.ZoomModeAction(plot, parent))
+ toolbar.addAction(actions.mode.PanModeAction(plot, parent))
+ toolbar.addSeparator()
+ toolbar.addAction(actions.control.ResetZoomAction(plot, parent))
+ toolbar.addSeparator()
+ toolbar.addAction(actions.control.ColormapAction(plot, parent))
+ return toolbar
+
+ def _isDataSupportable(self, data):
+ """Check if the selected data can be supported at one point.
+
+ If true, the data selector will be checked and it will update the data
+ preview. Else the selecting is disabled.
+
+ :rtype: bool
+ """
+ if not hasattr(data, "dtype"):
+ # It is not an HDF5 dataset nor a fabio image wrapper
+ return False
+
+ if data is None or data.shape is None:
+ return False
+
+ if data.dtype.kind not in set(["f", "u", "i", "b"]):
+ return False
+
+ dim = len(data.shape)
+ return dim >= 2
+
+ def _isFabioFilesSupported(self):
+ return True
+
+ def _isDataSupported(self, data):
+ """Check if the data can be returned by the dialog.
+
+ If true, this data can be returned by the dialog and the open button
+ while be enabled. If false the button will be disabled.
+
+ :rtype: bool
+ """
+ dim = len(data.shape)
+ return dim == 2
+
+ def _displayedDataInfo(self, dataBeforeSelection, dataAfterSelection):
+ """Returns the text displayed under the data preview.
+
+ This zone is used to display error in case or problem of data selection
+ or problems with IO.
+
+ :param numpy.ndarray dataAfterSelection: Data as it is after the
+ selection widget (basically the data from the preview widget)
+ :param numpy.ndarray dataAfterSelection: Data as it is before the
+ selection widget (basically the data from the browsing widget)
+ :rtype: bool
+ """
+ destination = self.__formatShape(dataAfterSelection.shape)
+ source = self.__formatShape(dataBeforeSelection.shape)
+ return u"%s \u2192 %s" % (source, destination)
+
+ def __formatShape(self, shape):
+ result = []
+ for s in shape:
+ if isinstance(s, slice):
+ v = u"\u2026"
+ else:
+ v = str(s)
+ result.append(v)
+ return u" \u00D7 ".join(result)
diff --git a/silx/gui/dialog/SafeFileIconProvider.py b/src/silx/gui/dialog/SafeFileIconProvider.py
index 1e06b64..1e06b64 100644
--- a/silx/gui/dialog/SafeFileIconProvider.py
+++ b/src/silx/gui/dialog/SafeFileIconProvider.py
diff --git a/src/silx/gui/dialog/SafeFileSystemModel.py b/src/silx/gui/dialog/SafeFileSystemModel.py
new file mode 100644
index 0000000..1ec7153
--- /dev/null
+++ b/src/silx/gui/dialog/SafeFileSystemModel.py
@@ -0,0 +1,802 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains an :class:`SafeFileSystemModel`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "22/11/2017"
+
+import sys
+import os.path
+import logging
+import weakref
+
+from silx.gui import qt
+from .SafeFileIconProvider import SafeFileIconProvider
+
+_logger = logging.getLogger(__name__)
+
+
+class _Item(object):
+
+ def __init__(self, fileInfo):
+ self.__fileInfo = fileInfo
+ self.__parent = None
+ self.__children = None
+ self.__absolutePath = None
+
+ def isDrive(self):
+ if sys.platform == "win32":
+ return self.parent().parent() is None
+ else:
+ return False
+
+ def isRoot(self):
+ return self.parent() is None
+
+ def isFile(self):
+ """
+ Returns true if the path is a file.
+
+ It avoid to access to the `Qt.QFileInfo` in case the file is a drive.
+ """
+ if self.isDrive():
+ return False
+ return self.__fileInfo.isFile()
+
+ def isDir(self):
+ """
+ Returns true if the path is a directory.
+
+ The default `qt.QFileInfo.isDir` can freeze the file system with
+ network drives. This function avoid the freeze in case of browsing
+ the root.
+ """
+ if self.isDrive():
+ # A drive is a directory, we don't have to synchronize the
+ # drive to know that
+ return True
+ return self.__fileInfo.isDir()
+
+ def absoluteFilePath(self):
+ """
+ Returns an absolute path including the file name.
+
+ This function uses in most cases the default
+ `qt.QFileInfo.absoluteFilePath`. But it is known to freeze the file
+ system with network drives.
+
+ This function uses `qt.QFileInfo.filePath` in case of root drives, to
+ avoid this kind of issues. In case of drive, the result is the same,
+ while the file path is already absolute.
+
+ :rtype: str
+ """
+ if self.__absolutePath is None:
+ if self.isRoot():
+ path = ""
+ elif self.isDrive():
+ path = self.__fileInfo.filePath()
+ else:
+ path = os.path.join(self.parent().absoluteFilePath(), self.__fileInfo.fileName())
+ if path == "":
+ return "/"
+ self.__absolutePath = path
+ return self.__absolutePath
+
+ def child(self):
+ self.populate()
+ return self.__children
+
+ def childAt(self, position):
+ self.populate()
+ return self.__children[position]
+
+ def childCount(self):
+ self.populate()
+ return len(self.__children)
+
+ def indexOf(self, item):
+ self.populate()
+ return self.__children.index(item)
+
+ def parent(self):
+ parent = self.__parent
+ if parent is None:
+ return None
+ return parent()
+
+ def filePath(self):
+ return self.__fileInfo.filePath()
+
+ def fileName(self):
+ if self.isDrive():
+ name = self.absoluteFilePath()
+ if name[-1] == "/":
+ name = name[:-1]
+ return name
+ return os.path.basename(self.absoluteFilePath())
+
+ def fileInfo(self):
+ """
+ Returns the Qt file info.
+
+ :rtype: Qt.QFileInfo
+ """
+ return self.__fileInfo
+
+ def _setParent(self, parent):
+ self.__parent = weakref.ref(parent)
+
+ def findChildrenByPath(self, path):
+ if path == "":
+ return self
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ names = path.split("/")
+ caseSensitive = qt.QFSFileEngine(path).caseSensitive()
+ count = len(names)
+ cursor = self
+ for name in names:
+ for item in cursor.child():
+ if caseSensitive:
+ same = item.fileName() == name
+ else:
+ same = item.fileName().lower() == name.lower()
+ if same:
+ cursor = item
+ count -= 1
+ break
+ else:
+ return None
+ if count == 0:
+ break
+ else:
+ return None
+ return cursor
+
+ def populate(self):
+ if self.__children is not None:
+ return
+ self.__children = []
+ if self.isRoot():
+ items = qt.QDir.drives()
+ else:
+ directory = qt.QDir(self.absoluteFilePath())
+ filters = qt.QDir.AllEntries | qt.QDir.Hidden | qt.QDir.System
+ items = directory.entryInfoList(filters)
+ for fileInfo in items:
+ i = _Item(fileInfo)
+ self.__children.append(i)
+ i._setParent(self)
+
+
+class _RawFileSystemModel(qt.QAbstractItemModel):
+ """
+ This class implement a file system model and try to avoid freeze. On Qt4,
+ :class:`qt.QFileSystemModel` is known to freeze the file system when
+ network drives are available.
+
+ To avoid this behaviour, this class does not use
+ `qt.QFileInfo.absoluteFilePath` nor `qt.QFileInfo.canonicalPath` to reach
+ information on drives.
+
+ This model do not take care of sorting and filtering. This features are
+ managed by another model, by composition.
+
+ And because it is the end of life of Qt4, we do not implement asynchronous
+ loading of files as it is done by :class:`qt.QFileSystemModel`, nor some
+ useful features.
+ """
+
+ __directoryLoadedSync = qt.Signal(str)
+ """This signal is connected asynchronously to a slot. It allows to
+ emit directoryLoaded as an asynchronous signal."""
+
+ directoryLoaded = qt.Signal(str)
+ """This signal is emitted when the gatherer thread has finished to load the
+ path."""
+
+ rootPathChanged = qt.Signal(str)
+ """This signal is emitted whenever the root path has been changed to a
+ newPath."""
+
+ NAME_COLUMN = 0
+ SIZE_COLUMN = 1
+ TYPE_COLUMN = 2
+ LAST_MODIFIED_COLUMN = 3
+
+ def __init__(self, parent=None):
+ qt.QAbstractItemModel.__init__(self, parent)
+ self.__computer = _Item(qt.QFileInfo())
+ self.__header = "Name", "Size", "Type", "Last modification"
+ self.__currentPath = ""
+ self.__iconProvider = SafeFileIconProvider()
+ self.__directoryLoadedSync.connect(self.__emitDirectoryLoaded, qt.Qt.QueuedConnection)
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ if orientation == qt.Qt.Horizontal:
+ if role == qt.Qt.DisplayRole:
+ return self.__header[section]
+ if role == qt.Qt.TextAlignmentRole:
+ return qt.Qt.AlignRight if section == 1 else qt.Qt.AlignLeft
+ return None
+
+ def flags(self, index):
+ if not index.isValid():
+ return 0
+ return qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return len(self.__header)
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ item = self.__getItem(parent)
+ return item.childCount()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ if not index.isValid():
+ return None
+
+ column = index.column()
+ if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
+ if column == self.NAME_COLUMN:
+ return self.__displayName(index)
+ elif column == self.SIZE_COLUMN:
+ return self.size(index)
+ elif column == self.TYPE_COLUMN:
+ return self.type(index)
+ elif column == self.LAST_MODIFIED_COLUMN:
+ return self.lastModified(index)
+ else:
+ _logger.warning("data: invalid display value column %d", index.column())
+ elif role == qt.QFileSystemModel.FilePathRole:
+ return self.filePath(index)
+ elif role == qt.QFileSystemModel.FileNameRole:
+ return self.fileName(index)
+ elif role == qt.Qt.DecorationRole:
+ if column == self.NAME_COLUMN:
+ icon = self.fileIcon(index)
+ if icon is None or icon.isNull():
+ if self.isDir(index):
+ self.__iconProvider.icon(qt.QFileIconProvider.Folder)
+ else:
+ self.__iconProvider.icon(qt.QFileIconProvider.File)
+ return icon
+ elif role == qt.Qt.TextAlignmentRole:
+ if column == self.SIZE_COLUMN:
+ return qt.Qt.AlignRight
+ elif role == qt.QFileSystemModel.FilePermissions:
+ return self.permissions(index)
+
+ return None
+
+ def index(self, *args, **kwargs):
+ path_api = False
+ path_api |= len(args) >= 1 and isinstance(args[0], str)
+ path_api |= "path" in kwargs
+
+ if path_api:
+ return self.__indexFromPath(*args, **kwargs)
+ else:
+ return self.__index(*args, **kwargs)
+
+ def __index(self, row, column, parent=qt.QModelIndex()):
+ if parent.isValid() and parent.column() != 0:
+ return None
+
+ parentItem = self.__getItem(parent)
+ item = parentItem.childAt(row)
+ return self.createIndex(row, column, item)
+
+ def __indexFromPath(self, path, column=0):
+ """
+ Uses the index(str) C++ API
+
+ :rtype: qt.QModelIndex
+ """
+ if path == "":
+ return qt.QModelIndex()
+
+ item = self.__computer.findChildrenByPath(path)
+ if item is None:
+ return qt.QModelIndex()
+
+ return self.createIndex(item.parent().indexOf(item), column, item)
+
+ def parent(self, index):
+ if not index.isValid():
+ return qt.QModelIndex()
+
+ item = self.__getItem(index)
+ if index is None:
+ return qt.QModelIndex()
+
+ parent = item.parent()
+ if parent is None or parent is self.__computer:
+ return qt.QModelIndex()
+
+ return self.createIndex(parent.parent().indexOf(parent), 0, parent)
+
+ def __emitDirectoryLoaded(self, path):
+ self.directoryLoaded.emit(path)
+
+ def __emitRootPathChanged(self, path):
+ self.rootPathChanged.emit(path)
+
+ def __getItem(self, index):
+ if not index.isValid():
+ return self.__computer
+ item = index.internalPointer()
+ return item
+
+ def fileIcon(self, index):
+ item = self.__getItem(index)
+ if self.__iconProvider is not None:
+ fileInfo = item.fileInfo()
+ result = self.__iconProvider.icon(fileInfo)
+ else:
+ style = qt.QApplication.instance().style()
+ if item.isRoot():
+ result = style.standardIcon(qt.QStyle.SP_ComputerIcon)
+ elif item.isDrive():
+ result = style.standardIcon(qt.QStyle.SP_DriveHDIcon)
+ elif item.isDir():
+ result = style.standardIcon(qt.QStyle.SP_DirIcon)
+ else:
+ result = style.standardIcon(qt.QStyle.SP_FileIcon)
+ return result
+
+ def _item(self, index):
+ item = self.__getItem(index)
+ return item
+
+ def fileInfo(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo()
+ return result
+
+ def __fileIcon(self, index):
+ item = self.__getItem(index)
+ result = item.fileName()
+ return result
+
+ def __displayName(self, index):
+ item = self.__getItem(index)
+ result = item.fileName()
+ return result
+
+ def fileName(self, index):
+ item = self.__getItem(index)
+ result = item.fileName()
+ return result
+
+ def filePath(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().filePath()
+ return result
+
+ def isDir(self, index):
+ item = self.__getItem(index)
+ result = item.isDir()
+ return result
+
+ def lastModified(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().lastModified()
+ return result
+
+ def permissions(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().permissions()
+ return result
+
+ def size(self, index):
+ item = self.__getItem(index)
+ result = item.fileInfo().size()
+ return result
+
+ def type(self, index):
+ item = self.__getItem(index)
+ if self.__iconProvider is not None:
+ fileInfo = item.fileInfo()
+ result = self.__iconProvider.type(fileInfo)
+ else:
+ if item.isRoot():
+ result = "Computer"
+ elif item.isDrive():
+ result = "Drive"
+ elif item.isDir():
+ result = "Directory"
+ else:
+ fileInfo = item.fileInfo()
+ result = fileInfo.suffix()
+ return result
+
+ # File manipulation
+
+ # bool remove(const QModelIndex & index) const
+ # bool rmdir(const QModelIndex & index) const
+ # QModelIndex mkdir(const QModelIndex & parent, const QString & name)
+
+ # Configuration
+
+ def rootDirectory(self):
+ return qt.QDir(self.rootPath())
+
+ def rootPath(self):
+ return self.__currentPath
+
+ def setRootPath(self, path):
+ if self.__currentPath == path:
+ return
+ self.__currentPath = path
+ item = self.__computer.findChildrenByPath(path)
+ self.__emitRootPathChanged(path)
+ if item is None or item.parent() is None:
+ return qt.QModelIndex()
+ index = self.createIndex(item.parent().indexOf(item), 0, item)
+ self.__directoryLoadedSync.emit(path)
+ return index
+
+ def iconProvider(self):
+ # FIXME: invalidate the model
+ return self.__iconProvider
+
+ def setIconProvider(self, provider):
+ # FIXME: invalidate the model
+ self.__iconProvider = provider
+
+ # bool resolveSymlinks() const
+ # void setResolveSymlinks(bool enable)
+
+ def setNameFilterDisables(self, enable):
+ return None
+
+ def nameFilterDisables(self):
+ return None
+
+ def myComputer(self, role=qt.Qt.DisplayRole):
+ return None
+
+ def setNameFilters(self, filters):
+ return
+
+ def nameFilters(self):
+ return None
+
+ def filter(self):
+ return self.__filters
+
+ def setFilter(self, filters):
+ return
+
+ def setReadOnly(self, enable):
+ assert(enable is True)
+
+ def isReadOnly(self):
+ return False
+
+
+class SafeFileSystemModel(qt.QSortFilterProxyModel):
+ """
+ This class implement a file system model and try to avoid freeze. On Qt4,
+ :class:`qt.QFileSystemModel` is known to freeze the file system when
+ network drives are available.
+
+ To avoid this behaviour, this class does not use
+ `qt.QFileInfo.absoluteFilePath` nor `qt.QFileInfo.canonicalPath` to reach
+ information on drives.
+
+ And because it is the end of life of Qt4, we do not implement asynchronous
+ loading of files as it is done by :class:`qt.QFileSystemModel`, nor some
+ useful features.
+ """
+
+ def __init__(self, parent=None):
+ qt.QSortFilterProxyModel.__init__(self, parent=parent)
+ self.__nameFilterDisables = sys.platform == "darwin"
+ self.__nameFilters = []
+ self.__filters = qt.QDir.AllEntries | qt.QDir.NoDotAndDotDot | qt.QDir.AllDirs
+ sourceModel = _RawFileSystemModel(self)
+ self.setSourceModel(sourceModel)
+
+ @property
+ def directoryLoaded(self):
+ return self.sourceModel().directoryLoaded
+
+ @property
+ def rootPathChanged(self):
+ return self.sourceModel().rootPathChanged
+
+ def index(self, *args, **kwargs):
+ path_api = False
+ path_api |= len(args) >= 1 and isinstance(args[0], str)
+ path_api |= "path" in kwargs
+
+ if path_api:
+ return self.__indexFromPath(*args, **kwargs)
+ else:
+ return self.__index(*args, **kwargs)
+
+ def __index(self, row, column, parent=qt.QModelIndex()):
+ return qt.QSortFilterProxyModel.index(self, row, column, parent)
+
+ def __indexFromPath(self, path, column=0):
+ """
+ Uses the index(str) C++ API
+
+ :rtype: qt.QModelIndex
+ """
+ if path == "":
+ return qt.QModelIndex()
+
+ index = self.sourceModel().index(path, column)
+ index = self.mapFromSource(index)
+ return index
+
+ def lessThan(self, leftSourceIndex, rightSourceIndex):
+ sourceModel = self.sourceModel()
+ sortColumn = self.sortColumn()
+ if sortColumn == _RawFileSystemModel.NAME_COLUMN:
+ leftItem = sourceModel._item(leftSourceIndex)
+ rightItem = sourceModel._item(rightSourceIndex)
+ if sys.platform != "darwin":
+ # Sort directories before files
+ leftIsDir = leftItem.isDir()
+ rightIsDir = rightItem.isDir()
+ if leftIsDir ^ rightIsDir:
+ return leftIsDir
+ return leftItem.fileName().lower() < rightItem.fileName().lower()
+ elif sortColumn == _RawFileSystemModel.SIZE_COLUMN:
+ left = sourceModel.fileInfo(leftSourceIndex)
+ right = sourceModel.fileInfo(rightSourceIndex)
+ return left.size() < right.size()
+ elif sortColumn == _RawFileSystemModel.TYPE_COLUMN:
+ left = sourceModel.type(leftSourceIndex)
+ right = sourceModel.type(rightSourceIndex)
+ return left < right
+ elif sortColumn == _RawFileSystemModel.LAST_MODIFIED_COLUMN:
+ left = sourceModel.fileInfo(leftSourceIndex)
+ right = sourceModel.fileInfo(rightSourceIndex)
+ return left.lastModified() < right.lastModified()
+ else:
+ _logger.warning("Unsupported sorted column %d", sortColumn)
+
+ return False
+
+ def __filtersAccepted(self, item, filters):
+ """
+ Check individual flag filters.
+ """
+ if not (filters & (qt.QDir.Dirs | qt.QDir.AllDirs)):
+ # Hide dirs
+ if item.isDir():
+ return False
+ if not (filters & qt.QDir.Files):
+ # Hide files
+ if item.isFile():
+ return False
+ if not (filters & qt.QDir.Drives):
+ # Hide drives
+ if item.isDrive():
+ return False
+
+ fileInfo = item.fileInfo()
+ if fileInfo is None:
+ return False
+
+ filterPermissions = (filters & qt.QDir.PermissionMask) != 0
+ if filterPermissions and (filters & (qt.QDir.Dirs | qt.QDir.Files)):
+ if (filters & qt.QDir.Readable):
+ # Hide unreadable
+ if not fileInfo.isReadable():
+ return False
+ if (filters & qt.QDir.Writable):
+ # Hide unwritable
+ if not fileInfo.isWritable():
+ return False
+ if (filters & qt.QDir.Executable):
+ # Hide unexecutable
+ if not fileInfo.isExecutable():
+ return False
+
+ if (filters & qt.QDir.NoSymLinks):
+ # Hide sym links
+ if fileInfo.isSymLink():
+ return False
+
+ if not (filters & qt.QDir.System):
+ # Hide system
+ if not item.isDir() and not item.isFile():
+ return False
+
+ fileName = item.fileName()
+ isDot = fileName == "."
+ isDotDot = fileName == ".."
+
+ if not (filters & qt.QDir.Hidden):
+ # Hide hidden
+ if not (isDot or isDotDot) and fileInfo.isHidden():
+ return False
+
+ if filters & (qt.QDir.NoDot | qt.QDir.NoDotDot | qt.QDir.NoDotAndDotDot):
+ # Hide parent/self references
+ if filters & qt.QDir.NoDot:
+ if isDot:
+ return False
+ if filters & qt.QDir.NoDotDot:
+ if isDotDot:
+ return False
+ if filters & qt.QDir.NoDotAndDotDot:
+ if isDot or isDotDot:
+ return False
+
+ return True
+
+ def filterAcceptsRow(self, sourceRow, sourceParent):
+ if not sourceParent.isValid():
+ return True
+
+ sourceModel = self.sourceModel()
+ index = sourceModel.index(sourceRow, 0, sourceParent)
+ if not index.isValid():
+ return True
+ item = sourceModel._item(index)
+
+ filters = self.__filters
+
+ if item.isDrive():
+ # Let say a user always have access to a drive
+ # It avoid to access to fileInfo then avoid to freeze the file
+ # system
+ return True
+
+ if not self.__filtersAccepted(item, filters):
+ return False
+
+ if self.__nameFilterDisables:
+ return True
+
+ if item.isDir() and (filters & qt.QDir.AllDirs):
+ # dont apply the filters to directory names
+ return True
+
+ return self.__nameFiltersAccepted(item)
+
+ def __nameFiltersAccepted(self, item):
+ if len(self.__nameFilters) == 0:
+ return True
+
+ fileName = item.fileName()
+ for reg in self.__nameFilters:
+ if reg.exactMatch(fileName):
+ return True
+ return False
+
+ def setNameFilterDisables(self, enable):
+ self.__nameFilterDisables = enable
+ self.invalidate()
+
+ def nameFilterDisables(self):
+ return self.__nameFilterDisables
+
+ def myComputer(self, role=qt.Qt.DisplayRole):
+ return self.sourceModel().myComputer(role)
+
+ def setNameFilters(self, filters):
+ self.__nameFilters = []
+ isCaseSensitive = self.__filters & qt.QDir.CaseSensitive
+ caseSensitive = qt.Qt.CaseSensitive if isCaseSensitive else qt.Qt.CaseInsensitive
+ for f in filters:
+ reg = qt.QRegExp(f, caseSensitive, qt.QRegExp.Wildcard)
+ self.__nameFilters.append(reg)
+ self.invalidate()
+
+ def nameFilters(self):
+ return [f.pattern() for f in self.__nameFilters]
+
+ def filter(self):
+ return self.__filters
+
+ def setFilter(self, filters):
+ self.__filters = filters
+ # In case of change of case sensitivity
+ self.setNameFilters(self.nameFilters())
+ self.invalidate()
+
+ def setReadOnly(self, enable):
+ assert(enable is True)
+
+ def isReadOnly(self):
+ return False
+
+ def rootPath(self):
+ return self.sourceModel().rootPath()
+
+ def setRootPath(self, path):
+ index = self.sourceModel().setRootPath(path)
+ index = self.mapFromSource(index)
+ return index
+
+ def flags(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ filters = sourceModel.flags(index)
+
+ if self.__nameFilterDisables and not sourceModel.isDir(index):
+ item = sourceModel._item(index)
+ if not self.__nameFiltersAccepted(item):
+ filters &= ~qt.Qt.ItemIsEnabled
+
+ return filters
+
+ def fileIcon(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.fileIcon(index)
+
+ def fileInfo(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.fileInfo(index)
+
+ def fileName(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.fileName(index)
+
+ def filePath(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.filePath(index)
+
+ def isDir(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.isDir(index)
+
+ def lastModified(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.lastModified(index)
+
+ def permissions(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.permissions(index)
+
+ def size(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.size(index)
+
+ def type(self, index):
+ sourceModel = self.sourceModel()
+ index = self.mapToSource(index)
+ return sourceModel.type(index)
diff --git a/silx/gui/dialog/__init__.py b/src/silx/gui/dialog/__init__.py
index 77c5949..77c5949 100644
--- a/silx/gui/dialog/__init__.py
+++ b/src/silx/gui/dialog/__init__.py
diff --git a/silx/gui/dialog/setup.py b/src/silx/gui/dialog/setup.py
index 48ab8d8..48ab8d8 100644
--- a/silx/gui/dialog/setup.py
+++ b/src/silx/gui/dialog/setup.py
diff --git a/src/silx/gui/dialog/test/__init__.py b/src/silx/gui/dialog/test/__init__.py
new file mode 100644
index 0000000..71128fb
--- /dev/null
+++ b/src/silx/gui/dialog/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/dialog/test/test_colormapdialog.py b/src/silx/gui/dialog/test/test_colormapdialog.py
new file mode 100644
index 0000000..16a5ab2
--- /dev/null
+++ b/src/silx/gui/dialog/test/test_colormapdialog.py
@@ -0,0 +1,395 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColormapDialog"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+
+import pytest
+import weakref
+
+from silx.gui import qt
+from silx.gui.dialog import ColormapDialog
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.colors import Colormap, preferredColormaps
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot.items.image import ImageData
+
+import numpy
+
+
+@pytest.fixture
+def colormap():
+ colormap = Colormap(name='gray',
+ vmin=10.0, vmax=20.0,
+ normalization='linear')
+ yield colormap
+
+
+@pytest.fixture
+def colormapDialog(qapp, qapp_utils):
+ dialog = ColormapDialog.ColormapDialog()
+ dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield weakref.proxy(dialog)
+ qapp.processEvents()
+ from silx.gui.qt import inspect
+ if inspect.isValid(dialog):
+ dialog.close()
+ qapp.processEvents()
+
+
+@pytest.fixture
+def colormap_class_attr(request, qapp_utils, colormap, colormapDialog):
+ """Provides few fixtures to a class as class attribute
+
+ Used as transition from TestCase to pytest
+ """
+ request.cls.qapp_utils = qapp_utils
+ request.cls.colormap = colormap
+ request.cls.colormapDiag = colormapDialog
+ yield
+ request.cls.qapp_utils = None
+ request.cls.colormap = None
+ request.cls.colormapDiag = None
+
+
+@pytest.mark.usefixtures("colormap_class_attr")
+class TestColormapDialog(TestCaseQt, ParametricTestCase):
+
+ def testGUIEdition(self):
+ """Make sure the colormap is correctly edited and also that the
+ modification are correctly updated if an other colormapdialog is
+ editing the same colormap"""
+ colormapDiag2 = ColormapDialog.ColormapDialog()
+ colormapDiag2.setColormap(self.colormap)
+ colormapDiag2.show()
+ self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+
+ self.colormapDiag._comboBoxColormap._setCurrentName('red')
+ self.colormapDiag._comboBoxNormalization.setCurrentIndex(
+ self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
+ self.assertTrue(self.colormap.getName() == 'red')
+ self.assertTrue(self.colormapDiag.getColormap().getName() == 'red')
+ self.assertTrue(self.colormap.getNormalization() == 'log')
+ self.assertTrue(self.colormap.getVMin() == 10)
+ self.assertTrue(self.colormap.getVMax() == 20)
+ # checked second colormap dialog
+ self.assertTrue(colormapDiag2._comboBoxColormap.getCurrentName() == 'red')
+ self.assertEqual(colormapDiag2._comboBoxNormalization.currentData(),
+ Colormap.LOGARITHM)
+ self.assertTrue(int(colormapDiag2._minValue.getValue()) == 10)
+ self.assertTrue(int(colormapDiag2._maxValue.getValue()) == 20)
+ colormapDiag2.close()
+
+ def testGUIModalOk(self):
+ """Make sure the colormap is modified if gone through accept"""
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(True)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.colormapDiag._maxValue.sigAutoScaleChanged.emit(True)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Ok),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.assertTrue(self.colormap.getVMax() is None)
+ self.assertTrue(self.colormap.isAutoscale() is True)
+
+ def testGUIModalCancel(self):
+ """Make sure the colormap is not modified if gone through reject"""
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(True)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsModal.button(qt.QDialogButtonBox.Cancel),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is not None)
+
+ def testGUIModalClose(self):
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(False)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Close),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is None)
+
+ def testGUIModalReset(self):
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.setModal(False)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.mouseClick(
+ widget=self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset),
+ button=qt.Qt.LeftButton
+ )
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag.close()
+
+ def testGUIClose(self):
+ """Make sure the colormap is modify if go through reject"""
+ assert self.colormap.isAutoscale() is False
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertTrue(self.colormap.getVMin() is not None)
+ self.colormapDiag._minValue.sigAutoScaleChanged.emit(True)
+ self.assertTrue(self.colormap.getVMin() is None)
+ self.colormapDiag.close()
+ self.qapp.processEvents()
+ self.assertTrue(self.colormap.getVMin() is None)
+
+ def testSetColormapIsCorrect(self):
+ """Make sure the interface fir the colormap when set a new colormap"""
+ self.colormap.setName('red')
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ for norm in (Colormap.NORMALIZATIONS):
+ for autoscale in (True, False):
+ if autoscale is True:
+ self.colormap.setVRange(None, None)
+ else:
+ self.colormap.setVRange(11, 101)
+ self.colormap.setNormalization(norm)
+ with self.subTest(colormap=self.colormap):
+ self.colormapDiag.setColormap(self.colormap)
+ self.assertEqual(
+ self.colormapDiag._comboBoxNormalization.currentData(), norm)
+ self.assertTrue(
+ self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
+ self.assertTrue(
+ self.colormapDiag._minValue.isAutoChecked() == autoscale)
+ self.assertTrue(
+ self.colormapDiag._maxValue.isAutoChecked() == autoscale)
+ if autoscale is False:
+ self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
+ self.assertTrue(self.colormapDiag._maxValue.getValue() == 101)
+ self.assertTrue(self.colormapDiag._minValue.isEnabled())
+ self.assertTrue(self.colormapDiag._maxValue.isEnabled())
+ else:
+ self.assertFalse(self.colormapDiag._minValue._numVal.isEnabled())
+ self.assertFalse(self.colormapDiag._maxValue._numVal.isEnabled())
+
+ def testColormapDel(self):
+ """Check behavior if the colormap has been deleted outside. For now
+ we make sure the colormap is still running and nothing more"""
+ colormap = Colormap(name='gray')
+ self.colormapDiag.setColormap(colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ colormap = None
+ self.assertTrue(self.colormapDiag.getColormap() is None)
+ self.colormapDiag._comboBoxColormap._setCurrentName('blue')
+
+ def testColormapEditedOutside(self):
+ """Make sure the GUI is still up to date if the colormap is modified
+ outside"""
+ self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+
+ self.colormap.setName('red')
+ self.assertTrue(
+ self.colormapDiag._comboBoxColormap.getCurrentName() == 'red')
+ self.colormap.setNormalization(Colormap.LOGARITHM)
+ self.assertEqual(self.colormapDiag._comboBoxNormalization.currentData(),
+ Colormap.LOGARITHM)
+ self.colormap.setVRange(11, 201)
+ self.assertTrue(self.colormapDiag._minValue.getValue() == 11)
+ self.assertTrue(self.colormapDiag._maxValue.getValue() == 201)
+ self.assertTrue(self.colormapDiag._minValue._numVal.isEnabled())
+ self.assertTrue(self.colormapDiag._maxValue._numVal.isEnabled())
+ self.assertFalse(self.colormapDiag._minValue.isAutoChecked())
+ self.assertFalse(self.colormapDiag._maxValue.isAutoChecked())
+ self.colormap.setVRange(None, None)
+ self.assertFalse(self.colormapDiag._minValue._numVal.isEnabled())
+ self.assertFalse(self.colormapDiag._maxValue._numVal.isEnabled())
+ self.assertTrue(self.colormapDiag._minValue.isAutoChecked())
+ self.assertTrue(self.colormapDiag._maxValue.isAutoChecked())
+
+ def testSetColormapScenario(self):
+ """Test of a simple scenario of a colormap dialog editing several
+ colormap"""
+ colormap1 = Colormap(name='gray', vmin=10.0, vmax=20.0,
+ normalization='linear')
+ colormap2 = Colormap(name='red', vmin=10.0, vmax=20.0,
+ normalization='log')
+ colormap3 = Colormap(name='blue', vmin=None, vmax=None,
+ normalization='linear')
+ self.colormapDiag.setColormap(self.colormap)
+ self.colormapDiag.setColormap(colormap1)
+ del colormap1
+ self.colormapDiag.setColormap(colormap2)
+ del colormap2
+ self.colormapDiag.setColormap(colormap3)
+ del colormap3
+
+ def testNotPreferredColormap(self):
+ """Test that the colormapEditor is able to edit a colormap which is not
+ part of the 'prefered colormap'
+ """
+ def getFirstNotPreferredColormap():
+ cms = Colormap.getSupportedColormaps()
+ preferred = preferredColormaps()
+ for cm in cms:
+ if cm not in preferred:
+ return cm
+ return None
+
+ colormapName = getFirstNotPreferredColormap()
+ assert colormapName is not None
+ colormap = Colormap(name=colormapName)
+ self.colormapDiag.setColormap(colormap)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ cb = self.colormapDiag._comboBoxColormap
+ self.assertTrue(cb.getCurrentName() == colormapName)
+ cb.setCurrentIndex(0)
+ index = cb.findLutName(colormapName)
+ assert index != 0 # if 0 then the rest of the test has no sense
+ cb.setCurrentIndex(index)
+ self.assertTrue(cb.getCurrentName() == colormapName)
+
+ def testColormapEditableMode(self):
+ """Test that the colormapDialog is correctly updated when changing the
+ colormap editable status"""
+ colormap = Colormap(normalization='linear', vmin=1.0, vmax=10.0)
+ self.colormapDiag.show()
+ self.qapp.processEvents()
+ self.colormapDiag.setColormap(colormap)
+ for editable in (True, False):
+ with self.subTest(editable=editable):
+ colormap.setEditable(editable)
+ self.assertTrue(
+ self.colormapDiag._comboBoxColormap.isEnabled() is editable)
+ self.assertTrue(
+ self.colormapDiag._minValue.isEnabled() is editable)
+ self.assertTrue(
+ self.colormapDiag._maxValue.isEnabled() is editable)
+ self.assertTrue(
+ self.colormapDiag._comboBoxNormalization.isEnabled() is editable)
+
+ # Make sure the reset button is also set to enable when edition mode is
+ # False
+ self.colormapDiag.setModal(False)
+ colormap.setEditable(True)
+ self.colormapDiag._comboBoxNormalization.setCurrentIndex(
+ self.colormapDiag._comboBoxNormalization.findData(Colormap.LOGARITHM))
+ resetButton = self.colormapDiag._buttonsNonModal.button(qt.QDialogButtonBox.Reset)
+ self.assertTrue(resetButton.isEnabled())
+ colormap.setEditable(False)
+ self.assertFalse(resetButton.isEnabled())
+
+ def testImageData(self):
+ data = numpy.random.rand(5, 5)
+ self.colormapDiag.setData(data)
+
+ def testEmptyData(self):
+ data = numpy.empty((10, 0))
+ self.colormapDiag.setData(data)
+
+ def testNoneData(self):
+ data = numpy.random.rand(5, 5)
+ self.colormapDiag.setData(data)
+ self.colormapDiag.setData(None)
+
+ def testImageItem(self):
+ """Check that an ImageData plot item can be used"""
+ dialog = self.colormapDiag
+ colormap = Colormap(name='gray', vmin=None, vmax=None)
+ data = numpy.arange(3**2).reshape(3, 3)
+ item = ImageData()
+ item.setData(data, copy=False)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ self.qapp.processEvents()
+ dialog.setItem(item)
+ vrange = dialog._getFiniteColormapRange()
+ self.assertEqual(vrange, (0, 8))
+
+ def testItemDel(self):
+ """Check that the plot items are not hard linked to the dialog"""
+ dialog = self.colormapDiag
+ colormap = Colormap(name='gray', vmin=None, vmax=None)
+ data = numpy.arange(3**2).reshape(3, 3)
+ item = ImageData()
+ item.setData(data, copy=False)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ self.qapp.processEvents()
+ dialog.setItem(item)
+ previousRange = dialog._getFiniteColormapRange()
+ del item
+ vrange = dialog._getFiniteColormapRange()
+ self.assertNotEqual(vrange, previousRange)
+
+ def testDataDel(self):
+ """Check that the data are not hard linked to the dialog"""
+ dialog = self.colormapDiag
+ colormap = Colormap(name='gray', vmin=None, vmax=None)
+ data = numpy.arange(5)
+
+ dialog.setColormap(colormap)
+ dialog.show()
+ self.qapp.processEvents()
+ dialog.setData(data)
+ previousRange = dialog._getFiniteColormapRange()
+ del data
+ vrange = dialog._getFiniteColormapRange()
+ self.assertNotEqual(vrange, previousRange)
+
+ def testDeleteWhileExec(self):
+ colormapDiag = self.colormapDiag
+ self.colormapDiag = None
+ qt.QTimer.singleShot(1000, colormapDiag.deleteLater)
+ result = colormapDiag.exec()
+ self.assertEqual(result, 0)
diff --git a/src/silx/gui/dialog/test/test_datafiledialog.py b/src/silx/gui/dialog/test/test_datafiledialog.py
new file mode 100644
index 0000000..8411c67
--- /dev/null
+++ b/src/silx/gui/dialog/test/test_datafiledialog.py
@@ -0,0 +1,924 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import unittest
+import tempfile
+import numpy
+import shutil
+import os
+import io
+import weakref
+import fabio
+import h5py
+import silx.io.url
+from silx.gui import qt
+from silx.gui.utils import testutils
+from ..DataFileDialog import DataFileDialog
+from silx.gui.hdf5 import Hdf5TreeModel
+
+_tmpDirectory = None
+
+
+def setUpModule():
+ global _tmpDirectory
+ _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
+
+ data = numpy.arange(100 * 100)
+ data.shape = 100, 100
+
+ filename = _tmpDirectory + "/singleimage.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/data.h5"
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f["nxdata/foo"] = 10
+ f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f.close()
+
+ directory = os.path.join(_tmpDirectory, "data")
+ os.mkdir(directory)
+ filename = os.path.join(directory, "data.h5")
+ f = h5py.File(filename, "w")
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+ f["nxdata/foo"] = 10
+ f["nxdata"].attrs["NX_class"] = u"NXdata"
+ f.close()
+
+ filename = _tmpDirectory + "/badformat.h5"
+ with io.open(filename, "wb") as f:
+ f.write(b"{\nHello Nurse!")
+
+
+def tearDownModule():
+ global _tmpDirectory
+ shutil.rmtree(_tmpDirectory)
+ _tmpDirectory = None
+
+
+class _UtilsMixin(object):
+
+ def createDialog(self):
+ self._deleteDialog()
+ self._dialog = self._createDialog()
+ return self._dialog
+
+ def _createDialog(self):
+ return DataFileDialog()
+
+ def _deleteDialog(self):
+ if not hasattr(self, "_dialog"):
+ return
+ if self._dialog is not None:
+ ref = weakref.ref(self._dialog)
+ self._dialog = None
+ self.qWaitForDestroy(ref)
+
+ def qWaitForPendingActions(self, dialog):
+ for _ in range(20):
+ if not dialog.hasPendingEvents():
+ return
+ self.qWait(10)
+ raise RuntimeError("Still have pending actions")
+
+ def assertSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ != path2_:
+ # Use the unittest API to log and display error
+ self.assertEqual(path1, path2)
+
+ def assertNotSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ == path2_:
+ # Use the unittest API to log and display error
+ self.assertNotEqual(path1, path2)
+
+
+class TestDataFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def testDisplayAndKeyEscape(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ self.keyClick(dialog, qt.Qt.Key_Escape)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickCancel(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="cancel")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.assertFalse(dialog.isVisible())
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickLockedOpen(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ # open button locked, dialog is not closed
+ self.assertTrue(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testSelectRoot_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertTrue(url.data_path() is not None)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testSelectGroup_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testSelectDataset_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/scalar")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testClickOnBackToParentTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toParentAction")[0]
+ toParentButton = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ # test
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory + "/data")
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+
+ def testClickOnBackToRootTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toRootFileAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+ # self.assertFalse(button.isEnabled())
+
+ def testClickOnBackToDirectoryTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toDirectoryAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+ self.assertFalse(button.isEnabled())
+
+ # FIXME: There is an unreleased qt.QWidget without nameObject
+ # No idea where it come from.
+ self.allowedLeakingWidgets = 1
+
+ def testClickOnHistoryTools(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
+ backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
+ filename = _tmpDirectory + "/data.h5"
+
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ # No way to use QTest.mouseDClick with QListView, QListWidget
+ # Then we feed the history using selectPath
+ dialog.selectUrl(filename)
+ self.qWaitForPendingActions(dialog)
+ path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ dialog.selectUrl(path2)
+ self.qWaitForPendingActions(dialog)
+ path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
+ dialog.selectUrl(path3)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+
+ button = testutils.getQToolButtonFromAction(backwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertTrue(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path2)
+
+ button = testutils.getQToolButtonFromAction(forwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path3)
+
+ def testSelectImageFromEdf(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/singleimage.edf"
+ url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scan_0/instrument/detector_0/data")
+ dialog.selectUrl(url.path())
+ self.assertEqual(dialog._selectedData().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), url.path())
+
+ def testSelectImage(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog._selectedData().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectScalar(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/scalar").path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog._selectedData()[()], 10)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectGroup(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ uri = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group")
+ dialog.selectUrl(uri.path())
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertTrue(silx.io.is_group(dialog._selectedData()))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ uri = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertSamePath(uri.data_path(), "/group")
+
+ def testSelectRoot(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ uri = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/")
+ dialog.selectUrl(uri.path())
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertTrue(silx.io.is_file(dialog._selectedData()))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ uri = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertSamePath(uri.data_path(), "/")
+
+ def testSelectH5_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ index = browser.rootIndex().model().index(filename)
+ # click
+ browser.selectIndex(index)
+ # double click
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectBadFileFormat_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/badformat.h5"
+ index = browser.model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), filename)
+
+ def _countSelectableItems(self, model, rootIndex):
+ selectable = 0
+ for i in range(model.rowCount(rootIndex)):
+ index = model.index(i, 0, rootIndex)
+ flags = model.flags(index)
+ isEnabled = (int(flags) & qt.Qt.ItemIsEnabled) != 0
+ if isEnabled:
+ selectable += 1
+ return selectable
+
+ def testFilterExtensions(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
+
+
+class TestDataFileDialog_FilterDataset(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingDataset)
+ return dialog
+
+ def testSelectGroup_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertFalse(button.isEnabled())
+
+ def testSelectDataset_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/scalar")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ data = dialog.selectedData()
+ self.assertEqual(data, 10)
+
+
+class TestDataFileDialog_FilterGroup(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
+ return dialog
+
+ def testSelectGroup_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ self.assertRaises(Exception, dialog.selectedData)
+
+ def testSelectDataset_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/scalar"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertFalse(button.isEnabled())
+
+
+class TestDataFileDialog_FilterNXdata(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ def customFilter(obj):
+ if "NX_class" in obj.attrs:
+ return obj.attrs["NX_class"] == u"NXdata"
+ return False
+
+ dialog = DataFileDialog()
+ dialog.setFilterMode(DataFileDialog.FilterMode.ExistingGroup)
+ dialog.setFilterCallback(customFilter)
+ return dialog
+
+ def testSelectGroupRefused_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/group"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertFalse(button.isEnabled())
+
+ self.assertRaises(Exception, dialog.selectedData)
+
+ def testSelectNXdataAccepted_Activate(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/data.h5"
+ dialog.selectFile(os.path.dirname(filename))
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ # select, then double click on the file
+ index = browser.rootIndex().model().indexFromH5Object(dialog._AbstractDataFileDialog__h5["/nxdata"])
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/nxdata")
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+
+class TestDataFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def _createDialog(self):
+ dialog = DataFileDialog()
+ return dialog
+
+ def testSaveRestoreState(self):
+ dialog = self.createDialog()
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.assertTrue(result)
+ dialog2 = None
+
+ def printState(self):
+ """
+ Print state of the ImageFileDialog.
+
+ Can be used to add or regenerate `STATE_VERSION1_QT4` or
+ `STATE_VERSION1_QT5`.
+
+ >>> ./run_tests.py -v silx.gui.dialog.test.test_datafiledialog.TestDataFileDialogApi.printState
+ """
+ dialog = self.createDialog()
+ dialog.setDirectory("")
+ dialog.setHistory([])
+ dialog.setSidebarUrls([])
+ state = dialog.saveState()
+ string = ""
+ strings = []
+ for i in range(state.size()):
+ d = state.data()[i]
+ if not isinstance(d, int):
+ d = ord(d)
+ if d > 0x20 and d < 0x7F:
+ string += chr(d)
+ else:
+ string += "\\x%02X" % d
+ if len(string) > 60:
+ strings.append(string)
+ string = ""
+ strings.append(string)
+ strings = ["b'%s'" % s for s in strings]
+ print()
+ print("\\\n".join(strings))
+
+ STATE_VERSION1_QT4 = b''\
+ b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
+ b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
+ b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00\xFF\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00\x00'\
+ b'}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00\xFF\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00\x00\x81'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00\x00\x00\x04'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x01\xFF\xFF\xFF\xFF'
+ """Serialized state on Qt4. Generated using :meth:`printState`"""
+
+ STATE_VERSION1_QT5 = b''\
+ b'\x00\x00\x00Z\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00a\x00F\x00i'\
+ b'\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00D\x00a\x00t\x00'\
+ b'a\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00\xFF\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00\x00\x00'\
+ b'\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00r\x00'\
+ b'\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87\x00\x00\x00\xFF'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00\x00'\
+ b'\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00d\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00'\
+ b'\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00'\
+ b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03\xE8\x00\xFF'\
+ b'\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x01'
+ """Serialized state on Qt5. Generated using :meth:`printState`"""
+
+ def testAvoidRestoreRegression_Version1(self):
+ version = qt.qVersion().split(".")[0]
+ if version == "4":
+ state = self.STATE_VERSION1_QT4
+ elif version == "5":
+ state = self.STATE_VERSION1_QT5
+ else:
+ self.skipTest("Resource not available")
+
+ state = qt.QByteArray(state)
+ dialog = self.createDialog()
+ result = dialog.restoreState(state)
+ self.assertTrue(result)
+
+ def testRestoreRobusness(self):
+ """What's happen if you try to open a config file with a different
+ binding."""
+ state = qt.QByteArray(self.STATE_VERSION1_QT4)
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+ state = qt.QByteArray(self.STATE_VERSION1_QT5)
+ dialog = None
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+
+ def testRestoreNonExistingDirectory(self):
+ directory = os.path.join(_tmpDirectory, "dir")
+ os.mkdir(directory)
+ dialog = self.createDialog()
+ dialog.setDirectory(directory)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ os.rmdir(directory)
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.assertTrue(result)
+ self.assertNotEqual(dialog2.directory(), directory)
+
+ def testHistory(self):
+ dialog = self.createDialog()
+ history = dialog.history()
+ dialog.setHistory([])
+ self.assertEqual(dialog.history(), [])
+ dialog.setHistory(history)
+ self.assertEqual(dialog.history(), history)
+
+ def testSidebarUrls(self):
+ dialog = self.createDialog()
+ urls = dialog.sidebarUrls()
+ dialog.setSidebarUrls([])
+ self.assertEqual(dialog.sidebarUrls(), [])
+ dialog.setSidebarUrls(urls)
+ self.assertEqual(dialog.sidebarUrls(), urls)
+
+ def testDirectory(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(dialog.directory(), _tmpDirectory)
+
+ def testBadFileFormat(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/badformat.h5")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadPath(self):
+ dialog = self.createDialog()
+ dialog.selectUrl("#$%/#$%")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadSubpath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+
+ filename = _tmpDirectory + "/data.h5"
+ url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
+ dialog.selectUrl(url.path())
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNotNone(dialog._selectedData())
+
+ # an existing node is browsed, but the wrong path is selected
+ index = browser.rootIndex()
+ obj = index.model().data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertEqual(obj.name, "/group")
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+
+ def testUnsupportedSlicingPath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory + "/data.h5?path=/cube&slice=0")
+ self.qWaitForPendingActions(dialog)
+ data = dialog._selectedData()
+ if data is None:
+ # Maybe nothing is selected
+ self.assertTrue(True)
+ else:
+ # Maybe the cube is selected but not sliced
+ self.assertEqual(len(data.shape), 3)
diff --git a/src/silx/gui/dialog/test/test_imagefiledialog.py b/src/silx/gui/dialog/test/test_imagefiledialog.py
new file mode 100644
index 0000000..9e204b9
--- /dev/null
+++ b/src/silx/gui/dialog/test/test_imagefiledialog.py
@@ -0,0 +1,772 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import unittest
+import tempfile
+import numpy
+import shutil
+import os
+import io
+import weakref
+import fabio
+import h5py
+import silx.io.url
+from silx.gui import qt
+from silx.gui.utils import testutils
+from ..ImageFileDialog import ImageFileDialog
+from silx.gui.colors import Colormap
+from silx.gui.hdf5 import Hdf5TreeModel
+
+_tmpDirectory = None
+
+
+def setUpModule():
+ global _tmpDirectory
+ _tmpDirectory = tempfile.mkdtemp(prefix=__name__)
+
+ data = numpy.arange(100 * 100)
+ data.shape = 100, 100
+
+ filename = _tmpDirectory + "/singleimage.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/multiframe.edf"
+ image = fabio.edfimage.EdfImage(data=data)
+ image.append_frame(data=data + 1)
+ image.append_frame(data=data + 2)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/singleimage.msk"
+ image = fabio.fit2dmaskimage.Fit2dMaskImage(data=data % 2 == 1)
+ image.write(filename)
+
+ filename = _tmpDirectory + "/data.h5"
+ with h5py.File(filename, "w") as f:
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["single_frame"] = [data + 5]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+
+ directory = os.path.join(_tmpDirectory, "data")
+ os.mkdir(directory)
+ filename = os.path.join(directory, "data.h5")
+ with h5py.File(filename, "w") as f:
+ f["scalar"] = 10
+ f["image"] = data
+ f["cube"] = [data, data + 1, data + 2]
+ f["single_frame"] = [data + 5]
+ f["complex_image"] = data * 1j
+ f["group/image"] = data
+
+ filename = _tmpDirectory + "/badformat.edf"
+ with io.open(filename, "wb") as f:
+ f.write(b"{\nHello Nurse!")
+
+
+def tearDownModule():
+ global _tmpDirectory
+ shutil.rmtree(_tmpDirectory)
+ _tmpDirectory = None
+
+
+class _UtilsMixin(object):
+
+ def createDialog(self):
+ self._deleteDialog()
+ self._dialog = self._createDialog()
+ return self._dialog
+
+ def _createDialog(self):
+ return ImageFileDialog()
+
+ def _deleteDialog(self):
+ if not hasattr(self, "_dialog"):
+ return
+ if self._dialog is not None:
+ ref = weakref.ref(self._dialog)
+ self._dialog = None
+ self.qWaitForDestroy(ref)
+
+ def qWaitForPendingActions(self, dialog):
+ for _ in range(20):
+ if not dialog.hasPendingEvents():
+ return
+ self.qWait(10)
+ raise RuntimeError("Still have pending actions")
+
+ def assertSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ != path2_:
+ # Use the unittest API to log and display error
+ self.assertEqual(path1, path2)
+
+ def assertNotSamePath(self, path1, path2):
+ path1_ = os.path.normcase(path1)
+ path2_ = os.path.normcase(path2)
+ if path1_ == path2_:
+ # Use the unittest API to log and display error
+ self.assertNotEqual(path1, path2)
+
+
+class TestImageFileDialogInteraction(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def testDisplayAndKeyEscape(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ self.keyClick(dialog, qt.Qt.Key_Escape)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickCancel(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="cancel")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.assertFalse(dialog.isVisible())
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickLockedOpen(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.mouseClick(button, qt.Qt.LeftButton)
+ # open button locked, dialog is not closed
+ self.assertTrue(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Rejected)
+
+ def testDisplayAndClickOpen(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ self.assertTrue(dialog.isVisible())
+ filename = _tmpDirectory + "/singleimage.edf"
+ dialog.selectFile(filename)
+ self.qWaitForPendingActions(dialog)
+
+ button = testutils.findChildren(dialog, qt.QPushButton, name="open")[0]
+ self.assertTrue(button.isEnabled())
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.assertFalse(dialog.isVisible())
+ self.assertEqual(dialog.result(), qt.QDialog.Accepted)
+
+ def testClickOnShortcut(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ sidebar = testutils.findChildren(dialog, qt.QListView, name="sidebar")[0]
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+
+ self.assertSamePath(url.text(), _tmpDirectory)
+
+ urls = sidebar.urls()
+ if len(urls) == 0:
+ self.skipTest("No sidebar path")
+ path = urls[0].path()
+ if path != "" and not os.path.exists(path):
+ self.skipTest("Sidebar path do not exists")
+
+ index = sidebar.model().index(0, 0)
+ # rect = sidebar.visualRect(index)
+ # self.mouseClick(sidebar, qt.Qt.LeftButton, pos=rect.center())
+ # Using mouse click is not working, let's use the selection API
+ sidebar.selectionModel().select(index, qt.QItemSelectionModel.ClearAndSelect)
+ self.qWaitForPendingActions(dialog)
+
+ index = browser.rootIndex()
+ if not index.isValid():
+ path = ""
+ else:
+ path = index.model().filePath(index)
+ self.assertNotSamePath(_tmpDirectory, path)
+ self.assertNotSamePath(url.text(), _tmpDirectory)
+
+ def testClickOnDetailView(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ action = testutils.findChildren(dialog, qt.QAction, name="detailModeAction")[0]
+ detailModeButton = testutils.getQToolButtonFromAction(action)
+ self.mouseClick(detailModeButton, qt.Qt.LeftButton)
+ self.assertEqual(dialog.viewMode(), qt.QFileDialog.Detail)
+
+ action = testutils.findChildren(dialog, qt.QAction, name="listModeAction")[0]
+ listModeButton = testutils.getQToolButtonFromAction(action)
+ self.mouseClick(listModeButton, qt.Qt.LeftButton)
+ self.assertEqual(dialog.viewMode(), qt.QFileDialog.List)
+
+ def testClickOnBackToParentTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toParentAction")[0]
+ toParentButton = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ # test
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory + "/data")
+
+ self.mouseClick(toParentButton, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+
+ def testClickOnBackToRootTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toRootFileAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ self.assertSamePath(url.text(), path)
+ # self.assertFalse(button.isEnabled())
+
+ def testClickOnBackToDirectoryTool(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ action = testutils.findChildren(dialog, qt.QAction, name="toDirectoryAction")[0]
+ button = testutils.getQToolButtonFromAction(action)
+ filename = _tmpDirectory + "/data.h5"
+
+ # init state
+ path = silx.io.url.DataUrl(file_path=filename, data_path="/group/image").path()
+ dialog.selectUrl(path)
+ self.qWaitForPendingActions(dialog)
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/image").path()
+ self.assertSamePath(url.text(), path)
+ self.assertTrue(button.isEnabled())
+ # test
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(url.text(), _tmpDirectory)
+ self.assertFalse(button.isEnabled())
+
+ # FIXME: There is an unreleased qt.QWidget without nameObject
+ # No idea where it come from.
+ self.allowedLeakingWidgets = 1
+
+ def testClickOnHistoryTools(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ url = testutils.findChildren(dialog, qt.QLineEdit, name="url")[0]
+ forwardAction = testutils.findChildren(dialog, qt.QAction, name="forwardAction")[0]
+ backwardAction = testutils.findChildren(dialog, qt.QAction, name="backwardAction")[0]
+ filename = _tmpDirectory + "/data.h5"
+
+ dialog.setDirectory(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ # No way to use QTest.mouseDClick with QListView, QListWidget
+ # Then we feed the history using selectPath
+ dialog.selectUrl(filename)
+ self.qWaitForPendingActions(dialog)
+ path2 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ dialog.selectUrl(path2)
+ self.qWaitForPendingActions(dialog)
+ path3 = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group").path()
+ dialog.selectUrl(path3)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+
+ button = testutils.getQToolButtonFromAction(backwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertTrue(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path2)
+
+ button = testutils.getQToolButtonFromAction(forwardAction)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qWaitForPendingActions(dialog)
+ self.assertFalse(forwardAction.isEnabled())
+ self.assertTrue(backwardAction.isEnabled())
+ self.assertSamePath(url.text(), path3)
+
+ def testSelectImageFromEdf(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/singleimage.edf"
+ path = filename
+ dialog.selectUrl(path)
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectImageFromEdf_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/singleimage.edf"
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
+ index = browser.rootIndex().model().index(filename)
+ # click
+ browser.selectIndex(index)
+ # double click
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectFrameFromEdf(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/multiframe.edf"
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename, data_slice=(1,)).path()
+ dialog.selectUrl(path)
+ # test
+ image = dialog.selectedImage()
+ self.assertEqual(image.shape, (100, 100))
+ self.assertEqual(image[0, 0], 1)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectImageFromMsk(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/singleimage.msk"
+ path = silx.io.url.DataUrl(scheme="fabio", file_path=filename).path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectImageFromH5(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/image").path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectH5_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/").path()
+ index = browser.rootIndex().model().index(filename)
+ # click
+ browser.selectIndex(index)
+ # double click
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectFrameFromH5(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/cube", data_slice=(1, )).path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertEqual(dialog.selectedImage()[0, 0], 1)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectSingleFrameFromH5(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ filename = _tmpDirectory + "/data.h5"
+ path = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/single_frame", data_slice=(0, )).path()
+ dialog.selectUrl(path)
+ # test
+ self.assertEqual(dialog.selectedImage().shape, (100, 100))
+ self.assertEqual(dialog.selectedImage()[0, 0], 5)
+ self.assertSamePath(dialog.selectedFile(), filename)
+ self.assertSamePath(dialog.selectedUrl(), path)
+
+ def testSelectBadFileFormat_Activate(self):
+ dialog = self.createDialog()
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+
+ # init state
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filename = _tmpDirectory + "/badformat.edf"
+ index = browser.model().index(filename)
+ browser.selectIndex(index)
+ browser.activated.emit(index)
+ self.qWaitForPendingActions(dialog)
+ # test
+ self.assertSamePath(dialog.selectedUrl(), filename)
+
+ def _countSelectableItems(self, model, rootIndex):
+ selectable = 0
+ for i in range(model.rowCount(rootIndex)):
+ index = model.index(i, 0, rootIndex)
+ flags = model.flags(index)
+ isEnabled = (int(flags) & qt.Qt.ItemIsEnabled) != 0
+ if isEnabled:
+ selectable += 1
+ return selectable
+
+ def testFilterExtensions(self):
+ dialog = self.createDialog()
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+ filters = testutils.findChildren(dialog, qt.QWidget, name="fileTypeCombo")[0]
+ dialog.show()
+ self.qWaitForWindowExposed(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 6)
+
+ codecName = fabio.edfimage.EdfImage.codec_name()
+ index = filters.indexFromCodec(codecName)
+ filters.setCurrentIndex(index)
+ filters.activated[int].emit(index)
+ self.qWait(50)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 4)
+
+ codecName = fabio.fit2dmaskimage.Fit2dMaskImage.codec_name()
+ index = filters.indexFromCodec(codecName)
+ filters.setCurrentIndex(index)
+ filters.activated[int].emit(index)
+ self.qWait(50)
+ self.assertEqual(self._countSelectableItems(browser.model(), browser.rootIndex()), 2)
+
+
+class TestImageFileDialogApi(testutils.TestCaseQt, _UtilsMixin):
+
+ def tearDown(self):
+ self._deleteDialog()
+ testutils.TestCaseQt.tearDown(self)
+
+ def testSaveRestoreState(self):
+ dialog = self.createDialog()
+ dialog.setDirectory(_tmpDirectory)
+ colormap = Colormap(normalization=Colormap.LOGARITHM)
+ dialog.setColormap(colormap)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.qWaitForPendingActions(dialog2)
+ self.assertTrue(result)
+ self.assertEqual(dialog2.colormap().getNormalization(), "log")
+
+ def printState(self):
+ """
+ Print state of the ImageFileDialog.
+
+ Can be used to add or regenerate `STATE_VERSION1_QT4` or
+ `STATE_VERSION1_QT5`.
+
+ >>> ./run_tests.py -v silx.gui.dialog.test.test_imagefiledialog.TestImageFileDialogApi.printState
+ """
+ dialog = self.createDialog()
+ colormap = Colormap(normalization=Colormap.LOGARITHM)
+ dialog.setDirectory("")
+ dialog.setHistory([])
+ dialog.setColormap(colormap)
+ dialog.setSidebarUrls([])
+ state = dialog.saveState()
+ string = ""
+ strings = []
+ for i in range(state.size()):
+ d = state.data()[i]
+ if not isinstance(d, int):
+ d = ord(d)
+ if d > 0x20 and d < 0x7F:
+ string += chr(d)
+ else:
+ string += "\\x%02X" % d
+ if len(string) > 60:
+ strings.append(string)
+ string = ""
+ strings.append(string)
+ strings = ["b'%s'" % s for s in strings]
+ print()
+ print("\\\n".join(strings))
+
+ STATE_VERSION1_QT4 = b''\
+ b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
+ b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
+ b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
+ b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00"\x00\x00\x00'\
+ b'\xFF\x00\x00\x00\x00\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x06\x01\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C\x00'\
+ b'\x00\x00\x00}\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s\x00e\x00'\
+ b'r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00Z\x00\x00\x00'\
+ b'\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF\xFF\xFF\x00'\
+ b'\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x01\x90\x00'\
+ b'\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00'\
+ b'o\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
+ b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
+ """Serialized state on Qt4. Generated using :meth:`printState`"""
+
+ STATE_VERSION1_QT5 = b''\
+ b'\x00\x00\x00^\x00s\x00i\x00l\x00x\x00.\x00g\x00u\x00i\x00.\x00'\
+ b'd\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00a\x00g\x00e\x00F'\
+ b'\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g\x00.\x00I\x00m\x00'\
+ b'a\x00g\x00e\x00F\x00i\x00l\x00e\x00D\x00i\x00a\x00l\x00o\x00g'\
+ b'\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00#\x00\x00\x00'\
+ b'\xFF\x00\x00\x00\x01\x00\x00\x00\x03\xFF\xFF\xFF\xFF\xFF\xFF\xFF'\
+ b'\xFF\xFF\xFF\xFF\xFF\x01\xFF\xFF\xFF\xFF\x01\x00\x00\x00\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0C'\
+ b'\x00\x00\x00\x00\xAA\x00\x00\x00\x0E\x00B\x00r\x00o\x00w\x00s'\
+ b'\x00e\x00r\x00\x00\x00\x01\x00\x00\x00\x0C\x00\x00\x00\x00\x87'\
+ b'\x00\x00\x00\xFF\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x01\x90\x00\x00\x00\x04\x01\x01\x00'\
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00d\xFF\xFF'\
+ b'\xFF\xFF\x00\x00\x00\x81\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00'\
+ b'\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00'\
+ b'\x01\x00\x00\x00\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00'\
+ b'\x00\x00\x00\x00d\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03'\
+ b'\xE8\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00'\
+ b'\x00\x0C\x00\x00\x00\x000\x00\x00\x00\x10\x00C\x00o\x00l\x00o'\
+ b'\x00r\x00m\x00a\x00p\x00\x00\x00\x01\x00\x00\x00\x08\x00g\x00'\
+ b'r\x00a\x00y\x01\x01\x00\x00\x00\x06\x00l\x00o\x00g'
+ """Serialized state on Qt5. Generated using :meth:`printState`"""
+
+ def testAvoidRestoreRegression_Version1(self):
+ version = qt.qVersion().split(".")[0]
+ if version == "4":
+ state = self.STATE_VERSION1_QT4
+ elif version == "5":
+ state = self.STATE_VERSION1_QT5
+ else:
+ self.skipTest("Resource not available")
+
+ state = qt.QByteArray(state)
+ dialog = self.createDialog()
+ result = dialog.restoreState(state)
+ self.assertTrue(result)
+ colormap = dialog.colormap()
+ self.assertEqual(colormap.getNormalization(), "log")
+
+ def testRestoreRobusness(self):
+ """What's happen if you try to open a config file with a different
+ binding."""
+ state = qt.QByteArray(self.STATE_VERSION1_QT4)
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+ state = qt.QByteArray(self.STATE_VERSION1_QT5)
+ dialog = None
+ dialog = self.createDialog()
+ dialog.restoreState(state)
+
+ def testRestoreNonExistingDirectory(self):
+ directory = os.path.join(_tmpDirectory, "dir")
+ os.mkdir(directory)
+ dialog = self.createDialog()
+ dialog.setDirectory(directory)
+ self.qWaitForPendingActions(dialog)
+ state = dialog.saveState()
+ os.rmdir(directory)
+ dialog = None
+
+ dialog2 = self.createDialog()
+ result = dialog2.restoreState(state)
+ self.assertTrue(result)
+ self.assertNotEqual(dialog2.directory(), directory)
+
+ def testHistory(self):
+ dialog = self.createDialog()
+ history = dialog.history()
+ dialog.setHistory([])
+ self.assertEqual(dialog.history(), [])
+ dialog.setHistory(history)
+ self.assertEqual(dialog.history(), history)
+
+ def testSidebarUrls(self):
+ dialog = self.createDialog()
+ urls = dialog.sidebarUrls()
+ dialog.setSidebarUrls([])
+ self.assertEqual(dialog.sidebarUrls(), [])
+ dialog.setSidebarUrls(urls)
+ self.assertEqual(dialog.sidebarUrls(), urls)
+
+ def testColomap(self):
+ dialog = self.createDialog()
+ colormap = dialog.colormap()
+ self.assertEqual(colormap.getNormalization(), "linear")
+ colormap = Colormap(normalization=Colormap.LOGARITHM)
+ dialog.setColormap(colormap)
+ self.assertEqual(colormap.getNormalization(), "log")
+
+ def testDirectory(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory)
+ self.qWaitForPendingActions(dialog)
+ self.assertSamePath(dialog.directory(), _tmpDirectory)
+
+ def testBadDataType(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/data.h5::/complex_image")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadDataShape(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/data.h5::/unknown")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadDataFormat(self):
+ dialog = self.createDialog()
+ dialog.selectUrl(_tmpDirectory + "/badformat.edf")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadPath(self):
+ dialog = self.createDialog()
+ dialog.selectUrl("#$%/#$%")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ def testBadSubpath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+
+ browser = testutils.findChildren(dialog, qt.QWidget, name="browser")[0]
+
+ filename = _tmpDirectory + "/data.h5"
+ url = silx.io.url.DataUrl(scheme="silx", file_path=filename, data_path="/group/foobar")
+ dialog.selectUrl(url.path())
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
+
+ # an existing node is browsed, but the wrong path is selected
+ index = browser.rootIndex()
+ obj = index.model().data(index, role=Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertEqual(obj.name, "/group")
+ url = silx.io.url.DataUrl(dialog.selectedUrl())
+ self.assertEqual(url.data_path(), "/group")
+
+ def testBadSlicingPath(self):
+ dialog = self.createDialog()
+ self.qWaitForPendingActions(dialog)
+ dialog.selectUrl(_tmpDirectory + "/data.h5::/cube[a;45,-90]")
+ self.qWaitForPendingActions(dialog)
+ self.assertIsNone(dialog._selectedData())
diff --git a/src/silx/gui/dialog/utils.py b/src/silx/gui/dialog/utils.py
new file mode 100644
index 0000000..4c48930
--- /dev/null
+++ b/src/silx/gui/dialog/utils.py
@@ -0,0 +1,99 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module contains utilitaries used by other dialog modules.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "25/10/2017"
+
+import os
+import types
+
+from silx.gui import qt
+
+
+def samefile(path1, path2):
+ """Portable :func:`os.path.samepath` function.
+
+ :param str path1: A path to a file
+ :param str path2: Another path to a file
+ :rtype: bool
+ """
+ if path1 == path2:
+ return True
+ if path1 == "":
+ return False
+ if path2 == "":
+ return False
+ return os.path.samefile(path1, path2)
+
+
+def findClosestSubPath(hdf5Object, path):
+ """Find the closest existing path from the hdf5Object using a subset of the
+ provided path.
+
+ Returns None if no path found. It is possible if the path is a relative
+ path.
+
+ :param h5py.Node hdf5Object: An HDF5 node
+ :param str path: A path
+ :rtype: str
+ """
+ if path in ["", "/"]:
+ return "/"
+ names = path.split("/")
+ if path[0] == "/":
+ names.pop(0)
+ for i in range(len(names)):
+ n = len(names) - i
+ path2 = "/".join(names[0:n])
+ if path2 == "":
+ return ""
+ if path2 in hdf5Object:
+ return path2
+
+ if path[0] == "/":
+ return "/"
+ return None
+
+
+def patchToConsumeReturnKey(widget):
+ """
+ Monkey-patch a widget to consume the return key instead of propagating it
+ to the dialog.
+ """
+ assert(not hasattr(widget, "_oldKeyPressEvent"))
+
+ def keyPressEvent(self, event):
+ k = event.key()
+ result = self._oldKeyPressEvent(event)
+ if k in [qt.Qt.Key_Return, qt.Qt.Key_Enter]:
+ event.accept()
+ return result
+
+ widget._oldKeyPressEvent = widget.keyPressEvent
+ widget.keyPressEvent = types.MethodType(keyPressEvent, widget)
diff --git a/src/silx/gui/fit/BackgroundWidget.py b/src/silx/gui/fit/BackgroundWidget.py
new file mode 100644
index 0000000..7703ee1
--- /dev/null
+++ b/src/silx/gui/fit/BackgroundWidget.py
@@ -0,0 +1,534 @@
+# coding: utf-8
+#/*##########################################################################
+# Copyright (C) 2004-2021 V.A. Sole, European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# #########################################################################*/
+"""This module provides a background configuration widget
+:class:`BackgroundWidget` and a corresponding dialog window
+:class:`BackgroundDialog`.
+
+.. image:: img/BackgroundDialog.png
+ :height: 300px
+"""
+import sys
+import numpy
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from silx.math.fit import filters
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/06/2017"
+
+
+class HorizontalSpacer(qt.QWidget):
+ def __init__(self, *args):
+ qt.QWidget.__init__(self, *args)
+ self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Fixed))
+
+
+class BackgroundParamWidget(qt.QWidget):
+ """Background configuration composite widget.
+
+ Strip and snip filters parameters can be adjusted using input widgets.
+
+ Updating the widgets causes :attr:`sigBackgroundParamWidgetSignal` to
+ be emitted.
+ """
+ sigBackgroundParamWidgetSignal = qt.pyqtSignal(object)
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.mainLayout = qt.QGridLayout(self)
+ self.mainLayout.setColumnStretch(1, 1)
+
+ # Algorithm choice ---------------------------------------------------
+ self.algorithmComboLabel = qt.QLabel(self)
+ self.algorithmComboLabel.setText("Background algorithm")
+ self.algorithmCombo = qt.QComboBox(self)
+ self.algorithmCombo.addItem("Strip")
+ self.algorithmCombo.addItem("Snip")
+ self.algorithmCombo.activated[int].connect(
+ self._algorithmComboActivated)
+
+ # Strip parameters ---------------------------------------------------
+ self.stripWidthLabel = qt.QLabel(self)
+ self.stripWidthLabel.setText("Strip Width")
+
+ self.stripWidthSpin = qt.QSpinBox(self)
+ self.stripWidthSpin.setMaximum(100)
+ self.stripWidthSpin.setMinimum(1)
+ self.stripWidthSpin.valueChanged[int].connect(self._emitSignal)
+
+ self.stripIterLabel = qt.QLabel(self)
+ self.stripIterLabel.setText("Strip Iterations")
+ self.stripIterValue = qt.QLineEdit(self)
+ validator = qt.QIntValidator(self.stripIterValue)
+ self.stripIterValue._v = validator
+ self.stripIterValue.setText("0")
+ self.stripIterValue.editingFinished[()].connect(self._emitSignal)
+ self.stripIterValue.setToolTip(
+ "Number of iterations for strip algorithm.\n" +
+ "If greater than 999, an 2nd pass of strip filter is " +
+ "applied to remove artifacts created by first pass.")
+
+ # Snip parameters ----------------------------------------------------
+ self.snipWidthLabel = qt.QLabel(self)
+ self.snipWidthLabel.setText("Snip Width")
+
+ self.snipWidthSpin = qt.QSpinBox(self)
+ self.snipWidthSpin.setMaximum(300)
+ self.snipWidthSpin.setMinimum(0)
+ self.snipWidthSpin.valueChanged[int].connect(self._emitSignal)
+
+
+ # Smoothing parameters -----------------------------------------------
+ self.smoothingFlagCheck = qt.QCheckBox(self)
+ self.smoothingFlagCheck.setText("Smoothing Width (Savitsky-Golay)")
+ self.smoothingFlagCheck.toggled.connect(self._smoothingToggled)
+
+ self.smoothingSpin = qt.QSpinBox(self)
+ self.smoothingSpin.setMinimum(3)
+ #self.smoothingSpin.setMaximum(40)
+ self.smoothingSpin.setSingleStep(2)
+ self.smoothingSpin.valueChanged[int].connect(self._emitSignal)
+
+ # Anchors ------------------------------------------------------------
+
+ self.anchorsGroup = qt.QWidget(self)
+ anchorsLayout = qt.QHBoxLayout(self.anchorsGroup)
+ anchorsLayout.setSpacing(2)
+ anchorsLayout.setContentsMargins(0, 0, 0, 0)
+
+ self.anchorsFlagCheck = qt.QCheckBox(self.anchorsGroup)
+ self.anchorsFlagCheck.setText("Use anchors")
+ self.anchorsFlagCheck.setToolTip(
+ "Define X coordinates of points that must remain fixed")
+ self.anchorsFlagCheck.stateChanged[int].connect(
+ self._anchorsToggled)
+ anchorsLayout.addWidget(self.anchorsFlagCheck)
+
+ maxnchannel = 16384 * 4 # Fixme ?
+ self.anchorsList = []
+ num_anchors = 4
+ for i in range(num_anchors):
+ anchorSpin = qt.QSpinBox(self.anchorsGroup)
+ anchorSpin.setMinimum(0)
+ anchorSpin.setMaximum(maxnchannel)
+ anchorSpin.valueChanged[int].connect(self._emitSignal)
+ anchorsLayout.addWidget(anchorSpin)
+ self.anchorsList.append(anchorSpin)
+
+ # Layout ------------------------------------------------------------
+ self.mainLayout.addWidget(self.algorithmComboLabel, 0, 0)
+ self.mainLayout.addWidget(self.algorithmCombo, 0, 2)
+ self.mainLayout.addWidget(self.stripWidthLabel, 1, 0)
+ self.mainLayout.addWidget(self.stripWidthSpin, 1, 2)
+ self.mainLayout.addWidget(self.stripIterLabel, 2, 0)
+ self.mainLayout.addWidget(self.stripIterValue, 2, 2)
+ self.mainLayout.addWidget(self.snipWidthLabel, 3, 0)
+ self.mainLayout.addWidget(self.snipWidthSpin, 3, 2)
+ self.mainLayout.addWidget(self.smoothingFlagCheck, 4, 0)
+ self.mainLayout.addWidget(self.smoothingSpin, 4, 2)
+ self.mainLayout.addWidget(self.anchorsGroup, 5, 0, 1, 4)
+
+ # Initialize interface -----------------------------------------------
+ self._setAlgorithm("strip")
+ self.smoothingFlagCheck.setChecked(False)
+ self._smoothingToggled(is_checked=False)
+ self.anchorsFlagCheck.setChecked(False)
+ self._anchorsToggled(is_checked=False)
+
+ def _algorithmComboActivated(self, algorithm_index):
+ self._setAlgorithm("strip" if algorithm_index == 0 else "snip")
+
+ def _setAlgorithm(self, algorithm):
+ """Enable/disable snip and snip input widgets, depending on the
+ chosen algorithm.
+ :param algorithm: "snip" or "strip"
+ """
+ if algorithm not in ["strip", "snip"]:
+ raise ValueError(
+ "Unknown background filter algorithm %s" % algorithm)
+
+ self.algorithm = algorithm
+ self.stripWidthSpin.setEnabled(algorithm == "strip")
+ self.stripIterValue.setEnabled(algorithm == "strip")
+ self.snipWidthSpin.setEnabled(algorithm == "snip")
+
+ def _smoothingToggled(self, is_checked):
+ """Enable/disable smoothing input widgets, emit dictionary"""
+ self.smoothingSpin.setEnabled(is_checked)
+ self._emitSignal()
+
+ def _anchorsToggled(self, is_checked):
+ """Enable/disable all spin widgets defining anchor X coordinates,
+ emit signal.
+ """
+ for anchor_spin in self.anchorsList:
+ anchor_spin.setEnabled(is_checked)
+ self._emitSignal()
+
+ def setParameters(self, ddict):
+ """Set values for all input widgets.
+
+ :param dict ddict: Input dictionary, must have the same
+ keys as the dictionary output by :meth:`getParameters`
+ """
+ if "algorithm" in ddict:
+ self._setAlgorithm(ddict["algorithm"])
+
+ if "SnipWidth" in ddict:
+ self.snipWidthSpin.setValue(int(ddict["SnipWidth"]))
+
+ if "StripWidth" in ddict:
+ self.stripWidthSpin.setValue(int(ddict["StripWidth"]))
+
+ if "StripIterations" in ddict:
+ self.stripIterValue.setText("%d" % int(ddict["StripIterations"]))
+
+ if "SmoothingFlag" in ddict:
+ self.smoothingFlagCheck.setChecked(bool(ddict["SmoothingFlag"]))
+
+ if "SmoothingWidth" in ddict:
+ self.smoothingSpin.setValue(int(ddict["SmoothingWidth"]))
+
+ if "AnchorsFlag" in ddict:
+ self.anchorsFlagCheck.setChecked(bool(ddict["AnchorsFlag"]))
+
+ if "AnchorsList" in ddict:
+ anchorslist = ddict["AnchorsList"]
+ if anchorslist in [None, 'None']:
+ anchorslist = []
+ for spin in self.anchorsList:
+ spin.setValue(0)
+
+ i = 0
+ for value in anchorslist:
+ self.anchorsList[i].setValue(int(value))
+ i += 1
+
+ def getParameters(self):
+ """Return dictionary of parameters defined in the GUI
+
+ The returned dictionary contains following values:
+
+ - *algorithm*: *"strip"* or *"snip"*
+ - *StripWidth*: width of strip iterator
+ - *StripIterations*: number of iterations
+ - *StripThreshold*: curvature parameter (currently fixed to 1.0)
+ - *SnipWidth*: width of snip algorithm
+ - *SmoothingFlag*: flag to enable/disable smoothing
+ - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
+ - *AnchorsFlag*: flag to enable/disable anchors
+ - *AnchorsList*: list of anchors (X coordinates of fixed values)
+ """
+ stripitertext = self.stripIterValue.text()
+ stripiter = int(stripitertext) if len(stripitertext) else 0
+
+ return {"algorithm": self.algorithm,
+ "StripThreshold": 1.0,
+ "SnipWidth": self.snipWidthSpin.value(),
+ "StripIterations": stripiter,
+ "StripWidth": self.stripWidthSpin.value(),
+ "SmoothingFlag": self.smoothingFlagCheck.isChecked(),
+ "SmoothingWidth": self.smoothingSpin.value(),
+ "AnchorsFlag": self.anchorsFlagCheck.isChecked(),
+ "AnchorsList": [spin.value() for spin in self.anchorsList]}
+
+ def _emitSignal(self, dummy=None):
+ self.sigBackgroundParamWidgetSignal.emit(
+ {'event': 'ParametersChanged',
+ 'parameters': self.getParameters()})
+
+
+class BackgroundWidget(qt.QWidget):
+ """Background configuration widget, with a plot to preview the results.
+
+ Strip and snip filters parameters can be adjusted using input widgets,
+ and the computed backgrounds are plotted next to the original data to
+ show the result."""
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.setWindowTitle("Strip and SNIP Configuration Window")
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ self.parametersWidget = BackgroundParamWidget(self)
+ self.graphWidget = PlotWidget(parent=self)
+ self.mainLayout.addWidget(self.parametersWidget)
+ self.mainLayout.addWidget(self.graphWidget)
+ self._x = None
+ self._y = None
+ self.parametersWidget.sigBackgroundParamWidgetSignal.connect(self._slot)
+
+ def getParameters(self):
+ """Return dictionary of parameters defined in the GUI
+
+ The returned dictionary contains following values:
+
+ - *algorithm*: *"strip"* or *"snip"*
+ - *StripWidth*: width of strip iterator
+ - *StripIterations*: number of iterations
+ - *StripThreshold*: strip curvature (currently fixed to 1.0)
+ - *SnipWidth*: width of snip algorithm
+ - *SmoothingFlag*: flag to enable/disable smoothing
+ - *SmoothingWidth*: width of Savitsky-Golay smoothing filter
+ - *AnchorsFlag*: flag to enable/disable anchors
+ - *AnchorsList*: list of anchors (X coordinates of fixed values)
+ """
+ return self.parametersWidget.getParameters()
+
+ def setParameters(self, ddict):
+ """Set values for all input widgets.
+
+ :param dict ddict: Input dictionary, must have the same
+ keys as the dictionary output by :meth:`getParameters`
+ """
+ return self.parametersWidget.setParameters(ddict)
+
+ def setData(self, x, y, xmin=None, xmax=None):
+ """Set data for the original curve, and _update strip and snip
+ curves accordingly.
+
+ :param x: Array or sequence of curve abscissa values
+ :param y: Array or sequence of curve ordinate values
+ :param xmin: Min value to be displayed on the X axis
+ :param xmax: Max value to be displayed on the X axis
+ """
+ self._x = x
+ self._y = y
+ self._xmin = xmin
+ self._xmax = xmax
+ self._update(resetzoom=True)
+
+ def _slot(self, ddict):
+ self._update()
+
+ def _update(self, resetzoom=False):
+ """Compute strip and snip backgrounds, update the curves
+ """
+ if self._y is None:
+ return
+
+ pars = self.getParameters()
+
+ # smoothed data
+ y = numpy.ravel(numpy.array(self._y)).astype(numpy.float64)
+ if pars["SmoothingFlag"]:
+ ysmooth = filters.savitsky_golay(y, pars['SmoothingWidth'])
+ f = [0.25, 0.5, 0.25]
+ ysmooth[1:-1] = numpy.convolve(ysmooth, f, mode=0)
+ ysmooth[0] = 0.5 * (ysmooth[0] + ysmooth[1])
+ ysmooth[-1] = 0.5 * (ysmooth[-1] + ysmooth[-2])
+ else:
+ ysmooth = y
+
+
+ # loop for anchors
+ x = self._x
+ niter = pars['StripIterations']
+ anchors_indices = []
+ if pars['AnchorsFlag'] and pars['AnchorsList'] is not None:
+ ravelled = x
+ for channel in pars['AnchorsList']:
+ if channel <= ravelled[0]:
+ continue
+ index = numpy.nonzero(ravelled >= channel)[0]
+ if len(index):
+ index = min(index)
+ if index > 0:
+ anchors_indices.append(index)
+
+ stripBackground = filters.strip(ysmooth,
+ w=pars['StripWidth'],
+ niterations=niter,
+ factor=pars['StripThreshold'],
+ anchors=anchors_indices)
+
+ if niter >= 1000:
+ # final smoothing
+ stripBackground = filters.strip(stripBackground,
+ w=1,
+ niterations=50*pars['StripWidth'],
+ factor=pars['StripThreshold'],
+ anchors=anchors_indices)
+
+ if len(anchors_indices) == 0:
+ anchors_indices = [0, len(ysmooth)-1]
+ anchors_indices.sort()
+ snipBackground = 0.0 * ysmooth
+ lastAnchor = 0
+ for anchor in anchors_indices:
+ if (anchor > lastAnchor) and (anchor < len(ysmooth)):
+ snipBackground[lastAnchor:anchor] =\
+ filters.snip1d(ysmooth[lastAnchor:anchor],
+ pars['SnipWidth'])
+ lastAnchor = anchor
+ if lastAnchor < len(ysmooth):
+ snipBackground[lastAnchor:] =\
+ filters.snip1d(ysmooth[lastAnchor:],
+ pars['SnipWidth'])
+
+ self.graphWidget.addCurve(x, y,
+ legend='Input Data',
+ replace=True,
+ resetzoom=resetzoom)
+ self.graphWidget.addCurve(x, stripBackground,
+ legend='Strip Background',
+ resetzoom=False)
+ self.graphWidget.addCurve(x, snipBackground,
+ legend='SNIP Background',
+ resetzoom=False)
+ if self._xmin is not None and self._xmax is not None:
+ self.graphWidget.getXAxis().setLimits(self._xmin, self._xmax)
+
+
+class BackgroundDialog(qt.QDialog):
+ """QDialog window featuring a :class:`BackgroundWidget`"""
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Strip and Snip Configuration Window")
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ self.parametersWidget = BackgroundWidget(self)
+ self.mainLayout.addWidget(self.parametersWidget)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ self.okButton = qt.QPushButton(hbox)
+ self.okButton.setText("OK")
+ self.okButton.setAutoDefault(False)
+ self.dismissButton = qt.QPushButton(hbox)
+ self.dismissButton.setText("Cancel")
+ self.dismissButton.setAutoDefault(False)
+ hboxLayout.addWidget(HorizontalSpacer(hbox))
+ hboxLayout.addWidget(self.okButton)
+ hboxLayout.addWidget(self.dismissButton)
+ self.mainLayout.addWidget(hbox)
+ self.dismissButton.clicked.connect(self.reject)
+ self.okButton.clicked.connect(self.accept)
+
+ self.output = {}
+ """Configuration dictionary containing following fields:
+
+ - *SmoothingFlag*
+ - *SmoothingWidth*
+ - *StripWidth*
+ - *StripIterations*
+ - *StripThreshold*
+ - *SnipWidth*
+ - *AnchorsFlag*
+ - *AnchorsList*
+ """
+
+ # self.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(self.updateOutput)
+
+ # def updateOutput(self, ddict):
+ # self.output = ddict
+
+ def accept(self):
+ """Update :attr:`output`, then call :meth:`QDialog.accept`
+ """
+ self.output = self.getParameters()
+ super(BackgroundDialog, self).accept()
+
+ def sizeHint(self):
+ return qt.QSize(int(1.5*qt.QDialog.sizeHint(self).width()),
+ qt.QDialog.sizeHint(self).height())
+
+ def setData(self, x, y, xmin=None, xmax=None):
+ """See :meth:`BackgroundWidget.setData`"""
+ return self.parametersWidget.setData(x, y, xmin, xmax)
+
+ def getParameters(self):
+ """See :meth:`BackgroundWidget.getParameters`"""
+ return self.parametersWidget.getParameters()
+
+ def setParameters(self, ddict):
+ """See :meth:`BackgroundWidget.setPrintGeometry`"""
+ return self.parametersWidget.setParameters(ddict)
+
+ def setDefault(self, ddict):
+ """Alias for :meth:`setPrintGeometry`"""
+ return self.setParameters(ddict)
+
+
+def getBgDialog(parent=None, default=None, modal=True):
+ """Instantiate and return a bg configuration dialog, adapted
+ for configuring standard background theories from
+ :mod:`silx.math.fit.bgtheories`.
+
+ :return: Instance of :class:`BackgroundDialog`
+ """
+ bgd = BackgroundDialog(parent=parent)
+ # apply default to newly added pages
+ bgd.setParameters(default)
+
+ return bgd
+
+
+def main():
+ # synthetic data
+ from silx.math.fit.functions import sum_gauss
+
+ x = numpy.arange(5000)
+ # (height1, center1, fwhm1, ...) 5 peaks
+ params1 = (50, 500, 100,
+ 20, 2000, 200,
+ 50, 2250, 100,
+ 40, 3000, 75,
+ 23, 4000, 150)
+ y0 = sum_gauss(x, *params1)
+
+ # random values between [-1;1]
+ noise = 2 * numpy.random.random(5000) - 1
+ # make it +- 5%
+ noise *= 0.05
+
+ # 2 gaussians with very large fwhm, as background signal
+ actual_bg = sum_gauss(x, 15, 3500, 3000, 5, 1000, 1500)
+
+ # Add 5% random noise to gaussians and add background
+ y = y0 + numpy.average(y0) * noise + actual_bg
+
+ # Open widget
+ a = qt.QApplication(sys.argv)
+ a.lastWindowClosed.connect(a.quit)
+
+ def mySlot(ddict):
+ print(ddict)
+
+ w = BackgroundDialog()
+ w.parametersWidget.parametersWidget.sigBackgroundParamWidgetSignal.connect(mySlot)
+ w.setData(x, y)
+ w.exec()
+ #a.exec()
+
+if __name__ == "__main__":
+ main()
diff --git a/src/silx/gui/fit/FitConfig.py b/src/silx/gui/fit/FitConfig.py
new file mode 100644
index 0000000..48ebca2
--- /dev/null
+++ b/src/silx/gui/fit/FitConfig.py
@@ -0,0 +1,543 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2021 V.A. Sole, European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module defines widgets used to build a fit configuration dialog.
+The resulting dialog widget outputs a dictionary of configuration parameters.
+"""
+from silx.gui import qt
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+class TabsDialog(qt.QDialog):
+ """Dialog widget containing a QTabWidget :attr:`tabWidget`
+ and a buttons:
+
+ # - buttonHelp
+ - buttonDefaults
+ - buttonOk
+ - buttonCancel
+
+ This dialog defines a __len__ returning the number of tabs,
+ and an __iter__ method yielding the tab widgets.
+ """
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.tabWidget = qt.QTabWidget(self)
+
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.tabWidget)
+
+ layout2 = qt.QHBoxLayout(None)
+
+ # self.buttonHelp = qt.QPushButton(self)
+ # self.buttonHelp.setText("Help")
+ # layout2.addWidget(self.buttonHelp)
+
+ self.buttonDefault = qt.QPushButton(self)
+ self.buttonDefault.setText("Undo changes")
+ layout2.addWidget(self.buttonDefault)
+
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout2.addItem(spacer)
+
+ self.buttonOk = qt.QPushButton(self)
+ self.buttonOk.setText("OK")
+ layout2.addWidget(self.buttonOk)
+
+ self.buttonCancel = qt.QPushButton(self)
+ self.buttonCancel.setText("Cancel")
+ layout2.addWidget(self.buttonCancel)
+
+ layout.addLayout(layout2)
+
+ self.buttonOk.clicked.connect(self.accept)
+ self.buttonCancel.clicked.connect(self.reject)
+
+ def __len__(self):
+ """Return number of tabs"""
+ return self.tabWidget.count()
+
+ def __iter__(self):
+ """Return the next tab widget in :attr:`tabWidget` every
+ time this method is called.
+
+ :return: Tab widget
+ :rtype: QWidget
+ """
+ for widget_index in range(len(self)):
+ yield self.tabWidget.widget(widget_index)
+
+ def addTab(self, page, label):
+ """Add a new tab
+
+ :param page: Content of new page. Must be a widget with
+ a get() method returning a dictionary.
+ :param str label: Tab label
+ """
+ self.tabWidget.addTab(page, label)
+
+ def getTabLabels(self):
+ """
+ Return a list of all tab labels in :attr:`tabWidget`
+ """
+ return [self.tabWidget.tabText(i) for i in range(len(self))]
+
+
+class TabsDialogData(TabsDialog):
+ """This dialog adds a data attribute to :class:`TabsDialog`.
+
+ Data input in widgets, such as text entries or checkboxes, is stored in an
+ attribute :attr:`output` when the user clicks the OK button.
+
+ A default dictionary can be supplied when this dialog is initialized, to
+ be used as default data for :attr:`output`.
+ """
+ def __init__(self, parent=None, modal=True, default=None):
+ """
+
+ :param parent: Parent :class:`QWidget`
+ :param modal: If `True`, dialog is modal, meaning this dialog remains
+ in front of it's parent window and disables it until the user is
+ done interacting with the dialog
+ :param default: Default dictionary, used to initialize and reset
+ :attr:`output`.
+ """
+ TabsDialog.__init__(self, parent)
+ self.setModal(modal)
+ self.setWindowTitle("Fit configuration")
+
+ self.output = {}
+
+ self.default = {} if default is None else default
+
+ self.buttonDefault.clicked.connect(self._resetDefault)
+ # self.keyPressEvent(qt.Qt.Key_Enter).
+
+ def keyPressEvent(self, event):
+ """Redefining this method to ignore Enter key
+ (for some reason it activates buttonDefault callback which
+ resets all widgets)
+ """
+ if event.key() in [qt.Qt.Key_Enter, qt.Qt.Key_Return]:
+ return
+ TabsDialog.keyPressEvent(self, event)
+
+ def accept(self):
+ """When *OK* is clicked, update :attr:`output` with data from
+ various widgets
+ """
+ self.output.update(self.default)
+
+ # loop over all tab widgets (uses TabsDialog.__iter__)
+ for tabWidget in self:
+ self.output.update(tabWidget.get())
+
+ # avoid pathological None cases
+ for key in self.output.keys():
+ if self.output[key] is None:
+ if key in self.default:
+ self.output[key] = self.default[key]
+ super(TabsDialogData, self).accept()
+
+ def reject(self):
+ """When the *Cancel* button is clicked, reinitialize :attr:`output`
+ and quit
+ """
+ self.setDefault()
+ super(TabsDialogData, self).reject()
+
+ def _resetDefault(self, checked):
+ self.setDefault()
+
+ def setDefault(self, newdefault=None):
+ """Reinitialize :attr:`output` with :attr:`default` or with
+ new dictionary ``newdefault`` if provided.
+ Call :meth:`setDefault` for each tab widget, if available.
+ """
+ self.output = {}
+ if newdefault is None:
+ newdefault = self.default
+ else:
+ self.default = newdefault
+ self.output.update(newdefault)
+
+ for tabWidget in self:
+ if hasattr(tabWidget, "setDefault"):
+ tabWidget.setDefault(self.output)
+
+
+class ConstraintsPage(qt.QGroupBox):
+ """Checkable QGroupBox widget filled with QCheckBox widgets,
+ to configure the fit estimation for standard fit theories.
+ """
+ def __init__(self, parent=None, title="Set constraints"):
+ super(ConstraintsPage, self).__init__(parent)
+ self.setTitle(title)
+ self.setToolTip("Disable 'Set constraints' to remove all " +
+ "constraints on all fit parameters")
+ self.setCheckable(True)
+
+ layout = qt.QVBoxLayout(self)
+ self.setLayout(layout)
+
+ self.positiveHeightCB = qt.QCheckBox("Force positive height/area", self)
+ self.positiveHeightCB.setToolTip("Fit must find positive peaks")
+ layout.addWidget(self.positiveHeightCB)
+
+ self.positionInIntervalCB = qt.QCheckBox("Force position in interval", self)
+ self.positionInIntervalCB.setToolTip(
+ "Fit must position peak within X limits")
+ layout.addWidget(self.positionInIntervalCB)
+
+ self.positiveFwhmCB = qt.QCheckBox("Force positive FWHM", self)
+ self.positiveFwhmCB.setToolTip("Fit must find a positive FWHM")
+ layout.addWidget(self.positiveFwhmCB)
+
+ self.sameFwhmCB = qt.QCheckBox("Force same FWHM for all peaks", self)
+ self.sameFwhmCB.setToolTip("Fit must find same FWHM for all peaks")
+ layout.addWidget(self.sameFwhmCB)
+
+ self.quotedEtaCB = qt.QCheckBox("Force Eta between 0 and 1", self)
+ self.quotedEtaCB.setToolTip(
+ "Fit must find Eta between 0 and 1 for pseudo-Voigt function")
+ layout.addWidget(self.quotedEtaCB)
+
+ layout.addStretch()
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default state for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default state."""
+ if default_dict is None:
+ default_dict = {}
+ # this one uses reverse logic: if checked, NoConstraintsFlag must be False
+ self.setChecked(
+ not default_dict.get('NoConstraintsFlag', False))
+ self.positiveHeightCB.setChecked(
+ default_dict.get('PositiveHeightAreaFlag', True))
+ self.positionInIntervalCB.setChecked(
+ default_dict.get('QuotedPositionFlag', False))
+ self.positiveFwhmCB.setChecked(
+ default_dict.get('PositiveFwhmFlag', True))
+ self.sameFwhmCB.setChecked(
+ default_dict.get('SameFwhmFlag', False))
+ self.quotedEtaCB.setChecked(
+ default_dict.get('QuotedEtaFlag', False))
+
+ def get(self):
+ """Return a dictionary of constraint flags, to be processed by the
+ :meth:`configure` method of the selected fit theory."""
+ ddict = {
+ 'NoConstraintsFlag': not self.isChecked(),
+ 'PositiveHeightAreaFlag': self.positiveHeightCB.isChecked(),
+ 'QuotedPositionFlag': self.positionInIntervalCB.isChecked(),
+ 'PositiveFwhmFlag': self.positiveFwhmCB.isChecked(),
+ 'SameFwhmFlag': self.sameFwhmCB.isChecked(),
+ 'QuotedEtaFlag': self.quotedEtaCB.isChecked(),
+ }
+ return ddict
+
+
+class SearchPage(qt.QWidget):
+ def __init__(self, parent=None):
+ super(SearchPage, self).__init__(parent)
+ layout = qt.QVBoxLayout(self)
+
+ self.manualFwhmGB = qt.QGroupBox("Define FWHM manually", self)
+ self.manualFwhmGB.setCheckable(True)
+ self.manualFwhmGB.setToolTip(
+ "If disabled, the FWHM parameter used for peak search is " +
+ "estimated based on the highest peak in the data")
+ layout.addWidget(self.manualFwhmGB)
+ # ------------ GroupBox fwhm--------------------------
+ layout2 = qt.QHBoxLayout(self.manualFwhmGB)
+ self.manualFwhmGB.setLayout(layout2)
+
+ label = qt.QLabel("Fwhm Points", self.manualFwhmGB)
+ layout2.addWidget(label)
+
+ self.fwhmPointsSpin = qt.QSpinBox(self.manualFwhmGB)
+ self.fwhmPointsSpin.setRange(0, 999999)
+ self.fwhmPointsSpin.setToolTip("Typical peak fwhm (number of data points)")
+ layout2.addWidget(self.fwhmPointsSpin)
+ # ----------------------------------------------------
+
+ self.manualScalingGB = qt.QGroupBox("Define scaling manually", self)
+ self.manualScalingGB.setCheckable(True)
+ self.manualScalingGB.setToolTip(
+ "If disabled, the Y scaling used for peak search is " +
+ "estimated automatically")
+ layout.addWidget(self.manualScalingGB)
+ # ------------ GroupBox scaling-----------------------
+ layout3 = qt.QHBoxLayout(self.manualScalingGB)
+ self.manualScalingGB.setLayout(layout3)
+
+ label = qt.QLabel("Y Scaling", self.manualScalingGB)
+ layout3.addWidget(label)
+
+ self.yScalingEntry = qt.QLineEdit(self.manualScalingGB)
+ self.yScalingEntry.setToolTip(
+ "Data values will be multiplied by this value prior to peak" +
+ " search")
+ self.yScalingEntry.setValidator(qt.QDoubleValidator(self))
+ layout3.addWidget(self.yScalingEntry)
+ # ----------------------------------------------------
+
+ # ------------------- grid layout --------------------
+ containerWidget = qt.QWidget(self)
+ layout4 = qt.QHBoxLayout(containerWidget)
+ containerWidget.setLayout(layout4)
+
+ label = qt.QLabel("Sensitivity", containerWidget)
+ layout4.addWidget(label)
+
+ self.sensitivityEntry = qt.QLineEdit(containerWidget)
+ self.sensitivityEntry.setToolTip(
+ "Peak search sensitivity threshold, expressed as a multiple " +
+ "of the standard deviation of the noise.\nMinimum value is 1 " +
+ "(to be detected, peak must be higher than the estimated noise)")
+ sensivalidator = qt.QDoubleValidator(self)
+ sensivalidator.setBottom(1.0)
+ self.sensitivityEntry.setValidator(sensivalidator)
+ layout4.addWidget(self.sensitivityEntry)
+ # ----------------------------------------------------
+ layout.addWidget(containerWidget)
+
+ self.forcePeakPresenceCB = qt.QCheckBox("Force peak presence", self)
+ self.forcePeakPresenceCB.setToolTip(
+ "If peak search algorithm is unsuccessful, place one peak " +
+ "at the maximum of the curve")
+ layout.addWidget(self.forcePeakPresenceCB)
+
+ layout.addStretch()
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default values for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default values."""
+ if default_dict is None:
+ default_dict = {}
+ self.manualFwhmGB.setChecked(
+ not default_dict.get('AutoFwhm', True))
+ self.fwhmPointsSpin.setValue(
+ default_dict.get('FwhmPoints', 8))
+ self.sensitivityEntry.setText(
+ str(default_dict.get('Sensitivity', 1.0)))
+ self.manualScalingGB.setChecked(
+ not default_dict.get('AutoScaling', False))
+ self.yScalingEntry.setText(
+ str(default_dict.get('Yscaling', 1.0)))
+ self.forcePeakPresenceCB.setChecked(
+ default_dict.get('ForcePeakPresence', False))
+
+ def get(self):
+ """Return a dictionary of peak search parameters, to be processed by
+ the :meth:`configure` method of the selected fit theory."""
+ ddict = {
+ 'AutoFwhm': not self.manualFwhmGB.isChecked(),
+ 'FwhmPoints': self.fwhmPointsSpin.value(),
+ 'Sensitivity': safe_float(self.sensitivityEntry.text()),
+ 'AutoScaling': not self.manualScalingGB.isChecked(),
+ 'Yscaling': safe_float(self.yScalingEntry.text()),
+ 'ForcePeakPresence': self.forcePeakPresenceCB.isChecked()
+ }
+ return ddict
+
+
+class BackgroundPage(qt.QGroupBox):
+ """Background subtraction configuration, specific to fittheories
+ estimation functions."""
+ def __init__(self, parent=None,
+ title="Subtract strip background prior to estimation"):
+ super(BackgroundPage, self).__init__(parent)
+ self.setTitle(title)
+ self.setCheckable(True)
+ self.setToolTip(
+ "The strip algorithm strips away peaks to compute the " +
+ "background signal.\nAt each iteration, a sample is compared " +
+ "to the average of the two samples at a given distance in both" +
+ " directions,\n and if its value is higher than the average,"
+ "it is replaced by the average.")
+
+ layout = qt.QGridLayout(self)
+ self.setLayout(layout)
+
+ for i, label_text in enumerate(
+ ["Strip width (in samples)",
+ "Number of iterations",
+ "Strip threshold factor"]):
+ label = qt.QLabel(label_text)
+ layout.addWidget(label, i, 0)
+
+ self.stripWidthSpin = qt.QSpinBox(self)
+ self.stripWidthSpin.setToolTip(
+ "Width, in number of samples, of the strip operator")
+ self.stripWidthSpin.setRange(1, 999999)
+
+ layout.addWidget(self.stripWidthSpin, 0, 1)
+
+ self.numIterationsSpin = qt.QSpinBox(self)
+ self.numIterationsSpin.setToolTip(
+ "Number of iterations of the strip algorithm")
+ self.numIterationsSpin.setRange(1, 999999)
+ layout.addWidget(self.numIterationsSpin, 1, 1)
+
+ self.thresholdFactorEntry = qt.QLineEdit(self)
+ self.thresholdFactorEntry.setToolTip(
+ "Factor used by the strip algorithm to decide whether a sample" +
+ "value should be stripped.\nThe value must be higher than the " +
+ "average of the 2 samples at +- w times this factor.\n")
+ self.thresholdFactorEntry.setValidator(qt.QDoubleValidator(self))
+ layout.addWidget(self.thresholdFactorEntry, 2, 1)
+
+ self.smoothStripGB = qt.QGroupBox("Apply smoothing prior to strip", self)
+ self.smoothStripGB.setCheckable(True)
+ self.smoothStripGB.setToolTip(
+ "Apply a smoothing before subtracting strip background" +
+ " in fit and estimate processes")
+ smoothlayout = qt.QHBoxLayout(self.smoothStripGB)
+ label = qt.QLabel("Smoothing width (Savitsky-Golay)")
+ smoothlayout.addWidget(label)
+ self.smoothingWidthSpin = qt.QSpinBox(self)
+ self.smoothingWidthSpin.setToolTip(
+ "Width parameter for Savitsky-Golay smoothing (number of samples, must be odd)")
+ self.smoothingWidthSpin.setRange(3, 101)
+ self.smoothingWidthSpin.setSingleStep(2)
+ smoothlayout.addWidget(self.smoothingWidthSpin)
+
+ layout.addWidget(self.smoothStripGB, 3, 0, 1, 2)
+
+ layout.setRowStretch(4, 1)
+
+ self.setDefault()
+
+ def setDefault(self, default_dict=None):
+ """Set default values for all widgets.
+
+ :param default_dict: If a default config dictionary is provided as
+ a parameter, its values are used as default values."""
+ if default_dict is None:
+ default_dict = {}
+
+ self.setChecked(
+ default_dict.get('StripBackgroundFlag', True))
+
+ self.stripWidthSpin.setValue(
+ default_dict.get('StripWidth', 2))
+ self.numIterationsSpin.setValue(
+ default_dict.get('StripIterations', 5000))
+ self.thresholdFactorEntry.setText(
+ str(default_dict.get('StripThreshold', 1.0)))
+ self.smoothStripGB.setChecked(
+ default_dict.get('SmoothingFlag', False))
+ self.smoothingWidthSpin.setValue(
+ default_dict.get('SmoothingWidth', 3))
+
+ def get(self):
+ """Return a dictionary of background subtraction parameters, to be
+ processed by the :meth:`configure` method of the selected fit theory.
+ """
+ ddict = {
+ 'StripBackgroundFlag': self.isChecked(),
+ 'StripWidth': self.stripWidthSpin.value(),
+ 'StripIterations': self.numIterationsSpin.value(),
+ 'StripThreshold': safe_float(self.thresholdFactorEntry.text()),
+ 'SmoothingFlag': self.smoothStripGB.isChecked(),
+ 'SmoothingWidth': self.smoothingWidthSpin.value()
+ }
+ return ddict
+
+
+def safe_float(string_, default=1.0):
+ """Convert a string into a float.
+ If the conversion fails, return the default value.
+ """
+ try:
+ ret = float(string_)
+ except ValueError:
+ return default
+ else:
+ return ret
+
+
+def safe_int(string_, default=1):
+ """Convert a string into a integer.
+ If the conversion fails, return the default value.
+ """
+ try:
+ ret = int(float(string_))
+ except ValueError:
+ return default
+ else:
+ return ret
+
+
+def getFitConfigDialog(parent=None, default=None, modal=True):
+ """Instantiate and return a fit configuration dialog, adapted
+ for configuring standard fit theories from
+ :mod:`silx.math.fit.fittheories`.
+
+ :return: Instance of :class:`TabsDialogData` with 3 tabs:
+ :class:`ConstraintsPage`, :class:`SearchPage` and
+ :class:`BackgroundPage`
+ """
+ tdd = TabsDialogData(parent=parent, default=default)
+ tdd.addTab(ConstraintsPage(), label="Constraints")
+ tdd.addTab(SearchPage(), label="Peak search")
+ tdd.addTab(BackgroundPage(), label="Background")
+ # apply default to newly added pages
+ tdd.setDefault()
+
+ return tdd
+
+
+def main():
+ a = qt.QApplication([])
+
+ mw = qt.QMainWindow()
+ mw.show()
+
+ tdd = getFitConfigDialog(mw, default={"a": 1})
+ tdd.show()
+ tdd.exec()
+ print("TabsDialogData result: ", tdd.result())
+ print("TabsDialogData output: ", tdd.output)
+
+ a.exec()
+
+if __name__ == "__main__":
+ main()
diff --git a/src/silx/gui/fit/FitWidget.py b/src/silx/gui/fit/FitWidget.py
new file mode 100644
index 0000000..52ecafe
--- /dev/null
+++ b/src/silx/gui/fit/FitWidget.py
@@ -0,0 +1,751 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module provides a widget designed to configure and run a fitting
+process with constraints on parameters.
+
+The main class is :class:`FitWidget`. It relies on
+:mod:`silx.math.fit.fitmanager`, which relies on :func:`silx.math.fit.leastsq`.
+
+The user can choose between functions before running the fit. These function can
+be user defined, or by default are loaded from
+:mod:`silx.math.fit.fittheories`.
+"""
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+import logging
+import sys
+import traceback
+
+from silx.math.fit import fittheories
+from silx.math.fit import fitmanager, functions
+from silx.gui import qt
+from .FitWidgets import (FitActionsButtons, FitStatusLines,
+ FitConfigWidget, ParametersTab)
+from .FitConfig import getFitConfigDialog
+from .BackgroundWidget import getBgDialog, BackgroundDialog
+from ...utils.deprecation import deprecated
+
+DEBUG = 0
+_logger = logging.getLogger(__name__)
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+class FitWidget(qt.QWidget):
+ """This widget can be used to configure, run and display results of a
+ fitting process.
+
+ The standard steps for using this widget is to initialize it, then load
+ the data to be fitted.
+
+ Optionally, you can also load user defined fit theories. If you skip this
+ step, a series of default fit functions will be presented (gaussian-like
+ functions), and you can later load your custom fit theories from an
+ external file using the GUI.
+
+ A fit theory is a fit function and its associated features:
+
+ - estimation function,
+ - list of parameter names
+ - numerical derivative algorithm
+ - configuration widget
+
+ Once the widget is up and running, the user may select a fit theory and a
+ background theory, change configuration parameters specific to the theory
+ run the estimation, set constraints on parameters and run the actual fit.
+
+ The results are displayed in a table.
+
+ .. image:: img/FitWidget.png
+ """
+ sigFitWidgetSignal = qt.Signal(object)
+ """This signal is emitted by the estimation and fit methods.
+ It carries a dictionary with two items:
+
+ - *event*: one of the following strings
+
+ - *EstimateStarted*,
+ - *FitStarted*
+ - *EstimateFinished*,
+ - *FitFinished*
+ - *EstimateFailed*
+ - *FitFailed*
+
+ - *data*: None, or fit/estimate results (see documentation for
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
+ """
+
+ def __init__(self, parent=None, title=None, fitmngr=None,
+ enableconfig=True, enablestatus=True, enablebuttons=True):
+ """
+
+ :param parent: Parent widget
+ :param title: Window title
+ :param fitmngr: User defined instance of
+ :class:`silx.math.fit.fitmanager.FitManager`, or ``None``
+ :param enableconfig: If ``True``, activate widgets to modify the fit
+ configuration (select between several fit functions or background
+ functions, apply global constraints, peak search parameters…)
+ :param enablestatus: If ``True``, add a fit status widget, to display
+ a message when fit estimation is available and when fit results
+ are available, as well as a measure of the fit error.
+ :param enablebuttons: If ``True``, add buttons to run estimation and
+ fitting.
+ """
+ if title is None:
+ title = "FitWidget"
+ qt.QWidget.__init__(self, parent)
+
+ self.setWindowTitle(title)
+ layout = qt.QVBoxLayout(self)
+
+ self.fitmanager = self._setFitManager(fitmngr)
+ """Instance of :class:`FitManager`.
+ This is the underlying data model of this FitWidget.
+
+ If no custom theories are defined, the default ones from
+ :mod:`silx.math.fit.fittheories` are imported.
+ """
+
+ # reference fitmanager.configure method for direct access
+ self.configure = self.fitmanager.configure
+ self.fitconfig = self.fitmanager.fitconfig
+
+ self.configdialogs = {}
+ """This dictionary defines the fit configuration widgets
+ associated with the fit theories in :attr:`fitmanager.theories`
+
+ Keys must correspond to existing theory names, i.e. existing keys
+ in :attr:`fitmanager.theories`.
+
+ Values must be instances of QDialog widgets with an additional
+ *output* attribute, a dictionary storing configuration parameters
+ interpreted by the corresponding fit theory.
+
+ The dialog can also define a *setDefault* method to initialize the
+ widget values with values in a dictionary passed as a parameter.
+ This will be executed first.
+
+ In case the widget does not actually inherit :class:`QDialog`, it
+ must at least implement the following methods (executed in this
+ particular order):
+
+ - :meth:`show`: should cause the widget to become visible to the
+ user)
+ - :meth:`exec`: should run while the user is interacting with the
+ widget, interrupting the rest of the program. It should
+ typically end (*return*) when the user clicks an *OK*
+ or a *Cancel* button.
+ - :meth:`result`: must return ``True`` if the new configuration in
+ attribute :attr:`output` is to be accepted (user clicked *OK*),
+ or return ``False`` if :attr:`output` is to be rejected (user
+ clicked *Cancel*)
+
+ To associate a custom configuration widget with a fit theory, use
+ :meth:`associateConfigDialog`. E.g.::
+
+ fw = FitWidget()
+ my_config_widget = MyGaussianConfigWidget(parent=fw)
+ fw.associateConfigDialog(theory_name="Gaussians",
+ config_widget=my_config_widget)
+ """
+
+ self.bgconfigdialogs = {}
+ """Same as :attr:`configdialogs`, except that the widget is associated
+ with a background theory in :attr:`fitmanager.bgtheories`"""
+
+ self._associateDefaultConfigDialogs()
+
+ self.guiConfig = None
+ """Configuration widget at the top of FitWidget, to select
+ fit function, background function, and open an advanced
+ configuration dialog."""
+
+ self.guiParameters = ParametersTab(self)
+ """Table widget for display of fit parameters and constraints"""
+
+ if enableconfig:
+ self.guiConfig = FitConfigWidget(self)
+ """Function selector and configuration widget"""
+
+ self.guiConfig.FunConfigureButton.clicked.connect(
+ self.__funConfigureGuiSlot)
+ self.guiConfig.BgConfigureButton.clicked.connect(
+ self.__bgConfigureGuiSlot)
+
+ self.guiConfig.WeightCheckBox.setChecked(
+ self.fitconfig.get("WeightFlag", False))
+ self.guiConfig.WeightCheckBox.stateChanged[int].connect(self.weightEvent)
+
+ if qt.BINDING in ('PySide2', 'PyQt5'):
+ self.guiConfig.BkgComBox.activated[str].connect(self.bkgEvent)
+ self.guiConfig.FunComBox.activated[str].connect(self.funEvent)
+ else: # Qt6
+ self.guiConfig.BkgComBox.textActivated.connect(self.bkgEvent)
+ self.guiConfig.FunComBox.textActivated.connect(self.funEvent)
+
+ self._populateFunctions()
+
+ layout.addWidget(self.guiConfig)
+
+ layout.addWidget(self.guiParameters)
+
+ if enablestatus:
+ self.guistatus = FitStatusLines(self)
+ """Status bar"""
+ layout.addWidget(self.guistatus)
+
+ if enablebuttons:
+ self.guibuttons = FitActionsButtons(self)
+ """Widget with estimate, start fit and dismiss buttons"""
+ self.guibuttons.EstimateButton.clicked.connect(self.estimate)
+ self.guibuttons.EstimateButton.setEnabled(False)
+ self.guibuttons.StartFitButton.clicked.connect(self.startFit)
+ self.guibuttons.StartFitButton.setEnabled(False)
+ self.guibuttons.DismissButton.clicked.connect(self.dismiss)
+ layout.addWidget(self.guibuttons)
+
+ def _setFitManager(self, fitinstance):
+ """Initialize a :class:`FitManager` instance, to be assigned to
+ :attr:`fitmanager`, or use a custom FitManager instance.
+
+ :param fitinstance: Existing instance of FitManager, possibly
+ customized by the user, or None to load a default instance."""
+ if isinstance(fitinstance, fitmanager.FitManager):
+ # customized
+ fitmngr = fitinstance
+ else:
+ # initialize default instance
+ fitmngr = fitmanager.FitManager()
+
+ # initialize the default fitting functions in case
+ # none is present
+ if not len(fitmngr.theories):
+ fitmngr.loadtheories(fittheories)
+
+ return fitmngr
+
+ def _associateDefaultConfigDialogs(self):
+ """Fill :attr:`bgconfigdialogs` and :attr:`configdialogs` by calling
+ :meth:`associateConfigDialog` with default config dialog widgets.
+ """
+ # associate silx.gui.fit.FitConfig with all theories
+ # Users can later associate their own custom dialogs to
+ # replace the default.
+ configdialog = getFitConfigDialog(parent=self,
+ default=self.fitconfig)
+ for theory in self.fitmanager.theories:
+ self.associateConfigDialog(theory, configdialog)
+ for bgtheory in self.fitmanager.bgtheories:
+ self.associateConfigDialog(bgtheory, configdialog,
+ theory_is_background=True)
+
+ # associate silx.gui.fit.BackgroundWidget with Strip and Snip
+ bgdialog = getBgDialog(parent=self,
+ default=self.fitconfig)
+ for bgtheory in ["Strip", "Snip"]:
+ if bgtheory in self.fitmanager.bgtheories:
+ self.associateConfigDialog(bgtheory, bgdialog,
+ theory_is_background=True)
+
+ def _populateFunctions(self):
+ """Fill combo-boxes with fit theories and background theories
+ loaded by :attr:`fitmanager`.
+ Run :meth:`fitmanager.configure` to ensure the custom configuration
+ of the selected theory has been loaded into :attr:`fitconfig`"""
+ for theory_name in self.fitmanager.bgtheories:
+ self.guiConfig.BkgComBox.addItem(theory_name)
+ self.guiConfig.BkgComBox.setItemData(
+ self.guiConfig.BkgComBox.findText(theory_name),
+ self.fitmanager.bgtheories[theory_name].description,
+ qt.Qt.ToolTipRole)
+
+ for theory_name in self.fitmanager.theories:
+ self.guiConfig.FunComBox.addItem(theory_name)
+ self.guiConfig.FunComBox.setItemData(
+ self.guiConfig.FunComBox.findText(theory_name),
+ self.fitmanager.theories[theory_name].description,
+ qt.Qt.ToolTipRole)
+
+ # - activate selected fit theory (if any)
+ # - activate selected bg theory (if any)
+ configuration = self.fitmanager.configure()
+ if self.fitmanager.selectedtheory is None:
+ # take the first one by default
+ self.guiConfig.FunComBox.setCurrentIndex(1)
+ self.funEvent(list(self.fitmanager.theories.keys())[0])
+ else:
+ idx = list(self.fitmanager.theories).index(self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(idx + 1)
+ self.funEvent(self.fitmanager.selectedtheory)
+
+ if self.fitmanager.selectedbg is None:
+ self.guiConfig.BkgComBox.setCurrentIndex(1)
+ self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
+ else:
+ idx = list(self.fitmanager.bgtheories).index(self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(idx + 1)
+ self.bkgEvent(self.fitmanager.selectedbg)
+
+ configuration.update(self.configure())
+
+ @deprecated(replacement='setData', since_version='0.3.0')
+ def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
+ self.setData(x, y, sigmay, xmin, xmax)
+
+ def setData(self, x=None, y=None, sigmay=None, xmin=None, xmax=None):
+ """Set data to be fitted.
+
+ :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to
+ ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
+ :type x: Sequence or numpy array or None
+ :param y: The dependant data ``y = f(x)``. ``y`` must have the same
+ shape as ``x`` if ``x`` is not ``None``.
+ :type y: Sequence or numpy array or None
+ :param sigmay: The uncertainties in the ``ydata`` array. These are
+ used as weights in the least-squares problem.
+ If ``None``, the uncertainties are assumed to be 1.
+ :type sigmay: Sequence or numpy array or None
+ :param xmin: Lower value of x values to use for fitting
+ :param xmax: Upper value of x values to use for fitting
+ """
+ if y is None:
+ self.guibuttons.EstimateButton.setEnabled(False)
+ self.guibuttons.StartFitButton.setEnabled(False)
+ else:
+ self.guibuttons.EstimateButton.setEnabled(True)
+ self.guibuttons.StartFitButton.setEnabled(True)
+ self.fitmanager.setdata(x=x, y=y, sigmay=sigmay,
+ xmin=xmin, xmax=xmax)
+ for config_dialog in self.bgconfigdialogs.values():
+ if isinstance(config_dialog, BackgroundDialog):
+ config_dialog.setData(x, y, xmin=xmin, xmax=xmax)
+
+ def associateConfigDialog(self, theory_name, config_widget,
+ theory_is_background=False):
+ """Associate an instance of custom configuration dialog widget to
+ a fit theory or to a background theory.
+
+ This adds or modifies an item in the correspondence table
+ :attr:`configdialogs` or :attr:`bgconfigdialogs`.
+
+ :param str theory_name: Name of fit theory. This must be a key of dict
+ :attr:`fitmanager.theories`
+ :param config_widget: Custom configuration widget. See documentation
+ for :attr:`configdialogs`
+ :param bool theory_is_background: If flag is *True*, add dialog to
+ :attr:`bgconfigdialogs` rather than :attr:`configdialogs`
+ (default).
+ :raise: KeyError if parameter ``theory_name`` does not match an
+ existing fit theory or background theory in :attr:`fitmanager`.
+ :raise: AttributeError if the widget does not implement the mandatory
+ methods (*show*, *exec*, *result*, *setDefault*) or the mandatory
+ attribute (*output*).
+ """
+ theories = self.fitmanager.bgtheories if theory_is_background else\
+ self.fitmanager.theories
+
+ if theory_name not in theories:
+ raise KeyError("%s does not match an existing fitmanager theory")
+
+ if config_widget is not None:
+ if (not hasattr(config_widget, "exec") and
+ not hasattr(config_widget, "exec_")):
+ raise AttributeError(
+ "Custom configuration widget must define exec or exec_")
+
+ for mandatory_attr in ["show", "result", "output"]:
+ if not hasattr(config_widget, mandatory_attr):
+ raise AttributeError(
+ "Custom configuration widget must define " +
+ "attribute or method " + mandatory_attr)
+
+ if theory_is_background:
+ self.bgconfigdialogs[theory_name] = config_widget
+ else:
+ self.configdialogs[theory_name] = config_widget
+
+ def _emitSignal(self, ddict):
+ """Emit pyqtSignal after estimation completed
+ (``ddict = {'event': 'EstimateFinished', 'data': fit_results}``)
+ and after fit completed
+ (``ddict = {'event': 'FitFinished', 'data': fit_results}``)"""
+ self.sigFitWidgetSignal.emit(ddict)
+
+ def __funConfigureGuiSlot(self):
+ """Open an advanced configuration dialog widget"""
+ self.__configureGui(dialog_type="function")
+
+ def __bgConfigureGuiSlot(self):
+ """Open an advanced configuration dialog widget"""
+ self.__configureGui(dialog_type="background")
+
+ def __configureGui(self, newconfiguration=None, dialog_type="function"):
+ """Open an advanced configuration dialog widget to get a configuration
+ dictionary, or use a supplied configuration dictionary. Call
+ :meth:`configure` with this dictionary as a parameter. Update the gui
+ accordingly. Reinitialize the fit results in the table and in
+ :attr:`fitmanager`.
+
+ :param newconfiguration: User supplied configuration dictionary. If ``None``,
+ open a dialog widget that returns a dictionary."""
+ configuration = self.configure()
+ # get new dictionary
+ if newconfiguration is None:
+ newconfiguration = self.configureDialog(configuration, dialog_type)
+ # update configuration
+ configuration.update(self.configure(**newconfiguration))
+ # set fit function theory
+ try:
+ i = 1 + \
+ list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(i)
+ self.funEvent(self.fitmanager.selectedtheory)
+ except ValueError:
+ _logger.error("Function not in list %s",
+ self.fitmanager.selectedtheory)
+ self.funEvent(list(self.fitmanager.theories.keys())[0])
+ # current background
+ try:
+ i = 1 + \
+ list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(i)
+ self.bkgEvent(self.fitmanager.selectedbg)
+ except ValueError:
+ _logger.error("Background not in list %s",
+ self.fitmanager.selectedbg)
+ self.bkgEvent(list(self.fitmanager.bgtheories.keys())[0])
+
+ # update the Gui
+ self.__initialParameters()
+
+ def configureDialog(self, oldconfiguration, dialog_type="function"):
+ """Display a dialog, allowing the user to define fit configuration
+ parameters.
+
+ By default, a common dialog is used for all fit theories. But if the
+ defined a custom dialog using :meth:`associateConfigDialog`, it is
+ used instead.
+
+ :param dict oldconfiguration: Dictionary containing previous configuration
+ :param str dialog_type: "function" or "background"
+ :return: User defined parameters in a dictionary
+ """
+ newconfiguration = {}
+ newconfiguration.update(oldconfiguration)
+
+ if dialog_type == "function":
+ theory = self.fitmanager.selectedtheory
+ configdialog = self.configdialogs[theory]
+ elif dialog_type == "background":
+ theory = self.fitmanager.selectedbg
+ configdialog = self.bgconfigdialogs[theory]
+
+ # this should only happen if a user specifically associates None
+ # with a theory, to have no configuration option
+ if configdialog is None:
+ return {}
+
+ # update state of configdialog before showing it
+ if hasattr(configdialog, "setDefault"):
+ configdialog.setDefault(newconfiguration)
+ configdialog.show()
+ if hasattr(configdialog, "exec"):
+ configdialog.exec()
+ else: # Qt5 compatibility
+ configdialog.exec_()
+ if configdialog.result():
+ newconfiguration.update(configdialog.output)
+
+ return newconfiguration
+
+ def estimate(self):
+ """Run parameter estimation function then emit
+ :attr:`sigFitWidgetSignal` with a dictionary containing a status
+ message and a list of fit parameters estimations
+ in the format defined in
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
+
+ The emitted dictionary has an *"event"* key that can have
+ following values:
+
+ - *'EstimateStarted'*
+ - *'EstimateFailed'*
+ - *'EstimateFinished'*
+ """
+ try:
+ theory_name = self.fitmanager.selectedtheory
+ estimation_function = self.fitmanager.theories[theory_name].estimate
+ if estimation_function is not None:
+ ddict = {'event': 'EstimateStarted',
+ 'data': None}
+ self._emitSignal(ddict)
+ self.fitmanager.estimate(callback=self.fitStatus)
+ else:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Information)
+ text = "Function does not define a way to estimate\n"
+ text += "the initial parameters. Please, fill them\n"
+ text += "yourself in the table and press Start Fit\n"
+ msg.setText(text)
+ msg.setWindowTitle('FitWidget Message')
+ msg.exec()
+ return
+ except Exception as e: # noqa (we want to catch and report all errors)
+ _logger.warning('Estimate error: %s', traceback.format_exc())
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle("Estimate Error")
+ msg.setText("Error on estimate: %s" % e)
+ msg.exec()
+ ddict = {
+ 'event': 'EstimateFailed',
+ 'data': None}
+ self._emitSignal(ddict)
+ return
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+ self.guiParameters.removeAllViews(keep='Fit')
+ ddict = {
+ 'event': 'EstimateFinished',
+ 'data': self.fitmanager.fit_results}
+ self._emitSignal(ddict)
+
+ @deprecated(replacement='startFit', since_version='0.3.0')
+ def startfit(self):
+ self.startFit()
+
+ def startFit(self):
+ """Run fit, then emit :attr:`sigFitWidgetSignal` with a dictionary
+ containing a status message and a list of fit
+ parameters results in the format defined in
+ :attr:`silx.math.fit.fitmanager.FitManager.fit_results`
+
+ The emitted dictionary has an *"event"* key that can have
+ following values:
+
+ - *'FitStarted'*
+ - *'FitFailed'*
+ - *'FitFinished'*
+ """
+ self.fitmanager.fit_results = self.guiParameters.getFitResults()
+ try:
+ ddict = {'event': 'FitStarted',
+ 'data': None}
+ self._emitSignal(ddict)
+ self.fitmanager.runfit(callback=self.fitStatus)
+ except Exception as e: # noqa (we want to catch and report all errors)
+ _logger.warning('Estimate error: %s', traceback.format_exc())
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle("Fit Error")
+ msg.setText("Error on Fit: %s" % e)
+ msg.exec()
+ ddict = {
+ 'event': 'FitFailed',
+ 'data': None
+ }
+ self._emitSignal(ddict)
+ return
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+ self.guiParameters.removeAllViews(keep='Fit')
+ ddict = {
+ 'event': 'FitFinished',
+ 'data': self.fitmanager.fit_results
+ }
+ self._emitSignal(ddict)
+ return
+
+ def bkgEvent(self, bgtheory):
+ """Select background theory, then reinitialize parameters"""
+ bgtheory = str(bgtheory)
+ if bgtheory in self.fitmanager.bgtheories:
+ self.fitmanager.setbackground(bgtheory)
+ else:
+ functionsfile = qt.QFileDialog.getOpenFileName(
+ self, "Select python module with your function(s)", "",
+ "Python Files (*.py);;All Files (*)")
+
+ if len(functionsfile):
+ try:
+ self.fitmanager.loadbgtheories(functionsfile)
+ except ImportError:
+ qt.QMessageBox.critical(self, "ERROR",
+ "Function not imported")
+ return
+ else:
+ # empty the ComboBox
+ while self.guiConfig.BkgComBox.count() > 1:
+ self.guiConfig.BkgComBox.removeItem(1)
+ # and fill it again
+ for key in self.fitmanager.bgtheories:
+ self.guiConfig.BkgComBox.addItem(str(key))
+
+ i = 1 + \
+ list(self.fitmanager.bgtheories.keys()).index(
+ self.fitmanager.selectedbg)
+ self.guiConfig.BkgComBox.setCurrentIndex(i)
+ self.__initialParameters()
+
+ def funEvent(self, theoryname):
+ """Select a fit theory to be used for fitting. If this theory exists
+ in :attr:`fitmanager`, use it. Then, reinitialize table.
+
+ :param theoryname: Name of the fit theory to use for fitting. If this theory
+ exists in :attr:`fitmanager`, use it. Else, open a file dialog to open
+ a custom fit function definition file with
+ :meth:`fitmanager.loadtheories`.
+ """
+ theoryname = str(theoryname)
+ if theoryname in self.fitmanager.theories:
+ self.fitmanager.settheory(theoryname)
+ else:
+ # open a load file dialog
+ functionsfile = qt.QFileDialog.getOpenFileName(
+ self, "Select python module with your function(s)", "",
+ "Python Files (*.py);;All Files (*)")
+
+ if len(functionsfile):
+ try:
+ self.fitmanager.loadtheories(functionsfile)
+ except ImportError:
+ qt.QMessageBox.critical(self, "ERROR",
+ "Function not imported")
+ return
+ else:
+ # empty the ComboBox
+ while self.guiConfig.FunComBox.count() > 1:
+ self.guiConfig.FunComBox.removeItem(1)
+ # and fill it again
+ for key in self.fitmanager.theories:
+ self.guiConfig.FunComBox.addItem(str(key))
+
+ i = 1 + \
+ list(self.fitmanager.theories.keys()).index(
+ self.fitmanager.selectedtheory)
+ self.guiConfig.FunComBox.setCurrentIndex(i)
+ self.__initialParameters()
+
+ def weightEvent(self, flag):
+ """This is called when WeightCheckBox is clicked, to configure the
+ *WeightFlag* field in :attr:`fitmanager.fitconfig` and set weights
+ in the least-square problem."""
+ self.configure(WeightFlag=flag)
+ if flag:
+ self.fitmanager.enableweight()
+ else:
+ # set weights back to 1
+ self.fitmanager.disableweight()
+
+ def __initialParameters(self):
+ """Fill the fit parameters names with names of the parameters of
+ the selected background theory and the selected fit theory.
+ Initialize :attr:`fitmanager.fit_results` with these names, and
+ initialize the table with them. This creates a view called "Fit"
+ in :attr:`guiParameters`"""
+ self.fitmanager.parameter_names = []
+ self.fitmanager.fit_results = []
+ for pname in self.fitmanager.bgtheories[self.fitmanager.selectedbg].parameters:
+ self.fitmanager.parameter_names.append(pname)
+ self.fitmanager.fit_results.append({'name': pname,
+ 'estimation': 0,
+ 'group': 0,
+ 'code': 'FREE',
+ 'cons1': 0,
+ 'cons2': 0,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': None,
+ 'xmax': None})
+ if self.fitmanager.selectedtheory is not None:
+ theory = self.fitmanager.selectedtheory
+ for pname in self.fitmanager.theories[theory].parameters:
+ self.fitmanager.parameter_names.append(pname + "1")
+ self.fitmanager.fit_results.append({'name': pname + "1",
+ 'estimation': 0,
+ 'group': 1,
+ 'code': 'FREE',
+ 'cons1': 0,
+ 'cons2': 0,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': None,
+ 'xmax': None})
+
+ self.guiParameters.fillFromFit(
+ self.fitmanager.fit_results, view='Fit')
+
+ def fitStatus(self, data):
+ """Set *status* and *chisq* in status bar"""
+ if 'chisq' in data:
+ if data['chisq'] is None:
+ self.guistatus.ChisqLine.setText(" ")
+ else:
+ chisq = data['chisq']
+ self.guistatus.ChisqLine.setText("%6.2f" % chisq)
+
+ if 'status' in data:
+ status = data['status']
+ self.guistatus.StatusLine.setText(str(status))
+
+ def dismiss(self):
+ """Close FitWidget"""
+ self.close()
+
+
+if __name__ == "__main__":
+ import numpy
+
+ x = numpy.arange(1500).astype(numpy.float64)
+ constant_bg = 3.14
+
+ p = [1000, 100., 30.0,
+ 500, 300., 25.,
+ 1700, 500., 35.,
+ 750, 700., 30.0,
+ 1234, 900., 29.5,
+ 302, 1100., 30.5,
+ 75, 1300., 21.]
+ y = functions.sum_gauss(x, *p) + constant_bg
+
+ a = qt.QApplication(sys.argv)
+ w = FitWidget()
+ w.setData(x=x, y=y)
+ w.show()
+ a.exec()
diff --git a/src/silx/gui/fit/FitWidgets.py b/src/silx/gui/fit/FitWidgets.py
new file mode 100644
index 0000000..0fcc6b7
--- /dev/null
+++ b/src/silx/gui/fit/FitWidgets.py
@@ -0,0 +1,555 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""Collection of widgets used to build
+:class:`silx.gui.fit.FitWidget.FitWidget`"""
+
+from collections import OrderedDict
+
+from silx.gui import qt
+from silx.gui.fit.Parameters import Parameters
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "13/10/2016"
+
+
+class FitActionsButtons(qt.QWidget):
+ """Widget with 3 ``QPushButton``:
+
+ The buttons can be accessed as public attributes::
+
+ - ``EstimateButton``
+ - ``StartFitButton``
+ - ``DismissButton``
+
+ You will typically need to access these attributes to connect the buttons
+ to actions. For instance, if you have 3 functions ``estimate``,
+ ``runfit`` and ``dismiss``, you can connect them like this::
+
+ >>> fit_actions_buttons = FitActionsButtons()
+ >>> fit_actions_buttons.EstimateButton.clicked.connect(estimate)
+ >>> fit_actions_buttons.StartFitButton.clicked.connect(runfit)
+ >>> fit_actions_buttons.DismissButton.clicked.connect(dismiss)
+
+ """
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.resize(234, 53)
+
+ grid_layout = qt.QGridLayout(self)
+ grid_layout.setContentsMargins(11, 11, 11, 11)
+ grid_layout.setSpacing(6)
+ layout = qt.QHBoxLayout(None)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.EstimateButton = qt.QPushButton(self)
+ self.EstimateButton.setText("Estimate")
+ layout.addWidget(self.EstimateButton)
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout.addItem(spacer)
+
+ self.StartFitButton = qt.QPushButton(self)
+ self.StartFitButton.setText("Start Fit")
+ layout.addWidget(self.StartFitButton)
+ spacer_2 = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ layout.addItem(spacer_2)
+
+ self.DismissButton = qt.QPushButton(self)
+ self.DismissButton.setText("Dismiss")
+ layout.addWidget(self.DismissButton)
+
+ grid_layout.addLayout(layout, 0, 0)
+
+
+class FitStatusLines(qt.QWidget):
+ """Widget with 2 greyed out write-only ``QLineEdit``.
+
+ These text widgets can be accessed as public attributes::
+
+ - ``StatusLine``
+ - ``ChisqLine``
+
+ You will typically need to access these widgets to update the displayed
+ text::
+
+ >>> fit_status_lines = FitStatusLines()
+ >>> fit_status_lines.StatusLine.setText("Ready")
+ >>> fit_status_lines.ChisqLine.setText("%6.2f" % 0.01)
+
+ """
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.resize(535, 47)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.StatusLabel = qt.QLabel(self)
+ self.StatusLabel.setText("Status:")
+ layout.addWidget(self.StatusLabel)
+
+ self.StatusLine = qt.QLineEdit(self)
+ self.StatusLine.setText("Ready")
+ self.StatusLine.setReadOnly(1)
+ layout.addWidget(self.StatusLine)
+
+ self.ChisqLabel = qt.QLabel(self)
+ self.ChisqLabel.setText("Reduced chisq:")
+ layout.addWidget(self.ChisqLabel)
+
+ self.ChisqLine = qt.QLineEdit(self)
+ self.ChisqLine.setMaximumSize(qt.QSize(16000, 32767))
+ self.ChisqLine.setText("")
+ self.ChisqLine.setReadOnly(1)
+ layout.addWidget(self.ChisqLine)
+
+
+class FitConfigWidget(qt.QWidget):
+ """Widget whose purpose is to select a fit theory and a background
+ theory, load a new fit theory definition file and provide
+ a "Configure" button to open an advanced configuration dialog.
+
+ This is used in :class:`silx.gui.fit.FitWidget.FitWidget`, to offer
+ an interface to quickly modify the main parameters prior to running a fit:
+
+ - select a fitting function through :attr:`FunComBox`
+ - select a background function through :attr:`BkgComBox`
+ - open a dialog for modifying advanced parameters through
+ :attr:`FunConfigureButton`
+ """
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+
+ self.setWindowTitle("FitConfigGUI")
+
+ layout = qt.QGridLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(6)
+
+ self.FunLabel = qt.QLabel(self)
+ self.FunLabel.setText("Function")
+ layout.addWidget(self.FunLabel, 0, 0)
+
+ self.FunComBox = qt.QComboBox(self)
+ self.FunComBox.addItem("Add Function(s)")
+ self.FunComBox.setItemData(self.FunComBox.findText("Add Function(s)"),
+ "Load fit theories from a file",
+ qt.Qt.ToolTipRole)
+ layout.addWidget(self.FunComBox, 0, 1)
+
+ self.BkgLabel = qt.QLabel(self)
+ self.BkgLabel.setText("Background")
+ layout.addWidget(self.BkgLabel, 1, 0)
+
+ self.BkgComBox = qt.QComboBox(self)
+ self.BkgComBox.addItem("Add Background(s)")
+ self.BkgComBox.setItemData(self.BkgComBox.findText("Add Background(s)"),
+ "Load background theories from a file",
+ qt.Qt.ToolTipRole)
+ layout.addWidget(self.BkgComBox, 1, 1)
+
+ self.FunConfigureButton = qt.QPushButton(self)
+ self.FunConfigureButton.setText("Configure")
+ self.FunConfigureButton.setToolTip(
+ "Open a configuration dialog for the selected function")
+ layout.addWidget(self.FunConfigureButton, 0, 2)
+
+ self.BgConfigureButton = qt.QPushButton(self)
+ self.BgConfigureButton.setText("Configure")
+ self.BgConfigureButton.setToolTip(
+ "Open a configuration dialog for the selected background")
+ layout.addWidget(self.BgConfigureButton, 1, 2)
+
+ self.WeightCheckBox = qt.QCheckBox(self)
+ self.WeightCheckBox.setText("Weighted fit")
+ self.WeightCheckBox.setToolTip(
+ "Enable usage of weights in the least-square problem.\n Use" +
+ " the uncertainties (sigma) if provided, else use sqrt(y).")
+
+ layout.addWidget(self.WeightCheckBox, 0, 3, 2, 1)
+
+ layout.setColumnStretch(4, 1)
+
+
+class ParametersTab(qt.QTabWidget):
+ """This widget provides tabs to display and modify fit parameters. Each
+ tab contains a table with fit data such as parameter names, estimated
+ values, fit constraints, and final fit results.
+
+ The usual way to initialize the table is to fill it with the fit
+ parameters from a :class:`silx.math.fit.fitmanager.FitManager` object, after
+ the estimation process or after the final fit.
+
+ In the following example we use a :class:`ParametersTab` to display the
+ results of two separate fits::
+
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ from silx.math.fit import functions
+ from silx.gui import qt
+ import numpy
+
+ a = qt.QApplication([])
+
+ # Create synthetic data
+ x = numpy.arange(1000)
+ y1 = functions.sum_gauss(x, 100, 400, 100)
+
+ fit = fitmanager.FitManager(x=x, y=y1)
+
+ fitfuns = fittheories.FitTheories()
+ fit.addtheory(theory="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm)
+ fit.settheory('Gaussian')
+ fit.configure(PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,)
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ # Show first fit result in a tab in our widget
+ w = ParametersTab()
+ w.show()
+ w.fillFromFit(fit.fit_results, view='Gaussians')
+
+ # new synthetic data
+ y2 = functions.sum_splitgauss(x,
+ 100, 400, 100, 40,
+ 10, 600, 50, 500,
+ 80, 850, 10, 50)
+ fit.setData(x=x, y=y2)
+
+ # Define new theory
+ fit.addtheory(theory="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss)
+ fit.settheory('Asymetric gaussian')
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ # Show first fit result in another tab in our widget
+ w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+ a.exec()
+
+ """
+
+ def __init__(self, parent=None, name="FitParameters"):
+ """
+
+ :param parent: Parent widget
+ :param name: Widget title
+ """
+ qt.QTabWidget.__init__(self, parent)
+ self.setWindowTitle(name)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self.views = OrderedDict()
+ """Dictionary of views. Keys are view names,
+ items are :class:`Parameters` widgets"""
+
+ self.latest_view = None
+ """Name of latest view"""
+
+ # the widgets/tables themselves
+ self.tables = {}
+ """Dictionary of :class:`silx.gui.fit.parameters.Parameters` objects.
+ These objects store fit results
+ """
+
+ self.setContentsMargins(10, 10, 10, 10)
+
+ def setView(self, view=None, fitresults=None):
+ """Add or update a table. Fill it with data from a fit
+
+ :param view: Tab name to be added or updated. If ``None``, use the
+ latest view.
+ :param fitresults: Fit data to be added to the table
+ :raise: KeyError if no view name specified and no latest view
+ available.
+ """
+ if view is None:
+ if self.latest_view is not None:
+ view = self.latest_view
+ else:
+ raise KeyError(
+ "No view available. You must specify a view" +
+ " name the first time you call this method."
+ )
+
+ if view in self.tables.keys():
+ table = self.tables[view]
+ else:
+ # create the parameters instance
+ self.tables[view] = Parameters(self)
+ table = self.tables[view]
+ self.views[view] = table
+ self.addTab(table, str(view))
+
+ if fitresults is not None:
+ table.fillFromFit(fitresults)
+
+ self.setCurrentWidget(self.views[view])
+ self.latest_view = view
+
+ def renameView(self, oldname=None, newname=None):
+ """Rename a view (tab)
+
+ :param oldname: Name of the view to be renamed
+ :param newname: New name of the view"""
+ error = 1
+ if newname is not None:
+ if newname not in self.views.keys():
+ if oldname in self.views.keys():
+ parameterlist = self.tables[oldname].getFitResults()
+ self.setView(view=newname, fitresults=parameterlist)
+ self.removeView(oldname)
+ error = 0
+ return error
+
+ def fillFromFit(self, fitparameterslist, view=None):
+ """Update a view with data from a fit (alias for :meth:`setView`)
+
+ :param view: Tab name to be added or updated (default: latest view)
+ :param fitparameterslist: Fit data to be added to the table
+ """
+ self.setView(view=view, fitresults=fitparameterslist)
+
+ def getFitResults(self, name=None):
+ """Call :meth:`getFitResults` for the
+ :class:`silx.gui.fit.parameters.Parameters` corresponding to the
+ latest table or to the named table (if ``name`` is not
+ ``None``). This return a list of dictionaries in the format used by
+ :class:`silx.math.fit.fitmanager.FitManager` to store fit parameter
+ results.
+
+ :param name: View name.
+ """
+ if name is None:
+ name = self.latest_view
+ return self.tables[name].getFitResults()
+
+ def removeView(self, name):
+ """Remove a view by name.
+
+ :param name: View name.
+ """
+ if name in self.views:
+ index = self.indexOf(self.tables[name])
+ self.removeTab(index)
+ index = self.indexOf(self.views[name])
+ self.removeTab(index)
+ del self.tables[name]
+ del self.views[name]
+
+ def removeAllViews(self, keep=None):
+ """Remove all views, except the one specified (argument
+ ``keep``)
+
+ :param keep: Name of the view to be kept."""
+ for view in self.tables:
+ if view != keep:
+ self.removeView(view)
+
+ def getHtmlText(self, name=None):
+ """Return the table data as HTML
+
+ :param name: View name."""
+ if name is None:
+ name = self.latest_view
+ table = self.tables[name]
+ lemon = ("#%x%x%x" % (255, 250, 205)).upper()
+ hcolor = ("#%x%x%x" % (230, 240, 249)).upper()
+ text = ""
+ text += "<nobr>"
+ text += "<table>"
+ text += "<tr>"
+ ncols = table.columnCount()
+ for l in range(ncols):
+ text += ('<td align="left" bgcolor="%s"><b>' % hcolor)
+ text += str(table.horizontalHeaderItem(l).text())
+ text += "</b></td>"
+ text += "</tr>"
+ nrows = table.rowCount()
+ for r in range(nrows):
+ text += "<tr>"
+ item = table.item(r, 0)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ color = "white"
+ b = "<b>"
+ else:
+ b = ""
+ color = lemon
+ try:
+ # MyQTable item has color defined
+ cc = table.item(r, 0).color
+ cc = ("#%x%x%x" % (cc.red(), cc.green(), cc.blue())).upper()
+ color = cc
+ except:
+ pass
+ for c in range(ncols):
+ item = table.item(r, c)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ finalcolor = color
+ else:
+ finalcolor = "white"
+ if c < 2:
+ text += ('<td align="left" bgcolor="%s">%s' %
+ (finalcolor, b))
+ else:
+ text += ('<td align="right" bgcolor="%s">%s' %
+ (finalcolor, b))
+ text += newtext
+ if len(b):
+ text += "</td>"
+ else:
+ text += "</b></td>"
+ item = table.item(r, 0)
+ newtext = ""
+ if item is not None:
+ newtext = str(item.text())
+ if len(newtext):
+ text += "</b>"
+ text += "</tr>"
+ text += "\n"
+ text += "</table>"
+ text += "</nobr>"
+ return text
+
+ def getText(self, name=None):
+ """Return the table data as CSV formatted text, using tabulation
+ characters as separators.
+
+ :param name: View name."""
+ if name is None:
+ name = self.latest_view
+ table = self.tables[name]
+ text = ""
+ ncols = table.columnCount()
+ for l in range(ncols):
+ text += (str(table.horizontalHeaderItem(l).text())) + "\t"
+ text += "\n"
+ nrows = table.rowCount()
+ for r in range(nrows):
+ for c in range(ncols):
+ newtext = ""
+ if c != 4:
+ item = table.item(r, c)
+ if item is not None:
+ newtext = str(item.text())
+ else:
+ item = table.cellWidget(r, c)
+ if item is not None:
+ newtext = str(item.currentText())
+ text += newtext + "\t"
+ text += "\n"
+ text += "\n"
+ return text
+
+
+def test():
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ from silx.math.fit import functions
+ from silx.gui.plot.PlotWindow import PlotWindow
+ import numpy
+
+ a = qt.QApplication([])
+
+ x = numpy.arange(1000)
+ y1 = functions.sum_gauss(x, 100, 400, 100)
+
+ fit = fitmanager.FitManager(x=x, y=y1)
+
+ fitfuns = fittheories.FitTheories()
+ fit.addtheory(name="Gaussian",
+ function=functions.sum_gauss,
+ parameters=("height", "peak center", "fwhm"),
+ estimate=fitfuns.estimate_height_position_fwhm)
+ fit.settheory('Gaussian')
+ fit.configure(PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ AutoFwhm=True,)
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ w = ParametersTab()
+ w.show()
+ w.fillFromFit(fit.fit_results, view='Gaussians')
+
+ y2 = functions.sum_splitgauss(x,
+ 100, 400, 100, 40,
+ 10, 600, 50, 500,
+ 80, 850, 10, 50)
+ fit.setdata(x=x, y=y2)
+
+ # Define new theory
+ fit.addtheory(name="Asymetric gaussian",
+ function=functions.sum_splitgauss,
+ parameters=("height", "peak center", "left fwhm", "right fwhm"),
+ estimate=fitfuns.estimate_splitgauss)
+ fit.settheory('Asymetric gaussian')
+
+ # Fit
+ fit.estimate()
+ fit.runfit()
+
+ w.fillFromFit(fit.fit_results, view='Asymetric gaussians')
+
+ # Plot
+ pw = PlotWindow(control=True)
+ pw.addCurve(x, y1, "Gaussians")
+ pw.addCurve(x, y2, "Asymetric gaussians")
+ pw.show()
+
+ a.exec()
+
+
+if __name__ == "__main__":
+ test()
diff --git a/src/silx/gui/fit/Parameters.py b/src/silx/gui/fit/Parameters.py
new file mode 100644
index 0000000..daa72f3
--- /dev/null
+++ b/src/silx/gui/fit/Parameters.py
@@ -0,0 +1,882 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ######################################################################### */
+"""This module defines a table widget that is specialized in displaying fit
+parameter results and associated constraints."""
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "25/11/2016"
+
+import sys
+from collections import OrderedDict
+
+from silx.gui import qt
+from silx.gui.widgets.TableWidget import TableWidget
+
+
+def float_else_zero(sstring):
+ """Return converted string to float. If conversion fail, return zero.
+
+ :param sstring: String to be converted
+ :return: ``float(sstrinq)`` if ``sstring`` can be converted to float
+ (e.g. ``"3.14"``), else ``0``
+ """
+ try:
+ return float(sstring)
+ except ValueError:
+ return 0
+
+
+class QComboTableItem(qt.QComboBox):
+ """:class:`qt.QComboBox` augmented with a ``sigCellChanged`` signal
+ to emit a tuple of ``(row, column)`` coordinates when the value is
+ changed.
+
+ This signal can be used to locate the modified combo box in a table.
+
+ :param row: Row number of the table cell containing this widget
+ :param col: Column number of the table cell containing this widget"""
+ sigCellChanged = qt.Signal(int, int)
+ """Signal emitted when this ``QComboBox`` is activated.
+ A ``(row, column)`` tuple is passed."""
+
+ def __init__(self, parent=None, row=None, col=None):
+ self._row = row
+ self._col = col
+ qt.QComboBox.__init__(self, parent)
+ self.activated[int].connect(self._cellChanged)
+
+ def _cellChanged(self, idx): # noqa
+ self.sigCellChanged.emit(self._row, self._col)
+
+
+class QCheckBoxItem(qt.QCheckBox):
+ """:class:`qt.QCheckBox` augmented with a ``sigCellChanged`` signal
+ to emit a tuple of ``(row, column)`` coordinates when the check box has
+ been clicked on.
+
+ This signal can be used to locate the modified check box in a table.
+
+ :param row: Row number of the table cell containing this widget
+ :param col: Column number of the table cell containing this widget"""
+ sigCellChanged = qt.Signal(int, int)
+ """Signal emitted when this ``QCheckBox`` is clicked.
+ A ``(row, column)`` tuple is passed."""
+
+ def __init__(self, parent=None, row=None, col=None):
+ self._row = row
+ self._col = col
+ qt.QCheckBox.__init__(self, parent)
+ self.clicked.connect(self._cellChanged)
+
+ def _cellChanged(self):
+ self.sigCellChanged.emit(self._row, self._col)
+
+
+class Parameters(TableWidget):
+ """:class:`TableWidget` customized to display fit results
+ and to interact with :class:`FitManager` objects.
+
+ Data and references to cell widgets are kept in a dictionary
+ attribute :attr:`parameters`.
+
+ :param parent: Parent widget
+ :param labels: Column headers. If ``None``, default headers will be used.
+ :type labels: List of strings or None
+ :param paramlist: List of fit parameters to be displayed for each fitted
+ peak.
+ :type paramlist: list[str] or None
+ """
+ def __init__(self, parent=None, paramlist=None):
+ TableWidget.__init__(self, parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ labels = ['Parameter', 'Estimation', 'Fit Value', 'Sigma',
+ 'Constraints', 'Min/Parame', 'Max/Factor/Delta']
+ tooltips = ["Fit parameter name",
+ "Estimated value for fit parameter. You can edit this column.",
+ "Actual value for parameter, after fit",
+ "Uncertainty (same unit as the parameter)",
+ "Constraint to be applied to the parameter for fit",
+ "First parameter for constraint (name of another param or min value)",
+ "Second parameter for constraint (max value, or factor/delta)"]
+
+ self.columnKeys = ['name', 'estimation', 'fitresult',
+ 'sigma', 'code', 'val1', 'val2']
+ """This list assigns shorter keys to refer to columns than the
+ displayed labels."""
+
+ self.__configuring = False
+
+ # column headers and associated tooltips
+ self.setColumnCount(len(labels))
+
+ for i, label in enumerate(labels):
+ item = self.horizontalHeaderItem(i)
+ if item is None:
+ item = qt.QTableWidgetItem(label,
+ qt.QTableWidgetItem.Type)
+ self.setHorizontalHeaderItem(i, item)
+
+ item.setText(label)
+ if tooltips is not None:
+ item.setToolTip(tooltips[i])
+
+ # resize columns
+ for col_key in ["name", "estimation", "sigma", "val1", "val2"]:
+ col_idx = self.columnIndexByField(col_key)
+ self.resizeColumnToContents(col_idx)
+
+ # Initialize the table with one line per supplied parameter
+ paramlist = paramlist if paramlist is not None else []
+ self.parameters = OrderedDict()
+ """This attribute stores all the data in an ordered dictionary.
+ New data can be added using :meth:`newParameterLine`.
+ Existing data can be modified using :meth:`configureLine`
+
+ Keys of the dictionary are:
+
+ - 'name': parameter name
+ - 'line': line index for the parameter in the table
+ - 'estimation'
+ - 'fitresult'
+ - 'sigma'
+ - 'code': constraint code (one of the elements of
+ :attr:`code_options`)
+ - 'val1': first parameter related to constraint, formatted
+ as a string, as typed in the table
+ - 'val2': second parameter related to constraint, formatted
+ as a string, as typed in the table
+ - 'cons1': scalar representation of 'val1'
+ (e.g. when val1 is the name of a fit parameter, cons1
+ will be the line index of this parameter)
+ - 'cons2': scalar representation of 'val2'
+ - 'vmin': equal to 'val1' when 'code' is "QUOTED"
+ - 'vmax': equal to 'val2' when 'code' is "QUOTED"
+ - 'relatedto': name of related parameter when this parameter
+ is constrained to another parameter (same as 'val1')
+ - 'factor': same as 'val2' when 'code' is 'FACTOR'
+ - 'delta': same as 'val2' when 'code' is 'DELTA'
+ - 'sum': same as 'val2' when 'code' is 'SUM'
+ - 'group': group index for the parameter
+ - 'xmin': data range minimum
+ - 'xmax': data range maximum
+ """
+ for line, param in enumerate(paramlist):
+ self.newParameterLine(param, line)
+
+ self.code_options = ["FREE", "POSITIVE", "QUOTED", "FIXED",
+ "FACTOR", "DELTA", "SUM", "IGNORE", "ADD"]
+ """Possible values in the combo boxes in the 'Constraints' column.
+ """
+
+ # connect signal
+ self.cellChanged[int, int].connect(self.onCellChanged)
+
+ def newParameterLine(self, param, line):
+ """Add a line to the :class:`QTableWidget`.
+
+ Each line represents one of the fit parameters for one of
+ the fitted peaks.
+
+ :param param: Name of the fit parameter
+ :type param: str
+ :param line: 0-based line index
+ :type line: int
+ """
+ # get current number of lines
+ nlines = self.rowCount()
+ self.__configuring = True
+ if line >= nlines:
+ self.setRowCount(line + 1)
+
+ # default configuration for fit parameters
+ self.parameters[param] = OrderedDict((('line', line),
+ ('estimation', '0'),
+ ('fitresult', ''),
+ ('sigma', ''),
+ ('code', 'FREE'),
+ ('val1', ''),
+ ('val2', ''),
+ ('cons1', 0),
+ ('cons2', 0),
+ ('vmin', '0'),
+ ('vmax', '1'),
+ ('relatedto', ''),
+ ('factor', '1.0'),
+ ('delta', '0.0'),
+ ('sum', '0.0'),
+ ('group', ''),
+ ('name', param),
+ ('xmin', None),
+ ('xmax', None)))
+ self.setReadWrite(param, 'estimation')
+ self.setReadOnly(param, ['name', 'fitresult', 'sigma', 'val1', 'val2'])
+
+ # Constraint codes
+ a = []
+ for option in self.code_options:
+ a.append(option)
+
+ code_column_index = self.columnIndexByField('code')
+ cellWidget = self.cellWidget(line, code_column_index)
+ if cellWidget is None:
+ cellWidget = QComboTableItem(self, row=line,
+ col=code_column_index)
+ cellWidget.addItems(a)
+ self.setCellWidget(line, code_column_index, cellWidget)
+ cellWidget.sigCellChanged[int, int].connect(self.onCellChanged)
+ self.parameters[param]['code_item'] = cellWidget
+ self.parameters[param]['relatedto_item'] = None
+ self.__configuring = False
+
+ def columnIndexByField(self, field):
+ """
+
+ :param field: Field name (column key)
+ :return: Index of the column with this field name
+ """
+ return self.columnKeys.index(field)
+
+ def fillFromFit(self, fitresults):
+ """Fill table with values from a list of dictionaries
+ (see :attr:`silx.math.fit.fitmanager.FitManager.fit_results`)
+
+ :param fitresults: List of parameters as recorded
+ in the ``paramlist`` attribute of a :class:`FitManager` object
+ :type fitresults: list[dict]
+ """
+ self.setRowCount(len(fitresults))
+
+ # Reinitialize and fill self.parameters
+ self.parameters = OrderedDict()
+ for (line, param) in enumerate(fitresults):
+ self.newParameterLine(param['name'], line)
+
+ for param in fitresults:
+ name = param['name']
+ code = str(param['code'])
+ if code not in self.code_options:
+ # convert code from int to descriptive string
+ code = self.code_options[int(code)]
+ val1 = param['cons1']
+ val2 = param['cons2']
+ estimation = param['estimation']
+ group = param['group']
+ sigma = param['sigma']
+ fitresult = param['fitresult']
+
+ xmin = param.get('xmin')
+ xmax = param.get('xmax')
+
+ self.configureLine(name=name,
+ code=code,
+ val1=val1, val2=val2,
+ estimation=estimation,
+ fitresult=fitresult,
+ sigma=sigma,
+ group=group,
+ xmin=xmin, xmax=xmax)
+
+ def getConfiguration(self):
+ """Return ``FitManager.paramlist`` dictionary
+ encapsulated in another dictionary"""
+ return {'parameters': self.getFitResults()}
+
+ def setConfiguration(self, ddict):
+ """Fill table with values from a ``FitManager.paramlist`` dictionary
+ encapsulated in another dictionary"""
+ self.fillFromFit(ddict['parameters'])
+
+ def getFitResults(self):
+ """Return fit parameters as a list of dictionaries in the format used
+ by :class:`FitManager` (attribute ``paramlist``).
+ """
+ fitparameterslist = []
+ for param in self.parameters:
+ fitparam = {}
+ name = param
+ estimation, [code, cons1, cons2] = self.getEstimationConstraints(name)
+ buf = str(self.parameters[param]['fitresult'])
+ xmin = self.parameters[param]['xmin']
+ xmax = self.parameters[param]['xmax']
+ if len(buf):
+ fitresult = float(buf)
+ else:
+ fitresult = 0.0
+ buf = str(self.parameters[param]['sigma'])
+ if len(buf):
+ sigma = float(buf)
+ else:
+ sigma = 0.0
+ buf = str(self.parameters[param]['group'])
+ if len(buf):
+ group = float(buf)
+ else:
+ group = 0
+ fitparam['name'] = name
+ fitparam['estimation'] = estimation
+ fitparam['fitresult'] = fitresult
+ fitparam['sigma'] = sigma
+ fitparam['group'] = group
+ fitparam['code'] = code
+ fitparam['cons1'] = cons1
+ fitparam['cons2'] = cons2
+ fitparam['xmin'] = xmin
+ fitparam['xmax'] = xmax
+ fitparameterslist.append(fitparam)
+ return fitparameterslist
+
+ def onCellChanged(self, row, col):
+ """Slot called when ``cellChanged`` signal is emitted.
+ Checks the validity of the new text in the cell, then calls
+ :meth:`configureLine` to update the internal ``self.parameters``
+ dictionary.
+
+ :param row: Row number of the changed cell (0-based index)
+ :param col: Column number of the changed cell (0-based index)
+ """
+ if (col != self.columnIndexByField("code")) and (col != -1):
+ if row != self.currentRow():
+ return
+ if col != self.currentColumn():
+ return
+ if self.__configuring:
+ return
+ param = list(self.parameters)[row]
+ field = self.columnKeys[col]
+ oldvalue = self.parameters[param][field]
+ if col != 4:
+ item = self.item(row, col)
+ if item is not None:
+ newvalue = item.text()
+ else:
+ newvalue = ''
+ else:
+ # this is the combobox
+ widget = self.cellWidget(row, col)
+ newvalue = widget.currentText()
+ if self.validate(param, field, oldvalue, newvalue):
+ paramdict = {"name": param, field: newvalue}
+ self.configureLine(**paramdict)
+ else:
+ if field == 'code':
+ # New code not valid, try restoring the old one
+ index = self.code_options.index(oldvalue)
+ self.__configuring = True
+ try:
+ self.parameters[param]['code_item'].setCurrentIndex(index)
+ finally:
+ self.__configuring = False
+ else:
+ paramdict = {"name": param, field: oldvalue}
+ self.configureLine(**paramdict)
+
+ def validate(self, param, field, oldvalue, newvalue):
+ """Check validity of ``newvalue`` when a cell's value is modified.
+
+ :param param: Fit parameter name
+ :param field: Column name
+ :param oldvalue: Cell value before change attempt
+ :param newvalue: New value to be validated
+ :return: True if new cell value is valid, else False
+ """
+ if field == 'code':
+ return self.setCodeValue(param, oldvalue, newvalue)
+ # FIXME: validate() shouldn't have side effects. Move this bit to configureLine()?
+ if field == 'val1' and str(self.parameters[param]['code']) in ['DELTA', 'FACTOR', 'SUM']:
+ _, candidates = self.getRelatedCandidates(param)
+ # We expect val1 to be a fit parameter name
+ if str(newvalue) in candidates:
+ return True
+ else:
+ return False
+ # except for code, val1 and name (which is read-only and does not need
+ # validation), all fields must always be convertible to float
+ else:
+ try:
+ float(str(newvalue))
+ except ValueError:
+ return False
+ return True
+
+ def setCodeValue(self, param, oldvalue, newvalue):
+ """Update 'code' and 'relatedto' fields when code cell is
+ changed.
+
+ :param param: Fit parameter name
+ :param oldvalue: Cell value before change attempt
+ :param newvalue: New value to be validated
+ :return: ``True`` if code was successfully updated
+ """
+
+ if str(newvalue) in ['FREE', 'POSITIVE', 'QUOTED', 'FIXED']:
+ self.configureLine(name=param,
+ code=newvalue)
+ if str(oldvalue) == 'IGNORE':
+ self.freeRestOfGroup(param)
+ return True
+ elif str(newvalue) in ['FACTOR', 'DELTA', 'SUM']:
+ # I should check here that some parameter is set
+ best, candidates = self.getRelatedCandidates(param)
+ if len(candidates) == 0:
+ return False
+ self.configureLine(name=param,
+ code=newvalue,
+ relatedto=best)
+ if str(oldvalue) == 'IGNORE':
+ self.freeRestOfGroup(param)
+ return True
+
+ elif str(newvalue) == 'IGNORE':
+ # I should check if the group can be ignored
+ # for the time being I just fix all of them to ignore
+ group = int(float(str(self.parameters[param]['group'])))
+ candidates = []
+ for param in self.parameters.keys():
+ if group == int(float(str(self.parameters[param]['group']))):
+ candidates.append(param)
+ # print candidates
+ # I should check here if there is any relation to them
+ for param in candidates:
+ self.configureLine(name=param,
+ code=newvalue)
+ return True
+ elif str(newvalue) == 'ADD':
+ group = int(float(str(self.parameters[param]['group'])))
+ if group == 0:
+ # One cannot add a background group
+ return False
+ i = 0
+ for param in self.parameters:
+ if i <= int(float(str(self.parameters[param]['group']))):
+ i += 1
+ if (group == 0) and (i == 1): # FIXME: why +1?
+ i += 1
+ self.addGroup(i, group)
+ return False
+ elif str(newvalue) == 'SHOW':
+ print(self.getEstimationConstraints(param))
+ return False
+
+ def addGroup(self, newg, gtype):
+ """Add a fit parameter group with the same fit parameters as an
+ existing group.
+
+ This function is called when the user selects "ADD" in the
+ "constraints" combobox.
+
+ :param int newg: New group number
+ :param int gtype: Group number whose parameters we want to copy
+
+ """
+ newparam = []
+ # loop through parameters until we encounter group number `gtype`
+ for param in list(self.parameters):
+ paramgroup = int(float(str(self.parameters[param]['group'])))
+ # copy parameter names in group number `gtype`
+ if paramgroup == gtype:
+ # but replace `gtype` with `newg`
+ newparam.append(param.rstrip("0123456789") + "%d" % newg)
+
+ xmin = self.parameters[param]['xmin']
+ xmax = self.parameters[param]['xmax']
+
+ # Add new parameters (one table line per parameter) and configureLine each
+ # one by updating xmin and xmax to the same values as group `gtype`
+ line = len(list(self.parameters))
+ for param in newparam:
+ self.newParameterLine(param, line)
+ line += 1
+ for param in newparam:
+ self.configureLine(name=param, group=newg, xmin=xmin, xmax=xmax)
+
+ def freeRestOfGroup(self, workparam):
+ """Set ``code`` to ``"FREE"`` for all fit parameters belonging to
+ the same group as ``workparam``. This is done when the entire group
+ of parameters was previously ignored and one of them has his code
+ set to something different than ``"IGNORE"``.
+
+ :param workparam: Fit parameter name
+ """
+ if workparam in self.parameters.keys():
+ group = int(float(str(self.parameters[workparam]['group'])))
+ for param in self.parameters:
+ if param != workparam and\
+ group == int(float(str(self.parameters[param]['group']))):
+ self.configureLine(name=param,
+ code='FREE',
+ cons1=0,
+ cons2=0,
+ val1='',
+ val2='')
+
+ def getRelatedCandidates(self, workparam):
+ """If fit parameter ``workparam`` has a constraint that involves other
+ fit parameters, find possible candidates and try to guess which one
+ is the most likely.
+
+ :param workparam: Fit parameter name
+ :return: (best_candidate, possible_candidates) tuple
+ :rtype: (str, list[str])
+ """
+ candidates = []
+ for param_name in self.parameters:
+ if param_name != workparam:
+ # ignore parameters that are fixed by a constraint
+ if str(self.parameters[param_name]['code']) not in\
+ ['IGNORE', 'FACTOR', 'DELTA', 'SUM']:
+ candidates.append(param_name)
+ # take the previous one (before code cell changed) if possible
+ if str(self.parameters[workparam]['relatedto']) in candidates:
+ best = str(self.parameters[workparam]['relatedto'])
+ return best, candidates
+ # take the first with same base name (after removing numbers)
+ for param_name in candidates:
+ basename = param_name.rstrip("0123456789")
+ try:
+ pos = workparam.index(basename)
+ if pos == 0:
+ best = param_name
+ return best, candidates
+ except ValueError:
+ pass
+ # take the first
+ return candidates[0], candidates
+
+ def setReadOnly(self, parameter, fields):
+ """Make table cells read-only by setting it's flags and omitting
+ flag ``qt.Qt.ItemIsEditable``
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ """
+ editflags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
+ self.setField(parameter, fields, editflags)
+
+ def setReadWrite(self, parameter, fields):
+ """Make table cells read-write by setting it's flags including
+ flag ``qt.Qt.ItemIsEditable``
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ """
+ editflags = qt.Qt.ItemIsSelectable |\
+ qt.Qt.ItemIsEnabled |\
+ qt.Qt.ItemIsEditable
+ self.setField(parameter, fields, editflags)
+
+ def setField(self, parameter, fields, edit_flags):
+ """Set text and flags in a table cell.
+
+ :param parameter: Fit parameter names identifying the rows
+ :type parameter: str or list[str]
+ :param fields: Field names identifying the columns
+ :type fields: str or list[str]
+ :param edit_flags: Flag combination, e.g::
+
+ qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable
+ """
+ if isinstance(parameter, list) or \
+ isinstance(parameter, tuple):
+ paramlist = parameter
+ else:
+ paramlist = [parameter]
+ if isinstance(fields, list) or \
+ isinstance(fields, tuple):
+ fieldlist = fields
+ else:
+ fieldlist = [fields]
+
+ # Set _configuring flag to ignore cellChanged signals in
+ # self.onCellChanged
+ _oldvalue = self.__configuring
+ self.__configuring = True
+
+ # 2D loop through parameter list and field list
+ # to update their cells
+ for param in paramlist:
+ row = list(self.parameters.keys()).index(param)
+ for field in fieldlist:
+ col = self.columnIndexByField(field)
+ if field != 'code':
+ key = field + "_item"
+ item = self.item(row, col)
+ if item is None:
+ item = qt.QTableWidgetItem()
+ item.setText(self.parameters[param][field])
+ self.setItem(row, col, item)
+ else:
+ item.setText(self.parameters[param][field])
+ self.parameters[param][key] = item
+ item.setFlags(edit_flags)
+
+ # Restore previous _configuring flag
+ self.__configuring = _oldvalue
+
+ def configureLine(self, name, code=None, val1=None, val2=None,
+ sigma=None, estimation=None, fitresult=None,
+ group=None, xmin=None, xmax=None, relatedto=None,
+ cons1=None, cons2=None):
+ """This function updates values in a line of the table
+
+ :param name: Name of the parameter (serves as unique identifier for
+ a line).
+ :param code: Constraint code *FREE, FIXED, POSITIVE, DELTA, FACTOR,
+ SUM, QUOTED, IGNORE*
+ :param val1: Constraint 1 (can be the index or name of another
+ parameter for code *DELTA, FACTOR, SUM*, or a min value
+ for code *QUOTED*)
+ :param val2: Constraint 2
+ :param sigma: Standard deviation for a fit parameter
+ :param estimation: Estimated initial value for a fit parameter (used
+ as input to iterative fit)
+ :param fitresult: Final result of fit
+ :param group: Group number of a fit parameter (peak number when doing
+ multi-peak fitting, as each peak corresponds to a group
+ of several consecutive parameters)
+ :param xmin:
+ :param xmax:
+ :param relatedto: Index or name of another fit parameter
+ to which this parameter is related to (constraints)
+ :param cons1: similar meaning to ``val1``, but is always a number
+ :param cons2: similar meaning to ``val2``, but is always a number
+ :return:
+ """
+ paramlist = list(self.parameters.keys())
+
+ if name not in self.parameters:
+ raise KeyError("'%s' is not in the parameter list" % name)
+
+ # update code first, if specified
+ if code is not None:
+ code = str(code)
+ self.parameters[name]['code'] = code
+ # update combobox
+ index = self.parameters[name]['code_item'].findText(code)
+ self.parameters[name]['code_item'].setCurrentIndex(index)
+ else:
+ # set code to previous value, used later for setting val1 val2
+ code = self.parameters[name]['code']
+
+ # val1 and sigma have special formats
+ if val1 is not None:
+ fmt = None if self.parameters[name]['code'] in\
+ ['DELTA', 'FACTOR', 'SUM'] else "%8g"
+ self._updateField(name, "val1", val1, fmat=fmt)
+
+ if sigma is not None:
+ self._updateField(name, "sigma", sigma, fmat="%6.3g")
+
+ # other fields are formatted as "%8g"
+ keys_params = (("val2", val2), ("estimation", estimation),
+ ("fitresult", fitresult))
+ for key, value in keys_params:
+ if value is not None:
+ self._updateField(name, key, value, fmat="%8g")
+
+ # the rest of the parameters are treated as strings and don't need
+ # validation
+ keys_params = (("group", group), ("xmin", xmin),
+ ("xmax", xmax), ("relatedto", relatedto),
+ ("cons1", cons1), ("cons2", cons2))
+ for key, value in keys_params:
+ if value is not None:
+ self.parameters[name][key] = str(value)
+
+ # val1 and val2 have different meanings depending on the code
+ if code == 'QUOTED':
+ if val1 is not None:
+ self.parameters[name]['vmin'] = self.parameters[name]['val1']
+ else:
+ self.parameters[name]['val1'] = self.parameters[name]['vmin']
+ if val2 is not None:
+ self.parameters[name]['vmax'] = self.parameters[name]['val2']
+ else:
+ self.parameters[name]['val2'] = self.parameters[name]['vmax']
+
+ # cons1 and cons2 are scalar representations of val1 and val2
+ self.parameters[name]['cons1'] =\
+ float_else_zero(self.parameters[name]['val1'])
+ self.parameters[name]['cons2'] =\
+ float_else_zero(self.parameters[name]['val2'])
+
+ # cons1, cons2 = min(val1, val2), max(val1, val2)
+ if self.parameters[name]['cons1'] > self.parameters[name]['cons2']:
+ self.parameters[name]['cons1'], self.parameters[name]['cons2'] =\
+ self.parameters[name]['cons2'], self.parameters[name]['cons1']
+
+ elif code in ['DELTA', 'SUM', 'FACTOR']:
+ # For these codes, val1 is the fit parameter name on which the
+ # constraint depends
+ if val1 is not None and val1 in paramlist:
+ self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+
+ elif val1 is not None:
+ # val1 could be the index of the fit parameter
+ try:
+ self.parameters[name]['relatedto'] = paramlist[int(val1)]
+ except ValueError:
+ self.parameters[name]['relatedto'] = self.parameters[name]["val1"]
+
+ elif relatedto is not None:
+ # code changed, val1 not specified but relatedto specified:
+ # set val1 to relatedto (pre-fill best guess)
+ self.parameters[name]["val1"] = relatedto
+
+ # update fields "delta", "sum" or "factor"
+ key = code.lower()
+ self.parameters[name][key] = self.parameters[name]["val2"]
+
+ # FIXME: val1 is sometimes specified as an index rather than a param name
+ self.parameters[name]['val1'] = self.parameters[name]['relatedto']
+
+ # cons1 is the index of the fit parameter in the ordered dictionary
+ if self.parameters[name]['val1'] in paramlist:
+ self.parameters[name]['cons1'] =\
+ paramlist.index(self.parameters[name]['val1'])
+
+ # cons2 is the constraint value (factor, delta or sum)
+ try:
+ self.parameters[name]['cons2'] =\
+ float(str(self.parameters[name]['val2']))
+ except ValueError:
+ self.parameters[name]['cons2'] = 1.0 if code == "FACTOR" else 0.0
+
+ elif code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
+ self.parameters[name]['val1'] = ""
+ self.parameters[name]['val2'] = ""
+ self.parameters[name]['cons1'] = 0
+ self.parameters[name]['cons2'] = 0
+
+ self._updateCellRWFlags(name, code)
+
+ def _updateField(self, name, field, value, fmat=None):
+ """Update field in ``self.parameters`` dictionary, if the new value
+ is valid.
+
+ :param name: Fit parameter name
+ :param field: Field name
+ :param value: New value to assign
+ :type value: String
+ :param fmat: Format string (e.g. "%8g") to be applied if value represents
+ a scalar. If ``None``, format is not modified. If ``value`` is an
+ empty string, ``fmat`` is ignored.
+ """
+ if value is not None:
+ oldvalue = self.parameters[name][field]
+ if fmat is not None:
+ newvalue = fmat % float(value) if value != "" else ""
+ else:
+ newvalue = value
+ self.parameters[name][field] = newvalue if\
+ self.validate(name, field, oldvalue, newvalue) else\
+ oldvalue
+
+ def _updateCellRWFlags(self, name, code=None):
+ """Set read-only or read-write flags in a row,
+ depending on the constraint code
+
+ :param name: Fit parameter name identifying the row
+ :param code: Constraint code, in `'FREE', 'POSITIVE', 'IGNORE',`
+ `'FIXED', 'FACTOR', 'DELTA', 'SUM', 'ADD'`
+ :return:
+ """
+ if code in ['FREE', 'POSITIVE', 'IGNORE', 'FIXED']:
+ self.setReadWrite(name, 'estimation')
+ self.setReadOnly(name, ['fitresult', 'sigma', 'val1', 'val2'])
+ else:
+ self.setReadWrite(name, ['estimation', 'val1', 'val2'])
+ self.setReadOnly(name, ['fitresult', 'sigma'])
+
+ def getEstimationConstraints(self, param):
+ """
+ Return tuple ``(estimation, constraints)`` where ``estimation`` is the
+ value in the ``estimate`` field and ``constraints`` are the relevant
+ constraints according to the active code
+ """
+ estimation = None
+ constraints = None
+ if param in self.parameters.keys():
+ buf = str(self.parameters[param]['estimation'])
+ if len(buf):
+ estimation = float(buf)
+ else:
+ estimation = 0
+ if str(self.parameters[param]['code']) in self.code_options:
+ code = self.code_options.index(
+ str(self.parameters[param]['code']))
+ else:
+ code = str(self.parameters[param]['code'])
+ cons1 = self.parameters[param]['cons1']
+ cons2 = self.parameters[param]['cons2']
+ constraints = [code, cons1, cons2]
+ return estimation, constraints
+
+
+def main(args):
+ from silx.math.fit import fittheories
+ from silx.math.fit import fitmanager
+ try:
+ from PyMca5 import PyMcaDataDir
+ except ImportError:
+ raise ImportError("This demo requires PyMca data. Install PyMca5.")
+ import numpy
+ import os
+ app = qt.QApplication(args)
+ tab = Parameters(paramlist=['Height', 'Position', 'FWHM'])
+ tab.showGrid()
+ tab.configureLine(name='Height', estimation='1234', group=0)
+ tab.configureLine(name='Position', code='FIXED', group=1)
+ tab.configureLine(name='FWHM', group=1)
+
+ y = numpy.loadtxt(os.path.join(PyMcaDataDir.PYMCA_DATA_DIR,
+ "XRFSpectrum.mca")) # FIXME
+
+ x = numpy.arange(len(y)) * 0.0502883 - 0.492773
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y, xmin=20, xmax=150)
+
+ fit.loadtheories(fittheories)
+
+ fit.settheory('ahypermet')
+ fit.configure(Yscaling=1.,
+ PositiveFwhmFlag=True,
+ PositiveHeightAreaFlag=True,
+ FwhmPoints=16,
+ QuotedPositionFlag=1,
+ HypermetTails=1)
+ fit.setbackground('Linear')
+ fit.estimate()
+ fit.runfit()
+ tab.fillFromFit(fit.fit_results)
+ tab.show()
+ app.exec()
+
+if __name__ == "__main__":
+ main(sys.argv)
diff --git a/silx/gui/fit/__init__.py b/src/silx/gui/fit/__init__.py
index e4fd3ab..e4fd3ab 100644
--- a/silx/gui/fit/__init__.py
+++ b/src/silx/gui/fit/__init__.py
diff --git a/silx/gui/fit/setup.py b/src/silx/gui/fit/setup.py
index 6672363..6672363 100644
--- a/silx/gui/fit/setup.py
+++ b/src/silx/gui/fit/setup.py
diff --git a/src/silx/gui/fit/test/__init__.py b/src/silx/gui/fit/test/__init__.py
new file mode 100644
index 0000000..71128fb
--- /dev/null
+++ b/src/silx/gui/fit/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/fit/test/testBackgroundWidget.py b/src/silx/gui/fit/test/testBackgroundWidget.py
new file mode 100644
index 0000000..b8570f7
--- /dev/null
+++ b/src/silx/gui/fit/test/testBackgroundWidget.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from .. import BackgroundWidget
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+class TestBackgroundWidget(TestCaseQt):
+ def setUp(self):
+ super(TestBackgroundWidget, self).setUp()
+ self.bgdialog = BackgroundWidget.BackgroundDialog()
+ self.bgdialog.setData(list([0, 1, 2, 3]),
+ list([0, 1, 4, 8]))
+ self.qWaitForWindowExposed(self.bgdialog)
+
+ def tearDown(self):
+ del self.bgdialog
+ super(TestBackgroundWidget, self).tearDown()
+
+ def testShow(self):
+ self.bgdialog.show()
+ self.bgdialog.hide()
+
+ def testAccept(self):
+ self.bgdialog.accept()
+ self.assertTrue(self.bgdialog.result())
+
+ def testReject(self):
+ self.bgdialog.reject()
+ self.assertFalse(self.bgdialog.result())
+
+ def testDefaultOutput(self):
+ self.bgdialog.accept()
+ output = self.bgdialog.output
+
+ for key in ["algorithm", "StripThreshold", "SnipWidth",
+ "StripIterations", "StripWidth", "SmoothingFlag",
+ "SmoothingWidth", "AnchorsFlag", "AnchorsList"]:
+ self.assertIn(key, output)
+
+ self.assertFalse(output["AnchorsFlag"])
+ self.assertEqual(output["StripWidth"], 1)
+ self.assertEqual(output["SmoothingFlag"], False)
+ self.assertEqual(output["SmoothingWidth"], 3)
diff --git a/src/silx/gui/fit/test/testFitConfig.py b/src/silx/gui/fit/test/testFitConfig.py
new file mode 100644
index 0000000..53da2dd
--- /dev/null
+++ b/src/silx/gui/fit/test/testFitConfig.py
@@ -0,0 +1,84 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for :class:`FitConfig`"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+from .. import FitConfig
+
+
+class TestFitConfig(TestCaseQt):
+ """Basic test for FitWidget"""
+
+ def setUp(self):
+ super(TestFitConfig, self).setUp()
+ self.fit_config = FitConfig.getFitConfigDialog(modal=False)
+ self.qWaitForWindowExposed(self.fit_config)
+
+ def tearDown(self):
+ del self.fit_config
+ super(TestFitConfig, self).tearDown()
+
+ def testShow(self):
+ self.fit_config.show()
+ self.fit_config.hide()
+
+ def testAccept(self):
+ self.fit_config.accept()
+ self.assertTrue(self.fit_config.result())
+
+ def testReject(self):
+ self.fit_config.reject()
+ self.assertFalse(self.fit_config.result())
+
+ def testDefaultOutput(self):
+ self.fit_config.accept()
+ output = self.fit_config.output
+
+ for key in ["AutoFwhm",
+ "PositiveHeightAreaFlag",
+ "QuotedPositionFlag",
+ "PositiveFwhmFlag",
+ "SameFwhmFlag",
+ "QuotedEtaFlag",
+ "NoConstraintsFlag",
+ "FwhmPoints",
+ "Sensitivity",
+ "Yscaling",
+ "ForcePeakPresence",
+ "StripBackgroundFlag",
+ "StripWidth",
+ "StripIterations",
+ "StripThreshold",
+ "SmoothingFlag"]:
+ self.assertIn(key, output)
+
+ self.assertTrue(output["AutoFwhm"])
+ self.assertEqual(output["StripWidth"], 2)
diff --git a/src/silx/gui/fit/test/testFitWidget.py b/src/silx/gui/fit/test/testFitWidget.py
new file mode 100644
index 0000000..abe9d89
--- /dev/null
+++ b/src/silx/gui/fit/test/testFitWidget.py
@@ -0,0 +1,124 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for :class:`FitWidget`"""
+
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from ... import qt
+from .. import FitWidget
+
+from ....math.fit.fittheory import FitTheory
+from ....math.fit.fitmanager import FitManager
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+class TestFitWidget(TestCaseQt):
+ """Basic test for FitWidget"""
+
+ def setUp(self):
+ super(TestFitWidget, self).setUp()
+ self.fit_widget = FitWidget()
+ self.fit_widget.show()
+ self.qWaitForWindowExposed(self.fit_widget)
+
+ def tearDown(self):
+ self.fit_widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.fit_widget.close()
+ del self.fit_widget
+ super(TestFitWidget, self).tearDown()
+
+ def testShow(self):
+ pass
+
+ def testInteract(self):
+ self.mouseClick(self.fit_widget, qt.Qt.LeftButton)
+ self.keyClick(self.fit_widget, qt.Qt.Key_Enter)
+ self.qapp.processEvents()
+
+ def testCustomConfigWidget(self):
+ class CustomConfigWidget(qt.QDialog):
+ def __init__(self):
+ qt.QDialog.__init__(self)
+ self.setModal(True)
+ self.ok = qt.QPushButton("ok", self)
+ self.ok.clicked.connect(self.accept)
+ cancel = qt.QPushButton("cancel", self)
+ cancel.clicked.connect(self.reject)
+ layout = qt.QVBoxLayout(self)
+ layout.addWidget(self.ok)
+ layout.addWidget(cancel)
+ self.output = {"hello": "world"}
+
+ def fitfun(x, a, b):
+ return a * x + b
+
+ x = list(range(0, 100))
+ y = [fitfun(x_, 2, 3) for x_ in x]
+
+ def conf(**kw):
+ return {"spam": "eggs",
+ "hello": "world!"}
+
+ theory = FitTheory(
+ function=fitfun,
+ parameters=["a", "b"],
+ configure=conf)
+
+ fitmngr = FitManager()
+ fitmngr.setdata(x, y)
+ fitmngr.addtheory("foo", theory)
+ fitmngr.addtheory("bar", theory)
+ fitmngr.addbgtheory("spam", theory)
+
+ fw = FitWidget(fitmngr=fitmngr)
+ fw.associateConfigDialog("spam", CustomConfigWidget(),
+ theory_is_background=True)
+ fw.associateConfigDialog("foo", CustomConfigWidget())
+ fw.show()
+ self.qWaitForWindowExposed(fw)
+
+ fw.bgconfigdialogs["spam"].accept()
+ self.assertTrue(fw.bgconfigdialogs["spam"].result())
+
+ self.assertEqual(fw.bgconfigdialogs["spam"].output,
+ {"hello": "world"})
+
+ fw.bgconfigdialogs["spam"].reject()
+ self.assertFalse(fw.bgconfigdialogs["spam"].result())
+
+ fw.configdialogs["foo"].accept()
+ self.assertTrue(fw.configdialogs["foo"].result())
+
+ # todo: figure out how to click fw.configdialog.ok to close dialog
+ # open dialog
+ # self.mouseClick(fw.guiConfig.FunConfigureButton, qt.Qt.LeftButton)
+ # clove dialog
+ # self.mouseClick(fw.configdialogs["foo"].ok, qt.Qt.LeftButton)
+ # self.qapp.processEvents()
diff --git a/src/silx/gui/hdf5/Hdf5Formatter.py b/src/silx/gui/hdf5/Hdf5Formatter.py
new file mode 100644
index 0000000..6c3de41
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5Formatter.py
@@ -0,0 +1,240 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a class sharred by widgets to format HDF5 data as
+text."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "06/06/2018"
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.data.TextFormatter import TextFormatter
+
+import h5py
+
+
+class Hdf5Formatter(qt.QObject):
+ """Formatter to convert HDF5 data to string.
+ """
+
+ formatChanged = qt.Signal()
+ """Emitted when properties of the formatter change."""
+
+ def __init__(self, parent=None, textFormatter=None):
+ """
+ Constructor
+
+ :param qt.QObject parent: Owner of the object
+ :param TextFormatter formatter: Text formatter
+ """
+ qt.QObject.__init__(self, parent)
+ if textFormatter is not None:
+ self.__formatter = textFormatter
+ else:
+ self.__formatter = TextFormatter(self)
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+
+ def textFormatter(self):
+ """Returns the used text formatter
+
+ :rtype: TextFormatter
+ """
+ return self.__formatter
+
+ def setTextFormatter(self, textFormatter):
+ """Set the text formatter to be used
+
+ :param TextFormatter textFormatter: The text formatter to use
+ """
+ if textFormatter is None:
+ raise ValueError("Formatter expected but None found")
+ if self.__formatter is textFormatter:
+ return
+ self.__formatter.formatChanged.disconnect(self.__formatChanged)
+ self.__formatter = textFormatter
+ self.__formatter.formatChanged.connect(self.__formatChanged)
+ self.__formatChanged()
+
+ def __formatChanged(self):
+ self.formatChanged.emit()
+
+ def humanReadableShape(self, dataset):
+ if dataset.shape is None:
+ return "none"
+ if dataset.shape == tuple():
+ return "scalar"
+ shape = [str(i) for i in dataset.shape]
+ text = u" \u00D7 ".join(shape)
+ return text
+
+ def humanReadableValue(self, dataset):
+ if dataset.shape is None:
+ return "No data"
+
+ dtype = dataset.dtype
+ if dataset.dtype.type == numpy.void:
+ if dtype.fields is None:
+ return "Raw data"
+
+ if dataset.shape == tuple():
+ numpy_object = dataset[()]
+ text = self.__formatter.toString(numpy_object, dtype=dataset.dtype)
+ else:
+ if dataset.size < 5 and dataset.compression is None:
+ numpy_object = dataset[0:5]
+ text = self.__formatter.toString(numpy_object, dtype=dataset.dtype)
+ else:
+ dimension = len(dataset.shape)
+ if dataset.compression is not None:
+ text = "Compressed %dD data" % dimension
+ else:
+ text = "%dD data" % dimension
+ return text
+
+ def humanReadableType(self, dataset, full=False):
+ if hasattr(dataset, "dtype"):
+ dtype = dataset.dtype
+ else:
+ # Fallback...
+ dtype = type(dataset)
+ return self.humanReadableDType(dtype, full)
+
+ def humanReadableDType(self, dtype, full=False):
+ if dtype == bytes or numpy.issubdtype(dtype, numpy.string_):
+ text = "string"
+ if full:
+ text = "ASCII " + text
+ return text
+ elif dtype == str or numpy.issubdtype(dtype, numpy.unicode_):
+ text = "string"
+ if full:
+ text = "UTF-8 " + text
+ return text
+ elif dtype.type == numpy.object_:
+ ref = h5py.check_dtype(ref=dtype)
+ if ref is not None:
+ return "reference"
+ vlen = h5py.check_dtype(vlen=dtype)
+ if vlen is not None:
+ text = self.humanReadableDType(vlen, full=full)
+ if full:
+ text = "variable-length " + text
+ return text
+ return "object"
+ elif dtype.type == numpy.bool_:
+ return "bool"
+ elif dtype.type == numpy.void:
+ if dtype.fields is None:
+ return "opaque"
+ else:
+ if not full:
+ return "compound"
+ else:
+ fields = sorted(dtype.fields.items(), key=lambda e: e[1][1])
+ compound = [d[1][0] for d in fields]
+ compound = [self.humanReadableDType(d) for d in compound]
+ return "compound(%s)" % ", ".join(compound)
+ elif numpy.issubdtype(dtype, numpy.integer):
+ enumType = h5py.check_dtype(enum=dtype)
+ if enumType is not None:
+ return "enum"
+
+ text = str(dtype.newbyteorder('N'))
+ if numpy.issubdtype(dtype, numpy.floating):
+ if hasattr(numpy, "float128") and dtype == numpy.float128:
+ text = "float80"
+ if full:
+ text += " (padding 128bits)"
+ elif hasattr(numpy, "float96") and dtype == numpy.float96:
+ text = "float80"
+ if full:
+ text += " (padding 96bits)"
+
+ if full:
+ if dtype.byteorder == "<":
+ text = "Little-endian " + text
+ elif dtype.byteorder == ">":
+ text = "Big-endian " + text
+ elif dtype.byteorder == "=":
+ text = "Native " + text
+
+ dtype = dtype.newbyteorder('N')
+ return text
+
+ def humanReadableHdf5Type(self, dataset):
+ """Format the internal HDF5 type as a string"""
+ t = dataset.id.get_type()
+ class_ = t.get_class()
+ if class_ == h5py.h5t.NO_CLASS:
+ return "NO_CLASS"
+ elif class_ == h5py.h5t.INTEGER:
+ return "INTEGER"
+ elif class_ == h5py.h5t.FLOAT:
+ return "FLOAT"
+ elif class_ == h5py.h5t.TIME:
+ return "TIME"
+ elif class_ == h5py.h5t.STRING:
+ charset = t.get_cset()
+ strpad = t.get_strpad()
+ text = ""
+
+ if strpad == h5py.h5t.STR_NULLTERM:
+ text += "NULLTERM"
+ elif strpad == h5py.h5t.STR_NULLPAD:
+ text += "NULLPAD"
+ elif strpad == h5py.h5t.STR_SPACEPAD:
+ text += "SPACEPAD"
+ else:
+ text += "UNKNOWN_STRPAD"
+
+ if t.is_variable_str():
+ text += " VARIABLE"
+
+ if charset == h5py.h5t.CSET_ASCII:
+ text += " ASCII"
+ elif charset == h5py.h5t.CSET_UTF8:
+ text += " UTF8"
+ else:
+ text += " UNKNOWN_CSET"
+
+ return text + " STRING"
+ elif class_ == h5py.h5t.BITFIELD:
+ return "BITFIELD"
+ elif class_ == h5py.h5t.OPAQUE:
+ return "OPAQUE"
+ elif class_ == h5py.h5t.COMPOUND:
+ return "COMPOUND"
+ elif class_ == h5py.h5t.REFERENCE:
+ return "REFERENCE"
+ elif class_ == h5py.h5t.ENUM:
+ return "ENUM"
+ elif class_ == h5py.h5t.VLEN:
+ return "VLEN"
+ elif class_ == h5py.h5t.ARRAY:
+ return "ARRAY"
+ else:
+ return "UNKNOWN_CLASS"
diff --git a/src/silx/gui/hdf5/Hdf5HeaderView.py b/src/silx/gui/hdf5/Hdf5HeaderView.py
new file mode 100644
index 0000000..7255ce0
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5HeaderView.py
@@ -0,0 +1,184 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "16/06/2017"
+
+
+from .. import qt
+from .Hdf5TreeModel import Hdf5TreeModel
+
+
+class Hdf5HeaderView(qt.QHeaderView):
+ """
+ Default HDF5 header
+
+ Manage auto-resize and context menu to display/hide columns
+ """
+
+ def __init__(self, orientation, parent=None):
+ """
+ Constructor
+
+ :param orientation qt.Qt.Orientation: Orientation of the header
+ :param parent qt.QWidget: Parent of the widget
+ """
+ super(Hdf5HeaderView, self).__init__(orientation, parent)
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested.connect(self.__createContextMenu)
+
+ # default initialization done by QTreeView for it's own header
+ self.setSectionsClickable(True)
+ self.setSectionsMovable(True)
+ self.setDefaultAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ self.setStretchLastSection(True)
+
+ self.__auto_resize = True
+ self.__hide_columns_popup = True
+
+ def setModel(self, model):
+ """Override model to configure view when a model is expected
+
+ `qt.QHeaderView.setSectionResizeMode` expect already existing columns
+ to work.
+
+ :param model qt.QAbstractItemModel: A model
+ """
+ super(Hdf5HeaderView, self).setModel(model)
+ self.__updateAutoResize()
+
+ def __updateAutoResize(self):
+ """Update the view according to the state of the auto-resize"""
+ if self.__auto_resize:
+ self.setSectionResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.ResizeToContents)
+ self.setSectionResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.ResizeToContents)
+ else:
+ self.setSectionResizeMode(Hdf5TreeModel.NAME_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.TYPE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.SHAPE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.VALUE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.DESCRIPTION_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.NODE_COLUMN, qt.QHeaderView.Interactive)
+ self.setSectionResizeMode(Hdf5TreeModel.LINK_COLUMN, qt.QHeaderView.Interactive)
+
+ def setAutoResizeColumns(self, autoResize):
+ """Enable/disable auto-resize. When auto-resized, the header take care
+ of the content of the column to set fixed size of some of them, or to
+ auto fix the size according to the content.
+
+ :param autoResize bool: Enable/disable auto-resize
+ """
+ if self.__auto_resize == autoResize:
+ return
+ self.__auto_resize = autoResize
+ self.__updateAutoResize()
+
+ def hasAutoResizeColumns(self):
+ """Is auto-resize enabled.
+
+ :rtype: bool
+ """
+ return self.__auto_resize
+
+ autoResizeColumns = qt.Property(bool, hasAutoResizeColumns, setAutoResizeColumns)
+ """Property to enable/disable auto-resize."""
+
+ def setEnableHideColumnsPopup(self, enablePopup):
+ """Enable/disable a popup to allow to hide/show each column of the
+ model.
+
+ :param bool enablePopup: Enable/disable popup to hide/show columns
+ """
+ self.__hide_columns_popup = enablePopup
+
+ def hasHideColumnsPopup(self):
+ """Is popup to hide/show columns is enabled.
+
+ :rtype: bool
+ """
+ return self.__hide_columns_popup
+
+ enableHideColumnsPopup = qt.Property(bool, hasHideColumnsPopup, setAutoResizeColumns)
+ """Property to enable/disable popup allowing to hide/show columns."""
+
+ def __genHideSectionEvent(self, column):
+ """Generate a callback which change the column visibility according to
+ the event parameter
+
+ :param int column: logical id of the column
+ :rtype: callable
+ """
+ return lambda checked: self.setSectionHidden(column, not checked)
+
+ def __createContextMenu(self, pos):
+ """Callback to create and display a context menu
+
+ :param pos qt.QPoint: Requested position for the context menu
+ """
+ if not self.__hide_columns_popup:
+ return
+
+ model = self.model()
+ if model.columnCount() > 1:
+ menu = qt.QMenu(self)
+ menu.setTitle("Display/hide columns")
+
+ action = qt.QAction("Display/hide column", self)
+ action.setEnabled(False)
+ menu.addAction(action)
+
+ for column in range(model.columnCount()):
+ if column == 0:
+ # skip the main column
+ continue
+ text = model.headerData(column, qt.Qt.Horizontal, qt.Qt.DisplayRole)
+ action = qt.QAction("%s displayed" % text, self)
+ action.setCheckable(True)
+ action.setChecked(not self.isSectionHidden(column))
+ action.toggled.connect(self.__genHideSectionEvent(column))
+ menu.addAction(action)
+
+ menu.popup(self.viewport().mapToGlobal(pos))
+
+ def setSections(self, logicalIndexes):
+ """
+ Defines order of visible sections by logical indexes.
+
+ Use `Hdf5TreeModel.NAME_COLUMN` to set the list.
+
+ :param list logicalIndexes: List of logical indexes to display
+ """
+ for pos, column_id in enumerate(logicalIndexes):
+ current_pos = self.visualIndex(column_id)
+ self.moveSection(current_pos, pos)
+ self.setSectionHidden(column_id, False)
+ for column_id in set(range(self.model().columnCount())) - set(logicalIndexes):
+ self.setSectionHidden(column_id, True)
diff --git a/silx/gui/hdf5/Hdf5Item.py b/src/silx/gui/hdf5/Hdf5Item.py
index e07f835..e07f835 100755
--- a/silx/gui/hdf5/Hdf5Item.py
+++ b/src/silx/gui/hdf5/Hdf5Item.py
diff --git a/silx/gui/hdf5/Hdf5LoadingItem.py b/src/silx/gui/hdf5/Hdf5LoadingItem.py
index f11d252..f11d252 100644
--- a/silx/gui/hdf5/Hdf5LoadingItem.py
+++ b/src/silx/gui/hdf5/Hdf5LoadingItem.py
diff --git a/silx/gui/hdf5/Hdf5Node.py b/src/silx/gui/hdf5/Hdf5Node.py
index be16535..be16535 100644
--- a/silx/gui/hdf5/Hdf5Node.py
+++ b/src/silx/gui/hdf5/Hdf5Node.py
diff --git a/src/silx/gui/hdf5/Hdf5TreeModel.py b/src/silx/gui/hdf5/Hdf5TreeModel.py
new file mode 100644
index 0000000..a32f7cf
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5TreeModel.py
@@ -0,0 +1,742 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/03/2019"
+
+
+import os
+import logging
+import functools
+from .. import qt
+from .. import icons
+from .Hdf5Node import Hdf5Node
+from .Hdf5Item import Hdf5Item
+from .Hdf5LoadingItem import Hdf5LoadingItem
+from . import _utils
+from ... import io as silx_io
+
+_logger = logging.getLogger(__name__)
+
+
+def _createRootLabel(h5obj):
+ """
+ Create label for the very first npde of the tree.
+
+ :param h5obj: The h5py object to display in the GUI
+ :type h5obj: h5py-like object
+ :rtpye: str
+ """
+ if silx_io.is_file(h5obj):
+ label = os.path.basename(h5obj.filename)
+ else:
+ filename = os.path.basename(h5obj.file.filename)
+ path = h5obj.name
+ if path.startswith("/"):
+ path = path[1:]
+ label = "%s::%s" % (filename, path)
+ return label
+
+
+class LoadingItemRunnable(qt.QRunnable):
+ """Runner to process item loading from a file"""
+
+ class __Signals(qt.QObject):
+ """Signal holder"""
+ itemReady = qt.Signal(object, object, object)
+ runnerFinished = qt.Signal(object)
+
+ def __init__(self, filename, item):
+ """Constructor
+
+ :param LoadingItemWorker worker: Object holding data and signals
+ """
+ super(LoadingItemRunnable, self).__init__()
+ self.filename = filename
+ self.oldItem = item
+ self.signals = self.__Signals()
+
+ def setFile(self, filename, item):
+ self.filenames.append((filename, item))
+
+ @property
+ def itemReady(self):
+ return self.signals.itemReady
+
+ @property
+ def runnerFinished(self):
+ return self.signals.runnerFinished
+
+ def __loadItemTree(self, oldItem, h5obj):
+ """Create an item tree used by the GUI from an h5py object.
+
+ :param Hdf5Node oldItem: The current item displayed the GUI
+ :param h5py.File h5obj: The h5py object to display in the GUI
+ :rtpye: Hdf5Node
+ """
+ text = _createRootLabel(h5obj)
+ item = Hdf5Item(text=text, obj=h5obj, parent=oldItem.parent, populateAll=True)
+ return item
+
+ def run(self):
+ """Process the file loading. The worker is used as holder
+ of the data and the signal. The result is sent as a signal.
+ """
+ h5file = None
+ try:
+ h5file = silx_io.open(self.filename)
+ newItem = self.__loadItemTree(self.oldItem, h5file)
+ error = None
+ except IOError as e:
+ # Should be logged
+ error = e
+ newItem = None
+ if h5file is not None:
+ h5file.close()
+
+ self.itemReady.emit(self.oldItem, newItem, error)
+ self.runnerFinished.emit(self)
+
+ def autoDelete(self):
+ return True
+
+
+class Hdf5TreeModel(qt.QAbstractItemModel):
+ """Tree model storing a list of :class:`h5py.File` like objects.
+
+ The main column display the :class:`h5py.File` list and there hierarchy.
+ Other columns display information on node hierarchy.
+ """
+
+ H5PY_ITEM_ROLE = qt.Qt.UserRole
+ """Role to reach h5py item from an item index"""
+
+ H5PY_OBJECT_ROLE = qt.Qt.UserRole + 1
+ """Role to reach h5py object from an item index"""
+
+ USER_ROLE = qt.Qt.UserRole + 2
+ """Start of range of available user role for derivative models"""
+
+ NAME_COLUMN = 0
+ """Column id containing HDF5 node names"""
+
+ TYPE_COLUMN = 1
+ """Column id containing HDF5 dataset types"""
+
+ SHAPE_COLUMN = 2
+ """Column id containing HDF5 dataset shapes"""
+
+ VALUE_COLUMN = 3
+ """Column id containing HDF5 dataset values"""
+
+ DESCRIPTION_COLUMN = 4
+ """Column id containing HDF5 node description/title/message"""
+
+ NODE_COLUMN = 5
+ """Column id containing HDF5 node type"""
+
+ LINK_COLUMN = 6
+ """Column id containing HDF5 link type"""
+
+ COLUMN_IDS = [
+ NAME_COLUMN,
+ TYPE_COLUMN,
+ SHAPE_COLUMN,
+ VALUE_COLUMN,
+ DESCRIPTION_COLUMN,
+ NODE_COLUMN,
+ LINK_COLUMN,
+ ]
+ """List of logical columns available"""
+
+ sigH5pyObjectLoaded = qt.Signal(object)
+ """Emitted when a new root item was loaded and inserted to the model."""
+
+ sigH5pyObjectRemoved = qt.Signal(object)
+ """Emitted when a root item is removed from the model."""
+
+ sigH5pyObjectSynchronized = qt.Signal(object, object)
+ """Emitted when an item was synchronized."""
+
+ def __init__(self, parent=None, ownFiles=True):
+ """
+ Constructor
+
+ :param qt.QWidget parent: Parent widget
+ :param bool ownFiles: If true (default) the model will manage the files
+ life cycle when they was added using path (like DnD).
+ """
+ super(Hdf5TreeModel, self).__init__(parent)
+
+ self.header_labels = [None] * len(self.COLUMN_IDS)
+ self.header_labels[self.NAME_COLUMN] = 'Name'
+ self.header_labels[self.TYPE_COLUMN] = 'Type'
+ self.header_labels[self.SHAPE_COLUMN] = 'Shape'
+ self.header_labels[self.VALUE_COLUMN] = 'Value'
+ self.header_labels[self.DESCRIPTION_COLUMN] = 'Description'
+ self.header_labels[self.NODE_COLUMN] = 'Node'
+ self.header_labels[self.LINK_COLUMN] = 'Link'
+
+ # Create items
+ self.__root = Hdf5Node()
+ self.__fileDropEnabled = True
+ self.__fileMoveEnabled = True
+ self.__datasetDragEnabled = False
+
+ self.__animatedIcon = icons.getWaitIcon()
+ self.__animatedIcon.iconChanged.connect(self.__updateLoadingItems)
+ self.__runnerSet = set([])
+
+ # store used icons to avoid the cache to release it
+ self.__icons = []
+ self.__icons.append(icons.getQIcon("item-none"))
+ self.__icons.append(icons.getQIcon("item-0dim"))
+ self.__icons.append(icons.getQIcon("item-1dim"))
+ self.__icons.append(icons.getQIcon("item-2dim"))
+ self.__icons.append(icons.getQIcon("item-3dim"))
+ self.__icons.append(icons.getQIcon("item-ndim"))
+
+ self.__ownFiles = ownFiles
+ self.__openedFiles = []
+ """Store the list of files opened by the model itself."""
+ # FIXME: It should be managed one by one by Hdf5Item itself
+
+ # It is not possible to override the QObject destructor nor
+ # to access to the content of the Python object with the `destroyed`
+ # signal cause the Python method was already removed with the QWidget,
+ # while the QObject still exists.
+ # We use a static method plus explicit references to objects to
+ # release. The callback do not use any ref to self.
+ onDestroy = functools.partial(self._closeFileList, self.__openedFiles)
+ self.destroyed.connect(onDestroy)
+
+ @staticmethod
+ def _closeFileList(fileList):
+ """Static method to close explicit references to internal objects."""
+ _logger.debug("Clear Hdf5TreeModel")
+ for obj in fileList:
+ _logger.debug("Close file %s", obj.filename)
+ obj.close()
+ fileList[:] = []
+
+ def _closeOpened(self):
+ """Close files which was opened by this model.
+
+ File are opened by the model when it was inserted using
+ `insertFileAsync`, `insertFile`, `appendFile`."""
+ self._closeFileList(self.__openedFiles)
+
+ def __updateLoadingItems(self, icon):
+ for i in range(self.__root.childCount()):
+ item = self.__root.child(i)
+ if isinstance(item, Hdf5LoadingItem):
+ index1 = self.index(i, 0, qt.QModelIndex())
+ index2 = self.index(i, self.columnCount() - 1, qt.QModelIndex())
+ self.dataChanged.emit(index1, index2)
+
+ def __itemReady(self, oldItem, newItem, error):
+ """Called at the end of a concurent file loading, when the loading
+ item is ready. AN error is defined if an exception occured when
+ loading the newItem .
+
+ :param Hdf5Node oldItem: current displayed item
+ :param Hdf5Node newItem: item loaded, or None if error is defined
+ :param Exception error: An exception, or None if newItem is defined
+ """
+ row = self.__root.indexOfChild(oldItem)
+
+ rootIndex = qt.QModelIndex()
+ self.beginRemoveRows(rootIndex, row, row)
+ self.__root.removeChildAtIndex(row)
+ self.endRemoveRows()
+
+ if newItem is not None:
+ rootIndex = qt.QModelIndex()
+ if self.__ownFiles:
+ self.__openedFiles.append(newItem.obj)
+ self.beginInsertRows(rootIndex, row, row)
+ self.__root.insertChild(row, newItem)
+ self.endInsertRows()
+
+ if isinstance(oldItem, Hdf5LoadingItem):
+ self.sigH5pyObjectLoaded.emit(newItem.obj)
+ else:
+ self.sigH5pyObjectSynchronized.emit(oldItem.obj, newItem.obj)
+
+ # FIXME the error must be displayed
+
+ def isFileDropEnabled(self):
+ return self.__fileDropEnabled
+
+ def setFileDropEnabled(self, enabled):
+ self.__fileDropEnabled = enabled
+
+ fileDropEnabled = qt.Property(bool, isFileDropEnabled, setFileDropEnabled)
+ """Property to enable/disable file dropping in the model."""
+
+ def isDatasetDragEnabled(self):
+ return self.__datasetDragEnabled
+
+ def setDatasetDragEnabled(self, enabled):
+ self.__datasetDragEnabled = enabled
+
+ datasetDragEnabled = qt.Property(bool, isDatasetDragEnabled, setDatasetDragEnabled)
+ """Property to enable/disable drag of datasets."""
+
+ def isFileMoveEnabled(self):
+ return self.__fileMoveEnabled
+
+ def setFileMoveEnabled(self, enabled):
+ self.__fileMoveEnabled = enabled
+
+ fileMoveEnabled = qt.Property(bool, isFileMoveEnabled, setFileMoveEnabled)
+ """Property to enable/disable drag-and-drop of files to
+ change the ordering in the model."""
+
+ def supportedDropActions(self):
+ if self.__fileMoveEnabled or self.__fileDropEnabled:
+ return qt.Qt.CopyAction | qt.Qt.MoveAction
+ else:
+ return 0
+
+ def mimeTypes(self):
+ types = []
+ if self.__fileMoveEnabled or self.__datasetDragEnabled:
+ types.append(_utils.Hdf5DatasetMimeData.MIME_TYPE)
+ return types
+
+ def mimeData(self, indexes):
+ """
+ Returns an object that contains serialized items of data corresponding
+ to the list of indexes specified.
+
+ :param List[qt.QModelIndex] indexes: List of indexes
+ :rtype: qt.QMimeData
+ """
+ if len(indexes) == 0:
+ return None
+
+ indexes = [i for i in indexes if i.column() == 0]
+ if len(indexes) > 1:
+ raise NotImplementedError("Drag of multi rows is not implemented")
+ if len(indexes) == 0:
+ raise NotImplementedError("Drag of cell is not implemented")
+
+ node = self.nodeFromIndex(indexes[0])
+
+ if self.__fileMoveEnabled and node.parent is self.__root:
+ mimeData = _utils.Hdf5DatasetMimeData(node=node, isRoot=True)
+ elif self.__datasetDragEnabled:
+ mimeData = _utils.Hdf5DatasetMimeData(node=node)
+ else:
+ mimeData = None
+ return mimeData
+
+ def flags(self, index):
+ defaultFlags = qt.QAbstractItemModel.flags(self, index)
+
+ if index.isValid():
+ node = self.nodeFromIndex(index)
+ if self.__fileMoveEnabled and node.parent is self.__root:
+ # that's a root
+ return qt.Qt.ItemIsDragEnabled | defaultFlags
+ elif self.__datasetDragEnabled:
+ return qt.Qt.ItemIsDragEnabled | defaultFlags
+ return defaultFlags
+ elif self.__fileDropEnabled or self.__fileMoveEnabled:
+ return qt.Qt.ItemIsDropEnabled | defaultFlags
+ else:
+ return defaultFlags
+
+ def dropMimeData(self, mimedata, action, row, column, parentIndex):
+ if action == qt.Qt.IgnoreAction:
+ return True
+
+ if self.__fileMoveEnabled and mimedata.hasFormat(_utils.Hdf5DatasetMimeData.MIME_TYPE):
+ if mimedata.isRoot():
+ dragNode = mimedata.node()
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not dragNode.parent:
+ return False
+
+ if row == -1:
+ # append to the parent
+ row = parentNode.childCount()
+ else:
+ # insert at row
+ pass
+
+ dragNodeParent = dragNode.parent
+ sourceRow = dragNodeParent.indexOfChild(dragNode)
+ self.moveRow(parentIndex, sourceRow, parentIndex, row)
+ return True
+
+ if self.__fileDropEnabled and mimedata.hasFormat("text/uri-list"):
+
+ parentNode = self.nodeFromIndex(parentIndex)
+ if parentNode is not self.__root:
+ while(parentNode is not self.__root):
+ node = parentNode
+ parentNode = node.parent
+ row = parentNode.indexOfChild(node)
+ else:
+ if row == -1:
+ row = self.__root.childCount()
+
+ messages = []
+ for url in mimedata.urls():
+ try:
+ self.insertFileAsync(url.toLocalFile(), row)
+ row += 1
+ except IOError as e:
+ messages.append(e.args[0])
+ if len(messages) > 0:
+ title = "Error occurred when loading files"
+ message = "<html>%s:<ul><li>%s</li><ul></html>" % (title, "</li><li>".join(messages))
+ qt.QMessageBox.critical(None, title, message)
+ return True
+
+ return False
+
+ def headerData(self, section, orientation, role=qt.Qt.DisplayRole):
+ if orientation == qt.Qt.Horizontal:
+ if role in [qt.Qt.DisplayRole, qt.Qt.EditRole]:
+ return self.header_labels[section]
+ return None
+
+ def insertNode(self, row, node):
+ if row == -1:
+ row = self.__root.childCount()
+ self.beginInsertRows(qt.QModelIndex(), row, row)
+ self.__root.insertChild(row, node)
+ self.endInsertRows()
+
+ def moveRow(self, sourceParentIndex, sourceRow, destinationParentIndex, destinationRow):
+ if sourceRow == destinationRow or sourceRow == destinationRow - 1:
+ # abort move, same place
+ return
+ return self.moveRows(sourceParentIndex, sourceRow, 1, destinationParentIndex, destinationRow)
+
+ def moveRows(self, sourceParentIndex, sourceRow, count, destinationParentIndex, destinationRow):
+ self.beginMoveRows(sourceParentIndex, sourceRow, sourceRow, destinationParentIndex, destinationRow)
+ sourceNode = self.nodeFromIndex(sourceParentIndex)
+ destinationNode = self.nodeFromIndex(destinationParentIndex)
+
+ if sourceNode is destinationNode and sourceRow < destinationRow:
+ item = sourceNode.child(sourceRow)
+ destinationNode.insertChild(destinationRow, item)
+ sourceNode.removeChildAtIndex(sourceRow)
+ else:
+ item = sourceNode.removeChildAtIndex(sourceRow)
+ destinationNode.insertChild(destinationRow, item)
+
+ self.endMoveRows()
+ return True
+
+ def index(self, row, column, parent=qt.QModelIndex()):
+ try:
+ node = self.nodeFromIndex(parent)
+ return self.createIndex(row, column, node.child(row))
+ except IndexError:
+ return qt.QModelIndex()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ node = self.nodeFromIndex(index)
+
+ if role == self.H5PY_ITEM_ROLE:
+ return node
+
+ if role == self.H5PY_OBJECT_ROLE:
+ return node.obj
+
+ if index.column() == self.NAME_COLUMN:
+ return node.dataName(role)
+ elif index.column() == self.TYPE_COLUMN:
+ return node.dataType(role)
+ elif index.column() == self.SHAPE_COLUMN:
+ return node.dataShape(role)
+ elif index.column() == self.VALUE_COLUMN:
+ return node.dataValue(role)
+ elif index.column() == self.DESCRIPTION_COLUMN:
+ return node.dataDescription(role)
+ elif index.column() == self.NODE_COLUMN:
+ return node.dataNode(role)
+ elif index.column() == self.LINK_COLUMN:
+ return node.dataLink(role)
+ else:
+ return None
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return len(self.COLUMN_IDS)
+
+ def hasChildren(self, parent=qt.QModelIndex()):
+ node = self.nodeFromIndex(parent)
+ if node is None:
+ return 0
+ return node.hasChildren()
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ node = self.nodeFromIndex(parent)
+ if node is None:
+ return 0
+ return node.childCount()
+
+ def parent(self, child):
+ if not child.isValid():
+ return qt.QModelIndex()
+
+ node = self.nodeFromIndex(child)
+
+ if node is None:
+ return qt.QModelIndex()
+
+ parent = node.parent
+
+ if parent is None:
+ return qt.QModelIndex()
+
+ grandparent = parent.parent
+ if grandparent is None:
+ return qt.QModelIndex()
+ row = grandparent.indexOfChild(parent)
+
+ assert row != - 1
+ return self.createIndex(row, 0, parent)
+
+ def nodeFromIndex(self, index):
+ return index.internalPointer() if index.isValid() else self.__root
+
+ def _closeFileIfOwned(self, node):
+ """"Close the file if it was loaded from a filename or a
+ drag-and-drop"""
+ obj = node.obj
+ for f in self.__openedFiles:
+ if f is obj:
+ _logger.debug("Close file %s", obj.filename)
+ obj.close()
+ self.__openedFiles.remove(obj)
+
+ def synchronizeIndex(self, index):
+ """
+ Synchronize a file a given its index.
+
+ Basically close it and load it again.
+
+ :param qt.QModelIndex index: Index of the item to update
+ """
+ node = self.nodeFromIndex(index)
+ if node.parent is not self.__root:
+ return
+
+ filename = node.obj.filename
+ self.insertFileAsync(filename, index.row(), synchronizingNode=node)
+
+ def h5pyObjectRow(self, h5pyObject):
+ for row in range(self.__root.childCount()):
+ item = self.__root.child(row)
+ if item.obj == h5pyObject:
+ return row
+ return -1
+
+ def synchronizeH5pyObject(self, h5pyObject):
+ """
+ Synchronize a h5py object in all the tree.
+
+ Basically close it and load it again.
+
+ :param h5py.File h5pyObject: A :class:`h5py.File` object.
+ """
+ index = 0
+ while index < self.__root.childCount():
+ item = self.__root.child(index)
+ if item.obj == h5pyObject:
+ qindex = self.index(index, 0, qt.QModelIndex())
+ self.synchronizeIndex(qindex)
+ index += 1
+
+ def removeIndex(self, index):
+ """
+ Remove an item from the model using its index.
+
+ :param qt.QModelIndex index: Index of the item to remove
+ """
+ node = self.nodeFromIndex(index)
+ if node.parent != self.__root:
+ return
+ self._closeFileIfOwned(node)
+ self.beginRemoveRows(qt.QModelIndex(), index.row(), index.row())
+ self.__root.removeChildAtIndex(index.row())
+ self.endRemoveRows()
+ self.sigH5pyObjectRemoved.emit(node.obj)
+
+ def removeH5pyObject(self, h5pyObject):
+ """
+ Remove an item from the model using the holding h5py object.
+ It can remove more than one item.
+
+ :param h5py.File h5pyObject: A :class:`h5py.File` object.
+ """
+ index = 0
+ while index < self.__root.childCount():
+ item = self.__root.child(index)
+ if item.obj == h5pyObject:
+ qindex = self.index(index, 0, qt.QModelIndex())
+ self.removeIndex(qindex)
+ else:
+ index += 1
+
+ def insertH5pyObject(self, h5pyObject, text=None, row=-1):
+ """Append an HDF5 object from h5py to the tree.
+
+ :param h5pyObject: File handle/descriptor for a :class:`h5py.File`
+ or any other class of h5py file structure.
+ """
+ if text is None:
+ text = _createRootLabel(h5pyObject)
+ if row == -1:
+ row = self.__root.childCount()
+ self.insertNode(row, Hdf5Item(text=text, obj=h5pyObject, parent=self.__root))
+
+ def hasPendingOperations(self):
+ return len(self.__runnerSet) > 0
+
+ def insertFileAsync(self, filename, row=-1, synchronizingNode=None):
+ if not os.path.isfile(filename):
+ raise IOError("Filename '%s' must be a file path" % filename)
+
+ # create temporary item
+ if synchronizingNode is None:
+ text = os.path.basename(filename)
+ item = Hdf5LoadingItem(text=text, parent=self.__root, animatedIcon=self.__animatedIcon)
+ self.insertNode(row, item)
+ else:
+ item = synchronizingNode
+
+ # start loading the real one
+ runnable = LoadingItemRunnable(filename, item)
+ runnable.itemReady.connect(self.__itemReady)
+ runnable.runnerFinished.connect(self.__releaseRunner)
+ self.__runnerSet.add(runnable)
+ qt.silxGlobalThreadPool().start(runnable)
+
+ def __releaseRunner(self, runner):
+ self.__runnerSet.remove(runner)
+
+ def insertFile(self, filename, row=-1):
+ """Load a HDF5 file into the data model.
+
+ :param filename: file path.
+ """
+ try:
+ h5file = silx_io.open(filename)
+ if self.__ownFiles:
+ self.__openedFiles.append(h5file)
+ self.sigH5pyObjectLoaded.emit(h5file)
+ self.insertH5pyObject(h5file, row=row)
+ except IOError:
+ _logger.debug("File '%s' can't be read.", filename, exc_info=True)
+ raise
+
+ def clear(self):
+ """Remove all the content of the model"""
+ for _ in range(self.rowCount()):
+ qindex = self.index(0, 0, qt.QModelIndex())
+ self.removeIndex(qindex)
+
+ def appendFile(self, filename):
+ self.insertFile(filename, -1)
+
+ def indexFromH5Object(self, h5Object):
+ """Returns a model index from an h5py-like object.
+
+ :param object h5Object: An h5py-like object
+ :rtype: qt.QModelIndex
+ """
+ if h5Object is None:
+ return qt.QModelIndex()
+
+ filename = h5Object.file.filename
+
+ # Seach for the right roots
+ rootIndices = []
+ for index in range(self.rowCount(qt.QModelIndex())):
+ index = self.index(index, 0, qt.QModelIndex())
+ obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ if obj.file.filename == filename:
+ # We can have many roots with different subtree of the same
+ # root
+ rootIndices.append(index)
+
+ if len(rootIndices) == 0:
+ # No root found
+ return qt.QModelIndex()
+
+ path = h5Object.name + "/"
+ path = path.replace("//", "/")
+
+ # Search for the right node
+ found = False
+ foundIndices = []
+ for _ in range(1000 * len(rootIndices)):
+ # Avoid too much iterations, in case of recurssive links
+ if len(foundIndices) == 0:
+ if len(rootIndices) == 0:
+ # Nothing found
+ break
+ # Start fron a new root
+ foundIndices.append(rootIndices.pop(0))
+
+ obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ p = obj.name + "/"
+ p = p.replace("//", "/")
+ if path == p:
+ found = True
+ break
+
+ parentIndex = foundIndices[-1]
+ for index in range(self.rowCount(parentIndex)):
+ index = self.index(index, 0, parentIndex)
+ obj = self.data(index, Hdf5TreeModel.H5PY_OBJECT_ROLE)
+
+ p = obj.name + "/"
+ p = p.replace("//", "/")
+ if path == p:
+ foundIndices.append(index)
+ found = True
+ break
+ elif path.startswith(p):
+ foundIndices.append(index)
+ break
+ else:
+ # Nothing found, start again with another root
+ foundIndices = []
+
+ if found:
+ break
+
+ if found:
+ return foundIndices[-1]
+ return qt.QModelIndex()
diff --git a/src/silx/gui/hdf5/Hdf5TreeView.py b/src/silx/gui/hdf5/Hdf5TreeView.py
new file mode 100644
index 0000000..b276618
--- /dev/null
+++ b/src/silx/gui/hdf5/Hdf5TreeView.py
@@ -0,0 +1,269 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/04/2018"
+
+
+import logging
+from .. import qt
+from ...utils import weakref as silxweakref
+from .Hdf5TreeModel import Hdf5TreeModel
+from .Hdf5HeaderView import Hdf5HeaderView
+from .NexusSortFilterProxyModel import NexusSortFilterProxyModel
+from .Hdf5Item import Hdf5Item
+from . import _utils
+
+_logger = logging.getLogger(__name__)
+
+
+class Hdf5TreeView(qt.QTreeView):
+ """TreeView which allow to browse HDF5 file structure.
+
+ .. image:: img/Hdf5TreeView.png
+
+ It provides columns width auto-resizing and additional
+ signals.
+
+ The default model is a :class:`NexusSortFilterProxyModel` sourcing
+ a :class:`Hdf5TreeModel`. The :class:`Hdf5TreeModel` is reachable using
+ :meth:`findHdf5TreeModel`. The default header is :class:`Hdf5HeaderView`.
+
+ Context menu is managed by the :meth:`setContextMenuPolicy` with the value
+ Qt.CustomContextMenu. This policy must not be changed, otherwise context
+ menus will not work anymore. You can use :meth:`addContextMenuCallback` and
+ :meth:`removeContextMenuCallback` to add your custum actions according
+ to the selected objects.
+ """
+ def __init__(self, parent=None):
+ """
+ Constructor
+
+ :param parent qt.QWidget: The parent widget
+ """
+ qt.QTreeView.__init__(self, parent)
+
+ model = self.createDefaultModel()
+ self.setModel(model)
+
+ self.setHeader(Hdf5HeaderView(qt.Qt.Horizontal, self))
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ # optimise the rendering
+ self.setUniformRowHeights(True)
+
+ self.setIconSize(qt.QSize(16, 16))
+ self.setAcceptDrops(True)
+ self.setDragEnabled(True)
+ self.setDragDropMode(qt.QAbstractItemView.DragDrop)
+ self.showDropIndicator()
+
+ self.__context_menu_callbacks = silxweakref.WeakList()
+ self.setContextMenuPolicy(qt.Qt.CustomContextMenu)
+ self.customContextMenuRequested.connect(self._createContextMenu)
+
+ def createDefaultModel(self):
+ """Creates and returns the default model.
+
+ Inherite to custom the default model"""
+ model = Hdf5TreeModel(self)
+ proxy_model = NexusSortFilterProxyModel(self)
+ proxy_model.setSourceModel(model)
+ return proxy_model
+
+ def __removeContextMenuProxies(self, ref):
+ """Callback to remove dead proxy from the list"""
+ self.__context_menu_callbacks.remove(ref)
+
+ def _createContextMenu(self, pos):
+ """
+ Create context menu.
+
+ :param pos qt.QPoint: Position of the context menu
+ """
+ actions = []
+
+ menu = qt.QMenu(self)
+
+ hovered_index = self.indexAt(pos)
+ hovered_node = self.model().data(hovered_index, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if hovered_node is None or not isinstance(hovered_node, Hdf5Item):
+ return
+
+ hovered_object = _utils.H5Node(hovered_node)
+ event = _utils.Hdf5ContextMenuEvent(self, menu, hovered_object)
+
+ for callback in self.__context_menu_callbacks:
+ try:
+ callback(event)
+ except KeyboardInterrupt:
+ raise
+ except Exception:
+ # make sure no user callback crash the application
+ _logger.error("Error while calling callback", exc_info=True)
+ pass
+
+ if not menu.isEmpty():
+ for action in actions:
+ menu.addAction(action)
+ menu.popup(self.viewport().mapToGlobal(pos))
+
+ def addContextMenuCallback(self, callback):
+ """Register a context menu callback.
+
+ The callback will be called when a context menu is requested with the
+ treeview and the list of selected h5py objects in parameters. The
+ callback must return a list of :class:`qt.QAction` object.
+
+ Callbacks are stored as saferef. The object must store a reference by
+ itself.
+ """
+ self.__context_menu_callbacks.append(callback)
+
+ def removeContextMenuCallback(self, callback):
+ """Unregister a context menu callback"""
+ self.__context_menu_callbacks.remove(callback)
+
+ def findHdf5TreeModel(self):
+ """Find the Hdf5TreeModel from the stack of model filters.
+
+ :returns: A Hdf5TreeModel, else None
+ :rtype: Hdf5TreeModel
+ """
+ model = self.model()
+ while model is not None:
+ if isinstance(model, qt.QAbstractProxyModel):
+ model = model.sourceModel()
+ else:
+ break
+ if model is None:
+ return None
+ if isinstance(model, Hdf5TreeModel):
+ return model
+ else:
+ return None
+
+ def dragEnterEvent(self, event):
+ model = self.findHdf5TreeModel()
+ if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ self.setState(qt.QAbstractItemView.DraggingState)
+ event.accept()
+ else:
+ qt.QTreeView.dragEnterEvent(self, event)
+
+ def dragMoveEvent(self, event):
+ model = self.findHdf5TreeModel()
+ if model is not None and model.isFileDropEnabled() and event.mimeData().hasFormat("text/uri-list"):
+ event.setDropAction(qt.Qt.CopyAction)
+ event.accept()
+ else:
+ qt.QTreeView.dragMoveEvent(self, event)
+
+ def selectedH5Nodes(self, ignoreBrokenLinks=True):
+ """Returns selected h5py objects like :class:`h5py.File`,
+ :class:`h5py.Group`, :class:`h5py.Dataset` or mimicked objects.
+
+ :param ignoreBrokenLinks bool: Returns objects which are not not
+ broken links.
+ :rtype: iterator(:class:`_utils.H5Node`)
+ """
+ for index in self.selectedIndexes():
+ if index.column() != 0:
+ continue
+ item = self.model().data(index, Hdf5TreeModel.H5PY_ITEM_ROLE)
+ if item is None:
+ continue
+ if isinstance(item, Hdf5Item):
+ if ignoreBrokenLinks and item.isBrokenObj():
+ continue
+ yield _utils.H5Node(item)
+
+ def __intermediateModels(self, index):
+ """Returns intermediate models from the view model to the
+ model of the index."""
+ models = []
+ targetModel = index.model()
+ model = self.model()
+ while model is not None:
+ if model is targetModel:
+ # found
+ return models
+ models.append(model)
+ if isinstance(model, qt.QAbstractProxyModel):
+ model = model.sourceModel()
+ else:
+ break
+ raise RuntimeError("Model from the requested index is not reachable from this view")
+
+ def mapToModel(self, index):
+ """Map an index from any model reachable by the view to an index from
+ the very first model connected to the view.
+
+ :param qt.QModelIndex index: Index from the Hdf5Tree model
+ :rtype: qt.QModelIndex
+ :return: Index from the model connected to the view
+ """
+ if not index.isValid():
+ return index
+ models = self.__intermediateModels(index)
+ for model in reversed(models):
+ index = model.mapFromSource(index)
+ return index
+
+ def setSelectedH5Node(self, h5Object):
+ """
+ Select the specified node of the tree using an h5py node.
+
+ - If the item is found, parent items are expended, and then the item
+ is selected.
+ - If the item is not found, the selection do not change.
+ - A none argument allow to deselect everything
+
+ :param h5py.Node h5Object: The node to select
+ """
+ if h5Object is None:
+ self.setCurrentIndex(qt.QModelIndex())
+ return
+
+ model = self.findHdf5TreeModel()
+ index = model.indexFromH5Object(h5Object)
+ index = self.mapToModel(index)
+ if index.isValid():
+ # Update the GUI
+ i = index
+ while i.isValid():
+ self.expand(i)
+ i = i.parent()
+ self.setCurrentIndex(index)
+
+ def mousePressEvent(self, event):
+ """Override mousePressEvent to provide a consistante compatible API
+ between Qt4 and Qt5
+ """
+ super(Hdf5TreeView, self).mousePressEvent(event)
+ if event.button() != qt.Qt.LeftButton:
+ qindex = self.indexAt(event.pos())
+ self.clicked.emit(qindex)
diff --git a/silx/gui/hdf5/NexusSortFilterProxyModel.py b/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
index 9c3533f..9c3533f 100644
--- a/silx/gui/hdf5/NexusSortFilterProxyModel.py
+++ b/src/silx/gui/hdf5/NexusSortFilterProxyModel.py
diff --git a/silx/gui/hdf5/__init__.py b/src/silx/gui/hdf5/__init__.py
index 1b5a602..1b5a602 100644
--- a/silx/gui/hdf5/__init__.py
+++ b/src/silx/gui/hdf5/__init__.py
diff --git a/src/silx/gui/hdf5/_utils.py b/src/silx/gui/hdf5/_utils.py
new file mode 100644
index 0000000..8f32252
--- /dev/null
+++ b/src/silx/gui/hdf5/_utils.py
@@ -0,0 +1,461 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a set of helper class and function used by the
+package `silx.gui.hdf5` package.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2019"
+
+
+from html import escape
+import logging
+import os.path
+
+import silx.io.utils
+import silx.io.url
+from .. import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class Hdf5ContextMenuEvent(object):
+ """Hold information provided to context menu callbacks."""
+
+ def __init__(self, source, menu, hoveredObject):
+ """
+ Constructor
+
+ :param QWidget source: Widget source
+ :param QMenu menu: Context menu which will be displayed
+ :param H5Node hoveredObject: Hovered H5 node
+ """
+ self.__source = source
+ self.__menu = menu
+ self.__hoveredObject = hoveredObject
+
+ def source(self):
+ """Source of the event
+
+ :rtype: Hdf5TreeView
+ """
+ return self.__source
+
+ def menu(self):
+ """Menu which will be displayed
+
+ :rtype: qt.QMenu
+ """
+ return self.__menu
+
+ def hoveredObject(self):
+ """Item content hovered by the mouse when the context menu was
+ requested
+
+ :rtype: H5Node
+ """
+ return self.__hoveredObject
+
+
+def htmlFromDict(dictionary, title=None):
+ """Generate a readable HTML from a dictionary
+
+ :param dict dictionary: A Dictionary
+ :rtype: str
+ """
+ result = """<html>
+ <head>
+ <style type="text/css">
+ ul { -qt-list-indent: 0; list-style: none; }
+ li > b {display: inline-block; min-width: 4em; font-weight: bold; }
+ </style>
+ </head>
+ <body>
+ """
+ if title is not None:
+ result += "<b>%s</b>" % escape(title)
+ result += "<ul>"
+ for key, value in dictionary.items():
+ result += "<li><b>%s</b>: %s</li>" % (escape(key), escape(value))
+ result += "</ul>"
+ result += "</body></html>"
+ return result
+
+
+class Hdf5DatasetMimeData(qt.QMimeData):
+ """Mimedata class to identify an internal drag and drop of a Hdf5Node."""
+
+ MIME_TYPE = "application/x-internal-h5py-dataset"
+
+ SILX_URI_TYPE = "application/x-silx-uri"
+
+ def __init__(self, node=None, dataset=None, isRoot=False):
+ qt.QMimeData.__init__(self)
+ self.__dataset = dataset
+ self.__node = node
+ self.__isRoot = isRoot
+ self.setData(self.MIME_TYPE, "".encode(encoding='utf-8'))
+ if node is not None:
+ h5Node = H5Node(node)
+ silxUrl = h5Node.url
+ self.setText(silxUrl)
+ self.setData(self.SILX_URI_TYPE, silxUrl.encode(encoding='utf-8'))
+
+ def isRoot(self):
+ return self.__isRoot
+
+ def node(self):
+ return self.__node
+
+ def dataset(self):
+ if self.__node is not None:
+ return self.__node.obj
+ return self.__dataset
+
+
+class H5Node(object):
+ """Adapter over an h5py object to provide missing informations from h5py
+ nodes, like internal node path and filename (which are not provided by
+ :mod:`h5py` for soft and external links).
+
+ It also provides an abstraction to reach node type for mimicked h5py
+ objects.
+ """
+
+ def __init__(self, h5py_item=None):
+ """Constructor
+
+ :param Hdf5Item h5py_item: An Hdf5Item
+ """
+ self.__h5py_object = h5py_item.obj
+ self.__h5py_target = None
+ self.__h5py_item = h5py_item
+
+ def __getattr__(self, name):
+ if hasattr(self.__h5py_object, name):
+ attr = getattr(self.__h5py_object, name)
+ return attr
+ raise AttributeError("H5Node has no attribute %s" % name)
+
+ def __get_target(self, obj):
+ """
+ Return the actual physical target of the provided object.
+
+ Objects can contains links in the middle of the path, this function
+ check each groups and remove this prefix in case of the link by the
+ link of the path.
+
+ :param obj: A valid h5py object (File, group or dataset)
+ :type obj: h5py.Dataset or h5py.Group or h5py.File
+ :rtype: h5py.Dataset or h5py.Group or h5py.File
+ """
+ elements = obj.name.split("/")
+ if obj.name == "/":
+ return obj
+ elif obj.name.startswith("/"):
+ elements.pop(0)
+ path = ""
+ subpath = ""
+ while len(elements) > 0:
+ e = elements.pop(0)
+ subpath = path + "/" + e
+ link = obj.parent.get(subpath, getlink=True)
+ classlink = silx.io.utils.get_h5_class(link)
+
+ if classlink == silx.io.utils.H5Type.EXTERNAL_LINK:
+ subpath = "/".join(elements)
+ external_obj = obj.parent.get(self.basename + "/" + subpath)
+ return self.__get_target(external_obj)
+ elif classlink == silx.io.utils.H5Type.SOFT_LINK:
+ # Restart from this stat
+ root_elements = link.path.split("/")
+ if link.path == "/":
+ path = ""
+ root_elements = []
+ elif link.path.startswith("/"):
+ path = ""
+ root_elements.pop(0)
+
+ for name in reversed(root_elements):
+ elements.insert(0, name)
+ else:
+ path = subpath
+
+ return obj.file[path]
+
+ @property
+ def h5py_target(self):
+ if self.__h5py_target is not None:
+ return self.__h5py_target
+ self.__h5py_target = self.__get_target(self.__h5py_object)
+ return self.__h5py_target
+
+ @property
+ def h5py_object(self):
+ """Returns the internal h5py node.
+
+ :rtype: h5py.File or h5py.Group or h5py.Dataset
+ """
+ return self.__h5py_object
+
+ @property
+ def h5type(self):
+ """Returns the node type, as an H5Type.
+
+ :rtype: H5Node
+ """
+ return silx.io.utils.get_h5_class(self.__h5py_object)
+
+ @property
+ def ntype(self):
+ """Returns the node type, as an h5py class.
+
+ :rtype:
+ :class:`h5py.File`, :class:`h5py.Group` or :class:`h5py.Dataset`
+ """
+ type_ = self.h5type
+ return silx.io.utils.h5type_to_h5py_class(type_)
+
+ @property
+ def basename(self):
+ """Returns the basename of this h5py node. It is the last identifier of
+ the path.
+
+ :rtype: str
+ """
+ return self.__h5py_object.name.split("/")[-1]
+
+ @property
+ def is_broken(self):
+ """Returns true if the node is a broken link.
+
+ :rtype: bool
+ """
+ if self.__h5py_item is None:
+ raise RuntimeError("h5py_item is not defined")
+ return self.__h5py_item.isBrokenObj()
+
+ @property
+ def local_name(self):
+ """Returns the path from the master file root to this node.
+
+ For links, this path is not equal to the h5py one.
+
+ :rtype: str
+ """
+ if self.__h5py_item is None:
+ raise RuntimeError("h5py_item is not defined")
+
+ result = []
+ item = self.__h5py_item
+ while item is not None:
+ # stop before the root item (item without parent)
+ if item.parent.parent is None:
+ name = item.obj.name
+ if name != "/":
+ result.append(item.obj.name)
+ break
+ else:
+ result.append(item.basename)
+ item = item.parent
+ if item is None:
+ raise RuntimeError("The item does not have parent holding h5py.File")
+ if result == []:
+ return "/"
+ if not result[-1].startswith("/"):
+ result.append("")
+ result.reverse()
+ name = "/".join(result)
+ return name
+
+ def __get_local_file(self):
+ """Returns the file of the root of this tree
+
+ :rtype: h5py.File
+ """
+ item = self.__h5py_item
+ while item.parent.parent is not None:
+ class_ = silx.io.utils.get_h5_class(class_=item.h5pyClass)
+ if class_ == silx.io.utils.H5Type.FILE:
+ break
+ item = item.parent
+
+ class_ = silx.io.utils.get_h5_class(class_=item.h5pyClass)
+ if class_ == silx.io.utils.H5Type.FILE:
+ return item.obj
+ else:
+ return item.obj.file
+
+ @property
+ def local_file(self):
+ """Returns the master file in which is this node.
+
+ For path containing external links, this file is not equal to the h5py
+ one.
+
+ :rtype: h5py.File
+ :raises RuntimeException: If no file are found
+ """
+ return self.__get_local_file()
+
+ @property
+ def local_filename(self):
+ """Returns the filename from the master file of this node.
+
+ For path containing external links, this path is not equal to the
+ filename provided by h5py.
+
+ :rtype: str
+ :raises RuntimeException: If no file are found
+ """
+ return self.local_file.filename
+
+ @property
+ def local_basename(self):
+ """Returns the basename from the master file root to this node.
+
+ For path containing links, this basename can be different than the
+ basename provided by h5py.
+
+ :rtype: str
+ """
+ class_ = self.__h5py_item.h5Class
+ if class_ is not None and class_ == silx.io.utils.H5Type.FILE:
+ return ""
+ return self.__h5py_item.basename
+
+ @property
+ def physical_file(self):
+ """Returns the physical file in which is this node.
+
+ .. versionadded:: 0.6
+
+ :rtype: h5py.File
+ :raises RuntimeError: If no file are found
+ """
+ class_ = silx.io.utils.get_h5_class(self.__h5py_object)
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ # It means the link is broken
+ raise RuntimeError("No file node found")
+ if class_ == silx.io.utils.H5Type.SOFT_LINK:
+ # It means the link is broken
+ return self.local_file
+
+ physical_obj = self.h5py_target
+ return physical_obj.file
+
+ @property
+ def physical_name(self):
+ """Returns the path from the location this h5py node is physically
+ stored.
+
+ For broken links, this filename can be different from the
+ filename provided by h5py.
+
+ :rtype: str
+ """
+ class_ = silx.io.utils.get_h5_class(self.__h5py_object)
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ # It means the link is broken
+ return self.__h5py_object.path
+ if class_ == silx.io.utils.H5Type.SOFT_LINK:
+ # It means the link is broken
+ return self.__h5py_object.path
+
+ physical_obj = self.h5py_target
+ return physical_obj.name
+
+ @property
+ def physical_filename(self):
+ """Returns the filename from the location this h5py node is physically
+ stored.
+
+ For broken links, this filename can be different from the
+ filename provided by h5py.
+
+ :rtype: str
+ """
+ class_ = silx.io.utils.get_h5_class(self.__h5py_object)
+ if class_ == silx.io.utils.H5Type.EXTERNAL_LINK:
+ # It means the link is broken
+ return self.__h5py_object.filename
+ if class_ == silx.io.utils.H5Type.SOFT_LINK:
+ # It means the link is broken
+ return self.local_file.filename
+
+ return self.physical_file.filename
+
+ @property
+ def physical_basename(self):
+ """Returns the basename from the location this h5py node is physically
+ stored.
+
+ For broken links, this basename can be different from the
+ basename provided by h5py.
+
+ :rtype: str
+ """
+ return self.physical_name.split("/")[-1]
+
+ @property
+ def data_url(self):
+ """Returns a :class:`silx.io.url.DataUrl` object identify this node in the file
+ system.
+
+ :rtype: ~silx.io.url.DataUrl
+ """
+ absolute_filename = os.path.abspath(self.local_filename)
+ return silx.io.url.DataUrl(scheme="silx",
+ file_path=absolute_filename,
+ data_path=self.local_name)
+
+ @property
+ def url(self):
+ """Returns an URL object identifying this node in the file
+ system.
+
+ This URL can be used in different ways.
+
+ .. code-block:: python
+
+ # Parsing the URL
+ import silx.io.url
+ dataurl = silx.io.url.DataUrl(item.url)
+ # dataurl provides access to URL fields
+
+ # Open a numpy array
+ import silx.io
+ dataset = silx.io.get_data(item.url)
+
+ # Open an hdf5 object (URL targetting a file or a group)
+ import silx.io
+ with silx.io.open(item.url) as h5:
+ ...your stuff...
+
+ :rtype: str
+ """
+ data_url = self.data_url
+ return data_url.path()
diff --git a/silx/gui/hdf5/setup.py b/src/silx/gui/hdf5/setup.py
index 786a851..786a851 100644
--- a/silx/gui/hdf5/setup.py
+++ b/src/silx/gui/hdf5/setup.py
diff --git a/src/silx/gui/hdf5/test/__init__.py b/src/silx/gui/hdf5/test/__init__.py
new file mode 100644
index 0000000..71128fb
--- /dev/null
+++ b/src/silx/gui/hdf5/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/hdf5/test/test_hdf5.py b/src/silx/gui/hdf5/test/test_hdf5.py
new file mode 100755
index 0000000..9b1b88a
--- /dev/null
+++ b/src/silx/gui/hdf5/test/test_hdf5.py
@@ -0,0 +1,1092 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "12/03/2019"
+
+
+import time
+import os
+import unittest
+import tempfile
+import numpy
+from pkg_resources import parse_version
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import hdf5
+from silx.gui.utils.testutils import SignalListener
+from silx.io import commonh5
+import weakref
+
+import h5py
+import pytest
+
+
+h5py2_9 = parse_version(h5py.version.version) >= parse_version('2.9.0')
+
+
+@pytest.fixture(scope="class")
+def useH5File(request, tmpdir_factory):
+ tmp = tmpdir_factory.mktemp("test_hdf5")
+ request.cls.filename = os.path.join(tmp, "data.h5")
+ # create h5 data
+ with h5py.File(request.cls.filename, "w") as f:
+ g = f.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ yield
+
+
+def create_NXentry(group, name):
+ attrs = {"NX_class": "NXentry"}
+ node = commonh5.Group(name, parent=group, attrs=attrs)
+ group.add_node(node)
+ return node
+
+
+@pytest.mark.usefixtures("useH5File")
+class TestHdf5TreeModel(TestCaseQt):
+
+ def setUp(self):
+ super(TestHdf5TreeModel, self).setUp()
+
+ def waitForPendingOperations(self, model):
+ for _ in range(10):
+ if not model.hasPendingOperations():
+ break
+ self.qWait(10)
+ else:
+ raise RuntimeError("Still waiting for a pending operation")
+
+ @contextmanager
+ def h5TempFile(self):
+ # create tmp file
+ fd, tmp_name = tempfile.mkstemp(suffix=".h5")
+ os.close(fd)
+ # create h5 data
+ h5file = h5py.File(tmp_name, "w")
+ g = h5file.create_group("arrays")
+ g.create_dataset("scalar", data=10)
+ h5file.close()
+ yield tmp_name
+ # clean up
+ os.unlink(tmp_name)
+
+ def testCreate(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertIsNotNone(model)
+
+ def testAppendFilename(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.appendFile(self.filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testAppendBadFilename(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertRaises(IOError, model.appendFile, "#%$")
+
+ def testInsertFilename(self):
+ try:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFile(self.filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertIsNotNone(h5File)
+ finally:
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testInsertFilenameAsync(self):
+ try:
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFileAsync(self.filename)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5LoadingItem.Hdf5LoadingItem)
+ self.waitForPendingOperations(model)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ finally:
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testInsertObject(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+
+ def testRemoveObject(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ model.removeH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+
+ def testSynchronizeObject(self):
+ h5 = h5py.File(self.filename, mode="r")
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(h5)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0, qt.QModelIndex())
+ node1 = model.nodeFromIndex(index)
+ model.synchronizeH5pyObject(h5)
+ self.waitForPendingOperations(model)
+ # Now h5 was loaded from it's filename
+ # Another ref is owned by the model
+ h5.close()
+
+ index = model.index(0, 0, qt.QModelIndex())
+ node2 = model.nodeFromIndex(index)
+ self.assertIsNot(node1, node2)
+ # after sync
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertIsNotNone(h5File)
+ h5File = None
+ # delete the model
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def testFileMoveState(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.isFileMoveEnabled(), True)
+ model.setFileMoveEnabled(False)
+ self.assertEqual(model.isFileMoveEnabled(), False)
+
+ def testFileDropState(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.isFileDropEnabled(), True)
+ model.setFileDropEnabled(False)
+ self.assertEqual(model.isFileDropEnabled(), False)
+
+ def testSupportedDrop(self):
+ model = hdf5.Hdf5TreeModel()
+ self.assertNotEqual(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(False)
+ model.setFileDropEnabled(False)
+ self.assertEqual(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(False)
+ model.setFileDropEnabled(True)
+ self.assertNotEqual(model.supportedDropActions(), 0)
+
+ model.setFileMoveEnabled(True)
+ model.setFileDropEnabled(False)
+ self.assertNotEqual(model.supportedDropActions(), 0)
+
+ def testCloseFile(self):
+ """A file inserted as a filename is open and closed internally."""
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertFile(self.filename)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0)
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ model.removeIndex(index)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ self.assertFalse(bool(h5File.id.valid), "The HDF5 file was not closed")
+
+ def testNotCloseFile(self):
+ """A file inserted as an h5py object is not open (then not closed)
+ internally."""
+ try:
+ h5File = h5py.File(self.filename, mode="r")
+ model = hdf5.Hdf5TreeModel()
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ model.insertH5pyObject(h5File)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ index = model.index(0, 0)
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ model.removeIndex(index)
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 0)
+ self.assertTrue(bool(h5File.id.valid), "The HDF5 file was unexpetedly closed")
+ finally:
+ h5File.close()
+
+ def testDropExternalFile(self):
+ model = hdf5.Hdf5TreeModel()
+ mimeData = qt.QMimeData()
+ mimeData.setUrls([qt.QUrl.fromLocalFile(self.filename)])
+ model.dropMimeData(mimeData, qt.Qt.CopyAction, 0, 0, qt.QModelIndex())
+ self.assertEqual(model.rowCount(qt.QModelIndex()), 1)
+ # after sync
+ self.waitForPendingOperations(model)
+ index = model.index(0, 0, qt.QModelIndex())
+ self.assertIsInstance(model.nodeFromIndex(index), hdf5.Hdf5Item.Hdf5Item)
+ # clean up
+ index = model.index(0, 0, qt.QModelIndex())
+ h5File = model.data(index, role=hdf5.Hdf5TreeModel.H5PY_OBJECT_ROLE)
+ self.assertIsNotNone(h5File)
+ h5File = None
+ ref = weakref.ref(model)
+ model = None
+ self.qWaitForDestroy(ref)
+
+ def getRowDataAsDict(self, model, row):
+ displayed = {}
+ roles = [qt.Qt.DisplayRole, qt.Qt.DecorationRole, qt.Qt.ToolTipRole, qt.Qt.TextAlignmentRole]
+ for column in range(0, model.columnCount(qt.QModelIndex())):
+ index = model.index(0, column, qt.QModelIndex())
+ for role in roles:
+ datum = model.data(index, role)
+ displayed[column, role] = datum
+ return displayed
+
+ def getItemName(self, model, row):
+ index = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, qt.QModelIndex())
+ return model.data(index, qt.Qt.DisplayRole)
+
+ def testFileData(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(h5)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], None)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "File")
+
+ def testGroupData(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ d = h5.create_group("foo")
+ d.attrs["desc"] = "fooo"
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(d)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "fooo")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Group")
+
+ def testDatasetData(self):
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ value = numpy.array([1, 2, 3])
+ d = h5.create_dataset("foo", data=value)
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(d)
+ displayed = self.getRowDataAsDict(model, row=0)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DisplayRole], "1.mock::foo")
+ self.assertIsInstance(displayed[hdf5.Hdf5TreeModel.NAME_COLUMN, qt.Qt.DecorationRole], qt.QIcon)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.TYPE_COLUMN, qt.Qt.DisplayRole], value.dtype.name)
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.SHAPE_COLUMN, qt.Qt.DisplayRole], "3")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.VALUE_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.DESCRIPTION_COLUMN, qt.Qt.DisplayRole], "[1 2 3]")
+ self.assertEqual(displayed[hdf5.Hdf5TreeModel.NODE_COLUMN, qt.Qt.DisplayRole], "Dataset")
+
+ def testDropLastAsFirst(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = commonh5.File("/foo/bar/1.mock", "w")
+ h5_2 = commonh5.File("/foo/bar/2.mock", "w")
+ model.insertH5pyObject(h5_1)
+ model.insertH5pyObject(h5_2)
+ self.assertEqual(self.getItemName(model, 0), "1.mock")
+ self.assertEqual(self.getItemName(model, 1), "2.mock")
+ index = model.index(1, 0, qt.QModelIndex())
+ mimeData = model.mimeData([index])
+ model.dropMimeData(mimeData, qt.Qt.MoveAction, 0, 0, qt.QModelIndex())
+ self.assertEqual(self.getItemName(model, 0), "2.mock")
+ self.assertEqual(self.getItemName(model, 1), "1.mock")
+
+ def testDropFirstAsLast(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = commonh5.File("/foo/bar/1.mock", "w")
+ h5_2 = commonh5.File("/foo/bar/2.mock", "w")
+ model.insertH5pyObject(h5_1)
+ model.insertH5pyObject(h5_2)
+ self.assertEqual(self.getItemName(model, 0), "1.mock")
+ self.assertEqual(self.getItemName(model, 1), "2.mock")
+ index = model.index(0, 0, qt.QModelIndex())
+ mimeData = model.mimeData([index])
+ model.dropMimeData(mimeData, qt.Qt.MoveAction, 2, 0, qt.QModelIndex())
+ self.assertEqual(self.getItemName(model, 0), "2.mock")
+ self.assertEqual(self.getItemName(model, 1), "1.mock")
+
+ def testRootParent(self):
+ model = hdf5.Hdf5TreeModel()
+ h5_1 = commonh5.File("/foo/bar/1.mock", "w")
+ model.insertH5pyObject(h5_1)
+ index = model.index(0, 0, qt.QModelIndex())
+ index = model.parent(index)
+ self.assertEqual(index, qt.QModelIndex())
+
+
+@pytest.mark.usefixtures("useH5File")
+class TestHdf5TreeModelSignals(TestCaseQt):
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.model = hdf5.Hdf5TreeModel()
+ self.h5 = h5py.File(self.filename, mode='r')
+ self.model.insertH5pyObject(self.h5)
+
+ self.listener = SignalListener()
+ self.model.sigH5pyObjectLoaded.connect(self.listener.partial(signal="loaded"))
+ self.model.sigH5pyObjectRemoved.connect(self.listener.partial(signal="removed"))
+ self.model.sigH5pyObjectSynchronized.connect(self.listener.partial(signal="synchronized"))
+
+ def tearDown(self):
+ self.signals = None
+ ref = weakref.ref(self.model)
+ self.model = None
+ self.qWaitForDestroy(ref)
+ self.h5.close()
+ self.h5 = None
+ TestCaseQt.tearDown(self)
+
+ def waitForPendingOperations(self, model):
+ for _ in range(10):
+ if not model.hasPendingOperations():
+ break
+ self.qWait(10)
+ else:
+ raise RuntimeError("Still waiting for a pending operation")
+
+ def testInsert(self):
+ h5 = h5py.File(self.filename, mode='r')
+ self.model.insertH5pyObject(h5)
+ self.assertEqual(self.listener.callCount(), 0)
+
+ def testLoaded(self):
+ self.model.insertFile(self.filename)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(self.listener.karguments(argumentName="signal")[0], "loaded")
+ self.assertIsNot(self.listener.arguments(callIndex=0)[0], self.h5)
+ self.assertEqual(self.listener.arguments(callIndex=0)[0].filename, self.filename)
+
+ def testRemoved(self):
+ self.model.removeH5pyObject(self.h5)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(self.listener.karguments(argumentName="signal")[0], "removed")
+ self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
+
+ def testSynchonized(self):
+ self.model.synchronizeH5pyObject(self.h5)
+ self.waitForPendingOperations(self.model)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertEqual(self.listener.karguments(argumentName="signal")[0], "synchronized")
+ self.assertIs(self.listener.arguments(callIndex=0)[0], self.h5)
+ self.assertIsNot(self.listener.arguments(callIndex=0)[1], self.h5)
+
+
+class TestNexusSortFilterProxyModel(TestCaseQt):
+
+ def getChildNames(self, model, index):
+ count = model.rowCount(index)
+ result = []
+ for row in range(0, count):
+ itemIndex = model.index(row, hdf5.Hdf5TreeModel.NAME_COLUMN, index)
+ name = model.data(itemIndex, qt.Qt.DisplayRole)
+ result.append(name)
+ return result
+
+ def testNXentryStartTime(self):
+ """Test NXentry with start_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a").create_dataset("start_time", data=numpy.string_("2015"))
+ create_NXentry(h5, "b").create_dataset("start_time", data=numpy.string_("2013"))
+ create_NXentry(h5, "c").create_dataset("start_time", data=numpy.string_("2014"))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryStartTimeInArray(self):
+ """Test NXentry with start_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a").create_dataset("start_time", data=numpy.array([numpy.string_("2015")]))
+ create_NXentry(h5, "b").create_dataset("start_time", data=numpy.array([numpy.string_("2013")]))
+ create_NXentry(h5, "c").create_dataset("start_time", data=numpy.array([numpy.string_("2014")]))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryEndTimeInArray(self):
+ """Test NXentry with end_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a").create_dataset("end_time", data=numpy.array([numpy.string_("2015")]))
+ create_NXentry(h5, "b").create_dataset("end_time", data=numpy.array([numpy.string_("2013")]))
+ create_NXentry(h5, "c").create_dataset("end_time", data=numpy.array([numpy.string_("2014")]))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.DescendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "c", "b"])
+
+ def testNXentryName(self):
+ """Test NXentry without start_time or end_time"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ create_NXentry(h5, "a")
+ create_NXentry(h5, "c")
+ create_NXentry(h5, "b")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testStartTime(self):
+ """If it is not NXentry, start_time is not used"""
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a").create_dataset("start_time", data=numpy.string_("2015"))
+ h5.create_group("b").create_dataset("start_time", data=numpy.string_("2013"))
+ h5.create_group("c").create_dataset("start_time", data=numpy.string_("2014"))
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testName(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a")
+ h5.create_group("c")
+ h5.create_group("b")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a", "b", "c"])
+
+ def testNumber(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a1")
+ h5.create_group("a20")
+ h5.create_group("a3")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a1", "a3", "a20"])
+
+ def testMultiNumber(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("a1-1")
+ h5.create_group("a20-1")
+ h5.create_group("a3-1")
+ h5.create_group("a3-20")
+ h5.create_group("a3-3")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["a1-1", "a3-1", "a3-3", "a3-20", "a20-1"])
+
+ def testUnconsistantTypes(self):
+ model = hdf5.Hdf5TreeModel()
+ h5 = commonh5.File("/foo/bar/1.mock", "w")
+ h5.create_group("aaa100")
+ h5.create_group("100aaa")
+ model.insertH5pyObject(h5)
+
+ proxy = hdf5.NexusSortFilterProxyModel()
+ proxy.setSourceModel(model)
+ proxy.sort(0, qt.Qt.AscendingOrder)
+ names = self.getChildNames(proxy, proxy.index(0, 0, qt.QModelIndex()))
+ self.assertListEqual(names, ["100aaa", "aaa100"])
+
+
+@pytest.fixture(scope='class')
+def useH5Model(request, tmpdir_factory):
+ # Create HDF5 files
+ tmp = tmpdir_factory.mktemp("test_hdf5")
+ filename = os.path.join(tmp, "base.h5")
+ extH5FileName = os.path.join(tmp, "base__external.h5")
+ extDatFileName = os.path.join(tmp, "base__external.dat")
+
+ externalh5 = h5py.File(extH5FileName, mode="w")
+ externalh5["target/dataset"] = 50
+ externalh5["target/link"] = h5py.SoftLink("/target/dataset")
+ externalh5["/ext/vds0"] = [0, 1]
+ externalh5["/ext/vds1"] = [2, 3]
+ externalh5.close()
+
+ numpy.array([0,1,10,10,2,3]).tofile(extDatFileName)
+
+ h5 = h5py.File(filename, mode="w")
+ h5["group/dataset"] = 50
+ h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
+ h5["link/soft_link_to_group"] = h5py.SoftLink("/group")
+ h5["link/soft_link_to_link"] = h5py.SoftLink("/link/soft_link")
+ h5["link/soft_link_to_file"] = h5py.SoftLink("/")
+ h5["group/soft_link_relative"] = h5py.SoftLink("dataset")
+ h5["link/external_link"] = h5py.ExternalLink(extH5FileName, "/target/dataset")
+ h5["link/external_link_to_link"] = h5py.ExternalLink(extH5FileName, "/target/link")
+ h5["broken_link/external_broken_file"] = h5py.ExternalLink(extH5FileName + "_not_exists", "/target/link")
+ h5["broken_link/external_broken_link"] = h5py.ExternalLink(extH5FileName, "/target/not_exists")
+ h5["broken_link/soft_broken_link"] = h5py.SoftLink("/group/not_exists")
+ h5["broken_link/soft_link_to_broken_link"] = h5py.SoftLink("/group/not_exists")
+ if h5py2_9:
+ layout = h5py.VirtualLayout((2,2), dtype=int)
+ layout[0] = h5py.VirtualSource("base__external.h5", name="/ext/vds0", shape=(2,), dtype=int)
+ layout[1] = h5py.VirtualSource("base__external.h5", name="/ext/vds1", shape=(2,), dtype=int)
+ h5.create_group("/ext")
+ h5["/ext"].create_virtual_dataset("virtual", layout)
+ external = [("base__external.dat", 0, 2*8), ("base__external.dat", 4*8, 2*8)]
+ h5["/ext"].create_dataset("raw", shape=(2,2), dtype=int, external=external)
+ h5.close()
+
+ with h5py.File(filename, mode="r") as h5File:
+ # Create model
+ request.cls.model = hdf5.Hdf5TreeModel()
+ request.cls.model.insertH5pyObject(h5File)
+ yield
+ ref = weakref.ref(request.cls.model)
+ request.cls.model = None
+ TestCaseQt.qWaitForDestroy(ref)
+
+
+@pytest.mark.usefixtures('useH5Model')
+class _TestModelBase(TestCaseQt):
+ def getIndexFromPath(self, model, path):
+ """
+ :param qt.QAbstractItemModel: model
+ """
+ index = qt.QModelIndex()
+ for name in path:
+ for row in range(model.rowCount(index)):
+ i = model.index(row, 0, index)
+ label = model.data(i)
+ if label == name:
+ index = i
+ break
+ else:
+ raise RuntimeError("Path not found")
+ return index
+
+ def getH5ItemFromPath(self, model, path):
+ index = self.getIndexFromPath(model, path)
+ return model.data(index, hdf5.Hdf5TreeModel.H5PY_ITEM_ROLE)
+
+
+class TestH5Item(_TestModelBase):
+
+ def testFile(self):
+ path = ["base.h5"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testGroup(self):
+ path = ["base.h5", "group"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testDataset(self):
+ path = ["base.h5", "group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testSoftLink(self):
+ path = ["base.h5", "link", "soft_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkToLink(self):
+ path = ["base.h5", "link", "soft_link_to_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkRelative(self):
+ path = ["base.h5", "group", "soft_link_relative"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testExternalLink(self):
+ path = ["base.h5", "link", "external_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalLinkToLink(self):
+ path = ["base.h5", "link", "external_link_to_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalBrokenFile(self):
+ path = ["base.h5", "broken_link", "external_broken_file"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testExternalBrokenLink(self):
+ path = ["base.h5", "broken_link", "external_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "External")
+
+ def testSoftBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testSoftLinkToBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Soft")
+
+ def testDatasetFromSoftLinkToGroup(self):
+ path = ["base.h5", "link", "soft_link_to_group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ def testDatasetFromSoftLinkToFile(self):
+ path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalVirtual(self):
+ path = ["base.h5", "ext", "virtual"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "Virtual")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalRaw(self):
+ path = ["base.h5", "ext", "raw"]
+ h5item = self.getH5ItemFromPath(self.model, path)
+
+ self.assertEqual(h5item.dataLink(qt.Qt.DisplayRole), "ExtRaw")
+
+
+class TestH5Node(_TestModelBase):
+
+ def getH5NodeFromPath(self, model, path):
+ item = self.getH5ItemFromPath(model, path)
+ h5node = hdf5.H5Node(item)
+ return h5node
+
+ def testFile(self):
+ path = ["base.h5"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "")
+ self.assertEqual(h5node.physical_name, "/")
+ self.assertEqual(h5node.local_basename, "")
+ self.assertEqual(h5node.local_name, "/")
+
+ def testGroup(self):
+ path = ["base.h5", "group"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "group")
+ self.assertEqual(h5node.physical_name, "/group")
+ self.assertEqual(h5node.local_basename, "group")
+ self.assertEqual(h5node.local_name, "/group")
+
+ def testDataset(self):
+ path = ["base.h5", "group", "dataset"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "dataset")
+ self.assertEqual(h5node.local_name, "/group/dataset")
+
+ def testSoftLink(self):
+ path = ["base.h5", "link", "soft_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link")
+ self.assertEqual(h5node.local_name, "/link/soft_link")
+
+ def testSoftLinkToLink(self):
+ path = ["base.h5", "link", "soft_link_to_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link_to_link")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_link")
+
+ def testSoftLinkRelative(self):
+ path = ["base.h5", "group", "soft_link_relative"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "soft_link_relative")
+ self.assertEqual(h5node.local_name, "/group/soft_link_relative")
+
+ def testExternalLink(self):
+ path = ["base.h5", "link", "external_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("base__external.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/target/dataset")
+ self.assertEqual(h5node.local_basename, "external_link")
+ self.assertEqual(h5node.local_name, "/link/external_link")
+
+ def testExternalLinkToLink(self):
+ path = ["base.h5", "link", "external_link_to_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("base__external.h5", h5node.physical_filename)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/target/dataset")
+ self.assertEqual(h5node.local_basename, "external_link_to_link")
+ self.assertEqual(h5node.local_name, "/link/external_link_to_link")
+
+ def testExternalBrokenFile(self):
+ path = ["base.h5", "broken_link", "external_broken_file"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("not_exists", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "link")
+ self.assertEqual(h5node.physical_name, "/target/link")
+ self.assertEqual(h5node.local_basename, "external_broken_file")
+ self.assertEqual(h5node.local_name, "/broken_link/external_broken_file")
+
+ def testExternalBrokenLink(self):
+ path = ["base.h5", "broken_link", "external_broken_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertNotEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.local_filename)
+ self.assertIn("__external", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "not_exists")
+ self.assertEqual(h5node.physical_name, "/target/not_exists")
+ self.assertEqual(h5node.local_basename, "external_broken_link")
+ self.assertEqual(h5node.local_name, "/broken_link/external_broken_link")
+
+ def testSoftBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_broken_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "not_exists")
+ self.assertEqual(h5node.physical_name, "/group/not_exists")
+ self.assertEqual(h5node.local_basename, "soft_broken_link")
+ self.assertEqual(h5node.local_name, "/broken_link/soft_broken_link")
+
+ def testSoftLinkToBrokenLink(self):
+ path = ["base.h5", "broken_link", "soft_link_to_broken_link"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "not_exists")
+ self.assertEqual(h5node.physical_name, "/group/not_exists")
+ self.assertEqual(h5node.local_basename, "soft_link_to_broken_link")
+ self.assertEqual(h5node.local_name, "/broken_link/soft_link_to_broken_link")
+
+ def testDatasetFromSoftLinkToGroup(self):
+ path = ["base.h5", "link", "soft_link_to_group", "dataset"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "dataset")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_group/dataset")
+
+ def testDatasetFromSoftLinkToFile(self):
+ path = ["base.h5", "link", "soft_link_to_file", "link", "soft_link_to_group", "dataset"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "dataset")
+ self.assertEqual(h5node.physical_name, "/group/dataset")
+ self.assertEqual(h5node.local_basename, "dataset")
+ self.assertEqual(h5node.local_name, "/link/soft_link_to_file/link/soft_link_to_group/dataset")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalVirtual(self):
+ path = ["base.h5", "ext", "virtual"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "virtual")
+ self.assertEqual(h5node.physical_name, "/ext/virtual")
+ self.assertEqual(h5node.local_basename, "virtual")
+ self.assertEqual(h5node.local_name, "/ext/virtual")
+
+ @pytest.mark.skipif(not h5py2_9, reason="requires h5py>=2.9")
+ def testExternalRaw(self):
+ path = ["base.h5", "ext", "raw"]
+ h5node = self.getH5NodeFromPath(self.model, path)
+
+ self.assertEqual(h5node.physical_filename, h5node.local_filename)
+ self.assertIn("base.h5", h5node.physical_filename)
+ self.assertEqual(h5node.physical_basename, "raw")
+ self.assertEqual(h5node.physical_name, "/ext/raw")
+ self.assertEqual(h5node.local_basename, "raw")
+ self.assertEqual(h5node.local_name, "/ext/raw")
+
+
+class TestHdf5TreeView(TestCaseQt):
+ """Test to check that icons module."""
+
+ def setUp(self):
+ super(TestHdf5TreeView, self).setUp()
+
+ def testCreate(self):
+ view = hdf5.Hdf5TreeView()
+ self.assertIsNotNone(view)
+
+ def testContextMenu(self):
+ view = hdf5.Hdf5TreeView()
+ view._createContextMenu(qt.QPoint(0, 0))
+
+ def testSelection_OriginalModel(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ item = tree.create_group("a/b/c/d")
+ item.create_group("e").create_group("f")
+
+ view = hdf5.Hdf5TreeView()
+ view.findHdf5TreeModel().insertH5pyObject(tree)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_Simple(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ item = tree.create_group("a/b/c/d")
+ item.create_group("e").create_group("f")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_NotFound(self):
+ tree2 = commonh5.File("/foo/bar/2.mock", "w")
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ item = tree.create_group("a/b/c/d")
+ item.create_group("e").create_group("f")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(tree2)
+
+ selection = list(view.selectedH5Nodes())
+ self.assertEqual(len(selection), 0)
+
+ def testSelection_ManyGroupFromSameFile(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group1 = tree.create_group("a1")
+ group2 = tree.create_group("a2")
+ group3 = tree.create_group("a3")
+ group1.create_group("b/c/d")
+ item = group2.create_group("b/c/d")
+ group3.create_group("b/c/d")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(group1)
+ model.insertH5pyObject(group2)
+ model.insertH5pyObject(group3)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_RootFromSubTree(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group = tree.create_group("a1")
+ group.create_group("b/c/d")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(group)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(group)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(group, selected.h5py_object)
+
+ def testSelection_FileFromSubTree(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group = tree.create_group("a1")
+ group.create_group("b").create_group("b").create_group("d")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(group)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(tree)
+
+ selection = list(view.selectedH5Nodes())
+ self.assertEqual(len(selection), 0)
+
+ def testSelection_Tree(self):
+ tree1 = commonh5.File("/foo/bar/1.mock", "w")
+ tree2 = commonh5.File("/foo/bar/2.mock", "w")
+ tree3 = commonh5.File("/foo/bar/3.mock", "w")
+ tree1.create_group("a/b/c")
+ tree2.create_group("a/b/c")
+ tree3.create_group("a/b/c")
+ item = tree2
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree1)
+ model.insertH5pyObject(tree2)
+ model.insertH5pyObject(tree3)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertIs(item, selected.h5py_object)
+
+ def testSelection_RecurssiveLink(self):
+ """
+ Recurssive link selection
+
+ This example is not really working as expected cause commonh5 do not
+ support recurssive links.
+ But item.name == "/a/b" and the result is found.
+ """
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+ group = tree.create_group("a")
+ group.add_node(commonh5.SoftLink("b", "/"))
+
+ item = tree["/a/b/a/b/a/b/a/b/a/b/a/b/a/b/a/b"]
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(item)
+
+ selected = list(view.selectedH5Nodes())[0]
+ self.assertEqual(item.name, selected.h5py_object.name)
+
+ def testSelection_SelectNone(self):
+ tree = commonh5.File("/foo/bar/1.mock", "w")
+
+ model = hdf5.Hdf5TreeModel()
+ model.insertH5pyObject(tree)
+ view = hdf5.Hdf5TreeView()
+ view.setModel(model)
+ view.setSelectedH5Node(tree)
+ view.setSelectedH5Node(None)
+
+ selection = list(view.selectedH5Nodes())
+ self.assertEqual(len(selection), 0)
diff --git a/silx/gui/icons.py b/src/silx/gui/icons.py
index 1493b92..1493b92 100644
--- a/silx/gui/icons.py
+++ b/src/silx/gui/icons.py
diff --git a/src/silx/gui/plot/AlphaSlider.py b/src/silx/gui/plot/AlphaSlider.py
new file mode 100644
index 0000000..da55b1e
--- /dev/null
+++ b/src/silx/gui/plot/AlphaSlider.py
@@ -0,0 +1,300 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module defines slider widgets interacting with the transparency
+of an image on a :class:`PlotWidget`
+
+Classes:
+--------
+
+- :class:`BaseAlphaSlider` (abstract class)
+- :class:`NamedImageAlphaSlider`
+- :class:`ActiveImageAlphaSlider`
+
+Example:
+--------
+
+This widget can, for instance, be added to a plot toolbar.
+
+.. code-block:: python
+
+ import numpy
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.ImageAlphaSlider import NamedImageAlphaSlider
+
+ app = qt.QApplication([])
+ pw = PlotWidget()
+
+ img0 = numpy.arange(200*150).reshape((200, 150))
+ pw.addImage(img0, legend="my background", z=0, origin=(50, 50))
+
+ x, y = numpy.meshgrid(numpy.linspace(-10, 10, 200),
+ numpy.linspace(-10, 5, 150),
+ indexing="ij")
+ img1 = numpy.asarray(numpy.sin(x * y) / (x * y),
+ dtype='float32')
+
+ pw.addImage(img1, legend="my data", z=1,
+ replace=False)
+
+ alpha_slider = NamedImageAlphaSlider(parent=pw,
+ plot=pw,
+ legend="my data")
+ alpha_slider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", pw)
+ toolbar.addWidget(alpha_slider)
+ pw.addToolBar(toolbar)
+
+ pw.show()
+ app.exec()
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/03/2017"
+
+import logging
+
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+
+class BaseAlphaSlider(qt.QSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a plot primitive (image, scatter or curve).
+
+ Internally, the slider stores its state as an integer between
+ 0 and 255. This is the value emitted by the :attr:`valueChanged`
+ signal.
+
+ The method :meth:`getAlpha` returns the corresponding opacity/alpha
+ as a float between 0. and 1. (with a step of :math:`\frac{1}{255}`).
+
+ You must subclass this class and implement :meth:`getItem`.
+ """
+ sigAlphaChanged = qt.Signal(float)
+ """Emits the alpha value when the slider's value changes,
+ as a float between 0. and 1."""
+
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Parent plot widget
+ """
+ assert plot is not None
+ super(BaseAlphaSlider, self).__init__(parent)
+
+ self.plot = plot
+
+ self.setRange(0, 255)
+
+ # if already connected to an item, use its alpha as initial value
+ if self.getItem() is None:
+ self.setValue(255)
+ self.setEnabled(False)
+ else:
+ alpha = self.getItem().getAlpha()
+ self.setValue(round(255*alpha))
+
+ self.valueChanged.connect(self._valueChanged)
+
+ def getItem(self):
+ """You must implement this class to define which item
+ to work on. It must return an item that inherits
+ :class:`silx.gui.plot.items.core.AlphaMixIn`.
+
+ :return: Item on which to operate, or None
+ :rtype: :class:`silx.plot.items.Item`
+ """
+ raise NotImplementedError(
+ "BaseAlphaSlider must be subclassed to " +
+ "implement getItem()")
+
+ def getAlpha(self):
+ """Get the opacity, as a float between 0. and 1.
+
+ :return: Alpha value in [0., 1.]
+ :rtype: float
+ """
+ return self.value() / 255.
+
+ def _valueChanged(self, value):
+ self._updateItem()
+ self.sigAlphaChanged.emit(value / 255.)
+
+ def _updateItem(self):
+ """Update the item's alpha channel.
+ """
+ item = self.getItem()
+ if item is not None:
+ item.setAlpha(self.getAlpha())
+
+
+class ActiveImageAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of the **active image**.
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+
+ See documentation of :class:`BaseAlphaSlider`
+ """
+ def __init__(self, parent=None, plot=None):
+ """
+
+ :param parent: Parent QWidget
+ :param plot: Plot widget on which to operate
+ """
+ super(ActiveImageAlphaSlider, self).__init__(parent, plot)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def getItem(self):
+ return self.plot.getActiveImage()
+
+ def _activeImageChanged(self, previous, new):
+ """Activate or deactivate slider depending on presence of a new
+ active image.
+ Apply transparency value to new active image.
+
+ :param previous: Legend of previous active image, or None
+ :param new: Legend of new active image, or None
+ """
+ if new is not None and not self.isEnabled():
+ self.setEnabled(True)
+ elif new is None and self.isEnabled():
+ self.setEnabled(False)
+
+ self._updateItem()
+
+
+class NamedItemAlphaSlider(BaseAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an item (defined by its kind and legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str kind: Kind of item whose transparency is to be
+ controlled: "scatter", "image" or "curve".
+ :param str legend: Legend of item whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None,
+ kind=None, legend=None):
+ self._item_legend = legend
+ self._item_kind = kind
+
+ super(NamedItemAlphaSlider, self).__init__(parent, plot)
+
+ self._updateState()
+ plot.sigContentChanged.connect(self._onContentChanged)
+
+ def _onContentChanged(self, action, kind, legend):
+ if legend == self._item_legend and kind == self._item_kind:
+ if action == "add":
+ self.setEnabled(True)
+ elif action == "remove":
+ self.setEnabled(False)
+
+ def _updateState(self):
+ """Enable or disable widget based on item's availability."""
+ if self.getItem() is not None:
+ self.setEnabled(True)
+ else:
+ self.setEnabled(False)
+
+ def getItem(self):
+ """Return plot item currently associated to this widget (can be
+ a curve, an image, a scatter...)
+
+ :rtype: subclass of :class:`silx.gui.plot.items.Item`"""
+ if self._item_legend is None or self._item_kind is None:
+ return None
+ return self.plot._getItem(kind=self._item_kind,
+ legend=self._item_legend)
+
+ def setLegend(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getLegend(self):
+ """Return legend of the item currently controlled by this slider.
+
+ :return: Image legend associated to the slider
+ """
+ return self._item_kind
+
+ def setItemKind(self, legend):
+ """Associate a different item (of the same kind) to the slider.
+
+ :param legend: New legend of item whose transparency is to be
+ controlled.
+ """
+ self._item_legend = legend
+ self._updateState()
+
+ def getItemKind(self):
+ """Return kind of the item currently controlled by this slider.
+
+ :return: Item kind ("image", "scatter"...)
+ :rtype: str on None
+ """
+ return self._item_kind
+
+
+class NamedImageAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of an image (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of image whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot,
+ kind="image", legend=legend)
+
+
+class NamedScatterAlphaSlider(NamedItemAlphaSlider):
+ """Slider widget to be used in a plot toolbar to control the
+ transparency of a scatter (defined by its legend).
+
+ :param parent: Parent QWidget
+ :param plot: Plot on which to operate
+ :param str legend: Legend of scatter whose transparency is to be
+ controlled.
+ """
+ def __init__(self, parent=None, plot=None, legend=None):
+ NamedItemAlphaSlider.__init__(self, parent, plot,
+ kind="scatter", legend=legend)
diff --git a/src/silx/gui/plot/ColorBar.py b/src/silx/gui/plot/ColorBar.py
new file mode 100644
index 0000000..8cafc06
--- /dev/null
+++ b/src/silx/gui/plot/ColorBar.py
@@ -0,0 +1,883 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module containing several widgets associated to a colormap.
+"""
+
+__authors__ = ["H. Payno", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import weakref
+import numpy
+
+from ._utils import ticklayout
+from .. import qt
+from ..qt import inspect as qt_inspect
+from silx.gui import colors
+from silx.math.colormap import LogarithmicNormalization
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ColorBarWidget(qt.QWidget):
+ """Colorbar widget displaying a colormap
+
+ It uses a description of colormap as dict compatible with :class:`Plot`.
+
+ .. image:: img/linearColorbar.png
+ :width: 80px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> from silx.gui.plot import Plot2D
+ >>> from silx.gui.plot.ColorBar import ColorBarWidget
+
+ >>> plot = Plot2D() # Create a plot widget
+ >>> plot.show()
+
+ >>> colorbar = ColorBarWidget(plot=plot, legend='Colormap') # Associate the colorbar with it
+ >>> colorbar.show()
+
+ Initializer parameters:
+
+ :param parent: See :class:`QWidget`
+ :param plot: PlotWidget the colorbar is attached to (optional)
+ :param str legend: the label to set to the colorbar
+ """
+ sigVisibleChanged = qt.Signal(bool)
+ """Emitted when the property `visible` have changed."""
+
+ def __init__(self, parent=None, plot=None, legend=None):
+ self._isConnected = False
+ self._plotRef = None
+ self._colormap = None
+ self._data = None
+
+ super(ColorBarWidget, self).__init__(parent)
+
+ self.__buildGUI()
+ self.setLegend(legend)
+ self.setPlot(plot)
+
+ def __buildGUI(self):
+ self.setLayout(qt.QHBoxLayout())
+
+ # create color scale widget
+ self._colorScale = ColorScaleBar(parent=self,
+ colormap=None)
+ self.layout().addWidget(self._colorScale)
+
+ # legend (is the right group)
+ self.legend = _VerticalLegend('', self)
+ self.layout().addWidget(self.legend)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+
+ def getPlot(self):
+ """Returns the :class:`Plot` associated to this widget or None"""
+ return None if self._plotRef is None else self._plotRef()
+
+ def setPlot(self, plot):
+ """Associate a plot to the ColorBar
+
+ :param plot: the plot to associate with the colorbar.
+ If None will remove any connection with a previous plot.
+ """
+ self._disconnectPlot()
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self._connectPlot()
+
+ def _disconnectPlot(self):
+ """Disconnect from Plot signals"""
+ if self._isConnected:
+ self._isConnected = False
+ plot = self.getPlot()
+ if plot is not None and qt_inspect.isValid(plot):
+ plot.sigActiveImageChanged.disconnect(
+ self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChanged)
+ plot.sigPlotSignal.disconnect(self._defaultColormapChanged)
+
+ def _connectPlot(self):
+ """Connect to Plot signals"""
+ plot = self.getPlot()
+ if plot is not None and not self._isConnected:
+ activeImageLegend = plot.getActiveImage(just_legend=True)
+ activeScatterLegend = plot._getActiveItem(
+ kind='scatter', just_legend=True)
+ if activeImageLegend is None and activeScatterLegend is None:
+ # Show plot default colormap
+ self._syncWithDefaultColormap()
+ elif activeImageLegend is not None: # Show active image colormap
+ self._activeImageChanged(None, activeImageLegend)
+ elif activeScatterLegend is not None: # Show active scatter colormap
+ self._activeScatterChanged(None, activeScatterLegend)
+
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._defaultColormapChanged)
+ self._isConnected = True
+
+ def setVisible(self, isVisible):
+ qt.QWidget.setVisible(self, isVisible)
+ self.sigVisibleChanged.emit(isVisible)
+
+ def showEvent(self, event):
+ self._connectPlot()
+
+ def hideEvent(self, event):
+ self._disconnectPlot()
+
+ def getColormap(self):
+ """Returns the colormap displayed in the colorbar.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self.getColorScaleBar().getColormap()
+
+ def setColormap(self, colormap, data=None):
+ """Set the colormap to be displayed.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The colormap to apply on the ColorBarWidget
+ :param Union[numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ The data to display or item, needed if the colormap require an autoscale
+ """
+ self._data = data
+ self.getColorScaleBar().setColormap(colormap=colormap,
+ data=data)
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapHasChanged)
+ self._colormap = colormap
+ if self._colormap is not None:
+ self._colormap.sigChanged.connect(self._colormapHasChanged)
+
+ def _colormapHasChanged(self):
+ """handler of the Colormap.sigChanged signal
+ """
+ assert self._colormap is not None
+ self.setColormap(colormap=self._colormap,
+ data=self._data)
+
+ def setLegend(self, legend):
+ """Set the legend displayed along the colorbar
+
+ :param str legend: The label
+ """
+ if legend is None or legend == "":
+ self.legend.hide()
+ self.legend.setText("")
+ else:
+ assert type(legend) is str
+ self.legend.show()
+ self.legend.setText(legend)
+
+ def getLegend(self):
+ """
+ Returns the legend displayed along the colorbar
+
+ :return: return the legend displayed along the colorbar
+ :rtype: str
+ """
+ return self.legend.text()
+
+ def _activeScatterChanged(self, previous, legend):
+ """Handle plot active scatter changed"""
+ plot = self.getPlot()
+
+ # Do not handle active scatter while there is an image
+ if plot.getActiveImage() is not None:
+ return
+
+ if legend is None: # No active scatter, display no colormap
+ self.setColormap(colormap=None)
+ return
+
+ # Sync with active scatter
+ scatter = plot._getActiveItem(kind='scatter')
+
+ self.setColormap(colormap=scatter.getColormap(),
+ data=scatter)
+
+ def _activeImageChanged(self, previous, legend):
+ """Handle plot active image changed"""
+ plot = self.getPlot()
+
+ if legend is None: # No active image, try with active scatter
+ activeScatterLegend = plot._getActiveItem(
+ kind='scatter', just_legend=True)
+ # No more active image, use active scatter if any
+ self._activeScatterChanged(None, activeScatterLegend)
+ else:
+ # Sync with active image
+ image = plot.getActiveImage()
+
+ # RGB(A) image, display default colormap
+ array = image.getData(copy=False)
+ if array.ndim != 2:
+ self.setColormap(colormap=None)
+ return
+
+ # data image, sync with image colormap
+ # do we need the copy here : used in the case we are changing
+ # vmin and vmax but should have already be done by the plot
+ self.setColormap(colormap=image.getColormap(), data=image)
+
+ def _defaultColormapChanged(self, event):
+ """Handle plot default colormap changed"""
+ if event['event'] == 'defaultColormapChanged':
+ plot = self.getPlot()
+ if (plot is not None and
+ plot.getActiveImage() is None and
+ plot._getActiveItem(kind='scatter') is None):
+ # No active item, take default colormap update into account
+ self._syncWithDefaultColormap()
+
+ def _syncWithDefaultColormap(self):
+ """Update colorbar according to plot default colormap"""
+ self.setColormap(self.getPlot().getDefaultColormap())
+
+ def getColorScaleBar(self):
+ """
+
+ :return: return the :class:`ColorScaleBar` used to display ColorScale
+ and ticks"""
+ return self._colorScale
+
+
+class _VerticalLegend(qt.QLabel):
+ """Display vertically the given text
+ """
+ def __init__(self, text, parent=None):
+ """
+
+ :param text: the legend
+ :param parent: the Qt parent if any
+ """
+ qt.QLabel.__init__(self, text, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ painter.setFont(self.font())
+
+ painter.translate(0, self.rect().height())
+ painter.rotate(270)
+ newRect = qt.QRect(0, 0, self.rect().height(), self.rect().width())
+
+ painter.drawText(newRect, qt.Qt.AlignHCenter, self.text())
+
+ fm = qt.QFontMetrics(self.font())
+ preferedHeight = fm.width(self.text())
+ preferedWidth = fm.height()
+ self.setFixedWidth(preferedWidth)
+ self.setMinimumHeight(preferedHeight)
+
+
+class ColorScaleBar(qt.QWidget):
+ """This class is making the composition of a :class:`_ColorScale` and a
+ :class:`_TickBar`.
+
+ It is the simplest widget displaying ticks and colormap gradient.
+
+ .. image:: img/colorScaleBar.png
+ :width: 150px
+ :align: center
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap = Colormap(name='gray',
+ ... norm='log',
+ ... vmin=1,
+ ... vmax=100000,
+ ... )
+ >>> colorscale = ColorScaleBar(parent=None,
+ ... colormap=colormap )
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param displayTicksValues: display the ticks value or only the '-'
+ """
+
+ _TEXT_MARGIN = 5
+ """The tick bar need a margin to display all labels at the correct place.
+ So the ColorScale should have the same margin in order for both to fit"""
+
+ def __init__(self, parent=None, colormap=None, data=None,
+ displayTicksValues=True):
+ super(ColorScaleBar, self).__init__(parent)
+
+ self.minVal = None
+ """Value set to the _minLabel"""
+ self.maxVal = None
+ """Value set to the _maxLabel"""
+
+ self.setLayout(qt.QGridLayout())
+
+ # create the left side group (ColorScale)
+ self.colorScale = _ColorScale(colormap=colormap,
+ data=data,
+ parent=self,
+ margin=ColorScaleBar._TEXT_MARGIN)
+ if colormap:
+ vmin, vmax = colormap.getColormapRange(data)
+ normalizer = colormap._getNormalizer()
+ else:
+ vmin, vmax = colors.DEFAULT_MIN_LIN, colors.DEFAULT_MAX_LIN
+ normalizer = None
+
+ self.tickbar = _TickBar(vmin=vmin,
+ vmax=vmax,
+ normalizer=normalizer,
+ parent=self,
+ displayValues=displayTicksValues,
+ margin=ColorScaleBar._TEXT_MARGIN)
+
+ self.layout().addWidget(self.tickbar, 1, 0, 1, 1, qt.Qt.AlignRight)
+ self.layout().addWidget(self.colorScale, 1, 1, qt.Qt.AlignLeft)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.layout().setSpacing(0)
+
+ # max label
+ self._maxLabel = qt.QLabel(str(1.0), parent=self)
+ self._maxLabel.setToolTip(str(0.0))
+ self.layout().addWidget(self._maxLabel, 0, 0, 1, 2, qt.Qt.AlignRight)
+
+ # min label
+ self._minLabel = qt.QLabel(str(0.0), parent=self)
+ self._minLabel.setToolTip(str(0.0))
+ self.layout().addWidget(self._minLabel, 2, 0, 1, 2, qt.Qt.AlignRight)
+
+ self.layout().setSizeConstraint(qt.QLayout.SetMinAndMaxSize)
+ self.layout().setColumnStretch(0, 1)
+ self.layout().setRowStretch(1, 1)
+
+ def getTickBar(self):
+ """
+
+ :return: the instanciation of the :class:`_TickBar`
+ """
+ return self.tickbar
+
+ def getColorScale(self):
+ """
+
+ :return: the instanciation of the :class:`_ColorScale`
+ """
+ return self.colorScale
+
+ def getColormap(self):
+ """
+
+ :returns: the colormap.
+ :rtype: :class:`.Colormap`
+ """
+ return self.colorScale.getColormap()
+
+ def setColormap(self, colormap, data=None):
+ """Set the new colormap to be displayed
+
+ :param Colormap colormap: the colormap to set
+ :param Union[numpy.ndarray,~silx.gui.plot.items.Item] data:
+ The data or item to display, needed if the colormap requires an autoscale
+ """
+ self.colorScale.setColormap(colormap, data)
+
+ if colormap is not None:
+ vmin, vmax = colormap.getColormapRange(data)
+ normalizer = colormap._getNormalizer()
+ else:
+ vmin, vmax = None, None
+ normalizer = None
+
+ self.tickbar.update(vmin=vmin,
+ vmax=vmax,
+ normalizer=normalizer)
+ self._setMinMaxLabels(vmin, vmax)
+
+ def setMinMaxVisible(self, val=True):
+ """Change visibility of the min label and the max label
+
+ :param val: if True, set the labels visible, otherwise set it not visible
+ """
+ self._minLabel.setVisible(val)
+ self._maxLabel.setVisible(val)
+
+ def _updateMinMax(self):
+ """Update the min and max label if we are in the case of the
+ configuration 'minMaxValueOnly'"""
+ if self.minVal is None:
+ text, tooltip = '', ''
+ else:
+ if self.minVal == 0 or 0 <= numpy.log10(abs(self.minVal)) < 7:
+ text = '%.7g' % self.minVal
+ else:
+ text = '%.2e' % self.minVal
+ tooltip = repr(self.minVal)
+
+ self._minLabel.setText(text)
+ self._minLabel.setToolTip(tooltip)
+
+ if self.maxVal is None:
+ text, tooltip = '', ''
+ else:
+ if self.maxVal == 0 or 0 <= numpy.log10(abs(self.maxVal)) < 7:
+ text = '%.7g' % self.maxVal
+ else:
+ text = '%.2e' % self.maxVal
+ tooltip = repr(self.maxVal)
+
+ self._maxLabel.setText(text)
+ self._maxLabel.setToolTip(tooltip)
+
+ def _setMinMaxLabels(self, minVal, maxVal):
+ """Change the value of the min and max labels to be displayed.
+
+ :param minVal: the minimal value of the TickBar (not str)
+ :param maxVal: the maximal value of the TickBar (not str)
+ """
+ # bad hack to try to display has much information as possible
+ self.minVal = minVal
+ self.maxVal = maxVal
+ self._updateMinMax()
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self._updateMinMax()
+
+
+class _ColorScale(qt.QWidget):
+ """Widget displaying the colormap colorScale.
+
+ Show matching value between the gradient color (from the colormap) at mouse
+ position and value.
+
+ .. image:: img/colorScale.png
+ :width: 20px
+ :align: center
+
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> colormap = Colormap(name='viridis',
+ ... norm='log',
+ ... vmin=1,
+ ... vmax=100000,
+ ... )
+ >>> colorscale = ColorScale(parent=None,
+ ... colormap=colormap)
+ >>> colorscale.show()
+
+ Initializer parameters :
+
+ :param colormap: the colormap to be displayed
+ :param parent: the Qt parent if any
+ :param int margin: the top and left margin to apply.
+ :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ The data or item to use for getting the range for autoscale colormap.
+
+ .. warning:: Value drawing will be
+ done at the center of ticks. So if no margin is done your values
+ drawing might not be fully done for extrems values.
+ """
+
+ _NB_CONTROL_POINTS = 256
+
+ def __init__(self, colormap, parent=None, margin=5, data=None):
+ qt.QWidget.__init__(self, parent)
+ self._colormap = None
+ self.margin = margin
+ self.setColormap(colormap, data)
+
+ self.setLayout(qt.QVBoxLayout())
+ self.setSizePolicy(qt.QSizePolicy.Fixed, qt.QSizePolicy.Expanding)
+ # needed to get the mouse event without waiting for button click
+ self.setMouseTracking(True)
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self.setMinimumHeight(self._NB_CONTROL_POINTS // 2 + 2 * self.margin)
+ self.setFixedWidth(25)
+
+ def setColormap(self, colormap, data=None):
+ """Set the new colormap to be displayed
+
+ :param dict colormap: the colormap to set
+ :param Union[None,numpy.ndarray,~silx.gui.plot.items.ColormapMixin] data:
+ Optional data for which to compute colormap range.
+ """
+ self._colormap = colormap
+ self.setEnabled(colormap is not None)
+
+ if colormap is None:
+ self.vmin, self.vmax = None, None
+ else:
+ assert colormap.getNormalization() in colors.Colormap.NORMALIZATIONS
+ self.vmin, self.vmax = self._colormap.getColormapRange(data=data)
+ self._updateColorGradient()
+ self.update()
+
+ def getColormap(self):
+ """Returns the colormap
+
+ :rtype: :class:`.Colormap`
+ """
+ return None if self._colormap is None else self._colormap
+
+ def _updateColorGradient(self):
+ """Compute the color gradient"""
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ indices = numpy.linspace(0., 1., self._NB_CONTROL_POINTS)
+ colors = colormap.getNColors(nbColors=self._NB_CONTROL_POINTS)
+ self._gradient = qt.QLinearGradient(0, 1, 0, 0)
+ self._gradient.setCoordinateMode(qt.QGradient.StretchToDeviceMode)
+ self._gradient.setStops(
+ [(i, qt.QColor(*color)) for i, color in zip(indices, colors)]
+ )
+
+ def paintEvent(self, event):
+ """"""
+ painter = qt.QPainter(self)
+ if self.getColormap() is not None:
+ painter.setBrush(self._gradient)
+ penColor = self.palette().color(qt.QPalette.Active,
+ qt.QPalette.WindowText)
+ else:
+ penColor = self.palette().color(qt.QPalette.Disabled,
+ qt.QPalette.WindowText)
+ painter.setPen(penColor)
+
+ painter.drawRect(qt.QRect(
+ 0,
+ self.margin,
+ self.width() - 1,
+ self.height() - 2 * self.margin - 1))
+
+ def mouseMoveEvent(self, event):
+ tooltip = str(self.getValueFromRelativePosition(
+ self._getRelativePosition(event.y())))
+ qt.QToolTip.showText(event.globalPos(), tooltip, self)
+ super(_ColorScale, self).mouseMoveEvent(event)
+
+ def _getRelativePosition(self, yPixel):
+ """yPixel : pixel position into _ColorScale widget reference
+ """
+ # widgets are bottom-top referencial but we display in top-bottom referential
+ return 1. - (yPixel - self.margin) / float(self.height() - 2 * self.margin)
+
+ def getValueFromRelativePosition(self, value):
+ """Return the value in the colorMap from a relative position in the
+ ColorScaleBar (y)
+
+ :param value: float value in [0, 1]
+ :return: the value in [colormap['vmin'], colormap['vmax']]
+ """
+ colormap = self.getColormap()
+ if colormap is None:
+ return
+
+ value = numpy.clip(value, 0., 1.)
+ normalizer = colormap._getNormalizer()
+ normMin, normMax = normalizer.apply([self.vmin, self.vmax], self.vmin, self.vmax)
+
+ return normalizer.revert(
+ normMin + (normMax - normMin) * value, self.vmin, self.vmax)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a TickBar object.
+ This is needed since we can only paint on the viewport of the widget.
+ Didn't work with a simple setContentsMargins
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = int(margin)
+ self.update()
+
+
+class _TickBar(qt.QWidget):
+ """Bar grouping the ticks displayed
+
+ To run the following sample code, a QApplication must be initialized.
+
+ >>> bar = _TickBar(1, 1000, norm='log', parent=None, displayValues=True)
+ >>> bar.show()
+
+ .. image:: img/tickbar.png
+ :width: 40px
+ :align: center
+
+ :param int vmin: smaller value of the range of values
+ :param int vmax: higher value of the range of values
+ :param normalizer: Normalization object.
+ :param parent: the Qt parent if any
+ :param bool displayValues: if True display the values close to the tick,
+ Otherwise only signal it by '-'
+ :param int nticks: the number of tick we want to display. Should be an
+ unsigned int ot None. If None, let the Tick bar find the optimal
+ number of ticks from the tick density.
+ :param int margin: margin to set on the top and bottom
+ """
+ _WIDTH_DISP_VAL = 45
+ """widget width when displayed with ticks labels"""
+ _WIDTH_NO_DISP_VAL = 10
+ """widget width when displayed without ticks labels"""
+ _FONT_SIZE = 10
+ """font size for ticks labels"""
+ _LINE_WIDTH = 10
+ """width of the line to mark a tick"""
+
+ DEFAULT_TICK_DENSITY = 0.015
+
+ def __init__(self, vmin, vmax, normalizer, parent=None, displayValues=True,
+ nticks=None, margin=5):
+ super(_TickBar, self).__init__(parent)
+ self.margin = margin
+ self._nticks = None
+ self.ticks = ()
+ self.subTicks = ()
+ self._forcedDisplayType = None
+ self.ticksDensity = _TickBar.DEFAULT_TICK_DENSITY
+
+ self._vmin = vmin
+ self._vmax = vmax
+ self._normalizer = normalizer
+ self.displayValues = displayValues
+ self.setTicksNumber(nticks)
+
+ self.setMargin(margin)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ self._resetWidth()
+
+ def setTicksValuesVisible(self, val):
+ self.displayValues = val
+ self._resetWidth()
+
+ def _resetWidth(self):
+ width = self._WIDTH_DISP_VAL if self.displayValues else self._WIDTH_NO_DISP_VAL
+ self.setFixedWidth(width)
+
+ def update(self, vmin, vmax, normalizer):
+ self._vmin = vmin
+ self._vmax = vmax
+ self._normalizer = normalizer
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setMargin(self, margin):
+ """Define the margin to fit with a _ColorScale object.
+ This is needed since we can only paint on the viewport of the widget
+
+ :param int margin: the margin to apply on the top and bottom.
+ """
+ self.margin = margin
+
+ def setTicksNumber(self, nticks):
+ """Set the number of ticks to display.
+
+ :param nticks: the number of tick to be display. Should be an
+ unsigned int ot None. If None, let the :class:`_TickBar` find the
+ optimal number of ticks from the tick density.
+ """
+ self._nticks = nticks
+ self.computeTicks()
+ qt.QWidget.update(self)
+
+ def setTicksDensity(self, density):
+ """If you let :class:`_TickBar` deal with the number of ticks
+ (nticks=None) then you can specify a ticks density to be displayed.
+ """
+ if density < 0.0:
+ raise ValueError('Density should be a positive value')
+ self.ticksDensity = density
+
+ def computeTicks(self):
+ """This function compute ticks values labels. It is called at each
+ update and each resize event.
+ Deal only with linear and log scale.
+ """
+ nticks = self._nticks
+ if nticks is None:
+ nticks = self._getOptimalNbTicks()
+
+ if self._vmin == self._vmax:
+ # No range: no ticks
+ self.ticks = ()
+ self.subTicks = ()
+ elif isinstance(self._normalizer, LogarithmicNormalization):
+ self._computeTicksLog(nticks)
+ else: # Fallback: use linear
+ self._computeTicksLin(nticks)
+
+ # update the form
+ font = qt.QFont()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+
+ self.form = self._getFormat(font)
+
+ def _computeTicksLog(self, nticks):
+ logMin = numpy.log10(self._vmin)
+ logMax = numpy.log10(self._vmax)
+ lowBound, highBound, spacing, self._nfrac = ticklayout.niceNumbersForLog10(logMin,
+ logMax,
+ nticks)
+ self.ticks = numpy.power(10., numpy.arange(lowBound, highBound, spacing))
+ if spacing == 1:
+ self.subTicks = ticklayout.computeLogSubTicks(ticks=self.ticks,
+ lowBound=numpy.power(10., lowBound),
+ highBound=numpy.power(10., highBound))
+ else:
+ self.subTicks = []
+
+ def resizeEvent(self, event):
+ qt.QWidget.resizeEvent(self, event)
+ self.computeTicks()
+
+ def _computeTicksLin(self, nticks):
+ _min, _max, _spacing, self._nfrac = ticklayout.niceNumbers(self._vmin,
+ self._vmax,
+ nticks)
+
+ self.ticks = numpy.arange(_min, _max, _spacing)
+ self.subTicks = []
+
+ def _getOptimalNbTicks(self):
+ return max(2, int(round(self.ticksDensity * self.rect().height())))
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+ font = painter.font()
+ font.setPixelSize(_TickBar._FONT_SIZE)
+ painter.setFont(font)
+
+ # paint ticks
+ for val in self.ticks:
+ self._paintTick(val, painter, majorTick=True)
+
+ # paint subticks
+ for val in self.subTicks:
+ self._paintTick(val, painter, majorTick=False)
+
+ def _getRelativePosition(self, val):
+ """Return the relative position of val according to min and max value
+ """
+ if self._normalizer is None:
+ return 0.
+ normMin, normMax, normVal = self._normalizer.apply(
+ [self._vmin, self._vmax, val],
+ self._vmin,
+ self._vmax)
+
+ if normMin == normMax:
+ return 0.
+ else:
+ return 1. - (normVal - normMin) / (normMax - normMin)
+
+ def _paintTick(self, val, painter, majorTick=True):
+ """
+
+ :param bool majorTick: if False will never draw text and will set a line
+ with a smaller width
+ """
+ fm = qt.QFontMetrics(painter.font())
+ viewportHeight = self.rect().height() - self.margin * 2 - 1
+ relativePos = self._getRelativePosition(val)
+ height = int(viewportHeight * relativePos + self.margin)
+ lineWidth = _TickBar._LINE_WIDTH
+ if majorTick is False:
+ lineWidth /= 2
+
+ painter.drawLine(qt.QLine(int(self.width() - lineWidth),
+ height,
+ self.width(),
+ height))
+
+ if self.displayValues and majorTick is True:
+ painter.drawText(qt.QPoint(0, int(height + fm.height() / 2)),
+ self.form.format(val))
+
+ def setDisplayType(self, disType):
+ """Set the type of display we want to set for ticks labels
+
+ :param str disType: The type of display we want to set. disType values
+ can be :
+
+ - 'std' for standard, meaning only a formatting on the number of
+ digits is done
+ - 'e' for scientific display
+ - None to let the _TickBar guess the best display for this kind of data.
+ """
+ if disType not in (None, 'std', 'e'):
+ raise ValueError("display type not recognized, value should be in (None, 'std', 'e'")
+ self._forcedDisplayType = disType
+
+ def _getStandardFormat(self):
+ return "{0:.%sf}" % self._nfrac
+
+ def _getFormat(self, font):
+ if self._forcedDisplayType is None:
+ return self._guessType(font)
+ elif self._forcedDisplayType == 'std':
+ return self._getStandardFormat()
+ elif self._forcedDisplayType == 'e':
+ return self._getScientificForm()
+ else:
+ err = 'Forced type for display %s is not recognized' % self._forcedDisplayType
+ raise ValueError(err)
+
+ def _getScientificForm(self):
+ return "{0:.0e}"
+
+ def _guessType(self, font):
+ """Try fo find the better format to display the tick's labels
+
+ :param QFont font: the font we want to use during the painting
+ """
+ form = self._getStandardFormat()
+
+ fm = qt.QFontMetrics(font)
+ width = 0
+ for tick in self.ticks:
+ width = max(fm.boundingRect(form.format(tick)).width(), width)
+
+ # if the length of the string are too long we are moving to scientific
+ # display
+ if width > _TickBar._WIDTH_DISP_VAL - _TickBar._LINE_WIDTH:
+ return self._getScientificForm()
+ else:
+ return form
diff --git a/silx/gui/plot/Colormap.py b/src/silx/gui/plot/Colormap.py
index 22fea7f..22fea7f 100644
--- a/silx/gui/plot/Colormap.py
+++ b/src/silx/gui/plot/Colormap.py
diff --git a/silx/gui/plot/ColormapDialog.py b/src/silx/gui/plot/ColormapDialog.py
index 7c66cb8..7c66cb8 100644
--- a/silx/gui/plot/ColormapDialog.py
+++ b/src/silx/gui/plot/ColormapDialog.py
diff --git a/silx/gui/plot/Colors.py b/src/silx/gui/plot/Colors.py
index 277e104..277e104 100644
--- a/silx/gui/plot/Colors.py
+++ b/src/silx/gui/plot/Colors.py
diff --git a/src/silx/gui/plot/CompareImages.py b/src/silx/gui/plot/CompareImages.py
new file mode 100644
index 0000000..857fc79
--- /dev/null
+++ b/src/silx/gui/plot/CompareImages.py
@@ -0,0 +1,1259 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A widget dedicated to compare 2 images.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/07/2018"
+
+
+import enum
+import logging
+import numpy
+import weakref
+import collections
+import math
+
+import silx.image.bilinear
+from silx.gui import qt
+from silx.gui import plot
+from silx.gui import icons
+from silx.gui.colors import Colormap
+from silx.gui.plot import tools
+from silx.utils.weakref import WeakMethodProxy
+
+_logger = logging.getLogger(__name__)
+
+from silx.opencl import ocl
+if ocl is not None:
+ try:
+ from silx.opencl import sift
+ except ImportError:
+ # sift module is not available (e.g., in official Debian packages)
+ sift = None
+else: # No OpenCL device or no pyopencl
+ sift = None
+
+
+@enum.unique
+class VisualizationMode(enum.Enum):
+ """Enum for each visualization mode available."""
+ ONLY_A = 'a'
+ ONLY_B = 'b'
+ VERTICAL_LINE = 'vline'
+ HORIZONTAL_LINE = 'hline'
+ COMPOSITE_RED_BLUE_GRAY = "rbgchannel"
+ COMPOSITE_RED_BLUE_GRAY_NEG = "rbgnegchannel"
+ COMPOSITE_A_MINUS_B = "aminusb"
+
+
+@enum.unique
+class AlignmentMode(enum.Enum):
+ """Enum for each alignment mode available."""
+ ORIGIN = 'origin'
+ CENTER = 'center'
+ STRETCH = 'stretch'
+ AUTO = 'auto'
+
+
+AffineTransformation = collections.namedtuple("AffineTransformation",
+ ["tx", "ty", "sx", "sy", "rot"])
+"""Contains a 2D affine transformation: translation, scale and rotation"""
+
+
+class CompareImagesToolBar(qt.QToolBar):
+ """ToolBar containing specific tools to custom the configuration of a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+ def __init__(self, parent=None):
+ qt.QToolBar.__init__(self, parent)
+
+ self.__compareWidget = None
+
+ menu = qt.QMenu(self)
+ self.__visualizationToolButton = qt.QToolButton(self)
+ self.__visualizationToolButton.setMenu(menu)
+ self.__visualizationToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.addWidget(self.__visualizationToolButton)
+ self.__visualizationGroup = qt.QActionGroup(self)
+ self.__visualizationGroup.setExclusive(True)
+ self.__visualizationGroup.triggered.connect(self.__visualizationModeChanged)
+
+ icon = icons.getQIcon("compare-mode-a")
+ action = qt.QAction(icon, "Display the first image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_A))
+ action.setProperty("mode", VisualizationMode.ONLY_A)
+ menu.addAction(action)
+ self.__aModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-b")
+ action = qt.QAction(icon, "Display the second image only", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_B))
+ action.setProperty("mode", VisualizationMode.ONLY_B)
+ menu.addAction(action)
+ self.__bModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-vline")
+ action = qt.QAction(icon, "Vertical compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_V))
+ action.setProperty("mode", VisualizationMode.VERTICAL_LINE)
+ menu.addAction(action)
+ self.__vlineModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-hline")
+ action = qt.QAction(icon, "Horizontal compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_H))
+ action.setProperty("mode", VisualizationMode.HORIZONTAL_LINE)
+ menu.addAction(action)
+ self.__hlineModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rb-channel")
+ action = qt.QAction(icon, "Blue/red compare mode (additive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_C))
+ action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY)
+ menu.addAction(action)
+ self.__brChannelModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-rbneg-channel")
+ action = qt.QAction(icon, "Yellow/cyan compare mode (subtractive mode)", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
+ action.setProperty("mode", VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-mode-a-minus-b")
+ action = qt.QAction(icon, "Raw A minus B compare mode", self)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ action.setShortcut(qt.QKeySequence(qt.Qt.Key_W))
+ action.setProperty("mode", VisualizationMode.COMPOSITE_A_MINUS_B)
+ menu.addAction(action)
+ self.__ycChannelModeAction = action
+ self.__visualizationGroup.addAction(action)
+
+ menu = qt.QMenu(self)
+ self.__alignmentToolButton = qt.QToolButton(self)
+ self.__alignmentToolButton.setMenu(menu)
+ self.__alignmentToolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ self.addWidget(self.__alignmentToolButton)
+ self.__alignmentGroup = qt.QActionGroup(self)
+ self.__alignmentGroup.setExclusive(True)
+ self.__alignmentGroup.triggered.connect(self.__alignmentModeChanged)
+
+ icon = icons.getQIcon("compare-align-origin")
+ action = qt.QAction(icon, "Align images on their upper-left pixel", self)
+ action.setProperty("mode", AlignmentMode.ORIGIN)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__originAlignAction = action
+ menu.addAction(action)
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-align-center")
+ action = qt.QAction(icon, "Center images", self)
+ action.setProperty("mode", AlignmentMode.CENTER)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__centerAlignAction = action
+ menu.addAction(action)
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-align-stretch")
+ action = qt.QAction(icon, "Stretch the second image on the first one", self)
+ action.setProperty("mode", AlignmentMode.STRETCH)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__stretchAlignAction = action
+ menu.addAction(action)
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-align-auto")
+ action = qt.QAction(icon, "Auto-alignment of the second image", self)
+ action.setProperty("mode", AlignmentMode.AUTO)
+ action.setIconVisibleInMenu(True)
+ action.setCheckable(True)
+ self.__autoAlignAction = action
+ menu.addAction(action)
+ if sift is None:
+ action.setEnabled(False)
+ action.setToolTip("Sift module is not available")
+ self.__alignmentGroup.addAction(action)
+
+ icon = icons.getQIcon("compare-keypoints")
+ action = qt.QAction(icon, "Display/hide alignment keypoints", self)
+ action.setCheckable(True)
+ action.triggered.connect(self.__keypointVisibilityChanged)
+ self.addAction(action)
+ self.__displayKeypoints = action
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.sigConfigurationChanged.disconnect(self.__updateSelectedActions)
+ compareWidget = widget
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ widget.sigConfigurationChanged.connect(self.__updateSelectedActions)
+ self.__updateSelectedActions()
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __updateSelectedActions(self):
+ """
+ Update the state of this tool bar according to the state of the
+ connected :class:`CompareImages` widget.
+ """
+ widget = self.getCompareWidget()
+ if widget is None:
+ return
+
+ mode = widget.getVisualizationMode()
+ action = None
+ for a in self.__visualizationGroup.actions():
+ actionMode = a.property("mode")
+ if mode == actionMode:
+ action = a
+ break
+ old = self.__visualizationGroup.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__visualizationGroup.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateVisualizationMenu()
+ self.__visualizationGroup.blockSignals(old)
+
+ mode = widget.getAlignmentMode()
+ action = None
+ for a in self.__alignmentGroup.actions():
+ actionMode = a.property("mode")
+ if mode == actionMode:
+ action = a
+ break
+ old = self.__alignmentGroup.blockSignals(True)
+ if action is not None:
+ # Check this action
+ action.setChecked(True)
+ else:
+ action = self.__alignmentGroup.checkedAction()
+ if action is not None:
+ # Uncheck this action
+ action.setChecked(False)
+ self.__updateAlignmentMenu()
+ self.__alignmentGroup.blockSignals(old)
+
+ def __visualizationModeChanged(self, selectedAction):
+ """Called when user requesting changes of the visualization mode.
+ """
+ self.__updateVisualizationMenu()
+ widget = self.getCompareWidget()
+ if widget is not None:
+ mode = selectedAction.property("mode")
+ widget.setVisualizationMode(mode)
+
+ def __updateVisualizationMenu(self):
+ """Update the state of the action containing visualization menu.
+ """
+ selectedAction = self.__visualizationGroup.checkedAction()
+ if selectedAction is not None:
+ self.__visualizationToolButton.setText(selectedAction.text())
+ self.__visualizationToolButton.setIcon(selectedAction.icon())
+ self.__visualizationToolButton.setToolTip(selectedAction.toolTip())
+ else:
+ self.__visualizationToolButton.setText("")
+ self.__visualizationToolButton.setIcon(qt.QIcon())
+ self.__visualizationToolButton.setToolTip("")
+
+ def __alignmentModeChanged(self, selectedAction):
+ """Called when user requesting changes of the alignment mode.
+ """
+ self.__updateAlignmentMenu()
+ widget = self.getCompareWidget()
+ if widget is not None:
+ mode = selectedAction.property("mode")
+ widget.setAlignmentMode(mode)
+
+ def __updateAlignmentMenu(self):
+ """Update the state of the action containing alignment menu.
+ """
+ selectedAction = self.__alignmentGroup.checkedAction()
+ if selectedAction is not None:
+ self.__alignmentToolButton.setText(selectedAction.text())
+ self.__alignmentToolButton.setIcon(selectedAction.icon())
+ self.__alignmentToolButton.setToolTip(selectedAction.toolTip())
+ else:
+ self.__alignmentToolButton.setText("")
+ self.__alignmentToolButton.setIcon(qt.QIcon())
+ self.__alignmentToolButton.setToolTip("")
+
+ def __keypointVisibilityChanged(self):
+ """Called when action managing keypoints visibility changes"""
+ widget = self.getCompareWidget()
+ if widget is not None:
+ keypointsVisible = self.__displayKeypoints.isChecked()
+ widget.setKeypointsVisible(keypointsVisible)
+
+
+class CompareImagesStatusBar(qt.QStatusBar):
+ """StatusBar containing specific information contained in a
+ :class:`CompareImages` widget
+
+ Use :meth:`setCompareWidget` to connect this toolbar to a specific
+ :class:`CompareImages` widget.
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ """
+ def __init__(self, parent=None):
+ qt.QStatusBar.__init__(self, parent)
+ self.setSizeGripEnabled(False)
+ self.layout().setSpacing(0)
+ self.__compareWidget = None
+ self._label1 = qt.QLabel(self)
+ self._label1.setFrameShape(qt.QFrame.WinPanel)
+ self._label1.setFrameShadow(qt.QFrame.Sunken)
+ self._label2 = qt.QLabel(self)
+ self._label2.setFrameShape(qt.QFrame.WinPanel)
+ self._label2.setFrameShadow(qt.QFrame.Sunken)
+ self._transform = qt.QLabel(self)
+ self._transform.setFrameShape(qt.QFrame.WinPanel)
+ self._transform.setFrameShadow(qt.QFrame.Sunken)
+ self.addWidget(self._label1)
+ self.addWidget(self._label2)
+ self.addWidget(self._transform)
+ self._pos = None
+ self._updateStatusBar()
+
+ def setCompareWidget(self, widget):
+ """
+ Connect this tool bar to a specific :class:`CompareImages` widget.
+
+ :param Union[None,CompareImages] widget: The widget to connect with.
+ """
+ compareWidget = self.getCompareWidget()
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.disconnect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.disconnect(self.__dataChanged)
+ compareWidget = widget
+ if compareWidget is None:
+ self.__compareWidget = None
+ else:
+ self.__compareWidget = weakref.ref(compareWidget)
+ if compareWidget is not None:
+ compareWidget.getPlot().sigPlotSignal.connect(self.__plotSignalReceived)
+ compareWidget.sigConfigurationChanged.connect(self.__dataChanged)
+
+ def getCompareWidget(self):
+ """Returns the connected widget.
+
+ :rtype: CompareImages
+ """
+ if self.__compareWidget is None:
+ return None
+ else:
+ return self.__compareWidget()
+
+ def __plotSignalReceived(self, event):
+ """Called when old style signals at emmited from the plot."""
+ if event["event"] == "mouseMoved":
+ x, y = event["x"], event["y"]
+ self.__mouseMoved(x, y)
+
+ def __mouseMoved(self, x, y):
+ """Called when mouse move over the plot."""
+ self._pos = x, y
+ self._updateStatusBar()
+
+ def __dataChanged(self):
+ """Called when internal data from the connected widget changes."""
+ self._updateStatusBar()
+
+ def _formatData(self, data):
+ """Format pixel of an image.
+
+ It supports intensity, RGB, and RGBA.
+
+ :param Union[int,float,numpy.ndarray,str]: Value of a pixel
+ :rtype: str
+ """
+ if data is None:
+ return "No data"
+ if isinstance(data, (int, numpy.integer)):
+ return "%d" % data
+ if isinstance(data, (float, numpy.floating)):
+ return "%f" % data
+ if isinstance(data, numpy.ndarray):
+ # RGBA value
+ if data.shape == (3,):
+ return "R:%d G:%d B:%d" % (data[0], data[1], data[2])
+ elif data.shape == (4,):
+ return "R:%d G:%d B:%d A:%d" % (data[0], data[1], data[2], data[3])
+ _logger.debug("Unsupported data format %s. Cast it to string.", type(data))
+ return str(data)
+
+ def _updateStatusBar(self):
+ """Update the content of the status bar"""
+ widget = self.getCompareWidget()
+ if widget is None:
+ self._label1.setText("Image1: NA")
+ self._label2.setText("Image2: NA")
+ self._transform.setVisible(False)
+ else:
+ transform = widget.getTransformation()
+ self._transform.setVisible(transform is not None)
+ if transform is not None:
+ has_notable_translation = not numpy.isclose(transform.tx, 0.0, atol=0.01) \
+ or not numpy.isclose(transform.ty, 0.0, atol=0.01)
+ has_notable_scale = not numpy.isclose(transform.sx, 1.0, atol=0.01) \
+ or not numpy.isclose(transform.sy, 1.0, atol=0.01)
+ has_notable_rotation = not numpy.isclose(transform.rot, 0.0, atol=0.01)
+
+ strings = []
+ if has_notable_translation:
+ strings.append("Translation")
+ if has_notable_scale:
+ strings.append("Scale")
+ if has_notable_rotation:
+ strings.append("Rotation")
+ if strings == []:
+ has_translation = not numpy.isclose(transform.tx, 0.0) \
+ or not numpy.isclose(transform.ty, 0.0)
+ has_scale = not numpy.isclose(transform.sx, 1.0) \
+ or not numpy.isclose(transform.sy, 1.0)
+ has_rotation = not numpy.isclose(transform.rot, 0.0)
+ if has_translation or has_scale or has_rotation:
+ text = "No big changes"
+ else:
+ text = "No changes"
+ else:
+ text = "+".join(strings)
+ self._transform.setText("Align: " + text)
+
+ strings = []
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation x: %0.3fpx" % transform.tx)
+ if not numpy.isclose(transform.ty, 0.0):
+ strings.append("Translation y: %0.3fpx" % transform.ty)
+ if not numpy.isclose(transform.sx, 1.0):
+ strings.append("Scale x: %0.3f" % transform.sx)
+ if not numpy.isclose(transform.sy, 1.0):
+ strings.append("Scale y: %0.3f" % transform.sy)
+ if not numpy.isclose(transform.rot, 0.0):
+ strings.append("Rotation: %0.3fdeg" % (transform.rot * 180 / numpy.pi))
+ if strings == []:
+ text = "No transformation"
+ else:
+ text = "\n".join(strings)
+ self._transform.setToolTip(text)
+
+ if self._pos is None:
+ self._label1.setText("Image1: NA")
+ self._label2.setText("Image2: NA")
+ else:
+ data1, data2 = widget.getRawPixelData(self._pos[0], self._pos[1])
+ if isinstance(data1, str):
+ self._label1.setToolTip(data1)
+ text1 = "NA"
+ else:
+ self._label1.setToolTip("")
+ text1 = self._formatData(data1)
+ if isinstance(data2, str):
+ self._label2.setToolTip(data2)
+ text2 = "NA"
+ else:
+ self._label2.setToolTip("")
+ text2 = self._formatData(data2)
+ self._label1.setText("Image1: %s" % text1)
+ self._label2.setText("Image2: %s" % text2)
+
+
+class CompareImages(qt.QMainWindow):
+ """Widget providing tools to compare 2 images.
+
+ .. image:: img/CompareImages.png
+
+ :param Union[qt.QWidget,None] parent: Parent of this widget.
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ VisualizationMode = VisualizationMode
+ """Available visualization modes"""
+
+ AlignmentMode = AlignmentMode
+ """Available alignment modes"""
+
+ sigConfigurationChanged = qt.Signal()
+ """Emitted when the configuration of the widget (visualization mode,
+ alignement mode...) have changed."""
+
+ def __init__(self, parent=None, backend=None):
+ qt.QMainWindow.__init__(self, parent)
+ self._resetZoomActive = True
+ self._colormap = Colormap()
+ """Colormap shared by all modes, except the compose images (rgb image)"""
+ self._colormapKeyPoints = Colormap('spring')
+ """Colormap used for sift keypoints"""
+
+ if parent is None:
+ self.setWindowTitle('Compare images')
+ else:
+ self.setWindowFlags(qt.Qt.Widget)
+
+ self.__transformation = None
+ self.__raw1 = None
+ self.__raw2 = None
+ self.__data1 = None
+ self.__data2 = None
+ self.__previousSeparatorPosition = None
+
+ self.__plot = plot.PlotWidget(parent=self, backend=backend)
+ self.__plot.setDefaultColormap(self._colormap)
+ self.__plot.getXAxis().setLabel('Columns')
+ self.__plot.getYAxis().setLabel('Rows')
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.__plot.getYAxis().setInverted(True)
+
+ self.__plot.setKeepDataAspectRatio(True)
+ self.__plot.sigPlotSignal.connect(self.__plotSlot)
+ self.__plot.setAxesDisplayed(False)
+
+ self.setCentralWidget(self.__plot)
+
+ legend = VisualizationMode.VERTICAL_LINE.name
+ self.__plot.addXMarker(
+ 0,
+ legend=legend,
+ text='',
+ draggable=True,
+ color='blue',
+ constraint=WeakMethodProxy(self.__separatorConstraint))
+ self.__vline = self.__plot._getMarker(legend)
+
+ legend = VisualizationMode.HORIZONTAL_LINE.name
+ self.__plot.addYMarker(
+ 0,
+ legend=legend,
+ text='',
+ draggable=True,
+ color='blue',
+ constraint=WeakMethodProxy(self.__separatorConstraint))
+ self.__hline = self.__plot._getMarker(legend)
+
+ # default values
+ self.__visualizationMode = ""
+ self.__alignmentMode = ""
+ self.__keypointsVisible = True
+
+ self.setAlignmentMode(AlignmentMode.ORIGIN)
+ self.setVisualizationMode(VisualizationMode.VERTICAL_LINE)
+ self.setKeypointsVisible(False)
+
+ # Toolbars
+
+ self._createToolBars(self.__plot)
+ if self._interactiveModeToolBar is not None:
+ self.addToolBar(self._interactiveModeToolBar)
+ if self._imageToolBar is not None:
+ self.addToolBar(self._imageToolBar)
+ if self._compareToolBar is not None:
+ self.addToolBar(self._compareToolBar)
+
+ # Statusbar
+
+ self._createStatusBar(self.__plot)
+ if self._statusBar is not None:
+ self.setStatusBar(self._statusBar)
+
+ def _createStatusBar(self, plot):
+ self._statusBar = CompareImagesStatusBar(self)
+ self._statusBar.setCompareWidget(self)
+
+ def _createToolBars(self, plot):
+ """Create tool bars displayed by the widget"""
+ toolBar = tools.InteractiveModeToolBar(parent=self, plot=plot)
+ self._interactiveModeToolBar = toolBar
+ toolBar = tools.ImageToolBar(parent=self, plot=plot)
+ self._imageToolBar = toolBar
+ toolBar = CompareImagesToolBar(self)
+ toolBar.setCompareWidget(self)
+ self._compareToolBar = toolBar
+
+ def getPlot(self):
+ """Returns the plot which is used to display the images.
+
+ :rtype: silx.gui.plot.PlotWidget
+ """
+ return self.__plot
+
+ def getColormap(self):
+ """
+
+ :return: colormap used for compare image
+ :rtype: silx.gui.colors.Colormap
+ """
+ return self._colormap
+
+ def getRawPixelData(self, x, y):
+ """Return the raw pixel of each image data from axes positions.
+
+ If the coordinate is outside of the image it returns None element in
+ the tuple.
+
+ The pixel is reach from the raw data image without filter or
+ transformation. But the coordinate x and y are in the reference of the
+ current displayed mode.
+
+ :param float x: X-coordinate of the pixel in the current displayed plot
+ :param float y: Y-coordinate of the pixel in the current displayed plot
+ :return: A tuple of for each images containing pixel information. It
+ could be a scalar value or an array in case of RGB/RGBA informations.
+ It also could be a string containing information is some cases.
+ :rtype: Tuple(Union[int,float,numpy.ndarray,str],Union[int,float,numpy.ndarray,str])
+ """
+ data2 = None
+ alignmentMode = self.__alignmentMode
+ raw1, raw2 = self.__raw1, self.__raw2
+ if alignmentMode == AlignmentMode.ORIGIN:
+ x1 = x
+ y1 = y
+ x2 = x
+ y2 = y
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ x1 = x - (xx - raw1.shape[1]) * 0.5
+ x2 = x - (xx - raw2.shape[1]) * 0.5
+ y1 = y - (yy - raw1.shape[0]) * 0.5
+ y2 = y - (yy - raw2.shape[0]) * 0.5
+ elif alignmentMode == AlignmentMode.STRETCH:
+ x1 = x
+ y1 = y
+ x2 = x * raw2.shape[1] / raw1.shape[1]
+ y2 = x * raw2.shape[1] / raw1.shape[1]
+ elif alignmentMode == AlignmentMode.AUTO:
+ x1 = x
+ y1 = y
+ # Not implemented
+ data2 = "Not implemented with sift"
+ else:
+ assert(False)
+
+ x1, y1 = int(x1), int(y1)
+ if raw1 is None or y1 < 0 or y1 >= raw1.shape[0] or x1 < 0 or x1 >= raw1.shape[1]:
+ data1 = None
+ else:
+ data1 = raw1[y1, x1]
+
+ if data2 is None:
+ x2, y2 = int(x2), int(y2)
+ if raw2 is None or y2 < 0 or y2 >= raw2.shape[0] or x2 < 0 or x2 >= raw2.shape[1]:
+ data2 = None
+ else:
+ data2 = raw2[y2, x2]
+
+ return data1, data2
+
+ def setVisualizationMode(self, mode):
+ """Set the visualization mode.
+
+ :param str mode: New visualization to display the image comparison
+ """
+ if self.__visualizationMode == mode:
+ return
+ previousMode = self.getVisualizationMode()
+ self.__visualizationMode = mode
+ mode = self.getVisualizationMode()
+ self.__vline.setVisible(mode == VisualizationMode.VERTICAL_LINE)
+ self.__hline.setVisible(mode == VisualizationMode.HORIZONTAL_LINE)
+ visModeRawDisplay = (VisualizationMode.ONLY_A,
+ VisualizationMode.ONLY_B,
+ VisualizationMode.VERTICAL_LINE,
+ VisualizationMode.HORIZONTAL_LINE)
+ updateColormap = not(previousMode in visModeRawDisplay and
+ mode in visModeRawDisplay)
+ self.__updateData(updateColormap=updateColormap)
+ self.sigConfigurationChanged.emit()
+
+ def getVisualizationMode(self):
+ """Returns the current interaction mode."""
+ return self.__visualizationMode
+
+ def setAlignmentMode(self, mode):
+ """Set the alignment mode.
+
+ :param str mode: New alignement to apply to images
+ """
+ if self.__alignmentMode == mode:
+ return
+ self.__alignmentMode = mode
+ self.__updateData(updateColormap=False)
+ self.sigConfigurationChanged.emit()
+
+ def getAlignmentMode(self):
+ """Returns the current selected alignemnt mode."""
+ return self.__alignmentMode
+
+ def setKeypointsVisible(self, isVisible):
+ """Set keypoints visibility.
+
+ :param bool isVisible: If True, keypoints are displayed (if some)
+ """
+ if self.__keypointsVisible == isVisible:
+ return
+ self.__keypointsVisible = isVisible
+ self.__updateKeyPoints()
+ self.sigConfigurationChanged.emit()
+
+ def __setDefaultAlignmentMode(self):
+ """Reset the alignemnt mode to the default value"""
+ self.setAlignmentMode(AlignmentMode.ORIGIN)
+
+ def __plotSlot(self, event):
+ """Handle events from the plot"""
+ if event['event'] in ('markerMoving', 'markerMoved'):
+ mode = self.getVisualizationMode()
+ legend = mode.name
+ if event['label'] == legend:
+ if mode == VisualizationMode.VERTICAL_LINE:
+ value = int(float(str(event['xdata'])))
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ value = int(float(str(event['ydata'])))
+ else:
+ assert(False)
+ if self.__previousSeparatorPosition != value:
+ self.__separatorMoved(value)
+ self.__previousSeparatorPosition = value
+
+ def __separatorConstraint(self, x, y):
+ """Manage contains on the separators to clamp them inside the images."""
+ if self.__data1 is None:
+ return 0, 0
+ x = int(x)
+ if x < 0:
+ x = 0
+ elif x > self.__data1.shape[1]:
+ x = self.__data1.shape[1]
+ y = int(y)
+ if y < 0:
+ y = 0
+ elif y > self.__data1.shape[0]:
+ y = self.__data1.shape[0]
+ return x, y
+
+ def __updateSeparators(self):
+ """Redraw images according to the current state of the separators.
+ """
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.VERTICAL_LINE:
+ pos = self.__vline.getXPosition()
+ self.__separatorMoved(pos)
+ self.__previousSeparatorPosition = pos
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ pos = self.__hline.getYPosition()
+ self.__separatorMoved(pos)
+ self.__previousSeparatorPosition = pos
+ else:
+ self.__image1.setOrigin((0, 0))
+ self.__image2.setOrigin((0, 0))
+
+ def __separatorMoved(self, pos):
+ """Called when vertical or horizontal separators have moved.
+
+ Update the displayed images.
+ """
+ if self.__data1 is None:
+ return
+
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.VERTICAL_LINE:
+ pos = int(pos)
+ if pos <= 0:
+ pos = 0
+ elif pos >= self.__data1.shape[1]:
+ pos = self.__data1.shape[1]
+ data1 = self.__data1[:, 0:pos]
+ data2 = self.__data2[:, pos:]
+ self.__image1.setData(data1, copy=False)
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((pos, 0))
+ elif mode == VisualizationMode.HORIZONTAL_LINE:
+ pos = int(pos)
+ if pos <= 0:
+ pos = 0
+ elif pos >= self.__data1.shape[0]:
+ pos = self.__data1.shape[0]
+ data1 = self.__data1[0:pos, :]
+ data2 = self.__data2[pos:, :]
+ self.__image1.setData(data1, copy=False)
+ self.__image2.setData(data2, copy=False)
+ self.__image2.setOrigin((0, pos))
+ else:
+ assert(False)
+
+ def setData(self, image1, image2, updateColormap=True):
+ """Set images to compare.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image1: The first image
+ :param numpy.ndarray image2: The second image
+ """
+ self.__raw1 = image1
+ self.__raw2 = image2
+ self.__updateData(updateColormap=updateColormap)
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def setImage1(self, image1, updateColormap=True):
+ """Set image1 to be compared.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image1: The first image
+ """
+ self.__raw1 = image1
+ self.__updateData(updateColormap=updateColormap)
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def setImage2(self, image2, updateColormap=True):
+ """Set image2 to be compared.
+
+ Images can contains floating-point or integer values, or RGB and RGBA
+ values, but should have comparable intensities.
+
+ RGB and RGBA images are provided as an array as `[width,height,channels]`
+ of usigned integer 8-bits or floating-points between 0.0 to 1.0.
+
+ :param numpy.ndarray image2: The second image
+ """
+ self.__raw2 = image2
+ self.__updateData(updateColormap=updateColormap)
+ if self.isAutoResetZoom():
+ self.__plot.resetZoom()
+
+ def __updateKeyPoints(self):
+ """Update the displayed keypoints using cached keypoints.
+ """
+ if self.__keypointsVisible:
+ data = self.__matching_keypoints
+ else:
+ data = [], [], []
+ self.__plot.addScatter(x=data[0],
+ y=data[1],
+ z=1,
+ value=data[2],
+ colormap=self._colormapKeyPoints,
+ legend="keypoints")
+
+ def __updateData(self, updateColormap):
+ """Compute aligned image when the alignment mode changes.
+
+ This function cache input images which are used when
+ vertical/horizontal separators moves.
+ """
+ raw1, raw2 = self.__raw1, self.__raw2
+ if raw1 is None or raw2 is None:
+ return
+
+ alignmentMode = self.getAlignmentMode()
+ self.__transformation = None
+
+ if alignmentMode == AlignmentMode.ORIGIN:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size, transparent=True)
+ data2 = self.__createMarginImage(raw2, size, transparent=True)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ elif alignmentMode == AlignmentMode.CENTER:
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size, transparent=True, center=True)
+ data2 = self.__createMarginImage(raw2, size, transparent=True, center=True)
+ self.__matching_keypoints = ([data1.shape[1] // 2],
+ [data1.shape[0] // 2],
+ [1.0])
+ elif alignmentMode == AlignmentMode.STRETCH:
+ data1 = raw1
+ data2 = self.__rescaleImage(raw2, data1.shape)
+ self.__matching_keypoints = ([0, data1.shape[1], data1.shape[1], 0],
+ [0, 0, data1.shape[0], data1.shape[0]],
+ [1.0, 1.0, 1.0, 1.0])
+ elif alignmentMode == AlignmentMode.AUTO:
+ # TODO: sift implementation do not support RGBA images
+ yy = max(raw1.shape[0], raw2.shape[0])
+ xx = max(raw1.shape[1], raw2.shape[1])
+ size = yy, xx
+ data1 = self.__createMarginImage(raw1, size)
+ data2 = self.__createMarginImage(raw2, size)
+ self.__matching_keypoints = [0.0], [0.0], [1.0]
+ try:
+ data1, data2 = self.__createSiftData(data1, data2)
+ if data2 is None:
+ raise ValueError("Unexpected None value")
+ except Exception as e:
+ # TODO: Display it on the GUI
+ _logger.error(e)
+ self.__setDefaultAlignmentMode()
+ return
+ else:
+ assert(False)
+
+ mode = self.getVisualizationMode()
+ if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
+ data1 = self.__composeImage(data1, data2, mode)
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
+ data1 = self.__composeImage(data1, data2, mode)
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ data1 = self.__composeImage(data1, data2, mode)
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.ONLY_A:
+ data2 = numpy.empty((0, 0))
+ elif mode == VisualizationMode.ONLY_B:
+ data1 = numpy.empty((0, 0))
+
+ self.__data1, self.__data2 = data1, data2
+ self.__plot.addImage(data1, z=0, legend="image1", resetzoom=False)
+ self.__plot.addImage(data2, z=0, legend="image2", resetzoom=False)
+ self.__image1 = self.__plot.getImage("image1")
+ self.__image2 = self.__plot.getImage("image2")
+ self.__updateKeyPoints()
+
+ # Set the separator into the middle
+ if self.__previousSeparatorPosition is None:
+ value = self.__data1.shape[1] // 2
+ self.__vline.setPosition(value, 0)
+ value = self.__data1.shape[0] // 2
+ self.__hline.setPosition(0, value)
+ self.__updateSeparators()
+ if updateColormap:
+ self.__updateColormap()
+
+ def __updateColormap(self):
+ # TODO: The colormap histogram will still be wrong
+ mode1 = self.__getImageMode(self.__data1)
+ mode2 = self.__getImageMode(self.__data2)
+ if mode1 == "intensity" and mode1 == mode2:
+ if self.__data1.size == 0:
+ vmin = self.__data2.min()
+ vmax = self.__data2.max()
+ elif self.__data2.size == 0:
+ vmin = self.__data1.min()
+ vmax = self.__data1.max()
+ else:
+ vmin = min(self.__data1.min(), self.__data2.min())
+ vmax = max(self.__data1.max(), self.__data2.max())
+ colormap = self.getColormap()
+ colormap.setVRange(vmin=vmin, vmax=vmax)
+ self.__image1.setColormap(colormap)
+ self.__image2.setColormap(colormap)
+
+ def __getImageMode(self, image):
+ """Returns a value identifying the way the image is stored in the
+ array.
+
+ :param numpy.ndarray image: Image to check
+ :rtype: str
+ """
+ if len(image.shape) == 2:
+ return "intensity"
+ elif len(image.shape) == 3:
+ if image.shape[2] == 3:
+ return "rgb"
+ elif image.shape[2] == 4:
+ return "rgba"
+ raise TypeError("'image' argument is not an image.")
+
+ def __rescaleImage(self, image, shape):
+ """Rescale an image to the requested shape.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ if mode == "intensity":
+ data = self.__rescaleArray(image, shape)
+ elif mode == "rgb":
+ data = numpy.empty((shape[0], shape[1], 3), dtype=image.dtype)
+ for c in range(3):
+ data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
+ elif mode == "rgba":
+ data = numpy.empty((shape[0], shape[1], 4), dtype=image.dtype)
+ for c in range(4):
+ data[:, :, c] = self.__rescaleArray(image[:, :, c], shape)
+ return data
+
+ def __composeImage(self, data1, data2, mode):
+ """Returns an RBG image containing composition of data1 and data2 in 2
+ different channels
+
+ :param numpy.ndarray data1: First image
+ :param numpy.ndarray data1: Second image
+ :param VisualizationMode mode: Composition mode.
+ :rtype: numpy.ndarray
+ """
+ assert(data1.shape[0:2] == data2.shape[0:2])
+ if mode == VisualizationMode.COMPOSITE_A_MINUS_B:
+ # TODO: this calculation has no interest of generating a 'composed'
+ # rgb image, this could be moved in an other function or doc
+ # should be modified
+ _type = data1.dtype
+ result = data1.astype(numpy.float64) - data2.astype(numpy.float64)
+ return result
+ mode1 = self.__getImageMode(data1)
+ if mode1 in ["rgb", "rgba"]:
+ intensity1 = self.__luminosityImage(data1)
+ vmin1, vmax1 = 0.0, 1.0
+ else:
+ intensity1 = data1
+ vmin1, vmax1 = data1.min(), data1.max()
+
+ mode2 = self.__getImageMode(data2)
+ if mode2 in ["rgb", "rgba"]:
+ intensity2 = self.__luminosityImage(data2)
+ vmin2, vmax2 = 0.0, 1.0
+ else:
+ intensity2 = data2
+ vmin2, vmax2 = data2.min(), data2.max()
+
+ vmin, vmax = min(vmin1, vmin2) * 1.0, max(vmax1, vmax2) * 1.0
+ shape = data1.shape
+ result = numpy.empty((shape[0], shape[1], 3), dtype=numpy.uint8)
+ a = (intensity1 - vmin) * (1.0 / (vmax - vmin)) * 255.0
+ b = (intensity2 - vmin) * (1.0 / (vmax - vmin)) * 255.0
+ if mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY:
+ result[:, :, 0] = a
+ result[:, :, 1] = (a + b) / 2
+ result[:, :, 2] = b
+ elif mode == VisualizationMode.COMPOSITE_RED_BLUE_GRAY_NEG:
+ result[:, :, 0] = 255 - b
+ result[:, :, 1] = 255 - (a + b) / 2
+ result[:, :, 2] = 255 - a
+ return result
+
+ def __luminosityImage(self, image):
+ """Returns the luminosity channel from an RBG(A) image.
+ The alpha channel is ignored.
+
+ :rtype: numpy.ndarray
+ """
+ mode = self.__getImageMode(image)
+ assert(mode in ["rgb", "rgba"])
+ is_uint8 = image.dtype.type == numpy.uint8
+ # luminosity
+ image = 0.21 * image[..., 0] + 0.72 * image[..., 1] + 0.07 * image[..., 2]
+ if is_uint8:
+ image = image / 255.0
+ return image
+
+ def __rescaleArray(self, image, shape):
+ """Rescale a 2D array to the requested shape.
+
+ :rtype: numpy.ndarray
+ """
+ y, x = numpy.ogrid[:shape[0], :shape[1]]
+ y, x = y * 1.0 * (image.shape[0] - 1) / (shape[0] - 1), x * 1.0 * (image.shape[1] - 1) / (shape[1] - 1)
+ b = silx.image.bilinear.BilinearImage(image)
+ # TODO: could be optimized using strides
+ x2d = numpy.zeros_like(y) + x
+ y2d = numpy.zeros_like(x) + y
+ result = b.map_coordinates((y2d, x2d))
+ return result
+
+ def __createMarginImage(self, image, size, transparent=False, center=False):
+ """Returns a new image with margin to respect the requested size.
+
+ :rtype: numpy.ndarray
+ """
+ assert(image.shape[0] <= size[0])
+ assert(image.shape[1] <= size[1])
+ if image.shape == size:
+ return image
+ mode = self.__getImageMode(image)
+
+ if center:
+ pos0 = size[0] // 2 - image.shape[0] // 2
+ pos1 = size[1] // 2 - image.shape[1] // 2
+ else:
+ pos0, pos1 = 0, 0
+
+ if mode == "intensity":
+ data = numpy.zeros(size, dtype=image.dtype)
+ data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1]] = image
+ # TODO: It is maybe possible to put NaN on the margin
+ else:
+ if transparent:
+ data = numpy.zeros((size[0], size[1], 4), dtype=numpy.uint8)
+ else:
+ data = numpy.zeros((size[0], size[1], 3), dtype=numpy.uint8)
+ depth = min(data.shape[2], image.shape[2])
+ data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 0:depth] = image[:, :, 0:depth]
+ if transparent and depth == 3:
+ data[pos0:pos0 + image.shape[0], pos1:pos1 + image.shape[1], 3] = 255
+ return data
+
+ def __toAffineTransformation(self, sift_result):
+ """Returns an affine transformation from the sift result.
+
+ :param dict sift_result: Result of sift when using `all_result=True`
+ :rtype: AffineTransformation
+ """
+ offset = sift_result["offset"]
+ matrix = sift_result["matrix"]
+
+ tx = offset[0]
+ ty = offset[1]
+ a = matrix[0, 0]
+ b = matrix[0, 1]
+ c = matrix[1, 0]
+ d = matrix[1, 1]
+ rot = math.atan2(-b, a)
+ sx = (-1.0 if a < 0 else 1.0) * math.sqrt(a**2 + b**2)
+ sy = (-1.0 if d < 0 else 1.0) * math.sqrt(c**2 + d**2)
+ return AffineTransformation(tx, ty, sx, sy, rot)
+
+ def getTransformation(self):
+ """Retuns the affine transformation applied to the second image to align
+ it to the first image.
+
+ This result is only valid for sift alignment.
+
+ :rtype: Union[None,AffineTransformation]
+ """
+ return self.__transformation
+
+ def __createSiftData(self, image, second_image):
+ """Generate key points and aligned images from 2 images.
+
+ If no keypoints matches, unaligned data are anyway returns.
+
+ :rtype: Tuple(numpy.ndarray,numpy.ndarray)
+ """
+ devicetype = "GPU"
+
+ # Compute base image
+ sift_ocl = sift.SiftPlan(template=image, devicetype=devicetype)
+ keypoints = sift_ocl(image)
+
+ # Check image compatibility
+ second_keypoints = sift_ocl(second_image)
+ mp = sift.MatchPlan()
+ match = mp(keypoints, second_keypoints)
+ _logger.info("Number of Keypoints within image 1: %i" % keypoints.size)
+ _logger.info(" within image 2: %i" % second_keypoints.size)
+
+ self.__matching_keypoints = (match[:].x[:, 0],
+ match[:].y[:, 0],
+ match[:].scale[:, 0])
+ matching_keypoints = match.shape[0]
+ _logger.info("Matching keypoints: %i" % matching_keypoints)
+ if matching_keypoints == 0:
+ return image, second_image
+
+ # TODO: Problem here is we have to compute 2 time sift
+ # The first time to extract matching keypoints, second time
+ # to extract the aligned image.
+
+ # Normalize the second image
+ sa = sift.LinearAlign(image, devicetype=devicetype)
+ data1 = image
+ # TODO: Create a sift issue: if data1 is RGB and data2 intensity
+ # it returns None, while extracting manually keypoints (above) works
+ result = sa.align(second_image, return_all=True)
+ data2 = result["result"]
+ self.__transformation = self.__toAffineTransformation(result)
+ return data1, data2
+
+ def setAutoResetZoom(self, activate=True):
+ """
+
+ :param bool activate: True if we want to activate the automatic
+ plot reset zoom when setting images.
+ """
+ self._resetZoomActive = activate
+
+ def isAutoResetZoom(self):
+ """
+
+ :return: True if the automatic call to resetzoom is activated
+ :rtype: bool
+ """
+ return self._resetZoomActive
diff --git a/src/silx/gui/plot/ComplexImageView.py b/src/silx/gui/plot/ComplexImageView.py
new file mode 100644
index 0000000..4eee3b0
--- /dev/null
+++ b/src/silx/gui/plot/ComplexImageView.py
@@ -0,0 +1,518 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget to view 2D complex data.
+
+The :class:`ComplexImageView` widget is dedicated to visualize a single 2D dataset
+of complex data.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["Vincent Favre-Nicolin", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+import collections
+import numpy
+
+from ...utils.deprecation import deprecated
+from .. import qt, icons
+from .PlotWindow import Plot2D
+from . import items
+from .items import ImageComplexData
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+_logger = logging.getLogger(__name__)
+
+
+# Widgets
+
+class _AmplitudeRangeDialog(qt.QDialog):
+ """QDialog asking for the amplitude range to display."""
+
+ sigRangeChanged = qt.Signal(tuple)
+ """Signal emitted when the range has changed.
+
+ It provides the new range as a 2-tuple: (max, delta)
+ """
+
+ def __init__(self,
+ parent=None,
+ amplitudeRange=None,
+ displayedRange=(None, 2)):
+ super(_AmplitudeRangeDialog, self).__init__(parent)
+ self.setWindowTitle('Set Displayed Amplitude Range')
+
+ if amplitudeRange is not None:
+ amplitudeRange = min(amplitudeRange), max(amplitudeRange)
+ self._amplitudeRange = amplitudeRange
+ self._defaultDisplayedRange = displayedRange
+
+ layout = qt.QFormLayout()
+ self.setLayout(layout)
+
+ if self._amplitudeRange is not None:
+ min_, max_ = self._amplitudeRange
+ layout.addRow(
+ qt.QLabel('Data Amplitude Range: [%g, %g]' % (min_, max_)))
+
+ self._maxLineEdit = FloatEdit(parent=self)
+ self._maxLineEdit.validator().setBottom(0.)
+ self._maxLineEdit.setAlignment(qt.Qt.AlignRight)
+
+ self._maxLineEdit.editingFinished.connect(self._rangeUpdated)
+ layout.addRow('Displayed Max.:', self._maxLineEdit)
+
+ self._autoscale = qt.QCheckBox('autoscale')
+ self._autoscale.toggled.connect(self._autoscaleCheckBoxToggled)
+ layout.addRow('', self._autoscale)
+
+ self._deltaLineEdit = FloatEdit(parent=self)
+ self._deltaLineEdit.validator().setBottom(1.)
+ self._deltaLineEdit.setAlignment(qt.Qt.AlignRight)
+ self._deltaLineEdit.editingFinished.connect(self._rangeUpdated)
+ layout.addRow('Displayed delta (log10 unit):', self._deltaLineEdit)
+
+ buttons = qt.QDialogButtonBox(self)
+ buttons.addButton(qt.QDialogButtonBox.Ok)
+ buttons.addButton(qt.QDialogButtonBox.Cancel)
+ buttons.accepted.connect(self.accept)
+ buttons.rejected.connect(self.reject)
+ layout.addRow(buttons)
+
+ # Set dialog from default values
+ self._resetDialogToDefault()
+
+ self.rejected.connect(self._handleRejected)
+
+ def _resetDialogToDefault(self):
+ """Set Widgets of the dialog from range information
+ """
+ max_, delta = self._defaultDisplayedRange
+
+ if max_ is not None: # Not in autoscale
+ displayedMax = max_
+ elif self._amplitudeRange is not None: # Autoscale with data
+ displayedMax = self._amplitudeRange[1]
+ else: # Autoscale without data
+ displayedMax = ''
+ if displayedMax == "":
+ self._maxLineEdit.setText("")
+ else:
+ self._maxLineEdit.setValue(displayedMax)
+ self._maxLineEdit.setEnabled(max_ is not None)
+
+ self._deltaLineEdit.setValue(delta)
+
+ self._autoscale.setChecked(self._defaultDisplayedRange[0] is None)
+
+ def getRangeInfo(self):
+ """Returns the current range as a 2-tuple (max, delta (in log10))"""
+ if self._autoscale.isChecked():
+ max_ = None
+ else:
+ maxStr = self._maxLineEdit.text()
+ max_ = self._maxLineEdit.value() if maxStr else None
+ return max_, self._deltaLineEdit.value() if self._deltaLineEdit.text() else 2
+
+ def _handleRejected(self):
+ """Reset range info to default when rejected"""
+ self._resetDialogToDefault()
+ self._rangeUpdated()
+
+ def _rangeUpdated(self):
+ """Handle QLineEdit editing finised"""
+ self.sigRangeChanged.emit(self.getRangeInfo())
+
+ def _autoscaleCheckBoxToggled(self, checked):
+ """Handle autoscale checkbox state changes"""
+ if checked: # Use default values
+ if self._amplitudeRange is None:
+ max_ = ''
+ else:
+ max_ = self._amplitudeRange[1]
+ if max_ == "":
+ self._maxLineEdit.setText("")
+ else:
+ self._maxLineEdit.setValue(max_)
+ self._maxLineEdit.setEnabled(not checked)
+ self._rangeUpdated()
+
+
+class _ComplexDataToolButton(qt.QToolButton):
+ """QToolButton providing choices of complex data visualization modes
+
+ :param parent: See :class:`QToolButton`
+ :param plot: The :class:`ComplexImageView` to control
+ """
+
+ _MODES = collections.OrderedDict([
+ (ImageComplexData.ComplexMode.ABSOLUTE, ('math-amplitude', 'Amplitude')),
+ (ImageComplexData.ComplexMode.SQUARE_AMPLITUDE,
+ ('math-square-amplitude', 'Square amplitude')),
+ (ImageComplexData.ComplexMode.PHASE, ('math-phase', 'Phase')),
+ (ImageComplexData.ComplexMode.REAL, ('math-real', 'Real part')),
+ (ImageComplexData.ComplexMode.IMAGINARY,
+ ('math-imaginary', 'Imaginary part')),
+ (ImageComplexData.ComplexMode.AMPLITUDE_PHASE,
+ ('math-phase-color', 'Amplitude and Phase')),
+ (ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE,
+ ('math-phase-color-log', 'Log10(Amp.) and Phase'))
+ ])
+
+ _RANGE_DIALOG_TEXT = 'Set Amplitude Range...'
+
+ def __init__(self, parent=None, plot=None):
+ super(_ComplexDataToolButton, self).__init__(parent=parent)
+
+ assert plot is not None
+ self._plot2DComplex = plot
+
+ menu = qt.QMenu(self)
+ menu.triggered.connect(self._triggered)
+ self.setMenu(menu)
+
+ for mode, info in self._MODES.items():
+ icon, text = info
+ action = qt.QAction(icons.getQIcon(icon), text, self)
+ action.setData(mode)
+ action.setIconVisibleInMenu(True)
+ menu.addAction(action)
+
+ self._rangeDialogAction = qt.QAction(self)
+ self._rangeDialogAction.setText(self._RANGE_DIALOG_TEXT)
+ menu.addAction(self._rangeDialogAction)
+
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ self._modeChanged(self._plot2DComplex.getComplexMode())
+ self._plot2DComplex.sigVisualizationModeChanged.connect(
+ self._modeChanged)
+
+ def _modeChanged(self, mode):
+ """Handle change of visualization modes"""
+ icon, text = self._MODES[mode]
+ self.setIcon(icons.getQIcon(icon))
+ self.setToolTip('Display the ' + text.lower())
+ self._rangeDialogAction.setEnabled(
+ mode == ImageComplexData.ComplexMode.LOG10_AMPLITUDE_PHASE)
+
+ def _triggered(self, action):
+ """Handle triggering of menu actions"""
+ actionText = action.text()
+
+ if actionText == self._RANGE_DIALOG_TEXT: # Show dialog
+ # Get amplitude range
+ data = self._plot2DComplex.getData(copy=False)
+
+ if data.size > 0:
+ absolute = numpy.absolute(data)
+ dataRange = (numpy.nanmin(absolute), numpy.nanmax(absolute))
+ else:
+ dataRange = None
+
+ # Show dialog
+ dialog = _AmplitudeRangeDialog(
+ parent=self,
+ amplitudeRange=dataRange,
+ displayedRange=self._plot2DComplex._getAmplitudeRangeInfo())
+ dialog.sigRangeChanged.connect(self._rangeChanged)
+ dialog.exec()
+ dialog.sigRangeChanged.disconnect(self._rangeChanged)
+
+ else: # update mode
+ mode = action.data()
+ if isinstance(mode, ImageComplexData.ComplexMode):
+ self._plot2DComplex.setComplexMode(mode)
+
+ def _rangeChanged(self, range_):
+ """Handle updates of range in the dialog"""
+ self._plot2DComplex._setAmplitudeRangeInfo(*range_)
+
+
+class ComplexImageView(qt.QWidget):
+ """Display an image of complex data and allow to choose the visualization.
+
+ :param parent: See :class:`QMainWindow`
+ """
+
+ ComplexMode = ImageComplexData.ComplexMode
+ """Complex Modes enumeration"""
+
+ sigDataChanged = qt.Signal()
+ """Signal emitted when data has changed."""
+
+ sigVisualizationModeChanged = qt.Signal(object)
+ """Signal emitted when the visualization mode has changed.
+
+ It provides the new visualization mode.
+ """
+
+ def __init__(self, parent=None):
+ super(ComplexImageView, self).__init__(parent)
+ if parent is None:
+ self.setWindowTitle('ComplexImageView')
+
+ self._plot2D = Plot2D(self)
+
+ layout = qt.QHBoxLayout(self)
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot2D)
+ self.setLayout(layout)
+
+ # Create and add image to the plot
+ self._plotImage = ImageComplexData()
+ self._plotImage.setName('__ComplexImageView__complex_image__')
+ self._plotImage.sigItemChanged.connect(self._itemChanged)
+ self._plot2D.addItem(self._plotImage)
+ self._plot2D.setActiveImage(self._plotImage.getName())
+
+ toolBar = qt.QToolBar('Complex', self)
+ toolBar.addWidget(
+ _ComplexDataToolButton(parent=self, plot=self))
+
+ self._plot2D.insertToolBar(self._plot2D.getProfileToolbar(), toolBar)
+
+ def _itemChanged(self, event):
+ """Handle item changed signal"""
+ if event is items.ItemChangedType.DATA:
+ self.sigDataChanged.emit()
+ elif event is items.ItemChangedType.VISUALIZATION_MODE:
+ mode = self.getComplexMode()
+ self.sigVisualizationModeChanged.emit(mode)
+
+ def getPlot(self):
+ """Return the PlotWidget displaying the data"""
+ return self._plot2D
+
+ def setData(self, data=None, copy=True):
+ """Set the complex data to display.
+
+ :param numpy.ndarray data: 2D complex data
+ :param bool copy: True (default) to copy the data,
+ False to use provided data (do not modify!).
+ """
+ if data is None:
+ data = numpy.zeros((0, 0), dtype=numpy.complex64)
+
+ previousData = self._plotImage.getComplexData(copy=False)
+
+ self._plotImage.setData(data, copy=copy)
+
+ if previousData.shape != data.shape:
+ self.getPlot().resetZoom()
+
+ def getData(self, copy=True):
+ """Get the currently displayed complex data.
+
+ :param bool copy: True (default) to return a copy of the data,
+ False to return internal data (do not modify!).
+ :return: The complex data array.
+ :rtype: numpy.ndarray of complex with 2 dimensions
+ """
+ return self._plotImage.getComplexData(copy=copy)
+
+ def getDisplayedData(self, copy=True):
+ """Returns the displayed data depending on the visualization mode
+
+ WARNING: The returned data can be a uint8 RGBA image
+
+ :param bool copy: True (default) to return a copy of the data,
+ False to return internal data (do not modify!)
+ :rtype: numpy.ndarray of float with 2 dims or RGBA image (uint8).
+ """
+ mode = self.getComplexMode()
+ if mode in (self.ComplexMode.AMPLITUDE_PHASE,
+ self.ComplexMode.LOG10_AMPLITUDE_PHASE):
+ return self._plotImage.getRgbaImageData(copy=copy)
+ else:
+ return self._plotImage.getData(copy=copy)
+
+ # Backward compatibility
+
+ Mode = ComplexMode
+
+ @classmethod
+ @deprecated(replacement='supportedComplexModes', since_version='0.11.0')
+ def getSupportedVisualizationModes(cls):
+ return cls.supportedComplexModes()
+
+ @deprecated(replacement='setComplexMode', since_version='0.11.0')
+ def setVisualizationMode(self, mode):
+ return self.setComplexMode(mode)
+
+ @deprecated(replacement='getComplexMode', since_version='0.11.0')
+ def getVisualizationMode(self):
+ return self.getComplexMode()
+
+ # Image item proxy
+
+ @staticmethod
+ def supportedComplexModes():
+ """Returns the supported visualization modes.
+
+ Supported visualization modes are:
+
+ - amplitude: The absolute value provided by numpy.absolute
+ - phase: The phase (or argument) provided by numpy.angle
+ - real: Real part
+ - imaginary: Imaginary part
+ - amplitude_phase: Color-coded phase with amplitude as alpha.
+ - log10_amplitude_phase:
+ Color-coded phase with log10(amplitude) as alpha.
+
+ :rtype: List[ComplexMode]
+ """
+ return ImageComplexData.supportedComplexModes()
+
+ def setComplexMode(self, mode):
+ """Set the mode of visualization of the complex data.
+
+ See :meth:`supportedComplexModes` for the list of
+ supported modes.
+
+ How-to change visualization mode::
+
+ widget = ComplexImageView()
+ widget.setComplexMode(ComplexImageView.ComplexMode.PHASE)
+ # or
+ widget.setComplexMode('phase')
+
+ :param Unions[ComplexMode,str] mode: The mode to use.
+ """
+ self._plotImage.setComplexMode(mode)
+
+ def getComplexMode(self):
+ """Get the current visualization mode of the complex data.
+
+ :rtype: ComplexMode
+ """
+ return self._plotImage.getComplexMode()
+
+ def _setAmplitudeRangeInfo(self, max_=None, delta=2):
+ """Set the amplitude range to display for 'log10_amplitude_phase' mode.
+
+ :param max_: Max of the amplitude range.
+ If None it autoscales to data max.
+ :param float delta: Delta range in log10 to display
+ """
+ self._plotImage._setAmplitudeRangeInfo(max_, delta)
+
+ def _getAmplitudeRangeInfo(self):
+ """Returns the amplitude range to use for 'log10_amplitude_phase' mode.
+
+ :return: (max, delta), if max is None, then it autoscales to data max
+ :rtype: 2-tuple"""
+ return self._plotImage._getAmplitudeRangeInfo()
+
+ def setColormap(self, colormap, mode=None):
+ """Set the colormap to use for amplitude, phase, real or imaginary.
+
+ WARNING: This colormap is not used when displaying both
+ amplitude and phase.
+
+ :param ~silx.gui.colors.Colormap colormap: The colormap
+ :param ComplexMode mode: If specified, set the colormap of this specific mode
+ """
+ self._plotImage.setColormap(colormap, mode)
+
+ def getColormap(self, mode=None):
+ """Returns the colormap used to display the data.
+
+ :param ComplexMode mode: If specified, set the colormap of this specific mode
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._plotImage.getColormap(mode=mode)
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._plotImage.getOrigin()
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ self._plotImage.setOrigin(origin)
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._plotImage.getScale()
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ self._plotImage.setScale(scale)
+
+ # PlotWidget API proxy
+
+ def getXAxis(self):
+ """Returns the X axis
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self.getPlot().getXAxis()
+
+ def getYAxis(self):
+ """Returns an Y axis
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self.getPlot().getYAxis(axis='left')
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self.getPlot().getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self.getPlot().setGraphTitle(title)
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self.getPlot().setKeepDataAspectRatio(flag)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.getPlot().isKeepDataAspectRatio()
diff --git a/src/silx/gui/plot/CurvesROIWidget.py b/src/silx/gui/plot/CurvesROIWidget.py
new file mode 100644
index 0000000..132d398
--- /dev/null
+++ b/src/silx/gui/plot/CurvesROIWidget.py
@@ -0,0 +1,1581 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Widget to handle regions of interest (:class:`ROI`) on curves displayed in a
+:class:`PlotWindow`.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent", "H. Payno"]
+__license__ = "MIT"
+__date__ = "13/03/2018"
+
+from collections import OrderedDict
+import logging
+import os
+import sys
+import functools
+import numpy
+from silx.io import dictdump
+from silx.utils import deprecation
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.proxy import docstring
+from .. import icons, qt
+from silx.math.combo import min_max
+import weakref
+from silx.gui.widgets.TableWidget import TableWidget
+from . import items
+from .items.roi import _RegionOfInterestBase
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurvesROIWidget(qt.QWidget):
+ """
+ Widget displaying a table of ROI information.
+
+ Implements also the following behavior:
+
+ * if the roiTable has no ROI when showing create the default ICR one
+
+ :param parent: See :class:`QWidget`
+ :param str name: The title of this widget
+ """
+
+ sigROIWidgetSignal = qt.Signal(object)
+ """Signal of ROIs modifications.
+
+ Modification information if given as a dict with an 'event' key
+ providing the type of events.
+
+ Type of events:
+
+ - AddROI, DelROI, LoadROI and ResetROI with keys: 'roilist', 'roidict'
+ - selectionChanged with keys: 'row', 'col' 'roi', 'key', 'colheader',
+ 'rowheader'
+ """
+
+ sigROISignal = qt.Signal(object)
+
+ def __init__(self, parent=None, name=None, plot=None):
+ super(CurvesROIWidget, self).__init__(parent)
+ if name is not None:
+ self.setWindowTitle(name)
+ self.__lastSigROISignal = None
+ """Store the last value emitted for the sigRoiSignal. In the case the
+ active curve change we need to add this extra step in order to make
+ sure we won't send twice the sigROISignal.
+ This come from the fact sigROISignal is connected to the
+ activeROIChanged signal which is emitted when raw and net counts
+ values are changing but are not embed in the sigROISignal.
+ """
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._showAllMarkers = False
+ self.currentROI = None
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ self.headerLabel = qt.QLabel(self)
+ self.headerLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.setHeader()
+ layout.addWidget(self.headerLabel)
+
+ widgetAllCheckbox = qt.QWidget(parent=self)
+ self._showAllCheckBox = qt.QCheckBox("show all ROI",
+ parent=widgetAllCheckbox)
+ widgetAllCheckbox.setLayout(qt.QHBoxLayout())
+ spacer = qt.QWidget(parent=widgetAllCheckbox)
+ spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Fixed)
+ widgetAllCheckbox.layout().addWidget(spacer)
+ widgetAllCheckbox.layout().addWidget(self._showAllCheckBox)
+ layout.addWidget(widgetAllCheckbox)
+
+ self.roiTable = ROITable(self, plot=plot)
+ rheight = self.roiTable.horizontalHeader().sizeHint().height()
+ self.roiTable.setMinimumHeight(4 * rheight)
+ layout.addWidget(self.roiTable)
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ self._showAllCheckBox.toggled.connect(self.roiTable.showAllMarkers)
+
+ hbox = qt.QWidget(self)
+ hboxlayout = qt.QHBoxLayout(hbox)
+ hboxlayout.setContentsMargins(0, 0, 0, 0)
+ hboxlayout.setSpacing(0)
+
+ hboxlayout.addStretch(0)
+
+ self.addButton = qt.QPushButton(hbox)
+ self.addButton.setText("Add ROI")
+ self.addButton.setToolTip('Create a new ROI')
+ self.delButton = qt.QPushButton(hbox)
+ self.delButton.setText("Delete ROI")
+ self.addButton.setToolTip('Remove the selected ROI')
+ self.resetButton = qt.QPushButton(hbox)
+ self.resetButton.setText("Reset")
+ self.addButton.setToolTip('Clear all created ROIs. We only let the '
+ 'default ROI')
+
+ hboxlayout.addWidget(self.addButton)
+ hboxlayout.addWidget(self.delButton)
+ hboxlayout.addWidget(self.resetButton)
+
+ hboxlayout.addStretch(0)
+
+ self.loadButton = qt.QPushButton(hbox)
+ self.loadButton.setText("Load")
+ self.loadButton.setToolTip('Load ROIs from a .ini file')
+ self.saveButton = qt.QPushButton(hbox)
+ self.saveButton.setText("Save")
+ self.loadButton.setToolTip('Save ROIs to a .ini file')
+ hboxlayout.addWidget(self.loadButton)
+ hboxlayout.addWidget(self.saveButton)
+ layout.setStretchFactor(self.headerLabel, 0)
+ layout.setStretchFactor(self.roiTable, 1)
+ layout.setStretchFactor(hbox, 0)
+
+ layout.addWidget(hbox)
+
+ # Signal / Slot connections
+ self.addButton.clicked.connect(self._add)
+ self.delButton.clicked.connect(self._del)
+ self.resetButton.clicked.connect(self._reset)
+
+ self.loadButton.clicked.connect(self._load)
+ self.saveButton.clicked.connect(self._save)
+
+ self.roiTable.activeROIChanged.connect(self._emitCurrentROISignal)
+
+ self._isConnected = False # True if connected to plot signals
+ self._isInit = False
+
+ # expose API
+ self.getROIListAndDict = self.roiTable.getROIListAndDict
+
+ def getPlotWidget(self):
+ """Returns the associated PlotWidget or None
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ @property
+ def roiFileDir(self):
+ """The directory from which to load/save ROI from/to files."""
+ if not os.path.isdir(self._roiFileDir):
+ self._roiFileDir = qt.QDir.home().absolutePath()
+ return self._roiFileDir
+
+ @roiFileDir.setter
+ def roiFileDir(self, roiFileDir):
+ self._roiFileDir = str(roiFileDir)
+
+ def setRois(self, rois, order=None):
+ return self.roiTable.setRois(rois, order)
+
+ def getRois(self, order=None):
+ return self.roiTable.getRois(order)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ return self.roiTable.setMiddleROIMarkerFlag(flag)
+
+ def _add(self):
+ """Add button clicked handler"""
+ def getNextRoiName():
+ rois = self.roiTable.getRois(order=None)
+ roisNames = []
+ [roisNames.append(roiName) for roiName in rois]
+ nrois = len(rois)
+ if nrois == 0:
+ return "ICR"
+ else:
+ i = 1
+ newroi = "newroi %d" % i
+ while newroi in roisNames:
+ i += 1
+ newroi = "newroi %d" % i
+ return newroi
+ roi = ROI(name=getNextRoiName())
+
+ if roi.getName() == "ICR":
+ roi.setType("Default")
+ else:
+ roi.setType(self.getPlotWidget().getXAxis().getLabel())
+
+ xmin, xmax = self.getPlotWidget().getXAxis().getLimits()
+ fromdata = xmin + 0.25 * (xmax - xmin)
+ todata = xmin + 0.75 * (xmax - xmin)
+ if roi.isICR():
+ fromdata, dummy0, todata, dummy1 = self._getAllLimits()
+ roi.setFrom(fromdata)
+ roi.setTo(todata)
+ self.roiTable.addRoi(roi)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "AddROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _del(self):
+ """Delete button clicked handler"""
+ self.roiTable.deleteActiveRoi()
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "DelROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _reset(self):
+ """Reset button clicked handler"""
+ self.roiTable.clear()
+ old = self.blockSignals(True) # avoid several sigROISignal emission
+ self._add()
+ self.blockSignals(old)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "ResetROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def _load(self):
+ """Load button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(
+ ['INI File *.ini', 'JSON File *.json', 'All *.*'])
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ # pyflakes bug http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=666494
+ outputFile = dialog.selectedFiles()[0]
+ dialog.close()
+
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.roiTable.load(outputFile)
+
+ # back compatibility pymca roi signals
+ ddict = {}
+ ddict['event'] = "LoadROI"
+ ddict['roilist'] = self.roiTable.roidict.values()
+ ddict['roidict'] = self.roiTable.roidict
+ self.sigROIWidgetSignal.emit(ddict)
+ # end back compatibility pymca roi signals
+
+ def load(self, filename):
+ """Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ self.roiTable.load(filename)
+
+ def _save(self):
+ """Save button clicked handler"""
+ dialog = qt.QFileDialog(self)
+ dialog.setNameFilters(['INI File *.ini', 'JSON File *.json'])
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.roiFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ outputFile = dialog.selectedFiles()[0]
+ extension = '.' + dialog.selectedNameFilter().split('.')[-1]
+ dialog.close()
+
+ if not outputFile.endswith(extension):
+ outputFile += extension
+
+ if os.path.exists(outputFile):
+ try:
+ os.remove(outputFile)
+ except IOError:
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Input Output Error: %s" % (sys.exc_info()[1]))
+ msg.exec()
+ return
+ self.roiFileDir = os.path.dirname(outputFile)
+ self.save(outputFile)
+
+ def save(self, filename):
+ """Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ self.roiTable.save(filename)
+
+ def setHeader(self, text='ROIs'):
+ """Set the header text of this widget"""
+ self.headerLabel.setText("<b>%s<\b>" % text)
+
+ @deprecation.deprecated(replacement="calculateRois",
+ reason="CamelCase convention",
+ since_version="0.7")
+ def calculateROIs(self, *args, **kw):
+ self.calculateRois(*args, **kw)
+
+ def calculateRois(self, roiList=None, roiDict=None):
+ """Compute ROI information"""
+ return self.roiTable.calculateRois()
+
+ def showAllMarkers(self, _show=True):
+ self.roiTable.showAllMarkers(_show)
+
+ def _getAllLimits(self):
+ """Retrieve the limits based on the curves."""
+ plot = self.getPlotWidget()
+ curves = () if plot is None else plot.getAllCurves()
+ if not curves:
+ return 1.0, 1.0, 100., 100.
+
+ xmin, ymin = None, None
+ xmax, ymax = None, None
+
+ for curve in curves:
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+ if xmin is None:
+ xmin = x.min()
+ else:
+ xmin = min(xmin, x.min())
+ if xmax is None:
+ xmax = x.max()
+ else:
+ xmax = max(xmax, x.max())
+ if ymin is None:
+ ymin = y.min()
+ else:
+ ymin = min(ymin, y.min())
+ if ymax is None:
+ ymax = y.max()
+ else:
+ ymax = max(ymax, y.max())
+
+ return xmin, ymin, xmax, ymax
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ # if no ROI existing yet, add the default one
+ if self.roiTable.rowCount() == 0:
+ old = self.blockSignals(True) # avoid several sigROISignal emission
+ self._add()
+ self.blockSignals(old)
+ self.calculateRois()
+
+ def fillFromROIDict(self, *args, **kwargs):
+ self.roiTable.fillFromROIDict(*args, **kwargs)
+
+ def _emitCurrentROISignal(self):
+ ddict = {}
+ ddict['event'] = "currentROISignal"
+ if self.roiTable.activeRoi is not None:
+ ddict['ROI'] = self.roiTable.activeRoi.toDict()
+ ddict['current'] = self.roiTable.activeRoi.getName()
+ else:
+ ddict['current'] = None
+
+ if self.__lastSigROISignal != ddict:
+ self.__lastSigROISignal = ddict
+ self.sigROISignal.emit(ddict)
+
+ @property
+ def currentRoi(self):
+ return self.roiTable.activeRoi
+
+
+class _FloatItem(qt.QTableWidgetItem):
+ """
+ Simple QTableWidgetItem overloading the < operator to deal with ordering
+ """
+ def __init__(self):
+ qt.QTableWidgetItem.__init__(self, type=qt.QTableWidgetItem.Type)
+
+ def __lt__(self, other):
+ if self.text() in ('', ROITable.INFO_NOT_FOUND):
+ return False
+ if other.text() in ('', ROITable.INFO_NOT_FOUND):
+ return True
+ return float(self.text()) < float(other.text())
+
+
+class ROITable(TableWidget):
+ """Table widget displaying ROI information.
+
+ See :class:`QTableWidget` for constructor arguments.
+
+ Behavior: listen at the active curve changed only when the widget is
+ visible. Otherwise won't compute the row and net counts...
+ """
+
+ activeROIChanged = qt.Signal()
+ """Signal emitted when the active roi changed or when the value of the
+ active roi are changing"""
+
+ COLUMNS_INDEX = OrderedDict([
+ ('ID', 0),
+ ('ROI', 1),
+ ('Type', 2),
+ ('From', 3),
+ ('To', 4),
+ ('Raw Counts', 5),
+ ('Net Counts', 6),
+ ('Raw Area', 7),
+ ('Net Area', 8),
+ ])
+
+ COLUMNS = list(COLUMNS_INDEX.keys())
+
+ INFO_NOT_FOUND = '????????'
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ super(ROITable, self).__init__(parent)
+ self._showAllMarkers = False
+ self._userIsEditingRoi = False
+ """bool used to avoid conflict when editing the ROI object"""
+ self._isConnected = False
+ self._roiToItems = {}
+ self._roiDict = {}
+ """dict of ROI object. Key is ROi id, value is the ROI object"""
+ self._markersHandler = _RoiMarkerManager()
+
+ """
+ Associate for each marker legend used when the `_showAllMarkers` option
+ is active a roi.
+ """
+ self.setColumnCount(len(self.COLUMNS))
+ self.setPlot(plot)
+ self.__setTooltip()
+ self.setSortingEnabled(True)
+ self.itemChanged.connect(self._itemChanged)
+
+ @property
+ def roidict(self):
+ return self._getRoiDict()
+
+ @property
+ def activeRoi(self):
+ return self._markersHandler._activeRoi
+
+ def _getRoiDict(self):
+ ddict = {}
+ for id in self._roiDict:
+ ddict[self._roiDict[id].getName()] = self._roiDict[id]
+ return ddict
+
+ def clear(self):
+ """
+ .. note:: clear the interface only. keep the roidict...
+ """
+ self._markersHandler.clear()
+ self._roiToItems = {}
+ self._roiDict = {}
+
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setHorizontalHeaderLabels(self.COLUMNS)
+ header = self.horizontalHeader()
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+ self.sortByColumn(0, qt.Qt.AscendingOrder)
+ self.hideColumn(self.COLUMNS_INDEX['ID'])
+
+ def setPlot(self, plot):
+ self.clear()
+ self.plot = plot
+
+ def __setTooltip(self):
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['ROI']).setToolTip(
+ 'Region of interest identifier')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Type']).setToolTip(
+ 'Type of the ROI')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['From']).setToolTip(
+ 'X-value of the min point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['To']).setToolTip(
+ 'X-value of the max point')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Raw Counts']).setToolTip(
+ 'Estimation of the integral between y=0 and the selected curve')
+ self.horizontalHeaderItem(self.COLUMNS_INDEX['Net Counts']).setToolTip(
+ 'Estimation of the integral between the segment [maxPt, minPt] '
+ 'and the selected curve')
+
+ def setRois(self, rois, order=None):
+ """Set the ROIs by providing a dictionary of ROI information.
+
+ The dictionary keys are the ROI names.
+ Each value is a sub-dictionary of ROI info with the following fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+
+ :param roidict: Dictionary of ROIs
+ :param str order: Field used for ordering the ROIs.
+ One of "from", "to", "type".
+ None (default) for no ordering, or same order as specified
+ in parameter ``roidict`` if provided as an OrderedDict.
+ """
+ assert order in [None, "from", "to", "type"]
+ self.clear()
+
+ # backward compatibility since 0.10.0
+ if isinstance(rois, dict):
+ for roiName, roi in rois.items():
+ if isinstance(roi, ROI):
+ _roi = roi
+ else:
+ roi['name'] = roiName
+ _roi = ROI._fromDict(roi)
+ self.addRoi(_roi)
+ else:
+ for roi in rois:
+ assert isinstance(roi, ROI)
+ self.addRoi(roi)
+ self._updateMarkers()
+
+ def addRoi(self, roi):
+ """
+
+ :param :class:`ROI` roi: roi to add to the table
+ """
+ assert isinstance(roi, ROI)
+ self._getItem(name='ID', row=None, roi=roi)
+ self._roiDict[roi.getID()] = roi
+ self._markersHandler.add(roi, _RoiMarkerHandler(roi, self.plot))
+ self._updateRoiInfo(roi.getID())
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+ # set it as the active one
+ self.setActiveRoi(roi)
+
+ def _getItem(self, name, row, roi):
+ if row:
+ item = self.item(row, self.COLUMNS_INDEX[name])
+ else:
+ item = None
+ if item:
+ return item
+ else:
+ if name == 'ID':
+ assert roi
+ if roi.getID() in self._roiToItems:
+ return self._roiToItems[roi.getID()]
+ else:
+ # create a new row
+ row = self.rowCount()
+ self.setRowCount(self.rowCount() + 1)
+ item = qt.QTableWidgetItem(str(roi.getID()),
+ type=qt.QTableWidgetItem.Type)
+ self._roiToItems[roi.getID()] = item
+ elif name == 'ROI':
+ item = qt.QTableWidgetItem(roi.getName() if roi else '',
+ type=qt.QTableWidgetItem.Type)
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name == 'Type':
+ item = qt.QTableWidgetItem(type=qt.QTableWidgetItem.Type)
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ elif name in ('To', 'From'):
+ item = _FloatItem()
+ if roi.getName().upper() in ('ICR', 'DEFAULT'):
+ item.setFlags(qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled)
+ else:
+ item.setFlags(qt.Qt.ItemIsSelectable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsEditable)
+ elif name in ('Raw Counts', 'Net Counts', 'Raw Area', 'Net Area'):
+ item = _FloatItem()
+ item.setFlags((qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled))
+ else:
+ raise ValueError('item type not recognized')
+
+ self.setItem(row, self.COLUMNS_INDEX[name], item)
+ return item
+
+ def _itemChanged(self, item):
+ def getRoi():
+ IDItem = self.item(item.row(), self.COLUMNS_INDEX['ID'])
+ assert IDItem
+ id = int(IDItem.text())
+ assert id in self._roiDict
+ roi = self._roiDict[id]
+ return roi
+
+ def signalChanged(roi):
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ self._userIsEditingRoi = True
+ if item.column() in (self.COLUMNS_INDEX['To'], self.COLUMNS_INDEX['From']):
+ roi = getRoi()
+
+ if item.text() not in ('', self.INFO_NOT_FOUND):
+ try:
+ value = float(item.text())
+ except ValueError:
+ value = 0
+ changed = False
+ if item.column() == self.COLUMNS_INDEX['To']:
+ if value != roi.getTo():
+ roi.setTo(value)
+ changed = True
+ else:
+ assert(item.column() == self.COLUMNS_INDEX['From'])
+ if value != roi.getFrom():
+ roi.setFrom(value)
+ changed = True
+ if changed:
+ self._updateMarker(roi.getName())
+ signalChanged(roi)
+
+ if item.column() is self.COLUMNS_INDEX['ROI']:
+ roi = getRoi()
+ if roi.getName() != item.text():
+ roi.setName(item.text())
+ self._markersHandler.getMarkerHandler(roi.getID()).updateTexts()
+ signalChanged(roi)
+
+ self._userIsEditingRoi = False
+
+ def deleteActiveRoi(self):
+ """
+ remove the current active roi
+ """
+ activeItems = self.selectedItems()
+ if len(activeItems) == 0:
+ return
+ old = self.blockSignals(True) # avoid several emission of sigROISignal
+ roiToRm = set()
+ for item in activeItems:
+ row = item.row()
+ itemID = self.item(row, self.COLUMNS_INDEX['ID'])
+ roiToRm.add(self._roiDict[int(itemID.text())])
+ [self.removeROI(roi) for roi in roiToRm]
+ self.blockSignals(old)
+ self.setActiveRoi(None)
+
+ def removeROI(self, roi):
+ """
+ remove the requested roi
+
+ :param str name: the name of the roi to remove from the table
+ """
+ if roi and roi.getID() in self._roiToItems:
+ item = self._roiToItems[roi.getID()]
+ self.removeRow(item.row())
+ del self._roiToItems[roi.getID()]
+
+ assert roi.getID() in self._roiDict
+ del self._roiDict[roi.getID()]
+ self._markersHandler.remove(roi)
+
+ callback = functools.partial(WeakMethodProxy(self._updateRoiInfo),
+ roi.getID())
+ roi.sigChanged.connect(callback)
+
+ def setActiveRoi(self, roi):
+ """
+ Define the given roi as the active one.
+
+ .. warning:: this roi should already be registred / added to the table
+
+ :param :class:`ROI` roi: the roi to defined as active
+ """
+ if roi is None:
+ self.clearSelection()
+ self._markersHandler.setActiveRoi(None)
+ self.activeROIChanged.emit()
+ else:
+ assert isinstance(roi, ROI)
+ if roi and roi.getID() in self._roiToItems.keys():
+ # avoid several call back to setActiveROI
+ old = self.blockSignals(True)
+ self.selectRow(self._roiToItems[roi.getID()].row())
+ self.blockSignals(old)
+ self._markersHandler.setActiveRoi(roi)
+ self.activeROIChanged.emit()
+
+ def _updateRoiInfo(self, roiID):
+ if self._userIsEditingRoi is True:
+ return
+ if roiID not in self._roiDict:
+ return
+ roi = self._roiDict[roiID]
+ if roi.isICR():
+ activeCurve = self.plot.getActiveCurve()
+ if activeCurve:
+ xData = activeCurve.getXData()
+ if len(xData) > 0:
+ min, max = min_max(xData)
+ roi.blockSignals(True)
+ roi.setFrom(min)
+ roi.setTo(max)
+ roi.blockSignals(False)
+
+ itemID = self._getItem(name='ID', roi=roi, row=None)
+ itemName = self._getItem(name='ROI', row=itemID.row(), roi=roi)
+ itemName.setText(roi.getName())
+
+ itemType = self._getItem(name='Type', row=itemID.row(), roi=roi)
+ itemType.setText(roi.getType() or self.INFO_NOT_FOUND)
+
+ itemFrom = self._getItem(name='From', row=itemID.row(), roi=roi)
+ fromdata = str(roi.getFrom()) if roi.getFrom() is not None else self.INFO_NOT_FOUND
+ itemFrom.setText(fromdata)
+
+ itemTo = self._getItem(name='To', row=itemID.row(), roi=roi)
+ todata = str(roi.getTo()) if roi.getTo() is not None else self.INFO_NOT_FOUND
+ itemTo.setText(todata)
+
+ rawCounts, netCounts = roi.computeRawAndNetCounts(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawCounts = self._getItem(name='Raw Counts', row=itemID.row(),
+ roi=roi)
+ rawCounts = str(rawCounts) if rawCounts is not None else self.INFO_NOT_FOUND
+ itemRawCounts.setText(rawCounts)
+
+ itemNetCounts = self._getItem(name='Net Counts', row=itemID.row(),
+ roi=roi)
+ netCounts = str(netCounts) if netCounts is not None else self.INFO_NOT_FOUND
+ itemNetCounts.setText(netCounts)
+
+ rawArea, netArea = roi.computeRawAndNetArea(
+ curve=self.plot.getActiveCurve(just_legend=False))
+ itemRawArea = self._getItem(name='Raw Area', row=itemID.row(),
+ roi=roi)
+ rawArea = str(rawArea) if rawArea is not None else self.INFO_NOT_FOUND
+ itemRawArea.setText(rawArea)
+
+ itemNetArea = self._getItem(name='Net Area', row=itemID.row(),
+ roi=roi)
+ netArea = str(netArea) if netArea is not None else self.INFO_NOT_FOUND
+ itemNetArea.setText(netArea)
+
+ if self.activeRoi and roi.getID() == self.activeRoi.getID():
+ self.activeROIChanged.emit()
+
+ def currentChanged(self, current, previous):
+ if previous and current.row() != previous.row() and current.row() >= 0:
+ roiItem = self.item(current.row(),
+ self.COLUMNS_INDEX['ID'])
+
+ assert roiItem
+ self.setActiveRoi(self._roiDict[int(roiItem.text())])
+ self._markersHandler.updateAllMarkers()
+ qt.QTableWidget.currentChanged(self, current, previous)
+
+ @deprecation.deprecated(reason="Removed",
+ replacement="roidict and roidict.values()",
+ since_version="0.10.0")
+ def getROIListAndDict(self):
+ """
+
+ :return: the list of roi objects and the dictionary of roi name to roi
+ object.
+ """
+ roidict = self._roiDict
+ return list(roidict.values()), roidict
+
+ def calculateRois(self, roiList=None, roiDict=None):
+ """
+ Update values of all registred rois (raw and net counts in particular)
+
+ :param roiList: deprecated parameter
+ :param roiDict: deprecated parameter
+ """
+ if roiDict:
+ deprecation.deprecated_warning(name='roiDict', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+ if roiList:
+ deprecation.deprecated_warning(name='roiList', type_='Parameter',
+ reason='Unused parameter',
+ since_version="0.10.0")
+
+ for roiID in self._roiDict:
+ self._updateRoiInfo(roiID)
+
+ def _updateMarker(self, roiID):
+ """Make sure the marker of the given roi name is updated"""
+ if self._showAllMarkers or (self.activeRoi
+ and self.activeRoi.getName() == roiID):
+ self._updateMarkers()
+
+ def _updateMarkers(self):
+ if self._showAllMarkers is True:
+ self._markersHandler.updateMarkers()
+ else:
+ if not self.activeRoi or not self.plot:
+ return
+ assert isinstance(self.activeRoi, ROI)
+ markerHandler = self._markersHandler.getMarkerHandler(self.activeRoi.getID())
+ if markerHandler is not None:
+ markerHandler.updateMarkers()
+
+ def getRois(self, order):
+ """
+ Return the currently defined ROIs, as an ordered dict.
+
+ The dictionary keys are the ROI names.
+ Each value is a :class:`ROI` object..
+
+ :param order: Field used for ordering the ROIs.
+ One of "from", "to", "type", "netcounts", "rawcounts".
+ None (default) to get the same order as displayed in the widget.
+ :return: Ordered dictionary of ROI information
+ """
+
+ if order is None or order.lower() == "none":
+ ordered_roilist = list(self._roiDict.values())
+ res = OrderedDict([(roi.getName(), self._roiDict[roi.getID()]) for roi in ordered_roilist])
+ else:
+ assert order in ["from", "to", "type", "netcounts", "rawcounts"]
+ ordered_roilist = sorted(self._roiDict.keys(),
+ key=lambda roi_id: self._roiDict[roi_id].get(order))
+ res = OrderedDict([(roi.getName(), self._roiDict[id]) for id in ordered_roilist])
+
+ return res
+
+ def save(self, filename):
+ """
+ Save current ROIs of the widget as a dict of ROI to a file.
+
+ :param str filename: The file to which to save the ROIs
+ """
+ roilist = []
+ roidict = {}
+ for roiID, roi in self._roiDict.items():
+ roilist.append(roi.toDict())
+ roidict[roi.getName()] = roi.toDict()
+ datadict = {'ROI': {'roilist': roilist, 'roidict': roidict}}
+ dictdump.dump(datadict, filename)
+
+ def load(self, filename):
+ """
+ Load ROI widget information from a file storing a dict of ROI.
+
+ :param str filename: The file from which to load ROI
+ """
+ roisDict = dictdump.load(filename)
+ rois = []
+
+ # Remove rawcounts and netcounts from ROIs
+ for roiDict in roisDict['ROI']['roidict'].values():
+ roiDict.pop('rawcounts', None)
+ roiDict.pop('netcounts', None)
+ rois.append(ROI._fromDict(roiDict))
+
+ self.setRois(rois)
+
+ def showAllMarkers(self, _show=True):
+ """
+
+ :param bool _show: if true show all the markers of all the ROIs
+ boundaries otherwise will only show the one of
+ the active ROI.
+ """
+ self._markersHandler.setShowAllMarkers(_show)
+
+ def setMiddleROIMarkerFlag(self, flag=True):
+ """
+ Activate or deactivate middle marker.
+
+ This allows shifting both min and max limits at once, by dragging
+ a marker located in the middle.
+
+ :param bool flag: True to activate middle ROI marker
+ """
+ self._markersHandler._middleROIMarkerFlag = flag
+
+ def _handleROIMarkerEvent(self, ddict):
+ """Handle plot signals related to marker events."""
+ if ddict['event'] == 'markerMoved':
+ label = ddict['label']
+ roiID = self._markersHandler.getRoiID(markerID=label)
+ if roiID is not None:
+ # avoid several emission of sigROISignal
+ old = self.blockSignals(True)
+ self._markersHandler.changePosition(markerID=label,
+ x=ddict['x'])
+ self.blockSignals(old)
+ self._updateRoiInfo(roiID)
+
+ def showEvent(self, event):
+ self._visibilityChangedHandler(visible=True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self._visibilityChangedHandler(visible=False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _visibilityChangedHandler(self, visible):
+ """Handle widget's visibility updates.
+
+ It is connected to plot signals only when visible.
+ """
+ if visible:
+ assert self.plot
+ if self._isConnected is False:
+ self.plot.sigPlotSignal.connect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ self._isConnected = True
+ self.calculateRois()
+ else:
+ if self._isConnected:
+ self.plot.sigPlotSignal.disconnect(self._handleROIMarkerEvent)
+ self.plot.sigActiveCurveChanged.disconnect(self._activeCurveChanged)
+ self._isConnected = False
+
+ def _activeCurveChanged(self, curve):
+ self.calculateRois()
+
+ def setCountsVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.showColumn(self.COLUMNS_INDEX['Net Counts'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Counts'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Counts'])
+
+ def setAreaVisible(self, visible):
+ """
+ Display the columns relative to areas or not
+
+ :param bool visible: True if the columns 'Raw Area' and 'Net Area'
+ should be visible.
+ """
+ if visible is True:
+ self.showColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.showColumn(self.COLUMNS_INDEX['Net Area'])
+ else:
+ self.hideColumn(self.COLUMNS_INDEX['Raw Area'])
+ self.hideColumn(self.COLUMNS_INDEX['Net Area'])
+
+ def fillFromROIDict(self, roilist=(), roidict=None, currentroi=None):
+ """
+ This function API is kept for compatibility.
+ But `setRois` should be preferred.
+
+ Set the ROIs by providing a list of ROI names and a dictionary
+ of ROI information for each ROI.
+ The ROI names must match an existing dictionary key.
+ The name list is used to provide an order for the ROIs.
+ The dictionary's values are sub-dictionaries containing 3
+ mandatory fields:
+
+ - ``"from"``: x coordinate of the left limit, as a float
+ - ``"to"``: x coordinate of the right limit, as a float
+ - ``"type"``: type of ROI, as a string (e.g "channels", "energy")
+
+ :param roilist: List of ROI names (keys of roidict)
+ :type roilist: List
+ :param dict roidict: Dict of ROI information
+ :param currentroi: Name of the selected ROI or None (no selection)
+ """
+ if roidict is not None:
+ self.setRois(roidict)
+ else:
+ self.setRois(roilist)
+ if currentroi:
+ self.setActiveRoi(currentroi)
+
+
+_indexNextROI = 0
+
+
+class ROI(_RegionOfInterestBase):
+ """The Region Of Interest is defined by:
+
+ - A name
+ - A type. The type is the label of the x axis. This can be used to apply or
+ not some ROI to a curve and do some post processing.
+ - The x coordinate of the left limit (fromdata)
+ - The x coordinate of the right limit (todata)
+
+ :param str: name of the ROI
+ :param fromdata: left limit of the roi
+ :param todata: right limit of the roi
+ :param type: type of the ROI
+ """
+
+ sigChanged = qt.Signal()
+ """Signal emitted when the ROI is edited"""
+
+ def __init__(self, name, fromdata=None, todata=None, type_=None):
+ _RegionOfInterestBase.__init__(self)
+ self.setName(name)
+ global _indexNextROI
+ self._id = _indexNextROI
+ _indexNextROI += 1
+
+ self._fromdata = fromdata
+ self._todata = todata
+ self._type = type_ or 'Default'
+
+ self.sigItemChanged.connect(self.__itemChanged)
+
+ def __itemChanged(self, event):
+ """Handle name change"""
+ if event == items.ItemChangedType.NAME:
+ self.sigChanged.emit()
+
+ def getID(self):
+ """
+
+ :return int: the unique ID of the ROI
+ """
+ return self._id
+
+ def setType(self, type_):
+ """
+
+ :param str type_:
+ """
+ if self._type != type_:
+ self._type = type_
+ self.sigChanged.emit()
+
+ def getType(self):
+ """
+
+ :return str: the type of the ROI.
+ """
+ return self._type
+
+ def setFrom(self, frm):
+ """
+
+ :param frm: set x coordinate of the left limit
+ """
+ if self._fromdata != frm:
+ self._fromdata = frm
+ self.sigChanged.emit()
+
+ def getFrom(self):
+ """
+
+ :return: x coordinate of the left limit
+ """
+ return self._fromdata
+
+ def setTo(self, to):
+ """
+
+ :param to: x coordinate of the right limit
+ """
+ if self._todata != to:
+ self._todata = to
+ self.sigChanged.emit()
+
+ def getTo(self):
+ """
+
+ :return: x coordinate of the right limit
+ """
+ return self._todata
+
+ def getMiddle(self):
+ """
+
+ :return: middle position between 'from' and 'to' values
+ """
+ return 0.5 * (self.getFrom() + self.getTo())
+
+ def toDict(self):
+ """
+
+ :return: dict containing the roi parameters
+ """
+ ddict = {
+ 'type': self._type,
+ 'name': self.getName(),
+ 'from': self._fromdata,
+ 'to': self._todata,
+ }
+ if hasattr(self, '_extraInfo'):
+ ddict.update(self._extraInfo)
+ return ddict
+
+ @staticmethod
+ def _fromDict(dic):
+ assert 'name' in dic
+ roi = ROI(name=dic['name'])
+ roi._extraInfo = {}
+ for key in dic:
+ if key == 'from':
+ roi.setFrom(dic['from'])
+ elif key == 'to':
+ roi.setTo(dic['to'])
+ elif key == 'type':
+ roi.setType(dic['type'])
+ else:
+ roi._extraInfo[key] = dic[key]
+
+ return roi
+
+ def isICR(self):
+ """
+
+ :return: True if the ROI is the `ICR`
+ """
+ return self.getName() == 'ICR'
+
+ def computeRawAndNetCounts(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw count: Points values sum of the curve in the defined Region Of
+ Interest.
+
+ .. image:: img/rawCounts.png
+
+ - Net count: Raw counts minus background
+
+ .. image:: img/netCounts.png
+
+ :param CurveItem curve:
+ :return tuple: rawCount, netCount
+ """
+ assert isinstance(curve, items.Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ idx = numpy.nonzero((self._fromdata <= x) &
+ (x <= self._todata))[0]
+ if len(idx):
+ xw = x[idx]
+ yw = y[idx]
+ rawCounts = yw.sum(dtype=numpy.float64)
+ deltaX = xw[-1] - xw[0]
+ deltaY = yw[-1] - yw[0]
+ if deltaX > 0.0:
+ slope = (deltaY / deltaX)
+ background = yw[0] + slope * (xw - xw[0])
+ netCounts = (rawCounts -
+ background.sum(dtype=numpy.float64))
+ else:
+ netCounts = 0.0
+ else:
+ rawCounts = 0.0
+ netCounts = 0.0
+ return rawCounts, netCounts
+
+ def computeRawAndNetArea(self, curve):
+ """Compute the Raw and net counts in the ROI for the given curve.
+
+ - Raw area: integral of the curve between the min ROI point and the
+ max ROI point to the y = 0 line.
+
+ .. image:: img/rawArea.png
+
+ - Net area: Raw counts minus background
+
+ .. image:: img/netArea.png
+
+ :param CurveItem curve:
+ :return tuple: rawArea, netArea
+ """
+ assert isinstance(curve, items.Curve) or curve is None
+
+ if curve is None:
+ return None, None
+
+ x = curve.getXData(copy=False)
+ y = curve.getYData(copy=False)
+
+ y = y[(x >= self._fromdata) & (x <= self._todata)]
+ x = x[(x >= self._fromdata) & (x <= self._todata)]
+
+ if x.size == 0:
+ return 0.0, 0.0
+
+ rawArea = numpy.trapz(y, x=x)
+ # to speed up and avoid an intersection calculation we are taking the
+ # closest index to the ROI
+ closestXLeftIndex = (numpy.abs(x - self.getFrom())).argmin()
+ closestXRightIndex = (numpy.abs(x - self.getTo())).argmin()
+ yBackground = y[closestXLeftIndex], y[closestXRightIndex]
+ background = numpy.trapz(yBackground, x=x)
+ netArea = rawArea - background
+ return rawArea, netArea
+
+ @docstring(_RegionOfInterestBase)
+ def contains(self, position):
+ return self._fromdata <= position[0] <= self._todata
+
+
+class _RoiMarkerManager(object):
+ """
+ Deal with all the ROI markers
+ """
+ def __init__(self):
+ self._roiMarkerHandlers = {}
+ self._middleROIMarkerFlag = False
+ self._showAllMarkers = False
+ self._activeRoi = None
+
+ def setActiveRoi(self, roi):
+ self._activeRoi = roi
+ self.updateAllMarkers()
+
+ def setShowAllMarkers(self, show):
+ if show != self._showAllMarkers:
+ self._showAllMarkers = show
+ self.updateAllMarkers()
+
+ def add(self, roi, markersHandler):
+ assert isinstance(roi, ROI)
+ assert isinstance(markersHandler, _RoiMarkerHandler)
+ if roi.getID() in self._roiMarkerHandlers:
+ raise ValueError('roi with the same ID already existing')
+ else:
+ self._roiMarkerHandlers[roi.getID()] = markersHandler
+
+ def getMarkerHandler(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ return self._roiMarkerHandlers[roiID]
+ else:
+ return None
+
+ def clear(self):
+ roisHandler = list(self._roiMarkerHandlers.values())
+ for roiHandler in roisHandler:
+ self.remove(roiHandler.roi)
+
+ def remove(self, roi):
+ if roi is None:
+ return
+ assert isinstance(roi, ROI)
+ if roi.getID() in self._roiMarkerHandlers:
+ self._roiMarkerHandlers[roi.getID()].clear()
+ del self._roiMarkerHandlers[roi.getID()]
+
+ def hasMarker(self, markerID):
+ assert type(markerID) is str
+ return self.getMarker(markerID) is not None
+
+ def changePosition(self, markerID, x):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ markerHandler.changePosition(markerID=markerID, x=x)
+
+ def updateMarker(self, markerID):
+ markerHandler = self.getMarker(markerID)
+ if markerHandler is None:
+ raise ValueError('Marker %s not register' % markerID)
+ roiID = self.getRoiID(markerID)
+ visible = (self._activeRoi and self._activeRoi.getID() == roiID) or self._showAllMarkers is True
+ markerHandler.setVisible(visible)
+ markerHandler.updateAllMarkers()
+
+ def updateRoiMarkers(self, roiID):
+ if roiID in self._roiMarkerHandlers:
+ visible = ((self._activeRoi and self._activeRoi.getID() == roiID)
+ or self._showAllMarkers is True)
+ _roi = self._roiMarkerHandlers[roiID]._roi()
+ if _roi and not _roi.isICR():
+ self._roiMarkerHandlers[roiID].showMiddleMarker(self._middleROIMarkerFlag)
+ self._roiMarkerHandlers[roiID].setVisible(visible)
+ self._roiMarkerHandlers[roiID].updateMarkers()
+
+ def getMarker(self, markerID):
+ assert type(markerID) is str
+ for marker in list(self._roiMarkerHandlers.values()):
+ if marker.hasMarker(markerID):
+ return marker
+
+ def updateMarkers(self):
+ for markerHandler in list(self._roiMarkerHandlers.values()):
+ markerHandler.updateMarkers()
+
+ def getRoiID(self, markerID):
+ for roiID, markerHandler in self._roiMarkerHandlers.items():
+ if markerHandler.hasMarker(markerID):
+ return roiID
+ return None
+
+ def setShowMiddleMarkers(self, show):
+ self._middleROIMarkerFlag = show
+ self._roiMarkerHandlers.updateAllMarkers()
+
+ def updateAllMarkers(self):
+ for roiID in self._roiMarkerHandlers:
+ self.updateRoiMarkers(roiID)
+
+ def getVisibleRois(self):
+ res = {}
+ for roiID, roiHandler in self._roiMarkerHandlers.items():
+ markers = (roiHandler.getMarker('min'), roiHandler.getMarker('max'),
+ roiHandler.getMarker('middle'))
+ for marker in markers:
+ if marker.isVisible():
+ if roiID not in res:
+ res[roiID] = []
+ res[roiID].append(marker)
+ return res
+
+
+class _RoiMarkerHandler(object):
+ """Used to deal with ROI markers used in ROITable"""
+ def __init__(self, roi, plot):
+ assert roi and isinstance(roi, ROI)
+ assert plot
+
+ self._roi = weakref.ref(roi)
+ self._plot = weakref.ref(plot)
+ self._draggable = False if roi.isICR() else True
+ self._color = 'black' if roi.isICR() else 'blue'
+ self._displayMidMarker = False
+ self._visible = True
+
+ @property
+ def draggable(self):
+ return self._draggable
+
+ @property
+ def plot(self):
+ return self._plot()
+
+ def clear(self):
+ if self.plot and self.roi:
+ self.plot.removeMarker(self._markerID('min'))
+ self.plot.removeMarker(self._markerID('max'))
+ self.plot.removeMarker(self._markerID('middle'))
+
+ @property
+ def roi(self):
+ return self._roi()
+
+ def setVisible(self, visible):
+ if visible != self._visible:
+ self._visible = visible
+ self.updateMarkers()
+
+ def showMiddleMarker(self, visible):
+ if self.draggable is False and visible is True:
+ _logger.warning("ROI is not draggable. Won't display middle marker")
+ return
+ self._displayMidMarker = visible
+ self.getMarker('middle').setVisible(self._displayMidMarker)
+
+ def updateMarkers(self):
+ if self.roi is None:
+ return
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+ self._updateMiddleMarkerPos()
+
+ def _updateMinMarkerPos(self):
+ self.getMarker('min').setPosition(x=self.roi.getFrom(), y=None)
+ self.getMarker('min').setVisible(self._visible)
+
+ def _updateMaxMarkerPos(self):
+ self.getMarker('max').setPosition(x=self.roi.getTo(), y=None)
+ self.getMarker('max').setVisible(self._visible)
+
+ def _updateMiddleMarkerPos(self):
+ self.getMarker('middle').setPosition(x=self.roi.getMiddle(), y=None)
+ self.getMarker('middle').setVisible(self._displayMidMarker and self._visible)
+
+ def getMarker(self, markerType):
+ if self.plot is None:
+ return None
+ assert markerType in ('min', 'max', 'middle')
+ if self.plot._getMarker(self._markerID(markerType)) is None:
+ assert self.roi
+ if markerType == 'min':
+ val = self.roi.getFrom()
+ elif markerType == 'max':
+ val = self.roi.getTo()
+ else:
+ val = self.roi.getMiddle()
+
+ _color = self._color
+ if markerType == 'middle':
+ _color = 'yellow'
+ self.plot.addXMarker(val,
+ legend=self._markerID(markerType),
+ text=self.getMarkerName(markerType),
+ color=_color,
+ draggable=self.draggable)
+ return self.plot._getMarker(self._markerID(markerType))
+
+ def _markerID(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return '_'.join((str(self.roi.getID()), markerType))
+
+ def getMarkerName(self, markerType):
+ assert markerType in ('min', 'max', 'middle')
+ assert self.roi
+ return ' '.join((self.roi.getName(), markerType))
+
+ def updateTexts(self):
+ self.getMarker('min').setText(self.getMarkerName('min'))
+ self.getMarker('max').setText(self.getMarkerName('max'))
+ self.getMarker('middle').setText(self.getMarkerName('middle'))
+
+ def changePosition(self, markerID, x):
+ assert self.hasMarker(markerID)
+ markerType = self._getMarkerType(markerID)
+ assert markerType is not None
+ if self.roi is None:
+ return
+ if markerType == 'min':
+ self.roi.setFrom(x)
+ self._updateMiddleMarkerPos()
+ elif markerType == 'max':
+ self.roi.setTo(x)
+ self._updateMiddleMarkerPos()
+ else:
+ delta = x - 0.5 * (self.roi.getFrom() + self.roi.getTo())
+ self.roi.setFrom(self.roi.getFrom() + delta)
+ self.roi.setTo(self.roi.getTo() + delta)
+ self._updateMinMarkerPos()
+ self._updateMaxMarkerPos()
+
+ def hasMarker(self, marker):
+ return marker in (self._markerID('min'),
+ self._markerID('max'),
+ self._markerID('middle'))
+
+ def _getMarkerType(self, markerID):
+ if markerID.endswith('_min'):
+ return 'min'
+ elif markerID.endswith('_max'):
+ return 'max'
+ elif markerID.endswith('_middle'):
+ return 'middle'
+ else:
+ return None
+
+
+class CurvesROIDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`CurvesROIWidget` connected to a PlotWindow.
+
+ It makes the link between the :class:`CurvesROIWidget` and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ :param name: See :class:`QDockWidget`
+ """
+ sigROISignal = qt.Signal(object)
+ """Deprecated signal for backward compatibility with silx < 0.7.
+ Prefer connecting directly to :attr:`CurvesRoiWidget.sigRoiSignal`
+ """
+
+ def __init__(self, parent=None, plot=None, name=None):
+ super(CurvesROIDockWidget, self).__init__(name, parent)
+
+ assert plot is not None
+ self.plot = plot
+ self.roiWidget = CurvesROIWidget(self, name, plot=plot)
+ """Main widget of type :class:`CurvesROIWidget`"""
+
+ # convenience methods to offer a simpler API allowing to ignore
+ # the details of the underlying implementation
+ # (ALL DEPRECATED)
+ self.calculateROIs = self.calculateRois = self.roiWidget.calculateRois
+ self.setRois = self.roiWidget.setRois
+ self.getRois = self.roiWidget.getRois
+
+ self.roiWidget.sigROISignal.connect(self._forwardSigROISignal)
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self.roiWidget)
+
+ self.setAreaVisible = self.roiWidget.roiTable.setAreaVisible
+ self.setCountsVisible = self.roiWidget.roiTable.setCountsVisible
+
+ def _forwardSigROISignal(self, ddict):
+ # emit deprecated signal for backward compatibility (silx < 0.7)
+ self.sigROISignal.emit(ddict)
+
+ def toggleViewAction(self):
+ """Returns a checkable action that shows or closes this widget.
+
+ See :class:`QMainWindow`.
+ """
+ action = super(CurvesROIDockWidget, self).toggleViewAction()
+ action.setIcon(icons.getQIcon('plot-roi'))
+ return action
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
+ qt.QDockWidget.showEvent(self, event)
+
+ @property
+ def currentROI(self):
+ return self.roiWidget.currentRoi
diff --git a/src/silx/gui/plot/ImageStack.py b/src/silx/gui/plot/ImageStack.py
new file mode 100644
index 0000000..1588a31
--- /dev/null
+++ b/src/silx/gui/plot/ImageStack.py
@@ -0,0 +1,640 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Image stack view with data prefetch capabilty."""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "04/03/2019"
+
+
+from silx.gui import icons, qt
+from silx.gui.plot import Plot2D
+from silx.gui.utils import concurrent
+from silx.io.url import DataUrl
+from silx.io.utils import get_data
+from collections import OrderedDict
+from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
+import time
+import threading
+import typing
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class _PlotWithWaitingLabel(qt.QWidget):
+ """Image plot widget with an overlay 'waiting' status.
+ """
+
+ class AnimationThread(threading.Thread):
+ def __init__(self, label):
+ self.running = True
+ self._label = label
+ self.animated_icon = icons.getWaitIcon()
+ self.animated_icon.register(self._label)
+ super(_PlotWithWaitingLabel.AnimationThread, self).__init__()
+
+ def run(self):
+ while self.running:
+ time.sleep(0.05)
+ icon = self.animated_icon.currentIcon()
+ self.future_result = concurrent.submitToQtMainThread(
+ self._label.setPixmap, icon.pixmap(30, state=qt.QIcon.On))
+
+ def stop(self):
+ """Stop the update thread"""
+ if self.running:
+ self.animated_icon.unregister(self._label)
+ self.running = False
+ self.join(2)
+
+ def __init__(self, parent):
+ super(_PlotWithWaitingLabel, self).__init__(parent=parent)
+ self._autoResetZoom = True
+ layout = qt.QStackedLayout(self)
+ layout.setStackingMode(qt.QStackedLayout.StackAll)
+
+ self._waiting_label = qt.QLabel(parent=self)
+ self._waiting_label.setAlignment(qt.Qt.AlignHCenter | qt.Qt.AlignVCenter)
+ layout.addWidget(self._waiting_label)
+
+ self._plot = Plot2D(parent=self)
+ layout.addWidget(self._plot)
+
+ self.updateThread = _PlotWithWaitingLabel.AnimationThread(self._waiting_label)
+ self.updateThread.start()
+
+ def close(self) -> bool:
+ super(_PlotWithWaitingLabel, self).close()
+ self.stopUpdateThread()
+
+ def stopUpdateThread(self):
+ self.updateThread.stop()
+
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._autoResetZoom = reset
+ if self._autoResetZoom:
+ self._plot.resetZoom()
+
+ def isAutoResetZoom(self):
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._autoResetZoom
+
+ def setWaiting(self, activate=True):
+ if activate is True:
+ self._plot.clear()
+ self._waiting_label.show()
+ else:
+ self._waiting_label.hide()
+
+ def setData(self, data):
+ self.setWaiting(activate=False)
+ self._plot.addImage(data=data, resetzoom=self._autoResetZoom)
+
+ def clear(self):
+ self._plot.clear()
+ self.setWaiting(False)
+
+ def getPlotWidget(self):
+ return self._plot
+
+
+class _HorizontalSlider(HorizontalSliderWithBrowser):
+
+ sigCurrentUrlIndexChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super(_HorizontalSlider, self).__init__(parent=parent)
+ # connect signal / slot
+ self.valueChanged.connect(self._urlChanged)
+
+ def setUrlIndex(self, index):
+ self.setValue(index)
+ self.sigCurrentUrlIndexChanged.emit(index)
+
+ def _urlChanged(self, value):
+ self.sigCurrentUrlIndexChanged.emit(value)
+
+
+class UrlList(qt.QWidget):
+ """List of URLs the user to select an URL"""
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the active/current url change"""
+
+ def __init__(self, parent=None):
+ super(UrlList, self).__init__(parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._listWidget = qt.QListWidget(parent=self)
+ self.layout().addWidget(self._listWidget)
+
+ # connect signal / Slot
+ self._listWidget.currentItemChanged.connect(self._notifyCurrentUrlChanged)
+
+ # expose API
+ self.currentItem = self._listWidget.currentItem
+
+ def setUrls(self, urls: list) -> None:
+ url_names = []
+ [url_names.append(url.path()) for url in urls]
+ self._listWidget.addItems(url_names)
+
+ def _notifyCurrentUrlChanged(self, current, previous):
+ if current is None:
+ pass
+ else:
+ self.sigCurrentUrlChanged.emit(current.text())
+
+ def setUrl(self, url: DataUrl) -> None:
+ assert isinstance(url, DataUrl)
+ sel_items = self._listWidget.findItems(url.path(), qt.Qt.MatchExactly)
+ if sel_items is None:
+ _logger.warning(url.path(), ' is not registered in the list.')
+ elif len(sel_items) > 0:
+ item = sel_items[0]
+ self._listWidget.setCurrentItem(item)
+ self.sigCurrentUrlChanged.emit(item.text())
+
+ def clear(self):
+ self._listWidget.clear()
+
+
+class _ToggleableUrlSelectionTable(qt.QWidget):
+
+ _BUTTON_ICON = qt.QStyle.SP_ToolBarHorizontalExtensionButton # noqa
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the active/current url change"""
+
+ def __init__(self, parent=None) -> None:
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QGridLayout())
+ self._toggleButton = qt.QPushButton(parent=self)
+ self.layout().addWidget(self._toggleButton, 0, 2, 1, 1)
+ self._toggleButton.setSizePolicy(qt.QSizePolicy.Fixed,
+ qt.QSizePolicy.Fixed)
+
+ self._urlsTable = UrlList(parent=self)
+ self.layout().addWidget(self._urlsTable, 1, 1, 1, 2)
+
+ # set up
+ self._setButtonIcon(show=True)
+
+ # Signal / slot connection
+ self._toggleButton.clicked.connect(self.toggleUrlSelectionTable)
+ self._urlsTable.sigCurrentUrlChanged.connect(self._propagateSignal)
+
+ # expose API
+ self.setUrls = self._urlsTable.setUrls
+ self.setUrl = self._urlsTable.setUrl
+ self.currentItem = self._urlsTable.currentItem
+
+ def toggleUrlSelectionTable(self):
+ visible = not self.urlSelectionTableIsVisible()
+ self._setButtonIcon(show=visible)
+ self._urlsTable.setVisible(visible)
+
+ def _setButtonIcon(self, show):
+ style = qt.QApplication.instance().style()
+ # return a QIcon
+ icon = style.standardIcon(self._BUTTON_ICON)
+ if show is False:
+ pixmap = icon.pixmap(32, 32).transformed(qt.QTransform().scale(-1, 1))
+ icon = qt.QIcon(pixmap)
+ self._toggleButton.setIcon(icon)
+
+ def urlSelectionTableIsVisible(self):
+ return self._urlsTable.isVisible()
+
+ def _propagateSignal(self, url):
+ self.sigCurrentUrlChanged.emit(url)
+
+ def clear(self):
+ self._urlsTable.clear()
+
+
+class UrlLoader(qt.QThread):
+ """
+ Thread use to load DataUrl
+ """
+ def __init__(self, parent, url):
+ super(UrlLoader, self).__init__(parent=parent)
+ assert isinstance(url, DataUrl)
+ self.url = url
+ self.data = None
+
+ def run(self):
+ try:
+ self.data = get_data(self.url)
+ except IOError:
+ self.data = None
+
+
+class ImageStack(qt.QMainWindow):
+ """Widget loading on the fly images contained the given urls.
+
+ It prefetches images close to the displayed one.
+ """
+
+ N_PRELOAD = 10
+
+ sigLoaded = qt.Signal(str)
+ """Signal emitted when new data is available"""
+
+ sigCurrentUrlChanged = qt.Signal(str)
+ """Signal emitted when the current url change"""
+
+ def __init__(self, parent=None) -> None:
+ super(ImageStack, self).__init__(parent)
+ self.__n_prefetch = ImageStack.N_PRELOAD
+ self._loadingThreads = []
+ self.setWindowFlags(qt.Qt.Widget)
+ self._current_url = None
+ self._url_loader = UrlLoader
+ "class to instantiate for loading urls"
+
+ # main widget
+ self._plot = _PlotWithWaitingLabel(parent=self)
+ self._plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.setWindowTitle("Image stack")
+ self.setCentralWidget(self._plot)
+
+ # dock widget: url table
+ self._tableDockWidget = qt.QDockWidget(parent=self)
+ self._urlsTable = _ToggleableUrlSelectionTable(parent=self)
+ self._tableDockWidget.setWidget(self._urlsTable)
+ self._tableDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
+ self.addDockWidget(qt.Qt.RightDockWidgetArea, self._tableDockWidget)
+ # dock widget: qslider
+ self._sliderDockWidget = qt.QDockWidget(parent=self)
+ self._slider = _HorizontalSlider(parent=self)
+ self._sliderDockWidget.setWidget(self._slider)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._sliderDockWidget)
+ self._sliderDockWidget.setFeatures(qt.QDockWidget.DockWidgetMovable)
+
+ self.reset()
+
+ # connect signal / slot
+ self._urlsTable.sigCurrentUrlChanged.connect(self.setCurrentUrl)
+ self._slider.sigCurrentUrlIndexChanged.connect(self.setCurrentUrlIndex)
+
+ def close(self) -> bool:
+ self._freeLoadingThreads()
+ self._plot.close()
+ super(ImageStack, self).close()
+
+ def setUrlLoaderClass(self, urlLoader: typing.Type[UrlLoader]) -> None:
+ """
+
+ :param urlLoader: define the class to call for loading urls.
+ warning: this should be a class object and not a
+ class instance.
+ """
+ assert isinstance(urlLoader, type(UrlLoader))
+ self._url_loader = urlLoader
+
+ def getUrlLoaderClass(self):
+ """
+
+ :return: class to instantiate for loading urls
+ :rtype: typing.Type[UrlLoader]
+ """
+ return self._url_loader
+
+ def _freeLoadingThreads(self):
+ for thread in self._loadingThreads:
+ thread.blockSignals(True)
+ thread.wait(5)
+ self._loadingThreads.clear()
+
+ def getPlotWidget(self) -> Plot2D:
+ """
+ Returns the PlotWidget contained in this window
+
+ :return: PlotWidget contained in this window
+ :rtype: Plot2D
+ """
+ return self._plot.getPlotWidget()
+
+ def reset(self) -> None:
+ """Clear the plot and remove any link to url"""
+ self._freeLoadingThreads()
+ self._urls = None
+ self._urlIndexes = None
+ self._urlData = OrderedDict({})
+ self._current_url = None
+ self._plot.clear()
+ self._urlsTable.clear()
+ self._slider.setMaximum(-1)
+
+ def _preFetch(self, urls: list) -> None:
+ """Pre-fetch the given urls if necessary
+
+ :param urls: list of DataUrl to prefetch
+ :type: list
+ """
+ for url in urls:
+ if url.path() not in self._urlData:
+ self._load(url)
+
+ def _load(self, url):
+ """
+ Launch background load of a DataUrl
+
+ :param url:
+ :type: DataUrl
+ """
+ assert isinstance(url, DataUrl)
+ url_path = url.path()
+ assert url_path in self._urlIndexes
+ loader = self._url_loader(parent=self, url=url)
+ loader.finished.connect(self._urlLoaded, qt.Qt.QueuedConnection)
+ self._loadingThreads.append(loader)
+ loader.start()
+
+ def _urlLoaded(self) -> None:
+ """
+
+ :param url: restul of DataUrl.path() function
+ :return:
+ """
+ sender = self.sender()
+ assert isinstance(sender, UrlLoader)
+ url = sender.url.path()
+ if url in self._urlIndexes:
+ self._urlData[url] = sender.data
+ if self.getCurrentUrl().path() == url:
+ self._plot.setData(self._urlData[url])
+ if sender in self._loadingThreads:
+ self._loadingThreads.remove(sender)
+ self.sigLoaded.emit(url)
+
+ def setNPrefetch(self, n: int) -> None:
+ """
+ Define the number of url to prefetch around
+
+ :param int n: number of url to prefetch on left and right sides.
+ In total n*2 DataUrl will be prefetch
+ """
+ self.__n_prefetch = n
+ current_url = self.getCurrentUrl()
+ if current_url is not None:
+ self.setCurrentUrl(current_url)
+
+ def getNPrefetch(self) -> int:
+ """
+
+ :return: number of url to prefetch on left and right sides. In total
+ will load 2* NPrefetch DataUrls
+ """
+ return self.__n_prefetch
+
+ def setUrls(self, urls: list) -> None:
+ """list of urls within an index. Warning: urls should contain an image
+ compatible with the silx.gui.plot.Plot class
+
+ :param urls: urls we want to set in the stack. Key is the index
+ (position in the stack), value is the DataUrl
+ :type: list
+ """
+ def createUrlIndexes():
+ indexes = OrderedDict()
+ for index, url in enumerate(urls):
+ indexes[index] = url
+ return indexes
+
+ urls_with_indexes = createUrlIndexes()
+ urlsToIndex = self._urlsToIndex(urls_with_indexes)
+ self.reset()
+ self._urls = urls_with_indexes
+ self._urlIndexes = urlsToIndex
+
+ old_url_table = self._urlsTable.blockSignals(True)
+ self._urlsTable.setUrls(urls=list(self._urls.values()))
+ self._urlsTable.blockSignals(old_url_table)
+
+ old_slider = self._slider.blockSignals(True)
+ self._slider.setMinimum(0)
+ self._slider.setMaximum(len(self._urls) - 1)
+ self._slider.blockSignals(old_slider)
+
+ if self.getCurrentUrl() in self._urls:
+ self.setCurrentUrl(self.getCurrentUrl())
+ else:
+ if len(self._urls.keys()) > 0:
+ first_url = self._urls[list(self._urls.keys())[0]]
+ self.setCurrentUrl(first_url)
+
+ def getUrls(self) -> tuple:
+ """
+
+ :return: tuple of urls
+ :rtype: tuple
+ """
+ return tuple(self._urlIndexes.keys())
+
+ def _getNextUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
+ """
+ return the next url in the stack
+
+ :param url: url for which we want the next url
+ :type: DataUrl
+ :return: next url in the stack or None if `url` is the last one
+ :rtype: Union[None, DataUrl]
+ """
+ assert isinstance(url, DataUrl)
+ if self._urls is None:
+ return None
+ else:
+ index = self._urlIndexes[url.path()]
+ indexes = list(self._urls.keys())
+ res = list(filter(lambda x: x > index, indexes))
+ if len(res) == 0:
+ return None
+ else:
+ return self._urls[res[0]]
+
+ def _getPreviousUrl(self, url: DataUrl) -> typing.Union[None, DataUrl]:
+ """
+ return the previous url in the stack
+
+ :param url: url for which we want the previous url
+ :type: DataUrl
+ :return: next url in the stack or None if `url` is the last one
+ :rtype: Union[None, DataUrl]
+ """
+ if self._urls is None:
+ return None
+ else:
+ index = self._urlIndexes[url.path()]
+ indexes = list(self._urls.keys())
+ res = list(filter(lambda x: x < index, indexes))
+ if len(res) == 0:
+ return None
+ else:
+ return self._urls[res[-1]]
+
+ def _getNNextUrls(self, n: int, url: DataUrl) -> list:
+ """
+ Deduce the next urls in the stack after `url`
+
+ :param n: the number of url store after `url`
+ :type: int
+ :param url: url for which we want n next url
+ :type: DataUrl
+ :return: list of next urls.
+ :rtype: list
+ """
+ res = []
+ next_free = self._getNextUrl(url=url)
+ while len(res) < n and next_free is not None:
+ assert isinstance(next_free, DataUrl)
+ res.append(next_free)
+ next_free = self._getNextUrl(res[-1])
+ return res
+
+ def _getNPreviousUrls(self, n: int, url: DataUrl):
+ """
+ Deduce the previous urls in the stack after `url`
+
+ :param n: the number of url store after `url`
+ :type: int
+ :param url: url for which we want n previous url
+ :type: DataUrl
+ :return: list of previous urls.
+ :rtype: list
+ """
+ res = []
+ next_free = self._getPreviousUrl(url=url)
+ while len(res) < n and next_free is not None:
+ res.insert(0, next_free)
+ next_free = self._getPreviousUrl(res[0])
+ return res
+
+ def setCurrentUrlIndex(self, index: int):
+ """
+ Define the url to be displayed
+
+ :param index: url to be displayed
+ :type: int
+ """
+ if index < 0:
+ return
+ if self._urls is None:
+ return
+ elif index >= len(self._urls):
+ raise ValueError('requested index out of bounds')
+ else:
+ return self.setCurrentUrl(self._urls[index])
+
+ def setCurrentUrl(self, url: typing.Union[DataUrl, str]) -> None:
+ """
+ Define the url to be displayed
+
+ :param url: url to be displayed
+ :type: DataUrl
+ """
+ assert isinstance(url, (DataUrl, str))
+ if isinstance(url, str):
+ url = DataUrl(path=url)
+ if url != self._current_url:
+ self._current_url = url
+ self.sigCurrentUrlChanged.emit(url.path())
+
+ old_url_table = self._urlsTable.blockSignals(True)
+ old_slider = self._slider.blockSignals(True)
+
+ self._urlsTable.setUrl(url)
+ self._slider.setUrlIndex(self._urlIndexes[url.path()])
+ if self._current_url is None:
+ self._plot.clear()
+ else:
+ if self._current_url.path() in self._urlData:
+ self._plot.setData(self._urlData[url.path()])
+ else:
+ self._load(url)
+ self._notifyLoading()
+ self._preFetch(self._getNNextUrls(self.__n_prefetch, url))
+ self._preFetch(self._getNPreviousUrls(self.__n_prefetch, url))
+ self._urlsTable.blockSignals(old_url_table)
+ self._slider.blockSignals(old_slider)
+
+ def getCurrentUrl(self) -> typing.Union[None, DataUrl]:
+ """
+
+ :return: url currently displayed
+ :rtype: Union[None, DataUrl]
+ """
+ return self._current_url
+
+ def getCurrentUrlIndex(self) -> typing.Union[None, int]:
+ """
+
+ :return: index of the url currently displayed
+ :rtype: Union[None, int]
+ """
+ if self._current_url is None:
+ return None
+ else:
+ return self._urlIndexes[self._current_url.path()]
+
+ @staticmethod
+ def _urlsToIndex(urls):
+ """util, return a dictionary with url as key and index as value"""
+ res = {}
+ for index, url in urls.items():
+ res[url.path()] = index
+ return res
+
+ def _notifyLoading(self):
+ """display a simple image of loading..."""
+ self._plot.setWaiting(activate=True)
+
+ def setAutoResetZoom(self, reset):
+ """
+ Should we reset the zoom when adding an image (eq. when browsing)
+
+ :param bool reset:
+ """
+ self._plot.setAutoResetZoom(reset)
+
+ def isAutoResetZoom(self) -> bool:
+ """
+
+ :return: True if a reset is done when the image change
+ :rtype: bool
+ """
+ return self._plot.isAutoResetZoom()
diff --git a/src/silx/gui/plot/ImageView.py b/src/silx/gui/plot/ImageView.py
new file mode 100644
index 0000000..f8b830a
--- /dev/null
+++ b/src/silx/gui/plot/ImageView.py
@@ -0,0 +1,1057 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 2D image with histograms on its sides.
+
+The :class:`ImageView` implements this widget, and
+:class:`ImageViewMainWindow` provides a main window with additional toolbar
+and status bar.
+
+Basic usage of :class:`ImageView` is through the following methods:
+
+- :meth:`ImageView.getColormap`, :meth:`ImageView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`ImageView.setImage` to update the displayed image.
+
+For an example of use, see `imageview.py` in :ref:`sample-code`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/04/2018"
+
+
+import logging
+import numpy
+import collections
+from typing import Union
+
+import silx
+from .. import qt
+from .. import colors
+from .. import icons
+
+from . import items, PlotWindow, PlotWidget, actions
+from ..colors import Colormap
+from ..colors import cursorColorForColormap
+from .tools import LimitsToolBar
+from .Profile import ProfileToolBar
+from ...utils.proxy import docstring
+from ...utils.deprecation import deprecated
+from ...utils.enum import Enum
+from .tools.RadarView import RadarView
+from .utils.axis import SyncAxes
+from ..utils import blockSignals
+from . import _utils
+from .tools.profile import manager
+from .tools.profile import rois
+from .actions import PlotAction
+
+_logger = logging.getLogger(__name__)
+
+
+ProfileSumResult = collections.namedtuple("ProfileResult",
+ ["dataXRange", "dataYRange",
+ 'histoH', 'histoHRange',
+ 'histoV', 'histoVRange',
+ "xCoords", "xData",
+ "yCoords", "yData"])
+
+
+def computeProfileSumOnRange(imageItem, xRange, yRange, cache=None):
+ """
+ Compute a full vertical and horizontal profile on an image item using a
+ a range in the plot referential.
+
+ Optionally takes a previous computed result to be able to skip the
+ computation.
+
+ :rtype: ProfileSumResult
+ """
+ data = imageItem.getValueData(copy=False)
+ origin = imageItem.getOrigin()
+ scale = imageItem.getScale()
+ height, width = data.shape
+
+ xMin, xMax = xRange
+ yMin, yMax = yRange
+
+ # Convert plot area limits to image coordinates
+ # and work in image coordinates (i.e., in pixels)
+ xMin = int((xMin - origin[0]) / scale[0])
+ xMax = int((xMax - origin[0]) / scale[0])
+ yMin = int((yMin - origin[1]) / scale[1])
+ yMax = int((yMax - origin[1]) / scale[1])
+
+ if (xMin >= width or xMax < 0 or
+ yMin >= height or yMax < 0):
+ return None
+
+ # The image is at least partly in the plot area
+ # Get the visible bounds in image coords (i.e., in pixels)
+ subsetXMin = 0 if xMin < 0 else xMin
+ subsetXMax = (width if xMax >= width else xMax) + 1
+ subsetYMin = 0 if yMin < 0 else yMin
+ subsetYMax = (height if yMax >= height else yMax) + 1
+
+ if cache is not None:
+ if ((subsetXMin, subsetXMax) == cache.dataXRange and
+ (subsetYMin, subsetYMax) == cache.dataYRange):
+ # The visible area of data is the same
+ return cache
+
+ # Rebuild histograms for visible area
+ visibleData = data[subsetYMin:subsetYMax,
+ subsetXMin:subsetXMax]
+ histoHVisibleData = numpy.nansum(visibleData, axis=0)
+ histoVVisibleData = numpy.nansum(visibleData, axis=1)
+ histoHMin = numpy.nanmin(histoHVisibleData)
+ histoHMax = numpy.nanmax(histoHVisibleData)
+ histoVMin = numpy.nanmin(histoVVisibleData)
+ histoVMax = numpy.nanmax(histoVVisibleData)
+
+ # Convert to histogram curve and update plots
+ # Taking into account origin and scale
+ coords = numpy.arange(2 * histoHVisibleData.size)
+ xCoords = (coords + 1) // 2 + subsetXMin
+ xCoords = origin[0] + scale[0] * xCoords
+ xData = numpy.take(histoHVisibleData, coords // 2)
+ coords = numpy.arange(2 * histoVVisibleData.size)
+ yCoords = (coords + 1) // 2 + subsetYMin
+ yCoords = origin[1] + scale[1] * yCoords
+ yData = numpy.take(histoVVisibleData, coords // 2)
+
+ result = ProfileSumResult(
+ dataXRange=(subsetXMin, subsetXMax),
+ dataYRange=(subsetYMin, subsetYMax),
+ histoH=histoHVisibleData,
+ histoHRange=(histoHMin, histoHMax),
+ histoV=histoVVisibleData,
+ histoVRange=(histoVMin, histoVMax),
+ xCoords=xCoords,
+ xData=xData,
+ yCoords=yCoords,
+ yData=yData)
+
+ return result
+
+
+class _SideHistogram(PlotWidget):
+ """
+ Widget displaying one of the side profile of the ImageView.
+
+ Implement ProfileWindow
+ """
+
+ sigClose = qt.Signal()
+
+ sigMouseMoved = qt.Signal(float, float)
+
+ def __init__(self, parent=None, backend=None, direction=qt.Qt.Horizontal):
+ super(_SideHistogram, self).__init__(parent=parent, backend=backend)
+ self._direction = direction
+ self.sigPlotSignal.connect(self._plotEvents)
+ self._color = "blue"
+ self.__profile = None
+ self.__profileSum = None
+
+ def _plotEvents(self, eventDict):
+ """Callback for horizontal histogram plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ self.sigMouseMoved.emit(eventDict['x'], eventDict['y'])
+
+ def setProfileColor(self, color):
+ self._color = color
+
+ def setProfileSum(self, result):
+ self.__profileSum = result
+ if self.__profile is None:
+ self.__drawProfileSum()
+
+ def prepareWidget(self, roi):
+ """Implements `ProfileWindow`"""
+ pass
+
+ def setRoiProfile(self, roi):
+ """Implements `ProfileWindow`"""
+ if roi is None:
+ return
+ self._roiColor = colors.rgba(roi.getColor())
+
+ def getProfile(self):
+ """Implements `ProfileWindow`"""
+ return self.__profile
+
+ def setProfile(self, data):
+ """Implements `ProfileWindow`"""
+ self.__profile = data
+ if data is None:
+ self.__drawProfileSum()
+ else:
+ self.__drawProfile()
+
+ def __drawProfileSum(self):
+ """Only draw the profile sum on the plot.
+
+ Other elements are removed
+ """
+ profileSum = self.__profileSum
+
+ try:
+ self.removeCurve('profile')
+ except Exception:
+ pass
+
+ if profileSum is None:
+ try:
+ self.removeCurve('profilesum')
+ except Exception:
+ pass
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profileSum.xCoords, profileSum.xData
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profileSum.yData, profileSum.yCoords
+ else:
+ assert False
+
+ self.addCurve(xx, yy,
+ xlabel='', ylabel='',
+ legend="profilesum",
+ color=self._color,
+ linestyle='-',
+ selectable=False,
+ resetzoom=False)
+
+ self.__updateLimits()
+
+ def __drawProfile(self):
+ """Only draw the profile on the plot.
+
+ Other elements are removed
+ """
+ profile = self.__profile
+
+ try:
+ self.removeCurve('profilesum')
+ except Exception:
+ pass
+
+ if profile is None:
+ try:
+ self.removeCurve('profile')
+ except Exception:
+ pass
+ self.setProfileSum(self.__profileSum)
+ return
+
+ if self._direction == qt.Qt.Horizontal:
+ xx, yy = profile.coords, profile.profile
+ elif self._direction == qt.Qt.Vertical:
+ xx, yy = profile.profile, profile.coords
+ else:
+ assert False
+
+ self.addCurve(xx,
+ yy,
+ legend="profile",
+ color=self._roiColor,
+ resetzoom=False)
+
+ self.__updateLimits()
+
+ def __updateLimits(self):
+ if self.__profile:
+ data = self.__profile.profile
+ vMin = numpy.nanmin(data)
+ vMax = numpy.nanmax(data)
+ elif self.__profileSum is not None:
+ if self._direction == qt.Qt.Horizontal:
+ vMin, vMax = self.__profileSum.histoHRange
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax = self.__profileSum.histoVRange
+ else:
+ assert False
+ else:
+ vMin, vMax = 0, 0
+
+ # Tune the result using the data margins
+ margins = self.getDataMargins()
+ if self._direction == qt.Qt.Horizontal:
+ _, _, vMin, vMax = _utils.addMarginsToLimits(margins, False, False, 0, 0, vMin, vMax)
+ elif self._direction == qt.Qt.Vertical:
+ vMin, vMax, _, _ = _utils.addMarginsToLimits(margins, False, False, vMin, vMax, 0, 0)
+ else:
+ assert False
+
+ if self._direction == qt.Qt.Horizontal:
+ dataAxis = self.getYAxis()
+ elif self._direction == qt.Qt.Vertical:
+ dataAxis = self.getXAxis()
+ else:
+ assert False
+
+ with blockSignals(dataAxis):
+ dataAxis.setLimits(vMin, vMax)
+
+
+class ShowSideHistogramsAction(PlotAction):
+ """QAction to change visibility of side histogram of a :class:`.ImageView`.
+
+ :param plot: :class:`.ImageView` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(ShowSideHistogramsAction, self).__init__(
+ plot, icon='side-histograms', text='Show/hide side histograms',
+ tooltip='Show/hide side histogram',
+ triggered=self._actionTriggered,
+ checkable=True, parent=parent)
+
+ def _actionTriggered(self, checked=False):
+ if self.plot.isSideHistogramDisplayed() != checked:
+ self.plot.setSideHistogramDisplayed(checked)
+
+
+class AggregationModeAction(qt.QWidgetAction):
+ """Action providing few filters to the image"""
+
+ sigAggregationModeChanged = qt.Signal()
+
+ def __init__(self, parent):
+ qt.QWidgetAction.__init__(self, parent)
+
+ toolButton = qt.QToolButton(parent)
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("No filter")
+ filterAction.setCheckable(True)
+ filterAction.setChecked(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.NONE)
+ densityNoFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Max filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MAX)
+ densityMaxFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Mean filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MEAN)
+ densityMeanFilterAction = filterAction
+
+ filterAction = qt.QAction(self)
+ filterAction.setText("Min filter")
+ filterAction.setCheckable(True)
+ filterAction.setProperty("aggregation", items.ImageDataAggregated.Aggregation.MIN)
+ densityMinFilterAction = filterAction
+
+ densityGroup = qt.QActionGroup(self)
+ densityGroup.setExclusive(True)
+ densityGroup.addAction(densityNoFilterAction)
+ densityGroup.addAction(densityMaxFilterAction)
+ densityGroup.addAction(densityMeanFilterAction)
+ densityGroup.addAction(densityMinFilterAction)
+ densityGroup.triggered.connect(self._aggregationModeChanged)
+ self.__densityGroup = densityGroup
+
+ filterMenu = qt.QMenu(toolButton)
+ filterMenu.addAction(densityNoFilterAction)
+ filterMenu.addAction(densityMaxFilterAction)
+ filterMenu.addAction(densityMeanFilterAction)
+ filterMenu.addAction(densityMinFilterAction)
+
+ toolButton.setPopupMode(qt.QToolButton.InstantPopup)
+ toolButton.setMenu(filterMenu)
+ toolButton.setText("Data filters")
+ toolButton.setToolTip("Enable/disable filter on the image")
+ icon = icons.getQIcon("aggregation-mode")
+ toolButton.setIcon(icon)
+ toolButton.setText("Pixel aggregation filter")
+
+ self.setDefaultWidget(toolButton)
+
+ def _aggregationModeChanged(self):
+ self.sigAggregationModeChanged.emit()
+
+ def setAggregationMode(self, mode):
+ """Set an Aggregated enum from ImageDataAggregated"""
+ for a in self.__densityGroup.actions():
+ if a.property("aggregation") is mode:
+ a.setChecked(True)
+
+ def getAggregationMode(self):
+ """Returns an Aggregated enum from ImageDataAggregated"""
+ densityAction = self.__densityGroup.checkedAction()
+ if densityAction is None:
+ return items.ImageDataAggregated.Aggregation.NONE
+ return densityAction.property("aggregation")
+
+
+class ImageView(PlotWindow):
+ """Display a single image with horizontal and vertical histograms.
+
+ Use :meth:`setImage` to control the displayed image.
+ This class also provides the :class:`silx.gui.plot.Plot` API.
+
+ The :class:`ImageView` inherits from :class:`.PlotWindow` (which provides
+ the toolbars) and also exposes :class:`.PlotWidget` API for further
+ plot control (plot title, axes labels, aspect ratio, ...).
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ HISTOGRAMS_COLOR = 'blue'
+ """Color to use for the side histograms."""
+
+ HISTOGRAMS_HEIGHT = 200
+ """Height in pixels of the side histograms."""
+
+ IMAGE_MIN_SIZE = 200
+ """Minimum size in pixels of the image area."""
+
+ # Qt signals
+ valueChanged = qt.Signal(float, float, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+
+ When the cursor is over an histogram, either row or column is Nan
+ and the provided data value is the histogram value
+ (i.e., the sum along the corresponding row/column).
+ Row and columns are either Nan or integer values.
+ """
+
+ class ProfileWindowBehavior(Enum):
+ """ImageView's profile window behavior options"""
+
+ POPUP = 'popup'
+ """All profiles are displayed in pop-up windows"""
+
+ EMBEDDED = 'embedded'
+ """Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._imageLegend = '__ImageView__image' + str(id(self))
+ self._cache = None # Store currently visible data information
+
+ super(ImageView, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=False,
+ logScale=False, grid=False,
+ curveStyle=False, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=False,
+ roi=False, mask=True)
+
+ # Enable mask synchronisation to use it in profiles
+ maskToolsWidget = self.getMaskToolsDockWidget().widget()
+ maskToolsWidget.setItemMaskUpdated(True)
+
+ self.__showSideHistogramsAction = ShowSideHistogramsAction(self, self)
+ self.__showSideHistogramsAction.setChecked(True)
+
+ self.__aggregationModeAction = AggregationModeAction(self)
+ self.__aggregationModeAction.sigAggregationModeChanged.connect(self._aggregationModeChanged)
+
+ if parent is None:
+ self.setWindowTitle('ImageView')
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.getYAxis().setInverted(True)
+
+ self._initWidgets(backend)
+
+ toolBar = self.toolBar()
+ toolBar.addAction(self.__showSideHistogramsAction)
+ toolBar.addAction(self.__aggregationModeAction)
+
+ self.__profileWindowBehavior = self.ProfileWindowBehavior.POPUP
+ self.__profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.__profile)
+
+ def _initWidgets(self, backend):
+ """Set-up layout and plots."""
+ self._histoHPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Horizontal)
+ widgetHandle = self._histoHPlot.getWidgetHandle()
+ widgetHandle.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumHeight(self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot.setInteractiveMode('zoom')
+ self._histoHPlot.setDataMargins(0., 0., 0.1, 0.1)
+ self._histoHPlot.sigMouseMoved.connect(self._mouseMovedOnHistoH)
+ self._histoHPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self._histoVPlot = _SideHistogram(backend=backend, parent=self, direction=qt.Qt.Vertical)
+ widgetHandle = self._histoVPlot.getWidgetHandle()
+ widgetHandle.setMinimumWidth(self.HISTOGRAMS_HEIGHT)
+ widgetHandle.setMaximumWidth(self.HISTOGRAMS_HEIGHT)
+ self._histoVPlot.setInteractiveMode('zoom')
+ self._histoVPlot.setDataMargins(0.1, 0.1, 0., 0.)
+ self._histoVPlot.sigMouseMoved.connect(self._mouseMovedOnHistoV)
+ self._histoVPlot.setProfileColor(self.HISTOGRAMS_COLOR)
+
+ self.setPanWithArrowKeys(True)
+ self.setInteractiveMode('zoom') # Color set in setColormap
+ self.sigPlotSignal.connect(self._imagePlotCB)
+ self.sigActiveImageChanged.connect(self._activeImageChangedSlot)
+
+ self._radarView = RadarView(parent=self)
+ self._radarView.setPlotWidget(self)
+
+ self.__syncXAxis = SyncAxes([self.getXAxis(), self._histoHPlot.getXAxis()])
+ self.__syncYAxis = SyncAxes([self.getYAxis(), self._histoVPlot.getYAxis()])
+
+ self.__setCentralWidget()
+
+ def __setCentralWidget(self):
+ """Set central widget with all its content"""
+ layout = qt.QGridLayout()
+ layout.addWidget(self.getWidgetHandle(), 0, 0)
+ layout.addWidget(self._histoVPlot, 0, 1)
+ layout.addWidget(self._histoHPlot, 1, 0)
+ layout.addWidget(self._radarView, 1, 1, 1, 2)
+ layout.addWidget(self.getColorBarWidget(), 0, 2)
+
+ self._radarView.setMinimumWidth(self.IMAGE_MIN_SIZE)
+ self._radarView.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+ self._histoHPlot.setMinimumWidth(self.IMAGE_MIN_SIZE)
+ self._histoVPlot.setMinimumHeight(self.HISTOGRAMS_HEIGHT)
+
+ layout.setColumnStretch(0, 1)
+ layout.setColumnStretch(1, 0)
+ layout.setRowStretch(0, 1)
+ layout.setRowStretch(1, 0)
+
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(layout)
+ self.setCentralWidget(centralWidget)
+
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ # Use PlotWidget here since we override PlotWindow behavior
+ PlotWidget.setBackend(self, backend)
+ self.__setCentralWidget()
+
+ def _dirtyCache(self):
+ self._cache = None
+
+ def getAggregationModeAction(self):
+ return self.__aggregationModeAction
+
+ def _aggregationModeChanged(self):
+ item = self._getItem("image", self._imageLegend)
+ if item is None:
+ return
+ aggregationMode = self.__aggregationModeAction.getAggregationMode()
+ if aggregationMode is not None and isinstance(item, items.ImageDataAggregated):
+ item.setAggregationMode(aggregationMode)
+ else:
+ # It means the item type have to be changed
+ self.removeImage(self._imageLegend)
+ image = item.getData(copy=False)
+ if image is None:
+ return
+ origin = item.getOrigin()
+ scale = item.getScale()
+ self.setImage(image, origin, scale, copy=False, resetzoom=False)
+
+ def getShowSideHistogramsAction(self):
+ return self.__showSideHistogramsAction
+
+ def setSideHistogramDisplayed(self, show):
+ """Display or not the side histograms"""
+ if self.isSideHistogramDisplayed() == show:
+ return
+ self._histoHPlot.setVisible(show)
+ self._histoVPlot.setVisible(show)
+ self._radarView.setVisible(show)
+ self.__showSideHistogramsAction.setChecked(show)
+ if show:
+ # Probably have to be computed
+ self._updateHistograms()
+
+ def isSideHistogramDisplayed(self):
+ """True if the side histograms are displayed"""
+ return self._histoHPlot.isVisible()
+
+ def _updateHistograms(self):
+ """Update histograms content using current active image."""
+ if not self.isSideHistogramDisplayed():
+ # The histogram computation can be skipped
+ return
+
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ xRange = self.getXAxis().getLimits()
+ yRange = self.getYAxis().getLimits()
+ result = computeProfileSumOnRange(activeImage, xRange, yRange, self._cache)
+ self._cache = result
+ self._histoHPlot.setProfileSum(result)
+ self._histoVPlot.setProfileSum(result)
+
+ # Plots event listeners
+
+ def _imagePlotCB(self, eventDict):
+ """Callback for imageView plot events."""
+ if eventDict['event'] == 'mouseMoved':
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData(copy=False)
+ height, width = data.shape[0:2]
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ if (eventDict['x'] >= origin[0] and
+ eventDict['y'] >= origin[1]):
+ x = int((eventDict['x'] - origin[0]) / scale[0])
+ y = int((eventDict['y'] - origin[1]) / scale[1])
+
+ if x >= 0 and x < width and y >= 0 and y < height:
+ self.valueChanged.emit(float(x), float(y),
+ data[y][x])
+
+ elif eventDict['event'] == 'limitsChanged':
+ self._updateHistograms()
+
+ def _mouseMovedOnHistoH(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
+
+ xOrigin = activeImage.getOrigin()[0]
+ xScale = activeImage.getScale()[0]
+
+ minValue = xOrigin + xScale * self._cache.dataXRange[0]
+
+ if x >= minValue:
+ data = self._cache.histoH
+ column = int((x - minValue) / xScale)
+ if column >= 0 and column < data.shape[0]:
+ self.valueChanged.emit(
+ float('nan'),
+ float(column + self._cache.dataXRange[0]),
+ data[column])
+
+ def _mouseMovedOnHistoV(self, x, y):
+ if self._cache is None:
+ return
+ activeImage = self.getActiveImage()
+ if activeImage is None:
+ return
+
+ yOrigin = activeImage.getOrigin()[1]
+ yScale = activeImage.getScale()[1]
+
+ minValue = yOrigin + yScale * self._cache.dataYRange[0]
+
+ if y >= minValue:
+ data = self._cache.histoV
+ row = int((y - minValue) / yScale)
+ if row >= 0 and row < data.shape[0]:
+ self.valueChanged.emit(
+ float(row + self._cache.dataYRange[0]),
+ float('nan'),
+ data[row])
+
+ def _activeImageChangedSlot(self, previous, legend):
+ """Handle Plot active image change.
+
+ Resets side histograms cache
+ """
+ self._dirtyCache()
+ self._updateHistograms()
+
+ def setProfileWindowBehavior(self, behavior: Union[str, ProfileWindowBehavior]):
+ """Set where profile widgets are displayed.
+
+ :param ProfileWindowBehavior behavior:
+ - 'popup': All profiles are displayed in pop-up windows
+ - 'embedded': Horizontal, vertical and cross profiles are displayed in
+ sides widgets, others are displayed in pop-up windows.
+ """
+ behavior = self.ProfileWindowBehavior.from_value(behavior)
+ if behavior is not self.getProfileWindowBehavior():
+ manager = self.__profile.getProfileManager()
+ manager.clearProfile()
+ manager.requestUpdateAllProfile()
+
+ if behavior is self.ProfileWindowBehavior.EMBEDDED:
+ horizontalProfileWindow = self._histoHPlot
+ verticalProfileWindow = self._histoVPlot
+ else:
+ horizontalProfileWindow = None
+ verticalProfileWindow = None
+
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageHorizontalLineROI, horizontalProfileWindow
+ )
+ manager.setSpecializedProfileWindow(
+ rois.ProfileImageVerticalLineROI, verticalProfileWindow
+ )
+ self.__profileWindowBehavior = behavior
+
+ def getProfileWindowBehavior(self) -> ProfileWindowBehavior:
+ """Returns current profile display behavior.
+
+ See :meth:`setProfileWindowBehavior` and :class:`ProfileWindowBehavior`
+ """
+ return self.__profileWindowBehavior
+
+ def getProfileToolBar(self):
+ """"Returns profile tools attached to this plot.
+
+ :rtype: silx.gui.plot.PlotTools.ProfileToolBar
+ """
+ return self.__profile
+
+ @property
+ @deprecated(replacement="getProfileToolBar()")
+ def profile(self):
+ return self.getProfileToolBar()
+
+ def getHistogram(self, axis):
+ """Return the histogram and corresponding row or column extent.
+
+ The returned value when an histogram is available is a dict with keys:
+
+ - 'data': numpy array of the histogram values.
+ - 'extent': (start, end) row or column index.
+ end index is not included in the histogram.
+
+ :param str axis: 'x' for horizontal, 'y' for vertical
+ :return: The histogram and its extent as a dict or None.
+ :rtype: dict
+ """
+ assert axis in ('x', 'y')
+ if self._cache is None:
+ return None
+ else:
+ if axis == 'x':
+ return dict(
+ data=numpy.array(self._cache.histoH, copy=True),
+ extent=self._cache.dataXRange)
+ else:
+ return dict(
+ data=numpy.array(self._cache.histoV, copy=True),
+ extent=(self._cache.dataYRange))
+
+ def radarView(self):
+ """Get the lower right radarView widget."""
+ return self._radarView
+
+ def setRadarView(self, radarView):
+ """Change the lower right radarView widget.
+
+ :param RadarView radarView: Widget subclassing RadarView to replace
+ the lower right corner widget.
+ """
+ self._radarView = radarView
+ self._radarView.setPlotWidget(self)
+ self.centralWidget().layout().addWidget(self._radarView, 1, 1)
+
+ # High-level API
+
+ def getColormap(self):
+ """Get the default colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ return self.getDefaultColormap()
+
+ def setColormap(self, colormap=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the default colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True)
+ or range provided by keys 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or the description of the colormap as a dict.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale (True)
+ or [vmin, vmax] range (False).
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ cmap = self.getDefaultColormap()
+
+ if isinstance(colormap, Colormap):
+ # Replace colormap
+ cmap = colormap
+
+ self.setDefaultColormap(cmap)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(cmap)
+
+ elif isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ assert normalization is None
+ assert autoscale is None
+ assert vmin is None
+ assert vmax is None
+ assert colors is None
+ cmap._setFromDict(colormap)
+
+ else:
+ if colormap is not None:
+ cmap.setName(colormap)
+ if normalization is not None:
+ cmap.setNormalization(normalization)
+ if autoscale:
+ cmap.setVRange(None, None)
+ else:
+ if vmin is not None:
+ cmap.setVMin(vmin)
+ if vmax is not None:
+ cmap.setVMax(vmax)
+ if colors is not None:
+ cmap.setColormapLUT(colors)
+
+ cursorColor = cursorColorForColormap(cmap.getName())
+ self.setInteractiveMode('zoom', color=cursorColor)
+
+ def setImage(self, image, origin=(0, 0), scale=(1., 1.),
+ copy=True, reset=None, resetzoom=True):
+ """Set the image to display.
+
+ :param image: A 2D array representing the image or None to empty plot.
+ :type image: numpy.ndarray-like with 2 dimensions or None.
+ :param origin: The (x, y) position of the origin of the image.
+ Default: (0, 0).
+ The origin is the lower left corner of the image when
+ the Y axis is not inverted.
+ :type origin: Tuple of 2 floats: (origin x, origin y).
+ :param scale: The scale factor to apply to the image on X and Y axes.
+ Default: (1, 1).
+ It is the size of a pixel in the coordinates of the axes.
+ Scales must be positive numbers.
+ :type scale: Tuple of 2 floats: (scale x, scale y).
+ :param bool copy: Whether to copy image data (default) or not.
+ :param bool reset: Deprecated. Alias for `resetzoom`.
+ :param bool resetzoom: Whether to reset zoom and ROI (default) or not.
+ """
+ self._dirtyCache()
+
+ if reset is not None:
+ resetzoom = reset
+
+ assert len(origin) == 2
+ assert len(scale) == 2
+ assert scale[0] > 0
+ assert scale[1] > 0
+
+ if image is None:
+ self.remove(self._imageLegend, kind='image')
+ return
+
+ data = numpy.array(image, order='C', copy=copy)
+ if data.size == 0:
+ self.remove(self._imageLegend, kind='image')
+ return
+
+ assert data.ndim == 2 or (data.ndim == 3 and data.shape[2] in (3, 4))
+
+ aggregation = self.getAggregationModeAction().getAggregationMode()
+ if data.ndim != 2 and aggregation is not None:
+ # RGB/A with aggregation is not supported
+ aggregation = items.ImageDataAggregated.Aggregation.NONE
+
+ if aggregation is items.ImageDataAggregated.Aggregation.NONE:
+ self.addImage(data,
+ legend=self._imageLegend,
+ origin=origin, scale=scale,
+ colormap=self.getColormap(),
+ resetzoom=False)
+ else:
+ item = self._getItem("image", self._imageLegend)
+ if isinstance(item, items.ImageDataAggregated):
+ item.setData(data)
+ item.setOrigin(origin)
+ item.setScale(scale)
+ else:
+ if isinstance(item, items.ImageDataAggregated):
+ imageItem = item
+ wasCreated = False
+ else:
+ if item is not None:
+ self.removeImage(self._imageLegend)
+ imageItem = items.ImageDataAggregated()
+ imageItem.setName(self._imageLegend)
+ imageItem.setColormap(self.getColormap())
+ wasCreated = True
+ imageItem.setData(data)
+ imageItem.setOrigin(origin)
+ imageItem.setScale(scale)
+ imageItem.setAggregationMode(aggregation)
+ if wasCreated:
+ self.addItem(imageItem)
+
+ self.setActiveImage(self._imageLegend)
+ self._updateHistograms()
+ if resetzoom:
+ self.resetZoom()
+
+
+# ImageViewMainWindow #########################################################
+
+class ImageViewMainWindow(ImageView):
+ """:class:`ImageView` with additional toolbars
+
+ Adds extra toolbar and a status bar to :class:`ImageView`.
+ """
+ def __init__(self, parent=None, backend=None):
+ self._dataInfo = None
+ super(ImageViewMainWindow, self).__init__(parent, backend)
+ self.setWindowFlags(qt.Qt.Window)
+
+ self.getXAxis().setLabel('X')
+ self.getYAxis().setLabel('Y')
+ self.setGraphTitle('Image')
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea, LimitsToolBar(plot=self))
+
+ menu = self.menuBar().addMenu('File')
+ menu.addAction(self.getOutputToolBar().getSaveAction())
+ menu.addAction(self.getOutputToolBar().getPrintAction())
+ menu.addSeparator()
+ action = menu.addAction('Quit')
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu('Edit')
+ menu.addAction(self.getOutputToolBar().getCopyAction())
+ menu.addSeparator()
+ menu.addAction(self.getResetZoomAction())
+ menu.addAction(self.getColormapAction())
+ menu.addAction(actions.control.KeepAspectRatioAction(self, self))
+ menu.addAction(actions.control.YAxisInvertedAction(self, self))
+ menu.addAction(self.getShowSideHistogramsAction())
+
+ self.__profileMenu = self.menuBar().addMenu('Profile')
+ self.__updateProfileMenu()
+
+ # Connect to ImageView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def __updateProfileMenu(self):
+ """Update actions available in 'Profile' menu"""
+ profile = self.getProfileToolBar()
+ self.__profileMenu.clear()
+ self.__profileMenu.addAction(profile.hLineAction)
+ self.__profileMenu.addAction(profile.vLineAction)
+ self.__profileMenu.addAction(profile.crossAction)
+ self.__profileMenu.addAction(profile.lineAction)
+ self.__profileMenu.addAction(profile.clearAction)
+
+ def _formatValueToString(self, value):
+ try:
+ if isinstance(value, numpy.ndarray):
+ if len(value) == 4:
+ return "RGBA: %.3g, %.3g, %.3g, %.3g" % (value[0], value[1], value[2], value[3])
+ elif len(value) == 3:
+ return "RGB: %.3g, %.3g, %.3g" % (value[0], value[1], value[2])
+ else:
+ return "Value: %g" % value
+ except Exception:
+ _logger.error("Error while formatting pixel value", exc_info=True)
+ pass
+ return "Value: %s" % value
+
+ def _statusBarSlot(self, row, column, value):
+ """Update status bar with coordinates/value from plots."""
+ if numpy.isnan(row):
+ msg = 'Column: %d, Sum: %g' % (int(column), value)
+ elif numpy.isnan(column):
+ msg = 'Row: %d, Sum: %g' % (int(row), value)
+ else:
+ msg_value = self._formatValueToString(value)
+ msg = 'Position: (%d, %d), %s' % (int(row), int(column), msg_value)
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ', ' + msg
+
+ self.statusBar().showMessage(msg)
+
+ @docstring(ImageView)
+ def setProfileWindowBehavior(self, behavior: str):
+ super().setProfileWindowBehavior(behavior)
+ self.__updateProfileMenu()
+
+ @docstring(ImageView)
+ def setImage(self, image, *args, **kwargs):
+ if hasattr(image, 'dtype') and hasattr(image, 'shape'):
+ assert image.ndim == 2 or (image.ndim == 3 and image.shape[2] in (3, 4))
+ height, width = image.shape[0:2]
+ dataInfo = 'Data: %dx%d (%s)' % (width, height, str(image.dtype))
+ else:
+ dataInfo = None
+
+ if self._dataInfo != dataInfo:
+ self._dataInfo = dataInfo
+ self.statusBar().showMessage(self._dataInfo)
+
+ # Set the new image in ImageView widget
+ super(ImageViewMainWindow, self).setImage(image, *args, **kwargs)
diff --git a/silx/gui/plot/Interaction.py b/src/silx/gui/plot/Interaction.py
index 6213889..6213889 100644
--- a/silx/gui/plot/Interaction.py
+++ b/src/silx/gui/plot/Interaction.py
diff --git a/src/silx/gui/plot/ItemsSelectionDialog.py b/src/silx/gui/plot/ItemsSelectionDialog.py
new file mode 100644
index 0000000..c0504b0
--- /dev/null
+++ b/src/silx/gui/plot/ItemsSelectionDialog.py
@@ -0,0 +1,286 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a dialog widget to select plot items.
+
+.. autoclass:: ItemsSelectionDialog
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/06/2017"
+
+import logging
+
+from silx.gui import qt
+from silx.gui.plot.PlotWidget import PlotWidget
+
+_logger = logging.getLogger(__name__)
+
+
+class KindsSelector(qt.QListWidget):
+ """List widget allowing to select plot item kinds
+ ("curve", "scatter", "image"...)
+ """
+ sigSelectedKindsChanged = qt.Signal(list)
+
+ def __init__(self, parent=None, kinds=None):
+ """
+
+ :param parent: Parent QWidget or None
+ :param tuple(str) kinds: Sequence of kinds. If None, the default
+ behavior is to provide a checkbox for all possible item kinds.
+ """
+ qt.QListWidget.__init__(self, parent)
+
+ self.plot_item_kinds = []
+
+ self.setAvailableKinds(kinds if kinds is not None else PlotWidget.ITEM_KINDS)
+
+ self.setSelectionMode(qt.QAbstractItemView.ExtendedSelection)
+ self.selectAll()
+
+ self.itemSelectionChanged.connect(self.emitSigKindsSelectionChanged)
+
+ def emitSigKindsSelectionChanged(self):
+ self.sigSelectedKindsChanged.emit(self.selectedKinds)
+
+ @property
+ def selectedKinds(self):
+ """Tuple of all selected kinds (as strings)."""
+ # check for updates when self.itemSelectionChanged
+ return [item.text() for item in self.selectedItems()]
+
+ def setAvailableKinds(self, kinds):
+ """Set a list of kinds to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ self.plot_item_kinds = kinds
+
+ self.clear()
+ for kind in self.plot_item_kinds:
+ item = qt.QListWidgetItem(self)
+ item.setText(kind)
+ self.addItem(item)
+
+ def selectAll(self):
+ """Select all available kinds."""
+ if self.selectionMode() in [qt.QAbstractItemView.SingleSelection,
+ qt.QAbstractItemView.NoSelection]:
+ raise RuntimeError("selectAll requires a multiple selection mode")
+ for i in range(self.count()):
+ self.item(i).setSelected(True)
+
+
+class PlotItemsSelector(qt.QTableWidget):
+ """Table widget displaying the legend and kind of all
+ plot items corresponding to a list of specified kinds.
+
+ Selected plot items are provided as property :attr:`selectedPlotItems`.
+ You can be warned of selection changes by listening to signal
+ :attr:`itemSelectionChanged`.
+ """
+ def __init__(self, parent=None, plot=None):
+ if plot is None or not isinstance(plot, PlotWidget):
+ raise AttributeError("parameter plot is required")
+ self.plot = plot
+ """:class:`PlotWidget` instance"""
+
+ self.plot_item_kinds = None
+ """List of plot item kinds (strings)"""
+
+ qt.QTableWidget.__init__(self, parent)
+
+ self.setColumnCount(2)
+
+ self.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def _clear(self):
+ self.clear()
+ self.setHorizontalHeaderLabels(["legend", "type"])
+
+ def setAllKindsFilter(self):
+ """Display all kinds of plot items."""
+ self.setKindsFilter(PlotWidget.ITEM_KINDS)
+
+ def setKindsFilter(self, kinds):
+ """Set list of all kinds of plot items to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ if not set(kinds) <= set(PlotWidget.ITEM_KINDS):
+ raise KeyError("Illegal plot item kinds: %s" %
+ set(kinds) - set(PlotWidget.ITEM_KINDS))
+ self.plot_item_kinds = kinds
+
+ self.updatePlotItems()
+
+ def updatePlotItems(self):
+ self._clear()
+
+ # respect order of kinds as set in method setKindsFilter
+ itemsAndKind = []
+ for kind in self.plot_item_kinds:
+ itemClasses = self.plot._KIND_TO_CLASSES[kind]
+ for item in self.plot.getItems():
+ if isinstance(item, itemClasses) and item.isVisible():
+ itemsAndKind.append((item, kind))
+
+ self.setRowCount(len(itemsAndKind))
+
+ for index, (item, kind) in enumerate(itemsAndKind):
+ legend_twitem = qt.QTableWidgetItem(item.getName())
+ self.setItem(index, 0, legend_twitem)
+
+ kind_twitem = qt.QTableWidgetItem(kind)
+ self.setItem(index, 1, kind_twitem)
+
+ @property
+ def selectedPlotItems(self):
+ """List of all selected items"""
+ selection_model = self.selectionModel()
+ selected_rows_idx = selection_model.selectedRows()
+ selected_rows = [idx.row() for idx in selected_rows_idx]
+
+ items = []
+ for row in selected_rows:
+ legend = self.item(row, 0).text()
+ kind = self.item(row, 1).text()
+ item = self.plot._getItem(kind, legend)
+ if item is not None:
+ items.append(item)
+
+ return items
+
+
+class ItemsSelectionDialog(qt.QDialog):
+ """This widget is a modal dialog allowing to select one or more plot
+ items, in a table displaying their legend and kind.
+
+ Public methods:
+
+ - :meth:`getSelectedItems`
+ - :meth:`setAvailableKinds`
+ - :meth:`setItemsSelectionMode`
+
+ This widget inherits QDialog and therefore implements the usual
+ dialog methods, e.g. :meth:`exec`.
+
+ A trivial usage example would be::
+
+ isd = ItemsSelectionDialog(plot=my_plot_widget)
+ isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
+ result = isd.exec()
+ if result:
+ for item in isd.getSelectedItems():
+ print(item.getName(), type(item))
+ else:
+ print("Selection cancelled")
+ """
+ def __init__(self, parent=None, plot=None):
+ if plot is None or not isinstance(plot, PlotWidget):
+ raise AttributeError("parameter plot is required")
+ qt.QDialog.__init__(self, parent)
+
+ self.setWindowTitle("Plot items selector")
+
+ kind_selector_label = qt.QLabel("Filter item kinds:", self)
+ item_selector_label = qt.QLabel("Select items:", self)
+
+ self.kind_selector = KindsSelector(self)
+ self.kind_selector.setToolTip(
+ "select one or more item kinds to show them in the item list")
+
+ self.item_selector = PlotItemsSelector(self, plot)
+ self.item_selector.setToolTip("select items")
+
+ self.item_selector.setKindsFilter(self.kind_selector.selectedKinds)
+ self.kind_selector.sigSelectedKindsChanged.connect(
+ self.item_selector.setKindsFilter
+ )
+
+ okb = qt.QPushButton("OK", self)
+ okb.clicked.connect(self.accept)
+
+ cancelb = qt.QPushButton("Cancel", self)
+ cancelb.clicked.connect(self.reject)
+
+ layout = qt.QGridLayout(self)
+ layout.addWidget(kind_selector_label, 0, 0)
+ layout.addWidget(item_selector_label, 0, 1)
+ layout.addWidget(self.kind_selector, 1, 0)
+ layout.addWidget(self.item_selector, 1, 1)
+ layout.addWidget(okb, 2, 0)
+ layout.addWidget(cancelb, 2, 1)
+
+ self.setLayout(layout)
+
+ def getSelectedItems(self):
+ """Return a list of selected plot items
+
+ :return: List of selected plot items
+ :rtype: list[silx.gui.plot.items.Item]"""
+ return self.item_selector.selectedPlotItems
+
+ def setAvailableKinds(self, kinds):
+ """Set a list of kinds to be displayed.
+
+ :param list[str] kinds: Sequence of kinds
+ """
+ self.kind_selector.setAvailableKinds(kinds)
+
+ def selectAllKinds(self):
+ self.kind_selector.selectAll()
+
+ def setItemsSelectionMode(self, mode):
+ """Set selection mode for plot item (single item selection,
+ multiple...).
+
+ :param mode: One of :class:`QTableWidget` selection modes
+ """
+ if mode == self.item_selector.SingleSelection:
+ self.item_selector.setToolTip(
+ "Select one item by clicking on it.")
+ elif mode == self.item_selector.MultiSelection:
+ self.item_selector.setToolTip(
+ "Select one or more items by clicking with the left mouse"
+ " button.\nYou can unselect items by clicking them again.\n"
+ "Multiple items can be toggled by dragging the mouse over them.")
+ elif mode == self.item_selector.ExtendedSelection:
+ self.item_selector.setToolTip(
+ "Select one or more items. You can select multiple items "
+ "by keeping the Ctrl key pushed when clicking.\nYou can "
+ "select a range of items by clicking on the first and "
+ "last while keeping the Shift key pushed.")
+ elif mode == self.item_selector.ContiguousSelection:
+ self.item_selector.setToolTip(
+ "Select one item by clicking on it. If you press the Shift"
+ " key while clicking on a second item,\nall items between "
+ "the two will be selected.")
+ elif mode == self.item_selector.NoSelection:
+ raise ValueError("The NoSelection mode is not allowed "
+ "in this context.")
+ self.item_selector.setSelectionMode(mode)
diff --git a/src/silx/gui/plot/LegendSelector.py b/src/silx/gui/plot/LegendSelector.py
new file mode 100755
index 0000000..d439387
--- /dev/null
+++ b/src/silx/gui/plot/LegendSelector.py
@@ -0,0 +1,1039 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget displaying curves legends and allowing to operate on curves.
+
+This widget is meant to work with :class:`PlotWindow`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__data__ = "16/10/2017"
+
+
+import logging
+import weakref
+
+import numpy
+
+from .. import qt, colors
+from ..widgets.LegendIconWidget import LegendIconWidget
+from . import items
+
+
+_logger = logging.getLogger(__name__)
+
+
+class LegendIcon(LegendIconWidget):
+ """Object displaying a curve linestyle and symbol.
+
+ :param QWidget parent: See :class:`QWidget`
+ :param Union[~silx.gui.plot.items.Curve,None] curve:
+ Curve with which to synchronize
+ """
+
+ def __init__(self, parent=None, curve=None):
+ super(LegendIcon, self).__init__(parent)
+ self._curveRef = None
+ self.setCurve(curve)
+
+ def getCurve(self):
+ """Returns curve associated to this widget
+
+ :rtype: Union[~silx.gui.plot.items.Curve,None]
+ """
+ return None if self._curveRef is None else self._curveRef()
+
+ def setCurve(self, curve):
+ """Set the curve with which to synchronize this widget.
+
+ :param curve: Union[~silx.gui.plot.items.Curve,None]
+ """
+ assert curve is None or isinstance(curve, items.Curve)
+
+ previousCurve = self.getCurve()
+ if curve == previousCurve:
+ return
+
+ if previousCurve is not None:
+ previousCurve.sigItemChanged.disconnect(self._curveChanged)
+
+ self._curveRef = None if curve is None else weakref.ref(curve)
+
+ if curve is not None:
+ curve.sigItemChanged.connect(self._curveChanged)
+
+ self._update()
+
+ def _update(self):
+ """Update widget according to current curve state.
+ """
+ curve = self.getCurve()
+ if curve is None:
+ _logger.error('Curve no more exists')
+ self.setEnabled(False)
+ return
+
+ style = curve.getCurrentStyle()
+
+ self.setEnabled(curve.isVisible())
+ self.setSymbol(style.getSymbol())
+ self.setLineWidth(style.getLineWidth())
+ self.setLineStyle(style.getLineStyle())
+
+ color = style.getColor()
+ if numpy.array(color, copy=False).ndim != 1:
+ # array of colors, use transparent black
+ color = 0., 0., 0., 0.
+ color = colors.rgba(color) # Make sure it is float in [0, 1]
+ alpha = curve.getAlpha()
+ color = qt.QColor.fromRgbF(
+ color[0], color[1], color[2], color[3] * alpha)
+ self.setLineColor(color)
+ self.setSymbolColor(color)
+ self.update() # TODO this should not be needed
+
+ def _curveChanged(self, event):
+ """Handle update of curve item
+
+ :param event: Kind of change
+ """
+ if event in (items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.SYMBOL,
+ items.ItemChangedType.SYMBOL_SIZE,
+ items.ItemChangedType.LINE_WIDTH,
+ items.ItemChangedType.LINE_STYLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.ALPHA,
+ items.ItemChangedType.HIGHLIGHTED,
+ items.ItemChangedType.HIGHLIGHTED_STYLE):
+ self._update()
+
+
+class LegendModel(qt.QAbstractListModel):
+ """Data model of curve legends.
+
+ It holds the information of the curve:
+
+ - color
+ - line width
+ - line style
+ - visibility of the lines
+ - symbol
+ - visibility of the symbols
+ """
+ iconColorRole = qt.Qt.UserRole + 0
+ iconLineWidthRole = qt.Qt.UserRole + 1
+ iconLineStyleRole = qt.Qt.UserRole + 2
+ showLineRole = qt.Qt.UserRole + 3
+ iconSymbolRole = qt.Qt.UserRole + 4
+ showSymbolRole = qt.Qt.UserRole + 5
+
+ def __init__(self, legendList=None, parent=None):
+ super(LegendModel, self).__init__(parent)
+ if legendList is None:
+ legendList = []
+ self.legendList = []
+ self.insertLegendList(0, legendList)
+ self._palette = qt.QPalette()
+
+ def __getitem__(self, idx):
+ if idx >= len(self.legendList):
+ raise IndexError('list index out of range')
+ return self.legendList[idx]
+
+ def rowCount(self, modelIndex=None):
+ return len(self.legendList)
+
+ def flags(self, index):
+ return (qt.Qt.ItemIsEditable |
+ qt.Qt.ItemIsEnabled |
+ qt.Qt.ItemIsSelectable)
+
+ def data(self, modelIndex, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ raise IndexError('list index out of range')
+
+ item = self.legendList[idx]
+ isActive = item[1].get("active", False)
+ if role == qt.Qt.DisplayRole:
+ # Data to be rendered in the form of text
+ legend = str(item[0])
+ return legend
+ elif role == qt.Qt.SizeHintRole:
+ # size = qt.QSize(200,50)
+ _logger.warning('LegendModel -- size hint role not implemented')
+ return qt.QSize()
+ elif role == qt.Qt.TextAlignmentRole:
+ alignment = qt.Qt.AlignVCenter | qt.Qt.AlignLeft
+ return alignment
+ elif role == qt.Qt.BackgroundRole:
+ # Background color, must be QBrush
+ if isActive:
+ brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.Highlight)
+ elif idx % 2:
+ brush = qt.QBrush(qt.QColor(240, 240, 240))
+ else:
+ brush = qt.QBrush(qt.Qt.white)
+ return brush
+ elif role == qt.Qt.ForegroundRole:
+ # ForegroundRole color, must be QBrush
+ if isActive:
+ brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.HighlightedText)
+ else:
+ brush = self._palette.brush(qt.QPalette.Normal, qt.QPalette.WindowText)
+ return brush
+ elif role == qt.Qt.CheckStateRole:
+ return bool(item[2]) # item[2] == True
+ elif role == qt.Qt.ToolTipRole or role == qt.Qt.StatusTipRole:
+ return ''
+ elif role == self.iconColorRole:
+ return item[1]['color']
+ elif role == self.iconLineWidthRole:
+ return item[1]['linewidth']
+ elif role == self.iconLineStyleRole:
+ return item[1]['linestyle']
+ elif role == self.iconSymbolRole:
+ return item[1]['symbol']
+ elif role == self.showLineRole:
+ return item[3]
+ elif role == self.showSymbolRole:
+ return item[4]
+ else:
+ _logger.info('Unkown role requested: %s', str(role))
+ return None
+
+ def setData(self, modelIndex, value, role):
+ if modelIndex.isValid:
+ idx = modelIndex.row()
+ else:
+ return None
+ if idx >= len(self.legendList):
+ # raise IndexError('list index out of range')
+ _logger.warning(
+ 'setData -- List index out of range, idx: %d', idx)
+ return None
+
+ item = self.legendList[idx]
+ try:
+ if role == qt.Qt.DisplayRole:
+ # Set legend
+ item[0] = str(value)
+ elif role == self.iconColorRole:
+ item[1]['color'] = qt.QColor(value)
+ elif role == self.iconLineWidthRole:
+ item[1]['linewidth'] = int(value)
+ elif role == self.iconLineStyleRole:
+ item[1]['linestyle'] = str(value)
+ elif role == self.iconSymbolRole:
+ item[1]['symbol'] = str(value)
+ elif role == qt.Qt.CheckStateRole:
+ item[2] = value
+ elif role == self.showLineRole:
+ item[3] = value
+ elif role == self.showSymbolRole:
+ item[4] = value
+ except ValueError:
+ _logger.warning('Conversion failed:\n\tvalue: %s\n\trole: %s',
+ str(value), str(role))
+ # Can that be right? Read docs again..
+ self.dataChanged.emit(modelIndex, modelIndex)
+ return True
+
+ def insertLegendList(self, row, llist):
+ """
+ :param int row: Determines after which row the items are inserted
+ :param llist: Carries the new legend information
+ :type llist: List
+ """
+ modelIndex = self.createIndex(row, 0)
+ count = len(llist)
+ super(LegendModel, self).beginInsertRows(modelIndex,
+ row,
+ row + count)
+ head = self.legendList[0:row]
+ tail = self.legendList[row:]
+ new = []
+ for (legend, icon) in llist:
+ linestyle = icon.get('linestyle', None)
+ if LegendIconWidget.isEmptyLineStyle(linestyle):
+ # Curve had no line, give it one and hide it
+ # So when toggle line, it will display a solid line
+ showLine = False
+ icon['linestyle'] = '-'
+ else:
+ showLine = True
+
+ symbol = icon.get('symbol', None)
+ if LegendIconWidget.isEmptySymbol(symbol):
+ # Curve had no symbol, give it one and hide it
+ # So when toggle symbol, it will display 'o'
+ showSymbol = False
+ icon['symbol'] = 'o'
+ else:
+ showSymbol = True
+
+ selected = icon.get('selected', True)
+ item = [legend,
+ icon,
+ selected,
+ showLine,
+ showSymbol]
+ new.append(item)
+ self.legendList = head + new + tail
+ super(LegendModel, self).endInsertRows()
+ return True
+
+ def insertRows(self, row, count, modelIndex=qt.QModelIndex()):
+ raise NotImplementedError('Use LegendModel.insertLegendList instead')
+
+ def removeRow(self, row):
+ return self.removeRows(row, 1)
+
+ def removeRows(self, row, count, modelIndex=qt.QModelIndex()):
+ length = len(self.legendList)
+ if length == 0:
+ # Nothing to do..
+ return True
+ if row < 0 or row >= length:
+ raise IndexError('Index out of range -- ' +
+ 'idx: %d, len: %d' % (row, length))
+ if count == 0:
+ return False
+ super(LegendModel, self).beginRemoveRows(modelIndex,
+ row,
+ row + count)
+ del(self.legendList[row:row + count])
+ super(LegendModel, self).endRemoveRows()
+ return True
+
+ def setEditor(self, event, editor):
+ """
+ :param str event: String that identifies the editor
+ :param editor: Widget used to change data in the underlying model
+ :type editor: QWidget
+ """
+ if event not in self.eventList:
+ raise ValueError('setEditor -- Event must be in %s' %
+ str(self.eventList))
+ self.editorDict[event] = editor
+
+
+class LegendListItemWidget(qt.QItemDelegate):
+ """Object displaying a single item (i.e., a row) in the list."""
+
+ # Notice: LegendListItem does NOT inherit
+ # from QObject, it cannot emit signals!
+
+ def __init__(self, parent=None, itemType=0):
+ super(LegendListItemWidget, self).__init__(parent)
+
+ # Dictionary to render checkboxes
+ self.cbDict = {}
+ self.labelDict = {}
+ self.iconDict = {}
+
+ # Keep checkbox and legend to get sizeHint
+ self.checkbox = qt.QCheckBox()
+ self.legend = qt.QLabel()
+ self.icon = LegendIcon()
+
+ # Context Menu and Editors
+ self.contextMenu = None
+
+ def paint(self, painter, option, modelIndex):
+ """
+ Here be docs..
+
+ :param QPainter painter:
+ :param QStyleOptionViewItem option:
+ :param QModelIndex modelIndex:
+ """
+ painter.save()
+ rect = option.rect
+
+ # Calculate the icon rectangle
+ iconSize = self.icon.sizeHint()
+ # Calculate icon position
+ x = rect.left() + 2
+ y = rect.top() + int(.5 * (rect.height() - iconSize.height()))
+ iconRect = qt.QRect(qt.QPoint(x, y), iconSize)
+
+ # Calculate label rectangle
+ legendSize = qt.QSize(rect.width() - iconSize.width() - 30,
+ rect.height())
+ # Calculate label position
+ x = rect.left() + iconRect.width()
+ y = rect.top()
+ labelRect = qt.QRect(qt.QPoint(x, y), legendSize)
+ labelRect.translate(qt.QPoint(10, 0))
+
+ # Calculate the checkbox rectangle
+ x = rect.right() - 30
+ y = rect.top()
+ chBoxRect = qt.QRect(qt.QPoint(x, y), rect.bottomRight())
+
+ # Remember the rectangles
+ idx = modelIndex.row()
+ self.cbDict[idx] = chBoxRect
+ self.iconDict[idx] = iconRect
+ self.labelDict[idx] = labelRect
+
+ # Draw background first!
+ if option.state & qt.QStyle.State_MouseOver:
+ backgroundBrush = option.palette.highlight()
+ else:
+ backgroundBrush = modelIndex.data(qt.Qt.BackgroundRole)
+ painter.fillRect(rect, backgroundBrush)
+
+ # Draw label
+ legendText = modelIndex.data(qt.Qt.DisplayRole)
+ textBrush = modelIndex.data(qt.Qt.ForegroundRole)
+ textAlign = modelIndex.data(qt.Qt.TextAlignmentRole)
+ painter.setBrush(textBrush)
+ painter.setFont(self.legend.font())
+ painter.setPen(textBrush.color())
+ painter.drawText(labelRect, textAlign, legendText)
+
+ # Draw icon
+ iconColor = modelIndex.data(LegendModel.iconColorRole)
+ iconLineWidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ iconLineStyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ iconSymbol = modelIndex.data(LegendModel.iconSymbolRole)
+ icon = LegendIcon()
+ icon.resize(iconRect.size())
+ icon.move(iconRect.topRight())
+ icon.showSymbol = modelIndex.data(LegendModel.showSymbolRole)
+ icon.showLine = modelIndex.data(LegendModel.showLineRole)
+ icon.setSymbolColor(iconColor)
+ icon.setLineColor(iconColor)
+ icon.setLineWidth(iconLineWidth)
+ icon.setLineStyle(iconLineStyle)
+ icon.setSymbol(iconSymbol)
+ icon.symbolOutlineBrush = backgroundBrush
+ icon.paint(painter, iconRect, option.palette)
+
+ # Draw the checkbox
+ if modelIndex.data(qt.Qt.CheckStateRole):
+ checkState = qt.Qt.Checked
+ else:
+ checkState = qt.Qt.Unchecked
+
+ self.drawCheck(
+ painter, qt.QStyleOptionViewItem(), chBoxRect, checkState)
+
+ painter.restore()
+
+ def editorEvent(self, event, model, option, modelIndex):
+ # From the docs:
+ # Mouse events are sent to editorEvent()
+ # even if they don't start editing of the item.
+ if event.button() == qt.Qt.RightButton and self.contextMenu:
+ self.contextMenu.exec(event.globalPos(), modelIndex)
+ return True
+ elif event.button() == qt.Qt.LeftButton:
+ # Check if checkbox was clicked
+ idx = modelIndex.row()
+ cbRect = self.cbDict[idx]
+ if cbRect.contains(event.pos()):
+ # Toggle checkbox
+ model.setData(modelIndex,
+ not modelIndex.data(qt.Qt.CheckStateRole),
+ qt.Qt.CheckStateRole)
+ event.ignore()
+ return True
+ else:
+ return super(LegendListItemWidget, self).editorEvent(
+ event, model, option, modelIndex)
+
+ def createEditor(self, parent, option, idx):
+ _logger.info('### Editor request ###')
+
+ def sizeHint(self, option, idx):
+ # return qt.QSize(68,24)
+ iconSize = self.icon.sizeHint()
+ legendSize = self.legend.sizeHint()
+ checkboxSize = self.checkbox.sizeHint()
+ height = max([iconSize.height(),
+ legendSize.height(),
+ checkboxSize.height()]) + 4
+ width = iconSize.width() + legendSize.width() + checkboxSize.width()
+ return qt.QSize(width, height)
+
+
+class LegendListView(qt.QListView):
+ """Widget displaying a list of curve legends, line style and symbol."""
+
+ sigLegendSignal = qt.Signal(object)
+ """Signal emitting a dict when an action is triggered by the user."""
+
+ __mouseClickedEvent = 'mouseClicked'
+ __checkBoxClickedEvent = 'checkBoxClicked'
+ __legendClickedEvent = 'legendClicked'
+
+ def __init__(self, parent=None, model=None, contextMenu=None):
+ super(LegendListView, self).__init__(parent)
+ self.__lastButton = None
+ self.__lastClickPos = None
+ self.__lastModelIdx = None
+ # Set default delegate
+ self.setItemDelegate(LegendListItemWidget())
+ # Set default editors
+ # self.setSizePolicy(qt.QSizePolicy.MinimumExpanding,
+ # qt.QSizePolicy.MinimumExpanding)
+ # Set edit triggers by hand using self.edit(QModelIndex)
+ # in mousePressEvent (better to control than signals)
+ self.setEditTriggers(qt.QAbstractItemView.NoEditTriggers)
+
+ # Control layout
+ # self.setBatchSize(2)
+ # self.setLayoutMode(qt.QListView.Batched)
+ # self.setFlow(qt.QListView.LeftToRight)
+
+ # Control selection
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ if model is None:
+ model = LegendModel(parent=self)
+ self.setModel(model)
+ self.setContextMenu(contextMenu)
+
+ def setLegendList(self, legendList, row=None):
+ if row is not None:
+ model = self.model()
+ model.insertLegendList(row, legendList)
+ elif len(legendList) != self.model().rowCount():
+ self.clear()
+ model = self.model()
+ model.insertLegendList(0, legendList)
+ else:
+ model = self.model()
+ for i, (new_legend, icon) in enumerate(legendList):
+ modelIndex = model.index(i)
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ if new_legend != legend:
+ model.setData(modelIndex, new_legend, qt.Qt.DisplayRole)
+
+ color = modelIndex.data(LegendModel.iconColorRole)
+ new_color = icon.get('color', None)
+ if new_color != color:
+ model.setData(modelIndex, new_color, LegendModel.iconColorRole)
+
+ linewidth = modelIndex.data(LegendModel.iconLineWidthRole)
+ new_linewidth = icon.get('linewidth', 1.0)
+ if new_linewidth != linewidth:
+ model.setData(modelIndex, new_linewidth, LegendModel.iconLineWidthRole)
+
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ new_linestyle = icon.get('linestyle', None)
+ visible = not LegendIconWidget.isEmptyLineStyle(new_linestyle)
+ model.setData(modelIndex, visible, LegendModel.showLineRole)
+ if new_linestyle != linestyle:
+ model.setData(modelIndex, new_linestyle, LegendModel.iconLineStyleRole)
+
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ new_symbol = icon.get('symbol', None)
+ visible = not LegendIconWidget.isEmptySymbol(new_symbol)
+ model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ if new_symbol != symbol:
+ model.setData(modelIndex, new_symbol, LegendModel.iconSymbolRole)
+
+ selected = modelIndex.data(qt.Qt.CheckStateRole)
+ new_selected = icon.get('selected', True)
+ if new_selected != selected:
+ model.setData(modelIndex, new_selected, qt.Qt.CheckStateRole)
+ _logger.debug('LegendListView.setLegendList(legendList) finished')
+
+ def clear(self):
+ model = self.model()
+ model.removeRows(0, model.rowCount())
+ _logger.debug('LegendListView.clear() finished')
+
+ def setContextMenu(self, contextMenu=None):
+ delegate = self.itemDelegate()
+ if isinstance(delegate, LegendListItemWidget) and self.model():
+ if contextMenu is None:
+ delegate.contextMenu = LegendListContextMenu(self.model())
+ delegate.contextMenu.sigContextMenu.connect(
+ self._contextMenuSlot)
+ else:
+ delegate.contextMenu = contextMenu
+
+ def __getitem__(self, idx):
+ model = self.model()
+ try:
+ item = model[idx]
+ except ValueError:
+ item = None
+ return item
+
+ def _contextMenuSlot(self, ddict):
+ self.sigLegendSignal.emit(ddict)
+
+ def mousePressEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mousePressEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseDoubleClickEvent(self, event):
+ self.__lastButton = event.button()
+ self.__lastPosition = event.pos()
+ super(LegendListView, self).mouseDoubleClickEvent(event)
+ # call _handleMouseClick after editing was handled
+ # If right click (context menu) is aborted, no
+ # signal is emitted..
+ self._handleMouseClick(self.indexAt(self.__lastPosition))
+
+ def mouseMoveEvent(self, event):
+ # LegendListView.mouseMoveEvent is overwritten
+ # to suppress unwanted behavior in the delegate.
+ pass
+
+ def mouseReleaseEvent(self, event):
+ # LegendListView.mouseReleaseEvent is overwritten
+ # to subpress unwanted behavior in the delegate.
+ pass
+
+ def _handleMouseClick(self, modelIndex):
+ """
+ Distinguish between mouse click on Legend
+ and mouse click on CheckBox by setting the
+ currentCheckState attribute in LegendListItem.
+
+ Emits signal sigLegendSignal(ddict)
+
+ :param QModelIndex modelIndex: index of the clicked item
+ """
+ _logger.debug('self._handleMouseClick called')
+ if self.__lastButton not in [qt.Qt.LeftButton,
+ qt.Qt.RightButton]:
+ return
+ if not modelIndex.isValid():
+ _logger.debug('_handleMouseClick -- Invalid QModelIndex')
+ return
+ # model = self.model()
+ idx = modelIndex.row()
+
+ delegate = self.itemDelegate()
+ cbClicked = False
+ if isinstance(delegate, LegendListItemWidget):
+ for cbRect in delegate.cbDict.values():
+ if cbRect.contains(self.__lastPosition):
+ cbClicked = True
+ break
+
+ # TODO: Check for doubleclicks on legend/icon and spawn editors
+
+ ddict = {
+ 'legend': str(modelIndex.data(qt.Qt.DisplayRole)),
+ 'icon': {
+ 'linewidth': str(modelIndex.data(
+ LegendModel.iconLineWidthRole)),
+ 'linestyle': str(modelIndex.data(
+ LegendModel.iconLineStyleRole)),
+ 'symbol': str(modelIndex.data(LegendModel.iconSymbolRole))
+ },
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data())
+ }
+ if self.__lastButton == qt.Qt.RightButton:
+ _logger.debug('Right clicked')
+ ddict['button'] = "right"
+ ddict['event'] = self.__mouseClickedEvent
+ elif cbClicked:
+ _logger.debug('CheckBox clicked')
+ ddict['button'] = "left"
+ ddict['event'] = self.__checkBoxClickedEvent
+ else:
+ _logger.debug('Legend clicked')
+ ddict['button'] = "left"
+ ddict['event'] = self.__legendClickedEvent
+ _logger.debug(' idx: %d\n ddict: %s', idx, str(ddict))
+ self.sigLegendSignal.emit(ddict)
+
+
+class LegendListContextMenu(qt.QMenu):
+ """Contextual menu associated to items in a :class:`LegendListView`."""
+
+ sigContextMenu = qt.Signal(object)
+ """Signal emitting a dict upon contextual menu actions."""
+
+ def __init__(self, model):
+ super(LegendListContextMenu, self).__init__(parent=None)
+ self.model = model
+
+ self.addAction('Set Active', self.setActiveAction)
+ self.addAction('Map to left', self.mapToLeftAction)
+ self.addAction('Map to right', self.mapToRightAction)
+
+ self._pointsAction = self.addAction(
+ 'Points', self.togglePointsAction)
+ self._pointsAction.setCheckable(True)
+
+ self._linesAction = self.addAction('Lines', self.toggleLinesAction)
+ self._linesAction.setCheckable(True)
+
+ self.addAction('Remove curve', self.removeItemAction)
+ self.addAction('Rename curve', self.renameItemAction)
+
+ def exec(self, pos, idx):
+ self.__currentIdx = idx
+
+ # Set checkable action state
+ modelIndex = self.currentIdx()
+ self._pointsAction.setChecked(
+ modelIndex.data(LegendModel.showSymbolRole))
+ self._linesAction.setChecked(
+ modelIndex.data(LegendModel.showLineRole))
+
+ super(LegendListContextMenu, self).popup(pos)
+
+ def exec_(self, pos, idx): # Qt5-like compatibility
+ return self.exec(pos, idx)
+
+ def currentIdx(self):
+ return self.__currentIdx
+
+ def mapToLeftAction(self):
+ _logger.debug('LegendListContextMenu.mapToLeftAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "mapToLeft"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def mapToRightAction(self):
+ _logger.debug('LegendListContextMenu.mapToRightAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "mapToRight"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def removeItemAction(self):
+ _logger.debug('LegendListContextMenu.removeCurveAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "removeCurve"
+ }
+ self.model.removeRow(modelIndex.row())
+ self.sigContextMenu.emit(ddict)
+
+ def renameItemAction(self):
+ _logger.debug('LegendListContextMenu.renameCurveAction called')
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "renameCurve"
+ }
+ self.sigContextMenu.emit(ddict)
+
+ def toggleLinesAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ }
+ linestyle = modelIndex.data(LegendModel.iconLineStyleRole)
+ visible = not modelIndex.data(LegendModel.showLineRole)
+ _logger.debug('toggleLinesAction -- lines visible: %s', str(visible))
+ ddict['event'] = "toggleLine"
+ ddict['line'] = visible
+ ddict['linestyle'] = linestyle if visible else ''
+ self.model.setData(modelIndex, visible, LegendModel.showLineRole)
+ self.sigContextMenu.emit(ddict)
+
+ def togglePointsAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ }
+ flag = modelIndex.data(LegendModel.showSymbolRole)
+ symbol = modelIndex.data(LegendModel.iconSymbolRole)
+ visible = not flag or LegendIconWidget.isEmptySymbol(symbol)
+ _logger.debug(
+ 'togglePointsAction -- Symbols visible: %s', str(visible))
+
+ ddict['event'] = "togglePoints"
+ ddict['points'] = visible
+ ddict['symbol'] = symbol if visible else ''
+ self.model.setData(modelIndex, visible, LegendModel.showSymbolRole)
+ self.sigContextMenu.emit(ddict)
+
+ def setActiveAction(self):
+ modelIndex = self.currentIdx()
+ legend = str(modelIndex.data(qt.Qt.DisplayRole))
+ _logger.debug('setActiveAction -- active curve: %s', legend)
+ ddict = {
+ 'legend': legend,
+ 'label': legend,
+ 'selected': modelIndex.data(qt.Qt.CheckStateRole),
+ 'type': str(modelIndex.data()),
+ 'event': "setActiveCurve",
+ }
+ self.sigContextMenu.emit(ddict)
+
+
+class RenameCurveDialog(qt.QDialog):
+ """Dialog box to input the name of a curve."""
+
+ def __init__(self, parent=None, current="", curves=()):
+ super(RenameCurveDialog, self).__init__(parent)
+ self.setWindowTitle("Rename Curve %s" % current)
+ self.curves = curves
+ layout = qt.QVBoxLayout(self)
+ self.lineEdit = qt.QLineEdit(self)
+ self.lineEdit.setText(current)
+ self.hbox = qt.QWidget(self)
+ self.hboxLayout = qt.QHBoxLayout(self.hbox)
+ self.hboxLayout.addStretch(1)
+ self.okButton = qt.QPushButton(self.hbox)
+ self.okButton.setText('OK')
+ self.hboxLayout.addWidget(self.okButton)
+ self.cancelButton = qt.QPushButton(self.hbox)
+ self.cancelButton.setText('Cancel')
+ self.hboxLayout.addWidget(self.cancelButton)
+ self.hboxLayout.addStretch(1)
+ layout.addWidget(self.lineEdit)
+ layout.addWidget(self.hbox)
+ self.okButton.clicked.connect(self.preAccept)
+ self.cancelButton.clicked.connect(self.reject)
+
+ def preAccept(self):
+ text = str(self.lineEdit.text())
+ addedText = ""
+ if len(text):
+ if text not in self.curves:
+ self.accept()
+ return
+ else:
+ addedText = "Curve already exists."
+ text = "Invalid Curve Name"
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setWindowTitle(text)
+ text += "\n%s" % addedText
+ msg.setText(text)
+ msg.exec()
+
+ def getText(self):
+ return str(self.lineEdit.text())
+
+
+class LegendsDockWidget(qt.QDockWidget):
+ """QDockWidget with a :class:`LegendSelector` connected to a PlotWindow.
+
+ It makes the link between the LegendListView widget and the PlotWindow.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: :class:`.PlotWindow` instance on which to operate
+ """
+
+ def __init__(self, parent=None, plot=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._isConnected = False # True if widget connected to plot signals
+
+ super(LegendsDockWidget, self).__init__("Legends", parent)
+
+ self._legendWidget = LegendListView()
+
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self.setWidget(self._legendWidget)
+
+ self.visibilityChanged.connect(
+ self._visibilityChangedHandler)
+
+ self._legendWidget.sigLegendSignal.connect(self._legendSignalHandler)
+
+ @property
+ def plot(self):
+ """The :class:`.PlotWindow` this widget is attached to."""
+ return self._plotRef()
+
+ def renameCurve(self, oldLegend, newLegend):
+ """Change the name of a curve using remove and addCurve
+
+ :param str oldLegend: The legend of the curve to be changed
+ :param str newLegend: The new legend of the curve
+ """
+ is_active = self.plot.getActiveCurve(just_legend=True) == oldLegend
+ curve = self.plot.getCurve(oldLegend)
+ self.plot.remove(oldLegend, kind='curve')
+ self.plot.addCurve(curve.getXData(copy=False),
+ curve.getYData(copy=False),
+ legend=newLegend,
+ info=curve.getInfo(),
+ color=curve.getColor(),
+ symbol=curve.getSymbol(),
+ linewidth=curve.getLineWidth(),
+ linestyle=curve.getLineStyle(),
+ xlabel=curve.getXLabel(),
+ ylabel=curve.getYLabel(),
+ xerror=curve.getXErrorData(copy=False),
+ yerror=curve.getYErrorData(copy=False),
+ z=curve.getZValue(),
+ selectable=curve.isSelectable(),
+ fill=curve.isFill(),
+ resetzoom=False)
+ if is_active:
+ self.plot.setActiveCurve(newLegend)
+
+ def _legendSignalHandler(self, ddict):
+ """Handles events from the LegendListView signal"""
+ _logger.debug("Legend signal ddict = %s", str(ddict))
+
+ if ddict['event'] == "legendClicked":
+ if ddict['button'] == "left":
+ self.plot.setActiveCurve(ddict['legend'])
+
+ elif ddict['event'] == "removeCurve":
+ self.plot.removeCurve(ddict['legend'])
+
+ elif ddict['event'] == "renameCurve":
+ curveList = self.plot.getAllCurves(just_legend=True)
+ oldLegend = ddict['legend']
+ dialog = RenameCurveDialog(self.plot, oldLegend, curveList)
+ ret = dialog.exec()
+ if ret:
+ newLegend = dialog.getText()
+ self.renameCurve(oldLegend, newLegend)
+
+ elif ddict['event'] == "setActiveCurve":
+ self.plot.setActiveCurve(ddict['legend'])
+
+ elif ddict['event'] == "checkBoxClicked":
+ self.plot.hideCurve(ddict['legend'], not ddict['selected'])
+
+ elif ddict['event'] in ["mapToRight", "mapToLeft"]:
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ yaxis = 'right' if ddict['event'] == 'mapToRight' else 'left'
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getName(),
+ info=curve.getInfo(),
+ yaxis=yaxis)
+
+ elif ddict['event'] == "togglePoints":
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ symbol = ddict['symbol'] if ddict['points'] else ''
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getName(),
+ info=curve.getInfo(),
+ symbol=symbol)
+
+ elif ddict['event'] == "toggleLine":
+ legend = ddict['legend']
+ curve = self.plot.getCurve(legend)
+ linestyle = ddict['linestyle'] if ddict['line'] else ''
+ self.plot.addCurve(x=curve.getXData(copy=False),
+ y=curve.getYData(copy=False),
+ legend=curve.getName(),
+ info=curve.getInfo(),
+ linestyle=linestyle)
+
+ else:
+ _logger.debug("unhandled event %s", str(ddict['event']))
+
+ def updateLegends(self, *args):
+ """Sync the LegendSelector widget displayed info with the plot.
+ """
+ legendList = []
+ for curve in self.plot.getAllCurves(withhidden=True):
+ legend = curve.getName()
+ # Use active color if curve is active
+ isActive = legend == self.plot.getActiveCurve(just_legend=True)
+ style = curve.getCurrentStyle()
+ color = style.getColor()
+ if numpy.array(color, copy=False).ndim != 1:
+ # array of colors, use transparent black
+ color = 0., 0., 0., 0.
+
+ curveInfo = {
+ 'color': qt.QColor.fromRgbF(*color),
+ 'linewidth': style.getLineWidth(),
+ 'linestyle': style.getLineStyle(),
+ 'symbol': style.getSymbol(),
+ 'selected': not self.plot.isCurveHidden(legend),
+ 'active': isActive}
+ legendList.append((legend, curveInfo))
+
+ self._legendWidget.setLegendList(legendList)
+
+ def _visibilityChangedHandler(self, visible):
+ if visible:
+ self.updateLegends()
+ if not self._isConnected:
+ self.plot.sigContentChanged.connect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.connect(self.updateLegends)
+ self._isConnected = True
+ else:
+ if self._isConnected:
+ self.plot.sigContentChanged.disconnect(self.updateLegends)
+ self.plot.sigActiveCurveChanged.disconnect(self.updateLegends)
+ self._isConnected = False
+
+ def showEvent(self, event):
+ """Make sure this widget is raised when it is shown
+ (when it is first created as a tab in PlotWindow or when it is shown
+ again after hiding).
+ """
+ self.raise_()
diff --git a/silx/gui/plot/LimitsHistory.py b/src/silx/gui/plot/LimitsHistory.py
index a323548..a323548 100644
--- a/silx/gui/plot/LimitsHistory.py
+++ b/src/silx/gui/plot/LimitsHistory.py
diff --git a/src/silx/gui/plot/MaskToolsWidget.py b/src/silx/gui/plot/MaskToolsWidget.py
new file mode 100644
index 0000000..522be48
--- /dev/null
+++ b/src/silx/gui/plot/MaskToolsWidget.py
@@ -0,0 +1,919 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with :class:`silx.gui.plot.PlotWidget`.
+
+- :class:`ImageMask`: Handle mask bitmap update and history
+- :class:`MaskToolsWidget`: GUI for :class:`Mask`
+- :class:`MaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+from __future__ import division
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import os
+import sys
+import numpy
+import logging
+import collections
+import h5py
+
+from silx.image import shapes
+from silx.io.utils import NEXUS_HDF5_EXT, is_dataset
+from silx.gui.dialog.DatasetDialog import DatasetDialog
+
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from . import items
+from ..colors import cursorColorForColormap, rgba
+from .. import qt
+from ..utils import LockReentrant
+
+from silx.third_party.EdfFile import EdfFile
+from silx.third_party.TiffIO import TiffIO
+
+import fabio
+
+_logger = logging.getLogger(__name__)
+
+_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
+
+
+def _selectDataset(filename, mode=DatasetDialog.SaveMode):
+ """Open a dialog to prompt the user to select a dataset in
+ a hdf5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :param mode: DatasetDialog.SaveMode or DatasetDialog.LoadMode
+ :rtype: str
+ :return: Name of selected dataset
+ """
+ dialog = DatasetDialog()
+ dialog.addFile(filename)
+ dialog.setWindowTitle("Select a 2D dataset")
+ dialog.setMode(mode)
+ if not dialog.exec():
+ return None
+ return dialog.getSelectedDataUrl().data_path()
+
+
+class ImageMask(BaseMask):
+ """A 2D mask field with update operations.
+
+ Coords follows (row, column) convention and are in mask array coords.
+
+ This is meant for internal use by :class:`MaskToolsWidget`.
+ """
+
+ def __init__(self, image=None):
+ """
+
+ :param image: :class:`silx.gui.plot.items.ImageBase` instance
+ """
+ BaseMask.__init__(self, image)
+ self.reset(shape=(0, 0)) # Init the mask with a 2D shape
+
+ def getDataValues(self):
+ """Return image data as a 2D or 3D array (if it is a RGBA image).
+
+ :rtype: 2D or 3D numpy.ndarray
+ """
+ return self._dataItem.getData(copy=False)
+
+ def save(self, filename, kind):
+ """Save current mask in a file
+
+ :param str filename: The file where to save to mask
+ :param str kind: The kind of file to save in 'edf', 'tif', 'npy', 'h5'
+ or 'msk' (if FabIO is installed)
+ :raise Exception: Raised if the file writing fail
+ """
+ if kind == 'edf':
+ edfFile = EdfFile(filename, access="w+")
+ header = {"program_name": "silx-mask", "masked_value": "nonzero"}
+ edfFile.WriteImage(header, self.getMask(copy=False), Append=0)
+
+ elif kind == 'tif':
+ tiffFile = TiffIO(filename, mode='w')
+ tiffFile.writeImage(self.getMask(copy=False), software='silx')
+
+ elif kind == 'npy':
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ elif ("." + kind) in NEXUS_HDF5_EXT:
+ self._saveToHdf5(filename, self.getMask(copy=False))
+
+ elif kind == 'msk':
+ try:
+ data = self.getMask(copy=False)
+ image = fabio.fabioimage.FabioImage(data=data)
+ image = image.convert(fabio.fit2dmaskimage.Fit2dMaskImage)
+ image.save(filename)
+ except Exception:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("Mask file can't be written")
+ else:
+ raise ValueError("Format '%s' is not supported" % kind)
+
+ @staticmethod
+ def _saveToHdf5(filename, mask):
+ """Save a mask array to a HDF5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :param numpy.ndarray mask: Mask array.
+ :returns: True if operation succeeded, False otherwise.
+ """
+ if not os.path.exists(filename):
+ # create new file
+ with h5py.File(filename, "w") as _h5f:
+ pass
+ dataPath = _selectDataset(filename)
+ if dataPath is None:
+ return False
+ with h5py.File(filename, "a") as h5f:
+ existing_ds = h5f.get(dataPath)
+ if existing_ds is not None:
+ reply = qt.QMessageBox.question(
+ None,
+ "Confirm overwrite",
+ "Do you want to overwrite an existing dataset?",
+ qt.QMessageBox.Yes | qt.QMessageBox.No)
+ if reply != qt.QMessageBox.Yes:
+ return False
+ del h5f[dataPath]
+ try:
+ h5f.create_dataset(dataPath, data=mask)
+ except Exception:
+ return False
+ return True
+
+ # Drawing operations
+ def updateRectangle(self, level, row, col, height, width, mask=True):
+ """Mask/Unmask a rectangle of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row: Starting row of the rectangle
+ :param int col: Starting column of the rectangle
+ :param int height:
+ :param int width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ assert 0 < level < 256
+ if row + height <= 0 or col + width <= 0:
+ return # Rectangle outside image, avoid negative indices
+ selection = self._mask[max(0, row):row + height + 1,
+ max(0, col):col + width + 1]
+ if mask:
+ selection[:,:] = level
+ else:
+ selection[selection == level] = 0
+ self._notify()
+
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ fill = shapes.polygon_fill_mask(vertices, self._mask.shape)
+ if mask:
+ self._mask[fill != 0] = level
+ else:
+ self._mask[numpy.logical_and(fill != 0,
+ self._mask == level)] = 0
+ self._notify()
+
+ def updatePoints(self, level, rows, cols, mask=True):
+ """Mask/Unmask points with given coordinates.
+
+ :param int level: Mask level to update.
+ :param rows: Rows of selected points
+ :type rows: 1D numpy.ndarray
+ :param cols: Columns of selected points
+ :type cols: 1D numpy.ndarray
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ valid = numpy.logical_and(
+ numpy.logical_and(rows >= 0, cols >= 0),
+ numpy.logical_and(rows < self._mask.shape[0],
+ cols < self._mask.shape[1]))
+ rows, cols = rows[valid], cols[valid]
+
+ if mask:
+ self._mask[rows, cols] = level
+ else:
+ inMask = self._mask[rows, cols] == level
+ self._mask[rows[inMask], cols[inMask]] = 0
+ self._notify()
+
+ def updateDisk(self, level, crow, ccol, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Disk center row.
+ :param int ccol: Disk center column.
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.circle_fill(crow, ccol, radius)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.ellipse_fill(crow, ccol, radius_r, radius_c)
+ self.updatePoints(level, rows, cols, mask)
+
+ def updateLine(self, level, row0, col0, row1, col1, width, mask=True):
+ """Mask/Unmask a line of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int row0: Row of the starting point.
+ :param int col0: Column of the starting point.
+ :param int row1: Row of the end point.
+ :param int col1: Column of the end point.
+ :param int width: Width of the line in mask array unit.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ rows, cols = shapes.draw_line(row0, col0, row1, col1, width)
+ self.updatePoints(level, rows, cols, mask)
+
+
+class MaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for drawing mask on an image in a PlotWidget."""
+
+ _maxLevelNumber = 255
+
+ def __init__(self, parent=None, plot=None):
+ super(MaskToolsWidget, self).__init__(parent, plot,
+ mask=ImageMask())
+ self._origin = (0., 0.) # Mask origin in plot
+ self._scale = (1., 1.) # Mask scale in plot
+ self._z = 1 # Mask layer in plot
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8) # Store image
+
+ self.__itemMaskUpdatedLock = LockReentrant()
+ self.__itemMaskUpdated = False
+
+ def __maskStateChanged(self) -> None:
+ """Handle mask commit to update item mask"""
+ item = self._mask.getDataItem()
+ if item is not None:
+ with self.__itemMaskUpdatedLock:
+ item.setMaskData(self._mask.getMask(copy=True), copy=False)
+
+ def setItemMaskUpdated(self, enabled: bool) -> None:
+ """Toggle item mask and mask tool synchronisation.
+
+ :param bool enabled: True to synchronise. Default: False
+ """
+ enabled = bool(enabled)
+ if enabled != self.__itemMaskUpdated:
+ if self.__itemMaskUpdated:
+ self._mask.sigStateChanged.disconnect(self.__maskStateChanged)
+ self.__itemMaskUpdated = enabled
+ if self.__itemMaskUpdated:
+ # Synchronize item and tool mask
+ self._setMaskedImage(self._mask.getDataItem())
+ self._mask.sigStateChanged.connect(self.__maskStateChanged)
+
+ def isItemMaskUpdated(self) -> bool:
+ """Returns whether or not item and mask tool masks are synchronised.
+
+ :rtype: bool
+ """
+ return self.__itemMaskUpdated
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask:
+ The array to use for the mask or None to reset the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 2-tuple if successful.
+ The mask can be cropped or padded to fit active image,
+ the returned shape is that of the active image.
+ """
+ if mask is None:
+ self.resetSelectionMask()
+ return self._data.shape[:2]
+
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+ if len(mask.shape) != 2:
+ _logger.error('Not an image, shape: %d', len(mask.shape))
+ return None
+
+ # Handle mask with single level
+ if self.multipleMasks() == 'single':
+ mask = numpy.array(mask != 0, dtype=numpy.uint8)
+
+ # if mask has not changed, do nothing
+ if numpy.array_equal(mask, self.getSelectionMask()):
+ return mask.shape
+
+ if self._data.shape[0:2] == (0, 0) or mask.shape == self._data.shape[0:2]:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ _logger.warning('Mask has not the same size as current image.'
+ ' Mask will be cropped or padded to fit image'
+ ' dimensions. %s != %s',
+ str(mask.shape), str(self._data.shape))
+ resizedMask = numpy.zeros(self._data.shape[0:2],
+ dtype=numpy.uint8)
+ height = min(self._data.shape[0], mask.shape[0])
+ width = min(self._data.shape[1], mask.shape[1])
+ resizedMask[:height,:width] = mask[:height,:width]
+ self._mask.setMask(resizedMask, copy=False)
+ self._mask.commit()
+ return resizedMask.shape
+
+ # Handle mask refresh on the plot
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if mask is not None:
+ # get the mask from the plot
+ maskItem = self.plot.getImage(self._maskName)
+ mustBeAdded = maskItem is None
+ if mustBeAdded:
+ maskItem = items.MaskImageData()
+ maskItem.setName(self._maskName)
+ # update the items
+ maskItem.setData(mask, copy=False)
+ maskItem.setColormap(self._colormap)
+ maskItem.setOrigin(self._origin)
+ maskItem.setScale(self._scale)
+ maskItem.setZValue(self._z)
+
+ if mustBeAdded:
+ self.plot.addItem(maskItem)
+
+ elif self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ except (RuntimeError, TypeError):
+ pass
+
+ # Sync with current active image
+ self._setMaskedImage(self.plot.getActiveImage())
+ self.plot.sigActiveImageChanged.connect(self._activeImageChanged)
+
+ def hideEvent(self, event):
+ try:
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChanged)
+ except (RuntimeError, TypeError):
+ pass
+
+ image = self.getMaskedItem()
+ if image is not None:
+ try:
+ image.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO should not happen
+
+ if self.isMaskInteractionActivated():
+ # Disable drawing tool
+ self.browseAction.trigger()
+
+ if self.isItemMaskUpdated(): # No "after-care"
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ elif self.getSelectionMask(copy=False) is not None:
+ self.plot.sigActiveImageChanged.connect(
+ self._activeImageChangedAfterCare)
+
+ def _activeImageChanged(self, previous, current):
+ """Reacts upon active image change.
+
+ Only handle change of active image items here.
+ """
+ if previous != current:
+ image = self.plot.getActiveImage()
+ if image is not None and image.getName() == self._maskName:
+ image = None # Active image is the mask
+ self._setMaskedImage(image)
+
+ def _setOverlayColorForImage(self, image):
+ """Set the color of overlay adapted to image
+
+ :param image: :class:`.items.ImageBase` object to set color for.
+ """
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ self._defaultOverlayColor = rgba(
+ cursorColorForColormap(colormap['name']))
+ else:
+ self._defaultOverlayColor = rgba('black')
+
+ def _activeImageChangedAfterCare(self, *args):
+ """Check synchro of active image and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to origin, scale and z.
+ """
+ activeImage = self.plot.getActiveImage()
+ if activeImage is None or activeImage.getName() == self._maskName:
+ # No active image or active image is the mask...
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.setDataItem(None)
+ self._mask.reset()
+
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ else:
+ self._setOverlayColorForImage(activeImage)
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = activeImage.getOrigin()
+ self._scale = activeImage.getScale()
+ self._z = activeImage.getZValue() + 1
+ self._data = activeImage.getData(copy=False)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ # Image has not the same size, remove mask and stop listening
+ if self.plot.getImage(self._maskName):
+ self.plot.remove(self._maskName, kind='image')
+
+ self.plot.sigActiveImageChanged.disconnect(
+ self._activeImageChangedAfterCare)
+ else:
+ # Refresh in case origin, scale, z changed
+ self._mask.setDataItem(activeImage)
+ self._updatePlotMask()
+
+ def _setMaskedImage(self, image):
+ """Change the image that is used a reference to author the mask"""
+ previous = self.getMaskedItem()
+ if previous is not None and self.isVisible():
+ # Disconnect from previous image
+ try:
+ previous.sigItemChanged.disconnect(self.__imageChanged)
+ except (RuntimeError, TypeError):
+ pass # TODO fixme should not happen
+
+ # Set the image
+ self._mask.setDataItem(image)
+
+ if image is None: # No image, disable mask
+ self.setEnabled(False)
+
+ self._data = numpy.zeros((0, 0), dtype=numpy.uint8)
+ self._mask.reset()
+ self._mask.commit()
+
+ self._updateInteractiveMode()
+
+ else: # Update and connect to image's sigItemChanged
+ if self.isItemMaskUpdated():
+ if image.getMaskData(copy=False) is None:
+ # Image item has no mask: use current mask from the tool
+ image.setMaskData(
+ self.getSelectionMask(copy=False), copy=True)
+ else: # Image item has a mask: set it in tool
+ self.setSelectionMask(
+ image.getMaskData(copy=False), copy=True)
+ self._mask.resetHistory()
+ self.__imageUpdated()
+ if self.isVisible():
+ image.sigItemChanged.connect(self.__imageChanged)
+
+ def __imageChanged(self, event):
+ """Reacts upon image item changes"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("Mask is not attached to an image")
+ return
+
+ if event in (items.ItemChangedType.COLORMAP,
+ items.ItemChangedType.DATA,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE,
+ items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.ZVALUE):
+ self.__imageUpdated()
+
+ elif (event == items.ItemChangedType.MASK and
+ self.isItemMaskUpdated() and
+ not self.__itemMaskUpdatedLock.locked()):
+ # Update mask from the image item unless mask tool is updating it
+ self.setSelectionMask(image.getMaskData(copy=False), copy=True)
+
+ def __imageUpdated(self):
+ """Synchronize mask with current state of the image"""
+ image = self._mask.getDataItem()
+ if image is None:
+ _logger.error("No active image while expecting one")
+ return
+
+ self._setOverlayColorForImage(image)
+
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+
+ self._origin = image.getOrigin()
+ self._scale = image.getScale()
+ self._z = image.getZValue() + 1
+ self._data = image.getData(copy=False)
+ self._mask.setDataItem(image)
+ if self._data.shape[:2] != self._mask.getMask(copy=False).shape:
+ self._mask.reset(self._data.shape[:2])
+ self._mask.commit()
+ else:
+ # Refresh in case origin, scale, z changed
+ self._updatePlotMask()
+
+ # Visible and with data
+ self.setEnabled(image.isVisible() and self._data.size != 0)
+
+ # Threshold tools only available for data with colormap
+ self.thresholdGroup.setEnabled(self._data.ndim == 2)
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.', filename)
+ elif extension in ["tif", "tiff"]:
+ try:
+ image = TiffIO(filename, mode="r")
+ mask = image.getImage(0)
+ except Exception as e:
+ _logger.error("Can't load filename %s", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif extension == "edf":
+ try:
+ mask = EdfFile(filename, access='r').GetData(0)
+ except Exception as e:
+ _logger.error("Can't load filename %s", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif extension == "msk":
+ try:
+ mask = fabio.open(filename).data
+ except Exception as e:
+ _logger.error("Can't load fit2d mask file")
+ _logger.debug("Backtrace", exc_info=True)
+ raise e
+ elif ("." + extension) in NEXUS_HDF5_EXT:
+ mask = self._loadFromHdf5(filename)
+ if mask is None:
+ raise IOError("Could not load mask from HDF5 dataset")
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ effectiveMaskShape = self.setSelectionMask(mask, copy=False)
+ if effectiveMaskShape is None:
+ return
+ if mask.shape != effectiveMaskShape:
+ msg = 'Mask was resized from %s to %s'
+ msg = msg % (str(mask.shape), str(effectiveMaskShape))
+ raise RuntimeWarning(msg)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+
+ extensions = collections.OrderedDict()
+ extensions["EDF files"] = "*.edf"
+ extensions["TIFF files"] = "*.tif *.tiff"
+ extensions["NumPy binary files"] = "*.npy"
+ extensions["HDF5 files"] = _HDF5_EXT_STR
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ extensions["Fit2D mask files"] = "*.msk"
+
+ filters = []
+ filters.append("All supported files (%s)" % " ".join(extensions.values()))
+ for name, extension in extensions.items():
+ filters.append("%s (%s)" % (name, extension))
+ filters.append("All files (*)")
+
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.load(filename)
+ except RuntimeWarning as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Warning)
+ msg.setText("Mask loaded but an operation was applied.\n" + message)
+ msg.exec()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec()
+
+ @staticmethod
+ def _loadFromHdf5(filename):
+ """Load a mask array from a HDF5 file.
+
+ :param str filename: name of an existing HDF5 file
+ :returns: A mask as a numpy array, or None if the interactive dialog
+ was cancelled
+ """
+ dataPath = _selectDataset(filename, mode=DatasetDialog.LoadMode)
+ if dataPath is None:
+ return None
+
+ with h5py.File(filename, "r") as h5f:
+ dataset = h5f.get(dataPath)
+ if not is_dataset(dataset):
+ raise IOError("%s is not a dataset" % dataPath)
+ mask = dataset[()]
+ return mask
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setOption(dialog.DontUseNativeDialog)
+ dialog.setModal(1)
+ hdf5Filter = 'HDF5 (%s)' % _HDF5_EXT_STR
+ filters = [
+ 'EDF (*.edf)',
+ 'TIFF (*.tif)',
+ 'NumPy binary file (*.npy)',
+ hdf5Filter,
+ # Fit2D mask is displayed anyway fabio is here or not
+ # to show to the user that the option exists
+ 'Fit2D mask (*.msk)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+
+ def onFilterSelection(filt_):
+ # disable overwrite confirmation for HDF5,
+ # because we append the data to existing files
+ if filt_ == hdf5Filter:
+ dialog.setOption(dialog.DontConfirmOverwrite)
+ else:
+ dialog.setOption(dialog.DontConfirmOverwrite, False)
+
+ dialog.filterSelected.connect(onFilterSelection)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if "HDF5" in nameFilter:
+ has_allowed_ext = False
+ for ext in NEXUS_HDF5_EXT:
+ if (len(filename) > len(ext) and
+ filename[-len(ext):].lower() == ext.lower()):
+ has_allowed_ext = True
+ extension = ext
+ if not has_allowed_ext:
+ extension = ".h5"
+ filename += ".h5"
+ else:
+ # convert filter name to extension name with the .
+ extension = nameFilter.split()[-1][2:-1]
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename) and "HDF5" not in nameFilter:
+ try:
+ os.remove(filename)
+ except IOError as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save.\n"
+ "Input Output Error: %s" % strerror)
+ msg.exec()
+ return
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
+ msg.exec()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(shape=self._data.shape[:2])
+ self._mask.commit()
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if (self._drawingMode is None or
+ event['event'] not in ('drawingProgress', 'drawingFinished')):
+ return
+
+ if not len(self._data):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ # Convert from plot to array coords
+ doMask = self._isMasking()
+ ox, oy = self._origin
+ sx, sy = self._scale
+
+ height = int(abs(event['height'] / sy))
+ width = int(abs(event['width'] / sx))
+
+ row = int((event['y'] - oy) / sy)
+ if sy < 0:
+ row -= height
+
+ col = int((event['x'] - ox) / sx)
+ if sx < 0:
+ col -= width
+
+ self._mask.updateRectangle(
+ level,
+ row=row,
+ col=col,
+ height=height,
+ width=width,
+ mask=doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ center = (event['points'][0] - self._origin) / self._scale
+ size = event['points'][1] / self._scale
+ center = center.astype(numpy.int64) # (row, col)
+ self._mask.updateEllipse(level, center[1], center[0], size[1], size[0], doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ # Convert from plot to array coords
+ vertices = (event['points'] - self._origin) / self._scale
+ vertices = vertices.astype(numpy.int64)[:, (1, 0)] # (row, col)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'pencil':
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ col, row = (event['points'][-1] - self._origin) / self._scale
+ col, row = int(col), int(row)
+ brushSize = self._getPencilWidth()
+
+ if self._lastPencilPos != (row, col):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0], self._lastPencilPos[1],
+ row, col,
+ brushSize,
+ doMask)
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, row, col, brushSize / 2., doMask)
+
+ if event['event'] == 'drawingFinished':
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = row, col
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active image colormap range"""
+ activeImage = self.plot.getActiveImage()
+ if (isinstance(activeImage, items.ColormapMixIn) and
+ activeImage.getName() != self._maskName):
+ # Update thresholds according to colormap
+ colormap = activeImage.getColormap()
+ if colormap['autoscale']:
+ min_ = numpy.nanmin(activeImage.getData(copy=False))
+ max_ = numpy.nanmax(activeImage.getData(copy=False))
+ else:
+ min_, max_ = colormap['vmin'], colormap['vmax']
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class MaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`MaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+
+ def __init__(self, parent=None, plot=None, name='Mask'):
+ widget = MaskToolsWidget(plot=plot)
+ super(MaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/silx/gui/plot/PlotActions.py b/src/silx/gui/plot/PlotActions.py
index dd16221..dd16221 100644
--- a/silx/gui/plot/PlotActions.py
+++ b/src/silx/gui/plot/PlotActions.py
diff --git a/silx/gui/plot/PlotEvents.py b/src/silx/gui/plot/PlotEvents.py
index 83f253c..83f253c 100644
--- a/silx/gui/plot/PlotEvents.py
+++ b/src/silx/gui/plot/PlotEvents.py
diff --git a/src/silx/gui/plot/PlotInteraction.py b/src/silx/gui/plot/PlotInteraction.py
new file mode 100644
index 0000000..6ebe6b1
--- /dev/null
+++ b/src/silx/gui/plot/PlotInteraction.py
@@ -0,0 +1,1746 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Implementation of the interaction for the :class:`Plot`."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import math
+import numpy
+import time
+import weakref
+
+from .. import colors
+from .. import qt
+from . import items
+from .Interaction import (ClickOrDrag, LEFT_BTN, RIGHT_BTN, MIDDLE_BTN,
+ State, StateMachine)
+from .PlotEvents import (prepareCurveSignal, prepareDrawingSignal,
+ prepareHoverSignal, prepareImageSignal,
+ prepareMarkerSignal, prepareMouseSignal)
+
+from .backends.BackendBase import (CURSOR_POINTING, CURSOR_SIZE_HOR,
+ CURSOR_SIZE_VER, CURSOR_SIZE_ALL)
+
+from ._utils import (FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX,
+ applyZoomToPlot)
+
+
+# Base class ##################################################################
+
+class _PlotInteraction(object):
+ """Base class for interaction handler.
+
+ It provides a weakref to the plot and methods to set/reset overlay.
+ """
+ def __init__(self, plot):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._needReplot = False
+ self._selectionAreas = set()
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ @property
+ def plot(self):
+ plot = self._plot()
+ assert plot is not None
+ return plot
+
+ def setSelectionArea(self, points, fill, color, name='', shape='polygon'):
+ """Set a polygon selection area overlaid on the plot.
+ Multiple simultaneous areas are supported through the name parameter.
+
+ :param points: The 2D coordinates of the points of the polygon
+ :type points: An iterable of (x, y) coordinates
+ :param str fill: The fill mode: 'hatch', 'solid' or 'none'
+ :param color: RGBA color to use or None to disable display
+ :type color: list or tuple of 4 float in the range [0, 1]
+ :param name: The key associated with this selection area
+ :param str shape: Shape of the area in 'polygon', 'polylines'
+ """
+ assert shape in ('polygon', 'polylines')
+
+ if color is None:
+ return
+
+ points = numpy.asarray(points)
+
+ # TODO Not very nice, but as is for now
+ legend = '__SELECTION_AREA__' + name
+
+ fill = fill != 'none' # TODO not very nice either
+
+ greyed = colors.greyed(color)[0]
+ if greyed < 0.5:
+ color2 = "white"
+ else:
+ color2 = "black"
+
+ self.plot.addShape(points[:, 0], points[:, 1], legend=legend,
+ replace=False,
+ shape=shape, fill=fill,
+ color=color, linebgcolor=color2, linestyle="--",
+ overlay=True)
+
+ self._selectionAreas.add(legend)
+
+ def resetSelectionArea(self):
+ """Remove all selection areas set by setSelectionArea."""
+ for legend in self._selectionAreas:
+ self.plot.remove(legend, kind='item')
+ self._selectionAreas = set()
+
+
+# Zoom/Pan ####################################################################
+
+class _ZoomOnWheel(ClickOrDrag, _PlotInteraction):
+ """:class:`ClickOrDrag` state machine with zooming on mouse wheel.
+
+ Base class for :class:`Pan` and :class:`Zoom`
+ """
+
+ _DOUBLE_CLICK_TIMEOUT = 0.4
+
+ class Idle(ClickOrDrag.Idle):
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+
+ def click(self, x, y, btn):
+ """Handle clicks by sending events
+
+ :param int x: Mouse X position in pixels
+ :param int y: Mouse Y position in pixels
+ :param btn: Clicked mouse button
+ """
+ if btn == LEFT_BTN:
+ lastClickTime, lastClickPos = self._lastClick
+
+ # Signal mouse double clicked event first
+ if (time.time() - lastClickTime) <= self._DOUBLE_CLICK_TIMEOUT:
+ # Use position of first click
+ eventDict = prepareMouseSignal('mouseDoubleClicked', 'left',
+ *lastClickPos)
+ self.plot.notify(**eventDict)
+
+ self._lastClick = 0., None
+ else:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', 'left',
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ self._lastClick = time.time(), (dataPos[0], dataPos[1], x, y)
+
+ elif btn == RIGHT_BTN:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', 'right',
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ def __init__(self, plot, **kwargs):
+ """Init.
+
+ :param plot: The plot to apply modifications to.
+ """
+ self._lastClick = 0., None
+
+ _PlotInteraction.__init__(self, plot)
+ ClickOrDrag.__init__(self, **kwargs)
+
+
+# Pan #########################################################################
+
+class Pan(_ZoomOnWheel):
+ """Pan plot content and zoom on wheel state machine."""
+
+ def _pixelToData(self, x, y):
+ xData, yData = self.plot.pixelToData(x, y)
+ _, y2Data = self.plot.pixelToData(x, y, axis='right')
+ return xData, yData, y2Data
+
+ def beginDrag(self, x, y, btn):
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def drag(self, x, y, btn):
+ xData, yData, y2Data = self._pixelToData(x, y)
+ lastX, lastY, lastY2 = self._previousDataPos
+
+ xMin, xMax = self.plot.getXAxis().getLimits()
+ yMin, yMax = self.plot.getYAxis().getLimits()
+ y2Min, y2Max = self.plot.getYAxis(axis='right').getLimits()
+
+ if self.plot.getXAxis()._isLogarithmic():
+ try:
+ dx = math.log10(xData) - math.log10(lastX)
+ newXMin = pow(10., (math.log10(xMin) - dx))
+ newXMax = pow(10., (math.log10(xMax) - dx))
+ except (ValueError, OverflowError):
+ newXMin, newXMax = xMin, xMax
+
+ # Makes sure both values stays in positive float32 range
+ if newXMin < FLOAT32_MINPOS or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+ else:
+ dx = xData - lastX
+ newXMin, newXMax = xMin - dx, xMax - dx
+
+ # Makes sure both values stays in float32 range
+ if newXMin < FLOAT32_SAFE_MIN or newXMax > FLOAT32_SAFE_MAX:
+ newXMin, newXMax = xMin, xMax
+
+ if self.plot.getYAxis()._isLogarithmic():
+ try:
+ dy = math.log10(yData) - math.log10(lastY)
+ newYMin = pow(10., math.log10(yMin) - dy)
+ newYMax = pow(10., math.log10(yMax) - dy)
+
+ dy2 = math.log10(y2Data) - math.log10(lastY2)
+ newY2Min = pow(10., math.log10(y2Min) - dy2)
+ newY2Max = pow(10., math.log10(y2Max) - dy2)
+ except (ValueError, OverflowError):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ # Makes sure y and y2 stays in positive float32 range
+ if (newYMin < FLOAT32_MINPOS or newYMax > FLOAT32_SAFE_MAX or
+ newY2Min < FLOAT32_MINPOS or newY2Max > FLOAT32_SAFE_MAX):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+ else:
+ dy = yData - lastY
+ dy2 = y2Data - lastY2
+ newYMin, newYMax = yMin - dy, yMax - dy
+ newY2Min, newY2Max = y2Min - dy2, y2Max - dy2
+
+ # Makes sure y and y2 stays in float32 range
+ if (newYMin < FLOAT32_SAFE_MIN or
+ newYMax > FLOAT32_SAFE_MAX or
+ newY2Min < FLOAT32_SAFE_MIN or
+ newY2Max > FLOAT32_SAFE_MAX):
+ newYMin, newYMax = yMin, yMax
+ newY2Min, newY2Max = y2Min, y2Max
+
+ self.plot.setLimits(newXMin, newXMax,
+ newYMin, newYMax,
+ newY2Min, newY2Max)
+
+ self._previousDataPos = self._pixelToData(x, y)
+
+ def endDrag(self, startPos, endPos, btn):
+ del self._previousDataPos
+
+ def cancel(self):
+ pass
+
+
+# Zoom ########################################################################
+
+class Zoom(_ZoomOnWheel):
+ """Zoom-in/out state machine.
+
+ Zoom-in on selected area, zoom-out on right click,
+ and zoom on mouse wheel.
+ """
+
+ SURFACE_THRESHOLD = 5
+
+ def __init__(self, plot, color):
+ self.color = color
+
+ super(Zoom, self).__init__(plot)
+ self.plot.getLimitsHistory().clear()
+
+ def _areaWithAspectRatio(self, x0, y0, x1, y1):
+ _plotLeft, _plotTop, plotW, plotH = self.plot.getPlotBoundsInPixels()
+
+ areaX0, areaY0, areaX1, areaY1 = x0, y0, x1, y1
+
+ if plotH != 0.:
+ plotRatio = plotW / float(plotH)
+ width, height = math.fabs(x1 - x0), math.fabs(y1 - y0)
+
+ if height != 0. and width != 0.:
+ if width / height > plotRatio:
+ areaHeight = width / plotRatio
+ areaX0, areaX1 = x0, x1
+ center = 0.5 * (y0 + y1)
+ areaY0 = center - numpy.sign(y1 - y0) * 0.5 * areaHeight
+ areaY1 = center + numpy.sign(y1 - y0) * 0.5 * areaHeight
+ else:
+ areaWidth = height * plotRatio
+ areaY0, areaY1 = y0, y1
+ center = 0.5 * (x0 + x1)
+ areaX0 = center - numpy.sign(x1 - x0) * 0.5 * areaWidth
+ areaX1 = center + numpy.sign(x1 - x0) * 0.5 * areaWidth
+
+ return areaX0, areaY0, areaX1, areaY1
+
+ def beginDrag(self, x, y, btn):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.x0, self.y0 = x, y
+
+ def drag(self, x1, y1, btn):
+ if self.color is None:
+ return # Do not draw zoom area
+
+ dataPos = self.plot.pixelToData(x1, y1)
+ assert dataPos is not None
+
+ if self.plot.isKeepDataAspectRatio():
+ area = self._areaWithAspectRatio(self.x0, self.y0, x1, y1)
+ areaX0, areaY0, areaX1, areaY1 = area
+ areaPoints = ((areaX0, areaY0),
+ (areaX1, areaY0),
+ (areaX1, areaY1),
+ (areaX0, areaY1))
+ areaPoints = numpy.array([self.plot.pixelToData(
+ x, y, check=False) for (x, y) in areaPoints])
+
+ if self.color != 'video inverted':
+ areaColor = list(self.color)
+ areaColor[3] *= 0.25
+ else:
+ areaColor = [1., 1., 1., 1.]
+
+ self.setSelectionArea(areaPoints,
+ fill='none',
+ color=areaColor,
+ name="zoomedArea")
+
+ corners = ((self.x0, self.y0),
+ (self.x0, y1),
+ (x1, y1),
+ (x1, self.y0))
+ corners = numpy.array([self.plot.pixelToData(x, y, check=False)
+ for (x, y) in corners])
+
+ self.setSelectionArea(corners, fill='none', color=self.color)
+
+ def _zoom(self, x0, y0, x1, y1):
+ """Zoom to the rectangle view x0,y0 x1,y1.
+ """
+ startPos = x0, y0
+ endPos = x1, y1
+
+ # Store current zoom state in stack
+ self.plot.getLimitsHistory().push()
+
+ if self.plot.isKeepDataAspectRatio():
+ x0, y0, x1, y1 = self._areaWithAspectRatio(x0, y0, x1, y1)
+
+ # Convert to data space and set limits
+ x0, y0 = self.plot.pixelToData(x0, y0, check=False)
+
+ dataPos = self.plot.pixelToData(
+ startPos[0], startPos[1], axis="right", check=False)
+ y2_0 = dataPos[1]
+
+ x1, y1 = self.plot.pixelToData(x1, y1, check=False)
+
+ dataPos = self.plot.pixelToData(
+ endPos[0], endPos[1], axis="right", check=False)
+ y2_1 = dataPos[1]
+
+ xMin, xMax = min(x0, x1), max(x0, x1)
+ yMin, yMax = min(y0, y1), max(y0, y1)
+ y2Min, y2Max = min(y2_0, y2_1), max(y2_0, y2_1)
+
+ self.plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ def endDrag(self, startPos, endPos, btn):
+ x0, y0 = startPos
+ x1, y1 = endPos
+
+ if abs(x0 - x1) * abs(y0 - y1) >= self.SURFACE_THRESHOLD:
+ # Avoid empty zoom area
+ self._zoom(x0, y0, x1, y1)
+
+ self.resetSelectionArea()
+
+ def cancel(self):
+ if isinstance(self.state, self.states['drag']):
+ self.resetSelectionArea()
+
+
+# Select ######################################################################
+
+class Select(StateMachine, _PlotInteraction):
+ """Base class for drawing selection areas."""
+
+ def __init__(self, plot, parameters, states, state):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ :param dict states: The states of the state machine.
+ :param str state: The name of the initial state.
+ """
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+ StateMachine.__init__(self, states, state)
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.plot, scaleF, (x, y))
+
+ @property
+ def color(self):
+ return self.parameters.get('color', None)
+
+
+class SelectPolygon(Select):
+ """Drawing selection polygon area state machine."""
+
+ DRAG_THRESHOLD_DIST = 4
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self._firstPos = dataPos
+ self.points = [dataPos, dataPos]
+
+ self.updateFirstPoint()
+
+ def updateFirstPoint(self):
+ """Update drawing first point, using self._firstPos"""
+ x, y = self.machine.plot.dataToPixel(*self._firstPos, check=False)
+
+ offset = self.machine.getDragThreshold()
+ points = [(x - offset, y - offset),
+ (x - offset, y + offset),
+ (x + offset, y + offset),
+ (x + offset, y - offset)]
+ points = [self.machine.plot.pixelToData(xpix, ypix, check=False)
+ for xpix, ypix in points]
+ self.machine.setSelectionArea(points, fill=None,
+ color=self.machine.color,
+ name='first_point')
+
+ def updateSelectionArea(self):
+ """Update drawing selection area using self.points"""
+ self.machine.setSelectionArea(self.points,
+ fill='hatch',
+ color=self.machine.color)
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+
+ def validate(self):
+ if len(self.points) > 2:
+ self.closePolygon()
+ else:
+ # It would be nice to have a cancel event.
+ # The plot is not aware that the interaction was cancelled
+ self.machine.cancel()
+
+ def closePolygon(self):
+ self.machine.resetSelectionArea()
+ self.points[-1] = self.points[0]
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polygon',
+ self.points,
+ self.machine.parameters)
+ self.machine.plot.notify(**eventDict)
+ self.goto('idle')
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle)
+ self.updateFirstPoint()
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ # checking if the position is close to the first point
+ # if yes : closing the "loop"
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+
+ threshold = self.machine.getDragThreshold()
+
+ # Only allow to close polygon after first point
+ if len(self.points) > 2 and dx <= threshold and dy <= threshold:
+ self.closePolygon()
+ return False
+
+ # Update polygon last point not too close to previous one
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.updateSelectionArea()
+
+ # checking that the new points isnt the same (within range)
+ # of the previous one
+ # This has to be done because sometimes the mouse release event
+ # is caught right after entering the Select state (i.e : press
+ # in Idle state, but with a slightly different position that
+ # the mouse press. So we had the two first vertices that were
+ # almost identical.
+ previousPos = self.machine.plot.dataToPixel(*self.points[-2],
+ check=False)
+ dx, dy = abs(previousPos[0] - x), abs(previousPos[1] - y)
+ if dx >= threshold or dy >= threshold:
+ self.points.append(dataPos)
+ else:
+ self.points[-1] = dataPos
+
+ return True
+ return False
+
+ def onMove(self, x, y):
+ firstPos = self.machine.plot.dataToPixel(*self._firstPos,
+ check=False)
+ dx, dy = abs(firstPos[0] - x), abs(firstPos[1] - y)
+ threshold = self.machine.getDragThreshold()
+
+ if dx <= threshold and dy <= threshold:
+ x, y = firstPos # Snap to first point
+
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ self.points[-1] = dataPos
+ self.updateSelectionArea()
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': SelectPolygon.Idle,
+ 'select': SelectPolygon.Select
+ }
+ super(SelectPolygon, self).__init__(plot, parameters,
+ states, 'idle')
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.resetSelectionArea()
+
+ def getDragThreshold(self):
+ """Return dragging ratio with device to pixel ratio applied.
+
+ :rtype: float
+ """
+ ratio = self.plot.window().windowHandle().devicePixelRatio()
+ return self.DRAG_THRESHOLD_DIST * ratio
+
+
+class Select2Points(Select):
+ """Base class for drawing selection based on 2 input points."""
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('start', x, y)
+ return True
+
+ class Start(State):
+ def enterState(self, x, y):
+ self.machine.beginSelect(x, y)
+
+ def onMove(self, x, y):
+ self.goto('select', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': Select2Points.Idle,
+ 'start': Select2Points.Start,
+ 'select': Select2Points.Select
+ }
+ super(Select2Points, self).__init__(plot, parameters,
+ states, 'idle')
+
+ def beginSelect(self, x, y):
+ pass
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.cancelSelect()
+
+
+class SelectEllipse(Select2Points):
+ """Drawing ellipse selection area state machine."""
+ def beginSelect(self, x, y):
+ self.center = self.plot.pixelToData(x, y)
+ assert self.center is not None
+
+ def _getEllipseSize(self, pointInEllipse):
+ """
+ Returns the size from the center to the bounding box of the ellipse.
+
+ :param Tuple[float,float] pointInEllipse: A point of the ellipse
+ :rtype: Tuple[float,float]
+ """
+ x = abs(self.center[0] - pointInEllipse[0])
+ y = abs(self.center[1] - pointInEllipse[1])
+ if x == 0 or y == 0:
+ return x, y
+ # Ellipse definitions
+ # e: eccentricity
+ # a: length fron center to bounding box width
+ # b: length fron center to bounding box height
+ # Equations
+ # (1) b < a
+ # (2) For x,y a point in the ellipse: x^2/a^2 + y^2/b^2 = 1
+ # (3) b = a * sqrt(1-e^2)
+ # (4) e = sqrt(a^2 - b^2) / a
+
+ # The eccentricity of the ellipse defined by a,b=x,y is the same
+ # as the one we are searching for.
+ swap = x < y
+ if swap:
+ x, y = y, x
+ e = math.sqrt(x**2 - y**2) / x
+ # From (2) using (3) to replace b
+ # a^2 = x^2 + y^2 / (1-e^2)
+ a = math.sqrt(x**2 + y**2 / (1.0 - e**2))
+ b = a * math.sqrt(1 - e**2)
+ if swap:
+ a, b = b, a
+ return a, b
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ # Circle used for circle preview
+ nbpoints = 27.
+ angles = numpy.arange(nbpoints) * numpy.pi * 2.0 / nbpoints
+ circleShape = numpy.array((numpy.cos(angles) * width,
+ numpy.sin(angles) * height)).T
+ circleShape += numpy.array(self.center)
+
+ self.setSelectionArea(circleShape,
+ shape="polygon",
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ width, height = self._getEllipseSize(dataPos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'ellipse',
+ (self.center, (width, height)),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectRectangle(Select2Points):
+ """Drawing rectangle selection area state machine."""
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt,
+ (self.startPt[0], dataPos[1]),
+ dataPos,
+ (dataPos[0], self.startPt[1])),
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'rectangle',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'rectangle',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectLine(Select2Points):
+ """Drawing line selection area state machine."""
+ def beginSelect(self, x, y):
+ self.startPt = self.plot.pixelToData(x, y)
+ assert self.startPt is not None
+
+ def select(self, x, y):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ self.setSelectionArea((self.startPt, dataPos),
+ fill='hatch',
+ color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'line',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'line',
+ (self.startPt, dataPos),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class Select1Point(Select):
+ """Base class for drawing selection area based on one input point."""
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.onMove(x, y)
+
+ def onMove(self, x, y):
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def onWheel(self, x, y, angle):
+ self.machine.onWheel(x, y, angle) # Call select default wheel
+ self.machine.select(x, y)
+
+ def __init__(self, plot, parameters):
+ states = {
+ 'idle': Select1Point.Idle,
+ 'select': Select1Point.Select
+ }
+ super(Select1Point, self).__init__(plot, parameters, states, 'idle')
+
+ def select(self, x, y):
+ pass
+
+ def endSelect(self, x, y):
+ pass
+
+ def cancelSelect(self):
+ pass
+
+ def cancel(self):
+ if isinstance(self.state, self.states['select']):
+ self.cancelSelect()
+
+
+class SelectHLine(Select1Point):
+ """Drawing a horizontal line selection area state machine."""
+ def _hLine(self, y):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ left, _top, width, _height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(left, y, check=False)
+ dataPos2 = self.plot.pixelToData(left + width, y, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._hLine(y)
+ self.setSelectionArea(points, fill='hatch', color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'hline',
+ points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'hline',
+ self._hLine(y),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class SelectVLine(Select1Point):
+ """Drawing a vertical line selection area state machine."""
+ def _vLine(self, x):
+ """Return points in data coords of the segment visible in the plot.
+
+ Supports non-orthogonal axes.
+ """
+ _left, top, _width, height = self.plot.getPlotBoundsInPixels()
+
+ dataPos1 = self.plot.pixelToData(x, top, check=False)
+ dataPos2 = self.plot.pixelToData(x, top + height, check=False)
+ return dataPos1, dataPos2
+
+ def select(self, x, y):
+ points = self._vLine(x)
+ self.setSelectionArea(points, fill='hatch', color=self.color)
+
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'vline',
+ points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ self.resetSelectionArea()
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'vline',
+ self._vLine(x),
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+
+class DrawFreeHand(Select):
+ """Interaction for drawing pencil. It display the preview of the pencil
+ before pressing the mouse.
+ """
+
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.goto('select', x, y)
+ return True
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+
+ def onLeave(self):
+ self.machine.cancel()
+
+ class Select(State):
+ def enterState(self, x, y):
+ self.__isOut = False
+ self.machine.setFirstPoint(x, y)
+
+ def onMove(self, x, y):
+ self.machine.updatePencilShape(x, y)
+ self.machine.select(x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ if self.__isOut:
+ self.machine.resetSelectionArea()
+ self.machine.endSelect(x, y)
+ self.goto('idle')
+
+ def onEnter(self):
+ self.__isOut = False
+
+ def onLeave(self):
+ self.__isOut = True
+
+ def __init__(self, plot, parameters):
+ # Circle used for pencil preview
+ angle = numpy.arange(13.) * numpy.pi * 2.0 / 13.
+ size = parameters.get('width', 1.) * 0.5
+ self._circle = size * numpy.array((numpy.cos(angle),
+ numpy.sin(angle))).T
+
+ states = {
+ 'idle': DrawFreeHand.Idle,
+ 'select': DrawFreeHand.Select
+ }
+ super(DrawFreeHand, self).__init__(plot, parameters, states, 'idle')
+
+ @property
+ def width(self):
+ return self.parameters.get('width', None)
+
+ def setFirstPoint(self, x, y):
+ self._points = []
+ self.select(x, y)
+
+ def updatePencilShape(self, x, y):
+ center = self.plot.pixelToData(x, y, check=False)
+ assert center is not None
+
+ polygon = center + self._circle
+
+ self.setSelectionArea(polygon, fill='none', color=self.color)
+
+ def select(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] == pos:
+ # Skip same points
+ return
+ self._points.append(pos)
+ eventDict = prepareDrawingSignal('drawingProgress',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ def endSelect(self, x, y):
+ pos = self.plot.pixelToData(x, y, check=False)
+ if len(self._points) > 0:
+ if self._points[-1] != pos:
+ # Append if different
+ self._points.append(pos)
+
+ eventDict = prepareDrawingSignal('drawingFinished',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+ self._points = None
+
+ def cancelSelect(self):
+ self.resetSelectionArea()
+
+ def cancel(self):
+ self.resetSelectionArea()
+
+
+class SelectFreeLine(ClickOrDrag, _PlotInteraction):
+ """Base class for drawing free lines with tools such as pencil."""
+
+ def __init__(self, plot, parameters):
+ """Init a state machine.
+
+ :param plot: The plot to apply changes to.
+ :param dict parameters: A dict of parameters such as color.
+ """
+ # self.DRAG_THRESHOLD_SQUARE_DIST = 1 # Disable first move threshold
+ self._points = []
+ ClickOrDrag.__init__(self)
+ _PlotInteraction.__init__(self, plot)
+ self.parameters = parameters
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.plot, scaleF, (x, y))
+
+ @property
+ def color(self):
+ return self.parameters.get('color', None)
+
+ def click(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent(x, y, isLast=True)
+
+ def beginDrag(self, x, y, btn):
+ self._processEvent(x, y, isLast=False)
+
+ def drag(self, x, y, btn):
+ self._processEvent(x, y, isLast=False)
+
+ def endDrag(self, startPos, endPos, btn):
+ x, y = endPos
+ self._processEvent(x, y, isLast=True)
+
+ def cancel(self):
+ self.resetSelectionArea()
+ self._points = []
+
+ def _processEvent(self, x, y, isLast):
+ dataPos = self.plot.pixelToData(x, y, check=False)
+ isNewPoint = not self._points or dataPos != self._points[-1]
+
+ if isNewPoint:
+ self._points.append(dataPos)
+
+ if isNewPoint or isLast:
+ eventDict = prepareDrawingSignal(
+ 'drawingFinished' if isLast else 'drawingProgress',
+ 'polylines',
+ self._points,
+ self.parameters)
+ self.plot.notify(**eventDict)
+
+ if not isLast:
+ self.setSelectionArea(self._points, fill='none', color=self.color,
+ shape='polylines')
+ else:
+ self.cancel()
+
+
+# ItemInteraction #############################################################
+
+class ItemsInteraction(ClickOrDrag, _PlotInteraction):
+ """Interaction with items (markers, curves and images).
+
+ This class provides selection and dragging of plot primitives
+ that support those interaction.
+ It is also meant to be combined with the zoom interaction.
+ """
+
+ class Idle(ClickOrDrag.Idle):
+ def __init__(self, *args, **kw):
+ super(ItemsInteraction.Idle, self).__init__(*args, **kw)
+ self._hoverMarker = None
+
+ def onWheel(self, x, y, angle):
+ scaleF = 1.1 if angle > 0 else 1. / 1.1
+ applyZoomToPlot(self.machine.plot, scaleF, (x, y))
+
+ def onMove(self, x, y):
+ marker = self.machine.plot._getMarkerAt(x, y)
+
+ if marker is not None:
+ dataPos = self.machine.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareHoverSignal(
+ marker.getName(), 'marker',
+ dataPos, (x, y),
+ marker.isDraggable(),
+ marker.isSelectable())
+ self.machine.plot.notify(**eventDict)
+
+ if marker != self._hoverMarker:
+ self._hoverMarker = marker
+
+ if marker is None:
+ self.machine.plot.setGraphCursorShape()
+
+ elif marker.isDraggable():
+ if isinstance(marker, items.YMarker):
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_VER)
+ elif isinstance(marker, items.XMarker):
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_HOR)
+ else:
+ self.machine.plot.setGraphCursorShape(CURSOR_SIZE_ALL)
+
+ elif marker.isSelectable():
+ self.machine.plot.setGraphCursorShape(CURSOR_POINTING)
+ else:
+ self.machine.plot.setGraphCursorShape()
+
+ return True
+
+ def __init__(self, plot):
+ self._pan = Pan(plot)
+
+ _PlotInteraction.__init__(self, plot)
+ ClickOrDrag.__init__(self,
+ clickButtons=(LEFT_BTN, RIGHT_BTN),
+ dragButtons=(LEFT_BTN, MIDDLE_BTN))
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ eventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**eventDict)
+
+ eventDict = self._handleClick(x, y, btn)
+ if eventDict is not None:
+ self.plot.notify(**eventDict)
+
+ def _handleClick(self, x, y, btn):
+ """Perform picking and prepare event if click is handled here
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: event description to send of None if not handling event.
+ :rtype: dict or None
+ """
+
+ if btn == LEFT_BTN:
+ result = self.plot._pickTopMost(x, y, lambda i: i.isSelectable())
+ if result is None:
+ return None
+
+ item = result.getItem()
+
+ if isinstance(item, items.MarkerBase):
+ xData, yData = item.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ eventDict = prepareMarkerSignal('markerClicked',
+ 'left',
+ item.getName(),
+ 'marker',
+ item.isDraggable(),
+ item.isSelectable(),
+ (xData, yData),
+ (x, y), None)
+ return eventDict
+
+ elif isinstance(item, items.Curve):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ xData = item.getXData(copy=False)
+ yData = item.getYData(copy=False)
+
+ indices = result.getIndices(copy=False)
+ eventDict = prepareCurveSignal('left',
+ item.getName(),
+ 'curve',
+ xData[indices],
+ yData[indices],
+ dataPos[0], dataPos[1],
+ x, y)
+ return eventDict
+
+ elif isinstance(item, items.ImageBase):
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ indices = result.getIndices(copy=False)
+ row, column = indices[0][0], indices[1][0]
+ eventDict = prepareImageSignal('left',
+ item.getName(),
+ 'image',
+ column, row,
+ dataPos[0], dataPos[1],
+ x, y)
+ return eventDict
+
+ return None
+
+ def _signalMarkerMovingEvent(self, eventType, marker, x, y):
+ assert marker is not None
+
+ xData, yData = marker.getPosition()
+ if xData is None:
+ xData = [0, 1]
+ if yData is None:
+ yData = [0, 1]
+
+ posDataCursor = self.plot.pixelToData(x, y)
+ assert posDataCursor is not None
+
+ eventDict = prepareMarkerSignal(eventType,
+ 'left',
+ marker.getName(),
+ 'marker',
+ marker.isDraggable(),
+ marker.isSelectable(),
+ (xData, yData),
+ (x, y),
+ posDataCursor)
+ self.plot.notify(**eventDict)
+
+ @staticmethod
+ def __isDraggableItem(item):
+ return isinstance(item, items.DraggableMixIn) and item.isDraggable()
+
+ def __terminateDrag(self):
+ """Finalize a drag operation by reseting to initial state"""
+ self.plot.setGraphCursorShape()
+ self.draggedItemRef = None
+
+ def beginDrag(self, x, y, btn):
+ """Handle begining of drag interaction
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ :return: True if drag is catched by an item, False otherwise
+ """
+ if btn == LEFT_BTN:
+ self._lastPos = self.plot.pixelToData(x, y)
+ assert self._lastPos is not None
+
+ result = self.plot._pickTopMost(x, y, self.__isDraggableItem)
+ item = result.getItem() if result is not None else None
+
+ self.draggedItemRef = None if item is None else weakref.ref(item)
+
+ if item is None:
+ self.__terminateDrag()
+ return False
+
+ if isinstance(item, items.MarkerBase):
+ self._signalMarkerMovingEvent('markerMoving', item, x, y)
+ item._startDrag()
+
+ return True
+ elif btn == MIDDLE_BTN:
+ self._pan.beginDrag(x, y, btn)
+ return True
+
+ def drag(self, x, y, btn):
+ if btn == LEFT_BTN:
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+
+ item = None if self.draggedItemRef is None else self.draggedItemRef()
+ if item is not None:
+ item.drag(self._lastPos, dataPos)
+
+ if isinstance(item, items.MarkerBase):
+ self._signalMarkerMovingEvent('markerMoving', item, x, y)
+
+ self._lastPos = dataPos
+ elif btn == MIDDLE_BTN:
+ self._pan.drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ if btn == LEFT_BTN:
+ item = None if self.draggedItemRef is None else self.draggedItemRef()
+ if isinstance(item, items.MarkerBase):
+ posData = list(item.getPosition())
+ if posData[0] is None:
+ posData[0] = 1.
+ if posData[1] is None:
+ posData[1] = 1.
+
+ eventDict = prepareMarkerSignal(
+ 'markerMoved',
+ 'left',
+ item.getLegend(),
+ 'marker',
+ item.isDraggable(),
+ item.isSelectable(),
+ posData)
+ self.plot.notify(**eventDict)
+ item._endDrag()
+
+ self.__terminateDrag()
+ elif btn == MIDDLE_BTN:
+ self._pan.endDrag(startPos, endPos, btn)
+
+ def cancel(self):
+ self._pan.cancel()
+ self.__terminateDrag()
+
+
+class ItemsInteractionForCombo(ItemsInteraction):
+ """Interaction with items to combine through :class:`FocusManager`.
+ """
+
+ class Idle(ItemsInteraction.Idle):
+ @staticmethod
+ def __isItemSelectableOrDraggable(item):
+ return (item.isSelectable() or (
+ isinstance(item, items.DraggableMixIn) and item.isDraggable()))
+
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ result = self.machine.plot._pickTopMost(
+ x, y, self.__isItemSelectableOrDraggable)
+ if result is not None: # Request focus and handle interaction
+ self.goto('clickOrDrag', x, y, btn)
+ return True
+ else: # Do not request focus
+ return False
+ else:
+ return super().onPress(x, y, btn)
+
+
+# FocusManager ################################################################
+
+class FocusManager(StateMachine):
+ """Manages focus across multiple event handlers
+
+ On press an event handler can acquire focus.
+ By default it looses focus when all buttons are released.
+ """
+ class Idle(State):
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ for eventHandler in self.machine.eventHandlers:
+ requestFocus = eventHandler.handleEvent('press', x, y, btn)
+ if requestFocus:
+ self.goto('focus', eventHandler, btn)
+ break
+
+ def _processEvent(self, *args):
+ for eventHandler in self.machine.eventHandlers:
+ consumeEvent = eventHandler.handleEvent(*args)
+ if consumeEvent:
+ break
+
+ def onMove(self, x, y):
+ self._processEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self._processEvent('release', x, y, btn)
+
+ def onWheel(self, x, y, angle):
+ self._processEvent('wheel', x, y, angle)
+
+ class Focus(State):
+ def enterState(self, eventHandler, btn):
+ self.eventHandler = eventHandler
+ self.focusBtns = {btn}
+
+ def validate(self):
+ self.eventHandler.validate()
+ self.goto('idle')
+
+ def onPress(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.focusBtns.add(btn)
+ self.eventHandler.handleEvent('press', x, y, btn)
+
+ def onMove(self, x, y):
+ self.eventHandler.handleEvent('move', x, y)
+
+ def onRelease(self, x, y, btn):
+ if btn == LEFT_BTN:
+ self.focusBtns.discard(btn)
+ requestFocus = self.eventHandler.handleEvent('release', x, y, btn)
+ if len(self.focusBtns) == 0 and not requestFocus:
+ self.goto('idle')
+
+ def onWheel(self, x, y, angleInDegrees):
+ self.eventHandler.handleEvent('wheel', x, y, angleInDegrees)
+
+ def __init__(self, eventHandlers=()):
+ self.eventHandlers = list(eventHandlers)
+
+ states = {
+ 'idle': FocusManager.Idle,
+ 'focus': FocusManager.Focus
+ }
+ super(FocusManager, self).__init__(states, 'idle')
+
+ def cancel(self):
+ for handler in self.eventHandlers:
+ handler.cancel()
+
+
+class ZoomAndSelect(ItemsInteraction):
+ """Combine Zoom and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ :param color: The color to use for the zoom area bounding box
+ """
+
+ def __init__(self, plot, color):
+ super(ZoomAndSelect, self).__init__(plot)
+ self._zoom = Zoom(plot, color)
+ self._doZoom = False
+
+ @property
+ def color(self):
+ """Color of the zoom area"""
+ return self._zoom.color
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._zoom.click(x, y, btn)
+
+ def beginDrag(self, x, y, btn):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ self._doZoom = not super(ZoomAndSelect, self).beginDrag(x, y, btn)
+ if self._doZoom:
+ self._zoom.beginDrag(x, y, btn)
+
+ def drag(self, x, y, btn):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ if self._doZoom:
+ return self._zoom.drag(x, y, btn)
+ else:
+ return super(ZoomAndSelect, self).drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ :param str btn: The mouse button for which a drag is done.
+ """
+ if self._doZoom:
+ return self._zoom.endDrag(startPos, endPos, btn)
+ else:
+ return super(ZoomAndSelect, self).endDrag(startPos, endPos, btn)
+
+
+class PanAndSelect(ItemsInteraction):
+ """Combine Pan and ItemInteraction state machine.
+
+ :param plot: The Plot to which this interaction is attached
+ """
+
+ def __init__(self, plot):
+ super(PanAndSelect, self).__init__(plot)
+ self._pan = Pan(plot)
+ self._doPan = False
+
+ def click(self, x, y, btn):
+ """Handle mouse click
+
+ :param x: X position of the mouse in pixels
+ :param y: Y position of the mouse in pixels
+ :param btn: Pressed button id
+ :return: True if click is catched by an item, False otherwise
+ """
+ eventDict = self._handleClick(x, y, btn)
+
+ if eventDict is not None:
+ # Signal mouse clicked event
+ dataPos = self.plot.pixelToData(x, y)
+ assert dataPos is not None
+ clickedEventDict = prepareMouseSignal('mouseClicked', btn,
+ dataPos[0], dataPos[1],
+ x, y)
+ self.plot.notify(**clickedEventDict)
+
+ self.plot.notify(**eventDict)
+
+ else:
+ self._pan.click(x, y, btn)
+
+ def beginDrag(self, x, y, btn):
+ """Handle start drag and switching between zoom and item drag.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is starting.
+ """
+ self._doPan = not super(PanAndSelect, self).beginDrag(x, y, btn)
+ if self._doPan:
+ self._pan.beginDrag(x, y, btn)
+
+ def drag(self, x, y, btn):
+ """Handle drag, eventually forwarding to zoom.
+
+ :param x: X position in pixels
+ :param y: Y position in pixels
+ :param str btn: The mouse button for which a drag is in progress.
+ """
+ if self._doPan:
+ return self._pan.drag(x, y, btn)
+ else:
+ return super(PanAndSelect, self).drag(x, y, btn)
+
+ def endDrag(self, startPos, endPos, btn):
+ """Handle end of drag, eventually forwarding to zoom.
+
+ :param startPos: (x, y) position at the beginning of the drag
+ :param endPos: (x, y) position at the end of the drag
+ :param str btn: The mouse button for which a drag is done.
+ """
+ if self._doPan:
+ return self._pan.endDrag(startPos, endPos, btn)
+ else:
+ return super(PanAndSelect, self).endDrag(startPos, endPos, btn)
+
+
+# Interaction mode control ####################################################
+
+# Mapping of draw modes: event handler
+_DRAW_MODES = {
+ 'polygon': SelectPolygon,
+ 'rectangle': SelectRectangle,
+ 'ellipse': SelectEllipse,
+ 'line': SelectLine,
+ 'vline': SelectVLine,
+ 'hline': SelectHLine,
+ 'polylines': SelectFreeLine,
+ 'pencil': DrawFreeHand,
+ }
+
+
+class DrawMode(FocusManager):
+ """Interactive mode for draw and select"""
+
+ def __init__(self, plot, shape, label, color, width):
+ eventHandlerClass = _DRAW_MODES[shape]
+ parameters = {
+ 'shape': shape,
+ 'label': label,
+ 'color': color,
+ 'width': width,
+ }
+ super().__init__((
+ Pan(plot, clickButtons=(), dragButtons=(MIDDLE_BTN,)),
+ eventHandlerClass(plot, parameters)))
+
+ def getDescription(self):
+ """Returns the dict describing this interactive mode"""
+ params = self.eventHandlers[1].parameters.copy()
+ params['mode'] = 'draw'
+ return params
+
+
+class DrawSelectMode(FocusManager):
+ """Interactive mode for draw and select"""
+
+ def __init__(self, plot, shape, label, color, width):
+ eventHandlerClass = _DRAW_MODES[shape]
+ self._pan = Pan(plot)
+ self._panStart = None
+ parameters = {
+ 'shape': shape,
+ 'label': label,
+ 'color': color,
+ 'width': width,
+ }
+ super().__init__((
+ ItemsInteractionForCombo(plot),
+ eventHandlerClass(plot, parameters)))
+
+ def handleEvent(self, eventName, *args, **kwargs):
+ # Hack to add pan interaction to select-draw
+ # See issue Refactor PlotWidget interaction #3292
+ if eventName == 'press' and args[2] == MIDDLE_BTN:
+ self._panStart = args[:2]
+ self._pan.beginDrag(*args)
+ return # Consume middle click events
+ elif eventName == 'release' and args[2] == MIDDLE_BTN:
+ self._panStart = None
+ self._pan.endDrag(self._panStart, args[:2], MIDDLE_BTN)
+ return # Consume middle click events
+ elif self._panStart is not None and eventName == 'move':
+ x, y = args[:2]
+ self._pan.drag(x, y, MIDDLE_BTN)
+
+ super().handleEvent(eventName, *args, **kwargs)
+
+ def getDescription(self):
+ """Returns the dict describing this interactive mode"""
+ params = self.eventHandlers[1].parameters.copy()
+ params['mode'] = 'select-draw'
+ return params
+
+
+class PlotInteraction(object):
+ """Proxy to currently use state machine for interaction.
+
+ This allows to switch interactive mode.
+
+ :param plot: The :class:`Plot` to apply interaction to
+ """
+
+ _DRAW_MODES = {
+ 'polygon': SelectPolygon,
+ 'rectangle': SelectRectangle,
+ 'ellipse': SelectEllipse,
+ 'line': SelectLine,
+ 'vline': SelectVLine,
+ 'hline': SelectHLine,
+ 'polylines': SelectFreeLine,
+ 'pencil': DrawFreeHand,
+ }
+
+ def __init__(self, plot):
+ self._plot = weakref.ref(plot) # Avoid cyclic-ref
+
+ self.zoomOnWheel = True
+ """True to enable zoom on wheel, False otherwise."""
+
+ # Default event handler
+ self._eventHandler = ItemsInteraction(plot)
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ if isinstance(self._eventHandler, ZoomAndSelect):
+ return {'mode': 'zoom', 'color': self._eventHandler.color}
+
+ elif isinstance(self._eventHandler, (DrawMode, DrawSelectMode)):
+ return self._eventHandler.getDescription()
+
+ elif isinstance(self._eventHandler, PanAndSelect):
+ return {'mode': 'pan'}
+
+ else:
+ return {'mode': 'select'}
+
+ def validate(self):
+ """Validate the current interaction if possible
+
+ If was designed to close the polygon interaction.
+ """
+ self._eventHandler.validate()
+
+ def setInteractiveMode(self, mode, color='black',
+ shape='polygon', label=None, width=None):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ If None, selection area is not drawn.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats or None.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'polylines'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode.
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ assert mode in ('draw', 'pan', 'select', 'select-draw', 'zoom')
+
+ plot = self._plot()
+ assert plot is not None
+
+ if isinstance(color, numpy.ndarray) or color not in (None, 'video inverted'):
+ color = colors.rgba(color)
+
+ if mode in ('draw', 'select-draw'):
+ self._eventHandler.cancel()
+ handlerClass = DrawMode if mode == 'draw' else DrawSelectMode
+ self._eventHandler = handlerClass(plot, shape, label, color, width)
+
+ elif mode == 'pan':
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = PanAndSelect(plot)
+
+ elif mode == 'zoom':
+ # Ignores shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ZoomAndSelect(plot, color)
+
+ else: # Default mode: interaction with plot objects
+ # Ignores color, shape and label
+ self._eventHandler.cancel()
+ self._eventHandler = ItemsInteraction(plot)
+
+ def handleEvent(self, event, *args, **kwargs):
+ """Forward event to current interactive mode state machine."""
+ if not self.zoomOnWheel and event == 'wheel':
+ return # Discard wheel events
+ self._eventHandler.handleEvent(event, *args, **kwargs)
diff --git a/silx/gui/plot/PlotToolButtons.py b/src/silx/gui/plot/PlotToolButtons.py
index 3970896..3970896 100644
--- a/silx/gui/plot/PlotToolButtons.py
+++ b/src/silx/gui/plot/PlotToolButtons.py
diff --git a/silx/gui/plot/PlotTools.py b/src/silx/gui/plot/PlotTools.py
index 5929473..5929473 100644
--- a/silx/gui/plot/PlotTools.py
+++ b/src/silx/gui/plot/PlotTools.py
diff --git a/src/silx/gui/plot/PlotWidget.py b/src/silx/gui/plot/PlotWidget.py
new file mode 100755
index 0000000..6cb5ef5
--- /dev/null
+++ b/src/silx/gui/plot/PlotWidget.py
@@ -0,0 +1,3628 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Qt widget providing plot API for 1D and 2D data.
+
+The :class:`PlotWidget` implements the plot API initially provided in PyMca.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+from collections import OrderedDict, namedtuple
+from contextlib import contextmanager
+import datetime as dt
+import itertools
+import typing
+import warnings
+
+import numpy
+
+import silx
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.property import classproperty
+from silx.utils.deprecation import deprecated, deprecated_warning
+try:
+ # Import matplotlib now to init matplotlib our way
+ import silx.gui.utils.matplotlib # noqa
+except ImportError:
+ _logger.debug("matplotlib not available")
+
+from ..colors import Colormap
+from .. import colors
+from . import PlotInteraction
+from . import PlotEvents
+from .LimitsHistory import LimitsHistory
+from . import _utils
+
+from . import items
+from .items.curve import CurveStyle
+from .items.axis import TickMode # noqa
+
+from .. import qt
+from ._utils.panzoom import ViewConstraints
+from ...gui.plot._utils.dtime_ticklayout import timestamp
+
+
+
+_COLORDICT = colors.COLORDICT
+_COLORLIST = silx.config.DEFAULT_PLOT_CURVE_COLORS
+
+"""
+Object returned when requesting the data range.
+"""
+_PlotDataRange = namedtuple('PlotDataRange',
+ ['x', 'y', 'yright'])
+
+
+class _PlotWidgetSelection(qt.QObject):
+ """Object managing a :class:`PlotWidget` selection.
+
+ It is a wrapper over :class:`PlotWidget`'s active items API.
+
+ :param PlotWidget parent:
+ """
+
+ sigCurrentItemChanged = qt.Signal(object, object)
+ """This signal is emitted whenever the current item changes.
+
+ It provides the current and previous items.
+ """
+
+ sigSelectedItemsChanged = qt.Signal()
+ """Signal emitted whenever the list of selected items changes."""
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(_PlotWidgetSelection, self).__init__(parent=parent)
+
+ # Init history
+ self.__history = [ # Store active items from most recent to oldest
+ item for item in (parent.getActiveCurve(),
+ parent.getActiveImage(),
+ parent.getActiveScatter())
+ if item is not None]
+
+ self.__current = self.__mostRecentActiveItem()
+
+ parent.sigActiveImageChanged.connect(self._activeImageChanged)
+ parent.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ parent.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def __mostRecentActiveItem(self) -> typing.Optional[items.Item]:
+ """Returns most recent active item."""
+ return self.__history[0] if len(self.__history) >= 1 else None
+
+ def getSelectedItems(self) -> typing.Tuple[items.Item]:
+ """Returns the list of currently selected items in the :class:`PlotWidget`.
+
+ The list is given from most recently current item to oldest one."""
+ plot = self.parent()
+ if plot is None:
+ return ()
+
+ active = tuple(self.__history)
+
+ current = self.getCurrentItem()
+ if current is not None and current not in active:
+ # Current might not be an active item, if so add it
+ active = (current,) + active
+
+ return active
+
+ def getCurrentItem(self) -> typing.Optional[items.Item]:
+ """Returns the current item in the :class:`PlotWidget` or None. """
+ return self.__current
+
+ def setCurrentItem(self, item: typing.Optional[items.Item]):
+ """Set the current item in the :class:`PlotWidget`.
+
+ :param item:
+ The new item to select or None to clear the selection.
+ :raise ValueError: If the item is not the :class:`PlotWidget`
+ """
+ previous = self.getCurrentItem()
+ if previous is item:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ if item is None:
+ self.__current = None
+
+ # Reset all PlotWidget active items
+ plot = self.parent()
+ if plot is not None:
+ for kind in PlotWidget._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) is not None:
+ plot._setActiveItem(kind, None)
+
+ elif isinstance(item, items.Item):
+ plot = self.parent()
+ if plot is None or item.getPlot() is not plot:
+ raise ValueError(
+ "Item is not in the PlotWidget: %s" % str(item))
+ self.__current = item
+
+ kind = plot._itemKind(item)
+
+ # Clean-up history to be safe
+ self.__history = [item for item in self.__history
+ if PlotWidget._itemKind(item) != kind]
+
+ # Sync active item if needed
+ if (kind in plot._ACTIVE_ITEM_KINDS and
+ item is not plot._getActiveItem(kind)):
+ plot._setActiveItem(kind, item.getName())
+ else:
+ raise ValueError("Not an Item: %s" % str(item))
+
+ self.sigCurrentItemChanged.emit(previous, item)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def __activeItemChanged(self,
+ kind: str,
+ previous: typing.Optional[str],
+ legend: typing.Optional[str]):
+ """Set current item from kind and legend"""
+ if previous == legend:
+ return # No-op for update of item
+
+ plot = self.parent()
+ if plot is None:
+ return
+
+ previousSelected = self.getSelectedItems()
+
+ # Remove items of this kind from the history
+ self.__history = [item for item in self.__history
+ if PlotWidget._itemKind(item) != kind]
+
+ # Retrieve current item
+ if legend is None: # Use most recent active item
+ currentItem = self.__mostRecentActiveItem()
+ else:
+ currentItem = plot._getItem(kind=kind, legend=legend)
+ if currentItem is None: # Fallback in case something went wrong
+ currentItem = self.__mostRecentActiveItem()
+
+ # Update history
+ if currentItem is not None:
+ while currentItem in self.__history:
+ self.__history.remove(currentItem)
+ self.__history.insert(0, currentItem)
+
+ if currentItem != self.__current:
+ previousItem = self.__current
+ self.__current = currentItem
+ self.sigCurrentItemChanged.emit(previousItem, currentItem)
+
+ if previousSelected != self.getSelectedItems():
+ self.sigSelectedItemsChanged.emit()
+
+ def _activeImageChanged(self, previous, current):
+ """Handle active image change"""
+ self.__activeItemChanged('image', previous, current)
+
+ def _activeCurveChanged(self, previous, current):
+ """Handle active curve change"""
+ self.__activeItemChanged('curve', previous, current)
+
+ def _activeScatterChanged(self, previous, current):
+ """Handle active scatter change"""
+ self.__activeItemChanged('scatter', previous, current)
+
+
+class PlotWidget(qt.QMainWindow):
+ """Qt Widget providing a 1D/2D plot.
+
+ This widget is a QMainWindow.
+ This class implements the plot API initially provided in PyMca.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param parent: The parent of this widget or None (default).
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ # TODO: Can be removed for silx 0.10
+ @classproperty
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ def DEFAULT_BACKEND(self):
+ """Class attribute setting the default backend for all instances."""
+ return silx.config.DEFAULT_PLOT_BACKEND
+
+ colorList = _COLORLIST
+ colorDict = _COLORDICT
+
+ sigPlotSignal = qt.Signal(object)
+ """Signal for all events of the plot.
+
+ The signal information is provided as a dict.
+ See the :ref:`plot signal documentation page <plot_signal>` for
+ information about the content of the dict
+ """
+
+ sigSetKeepDataAspectRatio = qt.Signal(bool)
+ """Signal emitted when plot keep aspect ratio has changed"""
+
+ sigSetGraphGrid = qt.Signal(str)
+ """Signal emitted when plot grid has changed"""
+
+ sigSetGraphCursor = qt.Signal(bool)
+ """Signal emitted when plot crosshair cursor has changed"""
+
+ sigSetPanWithArrowKeys = qt.Signal(bool)
+ """Signal emitted when pan with arrow keys has changed"""
+
+ _sigAxesVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the axes visibility changed"""
+
+ sigContentChanged = qt.Signal(str, str, str)
+ """Signal emitted when the content of the plot is changed.
+
+ It provides the following information:
+
+ - action: The change of the plot: 'add' or 'remove'
+ - kind: The kind of primitive changed:
+ 'curve', 'image', 'scatter', 'histogram', 'item' or 'marker'
+ - legend: The legend of the primitive changed.
+ """
+
+ sigActiveCurveChanged = qt.Signal(object, object)
+ """Signal emitted when the active curve has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active curve or None
+ - legend: The legend of the new active curve or None if no curve is active
+ """
+
+ sigActiveImageChanged = qt.Signal(object, object)
+ """Signal emitted when the active image has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active image or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigActiveScatterChanged = qt.Signal(object, object)
+ """Signal emitted when the active Scatter has changed.
+
+ It provides the following information:
+
+ - previous: The legend of the previous active scatter or None
+ - legend: The legend of the new active image or None if no image is active
+ """
+
+ sigInteractiveModeChanged = qt.Signal(object)
+ """Signal emitted when the interactive mode has changed
+
+ It provides the source as passed to :meth:`setInteractiveMode`.
+ """
+
+ sigItemAdded = qt.Signal(items.Item)
+ """Signal emitted when an item was just added to the plot
+
+ It provides the added item.
+ """
+
+ sigItemAboutToBeRemoved = qt.Signal(items.Item)
+ """Signal emitted right before an item is removed from the plot.
+
+ It provides the item that will be removed.
+ """
+
+ sigItemRemoved = qt.Signal(items.Item)
+ """Signal emitted right after an item was removed from the plot.
+
+ It provides the item that was removed.
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the widget becomes visible (or invisible).
+ This happens when the widget is hidden or shown.
+
+ It provides the visible state.
+ """
+
+ _sigDefaultContextMenu = qt.Signal(qt.QMenu)
+ """Signal emitted when the default context menu of the plot is feed.
+
+ It provides the menu which will be displayed.
+ """
+
+ def __init__(self, parent=None, backend=None):
+ self._autoreplot = False
+ self._dirty = False
+ self._cursorInPlot = False
+ self.__muteActiveItemChanged = False
+
+ self._panWithArrowKeys = True
+ self._viewConstrains = None
+
+ super(PlotWidget, self).__init__(parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('PlotWidget')
+
+ # Init the backend
+ self._backend = self.__getBackendClass(backend)(self, self)
+
+ self.setCallback() # set _callback
+
+ # Items handling
+ self._content = OrderedDict()
+ self._contentToUpdate = [] # Used as an OrderedSet
+
+ self._dataRange = None
+
+ # line types
+ self._styleList = ['-', '--', '-.', ':']
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ self._activeCurveSelectionMode = "atmostone"
+ self._activeCurveStyle = CurveStyle(color='#000000')
+ self._activeLegend = {'curve': None, 'image': None,
+ 'scatter': None}
+
+ # plot colors (updated later to sync backend)
+ self._foregroundColor = 0., 0., 0., 1.
+ self._gridColor = .7, .7, .7, 1.
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = None
+
+ # default properties
+ self._cursorConfiguration = None
+
+ self._xAxis = items.XAxis(self)
+ self._yAxis = items.YAxis(self)
+ self._yRightAxis = items.YRightAxis(self, self._yAxis)
+
+ self._grid = None
+ self._graphTitle = ''
+ self.__graphCursorShape = 'default'
+
+ # Set axes margins
+ self.__axesDisplayed = True
+ self.__axesMargins = 0., 0., 0., 0.
+ self.setAxesMargins(.15, .1, .1, .15)
+
+ self.setGraphTitle()
+ self.setGraphXLabel()
+ self.setGraphYLabel()
+ self.setGraphYLabel('', axis='right')
+
+ self.setDefaultColormap() # Init default colormap
+
+ self.setDefaultPlotPoints(silx.config.DEFAULT_PLOT_CURVE_SYMBOL_MODE)
+ self.setDefaultPlotLines(True)
+
+ self._limitsHistory = LimitsHistory(self)
+
+ self._eventHandler = PlotInteraction.PlotInteraction(self)
+ self._eventHandler.setInteractiveMode('zoom', color=(0., 0., 0., 1.))
+ self._previousDefaultMode = "zoom", True
+
+ self._pressedButtons = [] # Currently pressed mouse buttons
+
+ self._defaultDataMargins = (0., 0., 0., 0.)
+
+ # Only activate autoreplot at the end
+ # This avoids errors when loaded in Qt designer
+ self._dirty = False
+ self._autoreplot = True
+
+ widget = self.getWidgetHandle()
+ if widget is not None:
+ self.setCentralWidget(widget)
+ else:
+ _logger.info("PlotWidget backend does not support widget")
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ # Set default limits
+ self.setGraphXLimits(0., 100.)
+ self.setGraphYLimits(0., 100., axis='right')
+ self.setGraphYLimits(0., 100., axis='left')
+
+ # Sync backend colors with default ones
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ # selection handling
+ self.__selection = None
+
+ def __getBackendClass(self, backend):
+ """Returns backend class corresponding to backend.
+
+ If multiple backends are provided, the first available one is used.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The name of the backend or its class or an iterable of those.
+ :rtype: BackendBase
+ :raise ValueError: In case the backend is not supported
+ :raise RuntimeError: If a backend is not available
+ """
+ if backend is None:
+ backend = silx.config.DEFAULT_PLOT_BACKEND
+
+ if callable(backend):
+ return backend
+
+ elif isinstance(backend, str):
+ backend = backend.lower()
+ if backend in ('matplotlib', 'mpl'):
+ try:
+ from .backends.BackendMatplotlib import \
+ BackendMatplotlibQt as backendClass
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("matplotlib backend is not available")
+
+ elif backend in ('gl', 'opengl'):
+ from ..utils.glutils import isOpenGLAvailable
+ checkOpenGL = isOpenGLAvailable(version=(2, 1), runtimeCheck=False)
+ if not checkOpenGL:
+ _logger.debug("OpenGL check failed")
+ raise RuntimeError(
+ "OpenGL backend is not available: %s" % checkOpenGL.error)
+
+ try:
+ from .backends.BackendOpenGL import \
+ BackendOpenGL as backendClass
+ except ImportError:
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError("OpenGL backend is not available")
+
+ elif backend == 'none':
+ from .backends.BackendBase import BackendBase as backendClass
+
+ else:
+ raise ValueError("Backend not supported %s" % backend)
+
+ return backendClass
+
+ elif isinstance(backend, (tuple, list)):
+ for b in backend:
+ try:
+ return self.__getBackendClass(b)
+ except RuntimeError:
+ pass
+ else: # No backend was found
+ raise RuntimeError("None of the request backends are available")
+
+ raise ValueError("Backend not supported %s" % str(backend))
+
+ def selection(self):
+ """Returns the selection hander"""
+ if self.__selection is None: # Lazy initialization
+ self.__selection = _PlotWidgetSelection(parent=self)
+ return self.__selection
+
+ # TODO: Can be removed for silx 0.10
+ @staticmethod
+ @deprecated(replacement="silx.config.DEFAULT_PLOT_BACKEND", since_version="0.8", skip_backtrace_count=2)
+ def setDefaultBackend(backend):
+ """Set system wide default plot backend.
+
+ .. versionadded:: 0.6
+
+ :param backend: The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none'
+ or a :class:`BackendBase.BackendBase` class
+ """
+ silx.config.DEFAULT_PLOT_BACKEND = backend
+
+ def setBackend(self, backend):
+ """Set the backend to use for rendering.
+
+ Supported backends:
+
+ - 'matplotlib' and 'mpl': Matplotlib with Qt.
+ - 'opengl' and 'gl': OpenGL backend (requires PyOpenGL and OpenGL >= 2.1)
+ - 'none': No backend, to run headless for testing purpose.
+
+ :param Union[str,BackendBase,List[Union[str,BackendBase]]] backend:
+ The backend to use, in:
+ 'matplotlib' (default), 'mpl', 'opengl', 'gl', 'none',
+ a :class:`BackendBase.BackendBase` class.
+ If multiple backends are provided, the first available one is used.
+ :raises ValueError: Unsupported backend descriptor
+ :raises RuntimeError: Error while loading a backend
+ """
+ backend = self.__getBackendClass(backend)(self, self)
+
+ # First save state that is stored in the backend
+ xaxis = self.getXAxis()
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = self.getYAxis(axis='left').getLimits()
+ y2min, y2max = self.getYAxis(axis='right').getLimits()
+ isKeepDataAspectRatio = self.isKeepDataAspectRatio()
+ xTimeZone = xaxis.getTimeZone()
+ isXAxisTimeSeries = xaxis.getTickMode() == TickMode.TIME_SERIES
+
+ isYAxisInverted = self.getYAxis().isInverted()
+
+ # Remove all items from previous backend
+ for item in self.getItems():
+ item._removeBackendRenderer(self._backend)
+
+ # Switch backend
+ self._backend = backend
+ widget = self._backend.getWidgetHandle()
+ self.setCentralWidget(widget)
+ if widget is None:
+ _logger.info("PlotWidget backend does not support widget")
+
+ # Mark as newly dirty
+ self._dirty = False
+ self._setDirtyPlot()
+
+ # Synchronize/restore state
+ self._foregroundColorsUpdated()
+ self._backgroundColorsUpdated()
+
+ self._backend.setGraphCursorShape(self.getGraphCursorShape())
+ crosshairConfig = self.getGraphCursor()
+ if crosshairConfig is None:
+ self._backend.setGraphCursor(False, 'black', 1, '-')
+ else:
+ self._backend.setGraphCursor(True, *crosshairConfig)
+
+ self._backend.setGraphTitle(self.getGraphTitle())
+ self._backend.setGraphGrid(self.getGraphGrid())
+ if self.isAxesDisplayed():
+ self._backend.setAxesMargins(*self.getAxesMargins())
+ else:
+ self._backend.setAxesMargins(0., 0., 0., 0.)
+
+ # Set axes
+ xaxis = self.getXAxis()
+ self._backend.setGraphXLabel(xaxis.getLabel())
+ self._backend.setXAxisTimeZone(xTimeZone)
+ self._backend.setXAxisTimeSeries(isXAxisTimeSeries)
+ self._backend.setXAxisLogarithmic(
+ xaxis.getScale() == items.Axis.LOGARITHMIC)
+
+ for axis in ('left', 'right'):
+ self._backend.setGraphYLabel(self.getYAxis(axis).getLabel(), axis)
+ self._backend.setYAxisInverted(isYAxisInverted)
+ self._backend.setYAxisLogarithmic(
+ self.getYAxis().getScale() == items.Axis.LOGARITHMIC)
+
+ # Finally restore aspect ratio and limits
+ self._backend.setKeepDataAspectRatio(isKeepDataAspectRatio)
+ self.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+
+ # Mark all items for update with new backend
+ for item in self.getItems():
+ item._updated()
+
+ def getBackend(self):
+ """Returns the backend currently used by :class:`PlotWidget`.
+
+ :rtype: ~silx.gui.plot.backend.BackendBase.BackendBase
+ """
+ return self._backend
+
+ def _getDirtyPlot(self):
+ """Return the plot dirty flag.
+
+ If False, the plot has not changed since last replot.
+ If True, the full plot need to be redrawn.
+ If 'overlay', only the overlay has changed since last replot.
+
+ It can be accessed by backend to check the dirty state.
+
+ :return: False, True, 'overlay'
+ """
+ return self._dirty
+
+ # Default Qt context menu
+
+ def contextMenuEvent(self, event):
+ """Override QWidget.contextMenuEvent to implement the context menu"""
+ menu = qt.QMenu(self)
+ from .actions.control import ZoomBackAction # Avoid cyclic import
+ zoomBackAction = ZoomBackAction(plot=self, parent=menu)
+ menu.addAction(zoomBackAction)
+
+ mode = self.getInteractiveMode()
+ if "shape" in mode and mode["shape"] == "polygon":
+ from .actions.control import ClosePolygonInteractionAction # Avoid cyclic import
+ action = ClosePolygonInteractionAction(plot=self, parent=menu)
+ menu.addAction(action)
+
+ self._sigDefaultContextMenu.emit(menu)
+
+ # Make sure the plot is updated, especially when the plot is in
+ # draw interaction mode
+ menu.aboutToHide.connect(self.__simulateMouseMove)
+
+ menu.exec(event.globalPos())
+
+ def _setDirtyPlot(self, overlayOnly=False):
+ """Mark the plot as needing redraw
+
+ :param bool overlayOnly: True to redraw only the overlay,
+ False to redraw everything
+ """
+ wasDirty = self._dirty
+
+ if not self._dirty and overlayOnly:
+ self._dirty = 'overlay'
+ else:
+ self._dirty = True
+
+ if self._autoreplot and not wasDirty and self.isVisible():
+ self._backend.postRedisplay()
+
+ def _foregroundColorsUpdated(self):
+ """Handle change of foreground/grid color"""
+ if self._gridColor is None:
+ gridColor = self._foregroundColor
+ else:
+ gridColor = self._gridColor
+ self._backend.setForegroundColors(
+ self._foregroundColor, gridColor)
+ self._setDirtyPlot()
+
+ def getForegroundColor(self):
+ """Returns the RGBA colors used to display the foreground of this widget
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._foregroundColor)
+
+ def setForegroundColor(self, color):
+ """Set the foreground color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._foregroundColorsUpdated()
+
+ def getGridColor(self):
+ """Returns the RGBA colors used to display the grid lines
+
+ An invalid QColor is returned if there is no grid color,
+ in which case the foreground color is used.
+
+ :rtype: qt.QColor
+ """
+ if self._gridColor is None:
+ return qt.QColor() # An invalid color
+ else:
+ return qt.QColor.fromRgbF(*self._gridColor)
+
+ def setGridColor(self, color):
+ """Set the grid lines color
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._gridColor != color:
+ self._gridColor = color
+ self._foregroundColorsUpdated()
+
+ def _backgroundColorsUpdated(self):
+ """Handle change of background/data background color"""
+ if self._dataBackgroundColor is None:
+ dataBGColor = self._backgroundColor
+ else:
+ dataBGColor = self._dataBackgroundColor
+ self._backend.setBackgroundColors(
+ self._backgroundColor, dataBGColor)
+ self._setDirtyPlot()
+
+ def getBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of this widget.
+
+ :rtype: qt.QColor
+ """
+ return qt.QColor.fromRgbF(*self._backgroundColor)
+
+ def setBackgroundColor(self, color):
+ """Set the background color of this widget.
+
+ :param Union[List[int],List[float],QColor] color:
+ The new RGB(A) color.
+ """
+ color = colors.rgba(color)
+ if self._backgroundColor != color:
+ self._backgroundColor = color
+ self._backgroundColorsUpdated()
+
+ def getDataBackgroundColor(self):
+ """Returns the RGBA colors used to display the background of the plot
+ view displaying the data.
+
+ An invalid QColor is returned if there is no data background color.
+
+ :rtype: qt.QColor
+ """
+ if self._dataBackgroundColor is None:
+ # An invalid color
+ return qt.QColor()
+ else:
+ return qt.QColor.fromRgbF(*self._dataBackgroundColor)
+
+ def setDataBackgroundColor(self, color):
+ """Set the background color of the plot area.
+
+ Set to None or an invalid QColor to use the background color.
+
+ :param Union[List[int],List[float],QColor,None] color:
+ The new RGB(A) color.
+ """
+ if isinstance(color, qt.QColor) and not color.isValid():
+ color = None
+ if color is not None:
+ color = colors.rgba(color)
+ if self._dataBackgroundColor != color:
+ self._dataBackgroundColor = color
+ self._backgroundColorsUpdated()
+
+ dataBackgroundColor = qt.Property(
+ qt.QColor, getDataBackgroundColor, setDataBackgroundColor
+ )
+
+ backgroundColor = qt.Property(qt.QColor, getBackgroundColor, setBackgroundColor)
+
+ foregroundColor = qt.Property(qt.QColor, getForegroundColor, setForegroundColor)
+
+ gridColor = qt.Property(qt.QColor, getGridColor, setGridColor)
+
+ def showEvent(self, event):
+ if self._autoreplot and self._dirty:
+ self._backend.postRedisplay()
+ super(PlotWidget, self).showEvent(event)
+ self.sigVisibilityChanged.emit(True)
+
+ def hideEvent(self, event):
+ super(PlotWidget, self).hideEvent(event)
+ self.sigVisibilityChanged.emit(False)
+
+ def _invalidateDataRange(self):
+ """
+ Notifies this PlotWidget instance that the range has changed
+ and will have to be recomputed.
+ """
+ self._dataRange = None
+
+ def _updateDataRange(self):
+ """
+ Recomputes the range of the data displayed on this PlotWidget.
+ """
+ xMin = yMinLeft = yMinRight = float('nan')
+ xMax = yMaxLeft = yMaxRight = float('nan')
+
+ for item in self.getItems():
+ if item.isVisible():
+ bounds = item.getBounds()
+ if bounds is not None:
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ xMin = numpy.nanmin([xMin, bounds[0]])
+ xMax = numpy.nanmax([xMax, bounds[1]])
+ # Take care of right axis
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ yMinRight = numpy.nanmin([yMinRight, bounds[2]])
+ yMaxRight = numpy.nanmax([yMaxRight, bounds[3]])
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore All-NaN slice encountered
+ yMinLeft = numpy.nanmin([yMinLeft, bounds[2]])
+ yMaxLeft = numpy.nanmax([yMaxLeft, bounds[3]])
+
+ def lGetRange(x, y):
+ return None if numpy.isnan(x) and numpy.isnan(y) else (x, y)
+ xRange = lGetRange(xMin, xMax)
+ yLeftRange = lGetRange(yMinLeft, yMaxLeft)
+ yRightRange = lGetRange(yMinRight, yMaxRight)
+
+ self._dataRange = _PlotDataRange(x=xRange,
+ y=yLeftRange,
+ yright=yRightRange)
+
+ def getDataRange(self):
+ """
+ Returns this PlotWidget's data range.
+
+ :return: a namedtuple with the following members :
+ x, y (left y axis), yright. Each member is a tuple (min, max)
+ or None if no data is associated with the axis.
+ :rtype: namedtuple
+ """
+ if self._dataRange is None:
+ self._updateDataRange()
+ return self._dataRange
+
+ # Content management
+
+ _KIND_TO_CLASSES = {
+ 'curve': (items.Curve,),
+ 'image': (items.ImageBase,),
+ 'scatter': (items.Scatter,),
+ 'marker': (items.MarkerBase,),
+ 'item': (items.Shape,
+ items.BoundingRect,
+ items.XAxisExtent,
+ items.YAxisExtent),
+ 'histogram': (items.Histogram,),
+ }
+ """Mapping kind to item classes of this kind"""
+
+ @classmethod
+ def _itemKind(cls, item):
+ """Returns the "kind" of a given item
+
+ :param Item item: The item get the kind
+ :rtype: str
+ """
+ for kind, itemClasses in cls._KIND_TO_CLASSES.items():
+ if isinstance(item, itemClasses):
+ return kind
+ raise ValueError('Unsupported item type %s' % type(item))
+
+ def _notifyContentChanged(self, item):
+ self.notify('contentChanged', action='add',
+ kind=self._itemKind(item), legend=item.getName())
+
+ def _itemRequiresUpdate(self, item):
+ """Called by items in the plot for asynchronous update
+
+ :param Item item: The item that required update
+ """
+ assert item.getPlot() == self
+ # Put item at the end of the list
+ if item in self._contentToUpdate:
+ self._contentToUpdate.remove(item)
+ self._contentToUpdate.append(item)
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+
+ def addItem(self, item=None, *args, **kwargs):
+ """Add an item to the plot content.
+
+ :param ~silx.gui.plot.items.Item item: The item to add.
+ :raises ValueError: If item is already in the plot.
+ """
+ if not isinstance(item, items.Item):
+ deprecated_warning(
+ 'Function',
+ 'addItem',
+ replacement='addShape',
+ since_version='0.13')
+ if item is None and not args: # Only kwargs
+ return self.addShape(**kwargs)
+ else:
+ return self.addShape(item, *args, **kwargs)
+
+ assert not args and not kwargs
+ if item in self.getItems():
+ raise ValueError('Item already in the plot')
+
+ # Add item to plot
+ self._content[(item.getName(), self._itemKind(item))] = item
+ item._setPlot(self)
+ self._itemRequiresUpdate(item)
+ if isinstance(item, items.DATA_ITEMS):
+ self._invalidateDataRange() # TODO handle this automatically
+
+ self._notifyContentChanged(item)
+ self.sigItemAdded.emit(item)
+
+ def removeItem(self, item):
+ """Remove the item from the plot.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :raises ValueError: If item is not in the plot.
+ """
+ if not isinstance(item, items.Item): # Previous method usage
+ deprecated_warning(
+ 'Function',
+ 'removeItem',
+ replacement='remove(legend, kind="item")',
+ since_version='0.13')
+ if item is None:
+ return
+ self.remove(item, kind='item')
+ return
+
+ if item not in self.getItems():
+ raise ValueError('Item not in the plot')
+
+ self.sigItemAboutToBeRemoved.emit(item)
+
+ kind = self._itemKind(item)
+
+ if kind in self._ACTIVE_ITEM_KINDS:
+ if self._getActiveItem(kind) == item:
+ # Reset active item
+ self._setActiveItem(kind, None)
+
+ # Remove item from plot
+ self._content.pop((item.getName(), kind))
+ if item in self._contentToUpdate:
+ self._contentToUpdate.remove(item)
+ if item.isVisible():
+ self._setDirtyPlot(overlayOnly=item.isOverlay())
+ if item.getBounds() is not None:
+ self._invalidateDataRange()
+ item._removeBackendRenderer(self._backend)
+ item._setPlot(None)
+
+ if (kind == 'curve' and not self.getAllCurves(just_legend=True,
+ withhidden=True)):
+ self._resetColorAndStyle()
+
+ self.sigItemRemoved.emit(item)
+
+ self.notify('contentChanged', action='remove',
+ kind=kind, legend=item.getName())
+
+ def discardItem(self, item) -> bool:
+ """Remove the item from the plot.
+
+ Same as :meth:`removeItem` but do not raise an exception.
+
+ :param ~silx.gui.plot.items.Item item: Item to remove from the plot.
+ :returns: True if the item was present, False otherwise.
+ """
+ try:
+ self.removeItem(item)
+ except ValueError:
+ return False
+ else:
+ return True
+
+ @deprecated(replacement='addItem', since_version='0.13')
+ def _add(self, item):
+ return self.addItem(item)
+
+ @deprecated(replacement='removeItem', since_version='0.13')
+ def _remove(self, item):
+ return self.removeItem(item)
+
+ def getItems(self):
+ """Returns the list of items in the plot
+
+ :rtype: List[silx.gui.plot.items.Item]
+ """
+ return tuple(self._content.values())
+
+ @contextmanager
+ def _muteActiveItemChangedSignal(self):
+ self.__muteActiveItemChanged = True
+ yield
+ self.__muteActiveItemChanged = False
+
+ # Add
+
+ # add * input arguments management:
+ # If an arg is set, then use it.
+ # Else:
+ # If a curve with the same legend exists, then use its arg value
+ # Else, use a default value.
+ # Store used value.
+ # This value is used when curve is updated either internally or by user.
+
+ def addCurve(self, x, y, legend=None, info=None,
+ replace=False,
+ color=None, symbol=None,
+ linewidth=None, linestyle=None,
+ xlabel=None, ylabel=None, yaxis=None,
+ xerror=None, yerror=None, z=None, selectable=None,
+ fill=None, resetzoom=True,
+ histogram=None, copy=True,
+ baseline=None):
+ """Add a 1D curve given by x an y to the graph.
+
+ Curves are uniquely identified by their legend.
+ To add multiple curves, call :meth:`addCurve` multiple times with
+ different legend argument.
+ To replace an existing curve, call :meth:`addCurve` with the
+ existing curve legend.
+ If you want to display the curve values as an histogram see the
+ histogram parameter or :meth:`addHistogram`.
+
+ When curve parameters are not provided, if a curve with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ If you attempt to plot an histogram you can set edges values in x.
+ In this case len(x) = len(y) + 1.
+ If x contains datetime objects the XAxis tickMode is set to
+ TickMode.TIME_SERIES.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param str legend: The legend to be associated to the curve (or None)
+ :param info: User-defined information associated to the curve
+ :param bool replace: True to delete already existing curves
+ (the default is False)
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param float linewidth: The width of the curve in pixels (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ - None (the default) to use default line style
+
+ :param str xlabel: Label to show on the X axis when the curve is active
+ or None to keep default axis label.
+ :param str ylabel: Label to show on the Y axis when the curve is active
+ or None to keep default axis label.
+ :param str yaxis: The Y axis this curve is attached to.
+ Either 'left' (the default) or 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the curve (default: 1)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the curve can be selected.
+ (Default: True)
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param str histogram: if not None then the curve will be draw as an
+ histogram. The step for each values of the curve can be set to the
+ left, center or right of the original x curve values.
+ If histogram is not None and len(x) == len(y)+1 then x is directly
+ take as edges of the histogram.
+ Type of histogram::
+
+ - None (default)
+ - 'left'
+ - 'right'
+ - 'center'
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :param baseline: curve baseline
+ :type: Union[None,float,numpy.ndarray]
+ :returns: The key string identify this curve
+ """
+ # This is an histogram, use addHistogram
+ if histogram is not None:
+ histoLegend = self.addHistogram(histogram=y,
+ edges=x,
+ legend=legend,
+ color=color,
+ fill=fill,
+ align=histogram,
+ copy=copy)
+ histo = self.getHistogram(histoLegend)
+
+ histo.setInfo(info)
+ if linewidth is not None:
+ histo.setLineWidth(linewidth)
+ if linestyle is not None:
+ histo.setLineStyle(linestyle)
+ if xlabel is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support xlabel argument')
+ if ylabel is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support ylabel argument')
+ if yaxis is not None:
+ histo.setYAxis(yaxis)
+ if z is not None:
+ histo.setZValue(z)
+ if selectable is not None:
+ _logger.warning(
+ 'addCurve: Histogram does not support selectable argument')
+
+ return
+
+ legend = 'Unnamed curve 1.1' if legend is None else str(legend)
+
+ # Check if curve was previously active
+ wasActive = self.getActiveCurve(just_legend=True) == legend
+
+ if replace:
+ self._resetColorAndStyle()
+
+ # Create/Update curve object
+ curve = self.getCurve(legend)
+ mustBeAdded = curve is None
+ if curve is None:
+ # No previous curve, create a default one and add it to the plot
+ curve = items.Curve() if histogram is None else items.Histogram()
+ curve.setName(legend)
+ # Set default color, linestyle and symbol
+ default_color, default_linestyle = self._getColorAndStyle()
+ curve.setColor(default_color)
+ curve.setLineStyle(default_linestyle)
+ curve.setSymbol(self._defaultPlotPoints)
+ curve._setBaseline(baseline=baseline)
+
+ # Do not emit sigActiveCurveChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ curve.setInfo(info)
+ if color is not None:
+ curve.setColor(color)
+ if symbol is not None:
+ curve.setSymbol(symbol)
+ if linewidth is not None:
+ curve.setLineWidth(linewidth)
+ if linestyle is not None:
+ curve.setLineStyle(linestyle)
+ if xlabel is not None:
+ curve._setXLabel(xlabel)
+ if ylabel is not None:
+ curve._setYLabel(ylabel)
+ if yaxis is not None:
+ curve.setYAxis(yaxis)
+ if z is not None:
+ curve.setZValue(z)
+ if selectable is not None:
+ curve._setSelectable(selectable)
+ if fill is not None:
+ curve.setFill(fill)
+
+ # Set curve data
+ # If errors not provided, reuse previous ones
+ # TODO: Issue if size of data change but not that of errors
+ if xerror is None:
+ xerror = curve.getXErrorData(copy=False)
+ if yerror is None:
+ yerror = curve.getYErrorData(copy=False)
+
+ # Convert x to timestamps so that the internal representation
+ # remains floating points. The user is expected to set the axis'
+ # tickMode to TickMode.TIME_SERIES and, if necessary, set the axis
+ # to the correct time zone.
+ if len(x) > 0 and isinstance(x[0], dt.datetime):
+ x = [timestamp(d) for d in x]
+
+ curve.setData(x, y, xerror, yerror, baseline=baseline, copy=copy)
+
+ if replace: # Then remove all other curves
+ for c in self.getAllCurves(withhidden=True):
+ if c is not curve:
+ self.removeItem(c)
+
+ if mustBeAdded:
+ self.addItem(curve)
+ else:
+ self._notifyContentChanged(curve)
+
+ if wasActive:
+ self.setActiveCurve(curve.getName())
+ elif self.getActiveCurveSelectionMode() == "legacy":
+ if self.getActiveCurve(just_legend=True) is None:
+ if len(self.getAllCurves(just_legend=True,
+ withhidden=False)) == 1:
+ if curve.isVisible():
+ self.setActiveCurve(curve.getName())
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addHistogram(self,
+ histogram,
+ edges,
+ legend=None,
+ color=None,
+ fill=None,
+ align='center',
+ resetzoom=True,
+ copy=True,
+ z=None,
+ baseline=None):
+ """Add an histogram to the graph.
+
+ This is NOT computing the histogram, this method takes as parameter
+ already computed histogram values.
+
+ Histogram are uniquely identified by their legend.
+ To add multiple histograms, call :meth:`addHistogram` multiple times
+ with different legend argument.
+
+ When histogram parameters are not provided, if an histogram with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray histogram: The values of the histogram.
+ :param numpy.ndarray edges:
+ The bin edges of the histogram.
+ If histogram and edges have the same length, the bin edges
+ are computed according to the align parameter.
+ :param str legend:
+ The legend to be associated to the histogram (or None)
+ :param color: color to be used
+ :type color: str ("#RRGGBB") or RGB unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool fill: True to fill the curve, False otherwise (default).
+ :param str align:
+ In case histogram values and edges have the same length N,
+ the N+1 bin edges are computed according to the alignment in:
+ 'center' (default), 'left', 'right'.
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :param int z: Layer on which to draw the histogram
+ :param baseline: histogram baseline
+ :type: Union[None,float,numpy.ndarray]
+ :returns: The key string identify this histogram
+ """
+ legend = 'Unnamed histogram' if legend is None else str(legend)
+
+ # Create/Update histogram object
+ histo = self.getHistogram(legend)
+ mustBeAdded = histo is None
+ if histo is None:
+ # No previous histogram, create a default one and
+ # add it to the plot
+ histo = items.Histogram()
+ histo.setName(legend)
+ histo.setColor(self._getColorAndStyle()[0])
+
+ # Override previous/default values with provided ones
+ if color is not None:
+ histo.setColor(color)
+ if fill is not None:
+ histo.setFill(fill)
+ if z is not None:
+ histo.setZValue(z=z)
+
+ # Set histogram data
+ histo.setData(histogram=histogram, edges=edges, baseline=baseline,
+ align=align, copy=copy)
+
+ if mustBeAdded:
+ self.addItem(histo)
+ else:
+ self._notifyContentChanged(histo)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addImage(self, data, legend=None, info=None,
+ replace=False,
+ z=None,
+ selectable=None, draggable=None,
+ colormap=None, pixmap=None,
+ xlabel=None, ylabel=None,
+ origin=None, scale=None,
+ resetzoom=True, copy=True):
+ """Add a 2D dataset or an image to the plot.
+
+ It displays either an array of data using a colormap or a RGB(A) image.
+
+ Images are uniquely identified by their legend.
+ To add multiple images, call :meth:`addImage` multiple times with
+ different legend argument.
+ To replace/update an existing image, call :meth:`addImage` with the
+ existing image legend.
+
+ When image parameters are not provided, if an image with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray data:
+ (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ Note: boolean values are converted to int8.
+ :param str legend: The legend to be associated to the image (or None)
+ :param info: User-defined information associated to the image
+ :param bool replace:
+ True to delete already existing images (Default: False).
+ :param int z: Layer on which to draw the image (default: 0)
+ This allows to control the overlay.
+ :param bool selectable: Indicate if the image can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the image can be moved.
+ (default: False)
+ :param colormap: Colormap object to use (or None).
+ This is ignored if data is a RGB(A) image.
+ :type colormap: Union[~silx.gui.colors.Colormap, dict]
+ :param pixmap: Pixmap representation of the data (if any)
+ :type pixmap: (nrows, ncolumns, RGBA) ubyte array or None (default)
+ :param str xlabel: X axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param str ylabel: Y axis label to show when this curve is active,
+ or None to keep default axis label.
+ :param origin: (origin X, origin Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (0., 0.)
+ :type origin: float or 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ It is possible to pass a single float if both
+ coordinates are equal.
+ Default: (1., 1.)
+ :type scale: float or 2-tuple of float
+ :param bool resetzoom: True (the default) to reset the zoom.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this image
+ """
+ legend = "Unnamed Image 1.1" if legend is None else str(legend)
+
+ # Check if image was previously active
+ wasActive = self.getActiveImage(just_legend=True) == legend
+
+ data = numpy.array(data, copy=False)
+ assert data.ndim in (2, 3)
+
+ image = self.getImage(legend)
+ if image is not None and image.getData(copy=False).ndim != data.ndim:
+ # Update a data image with RGBA image or the other way around:
+ # Remove previous image
+ # In this case, we don't retrieve defaults from the previous image
+ self.removeItem(image)
+ image = None
+
+ mustBeAdded = image is None
+ if image is None:
+ # No previous image, create a default one and add it to the plot
+ if data.ndim == 2:
+ image = items.ImageData()
+ image.setColormap(self.getDefaultColormap())
+ else:
+ image = items.ImageRgba()
+ image.setName(legend)
+
+ # Do not emit sigActiveImageChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ image.setInfo(info)
+ if origin is not None:
+ image.setOrigin(origin)
+ if scale is not None:
+ image.setScale(scale)
+ if z is not None:
+ image.setZValue(z)
+ if selectable is not None:
+ image._setSelectable(selectable)
+ if draggable is not None:
+ image._setDraggable(draggable)
+ if colormap is not None and isinstance(image, items.ColormapMixIn):
+ if isinstance(colormap, dict):
+ image.setColormap(Colormap._fromDict(colormap))
+ else:
+ assert isinstance(colormap, Colormap)
+ image.setColormap(colormap)
+ if xlabel is not None:
+ image._setXLabel(xlabel)
+ if ylabel is not None:
+ image._setYLabel(ylabel)
+
+ if data.ndim == 2:
+ image.setData(data, alternative=pixmap, copy=copy)
+ else: # RGB(A) image
+ if pixmap is not None:
+ _logger.warning(
+ 'addImage: pixmap argument ignored when data is RGB(A)')
+ image.setData(data, copy=copy)
+
+ if replace:
+ for img in self.getAllImages():
+ if img is not image:
+ self.removeItem(img)
+
+ if mustBeAdded:
+ self.addItem(image)
+ else:
+ self._notifyContentChanged(image)
+
+ if len(self.getAllImages()) == 1 or wasActive:
+ self.setActiveImage(legend)
+
+ if resetzoom:
+ # We ask for a zoom reset in order to handle the plot scaling
+ # if the user does not want that, autoscale of the different
+ # axes has to be set to off.
+ self.resetZoom()
+
+ return legend
+
+ def addScatter(self, x, y, value, legend=None, colormap=None,
+ info=None, symbol=None, xerror=None, yerror=None,
+ z=None, copy=True):
+ """Add a (x, y, value) scatter to the graph.
+
+ Scatters are uniquely identified by their legend.
+ To add multiple scatters, call :meth:`addScatter` multiple times with
+ different legend argument.
+ To replace/update an existing scatter, call :meth:`addScatter` with the
+ existing scatter legend.
+
+ When scatter parameters are not provided, if a scatter with the
+ same legend is displayed in the plot, its parameters are used.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates
+ :param numpy.ndarray value: The data value associated with each point
+ :param str legend: The legend to be associated to the scatter (or None)
+ :param ~silx.gui.colors.Colormap colormap:
+ Colormap object to be used for the scatter (or None)
+ :param info: User-defined information associated to the curve
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ - None (the default) to use default symbol
+
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param int z: Layer on which to draw the scatter (default: 1)
+ This allows to control the overlay.
+
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ :returns: The key string identify this scatter
+ """
+ legend = 'Unnamed scatter 1.1' if legend is None else str(legend)
+
+ # Check if scatter was previously active
+ wasActive = self._getActiveItem(kind='scatter',
+ just_legend=True) == legend
+
+ # Create/Update curve object
+ scatter = self._getItem(kind='scatter', legend=legend)
+ mustBeAdded = scatter is None
+ if scatter is None:
+ # No previous scatter, create a default one and add it to the plot
+ scatter = items.Scatter()
+ scatter.setName(legend)
+ scatter.setColormap(self.getDefaultColormap())
+
+ # Do not emit sigActiveScatterChanged,
+ # it will be sent once with _setActiveItem
+ with self._muteActiveItemChangedSignal():
+ # Override previous/default values with provided ones
+ scatter.setInfo(info)
+ if symbol is not None:
+ scatter.setSymbol(symbol)
+ if z is not None:
+ scatter.setZValue(z)
+ if colormap is not None:
+ if isinstance(colormap, dict):
+ scatter.setColormap(Colormap._fromDict(colormap))
+ else:
+ assert isinstance(colormap, Colormap)
+ scatter.setColormap(colormap)
+
+ # Set scatter data
+ # If errors not provided, reuse previous ones
+ if xerror is None:
+ xerror = scatter.getXErrorData(copy=False)
+ if xerror is not None and len(xerror) != len(x):
+ xerror = None
+ if yerror is None:
+ yerror = scatter.getYErrorData(copy=False)
+ if yerror is not None and len(yerror) != len(y):
+ yerror = None
+
+ scatter.setData(x, y, value, xerror, yerror, copy=copy)
+
+ if mustBeAdded:
+ self.addItem(scatter)
+ else:
+ self._notifyContentChanged(scatter)
+
+ scatters = [item for item in self.getItems()
+ if isinstance(item, items.Scatter) and item.isVisible()]
+ if len(scatters) == 1 or wasActive:
+ self._setActiveItem('scatter', scatter.getName())
+
+ return legend
+
+ def addShape(self, xdata, ydata, legend=None, info=None,
+ replace=False,
+ shape="polygon", color='black', fill=True,
+ overlay=False, z=None, linestyle="-", linewidth=1.0,
+ linebgcolor=None):
+ """Add an item (i.e. a shape) to the plot.
+
+ Items are uniquely identified by their legend.
+ To add multiple items, call :meth:`addItem` multiple times with
+ different legend argument.
+ To replace/update an existing item, call :meth:`addItem` with the
+ existing item legend.
+
+ :param numpy.ndarray xdata: The X coords of the points of the shape
+ :param numpy.ndarray ydata: The Y coords of the points of the shape
+ :param str legend: The legend to be associated to the item
+ :param info: User-defined information associated to the item
+ :param bool replace: True (default) to delete already existing images
+ :param str shape: Type of item to be drawn in
+ hline, polygon (the default), rectangle, vline,
+ polylines
+ :param str color: Color of the item, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool fill: True (the default) to fill the shape
+ :param bool overlay: True if item is an overlay (Default: False).
+ This allows for rendering optimization if this
+ item is changed often.
+ :param int z: Layer on which to draw the item (default: 2)
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
+ :returns: The key string identify this item
+ """
+ # expected to receive the same parameters as the signal
+
+ legend = "Unnamed Item 1.1" if legend is None else str(legend)
+
+ z = int(z) if z is not None else 2
+
+ if replace:
+ self.remove(kind='item')
+ else:
+ self.remove(legend, kind='item')
+
+ item = items.Shape(shape)
+ item.setName(legend)
+ item.setInfo(info)
+ item.setColor(color)
+ item.setFill(fill)
+ item.setOverlay(overlay)
+ item.setZValue(z)
+ item.setPoints(numpy.array((xdata, ydata)).T)
+ item.setLineStyle(linestyle)
+ item.setLineWidth(linewidth)
+ item.setLineBgColor(linebgcolor)
+
+ self.addItem(item)
+
+ return legend
+
+ def addXMarker(self, x, legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis='left'):
+ """Add a vertical line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addXMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param x: Position of the marker on the X axis in data coordinates
+ :type x: Union[None, float]
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display on the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The key string identify this marker
+ """
+ return self._addMarker(x=x, y=None, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=None, constraint=constraint,
+ yaxis=yaxis)
+
+ def addYMarker(self, y,
+ legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ constraint=None,
+ yaxis='left'):
+ """Add a horizontal line marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addYMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker.
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The key string identify this marker
+ """
+ return self._addMarker(x=None, y=y, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=None, constraint=constraint,
+ yaxis=yaxis)
+
+ def addMarker(self, x, y, legend=None,
+ text=None,
+ color=None,
+ selectable=False,
+ draggable=False,
+ symbol='+',
+ constraint=None,
+ yaxis='left'):
+ """Add a point marker to the plot.
+
+ Markers are uniquely identified by their legend.
+ As opposed to curves, images and items, two calls to
+ :meth:`addMarker` without legend argument adds two markers with
+ different identifying legends.
+
+ :param float x: Position of the marker on the X axis in data
+ coordinates
+ :param float y: Position of the marker on the Y axis in data
+ coordinates
+ :param str legend: Legend associated to the marker to identify it
+ :param str text: Text to display next to the marker
+ :param str color: Color of the marker, e.g., 'blue', 'b', '#FF0000'
+ (Default: 'black')
+ :param bool selectable: Indicate if the marker can be selected.
+ (default: False)
+ :param bool draggable: Indicate if the marker can be moved.
+ (default: False)
+ :param str symbol: Symbol representing the marker in::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross (the default)
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ This parameter is only used if draggable is True.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: The key string identify this marker
+ """
+ if x is None:
+ xmin, xmax = self._xAxis.getLimits()
+ x = 0.5 * (xmax + xmin)
+
+ if y is None:
+ ymin, ymax = self._yAxis.getLimits()
+ y = 0.5 * (ymax + ymin)
+
+ return self._addMarker(x=x, y=y, legend=legend,
+ text=text, color=color,
+ selectable=selectable, draggable=draggable,
+ symbol=symbol, constraint=constraint,
+ yaxis=yaxis)
+
+ def _addMarker(self, x, y, legend,
+ text, color,
+ selectable, draggable,
+ symbol, constraint,
+ yaxis=None):
+ """Common method for adding point, vline and hline marker.
+
+ See :meth:`addMarker` for argument documentation.
+ """
+ assert (x, y) != (None, None)
+
+ if legend is None: # Find an unused legend
+ markerLegends = [item.getName() for item in self.getItems()
+ if isinstance(item, items.MarkerBase)]
+ for index in itertools.count():
+ legend = "Unnamed Marker %d" % index
+ if legend not in markerLegends:
+ break # Keep this legend
+ legend = str(legend)
+
+ if x is None:
+ markerClass = items.YMarker
+ elif y is None:
+ markerClass = items.XMarker
+ else:
+ markerClass = items.Marker
+
+ # Create/Update marker object
+ marker = self._getMarker(legend)
+ if marker is not None and not isinstance(marker, markerClass):
+ _logger.warning('Adding marker with same legend'
+ ' but different type replaces it')
+ self.removeItem(marker)
+ marker = None
+
+ mustBeAdded = marker is None
+ if marker is None:
+ # No previous marker, create one
+ marker = markerClass()
+ marker.setName(legend)
+
+ if text is not None:
+ marker.setText(text)
+ if color is not None:
+ marker.setColor(color)
+ if selectable is not None:
+ marker._setSelectable(selectable)
+ if draggable is not None:
+ marker._setDraggable(draggable)
+ if symbol is not None:
+ marker.setSymbol(symbol)
+ marker.setYAxis(yaxis)
+
+ # TODO to improve, but this ensure constraint is applied
+ marker.setPosition(x, y)
+ if constraint is not None:
+ marker._setConstraint(constraint)
+ marker.setPosition(x, y)
+
+ if mustBeAdded:
+ self.addItem(marker)
+ else:
+ self._notifyContentChanged(marker)
+
+ return legend
+
+ # Hide
+
+ def isCurveHidden(self, legend):
+ """Returns True if the curve associated to legend is hidden, else False
+
+ :param str legend: The legend key identifying the curve
+ :return: True if the associated curve is hidden, False otherwise
+ """
+ curve = self._getItem('curve', legend)
+ return curve is not None and not curve.isVisible()
+
+ def hideCurve(self, legend, flag=True):
+ """Show/Hide the curve associated to legend.
+
+ Even when hidden, the curve is kept in the list of curves.
+
+ :param str legend: The legend associated to the curve to be hidden
+ :param bool flag: True (default) to hide the curve, False to show it
+ """
+ curve = self._getItem('curve', legend)
+ if curve is None:
+ _logger.warning('Curve not in plot: %s', legend)
+ return
+
+ isVisible = not flag
+ if isVisible != curve.isVisible():
+ curve.setVisible(isVisible)
+
+ # Remove
+
+ ITEM_KINDS = 'curve', 'image', 'scatter', 'item', 'marker', 'histogram'
+ """List of supported kind of items in the plot."""
+
+ _ACTIVE_ITEM_KINDS = 'curve', 'scatter', 'image'
+ """List of item's kind which have a active item."""
+
+ def remove(self, legend=None, kind=ITEM_KINDS):
+ """Remove one or all element(s) of the given legend and kind.
+
+ Examples:
+
+ - ``remove()`` clears the plot
+ - ``remove(kind='curve')`` removes all curves from the plot
+ - ``remove('myCurve', kind='curve')`` removes the curve with
+ legend 'myCurve' from the plot.
+ - ``remove('myImage, kind='image')`` removes the image with
+ legend 'myImage' from the plot.
+ - ``remove('myImage')`` removes elements (for instance curve, image,
+ item and marker) with legend 'myImage'.
+
+ :param str legend: The legend associated to the element to remove,
+ or None to remove
+ :param kind: The kind of elements to remove from the plot.
+ See :attr:`ITEM_KINDS`.
+ By default, it removes all kind of elements.
+ :type kind: str or tuple of str to specify multiple kinds.
+ """
+ if kind == 'all': # Replace all by tuple of all kinds
+ kind = self.ITEM_KINDS
+
+ if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
+ kind = (kind,)
+
+ for aKind in kind:
+ assert aKind in self.ITEM_KINDS
+
+ if legend is None: # This is a clear
+ # Clear each given kind
+ for aKind in kind:
+ for item in self.getItems():
+ if (isinstance(item, self._KIND_TO_CLASSES[aKind]) and
+ item.getPlot() is self): # Make sure item is still in the plot
+ self.removeItem(item)
+
+ else: # This is removing a single element
+ # Remove each given kind
+ for aKind in kind:
+ item = self._getItem(aKind, legend)
+ if item is not None:
+ self.removeItem(item)
+
+ def removeCurve(self, legend):
+ """Remove the curve associated to legend from the graph.
+
+ :param str legend: The legend associated to the curve to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='curve')
+
+ def removeImage(self, legend):
+ """Remove the image associated to legend from the graph.
+
+ :param str legend: The legend associated to the image to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='image')
+
+ def removeMarker(self, legend):
+ """Remove the marker associated to legend from the graph.
+
+ :param str legend: The legend associated to the marker to be deleted
+ """
+ if legend is None:
+ return
+ self.remove(legend, kind='marker')
+
+ # Clear
+
+ def clear(self):
+ """Remove everything from the plot."""
+ for item in self.getItems():
+ if item.getPlot() is self: # Make sure item is still in the plot
+ self.removeItem(item)
+
+ def clearCurves(self):
+ """Remove all the curves from the plot."""
+ self.remove(kind='curve')
+
+ def clearImages(self):
+ """Remove all the images from the plot."""
+ self.remove(kind='image')
+
+ def clearItems(self):
+ """Remove all the items from the plot. """
+ self.remove(kind='item')
+
+ def clearMarkers(self):
+ """Remove all the markers from the plot."""
+ self.remove(kind='marker')
+
+ # Interaction
+
+ def getGraphCursor(self):
+ """Returns the state of the crosshair cursor.
+
+ See :meth:`setGraphCursor`.
+
+ :return: None if the crosshair cursor is not active,
+ else a tuple (color, linewidth, linestyle).
+ """
+ return self._cursorConfiguration
+
+ def setGraphCursor(self, flag=False, color='black',
+ linewidth=1, linestyle='-'):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ The crosshair cursor is hidden by default.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array
+ (Default: black).
+ :param int linewidth: The width of the lines of the crosshair
+ (Default: 1).
+ :param str linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line (the default)
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ """
+ if flag:
+ self._cursorConfiguration = color, linewidth, linestyle
+ else:
+ self._cursorConfiguration = None
+
+ self._backend.setGraphCursor(flag=flag, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ self._setDirtyPlot()
+ self.notify('setGraphCursor',
+ state=self._cursorConfiguration is not None)
+
+ def pan(self, direction, factor=0.1):
+ """Pan the graph in the given direction by the given factor.
+
+ Warning: Pan of right Y axis not implemented!
+
+ :param str direction: One of 'up', 'down', 'left', 'right'.
+ :param float factor: Proportion of the range used to pan the graph.
+ Must be strictly positive.
+ """
+ assert direction in ('up', 'down', 'left', 'right')
+ assert factor > 0.
+
+ if direction in ('left', 'right'):
+ xFactor = factor if direction == 'right' else - factor
+ xMin, xMax = self._xAxis.getLimits()
+
+ xMin, xMax = _utils.applyPan(xMin, xMax, xFactor,
+ self._xAxis.getScale() == self._xAxis.LOGARITHMIC)
+ self._xAxis.setLimits(xMin, xMax)
+
+ else: # direction in ('up', 'down')
+ sign = -1. if self._yAxis.isInverted() else 1.
+ yFactor = sign * (factor if direction == 'up' else -factor)
+ yMin, yMax = self._yAxis.getLimits()
+ yIsLog = self._yAxis.getScale() == self._yAxis.LOGARITHMIC
+
+ yMin, yMax = _utils.applyPan(yMin, yMax, yFactor, yIsLog)
+ self._yAxis.setLimits(yMin, yMax)
+
+ y2Min, y2Max = self._yRightAxis.getLimits()
+
+ y2Min, y2Max = _utils.applyPan(y2Min, y2Max, yFactor, yIsLog)
+ self._yRightAxis.setLimits(y2Min, y2Max)
+
+ # Active Curve/Image
+
+ def isActiveCurveHandling(self):
+ """Returns True if active curve selection is enabled.
+
+ :rtype: bool
+ """
+ return self.getActiveCurveSelectionMode() != 'none'
+
+ def setActiveCurveHandling(self, flag=True):
+ """Enable/Disable active curve selection.
+
+ :param bool flag: True to enable 'atmostone' active curve selection,
+ False to disable active curve selection.
+ """
+ self.setActiveCurveSelectionMode('atmostone' if flag else 'none')
+
+ def getActiveCurveStyle(self):
+ """Returns the current style applied to active curve
+
+ :rtype: CurveStyle
+ """
+ return self._activeCurveStyle
+
+ def setActiveCurveStyle(self,
+ color=None,
+ linewidth=None,
+ linestyle=None,
+ symbol=None,
+ symbolsize=None):
+ """Set the style of active curve
+
+ :param color: Color
+ :param Union[str,None] linestyle: Style of the line
+ :param Union[float,None] linewidth: Width of the line
+ :param Union[str,None] symbol: Symbol of the markers
+ :param Union[float,None] symbolsize: Size of the symbols
+ """
+ self._activeCurveStyle = CurveStyle(color=color,
+ linewidth=linewidth,
+ linestyle=linestyle,
+ symbol=symbol,
+ symbolsize=symbolsize)
+ curve = self.getActiveCurve()
+ if curve is not None:
+ curve.setHighlightedStyle(self.getActiveCurveStyle())
+
+ @deprecated(replacement="getActiveCurveStyle", since_version="0.9")
+ def getActiveCurveColor(self):
+ """Get the color used to display the currently active curve.
+
+ See :meth:`setActiveCurveColor`.
+ """
+ return self._activeCurveStyle.getColor()
+
+ @deprecated(replacement="setActiveCurveStyle", since_version="0.9")
+ def setActiveCurveColor(self, color="#000000"):
+ """Set the color to use to display the currently active curve.
+
+ :param str color: Color of the active curve,
+ e.g., 'blue', 'b', '#FF0000' (Default: 'black')
+ """
+ if color is None:
+ color = "black"
+ if color in self.colorDict:
+ color = self.colorDict[color]
+ self.setActiveCurveStyle(color=color)
+
+ def getActiveCurve(self, just_legend=False):
+ """Return the currently active curve.
+
+ It returns None in case of not having an active curve.
+
+ :param bool just_legend: True to get the legend of the curve,
+ False (the default) to get the curve data
+ and info.
+ :return: Active curve's legend or corresponding
+ :class:`.items.Curve`
+ :rtype: str or :class:`.items.Curve` or None
+ """
+ if not self.isActiveCurveHandling():
+ return None
+
+ return self._getActiveItem(kind='curve', just_legend=just_legend)
+
+ def setActiveCurve(self, legend):
+ """Make the curve associated to legend the active curve.
+
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ if not self.isActiveCurveHandling():
+ return
+ if legend is None and self.getActiveCurveSelectionMode() == "legacy":
+ _logger.info(
+ 'setActiveCurve(None) ignored due to active curve selection mode')
+ return
+
+ return self._setActiveItem(kind='curve', legend=legend)
+
+ def setActiveCurveSelectionMode(self, mode):
+ """Sets the current selection mode.
+
+ :param str mode: The active curve selection mode to use.
+ It can be: 'legacy', 'atmostone' or 'none'.
+ """
+ assert mode in ('legacy', 'atmostone', 'none')
+
+ if mode != self._activeCurveSelectionMode:
+ self._activeCurveSelectionMode = mode
+ if mode == 'none': # reset active curve
+ self._setActiveItem(kind='curve', legend=None)
+
+ elif mode == 'legacy' and self.getActiveCurve() is None:
+ # Select an active curve
+ curves = self.getAllCurves(just_legend=False,
+ withhidden=False)
+ if len(curves) == 1:
+ if curves[0].isVisible():
+ self.setActiveCurve(curves[0].getName())
+
+ def getActiveCurveSelectionMode(self):
+ """Returns the current selection mode.
+
+ It can be "atmostone", "legacy" or "none".
+
+ :rtype: str
+ """
+ return self._activeCurveSelectionMode
+
+ def getActiveImage(self, just_legend=False):
+ """Returns the currently active image.
+
+ It returns None in case of not having an active image.
+
+ :param bool just_legend: True to get the legend of the image,
+ False (the default) to get the image data
+ and info.
+ :return: Active image's legend or corresponding image object
+ :rtype: str, :class:`.items.ImageData`, :class:`.items.ImageRgba`
+ or None
+ """
+ return self._getActiveItem(kind='image', just_legend=just_legend)
+
+ def setActiveImage(self, legend):
+ """Make the image associated to legend the active image.
+
+ :param str legend: The legend associated to the image
+ or None to have no active image.
+ """
+ return self._setActiveItem(kind='image', legend=legend)
+
+ def getActiveScatter(self, just_legend=False):
+ """Returns the currently active scatter.
+
+ It returns None in case of not having an active scatter.
+
+ :param bool just_legend: True to get the legend of the scatter,
+ False (the default) to get the scatter data
+ and info.
+ :return: Active scatter's legend or corresponding scatter object
+ :rtype: str, :class:`.items.Scatter` or None
+ """
+ return self._getActiveItem(kind='scatter', just_legend=just_legend)
+
+ def setActiveScatter(self, legend):
+ """Make the scatter associated to legend the active scatter.
+
+ :param str legend: The legend associated to the scatter
+ or None to have no active scatter.
+ """
+ return self._setActiveItem(kind='scatter', legend=legend)
+
+ def _getActiveItem(self, kind, just_legend=False):
+ """Return the currently active item of that kind if any
+
+ :param str kind: Type of item: 'curve', 'scatter' or 'image'
+ :param bool just_legend: True to get the legend,
+ False (default) to get the item
+ :return: legend or item or None if no active item
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
+
+ if self._activeLegend[kind] is None:
+ return None
+
+ item = self._getItem(kind, self._activeLegend[kind])
+ if item is None:
+ return None
+
+ return item.getName() if just_legend else item
+
+ def _setActiveItem(self, kind, legend):
+ """Make the curve associated to legend the active curve.
+
+ :param str kind: Type of item: 'curve' or 'image'
+ :param legend: The legend associated to the curve
+ or None to have no active curve.
+ :type legend: str or None
+ """
+ assert kind in self._ACTIVE_ITEM_KINDS
+
+ xLabel = None
+ yLabel = None
+ yRightLabel = None
+
+ oldActiveItem = self._getActiveItem(kind=kind)
+
+ if oldActiveItem is not None: # Stop listening previous active image
+ oldActiveItem.sigItemChanged.disconnect(self._activeItemChanged)
+
+ # Curve specific: Reset highlight of previous active curve
+ if kind == 'curve' and oldActiveItem is not None:
+ oldActiveItem.setHighlighted(False)
+
+ if legend is None:
+ self._activeLegend[kind] = None
+ else:
+ legend = str(legend)
+ item = self._getItem(kind, legend)
+ if item is None:
+ _logger.warning("This %s does not exist: %s", kind, legend)
+ self._activeLegend[kind] = None
+ else:
+ self._activeLegend[kind] = legend
+
+ # Curve specific: handle highlight
+ if kind == 'curve':
+ item.setHighlightedStyle(self.getActiveCurveStyle())
+ item.setHighlighted(True)
+
+ if isinstance(item, items.LabelsMixIn):
+ if item.getXLabel() is not None:
+ xLabel = item.getXLabel()
+ if item.getYLabel() is not None:
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ yRightLabel = item.getYLabel()
+ else:
+ yLabel = item.getYLabel()
+
+ # Start listening new active item
+ item.sigItemChanged.connect(self._activeItemChanged)
+
+ # Store current labels and update plot
+ self._xAxis._setCurrentLabel(xLabel)
+ self._yAxis._setCurrentLabel(yLabel)
+ self._yRightAxis._setCurrentLabel(yRightLabel)
+
+ self._setDirtyPlot()
+
+ activeLegend = self._activeLegend[kind]
+ if oldActiveItem is not None or activeLegend is not None:
+ if oldActiveItem is None:
+ oldActiveLegend = None
+ else:
+ oldActiveLegend = oldActiveItem.getName()
+ self.notify(
+ 'active' + kind[0].upper() + kind[1:] + 'Changed',
+ updated=oldActiveLegend != activeLegend,
+ previous=oldActiveLegend,
+ legend=activeLegend)
+
+ return activeLegend
+
+ def _activeItemChanged(self, type_):
+ """Listen for active item changed signal and broadcast signal
+
+ :param item.ItemChangedType type_: The type of item change
+ """
+ if not self.__muteActiveItemChanged:
+ item = self.sender()
+ if item is not None:
+ kind = self._itemKind(item)
+ self.notify(
+ 'active' + kind[0].upper() + kind[1:] + 'Changed',
+ updated=False,
+ previous=item.getName(),
+ legend=item.getName())
+
+ # Getters
+
+ def getAllCurves(self, just_legend=False, withhidden=False):
+ """Returns all curves legend or info and data.
+
+ It returns an empty list in case of not having any curve.
+
+ If just_legend is False, it returns a list of :class:`items.Curve`
+ objects describing the curves.
+ If just_legend is True, it returns a list of curves' legend.
+
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of curves' legend or :class:`.items.Curve`
+ :rtype: list of str or list of :class:`.items.Curve`
+ """
+ curves = [item for item in self.getItems() if
+ isinstance(item, items.Curve) and
+ (withhidden or item.isVisible())]
+ return [curve.getName() for curve in curves] if just_legend else curves
+
+ def getCurve(self, legend=None):
+ """Get the object describing a specific curve.
+
+ It returns None in case no matching curve is found.
+
+ :param str legend:
+ The legend identifying the curve.
+ If not provided or None (the default), the active curve is returned
+ or if there is no active curve, the latest updated curve that is
+ not hidden is returned if there are curves in the plot.
+ :return: None or :class:`.items.Curve` object
+ """
+ return self._getItem(kind='curve', legend=legend)
+
+ def getAllImages(self, just_legend=False):
+ """Returns all images legend or objects.
+
+ It returns an empty list in case of not having any image.
+
+ If just_legend is False, it returns a list of :class:`items.ImageBase`
+ objects describing the images.
+ If just_legend is True, it returns a list of legends.
+
+ :param bool just_legend: True to get the legend of the images,
+ False (the default) to get the images'
+ object.
+ :return: list of images' legend or :class:`.items.ImageBase`
+ :rtype: list of str or list of :class:`.items.ImageBase`
+ """
+ images = [item for item in self.getItems()
+ if isinstance(item, items.ImageBase)]
+ return [image.getName() for image in images] if just_legend else images
+
+ def getImage(self, legend=None):
+ """Get the object describing a specific image.
+
+ It returns None in case no matching image is found.
+
+ :param str legend:
+ The legend identifying the image.
+ If not provided or None (the default), the active image is returned
+ or if there is no active image, the latest updated image
+ is returned if there are images in the plot.
+ :return: None or :class:`.items.ImageBase` object
+ """
+ return self._getItem(kind='image', legend=legend)
+
+ def getScatter(self, legend=None):
+ """Get the object describing a specific scatter.
+
+ It returns None in case no matching scatter is found.
+
+ :param str legend:
+ The legend identifying the scatter.
+ If not provided or None (the default), the active scatter is
+ returned or if there is no active scatter, the latest updated
+ scatter is returned if there are scatters in the plot.
+ :return: None or :class:`.items.Scatter` object
+ """
+ return self._getItem(kind='scatter', legend=legend)
+
+ def getHistogram(self, legend=None):
+ """Get the object describing a specific histogram.
+
+ It returns None in case no matching histogram is found.
+
+ :param str legend:
+ The legend identifying the histogram.
+ If not provided or None (the default), the latest updated scatter
+ is returned if there are histograms in the plot.
+ :return: None or :class:`.items.Histogram` object
+ """
+ return self._getItem(kind='histogram', legend=legend)
+
+ @deprecated(replacement='getItems', since_version='0.13')
+ def _getItems(self, kind=ITEM_KINDS, just_legend=False, withhidden=False):
+ """Retrieve all items of a kind in the plot
+
+ :param kind: The kind of elements to retrieve from the plot.
+ See :attr:`ITEM_KINDS`.
+ By default, it removes all kind of elements.
+ :type kind: str or tuple of str to specify multiple kinds.
+ :param str kind: Type of item: 'curve' or 'image'
+ :param bool just_legend: True to get the legend of the curves,
+ False (the default) to get the curves' data
+ and info.
+ :param bool withhidden: False (default) to skip hidden curves.
+ :return: list of legends or item objects
+ """
+ if kind == 'all': # Replace all by tuple of all kinds
+ kind = self.ITEM_KINDS
+
+ if kind in self.ITEM_KINDS: # Kind is a str, make it a tuple
+ kind = (kind,)
+
+ for aKind in kind:
+ assert aKind in self.ITEM_KINDS
+
+ output = []
+ for item in self.getItems():
+ type_ = self._itemKind(item)
+ if type_ in kind and (withhidden or item.isVisible()):
+ output.append(item.getName() if just_legend else item)
+ return output
+
+ def _getItem(self, kind, legend=None):
+ """Get an item from the plot: either an image or a curve.
+
+ Returns None if no match found.
+
+ :param str kind: Type of item to retrieve,
+ see :attr:`ITEM_KINDS`.
+ :param str legend: Legend of the item or
+ None to get active or last item
+ :return: Object describing the item or None
+ """
+ assert kind in self.ITEM_KINDS
+
+ if legend is not None:
+ return self._content.get((legend, kind), None)
+ else:
+ if kind in self._ACTIVE_ITEM_KINDS:
+ item = self._getActiveItem(kind=kind)
+ if item is not None: # Return active item if available
+ return item
+ # Return last visible item if any
+ itemClasses = self._KIND_TO_CLASSES[kind]
+ allItems = [item for item in self.getItems()
+ if isinstance(item, itemClasses) and item.isVisible()]
+ return allItems[-1] if allItems else None
+
+ # Limits
+
+ def _notifyLimitsChanged(self, emitSignal=True):
+ """Send an event when plot area limits are changed."""
+ xRange = self._xAxis.getLimits()
+ yRange = self._yAxis.getLimits()
+ y2Range = self._yRightAxis.getLimits()
+ if emitSignal:
+ axes = self.getXAxis(), self.getYAxis(), self.getYAxis(axis="right")
+ ranges = xRange, yRange, y2Range
+ for axis, limits in zip(axes, ranges):
+ axis.sigLimitsChanged.emit(*limits)
+ event = PlotEvents.prepareLimitsChangedSignal(
+ id(self.getWidgetHandle()), xRange, yRange, y2Range)
+ self.notify(**event)
+
+ def getLimitsHistory(self):
+ """Returns the object handling the history of limits of the plot"""
+ return self._limitsHistory
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self._backend.getGraphXLimits()
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the graph X (bottom) limits.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self._xAxis.setLimits(xmin, xmax)
+
+ def getGraphYLimits(self, axis='left'):
+ """Get the graph Y limits.
+
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ :return: Minimum and maximum values of the X axis
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.getLimits()
+
+ def setGraphYLimits(self, ymin, ymax, axis='left'):
+ """Set the graph Y limits.
+
+ :param float ymin: minimum bottom axis value
+ :param float ymax: maximum bottom axis value
+ :param str axis: The axis for which to get the limits:
+ Either 'left' or 'right'
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.setLimits(ymin, ymax)
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ If y2min or y2max is None, the right Y axis limits are not updated.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value or None (the default)
+ :param float y2max: maximum right axis value or None (the default)
+ """
+ # Deal with incorrect values
+ axis = self.getXAxis()
+ xmin, xmax = axis._checkLimits(xmin, xmax)
+ axis = self.getYAxis()
+ ymin, ymax = axis._checkLimits(ymin, ymax)
+
+ if y2min is None or y2max is None:
+ # if one limit is None, both are ignored
+ y2min, y2max = None, None
+ else:
+ axis = self.getYAxis(axis="right")
+ y2min, y2max = axis._checkLimits(y2min, y2max)
+
+ if self._viewConstrains:
+ view = self._viewConstrains.normalize(xmin, xmax, ymin, ymax)
+ xmin, xmax, ymin, ymax = view
+
+ self._backend.setLimits(xmin, xmax, ymin, ymax, y2min, y2max)
+ self._setDirtyPlot()
+ self._notifyLimitsChanged()
+
+ def _getViewConstraints(self):
+ """Return the plot object managing constaints on the plot view.
+
+ :rtype: ViewConstraints
+ """
+ if self._viewConstrains is None:
+ self._viewConstrains = ViewConstraints()
+ return self._viewConstrains
+
+ # Title and labels
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str."""
+ return self._graphTitle
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ self._graphTitle = str(title)
+ self._backend.setGraphTitle(title)
+ self._setDirtyPlot()
+
+ def getGraphXLabel(self):
+ """Return the current X axis label as a str."""
+ return self._xAxis.getLabel()
+
+ def setGraphXLabel(self, label="X"):
+ """Set the plot X axis label.
+
+ The provided label can be temporarily replaced by the X label of the
+ active curve if any.
+
+ :param str label: The X axis label (default: 'X')
+ """
+ self._xAxis.setLabel(label)
+
+ def getGraphYLabel(self, axis='left'):
+ """Return the current Y axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.getLabel()
+
+ def setGraphYLabel(self, label="Y", axis='left'):
+ """Set the plot Y axis label.
+
+ The provided label can be temporarily replaced by the Y label of the
+ active curve if any.
+
+ :param str label: The Y axis label (default: 'Y')
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ assert axis in ('left', 'right')
+ yAxis = self._yAxis if axis == 'left' else self._yRightAxis
+ return yAxis.setLabel(label)
+
+ # Axes
+
+ def getXAxis(self):
+ """Returns the X axis
+
+ .. versionadded:: 0.6
+
+ :rtype: :class:`.items.Axis`
+ """
+ return self._xAxis
+
+ def getYAxis(self, axis="left"):
+ """Returns an Y axis
+
+ .. versionadded:: 0.6
+
+ :param str axis: The Y axis to return
+ ('left' or 'right').
+ :rtype: :class:`.items.Axis`
+ """
+ assert(axis in ["left", "right"])
+ return self._yAxis if axis == "left" else self._yRightAxis
+
+ def setAxesDisplayed(self, displayed: bool):
+ """Display or not the axes.
+
+ :param bool displayed: If `True` axes are displayed. If `False` axes
+ are not anymore visible and the margin used for them is removed.
+ """
+ if displayed != self.__axesDisplayed:
+ self.__axesDisplayed = displayed
+ if displayed:
+ self._backend.setAxesMargins(*self.__axesMargins)
+ else:
+ self._backend.setAxesMargins(0., 0., 0., 0.)
+ self._setDirtyPlot()
+ self._sigAxesVisibilityChanged.emit(displayed)
+
+ def isAxesDisplayed(self) -> bool:
+ """Returns whether or not axes are currently displayed
+
+ :rtype: bool
+ """
+ return self.__axesDisplayed
+
+ def setAxesMargins(
+ self, left: float, top: float, right: float, bottom: float):
+ """Set ratios of margins surrounding data plot area.
+
+ All ratios must be within [0., 1.].
+ Sums of ratios of opposed side must be < 1.
+
+ :param float left: Left-side margin ratio.
+ :param float top: Top margin ratio
+ :param float right: Right-side margin ratio
+ :param float bottom: Bottom margin ratio
+ :raises ValueError:
+ """
+ for value in (left, top, right, bottom):
+ if value < 0. or value > 1.:
+ raise ValueError("Margin ratios must be within [0., 1.]")
+ if left + right >= 1. or top + bottom >= 1.:
+ raise ValueError("Sum of ratios of opposed sides >= 1")
+ margins = left, top, right, bottom
+
+ if margins != self.__axesMargins:
+ self.__axesMargins = margins
+ if self.isAxesDisplayed(): # Only apply if axes are displayed
+ self._backend.setAxesMargins(*margins)
+ self._setDirtyPlot()
+
+ def getAxesMargins(self):
+ """Returns ratio of margins surrounding data plot area.
+
+ :return: (left, top, right, bottom)
+ :rtype: List[float]
+ """
+ return self.__axesMargins
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._yAxis.setInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self._yAxis.isInverted()
+
+ def isXAxisLogarithmic(self):
+ """Return True if X axis scale is logarithmic, False if linear."""
+ return self._xAxis._isLogarithmic()
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the bottom X axis scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ self._xAxis._setLogarithmic(flag)
+
+ def isYAxisLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self._yAxis._isLogarithmic()
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ self._yAxis._setLogarithmic(flag)
+
+ def isXAxisAutoScale(self):
+ """Return True if X axis is automatically adjusting its limits."""
+ return self._xAxis.isAutoScale()
+
+ def setXAxisAutoScale(self, flag=True):
+ """Set the X axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._xAxis.setAutoScale(flag)
+
+ def isYAxisAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self._yAxis.isAutoScale()
+
+ def setYAxisAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._yAxis.setAutoScale(flag)
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self._backend.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ flag = bool(flag)
+ if flag == self.isKeepDataAspectRatio():
+ return
+ self._backend.setKeepDataAspectRatio(flag=flag)
+ self._setDirtyPlot()
+ self._forceResetZoom()
+ self.notify('setKeepDataAspectRatio', state=flag)
+
+ def getGraphGrid(self):
+ """Return the current grid mode, either None, 'major' or 'both'.
+
+ See :meth:`setGraphGrid`.
+ """
+ return self._grid
+
+ def setGraphGrid(self, which=True):
+ """Set the type of grid to display.
+
+ :param which: None or False to disable the grid,
+ 'major' or True for grid on major ticks (the default),
+ 'both' for grid on both major and minor ticks.
+ :type which: str of bool
+ """
+ assert which in (None, True, False, 'both', 'major')
+ if not which:
+ which = None
+ elif which is True:
+ which = 'major'
+ self._grid = which
+ self._backend.setGraphGrid(which)
+ self._setDirtyPlot()
+ self.notify('setGraphGrid', which=str(which))
+
+ # Defaults
+
+ def isDefaultPlotPoints(self):
+ """Return True if the default Curve symbol is set and False if not."""
+ return self._defaultPlotPoints == silx.config.DEFAULT_PLOT_SYMBOL
+
+ def setDefaultPlotPoints(self, flag):
+ """Set the default symbol of all curves.
+
+ When called, this reset the symbol of all existing curves.
+
+ :param bool flag: True to use 'o' as the default curve symbol,
+ False to use no symbol.
+ """
+ self._defaultPlotPoints = silx.config.DEFAULT_PLOT_SYMBOL if flag else ''
+
+ # Reset symbol of all curves
+ curves = self.getAllCurves(just_legend=False, withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setSymbol(self._defaultPlotPoints)
+
+ def isDefaultPlotLines(self):
+ """Return True for line as default line style, False for no line."""
+ return self._plotLines
+
+ def setDefaultPlotLines(self, flag):
+ """Toggle the use of lines as the default curve line style.
+
+ :param bool flag: True to use a line as the default line style,
+ False to use no line as the default line style.
+ """
+ self._plotLines = bool(flag)
+
+ linestyle = '-' if self._plotLines else ' '
+
+ # Reset linestyle of all curves
+ curves = self.getAllCurves(withhidden=True)
+
+ if curves:
+ for curve in curves:
+ curve.setLineStyle(linestyle)
+
+ def getDefaultColormap(self):
+ """Return the default colormap used by :meth:`addImage`.
+
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self._defaultColormap
+
+ def setDefaultColormap(self, colormap=None):
+ """Set the default colormap used by :meth:`addImage`.
+
+ Setting the default colormap do not change any currently displayed
+ image.
+ It only affects future calls to :meth:`addImage` without the colormap
+ parameter.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The description of the default colormap, or
+ None to set the colormap to a linear
+ autoscale gray colormap.
+ """
+ if colormap is None:
+ colormap = Colormap(name=silx.config.DEFAULT_COLORMAP_NAME,
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ if isinstance(colormap, dict):
+ self._defaultColormap = Colormap._fromDict(colormap)
+ else:
+ assert isinstance(colormap, Colormap)
+ self._defaultColormap = colormap
+ self.notify('defaultColormapChanged')
+
+ @staticmethod
+ def getSupportedColormaps():
+ """Get the supported colormap names as a tuple of str.
+
+ The list contains at least:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue',
+ 'magma', 'inferno', 'plasma', 'viridis')
+ """
+ return Colormap.getSupportedColormaps()
+
+ def _resetColorAndStyle(self):
+ self._colorIndex = 0
+ self._styleIndex = 0
+
+ def _getColorAndStyle(self):
+ color = self.colorList[self._colorIndex]
+ style = self._styleList[self._styleIndex]
+
+ # Loop over color and then styles
+ self._colorIndex += 1
+ if self._colorIndex >= len(self.colorList):
+ self._colorIndex = 0
+ self._styleIndex = (self._styleIndex + 1) % len(self._styleList)
+
+ # If color is the one of active curve, take the next one
+ if colors.rgba(color) == self.getActiveCurveStyle().getColor():
+ color, style = self._getColorAndStyle()
+
+ if not self._plotLines:
+ style = ' '
+
+ return color, style
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget the plot is displayed in.
+
+ This widget is owned by the backend.
+ """
+ return self._backend.getWidgetHandle()
+
+ def notify(self, event, **kwargs):
+ """Send an event to the listeners and send signals.
+
+ Event are passed to the registered callback as a dict with an 'event'
+ key for backward compatibility with PyMca.
+
+ :param str event: The type of event
+ :param kwargs: The information of the event.
+ """
+ eventDict = kwargs.copy()
+ eventDict['event'] = event
+ self.sigPlotSignal.emit(eventDict)
+
+ if event == 'setKeepDataAspectRatio':
+ self.sigSetKeepDataAspectRatio.emit(kwargs['state'])
+ elif event == 'setGraphGrid':
+ self.sigSetGraphGrid.emit(kwargs['which'])
+ elif event == 'setGraphCursor':
+ self.sigSetGraphCursor.emit(kwargs['state'])
+ elif event == 'contentChanged':
+ self.sigContentChanged.emit(
+ kwargs['action'], kwargs['kind'], kwargs['legend'])
+ elif event == 'activeCurveChanged':
+ self.sigActiveCurveChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'activeImageChanged':
+ self.sigActiveImageChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'activeScatterChanged':
+ self.sigActiveScatterChanged.emit(
+ kwargs['previous'], kwargs['legend'])
+ elif event == 'interactiveModeChanged':
+ self.sigInteractiveModeChanged.emit(kwargs['source'])
+
+ eventDict = kwargs.copy()
+ eventDict['event'] = event
+ self._callback(eventDict)
+
+ def setCallback(self, callbackFunction=None):
+ """Attach a listener to the backend.
+
+ Limitation: Only one listener at a time.
+
+ :param callbackFunction: function accepting a dictionary as input
+ to handle the graph events
+ If None (default), use a default listener.
+ """
+ # TODO allow multiple listeners
+ # allow register listener by event type
+ if callbackFunction is None:
+ callbackFunction = WeakMethodProxy(self.graphCallback)
+ self._callback = callbackFunction
+
+ def graphCallback(self, ddict=None):
+ """This callback is going to receive all the events from the plot.
+
+ Those events will consist on a dictionary and among the dictionary
+ keys the key 'event' is mandatory to describe the type of event.
+ This default implementation only handles setting the active curve.
+ """
+
+ if ddict is None:
+ ddict = {}
+ _logger.debug("Received dict keys = %s", str(ddict.keys()))
+ _logger.debug(str(ddict))
+ if ddict['event'] in ["legendClicked", "curveClicked"]:
+ if ddict['button'] == "left":
+ self.setActiveCurve(ddict['label'])
+ qt.QToolTip.showText(self.cursor().pos(), ddict['label'])
+ elif ddict['event'] == 'mouseClicked' and ddict['button'] == 'left':
+ self.setActiveCurve(None)
+
+ def saveGraph(self, filename, fileFormat=None, dpi=None):
+ """Save a snapshot of the plot.
+
+ Supported file formats depends on the backend in use.
+ The following file formats are always supported: "png", "svg".
+ The matplotlib backend supports more formats:
+ "pdf", "ps", "eps", "tiff", "jpeg", "jpg".
+
+ :param filename: Destination
+ :type filename: str, StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :return: False if cannot save the plot, True otherwise
+ """
+ if fileFormat is None:
+ if not hasattr(filename, 'lower'):
+ _logger.warning(
+ 'saveGraph cancelled, cannot define file format.')
+ return False
+ else:
+ fileFormat = (filename.split(".")[-1]).lower()
+
+ supportedFormats = ("png", "svg", "pdf", "ps", "eps",
+ "tif", "tiff", "jpeg", "jpg")
+
+ if fileFormat not in supportedFormats:
+ _logger.warning('Unsupported format %s', fileFormat)
+ return False
+ else:
+ self._backend.saveGraph(filename,
+ fileFormat=fileFormat,
+ dpi=dpi)
+ return True
+
+ def getDataMargins(self):
+ """Get the default data margin ratios, see :meth:`setDataMargins`.
+
+ :return: The margin ratios for each side (xMin, xMax, yMin, yMax).
+ :rtype: A 4-tuple of floats.
+ """
+ return self._defaultDataMargins
+
+ def setDataMargins(self, xMinMargin=0., xMaxMargin=0.,
+ yMinMargin=0., yMaxMargin=0.):
+ """Set the default data margins to use in :meth:`resetZoom`.
+
+ Set the default ratios of margins (as floats) to add around the data
+ inside the plot area for each side.
+ """
+ self._defaultDataMargins = (xMinMargin, xMaxMargin,
+ yMinMargin, yMaxMargin)
+
+ def getAutoReplot(self):
+ """Return True if replot is automatically handled, False otherwise.
+
+ See :meth`setAutoReplot`.
+ """
+ return self._autoreplot
+
+ def setAutoReplot(self, autoreplot=True):
+ """Set automatic replot mode.
+
+ When enabled, the plot is redrawn automatically when changed.
+ When disabled, the plot is not redrawn when its content change.
+ Instead, it :meth:`replot` must be called.
+
+ :param bool autoreplot: True to enable it (default),
+ False to disable it.
+ """
+ self._autoreplot = bool(autoreplot)
+
+ # If the plot is dirty before enabling autoreplot,
+ # then _backend.postRedisplay will never be called from _setDirtyPlot
+ if self._autoreplot and self._getDirtyPlot():
+ self._backend.postRedisplay()
+
+ @contextmanager
+ def _paintContext(self):
+ """This context MUST surround backend rendering.
+
+ It is in charge of performing required PlotWidget operations
+ """
+ for item in self._contentToUpdate:
+ item._update(self._backend)
+
+ self._contentToUpdate = []
+ yield
+ self._dirty = False # reset dirty flag
+
+ def replot(self):
+ """Request to draw the plot."""
+ self._backend.replot()
+
+ def _forceResetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ This method forces a reset zoom and does not check axis autoscale.
+
+ Extra margins can be added around the data inside the plot area
+ (see :meth:`setDataMargins`).
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins: Ratios of margins to add around the data inside
+ the plot area for each side (default: no margins).
+ :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ """
+ if dataMargins is None:
+ dataMargins = self._defaultDataMargins
+
+ # Get data range
+ ranges = self.getDataRange()
+ xmin, xmax = (1., 100.) if ranges.x is None else ranges.x
+ ymin, ymax = (1., 100.) if ranges.y is None else ranges.y
+ if ranges.yright is None:
+ ymin2, ymax2 = ymin, ymax
+ else:
+ ymin2, ymax2 = ranges.yright
+ if ranges.y is None:
+ ymin, ymax = ranges.yright
+
+ # Add margins around data inside the plot area
+ newLimits = list(_utils.addMarginsToLimits(
+ dataMargins,
+ self._xAxis._isLogarithmic(),
+ self._yAxis._isLogarithmic(),
+ xmin, xmax, ymin, ymax, ymin2, ymax2))
+
+ if self.isKeepDataAspectRatio():
+ # Use limits with margins to keep ratio
+ xmin, xmax, ymin, ymax = newLimits[:4]
+
+ # Compute bbox wth figure aspect ratio
+ plotWidth, plotHeight = self.getPlotBoundsInPixels()[2:]
+ if plotWidth > 0 and plotHeight > 0:
+ plotRatio = plotHeight / plotWidth
+ dataRatio = (ymax - ymin) / (xmax - xmin)
+ if dataRatio < plotRatio:
+ # Increase y range
+ ycenter = 0.5 * (ymax + ymin)
+ yrange = (xmax - xmin) * plotRatio
+ newLimits[2] = ycenter - 0.5 * yrange
+ newLimits[3] = ycenter + 0.5 * yrange
+
+ elif dataRatio > plotRatio:
+ # Increase x range
+ xcenter = 0.5 * (xmax + xmin)
+ xrange_ = (ymax - ymin) / plotRatio
+ newLimits[0] = xcenter - 0.5 * xrange_
+ newLimits[1] = xcenter + 0.5 * xrange_
+
+ self.setLimits(*newLimits)
+
+ def resetZoom(self, dataMargins=None):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ It automatically scale limits of axes that are in autoscale mode
+ (see :meth:`getXAxis`, :meth:`getYAxis` and :meth:`Axis.setAutoScale`).
+ It keeps current limits on axes that are not in autoscale mode.
+
+ Extra margins can be added around the data inside the plot area
+ (see :meth:`setDataMargins`).
+ Margins are given as one ratio of the data range per limit of the
+ data (xMin, xMax, yMin and yMax limits).
+ For log scale, extra margins are applied in log10 of the data.
+
+ :param dataMargins: Ratios of margins to add around the data inside
+ the plot area for each side (default: no margins).
+ :type dataMargins: A 4-tuple of float as (xMin, xMax, yMin, yMax).
+ """
+ xLimits = self._xAxis.getLimits()
+ yLimits = self._yAxis.getLimits()
+ y2Limits = self._yRightAxis.getLimits()
+
+ xAuto = self._xAxis.isAutoScale()
+ yAuto = self._yAxis.isAutoScale()
+
+ # With log axes, autoscale if limits are <= 0
+ # This avoids issues with toggling log scale with matplotlib 2.1.0
+ if self._xAxis.getScale() == self._xAxis.LOGARITHMIC and xLimits[0] <= 0:
+ xAuto = True
+ if self._yAxis.getScale() == self._yAxis.LOGARITHMIC and (yLimits[0] <= 0 or y2Limits[0] <= 0):
+ yAuto = True
+
+ if not xAuto and not yAuto:
+ _logger.debug("Nothing to autoscale")
+ else: # Some axes to autoscale
+ self._forceResetZoom(dataMargins=dataMargins)
+
+ # Restore limits for axis not in autoscale
+ if not xAuto and yAuto:
+ self.setGraphXLimits(*xLimits)
+ elif xAuto and not yAuto:
+ if y2Limits is not None:
+ self.setGraphYLimits(
+ y2Limits[0], y2Limits[1], axis='right')
+ if yLimits is not None:
+ self.setGraphYLimits(yLimits[0], yLimits[1], axis='left')
+
+ if (xLimits != self._xAxis.getLimits() or
+ yLimits != self._yAxis.getLimits() or
+ y2Limits != self._yRightAxis.getLimits()):
+ self._notifyLimitsChanged()
+
+ # Coord conversion
+
+ def dataToPixel(self, x=None, y=None, axis="left", check=True):
+ """Convert a position in data coordinates to a position in pixels.
+
+ :param float x: The X coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :param float y: The Y coordinate in data space. If None (default)
+ the middle position of the displayed data is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: True to return None if outside displayed area,
+ False to convert to pixels anyway
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area and
+ check is True.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ assert axis in ("left", "right")
+
+ xmin, xmax = self._xAxis.getLimits()
+ yAxis = self.getYAxis(axis=axis)
+ ymin, ymax = yAxis.getLimits()
+
+ if x is None:
+ x = 0.5 * (xmax + xmin)
+ if y is None:
+ y = 0.5 * (ymax + ymin)
+
+ if check:
+ if x > xmax or x < xmin:
+ return None
+
+ if y > ymax or y < ymin:
+ return None
+
+ return self._backend.dataToPixel(x, y, axis=axis)
+
+ def pixelToData(self, x, y, axis="left", check=False):
+ """Convert a position in pixels to a position in data coordinates.
+
+ :param float x: The X coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param float y: The Y coordinate in pixels. If None (default)
+ the center of the widget is used.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :param bool check: Toggle checking if pixel is in plot area.
+ If False, this method never returns None.
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ assert axis in ("left", "right")
+
+ if x is None:
+ x = self.width() // 2
+ if y is None:
+ y = self.height() // 2
+
+ if check:
+ left, top, width, height = self.getPlotBoundsInPixels()
+ if not (left <= x <= left + width and top <= y <= top + height):
+ return None
+
+ return self._backend.pixelToData(x, y, axis)
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ return self._backend.getPlotBoundsInPixels()
+
+ # Interaction support
+
+ def getGraphCursorShape(self):
+ """Returns the current cursor shape.
+
+ :rtype: str
+ """
+ return self.__graphCursorShape
+
+ def setGraphCursorShape(self, cursor=None):
+ """Set the cursor shape.
+
+ :param str cursor: Name of the cursor shape
+ """
+ self.__graphCursorShape = cursor
+ self._backend.setGraphCursorShape(cursor)
+
+ @deprecated(replacement='getItems', since_version='0.13')
+ def _getAllMarkers(self, just_legend=False):
+ markers = [item for item in self.getItems() if isinstance(item, items.MarkerBase)]
+ if just_legend:
+ return [marker.getName() for marker in markers]
+ else:
+ return markers
+
+ def _getMarkerAt(self, x, y):
+ """Return the most interactive marker at a location, else None
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :rtype: None of marker object
+ """
+ def checkDraggable(item):
+ return isinstance(item, items.MarkerBase) and item.isDraggable()
+ def checkSelectable(item):
+ return isinstance(item, items.MarkerBase) and item.isSelectable()
+ def check(item):
+ return isinstance(item, items.MarkerBase)
+
+ result = self._pickTopMost(x, y, checkDraggable)
+ if not result:
+ result = self._pickTopMost(x, y, checkSelectable)
+ if not result:
+ result = self._pickTopMost(x, y, check)
+ marker = result.getItem() if result is not None else None
+ return marker
+
+ def _getMarker(self, legend=None):
+ """Get the object describing a specific marker.
+
+ It returns None in case no matching marker is found
+
+ :param str legend: The legend of the marker to retrieve
+ :rtype: None of marker object
+ """
+ return self._getItem(kind='marker', legend=legend)
+
+ def pickItems(self, x, y, condition=None):
+ """Generator of picked items in the plot at given position.
+
+ Items are returned from front to back.
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :param callable condition:
+ Callable taking an item as input and returning False for items
+ to skip during picking. If None (default) no item is skipped.
+ :return: Iterable of :class:`PickingResult` objects at picked position.
+ Items are ordered from front to back.
+ """
+ for item in reversed(self._backend.getItemsFromBackToFront(condition=condition)):
+ result = item.pick(x, y)
+ if result is not None:
+ yield result
+
+ def _pickTopMost(self, x, y, condition=None):
+ """Returns top-most picked item in the plot at given position.
+
+ Items are checked from front to back.
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :param callable condition:
+ Callable taking an item as input and returning False for items
+ to skip during picking. If None (default) no item is skipped.
+ :return: :class:`PickingResult` object at picked position.
+ If no item is picked, it returns None
+ :rtype: Union[None,PickingResult]
+ """
+ for result in self.pickItems(x, y, condition):
+ return result
+ return None
+
+ # User event handling #
+
+ def _isPositionInPlotArea(self, x, y):
+ """Project position in pixel to the closest point in the plot area
+
+ :param float x: X coordinate in widget coordinate (in pixel)
+ :param float y: Y coordinate in widget coordinate (in pixel)
+ :return: (x, y) in widget coord (in pixel) in the plot area
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width)
+ yPlot = numpy.clip(y, top, top + height)
+ return xPlot, yPlot
+
+ def onMousePress(self, xPixel, yPixel, btn):
+ """Handle mouse press event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._pressedButtons.append(btn)
+ self._eventHandler.handleEvent('press', xPixel, yPixel, btn)
+
+ def onMouseMove(self, xPixel, yPixel):
+ """Handle mouse move event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ """
+ inXPixel, inYPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ isCursorInPlot = inXPixel == xPixel and inYPixel == yPixel
+
+ if self._cursorInPlot != isCursorInPlot:
+ self._cursorInPlot = isCursorInPlot
+ self._eventHandler.handleEvent(
+ 'enter' if self._cursorInPlot else 'leave')
+
+ if isCursorInPlot:
+ # Signal mouse move event
+ dataPos = self.pixelToData(inXPixel, inYPixel)
+ assert dataPos is not None
+
+ btn = self._pressedButtons[-1] if self._pressedButtons else None
+ event = PlotEvents.prepareMouseSignal(
+ 'mouseMoved', btn, dataPos[0], dataPos[1], xPixel, yPixel)
+ self.notify(**event)
+
+ # Either button was pressed in the plot or cursor is in the plot
+ if isCursorInPlot or self._pressedButtons:
+ self._eventHandler.handleEvent('move', inXPixel, inYPixel)
+
+ def onMouseRelease(self, xPixel, yPixel, btn):
+ """Handle mouse release event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param str btn: Mouse button in 'left', 'middle', 'right'
+ """
+ try:
+ self._pressedButtons.remove(btn)
+ except ValueError:
+ pass
+ else:
+ xPixel, yPixel = self._isPositionInPlotArea(xPixel, yPixel)
+ self._eventHandler.handleEvent('release', xPixel, yPixel, btn)
+
+ def onMouseWheel(self, xPixel, yPixel, angleInDegrees):
+ """Handle mouse wheel event.
+
+ :param float xPixel: X mouse position in pixels
+ :param float yPixel: Y mouse position in pixels
+ :param float angleInDegrees: Angle corresponding to wheel motion.
+ Positive for movement away from the user,
+ negative for movement toward the user.
+ """
+ if self._isPositionInPlotArea(xPixel, yPixel) == (xPixel, yPixel):
+ self._eventHandler.handleEvent(
+ 'wheel', xPixel, yPixel, angleInDegrees)
+
+ def onMouseLeaveWidget(self):
+ """Handle mouse leave widget event."""
+ if self._cursorInPlot:
+ self._cursorInPlot = False
+ self._eventHandler.handleEvent('leave')
+
+ # Interaction modes #
+
+ def getInteractiveMode(self):
+ """Returns the current interactive mode as a dict.
+
+ The returned dict contains at least the key 'mode'.
+ Mode can be: 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ It can also contains extra keys (e.g., 'color') specific to a mode
+ as provided to :meth:`setInteractiveMode`.
+ """
+ return self._eventHandler.getInteractiveMode()
+
+ def resetInteractiveMode(self):
+ """Reset the interactive mode to use the previous basic interactive
+ mode used.
+
+ It can be one of "zoom" or "pan".
+ """
+ mode, zoomOnWheel = self._previousDefaultMode
+ self.setInteractiveMode(mode=mode, zoomOnWheel=zoomOnWheel)
+
+ def setInteractiveMode(self, mode, color='black',
+ shape='polygon', label=None,
+ zoomOnWheel=True, source=None, width=None):
+ """Switch the interactive mode.
+
+ :param str mode: The name of the interactive mode.
+ In 'draw', 'pan', 'select', 'select-draw', 'zoom'.
+ :param color: Only for 'draw' and 'zoom' modes.
+ Color to use for drawing selection area. Default black.
+ :type color: Color description: The name as a str or
+ a tuple of 4 floats.
+ :param str shape: Only for 'draw' mode. The kind of shape to draw.
+ In 'polygon', 'rectangle', 'line', 'vline', 'hline',
+ 'freeline'.
+ Default is 'polygon'.
+ :param str label: Only for 'draw' mode, sent in drawing events.
+ :param bool zoomOnWheel: Toggle zoom on wheel support
+ :param source: A user-defined object (typically the caller object)
+ that will be send in the interactiveModeChanged event,
+ to identify which object required a mode change.
+ Default: None
+ :param float width: Width of the pencil. Only for draw pencil mode.
+ """
+ self._eventHandler.setInteractiveMode(mode, color, shape, label, width)
+ self._eventHandler.zoomOnWheel = zoomOnWheel
+ if mode in ["pan", "zoom"]:
+ self._previousDefaultMode = mode, zoomOnWheel
+
+ self.notify(
+ 'interactiveModeChanged', source=source)
+
+ # Panning with arrow keys
+
+ def isPanWithArrowKeys(self):
+ """Returns whether or not panning the graph with arrow keys is enabled.
+
+ See :meth:`setPanWithArrowKeys`.
+ """
+ return self._panWithArrowKeys
+
+ def setPanWithArrowKeys(self, pan=False):
+ """Enable/Disable panning the graph with arrow keys.
+
+ This grabs the keyboard.
+
+ :param bool pan: True to enable panning, False to disable.
+ """
+ pan = bool(pan)
+ panHasChanged = self._panWithArrowKeys != pan
+
+ self._panWithArrowKeys = pan
+ if not self._panWithArrowKeys:
+ self.setFocusPolicy(qt.Qt.NoFocus)
+ else:
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self.setFocus(qt.Qt.OtherFocusReason)
+
+ if panHasChanged:
+ self.sigSetPanWithArrowKeys.emit(pan)
+
+ # Dict to convert Qt arrow key code to direction str.
+ _ARROWS_TO_PAN_DIRECTION = {
+ qt.Qt.Key_Left: 'left',
+ qt.Qt.Key_Right: 'right',
+ qt.Qt.Key_Up: 'up',
+ qt.Qt.Key_Down: 'down'
+ }
+
+ def __simulateMouseMove(self):
+ qapp = qt.QApplication.instance()
+ event = qt.QMouseEvent(
+ qt.QEvent.MouseMove,
+ self.getWidgetHandle().mapFromGlobal(qt.QCursor.pos()),
+ qt.Qt.NoButton,
+ qapp.mouseButtons(),
+ qapp.keyboardModifiers())
+ qapp.sendEvent(self.getWidgetHandle(), event)
+
+ def keyPressEvent(self, event):
+ """Key event handler handling panning on arrow keys.
+
+ Overrides base class implementation.
+ """
+ key = event.key()
+ if self._panWithArrowKeys and key in self._ARROWS_TO_PAN_DIRECTION:
+ self.pan(self._ARROWS_TO_PAN_DIRECTION[key], factor=0.1)
+
+ # Send a mouse move event to the plot widget to take into account
+ # that even if mouse didn't move on the screen, it moved relative
+ # to the plotted data.
+ self.__simulateMouseMove()
+ else:
+ # Only call base class implementation when key is not handled.
+ # See QWidget.keyPressEvent for details.
+ super(PlotWidget, self).keyPressEvent(event)
diff --git a/src/silx/gui/plot/PlotWindow.py b/src/silx/gui/plot/PlotWindow.py
new file mode 100644
index 0000000..0349585
--- /dev/null
+++ b/src/silx/gui/plot/PlotWindow.py
@@ -0,0 +1,993 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A :class:`.PlotWidget` with additional toolbars.
+
+The :class:`PlotWindow` is a subclass of :class:`.PlotWidget`.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "12/04/2019"
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+import weakref
+
+import silx
+from silx.utils.weakref import WeakMethodProxy
+from silx.utils.deprecation import deprecated
+from silx.utils.proxy import docstring
+
+from . import PlotWidget
+from . import actions
+from . import items
+from .actions import medfilt as actions_medfilt
+from .actions import fit as actions_fit
+from .actions import control as actions_control
+from .actions import histogram as actions_histogram
+from . import PlotToolButtons
+from . import tools
+from .Profile import ProfileToolBar
+from .LegendSelector import LegendsDockWidget
+from .CurvesROIWidget import CurvesROIDockWidget
+from .MaskToolsWidget import MaskToolsDockWidget
+from .StatsWidget import BasicStatsWidget
+from .ColorBar import ColorBarWidget
+try:
+ from ..console import IPythonDockWidget
+except ImportError:
+ IPythonDockWidget = None
+
+from .. import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PlotWindow(PlotWidget):
+ """Qt Widget providing a 1D/2D plot area and additional tools.
+
+ This widgets inherits from :class:`.PlotWidget` and provides its plot API.
+
+ Initialiser parameters:
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool curveStyle: Toggle visibility of curve style action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`~silx.gui.plot.tools.PositionInfo`.
+ :param bool roi: Toggle visibilty of ROI action.
+ :param bool mask: Toggle visibilty of mask action.
+ :param bool fit: Toggle visibilty of fit action.
+ """
+
+ def __init__(self, parent=None, backend=None,
+ resetzoom=True, autoScale=True, logScale=True, grid=True,
+ curveStyle=True, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=False,
+ roi=True, mask=True, fit=False):
+ super(PlotWindow, self).__init__(parent=parent, backend=backend)
+ if parent is None:
+ self.setWindowTitle('PlotWindow')
+
+ self._dockWidgets = []
+
+ # lazy loaded dock widgets
+ self._legendsDockWidget = None
+ self._curvesROIDockWidget = None
+ self._maskToolsDockWidget = None
+ self._consoleDockWidget = None
+ self._statsDockWidget = None
+
+ # Create color bar, hidden by default for backward compatibility
+ self._colorbar = ColorBarWidget(parent=self, plot=self)
+
+ # Init actions
+ self.group = qt.QActionGroup(self)
+ self.group.setExclusive(False)
+
+ self.resetZoomAction = self.group.addAction(
+ actions.control.ResetZoomAction(self, parent=self))
+ self.resetZoomAction.setVisible(resetzoom)
+ self.addAction(self.resetZoomAction)
+
+ self.zoomInAction = actions.control.ZoomInAction(self, parent=self)
+ self.addAction(self.zoomInAction)
+
+ self.zoomOutAction = actions.control.ZoomOutAction(self, parent=self)
+ self.addAction(self.zoomOutAction)
+
+ self.xAxisAutoScaleAction = self.group.addAction(
+ actions.control.XAxisAutoScaleAction(self, parent=self))
+ self.xAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.xAxisAutoScaleAction)
+
+ self.yAxisAutoScaleAction = self.group.addAction(
+ actions.control.YAxisAutoScaleAction(self, parent=self))
+ self.yAxisAutoScaleAction.setVisible(autoScale)
+ self.addAction(self.yAxisAutoScaleAction)
+
+ self.xAxisLogarithmicAction = self.group.addAction(
+ actions.control.XAxisLogarithmicAction(self, parent=self))
+ self.xAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.xAxisLogarithmicAction)
+
+ self.yAxisLogarithmicAction = self.group.addAction(
+ actions.control.YAxisLogarithmicAction(self, parent=self))
+ self.yAxisLogarithmicAction.setVisible(logScale)
+ self.addAction(self.yAxisLogarithmicAction)
+
+ self.gridAction = self.group.addAction(
+ actions.control.GridAction(self, gridMode='both', parent=self))
+ self.gridAction.setVisible(grid)
+ self.addAction(self.gridAction)
+
+ self.curveStyleAction = self.group.addAction(
+ actions.control.CurveStyleAction(self, parent=self))
+ self.curveStyleAction.setVisible(curveStyle)
+ self.addAction(self.curveStyleAction)
+
+ self.colormapAction = self.group.addAction(
+ actions.control.ColormapAction(self, parent=self))
+ self.colormapAction.setVisible(colormap)
+ self.addAction(self.colormapAction)
+
+ self.colorbarAction = self.group.addAction(
+ actions_control.ColorBarAction(self, parent=self))
+ self.colorbarAction.setVisible(False)
+ self.addAction(self.colorbarAction)
+ self._colorbar.setVisible(False)
+
+ self.keepDataAspectRatioButton = PlotToolButtons.AspectToolButton(
+ parent=self, plot=self)
+ self.keepDataAspectRatioButton.setVisible(aspectRatio)
+
+ self.yAxisInvertedButton = PlotToolButtons.YAxisOriginToolButton(
+ parent=self, plot=self)
+ self.yAxisInvertedButton.setVisible(yInverted)
+
+ self.group.addAction(self.getRoiAction())
+ self.getRoiAction().setVisible(roi)
+
+ self.group.addAction(self.getMaskAction())
+ self.getMaskAction().setVisible(mask)
+
+ self._intensityHistoAction = self.group.addAction(
+ actions_histogram.PixelIntensitiesHistoAction(self, parent=self))
+ self._intensityHistoAction.setVisible(False)
+
+ self._medianFilter2DAction = self.group.addAction(
+ actions_medfilt.MedianFilter2DAction(self, parent=self))
+ self._medianFilter2DAction.setVisible(False)
+
+ self._medianFilter1DAction = self.group.addAction(
+ actions_medfilt.MedianFilter1DAction(self, parent=self))
+ self._medianFilter1DAction.setVisible(False)
+
+ self.fitAction = self.group.addAction(actions_fit.FitAction(self, parent=self))
+ self.fitAction.setVisible(fit)
+ self.addAction(self.fitAction)
+
+ # lazy loaded actions needed by the controlButton menu
+ self._consoleAction = None
+ self._statsAction = None
+ self._panWithArrowKeysAction = None
+ self._crosshairAction = None
+
+ # Make colorbar background white
+ self._colorbar.setAutoFillBackground(True)
+ self._sigAxesVisibilityChanged.connect(self._updateColorBarBackground)
+ self._updateColorBarBackground()
+
+ if control: # Create control button only if requested
+ self.controlButton = qt.QToolButton()
+ self.controlButton.setText("Options")
+ self.controlButton.setToolButtonStyle(qt.Qt.ToolButtonTextBesideIcon)
+ self.controlButton.setAutoRaise(True)
+ self.controlButton.setPopupMode(qt.QToolButton.InstantPopup)
+ menu = qt.QMenu(self)
+ menu.aboutToShow.connect(self._customControlButtonMenu)
+ self.controlButton.setMenu(menu)
+
+ self._positionWidget = None
+ if position: # Add PositionInfo widget to the bottom of the plot
+ if isinstance(position, abc.Iterable):
+ # Use position as a set of converters
+ converters = position
+ else:
+ converters = None
+ self._positionWidget = tools.PositionInfo(
+ plot=self, converters=converters)
+ # Set a snapping mode that is consistent with legacy one
+ self._positionWidget.setSnappingMode(
+ tools.PositionInfo.SNAPPING_CROSSHAIR |
+ tools.PositionInfo.SNAPPING_ACTIVE_ONLY |
+ tools.PositionInfo.SNAPPING_SYMBOLS_ONLY |
+ tools.PositionInfo.SNAPPING_CURVE |
+ tools.PositionInfo.SNAPPING_SCATTER)
+
+ self.__setCentralWidget()
+
+ # Creating the toolbar also create actions for toolbuttons
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=self)
+ self.addToolBar(self._interactiveModeToolBar)
+
+ self._toolbar = self._createToolBar(title='Plot', parent=self)
+ self.addToolBar(self._toolbar)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=self)
+ self._outputToolBar.getCopyAction().setVisible(copy)
+ self._outputToolBar.getSaveAction().setVisible(save)
+ self._outputToolBar.getPrintAction().setVisible(print_)
+ self.addToolBar(self._outputToolBar)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (self._interactiveModeToolBar, self._outputToolBar):
+ for action in toolbar.actions():
+ self.addAction(action)
+
+ def __setCentralWidget(self):
+ """Set central widget to host plot backend, colorbar, and bottom bar"""
+ gridLayout = qt.QGridLayout()
+ gridLayout.setSpacing(0)
+ gridLayout.setContentsMargins(0, 0, 0, 0)
+ gridLayout.addWidget(self.getWidgetHandle(), 0, 0)
+ gridLayout.addWidget(self._colorbar, 0, 1)
+ gridLayout.setRowStretch(0, 1)
+ gridLayout.setColumnStretch(0, 1)
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(gridLayout)
+
+ if hasattr(self, "controlButton") or self._positionWidget is not None:
+ hbox = qt.QHBoxLayout()
+ hbox.setContentsMargins(0, 0, 0, 0)
+
+ if hasattr(self, "controlButton"):
+ hbox.addWidget(self.controlButton)
+
+ if self._positionWidget is not None:
+ hbox.addWidget(self._positionWidget)
+
+ hbox.addStretch(1)
+ bottomBar = qt.QWidget(centralWidget)
+ bottomBar.setLayout(hbox)
+
+ gridLayout.addWidget(bottomBar, 1, 0, 1, -1)
+
+ self.setCentralWidget(centralWidget)
+
+ @docstring(PlotWidget)
+ def setBackend(self, backend):
+ super(PlotWindow, self).setBackend(backend)
+ self.__setCentralWidget() # Recreate PlotWindow's central widget
+
+ @docstring(PlotWidget)
+ def setBackgroundColor(self, color):
+ super(PlotWindow, self).setBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ @docstring(PlotWidget)
+ def setDataBackgroundColor(self, color):
+ super(PlotWindow, self).setDataBackgroundColor(color)
+ self._updateColorBarBackground()
+
+ @docstring(PlotWidget)
+ def setForegroundColor(self, color):
+ super(PlotWindow, self).setForegroundColor(color)
+ self._updateColorBarBackground()
+
+ def _updateColorBarBackground(self):
+ """Update the colorbar background according to the state of the plot"""
+ if self.isAxesDisplayed():
+ color = self.getBackgroundColor()
+ else:
+ color = self.getDataBackgroundColor()
+ if not color.isValid():
+ # If no color defined, use the background one
+ color = self.getBackgroundColor()
+
+ foreground = self.getForegroundColor()
+
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Window, color)
+ palette.setColor(qt.QPalette.WindowText, foreground)
+ palette.setColor(qt.QPalette.Text, foreground)
+ self._colorbar.setPalette(palette)
+
+ def getInteractiveModeToolBar(self):
+ """Returns QToolBar controlling interactive mode.
+
+ :rtype: QToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getOutputToolBar(self):
+ """Returns QToolBar containing save, copy and print actions
+
+ :rtype: QToolBar
+ """
+ return self._outputToolBar
+
+ @property
+ @deprecated(replacement="getPositionInfoWidget()", since_version="0.8.0")
+ def positionWidget(self):
+ return self.getPositionInfoWidget()
+
+ def getPositionInfoWidget(self):
+ """Returns the widget displaying current cursor position information
+
+ :rtype: ~silx.gui.plot.tools.PositionInfo
+ """
+ return self._positionWidget
+
+ def getSelectionMask(self):
+ """Return the current mask handled by :attr:`maskToolsDockWidget`.
+
+ :return: The array of the mask with dimension of the 'active' image.
+ If there is no active image, an empty array is returned.
+ :rtype: 2D numpy.ndarray of uint8
+ """
+ return self.getMaskToolsDockWidget().getSelectionMask()
+
+ def setSelectionMask(self, mask):
+ """Set the mask handled by :attr:`maskToolsDockWidget`.
+
+ If the provided mask has not the same dimension as the 'active'
+ image, it will by cropped or padded.
+
+ :param mask: The array to use for the mask.
+ :type mask: numpy.ndarray of uint8 of dimension 2, C-contiguous.
+ Array of other types are converted.
+ :return: True if success, False if failed
+ """
+ return bool(self.getMaskToolsDockWidget().setSelectionMask(mask))
+
+ def _toggleConsoleVisibility(self, isChecked=False):
+ """Create IPythonDockWidget if needed,
+ show it or hide it."""
+ # create widget if needed (first call)
+ if self._consoleDockWidget is None:
+ available_vars = {"plt": weakref.proxy(self)}
+ banner = "The variable 'plt' is available. Use the 'whos' "
+ banner += "and 'help(plt)' commands for more information.\n\n"
+ self._consoleDockWidget = IPythonDockWidget(
+ available_vars=available_vars,
+ custom_banner=banner,
+ parent=self)
+ self.addTabbedDockWidget(self._consoleDockWidget)
+ # self._consoleDockWidget.setVisible(True)
+ self._consoleDockWidget.toggleViewAction().toggled.connect(
+ self.getConsoleAction().setChecked)
+
+ self._consoleDockWidget.setVisible(isChecked)
+
+ def _toggleStatsVisibility(self, isChecked=False):
+ self.getStatsWidget().parent().setVisible(isChecked)
+
+ def _createToolBar(self, title, parent):
+ """Create a QToolBar from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param qt.QWidget parent: See :class:`QToolBar`
+ """
+ toolbar = qt.QToolBar(title, parent)
+
+ # Order widgets with actions
+ objects = self.group.actions()
+
+ # Add push buttons to list
+ index = objects.index(self.colormapAction)
+ objects.insert(index + 1, self.keepDataAspectRatioButton)
+ objects.insert(index + 2, self.yAxisInvertedButton)
+
+ for obj in objects:
+ if isinstance(obj, qt.QAction):
+ toolbar.addAction(obj)
+ else:
+ # Add action for toolbutton in order to allow changing
+ # visibility (see doc QToolBar.addWidget doc)
+ if obj is self.keepDataAspectRatioButton:
+ self.keepDataAspectRatioAction = toolbar.addWidget(obj)
+ elif obj is self.yAxisInvertedButton:
+ self.yAxisInvertedAction = toolbar.addWidget(obj)
+ else:
+ raise RuntimeError()
+ return toolbar
+
+ def toolBar(self):
+ """Return a QToolBar from the QAction of the PlotWindow.
+ """
+ return self._toolbar
+
+ def menu(self, title='Plot', parent=None):
+ """Return a QMenu from the QAction of the PlotWindow.
+
+ :param str title: The title of the QMenu
+ :param parent: See :class:`QMenu`
+ """
+ menu = qt.QMenu(title, parent)
+ for action in self.group.actions():
+ menu.addAction(action)
+ return menu
+
+ def _customControlButtonMenu(self):
+ """Display Options button sub-menu."""
+ controlMenu = self.controlButton.menu()
+ controlMenu.clear()
+ controlMenu.addAction(self.getLegendsDockWidget().toggleViewAction())
+ controlMenu.addAction(self.getRoiAction())
+ controlMenu.addAction(self.getStatsAction())
+ controlMenu.addAction(self.getMaskAction())
+ controlMenu.addAction(self.getConsoleAction())
+
+ controlMenu.addSeparator()
+ controlMenu.addAction(self.getCrosshairAction())
+ controlMenu.addAction(self.getPanWithArrowKeysAction())
+
+ def addTabbedDockWidget(self, dock_widget):
+ """Add a dock widget as a new tab if there are already dock widgets
+ in the plot. When the first tab is added, the area is chosen
+ depending on the plot geometry:
+ if the window is much wider than it is high, the right dock area
+ is used, else the bottom dock area is used.
+
+ :param dock_widget: Instance of :class:`QDockWidget` to be added.
+ """
+ if dock_widget not in self._dockWidgets:
+ self._dockWidgets.append(dock_widget)
+ if len(self._dockWidgets) == 1:
+ # The first created dock widget must be added to a Widget area
+ width = self.centralWidget().width()
+ height = self.centralWidget().height()
+ if width > (1.25 * height):
+ area = qt.Qt.RightDockWidgetArea
+ else:
+ area = qt.Qt.BottomDockWidgetArea
+ self.addDockWidget(area, dock_widget)
+ else:
+ # Other dock widgets are added as tabs to the same widget area
+ self.tabifyDockWidget(self._dockWidgets[0],
+ dock_widget)
+
+ def removeDockWidget(self, dockwidget):
+ """Removes the *dockwidget* from the main window layout and hides it.
+
+ Note that the *dockwidget* is *not* deleted.
+
+ :param QDockWidget dockwidget:
+ """
+ if dockwidget in self._dockWidgets:
+ self._dockWidgets.remove(dockwidget)
+ super(PlotWindow, self).removeDockWidget(dockwidget)
+
+ def _handleFirstDockWidgetShow(self, visible):
+ """Handle QDockWidget.visibilityChanged
+
+ It calls :meth:`addTabbedDockWidget` for the `sender` widget.
+ This allows to call `addTabbedDockWidget` lazily.
+
+ It disconnect itself from the signal once done.
+
+ :param bool visible:
+ """
+ if visible:
+ dockWidget = self.sender()
+ dockWidget.visibilityChanged.disconnect(
+ self._handleFirstDockWidgetShow)
+ self.addTabbedDockWidget(dockWidget)
+
+ def getColorBarWidget(self):
+ """Returns the embedded :class:`ColorBarWidget` widget.
+
+ :rtype: ColorBarWidget
+ """
+ return self._colorbar
+
+ # getters for dock widgets
+
+ def getLegendsDockWidget(self):
+ """DockWidget with Legend panel"""
+ if self._legendsDockWidget is None:
+ self._legendsDockWidget = LegendsDockWidget(plot=self)
+ self._legendsDockWidget.hide()
+ self._legendsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._legendsDockWidget
+
+ def getCurvesRoiDockWidget(self):
+ # Undocumented for a "soft deprecation" in version 0.7.0
+ # (still used internally for lazy loading)
+ if self._curvesROIDockWidget is None:
+ self._curvesROIDockWidget = CurvesROIDockWidget(
+ plot=self, name='Regions Of Interest')
+ self._curvesROIDockWidget.hide()
+ self._curvesROIDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._curvesROIDockWidget
+
+ def getCurvesRoiWidget(self):
+ """Return the :class:`CurvesROIWidget`.
+
+ :class:`silx.gui.plot.CurvesROIWidget.CurvesROIWidget` offers a getter
+ and a setter for the ROI data:
+
+ - :meth:`CurvesROIWidget.getRois`
+ - :meth:`CurvesROIWidget.setRois`
+ """
+ return self.getCurvesRoiDockWidget().roiWidget
+
+ def getMaskToolsDockWidget(self):
+ """DockWidget with image mask panel (lazy-loaded)."""
+ if self._maskToolsDockWidget is None:
+ self._maskToolsDockWidget = MaskToolsDockWidget(
+ plot=self, name='Mask')
+ self._maskToolsDockWidget.hide()
+ self._maskToolsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._maskToolsDockWidget
+
+ def getStatsWidget(self):
+ """Returns a BasicStatsWidget connected to this plot
+
+ :rtype: BasicStatsWidget
+ """
+ if self._statsDockWidget is None:
+ self._statsDockWidget = qt.QDockWidget()
+ self._statsDockWidget.setWindowTitle("Curves stats")
+ self._statsDockWidget.layout().setContentsMargins(0, 0, 0, 0)
+ statsWidget = BasicStatsWidget(parent=self, plot=self)
+ self._statsDockWidget.setWidget(statsWidget)
+ statsWidget.sigVisibilityChanged.connect(
+ self.getStatsAction().setChecked)
+ self._statsDockWidget.hide()
+ self._statsDockWidget.visibilityChanged.connect(
+ self._handleFirstDockWidgetShow)
+ return self._statsDockWidget.widget()
+
+ # getters for actions
+ @property
+ @deprecated(replacement="getInteractiveModeToolBar().getZoomModeAction()",
+ since_version="0.8.0")
+ def zoomModeAction(self):
+ return self.getInteractiveModeToolBar().getZoomModeAction()
+
+ @property
+ @deprecated(replacement="getInteractiveModeToolBar().getPanModeAction()",
+ since_version="0.8.0")
+ def panModeAction(self):
+ return self.getInteractiveModeToolBar().getPanModeAction()
+
+ def getConsoleAction(self):
+ """QAction handling the IPython console activation.
+
+ By default, it is connected to a method that initializes the
+ console widget the first time the user clicks the "Console" menu
+ button. The following clicks, after initialization is done,
+ will toggle the visibility of the console widget.
+
+ :rtype: QAction
+ """
+ if self._consoleAction is None:
+ self._consoleAction = qt.QAction('Console', self)
+ self._consoleAction.setCheckable(True)
+ if IPythonDockWidget is not None:
+ self._consoleAction.toggled.connect(self._toggleConsoleVisibility)
+ else:
+ self._consoleAction.setEnabled(False)
+ return self._consoleAction
+
+ def getCrosshairAction(self):
+ """Action toggling crosshair cursor mode.
+
+ :rtype: actions.PlotAction
+ """
+ if self._crosshairAction is None:
+ self._crosshairAction = actions.control.CrosshairAction(self, color='red')
+ return self._crosshairAction
+
+ def getMaskAction(self):
+ """QAction toggling image mask dock widget
+
+ :rtype: QAction
+ """
+ return self.getMaskToolsDockWidget().toggleViewAction()
+
+ def getPanWithArrowKeysAction(self):
+ """Action toggling pan with arrow keys.
+
+ :rtype: actions.PlotAction
+ """
+ if self._panWithArrowKeysAction is None:
+ self._panWithArrowKeysAction = actions.control.PanWithArrowKeysAction(self)
+ return self._panWithArrowKeysAction
+
+ def getStatsAction(self):
+ if self._statsAction is None:
+ self._statsAction = qt.QAction('Curves stats', self)
+ self._statsAction.setCheckable(True)
+ self._statsAction.setChecked(self.getStatsWidget().parent().isVisible())
+ self._statsAction.toggled.connect(self._toggleStatsVisibility)
+ return self._statsAction
+
+ def getRoiAction(self):
+ """QAction toggling curve ROI dock widget
+
+ :rtype: QAction
+ """
+ return self.getCurvesRoiDockWidget().toggleViewAction()
+
+ def getResetZoomAction(self):
+ """Action resetting the zoom
+
+ :rtype: actions.PlotAction
+ """
+ return self.resetZoomAction
+
+ def getZoomInAction(self):
+ """Action to zoom in
+
+ :rtype: actions.PlotAction
+ """
+ return self.zoomInAction
+
+ def getZoomOutAction(self):
+ """Action to zoom out
+
+ :rtype: actions.PlotAction
+ """
+ return self.zoomOutAction
+
+ def getXAxisAutoScaleAction(self):
+ """Action to toggle the X axis autoscale on zoom reset
+
+ :rtype: actions.PlotAction
+ """
+ return self.xAxisAutoScaleAction
+
+ def getYAxisAutoScaleAction(self):
+ """Action to toggle the Y axis autoscale on zoom reset
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisAutoScaleAction
+
+ def getXAxisLogarithmicAction(self):
+ """Action to toggle logarithmic X axis
+
+ :rtype: actions.PlotAction
+ """
+ return self.xAxisLogarithmicAction
+
+ def getYAxisLogarithmicAction(self):
+ """Action to toggle logarithmic Y axis
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisLogarithmicAction
+
+ def getGridAction(self):
+ """Action to toggle the grid visibility in the plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.gridAction
+
+ def getCurveStyleAction(self):
+ """Action to change curve line and markers styles
+
+ :rtype: actions.PlotAction
+ """
+ return self.curveStyleAction
+
+ def getColormapAction(self):
+ """Action open a colormap dialog to change active image
+ and default colormap.
+
+ :rtype: actions.PlotAction
+ """
+ return self.colormapAction
+
+ def getKeepDataAspectRatioButton(self):
+ """Button to toggle aspect ratio preservation
+
+ :rtype: PlotToolButtons.AspectToolButton
+ """
+ return self.keepDataAspectRatioButton
+
+ def getKeepDataAspectRatioAction(self):
+ """Action associated to keepDataAspectRatioButton.
+ Use this to change the visibility of keepDataAspectRatioButton in the
+ toolbar (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: actions.PlotAction
+ """
+ return self.keepDataAspectRatioAction
+
+ def getYAxisInvertedButton(self):
+ """Button to switch the Y axis orientation
+
+ :rtype: PlotToolButtons.YAxisOriginToolButton
+ """
+ return self.yAxisInvertedButton
+
+ def getYAxisInvertedAction(self):
+ """Action associated to yAxisInvertedButton.
+ Use this to change the visibility yAxisInvertedButton in the toolbar.
+ (See :meth:`QToolBar.addWidget` documentation).
+
+ :rtype: actions.PlotAction
+ """
+ return self.yAxisInvertedAction
+
+ def getIntensityHistogramAction(self):
+ """Action toggling the histogram intensity Plot widget
+
+ :rtype: actions.PlotAction
+ """
+ return self._intensityHistoAction
+
+ def getCopyAction(self):
+ """Action to copy plot snapshot to clipboard
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getCopyAction()
+
+ def getSaveAction(self):
+ """Action to save plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getSaveAction()
+
+ def getPrintAction(self):
+ """Action to print plot
+
+ :rtype: actions.PlotAction
+ """
+ return self.getOutputToolBar().getPrintAction()
+
+ def getFitAction(self):
+ """Action to fit selected curve
+
+ :rtype: actions.PlotAction
+ """
+ return self.fitAction
+
+ def getMedianFilter1DAction(self):
+ """Action toggling the 1D median filter
+
+ :rtype: actions.PlotAction
+ """
+ return self._medianFilter1DAction
+
+ def getMedianFilter2DAction(self):
+ """Action toggling the 2D median filter
+
+ :rtype: actions.PlotAction
+ """
+ return self._medianFilter2DAction
+
+ def getColorBarAction(self):
+ """Action toggling the colorbar show/hide action
+
+ .. warning:: to show/hide the plot colorbar call directly the ColorBar
+ widget using getColorBarWidget()
+
+ :rtype: actions.PlotAction
+ """
+ return self.colorbarAction
+
+
+class Plot1D(PlotWindow):
+ """PlotWindow with tools specific for curves.
+
+ This widgets provides the plot API of :class:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ super(Plot1D, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=True,
+ logScale=True, grid=True,
+ curveStyle=True, colormap=False,
+ aspectRatio=False, yInverted=False,
+ copy=True, save=True, print_=True,
+ control=True, position=True,
+ roi=True, mask=False, fit=True)
+ if parent is None:
+ self.setWindowTitle('Plot1D')
+ self.getXAxis().setLabel('X')
+ self.getYAxis().setLabel('Y')
+ action = self.getFitAction()
+ action.setXRangeUpdatedOnZoom(True)
+ action.setFittedItemUpdatedFromActiveCurve(True)
+
+
+class Plot2D(PlotWindow):
+ """PlotWindow with a toolbar specific for images.
+
+ This widgets provides the plot API of :~:`.PlotWidget`.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ """
+
+ def __init__(self, parent=None, backend=None):
+ # List of information to display at the bottom of the plot
+ posInfo = [
+ ('X', lambda x, y: x),
+ ('Y', lambda x, y: y),
+ ('Data', WeakMethodProxy(self._getImageValue)),
+ ('Dims', WeakMethodProxy(self._getImageDims)),
+ ]
+
+ super(Plot2D, self).__init__(parent=parent, backend=backend,
+ resetzoom=True, autoScale=False,
+ logScale=False, grid=False,
+ curveStyle=False, colormap=True,
+ aspectRatio=True, yInverted=True,
+ copy=True, save=True, print_=True,
+ control=False, position=posInfo,
+ roi=False, mask=True)
+ if parent is None:
+ self.setWindowTitle('Plot2D')
+ self.getXAxis().setLabel('Columns')
+ self.getYAxis().setLabel('Rows')
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self.getYAxis().setInverted(True)
+
+ self.profile = ProfileToolBar(plot=self)
+ self.addToolBar(self.profile)
+
+ self.colorbarAction.setVisible(True)
+ self.getColorBarWidget().setVisible(True)
+
+ # Put colorbar action after colormap action
+ actions = self.toolBar().actions()
+ for action in actions:
+ if action is self.getColormapAction():
+ break
+
+ self.sigActiveImageChanged.connect(self.__activeImageChanged)
+
+ def __activeImageChanged(self, previous, legend):
+ """Handle change of active image
+
+ :param Union[str,None] previous: Legend of previous active image
+ :param Union[str,None] legend: Legend of current active image
+ """
+ if previous is not None:
+ item = self.getImage(previous)
+ if item is not None:
+ item.sigItemChanged.disconnect(self.__imageChanged)
+
+ if legend is not None:
+ item = self.getImage(legend)
+ item.sigItemChanged.connect(self.__imageChanged)
+
+ positionInfo = self.getPositionInfoWidget()
+ if positionInfo is not None:
+ positionInfo.updateInfo()
+
+ def __imageChanged(self, event):
+ """Handle update of active image item
+
+ :param event: Type of changed event
+ """
+ if event == items.ItemChangedType.DATA:
+ positionInfo = self.getPositionInfoWidget()
+ if positionInfo is not None:
+ positionInfo.updateInfo()
+
+ def _getImageValue(self, x, y):
+ """Get status bar value of top most image at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The value at that point or '-'
+ """
+ pickedMask = None
+ for picked in self.pickItems(
+ *self.dataToPixel(x, y, check=False),
+ lambda item: isinstance(item, items.ImageBase)):
+ if isinstance(picked.getItem(), items.MaskImageData):
+ if pickedMask is None: # Use top-most if many masks
+ pickedMask = picked
+ else:
+ image = picked.getItem()
+
+ indices = picked.getIndices(copy=False)
+ if indices is not None:
+ row, col = indices[0][0], indices[1][0]
+ value = image.getData(copy=False)[row, col]
+
+ if pickedMask is not None: # Check if masked
+ maskItem = pickedMask.getItem()
+ indices = pickedMask.getIndices()
+ row, col = indices[0][0], indices[1][0]
+ if maskItem.getData(copy=False)[row, col] != 0:
+ return value, "Masked"
+ return value
+
+ return '-' # No image picked
+
+ def _getImageDims(self, *args):
+ activeImage = self.getActiveImage()
+ if (activeImage is not None and
+ activeImage.getData(copy=False) is not None):
+ dims = activeImage.getData(copy=False).shape[1::-1]
+ return 'x'.join(str(dim) for dim in dims)
+ else:
+ return '-'
+
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+
+ See :class:`silx.gui.plot.Profile.ProfileToolBar`
+ """
+ return self.profile
+
+ @deprecated(replacement="getProfilePlot", since_version="0.5.0")
+ def getProfileWindow(self):
+ return self.getProfilePlot()
+
+ def getProfilePlot(self):
+ """Return plot window used to display profile curve.
+
+ :return: :class:`Plot1D`
+ """
+ return self.profile.getProfilePlot()
diff --git a/src/silx/gui/plot/PrintPreviewToolButton.py b/src/silx/gui/plot/PrintPreviewToolButton.py
new file mode 100644
index 0000000..30967e4
--- /dev/null
+++ b/src/silx/gui/plot/PrintPreviewToolButton.py
@@ -0,0 +1,388 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This modules provides tool buttons to send the content of a plot to a
+print preview page.
+The plot content can then be moved on the page and resized prior to printing.
+
+Classes
+-------
+
+- :class:`PrintPreviewToolButton`
+- :class:`SingletonPrintPreviewToolButton`
+
+Examples
+--------
+
+Simple example
+++++++++++++++
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.PrintPreviewToolButton import PrintPreviewToolButton
+ import numpy
+
+ app = qt.QApplication([])
+
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar, plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+
+ x = numpy.arange(1000)
+ y = x / numpy.sin(x)
+ pw.addCurve(x, y)
+
+ app.exec()
+
+Singleton example
++++++++++++++++++
+
+This example illustrates how to print the content of several different
+plots on the same page. The plots all instantiate a
+:class:`SingletonPrintPreviewToolButton`, which relies on a singleton widget
+(:class:`silx.gui.widgets.PrintPreview.SingletonPrintPreviewDialog`).
+
+.. image:: img/printPreviewMultiPlot.png
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.plot import PlotWidget
+ from silx.gui.plot.PrintPreviewToolButton import SingletonPrintPreviewToolButton
+ import numpy
+
+ app = qt.QApplication([])
+
+ plot_widgets = []
+
+ for i in range(3):
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = SingletonPrintPreviewToolButton(parent=toolbar,
+ plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+ plot_widgets.append(pw)
+
+ x = numpy.arange(1000)
+
+ plot_widgets[0].addCurve(x, numpy.sin(x * 2 * numpy.pi / 1000))
+ plot_widgets[1].addCurve(x, numpy.cos(x * 2 * numpy.pi / 1000))
+ plot_widgets[2].addCurve(x, numpy.tan(x * 2 * numpy.pi / 1000))
+
+ app.exec()
+
+"""
+from __future__ import absolute_import
+
+import logging
+from io import StringIO
+
+from .. import qt
+from .. import icons
+from . import PlotWidget
+from ..widgets.PrintPreview import PrintPreviewDialog, SingletonPrintPreviewDialog
+from ..widgets.PrintGeometryDialog import PrintGeometryDialog
+from silx.utils.deprecation import deprecated
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/12/2018"
+
+_logger = logging.getLogger(__name__)
+# _logger.setLevel(logging.DEBUG)
+
+
+class PrintPreviewToolButton(qt.QToolButton):
+ """QToolButton to open a :class:`PrintPreviewDialog` (if not already open)
+ and add the current plot to its page to be printed.
+
+ :param parent: See :class:`QAction`
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ """
+ def __init__(self, parent=None, plot=None):
+ super(PrintPreviewToolButton, self).__init__(parent)
+
+ if not isinstance(plot, PlotWidget):
+ raise TypeError("plot parameter must be a PlotWidget")
+ self._plot = plot
+
+ self.setIcon(icons.getQIcon('document-print'))
+
+ printGeomAction = qt.QAction("Print geometry", self)
+ printGeomAction.setToolTip("Define a print geometry prior to sending "
+ "the plot to the print preview dialog")
+ printGeomAction.setIcon(icons.getQIcon('shape-rectangle'))
+ printGeomAction.triggered.connect(self._setPrintConfiguration)
+
+ printPreviewAction = qt.QAction("Print preview", self)
+ printPreviewAction.setToolTip("Send plot to the print preview dialog")
+ printPreviewAction.setIcon(icons.getQIcon('document-print'))
+ printPreviewAction.triggered.connect(self._plotToPrintPreview)
+
+ menu = qt.QMenu(self)
+ menu.addAction(printGeomAction)
+ menu.addAction(printPreviewAction)
+ self.setMenu(menu)
+ self.setPopupMode(qt.QToolButton.InstantPopup)
+
+ self._printPreviewDialog = None
+ self._printConfigurationDialog = None
+
+ self._printGeometry = {"xOffset": 0.1,
+ "yOffset": 0.1,
+ "width": 0.9,
+ "height": 0.9,
+ "units": "page",
+ "keepAspectRatio": True}
+
+ @property
+ def printPreviewDialog(self):
+ """Lazy loaded :class:`PrintPreviewDialog`"""
+ # if changes are made here, don't forget making them in
+ # SingletonPrintPreviewToolButton.printPreviewDialog as well
+ if self._printPreviewDialog is None:
+ self._printPreviewDialog = PrintPreviewDialog(self.parent())
+ return self._printPreviewDialog
+
+ def getTitle(self):
+ """Implement this method to fetch the title in the plot.
+
+ :return: Title to be printed above the plot, or None (no title added)
+ :rtype: str or None
+ """
+ return None
+
+ def getCommentAndPosition(self):
+ """Implement this method to fetch the legend to be printed below the
+ figure and its position.
+
+ :return: Legend to be printed below the figure and its position:
+ "CENTER", "LEFT" or "RIGHT"
+ :rtype: (str, str) or (None, None)
+ """
+ return None, None
+
+ @property
+ @deprecated(since_version="0.10",
+ replacement="getPlot()")
+ def plot(self):
+ return self._plot
+
+ def getPlot(self):
+ """Return the :class:`.PlotWidget` associated with this tool button.
+
+ :rtype: :class:`.PlotWidget`
+ """
+ return self._plot
+
+ def _plotToPrintPreview(self):
+ """Grab the plot widget and send it to the print preview dialog.
+ Make sure the print preview dialog is shown and raised."""
+ if not self.printPreviewDialog.ensurePrinterIsSet():
+ return
+
+ comment, commentPosition = self.getCommentAndPosition()
+
+ if qt.HAS_SVG:
+ svgRenderer, viewBox = self._getSvgRendererAndViewbox()
+ self.printPreviewDialog.addSvgItem(svgRenderer,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition,
+ viewBox=viewBox,
+ keepRatio=self._printGeometry["keepAspectRatio"])
+ else:
+ _logger.warning("Missing QtSvg library, using a raster image")
+ pixmap = self._plot.centralWidget().grab()
+ self.printPreviewDialog.addPixmap(pixmap,
+ title=self.getTitle(),
+ comment=comment,
+ commentPosition=commentPosition)
+ self.printPreviewDialog.show()
+ self.printPreviewDialog.raise_()
+
+ def _getSvgRendererAndViewbox(self):
+ """Return a SVG renderer displaying the plot and its viewbox
+ (interactively specified by the user the first time this is called).
+
+ The size of the renderer is adjusted to the printer configuration
+ and to the geometry configuration (width, height, ratio) specified
+ by the user."""
+ imgData = StringIO()
+ assert self._plot.saveGraph(imgData, fileFormat="svg"), \
+ "Unable to save graph"
+ imgData.flush()
+ imgData.seek(0)
+ svgData = imgData.read()
+
+ svgRenderer = qt.QSvgRenderer()
+
+ viewbox = self._getViewBox()
+
+ svgRenderer.setViewBox(viewbox)
+
+ xml_stream = qt.QXmlStreamReader(svgData.encode(errors="replace"))
+
+ # This is for PyMca compatibility, to share a print preview with PyMca plots
+ svgRenderer._viewBox = viewbox
+ svgRenderer._svgRawData = svgData.encode(errors="replace")
+ svgRenderer._svgRendererData = xml_stream
+
+ if not svgRenderer.load(xml_stream):
+ raise RuntimeError("Cannot interpret svg data")
+
+ return svgRenderer, viewbox
+
+ def _getViewBox(self):
+ """
+ """
+ printer = self.printPreviewDialog.printer
+ dpix = printer.logicalDpiX()
+ dpiy = printer.logicalDpiY()
+ availableWidth = printer.width()
+ availableHeight = printer.height()
+
+ config = self._printGeometry
+ width = config['width']
+ height = config['height']
+ xOffset = config['xOffset']
+ yOffset = config['yOffset']
+ units = config['units']
+ keepAspectRatio = config['keepAspectRatio']
+ aspectRatio = self._getPlotAspectRatio()
+
+ # convert the offsets to dots
+ if units.lower() in ['inch', 'inches']:
+ xOffset = xOffset * dpix
+ yOffset = yOffset * dpiy
+ if width is not None:
+ width = width * dpix
+ if height is not None:
+ height = height * dpiy
+ elif units.lower() in ['cm', 'centimeters']:
+ xOffset = (xOffset / 2.54) * dpix
+ yOffset = (yOffset / 2.54) * dpiy
+ if width is not None:
+ width = (width / 2.54) * dpix
+ if height is not None:
+ height = (height / 2.54) * dpiy
+ else:
+ # page units
+ xOffset = availableWidth * xOffset
+ yOffset = availableHeight * yOffset
+ if width is not None:
+ width = availableWidth * width
+ if height is not None:
+ height = availableHeight * height
+
+ availableWidth -= xOffset
+ availableHeight -= yOffset
+
+ if width is not None:
+ if (availableWidth + 0.1) < width:
+ txt = "Available width %f is less than requested width %f" % \
+ (availableWidth, width)
+ raise ValueError(txt)
+ if height is not None:
+ if (availableHeight + 0.1) < height:
+ txt = "Available height %f is less than requested height %f" % \
+ (availableHeight, height)
+ raise ValueError(txt)
+
+ if keepAspectRatio:
+ bodyWidth = width or availableWidth
+ bodyHeight = bodyWidth * aspectRatio
+
+ if bodyHeight > availableHeight:
+ bodyHeight = availableHeight
+ bodyWidth = bodyHeight / aspectRatio
+
+ else:
+ bodyWidth = width or availableWidth
+ bodyHeight = height or availableHeight
+
+ return qt.QRectF(xOffset,
+ yOffset,
+ bodyWidth,
+ bodyHeight)
+
+ def _setPrintConfiguration(self):
+ """Open a dialog to prompt the user to adjust print
+ geometry parameters."""
+ self.printPreviewDialog.ensurePrinterIsSet()
+ if self._printConfigurationDialog is None:
+ self._printConfigurationDialog = PrintGeometryDialog(self.parent())
+
+ self._printConfigurationDialog.setPrintGeometry(self._printGeometry)
+ if self._printConfigurationDialog.exec():
+ self._printGeometry = self._printConfigurationDialog.getPrintGeometry()
+
+ def _getPlotAspectRatio(self):
+ widget = self._plot.centralWidget()
+ graphWidth = float(widget.width())
+ graphHeight = float(widget.height())
+ return graphHeight / graphWidth
+
+
+class SingletonPrintPreviewToolButton(PrintPreviewToolButton):
+ """This class is similar to its parent class :class:`PrintPreviewToolButton`
+ but it uses a singleton print preview widget.
+
+ This allows for several plots to send their content to the
+ same print page, and for users to arrange them."""
+ def __init__(self, parent=None, plot=None):
+ PrintPreviewToolButton.__init__(self, parent, plot)
+
+ @property
+ def printPreviewDialog(self):
+ if self._printPreviewDialog is None:
+ self._printPreviewDialog = SingletonPrintPreviewDialog(self.parent())
+ return self._printPreviewDialog
+
+
+if __name__ == '__main__':
+ import numpy
+ app = qt.QApplication([])
+
+ pw = PlotWidget()
+ toolbar = qt.QToolBar(pw)
+ toolbutton = PrintPreviewToolButton(parent=toolbar,
+ plot=pw)
+ pw.addToolBar(toolbar)
+ toolbar.addWidget(toolbutton)
+ pw.show()
+
+ x = numpy.arange(1000)
+ y = x / numpy.sin(x)
+ pw.addCurve(x, y)
+
+ app.exec()
diff --git a/silx/gui/plot/Profile.py b/src/silx/gui/plot/Profile.py
index 7565155..7565155 100644
--- a/silx/gui/plot/Profile.py
+++ b/src/silx/gui/plot/Profile.py
diff --git a/silx/gui/plot/ProfileMainWindow.py b/src/silx/gui/plot/ProfileMainWindow.py
index ce56cfd..ce56cfd 100644
--- a/silx/gui/plot/ProfileMainWindow.py
+++ b/src/silx/gui/plot/ProfileMainWindow.py
diff --git a/src/silx/gui/plot/ROIStatsWidget.py b/src/silx/gui/plot/ROIStatsWidget.py
new file mode 100644
index 0000000..32a1395
--- /dev/null
+++ b/src/silx/gui/plot/ROIStatsWidget.py
@@ -0,0 +1,780 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides widget for displaying statistics relative to a
+Region of interest and an item
+"""
+
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "22/07/2019"
+
+
+from contextlib import contextmanager
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot.StatsWidget import _StatsWidgetBase, StatsTable, _Container
+from silx.gui.plot.StatsWidget import UpdateModeWidget, UpdateMode
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.items.roi import RegionOfInterest
+from silx.gui.plot import items as plotitems
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.plot3d import items as plot3ditems
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot import stats as statsmdl
+from collections import OrderedDict
+from silx.utils.proxy import docstring
+import silx.gui.plot.items.marker
+import silx.gui.plot.items.shape
+import functools
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+class _GetROIItemCoupleDialog(qt.QDialog):
+ """
+ Dialog used to know which plot item and which roi he wants
+ """
+ _COMPATIBLE_KINDS = ('curve', 'image', 'scatter', 'histogram')
+
+ def __init__(self, parent=None, plot=None, rois=None):
+ qt.QDialog.__init__(self, parent=parent)
+ assert plot is not None
+ assert rois is not None
+ self._plot = plot
+ self._rois = rois
+
+ self.setLayout(qt.QVBoxLayout())
+
+ # define the selection widget
+ self._selection_widget = qt.QWidget()
+ self._selection_widget.setLayout(qt.QHBoxLayout())
+ self._kindCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._kindCB)
+ self._itemCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._itemCB)
+ self._roiCB = qt.QComboBox(parent=self)
+ self._selection_widget.layout().addWidget(self._roiCB)
+ self.layout().addWidget(self._selection_widget)
+
+ # define modal buttons
+ types = qt.QDialogButtonBox.Ok | qt.QDialogButtonBox.Cancel
+ self._buttonsModal = qt.QDialogButtonBox(parent=self)
+ self._buttonsModal.setStandardButtons(types)
+ self.layout().addWidget(self._buttonsModal)
+ self._buttonsModal.accepted.connect(self.accept)
+ self._buttonsModal.rejected.connect(self.reject)
+
+ # connect signal / slot
+ self._kindCB.currentIndexChanged.connect(self._updateValidItemAndRoi)
+
+ def _getCompatibleRois(self, kind):
+ """Return compatible rois for the given item kind"""
+ def is_compatible(roi, kind):
+ if isinstance(roi, RegionOfInterest):
+ return kind in ('image', 'scatter')
+ elif isinstance(roi, ROI):
+ return kind in ('curve', 'histogram')
+ else:
+ raise ValueError('kind not managed')
+ return list(filter(lambda x: is_compatible(x, kind), self._rois))
+
+ def exec(self):
+ self._kindCB.clear()
+ self._itemCB.clear()
+ # filter kind without any items
+ self._valid_kinds = {}
+ # key is item type, value kinds
+ self._valid_rois = {}
+ # key is item type, value rois
+ self._kind_name_to_roi = {}
+ # key is (kind, roi name) value is roi
+ self._kind_name_to_item = {}
+ # key is (kind, legend name) value is item
+ for kind in _GetROIItemCoupleDialog._COMPATIBLE_KINDS:
+ def getItems(kind):
+ output = []
+ for item in self._plot.getItems():
+ type_ = self._plot._itemKind(item)
+ if type_ in kind and item.isVisible():
+ output.append(item)
+ return output
+
+ items = getItems(kind=kind)
+ rois = self._getCompatibleRois(kind=kind)
+ if len(items) > 0 and len(rois) > 0:
+ self._valid_kinds[kind] = items
+ self._valid_rois[kind] = rois
+ for roi in rois:
+ name = roi.getName()
+ self._kind_name_to_roi[(kind, name)] = roi
+ for item in items:
+ self._kind_name_to_item[(kind, item.getLegend())] = item
+
+ # filter roi according to kinds
+ if len(self._valid_kinds) == 0:
+ _logger.warning('no couple item/roi detected for displaying stats')
+ return self.reject()
+
+ for kind in self._valid_kinds:
+ self._kindCB.addItem(kind)
+ self._updateValidItemAndRoi()
+
+ return qt.QDialog.exec(self)
+
+ def exec_(self): # Qt5 compatibility
+ return self.exec()
+
+ def _updateValidItemAndRoi(self, *args, **kwargs):
+ self._itemCB.clear()
+ self._roiCB.clear()
+ kind = self._kindCB.currentText()
+ for roi in self._valid_rois[kind]:
+ self._roiCB.addItem(roi.getName())
+ for item in self._valid_kinds[kind]:
+ self._itemCB.addItem(item.getLegend())
+
+ def getROI(self):
+ kind = self._kindCB.currentText()
+ roi_name = self._roiCB.currentText()
+ return self._kind_name_to_roi[(kind, roi_name)]
+
+ def getItem(self):
+ kind = self._kindCB.currentText()
+ item_name = self._itemCB.currentText()
+ return self._kind_name_to_item[(kind, item_name)]
+
+
+class ROIStatsItemHelper(object):
+ """Item utils to associate a plot item and a roi
+
+ Display on one row statistics regarding the couple
+ (Item (plot item) / roi).
+
+ :param Item plot_item: item for which we want statistics
+ :param Union[ROI,RegionOfInterest]: region of interest to use for
+ statistics.
+ """
+ def __init__(self, plot_item, roi):
+ self._plot_item = plot_item
+ self._roi = roi
+
+ @property
+ def roi(self):
+ """roi"""
+ return self._roi
+
+ def roi_name(self):
+ if isinstance(self._roi, ROI):
+ return self._roi.getName()
+ elif isinstance(self._roi, RegionOfInterest):
+ return self._roi.getName()
+ else:
+ raise TypeError('Unmanaged roi type')
+
+ @property
+ def roi_kind(self):
+ """roi class"""
+ return self._roi.__class__
+
+ # TODO: should call a util function from the wrapper ?
+ def item_kind(self):
+ """item kind"""
+ if isinstance(self._plot_item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(self._plot_item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(self._plot_item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(self._plot_item, plotitems.Histogram):
+ return 'histogram'
+ elif isinstance(self._plot_item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(self._plot_item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+
+ @property
+ def item_legend(self):
+ """legend of the plot Item"""
+ return self._plot_item.getLegend()
+
+ def id_key(self):
+ """unique key to represent the couple (item, roi)"""
+ return (self.item_kind(), self.item_legend, self.roi_kind,
+ self.roi_name())
+
+
+class _StatsROITable(_StatsWidgetBase, TableWidget):
+ """
+ Table sued to display some statistics regarding a couple (item/roi)
+ """
+ _LEGEND_HEADER_DATA = 'legend'
+
+ _KIND_HEADER_DATA = 'kind'
+
+ _ROI_HEADER_DATA = 'roi'
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent, plot):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+ self.__region_edition_callback = {}
+ """We need to keep trace of the roi signals connection because
+ the roi emits the sigChanged during roi edition"""
+ self._items = {}
+ self.setRowCount(0)
+ self.setColumnCount(3)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+ headerItem = qt.QTableWidgetItem(self._ROI_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._ROI_HEADER_DATA)
+ self.setHorizontalHeaderItem(2, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ self.__plotItemToItems = {}
+ """Key is plotItem, values is list of __RoiStatsItemWidget"""
+ self.__roiToItems = {}
+ """Key is roi, values is list of __RoiStatsItemWidget"""
+ self.__roisKeyToRoi = {}
+
+ def add(self, item):
+ assert isinstance(item, ROIStatsItemHelper)
+ if item.id_key() in self._items:
+ _logger.warning("Item %s is already present", item.id_key())
+ return None
+ self._items[item.id_key()] = item
+ self._addItem(item)
+ return item
+
+ def _addItem(self, item):
+ """
+ Add a _RoiStatsItemWidget item to the table.
+
+ :param item:
+ :return: True if successfully added.
+ """
+ if not isinstance(item, ROIStatsItemHelper):
+ # skipped because also receive all new plot item (Marker...) that
+ # we don't want to manage in this case.
+ return
+ # plotItem = item.getItem()
+ # roi = item.getROI()
+ kind = item.item_kind()
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # register the roi and the kind
+ self._registerPlotItem(item)
+ self._registerROI(item)
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem(), # Kind
+ qt.QTableWidgetItem()] # roi
+
+ for column in range(3, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item._plot_item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+ return True
+
+ def _removeAllItems(self):
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ # item = self._tableItemToItem(tableItem)
+ # item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def clear(self):
+ self._removeAllItems()
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 3) # + legend, kind and roi # noqa
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(3 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def _updateItemObserve(self, *args):
+ pass
+
+ def _dataChanged(self, item):
+ pass
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ assert isinstance(item, ROIStatsItemHelper)
+ plotItem = item._plot_item
+ roi = item._roi
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ stats = statsHandler.calculate(plotItem, plot,
+ onlimits=self._statsOnVisibleData,
+ roi=roi, data_changed=data_changed,
+ roi_changed=roi_changed)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(plotItem)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(plotItem))
+ elif name == self._ROI_HEADER_DATA:
+ name = roi.getName()
+ tableItem.setText(name)
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemToItems(self, plotItem):
+ """Return all _RoiStatsItemWidget associated to the plotItem
+ Needed for updating on itemChanged signal
+ """
+ if plotItem in self.__plotItemToItems:
+ return []
+ else:
+ return self.__plotItemToItems[plotItem]
+
+ def _registerPlotItem(self, item):
+ if item._plot_item not in self.__plotItemToItems:
+ self.__plotItemToItems[item._plot_item] = set()
+ self.__plotItemToItems[item._plot_item].add(item)
+
+ def _roiToItems(self, roi):
+ """Return all _RoiStatsItemWidget associated to the roi
+ Needed for updating on roiChanged signal
+ """
+ if roi in self.__roiToItems:
+ return []
+ else:
+ return self.__roiToItems[roi]
+
+ def _registerROI(self, item):
+ if item._roi not in self.__roiToItems:
+ self.__roiToItems[item._roi] = set()
+ # TODO: normalize also sig name
+ if isinstance(item._roi, RegionOfInterest):
+ # item connection within sigRegionChanged should only be
+ # stopped during the region edition
+ self.__region_edition_callback[item._roi] = functools.partial(
+ self._updateAllStats, False, True)
+ item._roi.sigRegionChanged.connect(self.__region_edition_callback[item._roi])
+ item._roi.sigEditingStarted.connect(functools.partial(
+ self._startFiltering, item._roi))
+ item._roi.sigEditingFinished.connect(functools.partial(
+ self._endFiltering, item._roi))
+ else:
+ item._roi.sigChanged.connect(functools.partial(
+ self._updateAllStats, False, True))
+ self.__roiToItems[item._roi].add(item)
+
+ def _startFiltering(self, roi):
+ roi.sigRegionChanged.disconnect(self.__region_edition_callback[roi])
+
+ def _endFiltering(self, roi):
+ roi.sigRegionChanged.connect(self.__region_edition_callback[roi])
+ self._updateAllStats(roi_changed=True)
+
+ def unregisterROI(self, roi):
+ if roi in self.__roiToItems:
+ del self.__roiToItems[roi]
+ if isinstance(roi, RegionOfInterest):
+ roi.sigRegionEditionStarted.disconnect(functools.partial(
+ self._startFiltering, roi))
+ roi.sigRegionEditionFinished.disconnect(functools.partial(
+ self._startFiltering, roi))
+ try:
+ roi.sigRegionChanged.disconnect(self._updateAllStats)
+ except:
+ pass
+ else:
+ roi.sigChanged.disconnect(self._updateAllStats)
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if event is ItemChangedType.DATA:
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ sender = self.sender()
+ for item in self.__plotItemToItems[sender]:
+ # TODO: get all concerned items
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _removeItem(self, itemKey):
+ if isinstance(itemKey, (silx.gui.plot.items.marker.Marker,
+ silx.gui.plot.items.shape.Shape)):
+ return
+ if itemKey not in self._items:
+ _logger.warning('key not recognized. Won\'t remove any item')
+ return
+ item = self._items[itemKey]
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item._plot_item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+ del self._items[itemKey]
+
+ def _updateAllStats(self, is_request=False, roi_changed=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if (self.getUpdateMode() is UpdateMode.MANUAL and
+ not is_request and not roi_changed):
+ return
+
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item, roi_changed=roi_changed,
+ data_changed=is_request)
+
+ def _plotCurrentChanged(self, *args):
+ pass
+
+ def _getRoi(self, kind, name):
+ """return the roi fitting the requirement kind, name. This information
+ is enough to be sure it is unique (in the widget)"""
+ for roi in self.__roiToItems:
+ roiName = roi.getName()
+ if isinstance(roi, kind) and name == roiName:
+ return roi
+ return None
+
+ def _getPlotItem(self, kind, legend):
+ """return the plotItem fitting the requirement kind, legend.
+ This information is enough to be sure it is unique (in the widget)"""
+ for plotItem in self.__plotItemToItems:
+ if legend == plotItem.getLegend() and self._plotWrapper.getKind(plotItem) == kind:
+ return plotItem
+ return None
+
+
+class ROIStatsWidget(qt.QMainWindow):
+ """
+ Widget used to define stats item for a couple(roi, plotItem).
+ Stats will be computing on a given item (curve, image...) in the given
+ region of interest.
+
+ It also provide an interface for adding and removing items.
+
+ .. snapshotqt:: img/ROIStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui import qt
+ from silx.gui.plot import Plot2D
+ from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+ from silx.gui.plot.items.roi import RectangleROI
+ import numpy
+ plot = Plot2D()
+ plot.addImage(numpy.arange(10000).reshape(100, 100), legend='img')
+ plot.show()
+ rectangleROI = RectangleROI()
+ rectangleROI.setGeometry(origin=(0, 100), size=(20, 20))
+ rectangleROI.setName('Initial ROI')
+ widget = ROIStatsWidget(plot=plot)
+ widget.setStats([('sum', numpy.sum), ('mean', numpy.mean)])
+ widget.registerROI(rectangleROI)
+ widget.addItem(roi=rectangleROI, plotItem=plot.getImage('img'))
+ widget.show()
+
+ :param Union[qt.QWidget,None] parent: parent qWidget
+ :param PlotWindow plot: plot widget containing the items
+ :param stats: stats to display
+ :param tuple rois: tuple of rois to manage
+ """
+
+ def __init__(self, parent=None, plot=None, stats=None, rois=None):
+ qt.QMainWindow.__init__(self, parent)
+
+ toolbar = qt.QToolBar(self)
+ icon = icons.getQIcon('add')
+ self._rois = list(rois) if rois is not None else []
+ self._addAction = qt.QAction(icon, 'add item/roi', toolbar)
+ self._addAction.triggered.connect(self._addRoiStatsItem)
+ icon = icons.getQIcon('rm')
+ self._removeAction = qt.QAction(icon, 'remove item/roi', toolbar)
+ self._removeAction.triggered.connect(self._removeCurrentRow)
+
+ toolbar.addAction(self._addAction)
+ toolbar.addAction(self._removeAction)
+ self.addToolBar(toolbar)
+
+ self._plot = plot
+ self._statsROITable = _StatsROITable(parent=self, plot=self._plot)
+ self.setStats(stats=stats)
+ self.setCentralWidget(self._statsROITable)
+ self.setWindowFlags(qt.Qt.Widget)
+
+ # expose API
+ self._setUpdateMode = self._statsROITable.setUpdateMode
+ self._updateAllStats = self._statsROITable._updateAllStats
+
+ # setup
+ self._statsROITable.setSelectionBehavior(qt.QTableWidget.SelectRows)
+
+ def registerROI(self, roi):
+ """For now there is no direct link between roi and plot. That is why
+ we need to add/register them to be able to associate them"""
+ self._rois.append(roi)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._plot = plot
+
+ def getPlot(self):
+ return self._plot
+
+ @docstring(_StatsROITable)
+ def setStats(self, stats):
+ if stats is not None:
+ self._statsROITable.setStats(statsHandler=stats)
+
+ @docstring(_StatsROITable)
+ def getStatsHandler(self):
+ """
+
+ :return:
+ """
+ return self._statsROITable.getStatsHandler()
+
+ def _addRoiStatsItem(self):
+ """Ask the user what couple ROI / item he want to display"""
+ dialog = _GetROIItemCoupleDialog(parent=self, plot=self._plot,
+ rois=self._rois)
+ if dialog.exec():
+ self.addItem(roi=dialog.getROI(), plotItem=dialog.getItem())
+
+ def addItem(self, plotItem, roi):
+ """
+ Add a row of statitstic regarding the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI, RegionOfInterest]
+ :return: None of failed to add the item
+ :rtype: Union[None,ROIStatsItemHelper]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ return self._statsROITable.add(item=statsItem)
+
+ def removeItem(self, plotItem, roi):
+ """
+ Remove the row associated to the couple (plotItem, roi)
+
+ :param Item plotItem: item to use for statistics
+ :param roi: region of interest to limit the statistic.
+ :type: Union[ROI,RegionOfInterest]
+ """
+ statsItem = ROIStatsItemHelper(roi=roi, plot_item=plotItem)
+ self._statsROITable._removeItem(itemKey=statsItem.id_key())
+
+ def _removeCurrentRow(self):
+ def is1DKind(kind):
+ if kind in ('curve', 'histogram', 'scatter'):
+ return True
+ else:
+ return False
+
+ currentRow = self._statsROITable.currentRow()
+ item_kind = self._statsROITable.item(currentRow, 1).text()
+ item_legend = self._statsROITable.item(currentRow, 0).text()
+
+ roi_name = self._statsROITable.item(currentRow, 2).text()
+ roi_kind = ROI if is1DKind(item_kind) else RegionOfInterest
+ roi = self._statsROITable._getRoi(kind=roi_kind, name=roi_name)
+ if roi is None:
+ _logger.warning('failed to retrieve the roi you want to remove')
+ return False
+ plot_item = self._statsROITable._getPlotItem(kind=item_kind,
+ legend=item_legend)
+ if plot_item is None:
+ _logger.warning('failed to retrieve the plot item you want to'
+ 'remove')
+ return False
+ return self.removeItem(plotItem=plot_item, roi=roi)
diff --git a/src/silx/gui/plot/ScatterMaskToolsWidget.py b/src/silx/gui/plot/ScatterMaskToolsWidget.py
new file mode 100644
index 0000000..c242dfc
--- /dev/null
+++ b/src/silx/gui/plot/ScatterMaskToolsWidget.py
@@ -0,0 +1,621 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Widget providing a set of tools to draw masks on a PlotWidget.
+
+This widget is meant to work with a modified :class:`silx.gui.plot.PlotWidget`
+
+- :class:`ScatterMask`: Handle scatter mask update and history
+- :class:`ScatterMaskToolsWidget`: GUI for :class:`ScatterMask`
+- :class:`ScatterMaskToolsDockWidget`: DockWidget to integrate in :class:`PlotWindow`
+"""
+
+from __future__ import division
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import math
+import logging
+import os
+import numpy
+import sys
+
+from .. import qt
+from ...math.combo import min_max
+from ...image import shapes
+
+from .items import ItemChangedType, Scatter
+from ._BaseMaskToolsWidget import BaseMask, BaseMaskToolsWidget, BaseMaskToolsDockWidget
+from ..colors import cursorColorForColormap, rgba
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterMask(BaseMask):
+ """A 1D mask for scatter data.
+ """
+ def __init__(self, scatter=None):
+ """
+
+ :param scatter: :class:`silx.gui.plot.items.Scatter` instance
+ """
+ BaseMask.__init__(self, scatter)
+
+ def _getXY(self):
+ x = self._dataItem.getXData(copy=False)
+ y = self._dataItem.getYData(copy=False)
+ return x, y
+
+ def getDataValues(self):
+ """Return scatter data values as a 1D array.
+
+ :rtype: 1D numpy.ndarray
+ """
+ return self._dataItem.getValueData(copy=False)
+
+ def save(self, filename, kind):
+ if kind == 'npy':
+ try:
+ numpy.save(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+ elif kind in ["csv", "txt"]:
+ try:
+ numpy.savetxt(filename, self.getMask(copy=False))
+ except IOError:
+ raise RuntimeError("Mask file can't be written")
+
+ def updatePoints(self, level, indices, mask=True):
+ """Mask/Unmask points with given indices.
+
+ :param int level: Mask level to update.
+ :param indices: Sequence or 1D array of indices of points to be
+ updated
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ if mask:
+ self._mask[indices] = level
+ else:
+ # unmask only where mask level is the specified value
+ indices_stencil = numpy.zeros_like(self._mask, dtype=bool)
+ indices_stencil[indices] = True
+ self._mask[numpy.logical_and(self._mask == level, indices_stencil)] = 0
+ self._notify()
+
+ # update shapes
+ def updatePolygon(self, level, vertices, mask=True):
+ """Mask/Unmask a polygon of the given mask level.
+
+ :param int level: Mask level to update.
+ :param vertices: Nx2 array of polygon corners as (y, x) or (row, col)
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ polygon = shapes.Polygon(vertices)
+ x, y = self._getXY()
+
+ # TODO: this could be optimized if necessary
+ indices_in_polygon = [idx for idx in range(len(x)) if
+ polygon.is_inside(y[idx], x[idx])]
+
+ self.updatePoints(level, indices_in_polygon, mask)
+
+ def updateRectangle(self, level, y, x, height, width, mask=True):
+ """Mask/Unmask data inside a rectangle
+
+ :param int level: Mask level to update.
+ :param float y: Y coordinate of bottom left corner of the rectangle
+ :param float x: X coordinate of bottom left corner of the rectangle
+ :param float height:
+ :param float width:
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ vertices = [(y, x),
+ (y + height, x),
+ (y + height, x + width),
+ (y, x + width)]
+ self.updatePolygon(level, vertices, mask)
+
+ def updateDisk(self, level, cy, cx, radius, mask=True):
+ """Mask/Unmask a disk of the given mask level.
+
+ :param int level: Mask level to update.
+ :param float cy: Disk center (y).
+ :param float cx: Disk center (x).
+ :param float radius: Radius of the disk in mask array unit
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ x, y = self._getXY()
+ stencil = (y - cy)**2 + (x - cx)**2 < radius**2
+ self.updateStencil(level, stencil, mask)
+
+ def updateEllipse(self, level, crow, ccol, radius_r, radius_c, mask=True):
+ """Mask/Unmask an ellipse of the given mask level.
+
+ :param int level: Mask level to update.
+ :param int crow: Row of the center of the ellipse
+ :param int ccol: Column of the center of the ellipse
+ :param float radius_r: Radius of the ellipse in the row
+ :param float radius_c: Radius of the ellipse in the column
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ def is_inside(px, py):
+ return (px - ccol)**2 / radius_c**2 + (py - crow)**2 / radius_r**2 <= 1.0
+ x, y = self._getXY()
+ indices_inside = [idx for idx in range(len(x)) if is_inside(x[idx], y[idx])]
+ self.updatePoints(level, indices_inside, mask)
+
+ def updateLine(self, level, y0, x0, y1, x1, width, mask=True):
+ """Mask/Unmask points inside a rectangle defined by a line (two
+ end points) and a width.
+
+ :param int level: Mask level to update.
+ :param float y0: Row of the starting point.
+ :param float x0: Column of the starting point.
+ :param float row1: Row of the end point.
+ :param float col1: Column of the end point.
+ :param float width: Width of the line.
+ :param bool mask: True to mask (default), False to unmask.
+ """
+ # theta is the angle between the horizontal and the line
+ theta = math.atan((y1 - y0) / (x1 - x0)) if x1 - x0 else 0
+ w_over_2_sin_theta = width / 2. * math.sin(theta)
+ w_over_2_cos_theta = width / 2. * math.cos(theta)
+
+ vertices = [(y0 - w_over_2_cos_theta, x0 + w_over_2_sin_theta),
+ (y0 + w_over_2_cos_theta, x0 - w_over_2_sin_theta),
+ (y1 + w_over_2_cos_theta, x1 - w_over_2_sin_theta),
+ (y1 - w_over_2_cos_theta, x1 + w_over_2_sin_theta)]
+
+ self.updatePolygon(level, vertices, mask)
+
+
+class ScatterMaskToolsWidget(BaseMaskToolsWidget):
+ """Widget with tools for masking data points on a scatter in a
+ :class:`PlotWidget`."""
+
+ def __init__(self, parent=None, plot=None):
+ super(ScatterMaskToolsWidget, self).__init__(parent, plot,
+ mask=ScatterMask())
+ self._z = 2 # Mask layer in plot
+ self._data_scatter = None
+ """plot Scatter item for data"""
+
+ self._data_extent = None
+ """Maximum extent of the data i.e., max(xMax-xMin, yMax-yMin)"""
+
+ self._mask_scatter = None
+ """plot Scatter item for representing the mask"""
+
+ def setSelectionMask(self, mask, copy=True):
+ """Set the mask to a new array.
+
+ :param numpy.ndarray mask:
+ The array to use for the mask or None to reset the mask.
+ :type mask: numpy.ndarray of uint8, C-contiguous.
+ Array of other types are converted.
+ :param bool copy: True (the default) to copy the array,
+ False to use it as is if possible.
+ :return: None if failed, shape of mask as 1-tuple if successful.
+ The mask can be cropped or padded to fit active scatter,
+ the returned shape is that of the scatter data.
+ """
+ if self._data_scatter is None:
+ # this can happen if the mask tools widget has never been shown
+ self._data_scatter = self.plot._getActiveItem(kind="scatter")
+ if self._data_scatter is None:
+ return None
+ self._adjustColorAndBrushSize(self._data_scatter)
+
+ if mask is None:
+ self.resetSelectionMask()
+ return self._data_scatter.getXData(copy=False).shape
+
+ mask = numpy.array(mask, copy=False, dtype=numpy.uint8)
+
+ if self._data_scatter.getXData(copy=False).shape == (0,) \
+ or mask.shape == self._data_scatter.getXData(copy=False).shape:
+ self._mask.setMask(mask, copy=copy)
+ self._mask.commit()
+ return mask.shape
+ else:
+ raise ValueError("Mask does not have the same shape as the data")
+
+ # Handle mask refresh on the plot
+
+ def _updatePlotMask(self):
+ """Update mask image in plot"""
+ mask = self.getSelectionMask(copy=False)
+ if mask is not None:
+ self.plot.addScatter(self._data_scatter.getXData(),
+ self._data_scatter.getYData(),
+ mask,
+ legend=self._maskName,
+ colormap=self._colormap,
+ z=self._z)
+ self._mask_scatter = self.plot._getItem(kind="scatter",
+ legend=self._maskName)
+ self._mask_scatter.setSymbolSize(
+ self._data_scatter.getSymbolSize() + 2.0)
+ self._mask_scatter.sigItemChanged.connect(self.__maskScatterChanged)
+ elif self.plot._getItem(kind="scatter",
+ legend=self._maskName) is not None:
+ self.plot.remove(self._maskName, kind='scatter')
+
+ def __maskScatterChanged(self, event):
+ """Handles update of mask scatter"""
+ if (event is ItemChangedType.VISUALIZATION_MODE and
+ self._mask_scatter is not None):
+ self._mask_scatter.setVisualization(Scatter.Visualization.POINTS)
+
+ # track widget visibility and plot active image changes
+
+ def showEvent(self, event):
+ try:
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ except (RuntimeError, TypeError):
+ pass
+ self._activeScatterChanged(None, None) # Init mask + enable/disable widget
+ self.plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+
+ def hideEvent(self, event):
+ try:
+ # if the method is not connected this raises a TypeError and there is no way
+ # to know the connected slots
+ self.plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ except (RuntimeError, TypeError):
+ _logger.info(sys.exc_info()[1])
+ if not self.browseAction.isChecked():
+ self.browseAction.trigger() # Disable drawing tool
+
+ if self.getSelectionMask(copy=False) is not None:
+ self.plot.sigActiveScatterChanged.connect(
+ self._activeScatterChangedAfterCare)
+
+ def _adjustColorAndBrushSize(self, activeScatter):
+ colormap = activeScatter.getColormap()
+ self._defaultOverlayColor = rgba(cursorColorForColormap(colormap['name']))
+ self._setMaskColors(self.levelSpinBox.value(),
+ self.transparencySlider.value() /
+ self.transparencySlider.maximum())
+ self._z = activeScatter.getZValue() + 1
+ self._data_scatter = activeScatter
+
+ # Adjust brush size to data range
+ xData = self._data_scatter.getXData(copy=False)
+ yData = self._data_scatter.getYData(copy=False)
+ # Adjust brush size to data range
+ if xData.size > 0 and yData.size > 0:
+ xMin, xMax = min_max(xData)
+ yMin, yMax = min_max(yData)
+ self._data_extent = max(xMax - xMin, yMax - yMin)
+ else:
+ self._data_extent = None
+
+ def _activeScatterChangedAfterCare(self, previous, next):
+ """Check synchro of active scatter and mask when mask widget is hidden.
+
+ If active image has no more the same size as the mask, the mask is
+ removed, otherwise it is adjusted to z.
+ """
+ # check that content changed was the active scatter
+ activeScatter = self.plot._getActiveItem(kind="scatter")
+
+ if activeScatter is None or activeScatter.getName() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ self._data_extent = None
+ self._data_scatter = None
+
+ else:
+ self._adjustColorAndBrushSize(activeScatter)
+
+ if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
+ # scatter has not the same size, remove mask and stop listening
+ if self.plot._getItem(kind="scatter", legend=self._maskName):
+ self.plot.remove(self._maskName, kind='scatter')
+
+ self.plot.sigActiveScatterChanged.disconnect(
+ self._activeScatterChangedAfterCare)
+ self._data_extent = None
+ self._data_scatter = None
+
+ else:
+ # Refresh in case z changed
+ self._mask.setDataItem(self._data_scatter)
+ self._updatePlotMask()
+
+ def _activeScatterChanged(self, previous, next):
+ """Update widget and mask according to active scatter changes"""
+ activeScatter = self.plot._getActiveItem(kind="scatter")
+
+ if activeScatter is None or activeScatter.getName() == self._maskName:
+ # No active scatter or active scatter is the mask...
+ self.setEnabled(False)
+
+ self._data_scatter = None
+ self._data_extent = None
+ self._mask.reset()
+ self._mask.commit()
+
+ else: # There is an active scatter
+ self.setEnabled(True)
+ self._adjustColorAndBrushSize(activeScatter)
+
+ self._mask.setDataItem(self._data_scatter)
+ if self._data_scatter.getXData(copy=False).shape != self._mask.getMask(copy=False).shape:
+ self._mask.reset(self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+ else:
+ # Refresh in case z changed
+ self._updatePlotMask()
+
+ self._updateInteractiveMode()
+
+ # Handle whole mask operations
+
+ def load(self, filename):
+ """Load a mask from an image file.
+
+ :param str filename: File name from which to load the mask
+ :raise Exception: An exception in case of failure
+ :raise RuntimeWarning: In case the mask was applied but with some
+ import changes to notice
+ """
+ _, extension = os.path.splitext(filename)
+ extension = extension.lower()[1:]
+ if extension == "npy":
+ try:
+ mask = numpy.load(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy file.',
+ filename)
+ elif extension in ["txt", "csv"]:
+ try:
+ mask = numpy.loadtxt(filename)
+ except IOError:
+ _logger.error("Can't load filename '%s'", filename)
+ _logger.debug("Backtrace", exc_info=True)
+ raise RuntimeError('File "%s" is not a numpy txt file.',
+ filename)
+ else:
+ msg = "Extension '%s' is not supported."
+ raise RuntimeError(msg % extension)
+
+ self.setSelectionMask(mask, copy=False)
+
+ def _loadMask(self):
+ """Open load mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Load Mask")
+ dialog.setModal(1)
+ filters = [
+ 'NumPy binary file (*.npy)',
+ 'CSV text file (*.csv)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.ExistingFile)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.load(filename)
+ # except RuntimeWarning as e:
+ # message = e.args[0]
+ # msg = qt.QMessageBox(self)
+ # msg.setIcon(qt.QMessageBox.Warning)
+ # msg.setText("Mask loaded but an operation was applied.\n" + message)
+ # msg.exec()
+ except Exception as e:
+ message = e.args[0]
+ msg = qt.QMessageBox(self)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setText("Cannot load mask from file. " + message)
+ msg.exec()
+
+ def _saveMask(self):
+ """Open Save mask dialog"""
+ dialog = qt.QFileDialog(self)
+ dialog.setWindowTitle("Save Mask")
+ dialog.setModal(1)
+ filters = [
+ 'NumPy binary file (*.npy)',
+ 'CSV text file (*.csv)',
+ ]
+ dialog.setNameFilters(filters)
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+ dialog.setDirectory(self.maskFileDir)
+ if not dialog.exec():
+ dialog.close()
+ return
+
+ # convert filter name to extension name with the .
+ extension = dialog.selectedNameFilter().split()[-1][2:-1]
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if not filename.lower().endswith(extension):
+ filename += extension
+
+ if os.path.exists(filename):
+ try:
+ os.remove(filename)
+ except IOError as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Removing existing file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save.\n"
+ "Input Output Error: %s" % strerror)
+ msg.exec()
+ return
+
+ # Update the directory according to the user selection
+ self.maskFileDir = os.path.dirname(filename)
+
+ try:
+ self.save(filename, extension[1:])
+ except Exception as e:
+ msg = qt.QMessageBox(self)
+ msg.setWindowTitle("Saving mask file")
+ msg.setIcon(qt.QMessageBox.Critical)
+
+ if hasattr(e, "strerror"):
+ strerror = e.strerror
+ else:
+ strerror = sys.exc_info()[1]
+ msg.setText("Cannot save file %s\n%s" % (filename, strerror))
+ msg.exec()
+
+ def resetSelectionMask(self):
+ """Reset the mask"""
+ self._mask.reset(
+ shape=self._data_scatter.getXData(copy=False).shape)
+ self._mask.commit()
+
+ def _getPencilWidth(self):
+ """Returns the width of the pencil to use in data coordinates`
+
+ :rtype: float
+ """
+ width = super(ScatterMaskToolsWidget, self)._getPencilWidth()
+ if self._data_extent is not None:
+ width *= 0.01 * self._data_extent
+ return width
+
+ def _plotDrawEvent(self, event):
+ """Handle draw events from the plot"""
+ if (self._drawingMode is None or
+ event['event'] not in ('drawingProgress', 'drawingFinished')):
+ return
+
+ if not len(self._data_scatter.getXData(copy=False)):
+ return
+
+ level = self.levelSpinBox.value()
+
+ if self._drawingMode == 'rectangle':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+
+ self._mask.updateRectangle(
+ level,
+ y=event['y'],
+ x=event['x'],
+ height=abs(event['height']),
+ width=abs(event['width']),
+ mask=doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'ellipse':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ center = event['points'][0]
+ size = event['points'][1]
+ self._mask.updateEllipse(level, center[1], center[0],
+ size[1], size[0], doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'polygon':
+ if event['event'] == 'drawingFinished':
+ doMask = self._isMasking()
+ vertices = event['points']
+ vertices = vertices[:, (1, 0)] # (y, x)
+ self._mask.updatePolygon(level, vertices, doMask)
+ self._mask.commit()
+
+ elif self._drawingMode == 'pencil':
+ doMask = self._isMasking()
+ # convert from plot to array coords
+ x, y = event['points'][-1]
+
+ brushSize = self._getPencilWidth()
+
+ if self._lastPencilPos != (y, x):
+ if self._lastPencilPos is not None:
+ # Draw the line
+ self._mask.updateLine(
+ level,
+ self._lastPencilPos[0], self._lastPencilPos[1],
+ y, x,
+ brushSize,
+ doMask)
+
+ # Draw the very first, or last point
+ self._mask.updateDisk(level, y, x, brushSize / 2., doMask)
+
+ if event['event'] == 'drawingFinished':
+ self._mask.commit()
+ self._lastPencilPos = None
+ else:
+ self._lastPencilPos = y, x
+ else:
+ _logger.error("Drawing mode %s unsupported", self._drawingMode)
+
+ def _loadRangeFromColormapTriggered(self):
+ """Set range from active scatter colormap range"""
+ if self._data_scatter is not None:
+ # Update thresholds according to colormap
+ colormap = self._data_scatter.getColormap()
+ if colormap['autoscale']:
+ min_ = numpy.nanmin(self._data_scatter.getValueData(copy=False))
+ max_ = numpy.nanmax(self._data_scatter.getValueData(copy=False))
+ else:
+ min_, max_ = colormap['vmin'], colormap['vmax']
+ self.minLineEdit.setText(str(min_))
+ self.maxLineEdit.setText(str(max_))
+
+
+class ScatterMaskToolsDockWidget(BaseMaskToolsDockWidget):
+ """:class:`ScatterMaskToolsWidget` embedded in a QDockWidget.
+
+ For integration in a :class:`PlotWindow`.
+
+ :param parent: See :class:`QDockWidget`
+ :param plot: The PlotWidget this widget is operating on
+ :paran str name: The title of this widget
+ """
+ def __init__(self, parent=None, plot=None, name='Mask'):
+ widget = ScatterMaskToolsWidget(plot=plot)
+ super(ScatterMaskToolsDockWidget, self).__init__(parent, name, widget)
diff --git a/src/silx/gui/plot/ScatterView.py b/src/silx/gui/plot/ScatterView.py
new file mode 100644
index 0000000..d3fd2e0
--- /dev/null
+++ b/src/silx/gui/plot/ScatterView.py
@@ -0,0 +1,404 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""A widget dedicated to display scatter plots
+
+It is based on a :class:`~silx.gui.plot.PlotWidget` with additional tools
+for scatter plots.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "14/06/2018"
+
+
+import logging
+import weakref
+
+import numpy
+
+from . import items
+from . import PlotWidget
+from . import tools
+from .actions import histogram as actions_histogram
+from .tools.profile import ScatterProfileToolBar
+from .ColorBar import ColorBarWidget
+from .ScatterMaskToolsWidget import ScatterMaskToolsWidget
+
+from ..widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
+from .. import qt, icons
+from ...utils.proxy import docstring
+from ...utils.weakref import WeakMethodProxy
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ScatterView(qt.QMainWindow):
+ """Main window with a PlotWidget and tools specific for scatter plots.
+
+ :param parent: The parent of this widget
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`~silx.gui.plot.PlotWidget` for the list of supported backend.
+ :type backend: Union[str,~silx.gui.plot.backends.BackendBase.BackendBase]
+ """
+
+ _SCATTER_LEGEND = ' '
+ """Legend used for the scatter item"""
+
+ def __init__(self, parent=None, backend=None):
+ super(ScatterView, self).__init__(parent=parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('ScatterView')
+
+ # Create plot widget
+ plot = PlotWidget(parent=self, backend=backend)
+ self._plot = weakref.ref(plot)
+
+ # Add an empty scatter
+ self.__createEmptyScatter()
+
+ # Create colorbar widget with white background
+ self._colorbar = ColorBarWidget(parent=self, plot=plot)
+ self._colorbar.setAutoFillBackground(True)
+ palette = self._colorbar.palette()
+ palette.setColor(qt.QPalette.Window, qt.Qt.white)
+ self._colorbar.setPalette(palette)
+
+ # Create PositionInfo widget
+ self.__lastPickingPos = None
+ self.__pickingCache = None
+ self._positionInfo = tools.PositionInfo(
+ plot=plot,
+ converters=(('X', WeakMethodProxy(self._getPickedX)),
+ ('Y', WeakMethodProxy(self._getPickedY)),
+ ('Data', WeakMethodProxy(self._getPickedValue)),
+ ('Index', WeakMethodProxy(self._getPickedIndex))))
+
+ # Combine plot, position info and colorbar into central widget
+ gridLayout = qt.QGridLayout()
+ gridLayout.setSpacing(0)
+ gridLayout.setContentsMargins(0, 0, 0, 0)
+ gridLayout.addWidget(plot, 0, 0)
+ gridLayout.addWidget(self._colorbar, 0, 1)
+ gridLayout.addWidget(self._positionInfo, 1, 0, 1, -1)
+ gridLayout.setRowStretch(0, 1)
+ gridLayout.setColumnStretch(0, 1)
+ centralWidget = qt.QWidget(self)
+ centralWidget.setLayout(gridLayout)
+ self.setCentralWidget(centralWidget)
+
+ # Create mask tool dock widget
+ self._maskToolsWidget = ScatterMaskToolsWidget(parent=self, plot=plot)
+ self._maskDock = BoxLayoutDockWidget()
+ self._maskDock.setWindowTitle('Scatter Mask')
+ self._maskDock.setWidget(self._maskToolsWidget)
+ self._maskDock.setVisible(False)
+ self.addDockWidget(qt.Qt.BottomDockWidgetArea, self._maskDock)
+
+ self._maskAction = self._maskDock.toggleViewAction()
+ self._maskAction.setIcon(icons.getQIcon('image-mask'))
+ self._maskAction.setToolTip("Display/hide mask tools")
+
+ self._intensityHistoAction = actions_histogram.PixelIntensitiesHistoAction(plot=plot, parent=self)
+
+ # Create toolbars
+ self._interactiveModeToolBar = tools.InteractiveModeToolBar(
+ parent=self, plot=plot)
+
+ self._scatterToolBar = tools.ScatterToolBar(
+ parent=self, plot=plot)
+ self._scatterToolBar.addAction(self._maskAction)
+ self._scatterToolBar.addAction(self._intensityHistoAction)
+
+ self._profileToolBar = ScatterProfileToolBar(parent=self, plot=plot)
+
+ self._outputToolBar = tools.OutputToolBar(parent=self, plot=plot)
+
+ # Activate shortcuts in PlotWindow widget:
+ for toolbar in (self._interactiveModeToolBar,
+ self._scatterToolBar,
+ self._profileToolBar,
+ self._outputToolBar):
+ self.addToolBar(toolbar)
+ for action in toolbar.actions():
+ self.addAction(action)
+
+
+ def __createEmptyScatter(self):
+ """Create an empty scatter item that is used to display the data
+
+ :rtype: ~silx.gui.plot.items.Scatter
+ """
+ plot = self.getPlotWidget()
+ plot.addScatter(x=(), y=(), value=(), legend=self._SCATTER_LEGEND)
+ scatter = plot._getItem(
+ kind='scatter', legend=self._SCATTER_LEGEND)
+ # Profile is not selectable,
+ # so it does not interfere with profile interaction
+ scatter._setSelectable(False)
+ return scatter
+
+ def _pickScatterData(self, x, y):
+ """Get data and index and value of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data index and value at that point or None
+ """
+ pickingPos = x, y
+ if self.__lastPickingPos != pickingPos:
+ self.__pickingCache = None
+ self.__lastPickingPos = pickingPos
+
+ plot = self.getPlotWidget()
+ if plot is not None:
+ pixelPos = plot.dataToPixel(x, y)
+ if pixelPos is not None:
+ # Start from top-most item
+ result = plot._pickTopMost(
+ pixelPos[0], pixelPos[1],
+ lambda item: isinstance(item, items.Scatter))
+ if result is not None:
+ item = result.getItem()
+ if item.getVisualization() is items.Scatter.Visualization.BINNED_STATISTIC:
+ # Get highest index of closest points
+ selected = result.getIndices(copy=False)[::-1]
+ dataIndex = selected[numpy.argmin(
+ (item.getXData(copy=False)[selected] - x)**2 +
+ (item.getYData(copy=False)[selected] - y)**2)]
+ else:
+ # Get last index
+ # with matplotlib it should be the top-most point
+ dataIndex = result.getIndices(copy=False)[-1]
+ self.__pickingCache = (
+ dataIndex,
+ item.getXData(copy=False)[dataIndex],
+ item.getYData(copy=False)[dataIndex],
+ item.getValueData(copy=False)[dataIndex])
+
+ return self.__pickingCache
+
+ def _getPickedIndex(self, x, y):
+ """Get data index of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data index at that point or '-'
+ """
+ picking = self._pickScatterData(x, y)
+ return '-' if picking is None else picking[0]
+
+ def _getPickedX(self, x, y):
+ """Returns X position snapped to scatter plot when close enough
+
+ :param float x:
+ :param float y:
+ :rtype: float
+ """
+ picking = self._pickScatterData(x, y)
+ return x if picking is None else picking[1]
+
+ def _getPickedY(self, x, y):
+ """Returns Y position snapped to scatter plot when close enough
+
+ :param float x:
+ :param float y:
+ :rtype: float
+ """
+ picking = self._pickScatterData(x, y)
+ return y if picking is None else picking[2]
+
+ def _getPickedValue(self, x, y):
+ """Get data value of top most scatter plot at position (x, y)
+
+ :param float x: X position in plot coordinates
+ :param float y: Y position in plot coordinates
+ :return: The data value at that point or '-'
+ """
+ picking = self._pickScatterData(x, y)
+ return '-' if picking is None else picking[3]
+
+ def _mouseInPlotArea(self, x, y):
+ """Clip mouse coordinates to plot area coordinates
+
+ :param float x: X position in pixels
+ :param float y: Y position in pixels
+ :return: (x, y) in data coordinates
+ """
+ plot = self.getPlotWidget()
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ xPlot = numpy.clip(x, left, left + width - 1)
+ yPlot = numpy.clip(y, top, top + height - 1)
+ return xPlot, yPlot
+
+ def getPlotWidget(self):
+ """Returns the :class:`~silx.gui.plot.PlotWidget` this window is based on.
+
+ :rtype: ~silx.gui.plot.PlotWidget
+ """
+ return self._plot()
+
+ def getPositionInfoWidget(self):
+ """Returns the widget display mouse coordinates information.
+
+ :rtype: ~silx.gui.plot.tools.PositionInfo
+ """
+ return self._positionInfo
+
+ def getMaskToolsWidget(self):
+ """Returns the widget controlling mask drawing
+
+ :rtype: ~silx.gui.plot.ScatterMaskToolsWidget
+ """
+ return self._maskToolsWidget
+
+ def getInteractiveModeToolBar(self):
+ """Returns QToolBar controlling interactive mode.
+
+ :rtype: ~silx.gui.plot.tools.InteractiveModeToolBar
+ """
+ return self._interactiveModeToolBar
+
+ def getScatterToolBar(self):
+ """Returns QToolBar providing scatter plot tools.
+
+ :rtype: ~silx.gui.plot.tools.ScatterToolBar
+ """
+ return self._scatterToolBar
+
+ def getScatterProfileToolBar(self):
+ """Returns QToolBar providing scatter profile tools.
+
+ :rtype: ~silx.gui.plot.tools.profile.ScatterProfileToolBar
+ """
+ return self._profileToolBar
+
+ def getOutputToolBar(self):
+ """Returns QToolBar containing save, copy and print actions
+
+ :rtype: ~silx.gui.plot.tools.OutputToolBar
+ """
+ return self._outputToolBar
+
+ def setColormap(self, colormap=None):
+ """Set the colormap for the displayed scatter and the
+ default plot colormap.
+
+ :param ~silx.gui.colors.Colormap colormap:
+ The description of the colormap.
+ """
+ self.getScatterItem().setColormap(colormap)
+ # Resilient to call to PlotWidget API (e.g., clear)
+ self.getPlotWidget().setDefaultColormap(colormap)
+
+ def getColormap(self):
+ """Return the colormap object in use.
+
+ :return: Colormap currently in use
+ :rtype: ~silx.gui.colors.Colormap
+ """
+ return self.getScatterItem().getColormap()
+
+ # Control displayed scatter plot
+
+ def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
+ """Set the data of the scatter plot.
+
+ To reset the scatter plot, set x, y and value to None.
+
+ :param Union[numpy.ndarray,None] x: X coordinates.
+ :param Union[numpy.ndarray,None] y: Y coordinates.
+ :param Union[numpy.ndarray,None] value:
+ The data corresponding to the value of the data points.
+ :param xerror: Values with the uncertainties on the x values.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :type xerror: A float, or a numpy.ndarray of float32.
+
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param alpha: Values with the transparency (between 0 and 1)
+ :type alpha: A float, or a numpy.ndarray of float32
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = () if x is None else x
+ y = () if y is None else y
+ value = () if value is None else value
+
+ self.getScatterItem().setData(
+ x=x, y=y, value=value, xerror=xerror, yerror=yerror, alpha=alpha, copy=copy)
+
+ @docstring(items.Scatter)
+ def getData(self, *args, **kwargs):
+ return self.getScatterItem().getData(*args, **kwargs)
+
+ def getScatterItem(self):
+ """Returns the plot item displaying the scatter data.
+
+ This allows to set the style of the displayed scatter.
+
+ :rtype: ~silx.gui.plot.items.Scatter
+ """
+ plot = self.getPlotWidget()
+ scatter = plot._getItem(kind='scatter', legend=self._SCATTER_LEGEND)
+ if scatter is None: # Resilient to call to PlotWidget API (e.g., clear)
+ scatter = self.__createEmptyScatter()
+ return scatter
+
+ # Convenient proxies
+
+ @docstring(PlotWidget)
+ def getXAxis(self, *args, **kwargs):
+ return self.getPlotWidget().getXAxis(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def getYAxis(self, *args, **kwargs):
+ return self.getPlotWidget().getYAxis(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def setGraphTitle(self, *args, **kwargs):
+ return self.getPlotWidget().setGraphTitle(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def getGraphTitle(self, *args, **kwargs):
+ return self.getPlotWidget().getGraphTitle(*args, **kwargs)
+
+ @docstring(PlotWidget)
+ def resetZoom(self, *args, **kwargs):
+ return self.getPlotWidget().resetZoom(*args, **kwargs)
+
+ @docstring(ScatterMaskToolsWidget)
+ def getSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().getSelectionMask(*args, **kwargs)
+
+ @docstring(ScatterMaskToolsWidget)
+ def setSelectionMask(self, *args, **kwargs):
+ return self.getMaskToolsWidget().setSelectionMask(*args, **kwargs)
diff --git a/src/silx/gui/plot/StackView.py b/src/silx/gui/plot/StackView.py
new file mode 100644
index 0000000..56793d7
--- /dev/null
+++ b/src/silx/gui/plot/StackView.py
@@ -0,0 +1,1254 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""QWidget displaying a 3D volume as a stack of 2D images.
+
+The :class:`StackView` class implements this widget.
+
+Basic usage of :class:`StackView` is through the following methods:
+
+- :meth:`StackView.getColormap`, :meth:`StackView.setColormap` to update the
+ default colormap to use and update the currently displayed image.
+- :meth:`StackView.setStack` to update the displayed image.
+
+The :class:`StackView` uses :class:`PlotWindow` and also
+exposes a subset of the :class:`silx.gui.plot.Plot` API for further control
+(plot title, axes labels, ...).
+
+The :class:`StackViewMainWindow` class implements a widget that adds a status
+bar displaying the 3D index and the value under the mouse cursor.
+
+Example::
+
+ import numpy
+ import sys
+ from silx.gui import qt
+ from silx.gui.plot.StackView import StackViewMainWindow
+
+
+ app = qt.QApplication(sys.argv[1:])
+
+ # synthetic data, stack of 100 images of size 200x300
+ mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (100, 200, 300)
+ )
+
+
+ sv = StackViewMainWindow()
+ sv.setColormap("jet", autoscale=True)
+ sv.setStack(mystack)
+ sv.setLabels(["1st dim (0-99)", "2nd dim (0-199)",
+ "3rd dim (0-299)"])
+ sv.show()
+
+ app.exec()
+
+"""
+
+__authors__ = ["P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+import numpy
+import logging
+
+import silx
+from silx.gui import qt
+from .. import icons
+from . import items, PlotWindow, actions
+from .items.image import ImageStack
+from ..colors import Colormap
+from ..colors import cursorColorForColormap
+from .tools import LimitsToolBar
+from .Profile import Profile3DToolBar
+from ..widgets.FrameBrowser import HorizontalSliderWithBrowser
+
+from silx.gui.plot.actions import control as actions_control
+from silx.gui.plot.actions import io as silx_io
+from silx.io.nxdata import save_NXdata
+from silx.utils.array_like import DatasetView, ListOfImages
+from silx.math import calibration
+from silx.utils.deprecation import deprecated_warning
+from silx.utils.deprecation import deprecated
+
+import h5py
+from silx.io.utils import is_dataset
+
+_logger = logging.getLogger(__name__)
+
+
+class StackView(qt.QMainWindow):
+ """Stack view widget, to display and browse through stack of
+ images.
+
+ The profile tool can be switched to "3D" mode, to compute the profile
+ on each image of the stack (not only the active image currently displayed)
+ and display the result as a slice.
+
+ :param QWidget parent: the Qt parent, or None
+ :param backend: The backend to use for the plot (default: matplotlib).
+ See :class:`.PlotWidget` for the list of supported backend.
+ :type backend: str or :class:`BackendBase.BackendBase`
+ :param bool resetzoom: Toggle visibility of reset zoom action.
+ :param bool autoScale: Toggle visibility of axes autoscale actions.
+ :param bool logScale: Toggle visibility of axes log scale actions.
+ :param bool grid: Toggle visibility of grid mode action.
+ :param bool colormap: Toggle visibility of colormap action.
+ :param bool aspectRatio: Toggle visibility of aspect ratio button.
+ :param bool yInverted: Toggle visibility of Y axis direction button.
+ :param bool copy: Toggle visibility of copy action.
+ :param bool save: Toggle visibility of save action.
+ :param bool print_: Toggle visibility of print action.
+ :param bool control: True to display an Options button with a sub-menu
+ to show legends, toggle crosshair and pan with arrows.
+ (Default: False)
+ :param position: True to display widget with (x, y) mouse position
+ (Default: False).
+ It also supports a list of (name, funct(x, y)->value)
+ to customize the displayed values.
+ See :class:`silx.gui.plot.PlotTools.PositionInfo`.
+ :param bool mask: Toggle visibilty of mask action.
+ """
+ # Qt signals
+ valueChanged = qt.Signal(object, object, object)
+ """Signals that the data value under the cursor has changed.
+
+ It provides: row, column, data value.
+ """
+
+ sigPlaneSelectionChanged = qt.Signal(int)
+ """Signal emitted when there is a change is perspective/displayed axes.
+
+ It provides the perspective as an integer, with the following meaning:
+
+ - 0: axis Y is the 2nd dimension, axis X is the 3rd dimension
+ - 1: axis Y is the 1st dimension, axis X is the 3rd dimension
+ - 2: axis Y is the 1st dimension, axis X is the 2nd dimension
+ """
+
+ sigStackChanged = qt.Signal(int)
+ """Signal emitted when the stack is changed.
+ This happens when a new volume is loaded, or when the current volume
+ is transposed (change in perspective).
+
+ The signal provides the size (number of pixels) of the stack.
+ This will be 0 if the stack is cleared, else it will be a positive
+ integer.
+ """
+
+ sigFrameChanged = qt.Signal(int)
+ """Signal emitter when the frame number has changed.
+
+ This signal provides the current frame number.
+ """
+
+ IMAGE_STACK_FILTER_NXDATA = 'Stack of images as NXdata (%s)' % silx_io._NEXUS_HDF5_EXT_STR
+
+
+ def __init__(self, parent=None, resetzoom=True, backend=None,
+ autoScale=False, logScale=False, grid=False,
+ colormap=True, aspectRatio=True, yinverted=True,
+ copy=True, save=True, print_=True, control=False,
+ position=None, mask=True):
+ qt.QMainWindow.__init__(self, parent)
+ if parent is not None:
+ # behave as a widget
+ self.setWindowFlags(qt.Qt.Widget)
+ else:
+ self.setWindowTitle('StackView')
+
+ self._stack = None
+ """Loaded stack, as a 3D array, a 3D dataset or a list of 2D arrays."""
+ self.__transposed_view = None
+ """View on :attr:`_stack` with the axes sorted, to have
+ the orthogonal dimension first"""
+ self._perspective = 0
+ """Orthogonal dimension (depth) in :attr:`_stack`"""
+
+ self._stackItem = ImageStack()
+ """Hold the item displaying the stack"""
+ imageLegend = '__StackView__image' + str(id(self))
+ self._stackItem.setName(imageLegend)
+
+ self.__autoscaleCmap = False
+ """Flag to disable/enable colormap auto-scaling
+ based on the min/max values of the entire 3D volume"""
+ self.__dimensionsLabels = ["Dimension 0", "Dimension 1",
+ "Dimension 2"]
+ """These labels are displayed on the X and Y axes.
+ :meth:`setLabels` updates this attribute."""
+
+ self._first_stack_dimension = 0
+ """Used for dimension labels and combobox"""
+
+ self._titleCallback = self._defaultTitleCallback
+ """Function returning the plot title based on the frame index.
+ It can be set to a custom function using :meth:`setTitleCallback`"""
+
+ self.calibrations3D = (calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration())
+
+ central_widget = qt.QWidget(self)
+
+ self._plot = PlotWindow(parent=central_widget, backend=backend,
+ resetzoom=resetzoom, autoScale=autoScale,
+ logScale=logScale, grid=grid,
+ curveStyle=False, colormap=colormap,
+ aspectRatio=aspectRatio, yInverted=yinverted,
+ copy=copy, save=save, print_=print_,
+ control=control, position=position,
+ roi=False, mask=mask)
+ self._plot.addItem(self._stackItem)
+ self._plot.getIntensityHistogramAction().setVisible(True)
+ self.sigInteractiveModeChanged = self._plot.sigInteractiveModeChanged
+ self.sigActiveImageChanged = self._plot.sigActiveImageChanged
+ self.sigPlotSignal = self._plot.sigPlotSignal
+
+ if silx.config.DEFAULT_PLOT_IMAGE_Y_AXIS_ORIENTATION == 'downward':
+ self._plot.getYAxis().setInverted(True)
+
+ self._addColorBarAction()
+
+ self._profileToolBar = Profile3DToolBar(parent=self._plot,
+ stackview=self)
+ self._plot.addToolBar(self._profileToolBar)
+ self._plot.getXAxis().setLabel('Columns')
+ self._plot.getYAxis().setLabel('Rows')
+ self._plot.sigPlotSignal.connect(self._plotCallback)
+ self._plot.getSaveAction().setFileFilter('image', self.IMAGE_STACK_FILTER_NXDATA, func=self._saveImageStack, appendToFile=True)
+
+ self.__planeSelection = PlanesWidget(self._plot)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
+
+ self._browser_label = qt.QLabel("Image index (Dim0):")
+
+ self._browser = HorizontalSliderWithBrowser(central_widget)
+ self._browser.setRange(0, 0)
+ self._browser.valueChanged[int].connect(self.__updateFrameNumber)
+ self._browser.setEnabled(False)
+
+ layout = qt.QGridLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._plot, 0, 0, 1, 3)
+ layout.addWidget(self.__planeSelection, 1, 0)
+ layout.addWidget(self._browser_label, 1, 1)
+ layout.addWidget(self._browser, 1, 2)
+
+ central_widget.setLayout(layout)
+ self.setCentralWidget(central_widget)
+
+ # clear profile lines when the perspective changes (plane browsed changed)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(
+ self._profileToolBar.clearProfile)
+
+ def _saveImageStack(self, plot, filename, nameFilter):
+ """Save all images from the stack into a volume.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ :raises: ValueError if nameFilter is invalid
+ """
+ if not nameFilter == self.IMAGE_STACK_FILTER_NXDATA:
+ raise ValueError('Wrong callback')
+ entryPath = silx_io.SaveAction._selectWriteableOutputGroup(filename, parent=self)
+ if entryPath is None:
+ return False
+ return save_NXdata(filename,
+ nxentry_name=entryPath,
+ signal=self.getStack(copy=False, returnNumpyArray=True)[0],
+ signal_name="image_stack")
+
+ def _addColorBarAction(self):
+ self._plot.getColorBarWidget().setVisible(True)
+ actions = self._plot.toolBar().actions()
+ for index, action in enumerate(actions):
+ if action is self._plot.getColormapAction():
+ break
+ self._colorbarAction = actions_control.ColorBarAction(self._plot, self._plot)
+ self._plot.toolBar().insertAction(actions[index + 1], self._colorbarAction)
+
+ def _plotCallback(self, eventDict):
+ """Callback for plot events.
+
+ Emit :attr:`valueChanged` signal, with (x, y, value) tuple of the
+ cursor location in the plot."""
+ if eventDict['event'] == 'mouseMoved':
+ activeImage = self.getActiveImage()
+ if activeImage is not None:
+ data = activeImage.getData()
+ height, width = data.shape
+
+ # Get corresponding coordinate in image
+ origin = activeImage.getOrigin()
+ scale = activeImage.getScale()
+ x = int((eventDict['x'] - origin[0]) / scale[0])
+ y = int((eventDict['y'] - origin[1]) / scale[1])
+
+ if 0 <= x < width and 0 <= y < height:
+ self.valueChanged.emit(float(x), float(y),
+ data[y][x])
+ else:
+ self.valueChanged.emit(float(x), float(y),
+ None)
+
+ def getPerspective(self):
+ """Returns the index of the dimension the stack is browsed with
+
+ Possible values are: 0, 1, or 2.
+
+ :rtype: int
+ """
+ return self._perspective
+
+ def setPerspective(self, perspective):
+ """Set the index of the dimension the stack is browsed with:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param int perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ if perspective == self._perspective:
+ return
+ else:
+ if perspective > 2 or perspective < 0:
+ raise ValueError(
+ "Perspective must be 0, 1 or 2, not %s" % perspective)
+
+ self._perspective = int(perspective)
+ self.__createTransposedView()
+ self.__updateFrameNumber(self._browser.value())
+ self._plot.resetZoom()
+ self.__updatePlotLabels()
+ self._updateTitle()
+ self._browser_label.setText("Image index (Dim%d):" %
+ (self._first_stack_dimension + perspective))
+
+ self.sigPlaneSelectionChanged.emit(perspective)
+ self.sigStackChanged.emit(self._stack.size if
+ self._stack is not None else 0)
+ self.__planeSelection.sigPlaneSelectionChanged.disconnect(self.setPerspective)
+ self.__planeSelection.setPerspective(self._perspective)
+ self.__planeSelection.sigPlaneSelectionChanged.connect(self.setPerspective)
+
+ def __updatePlotLabels(self):
+ """Update plot axes labels depending on perspective"""
+ y, x = (1, 2) if self._perspective == 0 else \
+ (0, 2) if self._perspective == 1 else (0, 1)
+ self.setGraphXLabel(self.__dimensionsLabels[x])
+ self.setGraphYLabel(self.__dimensionsLabels[y])
+
+ def __createTransposedView(self):
+ """Create the new view on the stack depending on the perspective
+ (set orthogonal axis browsed on the viewer as first dimension)
+ """
+ assert self._stack is not None
+ assert 0 <= self._perspective < 3
+
+ # ensure we have the stack encapsulated in an array-like object
+ # having a transpose() method
+ if isinstance(self._stack, numpy.ndarray):
+ self.__transposed_view = self._stack
+
+ elif is_dataset(self._stack) or isinstance(self._stack, DatasetView):
+ self.__transposed_view = DatasetView(self._stack)
+
+ elif isinstance(self._stack, ListOfImages):
+ self.__transposed_view = ListOfImages(self._stack)
+
+ # transpose the array-like object if necessary
+ if self._perspective == 1:
+ self.__transposed_view = self.__transposed_view.transpose((1, 0, 2))
+ elif self._perspective == 2:
+ self.__transposed_view = self.__transposed_view.transpose((2, 0, 1))
+
+ self._browser.setRange(0, self.__transposed_view.shape[0] - 1)
+ self._browser.setValue(0)
+
+ # Update the item structure
+ self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
+ self._stackItem.setColormap(self.getColormap())
+ self._stackItem.setOrigin(self._getImageOrigin())
+ self._stackItem.setScale(self._getImageScale())
+
+ def __updateFrameNumber(self, index):
+ """Update the current image.
+
+ :param index: index of the frame to be displayed
+ """
+ if self.__transposed_view is None:
+ # no data set
+ return
+
+ self._stackItem.setStackPosition(index)
+
+ self._updateTitle()
+ self.sigFrameChanged.emit(index)
+
+ def _set3DScaleAndOrigin(self, calibrations):
+ """Set scale and origin for all 3 axes, to be used when plotting
+ an image.
+
+ See setStack for parameter documentation
+ """
+ if calibrations is None:
+ self.calibrations3D = (calibration.NoCalibration(),
+ calibration.NoCalibration(),
+ calibration.NoCalibration())
+ else:
+ self.calibrations3D = []
+ for i, calib in enumerate(calibrations):
+ if hasattr(calib, "__len__") and len(calib) == 2:
+ calib = calibration.LinearCalibration(calib[0], calib[1])
+ elif calib is None:
+ calib = calibration.NoCalibration()
+ elif not isinstance(calib, calibration.AbstractCalibration):
+ raise TypeError("calibration must be a 2-tuple, None or" +
+ " an instance of an AbstractCalibration " +
+ "subclass")
+ elif not calib.is_affine():
+ _logger.warning(
+ "Calibration for dimension %d is not linear, "
+ "it will be ignored for scaling the graph axes.",
+ i)
+ self.calibrations3D.append(calib)
+
+ def getCalibrations(self, order='array'):
+ """Returns currently used calibrations for each axis
+
+ Returned calibrations might differ from the ones that were set as
+ non-linear calibrations used for image axes are temporarily ignored.
+
+ :param str order:
+ 'array' to sort calibrations as data array (dim0, dim1, dim2),
+ 'axes' to sort calibrations as currently selected x, y and z axes.
+ :return: Calibrations ordered depending on order
+ :rtype: List[~silx.math.calibration.AbstractCalibration]
+ """
+ assert order in ('array', 'axes')
+ calibs = []
+
+ # filter out non-linear calibration for graph axes
+ for index, calib in enumerate(self.calibrations3D):
+ if index != self._perspective and not calib.is_affine():
+ calib = calibration.NoCalibration()
+ calibs.append(calib)
+
+ if order == 'axes': # Move 'z' axis to the end
+ xy_dims = [d for d in (0, 1, 2) if d != self._perspective]
+ calibs = [calibs[max(xy_dims)],
+ calibs[min(xy_dims)],
+ calibs[self._perspective]]
+
+ return tuple(calibs)
+
+ def _getImageScale(self):
+ """
+ :return: 2-tuple (XScale, YScale) for current image view
+ """
+ xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
+ return xcalib.get_slope(), ycalib.get_slope()
+
+ def _getImageOrigin(self):
+ """
+ :return: 2-tuple (XOrigin, YOrigin) for current image view
+ """
+ xcalib, ycalib, _zcalib = self.getCalibrations(order='axes')
+ return xcalib(0), ycalib(0)
+
+ def _getImageZ(self, index):
+ """
+ :param idx: 0-based image index in the stack
+ :return: calibrated Z value corresponding to the image idx
+ """
+ _xcalib, _ycalib, zcalib = self.getCalibrations(order='axes')
+ return zcalib(index)
+
+ def _updateTitle(self):
+ frame_idx = self._browser.value()
+ self._plot.setGraphTitle(self._titleCallback(frame_idx))
+
+ def _defaultTitleCallback(self, index):
+ return "Image z=%g" % self._getImageZ(index)
+
+ # public API, stack specific methods
+ def setStack(self, stack, perspective=None, reset=True, calibrations=None):
+ """Set the 3D stack.
+
+ The perspective parameter is used to define which dimension of the 3D
+ array is to be used as frame index. The lowest remaining dimension
+ number is the row index of the displayed image (Y axis), and the highest
+ remaining dimension is the column index (X axis).
+
+ :param stack: 3D stack, or `None` to clear plot.
+ :type stack: 3D numpy.ndarray, or 3D h5py.Dataset, or list/tuple of 2D
+ numpy arrays, or None.
+ :param int perspective: Dimension for the frame index: 0, 1 or 2.
+ Use ``None`` to keep the current perspective (default).
+ :param bool reset: Whether to reset zoom or not.
+ :param calibrations: Sequence of 3 calibration objects for each axis.
+ These objects can be a subclass of :class:`AbstractCalibration`,
+ or 2-tuples *(a, b)* where *a* is the y-intercept and *b* is the
+ slope of a linear calibration (:math:`x \\mapsto a + b x`)
+ """
+ if stack is None:
+ self.clear()
+ self.sigStackChanged.emit(0)
+ return
+
+ self._set3DScaleAndOrigin(calibrations)
+
+ # stack as list of 2D arrays: must be converted into an array_like
+ if not isinstance(stack, numpy.ndarray):
+ if not is_dataset(stack):
+ try:
+ assert hasattr(stack, "__len__")
+ for img in stack:
+ assert hasattr(img, "shape")
+ assert len(img.shape) == 2
+ except AssertionError:
+ raise ValueError(
+ "Stack must be a 3D array/dataset or a list of " +
+ "2D arrays.")
+ stack = ListOfImages(stack)
+
+ assert len(stack.shape) == 3, "data must be 3D"
+
+ self._stack = stack
+ self.__createTransposedView()
+
+ perspective_changed = False
+ if perspective not in [None, self._perspective]:
+ perspective_changed = True
+ self.setPerspective(perspective)
+
+ if self.__autoscaleCmap:
+ self.scaleColormapRangeToStack()
+
+ # init plot
+ self._stackItem.setStackData(self.__transposed_view, 0, copy=False)
+ self._stackItem.setColormap(self.getColormap())
+ self._stackItem.setOrigin(self._getImageOrigin())
+ self._stackItem.setScale(self._getImageScale())
+ self._stackItem.setVisible(True)
+
+ # Put back the item in the plot in case it was cleared
+ exists = self._plot.getImage(self._stackItem.getName())
+ if exists is None:
+ self._plot.addItem(self._stackItem)
+
+ self._plot.setActiveImage(self._stackItem.getName())
+ self.__updatePlotLabels()
+ self._updateTitle()
+
+ if reset:
+ self._plot.resetZoom()
+
+ # enable and init browser
+ self._browser.setEnabled(True)
+
+ if not perspective_changed: # avoid double signal (see self.setPerspective)
+ self.sigStackChanged.emit(stack.size)
+
+ def getStack(self, copy=True, returnNumpyArray=False):
+ """Get the original stack, as a 3D array or dataset.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ returnNumpyArray is True, a copy will be made anyway.
+ :param bool returnNumpyArray: If True, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ if self._stack is None:
+ return None
+
+ image = self._stackItem
+ colormap = image.getColormap()
+
+ params = {
+ 'info': image.getInfo(),
+ 'origin': image.getOrigin(),
+ 'scale': image.getScale(),
+ 'z': image.getZValue(),
+ 'selectable': image.isSelectable(),
+ 'draggable': image.isDraggable(),
+ 'colormap': colormap,
+ 'xlabel': image.getXLabel(),
+ 'ylabel': image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self._stack, copy=copy), params
+
+ # if a list of 2D arrays was cast into a ListOfImages,
+ # return the original list
+ if isinstance(self._stack, ListOfImages):
+ return self._stack.images, params
+
+ return self._stack, params
+
+ def getCurrentView(self, copy=True, returnNumpyArray=False):
+ """Get the stack, as it is currently displayed.
+
+ The first index of the returned stack is always the frame
+ index. If the perspective has been changed in the widget since the
+ data was first loaded, this will be reflected in the order of the
+ dimensions of the returned object.
+
+ The output has the form: [data, params]
+ where params is a dictionary containing display parameters.
+
+ :param bool copy: If True (default), then the object is copied
+ and returned as a numpy array.
+ Else, a reference to original data is returned, if possible.
+ If the original data is not a numpy array and parameter
+ `returnNumpyArray` is `True`, a copy will be made anyway.
+ :param bool returnNumpyArray: If `True`, the returned object is
+ guaranteed to be a numpy array.
+ :return: 3D stack and parameters.
+ :rtype: (numpy.ndarray, dict)
+ """
+ image = self.getActiveImage()
+ if image is None:
+ return None
+
+ if isinstance(image, items.ColormapMixIn):
+ colormap = image.getColormap()
+ else:
+ colormap = None
+
+ params = {
+ 'info': image.getInfo(),
+ 'origin': image.getOrigin(),
+ 'scale': image.getScale(),
+ 'z': image.getZValue(),
+ 'selectable': image.isSelectable(),
+ 'draggable': image.isDraggable(),
+ 'colormap': colormap,
+ 'xlabel': image.getXLabel(),
+ 'ylabel': image.getYLabel(),
+ }
+ if returnNumpyArray or copy:
+ return numpy.array(self.__transposed_view, copy=copy), params
+ return self.__transposed_view, params
+
+ def setFrameNumber(self, number):
+ """Set the frame selection to a specific value
+
+ :param int number: Number of the frame
+ """
+ self._browser.setValue(number)
+
+ def getFrameNumber(self):
+ """Set the frame selection to a specific value
+
+ :return: Index of currently displayed frame
+ :rtype: int
+ """
+ return self._browser.value()
+
+ def setFirstStackDimension(self, first_stack_dimension):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ old_state = self.__planeSelection.blockSignals(True)
+ self.__planeSelection.setFirstStackDimension(first_stack_dimension)
+ self.__planeSelection.blockSignals(old_state)
+ self._first_stack_dimension = first_stack_dimension
+ self._browser_label.setText("Image index (Dim%d):" % first_stack_dimension)
+
+ def setTitleCallback(self, callback):
+ """Set a user defined function to generate the plot title based on the
+ image/frame index.
+
+ The callback function must accept an integer as a its first positional
+ parameter and must not require any other mandatory parameter.
+ It must return a string.
+
+ To switch back the default behavior, you can pass ``None``::
+
+ mystackview.setTitleCallback(None)
+
+ To have no title, pass a function that returns an empty string::
+
+ mystackview.setTitleCallback(lambda idx: "")
+
+ :param callback: Callback function generating the stack title based
+ on the frame number.
+ """
+
+ if callback is None:
+ self._titleCallback = self._defaultTitleCallback
+ elif callable(callback):
+ self._titleCallback = callback
+ else:
+ raise TypeError("Provided callback is not callable")
+ self._updateTitle()
+
+ def clear(self):
+ """Clear the widget:
+
+ - clear the plot
+ - clear the loaded data volume
+ """
+ self._stack = None
+ self.__transposed_view = None
+ self._perspective = 0
+ self._browser.setEnabled(False)
+ # reset browser range
+ self._browser.setRange(0, 0)
+ self._plot.clear()
+
+ def setLabels(self, labels=None):
+ """Set the labels to be displayed on the plot axes.
+
+ You must provide a sequence of 3 strings, corresponding to the 3
+ dimensions of the original data volume.
+ The proper label will automatically be selected for each plot axis
+ when the volume is rotated (when different axes are selected as the
+ X and Y axes).
+
+ :param List[str] labels: 3 labels corresponding to the 3 dimensions
+ of the data volumes.
+ """
+
+ default_labels = ["Dimension %d" % self._first_stack_dimension,
+ "Dimension %d" % (self._first_stack_dimension + 1),
+ "Dimension %d" % (self._first_stack_dimension + 2)]
+ if labels is None:
+ new_labels = default_labels
+ else:
+ # filter-out None
+ new_labels = []
+ for i, label in enumerate(labels):
+ new_labels.append(label or default_labels[i])
+
+ self.__dimensionsLabels = new_labels
+ self.__updatePlotLabels()
+
+ def getLabels(self):
+ """Return dimension labels displayed on the plot axes
+
+ :return: List of three strings corresponding to the 3 dimensions
+ of the stack: (name_dim0, name_dim1, name_dim2)
+ """
+ return self.__dimensionsLabels
+
+ def getColormap(self):
+ """Get the current colormap description.
+
+ :return: A description of the current colormap.
+ See :meth:`setColormap` for details.
+ :rtype: dict
+ """
+ # "default" colormap used by addImage when image is added without
+ # specifying a special colormap
+ return self._plot.getDefaultColormap()
+
+ def scaleColormapRangeToStack(self):
+ """Scale colormap range according to current stack data.
+
+ If no stack has been set through :meth:`setStack`, this has no effect.
+
+ The range scaling mode is given by current :class:`Colormap`'s
+ :meth:`Colormap.getAutoscaleMode`.
+ """
+ stack = self.getStack(copy=False, returnNumpyArray=True)
+ if stack is None:
+ return # No-op
+
+ colormap = self.getColormap()
+ vmin, vmax = colormap.getColormapRange(data=stack[0])
+ colormap.setVRange(vmin=vmin, vmax=vmax)
+
+ def setColormap(self, colormap=None, normalization=None,
+ autoscale=None, vmin=None, vmax=None, colors=None):
+ """Set the colormap and update active image.
+
+ Parameters that are not provided are taken from the current colormap.
+
+ The colormap parameter can also be a dict with the following keys:
+
+ - *name*: string. The colormap to use:
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ - *normalization*: string. The mapping to use for the colormap:
+ either 'linear' or 'log'.
+ - *autoscale*: bool. Whether to use autoscale (True) or range
+ provided by keys
+ 'vmin' and 'vmax' (False).
+ - *vmin*: float. The minimum value of the range to use if 'autoscale'
+ is False.
+ - *vmax*: float. The maximum value of the range to use if 'autoscale'
+ is False.
+ - *colors*: optional. Nx3 or Nx4 array of float in [0, 1] or uint8.
+ List of RGB or RGBA colors to use (only if name is None)
+
+ :param colormap: Name of the colormap in
+ 'gray', 'reversed gray', 'temperature', 'red', 'green', 'blue'.
+ Or a :class`.Colormap` object.
+ :type colormap: dict or str.
+ :param str normalization: Colormap mapping: 'linear' or 'log'.
+ :param bool autoscale: Whether to use autoscale or [vmin, vmax] range.
+ Default value of autoscale is False. This option is not compatible
+ with h5py datasets.
+ :param float vmin: The minimum value of the range to use if
+ 'autoscale' is False.
+ :param float vmax: The maximum value of the range to use if
+ 'autoscale' is False.
+ :param numpy.ndarray colors: Only used if name is None.
+ Custom colormap colors as Nx3 or Nx4 RGB or RGBA arrays
+ """
+ # if is a colormap object or a dictionary
+ if isinstance(colormap, Colormap) or isinstance(colormap, dict):
+ # Support colormap parameter as a dict
+ errmsg = "If colormap is provided as a Colormap object, all other parameters"
+ errmsg += " must not be specified when calling setColormap"
+ assert normalization is None, errmsg
+ assert autoscale is None, errmsg
+ assert vmin is None, errmsg
+ assert vmax is None, errmsg
+ assert colors is None, errmsg
+
+ if isinstance(colormap, dict):
+ reason = 'colormap parameter should now be an object'
+ replacement = 'Colormap()'
+ since_version = '0.6'
+ deprecated_warning(type_='function',
+ name='setColormap',
+ reason=reason,
+ replacement=replacement,
+ since_version=since_version)
+ _colormap = Colormap._fromDict(colormap)
+ else:
+ _colormap = colormap
+ else:
+ norm = normalization if normalization is not None else 'linear'
+ name = colormap if colormap is not None else 'gray'
+ _colormap = Colormap(name=name,
+ normalization=norm,
+ vmin=vmin,
+ vmax=vmax,
+ colors=colors)
+
+ if autoscale is not None:
+ deprecated_warning(
+ type_='function',
+ name='setColormap',
+ reason='autoscale argument is replaced by a method',
+ replacement='scaleColormapRangeToStack',
+ since_version='0.14')
+ self.__autoscaleCmap = bool(autoscale)
+
+ cursorColor = cursorColorForColormap(_colormap.getName())
+ self._plot.setInteractiveMode('zoom', color=cursorColor)
+
+ self._plot.setDefaultColormap(_colormap)
+
+ # Update active image colormap
+ activeImage = self.getActiveImage()
+ if isinstance(activeImage, items.ColormapMixIn):
+ activeImage.setColormap(self.getColormap())
+
+ if self.__autoscaleCmap:
+ # scaleColormapRangeToStack needs to be called **after**
+ # setDefaultColormap so getColormap returns the right colormap
+ self.scaleColormapRangeToStack()
+
+
+ @deprecated(replacement="getPlotWidget", since_version="0.13")
+ def getPlot(self):
+ return self.getPlotWidget()
+
+ def getPlotWidget(self):
+ """Return the :class:`PlotWidget`.
+
+ This gives access to advanced plot configuration options.
+ Be warned that modifying the plot can cause issues, and some changes
+ you make to the plot could be overwritten by the :class:`StackView`
+ widget's internal methods and callbacks.
+
+ :return: instance of :class:`PlotWidget` used in widget
+ """
+ return self._plot
+
+ def setOptionVisible(self, isVisible):
+ """
+ Set the visibility of the browsing options.
+
+ :param bool isVisible: True to have the options visible, else False
+ """
+ self._browser.setVisible(isVisible)
+ self.__planeSelection.setVisible(isVisible)
+
+ # proxies to PlotWidget or PlotWindow methods
+ def getProfileToolbar(self):
+ """Profile tools attached to this plot
+ """
+ return self._profileToolBar
+
+ def getGraphTitle(self):
+ """Return the plot main title as a str.
+ """
+ return self._plot.getGraphTitle()
+
+ def setGraphTitle(self, title=""):
+ """Set the plot main title.
+
+ :param str title: Main title of the plot (default: '')
+ """
+ return self._plot.setGraphTitle(title)
+
+ def getGraphXLabel(self):
+ """Return the current horizontal axis label as a str.
+ """
+ return self._plot.getXAxis().getLabel()
+
+ def setGraphXLabel(self, label=None):
+ """Set the plot horizontal axis label.
+
+ :param str label: The horizontal axis label
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 2 else 2]
+ self._plot.getXAxis().setLabel(label)
+
+ def getGraphYLabel(self, axis='left'):
+ """Return the current vertical axis label as a str.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ """
+ return self._plot.getYAxis().getLabel(axis)
+
+ def setGraphYLabel(self, label=None, axis='left'):
+ """Set the vertical axis label on the plot.
+
+ :param str label: The Y axis label
+ :param str axis: The Y axis for which to set the label (left or right)
+ """
+ if label is None:
+ label = self.__dimensionsLabels[1 if self._perspective == 0 else 0]
+ self._plot.getYAxis(axis=axis).setLabel(label)
+
+ def resetZoom(self):
+ """Reset the plot limits to the bounds of the data and redraw the plot.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().resetZoom()
+ """
+ self._plot.resetZoom()
+
+ def setYAxisInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().setYAxisInverted(flag)
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ self._plot.setYAxisInverted(flag)
+
+ def isYAxisInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().isYAxisInverted()"""
+ return self._plot.isYAxisInverted()
+
+ def getSupportedColormaps(self):
+ """Get the supported colormap names as a tuple of str.
+
+ The list should at least contain and start by:
+ ('gray', 'reversed gray', 'temperature', 'red', 'green', 'blue')
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().getSupportedColormaps()
+ """
+ return self._plot.getSupportedColormaps()
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().isKeepDataAspectRatio()"""
+ return self._plot.isKeepDataAspectRatio()
+
+ def setKeepDataAspectRatio(self, flag=True):
+ """Set whether the plot keeps data aspect ratio or not.
+
+ This method is a simple proxy to the legacy :class:`PlotWidget` method
+ of the same name. Using the object oriented approach is now
+ preferred::
+
+ stackview.getPlot().setKeepDataAspectRatio(flag)
+
+ :param bool flag: True to respect data aspect ratio
+ """
+ self._plot.setKeepDataAspectRatio(flag)
+
+ # kind of private methods, but needed by Profile
+ def getActiveImage(self, just_legend=False):
+ """Returns the stack image object.
+ """
+ if just_legend:
+ return self._stackItem.getName()
+ return self._stackItem
+
+ def getColorBarAction(self):
+ """Returns the action managing the visibility of the colorbar.
+
+ .. warning:: to show/hide the plot colorbar call directly the ColorBar
+ widget using getColorBarWidget()
+
+ :rtype: QAction
+ """
+ return self._colorbarAction
+
+ def remove(self, legend=None,
+ kind=('curve', 'image', 'item', 'marker')):
+ """See :meth:`Plot.Plot.remove`"""
+ self._plot.remove(legend, kind)
+
+ def setInteractiveMode(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.setInteractiveMode`
+ """
+ self._plot.setInteractiveMode(*args, **kwargs)
+
+ @deprecated(replacement="addShape", since_version="0.13")
+ def addItem(self, *args, **kwargs):
+ self.addShape(*args, **kwargs)
+
+ def addShape(self, *args, **kwargs):
+ """
+ See :meth:`Plot.Plot.addShape`
+ """
+ self._plot.addShape(*args, **kwargs)
+
+
+class PlanesWidget(qt.QWidget):
+ """Widget for the plane/perspective selection
+
+ :param parent: the parent QWidget
+ """
+ sigPlaneSelectionChanged = qt.Signal(int)
+
+ def __init__(self, parent):
+ super(PlanesWidget, self).__init__(parent)
+
+ self.setSizePolicy(qt.QSizePolicy.Minimum, qt.QSizePolicy.Minimum)
+ layout0 = qt.QHBoxLayout()
+ self.setLayout(layout0)
+ layout0.setContentsMargins(0, 0, 0, 0)
+
+ layout0.addWidget(qt.QLabel("Axes selection:"))
+
+ # By default, the first dimension (dim0) is the frame index/depth/z,
+ # the second dimension is the image row number/y axis
+ # and the third dimension is the image column index/x axis
+
+ # 1
+ # | 0
+ # |/__2
+ self.qcbAxisSelection = qt.QComboBox(self)
+ self._setCBChoices(first_stack_dimension=0)
+ self.qcbAxisSelection.currentIndexChanged[int].connect(
+ self.__planeSelectionChanged)
+
+ layout0.addWidget(self.qcbAxisSelection)
+
+ def __planeSelectionChanged(self, idx):
+ """Callback function when the combobox selection changes
+
+ idx is the dimension number orthogonal to the slice plane,
+ following the convention:
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+ """
+ self.sigPlaneSelectionChanged.emit(idx)
+
+ def _setCBChoices(self, first_stack_dimension):
+ self.qcbAxisSelection.clear()
+
+ dim1dim2 = 'Dim%d-Dim%d' % (first_stack_dimension + 1,
+ first_stack_dimension + 2)
+ dim0dim2 = 'Dim%d-Dim%d' % (first_stack_dimension,
+ first_stack_dimension + 2)
+ dim0dim1 = 'Dim%d-Dim%d' % (first_stack_dimension,
+ first_stack_dimension + 1)
+
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-front"), dim1dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-bottom"), dim0dim2)
+ self.qcbAxisSelection.addItem(icons.getQIcon("cube-left"), dim0dim1)
+
+ def setFirstStackDimension(self, first_stack_dim):
+ """When viewing the last 3 dimensions of an n-D array (n>3), you can
+ use this method to change the text in the combobox.
+
+ For instance, for a 7-D array, first stack dim is 4, so the default
+ "Dim1-Dim2" text should be replaced with "Dim5-Dim6" (dimensions
+ numbers are 0-based).
+
+ :param int first_stack_dim: First stack dimension (n-3) when viewing the
+ last 3 dimensions of an n-D array.
+ """
+ self._setCBChoices(first_stack_dim)
+
+ def setPerspective(self, perspective):
+ """Update the combobox selection.
+
+ - slice plane Dim1-Dim2: perspective 0
+ - slice plane Dim0-Dim2: perspective 1
+ - slice plane Dim0-Dim1: perspective 2
+
+ :param perspective: Orthogonal dimension number (0, 1, or 2)
+ """
+ self.qcbAxisSelection.setCurrentIndex(perspective)
+
+
+class StackViewMainWindow(StackView):
+ """This class is a :class:`StackView` with a menu, an additional toolbar
+ to set the plot limits, and a status bar to display the value and 3D
+ index of the data samples hovered by the mouse cursor.
+
+ :param QWidget parent: Parent widget, or None
+ """
+ def __init__(self, parent=None):
+ self._dataInfo = None
+ super(StackViewMainWindow, self).__init__(parent)
+ self.setWindowFlags(qt.Qt.Window)
+
+ # Add toolbars and status bar
+ self.addToolBar(qt.Qt.BottomToolBarArea,
+ LimitsToolBar(plot=self._plot))
+
+ self.statusBar()
+
+ menu = self.menuBar().addMenu('File')
+ menu.addAction(self._plot.getOutputToolBar().getSaveAction())
+ menu.addAction(self._plot.getOutputToolBar().getPrintAction())
+ menu.addSeparator()
+ action = menu.addAction('Quit')
+ action.triggered[bool].connect(qt.QApplication.instance().quit)
+
+ menu = self.menuBar().addMenu('Edit')
+ menu.addAction(self._plot.getOutputToolBar().getCopyAction())
+ menu.addSeparator()
+ menu.addAction(self._plot.getResetZoomAction())
+ menu.addAction(self._plot.getColormapAction())
+ menu.addAction(self.getColorBarAction())
+
+ menu.addAction(actions.control.KeepAspectRatioAction(self._plot, self))
+ menu.addAction(actions.control.YAxisInvertedAction(self._plot, self))
+
+ menu = self.menuBar().addMenu('Profile')
+ profileToolBar = self._profileToolBar
+ menu.addAction(profileToolBar.hLineAction)
+ menu.addAction(profileToolBar.vLineAction)
+ menu.addAction(profileToolBar.lineAction)
+ menu.addAction(profileToolBar.crossAction)
+ menu.addSeparator()
+ menu.addAction(profileToolBar._editor)
+ menu.addSeparator()
+ menu.addAction(profileToolBar.clearAction)
+
+ # Connect to StackView's signal
+ self.valueChanged.connect(self._statusBarSlot)
+
+ def _statusBarSlot(self, x, y, value):
+ """Update status bar with coordinates/value from plots."""
+ # todo (after implementing calibration):
+ # - use floats for (x, y, z)
+ # - display both indices (dim0, dim1, dim2) and (x, y, z)
+ msg = "Cursor out of range"
+ if x is not None and y is not None:
+ img_idx = self._browser.value()
+
+ if self._perspective == 0:
+ dim0, dim1, dim2 = img_idx, int(y), int(x)
+ elif self._perspective == 1:
+ dim0, dim1, dim2 = int(y), img_idx, int(x)
+ elif self._perspective == 2:
+ dim0, dim1, dim2 = int(y), int(x), img_idx
+
+ msg = 'Position: (%d, %d, %d)' % (dim0, dim1, dim2)
+ if value is not None:
+ msg += ', Value: %g' % value
+ if self._dataInfo is not None:
+ msg = self._dataInfo + ', ' + msg
+
+ self.statusBar().showMessage(msg)
+
+ def setStack(self, stack, *args, **kwargs):
+ """Set the displayed stack.
+
+ See :meth:`StackView.setStack` for details.
+ """
+ if hasattr(stack, 'dtype') and hasattr(stack, 'shape'):
+ assert len(stack.shape) == 3
+ nframes, height, width = stack.shape
+ self._dataInfo = 'Data: %dx%dx%d (%s)' % (nframes, height, width,
+ str(stack.dtype))
+ self.statusBar().showMessage(self._dataInfo)
+ else:
+ self._dataInfo = None
+
+ # Set the new stack in StackView widget
+ super(StackViewMainWindow, self).setStack(stack, *args, **kwargs)
+ self.setStatusBar(None)
diff --git a/src/silx/gui/plot/StatsWidget.py b/src/silx/gui/plot/StatsWidget.py
new file mode 100644
index 0000000..00f78d0
--- /dev/null
+++ b/src/silx/gui/plot/StatsWidget.py
@@ -0,0 +1,1658 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+Module containing widgets displaying stats from items of a plot.
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "24/07/2018"
+
+
+from collections import OrderedDict
+from contextlib import contextmanager
+import logging
+import weakref
+import functools
+import numpy
+import enum
+from silx.utils.proxy import docstring
+from silx.utils.enum import Enum as _Enum
+from silx.gui import qt
+from silx.gui import icons
+from silx.gui.plot import stats as statsmdl
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.gui.plot.stats.statshandler import StatsHandler, StatFormatter
+from silx.gui.plot.items.core import ItemChangedType
+from silx.gui.widgets.FlowLayout import FlowLayout
+from . import PlotWidget
+from . import items as plotitems
+
+
+_logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class UpdateMode(_Enum):
+ AUTO = 'auto'
+ MANUAL = 'manual'
+
+
+# Helper class to handle specific calls to PlotWidget and SceneWidget
+
+
+class _Wrapper(qt.QObject):
+ """Base class for connection with PlotWidget and SceneWidget.
+
+ This class is used when no PlotWidget or SceneWidget is connected.
+
+ :param plot: The plot to be used
+ """
+
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added.
+
+ It provides the added item.
+ """
+
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is (about to be) removed.
+
+ It provides the removed item.
+ """
+
+ sigCurrentChanged = qt.Signal(object)
+ """Signal emitted when the current item has changed.
+
+ It provides the current item.
+ """
+
+ sigVisibleDataChanged = qt.Signal()
+ """Signal emitted when the visible data area has changed"""
+
+ def __init__(self, plot=None):
+ super(_Wrapper, self).__init__(parent=None)
+ self._plotRef = None if plot is None else weakref.ref(plot)
+
+ def getPlot(self):
+ """Returns the plot attached to this widget"""
+ return None if self._plotRef is None else self._plotRef()
+
+ def getItems(self):
+ """Returns the list of items in the plot
+
+ :rtype: List[object]
+ """
+ return ()
+
+ def getSelectedItems(self):
+ """Returns the list of selected items in the plot
+
+ :rtype: List[object]
+ """
+ return ()
+
+ def setCurrentItem(self, item):
+ """Set the current/active item in the plot
+
+ :param item: The plot item to set as active/current
+ """
+ pass
+
+ def getLabel(self, item):
+ """Returns the label of the given item.
+
+ :param item:
+ :rtype: str
+ """
+ return ''
+
+ def getKind(self, item):
+ """Returns the kind of an item or None if not supported
+
+ :param item:
+ :rtype: Union[str,None]
+ """
+ return None
+
+
+class _PlotWidgetWrapper(_Wrapper):
+ """Class handling PlotWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param PlotWidget plot:
+ """
+
+ def __init__(self, plot):
+ assert isinstance(plot, PlotWidget)
+ super(_PlotWidgetWrapper, self).__init__(plot)
+ plot.sigItemAdded.connect(self.sigItemAdded.emit)
+ plot.sigItemAboutToBeRemoved.connect(self.sigItemRemoved.emit)
+ plot.sigActiveCurveChanged.connect(self._activeCurveChanged)
+ plot.sigActiveImageChanged.connect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self._activeScatterChanged)
+ plot.sigPlotSignal.connect(self._limitsChanged)
+
+ def _activeChanged(self, kind):
+ """Handle change of active curve/image/scatter"""
+ plot = self.getPlot()
+ if plot is not None:
+ item = plot._getActiveItem(kind=kind)
+ if item is None or self.getKind(item) is not None:
+ self.sigCurrentChanged.emit(item)
+
+ def _activeCurveChanged(self, previous, current):
+ self._activeChanged(kind='curve')
+
+ def _activeImageChanged(self, previous, current):
+ self._activeChanged(kind='image')
+
+ def _activeScatterChanged(self, previous, current):
+ self._activeChanged(kind='scatter')
+
+ def _limitsChanged(self, event):
+ """Handle change of plot area limits."""
+ if event['event'] == 'limitsChanged':
+ self.sigVisibleDataChanged.emit()
+
+ def getItems(self):
+ plot = self.getPlot()
+ if plot is None:
+ return ()
+ else:
+ return [item for item in plot.getItems() if item.isVisible()]
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ items = []
+ if plot is not None:
+ for kind in plot._ACTIVE_ITEM_KINDS:
+ item = plot._getActiveItem(kind=kind)
+ if item is not None:
+ items.append(item)
+ return tuple(items)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ kind = self.getKind(item)
+ if kind in plot._ACTIVE_ITEM_KINDS:
+ if plot._getActiveItem(kind) != item:
+ plot._setActiveItem(kind, item.getName())
+
+ def getLabel(self, item):
+ return item.getName()
+
+ def getKind(self, item):
+ if isinstance(item, plotitems.Curve):
+ return 'curve'
+ elif isinstance(item, plotitems.ImageData):
+ return 'image'
+ elif isinstance(item, plotitems.Scatter):
+ return 'scatter'
+ elif isinstance(item, plotitems.Histogram):
+ return 'histogram'
+ else:
+ return None
+
+
+class _SceneWidgetWrapper(_Wrapper):
+ """Class handling SceneWidget specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
+ """
+
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.SceneWidget import SceneWidget
+
+ assert isinstance(plot, SceneWidget)
+ super(_SceneWidgetWrapper, self).__init__(plot)
+ plot.getSceneGroup().sigItemAdded.connect(self.sigItemAdded)
+ plot.getSceneGroup().sigItemRemoved.connect(self.sigItemRemoved)
+ plot.selection().sigCurrentChanged.connect(self._currentChanged)
+ # sigVisibleDataChanged is never emitted
+
+ def _currentChanged(self, current, previous):
+ self.sigCurrentChanged.emit(current)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else tuple(plot.getSceneGroup().visit())
+
+ def getSelectedItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (plot.selection().getCurrentItem(),)
+
+ def setCurrentItem(self, item):
+ plot = self.getPlot()
+ if plot is not None:
+ plot.selection().setCurrentItem(item)
+
+ def getLabel(self, item):
+ return item.getLabel()
+
+ def getKind(self, item):
+ from ..plot3d import items as plot3ditems
+
+ if isinstance(item, (plot3ditems.ImageData,
+ plot3ditems.ScalarField3D)):
+ return 'image'
+ elif isinstance(item, (plot3ditems.Scatter2D,
+ plot3ditems.Scatter3D)):
+ return 'scatter'
+ else:
+ return None
+
+
+class _ScalarFieldViewWrapper(_Wrapper):
+ """Class handling ScalarFieldView specific calls and signal connections
+
+ See :class:`._Wrapper` for documentation
+
+ :param SceneWidget plot:
+ """
+
+ def __init__(self, plot):
+ # Lazy-import to avoid circular imports
+ from ..plot3d.ScalarFieldView import ScalarFieldView
+ from ..plot3d.items import ScalarField3D
+
+ assert isinstance(plot, ScalarFieldView)
+ super(_ScalarFieldViewWrapper, self).__init__(plot)
+ self._item = ScalarField3D()
+ self._dataChanged()
+ plot.sigDataChanged.connect(self._dataChanged)
+ # sigItemAdded, sigItemRemoved, sigVisibleDataChanged are never emitted
+
+ def _dataChanged(self):
+ plot = self.getPlot()
+ if plot is not None:
+ self._item.setData(plot.getData(copy=False), copy=False)
+ self.sigCurrentChanged.emit(self._item)
+
+ def getItems(self):
+ plot = self.getPlot()
+ return () if plot is None else (self._item,)
+
+ def getSelectedItems(self):
+ return self.getItems()
+
+ def setCurrentItem(self, item):
+ pass
+
+ def getLabel(self, item):
+ return 'Data'
+
+ def getKind(self, item):
+ return 'image'
+
+
+class _Container(object):
+ """Class to contain a plot item.
+
+ This is apparently needed for compatibility with PySide2,
+
+ :param QObject obj:
+ """
+ def __init__(self, obj):
+ self._obj = obj
+
+ def __call__(self):
+ return self._obj
+
+
+class _StatsWidgetBase(object):
+ """
+ Base class for all widgets which want to display statistics
+ """
+
+ def __init__(self, statsOnVisibleData, displayOnlyActItem):
+ self._displayOnlyActItem = displayOnlyActItem
+ self._statsOnVisibleData = statsOnVisibleData
+ self._statsHandler = None
+ self._updateMode = UpdateMode.AUTO
+
+ self.__default_skipped_events = (
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.LINE_WIDTH,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_BG_COLOR,
+ ItemChangedType.FILL,
+ ItemChangedType.HIGHLIGHTED_COLOR,
+ ItemChangedType.HIGHLIGHTED_STYLE,
+ ItemChangedType.TEXT,
+ ItemChangedType.OVERLAY,
+ ItemChangedType.VISUALIZATION_MODE,
+ )
+
+ self._plotWrapper = _Wrapper()
+ self._dealWithPlotConnection(create=True)
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ try:
+ import OpenGL
+ except ImportError:
+ has_opengl = False
+ else:
+ has_opengl = True
+ from ..plot3d.SceneWidget import SceneWidget # Lazy import
+ self._dealWithPlotConnection(create=False)
+ self.clear()
+ if plot is None:
+ self._plotWrapper = _Wrapper()
+ elif isinstance(plot, PlotWidget):
+ self._plotWrapper = _PlotWidgetWrapper(plot)
+ else:
+ if has_opengl is True:
+ if isinstance(plot, SceneWidget):
+ self._plotWrapper = _SceneWidgetWrapper(plot)
+ else: # Expect a ScalarFieldView
+ self._plotWrapper = _ScalarFieldViewWrapper(plot)
+ else:
+ _logger.warning('OpenGL not installed, %s not managed' % ('SceneWidget qnd ScalarFieldView'))
+ self._dealWithPlotConnection(create=True)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ if statsHandler is None:
+ statsHandler = StatsHandler(statFormatters=())
+ elif isinstance(statsHandler, (list, tuple)):
+ statsHandler = StatsHandler(statsHandler)
+ assert isinstance(statsHandler, StatsHandler)
+
+ self._statsHandler = statsHandler
+
+ def getStatsHandler(self):
+ """Returns the :class:`StatsHandler` in use.
+
+ :rtype: StatsHandler
+ """
+ return self._statsHandler
+
+ def getPlot(self):
+ """Returns the plot attached to this widget
+
+ :rtype: Union[PlotWidget,SceneWidget,None]
+ """
+ return self._plotWrapper.getPlot()
+
+ def _dealWithPlotConnection(self, create=True):
+ """Manage connection to plot signals
+
+ Note: connection on Item are managed by _addItem and _removeItem methods
+ """
+ connections = [] # List of (signal, slot) to connect/disconnect
+ if self._statsOnVisibleData:
+ connections.append(
+ (self._plotWrapper.sigVisibleDataChanged, self._updateAllStats))
+
+ if self._displayOnlyActItem:
+ connections.append(
+ (self._plotWrapper.sigCurrentChanged, self._updateCurrentItem))
+ else:
+ connections += [
+ (self._plotWrapper.sigItemAdded, self._addItem),
+ (self._plotWrapper.sigItemRemoved, self._removeItem),
+ (self._plotWrapper.sigCurrentChanged, self._plotCurrentChanged)]
+
+ for signal, slot in connections:
+ if create:
+ signal.connect(slot)
+ else:
+ signal.disconnect(slot)
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ raise NotImplementedError('Base class')
+
+ def _updateCurrentItem(self, *args):
+ """specific callback for the sigCurrentChanged and with the
+ _displayOnlyActItem option."""
+ raise NotImplementedError('Base class')
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
+ """
+ raise NotImplementedError('Base class')
+
+ def _updateAllStats(self):
+ """Update stats for all rows in the table"""
+ raise NotImplementedError('Base class')
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
+ """
+ self._displayOnlyActItem = displayOnlyActItem
+
+ def setStatsOnVisibleData(self, b):
+ """Toggle computation of statistics on whole data or only visible ones.
+
+ .. warning:: When visible data is activated we will process to a simple
+ filtering of visible data by the user. The filtering is a
+ simple data sub-sampling. No interpolation is made to fit
+ data to boundaries.
+
+ :param bool b: True if we want to apply statistics only on visible data
+ """
+ if self._statsOnVisibleData != b:
+ self._dealWithPlotConnection(create=False)
+ self._statsOnVisibleData = b
+ self._dealWithPlotConnection(create=True)
+ self._updateAllStats()
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ raise NotImplementedError('Base class')
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ raise NotImplementedError('Base class')
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ raise NotImplementedError('Base class')
+
+ def clear(self):
+ """clear GUI"""
+ pass
+
+ def _skipPlotItemChangedEvent(self, event):
+ """
+
+ :param ItemChangedtype event: event to filter or not
+ :return: True if we want to ignore this ItemChangedtype
+ :rtype: bool
+ """
+ return event in self.__default_skipped_events
+
+ def setUpdateMode(self, mode):
+ """Set the way to update the displayed statistics.
+
+ :param mode: mode requested for update
+ :type mode: Union[str,UpdateMode]
+ """
+ mode = UpdateMode.from_value(mode)
+ if mode != self._updateMode:
+ self._updateMode = mode
+ self._updateModeHasChanged()
+
+ def getUpdateMode(self):
+ """Returns update mode (See :meth:`setUpdateMode`).
+
+ :return: update mode
+ :rtype: UpdateMode
+ """
+ return self._updateMode
+
+ def _updateModeHasChanged(self):
+ """callback when the update mode has changed"""
+ pass
+
+
+class StatsTable(_StatsWidgetBase, TableWidget):
+ """
+ TableWidget displaying for each items contained by the Plot some
+ information:
+
+ * legend
+ * minimal value
+ * maximal value
+ * standard deviation (std)
+
+ :param QWidget parent: The widget's parent.
+ :param Union[PlotWidget,SceneWidget] plot:
+ :class:`PlotWidget` or :class:`SceneWidget` instance on which to operate
+ """
+
+ _LEGEND_HEADER_DATA = 'legend'
+ _KIND_HEADER_DATA = 'kind'
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent=None, plot=None):
+ TableWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self, statsOnVisibleData=False,
+ displayOnlyActItem=False)
+
+ # Init for _displayOnlyActItem == False
+ assert self._displayOnlyActItem is False
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ self.currentItemChanged.connect(self._currentItemChanged)
+
+ self.setRowCount(0)
+ self.setColumnCount(2)
+
+ # Init headers
+ headerItem = qt.QTableWidgetItem(self._LEGEND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._LEGEND_HEADER_DATA)
+ self.setHorizontalHeaderItem(0, headerItem)
+ headerItem = qt.QTableWidgetItem(self._KIND_HEADER_DATA.title())
+ headerItem.setData(qt.Qt.UserRole, self._KIND_HEADER_DATA)
+ self.setHorizontalHeaderItem(1, headerItem)
+
+ self.setSortingEnabled(True)
+ self.setPlot(plot)
+
+ @contextmanager
+ def _disableSorting(self):
+ """Context manager that disables table sorting
+
+ Previous state is restored when leaving
+ """
+ sorting = self.isSortingEnabled()
+ if sorting:
+ self.setSortingEnabled(False)
+ yield
+ if sorting:
+ self.setSortingEnabled(sorting)
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ self._removeAllItems()
+ _StatsWidgetBase.setStats(self, statsHandler)
+
+ self.setRowCount(0)
+ self.setColumnCount(len(self._statsHandler.stats) + 2) # + legend and kind
+
+ for index, stat in enumerate(self._statsHandler.stats.values()):
+ headerItem = qt.QTableWidgetItem(stat.name.capitalize())
+ headerItem.setData(qt.Qt.UserRole, stat.name)
+ if stat.description is not None:
+ headerItem.setToolTip(stat.description)
+ self.setHorizontalHeaderItem(2 + index, headerItem)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ self._updateItemObserve()
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateItemObserve()
+
+ def clear(self):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ self._removeAllItems()
+
+ def _updateItemObserve(self, *args):
+ """Reload table depending on mode"""
+ self._removeAllItems()
+
+ # Get selected or all items from the plot
+ if self._displayOnlyActItem: # Only selected
+ items = self._plotWrapper.getSelectedItems()
+ else: # All items
+ items = self._plotWrapper.getItems()
+
+ # Add items to the plot
+ for item in items:
+ self._addItem(item)
+
+ def _updateCurrentItem(self, *args):
+ """specific callback for the sigCurrentChanged and with the
+ _displayOnlyActItem option.
+
+ Behavior: create the tableItems if does not exists.
+ If exists, update it only when we are in 'auto' mode"""
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ # when sigCurrentChanged is giving the current item
+ if len(args) > 0 and isinstance(args[0], (plotitems.Curve, plotitems.Histogram, plotitems.ImageData, plotitems.Scatter)):
+ item = args[0]
+ tableItems = self._itemToTableItems(item)
+ # if the table does not exists yet
+ if len(tableItems) == 0:
+ self._updateItemObserve()
+ else:
+ # in this case no current item
+ self._updateItemObserve(args)
+ else:
+ # auto mode
+ self._updateItemObserve(args)
+
+ def _plotCurrentChanged(self, current):
+ """Handle change of current item and update selection in table
+
+ :param current:
+ """
+ row = self._itemToRow(current)
+ if row is None:
+ if self.currentRow() >= 0:
+ self.setCurrentCell(-1, -1)
+ elif row != self.currentRow():
+ self.setCurrentCell(row, 0)
+
+ def _tableItemToItem(self, tableItem):
+ """Find the plot item corresponding to a table item
+
+ :param QTableWidgetItem tableItem:
+ :rtype: QObject
+ """
+ container = tableItem.data(qt.Qt.UserRole)
+ return container()
+
+ def _itemToRow(self, item):
+ """Find the row corresponding to a plot item
+
+ :param item: The plot item
+ :return: The corresponding row index
+ :rtype: Union[int,None]
+ """
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ if self._tableItemToItem(tableItem) == item:
+ return row
+ return None
+
+ def _itemToTableItems(self, item):
+ """Find all table items corresponding to a plot item
+
+ :param item: The plot item
+ :return: An ordered dict of column name to QTableWidgetItem mapping
+ for the given plot item.
+ :rtype: OrderedDict
+ """
+ result = OrderedDict()
+ row = self._itemToRow(item)
+ if row is not None:
+ for column in range(self.columnCount()):
+ tableItem = self.item(row, column)
+ if self._tableItemToItem(tableItem) != item:
+ _logger.error("Table item/plot item mismatch")
+ else:
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+ result[name] = tableItem
+ return result
+
+ def _plotItemChanged(self, event):
+ """Handle modifications of the items.
+
+ :param event:
+ """
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if self._skipPlotItemChangedEvent(event) is True:
+ return
+ else:
+ item = self.sender()
+ self._updateStats(item, data_changed=True)
+ # deal with stat items visibility
+ if event is ItemChangedType.VISIBLE:
+ if len(self._itemToTableItems(item).items()) > 0:
+ item_0 = list(self._itemToTableItems(item).values())[0]
+ row_index = item_0.row()
+ self.setRowHidden(row_index, not item.isVisible())
+
+ def _addItem(self, item):
+ """Add a plot item to the table
+
+ If item is not supported, it is ignored.
+
+ :param item: The plot item
+ :returns: True if the item is added to the widget.
+ :rtype: bool
+ """
+ if self._itemToRow(item) is not None:
+ _logger.info("Item already present in the table")
+ self._updateStats(item)
+ return True
+
+ kind = self._plotWrapper.getKind(item)
+ if kind not in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.info("Item has not a supported type: %s", item)
+ return False
+
+ # Prepare table items
+ tableItems = [
+ qt.QTableWidgetItem(), # Legend
+ qt.QTableWidgetItem()] # Kind
+
+ for column in range(2, self.columnCount()):
+ header = self.horizontalHeaderItem(column)
+ name = header.data(qt.Qt.UserRole)
+
+ formatter = self._statsHandler.formatters[name]
+ if formatter:
+ tableItem = formatter.tabWidgetItemClass()
+ else:
+ tableItem = qt.QTableWidgetItem()
+
+ tooltip = self._statsHandler.stats[name].getToolTip(kind=kind)
+ if tooltip is not None:
+ tableItem.setToolTip(tooltip)
+
+ tableItems.append(tableItem)
+
+ # Disable sorting while adding table items
+ with self._disableSorting():
+ # Add a row to the table
+ self.setRowCount(self.rowCount() + 1)
+
+ # Add table items to the last row
+ row = self.rowCount() - 1
+ for column, tableItem in enumerate(tableItems):
+ tableItem.setData(qt.Qt.UserRole, _Container(item))
+ tableItem.setFlags(
+ qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, column, tableItem)
+
+ # Update table items content
+ self._updateStats(item, data_changed=True)
+
+ # Listen for item changes
+ # Using queued connection to avoid issue with sender
+ # being that of the signal calling the signal
+ item.sigItemChanged.connect(self._plotItemChanged,
+ qt.Qt.QueuedConnection)
+
+ return True
+
+ def _removeItem(self, item):
+ """Remove table items corresponding to given plot item from the table.
+
+ :param item: The plot item
+ """
+ row = self._itemToRow(item)
+ if row is None:
+ kind = self._plotWrapper.getKind(item)
+ if kind in statsmdl.BASIC_COMPATIBLE_KINDS:
+ _logger.error("Removing item that is not in table: %s", str(item))
+ return
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.removeRow(row)
+
+ def _removeAllItems(self):
+ """Remove content of the table"""
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ item.sigItemChanged.disconnect(self._plotItemChanged)
+ self.clearContents()
+ self.setRowCount(0)
+
+ def _updateStats(self, item, data_changed=False, roi_changed=False):
+ """Update displayed information for given plot item
+
+ :param item: The plot item
+ :param bool data_changed: is the item data changed.
+ :param bool roi_changed: is the associated roi changed.
+ """
+ if item is None:
+ return
+ plot = self.getPlot()
+ if plot is None:
+ _logger.info("Plot not available")
+ return
+
+ row = self._itemToRow(item)
+ if row is None:
+ _logger.error("This item is not in the table: %s", str(item))
+ return
+
+ statsHandler = self.getStatsHandler()
+ if statsHandler is not None:
+ # _updateStats is call when the plot visible area change.
+ # to force stats update we consider roi changed
+ if self._statsOnVisibleData:
+ roi_changed = True
+ else:
+ roi_changed = False
+ stats = statsHandler.calculate(
+ item, plot, self._statsOnVisibleData,
+ data_changed=data_changed, roi_changed=roi_changed)
+ else:
+ stats = {}
+
+ with self._disableSorting():
+ for name, tableItem in self._itemToTableItems(item).items():
+ if name == self._LEGEND_HEADER_DATA:
+ text = self._plotWrapper.getLabel(item)
+ tableItem.setText(text)
+ elif name == self._KIND_HEADER_DATA:
+ tableItem.setText(self._plotWrapper.getKind(item))
+ else:
+ value = stats.get(name)
+ if value is None:
+ _logger.error("Value not found for: %s", name)
+ tableItem.setText('-')
+ else:
+ tableItem.setText(str(value))
+
+ def _updateAllStats(self, is_request=False):
+ """Update stats for all rows in the table
+
+ :param bool is_request: True if come from a manual request
+ """
+ if self.getUpdateMode() is UpdateMode.MANUAL and not is_request:
+ return
+ with self._disableSorting():
+ for row in range(self.rowCount()):
+ tableItem = self.item(row, 0)
+ item = self._tableItemToItem(tableItem)
+ self._updateStats(item, data_changed=is_request)
+
+ def _currentItemChanged(self, current, previous):
+ """Handle change of selection in table and sync plot selection
+
+ :param QTableWidgetItem current:
+ :param QTableWidgetItem previous:
+ """
+ if current and current.row() >= 0:
+ item = self._tableItemToItem(current)
+ self._plotWrapper.setCurrentItem(item)
+
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ """Toggle display off all items or only the active/selected one
+
+ :param bool displayOnlyActItem:
+ True if we want to only show active item
+ """
+ if self._displayOnlyActItem == displayOnlyActItem:
+ return
+ self._dealWithPlotConnection(create=False)
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.disconnect(self._currentItemChanged)
+
+ _StatsWidgetBase.setDisplayOnlyActiveItem(self, displayOnlyActItem)
+
+ self._updateItemObserve()
+ self._dealWithPlotConnection(create=True)
+
+ if not self._displayOnlyActItem:
+ self.currentItemChanged.connect(self._currentItemChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+ else:
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+
+ def _updateModeHasChanged(self):
+ self.sigUpdateModeChanged.emit(self._updateMode)
+
+
+class UpdateModeWidget(qt.QWidget):
+ """Widget used to select the mode of update"""
+ sigUpdateModeChanged = qt.Signal(object)
+ """signal emitted when the mode for update changed"""
+ sigUpdateRequested = qt.Signal()
+ """signal emitted when an manual request for example is activate"""
+
+ def __init__(self, parent=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QHBoxLayout())
+ self._buttonGrp = qt.QButtonGroup(parent=self)
+ self._buttonGrp.setExclusive(True)
+
+ spacer = qt.QSpacerItem(20, 20,
+ qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Minimum)
+ self.layout().addItem(spacer)
+
+ self._autoRB = qt.QRadioButton('auto', parent=self)
+ self.layout().addWidget(self._autoRB)
+ self._buttonGrp.addButton(self._autoRB)
+
+ self._manualRB = qt.QRadioButton('manual', parent=self)
+ self.layout().addWidget(self._manualRB)
+ self._buttonGrp.addButton(self._manualRB)
+ self._manualRB.setChecked(True)
+
+ refresh_icon = icons.getQIcon('view-refresh')
+ self._updatePB = qt.QPushButton(refresh_icon, '', parent=self)
+ self.layout().addWidget(self._updatePB)
+
+ # connect signal / SLOT
+ self._updatePB.clicked.connect(self._updateRequested)
+ self._manualRB.toggled.connect(self._manualButtonToggled)
+ self._autoRB.toggled.connect(self._autoButtonToggled)
+
+ def _manualButtonToggled(self, checked):
+ if checked:
+ self.setUpdateMode(UpdateMode.MANUAL)
+ self.sigUpdateModeChanged.emit(self.getUpdateMode())
+
+ def _autoButtonToggled(self, checked):
+ if checked:
+ self.setUpdateMode(UpdateMode.AUTO)
+ self.sigUpdateModeChanged.emit(self.getUpdateMode())
+
+ def _updateRequested(self):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ self.sigUpdateRequested.emit()
+
+ def setUpdateMode(self, mode):
+ """Set the way to update the displayed statistics.
+
+ :param mode: mode requested for update
+ :type mode: Union[str,UpdateMode]
+ """
+ mode = UpdateMode.from_value(mode)
+
+ if mode is UpdateMode.AUTO:
+ if not self._autoRB.isChecked():
+ self._autoRB.setChecked(True)
+ elif mode is UpdateMode.MANUAL:
+ if not self._manualRB.isChecked():
+ self._manualRB.setChecked(True)
+ else:
+ raise ValueError('mode', mode, 'is not recognized')
+
+ def getUpdateMode(self):
+ """Returns update mode (See :meth:`setUpdateMode`).
+
+ :return: the active update mode
+ :rtype: UpdateMode
+ """
+ if self._manualRB.isChecked():
+ return UpdateMode.MANUAL
+ elif self._autoRB.isChecked():
+ return UpdateMode.AUTO
+ else:
+ raise RuntimeError("No mode selected")
+
+ def showRadioButtons(self, show):
+ """show / hide the QRadioButtons
+
+ :param bool show: if True make RadioButton visible
+ """
+ self._autoRB.setVisible(show)
+ self._manualRB.setVisible(show)
+
+
+class _OptionsWidget(qt.QToolBar):
+
+ def __init__(self, parent=None, updateMode=None, displayOnlyActItem=False):
+ assert updateMode is not None
+ qt.QToolBar.__init__(self, parent)
+ self.setIconSize(qt.QSize(16, 16))
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-active-items"))
+ action.setText("Active items only")
+ action.setToolTip("Display stats for active items only.")
+ action.setCheckable(True)
+ action.setChecked(displayOnlyActItem)
+ self.__displayActiveItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-items"))
+ action.setText("All items")
+ action.setToolTip("Display stats for all available items.")
+ action.setCheckable(True)
+ self.__displayWholeItems = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-visible-data"))
+ action.setText("Use the visible data range")
+ action.setToolTip("Use the visible data range.<br/>"
+ "If activated the data is filtered to only use"
+ "visible data of the plot."
+ "The filtering is a data sub-sampling."
+ "No interpolation is made to fit data to"
+ "boundaries.")
+ action.setCheckable(True)
+ self.__useVisibleData = action
+
+ action = qt.QAction(self)
+ action.setIcon(icons.getQIcon("stats-whole-data"))
+ action.setText("Use the full data range")
+ action.setToolTip("Use the full data range.")
+ action.setCheckable(True)
+ action.setChecked(True)
+ self.__useWholeData = action
+
+ self.addAction(self.__displayWholeItems)
+ self.addAction(self.__displayActiveItems)
+ self.addSeparator()
+ self.addAction(self.__useVisibleData)
+ self.addAction(self.__useWholeData)
+
+ self.itemSelection = qt.QActionGroup(self)
+ self.itemSelection.setExclusive(True)
+ self.itemSelection.addAction(self.__displayActiveItems)
+ self.itemSelection.addAction(self.__displayWholeItems)
+
+ self.dataRangeSelection = qt.QActionGroup(self)
+ self.dataRangeSelection.setExclusive(True)
+ self.dataRangeSelection.addAction(self.__useWholeData)
+ self.dataRangeSelection.addAction(self.__useVisibleData)
+
+ self.__updateStatsAction = qt.QAction(self)
+ self.__updateStatsAction.setIcon(icons.getQIcon("view-refresh"))
+ self.__updateStatsAction.setText("update statistics")
+ self.__updateStatsAction.setToolTip("update statistics")
+ self.__updateStatsAction.setCheckable(False)
+ self._updateStatsSep = self.addSeparator()
+ self.addAction(self.__updateStatsAction)
+
+ self._setUpdateMode(mode=updateMode)
+
+ # expose API
+ self.sigUpdateStats = self.__updateStatsAction.triggered
+
+ def isActiveItemMode(self):
+ return self.itemSelection.checkedAction() is self.__displayActiveItems
+
+ def setDisplayActiveItems(self, only_active):
+ self.__displayActiveItems.setChecked(only_active)
+ self.__displayWholeItems.setChecked(not only_active)
+
+ def isVisibleDataRangeMode(self):
+ return self.dataRangeSelection.checkedAction() is self.__useVisibleData
+
+ def setVisibleDataRangeModeEnabled(self, enabled):
+ """Enable/Disable the visible data range mode
+
+ :param bool enabled: True to allow user to choose
+ stats on visible data
+ """
+ self.__useVisibleData.setEnabled(enabled)
+ if not enabled:
+ self.__useWholeData.setChecked(True)
+
+ def _setUpdateMode(self, mode):
+ self.__updateStatsAction.setVisible(mode == UpdateMode.MANUAL)
+ self._updateStatsSep.setVisible(mode == UpdateMode.MANUAL)
+
+ def getUpdateStatsAction(self):
+ """
+
+ :return: the action for the automatic mode
+ :rtype: QAction
+ """
+ return self.__updateStatsAction
+
+
+class StatsWidget(qt.QWidget):
+ """
+ Widget displaying a set of :class:`Stat` to be displayed on a
+ :class:`StatsTable` and to be apply on items contained in the :class:`Plot`
+ Also contains options to:
+
+ * compute statistics on all the data or on visible data only
+ * show statistics of all items or only the active one
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ """
+
+ sigVisibilityChanged = qt.Signal(bool)
+ """Signal emitted when the visibility of this widget changes.
+
+ It Provides the visibility of the widget.
+ """
+
+ NUMBER_FORMAT = '{0:.3f}'
+
+ def __init__(self, parent=None, plot=None, stats=None):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._options = _OptionsWidget(parent=self, updateMode=UpdateMode.MANUAL)
+ self.layout().addWidget(self._options)
+ self._statsTable = StatsTable(parent=self, plot=plot)
+ self._statsTable.setDisplayOnlyActiveItem(self._options.isActiveItemMode())
+ self._options._setUpdateMode(mode=self._statsTable.getUpdateMode())
+ self.setStats(stats)
+
+ self.layout().addWidget(self._statsTable)
+
+ old = self._statsTable.blockSignals(True)
+ self._options.itemSelection.triggered.connect(
+ self._optSelectionChanged)
+ self._options.dataRangeSelection.triggered.connect(
+ self._optDataRangeChanged)
+ self._optDataRangeChanged()
+ self._statsTable.blockSignals(old)
+
+ self._statsTable.sigUpdateModeChanged.connect(self._options._setUpdateMode)
+ callback = functools.partial(self._getStatsTable()._updateAllStats, is_request=True)
+ self._options.sigUpdateStats.connect(callback)
+
+ def _getStatsTable(self):
+ """Returns the :class:`StatsTable` used by this widget.
+
+ :rtype: StatsTable
+ """
+ return self._statsTable
+
+ def showEvent(self, event):
+ self.sigVisibilityChanged.emit(True)
+ qt.QWidget.showEvent(self, event)
+
+ def hideEvent(self, event):
+ self.sigVisibilityChanged.emit(False)
+ qt.QWidget.hideEvent(self, event)
+
+ def _optSelectionChanged(self, action=None):
+ self._getStatsTable().setDisplayOnlyActiveItem(
+ self._options.isActiveItemMode())
+
+ def _optDataRangeChanged(self, action=None):
+ self._getStatsTable().setStatsOnVisibleData(
+ self._options.isVisibleDataRangeMode())
+
+ # Proxy methods
+
+ @docstring(StatsTable)
+ def setStats(self, statsHandler):
+ return self._getStatsTable().setStats(statsHandler=statsHandler)
+
+ @docstring(StatsTable)
+ def setPlot(self, plot):
+ self._options.setVisibleDataRangeModeEnabled(
+ plot is None or isinstance(plot, PlotWidget))
+ return self._getStatsTable().setPlot(plot=plot)
+
+ @docstring(StatsTable)
+ def getPlot(self):
+ return self._getStatsTable().getPlot()
+
+ @docstring(StatsTable)
+ def setDisplayOnlyActiveItem(self, displayOnlyActItem):
+ old = self._options.blockSignals(True)
+ # update the options
+ self._options.setDisplayActiveItems(displayOnlyActItem)
+ self._options.blockSignals(old)
+ return self._getStatsTable().setDisplayOnlyActiveItem(
+ displayOnlyActItem=displayOnlyActItem)
+
+ @docstring(StatsTable)
+ def setStatsOnVisibleData(self, b):
+ return self._getStatsTable().setStatsOnVisibleData(b=b)
+
+ @docstring(StatsTable)
+ def getUpdateMode(self):
+ return self._statsTable.getUpdateMode()
+
+ @docstring(StatsTable)
+ def setUpdateMode(self, mode):
+ self._statsTable.setUpdateMode(mode)
+
+
+DEFAULT_STATS = StatsHandler((
+ (statsmdl.StatMin(), StatFormatter()),
+ statsmdl.StatCoordMin(),
+ (statsmdl.StatMax(), StatFormatter()),
+ statsmdl.StatCoordMax(),
+ statsmdl.StatCOM(),
+ (('mean', numpy.mean), StatFormatter()),
+ (('std', numpy.std), StatFormatter()),
+))
+
+
+class BasicStatsWidget(StatsWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`StatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param PlotWidget plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+
+ .. snapshotqt:: img/BasicStatsWidget.png
+ :width: 300px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicStatsWidget(plot=plot)
+ widget.show()
+ """
+ def __init__(self, parent=None, plot=None):
+ StatsWidget.__init__(self, parent=parent, plot=plot,
+ stats=DEFAULT_STATS)
+
+
+class _BaseLineStatsWidget(_StatsWidgetBase, qt.QWidget):
+ """
+ Widget made to display stats into a QLayout with couple (QLabel, QLineEdit)
+ created for each stats.
+ The layout can be defined prior of adding any statistic.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+
+ sigUpdateModeChanged = qt.Signal(object)
+ """Signal emitted when the update mode changed"""
+
+ def __init__(self, parent=None, plot=None, kind='curve', stats=None,
+ statsOnVisibleData=False):
+ self._item_kind = kind
+ """The item displayed"""
+ self._statQlineEdit = {}
+ """list of legends actually displayed"""
+ self._n_statistics_per_line = 4
+ """number of statistics displayed per line in the grid layout"""
+ qt.QWidget.__init__(self, parent)
+ _StatsWidgetBase.__init__(self,
+ statsOnVisibleData=statsOnVisibleData,
+ displayOnlyActItem=True)
+ self.setLayout(self._createLayout())
+ self.setPlot(plot)
+ if stats is not None:
+ self.setStats(stats)
+
+ def _addItemForStatistic(self, statistic):
+ assert isinstance(statistic, statsmdl.StatBase)
+ assert statistic.name in self._statsHandler.stats
+
+ self.layout().setSpacing(2)
+ self.layout().setContentsMargins(2, 2, 2, 2)
+
+ if isinstance(self.layout(), qt.QGridLayout):
+ parent = self
+ else:
+ widget = qt.QWidget(parent=self)
+ parent = widget
+
+ qLabel = qt.QLabel(statistic.name + ':', parent=parent)
+ qLineEdit = qt.QLineEdit('', parent=parent)
+ qLineEdit.setReadOnly(True)
+
+ self._addStatsWidgetsToLayout(qLabel=qLabel, qLineEdit=qLineEdit)
+ self._statQlineEdit[statistic.name] = qLineEdit
+
+ def setPlot(self, plot):
+ """Define the plot to interact with
+
+ :param Union[PlotWidget,SceneWidget,None] plot:
+ The plot containing the items on which statistics are applied
+ """
+ _StatsWidgetBase.setPlot(self, plot)
+ self._updateAllStats()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ raise NotImplementedError('Base class')
+
+ def setStats(self, statsHandler):
+ """Set which stats to display and the associated formatting.
+
+ :param StatsHandler statsHandler:
+ Set the statistics to be displayed and how to format them using
+ """
+ _StatsWidgetBase.setStats(self, statsHandler)
+ for statName, stat in list(self._statsHandler.stats.items()):
+ self._addItemForStatistic(stat)
+ self._updateAllStats()
+
+ def _activeItemChanged(self, kind, previous, current):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ if kind == self._item_kind:
+ self._updateAllStats()
+
+ def _updateAllStats(self):
+ plot = self.getPlot()
+ if plot is not None:
+ _items = self._plotWrapper.getSelectedItems()
+
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ if len(items) == 1:
+ self._setItem(items[0])
+
+ def setKind(self, kind):
+ """Change the kind of active item to display
+ :param str kind: kind of item to display information for ('curve' ...)
+ """
+ if self._item_kind != kind:
+ self._item_kind = kind
+ self._updateItemObserve()
+
+ def getKind(self):
+ """
+ :return: kind of item we want to compute statistic for
+ :rtype: str
+ """
+ return self._item_kind
+
+ def _setItem(self, item, data_changed=True):
+ if item is None:
+ for stat_name, stat_widget in self._statQlineEdit.items():
+ stat_widget.setText('')
+ elif (self._statsHandler is not None and len(
+ self._statsHandler.stats) > 0):
+ plot = self.getPlot()
+ if plot is not None:
+ statsValDict = self._statsHandler.calculate(item,
+ plot,
+ self._statsOnVisibleData,
+ data_changed=data_changed)
+ for statName, statVal in list(statsValDict.items()):
+ self._statQlineEdit[statName].setText(statVal)
+
+ def _updateItemObserve(self, *argv):
+ if self.getUpdateMode() is UpdateMode.MANUAL:
+ return
+ assert self._displayOnlyActItem
+ _items = self._plotWrapper.getSelectedItems()
+
+ def kind_filter(_item):
+ return self._plotWrapper.getKind(_item) == self.getKind()
+ items = list(filter(kind_filter, _items))
+ assert len(items) in (0, 1)
+ _item = items[0] if len(items) == 1 else None
+ self._setItem(_item, data_changed=True)
+
+ def _updateCurrentItem(self):
+ self._updateItemObserve()
+
+ def _createLayout(self):
+ """create an instance of the main QLayout"""
+ raise NotImplementedError('Base class')
+
+ def _addItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _removeItem(self, item):
+ raise NotImplementedError('Display only the active item')
+
+ def _plotCurrentChanged(self, current):
+ raise NotImplementedError('Display only the active item')
+
+ def _updateModeHasChanged(self):
+ self.sigUpdateModeChanged.emit(self._updateMode)
+
+
+class _BasicLineStatsWidget(_BaseLineStatsWidget):
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+
+ def _createLayout(self):
+ return FlowLayout()
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ # create a mother widget to make sure both qLabel & qLineEdit will
+ # always be displayed side by side
+ widget = qt.QWidget(parent=self)
+ widget.setLayout(qt.QHBoxLayout())
+ widget.layout().setSpacing(0)
+ widget.layout().setContentsMargins(0, 0, 0, 0)
+
+ widget.layout().addWidget(qLabel)
+ widget.layout().addWidget(qLineEdit)
+
+ self.layout().addWidget(widget)
+
+ def _addOptionsWidget(self, widget):
+ self.layout().addWidget(widget)
+
+
+class BasicLineStatsWidget(qt.QWidget):
+ """
+ Widget defining a simple set of :class:`Stat` to be displayed on a
+ :class:`LineStatsWidget`.
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param str kind: the kind of plotitems we want to display
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ """
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QHBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+ self._lineStatsWidget = _BasicLineStatsWidget(parent=self, plot=plot,
+ kind=kind, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self.layout().addWidget(self._lineStatsWidget)
+
+ self._options = UpdateModeWidget()
+ self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
+ self._options.showRadioButtons(False)
+ self.layout().addWidget(self._options)
+
+ # connect Signal ? SLOT
+ self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
+ self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
+ self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
+
+ def showControl(self, visible):
+ self._options.setVisible(visible)
+
+ # Proxy methods
+
+ @docstring(_BasicLineStatsWidget)
+ def setUpdateMode(self, mode):
+ self._lineStatsWidget.setUpdateMode(mode=mode)
+
+ @docstring(_BasicLineStatsWidget)
+ def getUpdateMode(self):
+ return self._lineStatsWidget.getUpdateMode()
+
+ @docstring(_BasicLineStatsWidget)
+ def setPlot(self, plot):
+ self._lineStatsWidget.setPlot(plot=plot)
+
+ @docstring(_BasicLineStatsWidget)
+ def setStats(self, statsHandler):
+ self._lineStatsWidget.setStats(statsHandler=statsHandler)
+
+ @docstring(_BasicLineStatsWidget)
+ def setKind(self, kind):
+ self._lineStatsWidget.setKind(kind=kind)
+
+ @docstring(_BasicLineStatsWidget)
+ def getKind(self):
+ return self._lineStatsWidget.getKind()
+
+ @docstring(_BasicLineStatsWidget)
+ def setStatsOnVisibleData(self, b):
+ self._lineStatsWidget.setStatsOnVisibleData(b)
+
+ @docstring(UpdateModeWidget)
+ def showRadioButtons(self, show):
+ self._options.showRadioButtons(show=show)
+
+
+class _BasicGridStatsWidget(_BaseLineStatsWidget):
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False,
+ statsPerLine=4):
+ _BaseLineStatsWidget.__init__(self, parent=parent, kind=kind,
+ plot=plot, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self._n_statistics_per_line = statsPerLine
+
+ def _addStatsWidgetsToLayout(self, qLabel, qLineEdit):
+ column = len(self._statQlineEdit) % self._n_statistics_per_line
+ row = len(self._statQlineEdit) // self._n_statistics_per_line
+ self.layout().addWidget(qLabel, row, column * 2)
+ self.layout().addWidget(qLineEdit, row, column * 2 + 1)
+
+ def _createLayout(self):
+ return qt.QGridLayout()
+
+
+class BasicGridStatsWidget(qt.QWidget):
+ """
+ pymca design like widget
+
+ :param QWidget parent: Qt parent
+ :param Union[PlotWidget,SceneWidget] plot:
+ The plot containing items on which we want statistics.
+ :param StatsHandler stats:
+ Set the statistics to be displayed and how to format them using
+ :param str kind: the kind of plotitems we want to display
+ :param bool statsOnVisibleData: compute statistics for the whole data or
+ only visible ones.
+ :param int statsPerLine: number of statistic to be displayed per line
+
+ .. snapshotqt:: img/BasicGridStatsWidget.png
+ :width: 600px
+ :align: center
+
+ from silx.gui.plot import Plot1D
+ from silx.gui.plot.StatsWidget import BasicGridStatsWidget
+
+ plot = Plot1D()
+ x = range(100)
+ y = x
+ plot.addCurve(x, y, legend='curve_0')
+ plot.setActiveCurve('curve_0')
+
+ widget = BasicGridStatsWidget(plot=plot, kind='curve')
+ widget.show()
+ """
+
+ def __init__(self, parent=None, plot=None, kind='curve',
+ stats=DEFAULT_STATS, statsOnVisibleData=False):
+ qt.QWidget.__init__(self, parent)
+ self.setLayout(qt.QVBoxLayout())
+ self.layout().setSpacing(0)
+ self.layout().setContentsMargins(0, 0, 0, 0)
+
+ self._options = UpdateModeWidget()
+ self._options.showRadioButtons(False)
+ self.layout().addWidget(self._options)
+
+ self._lineStatsWidget = _BasicGridStatsWidget(parent=self, plot=plot,
+ kind=kind, stats=stats,
+ statsOnVisibleData=statsOnVisibleData)
+ self.layout().addWidget(self._lineStatsWidget)
+
+ # tune options
+ self._options.setUpdateMode(self._lineStatsWidget.getUpdateMode())
+
+ # connect Signal ? SLOT
+ self._lineStatsWidget.sigUpdateModeChanged.connect(self._options.setUpdateMode)
+ self._options.sigUpdateModeChanged.connect(self._lineStatsWidget.setUpdateMode)
+ self._options.sigUpdateRequested.connect(self._lineStatsWidget._updateAllStats)
+
+ def showControl(self, visible):
+ self._options.setVisible(visible)
+
+ @docstring(_BasicGridStatsWidget)
+ def setUpdateMode(self, mode):
+ self._lineStatsWidget.setUpdateMode(mode=mode)
+
+ @docstring(_BasicGridStatsWidget)
+ def getUpdateMode(self):
+ return self._lineStatsWidget.getUpdateMode()
+
+ @docstring(_BasicGridStatsWidget)
+ def setPlot(self, plot):
+ self._lineStatsWidget.setPlot(plot=plot)
+
+ @docstring(_BasicGridStatsWidget)
+ def setStats(self, statsHandler):
+ self._lineStatsWidget.setStats(statsHandler=statsHandler)
+
+ @docstring(_BasicGridStatsWidget)
+ def setKind(self, kind):
+ self._lineStatsWidget.setKind(kind=kind)
+
+ @docstring(_BasicGridStatsWidget)
+ def getKind(self):
+ return self._lineStatsWidget.getKind()
+
+ @docstring(_BasicGridStatsWidget)
+ def setStatsOnVisibleData(self, b):
+ self._lineStatsWidget.setStatsOnVisibleData(b)
+
+ @docstring(UpdateModeWidget)
+ def showRadioButtons(self, show):
+ self._options.showRadioButtons(show=show)
diff --git a/silx/gui/plot/_BaseMaskToolsWidget.py b/src/silx/gui/plot/_BaseMaskToolsWidget.py
index 407ab11..407ab11 100644
--- a/silx/gui/plot/_BaseMaskToolsWidget.py
+++ b/src/silx/gui/plot/_BaseMaskToolsWidget.py
diff --git a/silx/gui/plot/__init__.py b/src/silx/gui/plot/__init__.py
index 3a141b3..3a141b3 100644
--- a/silx/gui/plot/__init__.py
+++ b/src/silx/gui/plot/__init__.py
diff --git a/src/silx/gui/plot/_utils/__init__.py b/src/silx/gui/plot/_utils/__init__.py
new file mode 100644
index 0000000..ed87b18
--- /dev/null
+++ b/src/silx/gui/plot/_utils/__init__.py
@@ -0,0 +1,92 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Miscellaneous utility functions for the Plot"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/03/2017"
+
+
+import numpy
+
+from .panzoom import FLOAT32_SAFE_MIN, FLOAT32_MINPOS, FLOAT32_SAFE_MAX
+from .panzoom import applyZoomToPlot, applyPan, checkAxisLimits
+
+
+def addMarginsToLimits(margins, isXLog, isYLog,
+ xMin, xMax, yMin, yMax, y2Min=None, y2Max=None):
+ """Returns updated limits by extending them with margins.
+
+ :param margins: The ratio of the margins to add or None for no margins.
+ :type margins: A 4-tuple of floats as
+ (xMinMargin, xMaxMargin, yMinMargin, yMaxMargin)
+
+ :return: The updated limits
+ :rtype: tuple of 4 or 6 floats: Either (xMin, xMax, yMin, yMax) or
+ (xMin, xMax, yMin, yMax, y2Min, y2Max) if y2Min and y2Max
+ are provided.
+ """
+ if margins is not None:
+ xMinMargin, xMaxMargin, yMinMargin, yMaxMargin = margins
+
+ if not isXLog:
+ xRange = xMax - xMin
+ xMin -= xMinMargin * xRange
+ xMax += xMaxMargin * xRange
+
+ elif xMin > 0. and xMax > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ xMinLog, xMaxLog = numpy.log10(xMin), numpy.log10(xMax)
+ xRangeLog = xMaxLog - xMinLog
+ xMin = pow(10., xMinLog - xMinMargin * xRangeLog)
+ xMax = pow(10., xMaxLog + xMaxMargin * xRangeLog)
+
+ if not isYLog:
+ yRange = yMax - yMin
+ yMin -= yMinMargin * yRange
+ yMax += yMaxMargin * yRange
+ elif yMin > 0. and yMax > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(yMin), numpy.log10(yMax)
+ yRangeLog = yMaxLog - yMinLog
+ yMin = pow(10., yMinLog - yMinMargin * yRangeLog)
+ yMax = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is not None and y2Max is not None:
+ if not isYLog:
+ yRange = y2Max - y2Min
+ y2Min -= yMinMargin * yRange
+ y2Max += yMaxMargin * yRange
+ elif y2Min > 0. and y2Max > 0.: # Log scale
+ # Do not apply margins if limits < 0
+ yMinLog, yMaxLog = numpy.log10(y2Min), numpy.log10(y2Max)
+ yRangeLog = yMaxLog - yMinLog
+ y2Min = pow(10., yMinLog - yMinMargin * yRangeLog)
+ y2Max = pow(10., yMaxLog + yMaxMargin * yRangeLog)
+
+ if y2Min is None or y2Max is None:
+ return xMin, xMax, yMin, yMax
+ else:
+ return xMin, xMax, yMin, yMax, y2Min, y2Max
diff --git a/silx/gui/plot/_utils/delaunay.py b/src/silx/gui/plot/_utils/delaunay.py
index 49ad05f..49ad05f 100644
--- a/silx/gui/plot/_utils/delaunay.py
+++ b/src/silx/gui/plot/_utils/delaunay.py
diff --git a/silx/gui/plot/_utils/dtime_ticklayout.py b/src/silx/gui/plot/_utils/dtime_ticklayout.py
index ebf775b..ebf775b 100644
--- a/silx/gui/plot/_utils/dtime_ticklayout.py
+++ b/src/silx/gui/plot/_utils/dtime_ticklayout.py
diff --git a/src/silx/gui/plot/_utils/panzoom.py b/src/silx/gui/plot/_utils/panzoom.py
new file mode 100644
index 0000000..77efd10
--- /dev/null
+++ b/src/silx/gui/plot/_utils/panzoom.py
@@ -0,0 +1,325 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions to apply pan and zoom on a Plot"""
+
+__authors__ = ["T. Vincent", "V. Valls"]
+__license__ = "MIT"
+__date__ = "08/08/2017"
+
+
+import logging
+import math
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+# Float 32 info ###############################################################
+# Using min/max value below limits of float32
+# so operation with such value (e.g., max - min) do not overflow
+
+FLOAT32_SAFE_MIN = -1e37
+FLOAT32_MINPOS = numpy.finfo(numpy.float32).tiny
+FLOAT32_SAFE_MAX = 1e37
+# TODO double support
+
+
+def checkAxisLimits(vmin, vmax, isLog: bool=False, name: str=""):
+ """Makes sure axis range is not empty and within supported range.
+
+ :param float vmin: Min axis value
+ :param float vmax: Max axis value
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ min_ = FLOAT32_MINPOS if isLog else FLOAT32_SAFE_MIN
+ vmax = numpy.clip(vmax, min_, FLOAT32_SAFE_MAX)
+ vmin = numpy.clip(vmin, min_, FLOAT32_SAFE_MAX)
+
+ if vmax < vmin:
+ _logger.debug('%s axis: max < min, inverting limits.', name)
+ vmin, vmax = vmax, vmin
+ elif vmax == vmin:
+ _logger.debug('%s axis: max == min, expanding limits.', name)
+ if vmin == 0.:
+ vmin, vmax = -0.1, 0.1
+ elif vmin < 0:
+ vmax *= 0.9
+ vmin = max(vmin * 1.1, FLOAT32_SAFE_MIN) # Clip to range
+ else: # vmin > 0
+ vmax = min(vmin * 1.1, FLOAT32_SAFE_MAX) # Clip to range
+ vmin *= 0.9
+
+ return vmin, vmax
+
+
+def scale1DRange(min_, max_, center, scale, isLog):
+ """Scale a 1D range given a scale factor and an center point.
+
+ Keeps the values in a smaller range than float32.
+
+ :param float min_: The current min value of the range.
+ :param float max_: The current max value of the range.
+ :param float center: The center of the zoom (i.e., invariant point).
+ :param float scale: The scale to use for zoom
+ :param bool isLog: Whether using log scale or not.
+ :return: The zoomed range.
+ :rtype: tuple of 2 floats: (min, max)
+ """
+ if isLog:
+ # Min and center can be < 0 when
+ # autoscale is off and switch to log scale
+ # max_ < 0 should not happen
+ min_ = numpy.log10(min_) if min_ > 0. else FLOAT32_MINPOS
+ center = numpy.log10(center) if center > 0. else FLOAT32_MINPOS
+ max_ = numpy.log10(max_) if max_ > 0. else FLOAT32_MINPOS
+
+ if min_ == max_:
+ return min_, max_
+
+ offset = (center - min_) / (max_ - min_)
+ range_ = (max_ - min_) / scale
+ newMin = center - offset * range_
+ newMax = center + (1. - offset) * range_
+
+ if isLog:
+ # No overflow as exponent is log10 of a float32
+ newMin = pow(10., newMin)
+ newMax = pow(10., newMax)
+ newMin = numpy.clip(newMin, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_MINPOS, FLOAT32_SAFE_MAX)
+ else:
+ newMin = numpy.clip(newMin, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ newMax = numpy.clip(newMax, FLOAT32_SAFE_MIN, FLOAT32_SAFE_MAX)
+ return newMin, newMax
+
+
+def applyZoomToPlot(plot, scaleF, center=None):
+ """Zoom in/out plot given a scale and a center point.
+
+ :param plot: The plot on which to apply zoom.
+ :param float scaleF: Scale factor of zoom.
+ :param center: (x, y) coords in pixel coordinates of the zoom center.
+ :type center: 2-tuple of float
+ """
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+
+ if center is None:
+ left, top, width, height = plot.getPlotBoundsInPixels()
+ cx, cy = left + width // 2, top + height // 2
+ else:
+ cx, cy = center
+
+ dataCenterPos = plot.pixelToData(cx, cy)
+ assert dataCenterPos is not None
+
+ xMin, xMax = scale1DRange(xMin, xMax, dataCenterPos[0], scaleF,
+ plot.getXAxis()._isLogarithmic())
+
+ yMin, yMax = scale1DRange(yMin, yMax, dataCenterPos[1], scaleF,
+ plot.getYAxis()._isLogarithmic())
+
+ dataPos = plot.pixelToData(cx, cy, axis="right")
+ assert dataPos is not None
+ y2Center = dataPos[1]
+ y2Min, y2Max = plot.getYAxis(axis="right").getLimits()
+ y2Min, y2Max = scale1DRange(y2Min, y2Max, y2Center, scaleF,
+ plot.getYAxis()._isLogarithmic())
+
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+
+def applyPan(min_, max_, panFactor, isLog10):
+ """Returns a new range with applied panning.
+
+ Moves the range according to panFactor.
+ If isLog10 is True, converts to log10 before moving.
+
+ :param float min_: Min value of the data range to pan.
+ :param float max_: Max value of the data range to pan.
+ Must be >= min.
+ :param float panFactor: Signed proportion of the range to use for pan.
+ :param bool isLog10: True if log10 scale, False if linear scale.
+ :return: New min and max value with pan applied.
+ :rtype: 2-tuple of float.
+ """
+ if isLog10 and min_ > 0.:
+ # Negative range and log scale can happen with matplotlib
+ logMin, logMax = math.log10(min_), math.log10(max_)
+ logOffset = panFactor * (logMax - logMin)
+ newMin = pow(10., logMin + logOffset)
+ newMax = pow(10., logMax + logOffset)
+
+ # Takes care of out-of-range values
+ if newMin > 0. and newMax < float('inf'):
+ min_, max_ = newMin, newMax
+
+ else:
+ offset = panFactor * (max_ - min_)
+ newMin, newMax = min_ + offset, max_ + offset
+
+ # Takes care of out-of-range values
+ if newMin > - float('inf') and newMax < float('inf'):
+ min_, max_ = newMin, newMax
+ return min_, max_
+
+
+class _Unset(object):
+ """To be able to have distinction between None and unset"""
+ pass
+
+
+class ViewConstraints(object):
+ """
+ Store constraints applied on the view box and compute the resulting view box.
+ """
+
+ def __init__(self):
+ self._min = [None, None]
+ self._max = [None, None]
+ self._minRange = [None, None]
+ self._maxRange = [None, None]
+
+ def update(self, xMin=_Unset, xMax=_Unset,
+ yMin=_Unset, yMax=_Unset,
+ minXRange=_Unset, maxXRange=_Unset,
+ minYRange=_Unset, maxYRange=_Unset):
+ """
+ Update the constraints managed by the object
+
+ The constraints are the same as the ones provided by PyQtGraph.
+
+ :param float xMin: Minimum allowed x-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float xMax: Maximum allowed x-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float yMin: Minimum allowed y-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float yMax: Maximum allowed y-axis value.
+ (default do not change the stat, None remove the constraint)
+ :param float minXRange: Minimum allowed left-to-right span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float maxXRange: Maximum allowed left-to-right span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float minYRange: Minimum allowed top-to-bottom span across the
+ view (default do not change the stat, None remove the constraint)
+ :param float maxYRange: Maximum allowed top-to-bottom span across the
+ view (default do not change the stat, None remove the constraint)
+ :return: True if the constraints was changed
+ """
+ updated = False
+
+ minRange = [minXRange, minYRange]
+ maxRange = [maxXRange, maxYRange]
+ minPos = [xMin, yMin]
+ maxPos = [xMax, yMax]
+
+ for axis in range(2):
+
+ value = minPos[axis]
+ if value is not _Unset and value != self._min[axis]:
+ self._min[axis] = value
+ updated = True
+
+ value = maxPos[axis]
+ if value is not _Unset and value != self._max[axis]:
+ self._max[axis] = value
+ updated = True
+
+ value = minRange[axis]
+ if value is not _Unset and value != self._minRange[axis]:
+ self._minRange[axis] = value
+ updated = True
+
+ value = maxRange[axis]
+ if value is not _Unset and value != self._maxRange[axis]:
+ self._maxRange[axis] = value
+ updated = True
+
+ # Sanity checks
+
+ for axis in range(2):
+ if self._maxRange[axis] is not None and self._min[axis] is not None and self._max[axis] is not None:
+ # max range cannot be larger than bounds
+ diff = self._max[axis] - self._min[axis]
+ self._maxRange[axis] = min(self._maxRange[axis], diff)
+ updated = True
+
+ return updated
+
+ def normalize(self, xMin, xMax, yMin, yMax, allow_scaling=True):
+ """Normalize a view range defined by x and y corners using predefined
+ containts.
+
+ :param float xMin: Min position of the x-axis
+ :param float xMax: Max position of the x-axis
+ :param float yMin: Min position of the y-axis
+ :param float yMax: Max position of the y-axis
+ :param bool allow_scaling: Allow or not to apply scaling for the
+ normalization. Used according to the interaction mode.
+ :return: A normalized tuple of (xMin, xMax, yMin, yMax)
+ """
+ viewRange = [[xMin, xMax], [yMin, yMax]]
+
+ for axis in range(2):
+ # clamp xRange and yRange
+ if allow_scaling:
+ diff = viewRange[axis][1] - viewRange[axis][0]
+ delta = None
+ if self._maxRange[axis] is not None and diff > self._maxRange[axis]:
+ delta = self._maxRange[axis] - diff
+ elif self._minRange[axis] is not None and diff < self._minRange[axis]:
+ delta = self._minRange[axis] - diff
+ if delta is not None:
+ viewRange[axis][0] -= delta * 0.5
+ viewRange[axis][1] += delta * 0.5
+
+ # clamp min and max positions
+ outMin = self._min[axis] is not None and viewRange[axis][0] < self._min[axis]
+ outMax = self._max[axis] is not None and viewRange[axis][1] > self._max[axis]
+
+ if outMin and outMax:
+ if allow_scaling:
+ # we can clamp both sides
+ viewRange[axis][0] = self._min[axis]
+ viewRange[axis][1] = self._max[axis]
+ else:
+ # center the result
+ delta = viewRange[axis][1] - viewRange[axis][0]
+ mid = self._min[axis] + self._max[axis] - self._min[axis]
+ viewRange[axis][0] = mid - delta
+ viewRange[axis][1] = mid + delta
+ elif outMin:
+ delta = self._min[axis] - viewRange[axis][0]
+ viewRange[axis][0] += delta
+ viewRange[axis][1] += delta
+ elif outMax:
+ delta = self._max[axis] - viewRange[axis][1]
+ viewRange[axis][0] += delta
+ viewRange[axis][1] += delta
+
+ return viewRange[0][0], viewRange[0][1], viewRange[1][0], viewRange[1][1]
diff --git a/silx/gui/plot/_utils/setup.py b/src/silx/gui/plot/_utils/setup.py
index 0271745..0271745 100644
--- a/silx/gui/plot/_utils/setup.py
+++ b/src/silx/gui/plot/_utils/setup.py
diff --git a/src/silx/gui/plot/_utils/test/__init__.py b/src/silx/gui/plot/_utils/test/__init__.py
new file mode 100644
index 0000000..3ad225d
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
new file mode 100644
index 0000000..8d35acf
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/test_dtime_ticklayout.py
@@ -0,0 +1,79 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["P. Kenter"]
+__license__ = "MIT"
+__date__ = "06/04/2018"
+
+
+import datetime as dt
+import unittest
+
+
+from silx.gui.plot._utils.dtime_ticklayout import (
+ calcTicks, DtUnit, SECONDS_PER_YEAR)
+
+
+class TestTickLayout(unittest.TestCase):
+ """Test ticks layout algorithms"""
+
+ def testSmallMonthlySpacing(self):
+ """ Tests a range that did result in a spacing of less than 1 month.
+ It is impossible to add fractional month so the unit must be in days
+ """
+ from dateutil import parser
+ d1 = parser.parse("2017-01-03 13:15:06.000044")
+ d2 = parser.parse("2017-03-08 09:16:16.307584")
+ _ticks, _units, spacing = calcTicks(d1, d2, nTicks=4)
+
+ self.assertEqual(spacing, DtUnit.DAYS)
+
+
+ def testNoCrash(self):
+ """ Creates many combinations of and number-of-ticks and end-dates;
+ tests that it doesn't give an exception and returns a reasonable number
+ of ticks.
+ """
+ d1 = dt.datetime(2017, 1, 3, 13, 15, 6, 44)
+
+ value = 100e-6 # Start at 100 micro sec range.
+
+ while value <= 200 * SECONDS_PER_YEAR:
+
+ d2 = d1 + dt.timedelta(microseconds=value*1e6) # end date range
+
+ for numTicks in range(2, 12):
+ ticks, _, _ = calcTicks(d1, d2, numTicks)
+
+ margin = 2.5
+ self.assertTrue(
+ numTicks/margin <= len(ticks) <= numTicks*margin,
+ "Condition {} <= {} <= {} failed for # ticks={} and d2={}:"
+ .format(numTicks/margin, len(ticks), numTicks * margin,
+ numTicks, d2))
+
+ value = value * 1.5 # let date period grow exponentially
diff --git a/src/silx/gui/plot/_utils/test/test_ticklayout.py b/src/silx/gui/plot/_utils/test/test_ticklayout.py
new file mode 100644
index 0000000..884b71b
--- /dev/null
+++ b/src/silx/gui/plot/_utils/test/test_ticklayout.py
@@ -0,0 +1,81 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+
+from silx.gui.plot._utils import ticklayout
+
+
+class TestTickLayout(ParametricTestCase):
+ """Test ticks layout algorithms"""
+
+ def testTicks(self):
+ """Test of :func:`ticks`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (1., 1.): (1.,),
+ (0.5, 10.5): (2.0, 4.0, 6.0, 8.0, 10.0),
+ (0.001, 0.005): (0.001, 0.002, 0.003, 0.004, 0.005)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks, labels = ticklayout.ticks(vmin, vmax)
+ self.assertTrue(numpy.allclose(ticks, ref_ticks))
+
+ def testNiceNumbers(self):
+ """Minimalistic tests of :func:`niceNumbers`"""
+ tests = { # (vmin, vmax): ref_ticks
+ (0.5, 10.5): (0.0, 12.0, 2.0, 0),
+ (10000., 10000.5): (10000.0, 10000.5, 0.1, 1),
+ (0.001, 0.005): (0.001, 0.005, 0.001, 3)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbers(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
+
+ def testNiceNumbersLog(self):
+ """Minimalistic tests of :func:`niceNumbersForLog10`"""
+ tests = { # (log10(min), log10(max): ref_ticks
+ (0., 3.): (0, 3, 1, 0),
+ (-3., 3): (-3, 3, 1, 0),
+ (-32., 0.): (-36, 0, 6, 0)
+ }
+
+ for (vmin, vmax), ref_ticks in tests.items():
+ with self.subTest(vmin=vmin, vmax=vmax):
+ ticks = ticklayout.niceNumbersForLog10(vmin, vmax)
+ self.assertEqual(ticks, ref_ticks)
diff --git a/silx/gui/plot/_utils/ticklayout.py b/src/silx/gui/plot/_utils/ticklayout.py
index c9fd3e6..c9fd3e6 100644
--- a/silx/gui/plot/_utils/ticklayout.py
+++ b/src/silx/gui/plot/_utils/ticklayout.py
diff --git a/silx/gui/plot/actions/PlotAction.py b/src/silx/gui/plot/actions/PlotAction.py
index 2983775..2983775 100644
--- a/silx/gui/plot/actions/PlotAction.py
+++ b/src/silx/gui/plot/actions/PlotAction.py
diff --git a/silx/gui/plot/actions/PlotToolAction.py b/src/silx/gui/plot/actions/PlotToolAction.py
index fbb0b0f..fbb0b0f 100644
--- a/silx/gui/plot/actions/PlotToolAction.py
+++ b/src/silx/gui/plot/actions/PlotToolAction.py
diff --git a/silx/gui/plot/actions/__init__.py b/src/silx/gui/plot/actions/__init__.py
index 930c728..930c728 100644
--- a/silx/gui/plot/actions/__init__.py
+++ b/src/silx/gui/plot/actions/__init__.py
diff --git a/silx/gui/plot/actions/control.py b/src/silx/gui/plot/actions/control.py
index 439985e..439985e 100755
--- a/silx/gui/plot/actions/control.py
+++ b/src/silx/gui/plot/actions/control.py
diff --git a/src/silx/gui/plot/actions/fit.py b/src/silx/gui/plot/actions/fit.py
new file mode 100644
index 0000000..e130b24
--- /dev/null
+++ b/src/silx/gui/plot/actions/fit.py
@@ -0,0 +1,485 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.fit` module provides actions relative to fit.
+
+The following QAction are available:
+
+- :class:`.FitAction`
+
+.. autoclass:`.FitAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "10/10/2018"
+
+import logging
+import sys
+import weakref
+import numpy
+
+from .PlotToolAction import PlotToolAction
+from .. import items
+from ....utils.deprecation import deprecated
+from silx.gui import qt
+from silx.gui.plot.ItemsSelectionDialog import ItemsSelectionDialog
+
+_logger = logging.getLogger(__name__)
+
+
+def _getUniqueCurveOrHistogram(plot):
+ """Returns unique :class:`Curve` or :class:`Histogram` in a `PlotWidget`.
+
+ If there is an active curve, returns it, else return curve or histogram
+ only if alone in the plot.
+
+ :param PlotWidget plot:
+ :rtype: Union[None,~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram]
+ """
+ curve = plot.getActiveCurve()
+ if curve is not None:
+ return curve
+
+ visibleItems = [item for item in plot.getItems() if item.isVisible()]
+ histograms = [item for item in visibleItems
+ if isinstance(item, items.Histogram)]
+ curves = [item for item in visibleItems
+ if isinstance(item, items.Curve)]
+
+ if len(histograms) == 1 and len(curves) == 0:
+ return histograms[0]
+ elif len(curves) == 1 and len(histograms) == 0:
+ return curves[0]
+ else:
+ return None
+
+
+class _FitItemSelector(qt.QObject):
+ """
+ :class:`PlotWidget` observer that emits signal when fit selection changes.
+
+ Track active curve or unique curve or histogram.
+ """
+
+ sigCurrentItemChanged = qt.Signal(object)
+ """Signal emitted when the item to fit has changed"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__plotWidgetRef = None
+ self.__currentItem = None
+
+ def getCurrentItem(self):
+ """Return currently selected item
+
+ :rtype: Union[Item,None]
+ """
+ return self.__currentItem
+
+ def getPlotWidget(self):
+ """Return currently attached :class:`PlotWidget`
+
+ :rtype: Union[PlotWidget,None]
+ """
+ return None if self.__plotWidgetRef is None else self.__plotWidgetRef()
+
+ def setPlotWidget(self, plotWidget):
+ """Set the :class:`PlotWidget` for which to track changes
+
+ :param Union[PlotWidget,None] plotWidget:
+ The :class:`PlotWidget` to observe
+ """
+ # disconnect from previous plot
+ previousPlotWidget = self.getPlotWidget()
+ if previousPlotWidget is not None:
+ previousPlotWidget.sigItemAdded.disconnect(
+ self.__plotWidgetUpdated)
+ previousPlotWidget.sigItemRemoved.disconnect(
+ self.__plotWidgetUpdated)
+ previousPlotWidget.sigActiveCurveChanged.disconnect(
+ self.__plotWidgetUpdated)
+
+ if plotWidget is None:
+ self.__plotWidgetRef = None
+ self.__setCurrentItem(None)
+ return
+ self.__plotWidgetRef = weakref.ref(plotWidget, self.__plotDeleted)
+
+ # connect to new plot
+ plotWidget.sigItemAdded.connect(self.__plotWidgetUpdated)
+ plotWidget.sigItemRemoved.connect(self.__plotWidgetUpdated)
+ plotWidget.sigActiveCurveChanged.connect(self.__plotWidgetUpdated)
+ self.__plotWidgetUpdated()
+
+ def __plotDeleted(self):
+ """Handle deletion of PlotWidget"""
+ self.__setCurrentItem(None)
+
+ def __plotWidgetUpdated(self, *args, **kwargs):
+ """Handle updates of PlotWidget content"""
+ plotWidget = self.getPlotWidget()
+ if plotWidget is None:
+ return
+ self.__setCurrentItem(_getUniqueCurveOrHistogram(plotWidget))
+
+ def __setCurrentItem(self, item):
+ """Handle change of current item"""
+ if sys.is_finalizing():
+ return
+
+ previousItem = self.getCurrentItem()
+ if item != previousItem:
+ if previousItem is not None:
+ previousItem.sigItemChanged.disconnect(self.__itemUpdated)
+
+ self.__currentItem = item
+
+ if self.__currentItem is not None:
+ self.__currentItem.sigItemChanged.connect(self.__itemUpdated)
+ self.sigCurrentItemChanged.emit(self.__currentItem)
+
+ def __itemUpdated(self, event):
+ """Handle change on current item"""
+ if event == items.ItemChangedType.DATA:
+ self.sigCurrentItemChanged.emit(self.__currentItem)
+
+
+class FitAction(PlotToolAction):
+ """QAction to open a :class:`FitWidget` and set its data to the
+ active curve if any, or to the first curve.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ self.__item = None
+ self.__activeCurveSynchroEnabled = False
+ self.__range = 0, 1
+ self.__rangeAutoUpdate = False
+ self.__x, self.__y = None, None # Data to fit
+ self.__curveParams = {} # Store curve parameters to use for fit result
+ self.__legend = None
+
+ super(FitAction, self).__init__(
+ plot, icon='math-fit', text='Fit curve',
+ tooltip='Open a fit dialog',
+ parent=parent)
+
+ self.__fitItemSelector = _FitItemSelector()
+ self.__fitItemSelector.sigCurrentItemChanged.connect(
+ self._setFittedItem)
+
+
+ @property
+ @deprecated(replacement='getXRange()[0]', since_version='0.13.0')
+ def xmin(self):
+ return self.getXRange()[0]
+
+ @property
+ @deprecated(replacement='getXRange()[1]', since_version='0.13.0')
+ def xmax(self):
+ return self.getXRange()[1]
+
+ @property
+ @deprecated(replacement='getXData()', since_version='0.13.0')
+ def x(self):
+ return self.getXData()
+
+ @property
+ @deprecated(replacement='getYData()', since_version='0.13.0')
+ def y(self):
+ return self.getYData()
+
+ @property
+ @deprecated(since_version='0.13.0')
+ def xlabel(self):
+ return self.__curveParams.get('xlabel', None)
+
+ @property
+ @deprecated(since_version='0.13.0')
+ def ylabel(self):
+ return self.__curveParams.get('ylabel', None)
+
+ @property
+ @deprecated(since_version='0.13.0')
+ def legend(self):
+ return self.__legend
+
+ def _createToolWindow(self):
+ # import done here rather than at module level to avoid circular import
+ # FitWidget -> BackgroundWidget -> PlotWindow -> actions -> fit -> FitWidget
+ from ...fit.FitWidget import FitWidget
+
+ window = FitWidget(parent=self.plot)
+ window.setWindowFlags(qt.Qt.Dialog)
+ window.sigFitWidgetSignal.connect(self.handle_signal)
+ return window
+
+ def _connectPlot(self, window):
+ if self.isXRangeUpdatedOnZoom():
+ self.__setAutoXRangeEnabled(True)
+ else:
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+ self._setXRange(*plot.getXAxis().getLimits())
+
+ if self.isFittedItemUpdatedFromActiveCurve():
+ self.__setFittedItemAutoUpdateEnabled(True)
+ else:
+ # Wait for the next iteration, else the plot is not yet initialized
+ # No curve available
+ qt.QTimer.singleShot(10, self._initFit)
+
+ def _disconnectPlot(self, window):
+ if self.isXRangeUpdatedOnZoom():
+ self.__setAutoXRangeEnabled(False)
+
+ if self.isFittedItemUpdatedFromActiveCurve():
+ self.__setFittedItemAutoUpdateEnabled(False)
+
+ def _initFit(self):
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ item = _getUniqueCurveOrHistogram(plot)
+ if item is None:
+ # ambiguous case, we need to ask which plot item to fit
+ isd = ItemsSelectionDialog(parent=plot, plot=plot)
+ isd.setWindowTitle("Select item to be fitted")
+ isd.setItemsSelectionMode(qt.QTableWidget.SingleSelection)
+ isd.setAvailableKinds(["curve", "histogram"])
+ isd.selectAllKinds()
+
+ if not isd.exec(): # Cancel
+ self._getToolWindow().setVisible(False)
+ else:
+ selectedItems = isd.getSelectedItems()
+ item = selectedItems[0] if len(selectedItems) == 1 else None
+
+ self._setXRange(*plot.getXAxis().getLimits())
+ self._setFittedItem(item)
+
+ def __updateFitWidget(self):
+ """Update the data/range used by the FitWidget"""
+ fitWidget = self._getToolWindow()
+
+ item = self._getFittedItem()
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ if item is None or xdata is None or ydata is None:
+ fitWidget.setData(y=None)
+ fitWidget.setWindowTitle("No curve selected")
+
+ else:
+ xmin, xmax = self.getXRange()
+ fitWidget.setData(
+ xdata, ydata, xmin=xmin, xmax=xmax)
+ fitWidget.setWindowTitle(
+ "Fitting " + item.getName() +
+ " on x range %f-%f" % (xmin, xmax))
+
+ # X Range management
+
+ def getXRange(self):
+ """Returns the range on the X axis on which to perform the fit."""
+ return self.__range
+
+ def _setXRange(self, xmin, xmax):
+ """Set the range on which the fit is done.
+
+ :param float xmin:
+ :param float xmax:
+ """
+ range_ = float(xmin), float(xmax)
+ if self.__range != range_:
+ self.__range = range_
+ self.__updateFitWidget()
+
+ def __setAutoXRangeEnabled(self, enabled):
+ """Implement the change of update mode of the X range.
+
+ :param bool enabled:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ if enabled:
+ self._setXRange(*plot.getXAxis().getLimits())
+ plot.getXAxis().sigLimitsChanged.connect(self._setXRange)
+ else:
+ plot.getXAxis().sigLimitsChanged.disconnect(self._setXRange)
+
+ def setXRangeUpdatedOnZoom(self, enabled):
+ """Set whether or not to update the X range on zoom change.
+
+ :param bool enabled:
+ """
+ if enabled != self.__rangeAutoUpdate:
+ self.__rangeAutoUpdate = enabled
+ if self._getToolWindow().isVisible():
+ self.__setAutoXRangeEnabled(enabled)
+
+ def isXRangeUpdatedOnZoom(self):
+ """Returns the current mode of fitted data X range update.
+
+ :rtype: bool
+ """
+ return self.__rangeAutoUpdate
+
+ # Fitted item update
+
+ def getXData(self, copy=True):
+ """Returns the X data used for the fit or None if undefined.
+
+ :param bool copy:
+ True to get a copy of the data, False to get the internal data.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ return None if self.__x is None else numpy.array(self.__x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the Y data used for the fit or None if undefined.
+
+ :param bool copy:
+ True to get a copy of the data, False to get the internal data.
+ :rtype: Union[numpy.ndarray,None]
+ """
+ return None if self.__y is None else numpy.array(self.__y, copy=copy)
+
+ def _getFittedItem(self):
+ """Returns the current item used for the fit
+
+ :rtype: Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None]
+ """
+ return self.__item
+
+ def _setFittedItem(self, item):
+ """Set the curve to use for fitting.
+
+ :param Union[~silx.gui.plot.items.Curve,~silx.gui.plot.items.Histogram,None] item:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+
+ if plot is None or item is None:
+ self.__item = None
+ self.__curveParams = {}
+ self.__updateFitWidget()
+ return
+
+ axis = item.getYAxis() if isinstance(item, items.YAxisMixIn) else 'left'
+ self.__curveParams = {
+ 'yaxis': axis,
+ 'xlabel': plot.getXAxis().getLabel(),
+ 'ylabel': plot.getYAxis(axis).getLabel(),
+ }
+ self.__legend = item.getName()
+
+ if isinstance(item, items.Histogram):
+ bin_edges = item.getBinEdgesData(copy=False)
+ # take the middle coordinate between adjacent bin edges
+ self.__x = (bin_edges[1:] + bin_edges[:-1]) / 2
+ self.__y = item.getValueData(copy=False)
+ # else take the active curve, or else the unique curve
+ elif isinstance(item, items.Curve):
+ self.__x = item.getXData(copy=False)
+ self.__y = item.getYData(copy=False)
+
+ self.__item = item
+ self.__updateFitWidget()
+
+ def __setFittedItemAutoUpdateEnabled(self, enabled):
+ """Implement the change of fitted item update mode
+
+ :param bool enabled:
+ """
+ plot = self.plot
+ if plot is None:
+ _logger.error("No associated PlotWidget")
+ return
+
+ self.__fitItemSelector.setPlotWidget(self.plot if enabled else None)
+
+ def setFittedItemUpdatedFromActiveCurve(self, enabled):
+ """Toggle fitted data synchronization with plot active curve.
+
+ :param bool enabled:
+ """
+ enabled = bool(enabled)
+ if enabled != self.__activeCurveSynchroEnabled:
+ self.__activeCurveSynchroEnabled = enabled
+ if self._getToolWindow().isVisible():
+ self.__setFittedItemAutoUpdateEnabled(enabled)
+
+ def isFittedItemUpdatedFromActiveCurve(self):
+ """Returns True if fitted data is synchronized with plot.
+
+ :rtype: bool
+ """
+ return self.__activeCurveSynchroEnabled
+
+ # Handle fit completed
+
+ def handle_signal(self, ddict):
+ xdata = self.getXData(copy=False)
+ if xdata is None:
+ _logger.error("No reference data to display fit result for")
+ return
+
+ xmin, xmax = self.getXRange()
+ x_fit = xdata[xmin <= xdata]
+ x_fit = x_fit[x_fit <= xmax]
+ fit_legend = "Fit <%s>" % self.__legend
+ fit_curve = self.plot.getCurve(fit_legend)
+
+ if ddict["event"] == "FitFinished":
+ fit_widget = self._getToolWindow()
+ if fit_widget is None:
+ return
+ y_fit = fit_widget.fitmanager.gendata()
+ if fit_curve is None:
+ self.plot.addCurve(x_fit, y_fit,
+ fit_legend,
+ resetzoom=False,
+ **self.__curveParams)
+ else:
+ fit_curve.setData(x_fit, y_fit)
+ fit_curve.setVisible(True)
+ fit_curve.setYAxis(self.__curveParams.get('yaxis', 'left'))
+
+ if ddict["event"] in ["FitStarted", "FitFailed"]:
+ if fit_curve is not None:
+ fit_curve.setVisible(False)
diff --git a/src/silx/gui/plot/actions/histogram.py b/src/silx/gui/plot/actions/histogram.py
new file mode 100644
index 0000000..be9f5a7
--- /dev/null
+++ b/src/silx/gui/plot/actions/histogram.py
@@ -0,0 +1,542 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.histogram` provides actions relative to histograms
+for :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`PixelIntensitiesHistoAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__date__ = "01/12/2020"
+__license__ = "MIT"
+
+from typing import Optional, Tuple
+import numpy
+import logging
+import weakref
+
+from .PlotToolAction import PlotToolAction
+
+from silx.math.histogram import Histogramnd
+from silx.math.combo import min_max
+from silx.gui import qt
+from silx.gui.plot import items
+from silx.gui.widgets.ElidedLabel import ElidedLabel
+from silx.gui.widgets.RangeSlider import RangeSlider
+from silx.utils.deprecation import deprecated
+
+_logger = logging.getLogger(__name__)
+
+
+class _ElidedLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ nbchar = max(len(self.getText()), 12)
+ width = self.fontMetrics().boundingRect('#' * nbchar).width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+class _StatWidget(qt.QWidget):
+ """Widget displaying a name and a value
+
+ :param parent:
+ :param name:
+ """
+
+ def __init__(self, parent=None, name: str=''):
+ super().__init__(parent)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ keyWidget = qt.QLabel(parent=self)
+ keyWidget.setText("<b>" + name.capitalize() + ":<b>")
+ layout.addWidget(keyWidget)
+ self.__valueWidget = _ElidedLabel(parent=self)
+ self.__valueWidget.setText("-")
+ self.__valueWidget.setTextInteractionFlags(
+ qt.Qt.TextSelectableByMouse | qt.Qt.TextSelectableByKeyboard)
+ layout.addWidget(self.__valueWidget)
+
+ def setValue(self, value: Optional[float]):
+ """Set the displayed value
+
+ :param value:
+ """
+ self.__valueWidget.setText(
+ "-" if value is None else "{:.5g}".format(value))
+
+
+class _IntEdit(qt.QLineEdit):
+ """QLineEdit for integers with a default value and update on validation.
+
+ :param QWidget parent:
+ """
+
+ sigValueChanged = qt.Signal(int)
+ """Signal emitted when the value has changed (on editing finished)"""
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.__value = None
+ self.setAlignment(qt.Qt.AlignRight)
+ validator = qt.QIntValidator()
+ self.setValidator(validator)
+ validator.bottomChanged.connect(self.__updateSize)
+ validator.topChanged.connect(self.__updateSize)
+ self.__updateSize()
+
+ self.textEdited.connect(self.__textEdited)
+
+ def __updateSize(self, *args):
+ """Update widget's maximum size according to bounds"""
+ bottom, top = self.getRange()
+ nbchar = max(len(str(bottom)), len(str(top)))
+ font = self.font()
+ font.setStyle(qt.QFont.StyleItalic)
+ fontMetrics = qt.QFontMetrics(font)
+ self.setMaximumWidth(
+ fontMetrics.boundingRect('0' * (nbchar + 1)).width()
+ )
+ self.setMaxLength(nbchar)
+
+ def __textEdited(self, _):
+ if self.font().style() != qt.QFont.StyleItalic:
+ font = self.font()
+ font.setStyle(qt.QFont.StyleItalic)
+ self.setFont(font)
+
+ # Use events rather than editingFinished to also trigger with empty text
+
+ def focusOutEvent(self, event):
+ self.__commitValue()
+ return super().focusOutEvent(event)
+
+ def keyPressEvent(self, event):
+ if event.key() in (qt.Qt.Key_Enter, qt.Qt.Key_Return):
+ self.__commitValue()
+ return super().keyPressEvent(event)
+
+ def __commitValue(self):
+ """Update the value returned by :meth:`getValue`"""
+ value = self.getCurrentValue()
+ if value is None:
+ value = self.getDefaultValue()
+ if value is None:
+ return # No value, keep previous one
+
+ if self.font().style() != qt.QFont.StyleNormal:
+ font = self.font()
+ font.setStyle(qt.QFont.StyleNormal)
+ self.setFont(font)
+
+ if value != self.__value:
+ self.__value = value
+ self.sigValueChanged.emit(value)
+
+ def getValue(self) -> Optional[int]:
+ """Return current value (None if never set)."""
+ return self.__value
+
+ def setRange(self, bottom: int, top: int):
+ """Set the range of valid values"""
+ self.validator().setRange(bottom, top)
+
+ def getRange(self) -> Tuple[int, int]:
+ """Returns the current range of valid values
+
+ :returns: (bottom, top)
+ """
+ return self.validator().bottom(), self.validator().top()
+
+ def __validate(self, value: int, extend_range: bool):
+ """Ensure value is in range
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed.
+ """
+ if extend_range:
+ bottom, top = self.getRange()
+ self.setRange(min(value, bottom), max(value, top))
+ return numpy.clip(value, *self.getRange())
+
+ def setDefaultValue(self, value: int, extend_range: bool=False):
+ """Set default value when QLineEdit is empty
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed
+ """
+ self.setPlaceholderText(str(self.__validate(value, extend_range)))
+ if self.getCurrentValue() is None:
+ self.__commitValue()
+
+ def getDefaultValue(self) -> Optional[int]:
+ """Return the default value or the bottom one if not set"""
+ try:
+ return int(self.placeholderText())
+ except ValueError:
+ return None
+
+ def setCurrentValue(self, value: int, extend_range: bool=False):
+ """Set the currently displayed value
+
+ :param int value:
+ :param bool extend_range:
+ True to extend range if needed.
+ False to clip value if needed
+ """
+ self.setText(str(self.__validate(value, extend_range)))
+ self.__commitValue()
+
+ def getCurrentValue(self) -> Optional[int]:
+ """Returns the displayed value or None if not correct"""
+ try:
+ return int(self.text())
+ except ValueError:
+ return None
+
+
+class HistogramWidget(qt.QWidget):
+ """Widget displaying a histogram and some statistic indicators"""
+
+ _SUPPORTED_ITEM_CLASS = items.ImageBase, items.Scatter
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setWindowTitle('Histogram')
+
+ self.__itemRef = None # weakref on the item to track
+
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ # Plot
+ # Lazy import to avoid circular dependencies
+ from silx.gui.plot.PlotWindow import Plot1D
+ self.__plot = Plot1D(self)
+ layout.addWidget(self.__plot)
+
+ self.__plot.setDataMargins(0.1, 0.1, 0.1, 0.1)
+ self.__plot.getXAxis().setLabel("Value")
+ self.__plot.getYAxis().setLabel("Count")
+ posInfo = self.__plot.getPositionInfoWidget()
+ posInfo.setSnappingMode(posInfo.SNAPPING_CURVE)
+
+ # Histogram controls
+ controlsWidget = qt.QWidget(self)
+ layout.addWidget(controlsWidget)
+ controlsLayout = qt.QHBoxLayout(controlsWidget)
+ controlsLayout.setContentsMargins(4, 4, 4, 4)
+
+ controlsLayout.addWidget(qt.QLabel("<b>Histogram:<b>"))
+ controlsLayout.addWidget(qt.QLabel("N. bins:"))
+ self.__nbinsLineEdit = _IntEdit(self)
+ self.__nbinsLineEdit.setRange(2, 9999)
+ self.__nbinsLineEdit.sigValueChanged.connect(
+ self.__updateHistogramFromControls)
+ controlsLayout.addWidget(self.__nbinsLineEdit)
+ self.__rangeLabel = qt.QLabel("Range:")
+ controlsLayout.addWidget(self.__rangeLabel)
+ self.__rangeSlider = RangeSlider(parent=self)
+ self.__rangeSlider.sigValueChanged.connect(
+ self.__updateHistogramFromControls)
+ self.__rangeSlider.sigValueChanged.connect(self.__rangeChanged)
+ controlsLayout.addWidget(self.__rangeSlider)
+ controlsLayout.addStretch(1)
+
+ # Stats display
+ statsWidget = qt.QWidget(self)
+ layout.addWidget(statsWidget)
+ statsLayout = qt.QHBoxLayout(statsWidget)
+ statsLayout.setContentsMargins(4, 4, 4, 4)
+
+ self.__statsWidgets = dict(
+ (name, _StatWidget(parent=statsWidget, name=name))
+ for name in ("min", "max", "mean", "std", "sum"))
+
+ for widget in self.__statsWidgets.values():
+ statsLayout.addWidget(widget)
+ statsLayout.addStretch(1)
+
+ def getPlotWidget(self):
+ """Returns :class:`PlotWidget` use to display the histogram"""
+ return self.__plot
+
+ def resetZoom(self):
+ """Reset PlotWidget zoom"""
+ self.getPlotWidget().resetZoom()
+
+ def reset(self):
+ """Clear displayed information"""
+ self.getPlotWidget().clear()
+ self.setStatistics()
+
+ def getItem(self) -> Optional[items.Item]:
+ """Returns item used to display histogram and statistics."""
+ return None if self.__itemRef is None else self.__itemRef()
+
+ def setItem(self, item: Optional[items.Item]):
+ """Set item from which to display histogram and statistics.
+
+ :param item:
+ """
+ previous = self.getItem()
+ if previous is not None:
+ previous.sigItemChanged.disconnect(self.__itemChanged)
+
+ self.__itemRef = None if item is None else weakref.ref(item)
+ if item is not None:
+ if isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ # Only listen signal for supported items
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._updateFromItem()
+
+ def __itemChanged(self, event):
+ """Handle update of the item"""
+ if event in (items.ItemChangedType.DATA, items.ItemChangedType.MASK):
+ self._updateFromItem()
+
+ def __updateHistogramFromControls(self, *args):
+ """Handle udates coming from histogram control widgets"""
+
+ hist = self.getHistogram(copy=False)
+ if hist is not None:
+ count, edges = hist
+ if (len(count) == self.__nbinsLineEdit.getValue() and
+ (edges[0], edges[-1]) == self.__rangeSlider.getValues()):
+ return # Nothing has changed
+
+ self._updateFromItem()
+
+ def __rangeChanged(self, first, second):
+ """Handle change of histogram range from the range slider"""
+ tooltip = "Histogram range:\n[%g, %g]" % (first, second)
+ self.__rangeSlider.setToolTip(tooltip)
+ self.__rangeLabel.setToolTip(tooltip)
+
+ def _updateFromItem(self):
+ """Update histogram and stats from the item"""
+ item = self.getItem()
+
+ if item is None:
+ self.reset()
+ return
+
+ if not isinstance(item, self._SUPPORTED_ITEM_CLASS):
+ _logger.error("Unsupported item", item)
+ self.reset()
+ return
+
+ # Compute histogram and stats
+ array = item.getValueData(copy=False)
+
+ if array.size == 0:
+ self.reset()
+ return
+
+ xmin, xmax = min_max(array, min_positive=False, finite=True)
+ if xmin is None or xmax is None: # All not finite data
+ self.reset()
+ return
+ guessed_nbins = min(1024, int(numpy.sqrt(array.size)))
+
+ # bad hack: get 256 bins in the case we have a B&W
+ if numpy.issubdtype(array.dtype, numpy.integer):
+ if guessed_nbins > xmax - xmin:
+ guessed_nbins = xmax - xmin
+ guessed_nbins = max(2, guessed_nbins)
+
+ # Set default nbins
+ self.__nbinsLineEdit.setDefaultValue(guessed_nbins, extend_range=True)
+ # Set slider range: do not keep the range value, but the relative pos.
+ previousPositions = self.__rangeSlider.getPositions()
+ if xmin == xmax: # Enlarge range is none
+ if xmin == 0:
+ range_ = -0.01, 0.01
+ else:
+ range_ = sorted((xmin * .99, xmin * 1.01))
+ else:
+ range_ = xmin, xmax
+
+ self.__rangeSlider.setRange(*range_)
+ self.__rangeSlider.setPositions(*previousPositions)
+
+ histogram = Histogramnd(
+ array.ravel().astype(numpy.float32),
+ n_bins=max(2, self.__nbinsLineEdit.getValue()),
+ histo_range=self.__rangeSlider.getValues(),
+ )
+ if len(histogram.edges) != 1:
+ _logger.error("Error while computing the histogram")
+ self.reset()
+ return
+
+ self.setHistogram(histogram.histo, histogram.edges[0])
+ self.resetZoom()
+ self.setStatistics(
+ min_=xmin,
+ max_=xmax,
+ mean=numpy.nanmean(array),
+ std=numpy.nanstd(array),
+ sum_=numpy.nansum(array))
+
+ def setHistogram(self, histogram, edges):
+ """Set displayed histogram
+
+ :param histogram: Bin values (N)
+ :param edges: Bin edges (N+1)
+ """
+ # Only useful if setHistogram is called directly
+ # TODO
+ #nbins = len(histogram)
+ #if nbins != self.__nbinsLineEdit.getDefaultValue():
+ # self.__nbinsLineEdit.setValue(nbins, extend_range=True)
+ #self.__rangeSlider.setValues(edges[0], edges[-1])
+
+ self.getPlotWidget().addHistogram(
+ histogram=histogram,
+ edges=edges,
+ legend='histogram',
+ fill=True,
+ color='#66aad7',
+ resetzoom=False)
+
+ def getHistogram(self, copy: bool=True):
+ """Returns currently displayed histogram.
+
+ :param copy: True to get a copy,
+ False to get internal representation (Do not modify!)
+ :return: (histogram, edges) or None
+ """
+ for item in self.getPlotWidget().getItems():
+ if item.getName() == 'histogram':
+ return (item.getValueData(copy=copy),
+ item.getBinEdgesData(copy=copy))
+ else:
+ return None
+
+ def setStatistics(self,
+ min_: Optional[float] = None,
+ max_: Optional[float] = None,
+ mean: Optional[float] = None,
+ std: Optional[float] = None,
+ sum_: Optional[float] = None):
+ """Set displayed statistic indicators."""
+ self.__statsWidgets['min'].setValue(min_)
+ self.__statsWidgets['max'].setValue(max_)
+ self.__statsWidgets['mean'].setValue(mean)
+ self.__statsWidgets['std'].setValue(std)
+ self.__statsWidgets['sum'].setValue(sum_)
+
+
+class PixelIntensitiesHistoAction(PlotToolAction):
+ """QAction to plot the pixels intensities diagram
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ PlotToolAction.__init__(self,
+ plot,
+ icon='pixel-intensities',
+ text='pixels intensity',
+ tooltip='Compute image intensity distribution',
+ parent=parent)
+
+ def _connectPlot(self, window):
+ plot = self.plot
+ if plot is not None:
+ selection = plot.selection()
+ selection.sigSelectedItemsChanged.connect(self._selectedItemsChanged)
+ self._updateSelectedItem()
+
+ PlotToolAction._connectPlot(self, window)
+
+ def _disconnectPlot(self, window):
+ plot = self.plot
+ if plot is not None:
+ selection = self.plot.selection()
+ selection.sigSelectedItemsChanged.disconnect(self._selectedItemsChanged)
+
+ PlotToolAction._disconnectPlot(self, window)
+ self.getHistogramWidget().setItem(None)
+
+ def _updateSelectedItem(self):
+ """Synchronises selected item with plot widget."""
+ plot = self.plot
+ if plot is not None:
+ selected = plot.selection().getSelectedItems()
+ # Give priority to image over scatter
+ for klass in (items.ImageBase, items.Scatter):
+ for item in selected:
+ if isinstance(item, klass):
+ # Found a matching item, use it
+ self.getHistogramWidget().setItem(item)
+ return
+ self.getHistogramWidget().setItem(None)
+
+ def _selectedItemsChanged(self):
+ if self._isWindowInUse():
+ self._updateSelectedItem()
+
+ @deprecated(since_version='0.15.0')
+ def computeIntensityDistribution(self):
+ self.getHistogramWidget()._updateFromItem()
+
+ def getHistogramWidget(self):
+ """Returns the widget displaying the histogram"""
+ return self._getToolWindow()
+
+ @deprecated(since_version='0.15.0',
+ replacement='getHistogramWidget().getPlotWidget()')
+ def getHistogramPlotWidget(self):
+ return self._getToolWindow().getPlotWidget()
+
+ def _createToolWindow(self):
+ return HistogramWidget(self.plot, qt.Qt.Window)
+
+ def getHistogram(self) -> Optional[numpy.ndarray]:
+ """Return the last computed histogram
+
+ :return: the histogram displayed in the HistogramWidget
+ """
+ histogram = self.getHistogramWidget().getHistogram()
+ return None if histogram is None else histogram[0]
diff --git a/src/silx/gui/plot/actions/io.py b/src/silx/gui/plot/actions/io.py
new file mode 100644
index 0000000..7f4edd3
--- /dev/null
+++ b/src/silx/gui/plot/actions/io.py
@@ -0,0 +1,819 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+:mod:`silx.gui.plot.actions.io` provides a set of QAction relative of inputs
+and outputs for a :class:`.PlotWidget`.
+
+The following QAction are available:
+
+- :class:`CopyAction`
+- :class:`PrintAction`
+- :class:`SaveAction`
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "25/09/2020"
+
+from . import PlotAction
+from silx.io.utils import save1D, savespec, NEXUS_HDF5_EXT
+from silx.io.nxdata import save_NXdata
+import logging
+import sys
+import os.path
+from collections import OrderedDict
+import traceback
+import numpy
+from silx.utils.deprecation import deprecated
+from silx.gui import qt, printer
+from silx.gui.dialog.GroupDialog import GroupDialog
+from silx.third_party.EdfFile import EdfFile
+from silx.third_party.TiffIO import TiffIO
+from ...utils.image import convertArrayToQImage
+if sys.version_info[0] == 3:
+ from io import BytesIO
+else:
+ import cStringIO as _StringIO
+ BytesIO = _StringIO.StringIO
+
+_logger = logging.getLogger(__name__)
+
+_NEXUS_HDF5_EXT_STR = ' '.join(['*' + ext for ext in NEXUS_HDF5_EXT])
+
+
+def selectOutputGroup(h5filename):
+ """Open a dialog to prompt the user to select a group in
+ which to output data.
+
+ :param str h5filename: name of an existing HDF5 file
+ :rtype: str
+ :return: Name of output group, or None if the dialog was cancelled
+ """
+ dialog = GroupDialog()
+ dialog.addFile(h5filename)
+ dialog.setWindowTitle("Select an output group")
+ if not dialog.exec():
+ return None
+ return dialog.getSelectedDataUrl().data_path()
+
+
+class SaveAction(PlotAction):
+ """QAction for saving Plot content.
+
+ It opens a Save as... dialog.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ SNAPSHOT_FILTER_SVG = 'Plot Snapshot as SVG (*.svg)'
+ SNAPSHOT_FILTER_PNG = 'Plot Snapshot as PNG (*.png)'
+
+ DEFAULT_ALL_FILTERS = (SNAPSHOT_FILTER_PNG, SNAPSHOT_FILTER_SVG)
+
+ # Dict of curve filters with CSV-like format
+ # Using ordered dict to guarantee filters order
+ # Note: '%.18e' is numpy.savetxt default format
+ CURVE_FILTERS_TXT = OrderedDict((
+ ('Curve as Raw ASCII (*.txt)',
+ {'fmt': '%.18e', 'delimiter': ' ', 'header': False}),
+ ('Curve as ";"-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': ';', 'header': True}),
+ ('Curve as ","-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': ',', 'header': True}),
+ ('Curve as tab-separated CSV (*.csv)',
+ {'fmt': '%.18e', 'delimiter': '\t', 'header': True}),
+ ('Curve as OMNIC CSV (*.csv)',
+ {'fmt': '%.7E', 'delimiter': ',', 'header': False}),
+ ('Curve as SpecFile (*.dat)',
+ {'fmt': '%.10g', 'delimiter': '', 'header': False})
+ ))
+
+ CURVE_FILTER_NPY = 'Curve as NumPy binary file (*.npy)'
+
+ CURVE_FILTER_NXDATA = 'Curve as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_CURVE_FILTERS = list(CURVE_FILTERS_TXT.keys()) + [
+ CURVE_FILTER_NPY, CURVE_FILTER_NXDATA]
+
+ DEFAULT_ALL_CURVES_FILTERS = ("All curves as SpecFile (*.dat)",)
+
+ IMAGE_FILTER_EDF = 'Image data as EDF (*.edf)'
+ IMAGE_FILTER_TIFF = 'Image data as TIFF (*.tif)'
+ IMAGE_FILTER_NUMPY = 'Image data as NumPy binary file (*.npy)'
+ IMAGE_FILTER_ASCII = 'Image data as ASCII (*.dat)'
+ IMAGE_FILTER_CSV_COMMA = 'Image data as ,-separated CSV (*.csv)'
+ IMAGE_FILTER_CSV_SEMICOLON = 'Image data as ;-separated CSV (*.csv)'
+ IMAGE_FILTER_CSV_TAB = 'Image data as tab-separated CSV (*.csv)'
+ IMAGE_FILTER_RGB_PNG = 'Image as PNG (*.png)'
+ IMAGE_FILTER_NXDATA = 'Image as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+
+ DEFAULT_IMAGE_FILTERS = (IMAGE_FILTER_EDF,
+ IMAGE_FILTER_TIFF,
+ IMAGE_FILTER_NUMPY,
+ IMAGE_FILTER_ASCII,
+ IMAGE_FILTER_CSV_COMMA,
+ IMAGE_FILTER_CSV_SEMICOLON,
+ IMAGE_FILTER_CSV_TAB,
+ IMAGE_FILTER_RGB_PNG,
+ IMAGE_FILTER_NXDATA)
+
+ SCATTER_FILTER_NXDATA = 'Scatter as NXdata (%s)' % _NEXUS_HDF5_EXT_STR
+ DEFAULT_SCATTER_FILTERS = (SCATTER_FILTER_NXDATA,)
+
+ # filters for which we don't want an "overwrite existing file" warning
+ DEFAULT_APPEND_FILTERS = (CURVE_FILTER_NXDATA, IMAGE_FILTER_NXDATA,
+ SCATTER_FILTER_NXDATA)
+
+ def __init__(self, plot, parent=None):
+ self._filters = {
+ 'all': OrderedDict(),
+ 'curve': OrderedDict(),
+ 'curves': OrderedDict(),
+ 'image': OrderedDict(),
+ 'scatter': OrderedDict()}
+
+ self._appendFilters = list(self.DEFAULT_APPEND_FILTERS)
+
+ # Initialize filters
+ for nameFilter in self.DEFAULT_ALL_FILTERS:
+ self.setFileFilter(
+ dataKind='all', nameFilter=nameFilter, func=self._saveSnapshot)
+
+ for nameFilter in self.DEFAULT_CURVE_FILTERS:
+ self.setFileFilter(
+ dataKind='curve', nameFilter=nameFilter, func=self._saveCurve)
+
+ for nameFilter in self.DEFAULT_ALL_CURVES_FILTERS:
+ self.setFileFilter(
+ dataKind='curves', nameFilter=nameFilter, func=self._saveCurves)
+
+ for nameFilter in self.DEFAULT_IMAGE_FILTERS:
+ self.setFileFilter(
+ dataKind='image', nameFilter=nameFilter, func=self._saveImage)
+
+ for nameFilter in self.DEFAULT_SCATTER_FILTERS:
+ self.setFileFilter(
+ dataKind='scatter', nameFilter=nameFilter, func=self._saveScatter)
+
+ super(SaveAction, self).__init__(
+ plot, icon='document-save', text='Save as...',
+ tooltip='Save curve/image/plot snapshot dialog',
+ triggered=self._actionTriggered,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ @staticmethod
+ def _errorMessage(informativeText='', parent=None):
+ """Display an error message."""
+ # TODO issue with QMessageBox size fixed and too small
+ msg = qt.QMessageBox(parent)
+ msg.setIcon(qt.QMessageBox.Critical)
+ msg.setInformativeText(informativeText + ' ' + str(sys.exc_info()[1]))
+ msg.setDetailedText(traceback.format_exc())
+ msg.exec()
+
+ def _saveSnapshot(self, plot, filename, nameFilter):
+ """Save a snapshot of the :class:`PlotWindow` widget.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter == self.SNAPSHOT_FILTER_PNG:
+ fileFormat = 'png'
+ elif nameFilter == self.SNAPSHOT_FILTER_SVG:
+ fileFormat = 'svg'
+ else: # Format not supported
+ _logger.error(
+ 'Saving plot snapshot failed: format not supported')
+ return False
+
+ plot.saveGraph(filename, fileFormat=fileFormat)
+ return True
+
+ def _getAxesLabels(self, item):
+ # If curve has no associated label, get the default from the plot
+ xlabel = item.getXLabel() or self.plot.getXAxis().getLabel()
+ ylabel = item.getYLabel() or self.plot.getYAxis().getLabel()
+ return xlabel, ylabel
+
+ def _get1dData(self, item):
+ "provide xdata, [ydata], xlabel, [ylabel] and manages error bars"
+ xlabel, ylabel = self._getAxesLabels(item)
+ x_data = item.getXData(copy=False)
+ y_data = item.getYData(copy=False)
+ x_err = item.getXErrorData(copy=False)
+ y_err = item.getYErrorData(copy=False)
+ labels = [ylabel]
+ data = [y_data]
+
+ if x_err is not None:
+ if numpy.isscalar(x_err):
+ data.append(numpy.zeros_like(y_data) + x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 1:
+ data.append(x_err)
+ labels.append(xlabel + "_errors")
+ elif x_err.ndim == 2:
+ data.append(x_err[0])
+ labels.append(xlabel + "_errors_below")
+ data.append(x_err[1])
+ labels.append(xlabel + "_errors_above")
+
+ if y_err is not None:
+ if numpy.isscalar(y_err):
+ data.append(numpy.zeros_like(y_data) + y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 1:
+ data.append(y_err)
+ labels.append(ylabel + "_errors")
+ elif y_err.ndim == 2:
+ data.append(y_err[0])
+ labels.append(ylabel + "_errors_below")
+ data.append(y_err[1])
+ labels.append(ylabel + "_errors_above")
+ return x_data, data, xlabel, labels
+
+ @staticmethod
+ def _selectWriteableOutputGroup(filename, parent):
+ if os.path.exists(filename) and os.path.isfile(filename) \
+ and os.access(filename, os.W_OK):
+ entryPath = selectOutputGroup(filename)
+ if entryPath is None:
+ _logger.info("Save operation cancelled")
+ return None
+ return entryPath
+ elif not os.path.exists(filename):
+ # create new entry in new file
+ return "/entry"
+ else:
+ SaveAction._errorMessage('Save failed (file access issue)\n', parent=parent)
+ return None
+
+ def _saveCurveAsNXdata(self, curve, filename):
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+
+ xlabel, ylabel = self._getAxesLabels(curve)
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=curve.getYData(copy=False),
+ axes=[curve.getXData(copy=False)],
+ signal_name="y",
+ axes_names=["x"],
+ signal_long_name=ylabel,
+ axes_long_names=[xlabel],
+ signal_errors=curve.getYErrorData(copy=False),
+ axes_errors=[curve.getXErrorData(copy=True)],
+ title=self.plot.getGraphTitle())
+
+ def _saveCurve(self, plot, filename, nameFilter):
+ """Save a curve from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_CURVE_FILTERS:
+ return False
+
+ # Check if a curve is to be saved
+ curve = plot.getActiveCurve()
+ # before calling _saveCurve, if there is no selected curve, we
+ # make sure there is only one curve on the graph
+ if curve is None:
+ curves = plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curve to be saved", parent=self.plot)
+ return False
+ curve = curves[0]
+
+ if nameFilter in self.CURVE_FILTERS_TXT:
+ filter_ = self.CURVE_FILTERS_TXT[nameFilter]
+ fmt = filter_['fmt']
+ csvdelim = filter_['delimiter']
+ autoheader = filter_['header']
+ else:
+ # .npy or nxdata
+ fmt, csvdelim, autoheader = ("", "", False)
+
+ if nameFilter == self.CURVE_FILTER_NXDATA:
+ return self._saveCurveAsNXdata(curve, filename)
+
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
+ try:
+ save1D(filename,
+ xdata, data,
+ xlabel, labels,
+ fmt=fmt, csvdelim=csvdelim,
+ autoheader=autoheader)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+
+ return True
+
+ def _saveCurves(self, plot, filename, nameFilter):
+ """Save all curves from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_ALL_CURVES_FILTERS:
+ return False
+
+ curves = plot.getAllCurves()
+ if not curves:
+ self._errorMessage("No curves to be saved", parent=self.plot)
+ return False
+
+ curve = curves[0]
+ scanno = 1
+ try:
+ xdata, data, xlabel, labels = self._get1dData(curve)
+
+ specfile = savespec(filename,
+ xdata, data,
+ xlabel, labels,
+ fmt="%.7g", scan_number=1, mode="w",
+ write_file_header=True,
+ close_file=False)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+
+ for curve in curves[1:]:
+ try:
+ scanno += 1
+ xdata, data, xlabel, labels = self._get1dData(curve)
+ specfile = savespec(specfile,
+ xdata, data,
+ xlabel, labels,
+ fmt="%.7g", scan_number=scanno,
+ write_file_header=False,
+ close_file=False)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+ specfile.close()
+
+ return True
+
+ def _saveImage(self, plot, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_IMAGE_FILTERS:
+ return False
+
+ image = plot.getActiveImage()
+ if image is None:
+ qt.QMessageBox.warning(
+ plot, "No Data", "No image to be saved")
+ return False
+
+ data = image.getData(copy=False)
+
+ # TODO Use silx.io for writing files
+ if nameFilter == self.IMAGE_FILTER_EDF:
+ edfFile = EdfFile(filename, access="w+")
+ edfFile.WriteImage({}, data, Append=0)
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_TIFF:
+ tiffFile = TiffIO(filename, mode='w')
+ tiffFile.writeImage(data, software='silx')
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NUMPY:
+ try:
+ numpy.save(filename, data)
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_NXDATA:
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+ xorigin, yorigin = image.getOrigin()
+ xscale, yscale = image.getScale()
+ xaxis = xorigin + xscale * numpy.arange(data.shape[1])
+ yaxis = yorigin + yscale * numpy.arange(data.shape[0])
+ xlabel, ylabel = self._getAxesLabels(image)
+ interpretation = "image" if len(data.shape) == 2 else "rgba-image"
+
+ return save_NXdata(filename,
+ nxentry_name=entryPath,
+ signal=data,
+ axes=[yaxis, xaxis],
+ signal_name="image",
+ axes_names=["y", "x"],
+ axes_long_names=[ylabel, xlabel],
+ title=plot.getGraphTitle(),
+ interpretation=interpretation)
+
+ elif nameFilter in (self.IMAGE_FILTER_ASCII,
+ self.IMAGE_FILTER_CSV_COMMA,
+ self.IMAGE_FILTER_CSV_SEMICOLON,
+ self.IMAGE_FILTER_CSV_TAB):
+ csvdelim, filetype = {
+ self.IMAGE_FILTER_ASCII: (' ', 'txt'),
+ self.IMAGE_FILTER_CSV_COMMA: (',', 'csv'),
+ self.IMAGE_FILTER_CSV_SEMICOLON: (';', 'csv'),
+ self.IMAGE_FILTER_CSV_TAB: ('\t', 'csv'),
+ }[nameFilter]
+
+ height, width = data.shape
+ rows, cols = numpy.mgrid[0:height, 0:width]
+ try:
+ save1D(filename, rows.ravel(), (cols.ravel(), data.ravel()),
+ filetype=filetype,
+ xlabel='row',
+ ylabels=['column', 'value'],
+ csvdelim=csvdelim,
+ autoheader=True)
+
+ except IOError:
+ self._errorMessage('Save failed\n', parent=self.plot)
+ return False
+ return True
+
+ elif nameFilter == self.IMAGE_FILTER_RGB_PNG:
+ # Get displayed image
+ rgbaImage = image.getRgbaImageData(copy=False)
+ # Convert RGB QImage
+ qimage = convertArrayToQImage(rgbaImage[:, :, :3])
+
+ if qimage.save(filename, 'PNG'):
+ return True
+ else:
+ _logger.error('Failed to save image as %s', filename)
+ qt.QMessageBox.critical(
+ self.parent(),
+ 'Save image as',
+ 'Failed to save image')
+
+ return False
+
+ def _saveScatter(self, plot, filename, nameFilter):
+ """Save an image from the plot.
+
+ :param str filename: The name of the file to write
+ :param str nameFilter: The selected name filter
+ :return: False if format is not supported or save failed,
+ True otherwise.
+ """
+ if nameFilter not in self.DEFAULT_SCATTER_FILTERS:
+ return False
+
+ if nameFilter == self.SCATTER_FILTER_NXDATA:
+ entryPath = self._selectWriteableOutputGroup(filename, parent=self.plot)
+ if entryPath is None:
+ return False
+ scatter = plot.getScatter()
+
+ x = scatter.getXData(copy=False)
+ y = scatter.getYData(copy=False)
+ z = scatter.getValueData(copy=False)
+
+ xerror = scatter.getXErrorData(copy=False)
+ if isinstance(xerror, float):
+ xerror = xerror * numpy.ones(x.shape, dtype=numpy.float32)
+
+ yerror = scatter.getYErrorData(copy=False)
+ if isinstance(yerror, float):
+ yerror = yerror * numpy.ones(x.shape, dtype=numpy.float32)
+
+ xlabel = plot.getGraphXLabel()
+ ylabel = plot.getGraphYLabel()
+
+ return save_NXdata(
+ filename,
+ nxentry_name=entryPath,
+ signal=z,
+ axes=[x, y],
+ signal_name="values",
+ axes_names=["x", "y"],
+ axes_long_names=[xlabel, ylabel],
+ axes_errors=[xerror, yerror],
+ title=plot.getGraphTitle())
+
+ def setFileFilter(self, dataKind, nameFilter, func, index=None, appendToFile=False):
+ """Set a name filter to add/replace a file format support
+
+ :param str dataKind:
+ The kind of data for which the provided filter is valid.
+ One of: 'all', 'curve', 'curves', 'image', 'scatter'
+ :param str nameFilter: The name filter in the QFileDialog.
+ See :meth:`QFileDialog.setNameFilters`.
+ :param callable func: The function to call to perform saving.
+ Expected signature is:
+ bool func(PlotWidget plot, str filename, str nameFilter)
+ :param bool appendToFile: True to append the data into the selected
+ file.
+ :param integer index: Index of the filter in the final list (or None)
+ """
+ assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+
+ if appendToFile:
+ self._appendFilters.append(nameFilter)
+
+ # first append or replace the new filter to prevent colissions
+ self._filters[dataKind][nameFilter] = func
+ if index is None:
+ # we are already done
+ return
+
+ # get the current ordered list of keys
+ keyList = list(self._filters[dataKind].keys())
+
+ # deal with negative indices
+ if index < 0:
+ index = len(keyList) + index
+ if index < 0:
+ index = 0
+
+ if index >= len(keyList):
+ # nothing to be done, already at the end
+ txt = 'Requested index %d impossible, already at the end' % index
+ _logger.info(txt)
+ return
+
+ # get the new ordered list
+ oldIndex = keyList.index(nameFilter)
+ del keyList[oldIndex]
+ keyList.insert(index, nameFilter)
+
+ # build the new filters
+ newFilters = OrderedDict()
+ for key in keyList:
+ newFilters[key] = self._filters[dataKind][key]
+
+ # and update the filters
+ self._filters[dataKind] = newFilters
+ return
+
+ def getFileFilters(self, dataKind):
+ """Returns the nameFilter and associated function for a kind of data.
+
+ :param str dataKind:
+ The kind of data for which the provided filter is valid.
+ On of: 'all', 'curve', 'curves', 'image', 'scatter'
+ :return: {nameFilter: function} associations.
+ :rtype: collections.OrderedDict
+ """
+ assert dataKind in ('all', 'curve', 'curves', 'image', 'scatter')
+
+ return self._filters[dataKind].copy()
+
+ def _actionTriggered(self, checked=False):
+ """Handle save action."""
+ # Set-up filters
+ filters = OrderedDict()
+
+ # Add image filters if there is an active image
+ if self.plot.getActiveImage() is not None:
+ filters.update(self._filters['image'].items())
+
+ # Add curve filters if there is a curve to save
+ if (self.plot.getActiveCurve() is not None or
+ len(self.plot.getAllCurves()) == 1):
+ filters.update(self._filters['curve'].items())
+ if len(self.plot.getAllCurves()) >= 1:
+ filters.update(self._filters['curves'].items())
+
+ # Add scatter filters if there is a scatter
+ # todo: CSV
+ if self.plot.getScatter() is not None:
+ filters.update(self._filters['scatter'].items())
+
+ filters.update(self._filters['all'].items())
+
+ # Create and run File dialog
+ dialog = qt.QFileDialog(self.plot)
+ dialog.setOption(dialog.DontUseNativeDialog)
+ dialog.setWindowTitle("Output File Selection")
+ dialog.setModal(1)
+ dialog.setNameFilters(list(filters.keys()))
+
+ dialog.setFileMode(dialog.AnyFile)
+ dialog.setAcceptMode(dialog.AcceptSave)
+
+ def onFilterSelection(filt_):
+ # disable overwrite confirmation for NXdata types,
+ # because we append the data to existing files
+ if filt_ in self._appendFilters:
+ dialog.setOption(dialog.DontConfirmOverwrite)
+ else:
+ dialog.setOption(dialog.DontConfirmOverwrite, False)
+
+ dialog.filterSelected.connect(onFilterSelection)
+
+ if not dialog.exec():
+ return False
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ if '(' in nameFilter and ')' == nameFilter.strip()[-1]:
+ # Check for correct file extension
+ # Extract file extensions as .something
+ extensions = [ext[ext.find('.'):] for ext in
+ nameFilter[nameFilter.find('(') + 1:-1].split()]
+ for ext in extensions:
+ if (len(filename) > len(ext) and
+ filename[-len(ext):].lower() == ext.lower()):
+ break
+ else: # filename has no extension supported in nameFilter, add one
+ if len(extensions) >= 1:
+ filename += extensions[0]
+
+ # Handle save
+ func = filters.get(nameFilter, None)
+ if func is not None:
+ return func(self.plot, filename, nameFilter)
+ else:
+ _logger.error('Unsupported file filter: %s', nameFilter)
+ return False
+
+
+def _plotAsPNG(plot):
+ """Save a :class:`Plot` as PNG and return the payload.
+
+ :param plot: The :class:`Plot` to save
+ """
+ pngFile = BytesIO()
+ plot.saveGraph(pngFile, fileFormat='png')
+ pngFile.flush()
+ pngFile.seek(0)
+ data = pngFile.read()
+ pngFile.close()
+ return data
+
+
+class PrintAction(PlotAction):
+ """QAction for printing the plot.
+
+ It opens a Print dialog.
+
+ Current implementation print a bitmap of the plot area and not vector
+ graphics, so printing quality is not great.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate.
+ :param parent: See :class:`QAction`.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(PrintAction, self).__init__(
+ plot, icon='document-print', text='Print...',
+ tooltip='Open print dialog',
+ triggered=self.printPlot,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def getPrinter(self):
+ """The QPrinter instance used by the PrintAction.
+
+ :rtype: QPrinter
+ """
+ return printer.getDefaultPrinter()
+
+ @property
+ @deprecated(replacement="getPrinter()", since_version="0.8.0")
+ def printer(self):
+ return self.getPrinter()
+
+ def printPlotAsWidget(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`QWidget.render` to print the plot
+
+ :return: True if successful
+ """
+ dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
+ dialog.setWindowTitle('Print Plot')
+ if not dialog.exec():
+ return False
+
+ # Print a snapshot of the plot widget at the top of the page
+ widget = self.plot.centralWidget()
+
+ painter = qt.QPainter()
+ if not painter.begin(self.getPrinter()):
+ return False
+
+ pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
+ xScale = pageRect.width() / widget.width()
+ yScale = pageRect.height() / widget.height()
+ scale = min(xScale, yScale)
+
+ painter.translate(pageRect.width() / 2., 0.)
+ painter.scale(scale, scale)
+ painter.translate(-widget.width() / 2., 0.)
+ widget.render(painter)
+ painter.end()
+
+ return True
+
+ def printPlot(self):
+ """Open the print dialog and print the plot.
+
+ Use :meth:`Plot.saveGraph` to print the plot.
+
+ :return: True if successful
+ """
+ # Init printer and start printer dialog
+ dialog = qt.QPrintDialog(self.getPrinter(), self.plot)
+ dialog.setWindowTitle('Print Plot')
+ if not dialog.exec():
+ return False
+
+ # Save Plot as PNG and make a pixmap from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+
+ pixmap = qt.QPixmap()
+ pixmap.loadFromData(pngData, 'png')
+
+ pageRect = self.getPrinter().pageRect(qt.QPrinter.DevicePixel)
+ xScale = pageRect.width() / pixmap.width()
+ yScale = pageRect.height() / pixmap.height()
+ scale = min(xScale, yScale)
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(self.getPrinter()):
+ return False
+
+ painter.drawPixmap(0, 0,
+ pixmap.width() * scale,
+ pixmap.height() * scale,
+ pixmap)
+ painter.end()
+
+ return True
+
+
+class CopyAction(PlotAction):
+ """QAction to copy :class:`.PlotWidget` content to clipboard.
+
+ :param plot: :class:`.PlotWidget` instance on which to operate
+ :param parent: See :class:`QAction`
+ """
+
+ def __init__(self, plot, parent=None):
+ super(CopyAction, self).__init__(
+ plot, icon='edit-copy', text='Copy plot',
+ tooltip='Copy a snapshot of the plot into the clipboard',
+ triggered=self.copyPlot,
+ checkable=False, parent=parent)
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+
+ def copyPlot(self):
+ """Copy plot content to the clipboard as a bitmap."""
+ # Save Plot as PNG and make a QImage from it with default dpi
+ pngData = _plotAsPNG(self.plot)
+ image = qt.QImage.fromData(pngData, 'png')
+ qt.QApplication.clipboard().setImage(image)
diff --git a/silx/gui/plot/actions/medfilt.py b/src/silx/gui/plot/actions/medfilt.py
index f86a377..f86a377 100644
--- a/silx/gui/plot/actions/medfilt.py
+++ b/src/silx/gui/plot/actions/medfilt.py
diff --git a/silx/gui/plot/actions/mode.py b/src/silx/gui/plot/actions/mode.py
index ee05256..ee05256 100644
--- a/silx/gui/plot/actions/mode.py
+++ b/src/silx/gui/plot/actions/mode.py
diff --git a/src/silx/gui/plot/backends/BackendBase.py b/src/silx/gui/plot/backends/BackendBase.py
new file mode 100755
index 0000000..1e86807
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendBase.py
@@ -0,0 +1,568 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Base class for Plot backends.
+
+It documents the Plot backend API.
+
+This API is a simplified version of PyMca PlotBackend API.
+"""
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import weakref
+from ... import qt
+
+
+# Names for setCursor
+CURSOR_DEFAULT = 'default'
+CURSOR_POINTING = 'pointing'
+CURSOR_SIZE_HOR = 'size horizontal'
+CURSOR_SIZE_VER = 'size vertical'
+CURSOR_SIZE_ALL = 'size all'
+
+
+class BackendBase(object):
+ """Class defining the API a backend of the Plot should provide."""
+
+ def __init__(self, plot, parent=None):
+ """Init.
+
+ :param Plot plot: The Plot this backend is attached to
+ :param parent: The parent widget of the plot widget.
+ """
+ self.__xLimits = 1., 100.
+ self.__yLimits = {'left': (1., 100.), 'right': (1., 100.)}
+ self.__yAxisInverted = False
+ self.__keepDataAspectRatio = False
+ self.__xAxisTimeSeries = False
+ self._xAxisTimeZone = None
+ # Store a weakref to get access to the plot state.
+ self._setPlot(plot)
+
+ @property
+ def _plot(self):
+ """The plot this backend is attached to."""
+ if self._plotRef is None:
+ raise RuntimeError('This backend is not attached to a Plot')
+
+ plot = self._plotRef()
+ if plot is None:
+ raise RuntimeError('This backend is no more attached to a Plot')
+ return plot
+
+ def _setPlot(self, plot):
+ """Allow to set plot after init.
+
+ Use with caution, basically **immediately** after init.
+ """
+ self._plotRef = weakref.ref(plot)
+
+ # Add methods
+
+ def addCurve(self, x, y,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror,
+ fill, alpha, symbolsize, baseline):
+ """Add a 1D curve given by x an y to the graph.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param color: color(s) to be used
+ :type color: string ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param str symbol: Symbol to be drawn at each (x, y) position::
+
+ - ' ' or '' no symbol
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :param float linewidth: The width of the curve in pixels
+ :param str linestyle: Type of line::
+
+ - ' ' or '' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :param str yaxis: The Y axis this curve belongs to in: 'left', 'right'
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: numpy.ndarray or None
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: numpy.ndarray or None
+ :param bool fill: True to fill the curve, False otherwise
+ :param float alpha: Curve opacity, as a float in [0., 1.]
+ :param float symbolsize: Size of the symbol (if any) drawn
+ at each (x, y) position.
+ :returns: The handle used by the backend to univocally access the curve
+ """
+ return object()
+
+ def addImage(self, data,
+ origin, scale,
+ colormap, alpha):
+ """Add an image to the plot.
+
+ :param numpy.ndarray data: (nrows, ncolumns) data or
+ (nrows, ncolumns, RGBA) ubyte array
+ :param origin: (origin X, origin Y) of the data.
+ Default: (0., 0.)
+ :type origin: 2-tuple of float
+ :param scale: (scale X, scale Y) of the data.
+ Default: (1., 1.)
+ :type scale: 2-tuple of float
+ :param ~silx.gui.colors.Colormap colormap: Colormap object to use.
+ Ignored if data is RGB(A).
+ :param float alpha: Opacity of the image, as a float in range [0, 1].
+ :returns: The handle used by the backend to univocally access the image
+ """
+ return object()
+
+ def addTriangles(self, x, y, triangles,
+ color, alpha):
+ """Add a set of triangles.
+
+ :param numpy.ndarray x: The data corresponding to the x axis
+ :param numpy.ndarray y: The data corresponding to the y axis
+ :param numpy.ndarray triangles: The indices to make triangles
+ as a (Ntriangle, 3) array
+ :param numpy.ndarray color: color(s) as (npoints, 4) array
+ :param float alpha: Opacity as a float in [0., 1.]
+ :returns: The triangles' unique identifier used by the backend
+ """
+ return object()
+
+ def addShape(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ """Add an item (i.e. a shape) to the plot.
+
+ :param numpy.ndarray x: The X coords of the points of the shape
+ :param numpy.ndarray y: The Y coords of the points of the shape
+ :param str shape: Type of item to be drawn in
+ hline, polygon, rectangle, vline, polylines
+ :param str color: Color of the item
+ :param bool fill: True to fill the shape
+ :param bool overlay: True if item is an overlay, False otherwise
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param str linebgcolor: Background color of the line, e.g., 'blue', 'b',
+ '#FF0000'. It is used to draw dotted line using a second color.
+ :returns: The handle used by the backend to univocally access the item
+ """
+ return object()
+
+ def addMarker(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ """Add a point, vertical line or horizontal line marker to the plot.
+
+ :param float x: Horizontal position of the marker in graph coordinates.
+ If None, the marker is a horizontal line.
+ :param float y: Vertical position of the marker in graph coordinates.
+ If None, the marker is a vertical line.
+ :param str text: Text associated to the marker (or None for no text)
+ :param str color: Color to be used for instance 'blue', 'b', '#FF0000'
+ :param str symbol: Symbol representing the marker.
+ Only relevant for point markers where X and Y are not None.
+ Value in:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+ :param str linestyle: Style of the line.
+ Only relevant for line markers where X or Y is None.
+ Value in:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+ :param float linewidth: Width of the line.
+ Only relevant for line markers where X or Y is None.
+ :param constraint: A function filtering marker displacement by
+ dragging operations or None for no filter.
+ This function is called each time a marker is
+ moved.
+ :type constraint: None or a callable that takes the coordinates of
+ the current cursor position in the plot as input
+ and that returns the filtered coordinates.
+ :param str yaxis: The Y axis this marker belongs to in: 'left', 'right'
+ :return: Handle used by the backend to univocally access the marker
+ """
+ return object()
+
+ # Remove methods
+
+ def remove(self, item):
+ """Remove an existing item from the plot.
+
+ :param item: A backend specific item handle returned by a add* method
+ """
+ pass
+
+ # Interaction methods
+
+ def setGraphCursorShape(self, cursor):
+ """Set the cursor shape.
+
+ To override in interactive backends.
+
+ :param str cursor: Name of the cursor shape or None
+ """
+ pass
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ """Toggle the display of a crosshair cursor and set its attributes.
+
+ To override in interactive backends.
+
+ :param bool flag: Toggle the display of a crosshair cursor.
+ :param color: The color to use for the crosshair.
+ :type color: A string (either a predefined color name in colors.py
+ or "#RRGGBB")) or a 4 columns unsigned byte array.
+ :param int linewidth: The width of the lines of the crosshair.
+ :param linestyle: Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :type linestyle: None or one of the predefined styles.
+ """
+ pass
+
+ def getItemsFromBackToFront(self, condition=None):
+ """Returns the list of plot items order as rendered by the backend.
+
+ This is the order used for rendering.
+ By default, it takes into account overlays, z value and order of addition of items,
+ but backends can override it.
+
+ :param callable condition:
+ Callable taking an item as input and returning False for items to skip.
+ If None (default), no item is skipped.
+ :rtype: List[~silx.gui.plot.items.Item]
+ """
+ # Sort items: Overlays first, then others
+ # and in each category ordered by z and then by order of addition
+ # as content keeps this order.
+ content = self._plot.getItems()
+ if condition is not None:
+ content = [item for item in content if condition(item)]
+
+ return sorted(
+ content,
+ key=lambda i: ((1 if i.isOverlay() else 0), i.getZValue()))
+
+ def pickItem(self, x, y, item):
+ """Return picked indices if any, or None.
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :param item: A backend item created with add* methods.
+ :return: None if item was not picked, else returns
+ picked indices information.
+ :rtype: Union[None,List]
+ """
+ return None
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ """Set the color of a curve.
+
+ :param curve: The curve handle
+ :param str color: The color to use.
+ """
+ pass
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ """Return the widget this backend is drawing to."""
+ return None
+
+ def postRedisplay(self):
+ """Trigger backend update and repaint."""
+ self.replot()
+
+ def replot(self):
+ """Redraw the plot."""
+ with self._plot._paintContext():
+ pass
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ """Save the graph to a file (or a StringIO)
+
+ At least "png", "svg" are supported.
+
+ :param fileName: Destination
+ :type fileName: String or StringIO or BytesIO
+ :param str fileFormat: String specifying the format
+ :param int dpi: The resolution to use or None.
+ """
+ pass
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ """Set the main title of the plot.
+
+ :param str title: Title associated to the plot
+ """
+ pass
+
+ def setGraphXLabel(self, label):
+ """Set the X axis label.
+
+ :param str label: label associated to the plot bottom X axis
+ """
+ pass
+
+ def setGraphYLabel(self, label, axis):
+ """Set the left Y axis label.
+
+ :param str label: label associated to the plot left Y axis
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ pass
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ """Set the limits of the X and Y axes at once.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param float y2min: minimum right axis value
+ :param float y2max: maximum right axis value
+ """
+ self.__xLimits = xmin, xmax
+ self.__yLimits['left'] = ymin, ymax
+ if y2min is not None and y2max is not None:
+ self.__yLimits['right'] = y2min, y2max
+
+ def getGraphXLimits(self):
+ """Get the graph X (bottom) limits.
+
+ :return: Minimum and maximum values of the X axis
+ """
+ return self.__xLimits
+
+ def setGraphXLimits(self, xmin, xmax):
+ """Set the limits of X axis.
+
+ :param float xmin: minimum bottom axis value
+ :param float xmax: maximum bottom axis value
+ """
+ self.__xLimits = xmin, xmax
+
+ def getGraphYLimits(self, axis):
+ """Get the graph Y (left) limits.
+
+ :param str axis: The axis for which to get the limits: left or right
+ :return: Minimum and maximum values of the Y axis
+ """
+ return self.__yLimits[axis]
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ """Set the limits of the Y axis.
+
+ :param float ymin: minimum left axis value
+ :param float ymax: maximum left axis value
+ :param str axis: The axis for which to get the limits: left or right
+ """
+ self.__yLimits[axis] = ymin, ymax
+
+ # Graph axes
+
+
+ def getXAxisTimeZone(self):
+ """Returns tzinfo that is used if the X-Axis plots date-times.
+
+ None means the datetimes are interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ return self._xAxisTimeZone
+
+ def setXAxisTimeZone(self, tz):
+ """Sets tzinfo that is used if the X-Axis plots date-times.
+
+ Use None to let the datetimes be interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ self._xAxisTimeZone = tz
+
+ def isXAxisTimeSeries(self):
+ """Return True if the X-axis scale shows datetime objects.
+
+ :rtype: bool
+ """
+ return self.__xAxisTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ """Set whether the X-axis is a time series
+
+ :param bool flag: True to switch to time series, False for regular axis.
+ """
+ self.__xAxisTimeSeries = bool(isTimeSeries)
+
+ def setXAxisLogarithmic(self, flag):
+ """Set the X axis scale between linear and log.
+
+ :param bool flag: If True, the bottom axis will use a log scale
+ """
+ pass
+
+ def setYAxisLogarithmic(self, flag):
+ """Set the Y axis scale between linear and log.
+
+ :param bool flag: If True, the left axis will use a log scale
+ """
+ pass
+
+ def setYAxisInverted(self, flag):
+ """Invert the Y axis.
+
+ :param bool flag: If True, put the vertical axis origin on the top
+ """
+ self.__yAxisInverted = bool(flag)
+
+ def isYAxisInverted(self):
+ """Return True if left Y axis is inverted, False otherwise."""
+ return self.__yAxisInverted
+
+ def isKeepDataAspectRatio(self):
+ """Returns whether the plot is keeping data aspect ratio or not."""
+ return self.__keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ """Set whether to keep data aspect ratio or not.
+
+ :param flag: True to respect data aspect ratio
+ :type flag: Boolean, default True
+ """
+ self.__keepDataAspectRatio = bool(flag)
+
+ def setGraphGrid(self, which):
+ """Set grid.
+
+ :param which: None to disable grid, 'major' for major grid,
+ 'both' for major and minor grid
+ """
+ pass
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ """Convert a position in data space to a position in pixels
+ in the widget.
+
+ :param float x: The X coordinate in data space.
+ :param float y: The Y coordinate in data space.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in pixels or
+ None if the data position is not in the displayed area.
+ :rtype: A tuple of 2 floats: (xPixel, yPixel) or None.
+ """
+ raise NotImplementedError()
+
+ def pixelToData(self, x, y, axis):
+ """Convert a position in pixels in the widget to a position in
+ the data space.
+
+ :param float x: The X coordinate in pixels.
+ :param float y: The Y coordinate in pixels.
+ :param str axis: The Y axis to use for the conversion
+ ('left' or 'right').
+ :returns: The corresponding position in data space or
+ None if the pixel position is not in the plot area.
+ :rtype: A tuple of 2 floats: (xData, yData) or None.
+ """
+ raise NotImplementedError()
+
+ def getPlotBoundsInPixels(self):
+ """Plot area bounds in widget coordinates in pixels.
+
+ :return: bounds as a 4-tuple of int: (left, top, width, height)
+ """
+ raise NotImplementedError()
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ """Set the size of plot margins as ratios.
+
+ Values are expected in [0., 1.]
+
+ :param float left:
+ :param float top:
+ :param float right:
+ :param float bottom:
+ """
+ pass
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ """Set foreground and grid colors used to display this widget.
+
+ :param List[float] foregroundColor: RGBA foreground color of the widget
+ :param List[float] gridColor: RGBA grid color of the data view
+ """
+ pass
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ """Set background colors used to display this widget.
+
+ :param List[float] backgroundColor: RGBA background color of the widget
+ :param Union[Tuple[float],None] dataBackgroundColor:
+ RGBA background color of the data view
+ """
+ pass
diff --git a/src/silx/gui/plot/backends/BackendMatplotlib.py b/src/silx/gui/plot/backends/BackendMatplotlib.py
new file mode 100755
index 0000000..7fe4ec0
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendMatplotlib.py
@@ -0,0 +1,1557 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Matplotlib Plot backend."""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent, H. Payno"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+
+import logging
+import datetime as dt
+from typing import Tuple
+import numpy
+
+from pkg_resources import parse_version as _parse_version
+
+
+_logger = logging.getLogger(__name__)
+
+
+from ... import qt
+
+# First of all init matplotlib and set its backend
+from ...utils.matplotlib import FigureCanvasQTAgg
+import matplotlib
+from matplotlib.container import Container
+from matplotlib.figure import Figure
+from matplotlib.patches import Rectangle, Polygon
+from matplotlib.image import AxesImage
+from matplotlib.backend_bases import MouseEvent
+from matplotlib.lines import Line2D
+from matplotlib.text import Text
+from matplotlib.collections import PathCollection, LineCollection
+from matplotlib.ticker import Formatter, ScalarFormatter, Locator
+from matplotlib.tri import Triangulation
+from matplotlib.collections import TriMesh
+from matplotlib import path as mpath
+
+from . import BackendBase
+from .. import items
+from .._utils import FLOAT32_MINPOS
+from .._utils.dtime_ticklayout import calcTicks, bestFormatString, timestamp
+
+_PATCH_LINESTYLE = {
+ "-": 'solid',
+ "--": 'dashed',
+ '-.': 'dashdot',
+ ':': 'dotted',
+ '': "solid",
+ None: "solid",
+}
+"""Patches do not uses the same matplotlib syntax"""
+
+_MARKER_PATHS = {}
+"""Store cached extra marker paths"""
+
+_SPECIAL_MARKERS = {
+ 'tickleft': 0,
+ 'tickright': 1,
+ 'tickup': 2,
+ 'tickdown': 3,
+ 'caretleft': 4,
+ 'caretright': 5,
+ 'caretup': 6,
+ 'caretdown': 7,
+}
+
+
+def normalize_linestyle(linestyle):
+ """Normalize known old-style linestyle, else return the provided value."""
+ return _PATCH_LINESTYLE.get(linestyle, linestyle)
+
+def get_path_from_symbol(symbol):
+ """Get the path representation of a symbol, else None if
+ it is not provided.
+
+ :param str symbol: Symbol description used by silx
+ :rtype: Union[None,matplotlib.path.Path]
+ """
+ if symbol == u'\u2665':
+ path = _MARKER_PATHS.get(symbol, None)
+ if path is not None:
+ return path
+ vertices = numpy.array([
+ [0,-99],
+ [31,-73], [47,-55], [55,-46],
+ [63,-37], [94,-2], [94,33],
+ [94,69], [71,89], [47,89],
+ [24,89], [8,74], [0,58],
+ [-8,74], [-24,89], [-47,89],
+ [-71,89], [-94,69], [-94,33],
+ [-94,-2], [-63,-37], [-55,-46],
+ [-47,-55], [-31,-73], [0,-99],
+ [0,-99]])
+ codes = [mpath.Path.CURVE4] * len(vertices)
+ codes[0] = mpath.Path.MOVETO
+ codes[-1] = mpath.Path.CLOSEPOLY
+ path = mpath.Path(vertices, codes)
+ _MARKER_PATHS[symbol] = path
+ return path
+ return None
+
+class NiceDateLocator(Locator):
+ """
+ Matplotlib Locator that uses Nice Numbers algorithm (adapted to dates)
+ to find the tick locations. This results in the same number behaviour
+ as when using the silx Open GL backend.
+
+ Expects the data to be posix timestampes (i.e. seconds since 1970)
+ """
+ def __init__(self, numTicks=5, tz=None):
+ """
+ :param numTicks: target number of ticks
+ :param datetime.tzinfo tz: optional time zone. None is local time.
+ """
+ super(NiceDateLocator, self).__init__()
+ self.numTicks = numTicks
+
+ self._spacing = None
+ self._unit = None
+ self.tz = tz
+
+ @property
+ def spacing(self):
+ """ The current spacing. Will be updated when new tick value are made"""
+ return self._spacing
+
+ @property
+ def unit(self):
+ """ The current DtUnit. Will be updated when new tick value are made"""
+ return self._unit
+
+ def __call__(self):
+ """Return the locations of the ticks"""
+ vmin, vmax = self.axis.get_view_interval()
+ return self.tick_values(vmin, vmax)
+
+ def tick_values(self, vmin, vmax):
+ """ Calculates tick values
+ """
+ if vmax < vmin:
+ vmin, vmax = vmax, vmin
+
+ # vmin and vmax should be timestamps (i.e. seconds since 1 Jan 1970)
+ dtMin = dt.datetime.fromtimestamp(vmin, tz=self.tz)
+ dtMax = dt.datetime.fromtimestamp(vmax, tz=self.tz)
+ dtTicks, self._spacing, self._unit = \
+ calcTicks(dtMin, dtMax, self.numTicks)
+
+ # Convert datetime back to time stamps.
+ ticks = [timestamp(dtTick) for dtTick in dtTicks]
+ return ticks
+
+
+class NiceAutoDateFormatter(Formatter):
+ """
+ Matplotlib FuncFormatter that is linked to a NiceDateLocator and gives the
+ best possible formats given the locators current spacing an date unit.
+ """
+
+ def __init__(self, locator, tz=None):
+ """
+ :param niceDateLocator: a NiceDateLocator object
+ :param datetime.tzinfo tz: optional time zone. None is local time.
+ """
+ super(NiceAutoDateFormatter, self).__init__()
+ self.locator = locator
+ self.tz = tz
+
+ @property
+ def formatString(self):
+ if self.locator.spacing is None or self.locator.unit is None:
+ # Locator has no spacing or units yet. Return elaborate fmtString
+ return "Y-%m-%d %H:%M:%S"
+ else:
+ return bestFormatString(self.locator.spacing, self.locator.unit)
+
+ def __call__(self, x, pos=None):
+ """Return the format for tick val *x* at position *pos*
+ Expects x to be a POSIX timestamp (seconds since 1 Jan 1970)
+ """
+ dateTime = dt.datetime.fromtimestamp(x, tz=self.tz)
+ tickStr = dateTime.strftime(self.formatString)
+ return tickStr
+
+
+class _PickableContainer(Container):
+ """Artists container with a :meth:`contains` method"""
+
+ def __init__(self, *args, **kwargs):
+ Container.__init__(self, *args, **kwargs)
+ self.__zorder = None
+
+ @property
+ def axes(self):
+ """Mimin Artist.axes"""
+ for child in self.get_children():
+ if hasattr(child, 'axes'):
+ return child.axes
+ return None
+
+ def draw(self, *args, **kwargs):
+ """artist-like draw to broadcast draw to children"""
+ for child in self.get_children():
+ child.draw(*args, **kwargs)
+
+ def get_zorder(self):
+ """Mimic Artist.get_zorder"""
+ return self.__zorder
+
+ def set_zorder(self, z):
+ """Mimic Artist.set_zorder to broadcast to children"""
+ if z != self.__zorder:
+ self.__zorder = z
+ for child in self.get_children():
+ child.set_zorder(z)
+
+ def contains(self, mouseevent):
+ """Mimic Artist.contains, and call it on all children.
+
+ :param mouseevent:
+ :return: Picking status and associated information as a dict
+ :rtype: (bool,dict)
+ """
+ # Goes through children from front to back and return first picked one.
+ for child in reversed(self.get_children()):
+ picked, info = child.contains(mouseevent)
+ if picked:
+ return picked, info
+ return False, {}
+
+
+class _TextWithOffset(Text):
+ """Text object which can be displayed at a specific position
+ of the plot, but with a pixel offset"""
+
+ def __init__(self, *args, **kwargs):
+ Text.__init__(self, *args, **kwargs)
+ self.pixel_offset = (0, 0)
+ self.__cache = None
+
+ def draw(self, renderer):
+ self.__cache = None
+ return Text.draw(self, renderer)
+
+ def __get_xy(self):
+ if self.__cache is not None:
+ return self.__cache
+
+ align = self.get_horizontalalignment()
+ if align == "left":
+ xoffset = self.pixel_offset[0]
+ elif align == "right":
+ xoffset = -self.pixel_offset[0]
+ else:
+ xoffset = 0
+
+ align = self.get_verticalalignment()
+ if align == "top":
+ yoffset = -self.pixel_offset[1]
+ elif align == "bottom":
+ yoffset = self.pixel_offset[1]
+ else:
+ yoffset = 0
+
+ trans = self.get_transform()
+ x = super(_TextWithOffset, self).convert_xunits(self._x)
+ y = super(_TextWithOffset, self).convert_xunits(self._y)
+ pos = x, y
+
+ try:
+ invtrans = trans.inverted()
+ except numpy.linalg.LinAlgError:
+ # Cannot inverse transform, fallback: pos without offset
+ self.__cache = None
+ return pos
+
+ proj = trans.transform_point(pos)
+ proj = proj + numpy.array((xoffset, yoffset))
+ pos = invtrans.transform_point(proj)
+ self.__cache = pos
+ return pos
+
+ def convert_xunits(self, x):
+ """Return the pixel position of the annotated point."""
+ return self.__get_xy()[0]
+
+ def convert_yunits(self, y):
+ """Return the pixel position of the annotated point."""
+ return self.__get_xy()[1]
+
+
+class _MarkerContainer(_PickableContainer):
+ """Marker artists container supporting draw/remove and text position update
+
+ :param artists:
+ Iterable with either one Line2D or a Line2D and a Text.
+ The use of an iterable if enforced by Container being
+ a subclass of tuple that defines a specific __new__.
+ :param x: X coordinate of the marker (None for horizontal lines)
+ :param y: Y coordinate of the marker (None for vertical lines)
+ """
+
+ def __init__(self, artists, symbol, x, y, yAxis):
+ self.line = artists[0]
+ self.text = artists[1] if len(artists) > 1 else None
+ self.symbol = symbol
+ self.x = x
+ self.y = y
+ self.yAxis = yAxis
+
+ _PickableContainer.__init__(self, artists)
+
+ def draw(self, *args, **kwargs):
+ """artist-like draw to broadcast draw to line and text"""
+ self.line.draw(*args, **kwargs)
+ if self.text is not None:
+ self.text.draw(*args, **kwargs)
+
+ def updateMarkerText(self, xmin, xmax, ymin, ymax, yinverted):
+ """Update marker text position and visibility according to plot limits
+
+ :param xmin: X axis lower limit
+ :param xmax: X axis upper limit
+ :param ymin: Y axis lower limit
+ :param ymax: Y axis upper limit
+ :param yinverted: True if the y axis is inverted
+ """
+ if self.text is not None:
+ visible = ((self.x is None or xmin <= self.x <= xmax) and
+ (self.y is None or ymin <= self.y <= ymax))
+ self.text.set_visible(visible)
+
+ if self.x is not None and self.y is not None:
+ if self.symbol is None:
+ valign = 'baseline'
+ else:
+ if yinverted:
+ valign = 'bottom'
+ else:
+ valign = 'top'
+ self.text.set_verticalalignment(valign)
+
+ elif self.y is None: # vertical line
+ # Always display it on top
+ center = (ymax + ymin) * 0.5
+ pos = (ymax - ymin) * 0.5 * 0.99
+ if yinverted:
+ pos = -pos
+ self.text.set_y(center + pos)
+
+ elif self.x is None: # Horizontal line
+ delta = abs(xmax - xmin)
+ if xmin > xmax:
+ xmax = xmin
+ xmax -= 0.005 * delta
+ self.text.set_x(xmax)
+
+ def contains(self, mouseevent):
+ """Mimic Artist.contains, and call it on the line Artist.
+
+ :param mouseevent:
+ :return: Picking status and associated information as a dict
+ :rtype: (bool,dict)
+ """
+ return self.line.contains(mouseevent)
+
+
+class _DoubleColoredLinePatch(matplotlib.patches.Patch):
+ """Matplotlib patch to display any patch using double color."""
+
+ def __init__(self, patch):
+ super(_DoubleColoredLinePatch, self).__init__()
+ self.__patch = patch
+ self.linebgcolor = None
+
+ def __getattr__(self, name):
+ return getattr(self.__patch, name)
+
+ def draw(self, renderer):
+ oldLineStype = self.__patch.get_linestyle()
+ if self.linebgcolor is not None and oldLineStype != "solid":
+ oldLineColor = self.__patch.get_edgecolor()
+ oldHatch = self.__patch.get_hatch()
+ self.__patch.set_linestyle("solid")
+ self.__patch.set_edgecolor(self.linebgcolor)
+ self.__patch.set_hatch(None)
+ self.__patch.draw(renderer)
+ self.__patch.set_linestyle(oldLineStype)
+ self.__patch.set_edgecolor(oldLineColor)
+ self.__patch.set_hatch(oldHatch)
+ self.__patch.draw(renderer)
+
+ def set_transform(self, transform):
+ self.__patch.set_transform(transform)
+
+ def get_path(self):
+ return self.__patch.get_path()
+
+ def contains(self, mouseevent, radius=None):
+ return self.__patch.contains(mouseevent, radius)
+
+ def contains_point(self, point, radius=None):
+ return self.__patch.contains_point(point, radius)
+
+
+class Image(AxesImage):
+ """An AxesImage with a fast path for uint8 RGBA images.
+
+ :param List[float] silx_origin: (ox, oy) Offset of the image.
+ :param List[float] silx_scale: (sx, sy) Scale of the image.
+ """
+
+ def __init__(self, *args,
+ silx_origin=(0., 0.),
+ silx_scale=(1., 1.),
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__silx_origin = silx_origin
+ self.__silx_scale = silx_scale
+
+ def contains(self, mouseevent):
+ """Overridden to fill 'ind' with row and column"""
+ inside, info = super().contains(mouseevent)
+ if inside:
+ x, y = mouseevent.xdata, mouseevent.ydata
+ ox, oy = self.__silx_origin
+ sx, sy = self.__silx_scale
+ height, width = self.get_size()
+ column = numpy.clip(int((x - ox) / sx), 0, width - 1)
+ row = numpy.clip(int((y - oy) / sy), 0, height - 1)
+ info['ind'] = (row,), (column,)
+ return inside, info
+
+ def set_data(self, A):
+ """Overridden to add a fast path for RGBA unit8 images"""
+ A = numpy.array(A, copy=False)
+ if A.ndim != 3 or A.shape[2] != 4 or A.dtype != numpy.uint8:
+ super(Image, self).set_data(A)
+ else:
+ # Call AxesImage.set_data with small data to set attributes
+ super(Image, self).set_data(numpy.zeros((2, 2, 4), dtype=A.dtype))
+ self._A = A # Override stored data
+
+
+class BackendMatplotlib(BackendBase.BackendBase):
+ """Base class for Matplotlib backend without a FigureCanvas.
+
+ For interactive on screen plot, see :class:`BackendMatplotlibQt`.
+
+ See :class:`BackendBase.BackendBase` for public API documentation.
+ """
+
+ def __init__(self, plot, parent=None):
+ super(BackendMatplotlib, self).__init__(plot, parent)
+
+ # matplotlib is handling keep aspect ratio at draw time
+ # When keep aspect ratio is on, and one changes the limits and
+ # ask them *before* next draw has been performed he will get the
+ # limits without applying keep aspect ratio.
+ # This attribute is used to ensure consistent values returned
+ # when getting the limits at the expense of a replot
+ self._dirtyLimits = True
+ self._axesDisplayed = True
+ self._matplotlibVersion = _parse_version(matplotlib.__version__)
+
+ self.fig = Figure()
+ self.fig.set_facecolor("w")
+
+ self.ax = self.fig.add_axes([.15, .15, .75, .75], label="left")
+ self.ax2 = self.ax.twinx()
+ self.ax2.set_label("right")
+ # Make sure background of Axes is displayed
+ self.ax2.patch.set_visible(False)
+ self.ax.patch.set_visible(True)
+
+ # Set axis zorder=0.5 so grid is displayed at 0.5
+ self.ax.set_axisbelow(True)
+
+ # disable the use of offsets
+ try:
+ axes = [
+ self.ax.get_yaxis().get_major_formatter(),
+ self.ax.get_xaxis().get_major_formatter(),
+ self.ax2.get_yaxis().get_major_formatter(),
+ self.ax2.get_xaxis().get_major_formatter(),
+ ]
+ for axis in axes:
+ axis.set_useOffset(False)
+ axis.set_scientific(False)
+ except:
+ _logger.warning('Cannot disabled axes offsets in %s '
+ % matplotlib.__version__)
+
+ self.ax2.set_autoscaley_on(True)
+
+ # this works but the figure color is left
+ if self._matplotlibVersion < _parse_version('2'):
+ self.ax.set_axis_bgcolor('none')
+ else:
+ self.ax.set_facecolor('none')
+ self.fig.sca(self.ax)
+
+ self._background = None
+
+ self._colormaps = {}
+
+ self._graphCursor = tuple()
+
+ self._enableAxis('right', False)
+ self._isXAxisTimeSeries = False
+
+ def getItemsFromBackToFront(self, condition=None):
+ """Order as BackendBase + take into account matplotlib Axes structure"""
+ def axesOrder(item):
+ if item.isOverlay():
+ return 2
+ elif isinstance(item, items.YAxisMixIn) and item.getYAxis() == 'right':
+ return 1
+ else:
+ return 0
+
+ return sorted(
+ BackendBase.BackendBase.getItemsFromBackToFront(
+ self, condition=condition),
+ key=axesOrder)
+
+ def _overlayItems(self):
+ """Generator of backend renderer for overlay items"""
+ for item in self._plot.getItems():
+ if (item.isOverlay() and
+ item.isVisible() and
+ item._backendRenderer is not None):
+ yield item._backendRenderer
+
+ def _hasOverlays(self):
+ """Returns whether there is an overlay layer or not.
+
+ The overlay layers contains overlay items and the crosshair.
+
+ :rtype: bool
+ """
+ if self._graphCursor:
+ return True # There is the crosshair
+
+ for item in self._overlayItems():
+ return True # There is at least one overlay item
+ return False
+
+ # Add methods
+
+ def _getMarkerFromSymbol(self, symbol):
+ """Returns a marker that can be displayed by matplotlib.
+
+ :param str symbol: A symbol description used by silx
+ :rtype: Union[str,int,matplotlib.path.Path]
+ """
+ path = get_path_from_symbol(symbol)
+ if path is not None:
+ return path
+ num = _SPECIAL_MARKERS.get(symbol, None)
+ if num is not None:
+ return num
+ # This symbol must be supported by matplotlib
+ return symbol
+
+ def addCurve(self, x, y,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror,
+ fill, alpha, symbolsize, baseline):
+ for parameter in (x, y, color, symbol, linewidth, linestyle,
+ yaxis, fill, alpha, symbolsize):
+ assert parameter is not None
+ assert yaxis in ('left', 'right')
+
+ if (len(color) == 4 and
+ type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
+ color = numpy.array(color, dtype=numpy.float64) / 255.
+
+ if yaxis == "right":
+ axes = self.ax2
+ self._enableAxis("right", True)
+ else:
+ axes = self.ax
+
+ pickradius = 3
+
+ artists = [] # All the artists composing the curve
+
+ # First add errorbars if any so they are behind the curve
+ if xerror is not None or yerror is not None:
+ if hasattr(color, 'dtype') and len(color) == len(x):
+ errorbarColor = 'k'
+ else:
+ errorbarColor = color
+
+ # Nx1 error array deprecated in matplotlib >=3.1 (removed in 3.3)
+ if (isinstance(xerror, numpy.ndarray) and xerror.ndim == 2 and
+ xerror.shape[1] == 1):
+ xerror = numpy.ravel(xerror)
+ if (isinstance(yerror, numpy.ndarray) and yerror.ndim == 2 and
+ yerror.shape[1] == 1):
+ yerror = numpy.ravel(yerror)
+
+ errorbars = axes.errorbar(x, y,
+ xerr=xerror, yerr=yerror,
+ linestyle=' ', color=errorbarColor)
+ artists += list(errorbars.get_children())
+
+ if hasattr(color, 'dtype') and len(color) == len(x):
+ # scatter plot
+ if color.dtype not in [numpy.float32, numpy.float64]:
+ actualColor = color / 255.
+ else:
+ actualColor = color
+
+ if linestyle not in ["", " ", None]:
+ # scatter plot with an actual line ...
+ # we need to assign a color ...
+ curveList = axes.plot(x, y,
+ linestyle=linestyle,
+ color=actualColor[0],
+ linewidth=linewidth,
+ picker=True,
+ pickradius=pickradius,
+ marker=None)
+ artists += list(curveList)
+
+ marker = self._getMarkerFromSymbol(symbol)
+ scatter = axes.scatter(x, y,
+ color=actualColor,
+ marker=marker,
+ picker=True,
+ pickradius=pickradius,
+ s=symbolsize**2)
+ artists.append(scatter)
+
+ if fill:
+ if baseline is None:
+ _baseline = FLOAT32_MINPOS
+ else:
+ _baseline = baseline
+ artists.append(axes.fill_between(
+ x, _baseline, y, facecolor=actualColor[0], linestyle=''))
+
+ else: # Curve
+ curveList = axes.plot(x, y,
+ linestyle=linestyle,
+ color=color,
+ linewidth=linewidth,
+ marker=symbol,
+ picker=True,
+ pickradius=pickradius,
+ markersize=symbolsize)
+ artists += list(curveList)
+
+ if fill:
+ if baseline is None:
+ _baseline = FLOAT32_MINPOS
+ else:
+ _baseline = baseline
+ artists.append(
+ axes.fill_between(x, _baseline, y, facecolor=color))
+
+ for artist in artists:
+ if alpha < 1:
+ artist.set_alpha(alpha)
+
+ return _PickableContainer(artists)
+
+ def addImage(self, data, origin, scale, colormap, alpha):
+ # Non-uniform image
+ # http://wiki.scipy.org/Cookbook/Histograms
+ # Non-linear axes
+ # http://stackoverflow.com/questions/11488800/non-linear-axes-for-imshow-in-matplotlib
+ for parameter in (data, origin, scale):
+ assert parameter is not None
+
+ origin = float(origin[0]), float(origin[1])
+ scale = float(scale[0]), float(scale[1])
+ height, width = data.shape[0:2]
+
+ # All image are shown as RGBA image
+ image = Image(self.ax,
+ interpolation='nearest',
+ picker=True,
+ origin='lower',
+ silx_origin=origin,
+ silx_scale=scale)
+
+ if alpha < 1:
+ image.set_alpha(alpha)
+
+ # Set image extent
+ xmin = origin[0]
+ xmax = xmin + scale[0] * width
+ if scale[0] < 0.:
+ xmin, xmax = xmax, xmin
+
+ ymin = origin[1]
+ ymax = ymin + scale[1] * height
+ if scale[1] < 0.:
+ ymin, ymax = ymax, ymin
+
+ image.set_extent((xmin, xmax, ymin, ymax))
+
+ # Set image data
+ if scale[0] < 0. or scale[1] < 0.:
+ # For negative scale, step by -1
+ xstep = 1 if scale[0] >= 0. else -1
+ ystep = 1 if scale[1] >= 0. else -1
+ data = data[::ystep, ::xstep]
+
+ if data.ndim == 2: # Data image, convert to RGBA image
+ data = colormap.applyToData(data)
+ elif data.dtype == numpy.uint16:
+ # Normalize uint16 data to have a similar behavior as opengl backend
+ data = data.astype(numpy.float32)
+ data /= 65535
+
+ image.set_data(data)
+ self.ax.add_artist(image)
+ return image
+
+ def addTriangles(self, x, y, triangles, color, alpha):
+ for parameter in (x, y, triangles, color, alpha):
+ assert parameter is not None
+
+ color = numpy.array(color, copy=False)
+ assert color.ndim == 2 and len(color) == len(x)
+
+ if color.dtype not in [numpy.float32, numpy.float64]:
+ color = color.astype(numpy.float32) / 255.
+
+ collection = TriMesh(
+ Triangulation(x, y, triangles),
+ alpha=alpha,
+ pickradius=0) # 0 enables picking on filled triangle
+ collection.set_color(color)
+ self.ax.add_collection(collection)
+
+ return collection
+
+ def addShape(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ if (linebgcolor is not None and
+ shape not in ('rectangle', 'polygon', 'polylines')):
+ _logger.warning(
+ 'linebgcolor not implemented for %s with matplotlib backend',
+ shape)
+ xView = numpy.array(x, copy=False)
+ yView = numpy.array(y, copy=False)
+
+ linestyle = normalize_linestyle(linestyle)
+
+ if shape == "line":
+ item = self.ax.plot(x, y, color=color,
+ linestyle=linestyle, linewidth=linewidth,
+ marker=None)[0]
+
+ elif shape == "hline":
+ if hasattr(y, "__len__"):
+ y = y[-1]
+ item = self.ax.axhline(y, color=color,
+ linestyle=linestyle, linewidth=linewidth)
+
+ elif shape == "vline":
+ if hasattr(x, "__len__"):
+ x = x[-1]
+ item = self.ax.axvline(x, color=color,
+ linestyle=linestyle, linewidth=linewidth)
+
+ elif shape == 'rectangle':
+ xMin = numpy.nanmin(xView)
+ xMax = numpy.nanmax(xView)
+ yMin = numpy.nanmin(yView)
+ yMax = numpy.nanmax(yView)
+ w = xMax - xMin
+ h = yMax - yMin
+ item = Rectangle(xy=(xMin, yMin),
+ width=w,
+ height=h,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
+ if fill:
+ item.set_hatch('.')
+
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
+ self.ax.add_patch(item)
+
+ elif shape in ('polygon', 'polylines'):
+ points = numpy.array((xView, yView)).T
+ if shape == 'polygon':
+ closed = True
+ else: # shape == 'polylines'
+ closed = numpy.all(numpy.equal(points[0], points[-1]))
+ item = Polygon(points,
+ closed=closed,
+ fill=False,
+ color=color,
+ linestyle=linestyle,
+ linewidth=linewidth)
+ if fill and shape == 'polygon':
+ item.set_hatch('/')
+
+ if linestyle != "solid" and linebgcolor is not None:
+ item = _DoubleColoredLinePatch(item)
+ item.linebgcolor = linebgcolor
+
+ self.ax.add_patch(item)
+
+ else:
+ raise NotImplementedError("Unsupported item shape %s" % shape)
+
+ if overlay:
+ item.set_animated(True)
+
+ return item
+
+ def addMarker(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ textArtist = None
+
+ xmin, xmax = self.getGraphXLimits()
+ ymin, ymax = self.getGraphYLimits(axis=yaxis)
+
+ if yaxis == 'left':
+ ax = self.ax
+ elif yaxis == 'right':
+ ax = self.ax2
+ else:
+ assert(False)
+
+ marker = self._getMarkerFromSymbol(symbol)
+ if x is not None and y is not None:
+ line = ax.plot(x, y,
+ linestyle=" ",
+ color=color,
+ marker=marker,
+ markersize=10.)[-1]
+
+ if text is not None:
+ textArtist = _TextWithOffset(x, y, text,
+ color=color,
+ horizontalalignment='left')
+ if symbol is not None:
+ textArtist.pixel_offset = 10, 3
+ elif x is not None:
+ line = ax.axvline(x,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle)
+ if text is not None:
+ # Y position will be updated in updateMarkerText call
+ textArtist = _TextWithOffset(x, 1., text,
+ color=color,
+ horizontalalignment='left',
+ verticalalignment='top')
+ textArtist.pixel_offset = 5, 3
+ elif y is not None:
+ line = ax.axhline(y,
+ color=color,
+ linewidth=linewidth,
+ linestyle=linestyle)
+
+ if text is not None:
+ # X position will be updated in updateMarkerText call
+ textArtist = _TextWithOffset(1., y, text,
+ color=color,
+ horizontalalignment='right',
+ verticalalignment='top')
+ textArtist.pixel_offset = 5, 3
+ else:
+ raise RuntimeError('A marker must at least have one coordinate')
+
+ line.set_picker(True)
+ line.set_pickradius(5)
+
+ # All markers are overlays
+ line.set_animated(True)
+ if textArtist is not None:
+ ax.add_artist(textArtist)
+ textArtist.set_animated(True)
+
+ artists = [line] if textArtist is None else [line, textArtist]
+ container = _MarkerContainer(artists, symbol, x, y, yaxis)
+ container.updateMarkerText(xmin, xmax, ymin, ymax, self.isYAxisInverted())
+
+ return container
+
+ def _updateMarkers(self):
+ xmin, xmax = self.ax.get_xbound()
+ ymin1, ymax1 = self.ax.get_ybound()
+ ymin2, ymax2 = self.ax2.get_ybound()
+ yinverted = self.isYAxisInverted()
+ for item in self._overlayItems():
+ if isinstance(item, _MarkerContainer):
+ if item.yAxis == 'left':
+ item.updateMarkerText(xmin, xmax, ymin1, ymax1, yinverted)
+ else:
+ item.updateMarkerText(xmin, xmax, ymin2, ymax2, yinverted)
+
+ # Remove methods
+
+ def remove(self, item):
+ try:
+ item.remove()
+ except ValueError:
+ pass # Already removed e.g., in set[X|Y]AxisLogarithmic
+
+ # Interaction methods
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if flag:
+ lineh = self.ax.axhline(
+ self.ax.get_ybound()[0], visible=False, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ lineh.set_animated(True)
+
+ linev = self.ax.axvline(
+ self.ax.get_xbound()[0], visible=False, color=color,
+ linewidth=linewidth, linestyle=linestyle)
+ linev.set_animated(True)
+
+ self._graphCursor = lineh, linev
+ else:
+ if self._graphCursor:
+ lineh, linev = self._graphCursor
+ lineh.remove()
+ linev.remove()
+ self._graphCursor = tuple()
+
+ # Active curve
+
+ def setCurveColor(self, curve, color):
+ # Store Line2D and PathCollection
+ for artist in curve.get_children():
+ if isinstance(artist, (Line2D, LineCollection)):
+ artist.set_color(color)
+ elif isinstance(artist, PathCollection):
+ artist.set_facecolors(color)
+ artist.set_edgecolors(color)
+ else:
+ _logger.warning(
+ 'setActiveCurve ignoring artist %s', str(artist))
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self.fig.canvas
+
+ def _enableAxis(self, axis, flag=True):
+ """Show/hide Y axis
+
+ :param str axis: Axis name: 'left' or 'right'
+ :param bool flag: Default, True
+ """
+ assert axis in ('right', 'left')
+ axes = self.ax2 if axis == 'right' else self.ax
+ axes.get_yaxis().set_visible(flag)
+
+ def replot(self):
+ """Do not perform rendering.
+
+ Override in subclass to actually draw something.
+ """
+ with self._plot._paintContext():
+ self._replot()
+
+ def _replot(self):
+ """Call from subclass :meth:`replot` to handle updates"""
+ # TODO images, markers? scatter plot? move in remove?
+ # Right Y axis only support curve for now
+ # Hide right Y axis if no line is present
+ self._dirtyLimits = False
+ if not self.ax2.lines:
+ self._enableAxis('right', False)
+
+ def _drawOverlays(self):
+ """Draw overlays if any."""
+ def condition(item):
+ return (item.isVisible() and
+ item._backendRenderer is not None and
+ item.isOverlay())
+
+ for item in self.getItemsFromBackToFront(condition=condition):
+ if (isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right'):
+ axes = self.ax2
+ else:
+ axes = self.ax
+ axes.draw_artist(item._backendRenderer)
+
+ for item in self._graphCursor:
+ self.ax.draw_artist(item)
+
+ def updateZOrder(self):
+ """Reorder all items with z order from 0 to 1"""
+ items = self.getItemsFromBackToFront(
+ lambda item: item.isVisible() and item._backendRenderer is not None)
+ count = len(items)
+ for index, item in enumerate(items):
+ if item.getZValue() < 0.5:
+ # Make sure matplotlib z order is below the grid (with z=0.5)
+ zorder = 0.5 * index / count
+ else: # Make sure matplotlib z order is above the grid (> 0.5)
+ zorder = 1. + index / count
+ if zorder != item._backendRenderer.get_zorder():
+ item._backendRenderer.set_zorder(zorder)
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ self.updateZOrder()
+
+ # fileName can be also a StringIO or file instance
+ if dpi is not None:
+ self.fig.savefig(fileName, format=fileFormat, dpi=dpi)
+ else:
+ self.fig.savefig(fileName, format=fileFormat)
+ self._plot._setDirtyPlot()
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self.ax.set_title(title)
+
+ def setGraphXLabel(self, label):
+ self.ax.set_xlabel(label)
+
+ def setGraphYLabel(self, label, axis):
+ axes = self.ax if axis == 'left' else self.ax2
+ axes.set_ylabel(label)
+
+ # Graph limits
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ # Let matplotlib taking care of keep aspect ratio if any
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+
+ if y2min is not None and y2max is not None:
+ if not self.isYAxisInverted():
+ self.ax2.set_ylim(min(y2min, y2max), max(y2min, y2max))
+ else:
+ self.ax2.set_ylim(max(y2min, y2max), min(y2min, y2max))
+
+ if not self.isYAxisInverted():
+ self.ax.set_ylim(min(ymin, ymax), max(ymin, ymax))
+ else:
+ self.ax.set_ylim(max(ymin, ymax), min(ymin, ymax))
+
+ self._updateMarkers()
+
+ def getGraphXLimits(self):
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.ax.apply_aspect()
+ self.ax2.apply_aspect()
+ self._dirtyLimits = False
+ return self.ax.get_xbound()
+
+ def setGraphXLimits(self, xmin, xmax):
+ self._dirtyLimits = True
+ self.ax.set_xlim(min(xmin, xmax), max(xmin, xmax))
+ self._updateMarkers()
+
+ def getGraphYLimits(self, axis):
+ assert axis in ('left', 'right')
+ ax = self.ax2 if axis == 'right' else self.ax
+
+ if not ax.get_visible():
+ return None
+
+ if self._dirtyLimits and self.isKeepDataAspectRatio():
+ self.ax.apply_aspect()
+ self.ax2.apply_aspect()
+ self._dirtyLimits = False
+
+ return ax.get_ybound()
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ ax = self.ax2 if axis == 'right' else self.ax
+ if ymax < ymin:
+ ymin, ymax = ymax, ymin
+ self._dirtyLimits = True
+
+ if self.isKeepDataAspectRatio():
+ # matplotlib keeps limits of shared axis when keeping aspect ratio
+ # So x limits are kept when changing y limits....
+ # Change x limits first by taking into account aspect ratio
+ # and then change y limits.. so matplotlib does not need
+ # to make change (to y) to keep aspect ratio
+ xmin, xmax = ax.get_xbound()
+ curYMin, curYMax = ax.get_ybound()
+
+ newXRange = (xmax - xmin) * (ymax - ymin) / (curYMax - curYMin)
+ xcenter = 0.5 * (xmin + xmax)
+ ax.set_xlim(xcenter - 0.5 * newXRange, xcenter + 0.5 * newXRange)
+
+ if not self.isYAxisInverted():
+ ax.set_ylim(ymin, ymax)
+ else:
+ ax.set_ylim(ymax, ymin)
+
+ self._updateMarkers()
+
+ # Graph axes
+
+ def setXAxisTimeZone(self, tz):
+ super(BackendMatplotlib, self).setXAxisTimeZone(tz)
+
+ # Make new formatter and locator with the time zone.
+ self.setXAxisTimeSeries(self.isXAxisTimeSeries())
+
+ def isXAxisTimeSeries(self):
+ return self._isXAxisTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ self._isXAxisTimeSeries = isTimeSeries
+ if self._isXAxisTimeSeries:
+ # We can't use a matplotlib.dates.DateFormatter because it expects
+ # the data to be in datetimes. Silx works internally with
+ # timestamps (floats).
+ locator = NiceDateLocator(tz=self.getXAxisTimeZone())
+ self.ax.xaxis.set_major_locator(locator)
+ self.ax.xaxis.set_major_formatter(
+ NiceAutoDateFormatter(locator, tz=self.getXAxisTimeZone()))
+ else:
+ try:
+ scalarFormatter = ScalarFormatter(useOffset=False)
+ except:
+ _logger.warning('Cannot disabled axes offsets in %s ' %
+ matplotlib.__version__)
+ scalarFormatter = ScalarFormatter()
+ self.ax.xaxis.set_major_formatter(scalarFormatter)
+
+ def setXAxisLogarithmic(self, flag):
+ # Workaround for matplotlib 2.1.0 when one tries to set an axis
+ # to log scale with both limits <= 0
+ # In this case a draw with positive limits is needed first
+ if flag and self._matplotlibVersion >= _parse_version('2.1.0'):
+ xlim = self.ax.get_xlim()
+ if xlim[0] <= 0 and xlim[1] <= 0:
+ self.ax.set_xlim(1, 10)
+ self.draw()
+
+ self.ax2.set_xscale('log' if flag else 'linear')
+ self.ax.set_xscale('log' if flag else 'linear')
+
+ def setYAxisLogarithmic(self, flag):
+ # Workaround for matplotlib 2.0 issue with negative bounds
+ # before switching to log scale
+ if flag and self._matplotlibVersion >= _parse_version('2.0.0'):
+ redraw = False
+ for axis, dataRangeIndex in ((self.ax, 1), (self.ax2, 2)):
+ ylim = axis.get_ylim()
+ if ylim[0] <= 0 or ylim[1] <= 0:
+ dataRange = self._plot.getDataRange()[dataRangeIndex]
+ if dataRange is None:
+ dataRange = 1, 100 # Fallback
+ axis.set_ylim(*dataRange)
+ redraw = True
+ if redraw:
+ self.draw()
+
+ self.ax2.set_yscale('log' if flag else 'linear')
+ self.ax.set_yscale('log' if flag else 'linear')
+
+ def setYAxisInverted(self, flag):
+ if self.ax.yaxis_inverted() != bool(flag):
+ self.ax.invert_yaxis()
+ self._updateMarkers()
+
+ def isYAxisInverted(self):
+ return self.ax.yaxis_inverted()
+
+ def isKeepDataAspectRatio(self):
+ return self.ax.get_aspect() in (1.0, 'equal')
+
+ def setKeepDataAspectRatio(self, flag):
+ self.ax.set_aspect(1.0 if flag else 'auto')
+ self.ax2.set_aspect(1.0 if flag else 'auto')
+
+ def setGraphGrid(self, which):
+ self.ax.grid(False, which='both') # Disable all grid first
+ if which is not None:
+ self.ax.grid(True, which=which)
+
+ # Data <-> Pixel coordinates conversion
+
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ return 1.
+
+ def _mplToQtPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert matplotlib "display" space coord to Qt widget logical pixel
+ """
+ ratio = self._getDevicePixelRatio()
+ # Convert from matplotlib origin (bottom) to Qt origin (top)
+ # and apply device pixel ratio
+ return x / ratio, (self.fig.get_window_extent().height - y) / ratio
+
+ def _qtToMplPosition(self, x: float, y: float) -> Tuple[float, float]:
+ """Convert Qt widget logical pixel to matplotlib "display" space coord
+ """
+ ratio = self._getDevicePixelRatio()
+ # Apply device pixel ration and
+ # convert from Qt origin (top) to matplotlib origin (bottom)
+ return x * ratio, self.fig.get_window_extent().height - (y * ratio)
+
+ def dataToPixel(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ displayPos = ax.transData.transform_point((x, y)).transpose()
+ return self._mplToQtPosition(*displayPos)
+
+ def pixelToData(self, x, y, axis):
+ ax = self.ax2 if axis == "right" else self.ax
+ displayPos = self._qtToMplPosition(x, y)
+ return tuple(ax.transData.inverted().transform_point(displayPos))
+
+ def getPlotBoundsInPixels(self):
+ bbox = self.ax.get_window_extent()
+ # Warning this is not returning int...
+ ratio = self._getDevicePixelRatio()
+ return tuple(int(value / ratio) for value in (
+ bbox.xmin,
+ self.fig.get_window_extent().height - bbox.ymax,
+ bbox.width,
+ bbox.height))
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ width, height = 1. - left - right, 1. - top - bottom
+ position = left, bottom, width, height
+
+ # Toggle display of axes and viewbox rect
+ isFrameOn = position != (0., 0., 1., 1.)
+ self.ax.set_frame_on(isFrameOn)
+ self.ax2.set_frame_on(isFrameOn)
+
+ self.ax.set_position(position)
+ self.ax2.set_position(position)
+
+ self._synchronizeBackgroundColors()
+ self._synchronizeForegroundColors()
+ self._plot._setDirtyPlot()
+
+ def _synchronizeBackgroundColors(self):
+ backgroundColor = self._plot.getBackgroundColor().getRgbF()
+
+ dataBackgroundColor = self._plot.getDataBackgroundColor()
+ if dataBackgroundColor.isValid():
+ dataBackgroundColor = dataBackgroundColor.getRgbF()
+ else:
+ dataBackgroundColor = backgroundColor
+
+ if self.ax.get_frame_on():
+ self.fig.patch.set_facecolor(backgroundColor)
+ if self._matplotlibVersion < _parse_version('2'):
+ self.ax.set_axis_bgcolor(dataBackgroundColor)
+ else:
+ self.ax.set_facecolor(dataBackgroundColor)
+ else:
+ self.fig.patch.set_facecolor(dataBackgroundColor)
+
+ def _synchronizeForegroundColors(self):
+ foregroundColor = self._plot.getForegroundColor().getRgbF()
+
+ gridColor = self._plot.getGridColor()
+ if gridColor.isValid():
+ gridColor = gridColor.getRgbF()
+ else:
+ gridColor = foregroundColor
+
+ for axes in (self.ax, self.ax2):
+ if axes.get_frame_on():
+ axes.spines['bottom'].set_color(foregroundColor)
+ axes.spines['top'].set_color(foregroundColor)
+ axes.spines['right'].set_color(foregroundColor)
+ axes.spines['left'].set_color(foregroundColor)
+ axes.tick_params(axis='x', colors=foregroundColor)
+ axes.tick_params(axis='y', colors=foregroundColor)
+ axes.yaxis.label.set_color(foregroundColor)
+ axes.xaxis.label.set_color(foregroundColor)
+ axes.title.set_color(foregroundColor)
+
+ for line in axes.get_xgridlines():
+ line.set_color(gridColor)
+
+ for line in axes.get_ygridlines():
+ line.set_color(gridColor)
+ # axes.grid().set_markeredgecolor(gridColor)
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._synchronizeBackgroundColors()
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._synchronizeForegroundColors()
+
+
+class BackendMatplotlibQt(FigureCanvasQTAgg, BackendMatplotlib):
+ """QWidget matplotlib backend using a QtAgg canvas.
+
+ It adds fast overlay drawing and mouse event management.
+ """
+
+ _sigPostRedisplay = qt.Signal()
+ """Signal handling automatic asynchronous replot"""
+
+ def __init__(self, plot, parent=None):
+ BackendMatplotlib.__init__(self, plot, parent)
+ FigureCanvasQTAgg.__init__(self, self.fig)
+ self.setParent(parent)
+
+ self._limitsBeforeResize = None
+
+ FigureCanvasQTAgg.setSizePolicy(
+ self, qt.QSizePolicy.Expanding, qt.QSizePolicy.Expanding)
+ FigureCanvasQTAgg.updateGeometry(self)
+
+ # Make postRedisplay asynchronous using Qt signal
+ self._sigPostRedisplay.connect(
+ self.__deferredReplot, qt.Qt.QueuedConnection)
+
+ self._picked = None
+
+ self.mpl_connect('button_press_event', self._onMousePress)
+ self.mpl_connect('button_release_event', self._onMouseRelease)
+ self.mpl_connect('motion_notify_event', self._onMouseMove)
+ self.mpl_connect('scroll_event', self._onMouseWheel)
+
+ def postRedisplay(self):
+ self._sigPostRedisplay.emit()
+
+ def __deferredReplot(self):
+ # Since this is deferred, makes sure it is still needed
+ plot = self._plotRef()
+ if (plot is not None and
+ plot._getDirtyPlot() and
+ plot.getBackend() is self):
+ self.replot()
+
+ def _getDevicePixelRatio(self) -> float:
+ """Compatibility wrapper for devicePixelRatioF"""
+ if hasattr(self, 'devicePixelRatioF'):
+ ratio = self.devicePixelRatioF()
+ else: # Qt < 5.6 compatibility
+ ratio = float(self.devicePixelRatio())
+ # Safety net: avoid returning 0
+ return ratio if ratio != 0. else 1.
+
+ # Mouse event forwarding
+
+ _MPL_TO_PLOT_BUTTONS = {1: 'left', 2: 'middle', 3: 'right'}
+
+ def _onMousePress(self, event):
+ button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
+ if button is not None:
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMousePress(int(x), int(y), button)
+
+ def _onMouseMove(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
+ if self._graphCursor:
+ position = self._plot.pixelToData(
+ x, y, axis='left', check=True)
+ lineh, linev = self._graphCursor
+ if position is not None:
+ linev.set_visible(True)
+ linev.set_xdata((position[0], position[0]))
+ lineh.set_visible(True)
+ lineh.set_ydata((position[1], position[1]))
+ self._plot._setDirtyPlot(overlayOnly=True)
+ elif lineh.get_visible():
+ lineh.set_visible(False)
+ linev.set_visible(False)
+ self._plot._setDirtyPlot(overlayOnly=True)
+ # onMouseMove must trigger replot if dirty flag is raised
+
+ self._plot.onMouseMove(int(x), int(y))
+
+ def _onMouseRelease(self, event):
+ button = self._MPL_TO_PLOT_BUTTONS.get(event.button, None)
+ if button is not None:
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseRelease(int(x), int(y), button)
+
+ def _onMouseWheel(self, event):
+ x, y = self._mplToQtPosition(event.x, event.y)
+ self._plot.onMouseWheel(int(x), int(y), event.step)
+
+ def leaveEvent(self, event):
+ """QWidget event handler"""
+ try:
+ plot = self._plot
+ except RuntimeError:
+ pass
+ else:
+ plot.onMouseLeaveWidget()
+
+ # picking
+
+ def pickItem(self, x, y, item):
+ xDisplay, yDisplay = self._qtToMplPosition(x, y)
+ mouseEvent = MouseEvent(
+ 'button_press_event', self, int(xDisplay), int(yDisplay))
+ # Override axes and data position with the axes
+ mouseEvent.inaxes = item.axes
+ mouseEvent.xdata, mouseEvent.ydata = self.pixelToData(
+ x, y, axis='left' if item.axes is self.ax else 'right')
+ picked, info = item.contains(mouseEvent)
+
+ if not picked:
+ return None
+
+ elif isinstance(item, TriMesh):
+ # Convert selected triangle to data point indices
+ triangulation = item._triangulation
+ indices = triangulation.get_masked_triangles()[info['ind'][0]]
+
+ # Sort picked triangle points by distance to mouse
+ # from furthest to closest to put closest point last
+ # This is to be somewhat consistent with last scatter point
+ # being the top one.
+ xdata, ydata = self.pixelToData(x, y, axis='left')
+ dists = ((triangulation.x[indices] - xdata) ** 2 +
+ (triangulation.y[indices] - ydata) ** 2)
+ return indices[numpy.flip(numpy.argsort(dists), axis=0)]
+
+ else: # Returns indices if any
+ return info.get('ind', ())
+
+ # replot control
+
+ def resizeEvent(self, event):
+ # Store current limits
+ self._limitsBeforeResize = (
+ self.ax.get_xbound(), self.ax.get_ybound(), self.ax2.get_ybound())
+
+ FigureCanvasQTAgg.resizeEvent(self, event)
+ if self.isKeepDataAspectRatio() or self._hasOverlays():
+ # This is needed with matplotlib 1.5.x and 2.0.x
+ self._plot._setDirtyPlot()
+
+ def draw(self):
+ """Overload draw
+
+ It performs a full redraw (including overlays) of the plot.
+ It also resets background and emit limits changed signal.
+
+ This is directly called by matplotlib for widget resize.
+ """
+ self.updateZOrder()
+
+ # Starting with mpl 2.1.0, toggling autoscale raises a ValueError
+ # in some situations. See #1081, #1136, #1163,
+ if self._matplotlibVersion >= _parse_version("2.0.0"):
+ try:
+ FigureCanvasQTAgg.draw(self)
+ except ValueError as err:
+ _logger.debug(
+ "ValueError caught while calling FigureCanvasQTAgg.draw: "
+ "'%s'", err)
+ else:
+ FigureCanvasQTAgg.draw(self)
+
+ if self._hasOverlays():
+ # Save background
+ self._background = self.copy_from_bbox(self.fig.bbox)
+ else:
+ self._background = None # Reset background
+
+ # Check if limits changed due to a resize of the widget
+ if self._limitsBeforeResize is not None:
+ xLimits, yLimits, yRightLimits = self._limitsBeforeResize
+ self._limitsBeforeResize = None
+
+ if (xLimits != self.ax.get_xbound() or
+ yLimits != self.ax.get_ybound()):
+ self._updateMarkers()
+
+ if xLimits != self.ax.get_xbound():
+ self._plot.getXAxis()._emitLimitsChanged()
+ if yLimits != self.ax.get_ybound():
+ self._plot.getYAxis(axis='left')._emitLimitsChanged()
+ if yRightLimits != self.ax2.get_ybound():
+ self._plot.getYAxis(axis='right')._emitLimitsChanged()
+
+ self._drawOverlays()
+
+ def replot(self):
+ with self._plot._paintContext():
+ BackendMatplotlib._replot(self)
+
+ dirtyFlag = self._plot._getDirtyPlot()
+
+ if dirtyFlag == 'overlay':
+ # Only redraw overlays using fast rendering path
+ if self._background is None:
+ self._background = self.copy_from_bbox(self.fig.bbox)
+ self.restore_region(self._background)
+ self._drawOverlays()
+ self.blit(self.fig.bbox)
+
+ elif dirtyFlag: # Need full redraw
+ self.draw()
+
+ # Workaround issue of rendering overlays with some matplotlib versions
+ if (_parse_version('1.5') <= self._matplotlibVersion < _parse_version('2.1') and
+ not hasattr(self, '_firstReplot')):
+ self._firstReplot = False
+ if self._hasOverlays():
+ qt.QTimer.singleShot(0, self.draw) # Request async draw
+
+ # cursor
+
+ _QT_CURSORS = {
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ if cursor is None:
+ FigureCanvasQTAgg.unsetCursor(self)
+ else:
+ cursor = self._QT_CURSORS[cursor]
+ FigureCanvasQTAgg.setCursor(self, qt.QCursor(cursor))
diff --git a/src/silx/gui/plot/backends/BackendOpenGL.py b/src/silx/gui/plot/backends/BackendOpenGL.py
new file mode 100755
index 0000000..f1a12af
--- /dev/null
+++ b/src/silx/gui/plot/backends/BackendOpenGL.py
@@ -0,0 +1,1420 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""OpenGL Plot backend."""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+import logging
+import weakref
+
+import numpy
+
+from .. import items
+from .._utils import FLOAT32_MINPOS
+from . import BackendBase
+from ... import colors
+from ... import qt
+
+from ..._glutils import gl
+from ... import _glutils as glu
+from . import glutils
+from .glutils.PlotImageFile import saveImageToFile
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO idea: BackendQtMixIn class to share code between mpl and gl
+# TODO check if OpenGL is available
+# TODO make an off-screen mesa backend
+
+# Content #####################################################################
+
+class _ShapeItem(dict):
+ def __init__(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ super(_ShapeItem, self).__init__()
+
+ if shape not in ('polygon', 'rectangle', 'line',
+ 'vline', 'hline', 'polylines'):
+ raise NotImplementedError("Unsupported shape {0}".format(shape))
+
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ if shape == 'rectangle':
+ xMin, xMax = x
+ x = numpy.array((xMin, xMin, xMax, xMax))
+ yMin, yMax = y
+ y = numpy.array((yMin, yMax, yMax, yMin))
+
+ # Ignore fill for polylines to mimic matplotlib
+ fill = fill if shape != 'polylines' else False
+
+ self.update({
+ 'shape': shape,
+ 'color': colors.rgba(color),
+ 'fill': 'hatch' if fill else None,
+ 'x': x,
+ 'y': y,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
+ 'linebgcolor': linebgcolor,
+ })
+
+
+class _MarkerItem(dict):
+ def __init__(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ super(_MarkerItem, self).__init__()
+
+ if symbol is None:
+ symbol = '+'
+
+ # Apply constraint to provided position
+ isConstraint = (constraint is not None and
+ x is not None and y is not None)
+ if isConstraint:
+ x, y = constraint(x, y)
+
+ self.update({
+ 'x': x,
+ 'y': y,
+ 'text': text,
+ 'color': colors.rgba(color),
+ 'constraint': constraint if isConstraint else None,
+ 'symbol': symbol,
+ 'linestyle': linestyle,
+ 'linewidth': linewidth,
+ 'yaxis': yaxis,
+ })
+
+
+# shaders #####################################################################
+
+_baseVertShd = """
+ attribute vec2 position;
+ uniform mat4 matrix;
+ uniform bvec2 isLog;
+
+ const float oneOverLog10 = 0.43429448190325176;
+
+ void main(void) {
+ vec2 posTransformed = position;
+ if (isLog.x) {
+ posTransformed.x = oneOverLog10 * log(position.x);
+ }
+ if (isLog.y) {
+ posTransformed.y = oneOverLog10 * log(position.y);
+ }
+ gl_Position = matrix * vec4(posTransformed, 0.0, 1.0);
+ }
+ """
+
+_baseFragShd = """
+ uniform vec4 color;
+ uniform int hatchStep;
+ uniform float tickLen;
+
+ void main(void) {
+ if (tickLen != 0.) {
+ if (mod((gl_FragCoord.x + gl_FragCoord.y) / tickLen, 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ } else if (hatchStep == 0 ||
+ mod(gl_FragCoord.x - gl_FragCoord.y, float(hatchStep)) == 0.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+
+_texVertShd = """
+ attribute vec2 position;
+ attribute vec2 texCoords;
+ uniform mat4 matrix;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ coords = texCoords;
+ }
+ """
+
+_texFragShd = """
+ uniform sampler2D tex;
+
+ varying vec2 coords;
+
+ void main(void) {
+ gl_FragColor = texture2D(tex, coords);
+ gl_FragColor.a = 1.0;
+ }
+ """
+
+# BackendOpenGL ###############################################################
+
+
+class BackendOpenGL(BackendBase.BackendBase, glu.OpenGLWidget):
+ """OpenGL-based Plot backend.
+
+ WARNINGS:
+ Unless stated otherwise, this API is NOT thread-safe and MUST be
+ called from the main thread.
+ When numpy arrays are passed as arguments to the API (through
+ :func:`addCurve` and :func:`addImage`), they are copied only if
+ required.
+ So, the caller should not modify these arrays afterwards.
+ """
+
+ def __init__(self, plot, parent=None, f=qt.Qt.WindowFlags()):
+ glu.OpenGLWidget.__init__(self, parent,
+ alphaBufferSize=8,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=(2, 1),
+ f=f)
+ BackendBase.BackendBase.__init__(self, plot, parent)
+
+ self._backgroundColor = 1., 1., 1., 1.
+ self._dataBackgroundColor = 1., 1., 1., 1.
+
+ self.matScreenProj = glutils.mat4Identity()
+
+ self._progBase = glu.Program(
+ _baseVertShd, _baseFragShd, attrib0='position')
+ self._progTex = glu.Program(
+ _texVertShd, _texFragShd, attrib0='position')
+ self._plotFBOs = weakref.WeakKeyDictionary()
+
+ self._keepDataAspectRatio = False
+
+ self._crosshairCursor = None
+ self._mousePosInPixels = None
+
+ self._glGarbageCollector = []
+
+ self._plotFrame = glutils.GLPlotFrame2D(
+ foregroundColor=(0., 0., 0., 1.),
+ gridColor=(.7, .7, .7, 1.),
+ marginRatios=(.15, .1, .1, .15))
+ self._plotFrame.size = ( # Init size with size int
+ int(self.getDevicePixelRatio() * 640),
+ int(self.getDevicePixelRatio() * 480))
+
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ # QWidget
+
+ _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
+
+ def sizeHint(self):
+ return qt.QSize(8 * 80, 6 * 80) # Mimic MatplotlibBackend
+
+ def mousePressEvent(self, event):
+ if event.button() not in self._MOUSE_BTNS:
+ return super(BackendOpenGL, self).mousePressEvent(event)
+ self._plot.onMousePress(
+ event.x(), event.y(), self._MOUSE_BTNS[event.button()])
+ event.accept()
+
+ def mouseMoveEvent(self, event):
+ qtPos = event.x(), event.y()
+
+ previousMousePosInPixels = self._mousePosInPixels
+ if qtPos == self._mouseInPlotArea(*qtPos):
+ devicePixelRatio = self.getDevicePixelRatio()
+ devicePos = qtPos[0] * devicePixelRatio, qtPos[1] * devicePixelRatio
+ self._mousePosInPixels = devicePos # Mouse in plot area
+ else:
+ self._mousePosInPixels = None # Mouse outside plot area
+
+ if (self._crosshairCursor is not None and
+ previousMousePosInPixels != self._mousePosInPixels):
+ # Avoid replot when cursor remains outside plot area
+ self._plot._setDirtyPlot(overlayOnly=True)
+
+ self._plot.onMouseMove(*qtPos)
+ event.accept()
+
+ def mouseReleaseEvent(self, event):
+ if event.button() not in self._MOUSE_BTNS:
+ return super(BackendOpenGL, self).mouseReleaseEvent(event)
+ self._plot.onMouseRelease(
+ event.x(), event.y(), self._MOUSE_BTNS[event.button()])
+ event.accept()
+
+ def wheelEvent(self, event):
+ delta = event.angleDelta().y()
+ angleInDegrees = delta / 8.
+ if qt.BINDING == "PySide6":
+ x, y = event.position().x(), event.position().y()
+ else:
+ x, y = event.x(), event.y()
+ self._plot.onMouseWheel(x, y, angleInDegrees)
+ event.accept()
+
+ def leaveEvent(self, _):
+ self._plot.onMouseLeaveWidget()
+
+ # OpenGLWidget API
+
+ def initializeGL(self):
+ gl.testGL()
+
+ gl.glClearStencil(0)
+
+ gl.glEnable(gl.GL_BLEND)
+ # gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
+ gl.glBlendFuncSeparate(gl.GL_SRC_ALPHA,
+ gl.GL_ONE_MINUS_SRC_ALPHA,
+ gl.GL_ONE,
+ gl.GL_ONE)
+
+ # For lines
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ # For points
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ # gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def _paintDirectGL(self):
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+ self._renderOverlayGL()
+
+ def _paintFBOGL(self):
+ context = glu.Context.getCurrent()
+ plotFBOTex = self._plotFBOs.get(context)
+ if (self._plot._getDirtyPlot() or self._plotFrame.isDirty or
+ plotFBOTex is None):
+ self._plotVertices = (
+ # Vertex coordinates
+ numpy.array(((-1., -1.), (1., -1.), (-1., 1.), (1., 1.)),
+ dtype=numpy.float32),
+ # Texture coordinates
+ numpy.array(((0., 0.), (1., 0.), (0., 1.), (1., 1.)),
+ dtype=numpy.float32))
+ if plotFBOTex is None or \
+ plotFBOTex.shape[1] != self._plotFrame.size[0] or \
+ plotFBOTex.shape[0] != self._plotFrame.size[1]:
+ if plotFBOTex is not None:
+ plotFBOTex.discard()
+ plotFBOTex = glu.FramebufferTexture(
+ gl.GL_RGBA,
+ shape=(self._plotFrame.size[1],
+ self._plotFrame.size[0]),
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=(gl.GL_CLAMP_TO_EDGE,
+ gl.GL_CLAMP_TO_EDGE))
+ self._plotFBOs[context] = plotFBOTex
+
+ with plotFBOTex:
+ gl.glClearColor(*self._backgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+ self._renderPlotAreaGL()
+ self._plotFrame.render()
+
+ # Render plot in screen coords
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ self._progTex.use()
+ texUnit = 0
+
+ gl.glUniform1i(self._progTex.uniforms['tex'], texUnit)
+ gl.glUniformMatrix4fv(self._progTex.uniforms['matrix'], 1, gl.GL_TRUE,
+ glutils.mat4Identity().astype(numpy.float32))
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes['position'])
+ gl.glVertexAttribPointer(self._progTex.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[0])
+
+ gl.glEnableVertexAttribArray(self._progTex.attributes['texCoords'])
+ gl.glVertexAttribPointer(self._progTex.attributes['texCoords'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._plotVertices[1])
+
+ with plotFBOTex.texture:
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._plotVertices[0]))
+
+ self._renderOverlayGL()
+
+ def paintGL(self):
+ plot = self._plotRef()
+ if plot is None:
+ return
+
+ with plot._paintContext():
+ with glu.Context.current(self.context()):
+ # Release OpenGL resources
+ for item in self._glGarbageCollector:
+ item.discard()
+ self._glGarbageCollector = []
+
+ gl.glClearColor(*self._backgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT)
+
+ # Check if window is large enough
+ if self._plotFrame.plotSize <= (2, 2):
+ return
+
+ # Sync plot frame with window
+ self._plotFrame.devicePixelRatio = self.getDevicePixelRatio()
+ # self._paintDirectGL()
+ self._paintFBOGL()
+
+ def _renderItems(self, overlay=False):
+ """Render items according to :class:`PlotWidget` order
+
+ Note: Scissor test should already be set.
+
+ :param bool overlay:
+ False (the default) to render item that are not overlays.
+ True to render items that are overlays.
+ """
+ # Values that are often used
+ plotWidth, plotHeight = self._plotFrame.plotSize
+ isXLog = self._plotFrame.xAxis.isLog
+ isYLog = self._plotFrame.yAxis.isLog
+ isYInverted = self._plotFrame.isYAxisInverted
+
+ # Used by marker rendering
+ labels = []
+ pixelOffset = 3
+
+ context = glutils.RenderContext(
+ isXLog=isXLog, isYLog=isYLog, dpi=self.getDotsPerInch())
+
+ for plotItem in self.getItemsFromBackToFront(
+ condition=lambda i: i.isVisible() and i.isOverlay() == overlay):
+ if plotItem._backendRenderer is None:
+ continue
+
+ item = plotItem._backendRenderer
+
+ if isinstance(item, glutils.GLPlotItem): # Render data items
+ gl.glViewport(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ # Set matrix
+ if item.yaxis == 'right':
+ context.matrix = self._plotFrame.transformedDataY2ProjMat
+ else:
+ context.matrix = self._plotFrame.transformedDataProjMat
+ item.render(context)
+
+ elif isinstance(item, _ShapeItem): # Render shape items
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ if ((isXLog and numpy.min(item['x']) < FLOAT32_MINPOS) or
+ (isYLog and numpy.min(item['y']) < FLOAT32_MINPOS)):
+ # Ignore items <= 0. on log axes
+ continue
+
+ if item['shape'] == 'hline':
+ width = self._plotFrame.size[0]
+ _, yPixel = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]),
+ item['y'],
+ axis='left')
+ subShapes = [numpy.array(((0., yPixel), (width, yPixel)),
+ dtype=numpy.float32)]
+
+ elif item['shape'] == 'vline':
+ xPixel, _ = self._plotFrame.dataToPixel(
+ item['x'],
+ 0.5 * sum(self._plotFrame.dataRanges[1]),
+ axis='left')
+ height = self._plotFrame.size[1]
+ subShapes = [numpy.array(((xPixel, 0), (xPixel, height)),
+ dtype=numpy.float32)]
+
+ else:
+ # Split sub-shapes at not finite values
+ splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
+ numpy.isfinite(item['x']), numpy.isfinite(item['y']))))[0]
+ splits = numpy.concatenate(([-1], splits, [len(item['x'])]))
+ subShapes = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:]):
+ if end > begin:
+ subShapes.append(numpy.array([
+ self._plotFrame.dataToPixel(x, y, axis='left')
+ for (x, y) in zip(item['x'][begin:end], item['y'][begin:end])]))
+
+ for points in subShapes: # Draw each sub-shape
+ # Draw the fill
+ if (item['fill'] is not None and
+ item['shape'] not in ('hline', 'vline')):
+ self._progBase.use()
+ gl.glUniformMatrix4fv(
+ self._progBase.uniforms['matrix'], 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+
+ shape2D = glutils.FilledShape2D(
+ points, style=item['fill'], color=item['color'])
+ shape2D.render(
+ posAttrib=self._progBase.attributes['position'],
+ colorUnif=self._progBase.uniforms['color'],
+ hatchStepUnif=self._progBase.uniforms['hatchStep'])
+
+ # Draw the stroke
+ if item['linestyle'] not in ('', ' ', None):
+ if item['shape'] != 'polylines':
+ # close the polyline
+ points = numpy.append(points,
+ numpy.atleast_2d(points[0]), axis=0)
+
+ lines = glutils.GLLines2D(
+ points[:, 0], points[:, 1],
+ style=item['linestyle'],
+ color=item['color'],
+ dash2ndColor=item['linebgcolor'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ elif isinstance(item, _MarkerItem):
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ xCoord, yCoord, yAxis = item['x'], item['y'], item['yaxis']
+
+ if ((isXLog and xCoord is not None and xCoord <= 0) or
+ (isYLog and yCoord is not None and yCoord <= 0)):
+ # Do not render markers with negative coords on log axis
+ continue
+
+ color = item['color']
+ intensity = color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114
+ bgColor = (1., 1., 1., 0.5) if intensity <= 0.5 else (0., 0., 0., 0.5)
+ if xCoord is None or yCoord is None:
+ if xCoord is None: # Horizontal line in data space
+ pixelPos = self._plotFrame.dataToPixel(
+ 0.5 * sum(self._plotFrame.dataRanges[0]),
+ yCoord,
+ axis=yAxis)
+
+ if item['text'] is not None:
+ x = self._plotFrame.size[0] - \
+ self._plotFrame.margins.right - pixelOffset
+ y = pixelPos[1] - pixelOffset
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=bgColor,
+ align=glutils.RIGHT,
+ valign=glutils.BOTTOM,
+ devicePixelRatio=self.getDevicePixelRatio())
+ labels.append(label)
+
+ width = self._plotFrame.size[0]
+ lines = glutils.GLLines2D(
+ (0, width), (pixelPos[1], pixelPos[1]),
+ style=item['linestyle'],
+ color=item['color'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ else: # yCoord is None: vertical line in data space
+ yRange = self._plotFrame.dataRanges[1 if yAxis == 'left' else 2]
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, 0.5 * sum(yRange), axis=yAxis)
+
+ if item['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = self._plotFrame.margins.top + pixelOffset
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=bgColor,
+ align=glutils.LEFT,
+ valign=glutils.TOP,
+ devicePixelRatio=self.getDevicePixelRatio())
+ labels.append(label)
+
+ height = self._plotFrame.size[1]
+ lines = glutils.GLLines2D(
+ (pixelPos[0], pixelPos[0]), (0, height),
+ style=item['linestyle'],
+ color=item['color'],
+ width=item['linewidth'])
+ context.matrix = self.matScreenProj
+ lines.render(context)
+
+ else:
+ xmin, xmax = self._plot.getXAxis().getLimits()
+ ymin, ymax = self._plot.getYAxis(axis=yAxis).getLimits()
+ if not xmin < xCoord < xmax or not ymin < yCoord < ymax:
+ # Do not render markers outside visible plot area
+ continue
+ pixelPos = self._plotFrame.dataToPixel(
+ xCoord, yCoord, axis=yAxis)
+
+ if isYInverted:
+ valign = glutils.BOTTOM
+ vPixelOffset = -pixelOffset
+ else:
+ valign = glutils.TOP
+ vPixelOffset = pixelOffset
+
+ if item['text'] is not None:
+ x = pixelPos[0] + pixelOffset
+ y = pixelPos[1] + vPixelOffset
+ label = glutils.Text2D(
+ item['text'], x, y,
+ color=item['color'],
+ bgColor=bgColor,
+ align=glutils.LEFT,
+ valign=valign,
+ devicePixelRatio=self.getDevicePixelRatio())
+ labels.append(label)
+
+ # For now simple implementation: using a curve for each marker
+ # Should pack all markers to a single set of points
+ markerCurve = glutils.GLPlotCurve2D(
+ numpy.array((pixelPos[0],), dtype=numpy.float64),
+ numpy.array((pixelPos[1],), dtype=numpy.float64),
+ marker=item['symbol'],
+ markerColor=item['color'],
+ markerSize=11)
+
+ context = glutils.RenderContext(
+ matrix=self.matScreenProj,
+ isXLog=False,
+ isYLog=False,
+ dpi=self.getDotsPerInch())
+ markerCurve.render(context)
+ markerCurve.discard()
+
+ else:
+ _logger.error('Unsupported item: %s', str(item))
+ continue
+
+ # Render marker labels
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+ for label in labels:
+ label.render(self.matScreenProj)
+
+ def _renderOverlayGL(self):
+ """Render overlay layer: overlay items and crosshair."""
+ plotWidth, plotHeight = self._plotFrame.plotSize
+
+ # Scissor to plot area
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ self._renderItems(overlay=True)
+
+ # Render crosshair cursor
+ if self._crosshairCursor is not None and self._mousePosInPixels is not None:
+ self._progBase.use()
+ gl.glUniform2i(self._progBase.uniforms['isLog'], False, False)
+ gl.glUniform1f(self._progBase.uniforms['tickLen'], 0.)
+ posAttrib = self._progBase.attributes['position']
+ matrixUnif = self._progBase.uniforms['matrix']
+ colorUnif = self._progBase.uniforms['color']
+ hatchStepUnif = self._progBase.uniforms['hatchStep']
+
+ gl.glViewport(0, 0, self._plotFrame.size[0], self._plotFrame.size[1])
+
+ gl.glUniformMatrix4fv(matrixUnif, 1, gl.GL_TRUE,
+ self.matScreenProj.astype(numpy.float32))
+
+ color, lineWidth = self._crosshairCursor
+ gl.glUniform4f(colorUnif, *color)
+ gl.glUniform1i(hatchStepUnif, 0)
+
+ xPixel, yPixel = self._mousePosInPixels
+ xPixel, yPixel = xPixel + 0.5, yPixel + 0.5
+ vertices = numpy.array(((0., yPixel),
+ (self._plotFrame.size[0], yPixel),
+ (xPixel, 0.),
+ (xPixel, self._plotFrame.size[1])),
+ dtype=numpy.float32)
+
+ gl.glEnableVertexAttribArray(posAttrib)
+ gl.glVertexAttribPointer(posAttrib,
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+ gl.glLineWidth(lineWidth)
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def _renderPlotAreaGL(self):
+ """Render base layer of plot area.
+
+ It renders the background, grid and items except overlays
+ """
+ plotWidth, plotHeight = self._plotFrame.plotSize
+
+ gl.glScissor(self._plotFrame.margins.left,
+ self._plotFrame.margins.bottom,
+ plotWidth, plotHeight)
+ gl.glEnable(gl.GL_SCISSOR_TEST)
+
+ if self._dataBackgroundColor != self._backgroundColor:
+ gl.glClearColor(*self._dataBackgroundColor)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ self._plotFrame.renderGrid()
+
+ # Matrix
+ trBounds = self._plotFrame.transformedDataRanges
+ if trBounds.x[0] != trBounds.x[1] and trBounds.y[0] != trBounds.y[1]:
+ # Do rendering of items
+ self._renderItems(overlay=False)
+
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+
+ def resizeGL(self, width, height):
+ if width == 0 or height == 0: # Do not resize
+ return
+
+ self._plotFrame.size = (
+ int(self.getDevicePixelRatio() * width),
+ int(self.getDevicePixelRatio() * height))
+
+ self.matScreenProj = glutils.mat4Ortho(
+ 0, self._plotFrame.size[0],
+ self._plotFrame.size[1], 0,
+ 1, -1)
+
+ # Store current ranges
+ previousXRange = self.getGraphXLimits()
+ previousYRange = self.getGraphYLimits(axis='left')
+ previousYRightRange = self.getGraphYLimits(axis='right')
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
+ self._plotFrame.dataRanges
+ self.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+
+ # If plot range has changed, then emit signal
+ if previousXRange != self.getGraphXLimits():
+ self._plot.getXAxis()._emitLimitsChanged()
+ if previousYRange != self.getGraphYLimits(axis='left'):
+ self._plot.getYAxis(axis='left')._emitLimitsChanged()
+ if previousYRightRange != self.getGraphYLimits(axis='right'):
+ self._plot.getYAxis(axis='right')._emitLimitsChanged()
+
+ # Add methods
+
+ @staticmethod
+ def _castArrayTo(v):
+ """Returns best floating type to cast the array to.
+
+ :param numpy.ndarray v: Array to cast
+ :rtype: numpy.dtype
+ :raise ValueError: If dtype is not supported
+ """
+ if numpy.issubdtype(v.dtype, numpy.floating):
+ return numpy.float32 if v.itemsize <= 4 else numpy.float64
+ elif numpy.issubdtype(v.dtype, numpy.integer):
+ return numpy.float32 if v.itemsize <= 2 else numpy.float64
+ else:
+ raise ValueError('Unsupported data type')
+
+ def addCurve(self, x, y,
+ color, symbol, linewidth, linestyle,
+ yaxis,
+ xerror, yerror,
+ fill, alpha, symbolsize, baseline):
+ for parameter in (x, y, color, symbol, linewidth, linestyle,
+ yaxis, fill, symbolsize):
+ assert parameter is not None
+ assert yaxis in ('left', 'right')
+
+ # Convert input data
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ # Check if float32 is enough
+ if (self._castArrayTo(x) is numpy.float32 and
+ self._castArrayTo(y) is numpy.float32):
+ dtype = numpy.float32
+ else:
+ dtype = numpy.float64
+
+ x = numpy.array(x, dtype=dtype, copy=False, order='C')
+ y = numpy.array(y, dtype=dtype, copy=False, order='C')
+
+ # Convert errors to float32
+ if xerror is not None:
+ xerror = numpy.array(
+ xerror, dtype=numpy.float32, copy=False, order='C')
+ if yerror is not None:
+ yerror = numpy.array(
+ yerror, dtype=numpy.float32, copy=False, order='C')
+
+ # Handle axes log scale: convert data
+
+ if self._plotFrame.xAxis.isLog:
+ logX = numpy.log10(x)
+
+ if xerror is not None:
+ # Transform xerror so that
+ # log10(x) +/- xerror' = log10(x +/- xerror)
+ if hasattr(xerror, 'shape') and len(xerror.shape) == 2:
+ xErrorMinus, xErrorPlus = xerror[0], xerror[1]
+ else:
+ xErrorMinus, xErrorPlus = xerror, xerror
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ # Ignore divide by zero, invalid value encountered in log10
+ xErrorMinus = logX - numpy.log10(x - xErrorMinus)
+ xErrorPlus = numpy.log10(x + xErrorPlus) - logX
+ xerror = numpy.array((xErrorMinus, xErrorPlus),
+ dtype=numpy.float32)
+
+ x = logX
+
+ isYLog = (yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
+ yaxis == 'right' and self._plotFrame.y2Axis.isLog)
+
+ if isYLog:
+ logY = numpy.log10(y)
+
+ if yerror is not None:
+ # Transform yerror so that
+ # log10(y) +/- yerror' = log10(y +/- yerror)
+ if hasattr(yerror, 'shape') and len(yerror.shape) == 2:
+ yErrorMinus, yErrorPlus = yerror[0], yerror[1]
+ else:
+ yErrorMinus, yErrorPlus = yerror, yerror
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ # Ignore divide by zero, invalid value encountered in log10
+ yErrorMinus = logY - numpy.log10(y - yErrorMinus)
+ yErrorPlus = numpy.log10(y + yErrorPlus) - logY
+ yerror = numpy.array((yErrorMinus, yErrorPlus),
+ dtype=numpy.float32)
+
+ y = logY
+
+ # TODO check if need more filtering of error (e.g., clip to positive)
+
+ # TODO check and improve this
+ if (len(color) == 4 and
+ type(color[3]) in [type(1), numpy.uint8, numpy.int8]):
+ color = numpy.array(color, dtype=numpy.float32) / 255.
+
+ if isinstance(color, numpy.ndarray) and color.ndim == 2:
+ colorArray = color
+ color = None
+ else:
+ colorArray = None
+ color = colors.rgba(color)
+
+ if alpha < 1.: # Apply image transparency
+ if colorArray is not None and colorArray.shape[1] == 4:
+ # multiply alpha channel
+ colorArray[:, 3] = colorArray[:, 3] * alpha
+ if color is not None:
+ color = color[0], color[1], color[2], color[3] * alpha
+
+ fillColor = None
+ if fill is True:
+ fillColor = color
+ curve = glutils.GLPlotCurve2D(
+ x, y, colorArray,
+ xError=xerror,
+ yError=yerror,
+ lineStyle=linestyle,
+ lineColor=color,
+ lineWidth=linewidth,
+ marker=symbol,
+ markerColor=color,
+ markerSize=symbolsize,
+ fillColor=fillColor,
+ baseline=baseline,
+ isYLog=isYLog)
+ curve.yaxis = 'left' if yaxis is None else yaxis
+
+ if yaxis == "right":
+ self._plotFrame.isY2Axis = True
+
+ return curve
+
+ def addImage(self, data,
+ origin, scale,
+ colormap, alpha):
+ for parameter in (data, origin, scale):
+ assert parameter is not None
+
+ if data.ndim == 2:
+ # Ensure array is contiguous and eventually convert its type
+ dtypes = [dtype for dtype in (
+ numpy.float32, numpy.float16, numpy.uint8, numpy.uint16)
+ if glu.isSupportedGLType(dtype)]
+ if data.dtype in dtypes:
+ data = numpy.array(data, copy=False, order='C')
+ else:
+ _logger.info(
+ 'addImage: Convert %s data to float32', str(data.dtype))
+ data = numpy.array(data, dtype=numpy.float32, order='C')
+
+ normalization = colormap.getNormalization()
+ if normalization in glutils.GLPlotColormap.SUPPORTED_NORMALIZATIONS:
+ # Fast path applying colormap on the GPU
+ cmapRange = colormap.getColormapRange(data=data)
+ colormapLut = colormap.getNColors(nbColors=256)
+ gamma = colormap.getGammaNormalizationParameter()
+ nanColor = colors.rgba(colormap.getNaNColor())
+
+ image = glutils.GLPlotColormap(
+ data,
+ origin,
+ scale,
+ colormapLut,
+ normalization,
+ gamma,
+ cmapRange,
+ alpha,
+ nanColor)
+
+ else: # Fallback applying colormap on CPU
+ rgba = colormap.applyToData(data)
+ image = glutils.GLPlotRGBAImage(rgba, origin, scale, alpha)
+
+ elif len(data.shape) == 3:
+ # For RGB, RGBA data
+ assert data.shape[2] in (3, 4)
+
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ data = numpy.array(data, dtype=numpy.float32, copy=False)
+ elif data.dtype in [numpy.uint8, numpy.uint16]:
+ pass
+ elif numpy.issubdtype(data.dtype, numpy.integer):
+ data = numpy.array(data, dtype=numpy.uint8, copy=False)
+ else:
+ raise ValueError('Unsupported data type')
+
+ image = glutils.GLPlotRGBAImage(data, origin, scale, alpha)
+
+ else:
+ raise RuntimeError("Unsupported data shape {0}".format(data.shape))
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and image.xMin <= 0.:
+ raise RuntimeError(
+ 'Cannot add image with X <= 0 with X axis log scale')
+ if self._plotFrame.yAxis.isLog and image.yMin <= 0.:
+ raise RuntimeError(
+ 'Cannot add image with Y <= 0 with Y axis log scale')
+
+ return image
+
+ def addTriangles(self, x, y, triangles,
+ color, alpha):
+ # Handle axes log scale: convert data
+ if self._plotFrame.xAxis.isLog:
+ x = numpy.log10(x)
+ if self._plotFrame.yAxis.isLog:
+ y = numpy.log10(y)
+
+ triangles = glutils.GLPlotTriangles(x, y, color, triangles, alpha)
+
+ return triangles
+
+ def addShape(self, x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor):
+ x = numpy.array(x, copy=False)
+ y = numpy.array(y, copy=False)
+
+ # TODO is this needed?
+ if self._plotFrame.xAxis.isLog and x.min() <= 0.:
+ raise RuntimeError(
+ 'Cannot add item with X <= 0 with X axis log scale')
+ if self._plotFrame.yAxis.isLog and y.min() <= 0.:
+ raise RuntimeError(
+ 'Cannot add item with Y <= 0 with Y axis log scale')
+
+ return _ShapeItem(x, y, shape, color, fill, overlay,
+ linestyle, linewidth, linebgcolor)
+
+ def addMarker(self, x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis):
+ return _MarkerItem(x, y, text, color,
+ symbol, linestyle, linewidth, constraint, yaxis)
+
+ # Remove methods
+
+ def remove(self, item):
+ if isinstance(item, glutils.GLPlotItem):
+ if item.yaxis == 'right':
+ # Check if some curves remains on the right Y axis
+ y2AxisItems = (item for item in self._plot.getItems()
+ if isinstance(item, items.YAxisMixIn) and
+ item.getYAxis() == 'right')
+ self._plotFrame.isY2Axis = next(y2AxisItems, None) is not None
+
+ if item.isInitialized():
+ self._glGarbageCollector.append(item)
+
+ elif isinstance(item, (_MarkerItem, _ShapeItem)):
+ pass # No-op
+
+ else:
+ _logger.error('Unsupported item: %s', str(item))
+
+ # Interaction methods
+
+ _QT_CURSORS = {
+ BackendBase.CURSOR_DEFAULT: qt.Qt.ArrowCursor,
+ BackendBase.CURSOR_POINTING: qt.Qt.PointingHandCursor,
+ BackendBase.CURSOR_SIZE_HOR: qt.Qt.SizeHorCursor,
+ BackendBase.CURSOR_SIZE_VER: qt.Qt.SizeVerCursor,
+ BackendBase.CURSOR_SIZE_ALL: qt.Qt.SizeAllCursor,
+ }
+
+ def setGraphCursorShape(self, cursor):
+ if cursor is None:
+ super(BackendOpenGL, self).unsetCursor()
+ else:
+ cursor = self._QT_CURSORS[cursor]
+ super(BackendOpenGL, self).setCursor(qt.QCursor(cursor))
+
+ def setGraphCursor(self, flag, color, linewidth, linestyle):
+ if linestyle != '-':
+ _logger.warning(
+ "BackendOpenGL.setGraphCursor linestyle parameter ignored")
+
+ if flag:
+ color = colors.rgba(color)
+ crosshairCursor = color, linewidth
+ else:
+ crosshairCursor = None
+
+ if crosshairCursor != self._crosshairCursor:
+ self._crosshairCursor = crosshairCursor
+
+ _PICK_OFFSET = 3 # Offset in pixel used for picking
+
+ def _mouseInPlotArea(self, x, y):
+ """Returns closest visible position in the plot.
+
+ This is performed in Qt widget pixel, not device pixel.
+
+ :param float x: X coordinate in Qt widget pixel
+ :param float y: Y coordinate in Qt widget pixel
+ :return: (x, y) closest point in the plot.
+ :rtype: List[float]
+ """
+ left, top, width, height = self.getPlotBoundsInPixels()
+ return (numpy.clip(x, left, left + width - 1), # TODO -1?
+ numpy.clip(y, top, top + height - 1))
+
+ def __pickCurves(self, item, x, y):
+ """Perform picking on a curve item.
+
+ :param GLPlotCurve2D item:
+ :param float x: X position of the mouse in widget coordinates
+ :param float y: Y position of the mouse in widget coordinates
+ :return: List of indices of picked points or None if not picked
+ :rtype: Union[List[int],None]
+ """
+ offset = self._PICK_OFFSET
+ if item.marker is not None:
+ # Convert markerSize from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ size = item.markerSize / 72. * qtDpi
+ offset = max(size / 2., offset)
+ if item.lineStyle is not None:
+ # Convert line width from points to qt pixels
+ qtDpi = self.getDotsPerInch() / self.getDevicePixelRatio()
+ lineWidth = item.lineWidth / 72. * qtDpi
+ offset = max(lineWidth / 2., offset)
+
+ inAreaPos = self._mouseInPlotArea(x - offset, y - offset)
+ dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
+ axis=item.yaxis, check=True)
+ if dataPos is None:
+ return None
+ xPick0, yPick0 = dataPos
+
+ inAreaPos = self._mouseInPlotArea(x + offset, y + offset)
+ dataPos = self._plot.pixelToData(inAreaPos[0], inAreaPos[1],
+ axis=item.yaxis, check=True)
+ if dataPos is None:
+ return None
+ xPick1, yPick1 = dataPos
+
+ if xPick0 < xPick1:
+ xPickMin, xPickMax = xPick0, xPick1
+ else:
+ xPickMin, xPickMax = xPick1, xPick0
+
+ if yPick0 < yPick1:
+ yPickMin, yPickMax = yPick0, yPick1
+ else:
+ yPickMin, yPickMax = yPick1, yPick0
+
+ # Apply log scale if axis is log
+ if self._plotFrame.xAxis.isLog:
+ xPickMin = numpy.log10(xPickMin)
+ xPickMax = numpy.log10(xPickMax)
+
+ if (item.yaxis == 'left' and self._plotFrame.yAxis.isLog) or (
+ item.yaxis == 'right' and self._plotFrame.y2Axis.isLog):
+ yPickMin = numpy.log10(yPickMin)
+ yPickMax = numpy.log10(yPickMax)
+
+ return item.pick(xPickMin, yPickMin,
+ xPickMax, yPickMax)
+
+ def pickItem(self, x, y, item):
+ # Picking is performed in Qt widget pixels not device pixels
+ dataPos = self._plot.pixelToData(x, y, axis='left', check=True)
+ if dataPos is None:
+ return None # Outside plot area
+
+ if item is None:
+ _logger.error("No item provided for picking")
+ return None
+
+ # Pick markers
+ if isinstance(item, _MarkerItem):
+ yaxis = item['yaxis']
+ pixelPos = self._plot.dataToPixel(
+ item['x'], item['y'], axis=yaxis, check=False)
+ if pixelPos is None:
+ return None # negative coord on a log axis
+
+ if item['x'] is None: # Horizontal line
+ pt1 = self._plot.pixelToData(
+ x, y - self._PICK_OFFSET, axis=yaxis, check=False)
+ pt2 = self._plot.pixelToData(
+ x, y + self._PICK_OFFSET, axis=yaxis, check=False)
+ isPicked = (min(pt1[1], pt2[1]) <= item['y'] <=
+ max(pt1[1], pt2[1]))
+
+ elif item['y'] is None: # Vertical line
+ pt1 = self._plot.pixelToData(
+ x - self._PICK_OFFSET, y, axis=yaxis, check=False)
+ pt2 = self._plot.pixelToData(
+ x + self._PICK_OFFSET, y, axis=yaxis, check=False)
+ isPicked = (min(pt1[0], pt2[0]) <= item['x'] <=
+ max(pt1[0], pt2[0]))
+
+ else:
+ isPicked = (
+ numpy.fabs(x - pixelPos[0]) <= self._PICK_OFFSET and
+ numpy.fabs(y - pixelPos[1]) <= self._PICK_OFFSET)
+
+ return (0,) if isPicked else None
+
+ # Pick image, curve, triangles
+ elif isinstance(item, glutils.GLPlotItem):
+ if isinstance(item, glutils.GLPlotCurve2D):
+ return self.__pickCurves(item, x, y)
+ else:
+ return item.pick(*dataPos) # Might be None
+
+ # Update curve
+
+ def setCurveColor(self, curve, color):
+ pass # TODO
+
+ # Misc.
+
+ def getWidgetHandle(self):
+ return self
+
+ def postRedisplay(self):
+ self.update()
+
+ def replot(self):
+ self.update() # async redraw
+
+ def saveGraph(self, fileName, fileFormat, dpi):
+ if dpi is not None:
+ _logger.warning("saveGraph ignores dpi parameter")
+
+ if fileFormat not in ['png', 'ppm', 'svg', 'tiff']:
+ raise NotImplementedError('Unsupported format: %s' % fileFormat)
+
+ if not self.isValid():
+ _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ width, height = self._plotFrame.size
+ data = numpy.zeros((height, width, 3), dtype=numpy.uint8)
+ else:
+ self.makeCurrent()
+
+ data = numpy.empty(
+ (self._plotFrame.size[1], self._plotFrame.size[0], 3),
+ dtype=numpy.uint8, order='C')
+
+ context = self.context()
+ framebufferTexture = self._plotFBOs.get(context)
+ if framebufferTexture is None:
+ # Fallback, supports direct rendering mode: _paintDirectGL
+ # might have issues as it can read on-screen framebuffer
+ fboName = self.defaultFramebufferObject()
+ width, height = self._plotFrame.size
+ else:
+ fboName = framebufferTexture.name
+ height, width = framebufferTexture.shape
+
+ previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fboName)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(0, 0, width, height,
+ gl.GL_RGB, gl.GL_UNSIGNED_BYTE, data)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ data = numpy.flipud(data)
+
+ # fileName is either a file-like object or a str
+ saveImageToFile(data, fileName, fileFormat)
+
+ # Graph labels
+
+ def setGraphTitle(self, title):
+ self._plotFrame.title = title
+
+ def setGraphXLabel(self, label):
+ self._plotFrame.xAxis.title = label
+
+ def setGraphYLabel(self, label, axis):
+ if axis == 'left':
+ self._plotFrame.yAxis.title = label
+ else: # right axis
+ self._plotFrame.y2Axis.title = label
+
+ # Graph limits
+
+ def _setDataRanges(self, xlim=None, ylim=None, y2lim=None):
+ """Set the visible range of data in the plot frame.
+
+ This clips the ranges to possible values (takes care of float32
+ range + positive range for log).
+ This also takes care of non-orthogonal axes.
+
+ This should be moved to PlotFrame.
+ """
+ # Update axes range with a clipped range if too wide
+ self._plotFrame.setDataRanges(xlim, ylim, y2lim)
+
+ def _ensureAspectRatio(self, keepDim=None):
+ """Update plot bounds in order to keep aspect ratio.
+
+ Warning: keepDim on right Y axis is not implemented !
+
+ :param str keepDim: The dimension to maintain: 'x', 'y' or None.
+ If None (the default), the dimension with the largest range.
+ """
+ plotWidth, plotHeight = self._plotFrame.plotSize
+ if plotWidth <= 2 or plotHeight <= 2:
+ return
+
+ if keepDim is None:
+ ranges = self._plot.getDataRange()
+ if (ranges.y is not None and
+ ranges.x is not None and
+ (ranges.y[1] - ranges.y[0]) != 0.):
+ dataRatio = (ranges.x[1] - ranges.x[0]) / float(ranges.y[1] - ranges.y[0])
+ plotRatio = plotWidth / float(plotHeight) # Test != 0 before
+
+ keepDim = 'x' if dataRatio > plotRatio else 'y'
+ else: # Limit case
+ keepDim = 'x'
+
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = \
+ self._plotFrame.dataRanges
+ if keepDim == 'y':
+ dataW = (yMax - yMin) * plotWidth / float(plotHeight)
+ xCenter = 0.5 * (xMin + xMax)
+ xMin = xCenter - 0.5 * dataW
+ xMax = xCenter + 0.5 * dataW
+ elif keepDim == 'x':
+ dataH = (xMax - xMin) * plotHeight / float(plotWidth)
+ yCenter = 0.5 * (yMin + yMax)
+ yMin = yCenter - 0.5 * dataH
+ yMax = yCenter + 0.5 * dataH
+ y2Center = 0.5 * (y2Min + y2Max)
+ y2Min = y2Center - 0.5 * dataH
+ y2Max = y2Center + 0.5 * dataH
+ else:
+ raise RuntimeError('Unsupported dimension to keep: %s' % keepDim)
+
+ # Update plot frame bounds
+ self._setDataRanges(xlim=(xMin, xMax),
+ ylim=(yMin, yMax),
+ y2lim=(y2Min, y2Max))
+
+ def _setPlotBounds(self, xRange=None, yRange=None, y2Range=None,
+ keepDim=None):
+ # Update axes range with a clipped range if too wide
+ self._setDataRanges(xlim=xRange,
+ ylim=yRange,
+ y2lim=y2Range)
+
+ # Keep data aspect ratio
+ if self.isKeepDataAspectRatio():
+ self._ensureAspectRatio(keepDim)
+
+ def setLimits(self, xmin, xmax, ymin, ymax, y2min=None, y2max=None):
+ assert xmin < xmax
+ assert ymin < ymax
+
+ if y2min is None or y2max is None:
+ y2Range = None
+ else:
+ assert y2min < y2max
+ y2Range = y2min, y2max
+ self._setPlotBounds((xmin, xmax), (ymin, ymax), y2Range)
+
+ def getGraphXLimits(self):
+ return self._plotFrame.dataRanges.x
+
+ def setGraphXLimits(self, xmin, xmax):
+ assert xmin < xmax
+ self._setPlotBounds(xRange=(xmin, xmax), keepDim='x')
+
+ def getGraphYLimits(self, axis):
+ assert axis in ("left", "right")
+ if axis == "left":
+ return self._plotFrame.dataRanges.y
+ else:
+ return self._plotFrame.dataRanges.y2
+
+ def setGraphYLimits(self, ymin, ymax, axis):
+ assert ymin < ymax
+ assert axis in ("left", "right")
+
+ if axis == "left":
+ self._setPlotBounds(yRange=(ymin, ymax), keepDim='y')
+ else:
+ self._setPlotBounds(y2Range=(ymin, ymax), keepDim='y')
+
+ # Graph axes
+
+ def getXAxisTimeZone(self):
+ return self._plotFrame.xAxis.timeZone
+
+ def setXAxisTimeZone(self, tz):
+ self._plotFrame.xAxis.timeZone = tz
+
+ def isXAxisTimeSeries(self):
+ return self._plotFrame.xAxis.isTimeSeries
+
+ def setXAxisTimeSeries(self, isTimeSeries):
+ self._plotFrame.xAxis.isTimeSeries = isTimeSeries
+
+ def setXAxisLogarithmic(self, flag):
+ if flag != self._plotFrame.xAxis.isLog:
+ if flag and self._keepDataAspectRatio:
+ _logger.warning(
+ "KeepDataAspectRatio is ignored with log axes")
+
+ self._plotFrame.xAxis.isLog = flag
+
+ def setYAxisLogarithmic(self, flag):
+ if (flag != self._plotFrame.yAxis.isLog or
+ flag != self._plotFrame.y2Axis.isLog):
+ if flag and self._keepDataAspectRatio:
+ _logger.warning(
+ "KeepDataAspectRatio is ignored with log axes")
+
+ self._plotFrame.yAxis.isLog = flag
+ self._plotFrame.y2Axis.isLog = flag
+
+ def setYAxisInverted(self, flag):
+ if flag != self._plotFrame.isYAxisInverted:
+ self._plotFrame.isYAxisInverted = flag
+
+ def isYAxisInverted(self):
+ return self._plotFrame.isYAxisInverted
+
+ def isKeepDataAspectRatio(self):
+ if self._plotFrame.xAxis.isLog or self._plotFrame.yAxis.isLog:
+ return False
+ else:
+ return self._keepDataAspectRatio
+
+ def setKeepDataAspectRatio(self, flag):
+ if flag and (self._plotFrame.xAxis.isLog or
+ self._plotFrame.yAxis.isLog):
+ _logger.warning("KeepDataAspectRatio is ignored with log axes")
+
+ self._keepDataAspectRatio = flag
+
+ def setGraphGrid(self, which):
+ assert which in (None, 'major', 'both')
+ self._plotFrame.grid = which is not None # TODO True grid support
+
+ # Data <-> Pixel coordinates conversion
+
+ def dataToPixel(self, x, y, axis):
+ result = self._plotFrame.dataToPixel(x, y, axis)
+ if result is None:
+ return None
+ else:
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(value/devicePixelRatio for value in result)
+
+ def pixelToData(self, x, y, axis):
+ devicePixelRatio = self.getDevicePixelRatio()
+ return self._plotFrame.pixelToData(
+ x * devicePixelRatio, y * devicePixelRatio, axis)
+
+ def getPlotBoundsInPixels(self):
+ devicePixelRatio = self.getDevicePixelRatio()
+ return tuple(int(value / devicePixelRatio)
+ for value in self._plotFrame.plotOrigin + self._plotFrame.plotSize)
+
+ def setAxesMargins(self, left: float, top: float, right: float, bottom: float):
+ self._plotFrame.marginRatios = left, top, right, bottom
+
+ def setForegroundColors(self, foregroundColor, gridColor):
+ self._plotFrame.foregroundColor = foregroundColor
+ self._plotFrame.gridColor = gridColor
+
+ def setBackgroundColors(self, backgroundColor, dataBackgroundColor):
+ self._backgroundColor = backgroundColor
+ self._dataBackgroundColor = dataBackgroundColor
diff --git a/silx/gui/plot/backends/__init__.py b/src/silx/gui/plot/backends/__init__.py
index 966d9df..966d9df 100644
--- a/silx/gui/plot/backends/__init__.py
+++ b/src/silx/gui/plot/backends/__init__.py
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotCurve.py b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
new file mode 100644
index 0000000..e4667b4
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotCurve.py
@@ -0,0 +1,1380 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides classes to render 2D lines and scatter plots
+"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+import math
+import logging
+
+import numpy
+
+from silx.math.combo import min_max
+
+from ...._glutils import gl
+from ...._glutils import Program, vertexBuffer, VertexBufferAttrib
+from .GLSupport import buildFillMaskIndices, mat4Identity, mat4Translate
+from .GLPlotImage import GLPlotItem
+
+
+_logger = logging.getLogger(__name__)
+
+
+_MPL_NONES = None, 'None', '', ' '
+"""Possible values for None"""
+
+
+def _notNaNSlices(array, length=1):
+ """Returns slices of none NaN values in the array.
+
+ :param numpy.ndarray array: 1D array from which to get slices
+ :param int length: Slices shorter than length gets discarded
+ :return: Array of (start, end) slice indices
+ :rtype: numpy.ndarray
+ """
+ isnan = numpy.isnan(numpy.array(array, copy=False).reshape(-1))
+ notnan = numpy.logical_not(isnan)
+ start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
+ if notnan[0]:
+ start = numpy.append(0, start)
+ end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
+ if notnan[-1]:
+ end = numpy.append(end, len(array))
+ slices = numpy.transpose((start, end))
+ if length > 1:
+ # discard slices with less than length values
+ slices = slices[numpy.diff(slices, axis=1).ravel() >= length]
+ return slices
+
+
+# fill ########################################################################
+
+class _Fill2D(object):
+ """Object rendering curve filling as polygons
+
+ :param numpy.ndarray xData: X coordinates of points
+ :param numpy.ndarray yData: Y coordinates of points
+ :param float baseline: Y value of the 'bottom' of the fill.
+ 0 for linear Y scale, -38 for log Y scale
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ _PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0.0, 1.0);
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ uniform vec4 color;
+
+ void main(void) {
+ gl_FragColor = color;
+ }
+ """,
+ attrib0='xPos')
+
+ def __init__(self, xData=None, yData=None,
+ baseline=0,
+ color=(0., 0., 0., 1.),
+ offset=(0., 0.)):
+ self.xData = xData
+ self.yData = yData
+ self._xFillVboData = None
+ self._yFillVboData = None
+ self.color = color
+ self.offset = offset
+
+ # Offset baseline
+ self.baseline = baseline - self.offset[1]
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if (self._xFillVboData is None and
+ self.xData is not None and self.yData is not None):
+
+ # Get slices of not NaN values longer than 1 element
+ isnan = numpy.logical_or(numpy.isnan(self.xData), numpy.isnan(self.yData))
+ notnan = numpy.logical_not(isnan)
+ start = numpy.where(numpy.logical_and(isnan[:-1], notnan[1:]))[0] + 1
+ if notnan[0]:
+ start = numpy.append(0, start)
+ end = numpy.where(numpy.logical_and(notnan[:-1], isnan[1:]))[0] + 1
+ if notnan[-1]:
+ end = numpy.append(end, len(isnan))
+ slices = numpy.transpose((start, end))
+ # discard slices with less than length values
+ slices = slices[numpy.diff(slices, axis=1).reshape(-1) >= 2]
+
+ # Number of points: slice + 2 * leading and trailing points
+ # Twice leading and trailing points to produce degenerated triangles
+ nbPoints = numpy.sum(numpy.diff(slices, axis=1)) * 2 + 4 * len(slices)
+ points = numpy.empty((nbPoints, 2), dtype=numpy.float32)
+
+ offset = 0
+ # invert baseline for filling
+ new_y_data = numpy.append(self.yData, self.baseline)
+ for start, end in slices:
+ # Duplicate first point for connecting degenerated triangle
+ points[offset:offset+2] = self.xData[start], new_y_data[start]
+
+ # 2nd point of the polygon is last point
+ points[offset+2] = self.xData[start], self.baseline[start]
+
+ indices = numpy.append(numpy.arange(start, end),
+ numpy.arange(len(self.xData) + end-1, len(self.xData) + start-1, -1))
+ indices = indices[buildFillMaskIndices(len(indices))]
+
+ points[offset+3:offset+3+len(indices), 0] = self.xData[indices % len(self.xData)]
+ points[offset+3:offset+3+len(indices), 1] = new_y_data[indices]
+
+ # Duplicate last point for connecting degenerated triangle
+ points[offset+3+len(indices)] = points[offset+3+len(indices)-1]
+
+ offset += len(indices) + 4
+
+ self._xFillVboData, self._yFillVboData = vertexBuffer(points.T)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ self.prepare()
+
+ if self._xFillVboData is None:
+ return # Nothing to display
+
+ self._PROGRAM.use()
+
+ gl.glUniformMatrix4fv(
+ self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
+ numpy.dot(context.matrix,
+ mat4Translate(*self.offset)).astype(numpy.float32))
+
+ gl.glUniform4f(self._PROGRAM.uniforms['color'], *self.color)
+
+ xPosAttrib = self._PROGRAM.attributes['xPos']
+ yPosAttrib = self._PROGRAM.attributes['yPos']
+
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ self._xFillVboData.setVertexAttrib(xPosAttrib)
+
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ self._yFillVboData.setVertexAttrib(yPosAttrib)
+
+ # Prepare fill mask
+ gl.glEnable(gl.GL_STENCIL_TEST)
+ gl.glStencilMask(1)
+ gl.glStencilFunc(gl.GL_ALWAYS, 1, 1)
+ gl.glStencilOp(gl.GL_INVERT, gl.GL_INVERT, gl.GL_INVERT)
+ gl.glColorMask(gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE, gl.GL_FALSE)
+ gl.glDepthMask(gl.GL_FALSE)
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, self._xFillVboData.size)
+
+ gl.glStencilFunc(gl.GL_EQUAL, 1, 1)
+ # Reset stencil while drawing
+ gl.glStencilOp(gl.GL_ZERO, gl.GL_ZERO, gl.GL_ZERO)
+ gl.glColorMask(gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE, gl.GL_TRUE)
+ gl.glDepthMask(gl.GL_TRUE)
+
+ # Draw directly in NDC
+ gl.glUniformMatrix4fv(self._PROGRAM.uniforms['matrix'], 1, gl.GL_TRUE,
+ mat4Identity().astype(numpy.float32))
+
+ # NDC vertices
+ gl.glVertexAttribPointer(
+ xPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
+ numpy.array((-1., -1., 1., 1.), dtype=numpy.float32))
+ gl.glVertexAttribPointer(
+ yPosAttrib, 1, gl.GL_FLOAT, gl.GL_FALSE, 0,
+ numpy.array((-1., 1., -1., 1.), dtype=numpy.float32))
+
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, 4)
+
+ gl.glDisable(gl.GL_STENCIL_TEST)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.isInitialized():
+ self._xFillVboData.vbo.discard()
+
+ self._xFillVboData = None
+ self._yFillVboData = None
+
+ def isInitialized(self):
+ return self._xFillVboData is not None
+
+
+# line ########################################################################
+
+SOLID, DASHED, DASHDOT, DOTTED = '-', '--', '-.', ':'
+
+
+class GLLines2D(object):
+ """Object rendering curve as a polyline
+
+ :param xVboData: X coordinates VBO
+ :param yVboData: Y coordinates VBO
+ :param colorVboData: VBO of colors
+ :param distVboData: VBO of distance along the polyline
+ :param str style: Line style in: '-', '--', '-.', ':'
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param float width: Line width
+ :param float dashPeriod: Period of dashes
+ :param drawMode: OpenGL drawing mode
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ STYLES = SOLID, DASHED, DASHDOT, DOTTED
+ """Supported line styles"""
+
+ _SOLID_PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.) ;
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_FragColor = vColor;
+ }
+ """,
+ attrib0='xPos')
+
+ # Limitation: Dash using an estimate of distance in screen coord
+ # to avoid computing distance when viewport is resized
+ # results in inequal dashes when viewport aspect ratio is far from 1
+ _DASH_PROGRAM = Program(
+ vertexShader="""
+ #version 120
+
+ uniform mat4 matrix;
+ uniform vec2 halfViewportSize;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+ attribute float distance;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
+ //Estimate distance in pixels
+ vec2 probe = vec2(matrix * vec4(1., 1., 0., 0.)) *
+ halfViewportSize;
+ float pixelPerDataEstimate = length(probe)/sqrt(2.);
+ vDist = distance * pixelPerDataEstimate;
+ vColor = color;
+ }
+ """,
+ fragmentShader="""
+ #version 120
+
+ /* Dashes: [0, x], [y, z]
+ Dash period: w */
+ uniform vec4 dash;
+ uniform vec4 dash2ndColor;
+
+ varying float vDist;
+ varying vec4 vColor;
+
+ void main(void) {
+ float dist = mod(vDist, dash.w);
+ if ((dist > dash.x && dist < dash.y) || dist > dash.z) {
+ if (dash2ndColor.a == 0.) {
+ discard; // Discard full transparent bg color
+ } else {
+ gl_FragColor = dash2ndColor;
+ }
+ } else {
+ gl_FragColor = vColor;
+ }
+ }
+ """,
+ attrib0='xPos')
+
+ def __init__(self, xVboData=None, yVboData=None,
+ colorVboData=None, distVboData=None,
+ style=SOLID, color=(0., 0., 0., 1.), dash2ndColor=None,
+ width=1, dashPeriod=10., drawMode=None,
+ offset=(0., 0.)):
+ if (xVboData is not None and
+ not isinstance(xVboData, VertexBufferAttrib)):
+ xVboData = numpy.array(xVboData, copy=False, dtype=numpy.float32)
+ self.xVboData = xVboData
+
+ if (yVboData is not None and
+ not isinstance(yVboData, VertexBufferAttrib)):
+ yVboData = numpy.array(yVboData, copy=False, dtype=numpy.float32)
+ self.yVboData = yVboData
+
+ # Compute distances if not given while providing numpy array coordinates
+ if (isinstance(self.xVboData, numpy.ndarray) and
+ isinstance(self.yVboData, numpy.ndarray) and
+ distVboData is None):
+ distVboData = distancesFromArrays(self.xVboData, self.yVboData)
+
+ if (distVboData is not None and
+ not isinstance(distVboData, VertexBufferAttrib)):
+ distVboData = numpy.array(
+ distVboData, copy=False, dtype=numpy.float32)
+ self.distVboData = distVboData
+
+ if colorVboData is not None:
+ assert isinstance(colorVboData, VertexBufferAttrib)
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ self.color = color
+ self.dash2ndColor = dash2ndColor
+ self.width = width
+ self._style = None
+ self.style = style
+ self.dashPeriod = dashPeriod
+ self.offset = offset
+
+ self._drawMode = drawMode if drawMode is not None else gl.GL_LINE_STRIP
+
+ @property
+ def style(self):
+ """Line style (Union[str,None])"""
+ return self._style
+
+ @style.setter
+ def style(self, style):
+ if style in _MPL_NONES:
+ self._style = None
+ else:
+ assert style in self.STYLES
+ self._style = style
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ gl.glHint(gl.GL_LINE_SMOOTH_HINT, gl.GL_NICEST)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ width = self.width / 72. * context.dpi
+
+ style = self.style
+ if style is None:
+ return
+
+ elif style == SOLID:
+ program = self._SOLID_PROGRAM
+ program.use()
+
+ else: # DASHED, DASHDOT, DOTTED
+ program = self._DASH_PROGRAM
+ program.use()
+
+ x, y, viewWidth, viewHeight = gl.glGetFloatv(gl.GL_VIEWPORT)
+ gl.glUniform2f(program.uniforms['halfViewportSize'],
+ 0.5 * viewWidth, 0.5 * viewHeight)
+
+ dashPeriod = self.dashPeriod * width
+ if self.style == DOTTED:
+ dash = (0.2 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.7 * dashPeriod,
+ dashPeriod)
+ elif self.style == DASHDOT:
+ dash = (0.3 * dashPeriod,
+ 0.5 * dashPeriod,
+ 0.6 * dashPeriod,
+ dashPeriod)
+ else:
+ dash = (0.5 * dashPeriod,
+ dashPeriod,
+ dashPeriod,
+ dashPeriod)
+
+ gl.glUniform4f(program.uniforms['dash'], *dash)
+
+ if self.dash2ndColor is None:
+ # Use fully transparent color which gets discarded in shader
+ dash2ndColor = (0., 0., 0., 0.)
+ else:
+ dash2ndColor = self.dash2ndColor
+ gl.glUniform4f(program.uniforms['dash2ndColor'], *dash2ndColor)
+
+ distAttrib = program.attributes['distance']
+ gl.glEnableVertexAttribArray(distAttrib)
+ if isinstance(self.distVboData, VertexBufferAttrib):
+ self.distVboData.setVertexAttrib(distAttrib)
+ else:
+ gl.glVertexAttribPointer(distAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.distVboData)
+
+ if width != 1:
+ gl.glEnable(gl.GL_LINE_SMOOTH)
+
+ matrix = numpy.dot(context.matrix,
+ mat4Translate(*self.offset)).astype(numpy.float32)
+ gl.glUniformMatrix4fv(program.uniforms['matrix'],
+ 1, gl.GL_TRUE, matrix)
+
+ colorAttrib = program.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(colorAttrib)
+ self.colorVboData.setVertexAttrib(colorAttrib)
+ else:
+ gl.glDisableVertexAttribArray(colorAttrib)
+ gl.glVertexAttrib4f(colorAttrib, *self.color)
+
+ xPosAttrib = program.attributes['xPos']
+ gl.glEnableVertexAttribArray(xPosAttrib)
+ if isinstance(self.xVboData, VertexBufferAttrib):
+ self.xVboData.setVertexAttrib(xPosAttrib)
+ else:
+ gl.glVertexAttribPointer(xPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.xVboData)
+
+ yPosAttrib = program.attributes['yPos']
+ gl.glEnableVertexAttribArray(yPosAttrib)
+ if isinstance(self.yVboData, VertexBufferAttrib):
+ self.yVboData.setVertexAttrib(yPosAttrib)
+ else:
+ gl.glVertexAttribPointer(yPosAttrib,
+ 1,
+ gl.GL_FLOAT,
+ False,
+ 0,
+ self.yVboData)
+
+ gl.glLineWidth(width)
+ gl.glDrawArrays(self._drawMode, 0, self.xVboData.size)
+
+ gl.glDisable(gl.GL_LINE_SMOOTH)
+
+
+def distancesFromArrays(xData, yData):
+ """Returns distances between each points
+
+ :param numpy.ndarray xData: X coordinate of points
+ :param numpy.ndarray yData: Y coordinate of points
+ :rtype: numpy.ndarray
+ """
+ # Split array into sub-shapes at not finite points
+ splits = numpy.nonzero(numpy.logical_not(numpy.logical_and(
+ numpy.isfinite(xData), numpy.isfinite(yData))))[0]
+ splits = numpy.concatenate(([-1], splits, [len(xData) - 1]))
+
+ # Compute distance independently for each sub-shapes,
+ # putting not finite points as last points of sub-shapes
+ distances = []
+ for begin, end in zip(splits[:-1] + 1, splits[1:] + 1):
+ if begin == end: # Empty shape
+ continue
+ elif end - begin == 1: # Single element
+ distances.append([0])
+ else:
+ deltas = numpy.dstack((
+ numpy.ediff1d(xData[begin:end], to_begin=numpy.float32(0.)),
+ numpy.ediff1d(yData[begin:end], to_begin=numpy.float32(0.))))[0]
+ distances.append(
+ numpy.cumsum(numpy.sqrt(numpy.sum(deltas ** 2, axis=1))))
+ return numpy.concatenate(distances)
+
+
+# points ######################################################################
+
+DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK = \
+ 'd', 'o', 's', '+', 'x', '.', ',', '*'
+
+H_LINE, V_LINE, HEART = '_', '|', u'\u2665'
+
+TICK_LEFT = "tickleft"
+TICK_RIGHT = "tickright"
+TICK_UP = "tickup"
+TICK_DOWN = "tickdown"
+CARET_LEFT = "caretleft"
+CARET_RIGHT = "caretright"
+CARET_UP = "caretup"
+CARET_DOWN = "caretdown"
+
+
+class _Points2D(object):
+ """Object rendering curve markers
+
+ :param xVboData: X coordinates VBO
+ :param yVboData: Y coordinates VBO
+ :param colorVboData: VBO of colors
+ :param str marker: Kind of symbol to use, see :attr:`MARKERS`.
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param float size: Marker size
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ MARKERS = (DIAMOND, CIRCLE, SQUARE, PLUS, X_MARKER, POINT, PIXEL, ASTERISK,
+ H_LINE, V_LINE, HEART, TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN,
+ CARET_LEFT, CARET_RIGHT, CARET_UP, CARET_DOWN)
+ """List of supported markers"""
+
+ _VERTEX_SHADER = """
+ #version 120
+
+ uniform mat4 matrix;
+ uniform int transform;
+ uniform float size;
+ attribute float xPos;
+ attribute float yPos;
+ attribute vec4 color;
+
+ varying vec4 vColor;
+
+ void main(void) {
+ gl_Position = matrix * vec4(xPos, yPos, 0., 1.);
+ vColor = color;
+ gl_PointSize = size;
+ }
+ """
+
+ _FRAGMENT_SHADER_SYMBOLS = {
+ DIAMOND: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 centerCoord = abs(coord - vec2(0.5, 0.5));
+ float f = centerCoord.x + centerCoord.y;
+ return clamp(size * (0.5 - f), 0.0, 1.0);
+ }
+ """,
+ CIRCLE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float radius = 0.5;
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (radius - r), 0.0, 1.0);
+ }
+ """,
+ SQUARE: """
+ float alphaSymbol(vec2 coord, float size) {
+ return 1.0;
+ }
+ """,
+ PLUS: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 d = abs(size * (coord - vec2(0.5, 0.5)));
+ if (min(d.x, d.y) < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ X_MARKER: """
+ float alphaSymbol(vec2 coord, float size) {
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_x.x, d_x.y) <= 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ ASTERISK: """
+ float alphaSymbol(vec2 coord, float size) {
+ /* Combining +, x and circle */
+ vec2 d_plus = abs(size * (coord - vec2(0.5, 0.5)));
+ vec2 pos = floor(size * coord) + 0.5;
+ vec2 d_x = abs(pos.x + vec2(- pos.y, pos.y - size));
+ if (min(d_plus.x, d_plus.y) < 0.5) {
+ return 1.0;
+ } else if (min(d_x.x, d_x.y) <= 0.5) {
+ float r = distance(coord, vec2(0.5, 0.5));
+ return clamp(size * (0.5 - r), 0.0, 1.0);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ H_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dy = abs(size * (coord.y - 0.5));
+ if (dy < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ V_LINE: """
+ float alphaSymbol(vec2 coord, float size) {
+ float dx = abs(size * (coord.x - 0.5));
+ if (dx < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ HEART: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = (coord - 0.5) * 2.;
+ coord *= 0.75;
+ coord.y += 0.25;
+ float a = atan(coord.x,-coord.y)/3.141593;
+ float r = length(coord);
+ float h = abs(a);
+ float d = (13.0*h - 22.0*h*h + 10.0*h*h*h)/(6.0-5.0*h);
+ float res = clamp(r-d, 0., 1.);
+ // antialiasing
+ res = smoothstep(0.1, 0.001, res);
+ return res;
+ }
+ """,
+ TICK_LEFT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dy = abs(coord.y);
+ if (dy < 0.5 && coord.x < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ TICK_RIGHT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dy = abs(coord.y);
+ if (dy < 0.5 && coord.x > -0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ TICK_UP: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dx = abs(coord.x);
+ if (dx < 0.5 && coord.y < 0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ TICK_DOWN: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float dx = abs(coord.x);
+ if (dx < 0.5 && coord.y > -0.5) {
+ return 1.0;
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_LEFT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.x) - abs(coord.y);
+ if (d >= -0.1 && coord.x > 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_RIGHT: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.x) - abs(coord.y);
+ if (d >= -0.1 && coord.x < 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_UP: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.y) - abs(coord.x);
+ if (d >= -0.1 && coord.y > 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ CARET_DOWN: """
+ float alphaSymbol(vec2 coord, float size) {
+ coord = size * (coord - 0.5);
+ float d = abs(coord.y) - abs(coord.x);
+ if (d >= -0.1 && coord.y < 0.5) {
+ return smoothstep(-0.1, 0.1, d);
+ } else {
+ return 0.0;
+ }
+ }
+ """,
+ }
+
+ _FRAGMENT_SHADER_TEMPLATE = """
+ #version 120
+
+ uniform float size;
+
+ varying vec4 vColor;
+
+ %s
+
+ void main(void) {
+ float alpha = alphaSymbol(gl_PointCoord, size);
+ if (alpha <= 0.0) {
+ discard;
+ } else {
+ gl_FragColor = vec4(vColor.rgb, alpha * clamp(vColor.a, 0.0, 1.0));
+ }
+ }
+ """
+
+ _PROGRAMS = {}
+
+ def __init__(self, xVboData=None, yVboData=None, colorVboData=None,
+ marker=SQUARE, color=(0., 0., 0., 1.), size=7,
+ offset=(0., 0.)):
+ self.color = color
+ self._marker = None
+ self.marker = marker
+ self.size = size
+ self.offset = offset
+
+ self.xVboData = xVboData
+ self.yVboData = yVboData
+ self.colorVboData = colorVboData
+ self.useColorVboData = colorVboData is not None
+
+ @property
+ def marker(self):
+ """Symbol used to display markers (str)"""
+ return self._marker
+
+ @marker.setter
+ def marker(self, marker):
+ if marker in _MPL_NONES:
+ self._marker = None
+ else:
+ assert marker in self.MARKERS
+ self._marker = marker
+
+ @classmethod
+ def _getProgram(cls, marker):
+ """On-demand shader program creation."""
+ if marker == PIXEL:
+ marker = SQUARE
+ elif marker == POINT:
+ marker = CIRCLE
+
+ if marker not in cls._PROGRAMS:
+ cls._PROGRAMS[marker] = Program(
+ vertexShader=cls._VERTEX_SHADER,
+ fragmentShader=(cls._FRAGMENT_SHADER_TEMPLATE %
+ cls._FRAGMENT_SHADER_SYMBOLS[marker]),
+ attrib0='xPos')
+
+ return cls._PROGRAMS[marker]
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ version = gl.glGetString(gl.GL_VERSION)
+ majorVersion = int(version[0])
+ assert majorVersion >= 2
+ gl.glEnable(gl.GL_VERTEX_PROGRAM_POINT_SIZE) # OpenGL 2
+ gl.glEnable(gl.GL_POINT_SPRITE) # OpenGL 2
+ if majorVersion >= 3: # OpenGL 3
+ gl.glEnable(gl.GL_PROGRAM_POINT_SIZE)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ if self.marker is None:
+ return
+
+ program = self._getProgram(self.marker)
+ program.use()
+
+ matrix = numpy.dot(context.matrix,
+ mat4Translate(*self.offset)).astype(numpy.float32)
+ gl.glUniformMatrix4fv(program.uniforms['matrix'], 1, gl.GL_TRUE, matrix)
+
+ if self.marker == PIXEL:
+ size = 1
+ elif self.marker == POINT:
+ size = math.ceil(0.5 * self.size) + 1 # Mimic Matplotlib point
+ else:
+ size = self.size
+ size = size / 72. * context.dpi
+
+ if self.marker in (PLUS, H_LINE, V_LINE,
+ TICK_LEFT, TICK_RIGHT, TICK_UP, TICK_DOWN):
+ # Convert to nearest odd number
+ size = size // 2 * 2 + 1.
+
+ gl.glUniform1f(program.uniforms['size'], size)
+ # gl.glPointSize(self.size)
+
+ cAttrib = program.attributes['color']
+ if self.useColorVboData and self.colorVboData is not None:
+ gl.glEnableVertexAttribArray(cAttrib)
+ self.colorVboData.setVertexAttrib(cAttrib)
+ else:
+ gl.glDisableVertexAttribArray(cAttrib)
+ gl.glVertexAttrib4f(cAttrib, *self.color)
+
+ xAttrib = program.attributes['xPos']
+ gl.glEnableVertexAttribArray(xAttrib)
+ self.xVboData.setVertexAttrib(xAttrib)
+
+ yAttrib = program.attributes['yPos']
+ gl.glEnableVertexAttribArray(yAttrib)
+ self.yVboData.setVertexAttrib(yAttrib)
+
+ gl.glDrawArrays(gl.GL_POINTS, 0, self.xVboData.size)
+
+ gl.glUseProgram(0)
+
+
+# error bars ##################################################################
+
+class _ErrorBars(object):
+ """Display errors bars.
+
+ This is using its own VBO as opposed to fill/points/lines.
+ There is no picking on error bars.
+
+ It uses 2 vertices per error bars and uses :class:`GLLines2D` to
+ render error bars and :class:`_Points2D` to render the ends.
+
+ :param numpy.ndarray xData: X coordinates of the data.
+ :param numpy.ndarray yData: Y coordinates of the data.
+ :param xError: The absolute error on the X axis.
+ :type xError: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for negative errors,
+ row 1 for positive errors.
+ :param yError: The absolute error on the Y axis.
+ :type yError: A float, or a numpy.ndarray of float32. See xError.
+ :param float xMin: The min X value already computed by GLPlotCurve2D.
+ :param float yMin: The min Y value already computed by GLPlotCurve2D.
+ :param List[float] color: RGBA color as 4 float in [0, 1]
+ :param List[float] offset: Translation of coordinates (ox, oy)
+ """
+
+ def __init__(self, xData, yData, xError, yError,
+ xMin, yMin,
+ color=(0., 0., 0., 1.),
+ offset=(0., 0.)):
+ self._attribs = None
+ self._xMin, self._yMin = xMin, yMin
+ self.offset = offset
+
+ if xError is not None or yError is not None:
+ self._xData = numpy.array(
+ xData, order='C', dtype=numpy.float32, copy=False)
+ self._yData = numpy.array(
+ yData, order='C', dtype=numpy.float32, copy=False)
+
+ # This also works if xError, yError is a float/int
+ self._xError = numpy.array(
+ xError, order='C', dtype=numpy.float32, copy=False)
+ self._yError = numpy.array(
+ yError, order='C', dtype=numpy.float32, copy=False)
+ else:
+ self._xData, self._yData = None, None
+ self._xError, self._yError = None, None
+
+ self._lines = GLLines2D(
+ None, None, color=color, drawMode=gl.GL_LINES, offset=offset)
+ self._xErrPoints = _Points2D(
+ None, None, color=color, marker=V_LINE, offset=offset)
+ self._yErrPoints = _Points2D(
+ None, None, color=color, marker=H_LINE, offset=offset)
+
+ def _buildVertices(self):
+ """Generates error bars vertices"""
+ nbLinesPerDataPts = (0 if self._xError is None else 2) + \
+ (0 if self._yError is None else 2)
+
+ nbDataPts = len(self._xData)
+
+ # interleave coord+error, coord-error.
+ # xError vertices first if any, then yError vertices if any.
+ xCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
+ dtype=numpy.float32)
+ yCoords = numpy.empty(nbDataPts * nbLinesPerDataPts * 2,
+ dtype=numpy.float32)
+
+ if self._xError is not None: # errors on the X axis
+ if len(self._xError.shape) == 2:
+ xErrorMinus, xErrorPlus = self._xError[0], self._xError[1]
+ else:
+ # numpy arrays of len 1 or len(xData)
+ xErrorMinus, xErrorPlus = self._xError, self._xError
+
+ # Interleave vertices for xError
+ endXError = 4 * nbDataPts
+ with numpy.errstate(invalid="ignore"):
+ xCoords[0:endXError-3:4] = self._xData + xErrorPlus
+ xCoords[1:endXError-2:4] = self._xData
+ xCoords[2:endXError-1:4] = self._xData
+ with numpy.errstate(invalid="ignore"):
+ xCoords[3:endXError:4] = self._xData - xErrorMinus
+
+ yCoords[0:endXError-3:4] = self._yData
+ yCoords[1:endXError-2:4] = self._yData
+ yCoords[2:endXError-1:4] = self._yData
+ yCoords[3:endXError:4] = self._yData
+
+ else:
+ endXError = 0
+
+ if self._yError is not None: # errors on the Y axis
+ if len(self._yError.shape) == 2:
+ yErrorMinus, yErrorPlus = self._yError[0], self._yError[1]
+ else:
+ # numpy arrays of len 1 or len(yData)
+ yErrorMinus, yErrorPlus = self._yError, self._yError
+
+ # Interleave vertices for yError
+ xCoords[endXError::4] = self._xData
+ xCoords[endXError+1::4] = self._xData
+ xCoords[endXError+2::4] = self._xData
+ xCoords[endXError+3::4] = self._xData
+
+ with numpy.errstate(invalid="ignore"):
+ yCoords[endXError::4] = self._yData + yErrorPlus
+ yCoords[endXError+1::4] = self._yData
+ yCoords[endXError+2::4] = self._yData
+ with numpy.errstate(invalid="ignore"):
+ yCoords[endXError+3::4] = self._yData - yErrorMinus
+
+ return xCoords, yCoords
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if self._xData is None:
+ return
+
+ if self._attribs is None:
+ xCoords, yCoords = self._buildVertices()
+
+ xAttrib, yAttrib = vertexBuffer((xCoords, yCoords))
+ self._attribs = xAttrib, yAttrib
+
+ self._lines.xVboData = xAttrib
+ self._lines.yVboData = yAttrib
+
+ # Set xError points using the same VBO as lines
+ self._xErrPoints.xVboData = xAttrib.copy()
+ self._xErrPoints.xVboData.size //= 2
+ self._xErrPoints.yVboData = yAttrib.copy()
+ self._xErrPoints.yVboData.size //= 2
+
+ # Set yError points using the same VBO as lines
+ self._yErrPoints.xVboData = xAttrib.copy()
+ self._yErrPoints.xVboData.size //= 2
+ self._yErrPoints.xVboData.offset += (xAttrib.itemsize *
+ xAttrib.size // 2)
+ self._yErrPoints.yVboData = yAttrib.copy()
+ self._yErrPoints.yVboData.size //= 2
+ self._yErrPoints.yVboData.offset += (yAttrib.itemsize *
+ yAttrib.size // 2)
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context:
+ """
+ self.prepare()
+
+ if self._attribs is not None:
+ self._lines.render(context)
+ self._xErrPoints.render(context)
+ self._yErrPoints.render(context)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.isInitialized():
+ self._lines.xVboData, self._lines.yVboData = None, None
+ self._xErrPoints.xVboData, self._xErrPoints.yVboData = None, None
+ self._yErrPoints.xVboData, self._yErrPoints.yVboData = None, None
+ self._attribs[0].vbo.discard()
+ self._attribs = None
+
+ def isInitialized(self):
+ return self._attribs is not None
+
+
+# curves ######################################################################
+
+def _proxyProperty(*componentsAttributes):
+ """Create a property to access an attribute of attribute(s).
+ Useful for composition.
+ Supports multiple components this way:
+ getter returns the first found, setter sets all
+ """
+ def getter(self):
+ for compName, attrName in componentsAttributes:
+ try:
+ component = getattr(self, compName)
+ except AttributeError:
+ pass
+ else:
+ return getattr(component, attrName)
+
+ def setter(self, value):
+ for compName, attrName in componentsAttributes:
+ component = getattr(self, compName)
+ setattr(component, attrName, value)
+ return property(getter, setter)
+
+
+class GLPlotCurve2D(GLPlotItem):
+ def __init__(self, xData, yData, colorData=None,
+ xError=None, yError=None,
+ lineStyle=SOLID,
+ lineColor=(0., 0., 0., 1.),
+ lineWidth=1,
+ lineDashPeriod=20,
+ marker=SQUARE,
+ markerColor=(0., 0., 0., 1.),
+ markerSize=7,
+ fillColor=None,
+ baseline=None,
+ isYLog=False):
+ super().__init__()
+ self.colorData = colorData
+
+ # Compute x bounds
+ if xError is None:
+ self.xMin, self.xMax = min_max(xData, min_positive=False)
+ else:
+ # Takes the error into account
+ if hasattr(xError, 'shape') and len(xError.shape) == 2:
+ xErrorMinus, xErrorPlus = xError[0], xError[1]
+ else:
+ xErrorMinus, xErrorPlus = xError, xError
+ self.xMin = numpy.nanmin(xData - xErrorMinus)
+ self.xMax = numpy.nanmax(xData + xErrorPlus)
+
+ # Compute y bounds
+ if yError is None:
+ self.yMin, self.yMax = min_max(yData, min_positive=False)
+ else:
+ # Takes the error into account
+ if hasattr(yError, 'shape') and len(yError.shape) == 2:
+ yErrorMinus, yErrorPlus = yError[0], yError[1]
+ else:
+ yErrorMinus, yErrorPlus = yError, yError
+ self.yMin = numpy.nanmin(yData - yErrorMinus)
+ self.yMax = numpy.nanmax(yData + yErrorPlus)
+
+ # Handle data offset
+ if xData.itemsize > 4 or yData.itemsize > 4: # Use normalization
+ # offset data, do not offset error as it is relative
+ self.offset = self.xMin, self.yMin
+ with numpy.errstate(invalid="ignore"):
+ self.xData = (xData - self.offset[0]).astype(numpy.float32)
+ self.yData = (yData - self.offset[1]).astype(numpy.float32)
+
+ else: # float32
+ self.offset = 0., 0.
+ self.xData = xData
+ self.yData = yData
+ if fillColor is not None:
+ def deduce_baseline(baseline):
+ if baseline is None:
+ _baseline = 0
+ else:
+ _baseline = baseline
+ if not isinstance(_baseline, numpy.ndarray):
+ _baseline = numpy.repeat(_baseline,
+ len(self.xData))
+ if isYLog is True:
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ log_val = numpy.log10(_baseline)
+ _baseline = numpy.where(_baseline>0.0, log_val, -38)
+ return _baseline
+
+ _baseline = deduce_baseline(baseline)
+
+ # Use different baseline depending of Y log scale
+ self.fill = _Fill2D(self.xData, self.yData,
+ baseline=_baseline,
+ color=fillColor,
+ offset=self.offset)
+ else:
+ self.fill = None
+
+ self._errorBars = _ErrorBars(self.xData, self.yData,
+ xError, yError,
+ self.xMin, self.yMin,
+ offset=self.offset)
+
+ self.lines = GLLines2D()
+ self.lines.style = lineStyle
+ self.lines.color = lineColor
+ self.lines.width = lineWidth
+ self.lines.dashPeriod = lineDashPeriod
+ self.lines.offset = self.offset
+
+ self.points = _Points2D()
+ self.points.marker = marker
+ self.points.color = markerColor
+ self.points.size = markerSize
+ self.points.offset = self.offset
+
+ xVboData = _proxyProperty(('lines', 'xVboData'), ('points', 'xVboData'))
+
+ yVboData = _proxyProperty(('lines', 'yVboData'), ('points', 'yVboData'))
+
+ colorVboData = _proxyProperty(('lines', 'colorVboData'),
+ ('points', 'colorVboData'))
+
+ useColorVboData = _proxyProperty(('lines', 'useColorVboData'),
+ ('points', 'useColorVboData'))
+
+ distVboData = _proxyProperty(('lines', 'distVboData'))
+
+ lineStyle = _proxyProperty(('lines', 'style'))
+
+ lineColor = _proxyProperty(('lines', 'color'))
+
+ lineWidth = _proxyProperty(('lines', 'width'))
+
+ lineDashPeriod = _proxyProperty(('lines', 'dashPeriod'))
+
+ marker = _proxyProperty(('points', 'marker'))
+
+ markerColor = _proxyProperty(('points', 'color'))
+
+ markerSize = _proxyProperty(('points', 'size'))
+
+ @classmethod
+ def init(cls):
+ """OpenGL context initialization"""
+ GLLines2D.init()
+ _Points2D.init()
+
+ def prepare(self):
+ """Rendering preparation: build indices and bounding box vertices"""
+ if self.xVboData is None:
+ xAttrib, yAttrib, cAttrib, dAttrib = None, None, None, None
+ if self.lineStyle in (DASHED, DASHDOT, DOTTED):
+ dists = distancesFromArrays(self.xData, self.yData)
+ if self.colorData is None:
+ xAttrib, yAttrib, dAttrib = vertexBuffer(
+ (self.xData, self.yData, dists))
+ else:
+ xAttrib, yAttrib, cAttrib, dAttrib = vertexBuffer(
+ (self.xData, self.yData, self.colorData, dists))
+ elif self.colorData is None:
+ xAttrib, yAttrib = vertexBuffer((self.xData, self.yData))
+ else:
+ xAttrib, yAttrib, cAttrib = vertexBuffer(
+ (self.xData, self.yData, self.colorData))
+
+ self.xVboData = xAttrib
+ self.yVboData = yAttrib
+ self.distVboData = dAttrib
+
+ if cAttrib is not None and self.colorData.dtype.kind == 'u':
+ cAttrib.normalization = True # Normalize uint to [0, 1]
+ self.colorVboData = cAttrib
+ self.useColorVboData = cAttrib is not None
+
+ def render(self, context):
+ """Perform rendering
+
+ :param RenderContext context: Rendering information
+ """
+ self.prepare()
+ if self.fill is not None:
+ self.fill.render(context)
+ self._errorBars.render(context)
+ self.lines.render(context)
+ self.points.render(context)
+
+ def discard(self):
+ """Release VBOs"""
+ if self.xVboData is not None:
+ self.xVboData.vbo.discard()
+
+ self.xVboData = None
+ self.yVboData = None
+ self.colorVboData = None
+ self.distVboData = None
+
+ self._errorBars.discard()
+ if self.fill is not None:
+ self.fill.discard()
+
+ def isInitialized(self):
+ return (self.xVboData is not None or
+ self._errorBars.isInitialized() or
+ (self.fill is not None and self.fill.isInitialized()))
+
+ def pick(self, xPickMin, yPickMin, xPickMax, yPickMax):
+ """Perform picking on the curve according to its rendering.
+
+ The picking area is [xPickMin, xPickMax], [yPickMin, yPickMax].
+
+ In case a segment between 2 points with indices i, i+1 is picked,
+ only its lower index end point (i.e., i) is added to the result.
+ In case an end point with index i is picked it is added to the result,
+ and the segment [i-1, i] is not tested for picking.
+
+ :return: The indices of the picked data
+ :rtype: Union[List[int],None]
+ """
+ if (self.marker is None and self.lineStyle is None) or \
+ self.xMin > xPickMax or xPickMin > self.xMax or \
+ self.yMin > yPickMax or yPickMin > self.yMax:
+ return None
+
+ # offset picking bounds
+ xPickMin = xPickMin - self.offset[0]
+ xPickMax = xPickMax - self.offset[0]
+ yPickMin = yPickMin - self.offset[1]
+ yPickMax = yPickMax - self.offset[1]
+
+ if self.lineStyle is not None:
+ # Using Cohen-Sutherland algorithm for line clipping
+ with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
+ codes = ((self.yData > yPickMax) << 3) | \
+ ((self.yData < yPickMin) << 2) | \
+ ((self.xData > xPickMax) << 1) | \
+ (self.xData < xPickMin)
+
+ notNaN = numpy.logical_not(numpy.logical_or(
+ numpy.isnan(self.xData), numpy.isnan(self.yData)))
+
+ # Add all points that are inside the picking area
+ indices = numpy.nonzero(
+ numpy.logical_and(codes == 0, notNaN))[0].tolist()
+
+ # Segment that might cross the area with no end point inside it
+ segToTestIdx = numpy.nonzero((codes[:-1] != 0) &
+ (codes[1:] != 0) &
+ ((codes[:-1] & codes[1:]) == 0))[0]
+
+ TOP, BOTTOM, RIGHT, LEFT = (1 << 3), (1 << 2), (1 << 1), (1 << 0)
+
+ for index in segToTestIdx:
+ if index not in indices:
+ x0, y0 = self.xData[index], self.yData[index]
+ x1, y1 = self.xData[index + 1], self.yData[index + 1]
+ code1 = codes[index + 1]
+
+ # check for crossing with horizontal bounds
+ # y0 == y1 is a never event:
+ # => pt0 and pt1 in same vertical area are not in segToTest
+ if code1 & TOP:
+ x = x0 + (x1 - x0) * (yPickMax - y0) / (y1 - y0)
+ elif code1 & BOTTOM:
+ x = x0 + (x1 - x0) * (yPickMin - y0) / (y1 - y0)
+ else:
+ x = None # No horizontal bounds intersection test
+
+ if x is not None and xPickMin <= x <= xPickMax:
+ # Intersection
+ indices.append(index)
+
+ else:
+ # check for crossing with vertical bounds
+ # x0 == x1 is a never event (see remark for y)
+ if code1 & RIGHT:
+ y = y0 + (y1 - y0) * (xPickMax - x0) / (x1 - x0)
+ elif code1 & LEFT:
+ y = y0 + (y1 - y0) * (xPickMin - x0) / (x1 - x0)
+ else:
+ y = None # No vertical bounds intersection test
+
+ if y is not None and yPickMin <= y <= yPickMax:
+ # Intersection
+ indices.append(index)
+
+ indices.sort()
+
+ else:
+ with numpy.errstate(invalid='ignore'): # Ignore NaN comparison warnings
+ indices = numpy.nonzero((self.xData >= xPickMin) &
+ (self.xData <= xPickMax) &
+ (self.yData >= yPickMin) &
+ (self.yData <= yPickMax))[0].tolist()
+
+ return tuple(indices) if len(indices) > 0 else None
diff --git a/src/silx/gui/plot/backends/glutils/GLPlotFrame.py b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
new file mode 100644
index 0000000..1fccb02
--- /dev/null
+++ b/src/silx/gui/plot/backends/glutils/GLPlotFrame.py
@@ -0,0 +1,1210 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2014-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This modules provides the rendering of plot titles, axes and grid.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/04/2017"
+
+
+# TODO
+# keep aspect ratio managed here?
+# smarter dirty flag handling?
+
+import datetime as dt
+import math
+import weakref
+import logging
+from collections import namedtuple
+
+import numpy
+
+from ...._glutils import gl, Program
+from ..._utils import checkAxisLimits, FLOAT32_MINPOS
+from .GLSupport import mat4Ortho
+from .GLText import Text2D, CENTER, BOTTOM, TOP, LEFT, RIGHT, ROTATE_270
+from ..._utils.ticklayout import niceNumbersAdaptative, niceNumbersForLog10
+from ..._utils.dtime_ticklayout import calcTicksAdaptive, bestFormatString
+from ..._utils.dtime_ticklayout import timestamp
+
+_logger = logging.getLogger(__name__)
+
+
+# PlotAxis ####################################################################
+
+class PlotAxis(object):
+ """Represents a 1D axis of the plot.
+ This class is intended to be used with :class:`GLPlotFrame`.
+ """
+
+ def __init__(self, plotFrame,
+ tickLength=(0., 0.),
+ foregroundColor=(0., 0., 0., 1.0),
+ labelAlign=CENTER, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=CENTER,
+ titleRotate=0, titleOffset=(0., 0.)):
+ self._ticks = None
+
+ self._plotFrameRef = weakref.ref(plotFrame)
+
+ self._isDateTime = False
+ self._timeZone = None
+ self._isLog = False
+ self._dataRange = 1., 100.
+ self._displayCoords = (0., 0.), (1., 0.)
+ self._title = ''
+
+ self._tickLength = tickLength
+ self._foregroundColor = foregroundColor
+ self._labelAlign = labelAlign
+ self._labelVAlign = labelVAlign
+ self._titleAlign = titleAlign
+ self._titleVAlign = titleVAlign
+ self._titleRotate = titleRotate
+ self._titleOffset = titleOffset
+
+ @property
+ def dataRange(self):
+ """The range of the data represented on the axis as a tuple
+ of 2 floats: (min, max)."""
+ return self._dataRange
+
+ @dataRange.setter
+ def dataRange(self, dataRange):
+ assert len(dataRange) == 2
+ assert dataRange[0] <= dataRange[1]
+ dataRange = float(dataRange[0]), float(dataRange[1])
+
+ if dataRange != self._dataRange:
+ self._dataRange = dataRange
+ self._dirtyTicks()
+
+ @property
+ def isLog(self):
+ """Whether the axis is using a log10 scale or not as a bool."""
+ return self._isLog
+
+ @isLog.setter
+ def isLog(self, isLog):
+ isLog = bool(isLog)
+ if isLog != self._isLog:
+ self._isLog = isLog
+ self._dirtyTicks()
+
+ @property
+ def timeZone(self):
+ """Returnss datetime.tzinfo that is used if this axis plots date times."""
+ return self._timeZone
+
+ @timeZone.setter
+ def timeZone(self, tz):
+ """Sets dateetime.tzinfo that is used if this axis plots date times."""
+ self._timeZone = tz
+ self._dirtyTicks()
+
+ @property
+ def isTimeSeries(self):
+ """Whether the axis is showing floats as datetime objects"""
+ return self._isDateTime
+
+ @isTimeSeries.setter
+ def isTimeSeries(self, isTimeSeries):
+ isTimeSeries = bool(isTimeSeries)
+ if isTimeSeries != self._isDateTime:
+ self._isDateTime = isTimeSeries
+ self._dirtyTicks()
+
+ @property
+ def displayCoords(self):
+ """The coordinates of the start and end points of the axis
+ in display space (i.e., in pixels) as a tuple of 2 tuples of
+ 2 floats: ((x0, y0), (x1, y1)).
+ """
+ return self._displayCoords
+
+ @displayCoords.setter
+ def displayCoords(self, displayCoords):
+ assert len(displayCoords) == 2
+ assert len(displayCoords[0]) == 2
+ assert len(displayCoords[1]) == 2
+ displayCoords = tuple(displayCoords[0]), tuple(displayCoords[1])
+ if displayCoords != self._displayCoords:
+ self._displayCoords = displayCoords
+ self._dirtyTicks()
+
+ @property
+ def devicePixelRatio(self):
+ """Returns the ratio between qt pixels and device pixels."""
+ plotFrame = self._plotFrameRef()
+ return plotFrame.devicePixelRatio if plotFrame is not None else 1.
+
+ @property
+ def title(self):
+ """The text label associated with this axis as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirtyPlotFrame()
+
+ @property
+ def titleOffset(self):
+ """Title offset in pixels (x: int, y: int)"""
+ return self._titleOffset
+
+ @titleOffset.setter
+ def titleOffset(self, offset):
+ if offset != self._titleOffset:
+ self._titleOffset = offset
+ self._dirtyTicks()
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ self._dirtyTicks()
+
+ @property
+ def ticks(self):
+ """Ticks as tuples: ((x, y) in display, dataPos, textLabel)."""
+ if self._ticks is None:
+ self._ticks = tuple(self._ticksGenerator())
+ return self._ticks
+
+ def getVerticesAndLabels(self):
+ """Create the list of vertices for axis and associated text labels.
+
+ :returns: A tuple: List of 2D line vertices, List of Text2D labels.
+ """
+ vertices = list(self.displayCoords) # Add start and end points
+ labels = []
+ tickLabelsSize = [0., 0.]
+
+ xTickLength, yTickLength = self._tickLength
+ xTickLength *= self.devicePixelRatio
+ yTickLength *= self.devicePixelRatio
+ for (xPixel, yPixel), dataPos, text in self.ticks:
+ if text is None:
+ tickScale = 0.5
+ else:
+ tickScale = 1.
+
+ label = Text2D(text=text,
+ color=self._foregroundColor,
+ x=xPixel - xTickLength,
+ y=yPixel - yTickLength,
+ align=self._labelAlign,
+ valign=self._labelVAlign,
+ devicePixelRatio=self.devicePixelRatio)
+
+ width, height = label.size
+ if width > tickLabelsSize[0]:
+ tickLabelsSize[0] = width
+ if height > tickLabelsSize[1]:
+ tickLabelsSize[1] = height
+
+ labels.append(label)
+
+ vertices.append((xPixel, yPixel))
+ vertices.append((xPixel + tickScale * xTickLength,
+ yPixel + tickScale * yTickLength))
+
+ (x0, y0), (x1, y1) = self.displayCoords
+ xAxisCenter = 0.5 * (x0 + x1)
+ yAxisCenter = 0.5 * (y0 + y1)
+
+ xOffset, yOffset = self.titleOffset
+
+ # Adaptative title positioning:
+ # tickNorm = math.sqrt(xTickLength ** 2 + yTickLength ** 2)
+ # xOffset = -tickLabelsSize[0] * xTickLength / tickNorm
+ # xOffset -= 3 * xTickLength
+ # yOffset = -tickLabelsSize[1] * yTickLength / tickNorm
+ # yOffset -= 3 * yTickLength
+
+ axisTitle = Text2D(text=self.title,
+ color=self._foregroundColor,
+ x=xAxisCenter + xOffset,
+ y=yAxisCenter + yOffset,
+ align=self._titleAlign,
+ valign=self._titleVAlign,
+ rotate=self._titleRotate,
+ devicePixelRatio=self.devicePixelRatio)
+ labels.append(axisTitle)
+
+ return vertices, labels
+
+ def _dirtyPlotFrame(self):
+ """Dirty parent GLPlotFrame"""
+ plotFrame = self._plotFrameRef()
+ if plotFrame is not None:
+ plotFrame._dirty()
+
+ def _dirtyTicks(self):
+ """Mark ticks as dirty and notify listener (i.e., background)."""
+ self._ticks = None
+ self._dirtyPlotFrame()
+
+ @staticmethod
+ def _frange(start, stop, step):
+ """range for float (including stop)."""
+ while start <= stop:
+ yield start
+ start += step
+
+ def _ticksGenerator(self):
+ """Generator of ticks as tuples:
+ ((x, y) in display, dataPos, textLabel).
+ """
+ dataMin, dataMax = self.dataRange
+ if self.isLog and dataMin <= 0.:
+ _logger.warning(
+ 'Getting ticks while isLog=True and dataRange[0]<=0.')
+ dataMin = 1.
+ if dataMax < dataMin:
+ dataMax = 1.
+
+ if dataMin != dataMax: # data range is not null
+ (x0, y0), (x1, y1) = self.displayCoords
+
+ if self.isLog:
+
+ if self.isTimeSeries:
+ _logger.warning("Time series not implemented for log-scale")
+
+ logMin, logMax = math.log10(dataMin), math.log10(dataMax)
+ tickMin, tickMax, step, _ = niceNumbersForLog10(logMin, logMax)
+
+ xScale = (x1 - x0) / (logMax - logMin)
+ yScale = (y1 - y0) / (logMax - logMin)
+
+ for logPos in self._frange(tickMin, tickMax, step):
+ if logMin <= logPos <= logMax:
+ dataPos = 10 ** logPos
+ xPixel = x0 + (logPos - logMin) * xScale
+ yPixel = y0 + (logPos - logMin) * yScale
+ text = '1e%+03d' % logPos
+ yield ((xPixel, yPixel), dataPos, text)
+
+ if step == 1:
+ ticks = list(self._frange(tickMin, tickMax, step))[:-1]
+ for logPos in ticks:
+ dataOrigPos = 10 ** logPos
+ for index in range(2, 10):
+ dataPos = dataOrigPos * index
+ if dataMin <= dataPos <= dataMax:
+ logSubPos = math.log10(dataPos)
+ xPixel = x0 + (logSubPos - logMin) * xScale
+ yPixel = y0 + (logSubPos - logMin) * yScale
+ yield ((xPixel, yPixel), dataPos, None)
+
+ else:
+ xScale = (x1 - x0) / (dataMax - dataMin)
+ yScale = (y1 - y0) / (dataMax - dataMin)
+
+ nbPixels = math.sqrt(pow(x1 - x0, 2) + pow(y1 - y0, 2)) / self.devicePixelRatio
+
+ # Density of 1.3 label per 92 pixels
+ # i.e., 1.3 label per inch on a 92 dpi screen
+ tickDensity = 1.3 / 92
+
+ if not self.isTimeSeries:
+ tickMin, tickMax, step, nbFrac = niceNumbersAdaptative(
+ dataMin, dataMax, nbPixels, tickDensity)
+
+ for dataPos in self._frange(tickMin, tickMax, step):
+ if dataMin <= dataPos <= dataMax:
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+
+ if nbFrac == 0:
+ text = '%g' % dataPos
+ else:
+ text = ('%.' + str(nbFrac) + 'f') % dataPos
+ yield ((xPixel, yPixel), dataPos, text)
+ else:
+ # Time series
+ dtMin = dt.datetime.fromtimestamp(dataMin, tz=self.timeZone)
+ dtMax = dt.datetime.fromtimestamp(dataMax, tz=self.timeZone)
+
+ tickDateTimes, spacing, unit = calcTicksAdaptive(
+ dtMin, dtMax, nbPixels, tickDensity)
+
+ for tickDateTime in tickDateTimes:
+ if dtMin <= tickDateTime <= dtMax:
+
+ dataPos = timestamp(tickDateTime)
+ xPixel = x0 + (dataPos - dataMin) * xScale
+ yPixel = y0 + (dataPos - dataMin) * yScale
+
+ fmtStr = bestFormatString(spacing, unit)
+ text = tickDateTime.strftime(fmtStr)
+
+ yield ((xPixel, yPixel), dataPos, text)
+
+
+# GLPlotFrame #################################################################
+
+class GLPlotFrame(object):
+ """Base class for rendering a 2D frame surrounded by axes."""
+
+ _TICK_LENGTH_IN_PIXELS = 5
+ _LINE_WIDTH = 1
+
+ _SHADERS = {
+ 'vertex': """
+ attribute vec2 position;
+ uniform mat4 matrix;
+
+ void main(void) {
+ gl_Position = matrix * vec4(position, 0.0, 1.0);
+ }
+ """,
+ 'fragment': """
+ uniform vec4 color;
+ uniform float tickFactor; /* = 1./tickLength or 0. for solid line */
+
+ void main(void) {
+ if (mod(tickFactor * (gl_FragCoord.x + gl_FragCoord.y), 2.) < 1.) {
+ gl_FragColor = color;
+ } else {
+ discard;
+ }
+ }
+ """
+ }
+
+ _Margins = namedtuple('Margins', ('left', 'right', 'top', 'bottom'))
+
+ # Margins used when plot frame is not displayed
+ _NoDisplayMargins = _Margins(0, 0, 0, 0)
+
+ def __init__(self, marginRatios, foregroundColor, gridColor):
+ """
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+ """
+ self._renderResources = None
+
+ self.__marginRatios = marginRatios
+ self.__marginsCache = None
+
+ self._foregroundColor = foregroundColor
+ self._gridColor = gridColor
+
+ self.axes = [] # List of PlotAxis to be updated by subclasses
+
+ self._grid = False
+ self._size = 0., 0.
+ self._title = ''
+
+ self._devicePixelRatio = 1.
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return self._renderResources is None
+
+ GRID_NONE = 0
+ GRID_MAIN_TICKS = 1
+ GRID_SUB_TICKS = 2
+ GRID_ALL_TICKS = (GRID_MAIN_TICKS + GRID_SUB_TICKS)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._foregroundColor = color
+ for axis in self.axes:
+ axis.foregroundColor = color
+ self._dirty()
+
+ @property
+ def gridColor(self):
+ """Color used for frame and labels"""
+ return self._gridColor
+
+ @gridColor.setter
+ def gridColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "gridColor must have length 4, got {}".format(len(self._gridColor))
+ if self._gridColor != color:
+ self._gridColor = color
+ self._dirty()
+
+ @property
+ def marginRatios(self):
+ """Plot margin ratios: (left, top, right, bottom) as 4 float in [0, 1].
+ """
+ return self.__marginRatios
+
+ @marginRatios.setter
+ def marginRatios(self, ratios):
+ ratios = tuple(float(v) for v in ratios)
+ assert len(ratios) == 4
+ for value in ratios:
+ assert 0. <= value <= 1.
+ assert ratios[0] + ratios[2] < 1.
+ assert ratios[1] + ratios[3] < 1.
+
+ if self.__marginRatios != ratios:
+ self.__marginRatios = ratios
+ self.__marginsCache = None # Clear cached margins
+ self._dirty()
+
+ @property
+ def margins(self):
+ """Margins in pixels around the plot."""
+ if self.__marginsCache is None:
+ width, height = self.size
+ left, top, right, bottom = self.marginRatios
+ self.__marginsCache = self._Margins(
+ left=int(left*width),
+ right=int(right*width),
+ top=int(top*height),
+ bottom=int(bottom*height))
+ return self.__marginsCache
+
+ @property
+ def devicePixelRatio(self):
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ if ratio != self._devicePixelRatio:
+ self._devicePixelRatio = ratio
+ self._dirty()
+
+ @property
+ def grid(self):
+ """Grid display mode:
+ - 0: No grid.
+ - 1: Grid on main ticks.
+ - 2: Grid on sub-ticks for log scale axes.
+ - 3: Grid on main and sub ticks."""
+ return self._grid
+
+ @grid.setter
+ def grid(self, grid):
+ assert grid in (self.GRID_NONE, self.GRID_MAIN_TICKS,
+ self.GRID_SUB_TICKS, self.GRID_ALL_TICKS)
+ if grid != self._grid:
+ self._grid = grid
+ self._dirty()
+
+ @property
+ def size(self):
+ """Size in device pixels of the plot area including margins."""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ assert len(size) == 2
+ size = tuple(size)
+ if size != self._size:
+ self._size = size
+ self.__marginsCache = None # Clear cached margins
+ self._dirty()
+
+ @property
+ def plotOrigin(self):
+ """Plot area origin (left, top) in widget coordinates in pixels."""
+ return self.margins.left, self.margins.top
+
+ @property
+ def plotSize(self):
+ """Plot area size (width, height) in pixels."""
+ w, h = self.size
+ w -= self.margins.left + self.margins.right
+ h -= self.margins.top + self.margins.bottom
+ return w, h
+
+ @property
+ def title(self):
+ """Main title as a str in latin-1."""
+ return self._title
+
+ @title.setter
+ def title(self, title):
+ if title != self._title:
+ self._title = title
+ self._dirty()
+
+ # In-place update
+ # if self._renderResources is not None:
+ # self._renderResources[-1][-1].text = title
+
+ def _dirty(self):
+ # When Text2D require discard we need to handle it
+ self._renderResources = None
+
+ def _buildGridVertices(self):
+ if self._grid == self.GRID_NONE:
+ return []
+
+ elif self._grid == self.GRID_MAIN_TICKS:
+ def test(text):
+ return text is not None
+ elif self._grid == self.GRID_SUB_TICKS:
+ def test(text):
+ return text is None
+ elif self._grid == self.GRID_ALL_TICKS:
+ def test(_):
+ return True
+ else:
+ logging.warning('Wrong grid mode: %d' % self._grid)
+ return []
+
+ return self._buildGridVerticesWithTest(test)
+
+ def _buildGridVerticesWithTest(self, test):
+ """Override in subclass to generate grid vertices"""
+ return []
+
+ def _buildVerticesAndLabels(self):
+ # To fill with copy of axes lists
+ vertices = []
+ labels = []
+
+ for axis in self.axes:
+ axisVertices, axisLabels = axis.getVerticesAndLabels()
+ vertices += axisVertices
+ labels += axisLabels
+
+ vertices = numpy.array(vertices, dtype=numpy.float32)
+
+ # Add main title
+ xTitle = (self.size[0] + self.margins.left -
+ self.margins.right) // 2
+ yTitle = self.margins.top - self._TICK_LENGTH_IN_PIXELS
+ labels.append(Text2D(text=self.title,
+ color=self._foregroundColor,
+ x=xTitle,
+ y=yTitle,
+ align=CENTER,
+ valign=BOTTOM,
+ devicePixelRatio=self.devicePixelRatio))
+
+ # grid
+ gridVertices = numpy.array(self._buildGridVertices(),
+ dtype=numpy.float32)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ _program = Program(
+ _SHADERS['vertex'], _SHADERS['fragment'], attrib0='position')
+
+ def render(self):
+ if self.margins == self._NoDisplayMargins:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ matProj.astype(numpy.float32))
+ gl.glUniform4f(prog.uniforms['color'], *self._foregroundColor)
+ gl.glUniform1f(prog.uniforms['tickFactor'], 0.)
+
+ gl.glEnableVertexAttribArray(prog.attributes['position'])
+ gl.glVertexAttribPointer(prog.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, vertices)
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(vertices))
+
+ for label in labels:
+ label.render(matProj)
+
+ def renderGrid(self):
+ if self._grid == self.GRID_NONE:
+ return
+
+ if self._renderResources is None:
+ self._buildVerticesAndLabels()
+ vertices, gridVertices, labels = self._renderResources
+
+ width, height = self.size
+ matProj = mat4Ortho(0, width, height, 0, 1, -1)
+
+ gl.glViewport(0, 0, width, height)
+
+ prog = self._program
+ prog.use()
+
+ gl.glLineWidth(self._LINE_WIDTH)
+ gl.glUniformMatrix4fv(prog.uniforms['matrix'], 1, gl.GL_TRUE,
+ matProj.astype(numpy.float32))
+ gl.glUniform4f(prog.uniforms['color'], *self._gridColor)
+ gl.glUniform1f(prog.uniforms['tickFactor'], 0.) # 1/2.) # 1/tickLen
+
+ gl.glEnableVertexAttribArray(prog.attributes['position'])
+ gl.glVertexAttribPointer(prog.attributes['position'],
+ 2,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0, gridVertices)
+
+ gl.glDrawArrays(gl.GL_LINES, 0, len(gridVertices))
+
+
+# GLPlotFrame2D ###############################################################
+
+class GLPlotFrame2D(GLPlotFrame):
+ def __init__(self, marginRatios, foregroundColor, gridColor):
+ """
+ :param List[float] marginRatios:
+ The ratios of margins around plot area for axis and labels.
+ (left, top, right, bottom) as float in [0., 1.]
+ :param foregroundColor: color used for the frame and labels.
+ :type foregroundColor: tuple with RGBA values ranging from 0.0 to 1.0
+ :param gridColor: color used for grid lines.
+ :type gridColor: tuple RGBA with RGBA values ranging from 0.0 to 1.0
+
+ """
+ super(GLPlotFrame2D, self).__init__(marginRatios, foregroundColor, gridColor)
+ self.axes.append(PlotAxis(self,
+ tickLength=(0., -5.),
+ foregroundColor=self._foregroundColor,
+ labelAlign=CENTER, labelVAlign=TOP,
+ titleAlign=CENTER, titleVAlign=TOP,
+ titleRotate=0))
+
+ self._x2AxisCoords = ()
+
+ self.axes.append(PlotAxis(self,
+ tickLength=(5., 0.),
+ foregroundColor=self._foregroundColor,
+ labelAlign=RIGHT, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=BOTTOM,
+ titleRotate=ROTATE_270))
+
+ self._y2Axis = PlotAxis(self,
+ tickLength=(-5., 0.),
+ foregroundColor=self._foregroundColor,
+ labelAlign=LEFT, labelVAlign=CENTER,
+ titleAlign=CENTER, titleVAlign=TOP,
+ titleRotate=ROTATE_270)
+
+ self._isYAxisInverted = False
+
+ self._dataRanges = {
+ 'x': (1., 100.), 'y': (1., 100.), 'y2': (1., 100.)}
+
+ self._baseVectors = (1., 0.), (0., 1.)
+
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ def _dirty(self):
+ super(GLPlotFrame2D, self)._dirty()
+ self._transformedDataRanges = None
+ self._transformedDataProjMat = None
+ self._transformedDataY2ProjMat = None
+
+ @property
+ def isDirty(self):
+ """True if it need to refresh graphic rendering, False otherwise."""
+ return (super(GLPlotFrame2D, self).isDirty or
+ self._transformedDataRanges is None or
+ self._transformedDataProjMat is None or
+ self._transformedDataY2ProjMat is None)
+
+ @property
+ def xAxis(self):
+ return self.axes[0]
+
+ @property
+ def yAxis(self):
+ return self.axes[1]
+
+ @property
+ def y2Axis(self):
+ return self._y2Axis
+
+ @property
+ def isY2Axis(self):
+ """Whether to display the left Y axis or not."""
+ return len(self.axes) == 3
+
+ @isY2Axis.setter
+ def isY2Axis(self, isY2Axis):
+ if isY2Axis != self.isY2Axis:
+ if isY2Axis:
+ self.axes.append(self._y2Axis)
+ else:
+ self.axes = self.axes[:2]
+
+ self._dirty()
+
+ @property
+ def isYAxisInverted(self):
+ """Whether Y axes are inverted or not as a bool."""
+ return self._isYAxisInverted
+
+ @isYAxisInverted.setter
+ def isYAxisInverted(self, value):
+ value = bool(value)
+ if value != self._isYAxisInverted:
+ self._isYAxisInverted = value
+ self._dirty()
+
+ DEFAULT_BASE_VECTORS = (1., 0.), (0., 1.)
+ """Values of baseVectors for orthogonal axes."""
+
+ @property
+ def baseVectors(self):
+ """Coordinates of the X and Y axes in the orthogonal plot coords.
+
+ Raises ValueError if corresponding matrix is singular.
+
+ 2 tuples of 2 floats: (xx, xy), (yx, yy)
+ """
+ return self._baseVectors
+
+ @baseVectors.setter
+ def baseVectors(self, baseVectors):
+ self._dirty()
+
+ (xx, xy), (yx, yy) = baseVectors
+ vectors = (float(xx), float(xy)), (float(yx), float(yy))
+
+ det = (vectors[0][0] * vectors[1][1] - vectors[1][0] * vectors[0][1])
+ if det == 0.:
+ raise ValueError("Singular matrix for base vectors: " +
+ str(vectors))
+
+ if vectors != self._baseVectors:
+ self._baseVectors = vectors
+ self._dirty()
+
+ def _updateTitleOffset(self):
+ """Update axes title offset according to margins"""
+ margins = self.margins
+ self.xAxis.titleOffset = 0, margins.bottom // 2
+ self.yAxis.titleOffset = -3 * margins.left // 4, 0
+ self.y2Axis.titleOffset = 3 * margins.right // 4, 0
+
+ # Override size and marginRatios setters to update titleOffsets
+ @GLPlotFrame.size.setter
+ def size(self, size):
+ GLPlotFrame.size.fset(self, size)
+ self._updateTitleOffset()
+
+ @GLPlotFrame.marginRatios.setter
+ def marginRatios(self, ratios):
+ GLPlotFrame.marginRatios.fset(self, ratios)
+ self._updateTitleOffset()
+
+ @property
+ def dataRanges(self):
+ """Ranges of data visible in the plot on x, y and y2 axes.
+
+ This is different to the axes range when axes are not orthogonal.
+
+ Type: ((xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+ """
+ return self._DataRanges(self._dataRanges['x'],
+ self._dataRanges['y'],
+ self._dataRanges['y2'])
+
+ def setDataRanges(self, x=None, y=None, y2=None):
+ """Set data range over each axes.
+
+ The provided ranges are clipped to possible values
+ (i.e., 32 float range + positive range for log scale).
+
+ :param x: (min, max) data range over X axis
+ :param y: (min, max) data range over Y axis
+ :param y2: (min, max) data range over Y2 axis
+ """
+ if x is not None:
+ self._dataRanges['x'] = checkAxisLimits(
+ x[0], x[1], self.xAxis.isLog, name='x')
+
+ if y is not None:
+ self._dataRanges['y'] = checkAxisLimits(
+ y[0], y[1], self.yAxis.isLog, name='y')
+
+ if y2 is not None:
+ self._dataRanges['y2'] = checkAxisLimits(
+ y2[0], y2[1], self.y2Axis.isLog, name='y2')
+
+ self.xAxis.dataRange = self._dataRanges['x']
+ self.yAxis.dataRange = self._dataRanges['y']
+ self.y2Axis.dataRange = self._dataRanges['y2']
+
+ _DataRanges = namedtuple('dataRanges', ('x', 'y', 'y2'))
+
+ @property
+ def transformedDataRanges(self):
+ """Bounds of the displayed area in transformed data coordinates
+ (i.e., log scale applied if any as well as skew)
+
+ 3-tuple of 2-tuple (min, max) for each axis: x, y, y2.
+ """
+ if self._transformedDataRanges is None:
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max) = self.dataRanges
+
+ if self.xAxis.isLog:
+ try:
+ xMin = math.log10(xMin)
+ except ValueError:
+ _logger.info('xMin: warning log10(%f)', xMin)
+ xMin = 0.
+ try:
+ xMax = math.log10(xMax)
+ except ValueError:
+ _logger.info('xMax: warning log10(%f)', xMax)
+ xMax = 0.
+
+ if self.yAxis.isLog:
+ try:
+ yMin = math.log10(yMin)
+ except ValueError:
+ _logger.info('yMin: warning log10(%f)', yMin)
+ yMin = 0.
+ try:
+ yMax = math.log10(yMax)
+ except ValueError:
+ _logger.info('yMax: warning log10(%f)', yMax)
+ yMax = 0.
+
+ try:
+ y2Min = math.log10(y2Min)
+ except ValueError:
+ _logger.info('yMin: warning log10(%f)', y2Min)
+ y2Min = 0.
+ try:
+ y2Max = math.log10(y2Max)
+ except ValueError:
+ _logger.info('yMax: warning log10(%f)', y2Max)
+ y2Max = 0.
+
+ self._transformedDataRanges = self._DataRanges(
+ (xMin, xMax), (yMin, yMax), (y2Min, y2Max))
+
+ return self._transformedDataRanges
+
+ @property
+ def transformedDataProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ yMin, yMax = self.transformedDataRanges.y
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, yMax, yMin, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, yMin, yMax, 1, -1)
+ self._transformedDataProjMat = mat
+
+ return self._transformedDataProjMat
+
+ @property
+ def transformedDataY2ProjMat(self):
+ """Orthographic projection matrix for rendering transformed data
+ for the 2nd Y axis
+
+ :type: numpy.matrix
+ """
+ if self._transformedDataY2ProjMat is None:
+ xMin, xMax = self.transformedDataRanges.x
+ y2Min, y2Max = self.transformedDataRanges.y2
+
+ if self.isYAxisInverted:
+ mat = mat4Ortho(xMin, xMax, y2Max, y2Min, 1, -1)
+ else:
+ mat = mat4Ortho(xMin, xMax, y2Min, y2Max, 1, -1)
+ self._transformedDataY2ProjMat = mat
+
+ return self._transformedDataY2ProjMat
+
+ def dataToPixel(self, x, y, axis='left'):
+ """Convert data coordinate to widget pixel coordinate.
+ """
+ assert axis in ('left', 'right')
+
+ trBounds = self.transformedDataRanges
+
+ if self.xAxis.isLog:
+ if x < FLOAT32_MINPOS:
+ return None
+ xDataTr = math.log10(x)
+ else:
+ xDataTr = x
+
+ if self.yAxis.isLog:
+ if y < FLOAT32_MINPOS:
+ return None
+ yDataTr = math.log10(y)
+ else:
+ yDataTr = y
+
+ # Non-orthogonal axes
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+
+ coords = numpy.dot(skew_mat, numpy.array((xDataTr, yDataTr)))
+ xDataTr, yDataTr = coords
+
+ plotWidth, plotHeight = self.plotSize
+
+ xPixel = int(self.margins.left +
+ plotWidth * (xDataTr - trBounds.x[0]) /
+ (trBounds.x[1] - trBounds.x[0]))
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ yOffset = (plotHeight * (yDataTr - usedAxis[0]) /
+ (usedAxis[1] - usedAxis[0]))
+
+ if self.isYAxisInverted:
+ yPixel = int(self.margins.top + yOffset)
+ else:
+ yPixel = int(self.size[1] - self.margins.bottom - yOffset)
+
+ return xPixel, yPixel
+
+ def pixelToData(self, x, y, axis="left"):
+ """Convert pixel position to data coordinates.
+
+ :param float x: X coord
+ :param float y: Y coord
+ :param str axis: Y axis to use in ('left', 'right')
+ :return: (x, y) position in data coords
+ """
+ assert axis in ("left", "right")
+
+ plotWidth, plotHeight = self.plotSize
+
+ trBounds = self.transformedDataRanges
+
+ xData = (x - self.margins.left + 0.5) / float(plotWidth)
+ xData = trBounds.x[0] + xData * (trBounds.x[1] - trBounds.x[0])
+
+ usedAxis = trBounds.y if axis == "left" else trBounds.y2
+ if self.isYAxisInverted:
+ yData = (y - self.margins.top + 0.5) / float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+ else:
+ yData = self.size[1] - self.margins.bottom - y - 0.5
+ yData /= float(plotHeight)
+ yData = usedAxis[0] + yData * (usedAxis[1] - usedAxis[0])
+
+ # non-orthogonal axis
+ if self.baseVectors != self.DEFAULT_BASE_VECTORS:
+ (xx, xy), (yx, yy) = self.baseVectors
+ skew_mat = numpy.array(((xx, yx), (xy, yy)))
+ skew_mat = numpy.linalg.inv(skew_mat)
+
+ coords = numpy.dot(skew_mat, numpy.array((xData, yData)))
+ xData, yData = coords
+
+ if self.xAxis.isLog:
+ xData = pow(10, xData)
+ if self.yAxis.isLog:
+ yData = pow(10, yData)
+
+ return xData, yData
+
+ def _buildGridVerticesWithTest(self, test):
+ vertices = []
+
+ if self.baseVectors == self.DEFAULT_BASE_VECTORS:
+ for axis in self.axes:
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ vertices.append((xPixel, yPixel))
+ if axis == self.xAxis:
+ vertices.append((xPixel, self.margins.top))
+ elif axis == self.yAxis:
+ vertices.append((self.size[0] - self.margins.right,
+ yPixel))
+ else: # axis == self.y2Axis
+ vertices.append((self.margins.left, yPixel))
+
+ else:
+ # Get plot corners in data coords
+ plotLeft, plotTop = self.plotOrigin
+ plotWidth, plotHeight = self.plotSize
+
+ corners = [(plotLeft, plotTop),
+ (plotLeft, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop + plotHeight),
+ (plotLeft + plotWidth, plotTop)]
+
+ for axis in self.axes:
+ if axis == self.xAxis:
+ cornersInData = numpy.array([
+ self.pixelToData(x, y) for (x, y) in corners])
+ borders = ((cornersInData[0], cornersInData[3]), # top
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[3], cornersInData[2])) # right
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(x0, x1) <= data < max(x0, x1):
+ yIntersect = (data - x0) * \
+ (y1 - y0) / (x1 - x0) + y0
+
+ pixelPos = self.dataToPixel(
+ data, yIntersect)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ else: # y or y2 axes
+ if axis == self.yAxis:
+ axis_name = 'left'
+ cornersInData = numpy.array([
+ self.pixelToData(x, y) for (x, y) in corners])
+ borders = (
+ (cornersInData[3], cornersInData[2]), # right
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1])) # bottom
+
+ else: # axis == self.y2Axis
+ axis_name = 'right'
+ corners = numpy.array([self.pixelToData(
+ x, y, axis='right') for (x, y) in corners])
+ borders = (
+ (cornersInData[1], cornersInData[0]), # left
+ (cornersInData[0], cornersInData[3]), # top
+ (cornersInData[2], cornersInData[1])) # bottom
+
+ for (xPixel, yPixel), data, text in axis.ticks:
+ if test(text):
+ for (x0, y0), (x1, y1) in borders:
+ if min(y0, y1) <= data < max(y0, y1):
+ xIntersect = (data - y0) * \
+ (x1 - x0) / (y1 - y0) + x0
+
+ pixelPos = self.dataToPixel(
+ xIntersect, data, axis=axis_name)
+ if pixelPos is not None:
+ vertices.append((xPixel, yPixel))
+ vertices.append(pixelPos)
+ break # Stop at first intersection
+
+ return vertices
+
+ def _buildVerticesAndLabels(self):
+ width, height = self.size
+
+ xCoords = (self.margins.left - 0.5,
+ width - self.margins.right + 0.5)
+ yCoords = (height - self.margins.bottom + 0.5,
+ self.margins.top - 0.5)
+
+ self.axes[0].displayCoords = ((xCoords[0], yCoords[0]),
+ (xCoords[1], yCoords[0]))
+
+ self._x2AxisCoords = ((xCoords[0], yCoords[1]),
+ (xCoords[1], yCoords[1]))
+
+ if self.isYAxisInverted:
+ # Y axes are inverted, axes coordinates are inverted
+ yCoords = yCoords[1], yCoords[0]
+
+ self.axes[1].displayCoords = ((xCoords[0], yCoords[0]),
+ (xCoords[0], yCoords[1]))
+
+ self._y2Axis.displayCoords = ((xCoords[1], yCoords[0]),
+ (xCoords[1], yCoords[1]))
+
+ super(GLPlotFrame2D, self)._buildVerticesAndLabels()
+
+ vertices, gridVertices, labels = self._renderResources
+
+ # Adds vertices for borders without axis
+ extraVertices = []
+ extraVertices += self._x2AxisCoords
+ if not self.isY2Axis:
+ extraVertices += self._y2Axis.displayCoords
+
+ extraVertices = numpy.array(
+ extraVertices, copy=False, dtype=numpy.float32)
+ vertices = numpy.append(vertices, extraVertices, axis=0)
+
+ self._renderResources = (vertices, gridVertices, labels)
+
+ @property
+ def foregroundColor(self):
+ """Color used for frame and labels"""
+ return self._foregroundColor
+
+ @foregroundColor.setter
+ def foregroundColor(self, color):
+ """Color used for frame and labels"""
+ assert len(color) == 4, \
+ "foregroundColor must have length 4, got {}".format(len(self._foregroundColor))
+ if self._foregroundColor != color:
+ self._y2Axis.foregroundColor = color
+ GLPlotFrame.foregroundColor.fset(self, color) # call parent property
diff --git a/silx/gui/plot/backends/glutils/GLPlotImage.py b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
index 3ad94b9..3ad94b9 100644
--- a/silx/gui/plot/backends/glutils/GLPlotImage.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotImage.py
diff --git a/silx/gui/plot/backends/glutils/GLPlotItem.py b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
index ae13091..ae13091 100644
--- a/silx/gui/plot/backends/glutils/GLPlotItem.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotItem.py
diff --git a/silx/gui/plot/backends/glutils/GLPlotTriangles.py b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
index fbe9e02..fbe9e02 100644
--- a/silx/gui/plot/backends/glutils/GLPlotTriangles.py
+++ b/src/silx/gui/plot/backends/glutils/GLPlotTriangles.py
diff --git a/silx/gui/plot/backends/glutils/GLSupport.py b/src/silx/gui/plot/backends/glutils/GLSupport.py
index da6dffa..da6dffa 100644
--- a/silx/gui/plot/backends/glutils/GLSupport.py
+++ b/src/silx/gui/plot/backends/glutils/GLSupport.py
diff --git a/silx/gui/plot/backends/glutils/GLText.py b/src/silx/gui/plot/backends/glutils/GLText.py
index d6ae6fa..d6ae6fa 100644
--- a/silx/gui/plot/backends/glutils/GLText.py
+++ b/src/silx/gui/plot/backends/glutils/GLText.py
diff --git a/silx/gui/plot/backends/glutils/GLTexture.py b/src/silx/gui/plot/backends/glutils/GLTexture.py
index 37fbdd0..37fbdd0 100644
--- a/silx/gui/plot/backends/glutils/GLTexture.py
+++ b/src/silx/gui/plot/backends/glutils/GLTexture.py
diff --git a/silx/gui/plot/backends/glutils/PlotImageFile.py b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
index 5fb6853..5fb6853 100644
--- a/silx/gui/plot/backends/glutils/PlotImageFile.py
+++ b/src/silx/gui/plot/backends/glutils/PlotImageFile.py
diff --git a/silx/gui/plot/backends/glutils/__init__.py b/src/silx/gui/plot/backends/glutils/__init__.py
index f87d7c1..f87d7c1 100644
--- a/silx/gui/plot/backends/glutils/__init__.py
+++ b/src/silx/gui/plot/backends/glutils/__init__.py
diff --git a/src/silx/gui/plot/items/__init__.py b/src/silx/gui/plot/items/__init__.py
new file mode 100644
index 0000000..0fe29c2
--- /dev/null
+++ b/src/silx/gui/plot/items/__init__.py
@@ -0,0 +1,53 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides classes that describes :class:`.PlotWidget` content.
+
+Instances of those classes are returned by :class:`.PlotWidget` methods that give
+access to its content such as :meth:`.PlotWidget.getCurve`, :meth:`.PlotWidget.getImage`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/06/2017"
+
+from .core import (Item, DataItem, # noqa
+ LabelsMixIn, DraggableMixIn, ColormapMixIn, # noqa
+ SymbolMixIn, ColorMixIn, YAxisMixIn, FillMixIn, # noqa
+ AlphaMixIn, LineMixIn, ScatterVisualizationMixIn, # noqa
+ ComplexMixIn, ItemChangedType, PointsBase) # noqa
+from .complex import ImageComplexData # noqa
+from .curve import Curve, CurveStyle # noqa
+from .histogram import Histogram # noqa
+from .image import ImageBase, ImageData, ImageDataBase, ImageRgba, ImageStack, MaskImageData # noqa
+from .image_aggregated import ImageDataAggregated # noqa
+from .shape import Shape, BoundingRect, XAxisExtent, YAxisExtent # noqa
+from .scatter import Scatter # noqa
+from .marker import MarkerBase, Marker, XMarker, YMarker # noqa
+from .axis import Axis, XAxis, YAxis, YRightAxis
+
+DATA_ITEMS = (ImageComplexData, Curve, Histogram, ImageBase, Scatter,
+ BoundingRect, XAxisExtent, YAxisExtent)
+"""Classes of items representing data and to consider to compute data bounds.
+"""
diff --git a/silx/gui/plot/items/_arc_roi.py b/src/silx/gui/plot/items/_arc_roi.py
index 23416ec..23416ec 100644
--- a/silx/gui/plot/items/_arc_roi.py
+++ b/src/silx/gui/plot/items/_arc_roi.py
diff --git a/silx/gui/plot/items/_pick.py b/src/silx/gui/plot/items/_pick.py
index 8c8e781..8c8e781 100644
--- a/silx/gui/plot/items/_pick.py
+++ b/src/silx/gui/plot/items/_pick.py
diff --git a/silx/gui/plot/items/_roi_base.py b/src/silx/gui/plot/items/_roi_base.py
index 3eb6cf4..3eb6cf4 100644
--- a/silx/gui/plot/items/_roi_base.py
+++ b/src/silx/gui/plot/items/_roi_base.py
diff --git a/src/silx/gui/plot/items/axis.py b/src/silx/gui/plot/items/axis.py
new file mode 100644
index 0000000..c73323e
--- /dev/null
+++ b/src/silx/gui/plot/items/axis.py
@@ -0,0 +1,560 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the class for axes of the :class:`PlotWidget`.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "22/11/2018"
+
+import datetime as dt
+import enum
+import logging
+
+import dateutil.tz
+import numpy
+
+from ... import qt
+from .. import _utils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TickMode(enum.Enum):
+ """Determines if ticks are regular number or datetimes."""
+ DEFAULT = 0 # Ticks are regular numbers
+ TIME_SERIES = 1 # Ticks are datetime objects
+
+
+class Axis(qt.QObject):
+ """This class describes and controls a plot axis.
+
+ Note: This is an abstract class.
+ """
+ # States are half-stored on the backend of the plot, and half-stored on this
+ # object.
+ # TODO It would be good to store all the states of an axis in this object.
+ # i.e. vmin and vmax
+
+ LINEAR = "linear"
+ """Constant defining a linear scale"""
+
+ LOGARITHMIC = "log"
+ """Constant defining a logarithmic scale"""
+
+ _SCALES = set([LINEAR, LOGARITHMIC])
+
+ sigInvertedChanged = qt.Signal(bool)
+ """Signal emitted when axis orientation has changed"""
+
+ sigScaleChanged = qt.Signal(str)
+ """Signal emitted when axis scale has changed"""
+
+ _sigLogarithmicChanged = qt.Signal(bool)
+ """Signal emitted when axis scale has changed to or from logarithmic"""
+
+ sigAutoScaleChanged = qt.Signal(bool)
+ """Signal emitted when axis autoscale has changed"""
+
+ sigLimitsChanged = qt.Signal(float, float)
+ """Signal emitted when axis limits have changed"""
+
+ def __init__(self, plot):
+ """Constructor
+
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
+ axis
+ """
+ qt.QObject.__init__(self, parent=plot)
+ self._scale = self.LINEAR
+ self._isAutoScale = True
+ # Store default labels provided to setGraph[X|Y]Label
+ self._defaultLabel = ''
+ # Store currently displayed labels
+ # Current label can differ from input one with active curve handling
+ self._currentLabel = ''
+
+ def _getPlot(self):
+ """Returns the PlotWidget this Axis belongs to.
+
+ :rtype: PlotWidget
+ """
+ plot = self.parent()
+ if plot is None:
+ raise RuntimeError("Axis no longer attached to a PlotWidget")
+ return plot
+
+ def _getBackend(self):
+ """Returns the backend
+
+ :rtype: BackendBase
+ """
+ return self._getPlot()._backend
+
+ def getLimits(self):
+ """Get the limits of this axis.
+
+ :return: Minimum and maximum values of this axis as tuple
+ """
+ return self._internalGetLimits()
+
+ def setLimits(self, vmin, vmax):
+ """Set this axis limits.
+
+ :param float vmin: minimum axis value
+ :param float vmax: maximum axis value
+ """
+ vmin, vmax = self._checkLimits(vmin, vmax)
+ if self.getLimits() == (vmin, vmax):
+ return
+
+ self._internalSetLimits(vmin, vmax)
+ self._getPlot()._setDirtyPlot()
+
+ self._emitLimitsChanged()
+
+ def _emitLimitsChanged(self):
+ """Emit axis sigLimitsChanged and PlotWidget limitsChanged event"""
+ vmin, vmax = self.getLimits()
+ self.sigLimitsChanged.emit(vmin, vmax)
+ self._getPlot()._notifyLimitsChanged(emitSignal=False)
+
+ def _checkLimits(self, vmin, vmax):
+ """Makes sure axis range is not empty and within supported range.
+
+ :param float vmin: Min axis value
+ :param float vmax: Max axis value
+ :return: (min, max) making sure min < max
+ :rtype: 2-tuple of float
+ """
+ return _utils.checkAxisLimits(
+ vmin, vmax, isLog=self._isLogarithmic(), name=self._defaultLabel)
+
+ def isInverted(self):
+ """Return True if the axis is inverted (top to bottom for the y-axis),
+ False otherwise. It is always False for the X axis.
+
+ :rtype: bool
+ """
+ return False
+
+ def setInverted(self, isInverted):
+ """Set the axis orientation.
+
+ This is only available for the Y axis.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ if isInverted == self.isInverted():
+ return
+ raise NotImplementedError()
+
+ def getLabel(self):
+ """Return the current displayed label of this axis.
+
+ :param str axis: The Y axis for which to get the label (left or right)
+ :rtype: str
+ """
+ return self._currentLabel
+
+ def setLabel(self, label):
+ """Set the label displayed on the plot for this axis.
+
+ The provided label can be temporarily replaced by the label of the
+ active curve if any.
+
+ :param str label: The axis label
+ """
+ self._defaultLabel = label
+ self._setCurrentLabel(label)
+ self._getPlot()._setDirtyPlot()
+
+ def _setCurrentLabel(self, label):
+ """Define the label currently displayed.
+
+ If the label is None or empty the default label is used.
+
+ :param str label: Currently displayed label
+ """
+ if label is None or label == '':
+ label = self._defaultLabel
+ if label is None:
+ label = ''
+ self._currentLabel = label
+ self._internalSetCurrentLabel(label)
+
+ def getScale(self):
+ """Return the name of the scale used by this axis.
+
+ :rtype: str
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale to be used by this axis.
+
+ :param str scale: Name of the scale ("log", or "linear")
+ """
+ assert(scale in self._SCALES)
+ if self._scale == scale:
+ return
+
+ # For the backward compatibility signal
+ emitLog = self._scale == self.LOGARITHMIC or scale == self.LOGARITHMIC
+
+ self._scale = scale
+
+ # TODO hackish way of forcing update of curves and images
+ plot = self._getPlot()
+ for item in plot.getItems():
+ item._updated()
+ plot._invalidateDataRange()
+
+ if scale == self.LOGARITHMIC:
+ self._internalSetLogarithmic(True)
+ elif scale == self.LINEAR:
+ self._internalSetLogarithmic(False)
+ else:
+ raise ValueError("Scale %s unsupported" % scale)
+
+ plot._forceResetZoom()
+
+ self.sigScaleChanged.emit(self._scale)
+ if emitLog:
+ self._sigLogarithmicChanged.emit(self._scale == self.LOGARITHMIC)
+
+ def _isLogarithmic(self):
+ """Return True if this axis scale is logarithmic, False if linear.
+
+ :rtype: bool
+ """
+ return self._scale == self.LOGARITHMIC
+
+ def _setLogarithmic(self, flag):
+ """Set the scale of this axes (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ flag = bool(flag)
+ self.setScale(self.LOGARITHMIC if flag else self.LINEAR)
+
+ def getTimeZone(self):
+ """Sets tzinfo that is used if this axis plots date times.
+
+ None means the datetimes are interpreted as local time.
+
+ :rtype: datetime.tzinfo of None.
+ """
+ raise NotImplementedError()
+
+ def setTimeZone(self, tz):
+ """Sets tzinfo that is used if this axis' tickMode is TIME_SERIES
+
+ The tz must be a descendant of the datetime.tzinfo class, "UTC" or None.
+ Use None to let the datetimes be interpreted as local time.
+ Use the string "UTC" to let the date datetimes be in UTC time.
+
+ :param tz: datetime.tzinfo, "UTC" or None.
+ """
+ raise NotImplementedError()
+
+ def getTickMode(self):
+ """Determines if axis ticks are number or datetimes.
+
+ :rtype: TickMode enum.
+ """
+ raise NotImplementedError()
+
+ def setTickMode(self, tickMode):
+ """Determines if axis ticks are number or datetimes.
+
+ :param TickMode tickMode: tick mode enum.
+ """
+ raise NotImplementedError()
+
+ def isAutoScale(self):
+ """Return True if axis is automatically adjusting its limits.
+
+ :rtype: bool
+ """
+ return self._isAutoScale
+
+ def setAutoScale(self, flag=True):
+ """Set the axis limits adjusting behavior of :meth:`resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ self._isAutoScale = bool(flag)
+ self.sigAutoScaleChanged.emit(self._isAutoScale)
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ raise NotImplementedError()
+
+ def setLimitsConstraints(self, minPos=None, maxPos=None):
+ """
+ Set a constraint on the position of the axes.
+
+ :param float minPos: Minimum allowed axis value.
+ :param float maxPos: Maximum allowed axis value.
+ :return: True if the constaints was updated
+ :rtype: bool
+ """
+ updated = self._setLimitsConstraints(minPos, maxPos)
+ if updated:
+ plot = self._getPlot()
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis('right').getLimits()
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ raise NotImplementedError()
+
+ def setRangeConstraints(self, minRange=None, maxRange=None):
+ """
+ Set a constraint on the position of the axes.
+
+ :param float minRange: Minimum allowed left-to-right span across the
+ view
+ :param float maxRange: Maximum allowed left-to-right span across the
+ view
+ :return: True if the constaints was updated
+ :rtype: bool
+ """
+ updated = self._setRangeConstraints(minRange, maxRange)
+ if updated:
+ plot = self._getPlot()
+ xMin, xMax = plot.getXAxis().getLimits()
+ yMin, yMax = plot.getYAxis().getLimits()
+ y2Min, y2Max = plot.getYAxis('right').getLimits()
+ plot.setLimits(xMin, xMax, yMin, yMax, y2Min, y2Max)
+ return updated
+
+
+class XAxis(Axis):
+ """Axis class defining primitives for the X axis"""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def getTimeZone(self):
+ return self._getBackend().getXAxisTimeZone()
+
+ def setTimeZone(self, tz):
+ if isinstance(tz, str) and tz.upper() == "UTC":
+ tz = dateutil.tz.tzutc()
+ elif not(tz is None or isinstance(tz, dt.tzinfo)):
+ raise TypeError("tz must be a dt.tzinfo object, None or 'UTC'.")
+
+ self._getBackend().setXAxisTimeZone(tz)
+ self._getPlot()._setDirtyPlot()
+
+ def getTickMode(self):
+ if self._getBackend().isXAxisTimeSeries():
+ return TickMode.TIME_SERIES
+ else:
+ return TickMode.DEFAULT
+
+ def setTickMode(self, tickMode):
+ if tickMode == TickMode.DEFAULT:
+ self._getBackend().setXAxisTimeSeries(False)
+ elif tickMode == TickMode.TIME_SERIES:
+ self._getBackend().setXAxisTimeSeries(True)
+ else:
+ raise ValueError("Unexpected TickMode: {}".format(tickMode))
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphXLabel(label)
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphXLimits()
+
+ def _internalSetLimits(self, xmin, xmax):
+ self._getBackend().setGraphXLimits(xmin, xmax)
+
+ def _internalSetLogarithmic(self, flag):
+ self._getBackend().setXAxisLogarithmic(flag)
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(xMin=minPos, xMax=maxPos)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(minXRange=minRange, maxXRange=maxRange)
+ return updated
+
+
+class YAxis(Axis):
+ """Axis class defining primitives for the Y axis"""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphYLabel(label, axis='left')
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphYLimits(axis='left')
+
+ def _internalSetLimits(self, ymin, ymax):
+ self._getBackend().setGraphYLimits(ymin, ymax, axis='left')
+
+ def _internalSetLogarithmic(self, flag):
+ self._getBackend().setYAxisLogarithmic(flag)
+
+ def setInverted(self, flag=True):
+ """Set the axis orientation.
+
+ This is only available for the Y axis.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ flag = bool(flag)
+ if self.isInverted() == flag:
+ return
+ self._getBackend().setYAxisInverted(flag)
+ self._getPlot()._setDirtyPlot()
+ self.sigInvertedChanged.emit(flag)
+
+ def isInverted(self):
+ """Return True if the axis is inverted (top to bottom for the y-axis),
+ False otherwise. It is always False for the X axis.
+
+ :rtype: bool
+ """
+ return self._getBackend().isYAxisInverted()
+
+ def _setLimitsConstraints(self, minPos=None, maxPos=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(yMin=minPos, yMax=maxPos)
+ return updated
+
+ def _setRangeConstraints(self, minRange=None, maxRange=None):
+ constrains = self._getPlot()._getViewConstraints()
+ updated = constrains.update(minYRange=minRange, maxYRange=maxRange)
+ return updated
+
+
+class YRightAxis(Axis):
+ """Proxy axis for the secondary Y axes. It manages it own label and limit
+ but share the some state like scale and direction with the main axis."""
+
+ # TODO With some changes on the backend, it will be able to remove all this
+ # specialised implementations (prefixel by '_internal')
+
+ def __init__(self, plot, mainAxis):
+ """Constructor
+
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: Parent plot of this
+ axis
+ :param Axis mainAxis: Axis which sharing state with this axis
+ """
+ Axis.__init__(self, plot)
+ self.__mainAxis = mainAxis
+
+ @property
+ def sigInvertedChanged(self):
+ """Signal emitted when axis orientation has changed"""
+ return self.__mainAxis.sigInvertedChanged
+
+ @property
+ def sigScaleChanged(self):
+ """Signal emitted when axis scale has changed"""
+ return self.__mainAxis.sigScaleChanged
+
+ @property
+ def _sigLogarithmicChanged(self):
+ """Signal emitted when axis scale has changed to or from logarithmic"""
+ return self.__mainAxis._sigLogarithmicChanged
+
+ @property
+ def sigAutoScaleChanged(self):
+ """Signal emitted when axis autoscale has changed"""
+ return self.__mainAxis.sigAutoScaleChanged
+
+ def _internalSetCurrentLabel(self, label):
+ self._getBackend().setGraphYLabel(label, axis='right')
+
+ def _internalGetLimits(self):
+ return self._getBackend().getGraphYLimits(axis='right')
+
+ def _internalSetLimits(self, ymin, ymax):
+ self._getBackend().setGraphYLimits(ymin, ymax, axis='right')
+
+ def setInverted(self, flag=True):
+ """Set the Y axis orientation.
+
+ :param bool flag: True for Y axis going from top to bottom,
+ False for Y axis going from bottom to top
+ """
+ return self.__mainAxis.setInverted(flag)
+
+ def isInverted(self):
+ """Return True if Y axis goes from top to bottom, False otherwise."""
+ return self.__mainAxis.isInverted()
+
+ def getScale(self):
+ """Return the name of the scale used by this axis.
+
+ :rtype: str
+ """
+ return self.__mainAxis.getScale()
+
+ def setScale(self, scale):
+ """Set the scale to be used by this axis.
+
+ :param str scale: Name of the scale ("log", or "linear")
+ """
+ self.__mainAxis.setScale(scale)
+
+ def _isLogarithmic(self):
+ """Return True if Y axis scale is logarithmic, False if linear."""
+ return self.__mainAxis._isLogarithmic()
+
+ def _setLogarithmic(self, flag):
+ """Set the Y axes scale (either linear or logarithmic).
+
+ :param bool flag: True to use a logarithmic scale, False for linear.
+ """
+ return self.__mainAxis._setLogarithmic(flag)
+
+ def isAutoScale(self):
+ """Return True if Y axes are automatically adjusting its limits."""
+ return self.__mainAxis.isAutoScale()
+
+ def setAutoScale(self, flag=True):
+ """Set the Y axis limits adjusting behavior of :meth:`PlotWidget.resetZoom`.
+
+ :param bool flag: True to resize limits automatically,
+ False to disable it.
+ """
+ return self.__mainAxis.setAutoScale(flag)
diff --git a/silx/gui/plot/items/complex.py b/src/silx/gui/plot/items/complex.py
index abb64ad..abb64ad 100644
--- a/silx/gui/plot/items/complex.py
+++ b/src/silx/gui/plot/items/complex.py
diff --git a/src/silx/gui/plot/items/core.py b/src/silx/gui/plot/items/core.py
new file mode 100644
index 0000000..fa3b8cf
--- /dev/null
+++ b/src/silx/gui/plot/items/core.py
@@ -0,0 +1,1733 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base class for items of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+import collections
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+from copy import deepcopy
+import logging
+import enum
+from typing import Optional, Tuple
+import warnings
+import weakref
+
+import numpy
+
+from ....utils.deprecation import deprecated
+from ....utils.proxy import docstring
+from ....utils.enum import Enum as _Enum
+from ....math.combo import min_max
+from ... import qt
+from ... import colors
+from ...colors import Colormap
+from ._pick import PickingResult
+
+from silx import config
+
+_logger = logging.getLogger(__name__)
+
+
+@enum.unique
+class ItemChangedType(enum.Enum):
+ """Type of modification provided by :attr:`Item.sigItemChanged` signal."""
+ # Private setters and setInfo are not emitting sigItemChanged signal.
+ # Signals to consider:
+ # COLORMAP_SET emitted when setColormap is called but not forward colormap object signal
+ # CURRENT_COLOR_CHANGED emitted current color changed because highlight changed,
+ # highlighted color changed or color changed depending on hightlight state.
+
+ VISIBLE = 'visibleChanged'
+ """Item's visibility changed flag."""
+
+ ZVALUE = 'zValueChanged'
+ """Item's Z value changed flag."""
+
+ COLORMAP = 'colormapChanged' # Emitted when set + forward events from the colormap object
+ """Item's colormap changed flag.
+
+ This is emitted both when setting a new colormap and
+ when the current colormap object is updated.
+ """
+
+ SYMBOL = 'symbolChanged'
+ """Item's symbol changed flag."""
+
+ SYMBOL_SIZE = 'symbolSizeChanged'
+ """Item's symbol size changed flag."""
+
+ LINE_WIDTH = 'lineWidthChanged'
+ """Item's line width changed flag."""
+
+ LINE_STYLE = 'lineStyleChanged'
+ """Item's line style changed flag."""
+
+ COLOR = 'colorChanged'
+ """Item's color changed flag."""
+
+ LINE_BG_COLOR = 'lineBgColorChanged'
+ """Item's line background color changed flag."""
+
+ YAXIS = 'yAxisChanged'
+ """Item's Y axis binding changed flag."""
+
+ FILL = 'fillChanged'
+ """Item's fill changed flag."""
+
+ ALPHA = 'alphaChanged'
+ """Item's transparency alpha changed flag."""
+
+ DATA = 'dataChanged'
+ """Item's data changed flag"""
+
+ MASK = 'maskChanged'
+ """Item's mask changed flag"""
+
+ HIGHLIGHTED = 'highlightedChanged'
+ """Item's highlight state changed flag."""
+
+ HIGHLIGHTED_COLOR = 'highlightedColorChanged'
+ """Deprecated, use HIGHLIGHTED_STYLE instead."""
+
+ HIGHLIGHTED_STYLE = 'highlightedStyleChanged'
+ """Item's highlighted style changed flag."""
+
+ SCALE = 'scaleChanged'
+ """Item's scale changed flag."""
+
+ TEXT = 'textChanged'
+ """Item's text changed flag."""
+
+ POSITION = 'positionChanged'
+ """Item's position changed flag.
+
+ This is emitted when a marker position changed and
+ when an image origin changed.
+ """
+
+ OVERLAY = 'overlayChanged'
+ """Item's overlay state changed flag."""
+
+ VISUALIZATION_MODE = 'visualizationModeChanged'
+ """Item's visualization mode changed flag."""
+
+ COMPLEX_MODE = 'complexModeChanged'
+ """Item's complex data visualization mode changed flag."""
+
+ NAME = 'nameChanged'
+ """Item's name changed flag."""
+
+ EDITABLE = 'editableChanged'
+ """Item's editable state changed flags."""
+
+ SELECTABLE = 'selectableChanged'
+ """Item's selectable state changed flags."""
+
+
+class Item(qt.QObject):
+ """Description of an item of the plot"""
+
+ _DEFAULT_Z_LAYER = 0
+ """Default layer for overlay rendering"""
+
+ _DEFAULT_SELECTABLE = False
+ """Default selectable state of items"""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when the item has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` for flags description.
+ """
+
+ _sigVisibleBoundsChanged = qt.Signal()
+ """Signal emitted when the visible extent of the item in the plot has changed.
+
+ This signal is emitted only if visible extent tracking is enabled
+ (see :meth:`_setVisibleBoundsTracking`).
+ """
+
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self._dirty = True
+ self._plotRef = None
+ self._visible = True
+ self._selectable = self._DEFAULT_SELECTABLE
+ self._z = self._DEFAULT_Z_LAYER
+ self._info = None
+ self._xlabel = None
+ self._ylabel = None
+ self.__name = ''
+
+ self.__visibleBoundsTracking = False
+ self.__previousVisibleBounds = None
+
+ self._backendRenderer = None
+
+ def getPlot(self):
+ """Returns the ~silx.gui.plot.PlotWidget this item belongs to.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return None if self._plotRef is None else self._plotRef()
+
+ def _setPlot(self, plot):
+ """Set the plot this item belongs to.
+
+ WARNING: This should only be called from the Plot.
+
+ :param Union[~silx.gui.plot.PlotWidget,None] plot: The Plot instance.
+ """
+ if plot is not None and self._plotRef is not None:
+ raise RuntimeError('Trying to add a node at two places.')
+ self.__disconnectFromPlotWidget()
+ self._plotRef = None if plot is None else weakref.ref(plot)
+ self.__connectToPlotWidget()
+ self._updated()
+
+ def getBounds(self): # TODO return a Bounds object rather than a tuple
+ """Returns the bounding box of this item in data coordinates
+
+ :returns: (xmin, xmax, ymin, ymax) or None
+ :rtype: 4-tuple of float or None
+ """
+ return self._getBounds()
+
+ def _getBounds(self):
+ """:meth:`getBounds` implementation to override by sub-class"""
+ return None
+
+ def isVisible(self):
+ """True if item is visible, False otherwise
+
+ :rtype: bool
+ """
+ return self._visible
+
+ def setVisible(self, visible):
+ """Set visibility of item.
+
+ :param bool visible: True to display it, False otherwise
+ """
+ visible = bool(visible)
+ if visible != self._visible:
+ self._visible = visible
+ # When visibility has changed, always mark as dirty
+ self._updated(ItemChangedType.VISIBLE,
+ checkVisibility=False)
+
+ def isOverlay(self):
+ """Return true if item is drawn as an overlay.
+
+ :rtype: bool
+ """
+ return False
+
+ def getName(self):
+ """Returns the name of the item which is used as legend.
+
+ :rtype: str
+ """
+ return self.__name
+
+ def setName(self, name):
+ """Set the name of the item which is used as legend.
+
+ :param str name: New name of the item
+ :raises RuntimeError: If item belongs to a PlotWidget.
+ """
+ name = str(name)
+ if self.__name != name:
+ if self.getPlot() is not None:
+ raise RuntimeError(
+ "Cannot change name while item is in a PlotWidget")
+
+ self.__name = name
+ self._updated(ItemChangedType.NAME)
+
+ def getLegend(self): # Replaced by getName for API consistency
+ return self.getName()
+
+ @deprecated(replacement='setName', since_version='0.13')
+ def _setLegend(self, legend):
+ legend = str(legend) if legend is not None else ''
+ self.setName(legend)
+
+ def isSelectable(self):
+ """Returns true if item is selectable (bool)"""
+ return self._selectable
+
+ def _setSelectable(self, selectable): # TODO support update
+ """Set whether item is selectable or not.
+
+ This is private for now as change is not handled.
+
+ :param bool selectable: True to make item selectable
+ """
+ self._selectable = bool(selectable)
+
+ def getZValue(self):
+ """Returns the layer on which to draw this item (int)"""
+ return self._z
+
+ def setZValue(self, z):
+ z = int(z) if z is not None else self._DEFAULT_Z_LAYER
+ if z != self._z:
+ self._z = z
+ self._updated(ItemChangedType.ZVALUE)
+
+ def getInfo(self, copy=True):
+ """Returns the info associated to this item
+
+ :param bool copy: True to get a deepcopy, False otherwise.
+ """
+ return deepcopy(self._info) if copy else self._info
+
+ def setInfo(self, info, copy=True):
+ if copy:
+ info = deepcopy(info)
+ self._info = info
+
+ def getVisibleBounds(self) -> Optional[Tuple[float, float, float, float]]:
+ """Returns visible bounds of the item bounding box in the plot area.
+
+ :returns:
+ (xmin, xmax, ymin, ymax) in data coordinates of the visible area or
+ None if item is not visible in the plot area.
+ :rtype: Union[List[float],None]
+ """
+ plot = self.getPlot()
+ bounds = self.getBounds()
+ if plot is None or bounds is None or not self.isVisible():
+ return None
+
+ xmin, xmax = numpy.clip(bounds[:2], *plot.getXAxis().getLimits())
+ ymin, ymax = numpy.clip(
+ bounds[2:], *plot.getYAxis(self.__getYAxis()).getLimits())
+
+ if xmin == xmax or ymin == ymax: # Outside the plot area
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ def _isVisibleBoundsTracking(self) -> bool:
+ """Returns True if visible bounds changes are tracked.
+
+ When enabled, :attr:`_sigVisibleBoundsChanged` is emitted upon changes.
+ :rtype: bool
+ """
+ return self.__visibleBoundsTracking
+
+ def _setVisibleBoundsTracking(self, enable: bool) -> None:
+ """Set whether or not to track visible bounds changes.
+
+ :param bool enable:
+ """
+ if enable != self.__visibleBoundsTracking:
+ self.__disconnectFromPlotWidget()
+ self.__previousVisibleBounds = None
+ self.__visibleBoundsTracking = enable
+ self.__connectToPlotWidget()
+
+ def __getYAxis(self) -> str:
+ """Returns current Y axis ('left' or 'right')"""
+ return self.getYAxis() if isinstance(self, YAxisMixIn) else 'left'
+
+ def __connectToPlotWidget(self) -> None:
+ """Connect to PlotWidget signals and install event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.connect(self._visibleBoundsChanged)
+
+ plot.installEventFilter(self)
+
+ self._visibleBoundsChanged()
+
+ def __disconnectFromPlotWidget(self) -> None:
+ """Disconnect from PlotWidget signals and remove event filter"""
+ if not self._isVisibleBoundsTracking():
+ return
+
+ plot = self.getPlot()
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis(self.__getYAxis())):
+ axis.sigLimitsChanged.disconnect(self._visibleBoundsChanged)
+
+ plot.removeEventFilter(self)
+
+ def _visibleBoundsChanged(self, *args) -> None:
+ """Check if visible extent actually changed and emit signal"""
+ if not self._isVisibleBoundsTracking():
+ return # No visible extent tracking
+
+ plot = self.getPlot()
+ if plot is None or not plot.isVisible():
+ return # No plot or plot not visible
+
+ extent = self.getVisibleBounds()
+ if extent != self.__previousVisibleBounds:
+ self.__previousVisibleBounds = extent
+ self._sigVisibleBoundsChanged.emit()
+
+ def eventFilter(self, watched, event):
+ """Event filter to handle PlotWidget show events"""
+ if watched is self.getPlot() and event.type() == qt.QEvent.Show:
+ self._visibleBoundsChanged()
+ return super().eventFilter(watched, event)
+
+ def _updated(self, event=None, checkVisibility=True):
+ """Mark the item as dirty (i.e., needing update).
+
+ This also triggers Plot.replot.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ if not checkVisibility or self.isVisible():
+ if not self._dirty:
+ self._dirty = True
+ # TODO: send event instead of explicit call
+ plot = self.getPlot()
+ if plot is not None:
+ plot._itemRequiresUpdate(self)
+ if event is not None:
+ self.sigItemChanged.emit(event)
+
+ def _update(self, backend):
+ """Called by Plot to update the backend for this item.
+
+ This is meant to be called asynchronously from _updated.
+ This optimizes the number of call to _update.
+
+ :param backend: The backend to update
+ """
+ if self._dirty:
+ # Remove previous renderer from backend if any
+ self._removeBackendRenderer(backend)
+
+ # If not visible, do not add renderer to backend
+ if self.isVisible():
+ self._backendRenderer = self._addBackendRenderer(backend)
+
+ self._dirty = False
+
+ def _addBackendRenderer(self, backend):
+ """Override in subclass to add specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ :return: The renderer handle to store or None if no renderer in backend
+ """
+ return None
+
+ def _removeBackendRenderer(self, backend):
+ """Override in subclass to remove specific backend renderer.
+
+ :param BackendBase backend: The backend to update
+ """
+ if self._backendRenderer is not None:
+ backend.remove(self._backendRenderer)
+ self._backendRenderer = None
+
+ def pick(self, x, y):
+ """Run picking test on this item
+
+ :param float x: The x pixel coord where to pick.
+ :param float y: The y pixel coord where to pick.
+ :return: None if not picked, else the picked position information
+ :rtype: Union[None,PickingResult]
+ """
+ if not self.isVisible() or self._backendRenderer is None:
+ return None
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ indices = plot._backend.pickItem(x, y, self._backendRenderer)
+ if indices is None:
+ return None
+ else:
+ return PickingResult(self, indices)
+
+
+class DataItem(Item):
+ """Item with a data extent in the plot"""
+
+ def _boundsChanged(self, checkVisibility: bool=True) -> None:
+ """Call this method in subclass when data bounds has changed.
+
+ :param bool checkVisibility:
+ """
+ if not checkVisibility or self.isVisible():
+ self._visibleBoundsChanged()
+
+ # TODO hackish data range implementation
+ plot = self.getPlot()
+ if plot is not None:
+ plot._invalidateDataRange()
+
+ @docstring(Item)
+ def setVisible(self, visible: bool):
+ if visible != self.isVisible():
+ self._boundsChanged(checkVisibility=False)
+ super().setVisible(visible)
+
+# Mix-in classes ##############################################################
+
+
+class ItemMixInBase(object):
+ """Base class for Item mix-in"""
+
+ def _updated(self, event=None, checkVisibility=True):
+ """This is implemented in :class:`Item`.
+
+ Mark the item as dirty (i.e., needing update).
+ This also triggers Plot.replot.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ :param bool checkVisibility: True to only mark as dirty if visible,
+ False to always mark as dirty.
+ """
+ raise RuntimeError(
+ "Issue with Mix-In class inheritance order")
+
+
+class LabelsMixIn(ItemMixInBase):
+ """Mix-in class for items with x and y labels
+
+ Setters are private, otherwise it needs to check the plot
+ current active curve and access the internal current labels.
+ """
+
+ def __init__(self):
+ self._xlabel = None
+ self._ylabel = None
+
+ def getXLabel(self):
+ """Return the X axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._xlabel
+
+ def _setXLabel(self, label):
+ """Set the X axis label associated with this curve
+
+ :param str label: The X axis label
+ """
+ self._xlabel = str(label)
+
+ def getYLabel(self):
+ """Return the Y axis label associated to this curve
+
+ :rtype: str or None
+ """
+ return self._ylabel
+
+ def _setYLabel(self, label):
+ """Set the Y axis label associated with this curve
+
+ :param str label: The Y axis label
+ """
+ self._ylabel = str(label)
+
+
+class DraggableMixIn(ItemMixInBase):
+ """Mix-in class for draggable items"""
+
+ def __init__(self):
+ self._draggable = False
+
+ def isDraggable(self):
+ """Returns true if image is draggable
+
+ :rtype: bool
+ """
+ return self._draggable
+
+ def _setDraggable(self, draggable): # TODO support update
+ """Set if image is draggable or not.
+
+ This is private for not as it does not support update.
+
+ :param bool draggable:
+ """
+ self._draggable = bool(draggable)
+
+ def drag(self, from_, to):
+ """Perform a drag of the item.
+
+ :param List[float] from_: (x, y) previous position in data coordinates
+ :param List[float] to: (x, y) current position in data coordinates
+ """
+ raise NotImplementedError("Must be implemented in subclass")
+
+
+class ColormapMixIn(ItemMixInBase):
+ """Mix-in class for items with colormap"""
+
+ def __init__(self):
+ self._colormap = Colormap()
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self.__data = None
+ self.__cacheColormapRange = {} # Store {normalization: range}
+
+ def getColormap(self):
+ """Return the used colormap"""
+ return self._colormap
+
+ def setColormap(self, colormap):
+ """Set the colormap of this item
+
+ :param silx.gui.colors.Colormap colormap: colormap description
+ """
+ if self._colormap is colormap:
+ return
+ if isinstance(colormap, dict):
+ colormap = Colormap._fromDict(colormap)
+
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapChanged)
+ self._colormap = colormap
+ if self._colormap is not None:
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ self._colormapChanged()
+
+ def _colormapChanged(self):
+ """Handle updates of the colormap"""
+ self._updated(ItemChangedType.COLORMAP)
+
+ def _setColormappedData(self, data, copy=True,
+ min_=None, minPositive=None, max_=None):
+ """Set the data used to compute the colormapped display.
+
+ It also resets the cache of data ranges.
+
+ This method MUST be called by inheriting classes when data is updated.
+
+ :param Union[None,numpy.ndarray] data:
+ :param Union[None,float] min_: Minimum value of the data
+ :param Union[None,float] minPositive:
+ Minimum of strictly positive values of the data
+ :param Union[None,float] max_: Maximum value of the data
+ """
+ self.__data = None if data is None else numpy.array(data, copy=copy)
+ self.__cacheColormapRange = {} # Reset cache
+
+ # Fill-up colormap range cache if values are provided
+ if max_ is not None and numpy.isfinite(max_):
+ if min_ is not None and numpy.isfinite(min_):
+ self.__cacheColormapRange[Colormap.LINEAR, Colormap.MINMAX] = min_, max_
+ if minPositive is not None and numpy.isfinite(minPositive):
+ self.__cacheColormapRange[Colormap.LOGARITHM, Colormap.MINMAX] = minPositive, max_
+
+ colormap = self.getColormap()
+ if None in (colormap.getVMin(), colormap.getVMax()):
+ self._colormapChanged()
+
+ def getColormappedData(self, copy=True):
+ """Returns the data used to compute the displayed colors
+
+ :param bool copy: True to get a copy,
+ False to get internal data (do not modify!).
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self.__data is None:
+ return None
+ else:
+ return numpy.array(self.__data, copy=copy)
+
+ def _getColormapAutoscaleRange(self, colormap=None):
+ """Returns the autoscale range for current data and colormap.
+
+ :param Union[None,~silx.gui.colors.Colormap] colormap:
+ The colormap for which to compute the autoscale range.
+ If None, the default, the colormap of the item is used
+ :return: (vmin, vmax) range (vmin and /or vmax might be `None`)
+ """
+ if colormap is None:
+ colormap = self.getColormap()
+
+ data = self.getColormappedData(copy=False)
+ if colormap is None or data is None:
+ return None, None
+
+ normalization = colormap.getNormalization()
+ autoscaleMode = colormap.getAutoscaleMode()
+ key = normalization, autoscaleMode
+ vRange = self.__cacheColormapRange.get(key, None)
+ if vRange is None:
+ vRange = colormap._computeAutoscaleRange(data)
+ self.__cacheColormapRange[key] = vRange
+ return vRange
+
+
+class SymbolMixIn(ItemMixInBase):
+ """Mix-in class for items with symbol type"""
+
+ _DEFAULT_SYMBOL = None
+ """Default marker of the item"""
+
+ _DEFAULT_SYMBOL_SIZE = config.DEFAULT_PLOT_SYMBOL_SIZE
+ """Default marker size of the item"""
+
+ _SUPPORTED_SYMBOLS = collections.OrderedDict((
+ ('o', 'Circle'),
+ ('d', 'Diamond'),
+ ('s', 'Square'),
+ ('+', 'Plus'),
+ ('x', 'Cross'),
+ ('.', 'Point'),
+ (',', 'Pixel'),
+ ('|', 'Vertical line'),
+ ('_', 'Horizontal line'),
+ ('tickleft', 'Tick left'),
+ ('tickright', 'Tick right'),
+ ('tickup', 'Tick up'),
+ ('tickdown', 'Tick down'),
+ ('caretleft', 'Caret left'),
+ ('caretright', 'Caret right'),
+ ('caretup', 'Caret up'),
+ ('caretdown', 'Caret down'),
+ (u'\u2665', 'Heart'),
+ ('', 'None')))
+ """Dict of supported symbols"""
+
+ def __init__(self):
+ if self._DEFAULT_SYMBOL is None: # Use default from config
+ self._symbol = config.DEFAULT_PLOT_SYMBOL
+ else:
+ self._symbol = self._DEFAULT_SYMBOL
+
+ if self._DEFAULT_SYMBOL_SIZE is None: # Use default from config
+ self._symbol_size = config.DEFAULT_PLOT_SYMBOL_SIZE
+ else:
+ self._symbol_size = self._DEFAULT_SYMBOL_SIZE
+
+ @classmethod
+ def getSupportedSymbols(cls):
+ """Returns the list of supported symbol names.
+
+ :rtype: tuple of str
+ """
+ return tuple(cls._SUPPORTED_SYMBOLS.keys())
+
+ @classmethod
+ def getSupportedSymbolNames(cls):
+ """Returns the list of supported symbol human-readable names.
+
+ :rtype: tuple of str
+ """
+ return tuple(cls._SUPPORTED_SYMBOLS.values())
+
+ def getSymbolName(self, symbol=None):
+ """Returns human-readable name for a symbol.
+
+ :param str symbol: The symbol from which to get the name.
+ Default: current symbol.
+ :rtype: str
+ :raise KeyError: if symbol is not in :meth:`getSupportedSymbols`.
+ """
+ if symbol is None:
+ symbol = self.getSymbol()
+ return self._SUPPORTED_SYMBOLS[symbol]
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: str
+ """
+ return self._symbol
+
+ def setSymbol(self, symbol):
+ """Set the marker type
+
+ See :meth:`getSymbol`.
+
+ :param str symbol: Marker type or marker name
+ """
+ if symbol is None:
+ symbol = self._DEFAULT_SYMBOL
+
+ elif symbol not in self.getSupportedSymbols():
+ for symbolCode, name in self._SUPPORTED_SYMBOLS.items():
+ if name.lower() == symbol.lower():
+ symbol = symbolCode
+ break
+ else:
+ raise ValueError('Unsupported symbol %s' % str(symbol))
+
+ if symbol != self._symbol:
+ self._symbol = symbol
+ self._updated(ItemChangedType.SYMBOL)
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: float
+ """
+ return self._symbol_size
+
+ def setSymbolSize(self, size):
+ """Set the point marker size in points.
+
+ See :meth:`getSymbolSize`.
+
+ :param str symbol: Marker type
+ """
+ if size is None:
+ size = self._DEFAULT_SYMBOL_SIZE
+ if size != self._symbol_size:
+ self._symbol_size = size
+ self._updated(ItemChangedType.SYMBOL_SIZE)
+
+
+class LineMixIn(ItemMixInBase):
+ """Mix-in class for item with line"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style"""
+
+ _SUPPORTED_LINESTYLE = '', ' ', '-', '--', '-.', ':', None
+ """Supported line styles"""
+
+ def __init__(self):
+ self._linewidth = self._DEFAULT_LINEWIDTH
+ self._linestyle = self._DEFAULT_LINESTYLE
+
+ @classmethod
+ def getSupportedLineStyles(cls):
+ """Returns list of supported line styles.
+
+ :rtype: List[str,None]
+ """
+ return cls._SUPPORTED_LINESTYLE
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels
+
+ :rtype: float
+ """
+ return self._linewidth
+
+ def setLineWidth(self, width):
+ """Set the width in pixel of the curve line
+
+ See :meth:`getLineWidth`.
+
+ :param float width: Width in pixels
+ """
+ width = float(width)
+ if width != self._linewidth:
+ self._linewidth = width
+ self._updated(ItemChangedType.LINE_WIDTH)
+
+ def getLineStyle(self):
+ """Return the type of the line
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :rtype: str
+ """
+ return self._linestyle
+
+ def setLineStyle(self, style):
+ """Set the style of the curve line.
+
+ See :meth:`getLineStyle`.
+
+ :param str style: Line style
+ """
+ style = str(style)
+ assert style in self.getSupportedLineStyles()
+ if style is None:
+ style = self._DEFAULT_LINESTYLE
+ if style != self._linestyle:
+ self._linestyle = style
+ self._updated(ItemChangedType.LINE_STYLE)
+
+
+class ColorMixIn(ItemMixInBase):
+ """Mix-in class for item with color"""
+
+ _DEFAULT_COLOR = (0., 0., 0., 1.)
+ """Default color of the item"""
+
+ def __init__(self):
+ self._color = self._DEFAULT_COLOR
+
+ def getColor(self):
+ """Returns the RGBA color of the item
+
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._color
+
+ def setColor(self, color, copy=True):
+ """Set item color
+
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ elif isinstance(color, qt.QColor):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._color = color
+ self._updated(ItemChangedType.COLOR)
+
+
+class YAxisMixIn(ItemMixInBase):
+ """Mix-in class for item with yaxis"""
+
+ _DEFAULT_YAXIS = 'left'
+ """Default Y axis the item belongs to"""
+
+ def __init__(self):
+ self._yaxis = self._DEFAULT_YAXIS
+
+ def getYAxis(self):
+ """Returns the Y axis this curve belongs to.
+
+ Either 'left' or 'right'.
+
+ :rtype: str
+ """
+ return self._yaxis
+
+ def setYAxis(self, yaxis):
+ """Set the Y axis this curve belongs to.
+
+ :param str yaxis: 'left' or 'right'
+ """
+ yaxis = str(yaxis)
+ assert yaxis in ('left', 'right')
+ if yaxis != self._yaxis:
+ self._yaxis = yaxis
+ # Handle data extent changed for DataItem
+ if isinstance(self, DataItem):
+ self._boundsChanged()
+
+ # Handle visible extent changed
+ if self._isVisibleBoundsTracking():
+ # Switch Y axis signal connection
+ plot = self.getPlot()
+ if plot is not None:
+ previousYAxis = 'left' if self.getXAxis() == 'right' else 'right'
+ plot.getYAxis(previousYAxis).sigLimitsChanged.disconnect(
+ self._visibleBoundsChanged)
+ plot.getYAxis(self.getYAxis()).sigLimitsChanged.connect(
+ self._visibleBoundsChanged)
+ self._visibleBoundsChanged()
+
+ self._updated(ItemChangedType.YAXIS)
+
+
+class FillMixIn(ItemMixInBase):
+ """Mix-in class for item with fill"""
+
+ def __init__(self):
+ self._fill = False
+
+ def isFill(self):
+ """Returns whether the item is filled or not.
+
+ :rtype: bool
+ """
+ return self._fill
+
+ def setFill(self, fill):
+ """Set whether to fill the item or not.
+
+ :param bool fill:
+ """
+ fill = bool(fill)
+ if fill != self._fill:
+ self._fill = fill
+ self._updated(ItemChangedType.FILL)
+
+
+class AlphaMixIn(ItemMixInBase):
+ """Mix-in class for item with opacity"""
+
+ def __init__(self):
+ self._alpha = 1.
+
+ def getAlpha(self):
+ """Returns the opacity of the item
+
+ :rtype: float in [0, 1.]
+ """
+ return self._alpha
+
+ def setAlpha(self, alpha):
+ """Set the opacity of the item
+
+ .. note::
+
+ If the colormap already has some transparency, this alpha
+ adds additional transparency. The alpha channel of the colormap
+ is multiplied by this value.
+
+ :param alpha: Opacity of the item, between 0 (full transparency)
+ and 1. (full opacity)
+ :type alpha: float
+ """
+ alpha = float(alpha)
+ alpha = max(0., min(alpha, 1.)) # Clip alpha to [0., 1.] range
+ if alpha != self._alpha:
+ self._alpha = alpha
+ self._updated(ItemChangedType.ALPHA)
+
+
+class ComplexMixIn(ItemMixInBase):
+ """Mix-in class for complex data mode"""
+
+ _SUPPORTED_COMPLEX_MODES = None
+ """Override to only support a subset of all ComplexMode"""
+
+ class ComplexMode(_Enum):
+ """Identify available display mode for complex"""
+ NONE = 'none'
+ ABSOLUTE = 'amplitude'
+ PHASE = 'phase'
+ REAL = 'real'
+ IMAGINARY = 'imaginary'
+ AMPLITUDE_PHASE = 'amplitude_phase'
+ LOG10_AMPLITUDE_PHASE = 'log10_amplitude_phase'
+ SQUARE_AMPLITUDE = 'square_amplitude'
+
+ def __init__(self):
+ self.__complex_mode = self.ComplexMode.ABSOLUTE
+
+ def getComplexMode(self):
+ """Returns the current complex visualization mode.
+
+ :rtype: ComplexMode
+ """
+ return self.__complex_mode
+
+ def setComplexMode(self, mode):
+ """Set the complex visualization mode.
+
+ :param ComplexMode mode: The visualization mode in:
+ 'real', 'imaginary', 'phase', 'amplitude'
+ :return: True if value was set, False if is was already set
+ :rtype: bool
+ """
+ mode = self.ComplexMode.from_value(mode)
+ assert mode in self.supportedComplexModes()
+
+ if mode != self.__complex_mode:
+ self.__complex_mode = mode
+ self._updated(ItemChangedType.COMPLEX_MODE)
+ return True
+ else:
+ return False
+
+ def _convertComplexData(self, data, mode=None):
+ """Convert complex data to the specific mode.
+
+ :param Union[ComplexMode,None] mode:
+ The kind of value to compute.
+ If None (the default), the current complex mode is used.
+ :return: The converted dataset
+ :rtype: Union[numpy.ndarray[float],None]
+ """
+ if data is None:
+ return None
+
+ if mode is None:
+ mode = self.getComplexMode()
+
+ if mode is self.ComplexMode.REAL:
+ return numpy.real(data)
+ elif mode is self.ComplexMode.IMAGINARY:
+ return numpy.imag(data)
+ elif mode is self.ComplexMode.ABSOLUTE:
+ return numpy.absolute(data)
+ elif mode is self.ComplexMode.PHASE:
+ return numpy.angle(data)
+ elif mode is self.ComplexMode.SQUARE_AMPLITUDE:
+ return numpy.absolute(data) ** 2
+ else:
+ raise ValueError('Unsupported conversion mode: %s', str(mode))
+
+ @classmethod
+ def supportedComplexModes(cls):
+ """Returns the list of supported complex visualization modes.
+
+ See :class:`ComplexMode` and :meth:`setComplexMode`.
+
+ :rtype: List[ComplexMode]
+ """
+ if cls._SUPPORTED_COMPLEX_MODES is None:
+ return cls.ComplexMode.members()
+ else:
+ return cls._SUPPORTED_COMPLEX_MODES
+
+
+class ScatterVisualizationMixIn(ItemMixInBase):
+ """Mix-in class for scatter plot visualization modes"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = None
+ """Allows to override supported Visualizations"""
+
+ @enum.unique
+ class Visualization(_Enum):
+ """Different modes of scatter plot visualizations"""
+
+ POINTS = 'points'
+ """Display scatter plot as a point cloud"""
+
+ LINES = 'lines'
+ """Display scatter plot as a wireframe.
+
+ This is based on Delaunay triangulation
+ """
+
+ SOLID = 'solid'
+ """Display scatter plot as a set of filled triangles.
+
+ This is based on Delaunay triangulation
+ """
+
+ REGULAR_GRID = 'regular_grid'
+ """Display scatter plot as an image.
+
+ It expects the points to be the intersection of a regular grid,
+ and the order of points following that of an image.
+ First line, then second one, and always in the same direction
+ (either all lines from left to right or all from right to left).
+ """
+
+ IRREGULAR_GRID = 'irregular_grid'
+ """Display scatter plot as contiguous quadrilaterals.
+
+ It expects the points to be the intersection of an irregular grid,
+ and the order of points following that of an image.
+ First line, then second one, and always in the same direction
+ (either all lines from left to right or all from right to left).
+ """
+
+ BINNED_STATISTIC = 'binned_statistic'
+ """Display scatter plot as 2D binned statistic (i.e., generalized histogram).
+ """
+
+ @enum.unique
+ class VisualizationParameter(_Enum):
+ """Different parameter names for scatter plot visualizations"""
+
+ GRID_MAJOR_ORDER = 'grid_major_order'
+ """The major order of points in the regular grid.
+
+ Either 'row' (row-major, fast X) or 'column' (column-major, fast Y).
+ """
+
+ GRID_BOUNDS = 'grid_bounds'
+ """The expected range in data coordinates of the regular grid.
+
+ A 2-tuple of 2-tuple: (begin (x, y), end (x, y)).
+ This provides the data coordinates of the first point and the expected
+ last on.
+ As for `GRID_SHAPE`, this can be wider than the current data.
+ """
+
+ GRID_SHAPE = 'grid_shape'
+ """The expected size of the regular grid (height, width).
+
+ The given shape can be wider than the number of points,
+ in which case the grid is not fully filled.
+ """
+
+ BINNED_STATISTIC_SHAPE = 'binned_statistic_shape'
+ """The number of bins in each dimension (height, width).
+ """
+
+ BINNED_STATISTIC_FUNCTION = 'binned_statistic_function'
+ """The reduction function to apply to each bin (str).
+
+ Available reduction functions are: 'mean' (default), 'count', 'sum'.
+ """
+
+ DATA_BOUNDS_HINT = 'data_bounds_hint'
+ """The expected bounds of the data in data coordinates.
+
+ A 2-tuple of 2-tuple: ((ymin, ymax), (xmin, xmax)).
+ This provides a hint for the data ranges in both dimensions.
+ It is eventually enlarged with actually data ranges.
+
+ WARNING: dimension 0 i.e., Y first.
+ """
+
+ _SUPPORTED_VISUALIZATION_PARAMETER_VALUES = {
+ VisualizationParameter.GRID_MAJOR_ORDER: ('row', 'column'),
+ VisualizationParameter.BINNED_STATISTIC_FUNCTION: ('mean', 'count', 'sum'),
+ }
+ """Supported visualization parameter values.
+
+ Defined for parameters with a set of acceptable values.
+ """
+
+ def __init__(self):
+ self.__visualization = self.Visualization.POINTS
+ self.__parameters = dict(# Init parameters to None
+ (parameter, None) for parameter in self.VisualizationParameter)
+ self.__parameters[self.VisualizationParameter.BINNED_STATISTIC_FUNCTION] = 'mean'
+
+ @classmethod
+ def supportedVisualizations(cls):
+ """Returns the list of supported scatter visualization modes.
+
+ See :meth:`setVisualization`
+
+ :rtype: List[Visualization]
+ """
+ if cls._SUPPORTED_SCATTER_VISUALIZATION is None:
+ return cls.Visualization.members()
+ else:
+ return cls._SUPPORTED_SCATTER_VISUALIZATION
+
+ @classmethod
+ def supportedVisualizationParameterValues(cls, parameter):
+ """Returns the list of supported scatter visualization modes.
+
+ See :meth:`VisualizationParameters`
+
+ :param VisualizationParameter parameter:
+ This parameter for which to retrieve the supported values.
+ :returns: tuple of supported of values or None if not defined.
+ """
+ parameter = cls.VisualizationParameter(parameter)
+ return cls._SUPPORTED_VISUALIZATION_PARAMETER_VALUES.get(
+ parameter, None)
+
+ def setVisualization(self, mode):
+ """Set the scatter plot visualization mode to use.
+
+ See :class:`Visualization` for all possible values,
+ and :meth:`supportedVisualizations` for supported ones.
+
+ :param Union[str,Visualization] mode:
+ The visualization mode to use.
+ :return: True if value was set, False if is was already set
+ :rtype: bool
+ """
+ mode = self.Visualization.from_value(mode)
+ assert mode in self.supportedVisualizations()
+
+ if mode != self.__visualization:
+ self.__visualization = mode
+
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+ return True
+ else:
+ return False
+
+ def getVisualization(self):
+ """Returns the scatter plot visualization mode in use.
+
+ :rtype: Visualization
+ """
+ return self.__visualization
+
+ def setVisualizationParameter(self, parameter, value=None):
+ """Set the given visualization parameter.
+
+ :param Union[str,VisualizationParameter] parameter:
+ The name of the parameter to set
+ :param value: The value to use for this parameter
+ Set to None to automatically set the parameter
+ :raises ValueError: If parameter is not supported
+ :return: True if parameter was set, False if is was already set
+ :rtype: bool
+ :raise ValueError: If value is not supported
+ """
+ parameter = self.VisualizationParameter.from_value(parameter)
+
+ if self.__parameters[parameter] != value:
+ validValues = self.supportedVisualizationParameterValues(parameter)
+ if validValues is not None and value not in validValues:
+ raise ValueError("Unsupported parameter value: %s" % str(value))
+
+ self.__parameters[parameter] = value
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+ return True
+ return False
+
+ def getVisualizationParameter(self, parameter):
+ """Returns the value of the given visualization parameter.
+
+ This method returns the parameter as set by
+ :meth:`setVisualizationParameter`.
+
+ :param parameter: The name of the parameter to retrieve
+ :returns: The value previously set or None if automatically set
+ :raises ValueError: If parameter is not supported
+ """
+ if parameter not in self.VisualizationParameter:
+ raise ValueError("parameter not supported: %s", parameter)
+
+ return self.__parameters[parameter]
+
+ def getCurrentVisualizationParameter(self, parameter):
+ """Returns the current value of the given visualization parameter.
+
+ If the parameter was set by :meth:`setVisualizationParameter` to
+ a value that is not None, this value is returned;
+ else the current value that is automatically computed is returned.
+
+ :param parameter: The name of the parameter to retrieve
+ :returns: The current value (either set or automatically computed)
+ :raises ValueError: If parameter is not supported
+ """
+ # Override in subclass to provide automatically computed parameters
+ return self.getVisualizationParameter(parameter)
+
+
+class PointsBase(DataItem, SymbolMixIn, AlphaMixIn):
+ """Base class for :class:`Curve` and :class:`Scatter`"""
+ # note: _logFilterData must be overloaded if you overload
+ # getData to change its signature
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for points,
+ on top of images."""
+
+ def __init__(self):
+ DataItem.__init__(self)
+ SymbolMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ self._x = ()
+ self._y = ()
+ self._xerror = None
+ self._yerror = None
+
+ # Store filtered data for x > 0 and/or y > 0
+ self._filteredCache = {}
+ self._clippedCache = {}
+
+ # Store bounds depending on axes filtering >0:
+ # key is (isXPositiveFilter, isYPositiveFilter)
+ self._boundsCache = {}
+
+ @staticmethod
+ def _logFilterError(value, error):
+ """Filter/convert error values if they go <= 0.
+
+ Replace error leading to negative values by nan
+
+ :param numpy.ndarray value: 1D array of values
+ :param numpy.ndarray error:
+ Array of errors: scalar, N, Nx1 or 2xN or None.
+ :return: Filtered error so error bars are never negative
+ """
+ if error is not None:
+ # Convert Nx1 to N
+ if error.ndim == 2 and error.shape[1] == 1 and len(value) != 1:
+ error = numpy.ravel(error)
+
+ # Supports error being scalar, N or 2xN array
+ valueMinusError = value - numpy.atleast_2d(error)[0]
+ errorClipped = numpy.isnan(valueMinusError)
+ mask = numpy.logical_not(errorClipped)
+ errorClipped[mask] = valueMinusError[mask] <= 0
+
+ if numpy.any(errorClipped): # Need filtering
+
+ # expand errorbars to 2xN
+ if error.size == 1: # Scalar
+ error = numpy.full(
+ (2, len(value)), error, dtype=numpy.float64)
+
+ elif error.ndim == 1: # N array
+ newError = numpy.empty((2, len(value)),
+ dtype=numpy.float64)
+ newError[0,:] = error
+ newError[1,:] = error
+ error = newError
+
+ elif error.size == 2 * len(value): # 2xN array
+ error = numpy.array(
+ error, copy=True, dtype=numpy.float64)
+
+ else:
+ _logger.error("Unhandled error array")
+ return error
+
+ error[0, errorClipped] = numpy.nan
+
+ return error
+
+ def _getClippingBoolArray(self, xPositive, yPositive):
+ """Compute a boolean array to filter out points with negative
+ coordinates on log axes.
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :rtype: boolean numpy.ndarray
+ """
+ assert xPositive or yPositive
+ if (xPositive, yPositive) not in self._clippedCache:
+ xclipped, yclipped = False, False
+
+ if xPositive:
+ x = self.getXData(copy=False)
+ with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
+ xclipped = x <= 0
+
+ if yPositive:
+ y = self.getYData(copy=False)
+ with numpy.errstate(invalid='ignore'): # Ignore NaN warnings
+ yclipped = y <= 0
+
+ self._clippedCache[(xPositive, yPositive)] = \
+ numpy.logical_or(xclipped, yclipped)
+ return self._clippedCache[(xPositive, yPositive)]
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filter arrays or unchanged object if filtering not needed
+ :rtype: (x, y, xerror, yerror)
+ """
+ x = self.getXData(copy=False)
+ y = self.getYData(copy=False)
+ xerror = self.getXErrorData(copy=False)
+ yerror = self.getYErrorData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ x = numpy.array(x, copy=True, dtype=numpy.float64)
+ x[clipped] = numpy.nan
+ y = numpy.array(y, copy=True, dtype=numpy.float64)
+ y[clipped] = numpy.nan
+
+ if xPositive and xerror is not None:
+ xerror = self._logFilterError(x, xerror)
+
+ if yPositive and yerror is not None:
+ yerror = self._logFilterError(y, yerror)
+
+ return x, y, xerror, yerror
+
+ def _getBounds(self):
+ if self.getXData(copy=False).size == 0: # Empty data
+ return None
+
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ else:
+ xPositive = False
+ yPositive = False
+
+ # TODO bounds do not take error bars into account
+ if (xPositive, yPositive) not in self._boundsCache:
+ # use the getData class method because instance method can be
+ # overloaded to return additional arrays
+ data = PointsBase.getData(self, copy=False, displayed=True)
+ if len(data) == 5:
+ # hack to avoid duplicating caching mechanism in Scatter
+ # (happens when cached data is used, caching done using
+ # Scatter._logFilterData)
+ x, y, _xerror, _yerror = data[0], data[1], data[3], data[4]
+ else:
+ x, y, _xerror, _yerror = data
+
+ xmin, xmax = min_max(x, finite=True)
+ ymin, ymax = min_max(y, finite=True)
+ self._boundsCache[(xPositive, yPositive)] = tuple([
+ (bound if bound is not None else numpy.nan)
+ for bound in (xmin, xmax, ymin, ymax)])
+ return self._boundsCache[(xPositive, yPositive)]
+
+ def _getCachedData(self):
+ """Return cached filtered data if applicable,
+ i.e. if any axis is in log scale.
+ Return None if caching is not applicable."""
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ if xPositive or yPositive:
+ # At least one axis has log scale, filter data
+ if (xPositive, yPositive) not in self._filteredCache:
+ self._filteredCache[(xPositive, yPositive)] = \
+ self._logFilterData(xPositive, yPositive)
+ return self._filteredCache[(xPositive, yPositive)]
+ return None
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y values of the curve points and xerror, yerror
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, xerror, yerror)
+ :rtype: 4-tuple of numpy.ndarray
+ """
+ if displayed: # filter data according to plot state
+ cached_data = self._getCachedData()
+ if cached_data is not None:
+ return cached_data
+
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy))
+
+ def getXData(self, copy=True):
+ """Returns the x coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._x, copy=copy)
+
+ def getYData(self, copy=True):
+ """Returns the y coordinates of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._y, copy=copy)
+
+ def getXErrorData(self, copy=True):
+ """Returns the x error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray, float or None
+ """
+ if isinstance(self._xerror, numpy.ndarray):
+ return numpy.array(self._xerror, copy=copy)
+ else:
+ return self._xerror # float or None
+
+ def getYErrorData(self, copy=True):
+ """Returns the y error of the points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray, float or None
+ """
+ if isinstance(self._yerror, numpy.ndarray):
+ return numpy.array(self._yerror, copy=copy)
+ else:
+ return self._yerror # float or None
+
+ def setData(self, x, y, xerror=None, yerror=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ x = numpy.array(x, copy=copy)
+ y = numpy.array(y, copy=copy)
+ assert len(x) == len(y)
+ assert x.ndim == y.ndim == 1
+
+ # Convert complex data
+ if numpy.iscomplexobj(x):
+ _logger.warning(
+ 'Converting x data to absolute value to plot it.')
+ x = numpy.absolute(x)
+ if numpy.iscomplexobj(y):
+ _logger.warning(
+ 'Converting y data to absolute value to plot it.')
+ y = numpy.absolute(y)
+
+ if xerror is not None:
+ if isinstance(xerror, abc.Iterable):
+ xerror = numpy.array(xerror, copy=copy)
+ if numpy.iscomplexobj(xerror):
+ _logger.warning(
+ 'Converting xerror data to absolute value to plot it.')
+ xerror = numpy.absolute(xerror)
+ else:
+ xerror = float(xerror)
+ if yerror is not None:
+ if isinstance(yerror, abc.Iterable):
+ yerror = numpy.array(yerror, copy=copy)
+ if numpy.iscomplexobj(yerror):
+ _logger.warning(
+ 'Converting yerror data to absolute value to plot it.')
+ yerror = numpy.absolute(yerror)
+ else:
+ yerror = float(yerror)
+ # TODO checks on xerror, yerror
+ self._x, self._y = x, y
+ self._xerror, self._yerror = xerror, yerror
+
+ self._boundsCache = {} # Reset cached bounds
+ self._filteredCache = {} # Reset cached filtered data
+ self._clippedCache = {} # Reset cached clipped bool array
+
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+
+class BaselineMixIn(object):
+ """Base class for Baseline mix-in"""
+
+ def __init__(self, baseline=None):
+ self._baseline = baseline
+
+ def _setBaseline(self, baseline):
+ """
+ Set baseline value
+
+ :param baseline: baseline value(s)
+ :type: Union[None,float,numpy.ndarray]
+ """
+ if (isinstance(baseline, abc.Iterable)):
+ baseline = numpy.array(baseline)
+ self._baseline = baseline
+
+ def getBaseline(self, copy=True):
+ """
+
+ :param bool copy:
+ :return: histogram baseline
+ :rtype: Union[None,float,numpy.ndarray]
+ """
+ if isinstance(self._baseline, numpy.ndarray):
+ return numpy.array(self._baseline, copy=True)
+ else:
+ return self._baseline
+
+
+class _Style:
+ """Object which store styles"""
+
+
+class HighlightedMixIn(ItemMixInBase):
+
+ def __init__(self):
+ self._highlightStyle = self._DEFAULT_HIGHLIGHT_STYLE
+ self._highlighted = False
+
+ def isHighlighted(self):
+ """Returns True if curve is highlighted.
+
+ :rtype: bool
+ """
+ return self._highlighted
+
+ def setHighlighted(self, highlighted):
+ """Set the highlight state of the curve
+
+ :param bool highlighted:
+ """
+ highlighted = bool(highlighted)
+ if highlighted != self._highlighted:
+ self._highlighted = highlighted
+ # TODO inefficient: better to use backend's setCurveColor
+ self._updated(ItemChangedType.HIGHLIGHTED)
+
+ def getHighlightedStyle(self):
+ """Returns the highlighted style in use
+
+ :rtype: CurveStyle
+ """
+ return self._highlightStyle
+
+ def setHighlightedStyle(self, style):
+ """Set the style to use for highlighting
+
+ :param CurveStyle style: New style to use
+ """
+ previous = self.getHighlightedStyle()
+ if style != previous:
+ assert isinstance(style, _Style)
+ self._highlightStyle = style
+ self._updated(ItemChangedType.HIGHLIGHTED_STYLE)
+
+ # Backward compatibility event
+ if previous.getColor() != style.getColor():
+ self._updated(ItemChangedType.HIGHLIGHTED_COLOR)
diff --git a/src/silx/gui/plot/items/curve.py b/src/silx/gui/plot/items/curve.py
new file mode 100644
index 0000000..7cbe26e
--- /dev/null
+++ b/src/silx/gui/plot/items/curve.py
@@ -0,0 +1,325 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Curve` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import logging
+
+import numpy
+
+from ....utils.deprecation import deprecated
+from ... import colors
+from .core import (PointsBase, LabelsMixIn, ColorMixIn, YAxisMixIn,
+ FillMixIn, LineMixIn, SymbolMixIn, ItemChangedType,
+ BaselineMixIn, HighlightedMixIn, _Style)
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CurveStyle(_Style):
+ """Object storing the style of a curve.
+
+ Set a value to None to use the default
+
+ :param color: Color
+ :param Union[str,None] linestyle: Style of the line
+ :param Union[float,None] linewidth: Width of the line
+ :param Union[str,None] symbol: Symbol for markers
+ :param Union[float,None] symbolsize: Size of the markers
+ """
+
+ def __init__(self, color=None, linestyle=None, linewidth=None,
+ symbol=None, symbolsize=None):
+ if color is None:
+ self._color = None
+ else:
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ else: # array-like expected
+ color = numpy.array(color, copy=False)
+ if color.ndim == 1: # Array is 1D, this is a single color
+ color = colors.rgba(color)
+ self._color = color
+
+ if linestyle is not None:
+ assert linestyle in LineMixIn.getSupportedLineStyles()
+ self._linestyle = linestyle
+
+ self._linewidth = None if linewidth is None else float(linewidth)
+
+ if symbol is not None:
+ assert symbol in SymbolMixIn.getSupportedSymbols()
+ self._symbol = symbol
+
+ self._symbolsize = None if symbolsize is None else float(symbolsize)
+
+ def getColor(self, copy=True):
+ """Returns the color or None if not set.
+
+ :param bool copy: True to get a copy (default),
+ False to get internal representation (do not modify!)
+
+ :rtype: Union[List[float],None]
+ """
+ if isinstance(self._color, numpy.ndarray):
+ return numpy.array(self._color, copy=copy)
+ else:
+ return self._color
+
+ def getLineStyle(self):
+ """Return the type of the line or None if not set.
+
+ Type of line::
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ :rtype: Union[str,None]
+ """
+ return self._linestyle
+
+ def getLineWidth(self):
+ """Return the curve line width in pixels or None if not set.
+
+ :rtype: Union[float,None]
+ """
+ return self._linewidth
+
+ def getSymbol(self):
+ """Return the point marker type.
+
+ Marker type::
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ :rtype: Union[str,None]
+ """
+ return self._symbol
+
+ def getSymbolSize(self):
+ """Return the point marker size in points.
+
+ :rtype: Union[float,None]
+ """
+ return self._symbolsize
+
+ def __eq__(self, other):
+ if isinstance(other, CurveStyle):
+ return (numpy.array_equal(self.getColor(), other.getColor()) and
+ self.getLineStyle() == other.getLineStyle() and
+ self.getLineWidth() == other.getLineWidth() and
+ self.getSymbol() == other.getSymbol() and
+ self.getSymbolSize() == other.getSymbolSize())
+ else:
+ return False
+
+
+class Curve(PointsBase, ColorMixIn, YAxisMixIn, FillMixIn, LabelsMixIn,
+ LineMixIn, BaselineMixIn, HighlightedMixIn):
+ """Description of a curve"""
+
+ _DEFAULT_Z_LAYER = 1
+ """Default overlay layer for curves"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for curves"""
+
+ _DEFAULT_LINEWIDTH = 1.
+ """Default line width of the curve"""
+
+ _DEFAULT_LINESTYLE = '-'
+ """Default line style of the curve"""
+
+ _DEFAULT_HIGHLIGHT_STYLE = CurveStyle(color='black')
+ """Default highlight style of the item"""
+
+ _DEFAULT_BASELINE = None
+
+ def __init__(self):
+ PointsBase.__init__(self)
+ ColorMixIn.__init__(self)
+ YAxisMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LabelsMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ BaselineMixIn.__init__(self)
+ HighlightedMixIn.__init__(self)
+
+ self._setBaseline(Curve._DEFAULT_BASELINE)
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True)
+
+ if len(xFiltered) == 0 or not numpy.any(numpy.isfinite(xFiltered)):
+ return None # No data to display, do not add renderer to backend
+
+ style = self.getCurrentStyle()
+
+ return backend.addCurve(xFiltered, yFiltered,
+ color=style.getColor(),
+ symbol=style.getSymbol(),
+ linestyle=style.getLineStyle(),
+ linewidth=style.getLineWidth(),
+ yaxis=self.getYAxis(),
+ xerror=xerror,
+ yerror=yerror,
+ fill=self.isFill(),
+ alpha=self.getAlpha(),
+ symbolsize=style.getSymbolSize(),
+ baseline=self.getBaseline(copy=False))
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getXData(copy=False)
+ elif item == 1:
+ return self.getYData(copy=False)
+ elif item == 2:
+ return self.getName()
+ elif item == 3:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 4:
+ params = {
+ 'info': self.getInfo(),
+ 'color': self.getColor(),
+ 'symbol': self.getSymbol(),
+ 'linewidth': self.getLineWidth(),
+ 'linestyle': self.getLineStyle(),
+ 'xlabel': self.getXLabel(),
+ 'ylabel': self.getYLabel(),
+ 'yaxis': self.getYAxis(),
+ 'xerror': self.getXErrorData(copy=False),
+ 'yerror': self.getYErrorData(copy=False),
+ 'z': self.getZValue(),
+ 'selectable': self.isSelectable(),
+ 'fill': self.isFill(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s", str(item))
+
+ @deprecated(replacement='Curve.getHighlightedStyle().getColor()',
+ since_version='0.9.0')
+ def getHighlightedColor(self):
+ """Returns the RGBA highlight color of the item
+
+ :rtype: 4-tuple of float in [0, 1]
+ """
+ return self.getHighlightedStyle().getColor()
+
+ @deprecated(replacement='Curve.setHighlightedStyle()',
+ since_version='0.9.0')
+ def setHighlightedColor(self, color):
+ """Set the color to use when highlighted
+
+ :param color: color(s) to be used for highlight
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ """
+ self.setHighlightedStyle(CurveStyle(color))
+
+ def getCurrentStyle(self):
+ """Returns the current curve style.
+
+ Curve style depends on curve highlighting
+
+ :rtype: CurveStyle
+ """
+ if self.isHighlighted():
+ style = self.getHighlightedStyle()
+ color = style.getColor()
+ linestyle = style.getLineStyle()
+ linewidth = style.getLineWidth()
+ symbol = style.getSymbol()
+ symbolsize = style.getSymbolSize()
+
+ return CurveStyle(
+ color=self.getColor() if color is None else color,
+ linestyle=self.getLineStyle() if linestyle is None else linestyle,
+ linewidth=self.getLineWidth() if linewidth is None else linewidth,
+ symbol=self.getSymbol() if symbol is None else symbol,
+ symbolsize=self.getSymbolSize() if symbolsize is None else symbolsize)
+
+ else:
+ return CurveStyle(color=self.getColor(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ symbol=self.getSymbol(),
+ symbolsize=self.getSymbolSize())
+
+ @deprecated(replacement='Curve.getCurrentStyle()',
+ since_version='0.9.0')
+ def getCurrentColor(self):
+ """Returns the current color of the curve.
+
+ This color is either the color of the curve or the highlighted color,
+ depending on the highlight state.
+
+ :rtype: 4-tuple of float in [0, 1]
+ """
+ return self.getCurrentStyle().getColor()
+
+ def setData(self, x, y, xerror=None, yerror=None, baseline=None, copy=True):
+ """Set the data of the curve.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values.
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param baseline: curve baseline
+ :type baseline: Union[None,float,numpy.ndarray]
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ PointsBase.setData(self, x=x, y=y, xerror=xerror, yerror=yerror,
+ copy=copy)
+ self._setBaseline(baseline=baseline)
diff --git a/silx/gui/plot/items/histogram.py b/src/silx/gui/plot/items/histogram.py
index 16bbefa..16bbefa 100644
--- a/silx/gui/plot/items/histogram.py
+++ b/src/silx/gui/plot/items/histogram.py
diff --git a/src/silx/gui/plot/items/image.py b/src/silx/gui/plot/items/image.py
new file mode 100644
index 0000000..5cc719b
--- /dev/null
+++ b/src/silx/gui/plot/items/image.py
@@ -0,0 +1,641 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageData` and :class:`ImageRgba` items
+of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/12/2020"
+
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+
+import numpy
+
+from ....utils.proxy import docstring
+from .core import (DataItem, LabelsMixIn, DraggableMixIn, ColormapMixIn,
+ AlphaMixIn, ItemChangedType)
+
+_logger = logging.getLogger(__name__)
+
+
+def _convertImageToRgba32(image, copy=True):
+ """Convert an RGB or RGBA image to RGBA32.
+
+ It converts from floats in [0, 1], bool, integer and uint in [0, 255]
+
+ If the input image is already an RGBA32 image,
+ the returned image shares the same data.
+
+ :param image: Image to convert to
+ :type image: numpy.ndarray with 3 dimensions: height, width, color channels
+ :param bool copy: True (Default) to get a copy, False, avoid copy if possible
+ :return: The image converted to RGBA32 with dimension: (height, width, 4)
+ :rtype: numpy.ndarray of uint8
+ """
+ assert image.ndim == 3
+ assert image.shape[-1] in (3, 4)
+
+ # Convert type to uint8
+ if image.dtype.name != 'uint8':
+ if image.dtype.kind == 'f': # Float in [0, 1]
+ image = (numpy.clip(image, 0., 1.) * 255).astype(numpy.uint8)
+ elif image.dtype.kind == 'b': # boolean
+ image = image.astype(numpy.uint8) * 255
+ elif image.dtype.kind in ('i', 'u'): # int, uint
+ image = numpy.clip(image, 0, 255).astype(numpy.uint8)
+ else:
+ raise ValueError('Unsupported image dtype: %s', image.dtype.name)
+ copy = False # A copy as already been done, avoid next one
+
+ # Convert RGB to RGBA
+ if image.shape[-1] == 3:
+ new_image = numpy.empty((image.shape[0], image.shape[1], 4),
+ dtype=numpy.uint8)
+ new_image[:,:,:3] = image
+ new_image[:,:, 3] = 255
+ return new_image # This is a copy anyway
+ else:
+ return numpy.array(image, copy=copy)
+
+
+class ImageBase(DataItem, LabelsMixIn, DraggableMixIn, AlphaMixIn):
+ """Description of an image
+
+ :param numpy.ndarray data: Initial image data
+ """
+
+ def __init__(self, data=None, mask=None):
+ DataItem.__init__(self)
+ LabelsMixIn.__init__(self)
+ DraggableMixIn.__init__(self)
+ AlphaMixIn.__init__(self)
+ if data is None:
+ data = numpy.zeros((0, 0, 4), dtype=numpy.uint8)
+ self._data = data
+ self._mask = mask
+ self.__valueDataCache = None # Store default data
+ self._origin = (0., 0.)
+ self._scale = (1., 1.)
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if isinstance(item, slice):
+ return [self[index] for index in range(*item.indices(5))]
+ elif item == 0:
+ return self.getData(copy=False)
+ elif item == 1:
+ return self.getName()
+ elif item == 2:
+ info = self.getInfo(copy=False)
+ return {} if info is None else info
+ elif item == 3:
+ return None
+ elif item == 4:
+ params = {
+ 'info': self.getInfo(),
+ 'origin': self.getOrigin(),
+ 'scale': self.getScale(),
+ 'z': self.getZValue(),
+ 'selectable': self.isSelectable(),
+ 'draggable': self.isDraggable(),
+ 'colormap': None,
+ 'xlabel': self.getXLabel(),
+ 'ylabel': self.getYLabel(),
+ }
+ return params
+ else:
+ raise IndexError("Index out of range: %s" % str(item))
+
+ def _isPlotLinear(self, plot):
+ """Return True if plot only uses linear scale for both of x and y
+ axes."""
+ linear = plot.getXAxis().LINEAR
+ if plot.getXAxis().getScale() != linear:
+ return False
+ if plot.getYAxis().getScale() != linear:
+ return False
+ return True
+
+ def _getBounds(self):
+ if self.getData(copy=False).size == 0: # Empty data
+ return None
+
+ height, width = self.getData(copy=False).shape[:2]
+ origin = self.getOrigin()
+ scale = self.getScale()
+ # Taking care of scale might be < 0
+ xmin, xmax = origin[0], origin[0] + width * scale[0]
+ if xmin > xmax:
+ xmin, xmax = xmax, xmin
+ # Taking care of scale might be < 0
+ ymin, ymax = origin[1], origin[1] + height * scale[1]
+ if ymin > ymax:
+ ymin, ymax = ymax, ymin
+
+ plot = self.getPlot()
+ if plot is not None and not self._isPlotLinear(plot):
+ return None
+ else:
+ return xmin, xmax, ymin, ymax
+
+ @docstring(DraggableMixIn)
+ def drag(self, from_, to):
+ origin = self.getOrigin()
+ self.setOrigin((origin[0] + to[0] - from_[0],
+ origin[1] + to[1] - from_[1]))
+
+ def getData(self, copy=True):
+ """Returns the image data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._data, copy=copy)
+
+ def setData(self, data):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ """
+ previousShape = self._data.shape
+ self._data = data
+ self._valueDataChanged()
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ if (self.getMaskData(copy=False) is not None and
+ previousShape != self._data.shape):
+ # Data shape changed, so mask shape changes.
+ # Send event, mask is lazily updated in getMaskData
+ self._updated(ItemChangedType.MASK)
+
+ def getMaskData(self, copy=True):
+ """Returns the mask data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._mask is None:
+ return None
+
+ # Update mask if it does not match data shape
+ shape = self.getData(copy=False).shape[:2]
+ if self._mask.shape != shape:
+ # Clip/extend mask to match data
+ newMask = numpy.zeros(shape, dtype=self._mask.dtype)
+ newMask[:self._mask.shape[0], :self._mask.shape[1]] = self._mask[:shape[0], :shape[1]]
+ self._mask = newMask
+
+ return numpy.array(self._mask, copy=copy)
+
+ def setMaskData(self, mask, copy=True):
+ """Set the image data
+
+ :param numpy.ndarray data:
+ :param bool copy: True (Default) to make a copy,
+ False to use as is (do not modify!)
+ """
+ if mask is not None:
+ mask = numpy.array(mask, copy=copy)
+
+ shape = self.getData(copy=False).shape[:2]
+ if mask.shape != shape:
+ _logger.warning("Inconsistent shape between mask and data %s, %s", mask.shape, shape)
+ # Clip/extent is done lazily in getMaskData
+ elif self._mask is None:
+ return # No update
+
+ self._mask = mask
+ self._valueDataChanged()
+ self._updated(ItemChangedType.MASK)
+
+ def _valueDataChanged(self):
+ """Clear cache of default data array"""
+ self.__valueDataCache = None
+
+ def _getValueData(self, copy=True):
+ """Return data used by :meth:`getValueData`
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ return self.getData(copy=copy)
+
+ def getValueData(self, copy=True):
+ """Return data (converted to int or float) with mask applied.
+
+ Masked values are set to Not-A-Number.
+ It returns a 2D array of values (int or float).
+
+ :param bool copy:
+ :rtype: numpy.ndarray
+ """
+ if self.__valueDataCache is None:
+ data = self._getValueData(copy=False)
+ mask = self.getMaskData(copy=False)
+ if mask is not None:
+ if numpy.issubdtype(data.dtype, numpy.floating):
+ dtype = data.dtype
+ else:
+ dtype = numpy.float64
+ data = numpy.array(data, dtype=dtype, copy=True)
+ data[mask != 0] = numpy.NaN
+ self.__valueDataCache = data
+ return numpy.array(self.__valueDataCache, copy=copy)
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ raise NotImplementedError('This MUST be implemented in sub-class')
+
+ def getOrigin(self):
+ """Returns the offset from origin at which to display the image.
+
+ :rtype: 2-tuple of float
+ """
+ return self._origin
+
+ def setOrigin(self, origin):
+ """Set the offset from origin at which to display the image.
+
+ :param origin: (ox, oy) Offset from origin
+ :type origin: float or 2-tuple of float
+ """
+ if isinstance(origin, abc.Sequence):
+ origin = float(origin[0]), float(origin[1])
+ else: # single value origin
+ origin = float(origin), float(origin)
+ if origin != self._origin:
+ self._origin = origin
+ self._boundsChanged()
+ self._updated(ItemChangedType.POSITION)
+
+ def getScale(self):
+ """Returns the scale of the image in data coordinates.
+
+ :rtype: 2-tuple of float
+ """
+ return self._scale
+
+ def setScale(self, scale):
+ """Set the scale of the image
+
+ :param scale: (sx, sy) Scale of the image
+ :type scale: float or 2-tuple of float
+ """
+ if isinstance(scale, abc.Sequence):
+ scale = float(scale[0]), float(scale[1])
+ else: # single value scale
+ scale = float(scale), float(scale)
+
+ if scale != self._scale:
+ self._scale = scale
+ self._boundsChanged()
+ self._updated(ItemChangedType.SCALE)
+
+
+class ImageDataBase(ImageBase, ColormapMixIn):
+ """Base class for colormapped 2D data image"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0), dtype=numpy.float32))
+ ColormapMixIn.__init__(self)
+
+ def _getColormapForRendering(self):
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+ return colormap
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: Array of uint8 of shape (height, width, 4)
+ :rtype: numpy.ndarray
+ """
+ return self.getColormap().applyToData(self)
+
+ def setData(self, data, copy=True):
+ """"Set the image data
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+ if data.dtype.kind == 'b':
+ _logger.warning(
+ 'Converting boolean image to int8 to plot it.')
+ data = numpy.array(data, copy=False, dtype=numpy.int8)
+ elif numpy.iscomplexobj(data):
+ _logger.warning(
+ 'Converting complex image to absolute value to plot it.')
+ data = numpy.absolute(data)
+ super().setData(data)
+
+ def _updated(self, event=None, checkVisibility=True):
+ # Synchronizes colormapped data if changed
+ if event in (ItemChangedType.DATA, ItemChangedType.MASK):
+ self._setColormappedData(self.getValueData(copy=False), copy=False)
+ super()._updated(event=event, checkVisibility=checkVisibility)
+
+
+class ImageData(ImageDataBase):
+ """Description of a data image with a colormap"""
+
+ def __init__(self):
+ ImageDataBase.__init__(self)
+ self._alternativeImage = None
+ self.__alpha = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ if (self.getAlternativeImageData(copy=False) is not None or
+ self.getAlphaData(copy=False) is not None):
+ dataToUse = self.getRgbaImageData(copy=False)
+ else:
+ dataToUse = self.getData(copy=False)
+
+ if dataToUse.size == 0:
+ return None # No data to display
+
+ return backend.addImage(dataToUse,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha())
+
+ def __getitem__(self, item):
+ """Compatibility with PyMca and silx <= 0.4.0"""
+ if item == 3:
+ return self.getAlternativeImageData(copy=False)
+
+ params = ImageBase.__getitem__(self, item)
+ if item == 4:
+ params['colormap'] = self.getColormap()
+
+ return params
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: Array of uint8 of shape (height, width, 4)
+ :rtype: numpy.ndarray
+ """
+ alternative = self.getAlternativeImageData(copy=False)
+ if alternative is not None:
+ return _convertImageToRgba32(alternative, copy=copy)
+ else:
+ image = super().getRgbaImageData(copy=copy)
+ alphaImage = self.getAlphaData(copy=False)
+ if alphaImage is not None:
+ # Apply transparency
+ image[:,:, 3] = image[:,:, 3] * alphaImage
+ return image
+
+ def getAlternativeImageData(self, copy=True):
+ """Get the optional RGBA image that is displayed instead of the data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self._alternativeImage is None:
+ return None
+ else:
+ return numpy.array(self._alternativeImage, copy=copy)
+
+ def getAlphaData(self, copy=True):
+ """Get the optional transparency image applied on the data
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: Union[None,numpy.ndarray]
+ """
+ if self.__alpha is None:
+ return None
+ else:
+ return numpy.array(self.__alpha, copy=copy)
+
+ def setData(self, data, alternative=None, alpha=None, copy=True):
+ """"Set the image data and optionally an alternative RGB(A) representation
+
+ :param numpy.ndarray data: Data array with 2 dimensions (h, w)
+ :param alternative: RGB(A) image to display instead of data,
+ shape: (h, w, 3 or 4)
+ :type alternative: Union[None,numpy.ndarray]
+ :param alpha: An array of transparency value in [0, 1] to use for
+ display with shape: (h, w)
+ :type alpha: Union[None,numpy.ndarray]
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ if alternative is not None:
+ alternative = numpy.array(alternative, copy=copy)
+ assert alternative.ndim == 3
+ assert alternative.shape[2] in (3, 4)
+ assert alternative.shape[:2] == data.shape[:2]
+ self._alternativeImage = alternative
+
+ if alpha is not None:
+ alpha = numpy.array(alpha, copy=copy)
+ assert alpha.shape == data.shape
+ if alpha.dtype.kind != 'f':
+ alpha = alpha.astype(numpy.float32)
+ if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
+ alpha = numpy.clip(alpha, 0., 1.)
+ self.__alpha = alpha
+
+ super().setData(data)
+
+
+class ImageRgba(ImageBase):
+ """Description of an RGB(A) image"""
+
+ def __init__(self):
+ ImageBase.__init__(self, numpy.zeros((0, 0, 4), dtype=numpy.uint8))
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ data = self.getData(copy=False)
+
+ if data.size == 0:
+ return None # No data to display
+
+ return backend.addImage(data,
+ origin=self.getOrigin(),
+ scale=self.getScale(),
+ colormap=None,
+ alpha=self.getAlpha())
+
+ def getRgbaImageData(self, copy=True):
+ """Get the displayed RGB(A) image
+
+ :returns: numpy.ndarray of uint8 of shape (height, width, 4)
+ """
+ return _convertImageToRgba32(self.getData(copy=False), copy=copy)
+
+ def setData(self, data, copy=True):
+ """Set the image data
+
+ :param data: RGB(A) image data to set
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 3
+ assert data.shape[-1] in (3, 4)
+ super().setData(data)
+
+ def _getValueData(self, copy=True):
+ """Compute the intensity of the RGBA image as default data.
+
+ Conversion: https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion
+
+ :param bool copy:
+ """
+ rgba = self.getRgbaImageData(copy=False).astype(numpy.float32)
+ intensity = (rgba[:, :, 0] * 0.299 +
+ rgba[:, :, 1] * 0.587 +
+ rgba[:, :, 2] * 0.114)
+ intensity *= rgba[:, :, 3] / 255.
+ return intensity
+
+
+class MaskImageData(ImageData):
+ """Description of an image used as a mask.
+
+ This class is used to flag mask items. This information is used to improve
+ internal silx widgets.
+ """
+ pass
+
+
+class ImageStack(ImageData):
+ """Item to store a stack of images and to show it in the plot as one
+ of the images of the stack.
+
+ The stack is a 3D array ordered this way: `frame id, y, x`.
+ So the first image of the stack can be reached this way: `stack[0, :, :]`
+ """
+
+ def __init__(self):
+ ImageData.__init__(self)
+ self.__stack = None
+ """A 3D numpy array (or a mimic one, see ListOfImages)"""
+ self.__stackPosition = None
+ """Displayed position in the cube"""
+
+ def setStackData(self, stack, position=None, copy=True):
+ """Set the stack data
+
+ :param stack: A 3D numpy array like
+ :param int position: The position of the displayed image in the stack
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if self.__stack is stack:
+ return
+ if copy:
+ stack = numpy.array(stack)
+ assert stack.ndim == 3
+ self.__stack = stack
+ if position is not None:
+ self.__stackPosition = position
+ if self.__stackPosition is None:
+ self.__stackPosition = 0
+ self.__updateDisplayedData()
+
+ def getStackData(self, copy=True):
+ """Get the stored stack array.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: A 3D numpy array, or numpy array like
+ """
+ if copy:
+ return numpy.array(self.__stack)
+ else:
+ return self.__stack
+
+ def setStackPosition(self, pos):
+ """Set the displayed position on the stack.
+
+ This function will clamp the stack position according to
+ the real size of the first axis of the stack.
+
+ :param int pos: A position on the first axis of the stack.
+ """
+ if self.__stackPosition == pos:
+ return
+ self.__stackPosition = pos
+ self.__updateDisplayedData()
+
+ def getStackPosition(self):
+ """Get the displayed position of the stack.
+
+ :rtype: int
+ """
+ return self.__stackPosition
+
+ def __updateDisplayedData(self):
+ """Update the displayed frame whenever the stack or the stack
+ position are updated."""
+ if self.__stack is None or self.__stackPosition is None:
+ empty = numpy.array([]).reshape(0, 0)
+ self.setData(empty, copy=False)
+ return
+ size = len(self.__stack)
+ self.__stackPosition = numpy.clip(self.__stackPosition, 0, size)
+ self.setData(self.__stack[self.__stackPosition], copy=False)
diff --git a/src/silx/gui/plot/items/image_aggregated.py b/src/silx/gui/plot/items/image_aggregated.py
new file mode 100644
index 0000000..75fdd59
--- /dev/null
+++ b/src/silx/gui/plot/items/image_aggregated.py
@@ -0,0 +1,229 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`ImageDataAggregated` items of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "07/07/2021"
+
+import enum
+import logging
+from typing import Tuple, Union
+
+import numpy
+
+from ....utils.enum import Enum as _Enum
+from ....utils.proxy import docstring
+from .axis import Axis
+from .core import ItemChangedType
+from .image import ImageDataBase
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ImageDataAggregated(ImageDataBase):
+ """Item displaying an image as a density map."""
+
+ @enum.unique
+ class Aggregation(_Enum):
+ NONE = "none"
+ "Do not aggregate data, display as is (default)"
+
+ MAX = "max"
+ "Aggregates elements with max (ignore NaNs)"
+
+ MEAN = "mean"
+ "Aggregates elements with mean (ignore NaNs)"
+
+ MIN = "min"
+ "Aggregates elements with min (ignore NaNs)"
+
+ def __init__(self):
+ super().__init__()
+ self.__cacheLODData = {}
+ self.__currentLOD = 0, 0
+ self.__aggregationMode = self.Aggregation.NONE
+
+ def setAggregationMode(self, mode: Union[str,Aggregation]):
+ """Set the aggregation method used to reduce the data to screen resolution.
+
+ :param Aggregation mode: The aggregation method
+ """
+ aggregationMode = self.Aggregation.from_value(mode)
+ if aggregationMode != self.__aggregationMode:
+ self.__aggregationMode = aggregationMode
+ self.__cacheLODData = {} # Clear cache
+ self._updated(ItemChangedType.VISUALIZATION_MODE)
+
+ def getAggregationMode(self) -> Aggregation:
+ """Returns the currently used aggregation method."""
+ return self.__aggregationMode
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ plot = self.getPlot()
+ assert plot is not None
+ if not self._isPlotLinear(plot):
+ # Do not render with non linear scales
+ return None
+
+ data = self.getData(copy=False)
+ if data.size == 0:
+ return None # No data to display
+
+ aggregationMode = self.getAggregationMode()
+ if aggregationMode == self.Aggregation.NONE: # Pass data as it is
+ displayedData = data
+ scale = self.getScale()
+
+ else: # Aggregate data according to level of details
+ if aggregationMode == self.Aggregation.MAX:
+ aggregator = numpy.nanmax
+ elif aggregationMode == self.Aggregation.MEAN:
+ aggregator = numpy.nanmean
+ elif aggregationMode == self.Aggregation.MIN:
+ aggregator = numpy.nanmin
+ else:
+ _logger.error("Unsupported aggregation mode")
+ return None
+
+ lodx, lody = self._getLevelOfDetails()
+
+ if (lodx, lody) not in self.__cacheLODData:
+ height, width = data.shape
+ self.__cacheLODData[(lodx, lody)] = aggregator(
+ data[: (height // lody) * lody, : (width // lodx) * lodx].reshape(
+ height // lody, lody, width // lodx, lodx
+ ),
+ axis=(1, 3),
+ )
+
+ self.__currentLOD = lodx, lody
+ displayedData = self.__cacheLODData[self.__currentLOD]
+
+ sx, sy = self.getScale()
+ scale = sx * lodx, sy * lody
+
+ return backend.addImage(
+ displayedData,
+ origin=self.getOrigin(),
+ scale=scale,
+ colormap=self._getColormapForRendering(),
+ alpha=self.getAlpha(),
+ )
+
+ def _getPixelSizeInData(self, axis="left"):
+ """Returns the size of a pixel in plot data coordinates
+
+ :param str axis: Y axis to use in: 'left' (default), 'right'
+ :return:
+ Size (width, height) of a Qt pixel in data coordinates.
+ Size is None if it cannot be computed
+ :rtype: Union[List[float],None]
+ """
+ assert axis in ("left", "right")
+ plot = self.getPlot()
+ if plot is None:
+ return None
+
+ xaxis = plot.getXAxis()
+ yaxis = plot.getYAxis(axis)
+
+ if (
+ xaxis.getScale() != Axis.LINEAR
+ or yaxis.getScale() != Axis.LINEAR
+ ):
+ raise RuntimeError("Only available with linear axes")
+
+ xmin, xmax = xaxis.getLimits()
+ ymin, ymax = yaxis.getLimits()
+ width, height = plot.getPlotBoundsInPixels()[2:]
+ if width == 0 or height == 0:
+ return None
+ else:
+ return (xmax - xmin) / width, (ymax - ymin) / height
+
+ def _getLevelOfDetails(self) -> Tuple[int, int]:
+ """Return current level of details the image is displayed with."""
+ plot = self.getPlot()
+ if plot is None or not self._isPlotLinear(plot):
+ return 1, 1 # Fallback to bas LOD
+
+ sx, sy = self.getScale()
+ xUnitPerPixel, yUnitPerPixel = self._getPixelSizeInData()
+ lodx = max(1, int(numpy.ceil(xUnitPerPixel / sx)))
+ lody = max(1, int(numpy.ceil(yUnitPerPixel / sy)))
+ return lodx, lody
+
+ @docstring(ImageDataBase)
+ def setData(self, data, copy=True):
+ self.__cacheLODData = {} # Reset cache
+ super().setData(data)
+
+ @docstring(ImageDataBase)
+ def _setPlot(self, plot):
+ """Refresh image when plot limits change"""
+ previousPlot = self.getPlot()
+ if previousPlot is not None:
+ for axis in (previousPlot.getXAxis(), previousPlot.getYAxis()):
+ axis.sigLimitsChanged.disconnect(self.__plotLimitsChanged)
+
+ super()._setPlot(plot)
+
+ if plot is not None:
+ for axis in (plot.getXAxis(), plot.getYAxis()):
+ axis.sigLimitsChanged.connect(self.__plotLimitsChanged)
+
+ def __plotLimitsChanged(self):
+ """Trigger update if level of details has changed"""
+ if (self.getAggregationMode() != self.Aggregation.NONE and
+ self.__currentLOD != self._getLevelOfDetails()):
+ self._updated()
+
+ @docstring(ImageDataBase)
+ def pick(self, x, y):
+ result = super().pick(x, y)
+ if result is None:
+ return None
+
+ # Compute indices in initial data
+ plot = self.getPlot()
+ if plot is None:
+ return None
+ dataPos = plot.pixelToData(x, y, axis="left", check=True)
+ if dataPos is None:
+ return None # Outside plot area
+
+ ox, oy = self.getOrigin()
+ sx, sy = self.getScale()
+ col = int((dataPos[0] - ox) / sx)
+ row = int((dataPos[1] - oy) / sy)
+ height, width = self.getData(copy=False).shape[:2]
+ if 0 <= col < width and 0 <= row < height:
+ return PickingResult(self, ((row,), (col,)))
+ return None
diff --git a/silx/gui/plot/items/marker.py b/src/silx/gui/plot/items/marker.py
index 50d070c..50d070c 100755
--- a/silx/gui/plot/items/marker.py
+++ b/src/silx/gui/plot/items/marker.py
diff --git a/silx/gui/plot/items/roi.py b/src/silx/gui/plot/items/roi.py
index 38a1424..38a1424 100644
--- a/silx/gui/plot/items/roi.py
+++ b/src/silx/gui/plot/items/roi.py
diff --git a/src/silx/gui/plot/items/scatter.py b/src/silx/gui/plot/items/scatter.py
new file mode 100644
index 0000000..fdc66f7
--- /dev/null
+++ b/src/silx/gui/plot/items/scatter.py
@@ -0,0 +1,1002 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Scatter` item of the :class:`Plot`.
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "29/03/2017"
+
+
+from collections import namedtuple
+import logging
+import threading
+import numpy
+
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor, CancelledError
+
+from ....utils.proxy import docstring
+from ....math.combo import min_max
+from ....math.histogram import Histogramnd
+from ....utils.weakref import WeakList
+from .._utils.delaunay import delaunay
+from .core import PointsBase, ColormapMixIn, ScatterVisualizationMixIn
+from .axis import Axis
+from ._pick import PickingResult
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _GreedyThreadPoolExecutor(ThreadPoolExecutor):
+ """:class:`ThreadPoolExecutor` with an extra :meth:`submit_greedy` method.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(_GreedyThreadPoolExecutor, self).__init__(*args, **kwargs)
+ self.__futures = defaultdict(WeakList)
+ self.__lock = threading.RLock()
+
+ def submit_greedy(self, queue, fn, *args, **kwargs):
+ """Same as :meth:`submit` but cancel previous tasks in given queue.
+
+ This means that when a new task is submitted for a given queue,
+ all other pending tasks of that queue are cancelled.
+
+ :param queue: Identifier of the queue. This must be hashable.
+ :param callable fn: The callable to call with provided extra arguments
+ :return: Future corresponding to this task
+ :rtype: concurrent.futures.Future
+ """
+ with self.__lock:
+ # Cancel previous tasks in given queue
+ for future in self.__futures.pop(queue, []):
+ if not future.done():
+ future.cancel()
+
+ future = super(_GreedyThreadPoolExecutor, self).submit(
+ fn, *args, **kwargs)
+ self.__futures[queue].append(future)
+
+ return future
+
+
+# Functions to guess grid shape from coordinates
+
+def _get_z_line_length(array):
+ """Return length of line if array is a Z-like 2D regular grid.
+
+ :param numpy.ndarray array: The 1D array of coordinates to check
+ :return: 0 if no line length could be found,
+ else the number of element per line.
+ :rtype: int
+ """
+ sign = numpy.sign(numpy.diff(array))
+ if len(sign) == 0 or sign[0] == 0: # We don't handle that
+ return 0
+ # Check this way to account for 0 sign (i.e., diff == 0)
+ beginnings = numpy.where(sign == - sign[0])[0] + 1
+ if len(beginnings) == 0:
+ return 0
+ length = beginnings[0]
+ if numpy.all(numpy.equal(numpy.diff(beginnings), length)):
+ return length
+ return 0
+
+
+def _guess_z_grid_shape(x, y):
+ """Guess the shape of a grid from (x, y) coordinates.
+
+ The grid might contain more elements than x and y,
+ as the last line might be partly filled.
+
+ :param numpy.ndarray x:
+ :paran numpy.ndarray y:
+ :returns: (order, (height, width)) of the regular grid,
+ or None if could not guess one.
+ 'order' is 'row' if X (i.e., column) is the fast dimension, else 'column'.
+ :rtype: Union[List(str,int),None]
+ """
+ width = _get_z_line_length(x)
+ if width != 0:
+ return 'row', (int(numpy.ceil(len(x) / width)), width)
+ else:
+ height = _get_z_line_length(y)
+ if height != 0:
+ return 'column', (height, int(numpy.ceil(len(y) / height)))
+ return None
+
+
+def is_monotonic(array):
+ """Returns whether array is monotonic (increasing or decreasing).
+
+ :param numpy.ndarray array: 1D array-like container.
+ :returns: 1 if array is monotonically increasing,
+ -1 if array is monotonically decreasing,
+ 0 if array is not monotonic
+ :rtype: int
+ """
+ diff = numpy.diff(numpy.ravel(array))
+ with numpy.errstate(invalid='ignore'):
+ if numpy.all(diff >= 0):
+ return 1
+ elif numpy.all(diff <= 0):
+ return -1
+ else:
+ return 0
+
+
+def _guess_grid(x, y):
+ """Guess a regular grid from the points.
+
+ Result convention is (x, y)
+
+ :param numpy.ndarray x: X coordinates of the points
+ :param numpy.ndarray y: Y coordinates of the points
+ :returns: (order, (height, width)
+ order is 'row' or 'column'
+ :rtype: Union[List[str,List[int]],None]
+ """
+ x, y = numpy.ravel(x), numpy.ravel(y)
+
+ guess = _guess_z_grid_shape(x, y)
+ if guess is not None:
+ return guess
+
+ else:
+ # Cannot guess a regular grid
+ # Let's assume it's a single line
+ order = 'row' # or 'column' doesn't matter for a single line
+ y_monotonic = is_monotonic(y)
+ if is_monotonic(x) or y_monotonic: # we can guess a line
+ x_min, x_max = min_max(x)
+ y_min, y_max = min_max(y)
+
+ if not y_monotonic or x_max - x_min >= y_max - y_min:
+ # x only is monotonic or both are and X varies more
+ # line along X
+ shape = 1, len(x)
+ else:
+ # y only is monotonic or both are and Y varies more
+ # line along Y
+ shape = len(y), 1
+
+ else: # Cannot guess a line from the points
+ return None
+
+ return order, shape
+
+
+def _quadrilateral_grid_coords(points):
+ """Compute an irregular grid of quadrilaterals from a set of points
+
+ The input points are expected to lie on a grid.
+
+ :param numpy.ndarray points:
+ 3D data set of 2D input coordinates (height, width, 2)
+ height and width must be at least 2.
+ :return: 3D dataset of 2D coordinates of the grid (height+1, width+1, 2)
+ """
+ assert points.ndim == 3
+ assert points.shape[0] >= 2
+ assert points.shape[1] >= 2
+ assert points.shape[2] == 2
+
+ dim0, dim1 = points.shape[:2]
+ grid_points = numpy.zeros((dim0 + 1, dim1 + 1, 2), dtype=numpy.float64)
+
+ # Compute inner points as mean of 4 neighbours
+ neighbour_view = numpy.lib.stride_tricks.as_strided(
+ points,
+ shape=(dim0 - 1, dim1 - 1, 2, 2, points.shape[2]),
+ strides=points.strides[:2] + points.strides[:2] + points.strides[-1:], writeable=False)
+ inner_points = numpy.mean(neighbour_view, axis=(2, 3))
+ grid_points[1:-1, 1:-1] = inner_points
+
+ # Compute 'vertical' sides
+ # Alternative: grid_points[1:-1, [0, -1]] = points[:-1, [0, -1]] + points[1:, [0, -1]] - inner_points[:, [0, -1]]
+ grid_points[1:-1, [0, -1], 0] = points[:-1, [0, -1], 0] + points[1:, [0, -1], 0] - inner_points[:, [0, -1], 0]
+ grid_points[1:-1, [0, -1], 1] = inner_points[:, [0, -1], 1]
+
+ # Compute 'horizontal' sides
+ grid_points[[0, -1], 1:-1, 0] = inner_points[[0, -1], :, 0]
+ grid_points[[0, -1], 1:-1, 1] = points[[0, -1], :-1, 1] + points[[0, -1], 1:, 1] - inner_points[[0, -1], :, 1]
+
+ # Compute corners
+ d0, d1 = [0, 0, -1, -1], [0, -1, -1, 0]
+ grid_points[d0, d1] = 2 * points[d0, d1] - inner_points[d0, d1]
+ return grid_points
+
+
+def _quadrilateral_grid_as_triangles(points):
+ """Returns the points and indices to make a grid of quadirlaterals
+
+ :param numpy.ndarray points:
+ 3D array of points (height, width, 2)
+ :return: triangle corners (4 * N, 2), triangle indices (2 * N, 3)
+ With N = height * width, the number of input points
+ """
+ nbpoints = numpy.prod(points.shape[:2])
+
+ grid = _quadrilateral_grid_coords(points)
+ coords = numpy.empty((4 * nbpoints, 2), dtype=grid.dtype)
+ coords[::4] = grid[:-1, :-1].reshape(-1, 2)
+ coords[1::4] = grid[1:, :-1].reshape(-1, 2)
+ coords[2::4] = grid[:-1, 1:].reshape(-1, 2)
+ coords[3::4] = grid[1:, 1:].reshape(-1, 2)
+
+ indices = numpy.empty((2 * nbpoints, 3), dtype=numpy.uint32)
+ indices[::2, 0] = numpy.arange(0, 4 * nbpoints, 4)
+ indices[::2, 1] = numpy.arange(1, 4 * nbpoints, 4)
+ indices[::2, 2] = numpy.arange(2, 4 * nbpoints, 4)
+ indices[1::2, 0] = indices[::2, 1]
+ indices[1::2, 1] = indices[::2, 2]
+ indices[1::2, 2] = numpy.arange(3, 4 * nbpoints, 4)
+
+ return coords, indices
+
+
+_RegularGridInfo = namedtuple(
+ '_RegularGridInfo', ['bounds', 'origin', 'scale', 'shape', 'order'])
+
+
+_HistogramInfo = namedtuple(
+ '_HistogramInfo', ['mean', 'count', 'sum', 'origin', 'scale', 'shape'])
+
+
+class Scatter(PointsBase, ColormapMixIn, ScatterVisualizationMixIn):
+ """Description of a scatter"""
+
+ _DEFAULT_SELECTABLE = True
+ """Default selectable state for scatter plots"""
+
+ _SUPPORTED_SCATTER_VISUALIZATION = (
+ ScatterVisualizationMixIn.Visualization.POINTS,
+ ScatterVisualizationMixIn.Visualization.SOLID,
+ ScatterVisualizationMixIn.Visualization.REGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.IRREGULAR_GRID,
+ ScatterVisualizationMixIn.Visualization.BINNED_STATISTIC,
+ )
+ """Overrides supported Visualizations"""
+
+ def __init__(self):
+ PointsBase.__init__(self)
+ ColormapMixIn.__init__(self)
+ ScatterVisualizationMixIn.__init__(self)
+ self._value = ()
+ self.__alpha = None
+ # Cache Delaunay triangulation future object
+ self.__delaunayFuture = None
+ # Cache interpolator future object
+ self.__interpolatorFuture = None
+ self.__executor = None
+
+ # Cache triangles: x, y, indices
+ self.__cacheTriangles = None, None, None
+
+ # Cache regular grid and histogram info
+ self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ def _updateColormappedData(self):
+ """Update the colormapped data, to be called when changed"""
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ data = None
+ else:
+ data = getattr(
+ histoInfo,
+ self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+ else:
+ data = self.getValueData(copy=False)
+ self._setColormappedData(data, copy=False)
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualization(self, mode):
+ previous = self.getVisualization()
+ if super().setVisualization(mode):
+ if (bool(mode is self.Visualization.BINNED_STATISTIC) ^
+ bool(previous is self.Visualization.BINNED_STATISTIC)):
+ self._updateColormappedData()
+ return True
+ else:
+ return False
+
+ @docstring(ScatterVisualizationMixIn)
+ def setVisualizationParameter(self, parameter, value):
+ parameter = self.VisualizationParameter.from_value(parameter)
+
+ if super(Scatter, self).setVisualizationParameter(parameter, value):
+ if parameter in (self.VisualizationParameter.GRID_BOUNDS,
+ self.VisualizationParameter.GRID_MAJOR_ORDER,
+ self.VisualizationParameter.GRID_SHAPE):
+ self.__cacheRegularGridInfo = None
+
+ if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ self.VisualizationParameter.DATA_BOUNDS_HINT):
+ if parameter in (self.VisualizationParameter.BINNED_STATISTIC_SHAPE,
+ self.VisualizationParameter.DATA_BOUNDS_HINT):
+ self.__cacheHistogramInfo = None # Clean-up cache
+ if self.getVisualization() is self.Visualization.BINNED_STATISTIC:
+ self._updateColormappedData()
+ return True
+ else:
+ return False
+
+ @docstring(ScatterVisualizationMixIn)
+ def getCurrentVisualizationParameter(self, parameter):
+ value = self.getVisualizationParameter(parameter)
+ if (parameter is self.VisualizationParameter.DATA_BOUNDS_HINT or
+ value is not None):
+ return value # Value has been set, return it
+
+ elif parameter is self.VisualizationParameter.GRID_BOUNDS:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.bounds
+
+ elif parameter is self.VisualizationParameter.GRID_MAJOR_ORDER:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.order
+
+ elif parameter is self.VisualizationParameter.GRID_SHAPE:
+ grid = self.__getRegularGridInfo()
+ return None if grid is None else grid.shape
+
+ elif parameter is self.VisualizationParameter.BINNED_STATISTIC_SHAPE:
+ info = self.__getHistogramInfo()
+ return None if info is None else info.shape
+
+ else:
+ raise NotImplementedError()
+
+ def __getRegularGridInfo(self):
+ """Get grid info"""
+ if self.__cacheRegularGridInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_SHAPE)
+ order = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_MAJOR_ORDER)
+ if shape is None or order is None:
+ guess = _guess_grid(self.getXData(copy=False),
+ self.getYData(copy=False))
+ if guess is None:
+ _logger.warning(
+ 'Cannot guess a grid: Cannot display as regular grid image')
+ return None
+ if shape is None:
+ shape = guess[1]
+ if order is None:
+ order = guess[0]
+
+ nbpoints = len(self.getXData(copy=False))
+ if nbpoints > shape[0] * shape[1]:
+ # More data points that provided grid shape: enlarge grid
+ _logger.warning(
+ "More data points than provided grid shape size: extends grid")
+ dim0, dim1 = shape
+ if order == 'row': # keep dim1, enlarge dim0
+ dim0 = nbpoints // dim1 + (1 if nbpoints % dim1 else 0)
+ else: # keep dim0, enlarge dim1
+ dim1 = nbpoints // dim0 + (1 if nbpoints % dim0 else 0)
+ shape = dim0, dim1
+
+ bounds = self.getVisualizationParameter(
+ self.VisualizationParameter.GRID_BOUNDS)
+ if bounds is None:
+ x, y = self.getXData(copy=False), self.getYData(copy=False)
+ min_, max_ = min_max(x)
+ xRange = (min_, max_) if (x[0] - min_) < (max_ - x[0]) else (max_, min_)
+ min_, max_ = min_max(y)
+ yRange = (min_, max_) if (y[0] - min_) < (max_ - y[0]) else (max_, min_)
+ bounds = (xRange[0], yRange[0]), (xRange[1], yRange[1])
+
+ begin, end = bounds
+ scale = ((end[0] - begin[0]) / max(1, shape[1] - 1),
+ (end[1] - begin[1]) / max(1, shape[0] - 1))
+ if scale[0] == 0 and scale[1] == 0:
+ scale = 1., 1.
+ elif scale[0] == 0:
+ scale = scale[1], scale[1]
+ elif scale[1] == 0:
+ scale = scale[0], scale[0]
+
+ origin = begin[0] - 0.5 * scale[0], begin[1] - 0.5 * scale[1]
+
+ self.__cacheRegularGridInfo = _RegularGridInfo(
+ bounds=bounds, origin=origin, scale=scale, shape=shape, order=order)
+
+ return self.__cacheRegularGridInfo
+
+ def __getHistogramInfo(self):
+ """Get histogram info"""
+ if self.__cacheHistogramInfo is None:
+ shape = self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_SHAPE)
+ if shape is None:
+ shape = 100, 100 # TODO compute auto shape
+
+ x, y, values = self.getData(copy=False)[:3]
+ if len(x) == 0: # No histogram
+ return None
+
+ if not numpy.issubdtype(x.dtype, numpy.floating):
+ x = x.astype(numpy.float64)
+ if not numpy.issubdtype(y.dtype, numpy.floating):
+ y = y.astype(numpy.float64)
+ if not numpy.issubdtype(values.dtype, numpy.floating):
+ values = values.astype(numpy.float64)
+
+ ranges = (tuple(min_max(y, finite=True)),
+ tuple(min_max(x, finite=True)))
+ rangesHint = self.getVisualizationParameter(
+ self.VisualizationParameter.DATA_BOUNDS_HINT)
+ if rangesHint is not None:
+ ranges = tuple((min(dataMin, hintMin), max(dataMax, hintMax))
+ for (dataMin, dataMax), (hintMin, hintMax) in zip(ranges, rangesHint))
+
+ points = numpy.transpose(numpy.array((y, x)))
+ counts, sums, bin_edges = Histogramnd(
+ points,
+ histo_range=ranges,
+ n_bins=shape,
+ weights=values)
+ yEdges, xEdges = bin_edges
+ origin = xEdges[0], yEdges[0]
+ scale = ((xEdges[-1] - xEdges[0]) / (len(xEdges) - 1),
+ (yEdges[-1] - yEdges[0]) / (len(yEdges) - 1))
+
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ histo = sums / counts
+
+ self.__cacheHistogramInfo = _HistogramInfo(
+ mean=histo, count=counts, sum=sums,
+ origin=origin, scale=scale, shape=shape)
+
+ return self.__cacheHistogramInfo
+
+ def __applyColormapToData(self):
+ """Compute colors by applying colormap to values.
+
+ :returns: Array of RGBA colors
+ """
+ cmap = self.getColormap()
+ rgbacolors = cmap.applyToData(self)
+
+ if self.__alpha is not None:
+ rgbacolors[:, -1] = (rgbacolors[:, -1] * self.__alpha).astype(numpy.uint8)
+ return rgbacolors
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ # Filter-out values <= 0
+ xFiltered, yFiltered, valueFiltered, xerror, yerror = self.getData(
+ copy=False, displayed=True)
+
+ # Remove not finite numbers (this includes filtered out x, y <= 0)
+ mask = numpy.logical_and(numpy.isfinite(xFiltered), numpy.isfinite(yFiltered))
+ xFiltered = xFiltered[mask]
+ yFiltered = yFiltered[mask]
+
+ if len(xFiltered) == 0:
+ return None # No data to display, do not add renderer to backend
+
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.BINNED_STATISTIC:
+ plot = self.getPlot()
+ if (plot is None or
+ plot.getXAxis().getScale() != Axis.LINEAR or
+ plot.getYAxis().getScale() != Axis.LINEAR):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ data = getattr(histoInfo, self.getVisualizationParameter(
+ self.VisualizationParameter.BINNED_STATISTIC_FUNCTION))
+
+ return backend.addImage(
+ data=data,
+ origin=histoInfo.origin,
+ scale=histoInfo.scale,
+ colormap=self.getColormap(),
+ alpha=self.getAlpha())
+
+ elif visualization is self.Visualization.POINTS:
+ rgbacolors = self.__applyColormapToData()
+ return backend.addCurve(xFiltered, yFiltered,
+ color=rgbacolors[mask],
+ symbol=self.getSymbol(),
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=xerror,
+ yerror=yerror,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=self.getSymbolSize(),
+ baseline=None)
+
+ else:
+ plot = self.getPlot()
+ if (plot is None or
+ plot.getXAxis().getScale() != Axis.LINEAR or
+ plot.getYAxis().getScale() != Axis.LINEAR):
+ # Those visualizations are not available with log scaled axes
+ return None
+
+ if visualization is self.Visualization.SOLID:
+ triangulation = self._getDelaunay().result()
+ if triangulation is None:
+ _logger.warning(
+ 'Cannot get a triangulation: Cannot display as solid surface')
+ return None
+ else:
+ rgbacolors = self.__applyColormapToData()
+ triangles = triangulation.simplices.astype(numpy.int32)
+ return backend.addTriangles(xFiltered,
+ yFiltered,
+ triangles,
+ color=rgbacolors[mask],
+ alpha=self.getAlpha())
+
+ elif visualization is self.Visualization.REGULAR_GRID:
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ dim0, dim1 = gridInfo.shape
+ if gridInfo.order == 'column': # transposition needed
+ dim0, dim1 = dim1, dim0
+
+ values = self.getValueData(copy=False)
+ if self.__alpha is None and len(values) == dim0 * dim1:
+ image = values.reshape(dim0, dim1)
+ else:
+ # The points do not fill the whole image
+ if (self.__alpha is None and
+ numpy.issubdtype(values.dtype, numpy.floating)):
+ image = numpy.empty(dim0 * dim1, dtype=values.dtype)
+ image[:len(values)] = values
+ image[len(values):] = float('nan') # Transparent pixels
+ image.shape = dim0, dim1
+ else: # Per value alpha or no NaN, so convert to RGBA
+ rgbacolors = self.__applyColormapToData()
+ image = numpy.empty((dim0 * dim1, 4), dtype=numpy.uint8)
+ image[:len(rgbacolors)] = rgbacolors
+ image[len(rgbacolors):] = (0, 0, 0, 0) # Transparent pixels
+ image.shape = dim0, dim1, 4
+
+ if gridInfo.order == 'column':
+ if image.ndim == 2:
+ image = numpy.transpose(image)
+ else:
+ image = numpy.transpose(image, axes=(1, 0, 2))
+
+ if image.ndim == 2:
+ colormap = self.getColormap()
+ if colormap.isAutoscale():
+ # Avoid backend to compute autoscale: use item cache
+ colormap = colormap.copy()
+ colormap.setVRange(*colormap.getColormapRange(self))
+ else:
+ colormap = None
+
+ return backend.addImage(
+ data=image,
+ origin=gridInfo.origin,
+ scale=gridInfo.scale,
+ colormap=colormap,
+ alpha=self.getAlpha())
+
+ elif visualization is self.Visualization.IRREGULAR_GRID:
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ shape = gridInfo.shape
+ if shape is None: # No shape, no display
+ return None
+
+ rgbacolors = self.__applyColormapToData()
+
+ nbpoints = len(xFiltered)
+ if nbpoints == 1:
+ # single point, render as a square points
+ return backend.addCurve(xFiltered, yFiltered,
+ color=rgbacolors[mask],
+ symbol='s',
+ linewidth=0,
+ linestyle="",
+ yaxis='left',
+ xerror=None,
+ yerror=None,
+ fill=False,
+ alpha=self.getAlpha(),
+ symbolsize=7,
+ baseline=None)
+
+ # Make shape include all points
+ gridOrder = gridInfo.order
+ if nbpoints != numpy.prod(shape):
+ if gridOrder == 'row':
+ shape = int(numpy.ceil(nbpoints / shape[1])), shape[1]
+ else: # column-major order
+ shape = shape[0], int(numpy.ceil(nbpoints / shape[0]))
+
+ if shape[0] < 2 or shape[1] < 2: # Single line, at least 2 points
+ points = numpy.ones((2, nbpoints, 2), dtype=numpy.float64)
+ # Use row/column major depending on shape, not on info value
+ gridOrder = 'row' if shape[0] == 1 else 'column'
+
+ if gridOrder == 'row':
+ points[0, :, 0] = xFiltered
+ points[0, :, 1] = yFiltered
+ else: # column-major order
+ points[0, :, 0] = yFiltered
+ points[0, :, 1] = xFiltered
+
+ # Add a second line that will be clipped in the end
+ points[1, :-1] = points[0, :-1] + numpy.cross(
+ points[0, 1:] - points[0, :-1], (0., 0., 1.))[:, :2]
+ points[1, -1] = points[0, -1] + numpy.cross(
+ points[0, -1] - points[0, -2], (0., 0., 1.))[:2]
+
+ points.shape = 2, nbpoints, 2 # Use same shape for both orders
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ elif gridOrder == 'row': # row-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points[:nbpoints, 0] = xFiltered
+ points[:nbpoints, 1] = yFiltered
+ # Index of last element of last fully filled row
+ index = (nbpoints // shape[1]) * shape[1]
+ points[nbpoints:, 0] = xFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 1] = yFiltered[-1]
+ else:
+ points = numpy.transpose((xFiltered, yFiltered))
+ points.shape = shape[0], shape[1], 2
+
+ else: # column-major order
+ if nbpoints != numpy.prod(shape):
+ points = numpy.empty((numpy.prod(shape), 2), dtype=numpy.float64)
+ points[:nbpoints, 0] = yFiltered
+ points[:nbpoints, 1] = xFiltered
+ # Index of last element of last fully filled column
+ index = (nbpoints // shape[0]) * shape[0]
+ points[nbpoints:, 0] = yFiltered[index - (numpy.prod(shape) - nbpoints):index]
+ points[nbpoints:, 1] = xFiltered[-1]
+ else:
+ points = numpy.transpose((yFiltered, xFiltered))
+ points.shape = shape[1], shape[0], 2
+
+ coords, indices = _quadrilateral_grid_as_triangles(points)
+
+ # Remove unused extra triangles
+ coords = coords[:4*nbpoints]
+ indices = indices[:2*nbpoints]
+
+ if gridOrder == 'row':
+ x, y = coords[:, 0], coords[:, 1]
+ else: # column-major order
+ y, x = coords[:, 0], coords[:, 1]
+
+ rgbacolors = rgbacolors[mask] # Filter-out not finite points
+ gridcolors = numpy.empty(
+ (4 * nbpoints, rgbacolors.shape[-1]), dtype=rgbacolors.dtype)
+ for first in range(4):
+ gridcolors[first::4] = rgbacolors[:nbpoints]
+
+ return backend.addTriangles(x,
+ y,
+ indices,
+ color=gridcolors,
+ alpha=self.getAlpha())
+
+ else:
+ _logger.error("Unhandled visualization %s", visualization)
+ return None
+
+ @docstring(PointsBase)
+ def pick(self, x, y):
+ result = super(Scatter, self).pick(x, y)
+
+ if result is not None:
+ visualization = self.getVisualization()
+
+ if visualization is self.Visualization.IRREGULAR_GRID:
+ # Specific handling of picking for the irregular grid mode
+ index = result.getIndices(copy=False)[0] // 4
+ result = PickingResult(self, (index,))
+
+ elif visualization is self.Visualization.REGULAR_GRID:
+ # Specific handling of picking for the regular grid mode
+ picked = result.getIndices(copy=False)
+ if picked is None:
+ return None
+ row, column = picked[0][0], picked[1][0]
+
+ gridInfo = self.__getRegularGridInfo()
+ if gridInfo is None:
+ return None
+
+ if gridInfo.order == 'row':
+ index = row * gridInfo.shape[1] + column
+ else:
+ index = row + column * gridInfo.shape[0]
+ if index >= len(self.getXData(copy=False)): # OK as long as not log scale
+ return None # Image can be larger than scatter
+
+ result = PickingResult(self, (index,))
+
+ elif visualization is self.Visualization.BINNED_STATISTIC:
+ picked = result.getIndices(copy=False)
+ if picked is None or len(picked) == 0 or len(picked[0]) == 0:
+ return None
+ row, col = picked[0][0], picked[1][0]
+ histoInfo = self.__getHistogramInfo()
+ if histoInfo is None:
+ return None
+ sx, sy = histoInfo.scale
+ ox, oy = histoInfo.origin
+ xdata = self.getXData(copy=False)
+ ydata = self.getYData(copy=False)
+ indices = numpy.nonzero(numpy.logical_and(
+ numpy.logical_and(xdata >= ox + sx * col, xdata < ox + sx * (col + 1)),
+ numpy.logical_and(ydata >= oy + sy * row, ydata < oy + sy * (row + 1))))[0]
+ result = None if len(indices) == 0 else PickingResult(self, indices)
+
+ return result
+
+ def __getExecutor(self):
+ """Returns async greedy executor
+
+ :rtype: _GreedyThreadPoolExecutor
+ """
+ if self.__executor is None:
+ self.__executor = _GreedyThreadPoolExecutor(max_workers=2)
+ return self.__executor
+
+ def _getDelaunay(self):
+ """Returns a :class:`Future` which result is the Delaunay object.
+
+ :rtype: concurrent.futures.Future
+ """
+ if self.__delaunayFuture is None or self.__delaunayFuture.cancelled():
+ # Need to init a new delaunay
+ x, y = self.getData(copy=False)[:2]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+
+ self.__delaunayFuture = self.__getExecutor().submit_greedy(
+ 'delaunay', delaunay, x[mask], y[mask])
+
+ return self.__delaunayFuture
+
+ @staticmethod
+ def __initInterpolator(delaunayFuture, values):
+ """Returns an interpolator for the given data points
+
+ :param concurrent.futures.Future delaunayFuture:
+ Future object which result is a Delaunay object
+ :param numpy.ndarray values: The data value of valid points.
+ :rtype: Union[callable,None]
+ """
+ # Wait for Delaunay to complete
+ try:
+ triangulation = delaunayFuture.result()
+ except CancelledError:
+ triangulation = None
+
+ if triangulation is None:
+ interpolator = None # Error case
+ else:
+ # Lazy-loading of interpolator
+ try:
+ from scipy.interpolate import LinearNDInterpolator
+ except ImportError:
+ LinearNDInterpolator = None
+
+ if LinearNDInterpolator is not None:
+ interpolator = LinearNDInterpolator(triangulation, values)
+
+ # First call takes a while, do it here
+ interpolator([(0., 0.)])
+
+ else:
+ # Fallback using matplotlib interpolator
+ import matplotlib.tri
+
+ x, y = triangulation.points.T
+ tri = matplotlib.tri.Triangulation(
+ x, y, triangles=triangulation.simplices)
+ mplInterpolator = matplotlib.tri.LinearTriInterpolator(
+ tri, values)
+
+ # Wrap interpolator to have same API as scipy's one
+ def interpolator(points):
+ return mplInterpolator(*points.T)
+
+ return interpolator
+
+ def _getInterpolator(self):
+ """Returns a :class:`Future` which result is the interpolator.
+
+ The interpolator is a callable taking an array Nx2 of points
+ as a single argument.
+ The :class:`Future` result is None in case the interpolator cannot
+ be initialized.
+
+ :rtype: concurrent.futures.Future
+ """
+ if (self.__interpolatorFuture is None or
+ self.__interpolatorFuture.cancelled()):
+ # Need to init a new interpolator
+ x, y, values = self.getData(copy=False)[:3]
+ # Remove not finite points
+ mask = numpy.logical_and(numpy.isfinite(x), numpy.isfinite(y))
+ x, y, values = x[mask], y[mask], values[mask]
+
+ self.__interpolatorFuture = self.__getExecutor().submit_greedy(
+ 'interpolator',
+ self.__initInterpolator, self._getDelaunay(), values)
+ return self.__interpolatorFuture
+
+ def _logFilterData(self, xPositive, yPositive):
+ """Filter out values with x or y <= 0 on log axes
+
+ :param bool xPositive: True to filter arrays according to X coords.
+ :param bool yPositive: True to filter arrays according to Y coords.
+ :return: The filtered arrays or unchanged object if not filtering needed
+ :rtype: (x, y, value, xerror, yerror)
+ """
+ # overloaded from PointsBase to filter also value.
+ value = self.getValueData(copy=False)
+
+ if xPositive or yPositive:
+ clipped = self._getClippingBoolArray(xPositive, yPositive)
+
+ if numpy.any(clipped):
+ # copy to keep original array and convert to float
+ value = numpy.array(value, copy=True, dtype=numpy.float64)
+ value[clipped] = numpy.nan
+
+ x, y, xerror, yerror = PointsBase._logFilterData(self, xPositive, yPositive)
+
+ return x, y, value, xerror, yerror
+
+ def getValueData(self, copy=True):
+ """Returns the value assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._value, copy=copy)
+
+ def getAlphaData(self, copy=True):
+ """Returns the alpha (transparency) assigned to the scatter data points.
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self.__alpha, copy=copy)
+
+ def getData(self, copy=True, displayed=False):
+ """Returns the x, y coordinates and the value of the data points
+
+ :param copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :param bool displayed: True to only get curve points that are displayed
+ in the plot. Default: False.
+ Note: If plot has log scale, negative points
+ are not displayed.
+ :returns: (x, y, value, xerror, yerror)
+ :rtype: 5-tuple of numpy.ndarray
+ """
+ if displayed:
+ data = self._getCachedData()
+ if data is not None:
+ assert len(data) == 5
+ return data
+
+ return (self.getXData(copy),
+ self.getYData(copy),
+ self.getValueData(copy),
+ self.getXErrorData(copy),
+ self.getYErrorData(copy))
+
+ # reimplemented from PointsBase to handle `value`
+ def setData(self, x, y, value, xerror=None, yerror=None, alpha=None, copy=True):
+ """Set the data of the scatter.
+
+ :param numpy.ndarray x: The data corresponding to the x coordinates.
+ :param numpy.ndarray y: The data corresponding to the y coordinates.
+ :param numpy.ndarray value: The data corresponding to the value of
+ the data points.
+ :param xerror: Values with the uncertainties on the x values
+ :type xerror: A float, or a numpy.ndarray of float32.
+ If it is an array, it can either be a 1D array of
+ same length as the data or a 2D array with 2 rows
+ of same length as the data: row 0 for positive errors,
+ row 1 for negative errors.
+ :param yerror: Values with the uncertainties on the y values
+ :type yerror: A float, or a numpy.ndarray of float32. See xerror.
+ :param alpha: Values with the transparency (between 0 and 1)
+ :type alpha: A float, or a numpy.ndarray of float32
+ :param bool copy: True make a copy of the data (default),
+ False to use provided arrays.
+ """
+ value = numpy.array(value, copy=copy)
+ assert value.ndim == 1
+ assert len(x) == len(value)
+
+ # Convert complex data
+ if numpy.iscomplexobj(value):
+ _logger.warning(
+ 'Converting value data to absolute value to plot it.')
+ value = numpy.absolute(value)
+
+ # Reset triangulation and interpolator
+ if self.__delaunayFuture is not None:
+ self.__delaunayFuture.cancel()
+ self.__delaunayFuture = None
+ if self.__interpolatorFuture is not None:
+ self.__interpolatorFuture.cancel()
+ self.__interpolatorFuture = None
+
+ # Data changed, this needs update
+ self.__cacheRegularGridInfo = None
+ self.__cacheHistogramInfo = None
+
+ self._value = value
+
+ if alpha is not None:
+ # Make sure alpha is an array of float in [0, 1]
+ alpha = numpy.array(alpha, copy=copy)
+ assert alpha.ndim == 1
+ assert len(x) == len(alpha)
+ if alpha.dtype.kind != 'f':
+ alpha = alpha.astype(numpy.float32)
+ if numpy.any(numpy.logical_or(alpha < 0., alpha > 1.)):
+ alpha = numpy.clip(alpha, 0., 1.)
+ self.__alpha = alpha
+
+ # set x, y, xerror, yerror
+
+ # call self._updated + plot._invalidateDataRange()
+ PointsBase.setData(self, x, y, xerror, yerror, copy)
+
+ self._updateColormappedData()
diff --git a/src/silx/gui/plot/items/shape.py b/src/silx/gui/plot/items/shape.py
new file mode 100644
index 0000000..00ac5f5
--- /dev/null
+++ b/src/silx/gui/plot/items/shape.py
@@ -0,0 +1,287 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :class:`Shape` item of the :class:`Plot`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "21/12/2018"
+
+
+import logging
+
+import numpy
+
+from ... import colors
+from .core import (
+ Item, DataItem,
+ ColorMixIn, FillMixIn, ItemChangedType, LineMixIn, YAxisMixIn)
+
+
+_logger = logging.getLogger(__name__)
+
+
+# TODO probably make one class for each kind of shape
+# TODO check fill:polygon/polyline + fill = duplicated
+class Shape(Item, ColorMixIn, FillMixIn, LineMixIn):
+ """Description of a shape item
+
+ :param str type_: The type of shape in:
+ 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
+ """
+
+ def __init__(self, type_):
+ Item.__init__(self)
+ ColorMixIn.__init__(self)
+ FillMixIn.__init__(self)
+ LineMixIn.__init__(self)
+ self._overlay = False
+ assert type_ in ('hline', 'polygon', 'rectangle', 'vline', 'polylines')
+ self._type = type_
+ self._points = ()
+ self._lineBgColor = None
+
+ self._handle = None
+
+ def _addBackendRenderer(self, backend):
+ """Update backend renderer"""
+ points = self.getPoints(copy=False)
+ x, y = points.T[0], points.T[1]
+ return backend.addShape(x,
+ y,
+ shape=self.getType(),
+ color=self.getColor(),
+ fill=self.isFill(),
+ overlay=self.isOverlay(),
+ linestyle=self.getLineStyle(),
+ linewidth=self.getLineWidth(),
+ linebgcolor=self.getLineBgColor())
+
+ def isOverlay(self):
+ """Return true if shape is drawn as an overlay
+
+ :rtype: bool
+ """
+ return self._overlay
+
+ def setOverlay(self, overlay):
+ """Set the overlay state of the shape
+
+ :param bool overlay: True to make it an overlay
+ """
+ overlay = bool(overlay)
+ if overlay != self._overlay:
+ self._overlay = overlay
+ self._updated(ItemChangedType.OVERLAY)
+
+ def getType(self):
+ """Returns the type of shape to draw.
+
+ One of: 'hline', 'polygon', 'rectangle', 'vline', 'polylines'
+
+ :rtype: str
+ """
+ return self._type
+
+ def getPoints(self, copy=True):
+ """Get the control points of the shape.
+
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return: Array of point coordinates
+ :rtype: numpy.ndarray with 2 dimensions
+ """
+ return numpy.array(self._points, copy=copy)
+
+ def setPoints(self, points, copy=True):
+ """Set the point coordinates
+
+ :param numpy.ndarray points: Array of point coordinates
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ :return:
+ """
+ self._points = numpy.array(points, copy=copy)
+ self._updated(ItemChangedType.DATA)
+
+ def getLineBgColor(self):
+ """Returns the RGBA color of the item
+ :rtype: 4-tuple of float in [0, 1] or array of colors
+ """
+ return self._lineBgColor
+
+ def setLineBgColor(self, color, copy=True):
+ """Set item color
+ :param color: color(s) to be used
+ :type color: str ("#RRGGBB") or (npoints, 4) unsigned byte array or
+ one of the predefined color names defined in colors.py
+ :param bool copy: True (Default) to get a copy,
+ False to use internal representation (do not modify!)
+ """
+ if color is not None:
+ if isinstance(color, str):
+ color = colors.rgba(color)
+ else:
+ color = numpy.array(color, copy=copy)
+ # TODO more checks + improve color array support
+ if color.ndim == 1: # Single RGBA color
+ color = colors.rgba(color)
+ else: # Array of colors
+ assert color.ndim == 2
+
+ self._lineBgColor = color
+ self._updated(ItemChangedType.LINE_BG_COLOR)
+
+
+class BoundingRect(DataItem, YAxisMixIn):
+ """An invisible shape which enforce the plot view to display the defined
+ space on autoscale.
+
+ This item do not display anything. But if the visible property is true,
+ this bounding box is used by the plot, if not, the bounding box is
+ ignored. That's the default behaviour for plot items.
+
+ It can be applied on the "left" or "right" axes. Not both at the same time.
+ """
+
+ def __init__(self):
+ DataItem.__init__(self)
+ YAxisMixIn.__init__(self)
+ self.__bounds = None
+
+ def setBounds(self, rect):
+ """Set the bounding box of this item in data coordinates
+
+ :param Union[None,List[float]] rect: (xmin, xmax, ymin, ymax) or None
+ """
+ if rect is not None:
+ rect = float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3])
+ assert rect[0] <= rect[1]
+ assert rect[2] <= rect[3]
+
+ if rect != self.__bounds:
+ self.__bounds = rect
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def _getBounds(self):
+ if self.__bounds is None:
+ return None
+ plot = self.getPlot()
+ if plot is not None:
+ xPositive = plot.getXAxis()._isLogarithmic()
+ yPositive = plot.getYAxis()._isLogarithmic()
+ if xPositive or yPositive:
+ bounds = list(self.__bounds)
+ if xPositive and bounds[1] <= 0:
+ return None
+ if xPositive and bounds[0] <= 0:
+ bounds[0] = bounds[1]
+ if yPositive and bounds[3] <= 0:
+ return None
+ if yPositive and bounds[2] <= 0:
+ bounds[2] = bounds[3]
+ return tuple(bounds)
+
+ return self.__bounds
+
+
+class _BaseExtent(DataItem):
+ """Base class for :class:`XAxisExtent` and :class:`YAxisExtent`.
+
+ :param str axis: Either 'x' or 'y'.
+ """
+
+ def __init__(self, axis='x'):
+ assert axis in ('x', 'y')
+ DataItem.__init__(self)
+ self.__axis = axis
+ self.__range = 1., 100.
+
+ def setRange(self, min_, max_):
+ """Set the range of the extent of this item in data coordinates.
+
+ :param float min_: Lower bound of the extent
+ :param float max_: Upper bound of the extent
+ :raises ValueError: If min > max or not finite bounds
+ """
+ range_ = float(min_), float(max_)
+ if not numpy.all(numpy.isfinite(range_)):
+ raise ValueError("min_ and max_ must be finite numbers.")
+ if range_[0] > range_[1]:
+ raise ValueError("min_ must be lesser or equal to max_")
+
+ if range_ != self.__range:
+ self.__range = range_
+ self._boundsChanged()
+ self._updated(ItemChangedType.DATA)
+
+ def getRange(self):
+ """Returns the range (min, max) of the extent in data coordinates.
+
+ :rtype: List[float]
+ """
+ return self.__range
+
+ def _getBounds(self):
+ min_, max_ = self.getRange()
+
+ plot = self.getPlot()
+ if plot is not None:
+ axis = plot.getXAxis() if self.__axis == 'x' else plot.getYAxis()
+ if axis._isLogarithmic():
+ if max_ <= 0:
+ return None
+ if min_ <= 0:
+ min_ = max_
+
+ if self.__axis == 'x':
+ return min_, max_, float('nan'), float('nan')
+ else:
+ return float('nan'), float('nan'), min_, max_
+
+
+class XAxisExtent(_BaseExtent):
+ """Invisible item with a settable horizontal data extent.
+
+ This item do not display anything, but it behaves as a data
+ item with a horizontal extent regarding plot data bounds, i.e.,
+ :meth:`PlotWidget.resetZoom` will take this horizontal extent into account.
+ """
+ def __init__(self):
+ _BaseExtent.__init__(self, axis='x')
+
+
+class YAxisExtent(_BaseExtent, YAxisMixIn):
+ """Invisible item with a settable vertical data extent.
+
+ This item do not display anything, but it behaves as a data
+ item with a vertical extent regarding plot data bounds, i.e.,
+ :meth:`PlotWidget.resetZoom` will take this vertical extent into account.
+ """
+
+ def __init__(self):
+ _BaseExtent.__init__(self, axis='y')
+ YAxisMixIn.__init__(self)
diff --git a/silx/gui/plot/matplotlib/Colormap.py b/src/silx/gui/plot/matplotlib/Colormap.py
index dc432b2..dc432b2 100644
--- a/silx/gui/plot/matplotlib/Colormap.py
+++ b/src/silx/gui/plot/matplotlib/Colormap.py
diff --git a/silx/gui/plot/matplotlib/__init__.py b/src/silx/gui/plot/matplotlib/__init__.py
index e787240..e787240 100644
--- a/silx/gui/plot/matplotlib/__init__.py
+++ b/src/silx/gui/plot/matplotlib/__init__.py
diff --git a/silx/gui/plot/setup.py b/src/silx/gui/plot/setup.py
index e0b2c91..e0b2c91 100644
--- a/silx/gui/plot/setup.py
+++ b/src/silx/gui/plot/setup.py
diff --git a/silx/gui/plot/stats/__init__.py b/src/silx/gui/plot/stats/__init__.py
index 04a5327..04a5327 100644
--- a/silx/gui/plot/stats/__init__.py
+++ b/src/silx/gui/plot/stats/__init__.py
diff --git a/silx/gui/plot/stats/stats.py b/src/silx/gui/plot/stats/stats.py
index a81f7bb..a81f7bb 100644
--- a/silx/gui/plot/stats/stats.py
+++ b/src/silx/gui/plot/stats/stats.py
diff --git a/silx/gui/plot/stats/statshandler.py b/src/silx/gui/plot/stats/statshandler.py
index 17578d8..17578d8 100644
--- a/silx/gui/plot/stats/statshandler.py
+++ b/src/silx/gui/plot/stats/statshandler.py
diff --git a/src/silx/gui/plot/test/__init__.py b/src/silx/gui/plot/test/__init__.py
new file mode 100644
index 0000000..3ad225d
--- /dev/null
+++ b/src/silx/gui/plot/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/test/testAlphaSlider.py b/src/silx/gui/plot/test/testAlphaSlider.py
new file mode 100644
index 0000000..ca57bf5
--- /dev/null
+++ b/src/silx/gui/plot/test/testAlphaSlider.py
@@ -0,0 +1,204 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ImageAlphaSlider"""
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/03/2017"
+
+import numpy
+import unittest
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot import AlphaSlider
+
+
+class TestActiveImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestActiveImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.ActiveImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestActiveImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no active image initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ # now we have an active image
+ self.assertTrue(self.aslider.isEnabled())
+
+ self.plot.setActiveImage(None)
+ self.assertFalse(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]))
+ self.assertEqual(self.plot.getActiveImage(),
+ self.aslider.getItem())
+
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.plot.setActiveImage("2")
+ self.assertEqual(self.plot.getImage("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setValue(137)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 137. / 255)
+
+
+class TestNamedImageAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedImageAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedImageAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedImageAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no image set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetImage(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.plot.addImage(numpy.array([[0, 1, 3], [2, 4, 6]]), legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getImage("1"),
+ self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getImage("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addImage(numpy.array([[0, 1, 2], [3, 4, 5]]), legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 128. / 255)
+
+
+class TestNamedScatterAlphaSlider(TestCaseQt):
+ def setUp(self):
+ super(TestNamedScatterAlphaSlider, self).setUp()
+ self.plot = PlotWidget()
+ self.aslider = AlphaSlider.NamedScatterAlphaSlider(plot=self.plot)
+ self.aslider.setOrientation(qt.Qt.Horizontal)
+
+ toolbar = qt.QToolBar("plot", self.plot)
+ toolbar.addWidget(self.aslider)
+ self.plot.addToolBar(toolbar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.aslider
+
+ super(TestNamedScatterAlphaSlider, self).tearDown()
+
+ def testWidgetEnabled(self):
+ # no Scatter set initially, slider must be deactivate
+ self.assertFalse(self.aslider.isEnabled())
+
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
+ legend="1")
+ self.aslider.setLegend("1")
+ # now we have an image set
+ self.assertTrue(self.aslider.isEnabled())
+
+ def testGetScatter(self):
+ self.plot.addScatter([0, 1, 2], [2, 3, 4], [5, 6, 7],
+ legend="1")
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
+ legend="2")
+ self.aslider.setLegend("1")
+ self.assertEqual(self.plot.getScatter("1"),
+ self.aslider.getItem())
+
+ self.aslider.setLegend("2")
+ self.assertEqual(self.plot.getScatter("2"),
+ self.aslider.getItem())
+
+ def testGetAlpha(self):
+ self.plot.addScatter([0, 10, 20], [20, 30, 40], [50, 60, 70],
+ legend="1")
+ self.aslider.setLegend("1")
+ self.aslider.setValue(128)
+ self.assertAlmostEqual(self.aslider.getAlpha(),
+ 128. / 255)
diff --git a/src/silx/gui/plot/test/testColorBar.py b/src/silx/gui/plot/test/testColorBar.py
new file mode 100644
index 0000000..3dc8ff1
--- /dev/null
+++ b/src/silx/gui/plot/test/testColorBar.py
@@ -0,0 +1,340 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ColorBar featues and sub widgets of Colorbar module"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import unittest
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.ColorBar import _ColorScale
+from silx.gui.plot.ColorBar import ColorBarWidget
+from silx.gui.colors import Colormap
+from silx.math.colormap import LinearNormalization, LogarithmicNormalization
+from silx.gui.plot import Plot2D
+from silx.gui import qt
+import numpy
+
+
+class TestColorScale(TestCaseQt):
+ """Test that interaction with the colorScale is correct"""
+ def setUp(self):
+ super(TestColorScale, self).setUp()
+ self.colorScaleWidget = _ColorScale(colormap=None, parent=None)
+ self.colorScaleWidget.show()
+ self.qWaitForWindowExposed(self.colorScaleWidget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.colorScaleWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.colorScaleWidget.close()
+ del self.colorScaleWidget
+ super(TestColorScale, self).tearDown()
+
+ def testNoColormap(self):
+ """Test _ColorScale without a colormap"""
+ colormap = self.colorScaleWidget.getColormap()
+ self.assertIsNone(colormap)
+
+ def testRelativePositionLinear(self):
+ self.colorMapLin1 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=0.0,
+ vmax=1.0)
+ self.colorScaleWidget.setColormap(self.colorMapLin1)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == 0.25)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.5) == 0.5)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(1.0) == 1.0)
+
+ self.colorMapLin2 = Colormap(name='viridis',
+ normalization=Colormap.LINEAR,
+ vmin=-10,
+ vmax=0)
+ self.colorScaleWidget.setColormap(self.colorMapLin2)
+
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.25) == -7.5)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(0.5) == -5.0)
+ self.assertTrue(
+ self.colorScaleWidget.getValueFromRelativePosition(1.0) == 0.0)
+
+ def testRelativePositionLog(self):
+ self.colorMapLog1 = Colormap(name='temperature',
+ normalization=Colormap.LOGARITHM,
+ vmin=1.0,
+ vmax=100.0)
+
+ self.colorScaleWidget.setColormap(self.colorMapLog1)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(1.0)
+ self.assertAlmostEqual(val, 100.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.5)
+ self.assertAlmostEqual(val, 10.0)
+
+ val = self.colorScaleWidget.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+
+class TestNoAutoscale(TestCaseQt):
+ """Test that ticks and color displayed are correct in the case of a colormap
+ with no autoscale
+ """
+
+ def setUp(self):
+ super(TestNoAutoscale, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+ self.tickBar = self.colorBar.getColorScaleBar().getTickBar()
+ self.colorScale = self.colorBar.getColorScaleBar().getColorScale()
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.tickBar = None
+ self.colorScale = None
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestNoAutoscale, self).tearDown()
+
+ def testLogNormNoAutoscale(self):
+ colormapLog = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=1.0,
+ vmax=100.0)
+
+ data = numpy.linspace(10, 1e10, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ ticksTh = numpy.linspace(1.0, 100.0, 10)
+ ticksTh = 10**ticksTh
+ numpy.array_equal(self.tickBar.ticks, ticksTh)
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertAlmostEqual(val, 100.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == 1.0)
+
+ def testLinearNormNoAutoscale(self):
+ colormapLog = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=-4,
+ vmax=5)
+
+ data = numpy.linspace(1, 9, 9).reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # test Ticks
+ self.tickBar.setTicksNumber(10)
+ self.tickBar.computeTicks()
+
+ numpy.array_equal(self.tickBar.ticks, numpy.linspace(-4, 5, 10))
+
+ # test ColorScale
+ val = self.colorScale.getValueFromRelativePosition(1.0)
+ self.assertTrue(val == 5.0)
+
+ val = self.colorScale.getValueFromRelativePosition(0.0)
+ self.assertTrue(val == -4.0)
+
+
+class TestColorBarWidget(TestCaseQt):
+ """Test interaction with the ColorBarWidget"""
+
+ def setUp(self):
+ super(TestColorBarWidget, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestColorBarWidget, self).tearDown()
+
+ def testEmptyColorBar(self):
+ colorBar = ColorBarWidget(parent=None)
+ colorBar.show()
+ self.qWaitForWindowExposed(colorBar)
+
+ def testNegativeColormaps(self):
+ """test the behavior of the ColorBarWidget in the case of negative
+ values
+
+ Note : colorbar is modified by the Plot directly not ColorBarWidget
+ """
+ colormapLog = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+
+ data = numpy.array([-5, -4, 0, 2, 3, 5, 10, 20, 30])
+ data = data.reshape(3, 3)
+ self.plot.addImage(data=data, colormap=colormapLog, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # default behavior when with log and negative values: should set vmin
+ # to 1 and vmax to 10
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 2)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 30)
+
+ # if data is positive
+ data[data < 1] = data.max()
+ self.plot.addImage(data=data,
+ colormap=colormapLog,
+ legend='toto',
+ replace=True)
+ self.plot.setActiveImage('toto')
+
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == data.min())
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == data.max())
+
+ def testPlotAssocation(self):
+ """Make sure the ColorBarWidget is properly connected with the plot"""
+ colormap = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None)
+
+ # make sure that default settings are the same (but a copy of the
+ self.colorBar.setPlot(self.plot)
+ self.assertTrue(
+ self.colorBar.getColormap() is self.plot.getDefaultColormap())
+
+ data = numpy.linspace(0, 10, 100).reshape(10, 10)
+ self.plot.addImage(data=data, colormap=colormap, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ # make sure the modification of the colormap has been done
+ self.assertFalse(
+ self.colorBar.getColormap() is self.plot.getDefaultColormap())
+ self.assertTrue(
+ self.colorBar.getColormap() is colormap)
+
+ # test that colorbar is updated when default plot colormap changes
+ self.plot.clear()
+ plotColormap = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+ self.plot.setDefaultColormap(plotColormap)
+ self.assertTrue(self.colorBar.getColormap() is plotColormap)
+
+ def testColormapWithoutRange(self):
+ """Test with a colormap with vmin==vmax"""
+ colormap = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=1.0,
+ vmax=1.0)
+ self.colorBar.setColormap(colormap)
+
+
+class TestColorBarUpdate(TestCaseQt):
+ """Test that the ColorBar is correctly updated when the signal 'sigChanged'
+ of the colormap is emitted
+ """
+
+ def setUp(self):
+ super(TestColorBarUpdate, self).setUp()
+ self.plot = Plot2D()
+ self.colorBar = self.plot.getColorBarWidget()
+ self.colorBar.setVisible(True) # Makes sure the colormap is visible
+ self.colorBar.setPlot(self.plot)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.data = numpy.random.rand(9).reshape(3, 3)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ del self.colorBar
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestColorBarUpdate, self).tearDown()
+
+ def testUpdateColorMap(self):
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=0,
+ vmax=1)
+
+ # check inital state
+ self.plot.addImage(data=self.data, colormap=colormap, legend='toto')
+ self.plot.setActiveImage('toto')
+
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 1)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmin == 0)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmax == 1)
+ self.assertIsInstance(
+ self.colorBar.getColorScaleBar().getTickBar()._normalizer,
+ LinearNormalization)
+
+ # update colormap
+ colormap.setVMin(0.5)
+ self.assertTrue(self.colorBar.getColorScaleBar().minVal == 0.5)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmin == 0.5)
+
+ colormap.setVMax(0.8)
+ self.assertTrue(self.colorBar.getColorScaleBar().maxVal == 0.8)
+ self.assertTrue(
+ self.colorBar.getColorScaleBar().getTickBar()._vmax == 0.8)
+
+ colormap.setNormalization('log')
+ self.assertIsInstance(
+ self.colorBar.getColorScaleBar().getTickBar()._normalizer,
+ LogarithmicNormalization)
+
+ # TODO : should also check that if the colormap is changing then values (especially in log scale)
+ # should be coherent if in autoscale
diff --git a/src/silx/gui/plot/test/testCompareImages.py b/src/silx/gui/plot/test/testCompareImages.py
new file mode 100644
index 0000000..cf54b99
--- /dev/null
+++ b/src/silx/gui/plot/test/testCompareImages.py
@@ -0,0 +1,106 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for CompareImages widget"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "23/07/2018"
+
+import unittest
+import numpy
+import weakref
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.CompareImages import CompareImages
+
+
+class TestCompareImages(TestCaseQt):
+ """Test that CompareImages widget is working in some cases"""
+
+ def setUp(self):
+ super(TestCompareImages, self).setUp()
+ self.widget = CompareImages()
+
+ def tearDown(self):
+ ref = weakref.ref(self.widget)
+ self.widget = None
+ self.qWaitForDestroy(ref)
+ super(TestCompareImages, self).tearDown()
+
+ def testIntensityImage(self):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(10, 10)
+ self.widget.setData(image1, image2)
+
+ def testRgbImage(self):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 3))
+ self.widget.setData(image1, image2)
+
+ def testRgbaImage(self):
+ image1 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ image2 = numpy.random.randint(0, 255, size=(10, 10, 4))
+ self.widget.setData(image1, image2)
+
+ def testVizualisations(self):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(10, 10)
+ self.widget.setData(image1, image2)
+ for mode in CompareImages.VisualizationMode:
+ self.widget.setVisualizationMode(mode)
+
+ def testAlignemnt(self):
+ image1 = numpy.random.rand(10, 10)
+ image2 = numpy.random.rand(5, 5)
+ self.widget.setData(image1, image2)
+ for mode in CompareImages.AlignmentMode:
+ self.widget.setAlignmentMode(mode)
+
+ def testGetPixel(self):
+ image1 = numpy.random.rand(11, 11)
+ image2 = numpy.random.rand(5, 5)
+ image1[5, 5] = 111.111
+ image2[2, 2] = 222.222
+ self.widget.setData(image1, image2)
+ expectedValue = {}
+ expectedValue[CompareImages.AlignmentMode.CENTER] = 222.222
+ expectedValue[CompareImages.AlignmentMode.STRETCH] = 222.222
+ expectedValue[CompareImages.AlignmentMode.ORIGIN] = None
+ for mode in expectedValue.keys():
+ self.widget.setAlignmentMode(mode)
+ data = self.widget.getRawPixelData(11 / 2.0, 11 / 2.0)
+ data1, data2 = data
+ self.assertEqual(data1, 111.111)
+ self.assertEqual(data2, expectedValue[mode])
+
+ def testImageEmpty(self):
+ self.widget.setData(image1=None, image2=None)
+ self.assertTrue(self.widget.getRawPixelData(11 / 2.0, 11 / 2.0) == (None, None))
+
+ def testSetImageSeparately(self):
+ self.widget.setImage1(numpy.random.rand(10, 10))
+ self.widget.setImage2(numpy.random.rand(10, 10))
+ for mode in CompareImages.VisualizationMode:
+ self.widget.setVisualizationMode(mode)
diff --git a/src/silx/gui/plot/test/testComplexImageView.py b/src/silx/gui/plot/test/testComplexImageView.py
new file mode 100644
index 0000000..46025b9
--- /dev/null
+++ b/src/silx/gui/plot/test/testComplexImageView.py
@@ -0,0 +1,84 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test suite for :class:`ComplexImageView`"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+import logging
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot import ComplexImageView
+
+from .utils import PlotWidgetTestCase
+
+
+logger = logging.getLogger(__name__)
+
+
+class TestComplexImageView(PlotWidgetTestCase, ParametricTestCase):
+ """Test suite of ComplexImageView widget"""
+
+ def _createPlot(self):
+ return ComplexImageView.ComplexImageView()
+
+ def testPlot2DComplex(self):
+ """Test API of ComplexImageView widget"""
+ data = numpy.array(((0, 1j), (1, 1 + 1j)), dtype=numpy.complex64)
+ self.plot.setData(data)
+ self.plot.setKeepDataAspectRatio(True)
+ self.plot.getPlot().resetZoom()
+ self.qWait(100)
+
+ # Test colormap API
+ colormap = self.plot.getColormap().copy()
+ colormap.setName('magma')
+ self.plot.setColormap(colormap)
+ self.qWait(100)
+
+ # Test all modes
+ modes = self.plot.supportedComplexModes()
+ for mode in modes:
+ with self.subTest(mode=mode):
+ self.plot.setComplexMode(mode)
+ self.qWait(100)
+
+ # Test origin and scale API
+ self.plot.setScale((2, 1))
+ self.qWait(100)
+ self.plot.setOrigin((1, 1))
+ self.qWait(100)
+
+ # Test no data
+ self.plot.setData(numpy.zeros((0, 0), dtype=numpy.complex64))
+ self.qWait(100)
+
+ # Test float data
+ self.plot.setData(numpy.arange(100, dtype=numpy.float64).reshape(10, 10))
+ self.qWait(100)
diff --git a/src/silx/gui/plot/test/testCurvesROIWidget.py b/src/silx/gui/plot/test/testCurvesROIWidget.py
new file mode 100644
index 0000000..d7dfafd
--- /dev/null
+++ b/src/silx/gui/plot/test/testCurvesROIWidget.py
@@ -0,0 +1,465 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel", "H. Payno"]
+__license__ = "MIT"
+__date__ = "16/11/2017"
+
+
+import logging
+import os.path
+import pytest
+from collections import OrderedDict
+import numpy
+
+from silx.gui import qt
+from silx.gui.plot import items
+from silx.gui.plot import Plot1D
+from silx.test.utils import temp_dir
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import PlotWindow, CurvesROIWidget
+from silx.gui.plot.CurvesROIWidget import ROITable
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot.PlotInteraction import ItemsInteraction
+
+_logger = logging.getLogger(__name__)
+
+
+class TestCurvesROIWidget(TestCaseQt):
+ """Basic test for CurvesROIWidget"""
+
+ def setUp(self):
+ super(TestCurvesROIWidget, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.widget = self.plot.getCurvesRoiDockWidget()
+
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+
+ super(TestCurvesROIWidget, self).tearDown()
+
+ def testDummyAPI(self):
+ """Simple test of the getRois and setRois API"""
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ rois_defs = self.widget.roiWidget.getRois()
+ self.widget.roiWidget.setRois(rois=rois_defs)
+
+ def testWithCurves(self):
+ """Plot with curves: test all ROI widget buttons"""
+ for offset in range(2):
+ self.plot.addCurve(numpy.arange(1000),
+ offset + numpy.random.random(1000),
+ legend=str(offset))
+
+ # Add two ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+ self.qWait(200)
+
+ # Change active curve
+ self.plot.setActiveCurve(str(1))
+
+ # Delete a ROI
+ self.mouseClick(self.widget.roiWidget.delButton, qt.Qt.LeftButton)
+ self.qWait(200)
+
+ with temp_dir() as tmpDir:
+ self.tmpFile = os.path.join(tmpDir, 'test.ini')
+
+ # Save ROIs
+ self.widget.roiWidget.save(self.tmpFile)
+ self.assertTrue(os.path.isfile(self.tmpFile))
+ self.assertEqual(len(self.widget.getRois()), 2)
+
+ # Reset ROIs
+ self.mouseClick(self.widget.roiWidget.resetButton,
+ qt.Qt.LeftButton)
+ self.qWait(200)
+ rois = self.widget.getRois()
+ self.assertEqual(len(rois), 1)
+ roiID = list(rois.keys())[0]
+ self.assertEqual(rois[roiID].getName(), 'ICR')
+
+ # Load ROIs
+ self.widget.roiWidget.load(self.tmpFile)
+ self.assertEqual(len(self.widget.getRois()), 2)
+
+ del self.tmpFile
+
+ def testMiddleMarker(self):
+ """Test with middle marker enabled"""
+ self.widget.roiWidget.roiTable.setMiddleROIMarkerFlag(True)
+
+ # Add a ROI
+ self.mouseClick(self.widget.roiWidget.addButton, qt.Qt.LeftButton)
+
+ for roiID in self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers:
+ handler = self.widget.roiWidget.roiTable._markersHandler._roiMarkerHandlers[roiID]
+ assert handler.getMarker('min')
+ xleftMarker = handler.getMarker('min').getXPosition()
+ xMiddleMarker = handler.getMarker('middle').getXPosition()
+ xRightMarker = handler.getMarker('max').getXPosition()
+ thValue = xleftMarker + (xRightMarker - xleftMarker) / 2.
+ self.assertAlmostEqual(xMiddleMarker, thValue)
+
+ def testAreaCalculation(self):
+ """Test result of area calculation"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
+
+ self.assertEqual(roi_pos.computeRawAndNetArea(posCurve),
+ (numpy.trapz(y=[10, 20], x=[10, 20]),
+ 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetArea(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetArea(negCurve),
+ ((-150.0), 0.0))
+
+ def testCountsCalculation(self):
+ """Test result of count calculation"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+
+ # Add two curves
+ self.plot.addCurve(x, y, legend="positive")
+ self.plot.addCurve(-x, y, legend="negative")
+
+ # Make sure there is an active curve and it is the positive one
+ self.plot.setActiveCurve("positive")
+
+ # Add two ROIs
+ roi_neg = CurvesROIWidget.ROI(name='negative', fromdata=-20,
+ todata=-10, type_='X')
+ roi_pos = CurvesROIWidget.ROI(name='positive', fromdata=10,
+ todata=20, type_='X')
+
+ self.widget.roiWidget.setRois((roi_pos, roi_neg))
+
+ posCurve = self.plot.getCurve('positive')
+ negCurve = self.plot.getCurve('negative')
+
+ self.assertEqual(roi_pos.computeRawAndNetCounts(posCurve),
+ (y[10:21].sum(), 0.0))
+ self.assertEqual(roi_pos.computeRawAndNetCounts(negCurve),
+ (0.0, 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(posCurve),
+ ((0.0), 0.0))
+ self.assertEqual(roi_neg.computeRawAndNetCounts(negCurve),
+ (y[10:21].sum(), 0.0))
+
+ def testDeferedInit(self):
+ """Test behavior of the deferedInit"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+ roisDefs = OrderedDict([
+ ["range1",
+ OrderedDict([["from", 20], ["to", 200], ["type", "energy"]])],
+ ["range2",
+ OrderedDict([["from", 300], ["to", 500], ["type", "energy"]])]
+ ])
+
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ self.plot.getCurvesRoiDockWidget().setRois(roisDefs)
+ self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
+ self.plot.getCurvesRoiDockWidget().setVisible(True)
+ self.assertEqual(len(roiWidget.getRois()), len(roisDefs))
+
+ def testDictCompatibility(self):
+ """Test that ROI api is valid with dict and not information is lost"""
+ roiDict = {'from': 20, 'to': 200, 'type': 'energy', 'comment': 'no',
+ 'name': 'myROI', 'calibration': [1, 2, 3]}
+ roi = CurvesROIWidget.ROI._fromDict(roiDict)
+ self.assertEqual(roi.toDict(), roiDict)
+
+ def testShowAllROI(self):
+ """Test the show allROI action"""
+ x = numpy.arange(100.)
+ y = numpy.arange(100.)
+ self.plot.addCurve(x=x, y=y, legend="name", replace="True")
+
+ roisDefsDict = {
+ "range1": {"from": 20, "to": 200,"type": "energy"},
+ "range2": {"from": 300, "to": 500, "type": "energy"}
+ }
+
+ roisDefsObj = (
+ CurvesROIWidget.ROI(name='range3', fromdata=20, todata=200,
+ type_='energy'),
+ CurvesROIWidget.ROI(name='range4', fromdata=300, todata=500,
+ type_='energy')
+ )
+ self.widget.roiWidget.showAllMarkers(True)
+ roiWidget = self.plot.getCurvesRoiDockWidget().roiWidget
+ roiWidget.setRois(roisDefsDict)
+ markers = [item for item in self.plot.getItems()
+ if isinstance(item, items.MarkerBase)]
+ self.assertEqual(len(markers), 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 1)
+
+ roiWidget.setRois(roisDefsObj)
+ self.qapp.processEvents()
+ markers = [item for item in self.plot.getItems()
+ if isinstance(item, items.MarkerBase)]
+ self.assertEqual(len(markers), 2*3)
+
+ markersHandler = self.widget.roiWidget.roiTable._markersHandler
+ roiWidget.showAllMarkers(True)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 2)
+
+ roiWidget.showAllMarkers(False)
+ ICRROI = markersHandler.getVisibleRois()
+ self.assertEqual(len(ICRROI), 1)
+
+ def testRoiEdition(self):
+ """Make sure if the ROI object is edited the ROITable will be updated
+ """
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi, ))
+
+ x = (0, 1, 1, 2, 2, 3)
+ y = (1, 1, 2, 2, 1, 1)
+ self.plot.addCurve(x=x, y=y, legend='linearCurve')
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ roiTable = self.widget.roiWidget.roiTable
+ indexesColumns = CurvesROIWidget.ROITable.COLUMNS_INDEX
+ itemRawCounts = roiTable.item(0, indexesColumns['Raw Counts'])
+ itemNetCounts = roiTable.item(0, indexesColumns['Net Counts'])
+
+ self.assertTrue(itemRawCounts.text() == '8.0')
+ self.assertTrue(itemNetCounts.text() == '2.0')
+
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ itemNetArea = roiTable.item(0, indexesColumns['Net Area'])
+
+ self.assertTrue(itemRawArea.text() == '4.0')
+ self.assertTrue(itemNetArea.text() == '1.0')
+
+ roi.setTo(2)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '3.0')
+ roi.setFrom(1)
+ itemRawArea = roiTable.item(0, indexesColumns['Raw Area'])
+ self.assertTrue(itemRawArea.text() == '2.0')
+
+ def testRemoveActiveROI(self):
+ """Test widget behavior when removing the active ROI"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+
+ self.widget.roiWidget.roiTable.setActiveRoi(None)
+ self.assertEqual(len(self.widget.roiWidget.roiTable.selectedItems()), 0)
+ self.widget.roiWidget.setRois((roi,))
+ self.plot.setActiveCurve(legend='linearCurve')
+ self.widget.calculateROIs()
+
+ def testEmitCurrentROI(self):
+ """Test behavior of the CurvesROIWidget.sigROISignal"""
+ roi = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.widget.roiWidget.setRois((roi,))
+ signalListener = SignalListener()
+ self.widget.roiWidget.sigROISignal.connect(signalListener.partial())
+ self.widget.show()
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 0)
+ self.assertIs(self.widget.roiWidget.roiTable.activeRoi, roi)
+ roi.setFrom(0.0)
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 0)
+ roi.setFrom(0.3)
+ self.qapp.processEvents()
+ self.assertEqual(signalListener.callCount(), 1)
+
+
+class TestRoiWidgetSignals(TestCaseQt):
+ """Test Signals emitted by the RoiWidgetSignals"""
+
+ def setUp(self):
+ self.plot = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend='curve0')
+ self.listener = SignalListener()
+ self.curves_roi_widget = self.plot.getCurvesRoiWidget()
+ self.curves_roi_widget.sigROISignal.connect(self.listener)
+ assert self.curves_roi_widget.isVisible() is False
+ assert self.listener.callCount() == 0
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ toolButton = getQToolButtonFromAction(self.plot.getRoiAction())
+ self.qapp.processEvents()
+ self.mouseClick(widget=toolButton, button=qt.Qt.LeftButton)
+
+ self.curves_roi_widget.show()
+ self.qWaitForWindowExposed(self.curves_roi_widget)
+
+ def tearDown(self):
+ self.plot = None
+ self.curves_roi_widget = None
+
+ def testSigROISignalAddRmRois(self):
+ """Test SigROISignal when adding and removing ROIS"""
+ self.listener.clear()
+
+ roi1 = CurvesROIWidget.ROI(name='linear', fromdata=0, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.listener.clear()
+
+ roi2 = CurvesROIWidget.ROI(name='linear2', fromdata=0, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi2)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear2')
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.removeROI(roi2)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.deleteActiveRoi()
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi is None)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] is None)
+ self.listener.clear()
+
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'linear')
+ self.assertTrue(self.curves_roi_widget.roiTable.activeRoi == roi1)
+ self.listener.clear()
+ self.qapp.processEvents()
+
+ self.curves_roi_widget.roiTable.removeROI(roi1)
+ self.qapp.processEvents()
+ self.assertEqual(self.listener.callCount(), 1)
+ self.assertTrue(self.listener.arguments()[0][0]['current'] == 'ICR')
+ self.listener.clear()
+
+ def testSigROISignalModifyROI(self):
+ """Test SigROISignal when modifying it"""
+ self.curves_roi_widget.roiTable.setMiddleROIMarkerFlag(True)
+ roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
+ self.curves_roi_widget.roiTable.addRoi(roi1)
+ self.curves_roi_widget.roiTable.setActiveRoi(roi1)
+
+ # test modify the roi2 object
+ self.listener.clear()
+ roi1.setFrom(0.56)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setTo(2.56)
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setName('linear2')
+ self.assertEqual(self.listener.callCount(), 1)
+ self.listener.clear()
+ roi1.setType('new type')
+ self.assertEqual(self.listener.callCount(), 1)
+
+ widget = self.plot.getWidgetHandle()
+ widget.setFocus(qt.Qt.OtherFocusReason)
+ self.plot.raise_()
+ self.qapp.processEvents()
+
+ # modify roi limits (from the gui)
+ roi_marker_handler = self.curves_roi_widget.roiTable._markersHandler.getMarkerHandler(roi1.getID())
+ for marker_type in ('min', 'max', 'middle'):
+ with self.subTest(marker_type=marker_type):
+ self.listener.clear()
+ marker = roi_marker_handler.getMarker(marker_type)
+ x_pix, y_pix = self.plot.dataToPixel(marker.getXPosition(), marker.getYPosition())
+ self.mouseMove(widget, pos=(x_pix, y_pix))
+ self.qWait(100)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=(x_pix, y_pix))
+ self.mouseMove(widget, pos=(x_pix+20, y_pix))
+ self.qWait(100)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(x_pix+20, y_pix))
+ self.qWait(100)
+ self.mouseMove(widget, pos=(x_pix, y_pix))
+ self.qapp.processEvents()
+ self.assertEqual(self.listener.callCount(), 1)
+
+ def testSetActiveCurve(self):
+ """Test sigRoiSignal when set an active curve"""
+ roi1 = CurvesROIWidget.ROI(name='linear', fromdata=2, todata=5)
+ self.curves_roi_widget.roiTable.setActiveRoi(roi1)
+ self.listener.clear()
+ self.plot.setActiveCurve('curve0')
+ self.assertEqual(self.listener.callCount(), 0)
diff --git a/src/silx/gui/plot/test/testImageStack.py b/src/silx/gui/plot/test/testImageStack.py
new file mode 100644
index 0000000..5c44691
--- /dev/null
+++ b/src/silx/gui/plot/test/testImageStack.py
@@ -0,0 +1,186 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ImageStack"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "15/01/2020"
+
+
+import unittest
+import tempfile
+import numpy
+import h5py
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.io.url import DataUrl
+from silx.gui.plot.ImageStack import ImageStack
+from silx.gui.utils.testutils import SignalListener
+from collections import OrderedDict
+import os
+import time
+import shutil
+
+
+class TestImageStack(TestCaseQt):
+ """Simple test of the Image stack"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.urls = OrderedDict()
+ self._raw_data = {}
+ self._folder = tempfile.mkdtemp()
+ self._n_urls = 10
+ file_name = os.path.join(self._folder, 'test_inage_stack_file.h5')
+ with h5py.File(file_name, 'w') as h5f:
+ for i in range(self._n_urls):
+ width = numpy.random.randint(10, 40)
+ height = numpy.random.randint(10, 40)
+ raw_data = numpy.random.random((width, height))
+ self._raw_data[i] = raw_data
+ h5f[str(i)] = raw_data
+ self.urls[i] = DataUrl(file_path=file_name,
+ data_path=str(i),
+ scheme='silx')
+ self.widget = ImageStack()
+
+ self.urlLoadedListener = SignalListener()
+ self.widget.sigLoaded.connect(self.urlLoadedListener)
+
+ self.currentUrlChangedListener = SignalListener()
+ self.widget.sigCurrentUrlChanged.connect(self.currentUrlChangedListener)
+
+ def tearDown(self):
+ shutil.rmtree(self._folder)
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.widget.close()
+ TestCaseQt.setUp(self)
+
+ def testControls(self):
+ """Test that selection using the url table and the slider are working
+ """
+ self.widget.show()
+ self.assertEqual(self.widget.getCurrentUrl(), None)
+ self.assertEqual(self.widget.getCurrentUrlIndex(), None)
+ self.widget.setUrls(list(self.urls.values()))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+
+ self.assertEqual(self.widget.getCurrentUrl(), self.urls[0])
+
+ # make sure all image are loaded
+ self.assertEqual(self.urlLoadedListener.callCount(), self._n_urls)
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[0])
+ self.assertEqual(self.widget._slider.value(), 0)
+
+ self.widget._urlsTable.setUrl(self.urls[4])
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[4])
+ self.assertEqual(self.widget._slider.value(), 4)
+ self.assertEqual(self.widget.getCurrentUrl(), self.urls[4])
+ self.assertEqual(self.widget.getCurrentUrlIndex(), 4)
+
+ self.widget._slider.setUrlIndex(6)
+ numpy.testing.assert_array_equal(
+ self.widget.getPlotWidget().getActiveImage(just_legend=False).getData(),
+ self._raw_data[6])
+ self.assertEqual(self.widget._urlsTable.currentItem().text(),
+ self.urls[6].path())
+
+ def testCurrentUrlSignals(self):
+ """Test emission of 'currentUrlChangedListener'"""
+ # check initialization
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 0)
+ self.widget.setUrls(list(self.urls.values()))
+ self.qapp.processEvents()
+ time.sleep(0.5)
+ self.qapp.processEvents()
+ # once loaded the two signals should have been sended
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
+ # if the slider is stuck to the same position no signal should be
+ # emitted
+ self.qapp.processEvents()
+ time.sleep(0.5)
+ self.qapp.processEvents()
+ self.assertEqual(self.widget._slider.value(), 0)
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 1)
+ # if slider position is changed, one of each signal should have been
+ # emitted
+ self.widget._urlsTable.setUrl(self.urls[4])
+ self.qapp.processEvents()
+ time.sleep(1.5)
+ self.qapp.processEvents()
+ self.assertEqual(self.currentUrlChangedListener.callCount(), 2)
+
+ def testUtils(self):
+ """Test that some utils functions are working"""
+ self.widget.show()
+ self.widget.setUrls(list(self.urls.values()))
+ self.assertEqual(len(self.widget.getUrls()), len(self.urls))
+
+ # wait for image to be loaded
+ self._waitUntilUrlLoaded()
+
+ urls_values = list(self.urls.values())
+ self.assertEqual(urls_values[0], self.urls[0])
+ self.assertEqual(urls_values[7], self.urls[7])
+
+ self.assertEqual(self.widget._getNextUrl(urls_values[2]).path(),
+ urls_values[3].path())
+ self.assertEqual(self.widget._getPreviousUrl(urls_values[0]), None)
+ self.assertEqual(self.widget._getPreviousUrl(urls_values[6]).path(),
+ urls_values[5].path())
+
+ self.assertEqual(self.widget._getNNextUrls(2, urls_values[0]),
+ urls_values[1:3])
+ self.assertEqual(self.widget._getNNextUrls(5, urls_values[7]),
+ urls_values[8:])
+ self.assertEqual(self.widget._getNPreviousUrls(3, urls_values[2]),
+ urls_values[:2])
+ self.assertEqual(self.widget._getNPreviousUrls(5, urls_values[8]),
+ urls_values[3:8])
+
+ def _waitUntilUrlLoaded(self, timeout=2.0):
+ """Wait until all image urls are loaded"""
+ loop_duration = 0.2
+ remaining_duration = timeout
+ while(len(self.widget._loadingThreads) > 0 and remaining_duration > 0):
+ remaining_duration -= loop_duration
+ time.sleep(loop_duration)
+ self.qapp.processEvents()
+
+ if remaining_duration <= 0.0:
+ remaining_urls = []
+ for thread_ in self.widget._loadingThreads:
+ remaining_urls.append(thread_.url.path())
+ mess = 'All images are not loaded after the time out. ' \
+ 'Remaining urls are: ' + str(remaining_urls)
+ raise TimeoutError(mess)
+ return True
diff --git a/src/silx/gui/plot/test/testImageView.py b/src/silx/gui/plot/test/testImageView.py
new file mode 100644
index 0000000..7c1355f
--- /dev/null
+++ b/src/silx/gui/plot/test/testImageView.py
@@ -0,0 +1,194 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import items
+
+from silx.gui.plot.ImageView import ImageView
+from silx.gui.colors import Colormap
+
+
+class TestImageView(TestCaseQt):
+ """Tests of ImageView widget."""
+
+ def setUp(self):
+ super(TestImageView, self).setUp()
+ self.plot = ImageView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ self.qapp.processEvents()
+ super(TestImageView, self).tearDown()
+
+ def testSetImage(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ # With reset=False
+ self.plot.setImage(image[::2, ::2], reset=False)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ self.plot.setImage(image, origin=(10, 20), scale=(2, 4), reset=False)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ # With reset=True
+ self.plot.setImage(image, origin=(1, 2), scale=(1, 0.5), reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (1, 11))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (2, 7))
+
+ self.plot.setImage(image[::2, ::2], reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 5))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 5))
+
+ def testColormap(self):
+ """Test get|setColormap"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image)
+
+ # Colormap as dict
+ self.plot.setColormap({'name': 'viridis',
+ 'normalization': 'log',
+ 'autoscale': False,
+ 'vmin': 0,
+ 'vmax': 1})
+ colormap = self.plot.getColormap()
+ self.assertEqual(colormap.getName(), 'viridis')
+ self.assertEqual(colormap.getNormalization(), 'log')
+ self.assertEqual(colormap.getVMin(), 0)
+ self.assertEqual(colormap.getVMax(), 1)
+
+ # Colormap as keyword arguments
+ self.plot.setColormap(colormap='magma',
+ normalization='linear',
+ autoscale=True,
+ vmin=1,
+ vmax=2)
+ self.assertEqual(colormap.getName(), 'magma')
+ self.assertEqual(colormap.getNormalization(), 'linear')
+ self.assertEqual(colormap.getVMin(), None)
+ self.assertEqual(colormap.getVMax(), None)
+
+ # Update colormap with keyword argument
+ self.plot.setColormap(normalization='log')
+ self.assertEqual(colormap.getNormalization(), 'log')
+
+ # Colormap as Colormap object
+ cmap = Colormap()
+ self.plot.setColormap(cmap)
+ self.assertIs(self.plot.getColormap(), cmap)
+
+ def testSetProfileWindowBehavior(self):
+ """Test change of profile window display behavior"""
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.POPUP,
+ )
+
+ self.plot.setProfileWindowBehavior('embedded')
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.EMBEDDED,
+ )
+
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image)
+
+ self.plot.setProfileWindowBehavior(
+ ImageView.ProfileWindowBehavior.POPUP
+ )
+ self.assertIs(
+ self.plot.getProfileWindowBehavior(),
+ ImageView.ProfileWindowBehavior.POPUP,
+ )
+
+ def testRGBImage(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ def testRGBAImage(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 4, dtype=numpy.uint8).reshape(10, 10, 4)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 10))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 10))
+
+ def testImageAggregationMode(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.qWait(100)
+
+ def testImageAggregationModeBackToNormalMode(self):
+ """Test setImage"""
+ image = numpy.arange(100).reshape(10, 10)
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.NONE)
+ self.qWait(100)
+
+ def testRGBAInAggregationMode(self):
+ """Test setImage"""
+ image = numpy.arange(100 * 3, dtype=numpy.uint8).reshape(10, 10, 3)
+
+ self.plot.setImage(image, reset=True)
+ self.qWait(100)
+ self.plot.getAggregationModeAction().setAggregationMode(items.ImageDataAggregated.Aggregation.MAX)
+ self.qWait(100)
diff --git a/src/silx/gui/plot/test/testInteraction.py b/src/silx/gui/plot/test/testInteraction.py
new file mode 100644
index 0000000..d136b21
--- /dev/null
+++ b/src/silx/gui/plot/test/testInteraction.py
@@ -0,0 +1,78 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests from interaction state machines"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "18/02/2016"
+
+
+import unittest
+
+from silx.gui.plot import Interaction
+
+
+class TestInteraction(unittest.TestCase):
+ def testClickOrDrag(self):
+ """Minimalistic test for click or drag state machine."""
+ events = []
+
+ class TestClickOrDrag(Interaction.ClickOrDrag):
+ def click(self, x, y, btn):
+ events.append(('click', x, y, btn))
+
+ def beginDrag(self, x, y, btn):
+ events.append(('beginDrag', x, y, btn))
+
+ def drag(self, x, y, btn):
+ events.append(('drag', x, y, btn))
+
+ def endDrag(self, start, end, btn):
+ events.append(('endDrag', start, end, btn))
+
+ clickOrDrag = TestClickOrDrag()
+
+ # click
+ clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+
+ clickOrDrag.handleEvent('release', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 1)
+ self.assertEqual(events[0], ('click', 10, 10, Interaction.LEFT_BTN))
+
+ # drag
+ events = []
+ clickOrDrag.handleEvent('press', 10, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 0)
+ clickOrDrag.handleEvent('move', 15, 10)
+ self.assertEqual(len(events), 2) # Received beginDrag and drag
+ self.assertEqual(events[0], ('beginDrag', 10, 10, Interaction.LEFT_BTN))
+ self.assertEqual(events[1], ('drag', 15, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent('move', 20, 10)
+ self.assertEqual(len(events), 3)
+ self.assertEqual(events[-1], ('drag', 20, 10, Interaction.LEFT_BTN))
+ clickOrDrag.handleEvent('release', 20, 10, Interaction.LEFT_BTN)
+ self.assertEqual(len(events), 4)
+ self.assertEqual(events[-1], ('endDrag', (10, 10), (20, 10), Interaction.LEFT_BTN))
diff --git a/src/silx/gui/plot/test/testItem.py b/src/silx/gui/plot/test/testItem.py
new file mode 100644
index 0000000..0b15dc3
--- /dev/null
+++ b/src/silx/gui/plot/test/testItem.py
@@ -0,0 +1,360 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for PlotWidget items."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/09/2017"
+
+
+import unittest
+
+import numpy
+
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.plot.items import ItemChangedType
+from silx.gui.plot import items
+from .utils import PlotWidgetTestCase
+
+
+class TestSigItemChangedSignal(PlotWidgetTestCase):
+ """Test item's sigItemChanged signal"""
+
+ def testCurveChanged(self):
+ """Test sigItemChanged for curve"""
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
+ curve = self.plot.getCurve('test')
+
+ listener = SignalListener()
+ curve.sigItemChanged.connect(listener)
+
+ # Test for signal in Item class
+ curve.setVisible(False)
+ curve.setVisible(True)
+ curve.setZValue(100)
+
+ # Test for signals in PointsBase class
+ curve.setData(numpy.arange(100), numpy.arange(100))
+
+ # SymbolMixIn
+ curve.setSymbol('Circle')
+ curve.setSymbol('d')
+ curve.setSymbolSize(20)
+
+ # AlphaMixIn
+ curve.setAlpha(0.5)
+
+ # Test for signals in Curve class
+ # ColorMixIn
+ curve.setColor('yellow')
+ # YAxisMixIn
+ curve.setYAxis('right')
+ # FillMixIn
+ curve.setFill(True)
+ # LineMixIn
+ curve.setLineStyle(':')
+ curve.setLineStyle(':') # Not sending event
+ curve.setLineWidth(2)
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.VISIBLE,
+ ItemChangedType.VISIBLE,
+ ItemChangedType.ZVALUE,
+ ItemChangedType.DATA,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL,
+ ItemChangedType.SYMBOL_SIZE,
+ ItemChangedType.ALPHA,
+ ItemChangedType.COLOR,
+ ItemChangedType.YAXIS,
+ ItemChangedType.FILL,
+ ItemChangedType.LINE_STYLE,
+ ItemChangedType.LINE_WIDTH])
+
+ def testHistogramChanged(self):
+ """Test sigItemChanged for Histogram"""
+ self.plot.addHistogram(
+ numpy.arange(10), edges=numpy.arange(11), legend='test')
+ histogram = self.plot.getHistogram('test')
+ listener = SignalListener()
+ histogram.sigItemChanged.connect(listener)
+
+ # Test signals in Histogram class
+ histogram.setData(numpy.zeros(10), numpy.arange(11))
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.DATA])
+
+ def testImageDataChanged(self):
+ """Test sigItemChanged for ImageData"""
+ self.plot.addImage(numpy.arange(100).reshape(10, 10), legend='test')
+ image = self.plot.getImage('test')
+
+ listener = SignalListener()
+ image.sigItemChanged.connect(listener)
+
+ # ColormapMixIn
+ colormap = self.plot.getDefaultColormap().copy()
+ image.setColormap(colormap)
+ image.getColormap().setName('viridis')
+
+ # Test of signals in ImageBase class
+ image.setOrigin(10)
+ image.setScale(2)
+
+ # Test of signals in ImageData class
+ image.setData(numpy.ones((10, 10)))
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.COLORMAP,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.POSITION,
+ ItemChangedType.SCALE,
+ ItemChangedType.COLORMAP,
+ ItemChangedType.DATA])
+
+ def testImageRgbaChanged(self):
+ """Test sigItemChanged for ImageRgba"""
+ self.plot.addImage(numpy.ones((10, 10, 3)), legend='rgb')
+ image = self.plot.getImage('rgb')
+
+ listener = SignalListener()
+ image.sigItemChanged.connect(listener)
+
+ # Test of signals in ImageRgba class
+ image.setData(numpy.zeros((10, 10, 3)))
+
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.DATA])
+
+ def testMarkerChanged(self):
+ """Test sigItemChanged for markers"""
+ self.plot.addMarker(10, 20, legend='test')
+ marker = self.plot._getMarker('test')
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+
+ # Test signals in _BaseMarker
+ marker.setPosition(10, 10)
+ marker.setPosition(10, 10) # Not sending event
+ marker.setText('toto')
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION,
+ ItemChangedType.TEXT])
+
+ # XMarker
+ self.plot.addXMarker(10, legend='x')
+ marker = self.plot._getMarker('x')
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+ marker.setPosition(20, 20)
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION])
+
+ # YMarker
+ self.plot.addYMarker(10, legend='x')
+ marker = self.plot._getMarker('x')
+
+ listener = SignalListener()
+ marker.sigItemChanged.connect(listener)
+ marker.setPosition(20, 20)
+ self.assertEqual(listener.arguments(argumentIndex=0),
+ [ItemChangedType.POSITION])
+
+ def testScatterChanged(self):
+ """Test sigItemChanged for scatter"""
+ data = numpy.arange(10)
+ self.plot.addScatter(data, data, data, legend='test')
+ scatter = self.plot.getScatter('test')
+
+ listener = SignalListener()
+ scatter.sigItemChanged.connect(listener)
+
+ # ColormapMixIn
+ scatter.getColormap().setName('viridis')
+
+ # Test of signals in Scatter class
+ scatter.setData((0, 1, 2), (1, 0, 2), (0, 1, 2))
+
+ # Visualization mode changed
+ scatter.setVisualization(scatter.Visualization.SOLID)
+
+ self.assertEqual(listener.arguments(),
+ [(ItemChangedType.COLORMAP,),
+ (ItemChangedType.DATA,),
+ (ItemChangedType.COLORMAP,),
+ (ItemChangedType.VISUALIZATION_MODE,)])
+
+ def testShapeChanged(self):
+ """Test sigItemChanged for shape"""
+ data = numpy.array((1., 10.))
+ self.plot.addShape(data, data, legend='test', shape='rectangle')
+ shape = self.plot._getItem(kind='item', legend='test')
+
+ listener = SignalListener()
+ shape.sigItemChanged.connect(listener)
+
+ shape.setOverlay(True)
+ shape.setPoints(((2., 2.), (3., 3.)))
+
+ self.assertEqual(listener.arguments(),
+ [(ItemChangedType.OVERLAY,),
+ (ItemChangedType.DATA,)])
+
+
+class TestSymbol(PlotWidgetTestCase):
+ """Test item's symbol """
+
+ def test(self):
+ """Test sigItemChanged for curve"""
+ self.plot.addCurve(numpy.arange(10), numpy.arange(10), legend='test')
+ curve = self.plot.getCurve('test')
+
+ # SymbolMixIn
+ curve.setSymbol('o')
+ name = curve.getSymbolName()
+ self.assertEqual('Circle', name)
+
+ name = curve.getSymbolName('d')
+ self.assertEqual('Diamond', name)
+
+
+class TestVisibleExtent(PlotWidgetTestCase):
+ """Test item's visible extent feature"""
+
+ def testGetVisibleBounds(self):
+ """Test Item.getVisibleBounds"""
+
+ # Create test items (with a bounding box of x: [1,3], y: [0,2])
+ curve = items.Curve()
+ curve.setData((1, 2, 3), (0, 1, 2))
+
+ histogram = items.Histogram()
+ histogram.setData((0, 1, 2), (1, 5/3, 7/3, 3))
+
+ image = items.ImageData()
+ image.setOrigin((1, 0))
+ image.setData(numpy.arange(4).reshape(2, 2))
+
+ scatter = items.Scatter()
+ scatter.setData((1, 2, 3), (0, 1, 2), (1, 2, 3))
+
+ bbox = items.BoundingRect()
+ bbox.setBounds((1, 3, 0, 2))
+
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for item in (curve, histogram, image, scatter, bbox):
+ with self.subTest(item=item):
+ xaxis.setLimits(0, 100)
+ yaxis.setLimits(0, 100)
+ self.plot.addItem(item)
+ self.assertEqual(item.getVisibleBounds(), (1., 3., 0., 2.))
+
+ xaxis.setLimits(0.5, 2.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0., 2.))
+
+ yaxis.setLimits(0.5, 1.5)
+ self.assertEqual(item.getVisibleBounds(), (1, 2.5, 0.5, 1.5))
+
+ item.setVisible(False)
+ self.assertIsNone(item.getVisibleBounds())
+
+ self.plot.clear()
+
+ def testVisibleExtentTracking(self):
+ """Test Item's visible extent tracking"""
+ image = items.ImageData()
+ image.setData(numpy.arange(6).reshape(2, 3))
+
+ listener = SignalListener()
+ image._sigVisibleBoundsChanged.connect(listener)
+ image._setVisibleBoundsTracking(True)
+ self.assertTrue(image._isVisibleBoundsTracking())
+
+ self.plot.addItem(image)
+ self.assertEqual(listener.callCount(), 1)
+
+ self.plot.getXAxis().setLimits(0, 1)
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.hide()
+ self.qapp.processEvents()
+ # No event here
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.getXAxis().setLimits(1, 2)
+ # No event since PlotWidget is hidden, delayed to PlotWidget show
+ self.assertEqual(listener.callCount(), 2)
+
+ self.plot.show()
+ self.qapp.processEvents()
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 3)
+
+ image.setOrigin((-1, -1))
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(False)
+ image.setOrigin((0, 0))
+ # No event since item is not visible
+ self.assertEqual(listener.callCount(), 4)
+
+ image.setVisible(True)
+ # Receives delayed event now
+ self.assertEqual(listener.callCount(), 5)
+
+
+class TestImageDataAggregated(PlotWidgetTestCase):
+ """Test ImageDataAggregated item"""
+
+ def test(self):
+ data = numpy.random.random(1024**2).reshape(1024, 1024)
+
+ item = items.ImageDataAggregated()
+ item.setData(data)
+ self.assertEqual(item.getAggregationMode(), item.Aggregation.NONE)
+ self.plot.addItem(item)
+
+ for mode in item.Aggregation.members():
+ with self.subTest(mode=mode):
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ item.setAggregationMode(mode)
+ self.qapp.processEvents()
+
+ # Zoom-out
+ for i in range(4):
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.plot.setLimits(
+ xmin - (xmax - xmin)/2,
+ xmax + (xmax - xmin)/2,
+ ymin - (ymax - ymin)/2,
+ ymax + (ymax - ymin)/2,
+ )
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot/test/testLegendSelector.py b/src/silx/gui/plot/test/testLegendSelector.py
new file mode 100644
index 0000000..c40875d
--- /dev/null
+++ b/src/silx/gui/plot/test/testLegendSelector.py
@@ -0,0 +1,130 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Rueter", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/05/2017"
+
+
+import logging
+import unittest
+
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import LegendSelector
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestLegendSelector(TestCaseQt):
+ """Basic test for LegendSelector"""
+
+ def testLegendSelector(self):
+ """Test copied from __main__ of LegendSelector in PyMca"""
+ class Notifier(qt.QObject):
+ def __init__(self):
+ qt.QObject.__init__(self)
+ self.chk = True
+
+ def signalReceived(self, **kw):
+ obj = self.sender()
+ _logger.info('NOTIFIER -- signal received\n\tsender: %s',
+ str(obj))
+
+ notifier = Notifier()
+
+ legends = ['Legend0',
+ 'Legend1',
+ 'Long Legend 2',
+ 'Foo Legend 3',
+ 'Even Longer Legend 4',
+ 'Short Leg 5',
+ 'Dot symbol 6',
+ 'Comma symbol 7']
+ colors = [qt.Qt.darkRed, qt.Qt.green, qt.Qt.yellow, qt.Qt.darkCyan,
+ qt.Qt.blue, qt.Qt.darkBlue, qt.Qt.red, qt.Qt.darkYellow]
+ symbols = ['o', 't', '+', 'x', 's', 'd', '.', ',']
+
+ win = LegendSelector.LegendListView()
+ # win = LegendListContextMenu()
+ # win = qt.QWidget()
+ # layout = qt.QVBoxLayout()
+ # layout.setContentsMargins(0,0,0,0)
+ llist = []
+
+ for _idx, (l, c, s) in enumerate(zip(legends, colors, symbols)):
+ ddict = {
+ 'color': qt.QColor(c),
+ 'linewidth': 4,
+ 'symbol': s,
+ }
+ legend = l
+ llist.append((legend, ddict))
+ # item = qt.QListWidgetItem(win)
+ # legendWidget = LegendListItemWidget(l)
+ # legendWidget.icon.setSymbol(s)
+ # legendWidget.icon.setColor(qt.QColor(c))
+ # layout.addWidget(legendWidget)
+ # win.setItemWidget(item, legendWidget)
+
+ # win = LegendListItemWidget('Some Legend 1')
+ # print(llist)
+ model = LegendSelector.LegendModel(legendList=llist)
+ win.setModel(model)
+ win.setSelectionModel(qt.QItemSelectionModel(model))
+ win.setContextMenu()
+ # print('Edit triggers: %d'%win.editTriggers())
+
+ # win = LegendListWidget(None, legends)
+ # win[0].updateItem(ddict)
+ # win.setLayout(layout)
+ win.sigLegendSignal.connect(notifier.signalReceived)
+ win.show()
+
+ win.clear()
+ win.setLegendList(llist)
+
+ self.qWaitForWindowExposed(win)
+
+
+class TestRenameCurveDialog(TestCaseQt):
+ """Basic test for RenameCurveDialog"""
+
+ def testDialog(self):
+ """Create dialog, change name and press OK"""
+ self.dialog = LegendSelector.RenameCurveDialog(
+ None, 'curve1', ['curve1', 'curve2', 'curve3'])
+ self.dialog.open()
+ self.qWaitForWindowExposed(self.dialog)
+ self.keyClicks(self.dialog.lineEdit, 'changed')
+ self.mouseClick(self.dialog.okButton, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ ret = self.dialog.result()
+ self.assertEqual(ret, qt.QDialog.Accepted)
+ newName = self.dialog.getText()
+ self.assertEqual(newName, 'curve1changed')
+ del self.dialog
diff --git a/src/silx/gui/plot/test/testLimitConstraints.py b/src/silx/gui/plot/test/testLimitConstraints.py
new file mode 100644
index 0000000..0bd8e50
--- /dev/null
+++ b/src/silx/gui/plot/test/testLimitConstraints.py
@@ -0,0 +1,114 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test setLimitConstaints on the PlotWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/08/2017"
+
+
+import unittest
+from silx.gui.plot import PlotWidget
+
+
+class TestLimitConstaints(unittest.TestCase):
+ """Tests setLimitConstaints class"""
+
+ def setUp(self):
+ self.plot = PlotWidget()
+
+ def tearDown(self):
+ self.plot = None
+
+ def testApi(self):
+ """Test availability of the API"""
+ self.plot.getXAxis().setLimitsConstraints(minPos=1, maxPos=10)
+ self.plot.getXAxis().setRangeConstraints(minRange=1, maxRange=1)
+ self.plot.getYAxis().setLimitsConstraints(minPos=1, maxPos=10)
+ self.plot.getYAxis().setRangeConstraints(minRange=1, maxRange=1)
+
+ def testXMinMax(self):
+ """Test limit constains on x-axis"""
+ self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (-1, 101))
+
+ def testYMinMax(self):
+ """Test limit constains on y-axis"""
+ self.plot.getYAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
+ self.assertEqual(self.plot.getYAxis().getLimits(), (0, 100))
+
+ def testMinXRange(self):
+ """Test min range constains on x-axis"""
+ self.plot.getXAxis().setRangeConstraints(minRange=100)
+ self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+
+ def testMaxXRange(self):
+ """Test max range constains on x-axis"""
+ self.plot.getXAxis().setRangeConstraints(maxRange=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+
+ def testMinYRange(self):
+ """Test min range constains on y-axis"""
+ self.plot.getYAxis().setRangeConstraints(minRange=100)
+ self.plot.setLimits(xmin=1, xmax=99, ymin=1, ymax=99)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+
+ def testMaxYRange(self):
+ """Test max range constains on y-axis"""
+ self.plot.getYAxis().setRangeConstraints(maxRange=100)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ limits = self.plot.getXAxis().getLimits()
+ self.assertNotEqual(limits[1] - limits[0], 100)
+ limits = self.plot.getYAxis().getLimits()
+ self.assertEqual(limits[1] - limits[0], 100)
+
+ def testChangeOfConstraints(self):
+ """Test changing of the constraints"""
+ self.plot.getXAxis().setRangeConstraints(minRange=10, maxRange=10)
+ # There is no more constraints on the range
+ self.plot.getXAxis().setRangeConstraints(minRange=None, maxRange=None)
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (-1, 101))
+
+ def testSettingConstraints(self):
+ """Test setting a constaint (setLimits first then the constaint)"""
+ self.plot.setLimits(xmin=-1, xmax=101, ymin=-1, ymax=101)
+ self.plot.getXAxis().setLimitsConstraints(minPos=0, maxPos=100)
+ self.assertEqual(self.plot.getXAxis().getLimits(), (0, 100))
diff --git a/src/silx/gui/plot/test/testMaskToolsWidget.py b/src/silx/gui/plot/test/testMaskToolsWidget.py
new file mode 100644
index 0000000..522ca51
--- /dev/null
+++ b/src/silx/gui/plot/test/testMaskToolsWidget.py
@@ -0,0 +1,306 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, MaskToolsWidget
+from .utils import PlotWidgetTestCase
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestMaskToolsWidget, self).setUp()
+ self.widget = MaskToolsWidget.MaskToolsDockWidget(plot=self.plot, name='TEST')
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+ super(TestMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks('single')
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks('exclusive')
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=(0, 0))
+ self.mouseMove(plot, pos=pos0)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
+ self.mouseMove(plot, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset)] # Close polygon
+
+ self.mouseMove(plot, pos=(0, 0))
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ self.mouseMove(plot, pos=(0, 0))
+ for start, end in zip(star[:-1], star[1:]):
+ self.mouseMove(plot, pos=start)
+ self.mousePress(plot, qt.Qt.LeftButton, pos=start)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=end)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=end)
+ self.qapp.processEvents()
+
+ def _isMaskItemSync(self):
+ """Check if masks from item and tools are sync or not"""
+ if self.maskWidget.isItemMaskUpdated():
+ return numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(),
+ self.plot.getActiveImage().getMaskData(copy=False)))
+ else:
+ return True
+
+ def testWithAnImage(self):
+ """Plot with an image: test MaskToolsWidget interactions"""
+
+ # Add and remove a image (this should enable/disable GUI + change mask)
+ self.plot.addImage(numpy.random.random(1024**2).reshape(1024, 1024),
+ legend='test')
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='image')
+ self.qapp.processEvents()
+
+ tests = [((0, 0), (1, 1)),
+ ((1000, 1000), (1, 1)),
+ ((0, 0), (-1, -1)),
+ ((1000, 1000), (-1, -1))]
+
+ for itemMaskUpdated in (False, True):
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.maskWidget.setItemMaskUpdated(itemMaskUpdated)
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test',
+ origin=origin,
+ scale=scale)
+ self.qapp.processEvents()
+
+ self.assertEqual(
+ self.maskWidget.isItemMaskUpdated(), itemMaskUpdated)
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+ self.assertTrue(self._isMaskItemSync())
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ """Plot with an image: test MaskToolsWidget operations"""
+ self.plot.addImage(numpy.arange(1024**2).reshape(1024, 1024),
+ legend='test')
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(), ref_mask)))
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveFit2D(self):
+ self.__loadSave("msk")
+
+ def testSigMaskChangedEmitted(self):
+ self.plot.addImage(numpy.arange(512**2).reshape(512, 512),
+ legend='test')
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
diff --git a/src/silx/gui/plot/test/testPixelIntensityHistoAction.py b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
new file mode 100644
index 0000000..14a467d
--- /dev/null
+++ b/src/silx/gui/plot/test/testPixelIntensityHistoAction.py
@@ -0,0 +1,145 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PixelIntensitiesHistoAction"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/03/2018"
+
+
+import numpy
+import unittest
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+from silx.gui import qt
+from silx.gui.plot import Plot2D
+
+
+class TestPixelIntensitiesHisto(TestCaseQt, ParametricTestCase):
+ """Tests for PixelIntensitiesHistoAction widget."""
+
+ def setUp(self):
+ super(TestPixelIntensitiesHisto, self).setUp()
+ self.image = numpy.random.rand(10, 10)
+ self.plotImage = Plot2D()
+ self.plotImage.getIntensityHistogramAction().setVisible(True)
+
+ def tearDown(self):
+ del self.plotImage
+ super(TestPixelIntensitiesHisto, self).tearDown()
+
+ def testShowAndHide(self):
+ """Simple test that the plot is showing and hiding when activating the
+ action"""
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertTrue(histoAction.getHistogramWidget().isVisible())
+
+ # test the pixel intensity diagram is hiding
+ self.qapp.setActiveWindow(self.plotImage)
+ self.qapp.processEvents()
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ self.assertFalse(histoAction.getHistogramWidget().isVisible())
+
+ def testImageFormatInput(self):
+ """Test multiple type as image input"""
+ typesToTest = [numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
+ numpy.float32, numpy.float64]
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.plotImage.show()
+ button = getQToolButtonFromAction(
+ self.plotImage.getIntensityHistogramAction())
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+ for typeToTest in typesToTest:
+ with self.subTest(typeToTest=typeToTest):
+ self.plotImage.addImage(self.image.astype(typeToTest),
+ origin=(0, 0), legend='sino')
+
+ def testScatter(self):
+ """Test that an histogram from a scatter is displayed"""
+ xx = numpy.arange(10)
+ yy = numpy.arange(10)
+ value = numpy.sin(xx)
+ self.plotImage.addScatter(xx, yy, value)
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
+ self.assertEqual(len(items), 1)
+
+ def testChangeItem(self):
+ """Test that histogram changes it the item changes"""
+ xx = numpy.arange(10)
+ yy = numpy.arange(10)
+ value = numpy.sin(xx)
+ self.plotImage.addScatter(xx, yy, value)
+ self.plotImage.show()
+
+ histoAction = self.plotImage.getIntensityHistogramAction()
+
+ # test the pixel intensity diagram is showing
+ button = getQToolButtonFromAction(histoAction)
+ self.assertIsNot(button, None)
+ self.mouseMove(button)
+ self.mouseClick(button, qt.Qt.LeftButton)
+ self.qapp.processEvents()
+
+ # Reach histogram from the first item
+ widget = histoAction.getHistogramWidget()
+ self.assertTrue(widget.isVisible())
+ items = widget.getPlotWidget().getItems()
+ data1 = items[0].getValueData(copy=False)
+
+ # Set another item to the plot
+ self.plotImage.addImage(self.image, origin=(0, 0), legend='sino')
+ self.qapp.processEvents()
+ data2 = items[0].getValueData(copy=False)
+
+ # Histogram is not the same
+ self.assertFalse(numpy.array_equal(data1, data2))
diff --git a/src/silx/gui/plot/test/testPlotActions.py b/src/silx/gui/plot/test/testPlotActions.py
new file mode 100644
index 0000000..f38e05b
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotActions.py
@@ -0,0 +1,110 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of actions integrated in the plot window"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+
+import pytest
+import weakref
+
+from silx.gui import qt
+from silx.gui.colors import Colormap
+from silx.gui.plot.PlotWindow import PlotWindow
+
+import numpy
+
+
+@pytest.fixture
+def colormap1():
+ colormap = Colormap(name='gray',
+ vmin=10.0, vmax=20.0,
+ normalization='linear')
+ yield colormap
+
+
+@pytest.fixture
+def colormap2():
+ colormap = Colormap(name='red',
+ vmin=10.0, vmax=20.0,
+ normalization='linear')
+ yield colormap
+
+
+@pytest.fixture
+def plot(qapp):
+ plot = PlotWindow()
+ plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ yield weakref.proxy(plot)
+ plot.close()
+ qapp.processEvents()
+
+
+def test_action_active_colormap(qapp_utils, plot, colormap1, colormap2):
+ plot.getColormapAction()._actionTriggered(checked=True)
+ colormapDialog = plot.getColormapAction()._dialog
+
+ defaultColormap = plot.getDefaultColormap()
+ assert colormapDialog.getColormap() is defaultColormap
+
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
+ origin=(0, 0),
+ colormap=colormap1)
+ plot.setActiveImage('img1')
+ assert colormapDialog.getColormap() is colormap1
+
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img2',
+ origin=(0, 0), colormap=colormap2)
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img3',
+ origin=(0, 0))
+
+ plot.setActiveImage('img3')
+ assert colormapDialog.getColormap() is defaultColormap
+ plot.getActiveImage().setColormap(colormap2)
+ assert colormapDialog.getColormap() is colormap2
+
+ plot.remove('img2')
+ plot.remove('img3')
+ plot.remove('img1')
+ assert colormapDialog.getColormap() is defaultColormap
+
+
+def test_action_show_hide_colormap_dialog(qapp_utils, plot, colormap1):
+ plot.getColormapAction()._actionTriggered(checked=True)
+ colormapDialog = plot.getColormapAction()._dialog
+
+ plot.getColormapAction()._actionTriggered(checked=False)
+ assert not plot.getColormapAction().isChecked()
+ plot.getColormapAction()._actionTriggered(checked=True)
+ assert plot.getColormapAction().isChecked()
+ plot.addImage(data=numpy.random.rand(10, 10), legend='img1',
+ origin=(0, 0), colormap=colormap1)
+ colormap1.setName('red')
+ plot.getColormapAction()._actionTriggered()
+ colormap1.setName('blue')
+ colormapDialog.close()
+ assert not plot.getColormapAction().isChecked()
diff --git a/src/silx/gui/plot/test/testPlotInteraction.py b/src/silx/gui/plot/test/testPlotInteraction.py
new file mode 100644
index 0000000..fba364e
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotInteraction.py
@@ -0,0 +1,160 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016=2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of plot interaction, through a PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/09/2017"
+
+
+import unittest
+from silx.gui import qt
+from .utils import PlotWidgetTestCase
+
+
+class _SignalDump(object):
+ """Callable object that store passed arguments in a list"""
+
+ def __init__(self):
+ self._received = []
+
+ def __call__(self, *args):
+ self._received.append(args)
+
+ @property
+ def received(self):
+ """Return a shallow copy of the list of received arguments"""
+ return list(self._received)
+
+
+class TestSelectPolygon(PlotWidgetTestCase):
+ """Test polygon selection interaction"""
+
+ def _interactionModeChanged(self, source):
+ """Check that source received in event is the correct one"""
+ self.assertEqual(source, self)
+
+ def _draw(self, polygon):
+ """Draw a polygon in the plot
+
+ :param polygon: List of points (x, y) of the polygon (closed)
+ """
+ plot = self.plot.getWidgetHandle()
+
+ dump = _SignalDump()
+ self.plot.sigPlotSignal.connect(dump)
+
+ for pos in polygon:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ self.plot.sigPlotSignal.disconnect(dump)
+ return [args[0] for args in dump.received]
+
+ def test(self):
+ """Test draw polygons + events"""
+ self.plot.sigInteractiveModeChanged.connect(
+ self._interactionModeChanged)
+
+ self.plot.setInteractiveMode(
+ 'draw', shape='polygon', label='test', source=self)
+ interaction = self.plot.getInteractiveMode()
+
+ self.assertEqual(interaction['mode'], 'draw')
+ self.assertEqual(interaction['shape'], 'polygon')
+
+ self.plot.sigInteractiveModeChanged.disconnect(
+ self._interactionModeChanged)
+
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ # Star polygon
+ star = [(xCenter, yCenter + offset),
+ (xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter),
+ (xCenter - offset, yCenter),
+ (xCenter + offset, yCenter - offset),
+ (xCenter, yCenter + offset)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(star)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 6)
+
+ # Large square
+ largeSquare = [(xCenter - offset, yCenter - offset),
+ (xCenter + offset, yCenter - offset),
+ (xCenter + offset, yCenter + offset),
+ (xCenter - offset, yCenter + offset),
+ (xCenter - offset, yCenter - offset)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(largeSquare)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 5)
+
+ # Rectangle too thin along X: Some points are ignored
+ thinRectX = [(xCenter, yCenter - offset),
+ (xCenter, yCenter + offset),
+ (xCenter + 1, yCenter + offset),
+ (xCenter + 1, yCenter - offset)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(thinRectX)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 3)
+
+ # Rectangle too thin along Y: Some points are ignored
+ thinRectY = [(xCenter - offset, yCenter),
+ (xCenter + offset, yCenter),
+ (xCenter + offset, yCenter + 1),
+ (xCenter - offset, yCenter + 1)] # Close polygon
+
+ # Draw while dumping signals
+ events = self._draw(thinRectY)
+
+ # Test last event
+ drawEvents = [event for event in events
+ if event['event'].startswith('drawing')]
+ self.assertEqual(drawEvents[-1]['event'], 'drawingFinished')
+ self.assertEqual(len(drawEvents[-1]['points']), 3)
diff --git a/src/silx/gui/plot/test/testPlotWidget.py b/src/silx/gui/plot/test/testPlotWidget.py
new file mode 100755
index 0000000..f6e108d
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidget.py
@@ -0,0 +1,2113 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/01/2019"
+
+
+import unittest
+import logging
+import numpy
+import pytest
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.items.curve import CurveStyle
+from silx.gui.plot.items import BoundingRect, XAxisExtent, YAxisExtent, Axis
+from silx.gui.colors import Colormap
+
+from .utils import PlotWidgetTestCase
+
+
+SIZE = 1024
+"""Size of the test image"""
+
+DATA_2D = numpy.arange(SIZE ** 2).reshape(SIZE, SIZE)
+"""Image data set"""
+
+
+logger = logging.getLogger(__name__)
+
+
+class TestSpecialBackend(PlotWidgetTestCase, ParametricTestCase):
+
+ def __init__(self, methodName='runTest', backend=None):
+ TestCaseQt.__init__(self, methodName=methodName)
+ self.__backend = backend
+
+ def _createPlot(self):
+ return PlotWidget(backend=self.__backend)
+
+ def testPlot(self):
+ self.assertIsNotNone(self.plot)
+
+
+class TestPlotWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for PlotWidget"""
+
+ def testShow(self):
+ """Most basic test"""
+ pass
+
+ def testSetTitleLabels(self):
+ """Set title and axes labels"""
+
+ title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ self.plot.setGraphTitle(title)
+ self.plot.getXAxis().setLabel(xlabel)
+ self.plot.getYAxis().setLabel(ylabel)
+ self.qapp.processEvents()
+
+ self.assertEqual(self.plot.getGraphTitle(), title)
+ self.assertEqual(self.plot.getXAxis().getLabel(), xlabel)
+ self.assertEqual(self.plot.getYAxis().getLabel(), ylabel)
+
+ def _checkLimits(self,
+ expectedXLim=None,
+ expectedYLim=None,
+ expectedRatio=None):
+ """Assert that limits are as expected"""
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+ ratio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ if expectedXLim is not None:
+ self.assertEqual(expectedXLim, xlim)
+
+ if expectedYLim is not None:
+ self.assertEqual(expectedYLim, ylim)
+
+ if expectedRatio is not None:
+ self.assertTrue(
+ numpy.allclose(expectedRatio, ratio, atol=0.01))
+
+ def testChangeLimitsWithAspectRatio(self):
+ self.plot.setKeepDataAspectRatio()
+ self.qapp.processEvents()
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+ defaultRatio = abs(xlim[1] - xlim[0]) / abs(ylim[1] - ylim[0])
+
+ self.plot.getXAxis().setLimits(1., 10.)
+ self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ self._checkLimits(expectedXLim=(1., 10.), expectedRatio=defaultRatio)
+
+ self.plot.getYAxis().setLimits(1., 10.)
+ self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+ self.qapp.processEvents()
+ self._checkLimits(expectedYLim=(1., 10.), expectedRatio=defaultRatio)
+
+ def testResizeWidget(self):
+ """Test resizing the widget and receiving limitsChanged events"""
+ self.plot.resize(200, 200)
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ xlim = self.plot.getXAxis().getLimits()
+ ylim = self.plot.getYAxis().getLimits()
+
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial('x'))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial('y'))
+
+ # Resize without aspect ratio
+ self.plot.resize(200, 300)
+ self.qapp.processEvents()
+ self.qWait(100)
+ self._checkLimits(expectedXLim=xlim, expectedYLim=ylim)
+ self.assertEqual(listener.callCount(), 0)
+
+ # Resize with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ self.qapp.processEvents()
+ self.qWait(1000)
+ listener.clear() # Clean-up received signal
+
+ self.plot.resize(200, 200)
+ self.qapp.processEvents()
+ self.qWait(100)
+ self.assertNotEqual(listener.callCount(), 0)
+
+ def testAddRemoveItemSignals(self):
+ """Test sigItemAdded and sigItemAboutToBeRemoved"""
+ listener = SignalListener()
+ self.plot.sigItemAdded.connect(listener.partial('add'))
+ self.plot.sigItemAboutToBeRemoved.connect(listener.partial('remove'))
+
+ self.plot.addCurve((1, 2, 3), (3, 2, 1), legend='curve')
+ self.assertEqual(listener.callCount(), 1)
+
+ curve = self.plot.getCurve('curve')
+ self.plot.remove('curve')
+ self.assertEqual(listener.callCount(), 2)
+ self.assertEqual(listener.arguments(callIndex=0), ('add', curve))
+ self.assertEqual(listener.arguments(callIndex=1), ('remove', curve))
+
+ def testGetItems(self):
+ """Test getItems method"""
+ curve_x = 1, 2
+ self.plot.addCurve(curve_x, (3, 4))
+ image = (0, 1), (2, 3)
+ self.plot.addImage(image)
+ scatter_x = 10, 11
+ self.plot.addScatter(scatter_x, (12, 13), (0, 1))
+ marker_pos = 5, 5
+ self.plot.addMarker(*marker_pos)
+ marker_x = 6
+ self.plot.addXMarker(marker_x)
+ self.plot.addShape((0, 5), (2, 10), shape='rectangle')
+
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 6)
+ self.assertTrue(numpy.all(numpy.equal(items[0].getXData(), curve_x)))
+ self.assertTrue(numpy.all(numpy.equal(items[1].getData(), image)))
+ self.assertTrue(numpy.all(numpy.equal(items[2].getXData(), scatter_x)))
+ self.assertTrue(numpy.all(numpy.equal(items[3].getPosition(), marker_pos)))
+ self.assertTrue(numpy.all(numpy.equal(items[4].getPosition()[0], marker_x)))
+ self.assertEqual(items[5].getType(), 'rectangle')
+
+ def testRemoveDiscardItem(self):
+ """Test removeItem and discardItem"""
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ self.plot.removeItem(curve)
+ with self.assertRaises(ValueError):
+ self.plot.removeItem(curve)
+
+ self.plot.addCurve((1, 2, 3), (1, 2, 3))
+ curve = self.plot.getItems()[0]
+ result = self.plot.discardItem(curve)
+ self.assertTrue(result)
+ result = self.plot.discardItem(curve)
+ self.assertFalse(result)
+
+ def testBackGroundColors(self):
+ self.plot.setVisible(True)
+ self.qWaitForWindowExposed(self.plot)
+ self.qapp.processEvents()
+
+ # Custom the full background
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ self.plot.setBackgroundColor("red")
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Custom the data background
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.plot.setDataBackgroundColor("red")
+ color = self.plot.getDataBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.qapp.processEvents()
+
+ # Back to default
+ self.plot.setBackgroundColor('white')
+ self.plot.setDataBackgroundColor(None)
+ color = self.plot.getBackgroundColor()
+ self.assertTrue(color.isValid())
+ self.assertEqual(color, qt.QColor(255, 255, 255))
+ color = self.plot.getDataBackgroundColor()
+ self.assertFalse(color.isValid())
+ self.qapp.processEvents()
+
+
+class TestPlotImage(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addImage"""
+
+ def setUp(self):
+ super(TestPlotImage, self).setUp()
+
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+
+ def testPlotColormapTemperature(self):
+ self.plot.setGraphTitle('Temp. Linear')
+
+ colormap = Colormap(name='temperature',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapGray(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Gray Linear')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotColormapTemperatureLog(self):
+ self.plot.setGraphTitle('Temp. Log')
+
+ colormap = Colormap(name='temperature',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotRgbRgba(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('RGB + RGBA')
+
+ rgb = numpy.array(
+ (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 255))),
+ dtype=numpy.uint8)
+
+ self.plot.addImage(rgb, legend="rgb_uint8",
+ origin=(0, 0), scale=(1, 1),
+ resetzoom=False)
+
+ rgb = numpy.array(
+ (((0, 0, 0), (32768, 0, 0), (65535, 0, 0)),
+ ((0, 32768, 0), (0, 32768, 32768), (0, 32768, 65535))),
+ dtype=numpy.uint16)
+
+ self.plot.addImage(rgb, legend="rgb_uint16",
+ origin=(3, 2), scale=(2, 2),
+ resetzoom=False)
+
+ rgba = numpy.array(
+ (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
+ ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
+ dtype=numpy.float32)
+
+ self.plot.addImage(rgba, legend="rgba_float32",
+ origin=(9, 6), scale=(1, 1),
+ resetzoom=False)
+
+ self.plot.resetZoom()
+
+ def testPlotColormapCustom(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Custom colormap')
+
+ colormap = Colormap(name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=((0., 0., 0.), (1., 0., 0.),
+ (0., 1., 0.), (0., 0., 1.)))
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap,
+ resetzoom=False)
+
+ colormap = Colormap(name=None,
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None,
+ colors=numpy.array(
+ ((0, 0, 0, 0), (0, 0, 0, 128),
+ (128, 128, 128, 128), (255, 255, 255, 255)),
+ dtype=numpy.uint8))
+ self.plot.addImage(DATA_2D, legend="image 2", colormap=colormap,
+ origin=(DATA_2D.shape[0], 0),
+ resetzoom=False)
+ self.plot.resetZoom()
+
+ def testPlotColormapNaNColor(self):
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setGraphTitle('Colormap with NaN color')
+
+ colormap = Colormap()
+ colormap.setNaNColor('red')
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+ data = DATA_2D.astype(numpy.float32)
+ data[len(data)//2:] = numpy.nan
+ self.plot.addImage(data, legend="image 1", colormap=colormap,
+ resetzoom=False)
+ self.plot.resetZoom()
+
+ colormap.setNaNColor((0., 1., 0., 1.))
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(0, 255, 0))
+ self.qapp.processEvents()
+
+ def testImageOriginScale(self):
+ """Test of image with different origin and scale"""
+ self.plot.setGraphTitle('origin and scale')
+
+ tests = [ # (origin, scale)
+ ((10, 20), (1, 1)),
+ ((10, 20), (-1, -1)),
+ ((-10, 20), (2, 1)),
+ ((10, -20), (-1, -2)),
+ (100, 2),
+ (-100, (1, 1)),
+ ((10, 20), 2),
+ ]
+
+ for origin, scale in tests:
+ with self.subTest(origin=origin, scale=scale):
+ self.plot.addImage(DATA_2D, origin=origin, scale=scale)
+
+ try:
+ ox, oy = origin
+ except TypeError:
+ ox, oy = origin, origin
+ try:
+ sx, sy = scale
+ except TypeError:
+ sx, sy = scale, scale
+ xbounds = ox, ox + DATA_2D.shape[1] * sx
+ ybounds = oy, oy + DATA_2D.shape[0] * sy
+
+ # Check limits without aspect ratio
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.assertEqual(xmin, min(xbounds))
+ self.assertEqual(xmax, max(xbounds))
+ self.assertEqual(ymin, min(ybounds))
+ self.assertEqual(ymax, max(ybounds))
+
+ # Check limits with aspect ratio
+ self.plot.setKeepDataAspectRatio(True)
+ xmin, xmax = self.plot.getXAxis().getLimits()
+ ymin, ymax = self.plot.getYAxis().getLimits()
+ self.assertTrue(round(xmin, 7) <= min(xbounds))
+ self.assertTrue(round(xmax, 7) >= max(xbounds))
+ self.assertTrue(round(ymin, 7) <= min(ybounds))
+ self.assertTrue(round(ymax, 7) >= max(ybounds))
+
+ self.plot.setKeepDataAspectRatio(False) # Reset aspect ratio
+ self.plot.clear()
+ self.plot.resetZoom()
+
+ def testPlotColormapDictAPI(self):
+ """Test that the addImage API using a colormap dictionary is still
+ working"""
+ self.plot.setGraphTitle('Temp. Log')
+
+ colormap = {
+ 'name': 'temperature',
+ 'normalization': 'log',
+ 'vmin': None,
+ 'vmax': None
+ }
+ self.plot.addImage(DATA_2D, legend="image 1", colormap=colormap)
+
+ def testPlotComplexImage(self):
+ """Test that a complex image is displayed as its absolute value."""
+ data = numpy.linspace(1, 1j, 100).reshape(10, 10)
+ self.plot.addImage(data, legend='complex')
+
+ image = self.plot.getActiveImage()
+ retrievedData = image.getData(copy=False)
+ self.assertTrue(
+ numpy.all(numpy.equal(retrievedData, numpy.absolute(data))))
+
+ def testPlotBooleanImage(self):
+ """Test that a boolean image is displayed and converted to int8."""
+ data = numpy.zeros((10, 10), dtype=bool)
+ data[::2, ::2] = True
+ self.plot.addImage(data, legend='boolean')
+
+ image = self.plot.getActiveImage()
+ retrievedData = image.getData(copy=False)
+ self.assertTrue(numpy.all(numpy.equal(retrievedData, data)))
+ self.assertIs(retrievedData.dtype.type, numpy.int8)
+
+ def testPlotAlphaImage(self):
+ """Test with an alpha image layer"""
+ data = numpy.random.random((10, 10))
+ alpha = numpy.linspace(0, 1, 100).reshape(10, 10)
+ self.plot.addImage(data, legend='image')
+ image = self.plot.getActiveImage()
+ image.setData(data, alpha=alpha)
+ self.qapp.processEvents()
+ self.assertTrue(numpy.array_equal(alpha, image.getAlphaData()))
+
+
+class TestPlotCurve(PlotWidgetTestCase):
+ """Basic tests for addCurve."""
+
+ # Test data sets
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ def setUp(self):
+ super(TestPlotCurve, self).setUp()
+ self.plot.setGraphTitle('Curve')
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+
+ self.plot.setActiveCurveHandling(False)
+
+ def testPlotCurveInfinite(self):
+ """Test plot curves with not finite data"""
+ tests = {
+ 'y all not finite': ([0, 1, 2], [numpy.inf, numpy.nan, -numpy.inf]),
+ 'x all not finite': ([numpy.inf, numpy.nan, -numpy.inf], [0, 1, 2]),
+ 'x some inf': ([0, numpy.inf, 2], [0, 1, 2]),
+ 'y some inf': ([0, 1, 2], [0, numpy.inf, 2])
+ }
+ for name, args in tests.items():
+ with self.subTest(name):
+ self.plot.addCurve(*args)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.plot.clear()
+
+ def testPlotCurveColorFloat(self):
+ color = numpy.array(numpy.random.random(3 * 1000),
+ dtype=numpy.float32).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 1",
+ replace=False, resetzoom=False,
+ color=color,
+ linestyle="", symbol="s")
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ def testPlotCurveColorByte(self):
+ color = numpy.array(255 * numpy.random.random(3 * 1000),
+ dtype=numpy.uint8).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 1",
+ replace=False, resetzoom=False,
+ color=color,
+ linestyle="", symbol="s")
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ def testPlotCurveColors(self):
+ color = numpy.array(numpy.random.random(3 * 1000),
+ dtype=numpy.float32).reshape(1000, 3)
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=color, linestyle="-", symbol='o')
+ self.plot.resetZoom()
+
+ # Test updating color array
+
+ # From array to array
+ newColors = numpy.ones((len(self.xData), 3), dtype=numpy.float32)
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=newColors, symbol='o')
+
+ # Array to single color
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color='green', symbol='o')
+
+ # single color to array
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve 2",
+ replace=False, resetzoom=False,
+ color=color, symbol='o')
+
+ def testPlotBaselineNumpyArray(self):
+ """simple test of the API with baseline as a numpy array"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+ baseline = y - 1.0
+
+ self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
+ baseline=baseline)
+
+ def testPlotBaselineScalar(self):
+ """simple test of the API with baseline as an int"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+
+ self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
+ baseline=0)
+
+ def testPlotBaselineList(self):
+ """simple test of the API with baseline as an int"""
+ x = numpy.arange(0, 10, step=0.1)
+ my_sin = numpy.sin(x)
+ y = numpy.arange(-4, 6, step=0.1) + my_sin
+
+ self.plot.addCurve(x=x, y=y, color='grey', legend='curve1', fill=True,
+ baseline=list(range(0, 100, 1)))
+
+ def testPlotCurveComplexData(self):
+ """Test curve with complex data"""
+ data = numpy.arange(100.) + 1j
+ self.plot.addCurve(x=data, y=data, xerror=data, yerror=data)
+
+
+class TestPlotHistogram(PlotWidgetTestCase):
+ """Basic tests for add Histogram"""
+ def setUp(self):
+ super(TestPlotHistogram, self).setUp()
+ self.edges = numpy.arange(0, 10, step=1)
+ self.histogram = numpy.random.random(len(self.edges))
+
+ def testPlot(self):
+ self.plot.addHistogram(histogram=self.histogram,
+ edges=self.edges,
+ legend='histogram1')
+
+ def testPlotBaseline(self):
+ self.plot.addHistogram(histogram=self.histogram,
+ edges=self.edges,
+ legend='histogram1',
+ color='blue',
+ baseline=-2,
+ z=2,
+ fill=True)
+
+
+class TestPlotScatter(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addScatter"""
+
+ def testScatter(self):
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ value = numpy.arange(100)
+ self.plot.addScatter(x, y, value)
+ self.plot.resetZoom()
+
+ def testScatterComplexData(self):
+ """Test scatter item with complex data"""
+ data = numpy.arange(100.) + 1j
+ self.plot.addScatter(
+ x=data, y=data, value=data, xerror=data, yerror=data)
+ self.plot.resetZoom()
+
+ def testScatterVisualization(self):
+ self.plot.addScatter((0, 1, 0, 1), (0, 0, 2, 2), (0, 1, 2, 3))
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ scatter = self.plot.getItems()[0]
+
+ for visualization in ('solid',
+ 'points',
+ 'regular_grid',
+ 'irregular_grid',
+ 'binned_statistic',
+ scatter.Visualization.SOLID,
+ scatter.Visualization.POINTS,
+ scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID,
+ scatter.Visualization.BINNED_STATISTIC):
+ with self.subTest(visualization=visualization):
+ scatter.setVisualization(visualization)
+ self.qapp.processEvents()
+
+ def testGridVisualization(self):
+ """Test regular and irregular grid mode with different points"""
+ points = { # name: (x, y, order)
+ 'single point': ((1.,), (1.,), 'row'),
+ 'horizontal line': ((0, 1, 2), (0, 0, 0), 'row'),
+ 'horizontal line backward': ((2, 1, 0), (0, 0, 0), 'row'),
+ 'vertical line': ((0, 0, 0), (0, 1, 2), 'row'),
+ 'vertical line backward': ((0, 0, 0), (2, 1, 0), 'row'),
+ 'grid fast x, +x +y': ((0, 1, 2, 0, 1, 2), (0, 0, 0, 1, 1, 1), 'row'),
+ 'grid fast x, +x -y': ((0, 1, 2, 0, 1, 2), (1, 1, 1, 0, 0, 0), 'row'),
+ 'grid fast x, -x -y': ((2, 1, 0, 2, 1, 0), (1, 1, 1, 0, 0, 0), 'row'),
+ 'grid fast x, -x +y': ((2, 1, 0, 2, 1, 0), (0, 0, 0, 1, 1, 1), 'row'),
+ 'grid fast y, +x +y': ((0, 0, 0, 1, 1, 1), (0, 1, 2, 0, 1, 2), 'column'),
+ 'grid fast y, +x -y': ((0, 0, 0, 1, 1, 1), (2, 1, 0, 2, 1, 0), 'column'),
+ 'grid fast y, -x -y': ((1, 1, 1, 0, 0, 0), (2, 1, 0, 2, 1, 0), 'column'),
+ 'grid fast y, -x +y': ((1, 1, 1, 0, 0, 0), (0, 1, 2, 0, 1, 2), 'column'),
+ }
+
+ self.plot.addScatter((), (), ())
+ scatter = self.plot.getItems()[0]
+
+ self.qapp.processEvents()
+
+ for visualization in (scatter.Visualization.REGULAR_GRID,
+ scatter.Visualization.IRREGULAR_GRID):
+ scatter.setVisualization(visualization)
+ self.assertIs(scatter.getVisualization(), visualization)
+
+ for name, (x, y, ref_order) in points.items():
+ with self.subTest(name=name, visualization=visualization.name):
+ scatter.setData(x, y, numpy.arange(len(x)))
+ self.plot.setGraphTitle(name)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ order = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_MAJOR_ORDER)
+ self.assertEqual(ref_order, order)
+
+ ref_bounds = (x[0], y[0]), (x[-1], y[-1])
+ bounds = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_BOUNDS)
+ self.assertEqual(ref_bounds, bounds)
+
+ shape = scatter.getCurrentVisualizationParameter(
+ scatter.VisualizationParameter.GRID_SHAPE)
+
+ self.plot.getXAxis().setLimits(numpy.min(x) - 1, numpy.max(x) + 1)
+ self.plot.getYAxis().setLimits(numpy.min(y) - 1, numpy.max(y) + 1)
+ self.qapp.processEvents()
+
+ for index, position in enumerate(zip(x, y)):
+ xpixel, ypixel = self.plot.dataToPixel(*position)
+ result = scatter.pick(xpixel, ypixel)
+ self.assertIsNotNone(result)
+ self.assertIs(result.getItem(), scatter)
+ self.assertEqual(result.getIndices(), (index,))
+
+ def testBinnedStatisticVisualization(self):
+ """Test binned display"""
+ self.plot.addScatter((), (), ())
+ scatter = self.plot.getItems()[0]
+ scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC)
+ self.assertIs(scatter.getVisualization(),
+ scatter.Visualization.BINNED_STATISTIC)
+ self.assertEqual(
+ scatter.getVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
+ 'mean')
+
+ self.qapp.processEvents()
+
+ scatter.setData(*numpy.random.random(300).reshape(3, -1))
+ self.qapp.processEvents()
+
+ # Update data
+ scatter.setData(*numpy.random.random(3000).reshape(3, -1))
+ self.qapp.processEvents()
+
+ for reduction in ('count', 'sum', 'mean'):
+ with self.subTest(reduction=reduction):
+ scatter.setVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION,
+ reduction)
+ self.assertEqual(
+ scatter.getVisualizationParameter(
+ scatter.VisualizationParameter.BINNED_STATISTIC_FUNCTION),
+ reduction)
+
+ self.qapp.processEvents()
+
+
+class TestPlotMarker(PlotWidgetTestCase):
+ """Basic tests for add*Marker"""
+
+ def setUp(self):
+ super(TestPlotMarker, self).setUp()
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0., 100., -100., 100.)
+
+ def testPlotMarkerX(self):
+ self.plot.setGraphTitle('Markers X')
+
+ markers = [
+ (10., 'blue', False, False),
+ (20., 'red', False, False),
+ (40., 'green', True, False),
+ (60., 'gray', True, True),
+ (80., 'black', False, True),
+ ]
+
+ for x, color, select, drag in markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerY(self):
+ self.plot.setGraphTitle('Markers Y')
+
+ markers = [
+ (-50., 'blue', False, False),
+ (-30., 'red', False, False),
+ (0., 'green', True, False),
+ (10., 'gray', True, True),
+ (80., 'black', False, True),
+ ]
+
+ for y, color, select, drag in markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPt(self):
+ self.plot.setGraphTitle('Markers Pt')
+
+ markers = [
+ (10., -50., 'blue', False, False),
+ (40., -30., 'red', False, False),
+ (50., 0., 'green', True, False),
+ (50., 20., 'gray', True, True),
+ (70., 50., 'black', False, True),
+ ]
+ for x, y, color, select, drag in markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerWithoutLegend(self):
+ self.plot.setGraphTitle('Markers without legend')
+ self.plot.getYAxis().setInverted(True)
+
+ # Markers without legend
+ self.plot.addMarker(10, 10)
+ self.plot.addMarker(10, 20)
+ self.plot.addMarker(40, 50, text='test', symbol=None)
+ self.plot.addMarker(40, 50, text='test', symbol='+')
+ self.plot.addXMarker(25)
+ self.plot.addXMarker(35)
+ self.plot.addXMarker(45, text='test')
+ self.plot.addYMarker(55)
+ self.plot.addYMarker(65)
+ self.plot.addYMarker(75, text='test')
+
+ self.plot.resetZoom()
+
+ def testPlotMarkerYAxis(self):
+ # Check only the API
+
+ legend = self.plot.addMarker(10, 10)
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ legend = self.plot.addMarker(10, 10, yaxis="right")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "right")
+
+ legend = self.plot.addMarker(10, 10, yaxis="left")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ legend = self.plot.addXMarker(10, yaxis="right")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "right")
+
+ legend = self.plot.addXMarker(10, yaxis="left")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ legend = self.plot.addYMarker(10, yaxis="right")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "right")
+
+ legend = self.plot.addYMarker(10, yaxis="left")
+ item = self.plot._getMarker(legend)
+ self.assertEqual(item.getYAxis(), "left")
+
+ self.plot.resetZoom()
+
+
+# TestPlotItem ################################################################
+
+class TestPlotItem(PlotWidgetTestCase):
+ """Basic tests for addItem."""
+
+ # Polygon coordinates and color
+ POLYGONS = [ # legend, x coords, y coords, color
+ ('triangle', numpy.array((10, 30, 50)),
+ numpy.array((55, 70, 55)), 'red'),
+ ('square', numpy.array((10, 10, 50, 50)),
+ numpy.array((10, 50, 50, 10)), 'green'),
+ ('star', numpy.array((60, 70, 80, 60, 80)),
+ numpy.array((25, 50, 25, 40, 40)), 'blue'),
+ ('2 triangles-simple',
+ numpy.array((90., 95., 100., numpy.nan, 90., 95., 100.)),
+ numpy.array((25., 5., 25., numpy.nan, 30., 50., 30.)),
+ 'pink'),
+ ('2 triangles-extra NaN',
+ numpy.array((numpy.nan, 90., 95., 100., numpy.nan, 0., 90., 95., 100., numpy.nan)),
+ numpy.array((0., 55., 70., 55., numpy.nan, numpy.nan, 75., 90., 75., numpy.nan)),
+ 'black'),
+ ]
+
+ # Rectangle coordinantes and color
+ RECTANGLES = [ # legend, x coords, y coords, color
+ ('square 1', numpy.array((1., 10.)),
+ numpy.array((1., 10.)), 'red'),
+ ('square 2', numpy.array((10., 20.)),
+ numpy.array((10., 20.)), 'green'),
+ ('square 3', numpy.array((20., 30.)),
+ numpy.array((20., 30.)), 'blue'),
+ ('rect 1', numpy.array((1., 30.)),
+ numpy.array((35., 40.)), 'black'),
+ ('line h', numpy.array((1., 30.)),
+ numpy.array((45., 45.)), 'darkRed'),
+ ]
+
+ SCALES = Axis.LINEAR, Axis.LOGARITHMIC
+
+ def setUp(self):
+ super(TestPlotItem, self).setUp()
+
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(0., 100., -100., 100.)
+
+ def testPlotItemPolygonFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Item Fill %s' % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False, linestyle='--',
+ shape="polygon", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemPolygonNoFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Item No Fill %s' % scale)
+
+ for legend, xList, yList, color in self.POLYGONS:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False, linestyle='--',
+ shape="polygon", fill=False, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Rectangle Fill %s' % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=True, color=color)
+ self.plot.resetZoom()
+
+ def testPlotItemRectangleNoFill(self):
+ for scale in self.SCALES:
+ with self.subTest(scale=scale):
+ self.plot.clear()
+ self.plot.getXAxis().setScale(scale)
+ self.plot.getYAxis().setScale(scale)
+ self.plot.setGraphTitle('Rectangle No Fill %s' % scale)
+
+ for legend, xList, yList, color in self.RECTANGLES:
+ self.plot.addShape(xList, yList, legend=legend,
+ replace=False,
+ shape="rectangle", fill=False, color=color)
+ self.plot.resetZoom()
+
+
+class TestPlotActiveCurveImage(PlotWidgetTestCase):
+ """Basic tests for active curve and image handling"""
+ xData = numpy.arange(1000)
+ yData = -500 + 100 * numpy.sin(xData)
+ xData2 = xData + 1000
+ yData2 = xData - 1000 + 200 * numpy.random.random(1000)
+
+ def tearDown(self):
+ self.plot.setActiveCurveHandling(False)
+ super(TestPlotActiveCurveImage, self).tearDown()
+
+ def testActiveCurveAndLabels(self):
+ # Active curve handling off, no label change
+ self.plot.setActiveCurveHandling(False)
+ self.plot.getXAxis().setLabel('XLabel')
+ self.plot.getYAxis().setLabel('YLabel')
+ self.plot.addCurve((1, 2), (1, 2))
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.addCurve((1, 2), (2, 3), xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ # Active curve handling on, label changes
+ self.plot.setActiveCurveHandling(True)
+ self.plot.getXAxis().setLabel('XLabel')
+ self.plot.getYAxis().setLabel('YLabel')
+
+ # labels changed as active curve
+ self.plot.addCurve((1, 2), (1, 2), legend='1',
+ xlabel='x1', ylabel='y1')
+ self.plot.setActiveCurve('1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels not changed as not active curve
+ self.plot.addCurve((1, 2), (2, 3), legend='2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels changed
+ self.plot.setActiveCurve('2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.setActiveCurve('1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ def testPlotActiveCurveSelectionMode(self):
+ self.plot.clear()
+ self.plot.setActiveCurveHandling(True)
+ legend = "curve 1"
+ self.plot.addCurve(self.xData, self.yData,
+ legend=legend,
+ color="green")
+
+ # active curve should be None
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+
+ # active curve should be None when None is set as active curve
+ self.plot.setActiveCurve(legend)
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, legend)
+ self.plot.setActiveCurve(None)
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, None)
+
+ # testing it automatically toggles if there is only one
+ self.plot.setActiveCurveSelectionMode("legacy")
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, legend)
+
+ # active curve should not change when None set as active curve
+ self.assertEqual(self.plot.getActiveCurveSelectionMode(), "legacy")
+ self.plot.setActiveCurve(None)
+ current = self.plot.getActiveCurve(just_legend=True)
+ self.assertEqual(current, legend)
+
+ # situation where no curve is active
+ self.plot.clear()
+ self.plot.setActiveCurveHandling(True)
+ self.assertEqual(self.plot.getActiveCurveSelectionMode(), "atmostone")
+ self.plot.addCurve(self.xData, self.yData,
+ legend=legend,
+ color="green")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ color="red")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+ self.plot.setActiveCurveSelectionMode("legacy")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), None)
+
+ # the first curve added should be active
+ self.plot.clear()
+ self.plot.addCurve(self.xData, self.yData,
+ legend=legend,
+ color="green")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
+ self.plot.addCurve(self.xData2, self.yData2,
+ legend="curve 2",
+ color="red")
+ self.assertEqual(self.plot.getActiveCurve(just_legend=True), legend)
+
+ def testActiveCurveStyle(self):
+ """Test change of active curve style"""
+ self.plot.setActiveCurveHandling(True)
+ self.plot.setActiveCurveStyle(color='black')
+ style = self.plot.getActiveCurveStyle()
+ self.assertEqual(style.getColor(), (0., 0., 0., 1.))
+ self.assertIsNone(style.getLineStyle())
+ self.assertIsNone(style.getLineWidth())
+ self.assertIsNone(style.getSymbol())
+ self.assertIsNone(style.getSymbolSize())
+
+ self.plot.addCurve(x=self.xData, y=self.yData, legend="curve1")
+ curve = self.plot.getCurve("curve1")
+ curve.setColor('blue')
+ curve.setLineStyle('-')
+ curve.setLineWidth(1)
+ curve.setSymbol('o')
+ curve.setSymbolSize(5)
+
+ # Check default current style
+ defaultStyle = curve.getCurrentStyle()
+ self.assertEqual(defaultStyle, CurveStyle(color='blue',
+ linestyle='-',
+ linewidth=1,
+ symbol='o',
+ symbolsize=5))
+
+ # Activate curve with highlight color=black
+ self.plot.setActiveCurve("curve1")
+ style = curve.getCurrentStyle()
+ self.assertEqual(style.getColor(), (0., 0., 0., 1.))
+ self.assertEqual(style.getLineStyle(), '-')
+ self.assertEqual(style.getLineWidth(), 1)
+ self.assertEqual(style.getSymbol(), 'o')
+ self.assertEqual(style.getSymbolSize(), 5)
+
+ # Change highlight to linewidth=2
+ self.plot.setActiveCurveStyle(linewidth=2)
+ style = curve.getCurrentStyle()
+ self.assertEqual(style.getColor(), (0., 0., 1., 1.))
+ self.assertEqual(style.getLineStyle(), '-')
+ self.assertEqual(style.getLineWidth(), 2)
+ self.assertEqual(style.getSymbol(), 'o')
+ self.assertEqual(style.getSymbolSize(), 5)
+
+ self.plot.setActiveCurve(None)
+ self.assertEqual(curve.getCurrentStyle(), defaultStyle)
+
+ def testActiveImageAndLabels(self):
+ # Active image handling always on, no API for toggling it
+ self.plot.getXAxis().setLabel('XLabel')
+ self.plot.getYAxis().setLabel('YLabel')
+
+ # labels changed as active curve
+ self.plot.addImage(numpy.arange(100).reshape(10, 10),
+ legend='1', xlabel='x1', ylabel='y1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels not changed as not active curve
+ self.plot.addImage(numpy.arange(100).reshape(10, 10),
+ legend='2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ # labels changed
+ self.plot.setActiveImage('2')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+ self.plot.setActiveImage('1')
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'x1')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'y1')
+
+ self.plot.clear()
+ self.assertEqual(self.plot.getXAxis().getLabel(), 'XLabel')
+ self.assertEqual(self.plot.getYAxis().getLabel(), 'YLabel')
+
+
+##############################################################################
+# Log
+##############################################################################
+
+class TestPlotEmptyLog(PlotWidgetTestCase):
+ """Basic tests for log plot"""
+ def testEmptyPlotTitleLabelsLog(self):
+ self.plot.setGraphTitle('Empty Log Log')
+ self.plot.getXAxis().setLabel('X')
+ self.plot.getYAxis().setLabel('Y')
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.resetZoom()
+
+
+class TestPlotAxes(TestCaseQt, ParametricTestCase):
+
+ # Test data
+ xData = numpy.arange(1, 10)
+ yData = xData ** 2
+
+ def __init__(self, methodName='runTest', backend=None):
+ unittest.TestCase.__init__(self, methodName)
+ self.__backend = backend
+
+ def setUp(self):
+ super(TestPlotAxes, self).setUp()
+ self.plot = PlotWidget(backend=self.__backend)
+ # It is not needed to display the plot
+ # It saves a lot of time
+ # self.plot.show()
+ # self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotAxes, self).tearDown()
+
+ def testDefaultAxes(self):
+ axis = self.plot.getXAxis()
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+ axis = self.plot.getYAxis()
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+ axis = self.plot.getYAxis(axis="right")
+ self.assertEqual(axis.getScale(), axis.LINEAR)
+
+ def testOldPlotAxis_getterSetter(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ p = self.plot
+
+ tests = [
+ # setters
+ (p.setGraphXLimits, (10, 20), x.getLimits, (10, 20)),
+ (p.setGraphYLimits, (10, 20), y.getLimits, (10, 20)),
+ (p.setGraphXLabel, "foox", x.getLabel, "foox"),
+ (p.setGraphYLabel, "fooy", y.getLabel, "fooy"),
+ (p.setYAxisInverted, True, y.isInverted, True),
+ (p.setXAxisLogarithmic, True, x.getScale, x.LOGARITHMIC),
+ (p.setYAxisLogarithmic, True, y.getScale, y.LOGARITHMIC),
+ (p.setXAxisAutoScale, False, x.isAutoScale, False),
+ (p.setYAxisAutoScale, False, y.isAutoScale, False),
+ # getters
+ (x.setLimits, (11, 20), p.getGraphXLimits, (11, 20)),
+ (y.setLimits, (11, 20), p.getGraphYLimits, (11, 20)),
+ (x.setLabel, "fooxx", p.getGraphXLabel, "fooxx"),
+ (y.setLabel, "fooyy", p.getGraphYLabel, "fooyy"),
+ (y.setInverted, False, p.isYAxisInverted, False),
+ (x.setScale, x.LINEAR, p.isXAxisLogarithmic, False),
+ (y.setScale, y.LINEAR, p.isYAxisLogarithmic, False),
+ (x.setAutoScale, True, p.isXAxisAutoScale, True),
+ (y.setAutoScale, True, p.isYAxisAutoScale, True),
+ ]
+ for testCase in tests:
+ setter, value, getter, expected = testCase
+ with self.subTest():
+ if setter is not None:
+ if not isinstance(value, tuple):
+ value = (value, )
+ setter(*value)
+ if getter is not None:
+ self.assertEqual(getter(), expected)
+
+ def testOldPlotAxis_Logarithmic(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.getScale(), x.LINEAR)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+
+ self.plot.setXAxisLogarithmic(True)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), False)
+
+ self.plot.setYAxisLogarithmic(True)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LOGARITHMIC)
+ self.assertEqual(yright.getScale(), x.LOGARITHMIC)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), True)
+
+ yright.setScale(yright.LINEAR)
+ self.assertEqual(x.getScale(), x.LOGARITHMIC)
+ self.assertEqual(y.getScale(), x.LINEAR)
+ self.assertEqual(yright.getScale(), x.LINEAR)
+ self.assertEqual(self.plot.isXAxisLogarithmic(), True)
+ self.assertEqual(self.plot.isYAxisLogarithmic(), False)
+
+ def testOldPlotAxis_AutoScale(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.isAutoScale(), True)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+
+ self.plot.setXAxisAutoScale(False)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), True)
+
+ self.plot.setYAxisAutoScale(False)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), False)
+ self.assertEqual(yright.isAutoScale(), False)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), False)
+
+ yright.setAutoScale(True)
+ self.assertEqual(x.isAutoScale(), False)
+ self.assertEqual(y.isAutoScale(), True)
+ self.assertEqual(yright.isAutoScale(), True)
+ self.assertEqual(self.plot.isXAxisAutoScale(), False)
+ self.assertEqual(self.plot.isYAxisAutoScale(), True)
+
+ def testOldPlotAxis_Inverted(self):
+ """Test silx API prior to silx 0.6"""
+ x = self.plot.getXAxis()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), False)
+ self.assertEqual(yright.isInverted(), False)
+
+ self.plot.setYAxisInverted(True)
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), True)
+ self.assertEqual(yright.isInverted(), True)
+ self.assertEqual(self.plot.isYAxisInverted(), True)
+
+ yright.setInverted(False)
+ self.assertEqual(x.isInverted(), False)
+ self.assertEqual(y.isInverted(), False)
+ self.assertEqual(yright.isInverted(), False)
+ self.assertEqual(self.plot.isYAxisInverted(), False)
+
+ def testLogXWithData(self):
+ self.plot.setGraphTitle('Curve X: Log Y: Linear')
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+ axis = self.plot.getXAxis()
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLogYWithData(self):
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+ axis = self.plot.getYAxis()
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+ axis = self.plot.getYAxis(axis="right")
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLogYRightWithData(self):
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+ axis = self.plot.getYAxis(axis="right")
+ axis.setScale(axis.LOGARITHMIC)
+
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+ axis = self.plot.getYAxis()
+ self.assertEqual(axis.getScale(), axis.LOGARITHMIC)
+
+ def testLimitsChanged_setLimits(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
+ self.plot.setLimits(0, 1, 0, 1, 0, 1)
+ # at least one event per axis
+ self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
+
+ def testLimitsChanged_resetZoom(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ self.plot.getXAxis().sigLimitsChanged.connect(listener.partial(axis="x"))
+ self.plot.getYAxis().sigLimitsChanged.connect(listener.partial(axis="y"))
+ self.plot.getYAxis(axis="right").sigLimitsChanged.connect(listener.partial(axis="y2"))
+ self.plot.resetZoom()
+ # at least one event per axis
+ self.assertEqual(len(set(listener.karguments(argumentName="axis"))), 3)
+
+ def testLimitsChanged_setXLimit(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ axis = self.plot.getXAxis()
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testLimitsChanged_setYLimit(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ axis = self.plot.getYAxis()
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testLimitsChanged_setYRightLimit(self):
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=False,
+ color='green', linestyle="-", symbol='o')
+ listener = SignalListener()
+ axis = self.plot.getYAxis(axis="right")
+ axis.sigLimitsChanged.connect(listener)
+ axis.setLimits(20, 30)
+ # at least one event per axis
+ self.assertEqual(listener.arguments(callIndex=-1), (20.0, 30.0))
+ self.assertEqual(axis.getLimits(), (20.0, 30.0))
+
+ def testScaleProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigScaleChanged.connect(listener.partial("left"))
+ yright.sigScaleChanged.connect(listener.partial("right"))
+ yright.setScale(yright.LOGARITHMIC)
+
+ self.assertEqual(y.getScale(), y.LOGARITHMIC)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", y.LOGARITHMIC), events)
+ self.assertIn(("right", y.LOGARITHMIC), events)
+
+ def testAutoScaleProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigAutoScaleChanged.connect(listener.partial("left"))
+ yright.sigAutoScaleChanged.connect(listener.partial("right"))
+ yright.setAutoScale(False)
+
+ self.assertEqual(y.isAutoScale(), False)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", False), events)
+ self.assertIn(("right", False), events)
+
+ def testInvertedProxy(self):
+ listener = SignalListener()
+ y = self.plot.getYAxis()
+ yright = self.plot.getYAxis(axis="right")
+ y.sigInvertedChanged.connect(listener.partial("left"))
+ yright.sigInvertedChanged.connect(listener.partial("right"))
+ yright.setInverted(True)
+
+ self.assertEqual(y.isInverted(), True)
+ events = listener.arguments()
+ self.assertEqual(len(events), 2)
+ self.assertIn(("left", True), events)
+ self.assertIn(("right", True), events)
+
+ def testAxesDisplayedFalse(self):
+ """Test coverage on setAxesDisplayed(False)"""
+ self.plot.setAxesDisplayed(False)
+
+ def testAxesDisplayedTrue(self):
+ """Test coverage on setAxesDisplayed(True)"""
+ self.plot.setAxesDisplayed(True)
+
+ def testAxesMargins(self):
+ """Test PlotWidget's getAxesMargins and setAxesMargins"""
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ margins = self.plot.getAxesMargins()
+ self.assertEqual(margins, (.15, .1, .1, .15))
+
+ for margins in ((0., 0., 0., 0.), (.15, .1, .1, .15)):
+ with self.subTest(margins=margins):
+ self.plot.setAxesMargins(*margins)
+ self.qapp.processEvents()
+ self.assertEqual(self.plot.getAxesMargins(), margins)
+
+ def testBoundingRectItem(self):
+ item = BoundingRect()
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.addItem(item)
+ self.plot.resetZoom()
+ limits = numpy.array(self.plot.getXAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
+ limits = numpy.array(self.plot.getYAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
+
+ def testBoundingRectRightItem(self):
+ item = BoundingRect()
+ item.setYAxis("right")
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.addItem(item)
+ self.plot.resetZoom()
+ limits = numpy.array(self.plot.getXAxis().getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-1000, 1000]))
+ limits = numpy.array(self.plot.getYAxis("right").getLimits())
+ numpy.testing.assert_almost_equal(limits, numpy.array([-2000, 2000]))
+
+ def testBoundingRectArguments(self):
+ item = BoundingRect()
+ with self.assertRaises(Exception):
+ item.setBounds((1000, -1000, -2000, 2000))
+ with self.assertRaises(Exception):
+ item.setBounds((-1000, 1000, 2000, -2000))
+
+ def testBoundingRectWithLog(self):
+ item = BoundingRect()
+ self.plot.addItem(item)
+
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.assertEqual(item.getBounds(), (1000, 1000, -2000, 2000))
+
+ item.setBounds((-1000, 1000, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(False)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.assertEqual(item.getBounds(), (-1000, 1000, 2000, 2000))
+
+ item.setBounds((-1000, 0, -2000, 2000))
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.assertIsNone(item.getBounds())
+
+ def testAxisExtent(self):
+ """Test XAxisExtent and yAxisExtent"""
+ for cls, axis in ((XAxisExtent, self.plot.getXAxis()),
+ (YAxisExtent, self.plot.getYAxis())):
+ for range_, logRange in (((2, 3), (2, 3)),
+ ((-2, -1), (1, 100)),
+ ((-1, 3), (3. * 0.9, 3. * 1.1))):
+ extent = cls()
+ extent.setRange(*range_)
+ self.plot.addItem(extent)
+
+ for isLog, plotRange in ((False, range_), (True, logRange)):
+ with self.subTest(
+ cls=cls.__name__, range=range_, isLog=isLog):
+ axis._setLogarithmic(isLog)
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+ self.assertEqual(axis.getLimits(), plotRange)
+
+ axis._setLogarithmic(False)
+ self.plot.clear()
+
+ def testAxisLimitOverflow(self):
+ """Test setting limis beyond supported range"""
+ xaxis, yaxis = self.plot.getXAxis(), self.plot.getYAxis()
+ for scale in ("linear", "log"):
+ xaxis.setScale(scale)
+ yaxis.setScale(scale)
+ for limits in ((1e300, 1e308),
+ (-1e308, 1e308),
+ (1e-300, 2e-300)):
+ with self.subTest(scale=scale, limits=limits):
+ xaxis.setLimits(*limits)
+ self.qapp.processEvents()
+ self.assertNotEqual(xaxis.getLimits(), limits)
+ yaxis.setLimits(*limits)
+ self.qapp.processEvents()
+ self.assertNotEqual(yaxis.getLimits(), limits)
+
+
+class TestPlotCurveLog(PlotWidgetTestCase, ParametricTestCase):
+ """Basic tests for addCurve with log scale axes"""
+
+ # Test data
+ xData = numpy.arange(1000) + 1
+ yData = xData ** 2
+
+ def _setLabels(self):
+ self.plot.getXAxis().setLabel('X')
+ self.plot.getYAxis().setLabel('X * X')
+
+ def testPlotCurveLogX(self):
+ self._setLabels()
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('Curve X: Log Y: Linear')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveLogY(self):
+ self._setLabels()
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ self.plot.setGraphTitle('Curve X: Linear Y: Log')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveLogXY(self):
+ self._setLabels()
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ self.plot.setGraphTitle('Curve X: Log Y: Log')
+
+ self.plot.addCurve(self.xData, self.yData,
+ legend="curve",
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ def testPlotCurveErrorLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ # Every second error leads to negative number
+ errors = numpy.ones_like(self.xData)
+ errors[::2] = self.xData[::2] + 1
+
+ tests = [ # name, xerror, yerror
+ ('xerror=3', 3, None),
+ ('xerror=N array', errors, None),
+ ('xerror=Nx1 array', errors.reshape(len(errors), 1), None),
+ ('xerror=2xN array', numpy.array((errors, errors)), None),
+ ('yerror=6', None, 6),
+ ('yerror=N array', None, errors ** 2),
+ ('yerror=Nx1 array', None, (errors ** 2).reshape(len(errors), 1)),
+ ('yerror=2xN array', None, numpy.array((errors, errors)) ** 2),
+ ]
+
+ for name, xError, yError in tests:
+ with self.subTest(name):
+ self.plot.setGraphTitle(name)
+ self.plot.addCurve(self.xData, self.yData,
+ legend=name,
+ xerror=xError, yerror=yError,
+ replace=False, resetzoom=True,
+ color='green', linestyle="-", symbol='o')
+
+ self.qapp.processEvents()
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ def testPlotCurveToggleLog(self):
+ """Add a curve with negative data and toggle log axis"""
+ arange = numpy.arange(1000) + 1
+ tests = [ # name, xData, yData
+ ('x>0, some negative y', arange, arange - 500),
+ ('x>0, y<0', arange, -arange),
+ ('some negative x, y>0', arange - 500, arange),
+ ('x<0, y>0', -arange, arange),
+ ('some negative x and y', arange - 500, arange - 500),
+ ('x<0, y<0', -arange, -arange),
+ ]
+
+ for name, xData, yData in tests:
+ with self.subTest(name):
+ self.plot.addCurve(xData, yData, resetzoom=True)
+ self.qapp.processEvents()
+
+ # no log axis
+ xLim = self.plot.getXAxis().getLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ # x axis log
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = xData > 0
+ if numpy.any(positives):
+ self.assertTrue(numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))))
+ self.assertEqual(
+ yLim, (min(yData[positives]), max(yData[positives])))
+ else: # No positive x in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # x axis and y axis log
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = numpy.logical_and(xData > 0, yData > 0)
+ if numpy.any(positives):
+ self.assertTrue(numpy.allclose(
+ xLim, (min(xData[positives]), max(xData[positives]))))
+ self.assertTrue(numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))))
+ else: # No positive x and y in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # y axis log
+ self.plot.getXAxis()._setLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ yLim = self.plot.getYAxis().getLimits()
+ positives = yData > 0
+ if numpy.any(positives):
+ self.assertEqual(
+ xLim, (min(xData[positives]), max(xData[positives])))
+ self.assertTrue(numpy.allclose(
+ yLim, (min(yData[positives]), max(yData[positives]))))
+ else: # No positive y in the curve
+ self.assertEqual(xLim, (1., 100.))
+ self.assertEqual(yLim, (1., 100.))
+
+ # no log axis
+ self.plot.getYAxis()._setLogarithmic(False)
+ self.qapp.processEvents()
+
+ xLim = self.plot.getXAxis().getLimits()
+ self.assertEqual(xLim, (min(xData), max(xData)))
+ yLim = self.plot.getYAxis().getLimits()
+ self.assertEqual(yLim, (min(yData), max(yData)))
+
+ self.plot.clear()
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+
+class TestPlotImageLog(PlotWidgetTestCase):
+ """Basic tests for addImage with log scale axes."""
+
+ def setUp(self):
+ super(TestPlotImageLog, self).setUp()
+
+ self.plot.getXAxis().setLabel('Columns')
+ self.plot.getYAxis().setLabel('Rows')
+
+ def testPlotColormapGrayLogX(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Log Y: Linear')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogY(self):
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Linear Y: Log')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotColormapGrayLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('CMap X: Log Y: Log')
+
+ colormap = Colormap(name='gray',
+ normalization='linear',
+ vmin=None,
+ vmax=None)
+ self.plot.addImage(DATA_2D, legend="image 1",
+ origin=(1., 1.), scale=(1., 1.),
+ resetzoom=False, colormap=colormap)
+ self.plot.resetZoom()
+
+ def testPlotRgbRgbaLogXY(self):
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+ self.plot.setGraphTitle('RGB + RGBA X: Log Y: Log')
+
+ rgb = numpy.array(
+ (((0, 0, 0), (128, 0, 0), (255, 0, 0)),
+ ((0, 128, 0), (0, 128, 128), (0, 128, 256))),
+ dtype=numpy.uint8)
+
+ self.plot.addImage(rgb, legend="rgb",
+ origin=(1, 1), scale=(10, 10),
+ resetzoom=False)
+
+ rgba = numpy.array(
+ (((0, 0, 0, .5), (.5, 0, 0, 1), (1, 0, 0, .5)),
+ ((0, .5, 0, 1), (0, .5, .5, 1), (0, 1, 1, .5))),
+ dtype=numpy.float32)
+
+ self.plot.addImage(rgba, legend="rgba",
+ origin=(5., 5.), scale=(10., 10.),
+ resetzoom=False)
+ self.plot.resetZoom()
+
+
+class TestPlotMarkerLog(PlotWidgetTestCase):
+ """Basic tests for markers on log scales"""
+
+ # Test marker parameters
+ markers = [ # x, y, color, selectable, draggable
+ (10., 10., 'blue', False, False),
+ (20., 20., 'red', False, False),
+ (40., 100., 'green', True, False),
+ (40., 500., 'gray', True, True),
+ (60., 800., 'black', False, True),
+ ]
+
+ def setUp(self):
+ super(TestPlotMarkerLog, self).setUp()
+
+ self.plot.getYAxis().setLabel('Rows')
+ self.plot.getXAxis().setLabel('Columns')
+ self.plot.getXAxis().setAutoScale(False)
+ self.plot.getYAxis().setAutoScale(False)
+ self.plot.setKeepDataAspectRatio(False)
+ self.plot.setLimits(1., 100., 1., 1000.)
+ self.plot.getXAxis()._setLogarithmic(True)
+ self.plot.getYAxis()._setLogarithmic(True)
+
+ def testPlotMarkerXLog(self):
+ self.plot.setGraphTitle('Markers X, Log axes')
+
+ for x, _, color, select, drag in self.markers:
+ name = str(x)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addXMarker(x, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerYLog(self):
+ self.plot.setGraphTitle('Markers Y, Log axes')
+
+ for _, y, color, select, drag in self.markers:
+ name = str(y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addYMarker(y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+ def testPlotMarkerPtLog(self):
+ self.plot.setGraphTitle('Markers Pt, Log axes')
+
+ for x, y, color, select, drag in self.markers:
+ name = "{0},{1}".format(x, y)
+ if select:
+ name += " sel."
+ if drag:
+ name += " drag"
+ self.plot.addMarker(x, y, name, name, color, select, drag)
+ self.plot.resetZoom()
+
+
+@pytest.mark.usefixtures("test_options_class_attr")
+class TestPlotWidgetSwitchBackend(PlotWidgetTestCase):
+ """Test [get|set]Backend to switch backend"""
+
+ @pytest.mark.usefixtures("test_options")
+ def testSwitchBackend(self):
+ """Test switching a plot with a few items"""
+ backends = {'none': 'BackendBase', 'mpl': 'BackendMatplotlibQt'}
+ if self.test_options.WITH_GL_TEST:
+ backends['gl'] = 'BackendOpenGL'
+
+ self.plot.addImage(numpy.arange(100).reshape(10, 10))
+ self.plot.addCurve((-3, -2, -1), (1, 2, 3))
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ items = self.plot.getItems()
+ self.assertEqual(len(items), 2)
+
+ for backend, className in backends.items():
+ with self.subTest(backend=backend):
+ self.plot.setBackend(backend)
+ self.plot.replot()
+
+ retrievedBackend = self.plot.getBackend()
+ self.assertEqual(type(retrievedBackend).__name__, className)
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(self.plot.getItems(), items)
+
+
+class TestPlotWidgetSelection(PlotWidgetTestCase):
+ """Test PlotWidget.selection and active items handling"""
+
+ def _checkSelection(self, selection, current=None, selected=()):
+ """Check current item and selected items."""
+ self.assertIs(selection.getCurrentItem(), current)
+ self.assertEqual(selection.getSelectedItems(), selected)
+
+ def testSyncWithActiveItems(self):
+ """Test update of PlotWidgetSelection according to active items"""
+ listener = SignalListener()
+
+ selection = self.plot.selection()
+ selection.sigCurrentItemChanged.connect(listener)
+ self._checkSelection(selection)
+
+ # Active item is current
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ image = self.plot.getActiveImage()
+ self.assertEqual(listener.callCount(), 1)
+ self._checkSelection(selection, image, (image,))
+
+ # No active = no current
+ self.plot.setActiveImage(None)
+ self.assertEqual(listener.callCount(), 2)
+ self._checkSelection(selection)
+
+ # Active item is current
+ self.plot.setActiveImage('image')
+ self.assertEqual(listener.callCount(), 3)
+ self._checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ scatter = self.plot.getActiveScatter()
+ self.assertEqual(listener.callCount(), 4)
+ self._checkSelection(selection, scatter, (scatter, image))
+
+ # Previously mosted recently "actived" item is current
+ self.plot.setActiveScatter(None)
+ self.assertEqual(listener.callCount(), 5)
+ self._checkSelection(selection, image, (image,))
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveScatter('scatter')
+ self.assertEqual(listener.callCount(), 6)
+ self._checkSelection(selection, scatter, (scatter, image))
+
+ # No active = no current
+ self.plot.setActiveImage(None)
+ self.plot.setActiveScatter(None)
+ self.assertEqual(listener.callCount(), 7)
+ self._checkSelection(selection)
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveScatter('scatter')
+ self.assertEqual(listener.callCount(), 8)
+ self.plot.setActiveImage('image')
+ self.assertEqual(listener.callCount(), 9)
+ self._checkSelection(selection, image, (image, scatter))
+
+ # Add a curve which is not active by default
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ curve = self.plot.getCurve('curve')
+ self.assertEqual(listener.callCount(), 9)
+ self._checkSelection(selection, image, (image, scatter))
+
+ # Mosted recently "actived" item is current
+ self.plot.setActiveCurve('curve')
+ self.assertEqual(listener.callCount(), 10)
+ self._checkSelection(selection, curve, (curve, image, scatter))
+
+ # Add a curve which is not active by default
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve2')
+ curve2 = self.plot.getCurve('curve2')
+ self.assertEqual(listener.callCount(), 10)
+ self._checkSelection(selection, curve, (curve, image, scatter))
+
+ # Mosted recently "actived" item is current, previous curve is removed
+ self.plot.setActiveCurve('curve2')
+ self.assertEqual(listener.callCount(), 11)
+ self._checkSelection(selection, curve2, (curve2, image, scatter))
+
+ # No items = no current
+ self.plot.clear()
+ self.assertEqual(listener.callCount(), 12)
+ self._checkSelection(selection)
+
+ def testPlotWidgetWithItems(self):
+ """Test init of selection on a plot with items"""
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ self.plot.setActiveCurve('curve')
+
+ selection = self.plot.selection()
+ self.assertIsNotNone(selection.getCurrentItem())
+ selected = selection.getSelectedItems()
+ self.assertEqual(len(selected), 3)
+ self.assertIn(self.plot.getActiveCurve(), selected)
+ self.assertIn(self.plot.getActiveImage(), selected)
+ self.assertIn(self.plot.getActiveScatter(), selected)
+
+ def testSetCurrentItem(self):
+ """Test setCurrentItem"""
+ # Add items to the plot
+ self.plot.addImage(((0, 1), (2, 3)), legend='image')
+ image = self.plot.getActiveImage()
+ self.plot.addScatter((3, 2, 1), (0, 1, 2), (0, 1, 2), legend='scatter')
+ scatter = self.plot.getActiveScatter()
+ self.plot.addCurve((0, 1, 2), (0, 1, 2), legend='curve')
+ self.plot.setActiveCurve('curve')
+ curve = self.plot.getActiveCurve()
+
+ selection = self.plot.selection()
+ self.assertIsNotNone(selection.getCurrentItem())
+ self.assertEqual(len(selection.getSelectedItems()), 3)
+
+ # Set current to None reset all active items
+ selection.setCurrentItem(None)
+ self._checkSelection(selection)
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.assertIsNone(self.plot.getActiveImage())
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active
+ selection.setCurrentItem(image)
+ self._checkSelection(selection, image, (image,))
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(curve)
+ self._checkSelection(selection, curve, (curve, image))
+ self.assertIs(self.plot.getActiveCurve(), curve)
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIsNone(self.plot.getActiveScatter())
+
+ # Set current to an item makes it active and keeps other active
+ selection.setCurrentItem(scatter)
+ self._checkSelection(selection, scatter, (scatter, curve, image))
+ self.assertIs(self.plot.getActiveCurve(), curve)
+ self.assertIs(self.plot.getActiveImage(), image)
+ self.assertIs(self.plot.getActiveScatter(), scatter)
+
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotWidget_Gl(TestPlotWidget):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotImage_Gl(TestPlotImage):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotCurve_Gl(TestPlotCurve):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotHistogram_Gl(TestPlotHistogram):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotScatter_Gl(TestPlotScatter):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotMarker_Gl(TestPlotMarker):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotItem_Gl(TestPlotItem):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotAxes_Gl(TestPlotAxes):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotActiveCurveImage_Gl(TestPlotActiveCurveImage):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotEmptyLog_Gl(TestPlotEmptyLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotCurveLog_Gl(TestPlotCurveLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotImageLog_Gl(TestPlotImageLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotMarkerLog_Gl(TestPlotMarkerLog):
+ backend="gl"
+
+@pytest.mark.usefixtures("use_opengl")
+class TestPlotWidgetSelection_Gl(TestPlotWidgetSelection):
+ backend="gl"
+
+class TestSpecial_ExplicitMplBackend(TestSpecialBackend):
+ backend="mpl"
diff --git a/src/silx/gui/plot/test/testPlotWidgetNoBackend.py b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
new file mode 100644
index 0000000..4914929
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWidgetNoBackend.py
@@ -0,0 +1,618 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget with 'none' backend"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+from functools import reduce
+from silx.utils.testutils import ParametricTestCase
+
+import numpy
+
+from silx.gui.plot.PlotWidget import PlotWidget
+from silx.gui.plot.items.histogram import _getHistogramCurve, _computeEdges
+
+
+class TestPlot(unittest.TestCase):
+ """Basic tests of Plot without backend"""
+
+ def testPlotTitleLabels(self):
+ """Create a Plot and set the labels"""
+
+ plot = PlotWidget(backend='none')
+
+ title, xlabel, ylabel = 'the title', 'x label', 'y label'
+ plot.setGraphTitle(title)
+ plot.getXAxis().setLabel(xlabel)
+ plot.getYAxis().setLabel(ylabel)
+
+ self.assertEqual(plot.getGraphTitle(), title)
+ self.assertEqual(plot.getXAxis().getLabel(), xlabel)
+ self.assertEqual(plot.getYAxis().getLabel(), ylabel)
+
+ def testAddNoRemove(self):
+ """add objects to the Plot"""
+
+ plot = PlotWidget(backend='none')
+ plot.addCurve(x=(1, 2, 3), y=(3, 2, 1))
+ plot.addImage(numpy.arange(100.).reshape(10, -1))
+ plot.addShape(numpy.array((1., 10.)),
+ numpy.array((10., 10.)),
+ shape="rectangle")
+ plot.addXMarker(10.)
+
+
+class TestPlotRanges(ParametricTestCase):
+ """Basic tests of Plot data ranges without backend"""
+
+ _getValidValues = {True: lambda ar: ar > 0,
+ False: lambda ar: numpy.ones(shape=ar.shape,
+ dtype=bool)}
+
+ @staticmethod
+ def _getRanges(arrays, are_logs):
+ gen = (TestPlotRanges._getValidValues[is_log](ar)
+ for (ar, is_log) in zip(arrays, are_logs))
+ indices = numpy.where(reduce(numpy.logical_and, gen))[0]
+ if len(indices) > 0:
+ ranges = [(ar[indices[0]], ar[indices[-1]]) for ar in arrays]
+ else:
+ ranges = [None] * len(arrays)
+
+ return ranges
+
+ @staticmethod
+ def _getRangesMinmax(ranges):
+ # TODO : error if None in ranges.
+ rangeMin = numpy.min([rng[0] for rng in ranges])
+ rangeMax = numpy.max([rng[1] for rng in ranges])
+ return rangeMin, rangeMax
+
+ def testDataRangeNoPlot(self):
+ """empty plot data range"""
+
+ plot = PlotWidget(backend='none')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ self.assertIsNone(dataRange.x)
+ self.assertIsNone(dataRange.y)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeft(self):
+ """left axis range"""
+
+ plot = PlotWidget(backend='none')
+
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+
+ plot.addCurve(x=xData,
+ y=yData,
+ legend='plot_0',
+ yaxis='left')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData],
+ [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertSequenceEqual(dataRange.y, yRange)
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeRight(self):
+ """right axis range"""
+
+ plot = PlotWidget(backend='none')
+ xData = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData,
+ y=yData,
+ legend='plot_0',
+ yaxis='right')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = self._getRanges([xData, yData],
+ [logX, logY])
+ self.assertSequenceEqual(dataRange.x, xRange)
+ self.assertIsNone(dataRange.y)
+ self.assertSequenceEqual(dataRange.yright, yRange)
+
+ def testDataRangeImage(self):
+ """image data range"""
+
+ origin = (-10, 25)
+ scale = (3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = PlotWidget(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeLeftRight(self):
+ """right+left axis range"""
+
+ plot = PlotWidget(backend='none')
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l,
+ y=yData_l,
+ legend='plot_l',
+ yaxis='left')
+
+ xData_r = numpy.arange(10) - 4.9 # range : -4.9 , 4.1
+ yData_r = numpy.arange(10) - 6.9 # range : -6.9 , 2.1
+ plot.addCurve(x=xData_r,
+ y=yData_r,
+ legend='plot_r',
+ yaxis='right')
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
+ [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
+ [logX, logY])
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeCurveImage(self):
+ """right+left+image axis range"""
+
+ # overlapping ranges :
+ # image sets x min and y max
+ # plot_left sets y min
+ # plot_right sets x max (and yright)
+ plot = PlotWidget(backend='none')
+
+ origin = (-10, 5)
+ scale = (3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot.addImage(image,
+ origin=origin, scale=scale, legend='image')
+
+ xData_l = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ yData_l = numpy.arange(10) - 1.9 # range : -1.9 , 7.1
+ plot.addCurve(x=xData_l,
+ y=yData_l,
+ legend='plot_l',
+ yaxis='left')
+
+ xData_r = numpy.arange(10) + 4.1 # range : 4.1 , 13.1
+ yData_r = numpy.arange(10) - 0.9 # range : -0.9 , 8.1
+ plot.addCurve(x=xData_r,
+ y=yData_r,
+ legend='plot_r',
+ yaxis='right')
+
+ imgXRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ imgYRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRangeL, yRangeL = self._getRanges([xData_l, yData_l],
+ [logX, logY])
+ xRangeR, yRangeR = self._getRanges([xData_r, yData_r],
+ [logX, logY])
+ if logX or logY:
+ xRangeLR = self._getRangesMinmax([xRangeL, xRangeR])
+ else:
+ xRangeLR = self._getRangesMinmax([xRangeL,
+ xRangeR,
+ imgXRange])
+ yRangeL = self._getRangesMinmax([yRangeL, imgYRange])
+ self.assertSequenceEqual(dataRange.x, xRangeLR)
+ self.assertSequenceEqual(dataRange.y, yRangeL)
+ self.assertSequenceEqual(dataRange.yright, yRangeR)
+
+ def testDataRangeImageNegativeScaleX(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (-3., 8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = PlotWidget(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ xRange.sort() # negative scale!
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeImageNegativeScaleY(self):
+ """image data range, negative scale"""
+
+ origin = (-10, 25)
+ scale = (3., -8.)
+ image = numpy.arange(100.).reshape(20, 5)
+
+ plot = PlotWidget(backend='none')
+ plot.addImage(image,
+ origin=origin, scale=scale)
+
+ xRange = numpy.array([0., image.shape[1] * scale[0]]) + origin[0]
+ yRange = numpy.array([0., image.shape[0] * scale[1]]) + origin[1]
+ yRange.sort() # negative scale!
+
+ ranges = {(False, False): (xRange, yRange),
+ (True, False): (None, None),
+ (True, True): (None, None),
+ (False, True): (None, None)}
+
+ for logX, logY in ((False, False),
+ (True, False),
+ (True, True),
+ (False, True),
+ (False, False)):
+ with self.subTest(logX=logX, logY=logY):
+ plot.getXAxis()._setLogarithmic(logX)
+ plot.getYAxis()._setLogarithmic(logY)
+ dataRange = plot.getDataRange()
+ xRange, yRange = ranges[logX, logY]
+ self.assertTrue(numpy.array_equal(dataRange.x, xRange),
+ msg='{0} != {1}'.format(dataRange.x, xRange))
+ self.assertTrue(numpy.array_equal(dataRange.y, yRange),
+ msg='{0} != {1}'.format(dataRange.y, yRange))
+ self.assertIsNone(dataRange.yright)
+
+ def testDataRangeHiddenCurve(self):
+ """curves with a hidden curve"""
+ plot = PlotWidget(backend='none')
+ plot.addCurve((0, 1), (0, 1), legend='shown')
+ plot.addCurve((0, 1, 2), (5, 5, 5), legend='hidden')
+ range1 = plot.getDataRange()
+ self.assertEqual(range1.x, (0, 2))
+ self.assertEqual(range1.y, (0, 5))
+ plot.hideCurve('hidden')
+ range2 = plot.getDataRange()
+ self.assertEqual(range2.x, (0, 1))
+ self.assertEqual(range2.y, (0, 1))
+
+
+class TestPlotGetCurveImage(unittest.TestCase):
+ """Test of plot getCurve and getImage methods"""
+
+ def testGetCurve(self):
+ """PlotWidget.getCurve and Plot.getActiveCurve tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 0')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 1')
+ plot.addCurve(x=(0, 1), y=(0, 1), legend='curve 2')
+ plot.setActiveCurve('curve 0')
+
+ # Active curve
+ active = plot.getActiveCurve()
+ self.assertEqual(active.getName(), 'curve 0')
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), 'curve 0')
+
+ # No active curve and curves
+ plot.setActiveCurveHandling(False)
+ active = plot.getActiveCurve()
+ self.assertIsNone(active) # No active curve
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), 'curve 2') # Last added curve
+
+ # Last curve hidden
+ plot.hideCurve('curve 2', True)
+ curve = plot.getCurve()
+ self.assertEqual(curve.getName(), 'curve 1') # Last added curve
+
+ # All curves hidden
+ plot.hideCurve('curve 1', True)
+ plot.hideCurve('curve 0', True)
+ curve = plot.getCurve()
+ self.assertIsNone(curve)
+
+ def testGetCurveOldApi(self):
+ """old API PlotWidget.getCurve and Plot.getActiveCurve tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No curve
+ curve = plot.getCurve()
+ self.assertIsNone(curve) # No curve
+
+ plot.setActiveCurveHandling(True)
+ x = numpy.arange(10.).astype(numpy.float32)
+ y = x * x
+ plot.addCurve(x=x, y=y, legend='curve 0', info=["whatever"])
+ plot.addCurve(x=x, y=2*x, legend='curve 1', info="anything")
+ plot.setActiveCurve('curve 0')
+
+ # Active curve (4 elements)
+ xOut, yOut, legend, info = plot.getActiveCurve()[:4]
+ self.assertEqual(legend, 'curve 0')
+ self.assertTrue(numpy.allclose(xOut, x), 'curve 0 wrong x data')
+ self.assertTrue(numpy.allclose(yOut, y), 'curve 0 wrong y data')
+
+ # Active curve (5 elements)
+ xOut, yOut, legend, info, params = plot.getCurve("curve 1")
+ self.assertEqual(legend, 'curve 1')
+ self.assertEqual(info, 'anything')
+ self.assertTrue(numpy.allclose(xOut, x), 'curve 1 wrong x data')
+ self.assertTrue(numpy.allclose(yOut, 2 * x), 'curve 1 wrong y data')
+
+ def testGetImage(self):
+ """PlotWidget.getImage and PlotWidget.getActiveImage tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ plot.addImage(((0, 1), (2, 3)), legend='image 0')
+ plot.addImage(((0, 1), (2, 3)), legend='image 1')
+
+ # Active image
+ active = plot.getActiveImage()
+ self.assertEqual(active.getName(), 'image 0')
+ image = plot.getImage()
+ self.assertEqual(image.getName(), 'image 0')
+
+ # No active image
+ plot.addImage(((0, 1), (2, 3)), legend='image 2')
+ plot.setActiveImage(None)
+ active = plot.getActiveImage()
+ self.assertIsNone(active)
+ image = plot.getImage()
+ self.assertEqual(image.getName(), 'image 2')
+
+ # Active image
+ plot.setActiveImage('image 1')
+ active = plot.getActiveImage()
+ self.assertEqual(active.getName(), 'image 1')
+ image = plot.getImage()
+ self.assertEqual(image.getName(), 'image 1')
+
+ def testGetImageOldApi(self):
+ """PlotWidget.getImage and PlotWidget.getActiveImage old API tests"""
+
+ plot = PlotWidget(backend='none')
+
+ # No image
+ image = plot.getImage()
+ self.assertIsNone(image)
+
+ image = numpy.arange(10).astype(numpy.float32)
+ image.shape = 5, 2
+
+ plot.addImage(image, legend='image 0', info=["Hi!"])
+
+ # Active image
+ data, legend, info, something, params = plot.getActiveImage()
+ self.assertEqual(legend, 'image 0')
+ self.assertEqual(info, ["Hi!"])
+ self.assertTrue(numpy.allclose(data, image), "image 0 data not correct")
+
+ def testGetAllImages(self):
+ """PlotWidget.getAllImages test"""
+
+ plot = PlotWidget(backend='none')
+
+ # No image
+ images = plot.getAllImages()
+ self.assertEqual(len(images), 0)
+
+ # 2 images
+ data = numpy.arange(100).reshape(10, 10)
+ plot.addImage(data, legend='1')
+ plot.addImage(data, origin=(10, 10), legend='2')
+ images = plot.getAllImages(just_legend=True)
+ self.assertEqual(list(images), ['1', '2'])
+ images = plot.getAllImages(just_legend=False)
+ self.assertEqual(len(images), 2)
+ self.assertEqual(images[0].getName(), '1')
+ self.assertEqual(images[1].getName(), '2')
+
+
+class TestPlotAddScatter(unittest.TestCase):
+ """Test of plot addScatter"""
+
+ def testAddGetScatter(self):
+
+ plot = PlotWidget(backend='none')
+
+ # No curve
+ scatter = plot._getItem(kind="scatter")
+ self.assertIsNone(scatter) # No curve
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+ plot._setActiveItem('scatter', 'scatter 0')
+
+ # Active scatter
+ active = plot._getActiveItem(kind='scatter')
+ self.assertEqual(active.getName(), 'scatter 0')
+
+ # check default values
+ self.assertAlmostEqual(active.getSymbolSize(), active._DEFAULT_SYMBOL_SIZE)
+ self.assertEqual(active.getSymbol(), "o")
+ self.assertAlmostEqual(active.getAlpha(), 1.0)
+
+ # modify parameters
+ active.setSymbolSize(20.5)
+ active.setSymbol("d")
+ active.setAlpha(0.777)
+
+ s0 = plot.getScatter("scatter 0")
+
+ self.assertAlmostEqual(s0.getSymbolSize(), 20.5)
+ self.assertEqual(s0.getSymbol(), "d")
+ self.assertAlmostEqual(s0.getAlpha(), 0.777)
+
+ scatter1 = plot._getItem(kind='scatter', legend='scatter 1')
+ self.assertEqual(scatter1.getName(), 'scatter 1')
+
+ def testGetAllScatters(self):
+ """PlotWidget.getAllImages test"""
+
+ plot = PlotWidget(backend='none')
+
+ items = plot.getItems()
+ self.assertEqual(len(items), 0)
+
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 0')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 1')
+ plot.addScatter(x=(0, 1), y=(0, 1), value=(0, 1), legend='scatter 2')
+
+ items = plot.getItems()
+ self.assertEqual(len(items), 3)
+ self.assertEqual(items[0].getName(), 'scatter 0')
+ self.assertEqual(items[1].getName(), 'scatter 1')
+ self.assertEqual(items[2].getName(), 'scatter 2')
+
+
+class TestPlotHistogram(unittest.TestCase):
+ """Basic tests for histogram."""
+
+ def testEdges(self):
+ x = numpy.array([0, 1, 2])
+ edgesRight = numpy.array([0, 1, 2, 3])
+ edgesLeft = numpy.array([-1, 0, 1, 2])
+ edgesCenter = numpy.array([-0.5, 0.5, 1.5, 2.5])
+
+ # testing x values for right
+ edges = _computeEdges(x, 'right')
+ numpy.testing.assert_array_equal(edges, edgesRight)
+
+ edges = _computeEdges(x, 'center')
+ numpy.testing.assert_array_equal(edges, edgesCenter)
+
+ edges = _computeEdges(x, 'left')
+ numpy.testing.assert_array_equal(edges, edgesLeft)
+
+ def testHistogramCurve(self):
+ y = numpy.array([3, 2, 5])
+ edges = numpy.array([0, 1, 2, 3])
+
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([3, 3, 2, 2, 5, 5]))
+
+ y = numpy.array([-3, 2, 5, 0])
+ edges = numpy.array([-2, -1, 0, 1, 2])
+ xHisto, yHisto = _getHistogramCurve(y, edges)
+ numpy.testing.assert_array_equal(
+ yHisto, numpy.array([-3, -3, 2, 2, 5, 5, 0, 0]))
diff --git a/src/silx/gui/plot/test/testPlotWindow.py b/src/silx/gui/plot/test/testPlotWindow.py
new file mode 100644
index 0000000..9e1497f
--- /dev/null
+++ b/src/silx/gui/plot/test/testPlotWindow.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "27/06/2017"
+
+
+import unittest
+import numpy
+import pytest
+
+from silx.gui.utils.testutils import TestCaseQt, getQToolButtonFromAction
+
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.colors import Colormap
+
+
+class TestPlotWindow(TestCaseQt):
+ """Base class for tests of PlotWindow."""
+
+ def setUp(self):
+ super(TestPlotWindow, self).setUp()
+ self.plot = PlotWindow()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestPlotWindow, self).tearDown()
+
+ def testActions(self):
+ """Test the actions QToolButtons"""
+ self.plot.setLimits(1, 100, 1, 100)
+
+ checkList = [ # QAction, Plot state getter
+ (self.plot.xAxisAutoScaleAction, self.plot.getXAxis().isAutoScale),
+ (self.plot.yAxisAutoScaleAction, self.plot.getYAxis().isAutoScale),
+ (self.plot.xAxisLogarithmicAction, self.plot.getXAxis()._isLogarithmic),
+ (self.plot.yAxisLogarithmicAction, self.plot.getYAxis()._isLogarithmic),
+ (self.plot.gridAction, self.plot.getGraphGrid),
+ ]
+
+ for action, getter in checkList:
+ self.mouseMove(self.plot)
+ initialState = getter()
+ toolButton = getQToolButtonFromAction(action)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertNotEqual(getter(), initialState,
+ msg='"%s" state not changed' % action.text())
+
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.assertEqual(getter(), initialState,
+ msg='"%s" state not changed' % action.text())
+
+ # Trigger a zoom reset
+ self.mouseMove(self.plot)
+ resetZoomAction = self.plot.resetZoomAction
+ toolButton = getQToolButtonFromAction(resetZoomAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ def testDockWidgets(self):
+ """Test add/remove dock widgets"""
+ dock1 = qt.QDockWidget('Test 1')
+ dock1.setWidget(qt.QLabel('Test 1'))
+
+ self.plot.addTabbedDockWidget(dock1)
+ self.qapp.processEvents()
+
+ self.plot.removeDockWidget(dock1)
+ self.qapp.processEvents()
+
+ dock2 = qt.QDockWidget('Test 2')
+ dock2.setWidget(qt.QLabel('Test 2'))
+
+ self.plot.addTabbedDockWidget(dock2)
+ self.qapp.processEvents()
+
+ if qt.BINDING != 'PySide2':
+ # Weird bug with PySide2 later upon gc.collect() when getting the layout
+ self.assertNotEqual(self.plot.layout().indexOf(dock2),
+ -1,
+ "dock2 not properly displayed")
+
+ def testToolAspectRatio(self):
+ self.plot.toolBar()
+ self.plot.keepDataAspectRatioButton.keepDataAspectRatio()
+ self.assertTrue(self.plot.isKeepDataAspectRatio())
+ self.plot.keepDataAspectRatioButton.dontKeepDataAspectRatio()
+ self.assertFalse(self.plot.isKeepDataAspectRatio())
+
+ def testToolYAxisOrigin(self):
+ self.plot.toolBar()
+ self.plot.yAxisInvertedButton.setYAxisUpward()
+ self.assertFalse(self.plot.getYAxis().isInverted())
+ self.plot.yAxisInvertedButton.setYAxisDownward()
+ self.assertTrue(self.plot.getYAxis().isInverted())
+
+ def testColormapAutoscaleCache(self):
+ # Test that the min/max cache is not computed twice
+
+ old = Colormap._computeAutoscaleRange
+ self._count = 0
+ def _computeAutoscaleRange(colormap, data):
+ self._count = self._count + 1
+ return 10, 20
+ Colormap._computeAutoscaleRange = _computeAutoscaleRange
+ try:
+ colormap = Colormap(name='red')
+ self.plot.setVisible(True)
+
+ # Add an image
+ data = numpy.arange(8**2).reshape(8, 8)
+ self.plot.addImage(data, legend="foo", colormap=colormap)
+ self.plot.setActiveImage("foo")
+
+ # Use the colorbar
+ self.plot.getColorBarWidget().setVisible(True)
+ self.qWait(50)
+
+ # Remove and add again the same item
+ image = self.plot.getImage("foo")
+ self.plot.removeImage("foo")
+ self.plot.addItem(image)
+ self.qWait(50)
+ finally:
+ Colormap._computeAutoscaleRange = old
+ self.assertEqual(self._count, 1)
+ del self._count
+
+ @pytest.mark.usefixtures("use_opengl")
+ def testSwitchBackend(self):
+ """Test switching an empty plot"""
+ self.plot.resetZoom()
+ xlimits = self.plot.getXAxis().getLimits()
+ ylimits = self.plot.getYAxis().getLimits()
+ isKeepAspectRatio = self.plot.isKeepDataAspectRatio()
+
+ for backend in ('gl', 'mpl'):
+ with self.subTest():
+ self.plot.setBackend(backend)
+ self.plot.replot()
+ self.assertEqual(self.plot.getXAxis().getLimits(), xlimits)
+ self.assertEqual(self.plot.getYAxis().getLimits(), ylimits)
+ self.assertEqual(
+ self.plot.isKeepDataAspectRatio(), isKeepAspectRatio)
diff --git a/src/silx/gui/plot/test/testRoiStatsWidget.py b/src/silx/gui/plot/test/testRoiStatsWidget.py
new file mode 100644
index 0000000..eb29267
--- /dev/null
+++ b/src/silx/gui/plot/test/testRoiStatsWidget.py
@@ -0,0 +1,277 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ROIStatsWidget"""
+
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.ROIStatsWidget import ROIStatsWidget
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.StatsWidget import UpdateMode
+import unittest
+import numpy
+
+
+
+class _TestRoiStatsBase(TestCaseQt):
+ """Base class for several unittest relative to ROIStatsWidget"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ # define plot
+ self.plot = PlotWindow()
+ self.plot.addImage(numpy.arange(10000).reshape(100, 100),
+ legend='img1')
+ self.img_item = self.plot.getImage('img1')
+ self.plot.addCurve(x=numpy.linspace(0, 10, 56), y=numpy.arange(56),
+ legend='curve1')
+ self.curve_item = self.plot.getCurve('curve1')
+ self.plot.addHistogram(edges=numpy.linspace(0, 10, 56),
+ histogram=numpy.arange(56), legend='histo1')
+ self.histogram_item = self.plot.getHistogram(legend='histo1')
+ self.plot.addScatter(x=numpy.linspace(0, 10, 56),
+ y=numpy.linspace(0, 10, 56),
+ value=numpy.arange(56),
+ legend='scatter1')
+ self.scatter_item = self.plot.getScatter(legend='scatter1')
+
+ # stats widget
+ self.statsWidget = ROIStatsWidget(plot=self.plot)
+
+ # define stats
+ stats = [
+ ('sum', numpy.sum),
+ ('mean', numpy.mean),
+ ]
+ self.statsWidget.setStats(stats=stats)
+
+ # define rois
+ self.roi1D = ROI(name='range1', fromdata=0, todata=4, type_='energy')
+ self.rectangle_roi = RectangleROI()
+ self.rectangle_roi.setGeometry(origin=(0, 0), size=(20, 20))
+ self.rectangle_roi.setName('Initial ROI')
+ self.polygon_roi = PolygonROI()
+ points = numpy.array([[0, 5], [5, 0], [10, 5], [5, 10]])
+ self.polygon_roi.setPoints(points)
+
+ def statsTable(self):
+ return self.statsWidget._statsROITable
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.statsWidget.close()
+ self.statsWidget = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose, True)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+
+class TestRoiStatsCouple(_TestRoiStatsBase):
+ """
+ Test different possible couple (roi, plotItem).
+ Check that:
+
+ * computation is correct if couple is valid
+ * raise an error if couple is invalid
+ """
+ def testROICurve(self):
+ """
+ Test that the couple (ROI, curveItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.curve_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+ def testRectangleImage(self):
+ """
+ Test that the couple (RectangleROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ self.plot.addImage(numpy.ones(10000).reshape(100, 100),
+ legend='img1')
+ self.qapp.processEvents()
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), str(float(21*21)))
+ self.assertEqual(tableItems['mean'].text(), '1.0')
+
+ def testPolygonImage(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.polygon_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '22750')
+ self.assertEqual(tableItems['mean'].text(), '455.0')
+
+ def testROIImage(self):
+ """
+ Test that the couple (ROI, imageItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.img_item)
+
+ def testRectangleCurve(self):
+ """
+ Test that the couple (rectangleROI, curveItem) is raising an error
+ """
+ with self.assertRaises(TypeError):
+ self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.curve_item)
+
+ def testROIHistogram(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+ def testROIScatter(self):
+ """
+ Test that the couple (PolygonROI, imageItem) can be used for stats
+ """
+ item = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.scatter_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '253')
+ self.assertEqual(tableItems['mean'].text(), '11.0')
+
+
+class TestRoiStatsAddRemoveItem(_TestRoiStatsBase):
+ """Test adding and removing (roi, plotItem) items"""
+ def testAddRemoveItems(self):
+ item1 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.scatter_item)
+ self.assertTrue(item1 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 1)
+ item2 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ self.assertTrue(item2 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to add twice the same item
+ item3 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.histogram_item)
+ self.assertTrue(item3 is None)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ item4 = self.statsWidget.addItem(roi=self.roi1D,
+ plotItem=self.curve_item)
+ self.assertTrue(item4 is not None)
+ self.assertEqual(self.statsTable().rowCount(), 3)
+
+ self.statsWidget.removeItem(plotItem=item4._plot_item,
+ roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ # try to remove twice the same item
+ self.statsWidget.removeItem(plotItem=item4._plot_item,
+ roi=item4._roi)
+ self.assertEqual(self.statsTable().rowCount(), 2)
+ self.statsWidget.removeItem(plotItem=item2._plot_item,
+ roi=item2._roi)
+ self.statsWidget.removeItem(plotItem=item1._plot_item,
+ roi=item1._roi)
+ self.assertEqual(self.statsTable().rowCount(), 0)
+
+
+class TestRoiStatsRoiUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the roi is updated"""
+ def testChangeRoi(self):
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '445410')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+
+ # update roi
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['sum'].text(), '445410')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ self.rectangle_roi.setOrigin(position=(10, 10))
+ self.qapp.processEvents()
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertNotEqual(tableItems['sum'].text(), '445410')
+ self.assertNotEqual(tableItems['mean'].text(), '1010.0')
+
+
+class TestRoiStatsPlotItemUpdate(_TestRoiStatsBase):
+ """Test that the stats will be updated if the plot item is updated"""
+ def testChangeImage(self):
+ self.statsWidget._setUpdateMode(UpdateMode.AUTO)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+
+ # update plot
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
+ legend='img1')
+ self.assertNotEqual(tableItems['mean'].text(), '1059.5')
+
+ def testUpdateModeScenario(self):
+ """Test update according to a simple scenario"""
+ self.statsWidget._setUpdateMode(UpdateMode.MANUAL)
+ item = self.statsWidget.addItem(roi=self.rectangle_roi,
+ plotItem=self.img_item)
+
+ assert item is not None
+ tableItems = self.statsTable()._itemToTableItems(item)
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.plot.addImage(numpy.arange(100, 10100).reshape(100, 100),
+ legend='img1')
+ self.assertEqual(tableItems['mean'].text(), '1010.0')
+ self.statsWidget._updateAllStats(is_request=True)
+ self.assertEqual(tableItems['mean'].text(), '1110.0')
diff --git a/src/silx/gui/plot/test/testSaveAction.py b/src/silx/gui/plot/test/testSaveAction.py
new file mode 100644
index 0000000..9280fb6
--- /dev/null
+++ b/src/silx/gui/plot/test/testSaveAction.py
@@ -0,0 +1,132 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test the plot's save action (consistency of output)"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "28/11/2017"
+
+
+import unittest
+import tempfile
+import os
+
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.actions.io import SaveAction
+
+
+class TestSaveActionSaveCurvesAsSpec(unittest.TestCase):
+
+ def setUp(self):
+ self.plot = PlotWidget(backend='none')
+ self.saveAction = SaveAction(plot=self.plot)
+
+ self.tempdir = tempfile.mkdtemp()
+ self.out_fname = os.path.join(self.tempdir, "out.dat")
+
+ def tearDown(self):
+ os.unlink(self.out_fname)
+ os.rmdir(self.tempdir)
+
+ def testSaveMultipleCurvesAsSpec(self):
+ """Test that labels are properly used."""
+ self.plot.setGraphXLabel("graph x label")
+ self.plot.setGraphYLabel("graph y label")
+
+ self.plot.addCurve([0, 1], [1, 2], "curve with labels",
+ xlabel="curve0 X", ylabel="curve0 Y")
+ self.plot.addCurve([-1, 3], [-6, 2], "curve with X label",
+ xlabel="curve1 X")
+ self.plot.addCurve([-2, 0], [8, 12], "curve with Y label",
+ ylabel="curve2 Y")
+ self.plot.addCurve([3, 1], [7, 6], "curve with no labels")
+
+ self.saveAction._saveCurves(self.plot,
+ self.out_fname,
+ SaveAction.DEFAULT_ALL_CURVES_FILTERS[0]) # "All curves as SpecFile (*.dat)"
+
+ with open(self.out_fname, "rb") as f:
+ file_content = f.read()
+ if hasattr(file_content, "decode"):
+ file_content = file_content.decode()
+
+ # case with all curve labels specified
+ self.assertIn("#S 1 curve0 Y", file_content)
+ self.assertIn("#L curve0 X curve0 Y", file_content)
+
+ # graph X&Y labels are used when no curve label is specified
+ self.assertIn("#S 2 graph y label", file_content)
+ self.assertIn("#L curve1 X graph y label", file_content)
+
+ self.assertIn("#S 3 curve2 Y", file_content)
+ self.assertIn("#L graph x label curve2 Y", file_content)
+
+ self.assertIn("#S 4 graph y label", file_content)
+ self.assertIn("#L graph x label graph y label", file_content)
+
+
+class TestSaveActionExtension(PlotWidgetTestCase):
+ """Test SaveAction file filter API"""
+
+ def _dummySaveFunction(self, plot, filename, nameFilter):
+ pass
+
+ def testFileFilterAPI(self):
+ """Test addition/update of a file filter"""
+ saveAction = SaveAction(plot=self.plot, parent=self.plot)
+
+ # Add a new file filter
+ nameFilter = 'Dummy file (*.dummy)'
+ saveAction.setFileFilter('all', nameFilter, self._dummySaveFunction)
+ self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
+ self.assertEqual(saveAction.getFileFilters('all')[nameFilter],
+ self._dummySaveFunction)
+
+ # Add a new file filter at a particular position
+ nameFilter = 'Dummy file2 (*.dummy)'
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=3)
+ self.assertTrue(nameFilter in saveAction.getFileFilters('all'))
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter),3)
+
+ # Update an existing file filter
+ nameFilter = SaveAction.IMAGE_FILTER_EDF
+ saveAction.setFileFilter('image', nameFilter, self._dummySaveFunction)
+ self.assertEqual(saveAction.getFileFilters('image')[nameFilter],
+ self._dummySaveFunction)
+
+ # Change the position of an existing file filter
+ nameFilter = 'Dummy file2 (*.dummy)'
+ oldIndex = list(saveAction.getFileFilters('all')).index(nameFilter)
+ newIndex = oldIndex - 1
+ saveAction.setFileFilter('all', nameFilter,
+ self._dummySaveFunction, index=newIndex)
+ filters = saveAction.getFileFilters('all')
+ self.assertEqual(filters[nameFilter], self._dummySaveFunction)
+ self.assertEqual(list(filters.keys()).index(nameFilter), newIndex)
diff --git a/src/silx/gui/plot/test/testScatterMaskToolsWidget.py b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
new file mode 100644
index 0000000..447ee58
--- /dev/null
+++ b/src/silx/gui/plot/test/testScatterMaskToolsWidget.py
@@ -0,0 +1,306 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for MaskToolsWidget"""
+
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+import unittest
+
+import numpy
+
+from silx.gui import qt
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import getQToolButtonFromAction
+from silx.gui.plot import PlotWindow, ScatterMaskToolsWidget
+from .utils import PlotWidgetTestCase
+
+import fabio
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestScatterMaskToolsWidget(PlotWidgetTestCase, ParametricTestCase):
+ """Basic test for MaskToolsWidget"""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestScatterMaskToolsWidget, self).setUp()
+ self.widget = ScatterMaskToolsWidget.ScatterMaskToolsDockWidget(
+ plot=self.plot, name='TEST')
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, self.widget)
+
+ self.maskWidget = self.widget.widget()
+
+ def tearDown(self):
+ del self.maskWidget
+ del self.widget
+ super(TestScatterMaskToolsWidget, self).tearDown()
+
+ def testEmptyPlot(self):
+ """Empty plot, display MaskToolsDockWidget, toggle multiple masks"""
+ self.maskWidget.setMultipleMasks('single')
+ self.qapp.processEvents()
+
+ self.maskWidget.setMultipleMasks('exclusive')
+ self.qapp.processEvents()
+
+ def _drag(self):
+ """Drag from plot center to offset position"""
+ plot = self.plot.getWidgetHandle()
+ xCenter, yCenter = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ pos0 = xCenter, yCenter
+ pos1 = xCenter + offset, yCenter + offset
+
+ self.mouseMove(plot, pos=(0, 0))
+ self.mouseMove(plot, pos=pos0)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos0)
+ self.qapp.processEvents()
+
+ self.mouseMove(plot, pos=(pos0[0] + offset // 2, pos0[1] + offset // 2))
+ self.mouseMove(plot, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos1)
+ self.qapp.processEvents()
+ self.mouseMove(plot, pos=(0, 0))
+
+ def _drawPolygon(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset),
+ (x, y + offset)] # Close polygon
+
+ self.mouseMove(plot, pos=[0, 0])
+ for pos in star:
+ self.mouseMove(plot, pos=pos)
+ self.qapp.processEvents()
+ self.mousePress(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+ self.mouseRelease(plot, qt.Qt.LeftButton, pos=pos)
+ self.qapp.processEvents()
+
+ def _drawPencil(self):
+ """Draw a star polygon in the plot"""
+ plot = self.plot.getWidgetHandle()
+ x, y = plot.width() // 2, plot.height() // 2
+ offset = min(plot.width(), plot.height()) // 10
+
+ star = [(x, y + offset),
+ (x - offset, y - offset),
+ (x + offset, y),
+ (x - offset, y),
+ (x + offset, y - offset)]
+
+ self.mouseMove(plot, pos=[0, 0])
+ self.mouseMove(plot, pos=star[0])
+ self.mousePress(plot, qt.Qt.LeftButton, pos=star[0])
+ for pos in star[1:]:
+ self.mouseMove(plot, pos=pos)
+ self.mouseRelease(
+ plot, qt.Qt.LeftButton, pos=star[-1])
+
+ def testWithAScatter(self):
+ """Plot with a Scatter: test MaskToolsWidget interactions"""
+
+ # Add and remove a scatter (this should enable/disable GUI + change mask)
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=numpy.arange(256),
+ value=numpy.random.random(256),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='scatter')
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Test draw rectangle #
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drag()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw polygon #
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPolygon()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test draw pencil #
+ toolButton = getQToolButtonFromAction(self.maskWidget.pencilAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.maskWidget.pencilSpinBox.setValue(30)
+ self.qapp.processEvents()
+
+ # mask
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertFalse(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # unmask same region
+ self.maskWidget.maskStateGroup.button(0).click()
+ self.qapp.processEvents()
+ self._drawPencil()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ # Test no draw tool #
+ toolButton = getQToolButtonFromAction(self.maskWidget.browseAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+
+ self.plot.clear()
+
+ def __loadSave(self, file_format):
+ self.plot.addScatter(
+ x=numpy.arange(256),
+ y=25 * (numpy.arange(256) % 10),
+ value=numpy.random.random(256),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ # Draw a polygon mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.polygonAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self._drawPolygon()
+
+ ref_mask = self.maskWidget.getSelectionMask()
+ self.assertFalse(numpy.all(numpy.equal(ref_mask, 0)))
+
+ with temp_dir() as tmp:
+ mask_filename = os.path.join(tmp, 'mask.' + file_format)
+ self.maskWidget.save(mask_filename, file_format)
+
+ self.maskWidget.resetSelectionMask()
+ self.assertTrue(
+ numpy.all(numpy.equal(self.maskWidget.getSelectionMask(), 0)))
+
+ self.maskWidget.load(mask_filename)
+ self.assertTrue(numpy.all(numpy.equal(
+ self.maskWidget.getSelectionMask(), ref_mask)))
+
+ def testLoadSaveNpy(self):
+ self.__loadSave("npy")
+
+ def testLoadSaveCsv(self):
+ self.__loadSave("csv")
+
+ def testSigMaskChangedEmitted(self):
+ self.qapp.processEvents()
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.ones((1000,)),
+ legend='test')
+ self.plot._setActiveItem(kind="scatter", legend="test")
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ self.plot.remove('test', kind='scatter')
+ self.qapp.processEvents()
+
+ self.plot.addScatter(
+ x=numpy.arange(1000),
+ y=1000 * (numpy.arange(1000) % 20),
+ value=numpy.random.random(1000),
+ legend='test')
+
+ l = []
+
+ def slot():
+ l.append(1)
+
+ self.maskWidget.sigMaskChanged.connect(slot)
+
+ # rectangle mask
+ toolButton = getQToolButtonFromAction(self.maskWidget.rectAction)
+ self.assertIsNot(toolButton, None)
+ self.mouseClick(toolButton, qt.Qt.LeftButton)
+ self.maskWidget.maskStateGroup.button(1).click()
+ self.qapp.processEvents()
+ self._drag()
+
+ self.assertGreater(len(l), 0)
diff --git a/src/silx/gui/plot/test/testScatterView.py b/src/silx/gui/plot/test/testScatterView.py
new file mode 100644
index 0000000..d11d4d8
--- /dev/null
+++ b/src/silx/gui/plot/test/testScatterView.py
@@ -0,0 +1,123 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for ScatterView"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.gui.plot.items import Axis, Scatter
+from silx.gui.plot import ScatterView
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+
+class TestScatterView(PlotWidgetTestCase):
+ """Test of ScatterView widget"""
+
+ def _createPlot(self):
+ return ScatterView()
+
+ def test(self):
+ """Simple tests"""
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ value = numpy.arange(100)
+ self.plot.setData(x, y, value)
+ self.qapp.processEvents()
+
+ data = self.plot.getData()
+ self.assertEqual(len(data), 5)
+ self.assertTrue(numpy.all(numpy.equal(x, data[0])))
+ self.assertTrue(numpy.all(numpy.equal(y, data[1])))
+ self.assertTrue(numpy.all(numpy.equal(value, data[2])))
+ self.assertIsNone(data[3]) # xerror
+ self.assertIsNone(data[4]) # yerror
+
+ # Test access to scatter item
+ self.assertIsInstance(self.plot.getScatterItem(), Scatter)
+
+ # Test toolbar actions
+
+ action = self.plot.getScatterToolBar().getXAxisLogarithmicAction()
+ action.trigger()
+ self.qapp.processEvents()
+
+ maskAction = self.plot.getScatterToolBar().actions()[-1]
+ maskAction.trigger()
+ self.qapp.processEvents()
+
+ # Test proxy API
+
+ self.plot.resetZoom()
+ self.qapp.processEvents()
+
+ scale = self.plot.getXAxis().getScale()
+ self.assertEqual(scale, Axis.LOGARITHMIC)
+
+ scale = self.plot.getYAxis().getScale()
+ self.assertEqual(scale, Axis.LINEAR)
+
+ title = 'Test ScatterView'
+ self.plot.setGraphTitle(title)
+ self.assertEqual(self.plot.getGraphTitle(), title)
+
+ self.qapp.processEvents()
+
+ # Reset scatter data
+
+ self.plot.setData(None, None, None)
+ self.qapp.processEvents()
+
+ data = self.plot.getData()
+ self.assertEqual(len(data), 5)
+ self.assertEqual(len(data[0]), 0) # x
+ self.assertEqual(len(data[1]), 0) # y
+ self.assertEqual(len(data[2]), 0) # value
+ self.assertIsNone(data[3]) # xerror
+ self.assertIsNone(data[4]) # yerror
+
+ def testAlpha(self):
+ """Test alpha transparency in setData"""
+ _pts = 100
+ _levels = 100
+ _fwhm = 50
+ x = numpy.random.rand(_pts)*_levels
+ y = numpy.random.rand(_pts)*_levels
+ value = numpy.random.rand(_pts)*_levels
+ x0 = x[int(_pts/2)]
+ y0 = x[int(_pts/2)]
+ #2D Gaussian kernel
+ alpha = numpy.exp(-4*numpy.log(2) * ((x-x0)**2 + (y-y0)**2) / _fwhm**2)
+
+ self.plot.setData(x, y, value, alpha=alpha)
+ self.qapp.processEvents()
+
+ alphaData = self.plot.getScatterItem().getAlphaData()
+ self.assertTrue(numpy.all(numpy.equal(alpha, alphaData)))
diff --git a/src/silx/gui/plot/test/testStackView.py b/src/silx/gui/plot/test/testStackView.py
new file mode 100644
index 0000000..0d18113
--- /dev/null
+++ b/src/silx/gui/plot/test/testStackView.py
@@ -0,0 +1,248 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for StackView"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "20/03/2017"
+
+
+import unittest
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+
+from silx.gui import qt
+from silx.gui.plot import StackView
+from silx.gui.plot.StackView import StackViewMainWindow
+
+from silx.utils.array_like import ListOfImages
+
+
+class TestStackView(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackView, self).setUp()
+ self.stackview = StackView()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (10, 20, 30)
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackView, self).tearDown()
+
+ def testScaleColormapRangeToStack(self):
+ """Test scaleColormapRangeToStack"""
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis")
+ colormap = self.stackview.getColormap()
+
+ # Colormap autoscale to image
+ self.assertEqual(colormap.getVRange(), (None, None))
+ self.stackview.scaleColormapRangeToStack()
+
+ # Colormap range set according to stack range
+ self.assertEqual(colormap.getVRange(), (self.mystack.min(), self.mystack.max()))
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis", autoscale=True)
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertEqual(params["colormap"]["name"],
+ "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ # my_orig_stack, params = self.stackview.getStack()
+ my_trans_stack, params = self.stackview.getCurrentView()
+
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
+ self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
+ my_trans_stack))
+
+ def testSetStackListOfImages(self):
+ loi = [self.mystack[i] for i in range(self.mystack.shape[0])]
+
+ self.stackview.setStack(loi)
+ my_orig_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ my_trans_stack, params = self.stackview.getStack(returnNumpyArray=True)
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_orig_stack))
+ self.assertIsInstance(my_trans_stack, numpy.ndarray)
+
+ self.stackview.setStack(loi, perspective=2)
+ my_orig_stack, params = self.stackview.getStack(copy=False)
+ my_trans_stack, params = self.stackview.getCurrentView(copy=False)
+ # getStack(copy=False) must return the object set in setStack
+ self.assertIs(my_orig_stack, loi)
+ # getCurrentView(copy=False) returns a ListOfImages whose .images
+ # attr is the original data
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[2], self.mystack.shape[0], self.mystack.shape[1]))
+ self.assertTrue(numpy.array_equal(numpy.array(my_trans_stack),
+ numpy.transpose(self.mystack, axes=(2, 0, 1))))
+ self.assertIsInstance(my_trans_stack,
+ ListOfImages) # returnNumpyArray=False by default in getStack
+ self.assertIs(my_trans_stack.images, loi)
+
+ def testPerspective(self):
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)))
+ self.assertEqual(self.stackview._perspective, 0,
+ "Default perspective is not 0 (dim1-dim2).")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.assertEqual(self.stackview._perspective, 1,
+ "Plane selection combobox not updating perspective")
+
+ self.stackview.setStack(numpy.arange(6).reshape((1, 2, 3)))
+ self.assertEqual(self.stackview._perspective, 1,
+ "Perspective not preserved when calling setStack "
+ "without specifying the perspective parameter.")
+
+ self.stackview.setStack(numpy.arange(24).reshape((2, 3, 4)), perspective=2)
+ self.assertEqual(self.stackview._perspective, 2,
+ "Perspective not set in setStack(..., perspective=2).")
+
+ def testDefaultTitle(self):
+ """Test that the plot title contains the proper Z information"""
+ self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=2")
+
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=-10")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=10")
+
+ self.stackview._StackView__planeSelection.setPerspective(2)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=3.14")
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Image z=6.28")
+
+ def testCustomTitle(self):
+ """Test setting the plot title with a user defined callback"""
+ self.stackview.setStack(numpy.arange(24).reshape((4, 3, 2)),
+ calibrations=[(0, 1), (-10, 10), (3.14, 3.14)])
+
+ def title_callback(frame_idx):
+ return "Cubed index title %d" % (frame_idx**3)
+
+ self.stackview.setTitleCallback(title_callback)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 8")
+
+ # perspective should not matter, only frame index
+ self.stackview._StackView__planeSelection.setPerspective(1)
+ self.stackview.setFrameNumber(0)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 0")
+ self.stackview.setFrameNumber(2)
+ self.assertEqual(self.stackview._plot.getGraphTitle(),
+ "Cubed index title 8")
+
+ with self.assertRaises(TypeError):
+ # setTitleCallback should not accept non-callable objects like strings
+ self.stackview.setTitleCallback(
+ "Là, vous faites sirop de vingt-et-un et vous dites : "
+ "beau sirop, mi-sirop, siroté, gagne-sirop, sirop-grelot,"
+ " passe-montagne, sirop au bon goût.")
+
+ def testStackFrameNumber(self):
+ self.stackview.setStack(self.mystack)
+ self.assertEqual(self.stackview.getFrameNumber(), 0)
+
+ listener = SignalListener()
+ self.stackview.sigFrameChanged.connect(listener)
+
+ self.stackview.setFrameNumber(1)
+ self.assertEqual(self.stackview.getFrameNumber(), 1)
+ self.assertEqual(listener.arguments(), [(1,)])
+
+
+class TestStackViewMainWindow(TestCaseQt):
+ """Base class for tests of StackView."""
+
+ def setUp(self):
+ super(TestStackViewMainWindow, self).setUp()
+ self.stackview = StackViewMainWindow()
+ self.stackview.show()
+ self.qWaitForWindowExposed(self.stackview)
+ self.mystack = numpy.fromfunction(
+ lambda i, j, k: numpy.sin(i/15.) + numpy.cos(j/4.) + 2 * numpy.sin(k/6.),
+ (10, 20, 30)
+ )
+
+ def tearDown(self):
+ self.stackview.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.stackview.close()
+ del self.stackview
+ super(TestStackViewMainWindow, self).tearDown()
+
+ def testSetStack(self):
+ self.stackview.setStack(self.mystack)
+ self.stackview.setColormap("viridis", autoscale=True)
+ my_trans_stack, params = self.stackview.getStack()
+ self.assertEqual(my_trans_stack.shape, self.mystack.shape)
+ self.assertTrue(numpy.array_equal(self.mystack,
+ my_trans_stack))
+ self.assertEqual(params["colormap"]["name"],
+ "viridis")
+
+ def testSetStackPerspective(self):
+ self.stackview.setStack(self.mystack, perspective=1)
+ my_trans_stack, params = self.stackview.getCurrentView()
+ # get stack returns the transposed data, depending on the perspective
+ self.assertEqual(my_trans_stack.shape,
+ (self.mystack.shape[1], self.mystack.shape[0], self.mystack.shape[2]))
+ self.assertTrue(numpy.array_equal(numpy.transpose(self.mystack, axes=(1, 0, 2)),
+ my_trans_stack))
diff --git a/src/silx/gui/plot/test/testStats.py b/src/silx/gui/plot/test/testStats.py
new file mode 100644
index 0000000..0a792a4
--- /dev/null
+++ b/src/silx/gui/plot/test/testStats.py
@@ -0,0 +1,1047 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for CurvesROIWidget"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "07/03/2018"
+
+
+from silx.gui import qt
+from silx.gui.plot.stats import stats
+from silx.gui.plot import StatsWidget
+from silx.gui.plot.stats import statshandler
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import Plot1D, Plot2D
+from silx.gui.plot3d.SceneWidget import SceneWidget
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.CurvesROIWidget import ROI
+from silx.utils.testutils import ParametricTestCase
+import unittest
+import logging
+import numpy
+
+_logger = logging.getLogger(__name__)
+
+
+class TestStatsBase(object):
+ """Base class for stats TestCase"""
+ def setUp(self):
+ self.createCurveContext()
+ self.createImageContext()
+ self.createScatterContext()
+
+ def tearDown(self):
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ del self.plot1d
+ self.plot2d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot2d.close()
+ del self.plot2d
+ self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scatterPlot.close()
+ del self.scatterPlot
+
+ def createCurveContext(self):
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend='curve0')
+
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=None)
+
+ def createScatterContext(self):
+ self.scatterPlot = Plot2D()
+ lgd = 'scatter plot'
+ self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
+ self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
+ self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
+ self.scatterPlot.addScatter(self.xScatterData, self.yScatterData,
+ self.valuesScatterData, legend=lgd)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter(lgd),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=None
+ )
+
+ def createImageContext(self):
+ self.plot2d = Plot2D()
+ self._imgLgd = 'test image'
+ self.imageData = numpy.arange(32*128).reshape(32, 128)
+ self.plot2d.addImage(data=self.imageData,
+ legend=self._imgLgd, replace=False)
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=None
+ )
+
+ def getBasicStats(self):
+ return {
+ 'min': stats.StatMin(),
+ 'minCoords': stats.StatCoordMin(),
+ 'max': stats.StatMax(),
+ 'maxCoords': stats.StatCoordMax(),
+ 'std': stats.Stat(name='std', fct=numpy.std),
+ 'mean': stats.Stat(name='mean', fct=numpy.mean),
+ 'com': stats.StatCOM()
+ }
+
+
+class TestStats(TestStatsBase, TestCaseQt):
+ """
+ Test :class:`BaseClass` class and inheriting classes
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ TestStatsBase.setUp(self)
+
+ def tearDown(self):
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(20))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 19)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (0,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (19,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+
+ def testBasicStatsImage(self):
+ """Test result for simple stats on an image"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 128 * 32 - 1)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (0, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (127, 31))
+ self.assertEqual(_stats['std'].calculate(self.imageContext), numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(self.imageContext), numpy.mean(self.imageData))
+
+ yData = numpy.sum(self.imageData.astype(numpy.float64), axis=1)
+ xData = numpy.sum(self.imageData.astype(numpy.float64), axis=0)
+ dataXRange = range(self.imageData.shape[1])
+ dataYRange = range(self.imageData.shape[0])
+
+ ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+
+ def testStatsImageAdv(self):
+ """Test that scale and origin are taking into account for images"""
+
+ image2Data = numpy.arange(32 * 128).reshape(32, 128)
+ self.plot2d.addImage(data=image2Data, legend=self._imgLgd,
+ replace=True, origin=(100, 10), scale=(2, 0.5))
+ image2Context = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=None,
+ )
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(image2Context), 0)
+ self.assertEqual(
+ _stats['max'].calculate(image2Context), 128 * 32 - 1)
+ self.assertEqual(
+ _stats['minCoords'].calculate(image2Context), (100, 10))
+ self.assertEqual(
+ _stats['maxCoords'].calculate(image2Context), (127*2. + 100,
+ 31 * 0.5 + 10))
+ self.assertEqual(_stats['std'].calculate(image2Context),
+ numpy.std(self.imageData))
+ self.assertEqual(_stats['mean'].calculate(image2Context),
+ numpy.mean(self.imageData))
+
+ yData = numpy.sum(self.imageData, axis=1)
+ xData = numpy.sum(self.imageData, axis=0)
+ dataXRange = numpy.arange(self.imageData.shape[1], dtype=numpy.float64)
+ dataYRange = numpy.arange(self.imageData.shape[0], dtype=numpy.float64)
+
+ ycom = numpy.sum(yData * dataYRange) / numpy.sum(yData)
+ ycom = (ycom * 0.5) + 10
+ xcom = numpy.sum(xData * dataXRange) / numpy.sum(xData)
+ xcom = (xcom * 2.) + 100
+ self.assertTrue(numpy.allclose(
+ _stats['com'].calculate(image2Context), (xcom, ycom)))
+
+ def testBasicStatsScatter(self):
+ """Test result for simple stats on a scatter"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 5)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 90)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (0, 2))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (50, 69))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std(self.valuesScatterData))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean(self.valuesScatterData))
+
+ data = self.valuesScatterData.astype(numpy.float64)
+ comx = numpy.sum(self.xScatterData * data) / numpy.sum(data)
+ comy = numpy.sum(self.yScatterData * data) / numpy.sum(data)
+ self.assertEqual(_stats['com'].calculate(self.scatterContext),
+ (comx, comy))
+
+ def testKindNotManagedByStat(self):
+ """Make sure an exception is raised if we try to execute calculate
+ of the base class"""
+ b = stats.StatBase(name='toto', compatibleKinds='curve')
+ with self.assertRaises(NotImplementedError):
+ b.calculate(self.imageContext)
+
+ def testKindNotManagedByContext(self):
+ """
+ Make sure an error is raised if we try to calculate a statistic with
+ a context not managed
+ """
+ myStat = stats.Stat(name='toto', fct=numpy.std, kinds=('curve'))
+ myStat.calculate(self.curveContext)
+ with self.assertRaises(ValueError):
+ myStat.calculate(self.scatterContext)
+ with self.assertRaises(ValueError):
+ myStat.calculate(self.imageContext)
+
+ def testOnLimits(self):
+ stat = stats.StatMin()
+
+ self.plot1d.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ curveContextOnLimits = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=None)
+ self.assertEqual(stat.calculate(curveContextOnLimits), 2)
+
+ self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
+ imageContextOnLimits = stats._ImageContext(
+ item=self.plot2d.getImage('test image'),
+ plot=self.plot2d,
+ onlimits=True,
+ roi=None)
+ self.assertEqual(stat.calculate(imageContextOnLimits), 32)
+
+ self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
+ scatterContextOnLimits = stats._ScatterContext(
+ item=self.scatterPlot.getScatter('scatter plot'),
+ plot=self.scatterPlot,
+ onlimits=True,
+ roi=None)
+ self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
+
+
+class TestStatsFormatter(TestCaseQt):
+ """Simple test to check usage of the :class:`StatsFormatter`"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend='curve0')
+
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=None)
+
+ self.stat = stats.StatMin()
+
+ def tearDown(self):
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ del self.plot1d
+ TestCaseQt.tearDown(self)
+
+ def testEmptyFormatter(self):
+ """Make sure a formatter with no formatter definition will return a
+ simple cast to str"""
+ emptyFormatter = statshandler.StatFormatter()
+ self.assertEqual(
+ emptyFormatter.format(self.stat.calculate(self.curveContext)), '0.000')
+
+ def testSettedFormatter(self):
+ """Make sure a formatter with no formatter definition will return a
+ simple cast to str"""
+ formatter= statshandler.StatFormatter(formatter='{0:.3f}')
+ self.assertEqual(
+ formatter.format(self.stat.calculate(self.curveContext)), '0.000')
+
+
+class TestStatsHandler(TestCaseQt):
+ """Make sure the StatHandler is correctly making the link between
+ :class:`StatBase` and :class:`StatFormatter` and checking the API is valid
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1d = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plot1d.addCurve(x, y, legend='curve0')
+ self.curveItem = self.plot1d.getCurve('curve0')
+
+ self.stat = stats.StatMin()
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot1d.close()
+ self.plot1d = None
+ TestCaseQt.tearDown(self)
+
+ def testConstructor(self):
+ """Make sure the constructor can deal will all possible arguments:
+
+ * tuple of :class:`StatBase` derivated classes
+ * tuple of tuples (:class:`StatBase`, :class:`StatFormatter`)
+ * tuple of tuples (str, pointer to function, kind)
+ """
+ handler0 = statshandler.StatsHandler(
+ (stats.StatMin(), stats.StatMax())
+ )
+
+ res = handler0.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('min' in res)
+ self.assertEqual(res['min'], '0')
+ self.assertTrue('max' in res)
+ self.assertEqual(res['max'], '19')
+
+ handler1 = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), statshandler.StatFormatter(formatter=None)),
+ (stats.StatMax(), statshandler.StatFormatter())
+ )
+ )
+
+ res = handler1.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('min' in res)
+ self.assertEqual(res['min'], '0')
+ self.assertTrue('max' in res)
+ self.assertEqual(res['max'], '19.000')
+
+ handler2 = statshandler.StatsHandler(
+ (
+ (stats.StatMin(), None),
+ (stats.StatMax(), statshandler.StatFormatter())
+ ))
+
+ res = handler2.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('min' in res)
+ self.assertEqual(res['min'], '0')
+ self.assertTrue('max' in res)
+ self.assertEqual(res['max'], '19.000')
+
+ handler3 = statshandler.StatsHandler((
+ (('amin', numpy.argmin), statshandler.StatFormatter()),
+ ('amax', numpy.argmax)
+ ))
+
+ res = handler3.calculate(item=self.curveItem, plot=self.plot1d,
+ onlimits=False)
+ self.assertTrue('amin' in res)
+ self.assertEqual(res['amin'], '0.000')
+ self.assertTrue('amax' in res)
+ self.assertEqual(res['amax'], '19')
+
+ with self.assertRaises(ValueError):
+ statshandler.StatsHandler(('name'))
+
+
+class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
+ """Basic test for StatsWidget with curves"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot = Plot1D()
+ self.plot.show()
+ x = range(20)
+ y = range(20)
+ self.plot.addCurve(x, y, legend='curve0')
+ y = range(12, 32)
+ self.plot.addCurve(x, y, legend='curve1')
+ y = range(-2, 18)
+ self.plot.addCurve(x, y, legend='curve2')
+ self.widget = StatsWidget.StatsWidget(plot=self.plot)
+ self.statsTable = self.widget._statsTable
+
+ mystats = statshandler.StatsHandler((
+ stats.StatMin(),
+ (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatMax(),
+ (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatDelta(),
+ ('std', numpy.std),
+ ('mean', numpy.mean),
+ stats.StatCOM()
+ ))
+
+ self.statsTable.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.statsTable = None
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def testDisplayActiveItemsSyncOptions(self):
+ """
+ Test that the several option of the sync options are well
+ synchronized between the different object"""
+ widget = StatsWidget.StatsWidget(plot=self.plot)
+ table = StatsWidget.StatsTable(plot=self.plot)
+
+ def check_display_only_active_item(only_active):
+ # check internal value
+ self.assertIs(widget._statsTable._displayOnlyActItem, only_active)
+ # self.assertTrue(table._displayOnlyActItem is only_active)
+ # check gui display
+ self.assertEqual(widget._options.isActiveItemMode(), only_active)
+
+ for displayOnlyActiveItems in (True, False):
+ with self.subTest(displayOnlyActiveItems=displayOnlyActiveItems):
+ widget.setDisplayOnlyActiveItem(displayOnlyActiveItems)
+ # table.setDisplayOnlyActiveItem(displayOnlyActiveItems)
+ check_display_only_active_item(displayOnlyActiveItems)
+
+ check_display_only_active_item(only_active=False)
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ table.setAttribute(qt.Qt.WA_DeleteOnClose)
+ widget.close()
+ table.close()
+
+ def testInit(self):
+ """Make sure all the curves are registred on initialization"""
+ self.assertEqual(self.statsTable.rowCount(), 3)
+
+ def testRemoveCurve(self):
+ """Make sure the Curves stats take into account the curve removal from
+ plot"""
+ self.plot.removeCurve('curve2')
+ self.assertEqual(self.statsTable.rowCount(), 2)
+ for iRow in range(2):
+ self.assertTrue(self.statsTable.item(iRow, 0).text() in ('curve0', 'curve1'))
+
+ self.plot.removeCurve('curve0')
+ self.assertEqual(self.statsTable.rowCount(), 1)
+ self.plot.removeCurve('curve1')
+ self.assertEqual(self.statsTable.rowCount(), 0)
+
+ def testAddCurve(self):
+ """Make sure the Curves stats take into account the add curve action"""
+ self.plot.addCurve(legend='curve3', x=range(10), y=range(10))
+ self.assertEqual(self.statsTable.rowCount(), 4)
+
+ def testUpdateCurveFromAddCurve(self):
+ """Make sure the stats of the cuve will be removed after updating a
+ curve"""
+ self.plot.addCurve(legend='curve0', x=range(10), y=range(10))
+ self.qapp.processEvents()
+ self.assertEqual(self.statsTable.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.statsTable._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '9')
+
+ def testUpdateCurveFromCurveObj(self):
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
+ self.qapp.processEvents()
+ self.assertEqual(self.statsTable.rowCount(), 3)
+ curve = self.plot._getItem(kind='curve', legend='curve0')
+ tableItems = self.statsTable._itemToTableItems(curve)
+ self.assertEqual(tableItems['max'].text(), '3')
+
+ def testSetAnotherPlot(self):
+ plot2 = Plot1D()
+ plot2.addCurve(x=range(26), y=range(26), legend='new curve')
+ self.statsTable.setPlot(plot2)
+ self.assertEqual(self.statsTable.rowCount(), 1)
+ self.qapp.processEvents()
+ plot2.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plot2.close()
+ plot2 = None
+
+ def testUpdateMode(self):
+ """Make sure the update modes are well take into account"""
+ self.plot.setActiveCurve('curve0')
+ for display_only_active in (True, False):
+ with self.subTest(display_only_active=display_only_active):
+ self.widget.setDisplayOnlyActiveItem(display_only_active)
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(4))
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ update_stats_action = self.widget._options.getUpdateStatsAction()
+ # test from api
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
+ self.widget.show()
+ # check stats change in auto mode
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(-1, 3))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == -1.)
+
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == 1.)
+
+ # check stats change in manual mode only if requested
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
+
+ self.plot.getCurve('curve0').setData(x=range(4), y=range(2, 6))
+ self.qapp.processEvents()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == 1.)
+
+ update_stats_action.trigger()
+ tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
+ curve0_min = tableItems['min'].text()
+ self.assertTrue(float(curve0_min) == 2.)
+
+ def testItemHidden(self):
+ """Test if an item is hide, then the associated stats item is also
+ hide"""
+ curve0 = self.plot.getCurve('curve0')
+ curve1 = self.plot.getCurve('curve1')
+ curve2 = self.plot.getCurve('curve2')
+
+ self.plot.show()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+ self.assertFalse(self.statsTable.isRowHidden(0))
+ self.assertFalse(self.statsTable.isRowHidden(1))
+ self.assertFalse(self.statsTable.isRowHidden(2))
+
+ curve0.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(0))
+ curve0.setVisible(True)
+ self.qapp.processEvents()
+ self.assertFalse(self.statsTable.isRowHidden(0))
+ curve1.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(1))
+ tableItems = self.statsTable._itemToTableItems(curve2)
+ curve2_min = tableItems['min'].text()
+ self.assertTrue(float(curve2_min) == -2.)
+
+ curve0.setVisible(False)
+ curve1.setVisible(False)
+ curve2.setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.statsTable.isRowHidden(0))
+ self.assertTrue(self.statsTable.isRowHidden(1))
+ self.assertTrue(self.statsTable.isRowHidden(2))
+
+
+class TestStatsWidgetWithImages(TestCaseQt):
+ """Basic test for StatsWidget with images"""
+
+ IMAGE_LEGEND = 'test image'
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot = Plot2D()
+
+ self.plot.addImage(data=numpy.arange(128*128).reshape(128, 128),
+ legend=self.IMAGE_LEGEND, replace=False)
+
+ self.widget = StatsWidget.StatsTable(plot=self.plot)
+
+ mystats = statshandler.StatsHandler((
+ (stats.StatMin(), statshandler.StatFormatter()),
+ (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ (stats.StatMax(), statshandler.StatFormatter()),
+ (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ (stats.StatDelta(), statshandler.StatFormatter()),
+ ('std', numpy.std),
+ ('mean', numpy.mean),
+ (stats.StatCOM(), statshandler.StatFormatter(None))
+ ))
+
+ self.widget.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ image = self.plot._getItem(
+ kind='image', legend=self.IMAGE_LEGEND)
+ tableItems = self.widget._itemToTableItems(image)
+
+ maxText = '{0:.3f}'.format((128 * 128) - 1)
+ self.assertEqual(tableItems['legend'].text(), self.IMAGE_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '0.000')
+ self.assertEqual(tableItems['max'].text(), maxText)
+ self.assertEqual(tableItems['delta'].text(), maxText)
+ self.assertEqual(tableItems['coords min'].text(), '0.0, 0.0')
+ self.assertEqual(tableItems['coords max'].text(), '127.0, 127.0')
+
+ def testItemHidden(self):
+ """Test if an item is hide, then the associated stats item is also
+ hide"""
+ self.widget.show()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.widget)
+ self.assertFalse(self.widget.isRowHidden(0))
+ self.plot.getImage(self.IMAGE_LEGEND).setVisible(False)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget.isRowHidden(0))
+
+
+class TestStatsWidgetWithScatters(TestCaseQt):
+
+ SCATTER_LEGEND = 'scatter plot'
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.scatterPlot = Plot2D()
+ self.scatterPlot.addScatter([0, 1, 2, 20, 50, 60],
+ [2, 3, 4, 26, 69, 6],
+ [5, 6, 7, 10, 90, 20],
+ legend=self.SCATTER_LEGEND)
+ self.widget = StatsWidget.StatsTable(plot=self.scatterPlot)
+
+ mystats = statshandler.StatsHandler((
+ stats.StatMin(),
+ (stats.StatCoordMin(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatMax(),
+ (stats.StatCoordMax(), statshandler.StatFormatter(None, qt.QTableWidgetItem)),
+ stats.StatDelta(),
+ ('std', numpy.std),
+ ('mean', numpy.mean),
+ stats.StatCOM()
+ ))
+
+ self.widget.setStats(mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scatterPlot.close()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.scatterPlot = None
+ TestCaseQt.tearDown(self)
+
+ def testStats(self):
+ scatter = self.scatterPlot._getItem(
+ kind='scatter', legend=self.SCATTER_LEGEND)
+ tableItems = self.widget._itemToTableItems(scatter)
+ self.assertEqual(tableItems['legend'].text(), self.SCATTER_LEGEND)
+ self.assertEqual(tableItems['min'].text(), '5')
+ self.assertEqual(tableItems['coords min'].text(), '0, 2')
+ self.assertEqual(tableItems['max'].text(), '90')
+ self.assertEqual(tableItems['coords max'].text(), '50, 69')
+ self.assertEqual(tableItems['delta'].text(), '85')
+
+
+class TestEmptyStatsWidget(TestCaseQt):
+ def test(self):
+ widget = StatsWidget.StatsWidget()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+
+class TestLineWidget(TestCaseQt):
+ """Some test for the StatsLineWidget."""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+
+ mystats = statshandler.StatsHandler((
+ (stats.StatMin(), statshandler.StatFormatter()),
+ ))
+
+ self.plot = Plot1D()
+ self.plot.show()
+ self.x = range(20)
+ self.y0 = range(20)
+ self.curve0 = self.plot.addCurve(self.x, self.y0, legend='curve0')
+ self.y1 = range(12, 32)
+ self.plot.addCurve(self.x, self.y1, legend='curve1')
+ self.y2 = range(-2, 18)
+ self.plot.addCurve(self.x, self.y2, legend='curve2')
+ self.widget = StatsWidget.BasicGridStatsWidget(plot=self.plot,
+ kind='curve',
+ stats=mystats)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.widget.setPlot(None)
+ self.widget._lineStatsWidget._statQlineEdit.clear()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def testProcessing(self):
+ self.widget._lineStatsWidget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.plot.setActiveCurve(legend='curve0')
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.000')
+ self.plot.setActiveCurve(legend='curve1')
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '12.000')
+ self.plot.getXAxis().setLimitsConstraints(minPos=2, maxPos=5)
+ self.widget.setStatsOnVisibleData(True)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
+ self.plot.setActiveCurve(None)
+ self.assertIsNone(self.plot.getActiveCurve())
+ self.widget.setStatsOnVisibleData(False)
+ self.qapp.processEvents()
+ self.assertFalse(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '14.000')
+ self.widget.setKind('image')
+ self.plot.addImage(numpy.arange(100*100).reshape(100, 100) + 0.312)
+ self.qapp.processEvents()
+ self.assertTrue(self.widget._lineStatsWidget._statQlineEdit['min'].text() == '0.312')
+
+ def testUpdateMode(self):
+ """Make sure the update modes are well take into account"""
+ self.plot.setActiveCurve(self.curve0)
+ _autoRB = self.widget._options._autoRB
+ _manualRB = self.widget._options._manualRB
+ # test from api
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ self.assertTrue(_autoRB.isChecked())
+ self.assertFalse(_manualRB.isChecked())
+
+ # check stats change in auto mode
+ curve0_min = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ new_y = numpy.array(self.y0) - 2.56
+ self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
+ curve0_min2 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.assertTrue(curve0_min != curve0_min2)
+
+ # check stats change in manual mode only if requested
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertFalse(_autoRB.isChecked())
+ self.assertTrue(_manualRB.isChecked())
+
+ new_y = numpy.array(self.y0) - 1.2
+ self.plot.addCurve(x=self.x, y=new_y, legend=self.curve0)
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.assertTrue(curve0_min3 == curve0_min2)
+ self.widget._options._updateRequested()
+ curve0_min3 = self.widget._lineStatsWidget._statQlineEdit['min'].text()
+ self.assertTrue(curve0_min3 != curve0_min2)
+
+ # test from gui
+ self.widget.showRadioButtons(True)
+ self.widget._options._autoRB.toggle()
+ self.assertTrue(_autoRB.isChecked())
+ self.assertFalse(_manualRB.isChecked())
+
+ self.widget._options._manualRB.toggle()
+ self.assertFalse(_autoRB.isChecked())
+ self.assertTrue(_manualRB.isChecked())
+
+
+class TestUpdateModeWidget(TestCaseQt):
+ """Test UpdateModeWidget"""
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.widget = StatsWidget.UpdateModeWidget(parent=None)
+
+ def tearDown(self):
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ self.widget = None
+ TestCaseQt.tearDown(self)
+
+ def testSignals(self):
+ """Test the signal emission of the widget"""
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ modeChangedListener = SignalListener()
+ manualUpdateListener = SignalListener()
+ self.widget.sigUpdateModeChanged.connect(modeChangedListener)
+ self.widget.sigUpdateRequested.connect(manualUpdateListener)
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.AUTO)
+ self.assertEqual(modeChangedListener.callCount(), 0)
+ self.qapp.processEvents()
+
+ self.widget.setUpdateMode(StatsWidget.UpdateMode.MANUAL)
+ self.assertEqual(self.widget.getUpdateMode(), StatsWidget.UpdateMode.MANUAL)
+ self.qapp.processEvents()
+ self.assertEqual(modeChangedListener.callCount(), 1)
+ self.assertEqual(manualUpdateListener.callCount(), 0)
+ self.widget._updatePB.click()
+ self.widget._updatePB.click()
+ self.assertEqual(manualUpdateListener.callCount(), 2)
+
+ self.widget._autoRB.setChecked(True)
+ self.assertEqual(modeChangedListener.callCount(), 2)
+ self.widget._updatePB.click()
+ self.assertEqual(manualUpdateListener.callCount(), 2)
+
+
+class TestStatsROI(TestStatsBase, TestCaseQt):
+ """
+ Test stats based on ROI
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.createRois()
+ TestStatsBase.setUp(self)
+ self.createHistogramContext()
+
+ self.roiManager = RegionOfInterestManager(self.plot2d)
+ self.roiManager.addRoi(self._2Droi_rect)
+ self.roiManager.addRoi(self._2Droi_poly)
+
+ def tearDown(self):
+ self.roiManager.clear()
+ self.roiManager = None
+ self._1Droi = None
+ self._2Droi_rect = None
+ self._2Droi_poly = None
+ self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plotHisto.close()
+ self.plotHisto = None
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def createRois(self):
+ self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0)
+ self._2Droi_rect = RectangleROI()
+ self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
+ self._2Droi_poly = PolygonROI()
+ points = numpy.array(((0, 20), (0, 0), (10, 0)))
+ self._2Droi_poly.setPoints(points=points)
+
+ def createCurveContext(self):
+ TestStatsBase.createCurveContext(self)
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createHistogramContext(self):
+ self.plotHisto = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plotHisto.addHistogram(x, y, legend='histo0')
+
+ self.histoContext = stats._HistogramContext(
+ item=self.plotHisto.getHistogram('histo0'),
+ plot=self.plotHisto,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createScatterContext(self):
+ TestStatsBase.createScatterContext(self)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter('scatter plot'),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=self._1Droi
+ )
+
+ def createImageContext(self):
+ TestStatsBase.createImageContext(self)
+
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_rect
+ )
+
+ self.imageContext_2 = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_poly
+ )
+
+ def testErrors(self):
+ # test if onlimits is True and give also a roi
+ with self.assertRaises(ValueError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=self._1Droi)
+
+ # test if is a curve context and give an invalid 2D roi
+ with self.assertRaises(TypeError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._2Droi_rect)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(0, 10))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6]))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6]))
+ com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+
+ def testBasicStatsImageRectRoi(self):
+ """Test result for simple stats on an image"""
+ self.assertEqual(self.imageContext.values.compressed().size, 121)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 10)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 1300)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0))
+ self.assertAlmostEqual(_stats['std'].calculate(self.imageContext),
+ numpy.std(self.imageData[0:11, 10:21]))
+ self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext),
+ numpy.mean(self.imageData[0:11, 10:21]))
+
+ compressed_values = self.imageContext.values.compressed()
+ compressed_values = compressed_values.reshape(11, 11)
+ yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1)
+ xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0)
+
+ dataYRange = range(11)
+ dataXRange = range(10, 21)
+
+ ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+
+ def testBasicStatsImagePolyRoi(self):
+ """Test a simple rectangle ROI"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0))
+ # not 0.0, 19.0 because not fully in. Should all pixel have a weight,
+ # on to manage them in stats. For now 0 if the center is not in, else 1
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0))
+
+ def testBasicStatsScatter(self):
+ self.assertEqual(self.scatterContext.values.compressed().size, 2)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 6)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 7)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7]))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7]))
+
+ def testBasicHistogram(self):
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(2, 6))
+ self.assertEqual(_stats['min'].calculate(self.histoContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.histoContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats['com'].calculate(self.histoContext), com)
+
+
+class TestAdvancedROIImageContext(TestCaseQt):
+ """Test stats result on an image context with different scale and
+ origins"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.data_dims = (100, 100)
+ self.data = numpy.random.rand(*self.data_dims)
+ self.plot = Plot2D()
+
+ def tearDown(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+ TestCaseQt.tearDown(self)
+
+ def test(self):
+ """Test stats result on an image context with different scale and
+ origins"""
+ roi_origins = [(0, 0), (2, 10), (14, 20)]
+ img_origins = [(0, 0), (14, 20), (2, 10)]
+ img_scales = [1.0, 0.5, 2.0]
+ _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), }
+ for roi_origin in roi_origins:
+ for img_origin in img_origins:
+ for img_scale in img_scales:
+ with self.subTest(roi_origin=roi_origin,
+ img_origin=img_origin,
+ img_scale=img_scale):
+ self.plot.addImage(self.data, legend='img',
+ origin=img_origin,
+ scale=img_scale)
+ roi = RectangleROI()
+ roi.setGeometry(origin=roi_origin, size=(20, 20))
+ context = stats._ImageContext(
+ item=self.plot.getImage('img'),
+ plot=self.plot,
+ onlimits=False,
+ roi=roi)
+ x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
+ x_end = int(x_start + (20 / img_scale)) + 1
+ y_start = int((roi_origin[1] - img_origin[1])/ img_scale)
+ y_end = int(y_start + (20 / img_scale)) + 1
+ x_start = max(x_start, 0)
+ x_end = min(max(x_end, 0), self.data_dims[1])
+ y_start = max(y_start, 0)
+ y_end = min(max(y_end, 0), self.data_dims[0])
+ th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
+ self.assertAlmostEqual(_stats['sum'].calculate(context),
+ th_sum)
diff --git a/src/silx/gui/plot/test/testUtilsAxis.py b/src/silx/gui/plot/test/testUtilsAxis.py
new file mode 100644
index 0000000..dd4a689
--- /dev/null
+++ b/src/silx/gui/plot/test/testUtilsAxis.py
@@ -0,0 +1,203 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/11/2018"
+
+
+import unittest
+from silx.gui.plot import PlotWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.utils.axis import SyncAxes
+
+
+class TestAxisSync(TestCaseQt):
+ """Tests AxisSync class"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.plot1 = PlotWidget()
+ self.plot2 = PlotWidget()
+ self.plot3 = PlotWidget()
+
+ def tearDown(self):
+ self.plot1 = None
+ self.plot2 = None
+ self.plot3 = None
+ TestCaseQt.tearDown(self)
+
+ def testMoveFirstAxis(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testMoveSecondAxis(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ self.plot2.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testMoveTwoAxes(self):
+ """Test synchronization after construction"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ self.plot1.getXAxis().setLimits(1, 50)
+ self.plot2.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testDestruction(self):
+ """Test synchronization when sync object is destroyed"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ del sync
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testAxisDestruction(self):
+ """Test synchronization when an axis disappear"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+
+ # Destroy the plot is possible
+ import weakref
+ plot = weakref.ref(self.plot2)
+ self.plot2 = None
+ result = self.qWaitForDestroy(plot)
+ if not result:
+ # We can't test
+ self.skipTest("Object not destroyed")
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testStop(self):
+ """Test synchronization after calling stop"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.stop()
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testStopMovingStart(self):
+ """Test synchronization after calling stop, moving an axis, then start again"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.stop()
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.plot2.getXAxis().setLimits(1, 50)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ sync.start()
+
+ # The first axis is the reference
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testDoubleStop(self):
+ """Test double stop"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.stop()
+ self.assertRaises(RuntimeError, sync.stop)
+
+ def testDoubleStart(self):
+ """Test double stop"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ self.assertRaises(RuntimeError, sync.start)
+
+ def testScale(self):
+ """Test scale change"""
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ self.plot1.getXAxis().setScale(self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(self.plot1.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(self.plot2.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+ self.assertEqual(self.plot3.getXAxis().getScale(), self.plot1.getXAxis().LOGARITHMIC)
+
+ def testDirection(self):
+ """Test direction change"""
+ _sync = SyncAxes([self.plot1.getYAxis(), self.plot2.getYAxis(), self.plot3.getYAxis()])
+ self.plot1.getYAxis().setInverted(True)
+ self.assertEqual(self.plot1.getYAxis().isInverted(), True)
+ self.assertEqual(self.plot2.getYAxis().isInverted(), True)
+ self.assertEqual(self.plot3.getYAxis().isInverted(), True)
+
+ def testSyncCenter(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True)
+
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (100 - 10, 100 + 10))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (100 - 1, 100 + 1))
+
+ def testSyncCenterAndZoom(self):
+ """Test direction change"""
+ # Not the same scale
+ self.plot1.getXAxis().setLimits(0, 200)
+ self.plot2.getXAxis().setLimits(0, 20)
+ self.plot3.getXAxis().setLimits(0, 2)
+ _sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()],
+ syncLimits=False, syncCenter=True, syncZoom=True)
+
+ # Supposing all the plots use the same size
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (0, 200))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (0, 200))
+
+ def testAddAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis()])
+ sync.addAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot3.getXAxis().getLimits(), (10, 500))
+
+ def testRemoveAxis(self):
+ """Test synchronization after construction"""
+ sync = SyncAxes([self.plot1.getXAxis(), self.plot2.getXAxis(), self.plot3.getXAxis()])
+ sync.removeAxis(self.plot3.getXAxis())
+
+ self.plot1.getXAxis().setLimits(10, 500)
+ self.assertEqual(self.plot1.getXAxis().getLimits(), (10, 500))
+ self.assertEqual(self.plot2.getXAxis().getLimits(), (10, 500))
+ self.assertNotEqual(self.plot3.getXAxis().getLimits(), (10, 500))
diff --git a/src/silx/gui/plot/test/utils.py b/src/silx/gui/plot/test/utils.py
new file mode 100644
index 0000000..64fca56
--- /dev/null
+++ b/src/silx/gui/plot/test/utils.py
@@ -0,0 +1,93 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for PlotWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2018"
+
+
+import logging
+import pytest
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+from silx.gui.plot import PlotWidget
+
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.mark.usefixtures("test_options_class_attr")
+class PlotWidgetTestCase(TestCaseQt):
+ """Base class for tests of PlotWidget, not a TestCase in itself.
+
+ plot attribute is the PlotWidget created for the test.
+ """
+ __screenshot_already_taken = False
+ backend = None
+
+ def _createPlot(self):
+ return PlotWidget(backend=self.backend)
+
+ def setUp(self):
+ super(PlotWidgetTestCase, self).setUp()
+ self.plot = self._createPlot()
+ self.plot.show()
+ self.plotAlive = True
+ self.qWaitForWindowExposed(self.plot)
+ TestCaseQt.mouseClick(self, self.plot, button=qt.Qt.LeftButton, pos=(0, 0))
+
+ def __onPlotDestroyed(self):
+ self.plotAlive = False
+
+ def _waitForPlotClosed(self):
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.destroyed.connect(self.__onPlotDestroyed)
+ self.plot.close()
+ del self.plot
+ for _ in range(100):
+ if not self.plotAlive:
+ break
+ self.qWait(10)
+ else:
+ logger.error("Plot is still alive")
+
+ def tearDown(self):
+ if not self._currentTestSucceeded():
+ # MPL is the only widget which uses the real system mouse.
+ # In case of a the windows is outside of the screen, minimzed,
+ # overlapped by a system popup, the MPL widget will not receive the
+ # mouse event.
+ # Taking a screenshot help debuging this cases in the continuous
+ # integration environement.
+ if not PlotWidgetTestCase.__screenshot_already_taken:
+ PlotWidgetTestCase.__screenshot_already_taken = True
+ self.logScreenShot()
+ self.qapp.processEvents()
+ self._waitForPlotClosed()
+ super(PlotWidgetTestCase, self).tearDown()
diff --git a/silx/gui/plot/tools/CurveLegendsWidget.py b/src/silx/gui/plot/tools/CurveLegendsWidget.py
index 4a517dd..4a517dd 100644
--- a/silx/gui/plot/tools/CurveLegendsWidget.py
+++ b/src/silx/gui/plot/tools/CurveLegendsWidget.py
diff --git a/silx/gui/plot/tools/LimitsToolBar.py b/src/silx/gui/plot/tools/LimitsToolBar.py
index fc192a6..fc192a6 100644
--- a/silx/gui/plot/tools/LimitsToolBar.py
+++ b/src/silx/gui/plot/tools/LimitsToolBar.py
diff --git a/src/silx/gui/plot/tools/PositionInfo.py b/src/silx/gui/plot/tools/PositionInfo.py
new file mode 100644
index 0000000..8b95fbc
--- /dev/null
+++ b/src/silx/gui/plot/tools/PositionInfo.py
@@ -0,0 +1,373 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget displaying mouse coordinates in a PlotWidget.
+
+It can be configured to provide more information.
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/10/2017"
+
+
+import logging
+import numbers
+import traceback
+import weakref
+
+import numpy
+
+from ....utils.deprecation import deprecated
+from ... import qt
+from .. import items
+from ...widgets.ElidedLabel import ElidedLabel
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _PositionInfoLabel(ElidedLabel):
+ """QLabel with a default size larger than what is displayed."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ def sizeHint(self):
+ hint = super().sizeHint()
+ width = self.fontMetrics().boundingRect('##############').width()
+ return qt.QSize(max(hint.width(), width), hint.height())
+
+
+# PositionInfo ################################################################
+
+class PositionInfo(qt.QWidget):
+ """QWidget displaying coords converted from data coords of the mouse.
+
+ Provide this widget with a list of couple:
+
+ - A name to display before the data
+ - A function that takes (x, y) as arguments and returns something that
+ gets converted to a string.
+ If the result is a float it is converted with '%.7g' format.
+
+ To run the following sample code, a QApplication must be initialized.
+ First, create a PlotWindow and add a QToolBar where to place the
+ PositionInfo widget.
+
+ >>> from silx.gui.plot import PlotWindow
+ >>> from silx.gui import qt
+
+ >>> plot = PlotWindow() # Create a PlotWindow to add the widget to
+ >>> toolBar = qt.QToolBar() # Create a toolbar to place the widget in
+ >>> plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar) # Add it to plot
+
+ Then, create the PositionInfo widget and add it to the toolbar.
+ The PositionInfo widget is created with a list of converters, here
+ to display polar coordinates of the mouse position.
+
+ >>> import numpy
+ >>> from silx.gui.plot.tools import PositionInfo
+
+ >>> position = PositionInfo(plot=plot, converters=[
+ ... ('Radius', lambda x, y: numpy.sqrt(x*x + y*y)),
+ ... ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))])
+ >>> toolBar.addWidget(position) # Add the widget to the toolbar
+ <...>
+ >>> plot.show() # To display the PlotWindow with the position widget
+
+ :param plot: The PlotWidget this widget is displaying data coords from.
+ :param converters:
+ List of 2-tuple: name to display and conversion function from (x, y)
+ in data coords to displayed value.
+ If None, the default, it displays X and Y.
+ :param parent: Parent widget
+ """
+
+ SNAP_THRESHOLD_DIST = 5
+
+ def __init__(self, parent=None, plot=None, converters=None):
+ assert plot is not None
+ self._plotRef = weakref.ref(plot)
+ self._snappingMode = self.SNAPPING_DISABLED
+
+ super(PositionInfo, self).__init__(parent)
+
+ if converters is None:
+ converters = (('X', lambda x, y: x), ('Y', lambda x, y: y))
+
+ self._fields = [] # To store (QLineEdit, name, function (x, y)->v)
+
+ # Create a new layout with new widgets
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ # layout.setSpacing(0)
+
+ # Create all QLabel and store them with the corresponding converter
+ for name, func in converters:
+ layout.addWidget(qt.QLabel('<b>' + name + ':</b>'))
+
+ contentWidget = _PositionInfoLabel(self)
+ contentWidget.setText('------')
+ layout.addWidget(contentWidget)
+ self._fields.append((contentWidget, name, func))
+
+ layout.addStretch(1)
+ self.setLayout(layout)
+
+ # Connect to Plot events
+ plot.sigPlotSignal.connect(self._plotEvent)
+
+ def getPlotWidget(self):
+ """Returns the PlotWidget this widget is attached to or None.
+
+ :rtype: Union[~silx.gui.plot.PlotWidget,None]
+ """
+ return self._plotRef()
+
+ @property
+ @deprecated(replacement='getPlotWidget', since_version='0.8.0')
+ def plot(self):
+ return self.getPlotWidget()
+
+ def getConverters(self):
+ """Return the list of converters as 2-tuple (name, function)."""
+ return [(name, func) for _label, name, func in self._fields]
+
+ def _plotEvent(self, event):
+ """Handle events from the Plot.
+
+ :param dict event: Plot event
+ """
+ if event['event'] == 'mouseMoved':
+ x, y = event['x'], event['y']
+ xPixel, yPixel = event['xpixel'], event['ypixel']
+ self._updateStatusBar(x, y, xPixel, yPixel)
+
+ def updateInfo(self):
+ """Update displayed information"""
+ plot = self.getPlotWidget()
+ if plot is None:
+ _logger.error("Trying to update PositionInfo "
+ "while PlotWidget no longer exists")
+ return
+
+ widget = plot.getWidgetHandle()
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ xPixel, yPixel = position.x(), position.y()
+ dataPos = plot.pixelToData(xPixel, yPixel, check=True)
+ if dataPos is not None: # Inside plot area
+ x, y = dataPos
+ self._updateStatusBar(x, y, xPixel, yPixel)
+
+ def _updateStatusBar(self, x, y, xPixel, yPixel):
+ """Update information from the status bar using the definitions.
+
+ :param float x: Position-x in data
+ :param float y: Position-y in data
+ :param float xPixel: Position-x in pixels
+ :param float yPixel: Position-y in pixels
+ """
+ plot = self.getPlotWidget()
+ if plot is None:
+ return
+
+ styleSheet = "color: rgb(0, 0, 0);" # Default style
+ xData, yData = x, y
+
+ snappingMode = self.getSnappingMode()
+
+ # Snapping when crosshair either not requested or active
+ if (snappingMode & (self.SNAPPING_CURVE | self.SNAPPING_SCATTER) and
+ (not (snappingMode & self.SNAPPING_CROSSHAIR) or
+ plot.getGraphCursor())):
+ styleSheet = "color: rgb(255, 0, 0);" # Style far from item
+
+ if snappingMode & self.SNAPPING_ACTIVE_ONLY:
+ selectedItems = []
+
+ if snappingMode & self.SNAPPING_CURVE:
+ activeCurve = plot.getActiveCurve()
+ if activeCurve:
+ selectedItems.append(activeCurve)
+
+ if snappingMode & self.SNAPPING_SCATTER:
+ activeScatter = plot._getActiveItem(kind='scatter')
+ if activeScatter:
+ selectedItems.append(activeScatter)
+
+ else:
+ kinds = []
+ if snappingMode & self.SNAPPING_CURVE:
+ kinds.append(items.Curve)
+ kinds.append(items.Histogram)
+ if snappingMode & self.SNAPPING_SCATTER:
+ kinds.append(items.Scatter)
+ selectedItems = [item for item in plot.getItems()
+ if isinstance(item, tuple(kinds)) and item.isVisible()]
+
+ # Compute distance threshold
+ window = plot.window()
+ windowHandle = window.windowHandle()
+ if windowHandle is not None:
+ ratio = windowHandle.devicePixelRatio()
+ else:
+ ratio = qt.QGuiApplication.primaryScreen().devicePixelRatio()
+
+ # Baseline squared distance threshold
+ distInPixels = (self.SNAP_THRESHOLD_DIST * ratio)**2
+
+ for item in selectedItems:
+ if (snappingMode & self.SNAPPING_SYMBOLS_ONLY and (
+ not isinstance(item, items.SymbolMixIn) or
+ not item.getSymbol())):
+ # Only handled if item symbols are visible
+ continue
+
+ if isinstance(item, items.Histogram):
+ result = item.pick(xPixel, yPixel)
+ if result is not None: # Histogram picked
+ index = result.getIndices()[0]
+ edges = item.getBinEdgesData(copy=False)
+
+ # Snap to bin center and value
+ xData = 0.5 * (edges[index] + edges[index + 1])
+ yData = item.getValueData(copy=False)[index]
+
+ # Update label style sheet
+ styleSheet = "color: rgb(0, 0, 0);"
+ break
+
+ else: # Curve, Scatter
+ xArray = item.getXData(copy=False)
+ yArray = item.getYData(copy=False)
+ closestIndex = numpy.argmin(
+ pow(xArray - x, 2) + pow(yArray - y, 2))
+
+ xClosest = xArray[closestIndex]
+ yClosest = yArray[closestIndex]
+
+ if isinstance(item, items.YAxisMixIn):
+ axis = item.getYAxis()
+ else:
+ axis = 'left'
+
+ closestInPixels = plot.dataToPixel(
+ xClosest, yClosest, axis=axis)
+ if closestInPixels is not None:
+ curveDistInPixels = (
+ (closestInPixels[0] - xPixel)**2 +
+ (closestInPixels[1] - yPixel)**2)
+
+ if curveDistInPixels <= distInPixels:
+ # Update label style sheet
+ styleSheet = "color: rgb(0, 0, 0);"
+
+ # if close enough, snap to data point coord
+ xData, yData = xClosest, yClosest
+ distInPixels = curveDistInPixels
+
+ for label, name, func in self._fields:
+ label.setStyleSheet(styleSheet)
+
+ try:
+ value = func(xData, yData)
+ text = self.valueToString(value)
+ label.setText(text)
+ except:
+ label.setText('Error')
+ _logger.error(
+ "Error while converting coordinates (%f, %f)"
+ "with converter '%s'" % (xPixel, yPixel, name))
+ _logger.error(traceback.format_exc())
+
+ def valueToString(self, value):
+ if isinstance(value, (tuple, list)):
+ value = [self.valueToString(v) for v in value]
+ return ", ".join(value)
+ elif isinstance(value, numbers.Real):
+ # Use this for floats and int
+ return '%.7g' % value
+ else:
+ # Fallback for other types
+ return str(value)
+
+ # Snapping mode
+
+ SNAPPING_DISABLED = 0
+ """No snapping occurs"""
+
+ SNAPPING_CROSSHAIR = 1 << 0
+ """Snapping only enabled when crosshair cursor is enabled"""
+
+ SNAPPING_ACTIVE_ONLY = 1 << 1
+ """Snapping only enabled for active item"""
+
+ SNAPPING_SYMBOLS_ONLY = 1 << 2
+ """Snapping only when symbols are visible"""
+
+ SNAPPING_CURVE = 1 << 3
+ """Snapping on curves"""
+
+ SNAPPING_SCATTER = 1 << 4
+ """Snapping on scatter"""
+
+ def setSnappingMode(self, mode):
+ """Set the snapping mode.
+
+ The mode is a mask.
+
+ :param int mode: The mode to use
+ """
+ if mode != self._snappingMode:
+ self._snappingMode = mode
+ self.updateInfo()
+
+ def getSnappingMode(self):
+ """Returns the snapping mode as a mask
+
+ :rtype: int
+ """
+ return self._snappingMode
+
+ _SNAPPING_LEGACY = (SNAPPING_CROSSHAIR |
+ SNAPPING_ACTIVE_ONLY |
+ SNAPPING_SYMBOLS_ONLY |
+ SNAPPING_CURVE |
+ SNAPPING_SCATTER)
+ """Legacy snapping mode"""
+
+ @property
+ @deprecated(replacement="getSnappingMode", since_version="0.8")
+ def autoSnapToActiveCurve(self):
+ return self.getSnappingMode() == self._SNAPPING_LEGACY
+
+ @autoSnapToActiveCurve.setter
+ @deprecated(replacement="setSnappingMode", since_version="0.8")
+ def autoSnapToActiveCurve(self, flag):
+ self.setSnappingMode(
+ self._SNAPPING_LEGACY if flag else self.SNAPPING_DISABLED)
diff --git a/silx/gui/plot/tools/RadarView.py b/src/silx/gui/plot/tools/RadarView.py
index 7076835..7076835 100644
--- a/silx/gui/plot/tools/RadarView.py
+++ b/src/silx/gui/plot/tools/RadarView.py
diff --git a/silx/gui/plot/tools/__init__.py b/src/silx/gui/plot/tools/__init__.py
index 09f468c..09f468c 100644
--- a/silx/gui/plot/tools/__init__.py
+++ b/src/silx/gui/plot/tools/__init__.py
diff --git a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
index 44187ef..44187ef 100644
--- a/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
+++ b/src/silx/gui/plot/tools/profile/ScatterProfileToolBar.py
diff --git a/silx/gui/plot/tools/profile/__init__.py b/src/silx/gui/plot/tools/profile/__init__.py
index d91191e..d91191e 100644
--- a/silx/gui/plot/tools/profile/__init__.py
+++ b/src/silx/gui/plot/tools/profile/__init__.py
diff --git a/silx/gui/plot/tools/profile/core.py b/src/silx/gui/plot/tools/profile/core.py
index 200f5cf..200f5cf 100644
--- a/silx/gui/plot/tools/profile/core.py
+++ b/src/silx/gui/plot/tools/profile/core.py
diff --git a/silx/gui/plot/tools/profile/editors.py b/src/silx/gui/plot/tools/profile/editors.py
index 80e0452..80e0452 100644
--- a/silx/gui/plot/tools/profile/editors.py
+++ b/src/silx/gui/plot/tools/profile/editors.py
diff --git a/src/silx/gui/plot/tools/profile/manager.py b/src/silx/gui/plot/tools/profile/manager.py
new file mode 100644
index 0000000..4a22bc0
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/manager.py
@@ -0,0 +1,1079 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a manager to compute and display profiles.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+import logging
+import weakref
+
+from silx.gui import qt
+from silx.gui import colors
+from silx.gui import utils
+
+from silx.utils.weakref import WeakMethodProxy
+from silx.gui import icons
+from silx.gui.plot import PlotWidget
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.tools.roi import CreateRoiModeAction
+from silx.gui.plot import items
+from silx.gui.qt import silxGlobalThreadPool
+from silx.gui.qt import inspect
+from . import rois
+from . import core
+from . import editors
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _RunnableComputeProfile(qt.QRunnable):
+ """Runner to process profiles
+
+ :param qt.QThreadPool threadPool: The thread which will be used to
+ execute this runner. It is used to update the used signals
+ :param ~silx.gui.plot.items.Item item: Item in which the profile is
+ computed
+ :param ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn roi: ROI
+ defining the profile shape and other characteristics
+ """
+
+ class _Signals(qt.QObject):
+ """Signal holder"""
+ resultReady = qt.Signal(object, object)
+ runnerFinished = qt.Signal(object)
+
+ def __init__(self, threadPool, item, roi):
+ """Constructor
+ """
+ super(_RunnableComputeProfile, self).__init__()
+ self._signals = self._Signals()
+ self._signals.moveToThread(threadPool.thread())
+ self._item = item
+ self._roi = roi
+ self._cancelled = False
+
+ def _lazyCancel(self):
+ """Cancel the runner if it is not yet started.
+
+ The threadpool will still execute the runner, but this will process
+ nothing.
+
+ This is only used with Qt<5.9 where QThreadPool.tryTake is not available.
+ """
+ self._cancelled = True
+
+ def autoDelete(self):
+ return False
+
+ def getRoi(self):
+ """Returns the ROI in which the runner will compute a profile.
+
+ :rtype: ~silx.gui.plot.tools.profile.core.ProfileRoiMixIn
+ """
+ return self._roi
+
+ @property
+ def resultReady(self):
+ """Signal emitted when the result of the computation is available.
+
+ This signal provides 2 values: The ROI, and the computation result.
+ """
+ return self._signals.resultReady
+
+ @property
+ def runnerFinished(self):
+ """Signal emitted when runner have finished.
+
+ This signal provides a single value: the runner itself.
+ """
+ return self._signals.runnerFinished
+
+ def run(self):
+ """Process the profile computation.
+ """
+ if not self._cancelled:
+ try:
+ profileData = self._roi.computeProfile(self._item)
+ except Exception:
+ _logger.error("Error while computing profile", exc_info=True)
+ else:
+ self.resultReady.emit(self._roi, profileData)
+ self.runnerFinished.emit(self)
+
+
+class ProfileWindow(qt.QMainWindow):
+ """
+ Display a computed profile.
+
+ The content can be described using :meth:`setRoiProfile` if the source of
+ the profile is a profile ROI, and :meth:`setProfile` for the data content.
+ """
+
+ sigClose = qt.Signal()
+ """Emitted by :meth:`closeEvent` (e.g. when the window is closed
+ through the window manager's close icon)."""
+
+ def __init__(self, parent=None, backend=None):
+ qt.QMainWindow.__init__(self, parent=parent, flags=qt.Qt.Dialog)
+
+ self.setWindowTitle('Profile window')
+ self._plot1D = None
+ self._plot2D = None
+ self._backend = backend
+ self._data = None
+
+ widget = qt.QWidget()
+ self._layout = qt.QStackedLayout(widget)
+ self._layout.setContentsMargins(0, 0, 0, 0)
+ self.setCentralWidget(widget)
+
+ def prepareWidget(self, roi):
+ """Called before the show to prepare the window to use with
+ a specific ROI."""
+ if isinstance(roi, rois._DefaultImageStackProfileRoiMixIn):
+ profileType = roi.getProfileType()
+ else:
+ profileType = "1D"
+ if profileType == "1D":
+ self.getPlot1D()
+ elif profileType == "2D":
+ self.getPlot2D()
+
+ def createPlot1D(self, parent, backend):
+ """Inherit this function to create your own plot to render 1D
+ profiles. The default value is a `Plot1D`.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot.
+ See :class:`PlotWidget` for the list of supported backend.
+ :rtype: PlotWidget
+ """
+ # import here to avoid circular import
+ from ...PlotWindow import Plot1D
+ plot = Plot1D(parent=parent, backend=backend)
+ plot.setDataMargins(yMinMargin=0.1, yMaxMargin=0.1)
+ plot.setGraphYLabel('Profile')
+ plot.setGraphXLabel('')
+ return plot
+
+ def createPlot2D(self, parent, backend):
+ """Inherit this function to create your own plot to render 2D
+ profiles. The default value is a `Plot2D`.
+
+ :param parent: The parent of this widget or None.
+ :param backend: The backend to use for the plot.
+ See :class:`PlotWidget` for the list of supported backend.
+ :rtype: PlotWidget
+ """
+ # import here to avoid circular import
+ from ...PlotWindow import Plot2D
+ return Plot2D(parent=parent, backend=backend)
+
+ def getPlot1D(self, init=True):
+ """Return the current plot used to display curves and create it if it
+ does not yet exists and `init` is True. Else returns None."""
+ if not init:
+ return self._plot1D
+ if self._plot1D is None:
+ self._plot1D = self.createPlot1D(self, self._backend)
+ self._layout.addWidget(self._plot1D)
+ return self._plot1D
+
+ def _showPlot1D(self):
+ plot = self.getPlot1D()
+ self._layout.setCurrentWidget(plot)
+
+ def getPlot2D(self, init=True):
+ """Return the current plot used to display image and create it if it
+ does not yet exists and `init` is True. Else returns None."""
+ if not init:
+ return self._plot2D
+ if self._plot2D is None:
+ self._plot2D = self.createPlot2D(parent=self, backend=self._backend)
+ self._layout.addWidget(self._plot2D)
+ return self._plot2D
+
+ def _showPlot2D(self):
+ plot = self.getPlot2D()
+ self._layout.setCurrentWidget(plot)
+
+ def getCurrentPlotWidget(self):
+ return self._layout.currentWidget()
+
+ def closeEvent(self, qCloseEvent):
+ self.sigClose.emit()
+ qCloseEvent.accept()
+
+ def setRoiProfile(self, roi):
+ """Set the profile ROI which it the source of the following data
+ to display.
+
+ :param ProfileRoiMixIn roi: The profile ROI data source
+ """
+ if roi is None:
+ return
+ self.__color = colors.rgba(roi.getColor())
+
+ def _setImageProfile(self, data):
+ """
+ Setup the window to display a new profile data which is represented
+ by an image.
+
+ :param core.ImageProfileData data: Computed data profile
+ """
+ plot = self.getPlot2D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+
+
+ coords = data.coords
+ colormap = data.colormap
+ profileScale = (coords[-1] - coords[0]) / data.profile.shape[1], 1
+ plot.addImage(data.profile,
+ legend="profile",
+ colormap=colormap,
+ origin=(coords[0], 0),
+ scale=profileScale)
+ plot.getYAxis().setLabel("Frame index (depth)")
+
+ self._showPlot2D()
+
+ def _setCurveProfile(self, data):
+ """
+ Setup the window to display a new profile data which is represented
+ by a curve.
+
+ :param core.CurveProfileData data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ plot.addCurve(data.coords,
+ data.profile,
+ legend="level",
+ color=self.__color)
+
+ self._showPlot1D()
+
+ def _setRgbaProfile(self, data):
+ """
+ Setup the window to display a new profile data which is represented
+ by a curve.
+
+ :param core.RgbaProfileData data: Computed data profile
+ """
+ plot = self.getPlot1D()
+
+ plot.clear()
+ plot.setGraphTitle(data.title)
+ plot.getXAxis().setLabel(data.xLabel)
+ plot.getYAxis().setLabel(data.yLabel)
+
+ self._showPlot1D()
+
+ plot.addCurve(data.coords, data.profile,
+ legend="level", color="black")
+ plot.addCurve(data.coords, data.profile_r,
+ legend="red", color="red")
+ plot.addCurve(data.coords, data.profile_g,
+ legend="green", color="green")
+ plot.addCurve(data.coords, data.profile_b,
+ legend="blue", color="blue")
+ if data.profile_a is not None:
+ plot.addCurve(data.coords, data.profile_a, legend="alpha", color="gray")
+
+ def clear(self):
+ """Clear the window profile"""
+ plot = self.getPlot1D(init=False)
+ if plot is not None:
+ plot.clear()
+ plot = self.getPlot2D(init=False)
+ if plot is not None:
+ plot.clear()
+
+ def getProfile(self):
+ """Returns the profile data which is displayed"""
+ return self.__data
+
+ def setProfile(self, data):
+ """
+ Setup the window to display a new profile data.
+
+ This method dispatch the result to a specific method according to the
+ data type.
+
+ :param data: Computed data profile
+ """
+ self.__data = data
+ if data is None:
+ self.clear()
+ elif isinstance(data, core.ImageProfileData):
+ self._setImageProfile(data)
+ elif isinstance(data, core.RgbaProfileData):
+ self._setRgbaProfile(data)
+ elif isinstance(data, core.CurveProfileData):
+ self._setCurveProfile(data)
+ else:
+ raise TypeError("Unsupported type %s" % type(data))
+
+
+class _ClearAction(qt.QAction):
+ """Action to clear the profile manager
+
+ The action is only enabled if something can be cleaned up.
+ """
+
+ def __init__(self, parent, profileManager):
+ super(_ClearAction, self).__init__(parent)
+ self.__profileManager = weakref.ref(profileManager)
+ icon = icons.getQIcon('profile-clear')
+ self.setIcon(icon)
+ self.setText('Clear profile')
+ self.setToolTip('Clear the profiles')
+ self.setCheckable(False)
+ self.setEnabled(False)
+ self.triggered.connect(profileManager.clearProfile)
+ plot = profileManager.getPlotWidget()
+ roiManager = profileManager.getRoiManager()
+ plot.sigInteractiveModeChanged.connect(self.__modeUpdated)
+ roiManager.sigRoiChanged.connect(self.__roiListUpdated)
+
+ def getProfileManager(self):
+ return self.__profileManager()
+
+ def __roiListUpdated(self):
+ self.__update()
+
+ def __modeUpdated(self, source):
+ self.__update()
+
+ def __update(self):
+ profileManager = self.getProfileManager()
+ if profileManager is None:
+ return
+ roiManager = profileManager.getRoiManager()
+ if roiManager is None:
+ return
+ enabled = roiManager.isStarted() or len(roiManager.getRois()) > 0
+ self.setEnabled(enabled)
+
+
+class _StoreLastParamBehavior(qt.QObject):
+ """This object allow to store and restore the properties of the ROI
+ profiles"""
+
+ def __init__(self, parent):
+ assert isinstance(parent, ProfileManager)
+ super(_StoreLastParamBehavior, self).__init__(parent=parent)
+ self.__properties = {}
+ self.__profileRoi = None
+ self.__filter = utils.LockReentrant()
+
+ def _roi(self):
+ """Return the spied ROI"""
+ if self.__profileRoi is None:
+ return None
+ roi = self.__profileRoi()
+ if roi is None:
+ self.__profileRoi = None
+ return roi
+
+ def setProfileRoi(self, roi):
+ """Set a profile ROI to spy.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ previousRoi = self._roi()
+ if previousRoi is roi:
+ return
+ if previousRoi is not None:
+ previousRoi.sigProfilePropertyChanged.disconnect(self._profilePropertyChanged)
+ self.__profileRoi = None if roi is None else weakref.ref(roi)
+ if roi is not None:
+ roi.sigProfilePropertyChanged.connect(self._profilePropertyChanged)
+
+ def _profilePropertyChanged(self):
+ """Handle changes on the properties defining the profile ROI.
+ """
+ if self.__filter.locked():
+ return
+ roi = self.sender()
+ self.storeProperties(roi)
+
+ def storeProperties(self, roi):
+ if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI)):
+ self.__properties["method"] = roi.getProfileMethod()
+ self.__properties["line-width"] = roi.getProfileLineWidth()
+ self.__properties["type"] = roi.getProfileType()
+ elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
+ rois.ProfileImageCrossROI)):
+ self.__properties["method"] = roi.getProfileMethod()
+ self.__properties["line-width"] = roi.getProfileLineWidth()
+ elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
+ rois.ProfileScatterCrossROI)):
+ self.__properties["npoints"] = roi.getNPoints()
+
+ def restoreProperties(self, roi):
+ with self.__filter:
+ if isinstance(roi, (rois._DefaultImageStackProfileRoiMixIn,
+ rois.ProfileImageStackCrossROI)):
+ value = self.__properties.get("method", None)
+ if value is not None:
+ roi.setProfileMethod(value)
+ value = self.__properties.get("line-width", None)
+ if value is not None:
+ roi.setProfileLineWidth(value)
+ value = self.__properties.get("type", None)
+ if value is not None:
+ roi.setProfileType(value)
+ elif isinstance(roi, (rois._DefaultImageProfileRoiMixIn,
+ rois.ProfileImageCrossROI)):
+ value = self.__properties.get("method", None)
+ if value is not None:
+ roi.setProfileMethod(value)
+ value = self.__properties.get("line-width", None)
+ if value is not None:
+ roi.setProfileLineWidth(value)
+ elif isinstance(roi, (rois._DefaultScatterProfileRoiMixIn,
+ rois.ProfileScatterCrossROI)):
+ value = self.__properties.get("npoints", None)
+ if value is not None:
+ roi.setNPoints(value)
+
+
+class ProfileManager(qt.QObject):
+ """Base class for profile management tools
+
+ :param plot: :class:`~silx.gui.plot.PlotWidget` on which to operate.
+ :param plot: :class:`~silx.gui.plot.tools.roi.RegionOfInterestManager`
+ on which to operate.
+ """
+ def __init__(self, parent=None, plot=None, roiManager=None):
+ super(ProfileManager, self).__init__(parent)
+
+ assert isinstance(plot, PlotWidget)
+ self._plotRef = weakref.ref(
+ plot, WeakMethodProxy(self.__plotDestroyed))
+
+ # Set-up interaction manager
+ if roiManager is None:
+ roiManager = RegionOfInterestManager(plot)
+
+ self._roiManagerRef = weakref.ref(roiManager)
+ self._rois = []
+ self._pendingRunners = []
+ """List of ROIs which have to be updated"""
+
+ self.__reentrantResults = {}
+ """Store reentrant result to avoid to skip some of them
+ cause the implementation uses a QEventLoop."""
+
+ self._profileWindowClass = ProfileWindow
+ """Class used to display the profile results"""
+
+ self._computedProfiles = 0
+ """Statistics for tests"""
+
+ self.__itemTypes = []
+ """Kind of items to use"""
+
+ self.__tracking = False
+ """Is the plot active items are tracked"""
+
+ self.__useColorFromCursor = True
+ """If true, force the ROI color with the colormap marker color"""
+
+ self._item = None
+ """The selected item"""
+
+ self.__singleProfileAtATime = True
+ """When it's true, only a single profile is displayed at a time."""
+
+ self._previousWindowGeometry = []
+
+ self._storeProperties = _StoreLastParamBehavior(self)
+ """If defined the profile properties of the last ROI are reused to the
+ new created ones"""
+
+ # Listen to plot limits changed
+ plot.getXAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
+ plot.getYAxis().sigLimitsChanged.connect(self.requestUpdateAllProfile)
+
+ roiManager.sigInteractiveModeFinished.connect(self.__interactionFinished)
+ roiManager.sigInteractiveRoiCreated.connect(self.__roiCreated)
+ roiManager.sigRoiAdded.connect(self.__roiAdded)
+ roiManager.sigRoiAboutToBeRemoved.connect(self.__roiRemoved)
+
+ def setSingleProfile(self, enable):
+ """
+ Enable or disable the single profile mode.
+
+ In single mode, the manager enforce a single ROI at the same
+ time. A new one will remove the previous one.
+
+ If this mode is not enabled, many ROIs can be created, and many
+ profile windows will be displayed.
+ """
+ self.__singleProfileAtATime = enable
+
+ def isSingleProfile(self):
+ """
+ Returns true if the manager is in a single profile mode.
+
+ :rtype: bool
+ """
+ return self.__singleProfileAtATime
+
+ def __interactionFinished(self):
+ """Handle end of interactive mode"""
+ pass
+
+ def __roiAdded(self, roi):
+ """Handle new ROI"""
+ # Filter out non profile ROIs
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.__addProfile(roi)
+
+ def __roiRemoved(self, roi):
+ """Handle removed ROI"""
+ # Filter out non profile ROIs
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return
+ self.__removeProfile(roi)
+
+ def createProfileAction(self, profileRoiClass, parent=None):
+ """Create an action from a class of ProfileRoi
+
+ :param core.ProfileRoiMixIn profileRoiClass: A class of a profile ROI
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ if not issubclass(profileRoiClass, core.ProfileRoiMixIn):
+ raise TypeError("Type %s not expected" % type(profileRoiClass))
+ roiManager = self.getRoiManager()
+ action = CreateRoiModeAction(parent, roiManager, profileRoiClass)
+ if hasattr(profileRoiClass, "ICON"):
+ action.setIcon(icons.getQIcon(profileRoiClass.ICON))
+ if hasattr(profileRoiClass, "NAME"):
+ def articulify(word):
+ """Add an an/a article in the front of the word"""
+ first = word[1] if word[0] == 'h' else word[0]
+ if first in "aeiou":
+ return "an " + word
+ return "a " + word
+ action.setText('Define %s' % articulify(profileRoiClass.NAME))
+ action.setToolTip('Enables %s selection mode' % profileRoiClass.NAME)
+ action.setSingleShot(True)
+ return action
+
+ def createClearAction(self, parent):
+ """Create an action to clean up the plot from the profile ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ action = _ClearAction(parent, self)
+ return action
+
+ def createImageActions(self, parent):
+ """Create actions designed for image items. This actions created
+ new ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileImageHorizontalLineROI,
+ rois.ProfileImageVerticalLineROI,
+ rois.ProfileImageLineROI,
+ rois.ProfileImageDirectedLineROI,
+ rois.ProfileImageCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createScatterActions(self, parent):
+ """Create actions designed for scatter items. This actions created
+ new ROIs.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileScatterHorizontalLineROI,
+ rois.ProfileScatterVerticalLineROI,
+ rois.ProfileScatterLineROI,
+ rois.ProfileScatterCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createScatterSliceActions(self, parent):
+ """Create actions designed for regular scatter items. This actions
+ created new ROIs.
+
+ This ROIs was designed to use the input data without interpolation,
+ like you could do with an image.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileScatterHorizontalSliceROI,
+ rois.ProfileScatterVerticalSliceROI,
+ rois.ProfileScatterCrossSliceROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createImageStackActions(self, parent):
+ """Create actions designed for stack image items. This actions
+ created new ROIs.
+
+ This ROIs was designed to create both profile on the displayed image
+ and profile on the full stack (2D result).
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: List[qt.QAction]
+ """
+ profileClasses = [
+ rois.ProfileImageStackHorizontalLineROI,
+ rois.ProfileImageStackVerticalLineROI,
+ rois.ProfileImageStackLineROI,
+ rois.ProfileImageStackCrossROI,
+ ]
+ return [self.createProfileAction(pc, parent=parent) for pc in profileClasses]
+
+ def createEditorAction(self, parent):
+ """Create an action containing GUI to edit the selected profile ROI.
+
+ :param qt.QObject parent: The parent of the created action.
+ :rtype: qt.QAction
+ """
+ action = editors.ProfileRoiEditorAction(parent)
+ action.setRoiManager(self.getRoiManager())
+ return action
+
+ def setItemType(self, image=False, scatter=False):
+ """Set the item type to use and select the active one.
+
+ :param bool image: Image item are allowed
+ :param bool scatter: Scatter item are allowed
+ """
+ self.__itemTypes = []
+ plot = self.getPlotWidget()
+ item = None
+ if image:
+ self.__itemTypes.append("image")
+ item = plot.getActiveImage()
+ if scatter:
+ self.__itemTypes.append("scatter")
+ if item is None:
+ item = plot.getActiveScatter()
+ self.setPlotItem(item)
+
+ def setProfileWindowClass(self, profileWindowClass):
+ """Set the class which will be instantiated to display profile result.
+ """
+ self._profileWindowClass = profileWindowClass
+
+ def setActiveItemTracking(self, tracking):
+ """Enable/disable the tracking of the active item of the plot.
+
+ :param bool tracking: Tracking mode
+ """
+ if self.__tracking == tracking:
+ return
+ plot = self.getPlotWidget()
+ if self.__tracking:
+ plot.sigActiveImageChanged.disconnect(self._activeImageChanged)
+ plot.sigActiveScatterChanged.disconnect(self._activeScatterChanged)
+ self.__tracking = tracking
+ if self.__tracking:
+ plot.sigActiveImageChanged.connect(self.__activeImageChanged)
+ plot.sigActiveScatterChanged.connect(self.__activeScatterChanged)
+
+ def setDefaultColorFromCursorColor(self, enabled):
+ """Enabled/disable the use of the colormap cursor color to display the
+ ROIs.
+
+ If set, the manager will update the color of the profile ROIs using the
+ current colormap cursor color from the selected item.
+ """
+ self.__useColorFromCursor = enabled
+
+ def __activeImageChanged(self, previous, legend):
+ """Handle plot item selection"""
+ if "image" in self.__itemTypes:
+ plot = self.getPlotWidget()
+ item = plot.getImage(legend)
+ self.setPlotItem(item)
+
+ def __activeScatterChanged(self, previous, legend):
+ """Handle plot item selection"""
+ if "scatter" in self.__itemTypes:
+ plot = self.getPlotWidget()
+ item = plot.getScatter(legend)
+ self.setPlotItem(item)
+
+ def __roiCreated(self, roi):
+ """Handle ROI creation"""
+ # Filter out non profile ROIs
+ if isinstance(roi, core.ProfileRoiMixIn):
+ if self._storeProperties is not None:
+ # Initialize the properties with the previous ones
+ self._storeProperties.restoreProperties(roi)
+
+ def __addProfile(self, profileRoi):
+ """Add a new ROI to the manager."""
+ if profileRoi.getFocusProxy() is None:
+ if self._storeProperties is not None:
+ # Follow changes on properties
+ self._storeProperties.setProfileRoi(profileRoi)
+ if self.__singleProfileAtATime:
+ # FIXME: It would be good to reuse the windows to avoid blinking
+ self.clearProfile()
+
+ profileRoi._setProfileManager(self)
+ self._updateRoiColor(profileRoi)
+ self._rois.append(profileRoi)
+ self.requestUpdateProfile(profileRoi)
+
+ def __removeProfile(self, profileRoi):
+ """Remove a ROI from the manager."""
+ window = self._disconnectProfileWindow(profileRoi)
+ if window is not None:
+ geometry = window.geometry()
+ if not geometry.isEmpty():
+ self._previousWindowGeometry.append(geometry)
+ self.clearProfileWindow(window)
+ if profileRoi in self._rois:
+ self._rois.remove(profileRoi)
+
+ def _disconnectProfileWindow(self, profileRoi):
+ """Handle profile window close."""
+ window = profileRoi.getProfileWindow()
+ profileRoi.setProfileWindow(None)
+ return window
+
+ def clearProfile(self):
+ """Clear the associated ROI profile"""
+ roiManager = self.getRoiManager()
+ for roi in list(self._rois):
+ if roi.getFocusProxy() is not None:
+ # Skip sub ROIs, it will be removed by their parents
+ continue
+ roiManager.removeRoi(roi)
+
+ if not roiManager.isDrawing():
+ # Clean the selected mode
+ roiManager.stop()
+
+ def hasPendingOperations(self):
+ """Returns true if a thread is still computing or displaying a profile.
+
+ :rtype: bool
+ """
+ return len(self.__reentrantResults) > 0 or len(self._pendingRunners) > 0
+
+ def requestUpdateAllProfile(self):
+ """Request to update the profile of all the managed ROIs.
+ """
+ for roi in self._rois:
+ self.requestUpdateProfile(roi)
+
+ def requestUpdateProfile(self, profileRoi):
+ """Request to update a specific profile ROI.
+
+ :param ~core.ProfileRoiMixIn profileRoi:
+ """
+ if profileRoi.computeProfile is None:
+ return
+ threadPool = silxGlobalThreadPool()
+
+ # Clean up deprecated runners
+ for runner in list(self._pendingRunners):
+ if not inspect.isValid(runner):
+ self._pendingRunners.remove(runner)
+ continue
+ if runner.getRoi() is profileRoi:
+ if hasattr(threadPool, "tryTake"):
+ if threadPool.tryTake(runner):
+ self._pendingRunners.remove(runner)
+ else: # Support Qt<5.9
+ runner._lazyCancel()
+
+ item = self.getPlotItem()
+ if item is None or not isinstance(item, profileRoi.ITEM_KIND):
+ # This item is not compatible with this profile
+ profileRoi._setPlotItem(None)
+ profileWindow = profileRoi.getProfileWindow()
+ if profileWindow is not None:
+ profileWindow.setProfile(None)
+ return
+
+ profileRoi._setPlotItem(item)
+ runner = _RunnableComputeProfile(threadPool, item, profileRoi)
+ runner.runnerFinished.connect(self.__cleanUpRunner)
+ runner.resultReady.connect(self.__displayResult)
+ self._pendingRunners.append(runner)
+ threadPool.start(runner)
+
+ def __cleanUpRunner(self, runner):
+ """Remove a thread pool runner from the list of hold tasks.
+
+ Called at the termination of the runner.
+ """
+ if runner in self._pendingRunners:
+ self._pendingRunners.remove(runner)
+
+ def __displayResult(self, roi, profileData):
+ """Display the result of a ROI.
+
+ :param ~core.ProfileRoiMixIn profileRoi: A managed ROI
+ :param ~core.CurveProfileData profileData: Computed data profile
+ """
+ if roi in self.__reentrantResults:
+ # Store the data to process it in the main loop
+ # And not a sub loop created by initProfileWindow
+ # This also remove the duplicated requested
+ self.__reentrantResults[roi] = profileData
+ return
+
+ self.__reentrantResults[roi] = profileData
+ self._computedProfiles = self._computedProfiles + 1
+ window = roi.getProfileWindow()
+ if window is None:
+ plot = self.getPlotWidget()
+ window = self.createProfileWindow(plot, roi)
+ # roi.profileWindow have to be set before initializing the window
+ # Cause the initialization is using QEventLoop
+ roi.setProfileWindow(window)
+ self.initProfileWindow(window, roi)
+ window.show()
+
+ lastData = self.__reentrantResults.pop(roi)
+ window.setProfile(lastData)
+
+ def __plotDestroyed(self, ref):
+ """Handle finalization of PlotWidget
+
+ :param ref: weakref to the plot
+ """
+ self._plotRef = None
+ self._roiManagerRef = None
+ self._pendingRunners = []
+
+ def setPlotItem(self, item):
+ """Set the plot item focused by the profile manager.
+
+ :param ~silx.gui.plot.items.Item item: A plot item
+ """
+ previous = self.getPlotItem()
+ if previous is item:
+ return
+ if item is None:
+ self._item = None
+ else:
+ item.sigItemChanged.connect(self.__itemChanged)
+ self._item = weakref.ref(item)
+ self._updateRoiColors()
+ self.requestUpdateAllProfile()
+
+ def getDefaultColor(self, item):
+ """Returns the default ROI color to use according to the given item.
+
+ :param ~silx.gui.plot.items.item.Item item: AN item
+ :rtype: qt.QColor
+ """
+ color = 'pink'
+ if isinstance(item, items.ColormapMixIn):
+ colormap = item.getColormap()
+ name = colormap.getName()
+ if name is not None:
+ color = colors.cursorColorForColormap(name)
+ color = colors.asQColor(color)
+ return color
+
+ def _updateRoiColors(self):
+ """Update ROI color according to the item selection"""
+ if not self.__useColorFromCursor:
+ return
+ item = self.getPlotItem()
+ color = self.getDefaultColor(item)
+ for roi in self._rois:
+ roi.setColor(color)
+
+ def _updateRoiColor(self, roi):
+ """Update a specific ROI according to the current selected item.
+
+ :param RegionOfInterest roi: The ROI to update
+ """
+ if not self.__useColorFromCursor:
+ return
+ item = self.getPlotItem()
+ color = self.getDefaultColor(item)
+ roi.setColor(color)
+
+ def __itemChanged(self, changeType):
+ """Handle item changes.
+ """
+ if changeType in (items.ItemChangedType.DATA,
+ items.ItemChangedType.MASK,
+ items.ItemChangedType.POSITION,
+ items.ItemChangedType.SCALE):
+ self.requestUpdateAllProfile()
+ elif changeType == (items.ItemChangedType.COLORMAP):
+ self._updateRoiColors()
+
+ def getPlotItem(self):
+ """Returns the item focused by the profile manager.
+
+ :rtype: ~silx.gui.plot.items.Item
+ """
+ if self._item is None:
+ return None
+ item = self._item()
+ if item is None:
+ self._item = None
+ return item
+
+ def getPlotWidget(self):
+ """The plot associated to the profile manager.
+
+ :rtype: ~silx.gui.plot.PlotWidget
+ """
+ if self._plotRef is None:
+ return None
+ plot = self._plotRef()
+ if plot is None:
+ self._plotRef = None
+ return plot
+
+ def getCurrentRoi(self):
+ """Returns the currently selected ROI, else None.
+
+ :rtype: core.ProfileRoiMixIn
+ """
+ roiManager = self.getRoiManager()
+ if roiManager is None:
+ return None
+ roi = roiManager.getCurrentRoi()
+ if not isinstance(roi, core.ProfileRoiMixIn):
+ return None
+ return roi
+
+ def getRoiManager(self):
+ """Returns the used ROI manager
+
+ :rtype: RegionOfInterestManager
+ """
+ return self._roiManagerRef()
+
+ def createProfileWindow(self, plot, roi):
+ """Create a new profile window.
+
+ :param ~core.ProfileRoiMixIn roi: The plot containing the raw data
+ :param ~core.ProfileRoiMixIn roi: A managed ROI
+ :rtype: ~ProfileWindow
+ """
+ return self._profileWindowClass(plot)
+
+ def initProfileWindow(self, profileWindow, roi):
+ """This function is called just after the profile window creation in
+ order to initialize the window location.
+
+ :param ~ProfileWindow profileWindow:
+ The profile window to initialize.
+ """
+ # Enforce the use of one of the widgets
+ # To have the correct window size
+ profileWindow.prepareWidget(roi)
+ profileWindow.adjustSize()
+
+ # Trick to avoid blinking while retrieving the right window size
+ # Display the window, hide it and wait for some event loops
+ profileWindow.show()
+ profileWindow.hide()
+ eventLoop = qt.QEventLoop(self)
+ for _ in range(10):
+ if not eventLoop.processEvents():
+ break
+
+ profileWindow.show()
+ if len(self._previousWindowGeometry) > 0:
+ geometry = self._previousWindowGeometry.pop()
+ profileWindow.setGeometry(geometry)
+ return
+
+ window = self.getPlotWidget().window()
+ winGeom = window.frameGeometry()
+ if qt.BINDING in ("PySide2", "PyQt5"):
+ qapp = qt.QApplication.instance()
+ desktop = qapp.desktop()
+ screenGeom = desktop.availableGeometry(window)
+ else: # Qt6 (and also Qt>=5.14)
+ screenGeom = window.screen().availableGeometry()
+ spaceOnLeftSide = winGeom.left()
+ spaceOnRightSide = screenGeom.width() - winGeom.right()
+
+ profileGeom = profileWindow.frameGeometry()
+ profileWidth = profileGeom.width()
+
+ # Align vertically to the center of the window
+ top = winGeom.top() + (winGeom.height() - profileGeom.height()) // 2
+
+ margin = 5
+ if profileWidth < spaceOnRightSide:
+ # Place profile on the right
+ left = winGeom.right() + margin
+ elif profileWidth < spaceOnLeftSide:
+ # Place profile on the left
+ left = max(0, winGeom.left() - profileWidth - margin)
+ else:
+ # Move it as much as possible where there is more space
+ if spaceOnLeftSide > spaceOnRightSide:
+ left = 0
+ else:
+ left = screenGeom.width() - profileGeom.width()
+ profileWindow.move(left, top)
+
+
+ def clearProfileWindow(self, profileWindow):
+ """Called when a profile window is not anymore needed.
+
+ By default the window will be closed. But it can be
+ inherited to change this behavior.
+ """
+ profileWindow.deleteLater()
diff --git a/src/silx/gui/plot/tools/profile/rois.py b/src/silx/gui/plot/tools/profile/rois.py
new file mode 100644
index 0000000..9eef622
--- /dev/null
+++ b/src/silx/gui/plot/tools/profile/rois.py
@@ -0,0 +1,1156 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module define ROIs for profile tools.
+
+.. inheritance-diagram::
+ silx.gui.plot.tools.profile.rois
+ :top-classes: silx.gui.plot.tools.profile.core.ProfileRoiMixIn, silx.gui.plot.items.roi.RegionOfInterest
+ :parts: 1
+ :private-bases:
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "01/12/2020"
+
+import numpy
+import weakref
+from concurrent.futures import CancelledError
+
+from silx.gui import colors
+
+from silx.gui.plot import items
+from silx.gui.plot.items import roi as roi_items
+from . import core
+from silx.gui import utils
+from .....utils.proxy import docstring
+
+
+def _relabelAxes(plot, text):
+ """Relabel {xlabel} and {ylabel} from this text using the corresponding
+ plot axis label. If the axis label is empty, label it with "X" and "Y".
+
+ :rtype: str
+ """
+ xLabel = plot.getXAxis().getLabel()
+ if not xLabel:
+ xLabel = "X"
+ yLabel = plot.getYAxis().getLabel()
+ if not yLabel:
+ yLabel = "Y"
+ return text.format(xlabel=xLabel, ylabel=yLabel)
+
+
+def _lineProfileTitle(x0, y0, x1, y1):
+ """Compute corresponding plot title
+
+ This can be overridden to change title behavior.
+
+ :param float x0: Profile start point X coord
+ :param float y0: Profile start point Y coord
+ :param float x1: Profile end point X coord
+ :param float y1: Profile end point Y coord
+ :return: Title to use
+ :rtype: str
+ """
+ if x0 == x1:
+ title = '{xlabel} = %g; {ylabel} = [%g, %g]' % (x0, y0, y1)
+ elif y0 == y1:
+ title = '{ylabel} = %g; {xlabel} = [%g, %g]' % (y0, x0, x1)
+ else:
+ m = (y1 - y0) / (x1 - x0)
+ b = y0 - m * x0
+ title = '{ylabel} = %g * {xlabel} %+g' % (m, b)
+
+ return title
+
+
+class _ImageProfileArea(items.Shape):
+ """This shape displays the location of pixels used to compute the
+ profile."""
+
+ def __init__(self, parentRoi):
+ items.Shape.__init__(self, "polygon")
+ color = colors.rgba(parentRoi.getColor())
+ self.setColor(color)
+ self.setFill(True)
+ self.setOverlay(True)
+ self.setPoints([[0, 0], [0, 0]]) # Else it segfault
+
+ self.__parentRoi = weakref.ref(parentRoi)
+ parentRoi.sigItemChanged.connect(self._updateAreaProperty)
+ parentRoi.sigRegionChanged.connect(self._updateArea)
+ parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
+ parentRoi.sigPlotItemChanged.connect(self._updateArea)
+
+ def getParentRoi(self):
+ if self.__parentRoi is None:
+ return None
+ parentRoi = self.__parentRoi()
+ if parentRoi is None:
+ self.__parentRoi = None
+ return parentRoi
+
+ def _updateAreaProperty(self, event=None, checkVisibility=True):
+ parentRoi = self.sender()
+ if event == items.ItemChangedType.COLOR:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+ elif event == items.ItemChangedType.VISIBLE:
+ if self.getPlotItem() is not None:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+
+ def _updateArea(self):
+ roi = self.getParentRoi()
+ item = roi.getPlotItem()
+ if item is None:
+ self.setVisible(False)
+ return
+ polygon = self._computePolygon(item)
+ self.setVisible(True)
+ polygon = numpy.array(polygon).T
+ self.setLineStyle("--")
+ self.setPoints(polygon, copy=False)
+
+ def _computePolygon(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ currentData = item.getValueData(copy=False)
+
+ roi = self.getParentRoi()
+ origin = item.getOrigin()
+ scale = item.getScale()
+ _coords, _profile, area, _profileName, _xLabel = core.createProfile(
+ roiInfo=roi._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=roi.getProfileLineWidth(),
+ method="none")
+ return area
+
+
+class _SliceProfileArea(items.Shape):
+ """This shape displays the location a profile in a scatter.
+
+ Each point used to compute the slice are linked together.
+ """
+
+ def __init__(self, parentRoi):
+ items.Shape.__init__(self, "polygon")
+ color = colors.rgba(parentRoi.getColor())
+ self.setColor(color)
+ self.setFill(True)
+ self.setOverlay(True)
+ self.setPoints([[0, 0], [0, 0]]) # Else it segfault
+
+ self.__parentRoi = weakref.ref(parentRoi)
+ parentRoi.sigItemChanged.connect(self._updateAreaProperty)
+ parentRoi.sigRegionChanged.connect(self._updateArea)
+ parentRoi.sigProfilePropertyChanged.connect(self._updateArea)
+ parentRoi.sigPlotItemChanged.connect(self._updateArea)
+
+ def getParentRoi(self):
+ if self.__parentRoi is None:
+ return None
+ parentRoi = self.__parentRoi()
+ if parentRoi is None:
+ self.__parentRoi = None
+ return parentRoi
+
+ def _updateAreaProperty(self, event=None, checkVisibility=True):
+ parentRoi = self.sender()
+ if event == items.ItemChangedType.COLOR:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+ elif event == items.ItemChangedType.VISIBLE:
+ if self.getPlotItem() is not None:
+ parentRoi._updateItemProperty(event, parentRoi, self)
+
+ def _updateArea(self):
+ roi = self.getParentRoi()
+ item = roi.getPlotItem()
+ if item is None:
+ self.setVisible(False)
+ return
+ polylines = self._computePolylines(roi, item)
+ if polylines is None:
+ self.setVisible(False)
+ return
+ self.setVisible(True)
+ self.setLineStyle("--")
+ self.setPoints(polylines, copy=False)
+
+ def _computePolylines(self, roi, item):
+ slicing = roi._getSlice(item)
+ if slicing is None:
+ return None
+ xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
+ xx, yy = xx[slicing], yy[slicing]
+ polylines = numpy.array((xx, yy)).T
+ if len(polylines) == 0:
+ return None
+ return polylines
+
+
+class _DefaultImageProfileRoiMixIn(core.ProfileRoiMixIn):
+ """Provide common behavior for silx default image profile ROI.
+ """
+
+ ITEM_KIND = items.ImageBase
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__method = "mean"
+ self.__width = 1
+ self.sigRegionChanged.connect(self.__regionChanged)
+ self.sigPlotItemChanged.connect(self.__updateArea)
+ self.__area = _ImageProfileArea(self)
+ self.addItem(self.__area)
+
+ def __regionChanged(self):
+ self.invalidateProfile()
+ self.__updateArea()
+
+ def setProfileMethod(self, method):
+ """
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ """
+ if self.__method == method:
+ return
+ self.__method = method
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def getProfileMethod(self):
+ return self.__method
+
+ def setProfileLineWidth(self, width):
+ if self.__width == width:
+ return
+ self.__width = width
+ self.__updateArea()
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def getProfileLineWidth(self):
+ return self.__width
+
+ def __updateArea(self):
+ plotItem = self.getPlotItem()
+ if plotItem is None:
+ self.setLineStyle("-")
+ else:
+ self.setLineStyle("--")
+
+ def _getRoiInfo(self):
+ """Wrapper to allow to reuse the previous Profile code.
+
+ It would be good to remove it at one point.
+ """
+ if isinstance(self, roi_items.HorizontalLineROI):
+ lineProjectionMode = 'X'
+ y = self.getPosition()
+ roiStart = (0, y)
+ roiEnd = (1, y)
+ elif isinstance(self, roi_items.VerticalLineROI):
+ lineProjectionMode = 'Y'
+ x = self.getPosition()
+ roiStart = (x, 0)
+ roiEnd = (x, 1)
+ elif isinstance(self, roi_items.LineROI):
+ lineProjectionMode = 'D'
+ roiStart, roiEnd = self.getEndPoints()
+ else:
+ assert False
+
+ return roiStart, roiEnd, lineProjectionMode
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=self._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=lineWidth,
+ method=method)
+ return coords, profile, profileName, xLabel
+
+ currentData = item.getValueData(copy=False)
+
+ yLabel = "%s" % str(method).capitalize()
+ coords, profile, title, xLabel = createProfile2(currentData)
+ title = title + "; width = %d" % lineWidth
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ title = _relabelAxes(plot, title)
+ xLabel = _relabelAxes(plot, xLabel)
+
+ if isinstance(item, items.ImageRgba):
+ rgba = item.getData(copy=False)
+ _coords, r, _profileName, _xLabel = createProfile2(rgba[..., 0])
+ _coords, g, _profileName, _xLabel = createProfile2(rgba[..., 1])
+ _coords, b, _profileName, _xLabel = createProfile2(rgba[..., 2])
+ if rgba.shape[-1] == 4:
+ _coords, a, _profileName, _xLabel = createProfile2(rgba[..., 3])
+ else:
+ a = [None]
+ data = core.RgbaProfileData(
+ coords=coords,
+ profile=profile[0],
+ profile_r=r[0],
+ profile_g=g[0],
+ profile_b=b[0],
+ profile_a=a[0],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ else:
+ data = core.CurveProfileData(
+ coords=coords,
+ profile=profile[0],
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class ProfileImageHorizontalLineROI(roi_items.HorizontalLineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of an image"""
+
+ ICON = 'shape-horizontal'
+ NAME = 'horizontal line profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageVerticalLineROI(roi_items.VerticalLineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for a vertical profile at a location of an image"""
+
+ ICON = 'shape-vertical'
+ NAME = 'vertical line profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageLineROI(roi_items.LineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for an image profile between 2 points.
+
+ The X profile of this ROI is the projecting into one of the x/y axes,
+ using its scale and its orientation.
+ """
+
+ ICON = 'shape-diagonal'
+ NAME = 'line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageDirectedLineROI(roi_items.LineROI,
+ _DefaultImageProfileRoiMixIn):
+ """ROI for an image profile between 2 points.
+
+ The X profile of the line is displayed projected into the line itself,
+ using its scale and its orientation. It's the distance from the origin.
+ """
+
+ ICON = 'shape-diagonal-directed'
+ NAME = 'directed line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageProfileRoiMixIn.__init__(self, parent=parent)
+ self._handleStart.setSymbol('o')
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageBase):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ from silx.image.bilinear import BilinearImage
+
+ origin = item.getOrigin()
+ scale = item.getScale()
+ method = self.getProfileMethod()
+ lineWidth = self.getProfileLineWidth()
+ currentData = item.getValueData(copy=False)
+
+ roiInfo = self._getRoiInfo()
+ roiStart, roiEnd, _lineProjectionMode = roiInfo
+
+ startPt = ((roiStart[1] - origin[1]) / scale[1],
+ (roiStart[0] - origin[0]) / scale[0])
+ endPt = ((roiEnd[1] - origin[1]) / scale[1],
+ (roiEnd[0] - origin[0]) / scale[0])
+
+ if numpy.array_equal(startPt, endPt):
+ return None
+
+ bilinear = BilinearImage(currentData)
+ profile = bilinear.profile_line(
+ (startPt[0] - 0.5, startPt[1] - 0.5),
+ (endPt[0] - 0.5, endPt[1] - 0.5),
+ lineWidth,
+ method=method)
+
+ # Compute the line size
+ lineSize = numpy.sqrt((roiEnd[1] - roiStart[1]) ** 2 +
+ (roiEnd[0] - roiStart[0]) ** 2)
+ coords = numpy.linspace(0, lineSize, len(profile),
+ endpoint=True,
+ dtype=numpy.float32)
+
+ title = _lineProfileTitle(*roiStart, *roiEnd)
+ title = title + "; width = %d" % lineWidth
+ xLabel = "√({xlabel}²+{ylabel}²)"
+ yLabel = str(method).capitalize()
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = _relabelAxes(plot, xLabel)
+ title = _relabelAxes(plot, title)
+
+ data = core.CurveProfileData(
+ coords=coords,
+ profile=profile,
+ title=title,
+ xLabel=xLabel,
+ yLabel=yLabel,
+ )
+ return data
+
+
+class _ProfileCrossROI(roi_items.HandleBasedROI, core.ProfileRoiMixIn):
+
+ """ROI to manage a cross of profiles
+
+ It is managed using 2 sub ROIs for vertical and horizontal.
+ """
+
+ _kind = "Cross"
+ """Label for this kind of ROI"""
+
+ _plotShape = "point"
+ """Plot shape which is used for the first interaction"""
+
+ def __init__(self, parent=None):
+ roi_items.HandleBasedROI.__init__(self, parent=parent)
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.sigRegionChanged.connect(self.__regionChanged)
+ self.sigAboutToBeRemoved.connect(self.__aboutToBeRemoved)
+ self.__position = 0, 0
+ self.__vline = None
+ self.__hline = None
+ self.__handle = self.addHandle()
+ self.__handleLabel = self.addLabelHandle()
+ self.__handleLabel.setText(self.getName())
+ self.__inhibitReentance = utils.LockReentrant()
+ self.computeProfile = None
+ self.sigItemChanged.connect(self.__updateLineProperty)
+
+ # Make sure the marker is over the ROIs
+ self.__handle.setZValue(1)
+ # Create the vline and the hline
+ self._createSubRois()
+
+ @docstring(roi_items.HandleBasedROI)
+ def contains(self, position):
+ roiPos = self.getPosition()
+ return position[0] == roiPos[0] or position[1] == roiPos[1]
+
+ def setFirstShapePoints(self, points):
+ pos = points[0]
+ self.setPosition(pos)
+
+ def getPosition(self):
+ """Returns the position of this ROI
+
+ :rtype: numpy.ndarray
+ """
+ return self.__position
+
+ def setPosition(self, pos):
+ """Set the position of this ROI
+
+ :param numpy.ndarray pos: 2d-coordinate of this point
+ """
+ self.__position = pos
+ with utils.blockSignals(self.__handle):
+ self.__handle.setPosition(*pos)
+ with utils.blockSignals(self.__handleLabel):
+ self.__handleLabel.setPosition(*pos)
+ self.sigRegionChanged.emit()
+
+ def handleDragUpdated(self, handle, origin, previous, current):
+ if handle is self.__handle:
+ self.setPosition(current)
+
+ def __updateLineProperty(self, event=None, checkVisibility=True):
+ if event == items.ItemChangedType.NAME:
+ self.__handleLabel.setText(self.getName())
+ elif event in [items.ItemChangedType.COLOR,
+ items.ItemChangedType.VISIBLE]:
+ lines = []
+ if self.__vline:
+ lines.append(self.__vline)
+ if self.__hline:
+ lines.append(self.__hline)
+ self._updateItemProperty(event, self, lines)
+
+ def _createLines(self, parent):
+ """Inherit this function to return 2 ROI objects for respectivly
+ the horizontal, and the vertical lines."""
+ raise NotImplementedError()
+
+ def _setProfileManager(self, profileManager):
+ core.ProfileRoiMixIn._setProfileManager(self, profileManager)
+ # Connecting the vline and the hline
+ roiManager = profileManager.getRoiManager()
+ roiManager.addRoi(self.__vline)
+ roiManager.addRoi(self.__hline)
+
+ def _createSubRois(self):
+ hline, vline = self._createLines(parent=None)
+ for i, line in enumerate([vline, hline]):
+ line.setPosition(self.__position[i])
+ line.setEditable(True)
+ line.setSelectable(True)
+ line.setFocusProxy(self)
+ line.setName("")
+ self.__vline = vline
+ self.__hline = hline
+ vline.sigAboutToBeRemoved.connect(self.__vlineRemoved)
+ vline.sigRegionChanged.connect(self.__vlineRegionChanged)
+ hline.sigAboutToBeRemoved.connect(self.__hlineRemoved)
+ hline.sigRegionChanged.connect(self.__hlineRegionChanged)
+
+ def _getLines(self):
+ return self.__hline, self.__vline
+
+ def __regionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ x, y = self.getPosition()
+ hline, vline = self._getLines()
+ if hline is None:
+ return
+ with self.__inhibitReentance:
+ hline.setPosition(y)
+ vline.setPosition(x)
+
+ def __vlineRegionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ pos = self.getPosition()
+ vline = self.__vline
+ pos = vline.getPosition(), pos[1]
+ with self.__inhibitReentance:
+ self.setPosition(pos)
+
+ def __hlineRegionChanged(self):
+ if self.__inhibitReentance.locked():
+ return
+ pos = self.getPosition()
+ hline = self.__hline
+ pos = pos[0], hline.getPosition()
+ with self.__inhibitReentance:
+ self.setPosition(pos)
+
+ def __aboutToBeRemoved(self):
+ vline = self.__vline
+ hline = self.__hline
+ # Avoid side remove signals
+ if hline is not None:
+ hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
+ hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
+ if vline is not None:
+ vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
+ vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
+ # Clean up the child
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ if hline is not None:
+ roiManager.removeRoi(hline)
+ self.__hline = None
+ if vline is not None:
+ roiManager.removeRoi(vline)
+ self.__vline = None
+
+ def __hlineRemoved(self):
+ self.__lineRemoved(isHline=True)
+
+ def __vlineRemoved(self):
+ self.__lineRemoved(isHline=False)
+
+ def __lineRemoved(self, isHline):
+ """If any of the lines is removed: disconnect this objects, and let the
+ other one persist"""
+ hline, vline = self._getLines()
+
+ hline.sigAboutToBeRemoved.disconnect(self.__hlineRemoved)
+ vline.sigAboutToBeRemoved.disconnect(self.__vlineRemoved)
+ hline.sigRegionChanged.disconnect(self.__hlineRegionChanged)
+ vline.sigRegionChanged.disconnect(self.__vlineRegionChanged)
+
+ self.__hline = None
+ self.__vline = None
+ profileManager = self.getProfileManager()
+ roiManager = profileManager.getRoiManager()
+ if isHline:
+ self.__releaseLine(vline)
+ else:
+ self.__releaseLine(hline)
+ roiManager.removeRoi(self)
+
+ def __releaseLine(self, line):
+ """Release the line in order to make it independent"""
+ line.setFocusProxy(None)
+ line.setName(self.getName())
+ line.setEditable(self.isEditable())
+ line.setSelectable(self.isSelectable())
+
+
+class ProfileImageCrossROI(_ProfileCrossROI):
+ """ROI to manage a cross of profiles
+
+ It is managed using 2 sub ROIs for vertical and horizontal.
+ """
+
+ ICON = 'shape-cross'
+ NAME = 'cross profile'
+ ITEM_KIND = items.ImageBase
+
+ def _createLines(self, parent):
+ vline = ProfileImageVerticalLineROI(parent=parent)
+ hline = ProfileImageHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def setProfileMethod(self, method):
+ """
+ :param str method: method to compute the profile. Can be 'mean' or 'sum'
+ """
+ hline, vline = self._getLines()
+ hline.setProfileMethod(method)
+ vline.setProfileMethod(method)
+ self.invalidateProperties()
+
+ def getProfileMethod(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileMethod()
+
+ def setProfileLineWidth(self, width):
+ hline, vline = self._getLines()
+ hline.setProfileLineWidth(width)
+ vline.setProfileLineWidth(width)
+ self.invalidateProperties()
+
+ def getProfileLineWidth(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileLineWidth()
+
+
+class _DefaultScatterProfileRoiMixIn(core.ProfileRoiMixIn):
+ """Provide common behavior for silx default scatter profile ROI.
+ """
+
+ ITEM_KIND = items.Scatter
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__nPoints = 1024
+ self.sigRegionChanged.connect(self.__regionChanged)
+
+ def __regionChanged(self):
+ self.invalidateProfile()
+
+ # Number of points
+
+ def getNPoints(self):
+ """Returns the number of points of the profiles
+
+ :rtype: int
+ """
+ return self.__nPoints
+
+ def setNPoints(self, npoints):
+ """Set the number of points of the profiles
+
+ :param int npoints:
+ """
+ npoints = int(npoints)
+ if npoints < 1:
+ raise ValueError("Unsupported number of points: %d" % npoints)
+ elif npoints != self.__nPoints:
+ self.__nPoints = npoints
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def _computeProfile(self, scatter, x0, y0, x1, y1):
+ """Compute corresponding profile
+
+ :param float x0: Profile start point X coord
+ :param float y0: Profile start point Y coord
+ :param float x1: Profile end point X coord
+ :param float y1: Profile end point Y coord
+ :return: (points, values) profile data or None
+ """
+ future = scatter._getInterpolator()
+ try:
+ interpolator = future.result()
+ except CancelledError:
+ return None
+ if interpolator is None:
+ return None # Cannot init an interpolator
+
+ nPoints = self.getNPoints()
+ points = numpy.transpose((
+ numpy.linspace(x0, x1, nPoints, endpoint=True),
+ numpy.linspace(y0, y1, nPoints, endpoint=True)))
+
+ values = interpolator(points)
+
+ if not numpy.any(numpy.isfinite(values)):
+ return None # Profile outside convex hull
+
+ return points, values
+
+ def computeProfile(self, item):
+ """Update profile according to current ROI"""
+ if not isinstance(item, items.Scatter):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ # Get end points
+ if isinstance(self, roi_items.LineROI):
+ points = self.getEndPoints()
+ x0, y0 = points[0]
+ x1, y1 = points[1]
+ elif isinstance(self, (roi_items.VerticalLineROI, roi_items.HorizontalLineROI)):
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+
+ if isinstance(self, roi_items.HorizontalLineROI):
+ x0, x1 = plot.getXAxis().getLimits()
+ y0 = y1 = self.getPosition()
+
+ elif isinstance(self, roi_items.VerticalLineROI):
+ x0 = x1 = self.getPosition()
+ y0, y1 = plot.getYAxis().getLimits()
+ else:
+ raise RuntimeError('Unsupported ROI for profile: {}'.format(self.__class__))
+
+ if x1 < x0 or (x1 == x0 and y1 < y0):
+ # Invert points
+ x0, y0, x1, y1 = x1, y1, x0, y0
+
+ profile = self._computeProfile(item, x0, y0, x1, y1)
+ if profile is None:
+ return None
+
+ title = _lineProfileTitle(x0, y0, x1, y1)
+ points = profile[0]
+ values = profile[1]
+
+ if (numpy.abs(points[-1, 0] - points[0, 0]) >
+ numpy.abs(points[-1, 1] - points[0, 1])):
+ xProfile = points[:, 0]
+ xLabel = '{xlabel}'
+ else:
+ xProfile = points[:, 1]
+ xLabel = '{ylabel}'
+
+ # Use the axis names from the original
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ title = _relabelAxes(plot, title)
+ xLabel = _relabelAxes(plot, xLabel)
+
+ data = core.CurveProfileData(
+ coords=xProfile,
+ profile=values,
+ title=title,
+ xLabel=xLabel,
+ yLabel='Profile',
+ )
+ return data
+
+
+class ProfileScatterHorizontalLineROI(roi_items.HorizontalLineROI,
+ _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = 'shape-horizontal'
+ NAME = 'horizontal line profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterVerticalLineROI(roi_items.VerticalLineROI,
+ _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = 'shape-vertical'
+ NAME = 'vertical line profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterLineROI(roi_items.LineROI,
+ _DefaultScatterProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter"""
+
+ ICON = 'shape-diagonal'
+ NAME = 'line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterCrossROI(_ProfileCrossROI):
+ """ROI to manage a cross of profiles for scatters.
+ """
+
+ ICON = 'shape-cross'
+ NAME = 'cross profile'
+ ITEM_KIND = items.Scatter
+
+ def _createLines(self, parent):
+ vline = ProfileScatterVerticalLineROI(parent=parent)
+ hline = ProfileScatterHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def getNPoints(self):
+ """Returns the number of points of the profiles
+
+ :rtype: int
+ """
+ hline, _vline = self._getLines()
+ return hline.getNPoints()
+
+ def setNPoints(self, npoints):
+ """Set the number of points of the profiles
+
+ :param int npoints:
+ """
+ hline, vline = self._getLines()
+ hline.setNPoints(npoints)
+ vline.setNPoints(npoints)
+ self.invalidateProperties()
+
+
+class _DefaultScatterProfileSliceRoiMixIn(core.ProfileRoiMixIn):
+ """Default ROI to allow to slice in the scatter data."""
+
+ ITEM_KIND = items.Scatter
+
+ def __init__(self, parent=None):
+ core.ProfileRoiMixIn.__init__(self, parent=parent)
+ self.__area = _SliceProfileArea(self)
+ self.addItem(self.__area)
+ self.sigRegionChanged.connect(self._regionChanged)
+ self.sigPlotItemChanged.connect(self._updateArea)
+
+ def _regionChanged(self):
+ self.invalidateProfile()
+ self._updateArea()
+
+ def _updateArea(self):
+ plotItem = self.getPlotItem()
+ if plotItem is None:
+ self.setLineStyle("-")
+ else:
+ self.setLineStyle("--")
+
+ def _getSlice(self, item):
+ position = self.getPosition()
+ bounds = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_BOUNDS)
+ if isinstance(self, roi_items.HorizontalLineROI):
+ axis = 1
+ elif isinstance(self, roi_items.VerticalLineROI):
+ axis = 0
+ else:
+ assert False
+ if bounds is None or position < bounds[0][axis] or position > bounds[1][axis]:
+ # ROI outside of the scatter bound
+ return None
+
+ major_order = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_MAJOR_ORDER)
+ assert major_order == 'row'
+ max_grid_yy, max_grid_xx = item.getCurrentVisualizationParameter(items.Scatter.VisualizationParameter.GRID_SHAPE)
+
+ xx, yy, _values, _xx_error, _yy_error = item.getData(copy=False)
+ if isinstance(self, roi_items.HorizontalLineROI):
+ axis = yy
+ max_grid_first = max_grid_yy
+ max_grid_second = max_grid_xx
+ major_axis = major_order == 'column'
+ elif isinstance(self, roi_items.VerticalLineROI):
+ axis = xx
+ max_grid_first = max_grid_xx
+ max_grid_second = max_grid_yy
+ major_axis = major_order == 'row'
+ else:
+ assert False
+
+ def argnearest(array, value):
+ array = numpy.abs(array - value)
+ return numpy.argmin(array)
+
+ if major_axis:
+ # slice in the middle of the scatter
+ start = max_grid_second // 2 * max_grid_first
+ vslice = axis[start:start + max_grid_second]
+ index = argnearest(vslice, position)
+ slicing = slice(index, None, max_grid_first)
+ else:
+ # slice in the middle of the scatter
+ vslice = axis[max_grid_second // 2::max_grid_second]
+ index = argnearest(vslice, position)
+ start = index * max_grid_second
+ slicing = slice(start, start + max_grid_second)
+
+ return slicing
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.Scatter):
+ raise TypeError("Unsupported %s item" % type(item))
+
+ slicing = self._getSlice(item)
+ if slicing is None:
+ # ROI out of bounds
+ return None
+
+ _xx, _yy, values, _xx_error, _yy_error = item.getData(copy=False)
+ profile = values[slicing]
+
+ if isinstance(self, roi_items.HorizontalLineROI):
+ title = "Horizontal slice"
+ xLabel = "{xlabel} index"
+ elif isinstance(self, roi_items.VerticalLineROI):
+ title = "Vertical slice"
+ xLabel = "{ylabel} index"
+ else:
+ assert False
+
+ # Use the axis names from the original plot
+ profileManager = self.getProfileManager()
+ plot = profileManager.getPlotWidget()
+ xLabel = _relabelAxes(plot, xLabel)
+
+ data = core.CurveProfileData(
+ coords=numpy.arange(len(profile)),
+ profile=profile,
+ title=title,
+ xLabel=xLabel,
+ yLabel="Profile",
+ )
+ return data
+
+
+class ProfileScatterHorizontalSliceROI(roi_items.HorizontalLineROI,
+ _DefaultScatterProfileSliceRoiMixIn):
+ """ROI for an horizontal profile at a location of a scatter
+ using data slicing.
+ """
+
+ ICON = 'slice-horizontal'
+ NAME = 'horizontal data slice profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterVerticalSliceROI(roi_items.VerticalLineROI,
+ _DefaultScatterProfileSliceRoiMixIn):
+ """ROI for a vertical profile at a location of a scatter
+ using data slicing.
+ """
+
+ ICON = 'slice-vertical'
+ NAME = 'vertical data slice profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultScatterProfileSliceRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileScatterCrossSliceROI(_ProfileCrossROI):
+ """ROI to manage a cross of slicing profiles on scatters.
+ """
+
+ ICON = 'slice-cross'
+ NAME = 'cross data slice profile'
+ ITEM_KIND = items.Scatter
+
+ def _createLines(self, parent):
+ vline = ProfileScatterVerticalSliceROI(parent=parent)
+ hline = ProfileScatterHorizontalSliceROI(parent=parent)
+ return hline, vline
+
+
+class _DefaultImageStackProfileRoiMixIn(_DefaultImageProfileRoiMixIn):
+
+ ITEM_KIND = items.ImageStack
+
+ def __init__(self, parent=None):
+ super(_DefaultImageStackProfileRoiMixIn, self).__init__(parent=parent)
+ self.__profileType = "1D"
+ """Kind of profile"""
+
+ def getProfileType(self):
+ return self.__profileType
+
+ def setProfileType(self, kind):
+ assert kind in ["1D", "2D"]
+ if self.__profileType == kind:
+ return
+ self.__profileType = kind
+ self.invalidateProperties()
+ self.invalidateProfile()
+
+ def computeProfile(self, item):
+ if not isinstance(item, items.ImageStack):
+ raise TypeError("Unexpected class %s" % type(item))
+
+ kind = self.getProfileType()
+ if kind == "1D":
+ result = _DefaultImageProfileRoiMixIn.computeProfile(self, item)
+ # z = item.getStackPosition()
+ return result
+
+ assert kind == "2D"
+
+ def createProfile2(currentData):
+ coords, profile, _area, profileName, xLabel = core.createProfile(
+ roiInfo=self._getRoiInfo(),
+ currentData=currentData,
+ origin=origin,
+ scale=scale,
+ lineWidth=self.getProfileLineWidth(),
+ method=method)
+ return coords, profile, profileName, xLabel
+
+ currentData = numpy.array(item.getStackData(copy=False))
+ origin = item.getOrigin()
+ scale = item.getScale()
+ colormap = item.getColormap()
+ method = self.getProfileMethod()
+
+ coords, profile, profileName, xLabel = createProfile2(currentData)
+
+ data = core.ImageProfileData(
+ coords=coords,
+ profile=profile,
+ title=profileName,
+ xLabel=xLabel,
+ yLabel="Profile",
+ colormap=colormap,
+ )
+ return data
+
+
+class ProfileImageStackHorizontalLineROI(roi_items.HorizontalLineROI,
+ _DefaultImageStackProfileRoiMixIn):
+ """ROI for an horizontal profile at a location of a stack of images"""
+
+ ICON = 'shape-horizontal'
+ NAME = 'horizontal line profile'
+
+ def __init__(self, parent=None):
+ roi_items.HorizontalLineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackVerticalLineROI(roi_items.VerticalLineROI,
+ _DefaultImageStackProfileRoiMixIn):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = 'shape-vertical'
+ NAME = 'vertical line profile'
+
+ def __init__(self, parent=None):
+ roi_items.VerticalLineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackLineROI(roi_items.LineROI,
+ _DefaultImageStackProfileRoiMixIn):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = 'shape-diagonal'
+ NAME = 'line profile'
+
+ def __init__(self, parent=None):
+ roi_items.LineROI.__init__(self, parent=parent)
+ _DefaultImageStackProfileRoiMixIn.__init__(self, parent=parent)
+
+
+class ProfileImageStackCrossROI(ProfileImageCrossROI):
+ """ROI for an vertical profile at a location of a stack of images"""
+
+ ICON = 'shape-cross'
+ NAME = 'cross profile'
+ ITEM_KIND = items.ImageStack
+
+ def _createLines(self, parent):
+ vline = ProfileImageStackVerticalLineROI(parent=parent)
+ hline = ProfileImageStackHorizontalLineROI(parent=parent)
+ return hline, vline
+
+ def getProfileType(self):
+ hline, _vline = self._getLines()
+ return hline.getProfileType()
+
+ def setProfileType(self, kind):
+ hline, vline = self._getLines()
+ hline.setProfileType(kind)
+ vline.setProfileType(kind)
+ self.invalidateProperties()
diff --git a/silx/gui/plot/tools/profile/toolbar.py b/src/silx/gui/plot/tools/profile/toolbar.py
index 4a9a195..4a9a195 100644
--- a/silx/gui/plot/tools/profile/toolbar.py
+++ b/src/silx/gui/plot/tools/profile/toolbar.py
diff --git a/src/silx/gui/plot/tools/roi.py b/src/silx/gui/plot/tools/roi.py
new file mode 100644
index 0000000..e4be6a7
--- /dev/null
+++ b/src/silx/gui/plot/tools/roi.py
@@ -0,0 +1,1417 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides ROI interaction for :class:`~silx.gui.plot.PlotWidget`.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import enum
+import logging
+import time
+import weakref
+import functools
+
+import numpy
+
+from ... import qt, icons
+from ...utils import blockSignals
+from ...utils import LockReentrant
+from .. import PlotWidget
+from ..items import roi as roi_items
+
+from ...colors import rgba
+
+
+logger = logging.getLogger(__name__)
+
+
+class CreateRoiModeAction(qt.QAction):
+ """
+ This action is a plot mode which allows to create new ROIs using a ROI
+ manager.
+
+ A ROI is created using a specific `roiClass`. `initRoi` and `finalizeRoi`
+ can be inherited to custom the ROI initialization.
+
+ :param class roiClass: The ROI class which will be created by this action.
+ :param qt.QObject parent: The action parent
+ :param RegionOfInterestManager roiManager: The ROI manager
+ """
+
+ def __init__(self, parent, roiManager, roiClass):
+ assert roiManager is not None
+ assert roiClass is not None
+ qt.QAction.__init__(self, parent=parent)
+ self._roiManager = weakref.ref(roiManager)
+ self._roiClass = roiClass
+ self._singleShot = False
+ self._initAction()
+ self.triggered[bool].connect(self._actionTriggered)
+
+ def _initAction(self):
+ """Default initialization of the action"""
+ roiClass = self._roiClass
+
+ name = None
+ iconName = None
+ if hasattr(roiClass, "NAME"):
+ name = roiClass.NAME
+ if hasattr(roiClass, "ICON"):
+ iconName = roiClass.ICON
+
+ if iconName is None:
+ iconName = "add-shape-unknown"
+ if name is None:
+ name = roiClass.__name__
+ text = 'Add %s' % name
+ self.setIcon(icons.getQIcon(iconName))
+ self.setText(text)
+ self.setCheckable(True)
+ self.setToolTip(text)
+
+ def getRoiClass(self):
+ """Return the ROI class used by this action to create ROIs"""
+ return self._roiClass
+
+ def getRoiManager(self):
+ return self._roiManager()
+
+ def setSingleShot(self, singleShot):
+ """Set it to True to deactivate the action after the first creation
+ of a ROI.
+
+ :param bool singleShot: New single short state
+ """
+ self._singleShot = singleShot
+
+ def getSingleShot(self):
+ """If True, after the first creation of a ROI with this mode,
+ the mode is deactivated.
+
+ :rtype: bool
+ """
+ return self._singleShot
+
+ def _actionTriggered(self, checked):
+ """Handle mode actions being checked by the user
+
+ :param bool checked:
+ :param str kind: Corresponding shape kind
+ """
+ roiManager = self.getRoiManager()
+ if roiManager is None:
+ return
+
+ if checked:
+ roiManager.start(self._roiClass, self)
+ self.__interactiveModeStarted(roiManager)
+ else:
+ source = roiManager.getInteractionSource()
+ if source is self:
+ roiManager.stop()
+
+ def __interactiveModeStarted(self, roiManager):
+ roiManager.sigInteractiveRoiCreated.connect(self.initRoi)
+ roiManager.sigInteractiveRoiFinalized.connect(self.__finalizeRoi)
+ roiManager.sigInteractiveModeFinished.connect(self.__interactiveModeFinished)
+
+ def __interactiveModeFinished(self):
+ roiManager = self.getRoiManager()
+ if roiManager is not None:
+ roiManager.sigInteractiveRoiCreated.disconnect(self.initRoi)
+ roiManager.sigInteractiveRoiFinalized.disconnect(self.__finalizeRoi)
+ roiManager.sigInteractiveModeFinished.disconnect(self.__interactiveModeFinished)
+ self.setChecked(False)
+
+ def initRoi(self, roi):
+ """Inherit it to custom the new ROI at it's creation during the
+ interaction."""
+ pass
+
+ def __finalizeRoi(self, roi):
+ self.finalizeRoi(roi)
+ if self._singleShot:
+ roiManager = self.getRoiManager()
+ if roiManager is not None:
+ roiManager.stop()
+
+ def finalizeRoi(self, roi):
+ """Inherit it to custom the new ROI after it's creation when the
+ interaction is finalized."""
+ pass
+
+
+class RoiModeSelector(qt.QWidget):
+ def __init__(self, parent=None):
+ super(RoiModeSelector, self).__init__(parent=parent)
+ self.__roi = None
+ self.__reentrant = LockReentrant()
+
+ layout = qt.QHBoxLayout(self)
+ if isinstance(parent, qt.QMenu):
+ margins = layout.contentsMargins()
+ layout.setContentsMargins(margins.left(), 0, margins.right(), 0)
+ else:
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ self._label = qt.QLabel(self)
+ self._label.setText("Mode:")
+ self._label.setToolTip("Select a specific interaction to edit the ROI")
+ self._combo = qt.QComboBox(self)
+ self._combo.currentIndexChanged.connect(self._modeSelected)
+ layout.addWidget(self._label)
+ layout.addWidget(self._combo)
+ self._updateAvailableModes()
+
+ def getRoi(self):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ return self.__roi
+
+ def setRoi(self, roi):
+ """Returns the edited ROI.
+
+ :rtype: roi_items.RegionOfInterest
+ """
+ if self.__roi is roi:
+ return
+ if not isinstance(roi, roi_items.InteractionModeMixIn):
+ self.__roi = None
+ self._updateAvailableModes()
+ return
+
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.disconnect(self._modeChanged)
+ self.__roi = roi
+ if self.__roi is not None:
+ self.__roi.sigInteractionModeChanged.connect(self._modeChanged)
+ self._updateAvailableModes()
+
+ def isEmpty(self):
+ return not self._label.isVisibleTo(self)
+
+ def _updateAvailableModes(self):
+ roi = self.getRoi()
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ modes = roi.availableInteractionModes()
+ else:
+ modes = []
+ if len(modes) <= 1:
+ self._label.setVisible(False)
+ self._combo.setVisible(False)
+ else:
+ self._label.setVisible(True)
+ self._combo.setVisible(True)
+ with blockSignals(self._combo):
+ self._combo.clear()
+ for im, m in enumerate(modes):
+ self._combo.addItem(m.label, m)
+ self._combo.setItemData(im, m.description, qt.Qt.ToolTipRole)
+ mode = roi.getInteractionMode()
+ self._modeChanged(mode)
+ index = modes.index(mode)
+ self._combo.setCurrentIndex(index)
+
+ def _modeChanged(self, mode):
+ """Triggered when the ROI interaction mode was changed externally"""
+ if self.__reentrant.locked():
+ # This event was initialised by the widget
+ return
+ roi = self.__roi
+ modes = roi.availableInteractionModes()
+ index = modes.index(mode)
+ with blockSignals(self._combo):
+ self._combo.setCurrentIndex(index)
+
+ def _modeSelected(self):
+ """Triggered when the ROI interaction mode was selected in the widget"""
+ index = self._combo.currentIndex()
+ if index == -1:
+ return
+ roi = self.getRoi()
+ if roi is not None:
+ mode = self._combo.itemData(index, qt.Qt.UserRole)
+ with self.__reentrant:
+ roi.setInteractionMode(mode)
+
+
+class RoiModeSelectorAction(qt.QWidgetAction):
+ """Display the selected mode of a ROI and allow to change it"""
+
+ def __init__(self, parent=None):
+ super(RoiModeSelectorAction, self).__init__(parent)
+ self.__roiManager = None
+
+ def createWidget(self, parent):
+ """Inherit the method to create a new widget"""
+ widget = RoiModeSelector(parent)
+ manager = self.__roiManager
+ if manager is not None:
+ roi = manager.getCurrentRoi()
+ widget.setRoi(roi)
+ self.setVisible(not widget.isEmpty())
+ return widget
+
+ def deleteWidget(self, widget):
+ """Inherit the method to delete a widget"""
+ widget.setRoi(None)
+ return qt.QWidgetAction.deleteWidget(self, widget)
+
+ def setRoiManager(self, roiManager):
+ """
+ Connect this action to a ROI manager.
+
+ :param RegionOfInterestManager roiManager: A ROI manager
+ """
+ if self.__roiManager is roiManager:
+ return
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.disconnect(self.__currentRoiChanged)
+ self.__roiManager = roiManager
+ if self.__roiManager is not None:
+ self.__roiManager.sigCurrentRoiChanged.connect(self.__currentRoiChanged)
+ self.__currentRoiChanged(roiManager.getCurrentRoi())
+
+ def __currentRoiChanged(self, roi):
+ """Handle changes of the selected ROI"""
+ self.setRoi(roi)
+
+ def setRoi(self, roi):
+ """Set a profile ROI to edit.
+
+ :param ProfileRoiMixIn roi: A profile ROI
+ """
+ widget = None
+ for widget in self.createdWidgets():
+ widget.setRoi(roi)
+ if widget is not None:
+ self.setVisible(not widget.isEmpty())
+
+
+class RegionOfInterestManager(qt.QObject):
+ """Class handling ROI interaction on a PlotWidget.
+
+ It supports the multiple ROIs: points, rectangles, polygons,
+ lines, horizontal and vertical lines.
+
+ See ``plotInteractiveImageROI.py`` sample code (:ref:`sample-code`).
+
+ :param silx.gui.plot.PlotWidget parent:
+ The plot widget in which to control the ROIs.
+ """
+
+ sigRoiAdded = qt.Signal(roi_items.RegionOfInterest)
+ """Signal emitted when a new ROI has been added.
+
+ It provides the newly add :class:`RegionOfInterest` object.
+ """
+
+ sigRoiAboutToBeRemoved = qt.Signal(roi_items.RegionOfInterest)
+ """Signal emitted just before a ROI is removed.
+
+ It provides the :class:`RegionOfInterest` object that is about to be removed.
+ """
+
+ sigRoiChanged = qt.Signal()
+ """Signal emitted whenever the ROIs have changed."""
+
+ sigCurrentRoiChanged = qt.Signal(object)
+ """Signal emitted whenever a ROI is selected."""
+
+ sigInteractiveModeStarted = qt.Signal(object)
+ """Signal emitted when switching to ROI drawing interactive mode.
+
+ It provides the class of the ROI which will be created by the interactive
+ mode.
+ """
+
+ sigInteractiveRoiCreated = qt.Signal(object)
+ """Signal emitted when a ROI is created during the interaction.
+ The interaction is still incomplete and can be aborted.
+
+ It provides the ROI object which was just been created.
+ """
+
+ sigInteractiveRoiFinalized = qt.Signal(object)
+ """Signal emitted when a ROI creation is complet.
+
+ It provides the ROI object which was just been created.
+ """
+
+ sigInteractiveModeFinished = qt.Signal()
+ """Signal emitted when leaving interactive ROI drawing mode.
+ """
+
+ ROI_CLASSES = (
+ roi_items.PointROI,
+ roi_items.CrossROI,
+ roi_items.RectangleROI,
+ roi_items.CircleROI,
+ roi_items.EllipseROI,
+ roi_items.PolygonROI,
+ roi_items.LineROI,
+ roi_items.HorizontalLineROI,
+ roi_items.VerticalLineROI,
+ roi_items.ArcROI,
+ roi_items.HorizontalRangeROI,
+ )
+
+ def __init__(self, parent):
+ assert isinstance(parent, PlotWidget)
+ super(RegionOfInterestManager, self).__init__(parent)
+ self._rois = [] # List of ROIs
+ self._drawnROI = None # New ROI being currently drawn
+
+ self._roiClass = None
+ self._source = None
+ self._color = rgba('red')
+
+ self._label = "__RegionOfInterestManager__%d" % id(self)
+
+ self._currentRoi = None
+ """Hold currently selected ROI"""
+
+ self._eventLoop = None
+
+ self._modeActions = {}
+
+ parent.sigPlotSignal.connect(self._plotSignals)
+
+ parent.sigInteractiveModeChanged.connect(
+ self._plotInteractiveModeChanged)
+
+ parent.sigItemRemoved.connect(self._itemRemoved)
+
+ parent._sigDefaultContextMenu.connect(self._feedContextMenu)
+
+ @classmethod
+ def getSupportedRoiClasses(cls):
+ """Returns the default available ROI classes
+
+ :rtype: List[class]
+ """
+ return tuple(cls.ROI_CLASSES)
+
+ # Associated QActions
+
+ def getInteractionModeAction(self, roiClass):
+ """Returns the QAction corresponding to a kind of ROI
+
+ The QAction allows to enable the corresponding drawing
+ interactive mode.
+
+ :param class roiClass: The ROI class which will be created by this action.
+ :rtype: QAction
+ :raise ValueError: If kind is not supported
+ """
+ if not issubclass(roiClass, roi_items.RegionOfInterest):
+ raise ValueError('Unsupported ROI class %s' % roiClass)
+
+ action = self._modeActions.get(roiClass, None)
+ if action is None: # Lazy-loading
+ action = CreateRoiModeAction(self, self, roiClass)
+ self._modeActions[roiClass] = action
+ return action
+
+ # PlotWidget eventFilter and listeners
+
+ def _plotInteractiveModeChanged(self, source):
+ """Handle change of interactive mode in the plot"""
+ if source is not self:
+ self.__roiInteractiveModeEnded()
+
+ def _getRoiFromItem(self, item):
+ """Returns the ROI which own this item, else None
+ if this manager do not have knowledge of this ROI."""
+ for roi in self._rois:
+ if isinstance(roi, roi_items.RegionOfInterest):
+ for child in roi.getItems():
+ if child is item:
+ return roi
+ return None
+
+ def _itemRemoved(self, item):
+ """Called after an item was removed from the plot."""
+ if not hasattr(item, "_roiGroup"):
+ # Early break to avoid to use _getRoiFromItem
+ # And to avoid reentrant signal when the ROI remove the item itself
+ return
+ roi = self._getRoiFromItem(item)
+ if roi is not None:
+ self.removeRoi(roi)
+
+ # Handle ROI interaction
+
+ def _handleInteraction(self, event):
+ """Handle mouse interaction for ROI addition"""
+ roiClass = self.getCurrentInteractionModeRoiClass()
+ if roiClass is None:
+ return # Should not happen
+
+ kind = roiClass.getFirstInteractionShape()
+ if kind == 'point':
+ if event['event'] == 'mouseClicked' and event['button'] == 'left':
+ points = numpy.array([(event['x'], event['y'])],
+ dtype=numpy.float64)
+ # Not an interactive creation
+ roi = self._createInteractiveRoi(roiClass, points=points)
+ roi.creationFinalized()
+ self.sigInteractiveRoiFinalized.emit(roi)
+ else: # other shapes
+ if (event['event'] in ('drawingProgress', 'drawingFinished') and
+ event['parameters']['label'] == self._label):
+ points = numpy.array((event['xdata'], event['ydata']),
+ dtype=numpy.float64).T
+
+ if self._drawnROI is None: # Create new ROI
+ # NOTE: Set something before createRoi, so isDrawing is True
+ self._drawnROI = object()
+ self._drawnROI = self._createInteractiveRoi(roiClass, points=points)
+ else:
+ self._drawnROI.setFirstShapePoints(points)
+
+ if event['event'] == 'drawingFinished':
+ if kind == 'polygon' and len(points) > 1:
+ self._drawnROI.setFirstShapePoints(points[:-1])
+ roi = self._drawnROI
+ self._drawnROI = None # Stop drawing
+ roi.creationFinalized()
+ self.sigInteractiveRoiFinalized.emit(roi)
+
+ # RegionOfInterest selection
+
+ def __getRoiFromMarker(self, marker):
+ """Returns a ROI from a marker, else None"""
+ # This should be speed up
+ for roi in self._rois:
+ if isinstance(roi, roi_items.HandleBasedROI):
+ for m in roi.getHandles():
+ if m is marker:
+ return roi
+ else:
+ for m in roi.getItems():
+ if m is marker:
+ return roi
+ return None
+
+ def setCurrentRoi(self, roi):
+ """Set the currently selected ROI, and emit a signal.
+
+ :param Union[RegionOfInterest,None] roi: The ROI to select
+ """
+ if self._currentRoi is roi:
+ return
+ if roi is not None:
+ # Note: Fixed range to avoid infinite loops
+ for _ in range(10):
+ target = roi.getFocusProxy()
+ if target is None:
+ break
+ roi = target
+ else:
+ raise RuntimeError("Max selection proxy depth (10) reached.")
+
+ if self._currentRoi is not None:
+ self._currentRoi.setHighlighted(False)
+ self._currentRoi = roi
+ if self._currentRoi is not None:
+ self._currentRoi.setHighlighted(True)
+ self.sigCurrentRoiChanged.emit(roi)
+
+ def getCurrentRoi(self):
+ """Returns the currently selected ROI, else None.
+
+ :rtype: Union[RegionOfInterest,None]
+ """
+ return self._currentRoi
+
+ def _plotSignals(self, event):
+ """Handle mouse interaction for ROI addition"""
+ clicked = False
+ roi = None
+ if event["event"] in ("markerClicked", "markerMoving"):
+ plot = self.parent()
+ legend = event["label"]
+ marker = plot._getMarker(legend=legend)
+ roi = self.__getRoiFromMarker(marker)
+ elif event["event"] == "mouseClicked" and event["button"] == "left":
+ # Marker click is only for dnd
+ # This also can click on a marker
+ clicked = True
+ plot = self.parent()
+ marker = plot._getMarkerAt(event["xpixel"], event["ypixel"])
+ roi = self.__getRoiFromMarker(marker)
+ else:
+ return
+
+ if roi not in self._rois:
+ # The ROI is not own by this manager
+ return
+
+ if roi is not None:
+ currentRoi = self.getCurrentRoi()
+ if currentRoi is roi:
+ if clicked:
+ self.__updateMode(roi)
+ elif roi.isSelectable():
+ self.setCurrentRoi(roi)
+ else:
+ self.setCurrentRoi(None)
+
+ def __updateMode(self, roi):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ available = roi.availableInteractionModes()
+ mode = roi.getInteractionMode()
+ imode = available.index(mode)
+ mode = available[(imode + 1) % len(available)]
+ roi.setInteractionMode(mode)
+
+ def _feedContextMenu(self, menu):
+ """Called when the default plot context menu is about to be displayed"""
+ roi = self.getCurrentRoi()
+ if roi is not None:
+ if roi.isEditable():
+ # Filter by data position
+ # FIXME: It would be better to use GUI coords for it
+ plot = self.parent()
+ pos = plot.getWidgetHandle().mapFromGlobal(qt.QCursor.pos())
+ data = plot.pixelToData(pos.x(), pos.y())
+ if roi.contains(data):
+ if isinstance(roi, roi_items.InteractionModeMixIn):
+ self._contextMenuForInteractionMode(menu, roi)
+
+ removeAction = qt.QAction(menu)
+ removeAction.setText("Remove %s" % roi.getName())
+ callback = functools.partial(self.removeRoi, roi)
+ removeAction.triggered.connect(callback)
+ menu.addAction(removeAction)
+
+ def _contextMenuForInteractionMode(self, menu, roi):
+ availableModes = roi.availableInteractionModes()
+ currentMode = roi.getInteractionMode()
+ submenu = qt.QMenu(menu)
+ modeGroup = qt.QActionGroup(menu)
+ modeGroup.setExclusive(True)
+ for mode in availableModes:
+ action = qt.QAction(menu)
+ action.setText(mode.label)
+ action.setToolTip(mode.description)
+ action.setCheckable(True)
+ if mode is currentMode:
+ action.setChecked(True)
+ else:
+ callback = functools.partial(roi.setInteractionMode, mode)
+ action.triggered.connect(callback)
+ modeGroup.addAction(action)
+ submenu.addAction(action)
+ submenu.setTitle("%s interaction mode" % roi.getName())
+ menu.addMenu(submenu)
+
+ # RegionOfInterest API
+
+ def getRois(self):
+ """Returns the list of ROIs.
+
+ It returns an empty tuple if there is currently no ROI.
+
+ :return: Tuple of arrays of objects describing the ROIs
+ :rtype: List[RegionOfInterest]
+ """
+ return tuple(self._rois)
+
+ def clear(self):
+ """Reset current ROIs
+
+ :return: True if ROIs were reset.
+ :rtype: bool
+ """
+ if self.getRois(): # Something to reset
+ for roi in self._rois:
+ roi.sigRegionChanged.disconnect(
+ self._regionOfInterestChanged)
+ roi.setParent(None)
+ self._rois = []
+ self._roisUpdated()
+ return True
+
+ else:
+ return False
+
+ def _regionOfInterestChanged(self, event=None):
+ """Handle ROI object changed"""
+ self.sigRoiChanged.emit()
+
+ def _createInteractiveRoi(self, roiClass, points, label=None, index=None):
+ """Create a new ROI with interactive creation.
+
+ :param class roiClass: The class of the ROI to create
+ :param numpy.ndarray points: The first shape used to create the ROI
+ :param str label: The label to display along with the ROI.
+ :param int index: The position where to insert the ROI.
+ By default it is appended to the end of the list.
+ :return: The created ROI object
+ :rtype: roi_items.RegionOfInterest
+ :raise RuntimeError: When ROI cannot be added because the maximum
+ number of ROIs has been reached.
+ """
+ roi = roiClass(parent=None)
+ if label is not None:
+ roi.setName(str(label))
+ roi.creationStarted()
+ roi.setFirstShapePoints(points)
+
+ self.addRoi(roi, index)
+ if roi.isSelectable():
+ self.setCurrentRoi(roi)
+ self.sigInteractiveRoiCreated.emit(roi)
+ return roi
+
+ def containsRoi(self, roi):
+ """Returns true if the ROI is part of this manager.
+
+ :param roi_items.RegionOfInterest roi: The ROI to add
+ :rtype: bool
+ """
+ return roi in self._rois
+
+ def addRoi(self, roi, index=None, useManagerColor=True):
+ """Add the ROI to the list of ROIs.
+
+ :param roi_items.RegionOfInterest roi: The ROI to add
+ :param int index: The position where to insert the ROI,
+ By default it is appended to the end of the list of ROIs
+ :param bool useManagerColor:
+ Whether to set the ROI color to the default one of the manager or not.
+ (Default: True).
+ :raise RuntimeError: When ROI cannot be added because the maximum
+ number of ROIs has been reached.
+ """
+ plot = self.parent()
+ if plot is None:
+ raise RuntimeError(
+ 'Cannot add ROI: PlotWidget no more available')
+
+ roi.setParent(self)
+
+ if useManagerColor:
+ roi.setColor(self.getColor())
+
+ roi.sigRegionChanged.connect(self._regionOfInterestChanged)
+ roi.sigItemChanged.connect(self._regionOfInterestChanged)
+
+ if index is None:
+ self._rois.append(roi)
+ else:
+ self._rois.insert(index, roi)
+ self.sigRoiAdded.emit(roi)
+ self._roisUpdated()
+
+ def removeRoi(self, roi):
+ """Remove a ROI from the list of ROIs.
+
+ :param roi_items.RegionOfInterest roi: The ROI to remove
+ :raise ValueError: When ROI does not belong to this object
+ """
+ if not (isinstance(roi, roi_items.RegionOfInterest) and
+ roi.parent() is self and
+ roi in self._rois):
+ raise ValueError(
+ 'RegionOfInterest does not belong to this instance')
+
+ roi.sigAboutToBeRemoved.emit()
+ self.sigRoiAboutToBeRemoved.emit(roi)
+
+ if roi is self._currentRoi:
+ self.setCurrentRoi(None)
+
+ mustRestart = False
+ if roi is self._drawnROI:
+ self._drawnROI = None
+ mustRestart = True
+ self._rois.remove(roi)
+ roi.sigRegionChanged.disconnect(self._regionOfInterestChanged)
+ roi.sigItemChanged.disconnect(self._regionOfInterestChanged)
+ roi.setParent(None)
+ self._roisUpdated()
+
+ if mustRestart:
+ self._restart()
+
+ def _roisUpdated(self):
+ """Handle update of the ROI list"""
+ self.sigRoiChanged.emit()
+
+ # RegionOfInterest parameters
+
+ def getColor(self):
+ """Return the default color of created ROIs
+
+ :rtype: QColor
+ """
+ return qt.QColor.fromRgbF(*self._color)
+
+ def setColor(self, color):
+ """Set the default color to use when creating ROIs.
+
+ Existing ROIs are not affected.
+
+ :param color: The color to use for displaying ROIs as
+ either a color name, a QColor, a list of uint8 or float in [0, 1].
+ """
+ self._color = rgba(color)
+
+ # Control ROI
+
+ def getCurrentInteractionModeRoiClass(self):
+ """Returns the current ROI class used by the interactive drawing mode.
+
+ Returns None if the ROI manager is not in an interactive mode.
+
+ :rtype: Union[class,None]
+ """
+ return self._roiClass
+
+ def getInteractionSource(self):
+ """Returns the object which have requested the ROI creation.
+
+ Returns None if the ROI manager is not in an interactive mode.
+
+ :rtype: Union[object,None]
+ """
+ return self._source
+
+ def isStarted(self):
+ """Returns True if an interactive ROI drawing mode is active.
+
+ :rtype: bool
+ """
+ return self._roiClass is not None
+
+ def isDrawing(self):
+ """Returns True if an interactive ROI is drawing.
+
+ :rtype: bool
+ """
+ return self._drawnROI is not None
+
+ def start(self, roiClass, source=None):
+ """Start an interactive ROI drawing mode.
+
+ :param class roiClass: The ROI class to create. It have to inherite from
+ `roi_items.RegionOfInterest`.
+ :param object source: SOurce of the ROI interaction.
+ :return: True if interactive ROI drawing was started, False otherwise
+ :rtype: bool
+ :raise ValueError: If roiClass is not supported
+ """
+ self.stop()
+
+ if not issubclass(roiClass, roi_items.RegionOfInterest):
+ raise ValueError('Unsupported ROI class %s' % roiClass)
+
+ plot = self.parent()
+ if plot is None:
+ return False
+
+ self._roiClass = roiClass
+ self._source = source
+
+ self._restart()
+
+ plot.sigPlotSignal.connect(self._handleInteraction)
+
+ self.sigInteractiveModeStarted.emit(roiClass)
+
+ return True
+
+ def _restart(self):
+ """Restart the plot interaction without changing the
+ source or the ROI class.
+ """
+ roiClass = self._roiClass
+ plot = self.parent()
+ firstInteractionShapeKind = roiClass.getFirstInteractionShape()
+
+ if firstInteractionShapeKind == 'point':
+ plot.setInteractiveMode(mode='select', source=self)
+ else:
+ if roiClass.showFirstInteractionShape():
+ color = rgba(self.getColor())
+ else:
+ color = None
+ plot.setInteractiveMode(mode='select-draw',
+ source=self,
+ shape=firstInteractionShapeKind,
+ color=color,
+ label=self._label)
+
+ def __roiInteractiveModeEnded(self):
+ """Handle end of ROI draw interactive mode"""
+ if self.isStarted():
+ self._roiClass = None
+ self._source = None
+
+ if self._drawnROI is not None:
+ # Cancel ROI create
+ roi = self._drawnROI
+ self._drawnROI = None
+ self.removeRoi(roi)
+
+ plot = self.parent()
+ if plot is not None:
+ plot.sigPlotSignal.disconnect(self._handleInteraction)
+
+ self.sigInteractiveModeFinished.emit()
+
+ def stop(self):
+ """Stop interactive ROI drawing mode.
+
+ :return: True if an interactive ROI drawing mode was actually stopped
+ :rtype: bool
+ """
+ if not self.isStarted():
+ return False
+
+ plot = self.parent()
+ if plot is not None:
+ # This leads to call __roiInteractiveModeEnded through
+ # interactive mode changed signal
+ plot.resetInteractiveMode()
+ else: # Fallback
+ self.__roiInteractiveModeEnded()
+
+ return True
+
+ def exec(self, roiClass):
+ """Block until :meth:`quit` is called.
+
+ :param class kind: The class of the ROI which have to be created.
+ See `silx.gui.plot.items.roi`.
+ :return: The list of ROIs
+ :rtype: tuple
+ """
+ self.start(roiClass)
+
+ plot = self.parent()
+ plot.show()
+ plot.raise_()
+
+ self._eventLoop = qt.QEventLoop()
+ self._eventLoop.exec()
+ self._eventLoop = None
+
+ self.stop()
+
+ rois = self.getRois()
+ self.clear()
+ return rois
+
+ def exec_(self, roiClass): # Qt5-like compatibility
+ return self.exec(roiClass)
+
+ def quit(self):
+ """Stop a blocking :meth:`exec` and call :meth:`stop`"""
+ if self._eventLoop is not None:
+ self._eventLoop.quit()
+ self._eventLoop = None
+ self.stop()
+
+
+class InteractiveRegionOfInterestManager(RegionOfInterestManager):
+ """RegionOfInterestManager with features for use from interpreter.
+
+ It is meant to be used through the :meth:`exec`.
+ It provides some messages to display in a status bar and
+ different modes to end blocking calls to :meth:`exec`.
+
+ :param parent: See QObject
+ """
+
+ sigMessageChanged = qt.Signal(str)
+ """Signal emitted when a new message should be displayed to the user
+
+ It provides the message as a str.
+ """
+
+ def __init__(self, parent):
+ super(InteractiveRegionOfInterestManager, self).__init__(parent)
+ self._maxROI = None
+ self.__timeoutEndTime = None
+ self.__message = ''
+ self.__validationMode = self.ValidationMode.ENTER
+ self.__execClass = None
+
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__aboutToBeRemoved)
+ self.sigInteractiveModeStarted.connect(self.__started)
+ self.sigInteractiveModeFinished.connect(self.__finished)
+
+ # Max ROI
+
+ def getMaxRois(self):
+ """Returns the maximum number of ROIs or None if no limit.
+
+ :rtype: Union[int,None]
+ """
+ return self._maxROI
+
+ def setMaxRois(self, max_):
+ """Set the maximum number of ROIs.
+
+ :param Union[int,None] max_: The max limit or None for no limit.
+ :raise ValueError: If there is more ROIs than max value
+ """
+ if max_ is not None:
+ max_ = int(max_)
+ if max_ <= 0:
+ raise ValueError('Max limit must be strictly positive')
+
+ if len(self.getRois()) > max_:
+ raise ValueError(
+ 'Cannot set max limit: Already too many ROIs')
+
+ self._maxROI = max_
+
+ def isMaxRois(self):
+ """Returns True if the maximum number of ROIs is reached.
+
+ :rtype: bool
+ """
+ max_ = self.getMaxRois()
+ return max_ is not None and len(self.getRois()) >= max_
+
+ # Validation mode
+
+ @enum.unique
+ class ValidationMode(enum.Enum):
+ """Mode of validation to leave blocking :meth:`exec`"""
+
+ AUTO = 'auto'
+ """Automatically ends the interactive mode once
+ the user terminates the last ROI shape."""
+
+ ENTER = 'enter'
+ """Ends the interactive mode when the *Enter* key is pressed."""
+
+ AUTO_ENTER = 'auto_enter'
+ """Ends the interactive mode when reaching max ROIs or
+ when the *Enter* key is pressed.
+ """
+
+ NONE = 'none'
+ """Do not provide the user a way to end the interactive mode.
+
+ The end of :meth:`exec` is done through :meth:`quit` or timeout.
+ """
+
+ def getValidationMode(self):
+ """Returns the interactive mode validation in use.
+
+ :rtype: ValidationMode
+ """
+ return self.__validationMode
+
+ def setValidationMode(self, mode):
+ """Set the way to perform interactive mode validation.
+
+ See :class:`ValidationMode` enumeration for the supported
+ validation modes.
+
+ :param ValidationMode mode: The interactive mode validation to use.
+ """
+ assert isinstance(mode, self.ValidationMode)
+ if mode != self.__validationMode:
+ self.__validationMode = mode
+
+ if self.isExec():
+ if (self.isMaxRois() and self.getValidationMode() in
+ (self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER)):
+ self.quit()
+
+ self.__updateMessage()
+
+ def eventFilter(self, obj, event):
+ if event.type() == qt.QEvent.Hide:
+ self.quit()
+
+ if event.type() == qt.QEvent.KeyPress:
+ key = event.key()
+ if (key in (qt.Qt.Key_Return, qt.Qt.Key_Enter) and
+ self.getValidationMode() in (
+ self.ValidationMode.ENTER,
+ self.ValidationMode.AUTO_ENTER)):
+ # Stop on return key pressed
+ self.quit()
+ return True # Stop further handling of this keys
+
+ if (key in (qt.Qt.Key_Delete, qt.Qt.Key_Backspace) or (
+ key == qt.Qt.Key_Z and
+ event.modifiers() & qt.Qt.ControlModifier)):
+ rois = self.getRois()
+ if rois: # Something to undo
+ self.removeRoi(rois[-1])
+ # Stop further handling of keys if something was undone
+ return True
+
+ return super(InteractiveRegionOfInterestManager, self).eventFilter(obj, event)
+
+ # Message API
+
+ def getMessage(self):
+ """Returns the current status message.
+
+ This message is meant to be displayed in a status bar.
+
+ :rtype: str
+ """
+ if self.__timeoutEndTime is None:
+ return self.__message
+ else:
+ remaining = self.__timeoutEndTime - time.time()
+ return self.__message + (' - %d seconds remaining' %
+ max(1, int(remaining)))
+
+ # Listen to ROI updates
+
+ def __added(self, *args, **kwargs):
+ """Handle new ROI added"""
+ max_ = self.getMaxRois()
+ if max_ is not None:
+ # When reaching max number of ROIs, redo last one
+ while len(self.getRois()) > max_:
+ self.removeRoi(self.getRois()[-2])
+
+ self.__updateMessage()
+ if (self.isMaxRois() and
+ self.getValidationMode() in (self.ValidationMode.AUTO,
+ self.ValidationMode.AUTO_ENTER)):
+ self.quit()
+
+ def __aboutToBeRemoved(self, *args, **kwargs):
+ """Handle removal of a ROI"""
+ # RegionOfInterest not removed yet
+ self.__updateMessage(nbrois=len(self.getRois()) - 1)
+
+ def __started(self, roiKind):
+ """Handle interactive mode started"""
+ self.__updateMessage()
+
+ def __finished(self):
+ """Handle interactive mode finished"""
+ self.__updateMessage()
+
+ def __updateMessage(self, nbrois=None):
+ """Update message"""
+ if not self.isExec():
+ message = 'Done'
+
+ elif not self.isStarted():
+ message = 'Use %s ROI edition mode' % self.__execClass
+
+ else:
+ if nbrois is None:
+ nbrois = len(self.getRois())
+
+ name = self.__execClass._getShortName()
+
+ max_ = self.getMaxRois()
+ if max_ is None:
+ message = 'Select %ss (%d selected)' % (name, nbrois)
+
+ elif max_ <= 1:
+ message = 'Select a %s' % name
+ else:
+ message = 'Select %d/%d %ss' % (nbrois, max_, name)
+
+ if (self.getValidationMode() == self.ValidationMode.ENTER and
+ self.isMaxRois()):
+ message += ' - Press Enter to confirm'
+
+ if message != self.__message:
+ self.__message = message
+ # Use getMessage to add timeout message
+ self.sigMessageChanged.emit(self.getMessage())
+
+ # Handle blocking call
+
+ def __timeoutUpdate(self):
+ """Handle update of timeout"""
+ if (self.__timeoutEndTime is not None and
+ (self.__timeoutEndTime - time.time()) > 0):
+ self.sigMessageChanged.emit(self.getMessage())
+ else: # Stop interactive mode and message timer
+ timer = self.sender()
+ if timer is not None:
+ timer.stop()
+ self.__timeoutEndTime = None
+ self.quit()
+
+ def isExec(self):
+ """Returns True if :meth:`exec` is currently running.
+
+ :rtype: bool"""
+ return self.__execClass is not None
+
+ def exec(self, roiClass, timeout=0):
+ """Block until ROI selection is done or timeout is elapsed.
+
+ :meth:`quit` also ends this blocking call.
+
+ :param class roiClass: The class of the ROI which have to be created.
+ See `silx.gui.plot.items.roi`.
+ :param int timeout: Maximum duration in seconds to block.
+ Default: No timeout
+ :return: The list of ROIs
+ :rtype: List[RegionOfInterest]
+ """
+ plot = self.parent()
+ if plot is None:
+ return
+
+ self.__execClass = roiClass
+
+ plot.installEventFilter(self)
+
+ if timeout > 0:
+ self.__timeoutEndTime = time.time() + timeout
+ timer = qt.QTimer(self)
+ timer.timeout.connect(self.__timeoutUpdate)
+ timer.start(1000)
+
+ rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass)
+
+ timer.stop()
+ self.__timeoutEndTime = None
+
+ else:
+ rois = super(InteractiveRegionOfInterestManager, self).exec(roiClass)
+
+ plot.removeEventFilter(self)
+
+ self.__execClass = None
+ self.__updateMessage()
+
+ return rois
+
+ def exec_(self, roiClass, timeout=0): # Qt5-like compatibility
+ return self.exec(roiClass, timeout)
+
+
+class _DeleteRegionOfInterestToolButton(qt.QToolButton):
+ """Tool button deleting a ROI object
+
+ :param parent: See QWidget
+ :param RegionOfInterest roi: The ROI to delete
+ """
+
+ def __init__(self, parent, roi):
+ super(_DeleteRegionOfInterestToolButton, self).__init__(parent)
+ self.setIcon(icons.getQIcon('remove'))
+ self.setToolTip("Remove this ROI")
+ self.__roiRef = roi if roi is None else weakref.ref(roi)
+ self.clicked.connect(self.__clicked)
+
+ def __clicked(self, checked):
+ """Handle button clicked"""
+ roi = None if self.__roiRef is None else self.__roiRef()
+ if roi is not None:
+ manager = roi.parent()
+ if manager is not None:
+ manager.removeRoi(roi)
+ self.__roiRef = None
+
+
+class RegionOfInterestTableWidget(qt.QTableWidget):
+ """Widget displaying the ROIs of a :class:`RegionOfInterestManager`"""
+
+ def __init__(self, parent=None):
+ super(RegionOfInterestTableWidget, self).__init__(parent)
+ self._roiManagerRef = None
+
+ headers = ['Label', 'Edit', 'Kind', 'Coordinates', '']
+ self.setColumnCount(len(headers))
+ self.setHorizontalHeaderLabels(headers)
+
+ horizontalHeader = self.horizontalHeader()
+ horizontalHeader.setDefaultAlignment(qt.Qt.AlignLeft)
+
+ horizontalHeader.setSectionResizeMode(0, qt.QHeaderView.Interactive)
+ horizontalHeader.setSectionResizeMode(1, qt.QHeaderView.ResizeToContents)
+ horizontalHeader.setSectionResizeMode(2, qt.QHeaderView.ResizeToContents)
+ horizontalHeader.setSectionResizeMode(3, qt.QHeaderView.Stretch)
+ horizontalHeader.setSectionResizeMode(4, qt.QHeaderView.ResizeToContents)
+
+ verticalHeader = self.verticalHeader()
+ verticalHeader.setVisible(False)
+
+ self.setSelectionMode(qt.QAbstractItemView.NoSelection)
+ self.setFocusPolicy(qt.Qt.NoFocus)
+
+ self.itemChanged.connect(self.__itemChanged)
+
+ def __itemChanged(self, item):
+ """Handle item updates"""
+ column = item.column()
+ index = item.data(qt.Qt.UserRole)
+
+ if index is not None:
+ manager = self.getRegionOfInterestManager()
+ roi = manager.getRois()[index]
+ else:
+ return
+
+ if column == 0:
+ # First collect information from item, then update ROI
+ # Otherwise, this causes issues issues
+ checked = item.checkState() == qt.Qt.Checked
+ text= item.text()
+ roi.setVisible(checked)
+ roi.setName(text)
+ elif column == 1:
+ roi.setEditable(item.checkState() == qt.Qt.Checked)
+ elif column in (2, 3, 4):
+ pass # TODO
+ else:
+ logger.error('Unhandled column %d', column)
+
+ def setRegionOfInterestManager(self, manager):
+ """Set the :class:`RegionOfInterestManager` object to sync with
+
+ :param RegionOfInterestManager manager:
+ """
+ assert manager is None or isinstance(manager, RegionOfInterestManager)
+
+ previousManager = self.getRegionOfInterestManager()
+
+ if previousManager is not None:
+ previousManager.sigRoiChanged.disconnect(self._sync)
+ self.setRowCount(0)
+
+ self._roiManagerRef = weakref.ref(manager)
+
+ self._sync()
+
+ if manager is not None:
+ manager.sigRoiChanged.connect(self._sync)
+
+ def _getReadableRoiDescription(self, roi):
+ """Returns modelisation of a ROI as a readable sequence of values.
+
+ :rtype: str
+ """
+ text = str(roi)
+ try:
+ # Extract the params from syntax "CLASSNAME(PARAMS)"
+ elements = text.split("(", 1)
+ if len(elements) != 2:
+ return text
+ result = elements[1]
+ result = result.strip()
+ if not result.endswith(")"):
+ return text
+ result = result[0:-1]
+ # Capitalize each words
+ result = result.title()
+ return result
+ except Exception:
+ logger.debug("Backtrace", exc_info=True)
+ return text
+
+ def _sync(self):
+ """Update widget content according to ROI manger"""
+ manager = self.getRegionOfInterestManager()
+
+ if manager is None:
+ self.setRowCount(0)
+ return
+
+ rois = manager.getRois()
+
+ self.setRowCount(len(rois))
+ for index, roi in enumerate(rois):
+ baseFlags = qt.Qt.ItemIsSelectable | qt.Qt.ItemIsEnabled
+
+ # Label and visible
+ label = roi.getName()
+ item = qt.QTableWidgetItem(label)
+ item.setFlags(baseFlags | qt.Qt.ItemIsEditable | qt.Qt.ItemIsUserCheckable)
+ item.setData(qt.Qt.UserRole, index)
+ item.setCheckState(
+ qt.Qt.Checked if roi.isVisible() else qt.Qt.Unchecked)
+ self.setItem(index, 0, item)
+
+ # Editable
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags | qt.Qt.ItemIsUserCheckable)
+ item.setData(qt.Qt.UserRole, index)
+ item.setCheckState(
+ qt.Qt.Checked if roi.isEditable() else qt.Qt.Unchecked)
+ self.setItem(index, 1, item)
+ item.setTextAlignment(qt.Qt.AlignCenter)
+ item.setText(None)
+
+ # Kind
+ label = roi._getShortName()
+ if label is None:
+ # Default value if kind is not overrided
+ label = roi.__class__.__name__
+ item = qt.QTableWidgetItem(label.capitalize())
+ item.setFlags(baseFlags)
+ self.setItem(index, 2, item)
+
+ item = qt.QTableWidgetItem()
+ item.setFlags(baseFlags)
+
+ # Coordinates
+ text = self._getReadableRoiDescription(roi)
+ item.setText(text)
+ self.setItem(index, 3, item)
+
+ # Delete
+ delBtn = _DeleteRegionOfInterestToolButton(None, roi)
+ widget = qt.QWidget(self)
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(2, 2, 2, 2)
+ layout.setSpacing(0)
+ widget.setLayout(layout)
+ layout.addStretch(1)
+ layout.addWidget(delBtn)
+ layout.addStretch(1)
+ self.setCellWidget(index, 4, widget)
+
+ def getRegionOfInterestManager(self):
+ """Returns the :class:`RegionOfInterestManager` this widget supervise.
+
+ It returns None if not sync with an :class:`RegionOfInterestManager`.
+
+ :rtype: RegionOfInterestManager
+ """
+ return None if self._roiManagerRef is None else self._roiManagerRef()
diff --git a/src/silx/gui/plot/tools/test/__init__.py b/src/silx/gui/plot/tools/test/__init__.py
new file mode 100644
index 0000000..aa4a601
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
new file mode 100644
index 0000000..37af10e
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testCurveLegendsWidget.py
@@ -0,0 +1,113 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/08/2018"
+
+
+import unittest
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.tools import CurveLegendsWidget
+
+
+class TestCurveLegendsWidget(TestCaseQt, ParametricTestCase):
+ """Tests for CurveLegendsWidget class"""
+
+ def setUp(self):
+ super(TestCurveLegendsWidget, self).setUp()
+ self.plot = PlotWindow()
+
+ self.legends = CurveLegendsWidget.CurveLegendsWidget()
+ self.legends.setPlotWidget(self.plot)
+
+ dock = qt.QDockWidget()
+ dock.setWindowTitle('Curve Legends')
+ dock.setWidget(self.legends)
+ self.plot.addTabbedDockWidget(dock)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.legends
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestCurveLegendsWidget, self).tearDown()
+
+ def _assertNbLegends(self, count):
+ """Check the number of legends in the CurveLegendsWidget"""
+ children = self.legends.findChildren(CurveLegendsWidget._LegendWidget)
+ self.assertEqual(len(children), count)
+
+ def testAddRemoveCurves(self):
+ """Test CurveLegendsWidget while adding/removing curves"""
+ self.plot.addCurve((0, 1), (1, 2), legend='a')
+ self._assertNbLegends(1)
+ self.plot.addCurve((0, 1), (2, 3), legend='b')
+ self._assertNbLegends(2)
+
+ # Detached/attach
+ self.legends.setPlotWidget(None)
+ self._assertNbLegends(0)
+
+ self.legends.setPlotWidget(self.plot)
+ self._assertNbLegends(2)
+
+ self.plot.clear()
+ self._assertNbLegends(0)
+
+ def testUpdateCurves(self):
+ """Test CurveLegendsWidget while updating curves """
+ self.plot.addCurve((0, 1), (1, 2), legend='a')
+ self._assertNbLegends(1)
+ self.plot.addCurve((0, 1), (2, 3), legend='b')
+ self._assertNbLegends(2)
+
+ # Activate curve
+ self.plot.setActiveCurve('a')
+ self.qapp.processEvents()
+ self.plot.setActiveCurve('b')
+ self.qapp.processEvents()
+
+ # Change curve style
+ curve = self.plot.getCurve('a')
+ curve.setLineWidth(2)
+ for linestyle in (':', '', '--', '-'):
+ with self.subTest(linestyle=linestyle):
+ curve.setLineStyle(linestyle)
+ self.qapp.processEvents()
+ self.qWait(1000)
+
+ for symbol in ('o', 'd', '', 's'):
+ with self.subTest(symbol=symbol):
+ curve.setSymbol(symbol)
+ self.qapp.processEvents()
+ self.qWait(1000)
diff --git a/src/silx/gui/plot/tools/test/testProfile.py b/src/silx/gui/plot/tools/test/testProfile.py
new file mode 100644
index 0000000..829f49e
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testProfile.py
@@ -0,0 +1,654 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import unittest
+import contextlib
+import numpy
+import logging
+
+from silx.gui import qt
+from silx.utils import deprecation
+from silx.utils import testutils
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.plot import PlotWindow, Plot1D, Plot2D, Profile
+from silx.gui.plot.StackView import StackView
+from silx.gui.plot.tools.profile import rois
+from silx.gui.plot.tools.profile import editors
+from silx.gui.plot.items import roi as roi_items
+from silx.gui.plot.tools.profile import manager
+from silx.gui import plot as silx_plot
+
+_logger = logging.getLogger(__name__)
+
+
+class TestRois(TestCaseQt):
+
+ def test_init(self):
+ """Check that the constructor is not called twice"""
+ roi = rois.ProfileImageVerticalLineROI()
+ if qt.BINDING == "PyQt5":
+ # the profile ROI + the shape
+ self.assertEqual(roi.receivers(roi.sigRegionChanged), 2)
+
+
+class TestInteractions(TestCaseQt):
+
+ @contextlib.contextmanager
+ def defaultPlot(self):
+ try:
+ widget = silx_plot.PlotWidget()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def imagePlot(self):
+ try:
+ widget = silx_plot.Plot2D()
+ image = numpy.arange(10 * 10).reshape(10, -1)
+ widget.addImage(image)
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def scatterPlot(self):
+ try:
+ widget = silx_plot.ScatterView()
+
+ nbX, nbY = 7, 5
+ yy = numpy.atleast_2d(numpy.ones(nbY)).T
+ xx = numpy.atleast_2d(numpy.ones(nbX))
+ positionX = numpy.linspace(10, 50, nbX) * yy
+ positionX = positionX.reshape(nbX * nbY)
+ positionY = numpy.atleast_2d(numpy.linspace(20, 60, nbY)).T * xx
+ positionY = positionY.reshape(nbX * nbY)
+ values = numpy.arange(nbX * nbY)
+
+ widget.setData(positionX, positionY, values)
+ widget.resetZoom()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget.getPlotWidget()
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ @contextlib.contextmanager
+ def stackPlot(self):
+ try:
+ widget = silx_plot.StackView()
+ image = numpy.arange(10 * 10).reshape(10, -1)
+ cube = numpy.array([image, image, image])
+ widget.setStack(cube)
+ widget.resetZoom()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+ yield widget.getPlotWidget()
+ finally:
+ widget.close()
+ widget = None
+ self.qWait()
+
+ def waitPendingOperations(self, proflie):
+ for _ in range(10):
+ if not proflie.hasPendingOperations():
+ return
+ self.qWait(100)
+ _logger.error("The profile manager still have pending operations")
+
+ def genericRoiTest(self, plot, roiClass):
+ profileManager = manager.ProfileManager(plot, plot)
+ profileManager.setItemType(image=True, scatter=True)
+
+ try:
+ action = profileManager.createProfileAction(roiClass, plot)
+ action.triggered[bool].emit(True)
+ widget = plot.getWidgetHandle()
+
+ # Do the mouse interaction
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ if issubclass(roiClass, roi_items.LineROI):
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+ self.mouseMove(widget, pos=pos2)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.waitPendingOperations(profileManager)
+
+ # Test that something was computed
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ self.assertEqual(profileManager._computedProfiles, 2)
+ elif issubclass(roiClass, roi_items.LineROI):
+ self.assertGreaterEqual(profileManager._computedProfiles, 1)
+ else:
+ self.assertEqual(profileManager._computedProfiles, 1)
+
+ # Test the created ROIs
+ profileRois = profileManager.getRoiManager().getRois()
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ self.assertEqual(len(profileRois), 3)
+ else:
+ self.assertEqual(len(profileRois), 1)
+ # The first one should be the expected one
+ roi = profileRois[0]
+
+ # Test that something was displayed
+ if issubclass(roiClass, rois._ProfileCrossROI):
+ profiles = roi._getLines()
+ window = profiles[0].getProfileWindow()
+ self.assertIsNotNone(window)
+ window = profiles[1].getProfileWindow()
+ self.assertIsNotNone(window)
+ else:
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ finally:
+ profileManager.clearProfile()
+
+ def testImageActions(self):
+ roiClasses = [
+ rois.ProfileImageHorizontalLineROI,
+ rois.ProfileImageVerticalLineROI,
+ rois.ProfileImageLineROI,
+ rois.ProfileImageCrossROI,
+ ]
+ with self.imagePlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def testScatterActions(self):
+ roiClasses = [
+ rois.ProfileScatterHorizontalLineROI,
+ rois.ProfileScatterVerticalLineROI,
+ rois.ProfileScatterLineROI,
+ rois.ProfileScatterCrossROI,
+ rois.ProfileScatterHorizontalSliceROI,
+ rois.ProfileScatterVerticalSliceROI,
+ rois.ProfileScatterCrossSliceROI,
+ ]
+ with self.scatterPlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def testStackActions(self):
+ roiClasses = [
+ rois.ProfileImageStackHorizontalLineROI,
+ rois.ProfileImageStackVerticalLineROI,
+ rois.ProfileImageStackLineROI,
+ rois.ProfileImageStackCrossROI,
+ ]
+ with self.stackPlot() as plot:
+ for roiClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ self.genericRoiTest(plot, roiClass)
+
+ def genericEditorTest(self, plot, roi, editor):
+ if isinstance(editor, editors._NoProfileRoiEditor):
+ pass
+ elif isinstance(editor, editors._DefaultImageStackProfileRoiEditor):
+ # GUI to ROI
+ editor._lineWidth.setValue(2)
+ self.assertEqual(roi.getProfileLineWidth(), 2)
+ editor._methodsButton.setMethod("sum")
+ self.assertEqual(roi.getProfileMethod(), "sum")
+ editor._profileDim.setDimension(1)
+ self.assertEqual(roi.getProfileType(), "1D")
+ # ROI to GUI
+ roi.setProfileLineWidth(3)
+ self.assertEqual(editor._lineWidth.value(), 3)
+ roi.setProfileMethod("mean")
+ self.assertEqual(editor._methodsButton.getMethod(), "mean")
+ roi.setProfileType("2D")
+ self.assertEqual(editor._profileDim.getDimension(), 2)
+ elif isinstance(editor, editors._DefaultImageProfileRoiEditor):
+ # GUI to ROI
+ editor._lineWidth.setValue(2)
+ self.assertEqual(roi.getProfileLineWidth(), 2)
+ editor._methodsButton.setMethod("sum")
+ self.assertEqual(roi.getProfileMethod(), "sum")
+ # ROI to GUI
+ roi.setProfileLineWidth(3)
+ self.assertEqual(editor._lineWidth.value(), 3)
+ roi.setProfileMethod("mean")
+ self.assertEqual(editor._methodsButton.getMethod(), "mean")
+ elif isinstance(editor, editors._DefaultScatterProfileRoiEditor):
+ # GUI to ROI
+ editor._nPoints.setValue(100)
+ self.assertEqual(roi.getNPoints(), 100)
+ # ROI to GUI
+ roi.setNPoints(200)
+ self.assertEqual(editor._nPoints.value(), 200)
+ else:
+ assert False
+
+ def testEditors(self):
+ roiClasses = [
+ (rois.ProfileImageHorizontalLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageVerticalLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageLineROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileImageCrossROI, editors._DefaultImageProfileRoiEditor),
+ (rois.ProfileScatterHorizontalLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterVerticalLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterLineROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterCrossROI, editors._DefaultScatterProfileRoiEditor),
+ (rois.ProfileScatterHorizontalSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileScatterVerticalSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileScatterCrossSliceROI, editors._NoProfileRoiEditor),
+ (rois.ProfileImageStackHorizontalLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (rois.ProfileImageStackVerticalLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (rois.ProfileImageStackLineROI, editors._DefaultImageStackProfileRoiEditor),
+ (rois.ProfileImageStackCrossROI, editors._DefaultImageStackProfileRoiEditor),
+ ]
+ with self.defaultPlot() as plot:
+ profileManager = manager.ProfileManager(plot, plot)
+ editorAction = profileManager.createEditorAction(parent=plot)
+ for roiClass, editorClass in roiClasses:
+ with self.subTest(roiClass=roiClass):
+ roi = roiClass()
+ roi._setProfileManager(profileManager)
+ try:
+ # Force widget creation
+ menu = qt.QMenu(plot)
+ menu.addAction(editorAction)
+ widgets = editorAction.createdWidgets()
+ self.assertGreater(len(widgets), 0)
+
+ editorAction.setProfileRoi(roi)
+ editorWidget = editorAction._getEditor(widgets[0])
+ self.assertIsInstance(editorWidget, editorClass)
+ self.genericEditorTest(plot, roi, editorWidget)
+ finally:
+ editorAction.setProfileRoi(None)
+ menu.deleteLater()
+ menu = None
+ self.qapp.processEvents()
+
+
+class TestProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ProfileToolBar widget."""
+
+ def setUp(self):
+ super(TestProfileToolBar, self).setUp()
+ self.plot = PlotWindow()
+ self.toolBar = Profile.ProfileToolBar(plot=self.plot)
+ self.plot.addToolBar(self.toolBar)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.mouseMove(self.plot) # Move to center
+ self.qapp.processEvents()
+ deprecation.FORCE = True
+
+ def tearDown(self):
+ deprecation.FORCE = False
+ self.qapp.processEvents()
+ profileManager = self.toolBar.getProfileManager()
+ profileManager.clearProfile()
+ profileManager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ del self.toolBar
+
+ super(TestProfileToolBar, self).tearDown()
+
+ def testAlignedProfile(self):
+ """Test horizontal and vertical profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+ for method in ('sum', 'mean'):
+ with self.subTest(method=method):
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ for action in (self.toolBar.hLineAction, self.toolBar.vLineAction):
+ with self.subTest(mode=action.text()):
+ # Trigger tool button for mode
+ action.trigger()
+ # Without image
+ self.mouseMove(widget, pos=pos1)
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=pos1)
+
+ # with image
+ self.plot.addImage(
+ numpy.arange(100 * 100).reshape(100, -1))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(widget, pos=pos2)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+
+ self.mouseMove(widget)
+ self.mouseClick(widget, qt.Qt.LeftButton)
+
+ manager = self.toolBar.getProfileManager()
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=4)
+ def testDiagonalProfile(self):
+ """Test diagonal profile, without and with image"""
+ # Use Plot backend widget to submit mouse events
+ widget = self.plot.getWidgetHandle()
+
+ self.plot.addImage(
+ numpy.arange(100 * 100).reshape(100, -1))
+
+ for method in ('sum', 'mean'):
+ with self.subTest(method=method):
+ # 2 positions to use for mouse events
+ pos1 = widget.width() * 0.4, widget.height() * 0.4
+ pos2 = widget.width() * 0.6, widget.height() * 0.6
+
+ # Trigger tool button for diagonal profile mode
+ self.toolBar.lineAction.trigger()
+
+ # draw profile line
+ widget.setFocus(qt.Qt.OtherFocusReason)
+ self.mouseMove(widget, pos=pos1)
+ self.qWait(100)
+ self.mousePress(widget, qt.Qt.LeftButton, pos=pos1)
+ self.qWait(100)
+ self.mouseMove(widget, pos=pos2)
+ self.qWait(100)
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=pos2)
+ self.qWait(100)
+
+ manager = self.toolBar.getProfileManager()
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ roi = manager.getCurrentRoi()
+ self.assertIsNotNone(roi)
+ roi.setProfileLineWidth(3)
+ roi.setProfileMethod(method)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ curveItem = self.toolBar.getProfilePlot().getAllCurves()[0]
+ if method == 'sum':
+ self.assertTrue(curveItem.getData()[1].max() > 10000)
+ elif method == 'mean':
+ self.assertTrue(curveItem.getData()[1].max() < 10000)
+
+ # Remove the ROI so the profile window is also removed
+ roiManager = manager.getRoiManager()
+ roiManager.removeRoi(roi)
+ self.qWait(100)
+
+
+class TestDeprecatedProfileToolBar(TestCaseQt):
+ """Tests old features of the ProfileToolBar widget."""
+
+ def setUp(self):
+ self.plot = None
+ super(TestDeprecatedProfileToolBar, self).setUp()
+
+ def tearDown(self):
+ if self.plot is not None:
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+ self.qWait()
+
+ super(TestDeprecatedProfileToolBar, self).tearDown()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=2)
+ def testCustomProfileWindow(self):
+ from silx.gui.plot import ProfileMainWindow
+
+ self.plot = PlotWindow()
+ profileWindow = ProfileMainWindow.ProfileMainWindow(self.plot)
+ toolBar = Profile.ProfileToolBar(parent=self.plot,
+ plot=self.plot,
+ profileWindow=profileWindow)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ profileWindow.show()
+ self.qWaitForWindowExposed(profileWindow)
+ self.qapp.processEvents()
+
+ self.plot.addImage(numpy.arange(10 * 10).reshape(10, -1))
+ profile = rois.ProfileImageHorizontalLineROI()
+ profile.setPosition(5)
+ toolBar.getProfileManager().getRoiManager().addRoi(profile)
+ toolBar.getProfileManager().getRoiManager().setCurrentRoi(profile)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not toolBar.getProfileManager().hasPendingOperations():
+ break
+
+ # There is a displayed profile
+ self.assertIsNotNone(profileWindow.getProfile())
+ self.assertIs(toolBar.getProfileMainWindow(), profileWindow)
+
+ # There is nothing anymore but the window is still there
+ toolBar.getProfileManager().clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(profileWindow.getProfile())
+
+
+class TestProfile3DToolBar(TestCaseQt):
+ """Tests for Profile3DToolBar widget.
+ """
+ def setUp(self):
+ super(TestProfile3DToolBar, self).setUp()
+ self.plot = StackView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.plot.setStack(numpy.array([
+ [[0, 1, 2], [3, 4, 5]],
+ [[6, 7, 8], [9, 10, 11]],
+ [[12, 13, 14], [15, 16, 17]]
+ ]))
+ deprecation.FORCE = True
+
+ def tearDown(self):
+ deprecation.FORCE = False
+ profileManager = self.plot.getProfileToolbar().getProfileManager()
+ profileManager.clearProfile()
+ profileManager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+
+ super(TestProfile3DToolBar, self).tearDown()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=2)
+ def testMethodProfile2D(self):
+ """Test that the profile can have a different method if we want to
+ compute then in 1D or in 2D"""
+
+ toolBar = self.plot.getProfileToolbar()
+
+ toolBar.vLineAction.trigger()
+ plot2D = self.plot.getPlotWidget().getWidgetHandle()
+ pos1 = plot2D.width() * 0.5, plot2D.height() * 0.5
+ self.mouseClick(plot2D, qt.Qt.LeftButton, pos=pos1)
+
+ manager = toolBar.getProfileManager()
+ roi = manager.getCurrentRoi()
+ roi.setProfileMethod("mean")
+ roi.setProfileType("2D")
+ roi.setProfileLineWidth(3)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ # check 2D 'mean' profile
+ profilePlot = toolBar.getProfilePlot()
+ data = profilePlot.getAllImages()[0].getData()
+ expected = numpy.array([[1, 4], [7, 10], [13, 16]])
+ numpy.testing.assert_almost_equal(data, expected)
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=2)
+ def testMethodSumLine(self):
+ """Simple interaction test to make sure the sum is correctly computed
+ """
+ toolBar = self.plot.getProfileToolbar()
+
+ toolBar.lineAction.trigger()
+ plot2D = self.plot.getPlotWidget().getWidgetHandle()
+ pos1 = plot2D.width() * 0.5, plot2D.height() * 0.2
+ pos2 = plot2D.width() * 0.5, plot2D.height() * 0.8
+
+ self.mouseMove(plot2D, pos=pos1)
+ self.mousePress(plot2D, qt.Qt.LeftButton, pos=pos1)
+ self.mouseMove(plot2D, pos=pos2)
+ self.mouseRelease(plot2D, qt.Qt.LeftButton, pos=pos2)
+
+ manager = toolBar.getProfileManager()
+ roi = manager.getCurrentRoi()
+ roi.setProfileMethod("sum")
+ roi.setProfileType("2D")
+ roi.setProfileLineWidth(3)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ # check 2D 'sum' profile
+ profilePlot = toolBar.getProfilePlot()
+ data = profilePlot.getAllImages()[0].getData()
+ expected = numpy.array([[3, 12], [21, 30], [39, 48]])
+ numpy.testing.assert_almost_equal(data, expected)
+
+
+class TestGetProfilePlot(TestCaseQt):
+
+ def setUp(self):
+ self.plot = None
+ super(TestGetProfilePlot, self).setUp()
+
+ def tearDown(self):
+ if self.plot is not None:
+ manager = self.plot.getProfileToolbar().getProfileManager()
+ manager.clearProfile()
+ manager = None
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ self.plot = None
+
+ super(TestGetProfilePlot, self).tearDown()
+
+ def testProfile1D(self):
+ self.plot = Plot2D()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+ self.plot.addImage([[0, 1], [2, 3]])
+
+ toolBar = self.plot.getProfileToolbar()
+
+ manager = toolBar.getProfileManager()
+ roiManager = manager.getRoiManager()
+
+ roi = rois.ProfileImageHorizontalLineROI()
+ roi.setPosition(0.5)
+ roiManager.addRoi(roi)
+ roiManager.setCurrentRoi(roi)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
+
+ def testProfile2D(self):
+ """Test that the profile plot associated to a stack view is either a
+ Plot1D or a plot 2D instance."""
+ self.plot = StackView()
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ self.plot.setStack(numpy.array([[[0, 1], [2, 3]],
+ [[4, 5], [6, 7]]]))
+
+ toolBar = self.plot.getProfileToolbar()
+
+ manager = toolBar.getProfileManager()
+ roiManager = manager.getRoiManager()
+
+ roi = rois.ProfileImageStackHorizontalLineROI()
+ roi.setPosition(0.5)
+ roi.setProfileType("2D")
+ roiManager.addRoi(roi)
+ roiManager.setCurrentRoi(roi)
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot2D)
+
+ roi.setProfileType("1D")
+
+ for _ in range(20):
+ self.qWait(200)
+ if not manager.hasPendingOperations():
+ break
+
+ profileWindow = roi.getProfileWindow()
+ self.assertIsInstance(roi.getProfileWindow(), qt.QMainWindow)
+ self.assertIsInstance(profileWindow.getCurrentPlotWidget(), Plot1D)
diff --git a/src/silx/gui/plot/tools/test/testROI.py b/src/silx/gui/plot/tools/test/testROI.py
new file mode 100644
index 0000000..21697d1
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testROI.py
@@ -0,0 +1,682 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import unittest
+import numpy.testing
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+from silx.gui.plot import PlotWindow
+import silx.gui.plot.items.roi as roi_items
+from silx.gui.plot.tools import roi
+
+
+class TestRoiItems(TestCaseQt):
+
+ def testLine_geometry(self):
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[0], startPoint)
+ numpy.testing.assert_allclose(item.getEndPoints()[1], endPoint)
+
+ def testHLine_geometry(self):
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ self.assertEqual(item.getPosition(), 15)
+
+ def testVLine_geometry(self):
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ self.assertEqual(item.getPosition(), 15)
+
+ def testPoint_geometry(self):
+ point = numpy.array([1, 2])
+ item = roi_items.PointROI()
+ item.setPosition(point)
+ numpy.testing.assert_allclose(item.getPosition(), point)
+
+ def testRectangle_originGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+ def testRectangle_centerGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ center = numpy.array([5, 10])
+ item = roi_items.RectangleROI()
+ item.setGeometry(center=center, size=size)
+ numpy.testing.assert_allclose(item.getOrigin(), origin)
+ numpy.testing.assert_allclose(item.getSize(), size)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+
+ def testRectangle_setCenterGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newCenter = numpy.array([0, 0])
+ item.setCenter(newCenter)
+ expectedOrigin = numpy.array([-5, -10])
+ numpy.testing.assert_allclose(item.getOrigin(), expectedOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+ def testRectangle_setOriginGeometry(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ newOrigin = numpy.array([10, 10])
+ item.setOrigin(newOrigin)
+ expectedCenter = numpy.array([15, 20])
+ numpy.testing.assert_allclose(item.getOrigin(), newOrigin)
+ numpy.testing.assert_allclose(item.getCenter(), expectedCenter)
+ numpy.testing.assert_allclose(item.getSize(), size)
+
+ def testCircle_geometry(self):
+ center = numpy.array([0, 0])
+ radius = 10.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+ def testCircle_setCenter(self):
+ center = numpy.array([0, 0])
+ radius = 10.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newCenter = numpy.array([-10, 0])
+ item.setCenter(newCenter)
+ numpy.testing.assert_allclose(item.getCenter(), newCenter)
+ numpy.testing.assert_allclose(item.getRadius(), radius)
+
+ def testCircle_setRadius(self):
+ center = numpy.array([0, 0])
+ radius = 10.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ newRadius = 5.1
+ item.setRadius(newRadius)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ numpy.testing.assert_allclose(item.getRadius(), newRadius)
+
+ def testCircle_contains(self):
+ center = numpy.array([2, -1])
+ radius = 1.
+ item = roi_items.CircleROI()
+ item.setGeometry(center=center, radius=radius)
+ self.assertTrue(item.contains([1, -1]))
+ self.assertFalse(item.contains([0, 0]))
+ self.assertTrue(item.contains([2, 0]))
+ self.assertFalse(item.contains([3.01, -1]))
+
+ def testEllipse_contains(self):
+ center = numpy.array([-2, 0])
+ item = roi_items.EllipseROI()
+ item.setCenter(center)
+ item.setOrientation(numpy.pi / 4.0)
+ item.setMajorRadius(2)
+ item.setMinorRadius(1)
+ print(item.getMinorRadius(), item.getMajorRadius())
+ self.assertFalse(item.contains([0, 0]))
+ self.assertTrue(item.contains([-1, 1]))
+ self.assertTrue(item.contains([-3, 0]))
+ self.assertTrue(item.contains([-2, 0]))
+ self.assertTrue(item.contains([-2, 1]))
+ self.assertFalse(item.contains([-4, 1]))
+
+ def testRectangle_isIn(self):
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin, size=size)
+ self.assertTrue(item.contains(position=(0, 0)))
+ self.assertTrue(item.contains(position=(2, 14)))
+ self.assertFalse(item.contains(position=(14, 12)))
+
+ def testPolygon_emptyGeometry(self):
+ points = numpy.empty((0, 2))
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+ def testPolygon_geometry(self):
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ numpy.testing.assert_allclose(item.getPoints(), points)
+
+ def testPolygon_isIn(self):
+ points = numpy.array([[0, 0], [0, 10], [5, 10]])
+ item = roi_items.PolygonROI()
+ item.setPoints(points)
+ self.assertTrue(item.contains((0, 0)))
+ self.assertFalse(item.contains((6, 2)))
+ self.assertFalse(item.contains((-2, 5)))
+ self.assertFalse(item.contains((2, -1)))
+ self.assertFalse(item.contains((8, 1)))
+ self.assertTrue(item.contains((1, 8)))
+
+ def testArc_getToSetGeometry(self):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ item.setFirstShapePoints(numpy.array([[5, 10], [50, 100]]))
+ item.setGeometry(*item.getGeometry())
+
+ def testArc_degenerated_point(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+ def testArc_degenerated_line(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+
+ def testArc_special_circle(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
+ self.assertTrue(item.isClosed())
+
+ def testArc_special_donut(self):
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), item.getEndAngle() - numpy.pi * 2.0)
+ self.assertTrue(item.isClosed())
+
+ def testArc_clockwiseGeometry(self):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), startAngle)
+ self.assertAlmostEqual(item.getEndAngle(), endAngle)
+ self.assertAlmostEqual(item.isClosed(), False)
+
+ def testArc_anticlockwiseGeometry(self):
+ """Test that we can use getGeometry as input to setGeometry"""
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, -numpy.pi * 0.5
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ numpy.testing.assert_allclose(item.getCenter(), center)
+ self.assertAlmostEqual(item.getInnerRadius(), innerRadius)
+ self.assertAlmostEqual(item.getOuterRadius(), outerRadius)
+ self.assertAlmostEqual(item.getStartAngle(), startAngle)
+ self.assertAlmostEqual(item.getEndAngle(), endAngle)
+ self.assertAlmostEqual(item.isClosed(), False)
+
+ def testHRange_geometry(self):
+ item = roi_items.HorizontalRangeROI()
+ vmin = 1
+ vmax = 3
+ item.setRange(vmin, vmax)
+ self.assertAlmostEqual(item.getMin(), vmin)
+ self.assertAlmostEqual(item.getMax(), vmax)
+ self.assertAlmostEqual(item.getCenter(), 2)
+
+
+class TestRegionOfInterestManager(TestCaseQt, ParametricTestCase):
+ """Tests for RegionOfInterestManager class"""
+
+ def setUp(self):
+ super(TestRegionOfInterestManager, self).setUp()
+ self.plot = PlotWindow()
+
+ self.roiTableWidget = roi.RegionOfInterestTableWidget()
+ dock = qt.QDockWidget()
+ dock.setWidget(self.roiTableWidget)
+ self.plot.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.roiTableWidget
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestRegionOfInterestManager, self).tearDown()
+
+ def test(self):
+ """Test ROI of different shapes"""
+ tests = ( # shape, points=[list of (x, y), list of (x, y)]
+ (roi_items.PointROI, numpy.array(([(10., 15.)], [(20., 25.)]))),
+ (roi_items.RectangleROI,
+ numpy.array((((1., 10.), (11., 20.)),
+ ((2., 3.), (12., 13.))))),
+ (roi_items.PolygonROI,
+ numpy.array((((0., 1.), (0., 10.), (10., 0.)),
+ ((5., 6.), (5., 16.), (15., 6.))))),
+ (roi_items.LineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ (roi_items.HorizontalLineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ (roi_items.VerticalLineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ (roi_items.HorizontalLineROI,
+ numpy.array((((10., 20.), (10., 30.)),
+ ((30., 40.), (30., 50.))))),
+ )
+
+ for roiClass, points in tests:
+ with self.subTest(roiClass=roiClass):
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ manager.start(roiClass)
+
+ self.assertEqual(manager.getRois(), ())
+
+ finishListener = SignalListener()
+ manager.sigInteractiveModeFinished.connect(finishListener)
+
+ changedListener = SignalListener()
+ manager.sigRoiChanged.connect(changedListener)
+
+ # Add a point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 1)
+ self.assertEqual(changedListener.callCount(), 1)
+
+ # Remove it
+ manager.removeRoi(manager.getRois()[0])
+ self.assertEqual(manager.getRois(), ())
+ self.assertEqual(changedListener.callCount(), 2)
+
+ # Add two point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ r = roiClass()
+ r.setFirstShapePoints(points[1])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 2)
+ self.assertEqual(changedListener.callCount(), 4)
+
+ # Reset it
+ result = manager.clear()
+ self.assertTrue(result)
+ self.assertEqual(manager.getRois(), ())
+ self.assertEqual(changedListener.callCount(), 5)
+
+ changedListener.clear()
+
+ # Add two point
+ r = roiClass()
+ r.setFirstShapePoints(points[0])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ r = roiClass()
+ r.setFirstShapePoints(points[1])
+ manager.addRoi(r)
+ self.qapp.processEvents()
+ self.assertTrue(len(manager.getRois()), 2)
+ self.assertEqual(changedListener.callCount(), 2)
+
+ # stop
+ result = manager.stop()
+ self.assertTrue(result)
+ self.assertTrue(len(manager.getRois()), 1)
+ self.qapp.processEvents()
+ self.assertEqual(finishListener.callCount(), 1)
+
+ manager.clear()
+
+ def testRoiDisplay(self):
+ rois = []
+
+ # Line
+ item = roi_items.LineROI()
+ startPoint = numpy.array([1, 2])
+ endPoint = numpy.array([3, 4])
+ item.setEndPoints(startPoint, endPoint)
+ rois.append(item)
+ # Horizontal line
+ item = roi_items.HorizontalLineROI()
+ item.setPosition(15)
+ rois.append(item)
+ # Vertical line
+ item = roi_items.VerticalLineROI()
+ item.setPosition(15)
+ rois.append(item)
+ # Point
+ item = roi_items.PointROI()
+ point = numpy.array([1, 2])
+ item.setPosition(point)
+ rois.append(item)
+ # Rectangle
+ item = roi_items.RectangleROI()
+ origin = numpy.array([0, 0])
+ size = numpy.array([10, 20])
+ item.setGeometry(origin=origin, size=size)
+ rois.append(item)
+ # Polygon
+ item = roi_items.PolygonROI()
+ points = numpy.array([[10, 10], [12, 10], [50, 1]])
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated polygon: No points
+ item = roi_items.PolygonROI()
+ points = numpy.empty((0, 2))
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated polygon: A single point
+ item = roi_items.PolygonROI()
+ points = numpy.array([[5, 10]])
+ item.setPoints(points)
+ rois.append(item)
+ # Degenerated arc: it's a point
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 0, 0, 0
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Degenerated arc: it's a line
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 0, 100, numpy.pi, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Special arc: it's a donut
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi, 3 * numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Arc
+ item = roi_items.ArcROI()
+ center = numpy.array([10, 20])
+ innerRadius, outerRadius, startAngle, endAngle = 1, 100, numpy.pi * 0.5, numpy.pi
+ item.setGeometry(center, innerRadius, outerRadius, startAngle, endAngle)
+ rois.append(item)
+ # Horizontal Range
+ item = roi_items.HorizontalRangeROI()
+ item.setRange(-1, 3)
+ rois.append(item)
+
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ for item in rois:
+ with self.subTest(roi=str(item)):
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ item.setEditable(True)
+ self.qapp.processEvents()
+ item.setEditable(False)
+ self.qapp.processEvents()
+ manager.removeRoi(item)
+ self.qapp.processEvents()
+
+ def testSelectionProxy(self):
+ item1 = roi_items.PointROI()
+ item1.setSelectable(True)
+ item2 = roi_items.PointROI()
+ item2.setSelectable(True)
+ item1.setFocusProxy(item2)
+ manager = roi.RegionOfInterestManager(self.plot)
+ manager.setCurrentRoi(item1)
+ self.assertIs(manager.getCurrentRoi(), item2)
+
+ def testRemovedSelection(self):
+ item1 = roi_items.PointROI()
+ item1.setSelectable(True)
+ manager = roi.RegionOfInterestManager(self.plot)
+ manager.addRoi(item1)
+ manager.setCurrentRoi(item1)
+ manager.removeRoi(item1)
+ self.assertIs(manager.getCurrentRoi(), None)
+
+ def testMaxROI(self):
+ """Test Max ROI"""
+ origin1 = numpy.array([1., 10.])
+ size1 = numpy.array([10., 10.])
+ origin2 = numpy.array([2., 3.])
+ size2 = numpy.array([10., 10.])
+
+ manager = roi.InteractiveRegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ self.assertEqual(manager.getRois(), ())
+
+ changedListener = SignalListener()
+ manager.sigRoiChanged.connect(changedListener)
+
+ # Add two point
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin2, size=size2)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 2)
+ self.assertEqual(len(manager.getRois()), 2)
+
+ # Try to set max ROI to 1 while there is 2 ROIs
+ with self.assertRaises(ValueError):
+ manager.setMaxRois(1)
+
+ manager.clear()
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(changedListener.callCount(), 3)
+
+ # Set max limit to 1
+ manager.setMaxRois(1)
+
+ # Add a point
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 4)
+
+ # Add a 2nd point while max ROI is 1
+ item = roi_items.RectangleROI()
+ item.setGeometry(origin=origin1, size=size1)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+ self.assertEqual(changedListener.callCount(), 6)
+ self.assertEqual(len(manager.getRois()), 1)
+
+ def testChangeInteractionMode(self):
+ """Test change of interaction mode"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ self.roiTableWidget.setRegionOfInterestManager(manager)
+ manager.start(roi_items.PointROI)
+
+ interactiveModeToolBar = self.plot.getInteractiveModeToolBar()
+ panAction = interactiveModeToolBar.getPanModeAction()
+
+ for roiClass in manager.getSupportedRoiClasses():
+ with self.subTest(roiClass=roiClass):
+ # Change to pan mode
+ panAction.trigger()
+
+ # Change to interactive ROI mode
+ action = manager.getInteractionModeAction(roiClass)
+ action.trigger()
+
+ self.assertEqual(roiClass, manager.getCurrentInteractionModeRoiClass())
+
+ manager.clear()
+
+ def testLineInteraction(self):
+ """This test make sure that a ROI based on handles can be edited with
+ the mouse."""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints(points[0], points[1])
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Drag the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+ self.mouseMove(widget, pos=(mx, my))
+ self.mousePress(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.mouseMove(widget, pos=(mx, my+25))
+ self.mouseMove(widget, pos=(mx, my+50))
+ self.mouseRelease(widget, qt.Qt.LeftButton, pos=(mx, my+50))
+
+ result = numpy.array(item.getEndPoints())
+ # x location is still the same
+ numpy.testing.assert_allclose(points[:, 0], result[:, 0], atol=0.5)
+ # size is still the same
+ numpy.testing.assert_allclose(points[1] - points[0],
+ result[1] - result[0], atol=0.5)
+ # But Y is not the same
+ self.assertNotEqual(points[0, 1], result[0, 1])
+ self.assertNotEqual(points[1, 1], result[1, 1])
+ item = None
+ manager.clear()
+ self.qapp.processEvents()
+
+ def testPlotWhenCleared(self):
+ """PlotWidget.clear should clean up the available ROIs"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints((0, 0), (1, 1))
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qWait()
+ try:
+ # Make sure the test setup is fine
+ self.assertNotEqual(len(manager.getRois()), 0)
+ self.assertNotEqual(len(self.plot.getItems()), 0)
+
+ # Call clear and test the expected state
+ self.plot.clear()
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(len(self.plot.getItems()), 0)
+ finally:
+ # Clean up
+ manager.clear()
+
+ def testPlotWhenRoiRemoved(self):
+ """Make sure there is no remaining items in the plot when a ROI is removed"""
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.LineROI()
+ item.setEndPoints((0, 0), (1, 1))
+ item.setEditable(True)
+ manager.addRoi(item)
+ self.qWait()
+ try:
+ # Make sure the test setup is fine
+ self.assertNotEqual(len(manager.getRois()), 0)
+ self.assertNotEqual(len(self.plot.getItems()), 0)
+
+ # Call clear and test the expected state
+ manager.removeRoi(item)
+ self.assertEqual(len(manager.getRois()), 0)
+ self.assertEqual(len(self.plot.getItems()), 0)
+ finally:
+ # Clean up
+ manager.clear()
+
+ def testArcRoiSwitchMode(self):
+ """Make sure we can switch mode by clicking on the ROI"""
+ xlimit = self.plot.getXAxis().getLimits()
+ ylimit = self.plot.getYAxis().getLimits()
+ points = numpy.array([xlimit, ylimit]).T
+ center = numpy.mean(points, axis=0)
+ size = numpy.abs(points[1] - points[0])
+
+ # Create the line
+ manager = roi.RegionOfInterestManager(self.plot)
+ item = roi_items.ArcROI()
+ item.setGeometry(center, size[1] / 10, size[1] / 2, 0, 3)
+ item.setEditable(True)
+ item.setSelectable(True)
+ manager.addRoi(item)
+ self.qapp.processEvents()
+
+ # Initial state
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+ self.qWait(500)
+
+ # Click on the center
+ widget = self.plot.getWidgetHandle()
+ mx, my = self.plot.dataToPixel(*center)
+
+ # Select the ROI
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.ThreePointMode)
+
+ # Change the mode
+ self.mouseMove(widget, pos=(mx, my))
+ self.mouseClick(widget, qt.Qt.LeftButton, pos=(mx, my))
+ self.qWait(500)
+ self.assertIs(item.getInteractionMode(), roi_items.ArcROI.PolarMode)
+
+ manager.clear()
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
new file mode 100644
index 0000000..582a276
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testScatterProfileToolBar.py
@@ -0,0 +1,184 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "28/06/2018"
+
+
+import unittest
+import numpy
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot.tools.profile import manager
+from silx.gui.plot.tools.profile import core
+from silx.gui.plot.tools.profile import rois
+
+
+class TestScatterProfileToolBar(TestCaseQt, ParametricTestCase):
+ """Tests for ScatterProfileToolBar class"""
+
+ def setUp(self):
+ super(TestScatterProfileToolBar, self).setUp()
+ self.plot = PlotWindow()
+
+ self.manager = manager.ProfileManager(plot=self.plot)
+ self.manager.setItemType(scatter=True)
+ self.manager.setActiveItemTracking(True)
+
+ self.plot.show()
+ self.qWaitForWindowExposed(self.plot)
+
+ def tearDown(self):
+ del self.manager
+ self.qapp.processEvents()
+ self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plot.close()
+ del self.plot
+ super(TestScatterProfileToolBar, self).tearDown()
+
+ def testHorizontalProfile(self):
+ """Test ScatterProfileToolBar horizontal profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
+ self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterHorizontalLineROI()
+ roi.setPosition(0.5)
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(20):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+ self.qapp.processEvents()
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
+
+ # Check that profile has same limits than Plot
+ xLimits = self.plot.getXAxis().getLimits()
+ self.assertEqual(data.coords[0], xLimits[0])
+ self.assertEqual(data.coords[-1], xLimits[1])
+
+ # Clear the profile
+ self.manager.clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(roi.getProfileWindow())
+
+ def testVerticalProfile(self):
+ """Test ScatterProfileToolBar vertical profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
+ self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterVerticalLineROI()
+ roi.setPosition(0.5)
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
+
+ # Check that profile has same limits than Plot
+ yLimits = self.plot.getYAxis().getLimits()
+ self.assertEqual(data.coords[0], yLimits[0])
+ self.assertEqual(data.coords[-1], yLimits[1])
+
+ # Check that profile limits are updated when changing limits
+ self.plot.getYAxis().setLimits(yLimits[0] + 1, yLimits[1] + 10)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ yLimits = self.plot.getYAxis().getLimits()
+ data = window.getProfile()
+ self.assertEqual(data.coords[0], yLimits[0])
+ self.assertEqual(data.coords[-1], yLimits[1])
+
+ # Clear the profile
+ self.manager.clearProfile()
+ self.qapp.processEvents()
+ self.assertIsNone(roi.getProfileWindow())
+
+ def testLineProfile(self):
+ """Test ScatterProfileToolBar line profile"""
+
+ roiManager = self.manager.getRoiManager()
+
+ # Add a scatter plot
+ self.plot.addScatter(
+ x=(0., 1., 1., 0.), y=(0., 0., 1., 1.), value=(0., 1., 2., 3.))
+ self.plot.resetZoom(dataMargins=(.1, .1, .1, .1))
+ self.qapp.processEvents()
+
+ # Set a ROI profile
+ roi = rois.ProfileScatterLineROI()
+ roi.setEndPoints(numpy.array([0., 0.]), numpy.array([1., 1.]))
+ roi.setNPoints(8)
+ roiManager.addRoi(roi)
+
+ # Wait for async interpolator init
+ for _ in range(10):
+ self.qWait(200)
+ if not self.manager.hasPendingOperations():
+ break
+
+ window = roi.getProfileWindow()
+ self.assertIsNotNone(window)
+ data = window.getProfile()
+ self.assertIsInstance(data, core.CurveProfileData)
+ self.assertEqual(len(data.coords), 8)
diff --git a/src/silx/gui/plot/tools/test/testTools.py b/src/silx/gui/plot/tools/test/testTools.py
new file mode 100644
index 0000000..846f641
--- /dev/null
+++ b/src/silx/gui/plot/tools/test/testTools.py
@@ -0,0 +1,135 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for silx.gui.plot.tools package"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/03/2018"
+
+
+import functools
+import unittest
+import numpy
+
+from silx.utils.testutils import LoggingValidator
+from silx.gui.utils.testutils import qWaitForWindowExposedAndActivate
+from silx.gui import qt
+from silx.gui.plot import PlotWindow
+from silx.gui.plot import tools
+from silx.gui.plot.test.utils import PlotWidgetTestCase
+
+
+class TestPositionInfo(PlotWidgetTestCase):
+ """Tests for PositionInfo widget."""
+
+ def _createPlot(self):
+ return PlotWindow()
+
+ def setUp(self):
+ super(TestPositionInfo, self).setUp()
+ self.mouseMove(self.plot, pos=(0, 0))
+ self.qapp.processEvents()
+ self.qWait(100)
+
+ def tearDown(self):
+ super(TestPositionInfo, self).tearDown()
+
+ def _test(self, positionWidget, converterNames, **kwargs):
+ """General test of PositionInfo.
+
+ - Add it to a toolbar and
+ - Move mouse around the center of the PlotWindow.
+ """
+ toolBar = qt.QToolBar()
+ self.plot.addToolBar(qt.Qt.BottomToolBarArea, toolBar)
+
+ toolBar.addWidget(positionWidget)
+
+ converters = positionWidget.getConverters()
+ self.assertEqual(len(converters), len(converterNames))
+ for index, name in enumerate(converterNames):
+ self.assertEqual(converters[index][0], name)
+
+ self.qapp.processEvents()
+ with LoggingValidator(tools.__name__, **kwargs):
+ # Move mouse to center
+ center = self.plot.size() / 2
+ self.mouseMove(self.plot, pos=(center.width(), center.height()))
+ # Move out
+ self.mouseMove(self.plot, pos=(1, 1))
+
+ def testDefaultConverters(self):
+ """Test PositionInfo with default converters"""
+ positionWidget = tools.PositionInfo(plot=self.plot)
+ self._test(positionWidget, ('X', 'Y'))
+
+ def testCustomConverters(self):
+ """Test PositionInfo with custom converters"""
+ converters = [
+ ('Coords', lambda x, y: (int(x), int(y))),
+ ('Radius', lambda x, y: numpy.sqrt(x * x + y * y)),
+ ('Angle', lambda x, y: numpy.degrees(numpy.arctan2(y, x)))
+ ]
+ positionWidget = tools.PositionInfo(plot=self.plot,
+ converters=converters)
+ self._test(positionWidget, ('Coords', 'Radius', 'Angle'))
+
+ def testFailingConverters(self):
+ """Test PositionInfo with failing custom converters"""
+ def raiseException(x, y):
+ raise RuntimeError()
+
+ positionWidget = tools.PositionInfo(
+ plot=self.plot,
+ converters=[('Exception', raiseException)])
+ self._test(positionWidget, ['Exception'], error=2)
+
+ def testUpdate(self):
+ """Test :meth:`PositionInfo.updateInfo`"""
+ calls = []
+
+ def update(calls, x, y): # Get number of calls
+ calls.append((x, y))
+ return len(calls)
+
+ positionWidget = tools.PositionInfo(
+ plot=self.plot,
+ converters=[('Call count', functools.partial(update, calls))])
+
+ positionWidget.updateInfo()
+ self.assertEqual(len(calls), 1)
+
+
+class TestPlotToolsToolbars(PlotWidgetTestCase):
+ """Tests toolbars from silx.gui.plot.tools"""
+
+ def test(self):
+ """"Add all toolbars"""
+ for tbClass in (tools.InteractiveModeToolBar,
+ tools.ImageToolBar,
+ tools.CurveToolBar,
+ tools.OutputToolBar):
+ tb = tbClass(parent=self.plot, plot=self.plot)
+ self.plot.addToolBar(tb)
diff --git a/silx/gui/plot/tools/toolbars.py b/src/silx/gui/plot/tools/toolbars.py
index 3df7d06..3df7d06 100644
--- a/silx/gui/plot/tools/toolbars.py
+++ b/src/silx/gui/plot/tools/toolbars.py
diff --git a/silx/gui/plot/utils/__init__.py b/src/silx/gui/plot/utils/__init__.py
index 3187f6b..3187f6b 100644
--- a/silx/gui/plot/utils/__init__.py
+++ b/src/silx/gui/plot/utils/__init__.py
diff --git a/src/silx/gui/plot/utils/axis.py b/src/silx/gui/plot/utils/axis.py
new file mode 100644
index 0000000..5cf8ad9
--- /dev/null
+++ b/src/silx/gui/plot/utils/axis.py
@@ -0,0 +1,398 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module contains utils class for axes management.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "20/11/2018"
+
+import functools
+import logging
+from contextlib import contextmanager
+import weakref
+import silx.utils.weakref as silxWeakref
+from silx.gui.plot.items.axis import Axis, XAxis, YAxis
+from ...qt.inspect import isValid as _isQObjectValid
+
+
+_logger = logging.getLogger(__name__)
+
+
+class SyncAxes(object):
+ """Synchronize a set of plot axes together.
+
+ It is created with the expected axes and starts to synchronize them.
+
+ It can be customized to synchronize limits, scale, and direction of axes
+ together. By default everything is synchronized.
+
+ The API :meth:`start` and :meth:`stop` can be used to enable/disable the
+ synchronization while this object is still alive.
+
+ If this object is destroyed the synchronization stop.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, axes,
+ syncLimits=True,
+ syncScale=True,
+ syncDirection=True,
+ syncCenter=False,
+ syncZoom=False,
+ filterHiddenPlots=False
+ ):
+ """
+ Constructor
+
+ :param list(Axis) axes: A list of axes to synchronize together
+ :param bool syncLimits: Synchronize axes limits
+ :param bool syncScale: Synchronize axes scale
+ :param bool syncDirection: Synchronize axes direction
+ :param bool syncCenter: Synchronize the center of the axes in the center
+ of the plots
+ :param bool syncZoom: Synchronize the zoom of the plot
+ :param bool filterHiddenPlots: True to avoid updating hidden plots.
+ Default: False.
+ """
+ object.__init__(self)
+
+ def implies(x, y): return bool(y ** x)
+
+ assert(implies(syncZoom, not syncLimits))
+ assert(implies(syncCenter, not syncLimits))
+ assert(implies(syncLimits, not syncCenter))
+ assert(implies(syncLimits, not syncZoom))
+
+ self.__filterHiddenPlots = filterHiddenPlots
+ self.__locked = False
+ self.__axisRefs = []
+ self.__syncLimits = syncLimits
+ self.__syncScale = syncScale
+ self.__syncDirection = syncDirection
+ self.__syncCenter = syncCenter
+ self.__syncZoom = syncZoom
+ self.__callbacks = None
+ self.__lastMainAxis = None
+
+ for axis in axes:
+ self.addAxis(axis)
+
+ self.start()
+
+ def start(self):
+ """Start synchronizing axes together.
+
+ The first axis is used as the reference for the first synchronization.
+ After that, any changes to any axes will be used to synchronize other
+ axes.
+ """
+ if self.isSynchronizing():
+ raise RuntimeError("Axes already synchronized")
+ self.__callbacks = {}
+
+ axes = self.__getAxes()
+
+ # register callback for further sync
+ for axis in axes:
+ self.__connectAxes(axis)
+ self.synchronize()
+
+ def isSynchronizing(self):
+ """Returns true if events are connected to the axes to synchronize them
+ all together
+
+ :rtype: bool
+ """
+ return self.__callbacks is not None
+
+ def __connectAxes(self, axis):
+ refAxis = weakref.ref(axis)
+ callbacks = []
+ if self.__syncLimits:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisLimitsChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncCenter and self.__syncZoom:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterAndZoomChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ elif self.__syncZoom:
+ raise NotImplementedError()
+ elif self.__syncCenter:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisCenterChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigLimitsChanged
+ sig.connect(callback)
+ callbacks.append(("sigLimitsChanged", callback))
+ if self.__syncScale:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisScaleChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigScaleChanged
+ sig.connect(callback)
+ callbacks.append(("sigScaleChanged", callback))
+ if self.__syncDirection:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisInvertedChanged)
+ callback = functools.partial(callback, refAxis)
+ sig = axis.sigInvertedChanged
+ sig.connect(callback)
+ callbacks.append(("sigInvertedChanged", callback))
+
+ if self.__filterHiddenPlots:
+ # the weakref is needed to be able ignore self references
+ callback = silxWeakref.WeakMethodProxy(self.__axisVisibilityChanged)
+ callback = functools.partial(callback, refAxis)
+ plot = axis._getPlot()
+ plot.sigVisibilityChanged.connect(callback)
+ callbacks.append(("sigVisibilityChanged", callback))
+
+ self.__callbacks[refAxis] = callbacks
+
+ def __disconnectAxes(self, axis):
+ if axis is not None and _isQObjectValid(axis):
+ ref = weakref.ref(axis)
+ callbacks = self.__callbacks.pop(ref)
+ for sigName, callback in callbacks:
+ if sigName == "sigVisibilityChanged":
+ obj = axis._getPlot()
+ else:
+ obj = axis
+ if obj is not None:
+ sig = getattr(obj, sigName)
+ sig.disconnect(callback)
+
+ def addAxis(self, axis):
+ """Add a new axes to synchronize.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to synchronize
+ """
+ self.__axisRefs.append(weakref.ref(axis))
+ if self.isSynchronizing():
+ self.__connectAxes(axis)
+ # This could be done faster as only this axis have to be fixed
+ self.synchronize()
+
+ def removeAxis(self, axis):
+ """Remove an axis from the synchronized axes.
+
+ :param ~silx.gui.plot.items.Axis axis: The axis to remove
+ """
+ ref = weakref.ref(axis)
+ self.__axisRefs.remove(ref)
+ if self.isSynchronizing():
+ self.__disconnectAxes(axis)
+
+ def synchronize(self, mainAxis=None):
+ """Synchronize programatically all the axes.
+
+ :param ~silx.gui.plot.items.Axis mainAxis:
+ The axis to take as reference (Default: the first axis).
+ """
+ # sync the current state
+ axes = self.__getAxes()
+ if len(axes) == 0:
+ return
+
+ if mainAxis is None:
+ mainAxis = axes[0]
+
+ refMainAxis = weakref.ref(mainAxis)
+ if self.__syncLimits:
+ self.__axisLimitsChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter and self.__syncZoom:
+ self.__axisCenterAndZoomChanged(refMainAxis, *mainAxis.getLimits())
+ elif self.__syncCenter:
+ self.__axisCenterChanged(refMainAxis, *mainAxis.getLimits())
+ if self.__syncScale:
+ self.__axisScaleChanged(refMainAxis, mainAxis.getScale())
+ if self.__syncDirection:
+ self.__axisInvertedChanged(refMainAxis, mainAxis.isInverted())
+
+ def stop(self):
+ """Stop the synchronization of the axes"""
+ if not self.isSynchronizing():
+ raise RuntimeError("Axes not synchronized")
+ for ref in list(self.__callbacks.keys()):
+ axis = ref()
+ self.__disconnectAxes(axis)
+ self.__callbacks = None
+
+ def __del__(self):
+ """Destructor"""
+ # clean up references
+ if self.__callbacks is not None:
+ self.stop()
+
+ def __getAxes(self):
+ """Returns list of existing axes.
+
+ :rtype: List[Axis]
+ """
+ axes = [ref() for ref in self.__axisRefs]
+ return [axis for axis in axes if axis is not None]
+
+ @contextmanager
+ def __inhibitSignals(self):
+ self.__locked = True
+ yield
+ self.__locked = False
+
+ def __axesToUpdate(self, changedAxis):
+ for axis in self.__getAxes():
+ if axis is changedAxis:
+ continue
+ if self.__filterHiddenPlots:
+ plot = axis._getPlot()
+ if not plot.isVisible():
+ continue
+ yield axis
+
+ def __axisVisibilityChanged(self, changedAxis, isVisible):
+ if not isVisible:
+ return
+ if self.__locked:
+ return
+ changedAxis = changedAxis()
+ if self.__lastMainAxis is None:
+ self.__lastMainAxis = self.__axisRefs[0]
+ mainAxis = self.__lastMainAxis
+ mainAxis = mainAxis()
+ self.synchronize(mainAxis=mainAxis)
+ # force back the main axis
+ self.__lastMainAxis = weakref.ref(mainAxis)
+
+ def __getAxesCenter(self, axis, vmin, vmax):
+ """Returns the value displayed in the center of this axis range.
+
+ :rtype: float
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ center = (vmin + vmax) * 0.5
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ return center
+
+ def __getRangeInPixel(self, axis):
+ """Returns the size of the axis in pixel"""
+ bounds = axis._getPlot().getPlotBoundsInPixels()
+ # bounds: left, top, width, height
+ if isinstance(axis, XAxis):
+ return bounds[2]
+ elif isinstance(axis, YAxis):
+ return bounds[3]
+ else:
+ assert(False)
+
+ def __getLimitsFromCenter(self, axis, pos, pixelSize=None):
+ """Returns the limits to apply to this axis to move the `pos` into the
+ center of this axis.
+
+ :param Axis axis:
+ :param float pos: Position in the center of the computed limits
+ :param Union[None,float] pixelSize: Pixel size to apply to compute the
+ limits. If `None` the current pixel size is applyed.
+ """
+ scale = axis.getScale()
+ if scale == Axis.LINEAR:
+ if pixelSize is None:
+ # Use the current pixel size of the axis
+ limits = axis.getLimits()
+ valueRange = limits[0] - limits[1]
+ a = pos - valueRange * 0.5
+ b = pos + valueRange * 0.5
+ else:
+ pixelRange = self.__getRangeInPixel(axis)
+ a = pos - pixelRange * 0.5 * pixelSize
+ b = pos + pixelRange * 0.5 * pixelSize
+
+ else:
+ raise NotImplementedError("Log scale not implemented")
+ if a > b:
+ return b, a
+ return a, b
+
+ def __axisLimitsChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterAndZoomChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ pixelRange = self.__getRangeInPixel(changedAxis)
+ if pixelRange == 0:
+ return
+ pixelSize = (vmax - vmin) / pixelRange
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center, pixelSize)
+ axis.setLimits(vmin, vmax)
+
+ def __axisCenterChanged(self, changedAxis, vmin, vmax):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ center = self.__getAxesCenter(changedAxis, vmin, vmax)
+ for axis in self.__axesToUpdate(changedAxis):
+ vmin, vmax = self.__getLimitsFromCenter(axis, center)
+ axis.setLimits(vmin, vmax)
+
+ def __axisScaleChanged(self, changedAxis, scale):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setScale(scale)
+
+ def __axisInvertedChanged(self, changedAxis, isInverted):
+ if self.__locked:
+ return
+ self.__lastMainAxis = changedAxis
+ changedAxis = changedAxis()
+ with self.__inhibitSignals():
+ for axis in self.__axesToUpdate(changedAxis):
+ axis.setInverted(isInverted)
diff --git a/silx/gui/plot/utils/intersections.py b/src/silx/gui/plot/utils/intersections.py
index 53f2546..53f2546 100644
--- a/silx/gui/plot/utils/intersections.py
+++ b/src/silx/gui/plot/utils/intersections.py
diff --git a/src/silx/gui/plot3d/ParamTreeView.py b/src/silx/gui/plot3d/ParamTreeView.py
new file mode 100644
index 0000000..2593860
--- /dev/null
+++ b/src/silx/gui/plot3d/ParamTreeView.py
@@ -0,0 +1,522 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides a :class:`QTreeView` dedicated to display plot3d models.
+
+This module contains:
+- :class:`ParamTreeView`: A QTreeView specific for plot3d parameters and scene.
+- :class:`ParameterTreeDelegate`: The delegate for :class:`ParamTreeView`.
+- A set of specific editors used by :class:`ParameterTreeDelegate`:
+ :class:`FloatEditor`, :class:`Vector3DEditor`,
+ :class:`Vector4DEditor`, :class:`IntSliderEditor`, :class:`BooleanEditor`
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2017"
+
+
+import numbers
+import sys
+
+from .. import qt
+from ..widgets.FloatEdit import FloatEdit as _FloatEdit
+from ._model import visitQAbstractItemModel
+
+
+class FloatEditor(_FloatEdit):
+ """Editor widget for float.
+
+ :param parent: The widget's parent
+ :param float value: The initial editor value
+ """
+
+ valueChanged = qt.Signal(float)
+ """Signal emitted when the float value has changed"""
+
+ def __init__(self, parent=None, value=None):
+ super(FloatEditor, self).__init__(parent, value)
+ self.setAlignment(qt.Qt.AlignLeft)
+ self.editingFinished.connect(self._emit)
+
+ def _emit(self):
+ self.valueChanged.emit(self.value)
+
+ value = qt.Property(float,
+ fget=_FloatEdit.value,
+ fset=_FloatEdit.setValue,
+ user=True,
+ notify=valueChanged)
+ """Qt user property of the float value this widget edits"""
+
+
+class Vector3DEditor(qt.QWidget):
+ """Editor widget for QVector3D.
+
+ :param parent: The widget's parent
+ :param flags: The widgets's flags
+ """
+
+ valueChanged = qt.Signal(qt.QVector3D)
+ """Signal emitted when the QVector3D value has changed"""
+
+ def __init__(self, parent=None, flags=qt.Qt.Widget):
+ super(Vector3DEditor, self).__init__(parent, flags)
+ layout = qt.QHBoxLayout(self)
+ # layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+ self._xEdit = _FloatEdit(parent=self, value=0.)
+ self._xEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._xEdit.editingFinished.connect(self._emit)
+ self._yEdit = _FloatEdit(parent=self, value=0.)
+ self._yEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._yEdit.editingFinished.connect(self._emit)
+ self._zEdit = _FloatEdit(parent=self, value=0.)
+ self._zEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._zEdit.editingFinished.connect(self._emit)
+ layout.addWidget(qt.QLabel('x:'))
+ layout.addWidget(self._xEdit)
+ layout.addWidget(qt.QLabel('y:'))
+ layout.addWidget(self._yEdit)
+ layout.addWidget(qt.QLabel('z:'))
+ layout.addWidget(self._zEdit)
+ layout.addStretch(1)
+
+ def _emit(self):
+ vector = self.value
+ self.valueChanged.emit(vector)
+
+ def getValue(self):
+ """Returns the QVector3D value of this widget
+
+ :rtype: QVector3D
+ """
+ return qt.QVector3D(
+ self._xEdit.value(), self._yEdit.value(), self._zEdit.value())
+
+ def setValue(self, value):
+ """Set the QVector3D value
+
+ :param QVector3D value: The new value
+ """
+ self._xEdit.setValue(value.x())
+ self._yEdit.setValue(value.y())
+ self._zEdit.setValue(value.z())
+ self.valueChanged.emit(value)
+
+ value = qt.Property(qt.QVector3D,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ notify=valueChanged)
+ """Qt user property of the QVector3D value this widget edits"""
+
+
+class Vector4DEditor(qt.QWidget):
+ """Editor widget for QVector4D.
+
+ :param parent: The widget's parent
+ :param flags: The widgets's flags
+ """
+
+ valueChanged = qt.Signal(qt.QVector4D)
+ """Signal emitted when the QVector4D value has changed"""
+
+ def __init__(self, parent=None, flags=qt.Qt.Widget):
+ super(Vector4DEditor, self).__init__(parent, flags)
+ layout = qt.QHBoxLayout(self)
+ # layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+ self._xEdit = _FloatEdit(parent=self, value=0.)
+ self._xEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._xEdit.editingFinished.connect(self._emit)
+ self._yEdit = _FloatEdit(parent=self, value=0.)
+ self._yEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._yEdit.editingFinished.connect(self._emit)
+ self._zEdit = _FloatEdit(parent=self, value=0.)
+ self._zEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._zEdit.editingFinished.connect(self._emit)
+ self._wEdit = _FloatEdit(parent=self, value=0.)
+ self._wEdit.setAlignment(qt.Qt.AlignLeft)
+ # self._wEdit.editingFinished.connect(self._emit)
+ layout.addWidget(qt.QLabel('x:'))
+ layout.addWidget(self._xEdit)
+ layout.addWidget(qt.QLabel('y:'))
+ layout.addWidget(self._yEdit)
+ layout.addWidget(qt.QLabel('z:'))
+ layout.addWidget(self._zEdit)
+ layout.addWidget(qt.QLabel('w:'))
+ layout.addWidget(self._wEdit)
+ layout.addStretch(1)
+
+ def _emit(self):
+ vector = self.value
+ self.valueChanged.emit(vector)
+
+ def getValue(self):
+ """Returns the QVector4D value of this widget
+
+ :rtype: QVector4D
+ """
+ return qt.QVector4D(self._xEdit.value(), self._yEdit.value(),
+ self._zEdit.value(), self._wEdit.value())
+
+ def setValue(self, value):
+ """Set the QVector4D value
+
+ :param QVector4D value: The new value
+ """
+ self._xEdit.setValue(value.x())
+ self._yEdit.setValue(value.y())
+ self._zEdit.setValue(value.z())
+ self._wEdit.setValue(value.w())
+ self.valueChanged.emit(value)
+
+ value = qt.Property(qt.QVector4D,
+ fget=getValue,
+ fset=setValue,
+ user=True,
+ notify=valueChanged)
+ """Qt user property of the QVector4D value this widget edits"""
+
+
+class IntSliderEditor(qt.QSlider):
+ """Slider editor widget for integer.
+
+ Note: Tracking is disabled.
+
+ :param parent: The widget's parent
+ """
+
+ def __init__(self, parent=None):
+ super(IntSliderEditor, self).__init__(parent)
+ self.setOrientation(qt.Qt.Horizontal)
+ self.setSingleStep(1)
+ self.setRange(0, 255)
+ self.setValue(0)
+
+
+class BooleanEditor(qt.QCheckBox):
+ """Checkbox editor for bool.
+
+ This is a QCheckBox with white background.
+
+ :param parent: The widget's parent
+ """
+
+ def __init__(self, parent=None):
+ super(BooleanEditor, self).__init__(parent)
+ self.setStyleSheet("background: white;")
+
+
+class ParameterTreeDelegate(qt.QStyledItemDelegate):
+ """TreeView delegate specific to plot3d scene and object parameter tree.
+
+ It provides additional editors.
+
+ :param parent: Delegate's parent
+ """
+
+ EDITORS = {
+ bool: BooleanEditor,
+ float: FloatEditor,
+ qt.QVector3D: Vector3DEditor,
+ qt.QVector4D: Vector4DEditor,
+ }
+ """Specific editors for different type of data"""
+
+ def __init__(self, parent=None):
+ super(ParameterTreeDelegate, self).__init__(parent)
+
+ def paint(self, painter, option, index):
+ """See :meth:`QStyledItemDelegate.paint`"""
+ data = index.data(qt.Qt.DisplayRole)
+
+ if isinstance(data, (qt.QVector3D, qt.QVector4D)):
+ if isinstance(data, qt.QVector3D):
+ text = '(x: %g; y: %g; z: %g)' % (data.x(), data.y(), data.z())
+ elif isinstance(data, qt.QVector4D):
+ text = '(%g; %g; %g; %g)' % (data.x(), data.y(), data.z(), data.w())
+ else:
+ text = ''
+
+ painter.save()
+ painter.setRenderHint(qt.QPainter.Antialiasing, True)
+
+ # Select palette color group
+ colorGroup = qt.QPalette.Inactive
+ if option.state & qt.QStyle.State_Active:
+ colorGroup = qt.QPalette.Active
+ if not option.state & qt.QStyle.State_Enabled:
+ colorGroup = qt.QPalette.Disabled
+
+ # Draw background if selected
+ if option.state & qt.QStyle.State_Selected:
+ brush = option.palette.brush(colorGroup,
+ qt.QPalette.Highlight)
+ painter.fillRect(option.rect, brush)
+
+ # Draw text
+ if option.state & qt.QStyle.State_Selected:
+ colorRole = qt.QPalette.HighlightedText
+ else:
+ colorRole = qt.QPalette.WindowText
+ color = option.palette.color(colorGroup, colorRole)
+ painter.setPen(qt.QPen(color))
+ painter.drawText(option.rect, qt.Qt.AlignLeft, text)
+
+ painter.restore()
+
+ # The following commented code does the same as QPainter based code
+ # but it does not work with PySide
+ # self.initStyleOption(option, index)
+ # option.text = text
+ # widget = option.widget
+ # style = qt.QApplication.style() if not widget else widget.style()
+ # style.drawControl(qt.QStyle.CE_ItemViewItem, option, painter, widget)
+
+ else:
+ super(ParameterTreeDelegate, self).paint(painter, option, index)
+
+ def _commit(self, *args):
+ """Commit data to the model from editors"""
+ sender = self.sender()
+ self.commitData.emit(sender)
+
+ def editorEvent(self, event, model, option, index):
+ """See :meth:`QStyledItemDelegate.editorEvent`"""
+ if (event.type() == qt.QEvent.MouseButtonPress and
+ isinstance(index.data(qt.Qt.EditRole), qt.QColor)):
+ initialColor = index.data(qt.Qt.EditRole)
+
+ def callback(color):
+ theModel = index.model()
+ theModel.setData(index, color, qt.Qt.EditRole)
+
+ dialog = qt.QColorDialog(self.parent())
+ # dialog.setOption(qt.QColorDialog.ShowAlphaChannel, True)
+ if sys.platform == 'darwin':
+ # Use of native color dialog on macos might cause problems
+ dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
+ dialog.setCurrentColor(initialColor)
+ dialog.currentColorChanged.connect(callback)
+ if dialog.exec() == qt.QDialog.Rejected:
+ # Reset color
+ dialog.setCurrentColor(initialColor)
+
+ return True
+ else:
+ return super(ParameterTreeDelegate, self).editorEvent(
+ event, model, option, index)
+
+ def createEditor(self, parent, option, index):
+ """See :meth:`QStyledItemDelegate.createEditor`"""
+ data = index.data(qt.Qt.EditRole)
+ editorHint = index.data(qt.Qt.UserRole)
+
+ if callable(editorHint):
+ editor = editorHint()
+ assert isinstance(editor, qt.QWidget)
+ editor.setParent(parent)
+
+ elif isinstance(data, numbers.Number) and editorHint is not None:
+ # Use a slider
+ editor = IntSliderEditor(parent)
+ range_ = editorHint
+ editor.setRange(*range_)
+ editor.sliderReleased.connect(self._commit)
+
+ elif isinstance(data, str) and editorHint is not None:
+ # Use a combo box
+ editor = qt.QComboBox(parent)
+ if data not in editorHint:
+ editor.addItem(data)
+ editor.addItems(editorHint)
+
+ index = editor.findText(data)
+ editor.setCurrentIndex(index)
+
+ editor.currentIndexChanged.connect(self._commit)
+
+ else:
+ # Handle overridden editors from Python
+ # Mimic Qt C++ implementation
+ for type_, editorClass in self.EDITORS.items():
+ if isinstance(data, type_):
+ editor = editorClass(parent)
+ metaObject = editor.metaObject()
+ userProperty = metaObject.userProperty()
+ if userProperty.isValid() and userProperty.hasNotifySignal():
+ notifySignal = userProperty.notifySignal()
+ signature = notifySignal.methodSignature()
+ if qt.BINDING == 'PySide2':
+ signature = signature.data()
+ else:
+ signature = bytes(signature)
+
+ if hasattr(signature, 'decode'): # For PySide with python3
+ signature = signature.decode('ascii')
+ signalName = signature.split('(')[0]
+
+ signal = getattr(editor, signalName)
+ signal.connect(self._commit)
+ break
+
+ else: # Default handling for default types
+ return super(ParameterTreeDelegate, self).createEditor(
+ parent, option, index)
+
+ editor.setAutoFillBackground(True)
+ return editor
+
+ def setModelData(self, editor, model, index):
+ """See :meth:`QStyledItemDelegate.setModelData`"""
+ if isinstance(editor, tuple(self.EDITORS.values())):
+ # Special handling of Python classes
+ # Translation of QStyledItemDelegate::setModelData to Python
+ # To make it work with Python QVariant wrapping/unwrapping
+ name = editor.metaObject().userProperty().name()
+ if not name:
+ pass # TODO handle the case of missing user property
+ if name:
+ if hasattr(editor, name):
+ value = getattr(editor, name)
+ else:
+ value = editor.property(name)
+ model.setData(index, value, qt.Qt.EditRole)
+
+ else:
+ super(ParameterTreeDelegate, self).setModelData(editor, model, index)
+
+
+class ParamTreeView(qt.QTreeView):
+ """QTreeView specific to handle plot3d scene and object parameters.
+
+ It provides additional editors and specific creation of persistent editors.
+
+ :param parent: The widget's parent.
+ """
+
+ def __init__(self, parent=None):
+ super(ParamTreeView, self).__init__(parent)
+
+ header = self.header()
+ header.setMinimumSectionSize(128) # For colormap pixmaps
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ delegate = ParameterTreeDelegate()
+ self.setItemDelegate(delegate)
+
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ self.expanded.connect(self._expanded)
+
+ self.setEditTriggers(qt.QAbstractItemView.CurrentChanged |
+ qt.QAbstractItemView.DoubleClicked)
+
+ self.__persistentEditors = set()
+
+ def _openEditorForIndex(self, index):
+ """Check if it has to open a persistent editor for a specific cell.
+
+ :param QModelIndex index: The cell index
+ """
+ if index.flags() & qt.Qt.ItemIsEditable:
+ data = index.data(qt.Qt.EditRole)
+ editorHint = index.data(qt.Qt.UserRole)
+ if (isinstance(data, bool) or
+ callable(editorHint) or
+ (isinstance(data, numbers.Number) and editorHint)):
+ self.openPersistentEditor(index)
+ self.__persistentEditors.add(index)
+
+ def _openEditors(self, parent=qt.QModelIndex()):
+ """Open persistent editors in a subtree starting at parent.
+
+ :param QModelIndex parent: The root of the subtree to process.
+ """
+ model = self.model()
+ if model is not None:
+ for index in visitQAbstractItemModel(model, parent):
+ self._openEditorForIndex(index)
+
+ def setModel(self, model):
+ """Set the model this TreeView is displaying
+
+ :param QAbstractItemModel model:
+ """
+ super(ParamTreeView, self).setModel(model)
+ self._openEditors()
+
+ def rowsInserted(self, parent, start, end):
+ """See :meth:`QTreeView.rowsInserted`"""
+ super(ParamTreeView, self).rowsInserted(parent, start, end)
+ model = self.model()
+ if model is not None:
+ for row in range(start, end+1):
+ self._openEditorForIndex(model.index(row, 1, parent))
+ self._openEditors(model.index(row, 0, parent))
+
+ def _expanded(self, index):
+ """Handle QTreeView expanded signal"""
+ name = index.data(qt.Qt.DisplayRole)
+ if name == 'Transform':
+ rotateIndex = self.model().index(1, 0, index)
+ self.setExpanded(rotateIndex, True)
+
+ def dataChanged(self, topLeft, bottomRight, roles=()):
+ """Handle model dataChanged signal eventually closing editors"""
+ if roles: # Qt 5
+ super(ParamTreeView, self).dataChanged(topLeft, bottomRight, roles)
+ else: # Qt4 compatibility
+ super(ParamTreeView, self).dataChanged(topLeft, bottomRight)
+ if not roles or qt.Qt.UserRole in roles: # Check editorHint update
+ for row in range(topLeft.row(), bottomRight.row() + 1):
+ for column in range(topLeft.column(), bottomRight.column() + 1):
+ index = topLeft.sibling(row, column)
+ if index.isValid():
+ if self._isPersistentEditorOpen(index):
+ self.closePersistentEditor(index)
+ self._openEditorForIndex(index)
+
+ def _isPersistentEditorOpen(self, index):
+ """Returns True if a persistent editor is opened for index
+
+ :param QModelIndex index:
+ :rtype: bool
+ """
+ return index in self.__persistentEditors
+
+ def selectionCommand(self, index, event=None):
+ """Filter out selection of not selectable items"""
+ if index.flags() & qt.Qt.ItemIsSelectable:
+ return super(ParamTreeView, self).selectionCommand(index, event)
+ else:
+ return qt.QItemSelectionModel.NoUpdate
diff --git a/src/silx/gui/plot3d/Plot3DWidget.py b/src/silx/gui/plot3d/Plot3DWidget.py
new file mode 100644
index 0000000..a90d34c
--- /dev/null
+++ b/src/silx/gui/plot3d/Plot3DWidget.py
@@ -0,0 +1,463 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a Qt widget embedding an OpenGL scene."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+import enum
+import logging
+
+from silx.gui import qt
+from silx.gui.colors import rgba
+from . import actions
+
+from ...utils.enum import Enum as _Enum
+from ..utils.image import convertArrayToQImage
+
+from .. import _glutils as glu
+from .scene import interaction, primitives, transform
+from . import scene
+
+import numpy
+
+
+_logger = logging.getLogger(__name__)
+
+
+class _OverviewViewport(scene.Viewport):
+ """A scene displaying the orientation of the data in another scene.
+
+ :param Camera camera: The camera to track.
+ """
+
+ _SIZE = 100
+ """Size in pixels of the overview square"""
+
+ def __init__(self, camera=None):
+ super(_OverviewViewport, self).__init__()
+ self.size = self._SIZE, self._SIZE
+ self.background = None # Disable clear
+
+ self.scene.transforms = [transform.Scale(2.5, 2.5, 2.5)]
+
+ # Add a point to draw the background (in a group with depth mask)
+ backgroundPoint = primitives.ColorPoints(
+ x=0., y=0., z=0.,
+ color=(1., 1., 1., 0.5),
+ size=self._SIZE)
+ backgroundPoint.marker = 'o'
+ noDepthGroup = primitives.GroupNoDepth(mask=True, notest=True)
+ noDepthGroup.children.append(backgroundPoint)
+ self.scene.children.append(noDepthGroup)
+
+ axes = primitives.Axes()
+ self.scene.children.append(axes)
+
+ if camera is not None:
+ camera.addListener(self._cameraChanged)
+
+ def _cameraChanged(self, source):
+ """Listen to camera in other scene for transformation updates.
+
+ Sync the overview camera to point in the same direction
+ but from a sphere centered on origin.
+ """
+ position = -12. * source.extrinsic.direction
+ self.camera.extrinsic.position = position
+
+ self.camera.extrinsic.setOrientation(
+ source.extrinsic.direction, source.extrinsic.up)
+
+
+class Plot3DWidget(glu.OpenGLWidget):
+ """OpenGL widget with a 3D viewport and an overview."""
+
+ sigInteractiveModeChanged = qt.Signal()
+ """Signal emitted when the interactive mode has changed
+ """
+
+ sigStyleChanged = qt.Signal(str)
+ """Signal emitted when the style of the scene has changed
+
+ It provides the updated property.
+ """
+
+ sigSceneClicked = qt.Signal(float, float)
+ """Signal emitted when the scene is clicked with the left mouse button.
+
+ It provides the (x, y) clicked mouse position in logical widget pixel coordinates.
+ """
+
+ @enum.unique
+ class FogMode(_Enum):
+ """Different mode to render the scene with fog"""
+
+ NONE = 'none'
+ """No fog effect"""
+
+ LINEAR = 'linear'
+ """Linear fog through the whole scene"""
+
+ def __init__(self, parent=None, f=qt.Qt.WindowFlags()):
+ self._firstRender = True
+
+ super(Plot3DWidget, self).__init__(
+ parent,
+ alphaBufferSize=8,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=(2, 1),
+ f=f)
+
+ self.setAutoFillBackground(False)
+ self.setMouseTracking(True)
+
+ self.setFocusPolicy(qt.Qt.StrongFocus)
+ self._copyAction = actions.io.CopyAction(parent=self, plot3d=self)
+ self.addAction(self._copyAction)
+
+ self._updating = False # True if an update is requested
+
+ # Main viewport
+ self.viewport = scene.Viewport()
+
+ self._sceneScale = transform.Scale(1., 1., 1.)
+ self.viewport.scene.transforms = [self._sceneScale,
+ transform.Translate(0., 0., 0.)]
+
+ # Overview area
+ self.overview = _OverviewViewport(self.viewport.camera)
+
+ self.setBackgroundColor((0.2, 0.2, 0.2, 1.))
+
+ # Window describing on screen area to render
+ self._window = scene.Window(mode='framebuffer')
+ self._window.viewports = [self.viewport, self.overview]
+ self._window.addListener(self._redraw)
+
+ self.eventHandler = None
+ self.setInteractiveMode('rotate')
+
+ def __clickHandler(self, *args):
+ """Handle interaction state machine click"""
+ x, y = args[0][:2]
+ # Convert from device pixel to logical pixel unit
+ devicePixelRatio = self.getDevicePixelRatio()
+ self.sigSceneClicked.emit(x / devicePixelRatio, y / devicePixelRatio)
+
+ def setInteractiveMode(self, mode):
+ """Set the interactive mode.
+
+ :param str mode: The interactive mode: 'rotate', 'pan' or None
+ """
+ if mode == self.getInteractiveMode():
+ return
+
+ if mode is None:
+ self.eventHandler = None
+
+ elif mode == 'rotate':
+ self.eventHandler = interaction.RotateCameraControl(
+ self.viewport,
+ orbitAroundCenter=False,
+ mode='position',
+ scaleTransform=self._sceneScale,
+ selectCB=self.__clickHandler)
+
+ elif mode == 'pan':
+ self.eventHandler = interaction.PanCameraControl(
+ self.viewport,
+ orbitAroundCenter=False,
+ mode='position',
+ scaleTransform=self._sceneScale,
+ selectCB=self.__clickHandler)
+
+ elif isinstance(mode, interaction.StateMachine):
+ self.eventHandler = mode
+
+ else:
+ raise ValueError('Unsupported interactive mode %s', str(mode))
+
+ if (self.eventHandler is not None and
+ qt.QApplication.keyboardModifiers() & qt.Qt.ControlModifier):
+ self.eventHandler.handleEvent('keyPress', qt.Qt.Key_Control)
+
+ self.sigInteractiveModeChanged.emit()
+
+ def getInteractiveMode(self):
+ """Returns the interactive mode in use.
+
+ :rtype: str
+ """
+ if self.eventHandler is None:
+ return None
+ if isinstance(self.eventHandler, interaction.RotateCameraControl):
+ return 'rotate'
+ elif isinstance(self.eventHandler, interaction.PanCameraControl):
+ return 'pan'
+ else:
+ return None
+
+ def setProjection(self, projection):
+ """Change the projection in use.
+
+ :param str projection: In 'perspective', 'orthographic'.
+ """
+ if projection == 'orthographic':
+ projection = transform.Orthographic(size=self.viewport.size)
+ elif projection == 'perspective':
+ projection = transform.Perspective(fovy=30.,
+ size=self.viewport.size)
+ else:
+ raise RuntimeError('Unsupported projection: %s' % projection)
+
+ self.viewport.camera.intrinsic = projection
+ self.viewport.resetCamera()
+
+ def getProjection(self):
+ """Return the current camera projection mode as a str.
+
+ See :meth:`setProjection`
+ """
+ projection = self.viewport.camera.intrinsic
+ if isinstance(projection, transform.Orthographic):
+ return 'orthographic'
+ elif isinstance(projection, transform.Perspective):
+ return 'perspective'
+ else:
+ raise RuntimeError('Unknown projection in use')
+
+ def setBackgroundColor(self, color):
+ """Set the background color of the OpenGL view.
+
+ :param color: RGB color of the isosurface: name, #RRGGBB or RGB values
+ :type color:
+ QColor, str or array-like of 3 or 4 float in [0., 1.] or uint8
+ """
+ color = rgba(color)
+ if color != self.viewport.background:
+ self.viewport.background = color
+ self.sigStyleChanged.emit('backgroundColor')
+
+ def getBackgroundColor(self):
+ """Returns the RGBA background color (QColor)."""
+ return qt.QColor.fromRgbF(*self.viewport.background)
+
+ def setFogMode(self, mode):
+ """Set the kind of fog to use for the whole scene.
+
+ :param Union[str,FogMode] mode: The mode to use
+ :raise ValueError: If mode is not supported
+ """
+ mode = self.FogMode.from_value(mode)
+ if mode != self.getFogMode():
+ self.viewport.fog.isOn = mode is self.FogMode.LINEAR
+ self.sigStyleChanged.emit('fogMode')
+
+ def getFogMode(self):
+ """Returns the kind of fog in use
+
+ :return: The kind of fog in use
+ :rtype: FogMode
+ """
+ if self.viewport.fog.isOn:
+ return self.FogMode.LINEAR
+ else:
+ return self.FogMode.NONE
+
+ def isOrientationIndicatorVisible(self):
+ """Returns True if the orientation indicator is displayed.
+
+ :rtype: bool
+ """
+ return self.overview in self._window.viewports
+
+ def setOrientationIndicatorVisible(self, visible):
+ """Set the orientation indicator visibility.
+
+ :param bool visible: True to show
+ """
+ visible = bool(visible)
+ if visible != self.isOrientationIndicatorVisible():
+ if visible:
+ self._window.viewports = [self.viewport, self.overview]
+ else:
+ self._window.viewports = [self.viewport]
+ self.sigStyleChanged.emit('orientationIndicatorVisible')
+
+ def centerScene(self):
+ """Position the center of the scene at the center of rotation."""
+ self.viewport.resetCamera()
+
+ def resetZoom(self, face='front'):
+ """Reset the camera position to a default.
+
+ :param str face: The direction the camera is looking at:
+ side, front, back, top, bottom, right, left.
+ Default: front.
+ """
+ self.viewport.camera.extrinsic.reset(face=face)
+ self.centerScene()
+
+ def _redraw(self, source=None):
+ """Viewport listener to require repaint"""
+ if not self._updating:
+ self._updating = True # Mark that an update is requested
+ self.update() # Queued repaint (i.e., asynchronous)
+
+ def sizeHint(self):
+ return qt.QSize(400, 300)
+
+ def initializeGL(self):
+ pass
+
+ def paintGL(self):
+ # In case paintGL is called by the system and not through _redraw,
+ # Mark as updating.
+ self._updating = True
+
+ # Update near and far planes only if viewport needs refresh
+ if self.viewport.dirty:
+ self.viewport.adjustCameraDepthExtent()
+
+ self._window.render(self.context(), self.getDevicePixelRatio())
+
+ if self._firstRender: # TODO remove this ugly hack
+ self._firstRender = False
+ self.centerScene()
+ self._updating = False
+
+ def resizeGL(self, width, height):
+ width *= self.getDevicePixelRatio()
+ height *= self.getDevicePixelRatio()
+ self._window.size = width, height
+ self.viewport.size = self._window.size
+ overviewWidth, overviewHeight = self.overview.size
+ self.overview.origin = width - overviewWidth, height - overviewHeight
+
+ def grabGL(self):
+ """Renders the OpenGL scene into a numpy array
+
+ :returns: OpenGL scene RGB rasterization
+ :rtype: QImage
+ """
+ if not self.isValid():
+ _logger.error('OpenGL 2.1 not available, cannot save OpenGL image')
+ height, width = self._window.shape
+ image = numpy.zeros((height, width, 3), dtype=numpy.uint8)
+
+ else:
+ self.makeCurrent()
+ image = self._window.grab(self.context())
+
+ return convertArrayToQImage(image)
+
+ def wheelEvent(self, event):
+ if qt.BINDING == "PySide6":
+ x, y = event.position().x(), event.position().y()
+ else:
+ x, y = event.x(), event.y()
+ xpixel = x * self.getDevicePixelRatio()
+ ypixel = y * self.getDevicePixelRatio()
+ angle = event.angleDelta().y() / 8.
+ event.accept()
+
+ if self.eventHandler is not None and angle != 0 and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('wheel', xpixel, ypixel, angle)
+
+ def keyPressEvent(self, event):
+ keyCode = event.key()
+ # No need to accept QKeyEvent
+
+ converter = {
+ qt.Qt.Key_Left: 'left',
+ qt.Qt.Key_Right: 'right',
+ qt.Qt.Key_Up: 'up',
+ qt.Qt.Key_Down: 'down'
+ }
+ direction = converter.get(keyCode, None)
+ if direction is not None:
+ if event.modifiers() == qt.Qt.ControlModifier:
+ self.viewport.camera.rotate(direction)
+ elif event.modifiers() == qt.Qt.ShiftModifier:
+ self.viewport.moveCamera(direction)
+ else:
+ self.viewport.orbitCamera(direction)
+
+ else:
+ if (keyCode == qt.Qt.Key_Control and
+ self.eventHandler is not None and
+ self.isValid()):
+ self.eventHandler.handleEvent('keyPress', keyCode)
+
+ # Key not handled, call base class implementation
+ super(Plot3DWidget, self).keyPressEvent(event)
+
+ def keyReleaseEvent(self, event):
+ """Catch Ctrl key release"""
+ keyCode = event.key()
+ if (keyCode == qt.Qt.Key_Control and
+ self.eventHandler is not None and
+ self.isValid()):
+ self.eventHandler.handleEvent('keyRelease', keyCode)
+ super(Plot3DWidget, self).keyReleaseEvent(event)
+
+ # Mouse events #
+ _MOUSE_BTNS = {1: 'left', 2: 'right', 4: 'middle'}
+
+ def mousePressEvent(self, event):
+ xpixel = event.x() * self.getDevicePixelRatio()
+ ypixel = event.y() * self.getDevicePixelRatio()
+ btn = self._MOUSE_BTNS[event.button()]
+ event.accept()
+
+ if self.eventHandler is not None and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('press', xpixel, ypixel, btn)
+
+ def mouseMoveEvent(self, event):
+ xpixel = event.x() * self.getDevicePixelRatio()
+ ypixel = event.y() * self.getDevicePixelRatio()
+ event.accept()
+
+ if self.eventHandler is not None and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('move', xpixel, ypixel)
+
+ def mouseReleaseEvent(self, event):
+ xpixel = event.x() * self.getDevicePixelRatio()
+ ypixel = event.y() * self.getDevicePixelRatio()
+ btn = self._MOUSE_BTNS[event.button()]
+ event.accept()
+
+ if self.eventHandler is not None and self.isValid():
+ self.makeCurrent()
+ self.eventHandler.handleEvent('release', xpixel, ypixel, btn)
diff --git a/silx/gui/plot3d/Plot3DWindow.py b/src/silx/gui/plot3d/Plot3DWindow.py
index 470b966..470b966 100644
--- a/silx/gui/plot3d/Plot3DWindow.py
+++ b/src/silx/gui/plot3d/Plot3DWindow.py
diff --git a/src/silx/gui/plot3d/SFViewParamTree.py b/src/silx/gui/plot3d/SFViewParamTree.py
new file mode 100644
index 0000000..b269a6a
--- /dev/null
+++ b/src/silx/gui/plot3d/SFViewParamTree.py
@@ -0,0 +1,1814 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides a tree widget to set/view parameters of a ScalarFieldView.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["D. N."]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+import logging
+import sys
+import weakref
+
+import numpy
+
+from silx.gui import qt
+from silx.gui.icons import getQIcon
+from silx.gui.colors import Colormap
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+from .ScalarFieldView import Isosurface
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ModelColumns(object):
+ NameColumn, ValueColumn, ColumnMax = range(3)
+ ColumnNames = ['Name', 'Value']
+
+
+class SubjectItem(qt.QStandardItem):
+ """
+ Base class for observers items.
+
+ Subclassing:
+ ------------
+ The following method can/should be reimplemented:
+ - _init
+ - _pullData
+ - _pushData
+ - _setModelData
+ - _subjectChanged
+ - getEditor
+ - getSignals
+ - leftClicked
+ - queryRemove
+ - setEditorData
+
+ Also the following attributes are available:
+ - editable
+ - persistent
+
+ :param subject: object that this item will be observing.
+ """
+
+ editable = False
+ """ boolean: set to True to make the item editable. """
+
+ persistent = False
+ """
+ boolean: set to True to make the editor persistent.
+ See : Qt.QAbstractItemView.openPersistentEditor
+ """
+
+ def __init__(self, subject, *args):
+
+ super(SubjectItem, self).__init__(*args)
+
+ self.setEditable(self.editable)
+
+ self.__subject = None
+ self.subject = subject
+
+ def setData(self, value, role=qt.Qt.UserRole, pushData=True):
+ """
+ Overloaded method from QStandardItem. The pushData keyword tells
+ the item to push data to the subject if the role is equal to EditRole.
+ This is useful to let this method know if the setData method was called
+ internally or from the view.
+
+ :param value: the value ti set to data
+ :param role: role in the item
+ :param pushData: if True push value in the existing data.
+ """
+ if role == qt.Qt.EditRole and pushData:
+ setValue = self._pushData(value, role)
+ if setValue != value:
+ value = setValue
+ super(SubjectItem, self).setData(value, role)
+
+ @property
+ def subject(self):
+ """The subject this item is observing"""
+ return None if self.__subject is None else self.__subject()
+
+ @subject.setter
+ def subject(self, subject):
+ if self.__subject is not None:
+ raise ValueError('Subject already set '
+ ' (subject change not supported).')
+ if subject is None:
+ self.__subject = None
+ else:
+ self.__subject = weakref.ref(subject)
+ if subject is not None:
+ self._init()
+ self._connectSignals()
+
+ def _connectSignals(self):
+ """
+ Connects the signals. Called when the subject is set.
+ """
+
+ def gen_slot(_sigIdx):
+ def slotfn(*args, **kwargs):
+ self._subjectChanged(signalIdx=_sigIdx,
+ args=args,
+ kwargs=kwargs)
+ return slotfn
+
+ if self.__subject is not None:
+ self.__slots = slots = []
+
+ signals = self.getSignals()
+
+ if signals:
+ if not isinstance(signals, (list, tuple)):
+ signals = [signals]
+ for sigIdx, signal in enumerate(signals):
+ slot = gen_slot(sigIdx)
+ signal.connect(slot)
+ slots.append((signal, slot))
+
+ def _disconnectSignals(self):
+ """
+ Disconnects all subject's signal
+ """
+ if self.__slots:
+ for signal, slot in self.__slots:
+ try:
+ signal.disconnect(slot)
+ except TypeError:
+ pass
+
+ def _enableRow(self, enable):
+ """
+ Set the enabled state for this cell, or for the whole row
+ if this item has a parent.
+
+ :param bool enable: True if we wan't to enable the cell
+ """
+ parent = self.parent()
+ model = self.model()
+ if model is None or parent is None:
+ # no parent -> no siblings
+ self.setEnabled(enable)
+ return
+
+ for col in range(model.columnCount()):
+ sibling = parent.child(self.row(), col)
+ sibling.setEnabled(enable)
+
+ #################################################################
+ # Overloadable methods
+ #################################################################
+
+ def getSignals(self):
+ """
+ Returns the list of this items subject's signals that
+ this item will be listening to.
+
+ :return: list.
+ """
+ return None
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ """
+ Called when one of the signals is triggered. Default implementation
+ just calls _pullData, compares the result to the current value stored
+ as Qt.EditRole, and stores the new value if it is different. It also
+ stores its str representation as Qt.DisplayRole
+
+ :param signalIdx: index of the triggered signal. The value passed
+ is the same as the signal position in the list returned by
+ SubjectItem.getSignals.
+ :param args: arguments received from the signal
+ :param kwargs: keyword arguments received from the signal
+ """
+ data = self._pullData()
+ if data == self.data(qt.Qt.EditRole):
+ return
+ self.setData(data, role=qt.Qt.DisplayRole, pushData=False)
+ self.setData(data, role=qt.Qt.EditRole, pushData=False)
+
+ def _pullData(self):
+ """
+ Pulls data from the subject.
+
+ :return: subject data
+ """
+ return None
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ """
+ Pushes data to the subject and returns the actual value that was stored
+
+ :return: the value that was stored
+ """
+ return value
+
+ def _init(self):
+ """
+ Called when the subject is set.
+ :return:
+ """
+ self._subjectChanged()
+
+ def getEditor(self, parent, option, index):
+ """
+ Returns the editor widget used to edit this item's data. The arguments
+ are the one passed to the QStyledItemDelegate.createEditor method.
+
+ :param parent: the Qt parent of the editor
+ :param option:
+ :param index:
+ :return:
+ """
+ return None
+
+ def setEditorData(self, editor):
+ """
+ This is called by the View's delegate just before the editor is shown,
+ its purpose it to setup the editors contents. Return False to use
+ the delegate's default behaviour.
+
+ :param editor:
+ :return:
+ """
+ return True
+
+ def _setModelData(self, editor):
+ """
+ This is called by the View's delegate just before the editor is closed,
+ its allows this item to update itself with data from the editor.
+
+ :param editor:
+ :return:
+ """
+ return False
+
+ def queryRemove(self, view=None):
+ """
+ This is called by the view to ask this items if it (the view) can
+ remove it. Return True to let the view know that the item can be
+ removed.
+
+ :param view:
+ :return:
+ """
+ return False
+
+ def leftClicked(self):
+ """
+ This method is called by the view when the item's cell if left clicked.
+
+ :return:
+ """
+ pass
+
+
+# View settings ###############################################################
+
+class ColorItem(SubjectItem):
+ """color item."""
+ editable = True
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ editor = QColorEditor(parent)
+ editor.color = self.getColor()
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.sigColorChanged.connect(
+ lambda color: self._editorSlot(color))
+ return editor
+
+ def _editorSlot(self, color):
+ self.setData(color, qt.Qt.EditRole)
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.setColor(value)
+ return self.getColor()
+
+ def _pullData(self):
+ self.getColor()
+
+ def setColor(self, color):
+ """Override to implement actual color setter"""
+ pass
+
+
+class BackgroundColorItem(ColorItem):
+ itemName = 'Background'
+
+ def setColor(self, color):
+ self.subject.setBackgroundColor(color)
+
+ def getColor(self):
+ return self.subject.getBackgroundColor()
+
+
+class ForegroundColorItem(ColorItem):
+ itemName = 'Foreground'
+
+ def setColor(self, color):
+ self.subject.setForegroundColor(color)
+
+ def getColor(self):
+ return self.subject.getForegroundColor()
+
+
+class HighlightColorItem(ColorItem):
+ itemName = 'Highlight'
+
+ def setColor(self, color):
+ self.subject.setHighlightColor(color)
+
+ def getColor(self):
+ return self.subject.getHighlightColor()
+
+
+class _LightDirectionAngleBaseItem(SubjectItem):
+ """Base class for directional light angle item."""
+ editable = True
+ persistent = True
+
+ def _init(self):
+ pass
+
+ def getSignals(self):
+ """Override to provide signals to listen"""
+ raise NotImplementedError("MUST be implemented in subclass")
+
+ def _pullData(self):
+ """Override in subclass to get current angle"""
+ raise NotImplementedError("MUST be implemented in subclass")
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ """Override in subclass to set the angle"""
+ raise NotImplementedError("MUST be implemented in subclass")
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QSlider(parent)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(-90)
+ editor.setMaximum(90)
+ editor.setValue(int(self._pullData()))
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.valueChanged.connect(
+ lambda value: self._pushData(value))
+
+ return editor
+
+ def setEditorData(self, editor):
+ editor.setValue(int(self._pullData()))
+ return True
+
+ def _setModelData(self, editor):
+ value = editor.value()
+ self._pushData(value)
+ return True
+
+
+class LightAzimuthAngleItem(_LightDirectionAngleBaseItem):
+ """Light direction azimuth angle item."""
+
+ def getSignals(self):
+ return self.subject.sigAzimuthAngleChanged
+
+ def _pullData(self):
+ return self.subject.getAzimuthAngle()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setAzimuthAngle(value)
+
+
+class LightAltitudeAngleItem(_LightDirectionAngleBaseItem):
+ """Light direction altitude angle item."""
+
+ def getSignals(self):
+ return self.subject.sigAltitudeAngleChanged
+
+ def _pullData(self):
+ return self.subject.getAltitudeAngle()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setAltitudeAngle(value)
+
+
+class _DirectionalLightProxy(qt.QObject):
+ """Proxy to handle directional light with angles rather than vector.
+ """
+
+ sigAzimuthAngleChanged = qt.Signal()
+ """Signal sent when the azimuth angle has changed."""
+
+ sigAltitudeAngleChanged = qt.Signal()
+ """Signal sent when altitude angle has changed."""
+
+ def __init__(self, light):
+ super(_DirectionalLightProxy, self).__init__()
+ self._light = light
+ light.addListener(self._directionUpdated)
+ self._azimuth = 0.
+ self._altitude = 0.
+
+ def getAzimuthAngle(self):
+ """Returns the signed angle in the horizontal plane.
+
+ Unit: degrees.
+ The 0 angle corresponds to the axis perpendicular to the screen.
+
+ :rtype: float
+ """
+ return self._azimuth
+
+ def getAltitudeAngle(self):
+ """Returns the signed vertical angle from the horizontal plane.
+
+ Unit: degrees.
+ Range: [-90, +90]
+
+ :rtype: float
+ """
+ return self._altitude
+
+ def setAzimuthAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param float angle: Angle from -z axis in zx plane in degrees.
+ """
+ if angle != self._azimuth:
+ self._azimuth = angle
+ self._updateLight()
+ self.sigAzimuthAngleChanged.emit()
+
+ def setAltitudeAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param float angle: Angle from -z axis in zy plane in degrees.
+ """
+ if angle != self._altitude:
+ self._altitude = angle
+ self._updateLight()
+ self.sigAltitudeAngleChanged.emit()
+
+ def _directionUpdated(self, *args, **kwargs):
+ """Handle light direction update in the scene"""
+ # Invert direction to manipulate the 'source' pointing to
+ # the center of the viewport
+ x, y, z = - self._light.direction
+
+ # Horizontal plane is plane xz
+ azimuth = numpy.degrees(numpy.arctan2(x, z))
+ altitude = numpy.degrees(numpy.pi/2. - numpy.arccos(y))
+
+ if (abs(azimuth - self.getAzimuthAngle()) > 0.01 and
+ abs(abs(altitude) - 90.) >= 0.001): # Do not update when at zenith
+ self.setAzimuthAngle(azimuth)
+
+ if abs(altitude - self.getAltitudeAngle()) > 0.01:
+ self.setAltitudeAngle(altitude)
+
+ def _updateLight(self):
+ """Update light direction in the scene"""
+ azimuth = numpy.radians(self._azimuth)
+ delta = numpy.pi/2. - numpy.radians(self._altitude)
+ z = - numpy.sin(delta) * numpy.cos(azimuth)
+ x = - numpy.sin(delta) * numpy.sin(azimuth)
+ y = - numpy.cos(delta)
+ self._light.direction = x, y, z
+
+
+class DirectionalLightGroup(SubjectItem):
+ """
+ Root Item for the directional light
+ """
+
+ def __init__(self,subject, *args):
+ self._light = _DirectionalLightProxy(
+ subject.getPlot3DWidget().viewport.light)
+
+ super(DirectionalLightGroup, self).__init__(subject, *args)
+
+ def _init(self):
+
+ nameItem = qt.QStandardItem('Azimuth')
+ nameItem.setEditable(False)
+ valueItem = LightAzimuthAngleItem(self._light)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Altitude')
+ nameItem.setEditable(False)
+ valueItem = LightAltitudeAngleItem(self._light)
+ self.appendRow([nameItem, valueItem])
+
+
+class BoundingBoxItem(SubjectItem):
+ """Bounding box, axes labels and grid visibility item.
+
+ Item is checkable.
+ """
+ itemName = 'Bounding Box'
+
+ def _init(self):
+ visible = self.subject.isBoundingBoxVisible()
+ self.setCheckable(True)
+ self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != self.subject.isBoundingBoxVisible():
+ self.subject.setBoundingBoxVisible(checked)
+
+
+class OrientationIndicatorItem(SubjectItem):
+ """Orientation indicator visibility item.
+
+ Item is checkable.
+ """
+ itemName = 'Axes indicator'
+
+ def _init(self):
+ plot3d = self.subject.getPlot3DWidget()
+ visible = plot3d.isOrientationIndicatorVisible()
+ self.setCheckable(True)
+ self.setCheckState(qt.Qt.Checked if visible else qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ plot3d = self.subject.getPlot3DWidget()
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != plot3d.isOrientationIndicatorVisible():
+ plot3d.setOrientationIndicatorVisible(checked)
+
+
+class ViewSettingsItem(qt.QStandardItem):
+ """Viewport settings"""
+
+ def __init__(self, subject, *args):
+
+ super(ViewSettingsItem, self).__init__(*args)
+
+ self.setEditable(False)
+
+ classes = (BackgroundColorItem,
+ ForegroundColorItem,
+ HighlightColorItem,
+ BoundingBoxItem,
+ OrientationIndicatorItem)
+ for cls in classes:
+ titleItem = qt.QStandardItem(cls.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, cls(subject)])
+
+ nameItem = DirectionalLightGroup(subject, 'Light Direction')
+ valueItem = qt.QStandardItem()
+ self.appendRow([nameItem, valueItem])
+
+
+# Data information ############################################################
+
+class DataChangedItem(SubjectItem):
+ """
+ Base class for items listening to ScalarFieldView.sigDataChanged
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ if subject:
+ return subject.sigDataChanged, subject.sigTransformChanged
+ return None
+
+ def _init(self):
+ self._subjectChanged()
+
+
+class DataTypeItem(DataChangedItem):
+ itemName = 'dtype'
+
+ def _pullData(self):
+ data = self.subject.getData(copy=False)
+ return ((data is not None) and str(data.dtype)) or 'N/A'
+
+
+class DataShapeItem(DataChangedItem):
+ itemName = 'size'
+
+ def _pullData(self):
+ data = self.subject.getData(copy=False)
+ if data is None:
+ return 'N/A'
+ else:
+ return str(list(reversed(data.shape)))
+
+
+class OffsetItem(DataChangedItem):
+ itemName = 'offset'
+
+ def _pullData(self):
+ offset = self.subject.getTranslation()
+ return ((offset is not None) and str(offset)) or 'N/A'
+
+
+class ScaleItem(DataChangedItem):
+ itemName = 'scale'
+
+ def _pullData(self):
+ scale = self.subject.getScale()
+ return ((scale is not None) and str(scale)) or 'N/A'
+
+
+class MatrixItem(DataChangedItem):
+
+ def __init__(self, subject, row, *args):
+ self.__row = row
+ super(MatrixItem, self).__init__(subject, *args)
+
+ def _pullData(self):
+ matrix = self.subject.getTransformMatrix()
+ return str(matrix[self.__row])
+
+
+class DataSetItem(qt.QStandardItem):
+
+ def __init__(self, subject, *args):
+
+ super(DataSetItem, self).__init__(*args)
+
+ self.setEditable(False)
+
+ klasses = [DataTypeItem, DataShapeItem, OffsetItem]
+ for klass in klasses:
+ titleItem = qt.QStandardItem(klass.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, klass(subject)])
+
+ matrixItem = qt.QStandardItem('matrix')
+ matrixItem.setEditable(False)
+ valueItem = qt.QStandardItem()
+ self.appendRow([matrixItem, valueItem])
+
+ for row in range(3):
+ titleItem = qt.QStandardItem()
+ titleItem.setEditable(False)
+ valueItem = MatrixItem(subject, row)
+ matrixItem.appendRow([titleItem, valueItem])
+
+ titleItem = qt.QStandardItem(ScaleItem.itemName)
+ titleItem.setEditable(False)
+ self.appendRow([titleItem, ScaleItem(subject)])
+
+
+# Isosurface ##################################################################
+
+class IsoSurfaceRootItem(SubjectItem):
+ """
+ Root (i.e : column index 0) Isosurface item.
+ """
+
+ def __init__(self, subject, normalization, *args):
+ self._isoLevelSliderNormalization = normalization
+ super(IsoSurfaceRootItem, self).__init__(subject, *args)
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigColorChanged,
+ subject.sigVisibilityChanged]
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ if signalIdx == 0:
+ color = self.subject.getColor()
+ self.setData(color, qt.Qt.DecorationRole)
+ elif signalIdx == 1:
+ visible = args[0]
+ self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
+
+ def _init(self):
+ self.setCheckable(True)
+
+ isosurface = self.subject
+ color = isosurface.getColor()
+ visible = isosurface.isVisible()
+ self.setData(color, qt.Qt.DecorationRole)
+ self.setCheckState((visible and qt.Qt.Checked) or qt.Qt.Unchecked)
+
+ nameItem = qt.QStandardItem('Level')
+ sliderItem = IsoSurfaceLevelSlider(self.subject,
+ self._isoLevelSliderNormalization)
+ self.appendRow([nameItem, sliderItem])
+
+ nameItem = qt.QStandardItem('Color')
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceColorItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Opacity')
+ nameItem.setTextAlignment(qt.Qt.AlignLeft | qt.Qt.AlignTop)
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceAlphaItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem()
+ nameItem.setEditable(False)
+ valueItem = IsoSurfaceAlphaLegendItem(self.subject)
+ valueItem.setEditable(False)
+ self.appendRow([nameItem, valueItem])
+
+ def queryRemove(self, view=None):
+ buttons = qt.QMessageBox.Ok | qt.QMessageBox.Cancel
+ ans = qt.QMessageBox.question(view,
+ 'Remove isosurface',
+ 'Remove the selected iso-surface?',
+ buttons=buttons)
+ if ans == qt.QMessageBox.Ok:
+ sfview = self.subject.parent()
+ if sfview:
+ sfview.removeIsosurface(self.subject)
+ return False
+ return False
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ visible = self.subject.isVisible()
+ if checked != visible:
+ self.subject.setVisible(checked)
+
+
+class IsoSurfaceLevelItem(SubjectItem):
+ """
+ Base class for the isosurface level items.
+ """
+ editable = True
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigLevelChanged,
+ subject.sigVisibilityChanged]
+
+ def getEditor(self, parent, option, index):
+ return FloatEdit(parent)
+
+ def setEditorData(self, editor):
+ editor.setValue(self._pullData())
+ return False
+
+ def _setModelData(self, editor):
+ self._pushData(editor.value())
+ return True
+
+ def _pullData(self):
+ return self.subject.getLevel()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setLevel(value)
+ return self.subject.getLevel()
+
+
+class _IsoLevelSlider(qt.QSlider):
+ """QSlider used for iso-surface level with linear scale"""
+
+ def __init__(self, parent, subject, normalization):
+ super(_IsoLevelSlider, self).__init__(parent=parent)
+ self.subject = subject
+
+ if normalization == 'arcsinh':
+ self.__norm = numpy.arcsinh
+ self.__invNorm = numpy.sinh
+ elif normalization == 'linear':
+ self.__norm = lambda x: x
+ self.__invNorm = lambda x: x
+ else:
+ raise ValueError(
+ "Unsupported normalization %s", normalization)
+
+ self.sliderReleased.connect(self.__sliderReleased)
+
+ self.subject.sigLevelChanged.connect(self.setLevel)
+ self.subject.parent().sigDataChanged.connect(self.__dataChanged)
+
+ def setLevel(self, level):
+ """Set slider from iso-surface level"""
+ dataRange = self.subject.parent().getDataRange()
+
+ if dataRange is not None:
+ min_ = self.__norm(dataRange[0])
+ max_ = self.__norm(dataRange[-1])
+
+ width = max_ - min_
+ if width > 0:
+ sliderWidth = self.maximum() - self.minimum()
+ sliderPosition = sliderWidth * (self.__norm(level) - min_) / width
+ self.setValue(int(sliderPosition))
+
+ def __dataChanged(self):
+ """Handles data update to refresh slider range if needed"""
+ self.setLevel(self.subject.getLevel())
+
+ def __sliderReleased(self):
+ value = self.value()
+ dataRange = self.subject.parent().getDataRange()
+ if dataRange is not None:
+ min_ = self.__norm(dataRange[0])
+ max_ = self.__norm(dataRange[-1])
+ width = max_ - min_
+ sliderWidth = self.maximum() - self.minimum()
+ level = min_ + width * value / sliderWidth
+ self.subject.setLevel(self.__invNorm(level))
+
+
+class IsoSurfaceLevelSlider(IsoSurfaceLevelItem):
+ """
+ Isosurface level item with a slider editor.
+ """
+ nTicks = 1000
+ persistent = True
+
+ def __init__(self, subject, normalization):
+ self.normalization = normalization
+ super(IsoSurfaceLevelSlider, self).__init__(subject)
+
+ def getEditor(self, parent, option, index):
+ editor = _IsoLevelSlider(parent, self.subject, self.normalization)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(0)
+ editor.setMaximum(self.nTicks)
+
+ editor.setSingleStep(1)
+
+ editor.setLevel(self.subject.getLevel())
+ return editor
+
+ def setEditorData(self, editor):
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class IsoSurfaceColorItem(SubjectItem):
+ """
+ Isosurface color item.
+ """
+ editable = True
+ persistent = True
+
+ def getSignals(self):
+ return self.subject.sigColorChanged
+
+ def getEditor(self, parent, option, index):
+ editor = QColorEditor(parent)
+ color = self.subject.getColor()
+ color.setAlpha(255)
+ editor.color = color
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.sigColorChanged.connect(
+ lambda color: self.__editorChanged(color))
+ return editor
+
+ def __editorChanged(self, color):
+ color.setAlpha(self.subject.getColor().alpha())
+ self.subject.setColor(color)
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self.subject.setColor(value)
+ return self.subject.getColor()
+
+
+class QColorEditor(qt.QWidget):
+ """
+ QColor editor.
+ """
+ sigColorChanged = qt.Signal(object)
+
+ color = property(lambda self: qt.QColor(self.__color))
+
+ @color.setter
+ def color(self, color):
+ self._setColor(color)
+ self.__previousColor = color
+
+ def __init__(self, *args, **kwargs):
+ super(QColorEditor, self).__init__(*args, **kwargs)
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ button = qt.QToolButton()
+ icon = qt.QIcon(qt.QPixmap(32, 32))
+ button.setIcon(icon)
+ layout.addWidget(button)
+ button.clicked.connect(self.__showColorDialog)
+ layout.addStretch(1)
+
+ self.__color = None
+ self.__previousColor = None
+
+ def sizeHint(self):
+ return qt.QSize(0, 0)
+
+ def _setColor(self, qColor):
+ button = self.findChild(qt.QToolButton)
+ pixmap = qt.QPixmap(32, 32)
+ pixmap.fill(qColor)
+ button.setIcon(qt.QIcon(pixmap))
+ self.__color = qColor
+
+ def __showColorDialog(self):
+ dialog = qt.QColorDialog(parent=self)
+ if sys.platform == 'darwin':
+ # Use of native color dialog on macos might cause problems
+ dialog.setOption(qt.QColorDialog.DontUseNativeDialog, True)
+
+ self.__previousColor = self.__color
+ dialog.setAttribute(qt.Qt.WA_DeleteOnClose)
+ dialog.setModal(True)
+ dialog.currentColorChanged.connect(self.__colorChanged)
+ dialog.finished.connect(self.__dialogClosed)
+ dialog.show()
+
+ def __colorChanged(self, color):
+ self.__color = color
+ self._setColor(color)
+ self.sigColorChanged.emit(color)
+
+ def __dialogClosed(self, result):
+ if result == qt.QDialog.Rejected:
+ self.__colorChanged(self.__previousColor)
+ self.__previousColor = None
+
+
+class IsoSurfaceAlphaItem(SubjectItem):
+ """
+ Isosurface alpha item.
+ """
+ editable = True
+ persistent = True
+
+ def _init(self):
+ pass
+
+ def getSignals(self):
+ return self.subject.sigColorChanged
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QSlider(parent)
+ editor.setOrientation(qt.Qt.Horizontal)
+ editor.setMinimum(0)
+ editor.setMaximum(255)
+
+ color = self.subject.getColor()
+ editor.setValue(color.alpha())
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.valueChanged.connect(
+ lambda value: self.__editorChanged(value))
+
+ return editor
+
+ def __editorChanged(self, value):
+ color = self.subject.getColor()
+ color.setAlpha(value)
+ self.subject.setColor(color)
+
+ def setEditorData(self, editor):
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class IsoSurfaceAlphaLegendItem(SubjectItem):
+ """Legend to place under opacity slider"""
+
+ editable = False
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ layout = qt.QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ layout.addWidget(qt.QLabel('0'))
+ layout.addStretch(1)
+ layout.addWidget(qt.QLabel('1'))
+
+ editor = qt.QWidget(parent)
+ editor.setLayout(layout)
+ return editor
+
+
+class IsoSurfaceCount(SubjectItem):
+ """
+ Item displaying the number of isosurfaces.
+ """
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
+
+ def _pullData(self):
+ return len(self.subject.getIsosurfaces())
+
+
+class IsoSurfaceAddRemoveWidget(qt.QWidget):
+
+ sigViewTask = qt.Signal(str)
+ """Signal for the tree view to perform some task"""
+
+ def __init__(self, parent, item):
+ super(IsoSurfaceAddRemoveWidget, self).__init__(parent)
+ self._item = item
+ layout = qt.QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ addBtn = qt.QToolButton(self)
+ addBtn.setText('+')
+ addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(addBtn)
+ addBtn.clicked.connect(self.__addClicked)
+
+ removeBtn = qt.QToolButton(self)
+ removeBtn.setText('-')
+ removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(removeBtn)
+ removeBtn.clicked.connect(self.__removeClicked)
+
+ layout.addStretch(1)
+
+ def __addClicked(self):
+ sfview = self._item.subject
+ if not sfview:
+ return
+ dataRange = sfview.getDataRange()
+ if dataRange is None:
+ dataRange = [0, 1]
+
+ sfview.addIsosurface(
+ numpy.mean((dataRange[0], dataRange[-1])), '#0000FF')
+
+ def __removeClicked(self):
+ self.sigViewTask.emit('remove_iso')
+
+
+class IsoSurfaceAddRemoveItem(SubjectItem):
+ """
+ Item displaying a simple QToolButton allowing to add an isosurface.
+ """
+ persistent = True
+
+ def getEditor(self, parent, option, index):
+ return IsoSurfaceAddRemoveWidget(parent, self)
+
+
+class IsoSurfaceGroup(SubjectItem):
+ """
+ Root item for the list of isosurface items.
+ """
+
+ def __init__(self, subject, normalization, *args):
+ self._isoLevelSliderNormalization = normalization
+ super(IsoSurfaceGroup, self).__init__(subject, *args)
+
+ def getSignals(self):
+ subject = self.subject
+ return [subject.sigIsosurfaceAdded, subject.sigIsosurfaceRemoved]
+
+ def _subjectChanged(self, signalIdx=None, args=None, kwargs=None):
+ if signalIdx == 0:
+ if len(args) >= 1:
+ isosurface = args[0]
+ if not isinstance(isosurface, Isosurface):
+ raise ValueError('Expected an isosurface instance.')
+ self.__addIsosurface(isosurface)
+ else:
+ raise ValueError('Expected an isosurface instance.')
+ elif signalIdx == 1:
+ if len(args) >= 1:
+ isosurface = args[0]
+ if not isinstance(isosurface, Isosurface):
+ raise ValueError('Expected an isosurface instance.')
+ self.__removeIsosurface(isosurface)
+ else:
+ raise ValueError('Expected an isosurface instance.')
+
+ def __addIsosurface(self, isosurface):
+ valueItem = IsoSurfaceRootItem(
+ subject=isosurface,
+ normalization=self._isoLevelSliderNormalization)
+ nameItem = IsoSurfaceLevelItem(subject=isosurface)
+ self.insertRow(max(0, self.rowCount() - 1), [valueItem, nameItem])
+
+ def __removeIsosurface(self, isosurface):
+ for row in range(self.rowCount()):
+ child = self.child(row)
+ subject = getattr(child, 'subject', None)
+ if subject == isosurface:
+ self.takeRow(row)
+ break
+
+ def _init(self):
+ nameItem = IsoSurfaceAddRemoveItem(self.subject)
+ valueItem = qt.QStandardItem()
+ valueItem.setEditable(False)
+ self.appendRow([nameItem, valueItem])
+
+ subject = self.subject
+ isosurfaces = subject.getIsosurfaces()
+ for isosurface in isosurfaces:
+ self.__addIsosurface(isosurface)
+
+
+# Cutting Plane ###############################################################
+
+class ColormapBase(SubjectItem):
+ """
+ Mixin class for colormap items.
+ """
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigColormapChanged]
+
+
+class PlaneMinRangeItem(ColormapBase):
+ """
+ colormap minVal item.
+ Editor is a QLineEdit with a QDoubleValidator
+ """
+ editable = True
+
+ def _pullData(self):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ auto = colormap.isAutoscale()
+ if auto == self.isEnabled():
+ self._enableRow(not auto)
+ return colormap.getVMin()
+
+ def _pushData(self, value, role=qt.Qt.UserRole):
+ self._setVMin(value)
+
+ def _setVMin(self, value):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ vMin = value
+ vMax = colormap.getVMax()
+
+ if vMax is not None and value > vMax:
+ vMin = vMax
+ vMax = value
+ colormap.setVRange(vMin, vMax)
+
+ def getEditor(self, parent, option, index):
+ return FloatEdit(parent)
+
+ def setEditorData(self, editor):
+ editor.setValue(self._pullData())
+ return True
+
+ def _setModelData(self, editor):
+ value = editor.value()
+ self._setVMin(value)
+ return True
+
+
+class PlaneMaxRangeItem(ColormapBase):
+ """
+ colormap maxVal item.
+ Editor is a QLineEdit with a QDoubleValidator
+ """
+ editable = True
+
+ def _pullData(self):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ auto = colormap.isAutoscale()
+ if auto == self.isEnabled():
+ self._enableRow(not auto)
+ return self.subject.getCutPlanes()[0].getColormap().getVMax()
+
+ def _setVMax(self, value):
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ vMin = colormap.getVMin()
+ vMax = value
+ if vMin is not None and value < vMin:
+ vMax = vMin
+ vMin = value
+ colormap.setVRange(vMin, vMax)
+
+ def getEditor(self, parent, option, index):
+ return FloatEdit(parent)
+
+ def setEditorData(self, editor):
+ editor.setText(str(self._pullData()))
+ return True
+
+ def _setModelData(self, editor):
+ value = editor.value()
+ self._setVMax(value)
+ return True
+
+
+class PlaneOrientationItem(SubjectItem):
+ """
+ Plane orientation item.
+ Editor is a QComboBox.
+ """
+ editable = True
+
+ _PLANE_ACTIONS = (
+ ('3d-plane-normal-x', 'Plane 0',
+ 'Set plane perpendicular to red axis', (1., 0., 0.)),
+ ('3d-plane-normal-y', 'Plane 1',
+ 'Set plane perpendicular to green axis', (0., 1., 0.)),
+ ('3d-plane-normal-z', 'Plane 2',
+ 'Set plane perpendicular to blue axis', (0., 0., 1.)),
+ )
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigPlaneChanged]
+
+ def _pullData(self):
+ currentNormal = self.subject.getCutPlanes()[0].getNormal(
+ coordinates='scene')
+ for _, text, _, normal in self._PLANE_ACTIONS:
+ if numpy.allclose(normal, currentNormal):
+ return text
+ return ''
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ for iconName, text, tooltip, normal in self._PLANE_ACTIONS:
+ editor.addItem(getQIcon(iconName), text)
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.currentIndexChanged[int].connect(
+ lambda index: self.__editorChanged(index))
+ return editor
+
+ def __editorChanged(self, index):
+ normal = self._PLANE_ACTIONS[index][3]
+ plane = self.subject.getCutPlanes()[0]
+ plane.setNormal(normal, coordinates='scene')
+ plane.moveToCenter()
+
+ def setEditorData(self, editor):
+ currentText = self._pullData()
+ index = 0
+ for normIdx, (_, text, _, _) in enumerate(self._PLANE_ACTIONS):
+ if text == currentText:
+ index = normIdx
+ break
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ return True
+
+
+class PlaneInterpolationItem(SubjectItem):
+ """Toggle cut plane interpolation method: nearest or linear.
+
+ Item is checkable
+ """
+
+ def _init(self):
+ interpolation = self.subject.getCutPlanes()[0].getInterpolation()
+ self.setCheckable(True)
+ self.setCheckState(
+ qt.Qt.Checked if interpolation == 'linear' else qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigInterpolationChanged]
+
+ def leftClicked(self):
+ checked = self.checkState() == qt.Qt.Checked
+ self._setInterpolation('linear' if checked else 'nearest')
+
+ def _pullData(self):
+ interpolation = self.subject.getCutPlanes()[0].getInterpolation()
+ self._setInterpolation(interpolation)
+ return interpolation[0].upper() + interpolation[1:]
+
+ def _setInterpolation(self, interpolation):
+ self.subject.getCutPlanes()[0].setInterpolation(interpolation)
+
+
+class PlaneDisplayBelowMinItem(SubjectItem):
+ """Toggle whether to display or not values <= colormap min of the cut plane
+
+ Item is checkable
+ """
+
+ def _init(self):
+ display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
+ self.setCheckable(True)
+ self.setCheckState(
+ qt.Qt.Checked if display else qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def getSignals(self):
+ return [self.subject.getCutPlanes()[0].sigTransparencyChanged]
+
+ def leftClicked(self):
+ checked = self.checkState() == qt.Qt.Checked
+ self._setDisplayValuesBelowMin(checked)
+
+ def _pullData(self):
+ display = self.subject.getCutPlanes()[0].getDisplayValuesBelowMin()
+ self._setDisplayValuesBelowMin(display)
+ return "Displayed" if display else "Hidden"
+
+ def _setDisplayValuesBelowMin(self, display):
+ self.subject.getCutPlanes()[0].setDisplayValuesBelowMin(display)
+
+
+class PlaneColormapItem(ColormapBase):
+ """
+ colormap name item.
+ Editor is a QComboBox
+ """
+ editable = True
+
+ listValues = ['gray', 'reversed gray',
+ 'temperature', 'red',
+ 'green', 'blue',
+ 'viridis', 'magma', 'inferno', 'plasma']
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ editor.addItems(self.listValues)
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.currentIndexChanged[int].connect(
+ lambda index: self.__editorChanged(index))
+
+ return editor
+
+ def __editorChanged(self, index):
+ colormapName = self.listValues[index]
+ colormap = self.subject.getCutPlanes()[0].getColormap()
+ colormap.setName(colormapName)
+
+ def setEditorData(self, editor):
+ colormapName = self.subject.getCutPlanes()[0].getColormap().getName()
+ try:
+ index = self.listValues.index(colormapName)
+ except ValueError:
+ _logger.error('Unsupported colormap: %s', colormapName)
+ else:
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ self.__editorChanged(editor.currentIndex())
+ return True
+
+ def _pullData(self):
+ return self.subject.getCutPlanes()[0].getColormap().getName()
+
+
+class PlaneAutoScaleItem(ColormapBase):
+ """
+ colormap autoscale item.
+ Item is checkable.
+ """
+
+ def _init(self):
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ self.setCheckable(True)
+ self.setCheckState((colorMap.isAutoscale() and qt.Qt.Checked)
+ or qt.Qt.Unchecked)
+ self.setData(self._pullData(), role=qt.Qt.DisplayRole, pushData=False)
+
+ def leftClicked(self):
+ checked = (self.checkState() == qt.Qt.Checked)
+ self._setAutoScale(checked)
+
+ def _setAutoScale(self, auto):
+ view3d = self.subject
+ colormap = view3d.getCutPlanes()[0].getColormap()
+
+ if auto != colormap.isAutoscale():
+ if auto:
+ vMin = vMax = None
+ else:
+ dataRange = view3d.getDataRange()
+ if dataRange is None:
+ vMin = vMax = None
+ else:
+ vMin, vMax = dataRange[0], dataRange[-1]
+ colormap.setVRange(vMin, vMax)
+
+ def _pullData(self):
+ auto = self.subject.getCutPlanes()[0].getColormap().isAutoscale()
+ self._setAutoScale(auto)
+ if auto:
+ data = 'Auto'
+ else:
+ data = 'User'
+ return data
+
+
+class NormalizationNode(ColormapBase):
+ """
+ colormap normalization item.
+ Item is a QComboBox.
+ """
+ editable = True
+ listValues = list(Colormap.NORMALIZATIONS)
+
+ def getEditor(self, parent, option, index):
+ editor = qt.QComboBox(parent)
+ editor.addItems(self.listValues)
+
+ # Wrapping call in lambda is a workaround for PySide with Python 3
+ editor.currentIndexChanged[int].connect(
+ lambda index: self.__editorChanged(index))
+
+ return editor
+
+ def __editorChanged(self, index):
+ colorMap = self.subject.getCutPlanes()[0].getColormap()
+ normalization = self.listValues[index]
+ self.subject.getCutPlanes()[0].setColormap(name=colorMap.getName(),
+ norm=normalization,
+ vmin=colorMap.getVMin(),
+ vmax=colorMap.getVMax())
+
+ def setEditorData(self, editor):
+ normalization = self.subject.getCutPlanes()[0].getColormap().getNormalization()
+ index = self.listValues.index(normalization)
+ editor.setCurrentIndex(index)
+ return True
+
+ def _setModelData(self, editor):
+ self.__editorChanged(editor.currentIndex())
+ return True
+
+ def _pullData(self):
+ return self.subject.getCutPlanes()[0].getColormap().getNormalization()
+
+
+class PlaneGroup(SubjectItem):
+ """
+ Root Item for the plane items.
+ """
+ def _init(self):
+ valueItem = qt.QStandardItem()
+ valueItem.setEditable(False)
+ nameItem = PlaneVisibleItem(self.subject, 'Visible')
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Colormap')
+ nameItem.setEditable(False)
+ valueItem = PlaneColormapItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Normalization')
+ nameItem.setEditable(False)
+ valueItem = NormalizationNode(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Orientation')
+ nameItem.setEditable(False)
+ valueItem = PlaneOrientationItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Interpolation')
+ nameItem.setEditable(False)
+ valueItem = PlaneInterpolationItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Autoscale')
+ nameItem.setEditable(False)
+ valueItem = PlaneAutoScaleItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Min')
+ nameItem.setEditable(False)
+ valueItem = PlaneMinRangeItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Max')
+ nameItem.setEditable(False)
+ valueItem = PlaneMaxRangeItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+ nameItem = qt.QStandardItem('Values<=Min')
+ nameItem.setEditable(False)
+ valueItem = PlaneDisplayBelowMinItem(self.subject)
+ self.appendRow([nameItem, valueItem])
+
+
+class PlaneVisibleItem(SubjectItem):
+ """
+ Plane visibility item.
+ Item is checkable.
+ """
+ def _init(self):
+ plane = self.subject.getCutPlanes()[0]
+ self.setCheckable(True)
+ self.setCheckState((plane.isVisible() and qt.Qt.Checked)
+ or qt.Qt.Unchecked)
+
+ def leftClicked(self):
+ plane = self.subject.getCutPlanes()[0]
+ checked = (self.checkState() == qt.Qt.Checked)
+ if checked != plane.isVisible():
+ plane.setVisible(checked)
+ if plane.isVisible():
+ plane.moveToCenter()
+
+
+# Tree ########################################################################
+
+class ItemDelegate(qt.QStyledItemDelegate):
+ """
+ Delegate for the QTreeView filled with SubjectItems.
+ """
+
+ sigDelegateEvent = qt.Signal(str)
+
+ def __init__(self, parent=None):
+ super(ItemDelegate, self).__init__(parent)
+
+ def createEditor(self, parent, option, index):
+ item = index.model().itemFromIndex(index)
+ if item:
+ if isinstance(item, SubjectItem):
+ editor = item.getEditor(parent, option, index)
+ if editor:
+ editor.setAutoFillBackground(True)
+ if hasattr(editor, 'sigViewTask'):
+ editor.sigViewTask.connect(self.__viewTask)
+ return editor
+
+ editor = super(ItemDelegate, self).createEditor(parent,
+ option,
+ index)
+ return editor
+
+ def updateEditorGeometry(self, editor, option, index):
+ editor.setGeometry(option.rect)
+
+ def setEditorData(self, editor, index):
+ item = index.model().itemFromIndex(index)
+ if item:
+ if isinstance(item, SubjectItem) and item.setEditorData(editor):
+ return
+ super(ItemDelegate, self).setEditorData(editor, index)
+
+ def setModelData(self, editor, model, index):
+ item = index.model().itemFromIndex(index)
+ if isinstance(item, SubjectItem) and item._setModelData(editor):
+ return
+ super(ItemDelegate, self).setModelData(editor, model, index)
+
+ def __viewTask(self, task):
+ self.sigDelegateEvent.emit(task)
+
+
+class TreeView(qt.QTreeView):
+ """
+ TreeView displaying the SubjectItems for the ScalarFieldView.
+ """
+
+ def __init__(self, parent=None):
+ super(TreeView, self).__init__(parent)
+ self.__openedIndex = None
+ self._isoLevelSliderNormalization = 'linear'
+
+ self.setIconSize(qt.QSize(16, 16))
+
+ header = self.header()
+ header.setSectionResizeMode(qt.QHeaderView.ResizeToContents)
+
+ delegate = ItemDelegate()
+ self.setItemDelegate(delegate)
+ delegate.sigDelegateEvent.connect(self.__delegateEvent)
+ self.setSelectionBehavior(qt.QAbstractItemView.SelectRows)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection)
+
+ self.clicked.connect(self.__clicked)
+
+ def setSfView(self, sfView):
+ """
+ Sets the ScalarFieldView this view is controlling.
+
+ :param sfView: A `ScalarFieldView`
+ """
+ model = qt.QStandardItemModel()
+ model.setColumnCount(ModelColumns.ColumnMax)
+ model.setHorizontalHeaderLabels(['Name', 'Value'])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([ViewSettingsItem(sfView, 'Style'), item])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([DataSetItem(sfView, 'Data'), item])
+
+ item = IsoSurfaceCount(sfView)
+ item.setEditable(False)
+ model.appendRow([IsoSurfaceGroup(sfView,
+ self._isoLevelSliderNormalization,
+ 'Isosurfaces'),
+ item])
+
+ item = qt.QStandardItem()
+ item.setEditable(False)
+ model.appendRow([PlaneGroup(sfView, 'Cutting Plane'), item])
+
+ self.setModel(model)
+
+ def setModel(self, model):
+ """
+ Reimplementation of the QTreeView.setModel method. It connects the
+ rowsRemoved signal and opens the persistent editors.
+
+ :param qt.QStandardItemModel model: the model
+ """
+
+ prevModel = self.model()
+ if prevModel:
+ self.__openPersistentEditors(qt.QModelIndex(), False)
+ try:
+ prevModel.rowsRemoved.disconnect(self.rowsRemoved)
+ except TypeError:
+ pass
+
+ super(TreeView, self).setModel(model)
+ model.rowsRemoved.connect(self.rowsRemoved)
+ self.__openPersistentEditors(qt.QModelIndex())
+
+ def __openPersistentEditors(self, parent=None, openEditor=True):
+ """
+ Opens or closes the items persistent editors.
+
+ :param qt.QModelIndex parent: starting index, or None if the whole tree
+ is to be considered.
+ :param bool openEditor: True to open the editors, False to close them.
+ """
+ model = self.model()
+
+ if not model:
+ return
+
+ if not parent or not parent.isValid():
+ parent = self.model().invisibleRootItem().index()
+
+ if openEditor:
+ meth = self.openPersistentEditor
+ else:
+ meth = self.closePersistentEditor
+
+ curParent = parent
+ children = [model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))]
+
+ columnCount = model.columnCount()
+
+ while len(children) > 0:
+ curParent = children.pop(-1)
+
+ children.extend([model.index(row, 0, curParent)
+ for row in range(model.rowCount(curParent))])
+
+ for colIdx in range(columnCount):
+ sibling = model.sibling(curParent.row(),
+ colIdx,
+ curParent)
+ item = model.itemFromIndex(sibling)
+ if isinstance(item, SubjectItem) and item.persistent:
+ meth(sibling)
+
+ def rowsAboutToBeRemoved(self, parent, start, end):
+ """
+ Reimplementation of the QTreeView.rowsAboutToBeRemoved. Closes all
+ persistent editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index (inclusive)
+ :param int end: End index from parent index (inclusive)
+ """
+ self.__openPersistentEditors(parent, False)
+ super(TreeView, self).rowsAboutToBeRemoved(parent, start, end)
+
+ def rowsRemoved(self, parent, start, end):
+ """
+ Called when QTreeView.rowsRemoved is emitted. Opens all persistent
+ editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index (inclusive)
+ :param int end: End index from parent index (inclusive)
+ """
+ super(TreeView, self).rowsRemoved(parent, start, end)
+ self.__openPersistentEditors(parent, True)
+
+ def rowsInserted(self, parent, start, end):
+ """
+ Reimplementation of the QTreeView.rowsInserted. Opens all persistent
+ editors under parent.
+
+ :param qt.QModelIndex parent: Parent index
+ :param int start: Start index from parent index
+ :param int end: End index from parent index
+ """
+ self.__openPersistentEditors(parent, False)
+ super(TreeView, self).rowsInserted(parent, start, end)
+ self.__openPersistentEditors(parent)
+
+ def keyReleaseEvent(self, event):
+ """
+ Reimplementation of the QTreeView.keyReleaseEvent.
+ At the moment only Key_Delete is handled. It calls the selected item's
+ queryRemove method, and deleted the item if needed.
+
+ :param qt.QKeyEvent event: A key event
+ """
+
+ # TODO : better filtering
+ key = event.key()
+ modifiers = event.modifiers()
+
+ if key == qt.Qt.Key_Delete and modifiers == qt.Qt.NoModifier:
+ self.__removeIsosurfaces()
+
+ super(TreeView, self).keyReleaseEvent(event)
+
+ def __removeIsosurfaces(self):
+ model = self.model()
+ selected = self.selectedIndexes()
+ items = []
+ # WARNING : the selection mode is set to single, so we re not
+ # supposed to have more than one item here.
+ # Multiple selection deletion has not been tested.
+ # Watch out for index invalidation
+ for index in selected:
+ leftIndex = model.sibling(index.row(), 0, index)
+ leftItem = model.itemFromIndex(leftIndex)
+ if isinstance(leftItem, SubjectItem) and leftItem not in items:
+ items.append(leftItem)
+
+ isos = [item for item in items if isinstance(item, IsoSurfaceRootItem)]
+ if isos:
+ for iso in isos:
+ if iso.queryRemove(self):
+ parentItem = iso.parent()
+ parentItem.removeRow(iso.row())
+ else:
+ qt.QMessageBox.information(
+ self,
+ 'Remove isosurface',
+ 'Select an iso-surface to remove it')
+
+ def __clicked(self, index):
+ """
+ Called when the QTreeView.clicked signal is emitted. Calls the item's
+ leftClick method.
+
+ :param qt.QIndex index: An index
+ """
+ item = self.model().itemFromIndex(index)
+ if isinstance(item, SubjectItem):
+ item.leftClicked()
+
+ def __delegateEvent(self, task):
+ if task == 'remove_iso':
+ self.__removeIsosurfaces()
+
+ def setIsoLevelSliderNormalization(self, normalization):
+ """Set the normalization for iso level slider
+
+ This MUST be called *before* :meth:`setSfView` to have an effect.
+
+ :param str normalization: Either 'linear' or 'arcsinh'
+ """
+ assert normalization in ('linear', 'arcsinh')
+ self._isoLevelSliderNormalization = normalization
diff --git a/silx/gui/plot3d/ScalarFieldView.py b/src/silx/gui/plot3d/ScalarFieldView.py
index b2bb254..b2bb254 100644
--- a/silx/gui/plot3d/ScalarFieldView.py
+++ b/src/silx/gui/plot3d/ScalarFieldView.py
diff --git a/silx/gui/plot3d/SceneWidget.py b/src/silx/gui/plot3d/SceneWidget.py
index 883f5e7..883f5e7 100644
--- a/silx/gui/plot3d/SceneWidget.py
+++ b/src/silx/gui/plot3d/SceneWidget.py
diff --git a/silx/gui/plot3d/SceneWindow.py b/src/silx/gui/plot3d/SceneWindow.py
index 052a4dc..052a4dc 100644
--- a/silx/gui/plot3d/SceneWindow.py
+++ b/src/silx/gui/plot3d/SceneWindow.py
diff --git a/silx/gui/plot3d/__init__.py b/src/silx/gui/plot3d/__init__.py
index af74613..af74613 100644
--- a/silx/gui/plot3d/__init__.py
+++ b/src/silx/gui/plot3d/__init__.py
diff --git a/silx/gui/plot3d/_model/__init__.py b/src/silx/gui/plot3d/_model/__init__.py
index 4b16e32..4b16e32 100644
--- a/silx/gui/plot3d/_model/__init__.py
+++ b/src/silx/gui/plot3d/_model/__init__.py
diff --git a/silx/gui/plot3d/_model/core.py b/src/silx/gui/plot3d/_model/core.py
index e8e0820..e8e0820 100644
--- a/silx/gui/plot3d/_model/core.py
+++ b/src/silx/gui/plot3d/_model/core.py
diff --git a/src/silx/gui/plot3d/_model/items.py b/src/silx/gui/plot3d/_model/items.py
new file mode 100644
index 0000000..492f44b
--- /dev/null
+++ b/src/silx/gui/plot3d/_model/items.py
@@ -0,0 +1,1759 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""
+This module provides base classes to implement models for 3D scene content
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+
+from collections import OrderedDict
+import functools
+import logging
+import weakref
+
+import numpy
+
+from ...utils.image import convertArrayToQImage
+from ...colors import preferredColormaps
+from ... import qt, icons
+from .. import items
+from ..items.volume import Isosurface, CutPlane, ComplexIsosurface
+from ..Plot3DWidget import Plot3DWidget
+
+
+from .core import AngleDegreeRow, BaseRow, ColorProxyRow, ProxyRow, StaticRow
+
+
+_logger = logging.getLogger(__name__)
+
+
+class ItemProxyRow(ProxyRow):
+ """Provides a node to proxy a data accessible through functions.
+
+ It listens on sigItemChanged to trigger the update.
+
+ Warning: Only weak reference are kept on fget and fset.
+
+ :param Item3D item: The item to
+ :param str name: The name of this node
+ :param callable fget: A callable returning the data
+ :param callable fset:
+ An optional callable setting the data with data as a single argument.
+ :param events:
+ An optional event kind or list of event kinds to react upon.
+ :param callable toModelData:
+ An optional callable to convert from fget
+ callable to data returned by the model.
+ :param callable fromModelData:
+ An optional callable converting data provided to the model to
+ data for fset.
+ :param editorHint: Data to provide as UserRole for editor selection/setup
+ """
+
+ def __init__(self,
+ item,
+ name='',
+ fget=None,
+ fset=None,
+ events=None,
+ toModelData=None,
+ fromModelData=None,
+ editorHint=None):
+ super(ItemProxyRow, self).__init__(
+ name=name,
+ fget=fget,
+ fset=fset,
+ notify=None,
+ toModelData=toModelData,
+ fromModelData=fromModelData,
+ editorHint=editorHint)
+
+ if isinstance(events, (items.ItemChangedType,
+ items.Item3DChangedType)):
+ events = (events,)
+ self.__events = events
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def _itemChanged(self, event):
+ """Handle item changed
+
+ :param Union[ItemChangedType,Item3DChangedType] event:
+ """
+ if self.__events is None or event in self.__events:
+ self._notified()
+
+
+class ItemColorProxyRow(ColorProxyRow, ItemProxyRow):
+ """Combines :class:`ColorProxyRow` and :class:`ItemProxyRow`"""
+
+ def __init__(self, *args, **kwargs):
+ ItemProxyRow.__init__(self, *args, **kwargs)
+
+
+class ItemAngleDegreeRow(AngleDegreeRow, ItemProxyRow):
+ """Combines :class:`AngleDegreeRow` and :class:`ItemProxyRow`"""
+
+ def __init__(self, *args, **kwargs):
+ ItemProxyRow.__init__(self, *args, **kwargs)
+
+
+class _DirectionalLightProxy(qt.QObject):
+ """Proxy to handle directional light with angles rather than vector.
+ """
+
+ sigAzimuthAngleChanged = qt.Signal()
+ """Signal sent when the azimuth angle has changed."""
+
+ sigAltitudeAngleChanged = qt.Signal()
+ """Signal sent when altitude angle has changed."""
+
+ def __init__(self, light):
+ super(_DirectionalLightProxy, self).__init__()
+ self._light = light
+ light.addListener(self._directionUpdated)
+ self._azimuth = 0
+ self._altitude = 0
+
+ def getAzimuthAngle(self):
+ """Returns the signed angle in the horizontal plane.
+
+ Unit: degrees.
+ The 0 angle corresponds to the axis perpendicular to the screen.
+
+ :rtype: int
+ """
+ return self._azimuth
+
+ def getAltitudeAngle(self):
+ """Returns the signed vertical angle from the horizontal plane.
+
+ Unit: degrees.
+ Range: [-90, +90]
+
+ :rtype: int
+ """
+ return self._altitude
+
+ def setAzimuthAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param int angle: Angle from -z axis in zx plane in degrees.
+ """
+ angle = int(round(angle))
+ if angle != self._azimuth:
+ self._azimuth = angle
+ self._updateLight()
+ self.sigAzimuthAngleChanged.emit()
+
+ def setAltitudeAngle(self, angle):
+ """Set the horizontal angle.
+
+ :param int angle: Angle from -z axis in zy plane in degrees.
+ """
+ angle = int(round(angle))
+ if angle != self._altitude:
+ self._altitude = angle
+ self._updateLight()
+ self.sigAltitudeAngleChanged.emit()
+
+ def _directionUpdated(self, *args, **kwargs):
+ """Handle light direction update in the scene"""
+ # Invert direction to manipulate the 'source' pointing to
+ # the center of the viewport
+ x, y, z = - self._light.direction
+
+ # Horizontal plane is plane xz
+ azimuth = int(round(numpy.degrees(numpy.arctan2(x, z))))
+ altitude = int(round(numpy.degrees(numpy.pi/2. - numpy.arccos(y))))
+
+ if azimuth != self.getAzimuthAngle():
+ self.setAzimuthAngle(azimuth)
+
+ if altitude != self.getAltitudeAngle():
+ self.setAltitudeAngle(altitude)
+
+ def _updateLight(self):
+ """Update light direction in the scene"""
+ azimuth = numpy.radians(self._azimuth)
+ delta = numpy.pi/2. - numpy.radians(self._altitude)
+ if delta == 0.: # Avoids zenith position
+ delta = 0.0001
+ z = - numpy.sin(delta) * numpy.cos(azimuth)
+ x = - numpy.sin(delta) * numpy.sin(azimuth)
+ y = - numpy.cos(delta)
+ self._light.direction = x, y, z
+
+
+class Settings(StaticRow):
+ """Subtree for :class:`SceneWidget` style parameters.
+
+ :param SceneWidget sceneWidget: The widget to control
+ """
+
+ def __init__(self, sceneWidget):
+ background = ColorProxyRow(
+ name='Background',
+ fget=sceneWidget.getBackgroundColor,
+ fset=sceneWidget.setBackgroundColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ foreground = ColorProxyRow(
+ name='Foreground',
+ fget=sceneWidget.getForegroundColor,
+ fset=sceneWidget.setForegroundColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ text = ColorProxyRow(
+ name='Text',
+ fget=sceneWidget.getTextColor,
+ fset=sceneWidget.setTextColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ highlight = ColorProxyRow(
+ name='Highlight',
+ fget=sceneWidget.getHighlightColor,
+ fset=sceneWidget.setHighlightColor,
+ notify=sceneWidget.sigStyleChanged)
+
+ axesIndicator = ProxyRow(
+ name='Axes Indicator',
+ fget=sceneWidget.isOrientationIndicatorVisible,
+ fset=sceneWidget.setOrientationIndicatorVisible,
+ notify=sceneWidget.sigStyleChanged)
+
+ # Light direction
+
+ self._lightProxy = _DirectionalLightProxy(sceneWidget.viewport.light)
+
+ azimuthNode = ProxyRow(
+ name='Azimuth',
+ fget=self._lightProxy.getAzimuthAngle,
+ fset=self._lightProxy.setAzimuthAngle,
+ notify=self._lightProxy.sigAzimuthAngleChanged,
+ editorHint=(-90, 90))
+
+ altitudeNode = ProxyRow(
+ name='Altitude',
+ fget=self._lightProxy.getAltitudeAngle,
+ fset=self._lightProxy.setAltitudeAngle,
+ notify=self._lightProxy.sigAltitudeAngleChanged,
+ editorHint=(-90, 90))
+
+ lightDirection = StaticRow(('Light Direction', None),
+ children=(azimuthNode, altitudeNode))
+
+ # Fog
+ fog = ProxyRow(
+ name='Fog',
+ fget=sceneWidget.getFogMode,
+ fset=sceneWidget.setFogMode,
+ notify=sceneWidget.sigStyleChanged,
+ toModelData=lambda mode: mode is Plot3DWidget.FogMode.LINEAR,
+ fromModelData=lambda mode: Plot3DWidget.FogMode.LINEAR if mode else Plot3DWidget.FogMode.NONE)
+
+ # Settings row
+ children = (background, foreground, text, highlight,
+ axesIndicator, lightDirection, fog)
+ super(Settings, self).__init__(('Settings', None), children=children)
+
+
+class Item3DRow(BaseRow):
+ """Represents an :class:`Item3D` with checkable visibility
+
+ :param Item3D item: The scene item to represent.
+ :param str name: The optional name of the item
+ """
+
+ _EVENTS = items.ItemChangedType.VISIBLE, items.Item3DChangedType.LABEL
+ """Events for which to update the first column in the tree"""
+
+ def __init__(self, item, name=None):
+ self.__name = None if name is None else str(name)
+ super(Item3DRow, self).__init__()
+
+ self.setFlags(
+ self.flags(0) | qt.Qt.ItemIsUserCheckable | qt.Qt.ItemIsSelectable,
+ 0)
+ self.setFlags(self.flags(1) | qt.Qt.ItemIsSelectable, 1)
+
+ self._item = weakref.ref(item)
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def _itemChanged(self, event):
+ """Handle model update upon change"""
+ if event in self._EVENTS:
+ model = self.model()
+ if model is not None:
+ index = self.index(column=0)
+ model.dataChanged.emit(index, index)
+
+ def item(self):
+ """Returns the :class:`Item3D` item or None"""
+ return self._item()
+
+ def data(self, column, role):
+ if column == 0:
+ if role == qt.Qt.CheckStateRole:
+ item = self.item()
+ if item is not None and item.isVisible():
+ return qt.Qt.Checked
+ else:
+ return qt.Qt.Unchecked
+
+ elif role == qt.Qt.DecorationRole:
+ return icons.getQIcon('item-3dim')
+
+ elif role == qt.Qt.DisplayRole:
+ if self.__name is None:
+ item = self.item()
+ return '' if item is None else item.getLabel()
+ else:
+ return self.__name
+
+ return super(Item3DRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if column == 0 and role == qt.Qt.CheckStateRole:
+ item = self.item()
+ if item is not None:
+ item.setVisible(value == qt.Qt.Checked)
+ return True
+ else:
+ return False
+ return super(Item3DRow, self).setData(column, value, role)
+
+ def columnCount(self):
+ return 2
+
+
+class DataItem3DBoundingBoxRow(ItemProxyRow):
+ """Represents :class:`DataItem3D` bounding box visibility
+
+ :param DataItem3D item: The item for which to display/control bounding box
+ """
+
+ def __init__(self, item):
+ super(DataItem3DBoundingBoxRow, self).__init__(
+ item=item,
+ name='Bounding box',
+ fget=item.isBoundingBoxVisible,
+ fset=item.setBoundingBoxVisible,
+ events=items.Item3DChangedType.BOUNDING_BOX_VISIBLE)
+
+
+class MatrixProxyRow(ItemProxyRow):
+ """Proxy for a row of a DataItem3D 3x3 matrix transform
+
+ :param DataItem3D item:
+ :param int index: Matrix row index
+ """
+
+ def __init__(self, item, index):
+ self._item = weakref.ref(item)
+ self._index = index
+
+ super(MatrixProxyRow, self).__init__(
+ item=item,
+ name='',
+ fget=self._getMatrixRow,
+ fset=self._setMatrixRow,
+ events=items.Item3DChangedType.TRANSFORM)
+
+ def _getMatrixRow(self):
+ """Returns the matrix row.
+
+ :rtype: QVector3D
+ """
+ item = self._item()
+ if item is not None:
+ matrix = item.getMatrix()
+ return qt.QVector3D(*matrix[self._index, :])
+ else:
+ return None
+
+ def _setMatrixRow(self, row):
+ """Set the row of the matrix
+
+ :param QVector3D row: Row values to set
+ """
+ item = self._item()
+ if item is not None:
+ matrix = item.getMatrix()
+ matrix[self._index, :] = row.x(), row.y(), row.z()
+ item.setMatrix(matrix)
+
+ def data(self, column, role):
+ data = super(MatrixProxyRow, self).data(column, role)
+
+ if column == 1 and role == qt.Qt.DisplayRole:
+ # Convert QVector3D to text
+ data = "%g; %g; %g" % (data.x(), data.y(), data.z())
+
+ return data
+
+
+class DataItem3DTransformRow(StaticRow):
+ """Represents :class:`DataItem3D` transform parameters
+
+ :param DataItem3D item: The item for which to display/control transform
+ """
+
+ _ROTATION_CENTER_OPTIONS = 'Origin', 'Lower', 'Center', 'Upper'
+
+ def __init__(self, item):
+ super(DataItem3DTransformRow, self).__init__(('Transform', None))
+ self._item = weakref.ref(item)
+
+ translation = ItemProxyRow(
+ item=item,
+ name='Translation',
+ fget=item.getTranslation,
+ fset=self._setTranslation,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: qt.QVector3D(*data))
+ self.addRow(translation)
+
+ # Here to keep a reference
+ self._xSetCenter = functools.partial(self._setCenter, index=0)
+ self._ySetCenter = functools.partial(self._setCenter, index=1)
+ self._zSetCenter = functools.partial(self._setCenter, index=2)
+
+ rotateCenter = StaticRow(
+ ('Center', None),
+ children=(
+ ItemProxyRow(item=item,
+ name='X axis',
+ fget=item.getRotationCenter,
+ fset=self._xSetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(
+ self._centerToModelData, index=0),
+ editorHint=self._ROTATION_CENTER_OPTIONS),
+ ItemProxyRow(item=item,
+ name='Y axis',
+ fget=item.getRotationCenter,
+ fset=self._ySetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(
+ self._centerToModelData, index=1),
+ editorHint=self._ROTATION_CENTER_OPTIONS),
+ ItemProxyRow(item=item,
+ name='Z axis',
+ fget=item.getRotationCenter,
+ fset=self._zSetCenter,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=functools.partial(
+ self._centerToModelData, index=2),
+ editorHint=self._ROTATION_CENTER_OPTIONS),
+ ))
+
+ rotate = StaticRow(
+ ('Rotation', None),
+ children=(
+ ItemAngleDegreeRow(
+ item=item,
+ name='Angle',
+ fget=item.getRotation,
+ fset=self._setAngle,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: data[0]),
+ ItemProxyRow(
+ item=item,
+ name='Axis',
+ fget=item.getRotation,
+ fset=self._setAxis,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: qt.QVector3D(*data[1])),
+ rotateCenter
+ ))
+ self.addRow(rotate)
+
+ scale = ItemProxyRow(
+ item=item,
+ name='Scale',
+ fget=item.getScale,
+ fset=self._setScale,
+ events=items.Item3DChangedType.TRANSFORM,
+ toModelData=lambda data: qt.QVector3D(*data))
+ self.addRow(scale)
+
+ matrix = StaticRow(
+ ('Matrix', None),
+ children=(MatrixProxyRow(item, 0),
+ MatrixProxyRow(item, 1),
+ MatrixProxyRow(item, 2)))
+ self.addRow(matrix)
+
+ def item(self):
+ """Returns the :class:`Item3D` item or None"""
+ return self._item()
+
+ @staticmethod
+ def _centerToModelData(center, index):
+ """Convert rotation center information from scene to model.
+
+ :param center: The center info from the scene
+ :param int index: dimension to convert
+ """
+ value = center[index]
+ if isinstance(value, str):
+ return value.title()
+ elif value == 0.:
+ return 'Origin'
+ else:
+ return str(value)
+
+ def _setCenter(self, value, index):
+ """Set one dimension of the rotation center.
+
+ :param value: Value received through the model.
+ :param int index: dimension to set
+ """
+ item = self.item()
+ if item is not None:
+ if value == 'Origin':
+ value = 0.
+ elif value not in self._ROTATION_CENTER_OPTIONS:
+ value = float(value)
+ else:
+ value = value.lower()
+
+ center = list(item.getRotationCenter())
+ center[index] = value
+ item.setRotationCenter(*center)
+
+ def _setAngle(self, angle):
+ """Set rotation angle.
+
+ :param float angle:
+ """
+ item = self.item()
+ if item is not None:
+ _, axis = item.getRotation()
+ item.setRotation(angle, axis)
+
+ def _setAxis(self, axis):
+ """Set rotation axis.
+
+ :param QVector3D axis:
+ """
+ item = self.item()
+ if item is not None:
+ angle, _ = item.getRotation()
+ item.setRotation(angle, (axis.x(), axis.y(), axis.z()))
+
+ def _setTranslation(self, translation):
+ """Set translation transform.
+
+ :param QVector3D translation:
+ """
+ item = self.item()
+ if item is not None:
+ item.setTranslation(translation.x(), translation.y(), translation.z())
+
+ def _setScale(self, scale):
+ """Set scale transform.
+
+ :param QVector3D scale:
+ """
+ item = self.item()
+ if item is not None:
+ sx, sy, sz = scale.x(), scale.y(), scale.z()
+ if sx == 0. or sy == 0. or sz == 0.:
+ _logger.warning('Cannot set scale to 0: ignored')
+ else:
+ item.setScale(scale.x(), scale.y(), scale.z())
+
+
+class GroupItemRow(Item3DRow):
+ """Represents a :class:`GroupItem` with transforms and children
+
+ :param GroupItem item: The scene group to represent.
+ :param str name: The optional name of the group
+ """
+
+ _CHILDREN_ROW_OFFSET = 2
+ """Number of rows for group parameters. Children are added after"""
+
+ def __init__(self, item, name=None):
+ super(GroupItemRow, self).__init__(item, name)
+ self.addRow(DataItem3DBoundingBoxRow(item))
+ self.addRow(DataItem3DTransformRow(item))
+
+ item.sigItemAdded.connect(self._itemAdded)
+ item.sigItemRemoved.connect(self._itemRemoved)
+
+ for child in item.getItems():
+ self.addRow(nodeFromItem(child))
+
+ def _itemAdded(self, item):
+ """Handle item addition to the group and add it to the model.
+
+ :param Item3D item: added item
+ """
+ group = self.item()
+ if group is None:
+ return
+
+ row = group.getItems().index(item)
+ self.addRow(nodeFromItem(item), row + self._CHILDREN_ROW_OFFSET)
+
+ def _itemRemoved(self, item):
+ """Handle item removal from the group and remove it from the model.
+
+ :param Item3D item: removed item
+ """
+ group = self.item()
+ if group is None:
+ return
+
+ # Find item
+ for row in self.children():
+ if isinstance(row, Item3DRow) and row.item() is item:
+ self.removeRow(row)
+ break # Got it
+ else:
+ raise RuntimeError("Model does not correspond to scene content")
+
+
+class InterpolationRow(ItemProxyRow):
+ """Represents :class:`InterpolationMixIn` property.
+
+ :param Item3D item: Scene item with interpolation property
+ """
+
+ def __init__(self, item):
+ modes = [mode.title() for mode in item.INTERPOLATION_MODES]
+ super(InterpolationRow, self).__init__(
+ item=item,
+ name='Interpolation',
+ fget=item.getInterpolation,
+ fset=item.setInterpolation,
+ events=items.Item3DChangedType.INTERPOLATION,
+ toModelData=lambda mode: mode.title(),
+ fromModelData=lambda mode: mode.lower(),
+ editorHint=modes)
+
+
+class _ColormapBaseProxyRow(ProxyRow):
+ """Base class for colormap model row
+
+ This class handle synchronization and signals from the item and the colormap
+ """
+
+ _sigColormapChanged = qt.Signal()
+ """Signal used internally to notify colormap (or data) update"""
+
+ def __init__(self, item, *args, **kwargs):
+ self._item = weakref.ref(item)
+ self._colormap = item.getColormap()
+
+ ProxyRow.__init__(self, *args, **kwargs)
+
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ item.sigItemChanged.connect(self._itemChanged)
+ self._sigColormapChanged.connect(self._modelUpdated)
+
+ def item(self):
+ """Returns the :class:`ColormapMixIn` item or None"""
+ return self._item()
+
+ def _getColormapRange(self):
+ """Returns the range of the colormap for the current data.
+
+ :return: Colormap range (min, max)
+ """
+ item = self.item()
+ if item is not None and self._colormap is not None:
+ return self._colormap.getColormapRange(item)
+ else:
+ return 1, 100 # Fallback
+
+ def _modelUpdated(self, *args, **kwargs):
+ """Emit dataChanged in the model"""
+ topLeft = self.index(column=0)
+ bottomRight = self.index(column=1)
+ model = self.model()
+ if model is not None:
+ model.dataChanged.emit(topLeft, bottomRight)
+
+ def _colormapChanged(self):
+ self._sigColormapChanged.emit()
+
+ def _itemChanged(self, event):
+ """Handle change of colormap or data in the item.
+
+ :param ItemChangedType event:
+ """
+ if event == items.ItemChangedType.COLORMAP:
+ self._sigColormapChanged.emit()
+ if self._colormap is not None:
+ self._colormap.sigChanged.disconnect(self._colormapChanged)
+
+ item = self.item()
+ if item is not None:
+ self._colormap = item.getColormap()
+ self._colormap.sigChanged.connect(self._colormapChanged)
+ else:
+ self._colormap = None
+
+ elif event == items.ItemChangedType.DATA:
+ self._sigColormapChanged.emit()
+
+
+class _ColormapBoundRow(_ColormapBaseProxyRow):
+ """ProxyRow for colormap min or max
+
+ :param ColormapMixIn item: The item to handle
+ :param str name: Name of the raw
+ :param int index: 0 for Min and 1 of Max
+ """
+
+ def __init__(self, item, name, index):
+ self._index = index
+ _ColormapBaseProxyRow.__init__(
+ self,
+ item,
+ name=name,
+ fget=self._getBound,
+ fset=self._setBound)
+
+ self.setToolTip('Colormap %s bound:\n'
+ 'Check to set bound manually, '
+ 'uncheck for autoscale' % name.lower())
+
+ def _getRawBound(self):
+ """Proxy to get raw colormap bound
+
+ :rtype: float or None
+ """
+ if self._colormap is None:
+ return None
+ elif self._index == 0:
+ return self._colormap.getVMin()
+ else: # self._index == 1
+ return self._colormap.getVMax()
+
+ def _getBound(self):
+ """Proxy to get colormap effective bound value
+
+ :rtype: float
+ """
+ if self._colormap is not None:
+ bound = self._getRawBound()
+
+ if bound is None:
+ bound = self._getColormapRange()[self._index]
+ return bound
+ else:
+ return 1. # Fallback
+
+ def _setBound(self, value):
+ """Proxy to set colormap bound.
+
+ :param float value:
+ """
+ if self._colormap is not None:
+ if self._index == 0:
+ min_ = value
+ max_ = self._colormap.getVMax()
+ else: # self._index == 1
+ min_ = self._colormap.getVMin()
+ max_ = value
+
+ if max_ is not None and min_ is not None and min_ > max_:
+ min_, max_ = max_, min_
+ self._colormap.setVRange(min_, max_)
+
+ def flags(self, column):
+ if column == 0:
+ return qt.Qt.ItemIsEnabled | qt.Qt.ItemIsUserCheckable
+
+ elif column == 1:
+ if self._getRawBound() is not None:
+ flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
+ else:
+ flags = qt.Qt.NoItemFlags # Disabled if autoscale
+ return flags
+
+ else: # Never event
+ return super(_ColormapBoundRow, self).flags(column)
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.CheckStateRole:
+ if self._getRawBound() is None:
+ return qt.Qt.Unchecked
+ else:
+ return qt.Qt.Checked
+
+ else:
+ return super(_ColormapBoundRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if column == 0 and role == qt.Qt.CheckStateRole:
+ if self._colormap is not None:
+ bound = self._getBound() if value == qt.Qt.Checked else None
+ self._setBound(bound)
+ return True
+ else:
+ return False
+
+ return super(_ColormapBoundRow, self).setData(column, value, role)
+
+
+class _ColormapGammaRow(_ColormapBaseProxyRow):
+ """ProxyRow for colormap gamma normalization parameter
+
+ :param ColormapMixIn item: The item to handle
+ :param str name: Name of the raw
+ """
+
+ def __init__(self, item):
+ _ColormapBaseProxyRow.__init__(
+ self,
+ item,
+ name="Gamma",
+ fget=self._getGammaNormalizationParameter,
+ fset=self._setGammaNormalizationParameter)
+
+ self.setToolTip('Colormap gamma correction parameter:\n'
+ 'Only meaningful for gamma normalization.')
+
+ def _getGammaNormalizationParameter(self):
+ """Proxy for :meth:`Colormap.getGammaNormalizationParameter`"""
+ if self._colormap is not None:
+ return self._colormap.getGammaNormalizationParameter()
+ else:
+ return 0.0
+
+ def _setGammaNormalizationParameter(self, gamma):
+ """Proxy for :meth:`Colormap.setGammaNormalizationParameter`"""
+ if self._colormap is not None:
+ return self._colormap.setGammaNormalizationParameter(gamma)
+
+ def _getNormalization(self):
+ """Proxy for :meth:`Colormap.getNormalization`"""
+ if self._colormap is not None:
+ return self._colormap.getNormalization()
+ else:
+ return ''
+
+ def flags(self, column):
+ if column in (0, 1):
+ if self._getNormalization() == 'gamma':
+ flags = qt.Qt.ItemIsEditable | qt.Qt.ItemIsEnabled
+ else:
+ flags = qt.Qt.NoItemFlags # Disabled if not gamma correction
+ return flags
+
+ else: # Never event
+ return super(_ColormapGammaRow, self).flags(column)
+
+
+class ColormapRow(_ColormapBaseProxyRow):
+ """Represents :class:`ColormapMixIn` property.
+
+ :param Item3D item: Scene item with colormap property
+ """
+
+ def __init__(self, item):
+ super(ColormapRow, self).__init__(
+ item,
+ name='Colormap',
+ fget=self._get)
+
+ self._colormapImage = None
+
+ self._colormapsMapping = {}
+ for cmap in preferredColormaps():
+ self._colormapsMapping[cmap.title()] = cmap
+
+ self.addRow(ProxyRow(
+ name='Name',
+ fget=self._getName,
+ fset=self._setName,
+ notify=self._sigColormapChanged,
+ editorHint=list(self._colormapsMapping.keys())))
+
+ norms = [norm.title() for norm in self._colormap.NORMALIZATIONS]
+ self.addRow(ProxyRow(
+ name='Normalization',
+ fget=self._getNormalization,
+ fset=self._setNormalization,
+ notify=self._sigColormapChanged,
+ editorHint=norms))
+
+ self.addRow(_ColormapGammaRow(item))
+
+ modes = [mode.title() for mode in self._colormap.AUTOSCALE_MODES]
+ self.addRow(ProxyRow(
+ name='Autoscale Mode',
+ fget=self._getAutoscaleMode,
+ fset=self._setAutoscaleMode,
+ notify=self._sigColormapChanged,
+ editorHint=modes))
+
+ self.addRow(_ColormapBoundRow(item, name='Min.', index=0))
+ self.addRow(_ColormapBoundRow(item, name='Max.', index=1))
+
+ self._sigColormapChanged.connect(self._updateColormapImage)
+
+ def getColormapImage(self):
+ """Returns image representing the colormap or None
+
+ :rtype: Union[QImage,None]
+ """
+ if self._colormapImage is None and self._colormap is not None:
+ image = numpy.zeros((16, 130, 3), dtype=numpy.uint8)
+ image[1:-1, 1:-1] = self._colormap.getNColors(image.shape[1] - 2)[:, :3]
+ self._colormapImage = convertArrayToQImage(image)
+ return self._colormapImage
+
+ def _get(self):
+ """Getter for ProxyRow subclass"""
+ return None
+
+ def _getName(self):
+ """Proxy for :meth:`Colormap.getName`"""
+ if self._colormap is not None and self._colormap.getName() is not None:
+ return self._colormap.getName().title()
+ else:
+ return ''
+
+ def _setName(self, name):
+ """Proxy for :meth:`Colormap.setName`"""
+ # Convert back from titled to name if possible
+ if self._colormap is not None:
+ name = self._colormapsMapping.get(name, name)
+ self._colormap.setName(name)
+
+ def _getNormalization(self):
+ """Proxy for :meth:`Colormap.getNormalization`"""
+ if self._colormap is not None:
+ return self._colormap.getNormalization().title()
+ else:
+ return ''
+
+ def _setNormalization(self, normalization):
+ """Proxy for :meth:`Colormap.setNormalization`"""
+ if self._colormap is not None:
+ return self._colormap.setNormalization(normalization.lower())
+
+ def _getAutoscaleMode(self):
+ """Proxy for :meth:`Colormap.getAutoscaleMode`"""
+ if self._colormap is not None:
+ return self._colormap.getAutoscaleMode().title()
+ else:
+ return ''
+
+ def _setAutoscaleMode(self, mode):
+ """Proxy for :meth:`Colormap.setAutoscaleMode`"""
+ if self._colormap is not None:
+ return self._colormap.setAutoscaleMode(mode.lower())
+
+ def _updateColormapImage(self, *args, **kwargs):
+ """Notify colormap update to update the image in the tree"""
+ if self._colormapImage is not None:
+ self._colormapImage = None
+ model = self.model()
+ if model is not None:
+ index = self.index(column=1)
+ model.dataChanged.emit(index, index)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DecorationRole:
+ return self.getColormapImage()
+ else:
+ return super(ColormapRow, self).data(column, role)
+
+
+class SymbolRow(ItemProxyRow):
+ """Represents :class:`SymbolMixIn` symbol property.
+
+ :param Item3D item: Scene item with symbol property
+ """
+
+ def __init__(self, item):
+ names = [item.getSymbolName(s) for s in item.getSupportedSymbols()]
+ super(SymbolRow, self).__init__(
+ item=item,
+ name='Marker',
+ fget=item.getSymbolName,
+ fset=item.setSymbol,
+ events=items.ItemChangedType.SYMBOL,
+ editorHint=names)
+
+
+class SymbolSizeRow(ItemProxyRow):
+ """Represents :class:`SymbolMixIn` symbol size property.
+
+ :param Item3D item: Scene item with symbol size property
+ """
+
+ def __init__(self, item):
+ super(SymbolSizeRow, self).__init__(
+ item=item,
+ name='Marker size',
+ fget=item.getSymbolSize,
+ fset=item.setSymbolSize,
+ events=items.ItemChangedType.SYMBOL_SIZE,
+ editorHint=(1, 20)) # TODO link with OpenGL max point size
+
+
+class PlaneEquationRow(ItemProxyRow):
+ """Represents :class:`PlaneMixIn` as plane equation.
+
+ :param Item3D item: Scene item with plane equation property
+ """
+
+ def __init__(self, item):
+ super(PlaneEquationRow, self).__init__(
+ item=item,
+ name='Equation',
+ fget=item.getParameters,
+ fset=item.setParameters,
+ events=items.ItemChangedType.POSITION,
+ toModelData=lambda data: qt.QVector4D(*data),
+ fromModelData=lambda data: (data.x(), data.y(), data.z(), data.w()))
+ self._item = weakref.ref(item)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DisplayRole:
+ item = self._item()
+ if item is not None:
+ params = item.getParameters()
+ return ('%gx %+gy %+gz %+g = 0' %
+ (params[0], params[1], params[2], params[3]))
+ return super(PlaneEquationRow, self).data(column, role)
+
+
+class PlaneRow(ItemProxyRow):
+ """Represents :class:`PlaneMixIn` property.
+
+ :param Item3D item: Scene item with plane equation property
+ """
+
+ _PLANES = OrderedDict((('Plane 0', (1., 0., 0.)),
+ ('Plane 1', (0., 1., 0.)),
+ ('Plane 2', (0., 0., 1.)),
+ ('-', None)))
+ """Mapping of plane names to normals"""
+
+ _PLANE_ICONS = {'Plane 0': '3d-plane-normal-x',
+ 'Plane 1': '3d-plane-normal-y',
+ 'Plane 2': '3d-plane-normal-z',
+ '-': '3d-plane'}
+ """Mapping of plane names to normals"""
+
+ def __init__(self, item):
+ super(PlaneRow, self).__init__(
+ item=item,
+ name='Plane',
+ fget=self.__getPlaneName,
+ fset=self.__setPlaneName,
+ events=items.ItemChangedType.POSITION,
+ editorHint=tuple(self._PLANES.keys()))
+ self._item = weakref.ref(item)
+ self._lastName = None
+
+ self.addRow(PlaneEquationRow(item))
+
+ def _notified(self, *args, **kwargs):
+ """Handle notification of modification
+
+ Here only send if plane name actually changed
+ """
+ if self._lastName != self.__getPlaneName():
+ super(PlaneRow, self)._notified()
+
+ def __getPlaneName(self):
+ """Returns name of plane // to axes or '-'
+
+ :rtype: str
+ """
+ item = self._item()
+ planeNormal = item.getNormal() if item is not None else None
+
+ for name, normal in self._PLANES.items():
+ if numpy.array_equal(planeNormal, normal):
+ return name
+ return '-'
+
+ def __setPlaneName(self, data):
+ """Set plane normal according to given plane name
+
+ :param str data: Selected plane name
+ """
+ item = self._item()
+ if item is not None:
+ for name, normal in self._PLANES.items():
+ if data == name and normal is not None:
+ item.setNormal(normal)
+
+ def data(self, column, role):
+ if column == 1 and role == qt.Qt.DecorationRole:
+ return icons.getQIcon(self._PLANE_ICONS[self.__getPlaneName()])
+ data = super(PlaneRow, self).data(column, role)
+ if column == 1 and role == qt.Qt.DisplayRole:
+ self._lastName = data
+ return data
+
+
+class ComplexModeRow(ItemProxyRow):
+ """Represents :class:`items.ComplexMixIn` symbol property.
+
+ :param Item3D item: Scene item with symbol property
+ """
+
+ def __init__(self, item, name='Mode'):
+ names = [m.value.replace('_', ' ').title()
+ for m in item.supportedComplexModes()]
+ super(ComplexModeRow, self).__init__(
+ item=item,
+ name=name,
+ fget=item.getComplexMode,
+ fset=item.setComplexMode,
+ events=items.ItemChangedType.COMPLEX_MODE,
+ toModelData=lambda data: data.value.replace('_', ' ').title(),
+ fromModelData=lambda data: data.lower().replace(' ', '_'),
+ editorHint=names)
+
+
+class RemoveIsosurfaceRow(BaseRow):
+ """Class for Isosurface Delete button
+
+ :param Isosurface isosurface: The isosurface item to attach the button to.
+ """
+
+ def __init__(self, isosurface):
+ super(RemoveIsosurfaceRow, self).__init__()
+ self._isosurface = weakref.ref(isosurface)
+
+ def createEditor(self):
+ """Specific editor factory provided to the model"""
+ editor = qt.QWidget()
+ layout = qt.QHBoxLayout(editor)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ removeBtn = qt.QToolButton()
+ removeBtn.setText('Delete')
+ removeBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(removeBtn)
+ removeBtn.clicked.connect(self._removeClicked)
+
+ layout.addStretch(1)
+ return editor
+
+ def isosurface(self):
+ """Returns the controlled isosurface
+
+ :rtype: Isosurface
+ """
+ return self._isosurface()
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.UserRole: # editor hint
+ return self.createEditor
+
+ return super(RemoveIsosurfaceRow, self).data(column, role)
+
+ def flags(self, column):
+ flags = super(RemoveIsosurfaceRow, self).flags(column)
+ if column == 0:
+ flags |= qt.Qt.ItemIsEditable
+ return flags
+
+ def _removeClicked(self):
+ """Handle Delete button clicked"""
+ isosurface = self.isosurface()
+ if isosurface is not None:
+ volume = isosurface.parent()
+ if volume is not None:
+ volume.removeIsosurface(isosurface)
+
+
+class IsosurfaceRow(Item3DRow):
+ """Represents an :class:`Isosurface` item.
+
+ :param Isosurface item: Isosurface item
+ """
+
+ _LEVEL_SLIDER_RANGE = 0, 1000
+ """Range given as editor hint"""
+
+ _EVENTS = items.ItemChangedType.VISIBLE, items.ItemChangedType.COLOR
+ """Events for which to update the first column in the tree"""
+
+ def __init__(self, item):
+ super(IsosurfaceRow, self).__init__(item, name=item.getLevel())
+
+ self.setFlags(self.flags(1) | qt.Qt.ItemIsEditable, 1)
+
+ item.sigItemChanged.connect(self._levelChanged)
+
+ self.addRow(ItemProxyRow(
+ item=item,
+ name='Level',
+ fget=self._getValueForLevelSlider,
+ fset=self._setLevelFromSliderValue,
+ events=items.Item3DChangedType.ISO_LEVEL,
+ editorHint=self._LEVEL_SLIDER_RANGE))
+
+ self.addRow(ItemColorProxyRow(
+ item=item,
+ name='Color',
+ fget=self._rgbColor,
+ fset=self._setRgbColor,
+ events=items.ItemChangedType.COLOR))
+
+ self.addRow(ItemProxyRow(
+ item=item,
+ name='Opacity',
+ fget=self._opacity,
+ fset=self._setOpacity,
+ events=items.ItemChangedType.COLOR,
+ editorHint=(0, 255)))
+
+ self.addRow(RemoveIsosurfaceRow(item))
+
+ def _getValueForLevelSlider(self):
+ """Convert iso level to slider value.
+
+ :rtype: int
+ """
+ item = self.item()
+ if item is not None:
+ volume = item.parent()
+ if volume is not None:
+ dataRange = volume.getDataRange()
+ if dataRange is not None:
+ dataMin, dataMax = dataRange[0], dataRange[-1]
+ if dataMax != dataMin:
+ offset = (item.getLevel() - dataMin) / (dataMax - dataMin)
+ else:
+ offset = 0.
+
+ sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
+ value = sliderMin + (sliderMax - sliderMin) * offset
+ return value
+ return 0
+
+ def _setLevelFromSliderValue(self, value):
+ """Convert slider value to isolevel.
+
+ :param int value:
+ """
+ item = self.item()
+ if item is not None:
+ volume = item.parent()
+ if volume is not None:
+ dataRange = volume.getDataRange()
+ if dataRange is not None:
+ sliderMin, sliderMax = self._LEVEL_SLIDER_RANGE
+ offset = (value - sliderMin) / (sliderMax - sliderMin)
+
+ dataMin, dataMax = dataRange[0], dataRange[-1]
+ level = dataMin + (dataMax - dataMin) * offset
+ item.setLevel(level)
+
+ def _rgbColor(self):
+ """Proxy to get the isosurface's RGB color without transparency
+
+ :rtype: QColor
+ """
+ item = self.item()
+ if item is None:
+ return None
+ else:
+ color = item.getColor()
+ color.setAlpha(255)
+ return color
+
+ def _setRgbColor(self, color):
+ """Proxy to set the isosurface's RGB color without transparency
+
+ :param QColor color:
+ """
+ item = self.item()
+ if item is not None:
+ color.setAlpha(item.getColor().alpha())
+ item.setColor(color)
+
+ def _opacity(self):
+ """Proxy to get the isosurface's transparency
+
+ :rtype: int
+ """
+ item = self.item()
+ return 255 if item is None else item.getColor().alpha()
+
+ def _setOpacity(self, opacity):
+ """Proxy to set the isosurface's transparency.
+
+ :param int opacity:
+ """
+ item = self.item()
+ if item is not None:
+ color = item.getColor()
+ color.setAlpha(opacity)
+ item.setColor(color)
+
+ def _levelChanged(self, event):
+ """Handle isosurface level changed and notify model
+
+ :param ItemChangedType event:
+ """
+ if event == items.Item3DChangedType.ISO_LEVEL:
+ model = self.model()
+ if model is not None:
+ index = self.index(column=1)
+ model.dataChanged.emit(index, index)
+
+ def data(self, column, role):
+ if column == 0: # Show color as decoration, not text
+ if role == qt.Qt.DisplayRole:
+ return None
+ elif role == qt.Qt.DecorationRole:
+ return self._rgbColor()
+
+ elif column == 1 and role in (qt.Qt.DisplayRole, qt.Qt.EditRole):
+ item = self.item()
+ return None if item is None else item.getLevel()
+
+ return super(IsosurfaceRow, self).data(column, role)
+
+ def setData(self, column, value, role):
+ if column == 1 and role == qt.Qt.EditRole:
+ item = self.item()
+ if item is not None:
+ item.setLevel(value)
+ return True
+
+ return super(IsosurfaceRow, self).setData(column, value, role)
+
+
+class ComplexIsosurfaceRow(IsosurfaceRow):
+ """Represents an :class:`ComplexIsosurface` item.
+
+ :param ComplexIsosurface item:
+ """
+
+ _EVENTS = (items.ItemChangedType.VISIBLE,
+ items.ItemChangedType.COLOR,
+ items.ItemChangedType.COMPLEX_MODE)
+ """Events for which to update the first column in the tree"""
+
+ def __init__(self, item):
+ super(ComplexIsosurfaceRow, self).__init__(item)
+
+ self.addRow(ComplexModeRow(item, "Color Complex Mode"), index=1)
+ for row in self.children():
+ if isinstance(row, ColorProxyRow):
+ self._colorRow = row
+ break
+ else:
+ raise RuntimeError("Cannot retrieve Color tree row")
+ self._colormapRow = ColormapRow(item)
+
+ self.__updateRowsForItem(item)
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def _itemChanged(self, event):
+ """Update enabled/disabled rows"""
+ if event == items.ItemChangedType.COMPLEX_MODE:
+ item = self.sender()
+ self.__updateRowsForItem(item)
+
+ def __updateRowsForItem(self, item):
+ """Update rows for item
+
+ :param item:
+ """
+ if not isinstance(item, ComplexIsosurface):
+ return
+
+ if item.getComplexMode() == items.ComplexMixIn.ComplexMode.NONE:
+ removed = self._colormapRow
+ added = self._colorRow
+ else:
+ removed = self._colorRow
+ added = self._colormapRow
+
+ # Remove unwanted rows
+ if removed in self.children():
+ self.removeRow(removed)
+
+ # Add required rows
+ if added not in self.children():
+ self.addRow(added, index=2)
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.DecorationRole:
+ item = self.item()
+ if (item is not None and
+ item.getComplexMode() != items.ComplexMixIn.ComplexMode.NONE):
+ return self._colormapRow.getColormapImage()
+
+ return super(ComplexIsosurfaceRow, self).data(column, role)
+
+
+class AddIsosurfaceRow(BaseRow):
+ """Class for Isosurface create button
+
+ :param Union[ScalarField3D,ComplexField3D] volume:
+ The volume item to attach the button to.
+ """
+
+ def __init__(self, volume):
+ super(AddIsosurfaceRow, self).__init__()
+ self._volume = weakref.ref(volume)
+
+ def createEditor(self):
+ """Specific editor factory provided to the model"""
+ editor = qt.QWidget()
+ layout = qt.QHBoxLayout(editor)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+
+ addBtn = qt.QToolButton()
+ addBtn.setText('+')
+ addBtn.setToolButtonStyle(qt.Qt.ToolButtonTextOnly)
+ layout.addWidget(addBtn)
+ addBtn.clicked.connect(self._addClicked)
+
+ layout.addStretch(1)
+ return editor
+
+ def volume(self):
+ """Returns the controlled volume item
+
+ :rtype: Union[ScalarField3D,ComplexField3D]
+ """
+ return self._volume()
+
+ def data(self, column, role):
+ if column == 0 and role == qt.Qt.UserRole: # editor hint
+ return self.createEditor
+
+ return super(AddIsosurfaceRow, self).data(column, role)
+
+ def flags(self, column):
+ flags = super(AddIsosurfaceRow, self).flags(column)
+ if column == 0:
+ flags |= qt.Qt.ItemIsEditable
+ return flags
+
+ def _addClicked(self):
+ """Handle Delete button clicked"""
+ volume = self.volume()
+ if volume is not None:
+ dataRange = volume.getDataRange()
+ if dataRange is None:
+ dataRange = 0., 1.
+
+ volume.addIsosurface(
+ numpy.mean((dataRange[0], dataRange[-1])),
+ '#0000FF')
+
+
+class VolumeIsoSurfacesRow(StaticRow):
+ """Represents :class:`ScalarFieldView`'s isosurfaces
+
+ :param Union[ScalarField3D,ComplexField3D] volume:
+ Volume item to control
+ """
+
+ def __init__(self, volume):
+ super(VolumeIsoSurfacesRow, self).__init__(
+ ('Isosurfaces', None))
+ self._volume = weakref.ref(volume)
+
+ volume.sigIsosurfaceAdded.connect(self._isosurfaceAdded)
+ volume.sigIsosurfaceRemoved.connect(self._isosurfaceRemoved)
+
+ if isinstance(volume, items.ComplexMixIn):
+ self.addRow(ComplexModeRow(volume, "Complex Mode"))
+
+ for item in volume.getIsosurfaces():
+ self.addRow(nodeFromItem(item))
+
+ self.addRow(AddIsosurfaceRow(volume))
+
+ def volume(self):
+ """Returns the controlled volume item
+
+ :rtype: Union[ScalarField3D,ComplexField3D]
+ """
+ return self._volume()
+
+ def _isosurfaceAdded(self, item):
+ """Handle isosurface addition
+
+ :param Isosurface item: added isosurface
+ """
+ volume = self.volume()
+ if volume is None:
+ return
+
+ row = volume.getIsosurfaces().index(item)
+ if isinstance(volume, items.ComplexMixIn):
+ row += 1 # Offset for the ComplexModeRow
+ self.addRow(nodeFromItem(item), row)
+
+ def _isosurfaceRemoved(self, item):
+ """Handle isosurface removal
+
+ :param Isosurface item: removed isosurface
+ """
+ volume = self.volume()
+ if volume is None:
+ return
+
+ # Find item
+ for row in self.children():
+ if isinstance(row, IsosurfaceRow) and row.item() is item:
+ self.removeRow(row)
+ break # Got it
+ else:
+ raise RuntimeError("Model does not correspond to scene content")
+
+
+class Scatter2DPropertyMixInRow(object):
+ """Mix-in class that enable/disable row according to Scatter2D mode.
+
+ :param Scatter2D item:
+ :param str propertyName: Name of the Scatter2D property of this row
+ """
+
+ def __init__(self, item, propertyName):
+ assert propertyName in ('lineWidth', 'symbol', 'symbolSize')
+ self.__propertyName = propertyName
+
+ self.__isEnabled = item.isPropertyEnabled(propertyName)
+ self.__updateFlags()
+
+ item.sigItemChanged.connect(self._itemChanged)
+
+ def data(self, column, role):
+ if column == 1 and not self.__isEnabled:
+ # Discard data and editorHint if disabled
+ return None
+ else:
+ return super(Scatter2DPropertyMixInRow, self).data(column, role)
+
+ def __updateFlags(self):
+ """Update model flags"""
+ if self.__isEnabled:
+ self.setFlags(qt.Qt.ItemIsEnabled, 0)
+ self.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsEditable, 1)
+ else:
+ self.setFlags(qt.Qt.NoItemFlags)
+
+ def _itemChanged(self, event):
+ """Set flags to enable/disable the row"""
+ if event == items.ItemChangedType.VISUALIZATION_MODE:
+ item = self.sender()
+ if item is not None: # This occurs with PySide/python2.7
+ self.__isEnabled = item.isPropertyEnabled(self.__propertyName)
+ self.__updateFlags()
+
+ # Notify model
+ model = self.model()
+ if model is not None:
+ begin = self.index(column=0)
+ end = self.index(column=1)
+ model.dataChanged.emit(begin, end)
+
+
+class Scatter2DSymbolRow(Scatter2DPropertyMixInRow, SymbolRow):
+ """Specific class for Scatter2D symbol.
+
+ It is enabled/disabled according to visualization mode.
+
+ :param Scatter2D item:
+ """
+
+ def __init__(self, item):
+ SymbolRow.__init__(self, item)
+ Scatter2DPropertyMixInRow.__init__(self, item, 'symbol')
+
+
+class Scatter2DSymbolSizeRow(Scatter2DPropertyMixInRow, SymbolSizeRow):
+ """Specific class for Scatter2D symbol size.
+
+ It is enabled/disabled according to visualization mode.
+
+ :param Scatter2D item:
+ """
+
+ def __init__(self, item):
+ SymbolSizeRow.__init__(self, item)
+ Scatter2DPropertyMixInRow.__init__(self, item, 'symbolSize')
+
+
+class Scatter2DLineWidth(Scatter2DPropertyMixInRow, ItemProxyRow):
+ """Specific class for Scatter2D symbol size.
+
+ It is enabled/disabled according to visualization mode.
+
+ :param Scatter2D item:
+ """
+
+ def __init__(self, item):
+ # TODO link editorHint with OpenGL max line width
+ ItemProxyRow.__init__(self,
+ item=item,
+ name='Line width',
+ fget=item.getLineWidth,
+ fset=item.setLineWidth,
+ events=items.ItemChangedType.LINE_WIDTH,
+ editorHint=(1, 10))
+ Scatter2DPropertyMixInRow.__init__(self, item, 'lineWidth')
+
+
+def initScatter2DNode(node, item):
+ """Specific node init for Scatter2D to set order of parameters
+
+ :param Item3DRow node: The model node to setup
+ :param Scatter2D item: The Scatter2D the node is representing
+ """
+ node.addRow(ItemProxyRow(
+ item=item,
+ name='Mode',
+ fget=item.getVisualization,
+ fset=item.setVisualization,
+ events=items.ItemChangedType.VISUALIZATION_MODE,
+ editorHint=[m.value.title() for m in item.supportedVisualizations()],
+ toModelData=lambda data: data.value.title(),
+ fromModelData=lambda data: data.lower()))
+
+ node.addRow(ItemProxyRow(
+ item=item,
+ name='Height map',
+ fget=item.isHeightMap,
+ fset=item.setHeightMap,
+ events=items.Item3DChangedType.HEIGHT_MAP))
+
+ node.addRow(ColormapRow(item))
+
+ node.addRow(Scatter2DSymbolRow(item))
+ node.addRow(Scatter2DSymbolSizeRow(item))
+
+ node.addRow(Scatter2DLineWidth(item))
+
+
+def initVolumeNode(node, item):
+ """Specific node init for volume items
+
+ :param Item3DRow node: The model node to setup
+ :param Union[ScalarField3D,ComplexField3D] item:
+ The volume item represented by the node
+ """
+ node.addRow(nodeFromItem(item.getCutPlanes()[0])) # Add cut plane
+ node.addRow(VolumeIsoSurfacesRow(item))
+
+
+def initVolumeCutPlaneNode(node, item):
+ """Specific node init for volume CutPlane
+
+ :param Item3DRow node: The model node to setup
+ :param CutPlane item: The CutPlane the node is representing
+ """
+ if isinstance(item, items.ComplexMixIn):
+ node.addRow(ComplexModeRow(item))
+
+ node.addRow(PlaneRow(item))
+
+ node.addRow(ColormapRow(item))
+
+ node.addRow(ItemProxyRow(
+ item=item,
+ name='Show <=Min',
+ fget=item.getDisplayValuesBelowMin,
+ fset=item.setDisplayValuesBelowMin,
+ events=items.ItemChangedType.ALPHA))
+
+ node.addRow(InterpolationRow(item))
+
+
+NODE_SPECIFIC_INIT = [ # class, init(node, item)
+ (items.Scatter2D, initScatter2DNode),
+ (items.ScalarField3D, initVolumeNode),
+ (CutPlane, initVolumeCutPlaneNode),
+]
+"""List of specific node init for different item class"""
+
+
+def nodeFromItem(item):
+ """Create :class:`Item3DRow` subclass corresponding to item
+
+ :param Item3D item: The item fow which to create the node
+ :rtype: Item3DRow
+ """
+ assert isinstance(item, items.Item3D)
+
+ # Item with specific model row class
+ if isinstance(item, (items.GroupItem, items.GroupWithAxesItem)):
+ return GroupItemRow(item)
+ elif isinstance(item, ComplexIsosurface):
+ return ComplexIsosurfaceRow(item)
+ elif isinstance(item, Isosurface):
+ return IsosurfaceRow(item)
+
+ # Create Item3DRow and populate it
+ node = Item3DRow(item)
+
+ if isinstance(item, items.DataItem3D):
+ node.addRow(DataItem3DBoundingBoxRow(item))
+ node.addRow(DataItem3DTransformRow(item))
+
+ # Specific extra init
+ for cls, specificInit in NODE_SPECIFIC_INIT:
+ if isinstance(item, cls):
+ specificInit(node, item)
+ break
+
+ else: # Generic case: handle mixins
+ for cls in item.__class__.__mro__:
+ if cls is items.ColormapMixIn:
+ node.addRow(ColormapRow(item))
+
+ elif cls is items.InterpolationMixIn:
+ node.addRow(InterpolationRow(item))
+
+ elif cls is items.SymbolMixIn:
+ node.addRow(SymbolRow(item))
+ node.addRow(SymbolSizeRow(item))
+
+ elif cls is items.PlaneMixIn:
+ node.addRow(PlaneRow(item))
+
+ return node
diff --git a/silx/gui/plot3d/_model/model.py b/src/silx/gui/plot3d/_model/model.py
index 186838f..186838f 100644
--- a/silx/gui/plot3d/_model/model.py
+++ b/src/silx/gui/plot3d/_model/model.py
diff --git a/silx/gui/plot3d/actions/Plot3DAction.py b/src/silx/gui/plot3d/actions/Plot3DAction.py
index 94b9572..94b9572 100644
--- a/silx/gui/plot3d/actions/Plot3DAction.py
+++ b/src/silx/gui/plot3d/actions/Plot3DAction.py
diff --git a/silx/gui/plot3d/actions/__init__.py b/src/silx/gui/plot3d/actions/__init__.py
index 26243cf..26243cf 100644
--- a/silx/gui/plot3d/actions/__init__.py
+++ b/src/silx/gui/plot3d/actions/__init__.py
diff --git a/src/silx/gui/plot3d/actions/io.py b/src/silx/gui/plot3d/actions/io.py
new file mode 100644
index 0000000..25f4ade
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/io.py
@@ -0,0 +1,337 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Plot3DAction related to input/output.
+
+It provides QAction to copy, save (snapshot and video), print a Plot3DWidget.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import logging
+import os
+
+import numpy
+
+from silx.gui import qt, printer
+from silx.gui.icons import getQIcon
+from .Plot3DAction import Plot3DAction
+from ..utils import mng
+from ...utils.image import convertQImageToArray
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CopyAction(Plot3DAction):
+ """QAction to provide copy of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(CopyAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('edit-copy'))
+ self.setText('Copy')
+ self.setToolTip('Copy a snapshot of the 3D scene to the clipboard')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot copy widget, no associated Plot3DWidget')
+ else:
+ image = plot3d.grabGL()
+ qt.QApplication.clipboard().setImage(image)
+
+
+class SaveAction(Plot3DAction):
+ """QAction to provide save snapshot of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(SaveAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('document-save'))
+ self.setText('Save...')
+ self.setToolTip('Save a snapshot of the 3D scene')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Save)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot save widget, no associated Plot3DWidget')
+ else:
+ dialog = qt.QFileDialog(self.parent())
+ dialog.setWindowTitle('Save snapshot as')
+ dialog.setModal(True)
+ dialog.setNameFilters(('Plot3D Snapshot PNG (*.png)',
+ 'Plot3D Snapshot JPEG (*.jpg)'))
+
+ dialog.setFileMode(qt.QFileDialog.AnyFile)
+ dialog.setAcceptMode(qt.QFileDialog.AcceptSave)
+
+ if not dialog.exec():
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+ dialog.close()
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ image = plot3d.grabGL()
+ if not image.save(filename):
+ _logger.error('Failed to save image as %s', filename)
+ qt.QMessageBox.critical(
+ self.parent(),
+ 'Save snapshot as',
+ 'Failed to save snapshot')
+
+
+class PrintAction(Plot3DAction):
+ """QAction to provide printing of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PrintAction, self).__init__(parent, plot3d)
+
+ self.setIcon(getQIcon('document-print'))
+ self.setText('Print...')
+ self.setToolTip('Print a snapshot of the 3D scene')
+ self.setCheckable(False)
+ self.setShortcut(qt.QKeySequence.Print)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered[bool].connect(self._triggered)
+
+ def getPrinter(self):
+ """Return the QPrinter instance used for printing.
+
+ :rtype: QPrinter
+ """
+ return printer.getDefaultPrinter()
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Cannot print widget, no associated Plot3DWidget')
+ else:
+ printer = self.getPrinter()
+ dialog = qt.QPrintDialog(printer, plot3d)
+ dialog.setWindowTitle('Print Plot3D snapshot')
+ if not dialog.exec():
+ return
+
+ image = plot3d.grabGL()
+
+ # Draw pixmap with painter
+ painter = qt.QPainter()
+ if not painter.begin(printer):
+ return
+
+ pageRect = printer.pageRect(qt.QPrinter.DevicePixel)
+ if (pageRect.width() < image.width() or
+ pageRect.height() < image.height()):
+ # Downscale to page
+ xScale = pageRect.width() / image.width()
+ yScale = pageRect.height() / image.height()
+ scale = min(xScale, yScale)
+ else:
+ scale = 1.
+
+ rect = qt.QRectF(0,
+ 0,
+ scale * image.width(),
+ scale * image.height())
+ painter.drawImage(rect, image)
+ painter.end()
+
+
+class VideoAction(Plot3DAction):
+ """This action triggers the recording of a video of the scene.
+
+ The scene is rotated 360 degrees around a vertical axis.
+
+ :param parent: Action parent see :class:`QAction`.
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ PNG_SERIE_FILTER = 'Serie of PNG files (*.png)'
+ MNG_FILTER = 'Multiple-image Network Graphics file (*.mng)'
+
+ def __init__(self, parent, plot3d=None):
+ super(VideoAction, self).__init__(parent, plot3d)
+ self.setText('Record video..')
+ self.setIcon(getQIcon('camera'))
+ self.setToolTip(
+ 'Record a video of a 360 degrees rotation of the 3D scene.')
+ self.setCheckable(False)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ """Action triggered callback"""
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.warning(
+ 'Ignoring action triggered without Plot3DWidget set')
+ return
+
+ dialog = qt.QFileDialog(parent=plot3d)
+ dialog.setWindowTitle('Save video as...')
+ dialog.setModal(True)
+ dialog.setNameFilters([self.PNG_SERIE_FILTER,
+ self.MNG_FILTER])
+ dialog.setFileMode(dialog.AnyFile)
+ dialog.setAcceptMode(dialog.AcceptSave)
+
+ if not dialog.exec():
+ return
+
+ nameFilter = dialog.selectedNameFilter()
+ filename = dialog.selectedFiles()[0]
+
+ # Forces the filename extension to match the chosen filter
+ extension = nameFilter.split()[-1][2:-1]
+ if (len(filename) <= len(extension) or
+ filename[-len(extension):].lower() != extension.lower()):
+ filename += extension
+
+ nbFrames = int(4. * 25) # 4 seconds, 25 fps
+
+ if nameFilter == self.PNG_SERIE_FILTER:
+ self._saveAsPNGSerie(filename, nbFrames)
+ elif nameFilter == self.MNG_FILTER:
+ self._saveAsMNG(filename, nbFrames)
+ else:
+ _logger.error('Unsupported file filter: %s', nameFilter)
+
+ def _saveAsPNGSerie(self, filename, nbFrames):
+ """Save video as serie of PNG files.
+
+ It adds a counter to the provided filename before the extension.
+
+ :param str filename: filename to use as template
+ :param int nbFrames: Number of frames to generate
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ # Define filename template
+ nbDigits = int(numpy.log10(nbFrames)) + 1
+ indexFormat = '%%0%dd' % nbDigits
+ extensionIndex = filename.rfind('.')
+ filenameFormat = \
+ filename[:extensionIndex] + indexFormat + filename[extensionIndex:]
+
+ try:
+ for index, image in enumerate(self._video360(nbFrames)):
+ image.save(filenameFormat % index)
+ except GeneratorExit:
+ pass
+
+ def _saveAsMNG(self, filename, nbFrames):
+ """Save video as MNG file.
+
+ :param str filename: filename to use
+ :param int nbFrames: Number of frames to generate
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ frames = (convertQImageToArray(im) for im in self._video360(nbFrames))
+ try:
+ with open(filename, 'wb') as file_:
+ for chunk in mng.convert(frames, nb_images=nbFrames):
+ file_.write(chunk)
+ except GeneratorExit:
+ os.remove(filename) # Saving aborted, delete file
+
+ def _video360(self, nbFrames):
+ """Run the video and provides the images
+
+ :param int nbFrames: The number of frames to generate for
+ :return: Iterator of QImage of the video sequence
+ """
+ plot3d = self.getPlot3DWidget()
+ assert plot3d is not None
+
+ angleStep = 360. / nbFrames
+
+ # Create progress bar dialog
+ dialog = qt.QDialog(plot3d)
+ dialog.setWindowTitle('Record Video')
+ layout = qt.QVBoxLayout(dialog)
+ progress = qt.QProgressBar()
+ progress.setRange(0, nbFrames)
+ layout.addWidget(progress)
+
+ btnBox = qt.QDialogButtonBox(qt.QDialogButtonBox.Abort)
+ btnBox.rejected.connect(dialog.reject)
+ layout.addWidget(btnBox)
+
+ dialog.setModal(True)
+ dialog.show()
+
+ qapp = qt.QApplication.instance()
+
+ for frame in range(nbFrames):
+ progress.setValue(frame)
+ image = plot3d.grabGL()
+ yield image
+ plot3d.viewport.orbitCamera('left', angleStep)
+ qapp.processEvents()
+ if not dialog.isVisible():
+ break # It as been rejected by the abort button
+ else:
+ dialog.accept()
+
+ if dialog.result() == qt.QDialog.Rejected:
+ raise GeneratorExit('Aborted')
diff --git a/src/silx/gui/plot3d/actions/mode.py b/src/silx/gui/plot3d/actions/mode.py
new file mode 100644
index 0000000..b9cd7c8
--- /dev/null
+++ b/src/silx/gui/plot3d/actions/mode.py
@@ -0,0 +1,178 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides Plot3DAction related to interaction modes.
+
+It provides QAction to rotate or pan a Plot3DWidget
+as well as toggle a picking mode.
+"""
+
+from __future__ import absolute_import, division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import logging
+
+from ....utils.proxy import docstring
+from ... import qt
+from ...icons import getQIcon
+from .Plot3DAction import Plot3DAction
+
+
+_logger = logging.getLogger(__name__)
+
+
+class InteractiveModeAction(Plot3DAction):
+ """Base class for QAction changing interactive mode of a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param str interaction: The interactive mode this action controls
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, interaction, plot3d=None):
+ self._interaction = interaction
+
+ super(InteractiveModeAction, self).__init__(parent, plot3d)
+ self.setCheckable(True)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error(
+ 'Cannot set %s interaction, no associated Plot3DWidget' %
+ self._interaction)
+ else:
+ plot3d.setInteractiveMode(self._interaction)
+ self.setChecked(True)
+
+ @docstring(Plot3DAction)
+ def setPlot3DWidget(self, widget):
+ # Disconnect from previous Plot3DWidget
+ plot3d = self.getPlot3DWidget()
+ if plot3d is not None:
+ plot3d.sigInteractiveModeChanged.disconnect(
+ self._interactiveModeChanged)
+
+ super(InteractiveModeAction, self).setPlot3DWidget(widget)
+
+ # Connect to new Plot3DWidget
+ if widget is None:
+ self.setChecked(False)
+ else:
+ self.setChecked(widget.getInteractiveMode() == self._interaction)
+ widget.sigInteractiveModeChanged.connect(
+ self._interactiveModeChanged)
+
+ def _interactiveModeChanged(self):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is None:
+ _logger.error('Received a signal while there is no widget')
+ else:
+ self.setChecked(plot3d.getInteractiveMode() == self._interaction)
+
+
+class RotateArcballAction(InteractiveModeAction):
+ """QAction to set arcball rotation interaction on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(RotateArcballAction, self).__init__(parent, 'rotate', plot3d)
+
+ self.setIcon(getQIcon('rotate-3d'))
+ self.setText('Rotate')
+ self.setToolTip('Rotate the view. Press <b>Ctrl</b> to pan.')
+
+
+class PanAction(InteractiveModeAction):
+ """QAction to set pan interaction on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PanAction, self).__init__(parent, 'pan', plot3d)
+
+ self.setIcon(getQIcon('pan'))
+ self.setText('Pan')
+ self.setToolTip('Pan the view. Press <b>Ctrl</b> to rotate.')
+
+
+class PickingModeAction(Plot3DAction):
+ """QAction to toggle picking moe on a Plot3DWidget
+
+ :param parent: See :class:`QAction`
+ :param ~silx.gui.plot3d.Plot3DWidget.Plot3DWidget plot3d:
+ Plot3DWidget the action is associated with
+ """
+
+ sigSceneClicked = qt.Signal(float, float)
+ """Signal emitted when the scene is clicked with the left mouse button.
+
+ This signal is only emitted when the action is checked.
+
+ It provides the (x, y) clicked mouse position in logical widget pixel coordinates
+ """
+
+ def __init__(self, parent, plot3d=None):
+ super(PickingModeAction, self).__init__(parent, plot3d)
+ self.setIcon(getQIcon('pointing-hand'))
+ self.setText('Picking')
+ self.setToolTip('Toggle picking with left button click')
+ self.setCheckable(True)
+ self.triggered[bool].connect(self._triggered)
+
+ def _triggered(self, checked=False):
+ plot3d = self.getPlot3DWidget()
+ if plot3d is not None:
+ if checked:
+ plot3d.sigSceneClicked.connect(self.sigSceneClicked)
+ else:
+ plot3d.sigSceneClicked.disconnect(self.sigSceneClicked)
+
+ @docstring(Plot3DAction)
+ def setPlot3DWidget(self, widget):
+ # Disconnect from previous Plot3DWidget
+ plot3d = self.getPlot3DWidget()
+ if plot3d is not None and self.isChecked():
+ plot3d.sigSceneClicked.disconnect(self.sigSceneClicked)
+
+ super(PickingModeAction, self).setPlot3DWidget(widget)
+
+ # Connect to new Plot3DWidget
+ if widget is None:
+ self.setChecked(False)
+ elif self.isChecked():
+ widget.sigSceneClicked.connect(self.sigSceneClicked)
diff --git a/silx/gui/plot3d/actions/viewpoint.py b/src/silx/gui/plot3d/actions/viewpoint.py
index d764c40..d764c40 100644
--- a/silx/gui/plot3d/actions/viewpoint.py
+++ b/src/silx/gui/plot3d/actions/viewpoint.py
diff --git a/src/silx/gui/plot3d/conftest.py b/src/silx/gui/plot3d/conftest.py
new file mode 100644
index 0000000..da02238
--- /dev/null
+++ b/src/silx/gui/plot3d/conftest.py
@@ -0,0 +1,5 @@
+import pytest
+
+@pytest.mark.usefixtures("use_opengl")
+def setup_module(module):
+ pass
diff --git a/silx/gui/plot3d/items/__init__.py b/src/silx/gui/plot3d/items/__init__.py
index e7c4af1..e7c4af1 100644
--- a/silx/gui/plot3d/items/__init__.py
+++ b/src/silx/gui/plot3d/items/__init__.py
diff --git a/silx/gui/plot3d/items/_pick.py b/src/silx/gui/plot3d/items/_pick.py
index 0d6a495..0d6a495 100644
--- a/silx/gui/plot3d/items/_pick.py
+++ b/src/silx/gui/plot3d/items/_pick.py
diff --git a/silx/gui/plot3d/items/clipplane.py b/src/silx/gui/plot3d/items/clipplane.py
index 3e819d0..3e819d0 100644
--- a/silx/gui/plot3d/items/clipplane.py
+++ b/src/silx/gui/plot3d/items/clipplane.py
diff --git a/src/silx/gui/plot3d/items/core.py b/src/silx/gui/plot3d/items/core.py
new file mode 100644
index 0000000..0388ce7
--- /dev/null
+++ b/src/silx/gui/plot3d/items/core.py
@@ -0,0 +1,778 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the base class for items of the :class:`.SceneWidget`.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+from collections import defaultdict
+import enum
+
+import numpy
+
+from ... import qt
+from ...plot.items import ItemChangedType
+from .. import scene
+from ..scene import axes, primitives, transform
+from ._pick import PickContext
+
+
+@enum.unique
+class Item3DChangedType(enum.Enum):
+ """Type of modification provided by :attr:`Item3D.sigItemChanged` signal."""
+
+ INTERPOLATION = 'interpolationChanged'
+ """Item3D image interpolation changed flag."""
+
+ TRANSFORM = 'transformChanged'
+ """Item3D transform changed flag."""
+
+ HEIGHT_MAP = 'heightMapChanged'
+ """Item3D height map changed flag."""
+
+ ISO_LEVEL = 'isoLevelChanged'
+ """Isosurface level changed flag."""
+
+ LABEL = 'labelChanged'
+ """Item's label changed flag."""
+
+ BOUNDING_BOX_VISIBLE = 'boundingBoxVisibleChanged'
+ """Item's bounding box visibility changed"""
+
+ ROOT_ITEM = 'rootItemChanged'
+ """Item's root changed flag."""
+
+
+class Item3D(qt.QObject):
+ """Base class representing an item in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param primitive: An optional primitive to use as scene primitive
+ """
+
+ _LABEL_INDICES = defaultdict(int)
+ """Store per class label indices"""
+
+ sigItemChanged = qt.Signal(object)
+ """Signal emitted when an item's property has changed.
+
+ It provides a flag describing which property of the item has changed.
+ See :class:`ItemChangedType` and :class:`Item3DChangedType`
+ for flags description.
+ """
+
+ def __init__(self, parent, primitive=None):
+ qt.QObject.__init__(self, parent)
+
+ if primitive is None:
+ primitive = scene.Group()
+
+ self._primitive = primitive
+
+ self.__syncForegroundColor()
+
+ labelIndex = self._LABEL_INDICES[self.__class__]
+ self._label = str(self.__class__.__name__)
+ if labelIndex != 0:
+ self._label += u' %d' % labelIndex
+ self._LABEL_INDICES[self.__class__] += 1
+
+ if isinstance(parent, Item3D):
+ parent.sigItemChanged.connect(self.__parentItemChanged)
+
+ def setParent(self, parent):
+ """Override set parent to handle root item change"""
+ previousParent = self.parent()
+ if isinstance(previousParent, Item3D):
+ previousParent.sigItemChanged.disconnect(self.__parentItemChanged)
+
+ super(Item3D, self).setParent(parent)
+
+ if isinstance(parent, Item3D):
+ parent.sigItemChanged.connect(self.__parentItemChanged)
+
+ self._updated(Item3DChangedType.ROOT_ITEM)
+
+ def __parentItemChanged(self, event):
+ """Handle updates of the parent if it is an Item3D
+
+ :param Item3DChangedType event:
+ """
+ if event == Item3DChangedType.ROOT_ITEM:
+ self._updated(Item3DChangedType.ROOT_ITEM)
+
+ def root(self):
+ """Returns the root of the scene this item belongs to.
+
+ The root is the up-most Item3D in the scene tree hierarchy.
+
+ :rtype: Union[Item3D, None]
+ """
+ root = None
+ ancestor = self.parent()
+ while isinstance(ancestor, Item3D):
+ root = ancestor
+ ancestor = ancestor.parent()
+
+ return root
+
+ def _getScenePrimitive(self):
+ """Return the group containing the item rendering"""
+ return self._primitive
+
+ def _updated(self, event=None):
+ """Handle MixIn class updates.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ """
+ if event == Item3DChangedType.ROOT_ITEM:
+ self.__syncForegroundColor()
+
+ if event is not None:
+ self.sigItemChanged.emit(event)
+
+ # Label
+
+ def getLabel(self):
+ """Returns the label associated to this item.
+
+ :rtype: str
+ """
+ return self._label
+
+ def setLabel(self, label):
+ """Set the label associated to this item.
+
+ :param str label:
+ """
+ label = str(label)
+ if label != self._label:
+ self._label = label
+ self._updated(Item3DChangedType.LABEL)
+
+ # Visibility
+
+ def isVisible(self):
+ """Returns True if item is visible, else False
+
+ :rtype: bool
+ """
+ return self._getScenePrimitive().visible
+
+ def setVisible(self, visible=True):
+ """Set the visibility of the item in the scene.
+
+ :param bool visible: True (default) to show the item, False to hide
+ """
+ visible = bool(visible)
+ primitive = self._getScenePrimitive()
+ if visible != primitive.visible:
+ primitive.visible = visible
+ self._updated(ItemChangedType.VISIBLE)
+
+ # Foreground color
+
+ def _setForegroundColor(self, color):
+ """Set the foreground color of the item.
+
+ The default implementation does nothing, override it in subclass.
+
+ :param color: RGBA color
+ :type color: tuple of 4 float in [0., 1.]
+ """
+ if hasattr(super(Item3D, self), '_setForegroundColor'):
+ super(Item3D, self)._setForegroundColor(color)
+
+ def __syncForegroundColor(self):
+ """Retrieve foreground color from parent and update this item"""
+ # Look-up for SceneWidget to get its foreground color
+ root = self.root()
+ if root is not None:
+ widget = root.parent()
+ if isinstance(widget, qt.QWidget):
+ self._setForegroundColor(
+ widget.getForegroundColor().getRgbF())
+
+ # picking
+
+ def _pick(self, context):
+ """Implement picking on this item.
+
+ :param PickContext context: Current picking context
+ :return: Data indices at picked position or None
+ :rtype: Union[None,PickingResult]
+ """
+ if (self.isVisible() and
+ context.isEnabled() and
+ context.isItemPickable(self) and
+ self._pickFastCheck(context)):
+ return self._pickFull(context)
+ return None
+
+ def _pickFastCheck(self, context):
+ """Approximate item pick test (e.g., bounding box-based picking).
+
+ :param PickContext context: Current picking context
+ :return: True if item might be picked
+ :rtype: bool
+ """
+ primitive = self._getScenePrimitive()
+
+ positionNdc = context.getNDCPosition()
+ if positionNdc is None: # No picking outside viewport
+ return False
+
+ bounds = primitive.bounds(transformed=False, dataBounds=False)
+ if bounds is None: # primitive has no bounds
+ return False
+
+ bounds = primitive.objectToNDCTransform.transformBounds(bounds)
+
+ return (bounds[0, 0] <= positionNdc[0] <= bounds[1, 0] and
+ bounds[0, 1] <= positionNdc[1] <= bounds[1, 1])
+
+ def _pickFull(self, context):
+ """Perform precise picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ return None
+
+
+class DataItem3D(Item3D):
+ """Base class representing a data item with transform in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param Union[GroupBBox, None] group:
+ The scene group to use for rendering
+ """
+
+ def __init__(self, parent, group=None):
+ if group is None:
+ group = primitives.GroupBBox()
+
+ # Set-up bounding box
+ group.boxVisible = False
+ group.axesVisible = False
+ else:
+ assert isinstance(group, primitives.GroupBBox)
+
+ Item3D.__init__(self, parent=parent, primitive=group)
+
+ # Transformations
+ self._translate = transform.Translate()
+ self._rotateForwardTranslation = transform.Translate()
+ self._rotate = transform.Rotate()
+ self._rotateBackwardTranslation = transform.Translate()
+ self._translateFromRotationCenter = transform.Translate()
+ self._matrix = transform.Matrix()
+ self._scale = transform.Scale()
+ # Group transforms to do to data before rotation
+ # This is useful to handle rotation center relative to bbox
+ self._transformObjectToRotate = transform.TransformList(
+ [self._matrix, self._scale])
+ self._transformObjectToRotate.addListener(self._updateRotationCenter)
+
+ self._rotationCenter = 0., 0., 0.
+
+ self.__transforms = transform.TransformList([
+ self._translate,
+ self._rotateForwardTranslation,
+ self._rotate,
+ self._rotateBackwardTranslation,
+ self._transformObjectToRotate])
+
+ self._getScenePrimitive().transforms = self.__transforms
+
+ def _updated(self, event=None):
+ """Handle MixIn class updates.
+
+ :param event: The event to send to :attr:`sigItemChanged` signal.
+ """
+ if event == ItemChangedType.DATA:
+ self._updateRotationCenter()
+ super(DataItem3D, self)._updated(event)
+
+ # Transformations
+
+ def _getSceneTransforms(self):
+ """Return TransformList corresponding to current transforms
+
+ :rtype: TransformList
+ """
+ return self.__transforms
+
+ def setScale(self, sx=1., sy=1., sz=1.):
+ """Set the scale of the item in the scene.
+
+ :param float sx: Scale factor along the X axis
+ :param float sy: Scale factor along the Y axis
+ :param float sz: Scale factor along the Z axis
+ """
+ scale = numpy.array((sx, sy, sz), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(scale, self.getScale())):
+ self._scale.scale = scale
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getScale(self):
+ """Returns the scales provided by :meth:`setScale`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._scale.scale
+
+ def setTranslation(self, x=0., y=0., z=0.):
+ """Set the translation of the origin of the item in the scene.
+
+ :param float x: Offset of the data origin on the X axis
+ :param float y: Offset of the data origin on the Y axis
+ :param float z: Offset of the data origin on the Z axis
+ """
+ translation = numpy.array((x, y, z), dtype=numpy.float32)
+ if not numpy.all(numpy.equal(translation, self.getTranslation())):
+ self._translate.translation = translation
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getTranslation(self):
+ """Returns the offset set by :meth:`setTranslation`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._translate.translation
+
+ _ROTATION_CENTER_TAGS = 'lower', 'center', 'upper'
+
+ def _updateRotationCenter(self, *args, **kwargs):
+ """Update rotation center relative to bounding box"""
+ center = []
+ for index, position in enumerate(self.getRotationCenter()):
+ # Patch position relative to bounding box
+ if position in self._ROTATION_CENTER_TAGS:
+ bounds = self._getScenePrimitive().bounds(
+ transformed=False, dataBounds=True)
+ bounds = self._transformObjectToRotate.transformBounds(bounds)
+
+ if bounds is None:
+ position = 0.
+ elif position == 'lower':
+ position = bounds[0, index]
+ elif position == 'center':
+ position = 0.5 * (bounds[0, index] + bounds[1, index])
+ elif position == 'upper':
+ position = bounds[1, index]
+
+ center.append(position)
+
+ if not numpy.all(numpy.equal(
+ center, self._rotateForwardTranslation.translation)):
+ self._rotateForwardTranslation.translation = center
+ self._rotateBackwardTranslation.translation = \
+ - self._rotateForwardTranslation.translation
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def setRotationCenter(self, x=0., y=0., z=0.):
+ """Set the center of rotation of the item.
+
+ Position of the rotation center is either a float
+ for an absolute position or one of the following
+ string to define a position relative to the item's bounding box:
+ 'lower', 'center', 'upper'
+
+ :param x: rotation center position on the X axis
+ :rtype: float or str
+ :param y: rotation center position on the Y axis
+ :rtype: float or str
+ :param z: rotation center position on the Z axis
+ :rtype: float or str
+ """
+ center = []
+ for position in (x, y, z):
+ if isinstance(position, str):
+ assert position in self._ROTATION_CENTER_TAGS
+ else:
+ position = float(position)
+ center.append(position)
+ center = tuple(center)
+
+ if center != self._rotationCenter:
+ self._rotationCenter = center
+ self._updateRotationCenter()
+
+ def getRotationCenter(self):
+ """Returns the rotation center set by :meth:`setRotationCenter`.
+
+ :rtype: 3-tuple of float or str
+ """
+ return self._rotationCenter
+
+ def setRotation(self, angle=0., axis=(0., 0., 1.)):
+ """Set the rotation of the item in the scene
+
+ :param float angle: The rotation angle in degrees.
+ :param axis: The (x, y, z) coordinates of the rotation axis.
+ """
+ axis = numpy.array(axis, dtype=numpy.float32)
+ assert axis.ndim == 1
+ assert axis.size == 3
+ if (self._rotate.angle != angle or
+ not numpy.all(numpy.equal(axis, self._rotate.axis))):
+ self._rotate.setAngleAxis(angle, axis)
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getRotation(self):
+ """Returns the rotation set by :meth:`setRotation`.
+
+ :return: (angle, axis)
+ :rtype: 2-tuple (float, numpy.ndarray)
+ """
+ return self._rotate.angle, self._rotate.axis
+
+ def setMatrix(self, matrix=None):
+ """Set the transform matrix
+
+ :param numpy.ndarray matrix: 3x3 transform matrix
+ """
+ matrix4x4 = numpy.identity(4, dtype=numpy.float32)
+
+ if matrix is not None:
+ matrix = numpy.array(matrix, dtype=numpy.float32)
+ assert matrix.shape == (3, 3)
+ matrix4x4[:3, :3] = matrix
+
+ if not numpy.all(numpy.equal(matrix4x4, self._matrix.getMatrix())):
+ self._matrix.setMatrix(matrix4x4)
+ self._updated(Item3DChangedType.TRANSFORM)
+
+ def getMatrix(self):
+ """Returns the matrix set by :meth:`setMatrix`
+
+ :return: 3x3 matrix
+ :rtype: numpy.ndarray"""
+ return self._matrix.getMatrix(copy=True)[:3, :3]
+
+ # Bounding box
+
+ def _setForegroundColor(self, color):
+ """Set the color of the bounding box
+
+ :param color: RGBA color as 4 floats in [0, 1]
+ """
+ self._getScenePrimitive().color = color
+ super(DataItem3D, self)._setForegroundColor(color)
+
+ def isBoundingBoxVisible(self):
+ """Returns item's bounding box visibility.
+
+ :rtype: bool
+ """
+ return self._getScenePrimitive().boxVisible
+
+ def setBoundingBoxVisible(self, visible):
+ """Set item's bounding box visibility.
+
+ :param bool visible:
+ True to show the bounding box, False (default) to hide it
+ """
+ visible = bool(visible)
+ primitive = self._getScenePrimitive()
+ if visible != primitive.boxVisible:
+ primitive.boxVisible = visible
+ self._updated(Item3DChangedType.BOUNDING_BOX_VISIBLE)
+
+
+class BaseNodeItem(DataItem3D):
+ """Base class for data item having children (e.g., group, 3d volume)."""
+
+ def __init__(self, parent=None, group=None):
+ """Base class representing a group of items in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param Union[GroupBBox, None] group:
+ The scene group to use for rendering
+ """
+ DataItem3D.__init__(self, parent=parent, group=group)
+
+ def getItems(self):
+ """Returns the list of items currently present in the group.
+
+ :rtype: tuple
+ """
+ raise NotImplementedError('getItems must be implemented in subclass')
+
+ def visit(self, included=True):
+ """Generator visiting the group content.
+
+ It traverses the group sub-tree in a top-down left-to-right way.
+
+ :param bool included: True (default) to include self in visit
+ """
+ if included:
+ yield self
+ for child in self.getItems():
+ yield child
+ if hasattr(child, 'visit'):
+ for item in child.visit(included=False):
+ yield item
+
+ def pickItems(self, x, y, condition=None):
+ """Iterator over picked items in the group at given position.
+
+ Each picked item yield a :class:`PickingResult` object
+ holding the picking information.
+
+ It traverses the group sub-tree in a left-to-right top-down way.
+
+ :param int x: X widget device pixel coordinate
+ :param int y: Y widget device pixel coordinate
+ :param callable condition: Optional test called for each item
+ checking whether to process it or not.
+ """
+ viewport = self._getScenePrimitive().viewport
+ if viewport is None:
+ raise RuntimeError(
+ 'Cannot perform picking: Item not attached to a widget')
+
+ context = PickContext(x, y, viewport, condition)
+ for result in self._pickItems(context):
+ yield result
+
+ def _pickItems(self, context):
+ """Implement :meth:`pickItems`
+
+ :param PickContext context: Current picking context
+ """
+ if not self.isVisible() or not context.isEnabled():
+ return # empty iterator
+
+ # Use a copy to discard context changes once this returns
+ context = context.copy()
+
+ if not self._pickFastCheck(context):
+ return # empty iterator
+
+ result = self._pick(context)
+ if result is not None:
+ yield result
+
+ for child in self.getItems():
+ if isinstance(child, BaseNodeItem):
+ for result in child._pickItems(context):
+ yield result # Flatten result
+
+ else:
+ result = child._pick(context)
+ if result is not None:
+ yield result
+
+
+class _BaseGroupItem(BaseNodeItem):
+ """Base class for group of items sharing a common transform."""
+
+ sigItemAdded = qt.Signal(object)
+ """Signal emitted when a new item is added to the group.
+
+ The newly added item is provided by this signal
+ """
+
+ sigItemRemoved = qt.Signal(object)
+ """Signal emitted when an item is removed from the group.
+
+ The removed item is provided by this signal.
+ """
+
+ def __init__(self, parent=None, group=None):
+ """Base class representing a group of items in the scene.
+
+ :param parent: The View widget this item belongs to.
+ :param Union[GroupBBox, None] group:
+ The scene group to use for rendering
+ """
+ BaseNodeItem.__init__(self, parent=parent, group=group)
+ self._items = []
+
+ def _getGroupPrimitive(self):
+ """Returns the group for which to handle children.
+
+ This allows this group to be different from the primitive.
+ """
+ return self._getScenePrimitive()
+
+ def addItem(self, item, index=None):
+ """Add an item to the group
+
+ :param Item3D item: The item to add
+ :param int index: The index at which to place the item.
+ By default it is appended to the end of the list.
+ :raise ValueError: If the item is already in the group.
+ """
+ assert isinstance(item, Item3D)
+ assert item.parent() in (None, self)
+
+ if item in self.getItems():
+ raise ValueError("Item3D already in group: %s" % item)
+
+ item.setParent(self)
+ if index is None:
+ self._getGroupPrimitive().children.append(
+ item._getScenePrimitive())
+ self._items.append(item)
+ else:
+ self._getGroupPrimitive().children.insert(
+ index, item._getScenePrimitive())
+ self._items.insert(index, item)
+ self.sigItemAdded.emit(item)
+
+ def getItems(self):
+ """Returns the list of items currently present in the group.
+
+ :rtype: tuple
+ """
+ return tuple(self._items)
+
+ def removeItem(self, item):
+ """Remove an item from the scene.
+
+ :param Item3D item: The item to remove from the scene
+ :raises ValueError: If the item does not belong to the group
+ """
+ if item not in self.getItems():
+ raise ValueError("Item3D not in group: %s" % str(item))
+
+ self._getGroupPrimitive().children.remove(item._getScenePrimitive())
+ self._items.remove(item)
+ item.setParent(None)
+ self.sigItemRemoved.emit(item)
+
+ def clearItems(self):
+ """Remove all item from the group."""
+ for item in self.getItems():
+ self.removeItem(item)
+
+
+class GroupItem(_BaseGroupItem):
+ """Group of items sharing a common transform."""
+
+ def __init__(self, parent=None):
+ super(GroupItem, self).__init__(parent=parent)
+
+
+class GroupWithAxesItem(_BaseGroupItem):
+ """
+ Group of items sharing a common transform surrounded with labelled axes.
+ """
+
+ def __init__(self, parent=None):
+ """Class representing a group of items in the scene with labelled axes.
+
+ :param parent: The View widget this item belongs to.
+ """
+ super(GroupWithAxesItem, self).__init__(parent=parent,
+ group=axes.LabelledAxes())
+
+ # Axes labels
+
+ def setAxesLabels(self, xlabel=None, ylabel=None, zlabel=None):
+ """Set the text labels of the axes.
+
+ :param str xlabel: Label of the X axis, None to leave unchanged.
+ :param str ylabel: Label of the Y axis, None to leave unchanged.
+ :param str zlabel: Label of the Z axis, None to leave unchanged.
+ """
+ labelledAxes = self._getScenePrimitive()
+ if xlabel is not None:
+ labelledAxes.xlabel = xlabel
+
+ if ylabel is not None:
+ labelledAxes.ylabel = ylabel
+
+ if zlabel is not None:
+ labelledAxes.zlabel = zlabel
+
+ class _Labels(tuple):
+ """Return type of :meth:`getAxesLabels`"""
+
+ def getXLabel(self):
+ """Label of the X axis (str)"""
+ return self[0]
+
+ def getYLabel(self):
+ """Label of the Y axis (str)"""
+ return self[1]
+
+ def getZLabel(self):
+ """Label of the Z axis (str)"""
+ return self[2]
+
+ def getAxesLabels(self):
+ """Returns the text labels of the axes
+
+ >>> group = GroupWithAxesItem()
+ >>> group.setAxesLabels(xlabel='X')
+
+ You can get the labels either as a 3-tuple:
+
+ >>> xlabel, ylabel, zlabel = group.getAxesLabels()
+
+ Or as an object with methods getXLabel, getYLabel and getZLabel:
+
+ >>> labels = group.getAxesLabels()
+ >>> labels.getXLabel()
+ ... 'X'
+
+ :return: object describing the labels
+ """
+ labelledAxes = self._getScenePrimitive()
+ return self._Labels((labelledAxes.xlabel,
+ labelledAxes.ylabel,
+ labelledAxes.zlabel))
+
+
+class RootGroupWithAxesItem(GroupWithAxesItem):
+ """Special group with axes item for root of the scene.
+
+ Uses 2 groups so that axes take transforms into account.
+ """
+
+ def __init__(self, parent=None):
+ super(RootGroupWithAxesItem, self).__init__(parent)
+ self.__group = scene.Group()
+ self.__group.transforms = self._getSceneTransforms()
+
+ groupWithAxes = self._getScenePrimitive()
+ groupWithAxes.transforms = [] # Do not apply transforms here
+ groupWithAxes.children.append(self.__group)
+
+ def _getGroupPrimitive(self):
+ """Returns the group for which to handle children.
+
+ This allows this group to be different from the primitive.
+ """
+ return self.__group
diff --git a/src/silx/gui/plot3d/items/image.py b/src/silx/gui/plot3d/items/image.py
new file mode 100644
index 0000000..5a50459
--- /dev/null
+++ b/src/silx/gui/plot3d/items/image.py
@@ -0,0 +1,425 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides 2D data and RGB(A) image item class.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/11/2017"
+
+import numpy
+
+from ..scene import primitives, utils
+from .core import DataItem3D, ItemChangedType
+from .mixins import ColormapMixIn, InterpolationMixIn
+from ._pick import PickingResult
+
+
+class _Image(DataItem3D, InterpolationMixIn):
+ """Base class for images
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ InterpolationMixIn.__init__(self)
+
+ def _setPrimitive(self, primitive):
+ InterpolationMixIn._setPrimitive(self, primitive)
+
+ def getData(self, copy=True):
+ raise NotImplementedError()
+
+ def _pickFull(self, context):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ rayObject = context.getPickingSegment(frame=self._getScenePrimitive())
+ if rayObject is None:
+ return None
+
+ points = utils.segmentPlaneIntersect(
+ rayObject[0, :3],
+ rayObject[1, :3],
+ planeNorm=numpy.array((0., 0., 1.), dtype=numpy.float64),
+ planePt=numpy.array((0., 0., 0.), dtype=numpy.float64))
+
+ if len(points) == 1: # Single intersection
+ if points[0][0] < 0. or points[0][1] < 0.:
+ return None # Outside image
+ row, column = int(points[0][1]), int(points[0][0])
+ data = self.getData(copy=False)
+ height, width = data.shape[:2]
+ if row < height and column < width:
+ return PickingResult(
+ self,
+ positions=[(points[0][0], points[0][1], 0.)],
+ indices=([row], [column]))
+ else:
+ return None # Outside image
+ else: # Either no intersection or segment and image are coplanar
+ return None
+
+
+class ImageData(_Image, ColormapMixIn):
+ """Description of a 2D image data.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _Image.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+
+ self._data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ self._image = primitives.ImageData(self._data)
+ self._getScenePrimitive().children.append(self._image)
+
+ # Connect scene primitive to mix-in class
+ ColormapMixIn._setSceneColormap(self, self._image.colormap)
+ _Image._setPrimitive(self, self._image)
+
+ def setData(self, data, copy=True):
+ """Set the image data to display.
+
+ The data will be casted to float32.
+
+ :param numpy.ndarray data: The image data
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ self._image.setData(data, copy=copy)
+ self._setColormappedData(self.getData(copy=False), copy=False)
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Get the image data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :rtype: numpy.ndarray
+ :return: The image data
+ """
+ return self._image.getData(copy=copy)
+
+
+class ImageRgba(_Image, InterpolationMixIn):
+ """Description of a 2D data RGB(A) image.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _Image.__init__(self, parent=parent)
+ InterpolationMixIn.__init__(self)
+
+ self._data = numpy.zeros((0, 0, 3), dtype=numpy.float32)
+
+ self._image = primitives.ImageRgba(self._data)
+ self._getScenePrimitive().children.append(self._image)
+
+ # Connect scene primitive to mix-in class
+ _Image._setPrimitive(self, self._image)
+
+ def setData(self, data, copy=True):
+ """Set the RGB(A) image data to display.
+
+ Supported array format: float32 in [0, 1], uint8.
+
+ :param numpy.ndarray data:
+ The RGBA image data as an array of shape (H, W, Channels)
+ :param bool copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ self._image.setData(data, copy=copy)
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy=True):
+ """Get the image data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ :rtype: numpy.ndarray
+ :return: The image data
+ """
+ return self._image.getData(copy=copy)
+
+
+class _HeightMap(DataItem3D):
+ """Base class for 2D data array displayed as a height field.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ DataItem3D.__init__(self, parent=parent)
+ self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ def _pickFull(self, context, threshold=0., sort='depth'):
+ """Perform picking in this item at given widget position.
+
+ :param PickContext context: Current picking context
+ :param float threshold: Picking threshold in pixel.
+ Perform picking in a square of size threshold x threshold.
+ :param str sort: How returned indices are sorted:
+
+ - 'index' (default): sort by the value of the indices
+ - 'depth': Sort by the depth of the points from the current
+ camera point of view.
+ :return: Object holding the results or None
+ :rtype: Union[None,PickingResult]
+ """
+ assert sort in ('index', 'depth')
+
+ rayNdc = context.getPickingSegment(frame='ndc')
+ if rayNdc is None: # No picking outside viewport
+ return None
+
+ # TODO no colormapped or color data
+ # Project data to NDC
+ heightData = self.getData(copy=False)
+ if heightData.size == 0:
+ return # Nothing displayed
+
+ height, width = heightData.shape
+ z = numpy.ravel(heightData)
+ y, x = numpy.mgrid[0:height, 0:width]
+ dataPoints = numpy.transpose((numpy.ravel(x),
+ numpy.ravel(y),
+ z,
+ numpy.ones_like(z)))
+
+ primitive = self._getScenePrimitive()
+
+ pointsNdc = primitive.objectToNDCTransform.transformPoints(
+ dataPoints, perspectiveDivide=True)
+
+ # Perform picking
+ distancesNdc = numpy.abs(pointsNdc[:, :2] - rayNdc[0, :2])
+ # TODO issue with symbol size: using pixel instead of points
+ threshold += 1. # symbol size
+ thresholdNdc = 2. * threshold / numpy.array(primitive.viewport.size)
+ picked = numpy.where(numpy.logical_and(
+ numpy.all(distancesNdc < thresholdNdc, axis=1),
+ numpy.logical_and(rayNdc[0, 2] <= pointsNdc[:, 2],
+ pointsNdc[:, 2] <= rayNdc[1, 2])))[0]
+
+ if sort == 'depth':
+ # Sort picked points from front to back
+ picked = picked[numpy.argsort(pointsNdc[picked, 2])]
+
+ if picked.size > 0:
+ # Convert indices from 1D to 2D
+ return PickingResult(self,
+ positions=dataPoints[picked, :3],
+ indices=(picked // width, picked % width),
+ fetchdata=self.getData)
+ else:
+ return None
+
+ def setData(self, data, copy: bool=True):
+ """Set the height field data.
+
+ :param data:
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ self.__data = data
+ self._updated(ItemChangedType.DATA)
+
+ def getData(self, copy: bool=True) -> numpy.ndarray:
+ """Get the height field 2D data.
+
+ :param bool copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__data, copy=copy)
+
+
+class HeightMapData(_HeightMap, ColormapMixIn):
+ """Description of a 2D height field associated to a colormapped dataset.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _HeightMap.__init__(self, parent=parent)
+ ColormapMixIn.__init__(self)
+
+ self.__data = numpy.zeros((0, 0), dtype=numpy.float32)
+
+ def _updated(self, event=None):
+ if event == ItemChangedType.DATA:
+ self.__updateScene()
+ super()._updated(event=event)
+
+ def __updateScene(self):
+ """Update display primitive to use"""
+ self._getScenePrimitive().children = [] # Remove previous primitives
+ ColormapMixIn._setSceneColormap(self, None)
+
+ if not self.isVisible():
+ return # Update when visible
+
+ data = self.getColormappedData(copy=False)
+ heightData = self.getData(copy=False)
+
+ if data.size == 0 or heightData.size == 0:
+ return # Nothing to display
+
+ # Display as a set of points
+ height, width = heightData.shape
+ # Generates coordinates
+ y, x = numpy.mgrid[0:height, 0:width]
+
+ if data.shape != heightData.shape: # data and height size miss-match
+ # Colormapped data is interpolated (nearest-neighbour) to match the height field
+ data = data[numpy.floor(y * data.shape[0] / height).astype(numpy.int32),
+ numpy.floor(x * data.shape[1] / height).astype(numpy.int32)]
+
+ x = numpy.ravel(x)
+ y = numpy.ravel(y)
+
+ primitive = primitives.Points(
+ x=x,
+ y=y,
+ z=numpy.ravel(heightData),
+ value=numpy.ravel(data),
+ size=1)
+ primitive.marker = 's'
+ ColormapMixIn._setSceneColormap(self, primitive.colormap)
+ self._getScenePrimitive().children = [primitive]
+
+ def setColormappedData(self, data, copy: bool=True):
+ """Set the 2D data used to compute colors.
+
+ :param data: 2D array of data
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 2
+
+ self.__data = data
+ self._updated(ItemChangedType.DATA)
+
+ def getColormappedData(self, copy: bool=True) -> numpy.ndarray:
+ """Returns the 2D data used to compute colors.
+
+ :param copy:
+ True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__data, copy=copy)
+
+
+class HeightMapRGBA(_HeightMap):
+ """Description of a 2D height field associated to a RGB(A) image.
+
+ :param parent: The View widget this item belongs to.
+ """
+
+ def __init__(self, parent=None):
+ _HeightMap.__init__(self, parent=parent)
+
+ self.__rgba = numpy.zeros((0, 0, 3), dtype=numpy.float32)
+
+ def _updated(self, event=None):
+ if event == ItemChangedType.DATA:
+ self.__updateScene()
+ super()._updated(event=event)
+
+ def __updateScene(self):
+ """Update display primitive to use"""
+ self._getScenePrimitive().children = [] # Remove previous primitives
+
+ if not self.isVisible():
+ return # Update when visible
+
+ rgba = self.getColorData(copy=False)
+ heightData = self.getData(copy=False)
+ if rgba.size == 0 or heightData.size == 0:
+ return # Nothing to display
+
+ # Display as a set of points
+ height, width = heightData.shape
+ # Generates coordinates
+ y, x = numpy.mgrid[0:height, 0:width]
+
+ if rgba.shape[:2] != heightData.shape: # image and height size miss-match
+ # RGBA data is interpolated (nearest-neighbour) to match the height field
+ rgba = rgba[numpy.floor(y * rgba.shape[0] / height).astype(numpy.int32),
+ numpy.floor(x * rgba.shape[1] / height).astype(numpy.int32)]
+
+ x = numpy.ravel(x)
+ y = numpy.ravel(y)
+
+ primitive = primitives.ColorPoints(
+ x=x,
+ y=y,
+ z=numpy.ravel(heightData),
+ color=rgba.reshape(-1, rgba.shape[-1]),
+ size=1)
+ primitive.marker = 's'
+ self._getScenePrimitive().children = [primitive]
+
+ def setColorData(self, data, copy: bool=True):
+ """Set the RGB(A) image to use.
+
+ Supported array format: float32 in [0, 1], uint8.
+
+ :param data:
+ The RGBA image data as an array of shape (H, W, Channels)
+ :param copy: True (default) to copy the data,
+ False to use as is (do not modify!).
+ """
+ data = numpy.array(data, copy=copy)
+ assert data.ndim == 3
+ assert data.shape[-1] in (3, 4)
+ # TODO check type
+
+ self.__rgba = data
+ self._updated(ItemChangedType.DATA)
+
+ def getColorData(self, copy: bool=True) -> numpy.ndarray:
+ """Get the RGB(A) image data.
+
+ :param copy: True (default) to get a copy,
+ False to get internal representation (do not modify!).
+ """
+ return numpy.array(self.__rgba, copy=copy)
diff --git a/silx/gui/plot3d/items/mesh.py b/src/silx/gui/plot3d/items/mesh.py
index 4e19939..4e19939 100644
--- a/silx/gui/plot3d/items/mesh.py
+++ b/src/silx/gui/plot3d/items/mesh.py
diff --git a/silx/gui/plot3d/items/mixins.py b/src/silx/gui/plot3d/items/mixins.py
index f512365..f512365 100644
--- a/silx/gui/plot3d/items/mixins.py
+++ b/src/silx/gui/plot3d/items/mixins.py
diff --git a/silx/gui/plot3d/items/scatter.py b/src/silx/gui/plot3d/items/scatter.py
index 24abaa5..24abaa5 100644
--- a/silx/gui/plot3d/items/scatter.py
+++ b/src/silx/gui/plot3d/items/scatter.py
diff --git a/silx/gui/plot3d/items/volume.py b/src/silx/gui/plot3d/items/volume.py
index f80fea2..f80fea2 100644
--- a/silx/gui/plot3d/items/volume.py
+++ b/src/silx/gui/plot3d/items/volume.py
diff --git a/silx/gui/plot3d/scene/__init__.py b/src/silx/gui/plot3d/scene/__init__.py
index 9671725..9671725 100644
--- a/silx/gui/plot3d/scene/__init__.py
+++ b/src/silx/gui/plot3d/scene/__init__.py
diff --git a/silx/gui/plot3d/scene/axes.py b/src/silx/gui/plot3d/scene/axes.py
index e35e5e1..e35e5e1 100644
--- a/silx/gui/plot3d/scene/axes.py
+++ b/src/silx/gui/plot3d/scene/axes.py
diff --git a/silx/gui/plot3d/scene/camera.py b/src/silx/gui/plot3d/scene/camera.py
index 90de7ed..90de7ed 100644
--- a/silx/gui/plot3d/scene/camera.py
+++ b/src/silx/gui/plot3d/scene/camera.py
diff --git a/silx/gui/plot3d/scene/core.py b/src/silx/gui/plot3d/scene/core.py
index 43838fe..43838fe 100644
--- a/silx/gui/plot3d/scene/core.py
+++ b/src/silx/gui/plot3d/scene/core.py
diff --git a/silx/gui/plot3d/scene/cutplane.py b/src/silx/gui/plot3d/scene/cutplane.py
index 88147df..88147df 100644
--- a/silx/gui/plot3d/scene/cutplane.py
+++ b/src/silx/gui/plot3d/scene/cutplane.py
diff --git a/silx/gui/plot3d/scene/event.py b/src/silx/gui/plot3d/scene/event.py
index 98f8f8b..98f8f8b 100644
--- a/silx/gui/plot3d/scene/event.py
+++ b/src/silx/gui/plot3d/scene/event.py
diff --git a/silx/gui/plot3d/scene/function.py b/src/silx/gui/plot3d/scene/function.py
index 2deb785..2deb785 100644
--- a/silx/gui/plot3d/scene/function.py
+++ b/src/silx/gui/plot3d/scene/function.py
diff --git a/silx/gui/plot3d/scene/interaction.py b/src/silx/gui/plot3d/scene/interaction.py
index 14a54dc..14a54dc 100644
--- a/silx/gui/plot3d/scene/interaction.py
+++ b/src/silx/gui/plot3d/scene/interaction.py
diff --git a/silx/gui/plot3d/scene/primitives.py b/src/silx/gui/plot3d/scene/primitives.py
index 7f35c3c..7f35c3c 100644
--- a/silx/gui/plot3d/scene/primitives.py
+++ b/src/silx/gui/plot3d/scene/primitives.py
diff --git a/src/silx/gui/plot3d/scene/test/__init__.py b/src/silx/gui/plot3d/scene/test/__init__.py
new file mode 100644
index 0000000..3bb978e
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/plot3d/scene/test/test_transform.py b/src/silx/gui/plot3d/scene/test/test_transform.py
new file mode 100644
index 0000000..69e991b
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/test/test_transform.py
@@ -0,0 +1,80 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/01/2017"
+
+
+import numpy
+import unittest
+
+from silx.gui.plot3d.scene import transform
+
+
+class TestTransformList(unittest.TestCase):
+
+ def assertSameArrays(self, a, b):
+ return self.assertTrue(numpy.allclose(a, b, atol=1e-06))
+
+ def testTransformList(self):
+ """Minimalistic test of TransformList"""
+ transforms = transform.TransformList()
+ refmatrix = numpy.identity(4, dtype=numpy.float32)
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Append translate
+ transforms.append(transform.Translate(1., 1., 1.))
+ refmatrix = numpy.array(((1., 0., 0., 1.),
+ (0., 1., 0., 1.),
+ (0., 0., 1., 1.),
+ (0., 0., 0., 1.)), dtype=numpy.float32)
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Extend scale
+ transforms.extend([transform.Scale(0.1, 2., 1.)])
+ refmatrix = numpy.dot(refmatrix,
+ numpy.array(((0.1, 0., 0., 0.),
+ (0., 2., 0., 0.),
+ (0., 0., 1., 0.),
+ (0., 0., 0., 1.)),
+ dtype=numpy.float32))
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Insert rotate
+ transforms.insert(0, transform.Rotate(360.))
+ self.assertSameArrays(refmatrix, transforms.matrix)
+
+ # Update translate and check for listener called
+ self._callCount = 0
+
+ def listener(source):
+ self._callCount += 1
+ transforms.addListener(listener)
+
+ transforms[1].tx += 1
+ self.assertEqual(self._callCount, 1)
diff --git a/src/silx/gui/plot3d/scene/test/test_utils.py b/src/silx/gui/plot3d/scene/test/test_utils.py
new file mode 100644
index 0000000..65d0ce0
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/test/test_utils.py
@@ -0,0 +1,258 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+from silx.utils.testutils import ParametricTestCase
+
+import numpy
+
+from silx.gui.plot3d.scene import utils
+
+
+# angleBetweenVectors #########################################################
+
+class TestAngleBetweenVectors(ParametricTestCase):
+
+ TESTS = { # name: (refvector, vectors, norm, refangles)
+ 'single vector':
+ ((1., 0., 0.), (1., 0., 0.), (0., 0., 1.), 0.),
+ 'single vector, no norm':
+ ((1., 0., 0.), (1., 0., 0.), None, 0.),
+
+ 'with orthogonal norm':
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ (0., 0., 1.),
+ (0., 90., 180., 270.)),
+
+ 'with coplanar norm': # = similar to no norm
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ (1., 0., 0.),
+ (0., 90., 180., 90.)),
+
+ 'without norm':
+ ((1., 0., 0.),
+ ((1., 0., 0.), (0., 1., 0.), (-1., 0., 0.), (0., -1., 0.)),
+ None,
+ (0., 90., 180., 90.)),
+
+ 'not unit vectors':
+ ((2., 2., 0.), ((1., 1., 0.), (1., -1., 0.)), None, (0., 90.)),
+ }
+
+ def testAngleBetweenVectorsFunction(self):
+ for name, params in self.TESTS.items():
+ refvector, vectors, norm, refangles = params
+ with self.subTest(name):
+ refangles = numpy.radians(refangles)
+
+ refvector = numpy.array(refvector)
+ vectors = numpy.array(vectors)
+ if norm is not None:
+ norm = numpy.array(norm)
+
+ testangles = utils.angleBetweenVectors(
+ refvector, vectors, norm)
+
+ self.assertTrue(
+ numpy.allclose(testangles, refangles, atol=1e-5))
+
+
+# Plane #######################################################################
+
+class AssertNotificationContext(object):
+ """Context that checks if an event.Notifier is sending events."""
+
+ def __init__(self, notifier, count=1):
+ """Initializer.
+
+ :param event.Notifier notifier: The notifier to test.
+ :param int count: The expected number of calls.
+ """
+ self._notifier = notifier
+ self._callCount = None
+ self._count = count
+
+ def __enter__(self):
+ self._callCount = 0
+ self._notifier.addListener(self._callback)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ # Do not return True so exceptions are propagated
+ self._notifier.removeListener(self._callback)
+ assert self._callCount == self._count
+ self._callCount = None
+
+ def _callback(self, *args, **kwargs):
+ self._callCount += 1
+
+
+class TestPlaneParameters(ParametricTestCase):
+ """Test Plane.parameters read/write and notifications."""
+
+ PARAMETERS = {
+ 'unit normal': (1., 0., 0., 1.),
+ 'not unit normal': (1., 1., 0., 1.),
+ 'd = 0': (1., 0., 0., 0.)
+ }
+
+ def testParameters(self):
+ """Check parameters read/write and notification."""
+ plane = utils.Plane()
+
+ for name, parameters in self.PARAMETERS.items():
+ with self.subTest(name, parameters=parameters):
+ with AssertNotificationContext(plane):
+ plane.parameters = parameters
+
+ # Plane parameters are converted to have a unit normal
+ normparams = parameters / numpy.linalg.norm(parameters[:3])
+ self.assertTrue(numpy.allclose(plane.parameters, normparams))
+
+ ZEROS_PARAMETERS = (
+ (0., 0., 0., 0.),
+ (0., 0., 0., 1.)
+ )
+
+ ZEROS = 0., 0., 0., 0.
+
+ def testParametersNoPlane(self):
+ """Test Plane.parameters with ||normal|| == 0 ."""
+ plane = utils.Plane()
+ plane.parameters = self.ZEROS
+
+ for parameters in self.ZEROS_PARAMETERS:
+ with self.subTest(parameters=parameters):
+ with AssertNotificationContext(plane, count=0):
+ plane.parameters = parameters
+ self.assertTrue(
+ numpy.allclose(plane.parameters, self.ZEROS, 0., 0.))
+
+
+# unindexArrays ###############################################################
+
+class TestUnindexArrays(ParametricTestCase):
+ """Test unindexArrays function."""
+
+ def testBasicModes(self):
+ """Test for modes: points, lines and triangles"""
+ indices = numpy.array((1, 2, 0))
+ arrays = (numpy.array((0., 1., 2.)),
+ numpy.array(((0, 0), (1, 1), (2, 2))))
+ refresults = (numpy.array((1., 2., 0.)),
+ numpy.array(((1, 1), (2, 2), (0, 0))))
+
+ for mode in ('points', 'lines', 'triangles'):
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testPackedLines(self):
+ """Test for modes: line_strip, loop"""
+ indices = numpy.array((1, 2, 0))
+ arrays = (numpy.array((0., 1., 2.)),
+ numpy.array(((0, 0), (1, 1), (2, 2))))
+ results = {
+ 'line_strip': (
+ numpy.array((1., 2., 2., 0.)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0)))),
+ 'loop': (
+ numpy.array((1., 2., 2., 0., 0., 1.)),
+ numpy.array(((1, 1), (2, 2), (2, 2), (0, 0), (0, 0), (1, 1)))),
+ }
+
+ for mode, refresults in results.items():
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testPackedTriangles(self):
+ """Test for modes: triangle_strip, fan"""
+ indices = numpy.array((1, 2, 0, 3))
+ arrays = (numpy.array((0., 1., 2., 3.)),
+ numpy.array(((0, 0), (1, 1), (2, 2), (3, 3))))
+ results = {
+ 'triangle_strip': (
+ numpy.array((1., 2., 0., 2., 0., 3.)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (2, 2), (0, 0), (3, 3)))),
+ 'fan': (
+ numpy.array((1., 2., 0., 1., 0., 3.)),
+ numpy.array(((1, 1), (2, 2), (0, 0), (1, 1), (0, 0), (3, 3)))),
+ }
+
+ for mode, refresults in results.items():
+ with self.subTest(mode=mode):
+ testresults = utils.unindexArrays(mode, indices, *arrays)
+ for ref, test in zip(refresults, testresults):
+ self.assertTrue(numpy.equal(ref, test).all())
+
+ def testBadIndices(self):
+ """Test with negative indices and indices higher than array length"""
+ arrays = numpy.array((0, 1)), numpy.array((0, 1, 2))
+
+ # negative indices
+ with self.assertRaises(AssertionError):
+ utils.unindexArrays('points', (-1, 0), *arrays)
+
+ # Too high indices
+ with self.assertRaises(AssertionError):
+ utils.unindexArrays('points', (0, 10), *arrays)
+
+
+# triangleNormals #############################################################
+
+class TestTriangleNormals(ParametricTestCase):
+ """Test triangleNormals function."""
+
+ def test(self):
+ """Test for modes: points, lines and triangles"""
+ positions = numpy.array(
+ ((0., 0., 0.), (1., 0., 0.), (0., 1., 0.), # normal = Z
+ (1., 1., 1.), (1., 2., 3.), (4., 5., 6.), # Random triangle
+ # Degenerated triangles:
+ (0., 0., 0.), (1., 0., 0.), (2., 0., 0.), # Colinear points
+ (1., 1., 1.), (1., 1., 1.), (1., 1., 1.), # All same point
+ ),
+ dtype='float32')
+
+ normals = numpy.array(
+ ((0., 0., 1.),
+ (-0.40824829, 0.81649658, -0.40824829),
+ (0., 0., 0.),
+ (0., 0., 0.)),
+ dtype='float32')
+
+ testnormals = utils.trianglesNormal(positions)
+ self.assertTrue(numpy.allclose(testnormals, normals))
diff --git a/silx/gui/plot3d/scene/text.py b/src/silx/gui/plot3d/scene/text.py
index bacc2e6..bacc2e6 100644
--- a/silx/gui/plot3d/scene/text.py
+++ b/src/silx/gui/plot3d/scene/text.py
diff --git a/silx/gui/plot3d/scene/transform.py b/src/silx/gui/plot3d/scene/transform.py
index 43b739b..43b739b 100644
--- a/silx/gui/plot3d/scene/transform.py
+++ b/src/silx/gui/plot3d/scene/transform.py
diff --git a/silx/gui/plot3d/scene/utils.py b/src/silx/gui/plot3d/scene/utils.py
index c6cd129..c6cd129 100644
--- a/silx/gui/plot3d/scene/utils.py
+++ b/src/silx/gui/plot3d/scene/utils.py
diff --git a/silx/gui/plot3d/scene/viewport.py b/src/silx/gui/plot3d/scene/viewport.py
index 6de640e..6de640e 100644
--- a/silx/gui/plot3d/scene/viewport.py
+++ b/src/silx/gui/plot3d/scene/viewport.py
diff --git a/src/silx/gui/plot3d/scene/window.py b/src/silx/gui/plot3d/scene/window.py
new file mode 100644
index 0000000..b92c404
--- /dev/null
+++ b/src/silx/gui/plot3d/scene/window.py
@@ -0,0 +1,433 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a class for Viewports rendering on the screen.
+
+The :class:`Window` renders a list of Viewports in the current framebuffer.
+The rendering can be performed in an off-screen framebuffer that is only
+updated when the scene has changed and not each time Qt is requiring a repaint.
+
+The :class:`Context` and :class:`ContextGL2` represent the operating system
+OpenGL context and handle OpenGL resources.
+"""
+
+from __future__ import absolute_import, division, unicode_literals
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/01/2017"
+
+
+import weakref
+import numpy
+
+from ..._glutils import gl
+from ... import _glutils
+
+from . import event
+
+
+class Context(object):
+ """Correspond to an operating system OpenGL context.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param glContextHandle: System specific OpenGL context handle.
+ """
+
+ def __init__(self, glContextHandle):
+ self._context = glContextHandle
+ self._isCurrent = False
+ self._devicePixelRatio = 1.0
+
+ @property
+ def isCurrent(self):
+ """Whether this OpenGL context is the current one or not."""
+ return self._isCurrent
+
+ def setCurrent(self, isCurrent=True):
+ """Set the state of the OpenGL context to reflect OpenGL state.
+
+ This should not be called from the scene graph, only in the
+ wrapper that handle the OpenGL context to reflect its state.
+
+ :param bool isCurrent: The state of the system OpenGL context.
+ """
+ self._isCurrent = bool(isCurrent)
+
+ @property
+ def devicePixelRatio(self):
+ """Ratio between device and device independent pixels (float)
+
+ This is useful for font rendering.
+ """
+ return self._devicePixelRatio
+
+ @devicePixelRatio.setter
+ def devicePixelRatio(self, ratio):
+ assert ratio > 0
+ self._devicePixelRatio = float(ratio)
+
+ def __enter__(self):
+ self.setCurrent(True)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.setCurrent(False)
+
+ @property
+ def glContext(self):
+ """The handle to the OpenGL context provided by the system."""
+ return self._context
+
+ def cleanGLGarbage(self):
+ """This is releasing OpenGL resource that are no longer used."""
+ pass
+
+
+class ContextGL2(Context):
+ """Handle a system GL2 context.
+
+ User should NEVER use an instance of this class beyond the method
+ it is passed to as an argument (i.e., do not keep a reference to it).
+
+ :param glContextHandle: System specific OpenGL context handle.
+ """
+ def __init__(self, glContextHandle):
+ super(ContextGL2, self).__init__(glContextHandle)
+
+ self._programs = {} # GL programs already compiled
+ self._vbos = {} # GL Vbos already set
+ self._vboGarbage = [] # Vbos waiting to be discarded
+
+ # programs
+
+ def prog(self, vertexShaderSrc, fragmentShaderSrc, attrib0='position'):
+ """Cache program within context.
+
+ WARNING: No clean-up.
+
+ :param str vertexShaderSrc: Vertex shader source code
+ :param str fragmentShaderSrc: Fragment shader source code
+ :param str attrib0:
+ Attribute's name to bind to position 0 (default: 'position').
+ On some platform, this attribute MUST be active and with an
+ array attached to it in order for the rendering to occur....
+ """
+ assert self.isCurrent
+ key = vertexShaderSrc, fragmentShaderSrc, attrib0
+ program = self._programs.get(key, None)
+ if program is None:
+ program = _glutils.Program(
+ vertexShaderSrc, fragmentShaderSrc, attrib0=attrib0)
+ self._programs[key] = program
+ return program
+
+ # VBOs
+
+ def makeVbo(self, data=None, sizeInBytes=None,
+ usage=None, target=None):
+ """Create a VBO in this context with the data.
+
+ Current limitations:
+
+ - One array per VBO
+ - Do not support sharing VertexBuffer across VboAttrib
+
+ Automatically discards the VBO when the returned
+ :class:`VertexBuffer` istance is deleted.
+
+ :param numpy.ndarray data: 2D array of data to store in VBO or None.
+ :param int sizeInBytes: Size of the VBO or None.
+ It should be <= data.nbytes if both are given.
+ :param usage: OpenGL usage define in VertexBuffer._USAGES.
+ :param target: OpenGL target in VertexBuffer._TARGETS.
+ :return: The VertexBuffer created in this context.
+ """
+ assert self.isCurrent
+ vbo = _glutils.VertexBuffer(data, sizeInBytes, usage, target)
+ vboref = weakref.ref(vbo, self._deadVbo)
+ # weakref is hashable as far as target is
+ self._vbos[vboref] = vbo.name
+ return vbo
+
+ def makeVboAttrib(self, data, usage=None, target=None):
+ """Create a VBO from data and returns the associated VBOAttrib.
+
+ Automatically discards the VBO when the returned
+ :class:`VBOAttrib` istance is deleted.
+
+ :param numpy.ndarray data: 2D array of data to store in VBO or None.
+ :param usage: OpenGL usage define in VertexBuffer._USAGES.
+ :param target: OpenGL target in VertexBuffer._TARGETS.
+ :returns: A VBOAttrib instance created in this context.
+ """
+ assert self.isCurrent
+ vbo = self.makeVbo(data, usage=usage, target=target)
+
+ assert len(data.shape) <= 2
+ dimension = 1 if len(data.shape) == 1 else data.shape[1]
+
+ return _glutils.VertexBufferAttrib(
+ vbo,
+ type_=_glutils.numpyToGLType(data.dtype),
+ size=data.shape[0],
+ dimension=dimension,
+ offset=0,
+ stride=0)
+
+ def _deadVbo(self, vboRef):
+ """Callback handling dead VBOAttribs."""
+ vboid = self._vbos.pop(vboRef)
+ if self.isCurrent:
+ # Direct delete if context is active
+ gl.glDeleteBuffers(vboid)
+ else:
+ # Deferred VBO delete if context is not active
+ self._vboGarbage.append(vboid)
+
+ def cleanGLGarbage(self):
+ """Delete OpenGL resources that are pending for destruction.
+
+ This requires the associated OpenGL context to be active.
+ This is meant to be called before rendering.
+ """
+ assert self.isCurrent
+ if self._vboGarbage:
+ vboids = self._vboGarbage
+ gl.glDeleteBuffers(vboids)
+ self._vboGarbage = []
+
+
+class Window(event.Notifier):
+ """OpenGL Framebuffer where to render viewports
+
+ :param str mode: Rendering mode to use:
+
+ - 'direct' to render everything for each render call
+ - 'framebuffer' to cache viewport rendering in a texture and
+ update the texture only when needed.
+ """
+
+ _position = numpy.array(((-1., -1., 0., 0.),
+ (1., -1., 1., 0.),
+ (-1., 1., 0., 1.),
+ (1., 1., 1., 1.)),
+ dtype=numpy.float32)
+
+ _shaders = ("""
+ attribute vec4 position;
+ varying vec2 textureCoord;
+
+ void main(void) {
+ gl_Position = vec4(position.x, position.y, 0., 1.);
+ textureCoord = position.zw;
+ }
+ """,
+ """
+ uniform sampler2D texture;
+ varying vec2 textureCoord;
+
+ void main(void) {
+ gl_FragColor = texture2D(texture, textureCoord);
+ gl_FragColor.a = 1.0;
+ }
+ """)
+
+ def __init__(self, mode='framebuffer'):
+ super(Window, self).__init__()
+ self._dirty = True
+ self._size = 0, 0
+ self._contexts = {} # To map system GL context id to Context objects
+ self._viewports = event.NotifierList()
+ self._viewports.addListener(self._updated)
+ self._framebufferid = 0
+ self._framebuffers = {} # Cache of framebuffers
+
+ assert mode in ('direct', 'framebuffer')
+ self._isframebuffer = mode == 'framebuffer'
+
+ @property
+ def dirty(self):
+ """True if this object or any attached viewports is dirty."""
+ for viewport in self._viewports:
+ if viewport.dirty:
+ return True
+ return self._dirty
+
+ @property
+ def size(self):
+ """Size (width, height) of the window in pixels"""
+ return self._size
+
+ @size.setter
+ def size(self, size):
+ w, h = size
+ size = int(w), int(h)
+ if size != self._size:
+ self._size = size
+ self._dirty = True
+ self.notify()
+
+ @property
+ def shape(self):
+ """Shape (height, width) of the window in pixels.
+
+ This is a convenient wrapper to the reverse of size.
+ """
+ return self._size[1], self._size[0]
+
+ @shape.setter
+ def shape(self, shape):
+ self.size = shape[1], shape[0]
+
+ @property
+ def viewports(self):
+ """List of viewports to render in the corresponding framebuffer"""
+ return self._viewports
+
+ @viewports.setter
+ def viewports(self, iterable):
+ self._viewports.removeListener(self._updated)
+ self._viewports = event.NotifierList(iterable)
+ self._viewports.addListener(self._updated)
+ self._updated(self)
+
+ def _updated(self, source, *args, **kwargs):
+ self._dirty = True
+ self.notify(*args, **kwargs)
+
+ framebufferid = property(lambda self: self._framebufferid,
+ doc="Framebuffer ID used to perform rendering")
+
+ def grab(self, glcontext):
+ """Returns the raster of the scene as an RGB numpy array
+
+ :returns: OpenGL scene RGB bitmap
+ as an array of dimension (height, width, 3)
+ :rtype: numpy.ndarray of uint8
+ """
+ height, width = self.shape
+ image = numpy.empty((height, width, 3), dtype=numpy.uint8)
+
+ previousFramebuffer = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.framebufferid)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(
+ 0, 0, width, height, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, image)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, previousFramebuffer)
+
+ # glReadPixels gives bottom to top,
+ # while images are stored as top to bottom
+ image = numpy.flipud(image)
+
+ return numpy.array(image, copy=False, order='C')
+
+ def render(self, glcontext, devicePixelRatio):
+ """Perform the rendering of attached viewports
+
+ :param glcontext: System identifier of the OpenGL context
+ :param float devicePixelRatio:
+ Ratio between device and device-independent pixels
+ """
+ if self.size == (0, 0):
+ return
+
+ if glcontext not in self._contexts:
+ self._contexts[glcontext] = ContextGL2(glcontext) # New context
+
+ with self._contexts[glcontext] as context:
+ context.devicePixelRatio = devicePixelRatio
+ if self._isframebuffer:
+ self._renderWithOffscreenFramebuffer(context)
+ else:
+ self._renderDirect(context)
+
+ self._dirty = False
+
+ def _renderDirect(self, context):
+ """Perform the direct rendering of attached viewports
+
+ :param Context context: Object wrapping OpenGL context
+ """
+ for viewport in self._viewports:
+ viewport.framebuffer = self.framebufferid
+ viewport.render(context)
+ viewport.resetDirty()
+
+ def _renderWithOffscreenFramebuffer(self, context):
+ """Renders viewports in a texture and render this texture on screen.
+
+ The texture is updated only if viewport or size has changed.
+
+ :param ContextGL2 context: Object wrappign OpenGL context
+ """
+ if self.dirty or context not in self._framebuffers:
+ # Need to redraw framebuffer content
+
+ if (context not in self._framebuffers or
+ self._framebuffers[context].shape != self.shape):
+ # Need to rebuild framebuffer
+
+ if context in self._framebuffers:
+ self._framebuffers[context].discard()
+
+ fbo = _glutils.FramebufferTexture(gl.GL_RGBA,
+ shape=self.shape,
+ minFilter=gl.GL_NEAREST,
+ magFilter=gl.GL_NEAREST,
+ wrap=gl.GL_CLAMP_TO_EDGE)
+ self._framebuffers[context] = fbo
+ self._framebufferid = fbo.name
+
+ # Render in framebuffer
+ with self._framebuffers[context]:
+ self._renderDirect(context)
+
+ # Render framebuffer texture to screen
+ fbo = self._framebuffers[context]
+ height, width = fbo.shape
+
+ program = context.prog(*self._shaders)
+ program.use()
+
+ gl.glViewport(0, 0, width, height)
+ gl.glDisable(gl.GL_BLEND)
+ gl.glDisable(gl.GL_DEPTH_TEST)
+ gl.glDisable(gl.GL_SCISSOR_TEST)
+ # gl.glScissor(0, 0, width, height)
+ gl.glClearColor(0., 0., 0., 0.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+ gl.glUniform1i(program.uniforms['texture'], fbo.texture.texUnit)
+ gl.glEnableVertexAttribArray(program.attributes['position'])
+ gl.glVertexAttribPointer(program.attributes['position'],
+ 4,
+ gl.GL_FLOAT,
+ gl.GL_FALSE,
+ 0,
+ self._position)
+ fbo.texture.bind()
+ gl.glDrawArrays(gl.GL_TRIANGLE_STRIP, 0, len(self._position))
+ gl.glBindTexture(gl.GL_TEXTURE_2D, 0)
diff --git a/silx/gui/plot3d/setup.py b/src/silx/gui/plot3d/setup.py
index 59c0230..59c0230 100644
--- a/silx/gui/plot3d/setup.py
+++ b/src/silx/gui/plot3d/setup.py
diff --git a/src/silx/gui/plot3d/test/__init__.py b/src/silx/gui/plot3d/test/__init__.py
new file mode 100644
index 0000000..83491ad
--- /dev/null
+++ b/src/silx/gui/plot3d/test/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""plot3d test suite."""
diff --git a/src/silx/gui/plot3d/test/testGL.py b/src/silx/gui/plot3d/test/testGL.py
new file mode 100644
index 0000000..a7309a9
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testGL.py
@@ -0,0 +1,73 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test OpenGL"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "10/08/2017"
+
+
+import logging
+import unittest
+
+from silx.gui._glutils import gl, OpenGLWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestOpenGL(TestCaseQt):
+ """Tests of OpenGL widget."""
+
+ class OpenGLWidgetLogger(OpenGLWidget):
+ """Widget logging information of available OpenGL version"""
+
+ def __init__(self):
+ self._dump = False
+ super(TestOpenGL.OpenGLWidgetLogger, self).__init__(version=(1, 0))
+
+ def paintOpenGL(self):
+ """Perform the rendering and logging"""
+ if not self._dump:
+ self._dump = True
+ _logger.info('OpenGL info:')
+ _logger.info('\tQt OpenGL context version: %d.%d', *self.getOpenGLVersion())
+ _logger.info('\tGL_VERSION: %s' % gl.glGetString(gl.GL_VERSION))
+ _logger.info('\tGL_SHADING_LANGUAGE_VERSION: %s' %
+ gl.glGetString(gl.GL_SHADING_LANGUAGE_VERSION))
+ _logger.debug('\tGL_EXTENSIONS: %s' % gl.glGetString(gl.GL_EXTENSIONS))
+
+ gl.glClearColor(1., 1., 1., 1.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+ def testOpenGL(self):
+ """Log OpenGL version using an OpenGLWidget"""
+ super(TestOpenGL, self).setUp()
+ widget = self.OpenGLWidgetLogger()
+ widget.show()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.qWaitForWindowExposed(widget)
+ widget.close()
diff --git a/src/silx/gui/plot3d/test/testScalarFieldView.py b/src/silx/gui/plot3d/test/testScalarFieldView.py
new file mode 100644
index 0000000..e6535fc
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testScalarFieldView.py
@@ -0,0 +1,128 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test ScalarFieldView widget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
+from silx.gui.plot3d.SFViewParamTree import TreeView
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestScalarFieldView(TestCaseQt, ParametricTestCase):
+ """Tests of ScalarFieldView widget."""
+
+ def setUp(self):
+ super(TestScalarFieldView, self).setUp()
+ self.widget = ScalarFieldView()
+ self.widget.show()
+
+ paramTreeWidget = TreeView()
+ paramTreeWidget.setSfView(self.widget)
+
+ dock = qt.QDockWidget()
+ dock.setWidget(paramTreeWidget)
+ self.widget.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+
+ # Commented as it slows down the tests
+ # self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ super(TestScalarFieldView, self).tearDown()
+
+ @staticmethod
+ def _buildData(size):
+ """Make a 3D dataset"""
+ coords = numpy.linspace(-10, 10, size)
+ z = coords.reshape(-1, 1, 1)
+ y = coords.reshape(1, -1, 1)
+ x = coords.reshape(1, 1, -1)
+ return numpy.sin(x * y * z) / (x * y * z)
+
+ def testSimple(self):
+ """Set the data and an isosurface"""
+ data = self._buildData(size=32)
+
+ self.widget.setData(data)
+ self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
+ self.widget.addIsosurface(0.7, qt.QColor('green'))
+ self.qapp.processEvents()
+
+ def testNotFinite(self):
+ """Test with NaN and inf in data set"""
+
+ # Some NaNs and inf
+ data = self._buildData(size=32)
+ data[8, :, :] = numpy.nan
+ data[16, :, :] = numpy.inf
+ data[24, :, :] = - numpy.inf
+
+ self.widget.addIsosurface(0.5, 'red')
+ self.widget.setData(data, copy=True)
+ self.qapp.processEvents()
+ self.widget.setData(None)
+
+ # All NaNs or inf
+ data = numpy.empty((4, 4, 4), dtype=numpy.float32)
+ for value in (numpy.nan, numpy.inf):
+ with self.subTest(value=str(value)):
+ data[:] = value
+ self.widget.setData(data, copy=True)
+ self.qapp.processEvents()
+
+ def testIsoSliderNormalization(self):
+ """Test set TreeView with a different isoslider normalization"""
+ data = self._buildData(size=32)
+
+ self.widget.setData(data)
+ self.widget.addIsosurface(0.5, (1., 0., 0., 0.5))
+ self.widget.addIsosurface(0.7, qt.QColor('green'))
+ self.qapp.processEvents()
+
+ # Add a second TreeView
+ paramTreeWidget = TreeView(self.widget)
+ paramTreeWidget.setIsoLevelSliderNormalization('arcsinh')
+ paramTreeWidget.setSfView(self.widget)
+
+ dock = qt.QDockWidget()
+ dock.setWidget(paramTreeWidget)
+ self.widget.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
diff --git a/src/silx/gui/plot3d/test/testSceneWidget.py b/src/silx/gui/plot3d/test/testSceneWidget.py
new file mode 100644
index 0000000..fc96781
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testSceneWidget.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test SceneWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2019"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWidget import SceneWidget
+
+
+class TestSceneWidget(TestCaseQt, ParametricTestCase):
+ """Tests SceneWidget picking feature"""
+
+ def setUp(self):
+ super(TestSceneWidget, self).setUp()
+ self.widget = SceneWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ super(TestSceneWidget, self).tearDown()
+
+ def testFogEffect(self):
+ """Test fog effect on scene primitive"""
+ image = self.widget.addImage(numpy.arange(100).reshape(10, 10))
+ scatter = self.widget.add3DScatter(*numpy.random.random(4000).reshape(4, -1))
+ scatter.setTranslation(10, 10)
+ scatter.setScale(10, 10, 10)
+
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ self.widget.setFogMode(self.widget.FogMode.LINEAR)
+ self.qapp.processEvents()
+
+ self.widget.setFogMode(self.widget.FogMode.NONE)
+ self.qapp.processEvents()
diff --git a/src/silx/gui/plot3d/test/testSceneWidgetPicking.py b/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
new file mode 100644
index 0000000..d4d8db7
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testSceneWidgetPicking.py
@@ -0,0 +1,314 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test SceneWidget picking feature"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/10/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWidget import SceneWidget, items
+
+
+class TestSceneWidgetPicking(TestCaseQt, ParametricTestCase):
+ """Tests SceneWidget picking feature"""
+
+ def setUp(self):
+ super(TestSceneWidgetPicking, self).setUp()
+ self.widget = SceneWidget()
+ self.widget.resize(300, 300)
+ self.widget.show()
+ # self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ super(TestSceneWidgetPicking, self).tearDown()
+
+ def _widgetCenter(self):
+ """Returns widget center"""
+ size = self.widget.size()
+ return size.width() // 2, size.height() // 2
+
+ def testPickImage(self):
+ """Test picking of ImageData and ImageRgba items"""
+ imageData = items.ImageData()
+ imageData.setData(numpy.arange(100).reshape(10, 10))
+
+ imageRgba = items.ImageRgba()
+ imageRgba.setData(
+ numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
+
+ for item in (imageData, imageRgba):
+ with self.subTest(item=item.__class__.__name__):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ self.assertEqual(picking[0].getPositions('ndc').shape, (1, 3))
+ data = picking[0].getData()
+ self.assertEqual(len(data), 1)
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickScatter(self):
+ """Test picking of Scatter2D and Scatter3D items"""
+ data = numpy.arange(100)
+
+ scatter2d = items.Scatter2D()
+ scatter2d.setData(x=data, y=data, value=data)
+
+ scatter3d = items.Scatter3D()
+ scatter3d.setData(x=data, y=data, z=data, value=data)
+
+ for item in (scatter2d, scatter3d):
+ with self.subTest(item=item.__class__.__name__):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions('ndc'))
+ data = picking[0].getData()
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getValueData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickVolume(self):
+ """Test picking of volume CutPlane and Isosurface items"""
+ for dtype in (numpy.float32, numpy.complex64):
+ with self.subTest(dtype=dtype):
+ refData = numpy.arange(10**3, dtype=dtype).reshape(10, 10, 10)
+ volume = self.widget.addVolume(refData)
+ if dtype == numpy.complex64:
+ volume.setComplexMode(volume.ComplexMode.REAL)
+ refData = numpy.real(refData)
+ self.widget.resetZoom('front')
+
+ cutplane = volume.getCutPlanes()[0]
+ if dtype == numpy.complex64:
+ cutplane.setComplexMode(volume.ComplexMode.REAL)
+ cutplane.getColormap().setVRange(0, 100)
+ cutplane.setNormal((0, 0, 1))
+
+ # Picking on data without anything displayed
+ cutplane.setVisible(False)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+ self.assertEqual(len(picking), 0)
+
+ # Picking on data with the cut plane
+ cutplane.setVisible(True)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), cutplane)
+ data = picking[0].getData()
+ self.assertEqual(len(data), 1)
+ self.assertEqual(picking[0].getPositions().shape, (1, 3))
+ self.assertTrue(numpy.array_equal(
+ data,
+ refData[picking[0].getIndices()]))
+
+ # Picking on data with an isosurface
+ isosurface = volume.addIsosurface(
+ level=500, color=(1., 0., 0., .5))
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+ self.assertEqual(len(picking), 2)
+ self.assertIs(picking[0].getItem(), cutplane)
+ self.assertIs(picking[1].getItem(), isosurface)
+ self.assertEqual(picking[1].getPositions().shape, (1, 3))
+ data = picking[1].getData()
+ self.assertEqual(len(data), 1)
+ self.assertTrue(numpy.array_equal(
+ data,
+ refData[picking[1].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ self.widget.clearItems()
+
+ def testPickMesh(self):
+ """Test picking of Mesh items"""
+
+ triangles = items.Mesh()
+ triangles.setData(
+ position=((0, 0, 0), (1, 0, 0), (1, 1, 0),
+ (0, 0, 0), (1, 1, 0), (0, 1, 0)),
+ color=(1, 0, 0, 1),
+ mode='triangles')
+ triangleStrip = items.Mesh()
+ triangleStrip.setData(
+ position=(((1, 0, 0), (0, 0, 0), (1, 1, 0), (0, 1, 0))),
+ color=(0, 1, 0, 1),
+ mode='triangle_strip')
+ triangleFan = items.Mesh()
+ triangleFan.setData(
+ position=((0, 0, 0), (1, 0, 0), (1, 1, 0), (0, 1, 0)),
+ color=(0, 0, 1, 1),
+ mode='fan')
+
+ for item in (triangles, triangleStrip, triangleFan):
+ with self.subTest(mode=item.getDrawMode()):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions())
+ data = picking[0].getData()
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getPositionData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickMeshWithIndices(self):
+ """Test picking of Mesh items defined by indices"""
+
+ triangles = items.Mesh()
+ triangles.setData(
+ position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
+ color=(1, 0, 0, 1),
+ indices=numpy.array( # dummy triangles and square
+ (0, 0, 1, 0, 1, 2, 1, 2, 3), dtype=numpy.uint8),
+ mode='triangles')
+ triangleStrip = items.Mesh()
+ triangleStrip.setData(
+ position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
+ color=(0, 1, 0, 1),
+ indices=numpy.array( # dummy triangles and square
+ (1, 0, 0, 1, 2, 3), dtype=numpy.uint8),
+ mode='triangle_strip')
+ triangleFan = items.Mesh()
+ triangleFan.setData(
+ position=((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)),
+ color=(0, 0, 1, 1),
+ indices=numpy.array( # dummy triangle, square, dummy
+ (1, 1, 0, 2, 3, 3), dtype=numpy.uint8),
+ mode='fan')
+
+ for item in (triangles, triangleStrip, triangleFan):
+ with self.subTest(mode=item.getDrawMode()):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions())
+ data = picking[0].getData()
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getPositionData()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
+
+ def testPickCylindricalMesh(self):
+ """Test picking of Box, Cylinder and Hexagon items"""
+
+ positions = numpy.array(((0., 0., 0.), (1., 1., 0.), (2., 2., 0.)))
+ box = items.Box()
+ box.setData(position=positions)
+ cylinder = items.Cylinder()
+ cylinder.setData(position=positions)
+ hexagon = items.Hexagon()
+ hexagon.setData(position=positions)
+
+ for item in (box, cylinder, hexagon):
+ with self.subTest(item=item.__class__.__name__):
+ # Add item
+ self.widget.clearItems()
+ self.widget.addItem(item)
+ self.widget.resetZoom('front')
+ self.qapp.processEvents()
+
+ # Picking on data (at widget center)
+ picking = list(self.widget.pickItems(*self._widgetCenter()))
+
+ self.assertEqual(len(picking), 1)
+ self.assertIs(picking[0].getItem(), item)
+ nbPos = len(picking[0].getPositions())
+ data = picking[0].getData()
+ print(item.__class__.__name__, [positions[1]], data)
+ self.assertTrue(numpy.all(numpy.equal(positions[1], data)))
+ self.assertEqual(nbPos, len(data))
+ self.assertTrue(numpy.array_equal(
+ data,
+ item.getPosition()[picking[0].getIndices()]))
+
+ # Picking outside data
+ picking = list(self.widget.pickItems(1, 1))
+ self.assertEqual(len(picking), 0)
diff --git a/src/silx/gui/plot3d/test/testSceneWindow.py b/src/silx/gui/plot3d/test/testSceneWindow.py
new file mode 100644
index 0000000..6b61335
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testSceneWindow.py
@@ -0,0 +1,233 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test SceneWindow"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "22/03/2019"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWindow import SceneWindow
+from silx.gui.plot3d.items import HeightMapData, HeightMapRGBA
+
+class TestSceneWindow(TestCaseQt, ParametricTestCase):
+ """Tests SceneWidget picking feature"""
+
+ def setUp(self):
+ super(TestSceneWindow, self).setUp()
+ self.window = SceneWindow()
+ self.window.show()
+ self.qWaitForWindowExposed(self.window)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+ self.window.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.window.close()
+ del self.window
+ super(TestSceneWindow, self).tearDown()
+
+ def testAdd(self):
+ """Test add basic scene primitive"""
+ sceneWidget = self.window.getSceneWidget()
+ items = []
+
+ # RGB image
+ image = sceneWidget.addImage(numpy.random.random(
+ 10*10*3).astype(numpy.float32).reshape(10, 10, 3))
+ image.setLabel('RGB image')
+ items.append(image)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # Data image
+ image = sceneWidget.addImage(
+ numpy.arange(100, dtype=numpy.float32).reshape(10, 10))
+ image.setTranslation(10.)
+ items.append(image)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 2D scatter
+ scatter = sceneWidget.add2DScatter(
+ *numpy.random.random(3000).astype(numpy.float32).reshape(3, -1),
+ index=0)
+ scatter.setTranslation(0, 10)
+ scatter.setScale(10, 10, 10)
+ items.insert(0, scatter)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 3D scatter
+ scatter = sceneWidget.add3DScatter(
+ *numpy.random.random(4000).astype(numpy.float32).reshape(4, -1))
+ scatter.setTranslation(10, 10)
+ scatter.setScale(10, 10, 10)
+ items.append(scatter)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 3D array of float
+ volume = sceneWidget.addVolume(
+ numpy.arange(10**3, dtype=numpy.float32).reshape(10, 10, 10))
+ volume.setTranslation(0, 0, 10)
+ volume.setRotation(45, (0, 0, 1))
+ volume.addIsosurface(500, 'red')
+ volume.getCutPlanes()[0].getColormap().setName('viridis')
+ items.append(volume)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # 3D array of complex
+ volume = sceneWidget.addVolume(
+ numpy.arange(10**3).reshape(10, 10, 10).astype(numpy.complex64))
+ volume.setTranslation(10, 0, 10)
+ volume.setRotation(45, (0, 0, 1))
+ volume.setComplexMode(volume.ComplexMode.REAL)
+ volume.addIsosurface(500, (1., 0., 0., .5))
+ items.append(volume)
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ sceneWidget.resetZoom('front')
+ self.qapp.processEvents()
+
+ def testHeightMap(self):
+ """Test height map items"""
+ sceneWidget = self.window.getSceneWidget()
+
+ height = numpy.arange(10000).reshape(100, 100) /100.
+
+ for shape in ((100, 100), (4, 5), (150, 20), (110, 110)):
+ with self.subTest(shape=shape):
+ items = []
+
+ # Colormapped data height map
+ data = numpy.arange(numpy.prod(shape)).astype(numpy.float32).reshape(shape)
+
+ heightmap = HeightMapData()
+ heightmap.setData(height)
+ heightmap.setColormappedData(data)
+ heightmap.getColormap().setName('viridis')
+ items.append(heightmap)
+ sceneWidget.addItem(heightmap)
+
+ # RGBA height map
+ colors = numpy.zeros(shape + (3,), dtype=numpy.float32)
+ colors[:, :, 1] = numpy.random.random(shape)
+
+ heightmap = HeightMapRGBA()
+ heightmap.setData(height)
+ heightmap.setColorData(colors)
+ heightmap.setTranslation(100., 0., 0.)
+ items.append(heightmap)
+ sceneWidget.addItem(heightmap)
+
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+ sceneWidget.resetZoom('front')
+ self.qapp.processEvents()
+ sceneWidget.clearItems()
+
+ def testChangeContent(self):
+ """Test add/remove/clear items"""
+ sceneWidget = self.window.getSceneWidget()
+ items = []
+
+ # Add 2 images
+ image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
+ items.append(sceneWidget.addImage(image))
+ items.append(sceneWidget.addImage(image))
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getItems(), tuple(items))
+
+ # Clear
+ sceneWidget.clearItems()
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getItems(), ())
+
+ # Add 2 images and remove first one
+ image = numpy.arange(100, dtype=numpy.float32).reshape(10, 10)
+ sceneWidget.addImage(image)
+ items = (sceneWidget.addImage(image),)
+ self.qapp.processEvents()
+
+ sceneWidget.removeItem(sceneWidget.getItems()[0])
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getItems(), items)
+
+ def testColors(self):
+ """Test setting scene colors"""
+ sceneWidget = self.window.getSceneWidget()
+
+ color = qt.QColor(128, 128, 128)
+ sceneWidget.setBackgroundColor(color)
+ self.assertEqual(sceneWidget.getBackgroundColor(), color)
+
+ color = qt.QColor(0, 0, 0)
+ sceneWidget.setForegroundColor(color)
+ self.assertEqual(sceneWidget.getForegroundColor(), color)
+
+ color = qt.QColor(255, 0, 0)
+ sceneWidget.setTextColor(color)
+ self.assertEqual(sceneWidget.getTextColor(), color)
+
+ color = qt.QColor(0, 255, 0)
+ sceneWidget.setHighlightColor(color)
+ self.assertEqual(sceneWidget.getHighlightColor(), color)
+
+ self.qapp.processEvents()
+
+ def testInteractiveMode(self):
+ """Test changing interactive mode"""
+ sceneWidget = self.window.getSceneWidget()
+ center = numpy.array((sceneWidget.width() //2, sceneWidget.height() // 2))
+
+ self.mouseMove(sceneWidget, pos=center)
+ self.mouseClick(sceneWidget, qt.Qt.LeftButton, pos=center)
+
+ volume = sceneWidget.addVolume(
+ numpy.arange(10**3).astype(numpy.float32).reshape(10, 10, 10))
+ sceneWidget.selection().setCurrentItem( volume.getCutPlanes()[0])
+ sceneWidget.resetZoom('side')
+
+ for mode in (None, 'rotate', 'pan', 'panSelectedPlane'):
+ with self.subTest(mode=mode):
+ sceneWidget.setInteractiveMode(mode)
+ self.qapp.processEvents()
+ self.assertEqual(sceneWidget.getInteractiveMode(), mode)
+
+ self.mouseMove(sceneWidget, pos=center)
+ self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
+ self.mouseMove(sceneWidget, pos=center-10)
+ self.mouseMove(sceneWidget, pos=center-20)
+ self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
+
+ self.keyPress(sceneWidget, qt.Qt.Key_Control)
+ self.mouseMove(sceneWidget, pos=center)
+ self.mousePress(sceneWidget, qt.Qt.LeftButton, pos=center)
+ self.mouseMove(sceneWidget, pos=center-10)
+ self.mouseMove(sceneWidget, pos=center-20)
+ self.mouseRelease(sceneWidget, qt.Qt.LeftButton, pos=center-20)
+ self.keyRelease(sceneWidget, qt.Qt.Key_Control)
diff --git a/src/silx/gui/plot3d/test/testStatsWidget.py b/src/silx/gui/plot3d/test/testStatsWidget.py
new file mode 100644
index 0000000..d452eb5
--- /dev/null
+++ b/src/silx/gui/plot3d/test/testStatsWidget.py
@@ -0,0 +1,201 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test silx.gui.plot.StatsWidget with SceneWidget and ScalarFieldView"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/01/2019"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.plot.stats.stats import Stats
+from silx.gui import qt
+
+from silx.gui.plot.StatsWidget import BasicStatsWidget
+
+from silx.gui.plot3d.ScalarFieldView import ScalarFieldView
+from silx.gui.plot3d.SceneWidget import SceneWidget, items
+
+
+class TestSceneWidget(TestCaseQt, ParametricTestCase):
+ """Tests StatsWidget combined with SceneWidget"""
+
+ def setUp(self):
+ super(TestSceneWidget, self).setUp()
+ self.sceneWidget = SceneWidget()
+ self.sceneWidget.resize(300, 300)
+ self.sceneWidget.show()
+ self.statsWidget = BasicStatsWidget()
+ self.statsWidget.setPlot(self.sceneWidget)
+ # self.qWaitForWindowExposed(self.sceneWidget)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.sceneWidget.close()
+ del self.sceneWidget
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.statsWidget.close()
+ del self.statsWidget
+ super(TestSceneWidget, self).tearDown()
+
+ def test(self):
+ """Test StatsWidget with SceneWidget"""
+ # Prepare scene
+
+ # Data image
+ image = self.sceneWidget.addImage(numpy.arange(100).reshape(10, 10))
+ image.setLabel('Image')
+ # RGB image
+ imageRGB = self.sceneWidget.addImage(
+ numpy.arange(300, dtype=numpy.uint8).reshape(10, 10, 3))
+ imageRGB.setLabel('RGB Image')
+ # 2D scatter
+ data = numpy.arange(100)
+ scatter2D = self.sceneWidget.add2DScatter(x=data, y=data, value=data)
+ scatter2D.setLabel('2D Scatter')
+ # 3D scatter
+ scatter3D = self.sceneWidget.add3DScatter(x=data, y=data, z=data, value=data)
+ scatter3D.setLabel('3D Scatter')
+ # Add a group
+ group = items.GroupItem()
+ self.sceneWidget.addItem(group)
+ # 3D scalar field
+ data = numpy.arange(64**3).reshape(64, 64, 64)
+ scalarField = items.ScalarField3D()
+ scalarField.setData(data, copy=False)
+ scalarField.setLabel('3D Scalar field')
+ group.addItem(scalarField)
+
+ statsTable = self.statsWidget._getStatsTable()
+
+ # Test selection only
+ self.statsWidget.setDisplayOnlyActiveItem(True)
+ self.assertEqual(statsTable.rowCount(), 0)
+
+ self.sceneWidget.selection().setCurrentItem(group)
+ self.assertEqual(statsTable.rowCount(), 0)
+
+ for item in (image, scatter2D, scatter3D, scalarField):
+ with self.subTest('selection only', item=item.getLabel()):
+ self.sceneWidget.selection().setCurrentItem(item)
+ self.assertEqual(statsTable.rowCount(), 1)
+ self._checkItem(item)
+
+ # Test all data
+ self.statsWidget.setDisplayOnlyActiveItem(False)
+ self.assertEqual(statsTable.rowCount(), 4)
+
+ for item in (image, scatter2D, scatter3D, scalarField):
+ with self.subTest('all items', item=item.getLabel()):
+ self._checkItem(item)
+
+ def _checkItem(self, item):
+ """Check that item is in StatsTable and that stats are OK
+
+ :param silx.gui.plot3d.items.Item3D item:
+ """
+ if isinstance(item, (items.Scatter2D, items.Scatter3D)):
+ data = item.getValueData(copy=False)
+ else:
+ data = item.getData(copy=False)
+
+ statsTable = self.statsWidget._getStatsTable()
+ tableItems = statsTable._itemToTableItems(item)
+ self.assertTrue(len(tableItems) > 0)
+ self.assertEqual(tableItems['legend'].text(), item.getLabel())
+ self.assertEqual(float(tableItems['min'].text()), numpy.min(data))
+ self.assertEqual(float(tableItems['max'].text()), numpy.max(data))
+ # TODO
+
+
+class TestScalarFieldView(TestCaseQt):
+ """Tests StatsWidget combined with ScalarFieldView"""
+
+ def setUp(self):
+ super(TestScalarFieldView, self).setUp()
+ self.scalarFieldView = ScalarFieldView()
+ self.scalarFieldView.resize(300, 300)
+ self.scalarFieldView.show()
+ self.statsWidget = BasicStatsWidget()
+ self.statsWidget.setPlot(self.scalarFieldView)
+ # self.qWaitForWindowExposed(self.sceneWidget)
+
+ def tearDown(self):
+ Stats._getContext.cache_clear()
+ self.qapp.processEvents()
+ self.scalarFieldView.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.scalarFieldView.close()
+ del self.scalarFieldView
+ self.statsWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.statsWidget.close()
+ del self.statsWidget
+ super(TestScalarFieldView, self).tearDown()
+
+ def _getTextFor(self, row, name):
+ """Returns text in table at given row for column name
+
+ :param int row: Row number in the table
+ :param str name: Column id
+ :rtype: Union[str,None]
+ """
+ statsTable = self.statsWidget._getStatsTable()
+
+ for column in range(statsTable.columnCount()):
+ headerItem = statsTable.horizontalHeaderItem(column)
+ if headerItem.data(qt.Qt.UserRole) == name:
+ tableItem = statsTable.item(row, column)
+ return tableItem.text()
+
+ return None
+
+ def test(self):
+ """Test StatsWidget with ScalarFieldView"""
+ data = numpy.arange(64**3, dtype=numpy.float64).reshape(64, 64, 64)
+ self.scalarFieldView.setData(data)
+
+ statsTable = self.statsWidget._getStatsTable()
+
+ # Test selection only
+ self.statsWidget.setDisplayOnlyActiveItem(True)
+ self.assertEqual(statsTable.rowCount(), 1)
+
+ # Test all data
+ self.statsWidget.setDisplayOnlyActiveItem(False)
+ self.assertEqual(statsTable.rowCount(), 1)
+
+ for column in range(statsTable.columnCount()):
+ self.assertEqual(float(self._getTextFor(0, 'min')), numpy.min(data))
+ self.assertEqual(float(self._getTextFor(0, 'max')), numpy.max(data))
+ sum_ = numpy.sum(data)
+ comz = numpy.sum(numpy.arange(data.shape[0]) * numpy.sum(data, axis=(1, 2))) / sum_
+ comy = numpy.sum(numpy.arange(data.shape[1]) * numpy.sum(data, axis=(0, 2))) / sum_
+ comx = numpy.sum(numpy.arange(data.shape[2]) * numpy.sum(data, axis=(0, 1))) / sum_
+ self.assertEqual(self._getTextFor(0, 'COM'), str((comx, comy, comz)))
diff --git a/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py b/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
new file mode 100644
index 0000000..146c2cd
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/GroupPropertiesWidget.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+""":class:`GroupPropertiesWidget` allows to reset properties in a GroupItem."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "24/04/2018"
+
+from ....gui import qt
+from ....gui.colors import Colormap
+from ....gui.dialog.ColormapDialog import ColormapDialog
+
+from ..items import SymbolMixIn, ColormapMixIn
+
+
+class GroupPropertiesWidget(qt.QWidget):
+ """Set properties of all items in a :class:`GroupItem`
+
+ :param QWidget parent:
+ """
+
+ MAX_MARKER_SIZE = 20
+ """Maximum value for marker size"""
+
+ MAX_LINE_WIDTH = 10
+ """Maximum value for line width"""
+
+ def __init__(self, parent=None):
+ super(GroupPropertiesWidget, self).__init__(parent)
+ self._group = None
+ self.setEnabled(False)
+
+ # Set widgets
+ layout = qt.QFormLayout(self)
+ self.setLayout(layout)
+
+ # Colormap
+ colormapButton = qt.QPushButton('Set...')
+ colormapButton.setToolTip("Set colormap for all items")
+ colormapButton.clicked.connect(self._colormapButtonClicked)
+ layout.addRow('Colormap', colormapButton)
+
+ self._markerComboBox = qt.QComboBox(self)
+ self._markerComboBox.addItems(SymbolMixIn.getSupportedSymbolNames())
+
+ # Marker
+ markerButton = qt.QPushButton('Set')
+ markerButton.setToolTip("Set marker for all items")
+ markerButton.clicked.connect(self._markerButtonClicked)
+
+ markerLayout = qt.QHBoxLayout()
+ markerLayout.setContentsMargins(0, 0, 0, 0)
+ markerLayout.addWidget(self._markerComboBox, 1)
+ markerLayout.addWidget(markerButton, 0)
+
+ layout.addRow('Marker', markerLayout)
+
+ # Marker size
+ self._markerSizeSlider = qt.QSlider()
+ self._markerSizeSlider.setOrientation(qt.Qt.Horizontal)
+ self._markerSizeSlider.setSingleStep(1)
+ self._markerSizeSlider.setRange(1, self.MAX_MARKER_SIZE)
+ self._markerSizeSlider.setValue(1)
+
+ markerSizeButton = qt.QPushButton('Set')
+ markerSizeButton.setToolTip("Set marker size for all items")
+ markerSizeButton.clicked.connect(self._markerSizeButtonClicked)
+
+ markerSizeLayout = qt.QHBoxLayout()
+ markerSizeLayout.setContentsMargins(0, 0, 0, 0)
+ markerSizeLayout.addWidget(qt.QLabel('1'))
+ markerSizeLayout.addWidget(self._markerSizeSlider, 1)
+ markerSizeLayout.addWidget(qt.QLabel(str(self.MAX_MARKER_SIZE)))
+ markerSizeLayout.addWidget(markerSizeButton, 0)
+
+ layout.addRow('Marker Size', markerSizeLayout)
+
+ # Line width
+ self._lineWidthSlider = qt.QSlider()
+ self._lineWidthSlider.setOrientation(qt.Qt.Horizontal)
+ self._lineWidthSlider.setSingleStep(1)
+ self._lineWidthSlider.setRange(1, self.MAX_LINE_WIDTH)
+ self._lineWidthSlider.setValue(1)
+
+ lineWidthButton = qt.QPushButton('Set')
+ lineWidthButton.setToolTip("Set line width for all items")
+ lineWidthButton.clicked.connect(self._lineWidthButtonClicked)
+
+ lineWidthLayout = qt.QHBoxLayout()
+ lineWidthLayout.setContentsMargins(0, 0, 0, 0)
+ lineWidthLayout.addWidget(qt.QLabel('1'))
+ lineWidthLayout.addWidget(self._lineWidthSlider, 1)
+ lineWidthLayout.addWidget(qt.QLabel(str(self.MAX_LINE_WIDTH)))
+ lineWidthLayout.addWidget(lineWidthButton, 0)
+
+ layout.addRow('Line Width', lineWidthLayout)
+
+ self._colormapDialog = None # To store dialog
+ self._colormap = Colormap()
+
+ def getGroup(self):
+ """Returns the :class:`GroupItem` this widget is attached to.
+
+ :rtype: Union[GroupItem, None]
+ """
+ return self._group
+
+ def setGroup(self, group):
+ """Set the :class:`GroupItem` this widget is attached to.
+
+ :param GroupItem group: GroupItem to control (or None)
+ """
+ self._group = group
+ if group is not None:
+ self.setEnabled(True)
+
+ def _colormapButtonClicked(self, checked=False):
+ """Handle colormap button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ if self._colormapDialog is None:
+ self._colormapDialog = ColormapDialog(self)
+ self._colormapDialog.setColormap(self._colormap)
+
+ previousColormap = self._colormapDialog.getColormap()
+ if self._colormapDialog.exec():
+ colormap = self._colormapDialog.getColormap()
+
+ for item in group.visit():
+ if isinstance(item, ColormapMixIn):
+ itemCmap = item.getColormap()
+ cmapName = colormap.getName()
+ if cmapName is not None:
+ itemCmap.setName(colormap.getName())
+ else:
+ itemCmap.setColormapLUT(colormap.getColormapLUT())
+ itemCmap.setNormalization(colormap.getNormalization())
+ itemCmap.setGammaNormalizationParameter(
+ colormap.getGammaNormalizationParameter())
+ itemCmap.setVRange(colormap.getVMin(), colormap.getVMax())
+ else:
+ # Reset colormap
+ self._colormapDialog.setColormap(previousColormap)
+
+ def _markerButtonClicked(self, checked=False):
+ """Handle marker set button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ marker = self._markerComboBox.currentText()
+ for item in group.visit():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbol(marker)
+
+ def _markerSizeButtonClicked(self, checked=False):
+ """Handle marker size set button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ markerSize = self._markerSizeSlider.value()
+ for item in group.visit():
+ if isinstance(item, SymbolMixIn):
+ item.setSymbolSize(markerSize)
+
+ def _lineWidthButtonClicked(self, checked=False):
+ """Handle line width set button clicked"""
+ group = self.getGroup()
+ if group is None:
+ return
+
+ lineWidth = self._lineWidthSlider.value()
+ for item in group.visit():
+ if hasattr(item, 'setLineWidth'):
+ item.setLineWidth(lineWidth)
diff --git a/src/silx/gui/plot3d/tools/PositionInfoWidget.py b/src/silx/gui/plot3d/tools/PositionInfoWidget.py
new file mode 100644
index 0000000..99d6356
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/PositionInfoWidget.py
@@ -0,0 +1,225 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a widget that displays data values of a SceneWidget.
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/10/2018"
+
+
+import logging
+import weakref
+
+from ... import qt
+from .. import actions
+from .. import items
+from ..items import volume
+from ..SceneWidget import SceneWidget
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PositionInfoWidget(qt.QWidget):
+ """Widget displaying information about picked position
+
+ :param QWidget parent: See :class:`QWidget`
+ """
+
+ def __init__(self, parent=None):
+ super(PositionInfoWidget, self).__init__(parent)
+ self._sceneWidgetRef = None
+
+ self.setToolTip("Double-click on a data point to show its value")
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight, self)
+
+ self._xLabel = self._addInfoField('X')
+ self._yLabel = self._addInfoField('Y')
+ self._zLabel = self._addInfoField('Z')
+ self._dataLabel = self._addInfoField('Data')
+ self._itemLabel = self._addInfoField('Item')
+
+ layout.addStretch(1)
+
+ self._action = actions.mode.PickingModeAction(parent=self)
+ self._action.setText('Selection')
+ self._action.setToolTip(
+ 'Toggle selection information update with left button click')
+ self._action.sigSceneClicked.connect(self.pick)
+ self._action.changed.connect(self.__actionChanged)
+ self._action.setChecked(False) # Disabled by default
+ self.__actionChanged() # Sync action/widget
+
+ def __actionChanged(self):
+ """Handle toggle action change signal"""
+ if self.toggleAction().isChecked() != self.isEnabled():
+ self.setEnabled(self.toggleAction().isChecked())
+
+ def toggleAction(self):
+ """The action to toggle the picking mode.
+
+ :rtype: QAction
+ """
+ return self._action
+
+ def _addInfoField(self, label):
+ """Add a description: info widget to this widget
+
+ :param str label: Description label
+ :return: The QLabel used to display the info
+ :rtype: QLabel
+ """
+ subLayout = qt.QHBoxLayout()
+ subLayout.setContentsMargins(0, 0, 0, 0)
+
+ subLayout.addWidget(qt.QLabel(label + ':'))
+
+ widget = qt.QLabel('-')
+ widget.setAlignment(qt.Qt.AlignLeft | qt.Qt.AlignVCenter)
+ widget.setTextInteractionFlags(qt.Qt.TextSelectableByMouse)
+
+ metrics = widget.fontMetrics()
+ if qt.BINDING in ('PySide2', 'PyQt5'):
+ width = metrics.width("#######")
+ else: # Qt6
+ width = metrics.horizontalAdvance("#######")
+ widget.setMinimumWidth(width)
+ subLayout.addWidget(widget)
+
+ subLayout.addStretch(1)
+
+ layout = self.layout()
+ layout.addLayout(subLayout)
+ return widget
+
+ def getSceneWidget(self):
+ """Returns the associated :class:`SceneWidget` or None.
+
+ :rtype: Union[None,~silx.gui.plot3d.SceneWidget.SceneWidget]
+ """
+ if self._sceneWidgetRef is None:
+ return None
+ else:
+ return self._sceneWidgetRef()
+
+ def setSceneWidget(self, widget):
+ """Set the associated :class:`SceneWidget`
+
+ :param ~silx.gui.plot3d.SceneWidget.SceneWidget widget:
+ 3D scene for which to display information
+ """
+ if widget is not None and not isinstance(widget, SceneWidget):
+ raise ValueError("widget must be a SceneWidget or None")
+
+ self._sceneWidgetRef = None if widget is None else weakref.ref(widget)
+
+ self.toggleAction().setPlot3DWidget(widget)
+
+ def clear(self):
+ """Clean-up displayed values"""
+ for widget in (self._xLabel, self._yLabel, self._zLabel,
+ self._dataLabel, self._itemLabel):
+ widget.setText('-')
+
+ _SUPPORTED_ITEMS = (items.Scatter3D,
+ items.Scatter2D,
+ items.ImageData,
+ items.ImageRgba,
+ items.HeightMapData,
+ items.HeightMapRGBA,
+ items.Mesh,
+ items.Box,
+ items.Cylinder,
+ items.Hexagon,
+ volume.CutPlane,
+ volume.Isosurface)
+ """Type of items that are picked"""
+
+ def _isSupportedItem(self, item):
+ """Returns True if item is of supported type
+
+ :param Item3D item: The Item3D to check
+ :rtype: bool
+ """
+ return isinstance(item, self._SUPPORTED_ITEMS)
+
+ def pick(self, x, y):
+ """Pick items in the associated SceneWidget and display result
+
+ Only the closest point is displayed.
+
+ :param int x: X coordinate in pixel in the SceneWidget
+ :param int y: Y coordinate in pixel in the SceneWidget
+ """
+ self.clear()
+
+ sceneWidget = self.getSceneWidget()
+ if sceneWidget is None: # No associated widget
+ _logger.info('Picking without associated SceneWidget')
+ return
+
+ # Find closest (and latest in the tree) supported item
+ closestNdcZ = float('inf')
+ picking = None
+ for result in sceneWidget.pickItems(x, y,
+ condition=self._isSupportedItem):
+ ndcZ = result.getPositions('ndc', copy=False)[0, 2]
+ if ndcZ <= closestNdcZ:
+ closestNdcZ = ndcZ
+ picking = result
+
+ if picking is None:
+ return # No picked item
+
+ item = picking.getItem()
+ self._itemLabel.setText(item.getLabel())
+ positions = picking.getPositions('scene', copy=False)
+ x, y, z = positions[0]
+ self._xLabel.setText("%g" % x)
+ self._yLabel.setText("%g" % y)
+ self._zLabel.setText("%g" % z)
+
+ data = picking.getData(copy=False)
+ if data is not None:
+ data = data[0]
+ if hasattr(data, '__len__'):
+ text = ' '.join(["%.3g"] * len(data)) % tuple(data)
+ else:
+ text = "%g" % data
+ self._dataLabel.setText(text)
+
+ def updateInfo(self):
+ """Update information according to cursor position"""
+ widget = self.getSceneWidget()
+ if widget is None:
+ _logger.info('Update without associated SceneWidget')
+ self.clear()
+ return
+
+ position = widget.mapFromGlobal(qt.QCursor.pos())
+ self.pick(position.x(), position.y())
diff --git a/silx/gui/plot3d/tools/ViewpointTools.py b/src/silx/gui/plot3d/tools/ViewpointTools.py
index 0607382..0607382 100644
--- a/silx/gui/plot3d/tools/ViewpointTools.py
+++ b/src/silx/gui/plot3d/tools/ViewpointTools.py
diff --git a/silx/gui/plot3d/tools/__init__.py b/src/silx/gui/plot3d/tools/__init__.py
index c8b8d21..c8b8d21 100644
--- a/silx/gui/plot3d/tools/__init__.py
+++ b/src/silx/gui/plot3d/tools/__init__.py
diff --git a/src/silx/gui/plot3d/tools/test/__init__.py b/src/silx/gui/plot3d/tools/test/__init__.py
new file mode 100644
index 0000000..86741ed
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/test/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""plot3d tools test suite."""
diff --git a/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py b/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
new file mode 100644
index 0000000..17fb3db
--- /dev/null
+++ b/src/silx/gui/plot3d/tools/test/testPositionInfoWidget.py
@@ -0,0 +1,89 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+# ###########################################################################*/
+"""Test PositionInfoWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/10/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+from silx.gui.plot3d.SceneWidget import SceneWidget
+from silx.gui.plot3d.tools.PositionInfoWidget import PositionInfoWidget
+
+
+class TestPositionInfoWidget(TestCaseQt):
+ """Tests PositionInfoWidget"""
+
+ def setUp(self):
+ super(TestPositionInfoWidget, self).setUp()
+ self.sceneWidget = SceneWidget()
+ self.sceneWidget.resize(300, 300)
+ self.sceneWidget.show()
+
+ self.positionInfoWidget = PositionInfoWidget()
+ self.positionInfoWidget.setSceneWidget(self.sceneWidget)
+ self.positionInfoWidget.show()
+ self.qWaitForWindowExposed(self.positionInfoWidget)
+
+ # self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.qapp.processEvents()
+
+ self.sceneWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.sceneWidget.close()
+ del self.sceneWidget
+
+ self.positionInfoWidget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.positionInfoWidget.close()
+ del self.positionInfoWidget
+ super(TestPositionInfoWidget, self).tearDown()
+
+ def test(self):
+ """Test PositionInfoWidget"""
+ self.assertIs(self.positionInfoWidget.getSceneWidget(),
+ self.sceneWidget)
+
+ data = numpy.arange(100)
+ self.sceneWidget.add2DScatter(x=data, y=data, value=data)
+ self.sceneWidget.resetZoom('front')
+
+ # Double click at the center
+ self.mouseDClick(self.sceneWidget, button=qt.Qt.LeftButton)
+
+ # Clear displayed value
+ self.positionInfoWidget.clear()
+
+ # Update info from API
+ self.positionInfoWidget.pick(x=10, y=10)
+
+ # Remove SceneWidget
+ self.positionInfoWidget.setSceneWidget(None)
diff --git a/silx/gui/plot3d/tools/toolbars.py b/src/silx/gui/plot3d/tools/toolbars.py
index d4f32db..d4f32db 100644
--- a/silx/gui/plot3d/tools/toolbars.py
+++ b/src/silx/gui/plot3d/tools/toolbars.py
diff --git a/silx/gui/plot3d/utils/__init__.py b/src/silx/gui/plot3d/utils/__init__.py
index 99d3e08..99d3e08 100644
--- a/silx/gui/plot3d/utils/__init__.py
+++ b/src/silx/gui/plot3d/utils/__init__.py
diff --git a/silx/gui/plot3d/utils/mng.py b/src/silx/gui/plot3d/utils/mng.py
index 8049a2f..8049a2f 100644
--- a/silx/gui/plot3d/utils/mng.py
+++ b/src/silx/gui/plot3d/utils/mng.py
diff --git a/silx/gui/printer.py b/src/silx/gui/printer.py
index 761fa0f..761fa0f 100644
--- a/silx/gui/printer.py
+++ b/src/silx/gui/printer.py
diff --git a/src/silx/gui/qt/__init__.py b/src/silx/gui/qt/__init__.py
new file mode 100644
index 0000000..915c89b
--- /dev/null
+++ b/src/silx/gui/qt/__init__.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Common wrapper over Python Qt bindings:
+
+- `PyQt5 <http://pyqt.sourceforge.net/Docs/PyQt5/>`_
+- `PySide2 <https://pypi.org/project/PySide2/>`_
+- `PySide6 <https://pypi.org/project/PySide6/>`_
+
+If a Qt binding is already loaded, it will use it, otherwise the different
+Qt bindings are tried in this order: PyQt5, PySide2, PySide6.
+
+The name of the loaded Qt binding is stored in the BINDING variable.
+
+This module provides a flat namespace over Qt bindings by importing
+all symbols from **QtCore**, **QtGui**, **QtWidgets** and **QtPrintSupport**
+packages and if available from **QtOpenGL** and **QtSvg** packages.
+
+Example of using :mod:`silx.gui.qt` module:
+
+>>> from silx.gui import qt
+>>> app = qt.QApplication([])
+>>> widget = qt.QWidget()
+
+For an alternative solution providing a structured namespace,
+see `qtpy <https://pypi.org/project/QtPy/>`_.
+"""
+
+from ._qt import * # noqa
+if BINDING in ('PySide2', 'PySide6'):
+ # Import loadUi wrapper
+ from ._pyside_dynamic import loadUi # noqa
+from ._utils import * # noqa
diff --git a/src/silx/gui/qt/_pyside_dynamic.py b/src/silx/gui/qt/_pyside_dynamic.py
new file mode 100644
index 0000000..a841eae
--- /dev/null
+++ b/src/silx/gui/qt/_pyside_dynamic.py
@@ -0,0 +1,235 @@
+# -*- coding: utf-8 -*-
+
+# Taken from: https://gist.github.com/cpbotha/1b42a20c8f3eb9bb7cb8
+# Plus: https://github.com/spyder-ide/qtpy/commit/001a862c401d757feb63025f88dbb4601d353c84
+
+# Copyright (c) 2011 Sebastian Wiesner <lunaryorn@gmail.com>
+# Modifications by Charl Botha <cpbotha@vxlabs.com>
+# * customWidgets support (registerCustomWidget() causes segfault in
+# pyside 1.1.2 on Ubuntu 12.04 x86_64)
+# * workingDirectory support in loadUi
+
+# found this here:
+# https://github.com/lunaryorn/snippets/blob/master/qt4/designer/pyside_dynamic.py
+
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""
+ How to load a user interface dynamically with PySide.
+
+ .. moduleauthor:: Sebastian Wiesner <lunaryorn@gmail.com>
+"""
+
+import logging
+
+from ._qt import BINDING
+if BINDING == 'PySide2':
+ from PySide2.QtCore import QMetaObject, Property, Qt
+ from PySide2.QtWidgets import QFrame
+ from PySide2.QtUiTools import QUiLoader
+elif BINDING == 'PySide6':
+ from PySide6.QtCore import QMetaObject, Property, Qt
+ from PySide6.QtWidgets import QFrame
+ from PySide6.QtUiTools import QUiLoader
+else:
+ raise RuntimeError("Unsupported Qt binding: %s", BINDING)
+
+
+_logger = logging.getLogger(__name__)
+
+
+class UiLoader(QUiLoader):
+ """
+ Subclass :class:`~PySide.QtUiTools.QUiLoader` to create the user interface
+ in a base instance.
+
+ Unlike :class:`~PySide.QtUiTools.QUiLoader` itself this class does not
+ create a new instance of the top-level widget, but creates the user
+ interface in an existing instance of the top-level class.
+
+ This mimics the behaviour of :func:`PyQt*.uic.loadUi`.
+ """
+
+ def __init__(self, baseinstance, customWidgets=None):
+ """
+ Create a loader for the given ``baseinstance``.
+
+ The user interface is created in ``baseinstance``, which must be an
+ instance of the top-level class in the user interface to load, or a
+ subclass thereof.
+
+ ``customWidgets`` is a dictionary mapping from class name to class
+ object for widgets that you've promoted in the Qt Designer
+ interface. Usually, this should be done by calling
+ registerCustomWidget on the QUiLoader, but
+ with PySide 1.1.2 on Ubuntu 12.04 x86_64 this causes a segfault.
+
+ ``parent`` is the parent object of this loader.
+ """
+
+ QUiLoader.__init__(self, baseinstance)
+ self.baseinstance = baseinstance
+ self.customWidgets = {}
+ self.uifile = None
+ self.customWidgets.update(customWidgets)
+
+ def createWidget(self, class_name, parent=None, name=''):
+ """
+ Function that is called for each widget defined in ui file,
+ overridden here to populate baseinstance instead.
+ """
+
+ if parent is None and self.baseinstance:
+ # supposed to create the top-level widget, return the base instance
+ # instead
+ return self.baseinstance
+
+ else:
+ if class_name in self.availableWidgets():
+ # create a new widget for child widgets
+ widget = QUiLoader.createWidget(self, class_name, parent, name)
+
+ else:
+ # if not in the list of availableWidgets,
+ # must be a custom widget
+ # this will raise KeyError if the user has not supplied the
+ # relevant class_name in the dictionary, or TypeError, if
+ # customWidgets is None
+ if class_name not in self.customWidgets:
+ raise Exception('No custom widget ' + class_name +
+ ' found in customWidgets param of' +
+ 'UiFile %s.' % self.uifile)
+ try:
+ widget = self.customWidgets[class_name](parent)
+ except Exception:
+ _logger.error("Fail to instanciate widget %s from file %s", class_name, self.uifile)
+ raise
+
+ if self.baseinstance:
+ # set an attribute for the new child widget on the base
+ # instance, just like PyQt*.uic.loadUi does.
+ setattr(self.baseinstance, name, widget)
+
+ # this outputs the various widget names, e.g.
+ # sampleGraphicsView, dockWidget, samplesTableView etc.
+ # print(name)
+
+ return widget
+
+ def _parse_custom_widgets(self, ui_file):
+ """
+ This function is used to parse a ui file and look for the <customwidgets>
+ section, then automatically load all the custom widget classes.
+ """
+ import importlib
+ from xml.etree.ElementTree import ElementTree
+
+ # Parse the UI file
+ etree = ElementTree()
+ ui = etree.parse(ui_file)
+
+ # Get the customwidgets section
+ custom_widgets = ui.find('customwidgets')
+
+ if custom_widgets is None:
+ return
+
+ custom_widget_classes = {}
+
+ for custom_widget in custom_widgets.getchildren():
+
+ cw_class = custom_widget.find('class').text
+ cw_header = custom_widget.find('header').text
+
+ module = importlib.import_module(cw_header)
+
+ custom_widget_classes[cw_class] = getattr(module, cw_class)
+
+ self.customWidgets.update(custom_widget_classes)
+
+ def load(self, uifile):
+ self._parse_custom_widgets(uifile)
+ self.uifile = uifile
+ return QUiLoader.load(self, uifile)
+
+
+class _Line(QFrame):
+ """Widget to use as 'Line' Qt designer"""
+ def __init__(self, parent=None):
+ super(_Line, self).__init__(parent)
+ self.setFrameShape(QFrame.HLine)
+ self.setFrameShadow(QFrame.Sunken)
+
+ def getOrientation(self):
+ shape = self.frameShape()
+ if shape == QFrame.HLine:
+ return Qt.Horizontal
+ elif shape == QFrame.VLine:
+ return Qt.Vertical
+ else:
+ raise RuntimeError("Wrong shape: %d", shape)
+
+ def setOrientation(self, orientation):
+ if orientation == Qt.Horizontal:
+ self.setFrameShape(QFrame.HLine)
+ elif orientation == Qt.Vertical:
+ self.setFrameShape(QFrame.VLine)
+ else:
+ raise ValueError("Unsupported orientation %s" % str(orientation))
+
+ orientation = Property("Qt::Orientation", getOrientation, setOrientation)
+
+
+CUSTOM_WIDGETS = {"Line": _Line}
+"""Default custom widgets for `loadUi`"""
+
+
+def loadUi(uifile, baseinstance=None, package=None, resource_suffix=None):
+ """
+ Dynamically load a user interface from the given ``uifile``.
+
+ ``uifile`` is a string containing a file name of the UI file to load.
+
+ If ``baseinstance`` is ``None``, the a new instance of the top-level widget
+ will be created. Otherwise, the user interface is created within the given
+ ``baseinstance``. In this case ``baseinstance`` must be an instance of the
+ top-level widget class in the UI file to load, or a subclass thereof. In
+ other words, if you've created a ``QMainWindow`` interface in the designer,
+ ``baseinstance`` must be a ``QMainWindow`` or a subclass thereof, too. You
+ cannot load a ``QMainWindow`` UI file with a plain
+ :class:`~PySide.QtGui.QWidget` as ``baseinstance``.
+
+ :method:`~PySide.QtCore.QMetaObject.connectSlotsByName()` is called on the
+ created user interface, so you can implemented your slots according to its
+ conventions in your widget class.
+
+ Return ``baseinstance``, if ``baseinstance`` is not ``None``. Otherwise
+ return the newly created instance of the user interface.
+ """
+ if package is not None:
+ _logger.warning(
+ "loadUi package parameter not implemented with PySide")
+ if resource_suffix is not None:
+ _logger.warning(
+ "loadUi resource_suffix parameter not implemented with PySide")
+
+ loader = UiLoader(baseinstance, customWidgets=CUSTOM_WIDGETS)
+ widget = loader.load(uifile)
+ QMetaObject.connectSlotsByName(widget)
+ return widget
diff --git a/src/silx/gui/qt/_qt.py b/src/silx/gui/qt/_qt.py
new file mode 100644
index 0000000..f62f4c8
--- /dev/null
+++ b/src/silx/gui/qt/_qt.py
@@ -0,0 +1,232 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Load Qt binding"""
+
+__authors__ = ["V.A. Sole"]
+__license__ = "MIT"
+__date__ = "23/05/2018"
+
+
+import logging
+import sys
+import traceback
+
+
+_logger = logging.getLogger(__name__)
+
+
+BINDING = None
+"""The name of the Qt binding in use: PyQt5, PySide2, PySide6."""
+
+QtBinding = None # noqa
+"""The Qt binding module in use: PyQt5, PySide2, PySide6."""
+
+HAS_SVG = False
+"""True if Qt provides support for Scalable Vector Graphics (QtSVG)."""
+
+HAS_OPENGL = False
+"""True if Qt provides support for OpenGL (QtOpenGL)."""
+
+# First check for an already loaded wrapper
+for _binding in ('PySide2', 'PyQt5', 'PySide6'):
+ if _binding + '.QtCore' in sys.modules:
+ BINDING = _binding
+ break
+else: # Then try Qt bindings
+ try:
+ import PyQt5.QtCore # noqa
+ except ImportError:
+ if 'PyQt5' in sys.modules:
+ del sys.modules["PyQt5"]
+ try:
+ import PySide2.QtCore # noqa
+ except ImportError:
+ if 'PySide2' in sys.modules:
+ del sys.modules["PySide2"]
+ try:
+ import PySide6.QtCore # noqa
+ except ImportError:
+ if 'PySide6' in sys.modules:
+ del sys.modules["PySide6"]
+ raise ImportError(
+ 'No Qt wrapper found. Install PyQt5, PySide2, PySide6.')
+ else:
+ BINDING = 'PySide6'
+ else:
+ BINDING = 'PySide2'
+ else:
+ BINDING = 'PyQt5'
+
+
+if BINDING == 'PyQt5':
+ _logger.debug('Using PyQt5 bindings')
+
+ import PyQt5 as QtBinding # noqa
+
+ from PyQt5.QtCore import * # noqa
+ from PyQt5.QtGui import * # noqa
+ from PyQt5.QtWidgets import * # noqa
+ from PyQt5.QtPrintSupport import * # noqa
+
+ try:
+ from PyQt5.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PyQt5.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PyQt5.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PyQt5.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ from PyQt5.uic import loadUi # noqa
+
+ Signal = pyqtSignal
+
+ Property = pyqtProperty
+
+ Slot = pyqtSlot
+
+ # Disable PyQt5's cooperative multi-inheritance since other bindings do not provide it.
+ # See https://www.riverbankcomputing.com/static/Docs/PyQt5/multiinheritance.html?highlight=inheritance
+ class _Foo(object): pass
+ class QObject(QObject, _Foo): pass
+
+
+elif BINDING == 'PySide2':
+ _logger.debug('Using PySide2 bindings')
+
+ import PySide2 as QtBinding # noqa
+
+ from PySide2.QtCore import * # noqa
+ from PySide2.QtGui import * # noqa
+ from PySide2.QtWidgets import * # noqa
+ from PySide2.QtPrintSupport import * # noqa
+
+ try:
+ from PySide2.QtOpenGL import * # noqa
+ except ImportError:
+ _logger.info("PySide2.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PySide2.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PySide2.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ pyqtSignal = Signal
+
+ # Qt6 compatibility:
+ # with PySide2 `exec` method has a special behavior
+ class _ExecMixIn:
+ """Mix-in class providind `exec` compatibility"""
+ def exec(self, *args, **kwargs):
+ return super().exec_(*args, **kwargs)
+
+ # QtWidgets
+ class QApplication(_ExecMixIn, QApplication): pass
+ class QColorDialog(_ExecMixIn, QColorDialog): pass
+ class QDialog(_ExecMixIn, QDialog): pass
+ class QErrorMessage(_ExecMixIn, QErrorMessage): pass
+ class QFileDialog(_ExecMixIn, QFileDialog): pass
+ class QFontDialog(_ExecMixIn, QFontDialog): pass
+ class QInputDialog(_ExecMixIn, QInputDialog): pass
+ class QMenu(_ExecMixIn, QMenu): pass
+ class QMessageBox(_ExecMixIn, QMessageBox): pass
+ class QProgressDialog(_ExecMixIn, QProgressDialog): pass
+ #QtCore
+ class QCoreApplication(_ExecMixIn, QCoreApplication): pass
+ class QEventLoop(_ExecMixIn, QEventLoop): pass
+ if hasattr(QTextStreamManipulator, "exec_"):
+ # exec_ only wrapped in PySide2 and NOT in PyQt5
+ class QTextStreamManipulator(_ExecMixIn, QTextStreamManipulator): pass
+ class QThread(_ExecMixIn, QThread): pass
+
+
+elif BINDING == 'PySide6':
+ _logger.debug('Using PySide6 bindings')
+
+ import PySide6 as QtBinding # noqa
+
+ from PySide6.QtCore import * # noqa
+ from PySide6.QtGui import * # noqa
+ from PySide6.QtWidgets import * # noqa
+ from PySide6.QtPrintSupport import * # noqa
+
+ try:
+ from PySide6.QtOpenGL import * # noqa
+ from PySide6.QtOpenGLWidgets import QOpenGLWidget # noqa
+ except ImportError:
+ _logger.info("PySide6.QtOpenGL not available")
+ HAS_OPENGL = False
+ else:
+ HAS_OPENGL = True
+
+ try:
+ from PySide6.QtSvg import * # noqa
+ except ImportError:
+ _logger.info("PySide6.QtSvg not available")
+ HAS_SVG = False
+ else:
+ HAS_SVG = True
+
+ pyqtSignal = Signal
+
+else:
+ raise ImportError('No Qt wrapper found. Install PyQt5, PySide2 or PySide6')
+
+
+# provide a exception handler but not implement it by default
+def exceptionHandler(type_, value, trace):
+ """
+ This exception handler prevents quitting to the command line when there is
+ an unhandled exception while processing a Qt signal.
+
+ The script/application willing to use it should implement code similar to:
+
+ .. code-block:: python
+
+ if __name__ == "__main__":
+ sys.excepthook = qt.exceptionHandler
+
+ """
+ _logger.error("%s %s %s", type_, value, ''.join(traceback.format_tb(trace)))
+ msg = QMessageBox()
+ msg.setWindowTitle("Unhandled exception")
+ msg.setIcon(QMessageBox.Critical)
+ msg.setInformativeText("%s %s\nPlease report details" % (type_, value))
+ msg.setDetailedText(("%s " % value) + ''.join(traceback.format_tb(trace)))
+ msg.raise_()
+ msg.exec()
diff --git a/src/silx/gui/qt/_utils.py b/src/silx/gui/qt/_utils.py
new file mode 100644
index 0000000..5dced95
--- /dev/null
+++ b/src/silx/gui/qt/_utils.py
@@ -0,0 +1,68 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides convenient functions related to Qt.
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "30/11/2016"
+
+
+from . import _qt
+
+
+def supportedImageFormats():
+ """Return a set of string of file format extensions supported by the
+ Qt runtime."""
+ if _qt.BINDING == 'PySide2':
+ def convert(data):
+ return str(data.data(), 'ascii')
+ else:
+ convert = lambda data: str(data, 'ascii')
+ formats = _qt.QImageReader.supportedImageFormats()
+ return set([convert(data) for data in formats])
+
+
+__globalThreadPoolInstance = None
+"""Store the own silx global thread pool"""
+
+
+def silxGlobalThreadPool():
+ """"Manage an own QThreadPool to avoid issue on Qt5 Windows with the
+ default Qt global thread pool.
+
+ A thread pool is create in lazy loading. With a maximum of 4 threads.
+ Else `qt.Thread.idealThreadCount()` is used.
+
+ :rtype: qt.QThreadPool
+ """
+ global __globalThreadPoolInstance
+ if __globalThreadPoolInstance is None:
+ tp = _qt.QThreadPool()
+ # Setting maxThreadCount fixes a segfault with PyQt 5.9.1 on Windows
+ maxThreadCount = min(4, tp.maxThreadCount())
+ tp.setMaxThreadCount(maxThreadCount)
+ __globalThreadPoolInstance = tp
+ return __globalThreadPoolInstance
diff --git a/src/silx/gui/qt/inspect.py b/src/silx/gui/qt/inspect.py
new file mode 100644
index 0000000..b9a0d1d
--- /dev/null
+++ b/src/silx/gui/qt/inspect.py
@@ -0,0 +1,75 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides functions to access Qt C++ object state:
+
+- :func:`isValid` to check whether a QObject C++ pointer is valid.
+- :func:`createdByPython` to check if a QObject was created from Python.
+- :func:`ownedByPython` to check if a QObject is currently owned by Python.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/10/2018"
+
+
+from . import _qt as qt
+
+
+if qt.BINDING == 'PyQt5':
+ try:
+ from PyQt5.sip import isdeleted as _isdeleted # noqa
+ from PyQt5.sip import ispycreated as createdByPython # noqa
+ from PyQt5.sip import ispyowned as ownedByPython # noqa
+ except ImportError:
+ from sip import isdeleted as _isdeleted # noqa
+ from sip import ispycreated as createdByPython # noqa
+ from sip import ispyowned as ownedByPython # noqa
+
+
+ def isValid(obj):
+ """Returns True if underlying C++ object is valid.
+
+ :param QObject obj:
+ :rtype: bool
+ """
+ return not _isdeleted(obj)
+
+elif qt.BINDING == 'PySide2':
+ try:
+ from PySide2.shiboken2 import isValid # noqa
+ from PySide2.shiboken2 import createdByPython # noqa
+ from PySide2.shiboken2 import ownedByPython # noqa
+ except ImportError:
+ from shiboken2 import isValid # noqa
+ from shiboken2 import createdByPython # noqa
+ from shiboken2 import ownedByPython # noqa
+
+elif qt.BINDING == 'PySide6':
+ from shiboken6 import isValid, createdByPython, ownedByPython # noqa
+
+else:
+ raise ImportError("Unsupported Qt binding %s" % qt.BINDING)
+
+__all__ = ['isValid', 'createdByPython', 'ownedByPython']
diff --git a/silx/gui/setup.py b/src/silx/gui/setup.py
index 04a2bac..04a2bac 100644
--- a/silx/gui/setup.py
+++ b/src/silx/gui/setup.py
diff --git a/src/silx/gui/test/__init__.py b/src/silx/gui/test/__init__.py
new file mode 100644
index 0000000..00d6216
--- /dev/null
+++ b/src/silx/gui/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/test/test_colors.py b/src/silx/gui/test/test_colors.py
new file mode 100755
index 0000000..fa87d7d
--- /dev/null
+++ b/src/silx/gui/test/test_colors.py
@@ -0,0 +1,603 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the Colormap object
+"""
+
+from __future__ import absolute_import
+
+__authors__ = ["H.Payno"]
+__license__ = "MIT"
+__date__ = "09/11/2018"
+
+import unittest
+import numpy
+from silx.utils.testutils import ParametricTestCase
+from silx.gui import qt
+from silx.gui import colors
+from silx.gui.colors import Colormap
+from silx.gui.plot import items
+from silx.utils.exceptions import NotEditableError
+
+
+class TestColor(ParametricTestCase):
+ """Basic tests of rgba function"""
+
+ TEST_COLORS = { # name: (colors, expected values)
+ 'blue': ('blue', (0., 0., 1., 1.)),
+ '#010203': ('#010203', (1. / 255., 2. / 255., 3. / 255., 1.)),
+ '#01020304': ('#01020304', (1. / 255., 2. / 255., 3. / 255., 4. / 255.)),
+ '3 x uint8': (numpy.array((1, 255, 0), dtype=numpy.uint8),
+ (1 / 255., 1., 0., 1.)),
+ '4 x uint8': (numpy.array((1, 255, 0, 1), dtype=numpy.uint8),
+ (1 / 255., 1., 0., 1 / 255.)),
+ '3 x float overflow': ((3., 0.5, 1.), (1., 0.5, 1., 1.)),
+ }
+
+ def testRGBA(self):
+ """"Test rgba function with accepted values"""
+ for name, test in self.TEST_COLORS.items():
+ color, expected = test
+ with self.subTest(msg=name):
+ result = colors.rgba(color)
+ self.assertEqual(result, expected)
+
+ def testQColor(self):
+ """"Test getQColor function with accepted values"""
+ for name, test in self.TEST_COLORS.items():
+ color, expected = test
+ with self.subTest(msg=name):
+ result = colors.asQColor(color)
+ self.assertAlmostEqual(result.redF(), expected[0], places=4)
+ self.assertAlmostEqual(result.greenF(), expected[1], places=4)
+ self.assertAlmostEqual(result.blueF(), expected[2], places=4)
+ self.assertAlmostEqual(result.alphaF(), expected[3], places=4)
+
+
+class TestApplyColormapToData(ParametricTestCase):
+ """Tests of applyColormapToData function"""
+
+ def testApplyColormapToData(self):
+ """Simple test of applyColormapToData function"""
+ colormap = Colormap(name='gray', normalization='linear',
+ vmin=0, vmax=255)
+
+ size = 10
+ expected = numpy.empty((size, 4), dtype='uint8')
+ expected[:, 0] = numpy.arange(size, dtype='uint8')
+ expected[:, 1] = expected[:, 0]
+ expected[:, 2] = expected[:, 0]
+ expected[:, 3] = 255
+
+ for dtype in ('uint8', 'int32', 'float32', 'float64'):
+ with self.subTest(dtype=dtype):
+ array = numpy.arange(size, dtype=dtype)
+ result = colormap.applyToData(data=array)
+ self.assertTrue(numpy.all(numpy.equal(result, expected)))
+
+ def testAutoscaleFromDataReference(self):
+ colormap = Colormap(name='gray', normalization='linear')
+ data = numpy.array([50])
+ reference = numpy.array([0, 100])
+ value = colormap.applyToData(data, reference)
+ self.assertEqual(len(value), 1)
+ self.assertEqual(value[0, 0], 128)
+
+ def testAutoscaleFromItemReference(self):
+ colormap = Colormap(name='gray', normalization='linear')
+ data = numpy.array([50])
+ image = items.ImageData()
+ image.setData(numpy.array([[0, 100]]))
+ value = colormap.applyToData(data, reference=image)
+ self.assertEqual(len(value), 1)
+ self.assertEqual(value[0, 0], 128)
+
+ def testNaNColor(self):
+ """Test Colormap.applyToData with NaN values"""
+ colormap = Colormap(name='gray', normalization='linear')
+ colormap.setNaNColor('red')
+ self.assertEqual(colormap.getNaNColor(), qt.QColor(255, 0, 0))
+
+ data = numpy.array([50., numpy.nan])
+ image = items.ImageData()
+ image.setData(numpy.array([[0, 100]]))
+ value = colormap.applyToData(data, reference=image)
+ self.assertEqual(len(value), 2)
+ self.assertTrue(numpy.array_equal(value[0], (128, 128, 128, 255)))
+ self.assertTrue(numpy.array_equal(value[1], (255, 0, 0, 255)))
+
+
+class TestDictAPI(unittest.TestCase):
+ """Make sure the old dictionary API is working
+ """
+
+ def setUp(self):
+ self.vmin = -1.0
+ self.vmax = 12
+
+ def testGetItem(self):
+ """test the item getter API ([xxx])"""
+ colormap = Colormap(name='viridis',
+ normalization=Colormap.LINEAR,
+ vmin=self.vmin,
+ vmax=self.vmax)
+ self.assertTrue(colormap['name'] == 'viridis')
+ self.assertTrue(colormap['normalization'] == Colormap.LINEAR)
+ self.assertTrue(colormap['vmin'] == self.vmin)
+ self.assertTrue(colormap['vmax'] == self.vmax)
+ with self.assertRaises(KeyError):
+ colormap['toto']
+
+ def testGetDict(self):
+ """Test the getDict function API"""
+ clmObject = Colormap(name='viridis',
+ normalization=Colormap.LINEAR,
+ vmin=self.vmin,
+ vmax=self.vmax)
+ clmDict = clmObject._toDict()
+ self.assertTrue(clmDict['name'] == 'viridis')
+ self.assertTrue(clmDict['autoscale'] is False)
+ self.assertTrue(clmDict['vmin'] == self.vmin)
+ self.assertTrue(clmDict['vmax'] == self.vmax)
+ self.assertTrue(clmDict['normalization'] == Colormap.LINEAR)
+
+ clmObject.setVRange(None, None)
+ self.assertTrue(clmObject._toDict()['autoscale'] is True)
+
+ def testSetValidDict(self):
+ """Test that if a colormap is created from a dict then it is correctly
+ created and the values are copied (so if some values from the dict
+ is changing, this won't affect the Colormap object"""
+ clm_dict = {
+ 'name': 'temperature',
+ 'vmin': 1.0,
+ 'vmax': 2.0,
+ 'normalization': 'linear',
+ 'colors': None,
+ 'autoscale': False
+ }
+
+ # Test that the colormap is correctly created
+ colormapObject = Colormap._fromDict(clm_dict)
+ self.assertTrue(colormapObject.getName() == clm_dict['name'])
+ self.assertTrue(colormapObject.getColormapLUT() == clm_dict['colors'])
+ self.assertTrue(colormapObject.getVMin() == clm_dict['vmin'])
+ self.assertTrue(colormapObject.getVMax() == clm_dict['vmax'])
+ self.assertTrue(colormapObject.isAutoscale() == clm_dict['autoscale'])
+
+ # Check that the colormap has copied the values
+ clm_dict['vmin'] = None
+ clm_dict['vmax'] = None
+ clm_dict['colors'] = [1.0, 2.0]
+ clm_dict['autoscale'] = True
+ clm_dict['normalization'] = Colormap.LOGARITHM
+ clm_dict['name'] = 'viridis'
+
+ self.assertFalse(colormapObject.getName() == clm_dict['name'])
+ self.assertFalse(colormapObject.getColormapLUT() == clm_dict['colors'])
+ self.assertFalse(colormapObject.getVMin() == clm_dict['vmin'])
+ self.assertFalse(colormapObject.getVMax() == clm_dict['vmax'])
+ self.assertFalse(colormapObject.isAutoscale() == clm_dict['autoscale'])
+
+ def testMissingKeysFromDict(self):
+ """Make sure we can create a Colormap object from a dictionary even if
+ there is missing keys except if those keys are 'colors' or 'name'
+ """
+ colormap = Colormap._fromDict({'name': 'blue'})
+ self.assertTrue(colormap.getVMin() is None)
+ colormap = Colormap._fromDict({'colors': numpy.zeros((5, 3))})
+ self.assertTrue(colormap.getName() is None)
+
+ with self.assertRaises(ValueError):
+ Colormap._fromDict({})
+
+ def testUnknowNorm(self):
+ """Make sure an error is raised if the given normalization is not
+ knowed
+ """
+ clm_dict = {
+ 'name': 'temperature',
+ 'vmin': 1.0,
+ 'vmax': 2.0,
+ 'normalization': 'toto',
+ 'colors': None,
+ 'autoscale': False
+ }
+ with self.assertRaises(ValueError):
+ Colormap._fromDict(clm_dict)
+
+ def testNumericalColors(self):
+ """Make sure the old API using colors=int was supported"""
+ clm_dict = {
+ 'name': 'temperature',
+ 'vmin': 1.0,
+ 'vmax': 2.0,
+ 'colors': 256,
+ 'autoscale': False
+ }
+ Colormap._fromDict(clm_dict)
+
+
+class TestObjectAPI(ParametricTestCase):
+ """Test the new Object API of the colormap"""
+ def testVMinVMax(self):
+ """Test getter and setter associated to vmin and vmax values"""
+ vmin = 1.0
+ vmax = 2.0
+
+ colormapObject = Colormap(name='viridis',
+ vmin=vmin,
+ vmax=vmax,
+ normalization=Colormap.LINEAR)
+
+ with self.assertRaises(ValueError):
+ colormapObject.setVMin(3)
+
+ with self.assertRaises(ValueError):
+ colormapObject.setVMax(-2)
+
+ with self.assertRaises(ValueError):
+ colormapObject.setVRange(3, -2)
+
+ self.assertTrue(colormapObject.getColormapRange() == (1.0, 2.0))
+ self.assertTrue(colormapObject.isAutoscale() is False)
+ colormapObject.setVRange(None, None)
+ self.assertTrue(colormapObject.getVMin() is None)
+ self.assertTrue(colormapObject.getVMax() is None)
+ self.assertTrue(colormapObject.isAutoscale() is True)
+
+ def testCopy(self):
+ """Make sure the copy function is correctly processing
+ """
+ colormapObject = Colormap(name=None,
+ colors=numpy.array([[1., 0., 0.],
+ [0., 1., 0.],
+ [0., 0., 1.]]),
+ vmin=None,
+ vmax=None,
+ normalization=Colormap.LOGARITHM)
+
+ colormapObject2 = colormapObject.copy()
+ self.assertTrue(colormapObject == colormapObject2)
+ colormapObject.setColormapLUT([[0, 0, 0], [255, 255, 255]])
+ self.assertFalse(colormapObject == colormapObject2)
+
+ colormapObject2 = colormapObject.copy()
+ self.assertTrue(colormapObject == colormapObject2)
+ colormapObject.setNormalization(Colormap.LINEAR)
+ self.assertFalse(colormapObject == colormapObject2)
+
+ def testGetColorMapRange(self):
+ """Make sure the getColormapRange function of colormap is correctly
+ applying
+ """
+ # test linear scale
+ data = numpy.array([-1, 1, 2, 3, float('nan')])
+ cl1 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=0,
+ vmax=2)
+ cl2 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=2)
+ cl3 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=0,
+ vmax=None)
+ cl4 = Colormap(name='gray',
+ normalization=Colormap.LINEAR,
+ vmin=None,
+ vmax=None)
+
+ self.assertTrue(cl1.getColormapRange(data) == (0, 2))
+ self.assertTrue(cl2.getColormapRange(data) == (-1, 2))
+ self.assertTrue(cl3.getColormapRange(data) == (0, 3))
+ self.assertTrue(cl4.getColormapRange(data) == (-1, 3))
+
+ # test linear with annoying cases
+ self.assertEqual(cl3.getColormapRange((-1, -2)), (0, 0))
+ self.assertEqual(cl4.getColormapRange(()), (0., 1.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'))), (0., 1.))
+
+ # test log scale
+ data = numpy.array([float('nan'), -1, 1, 10, 100, 1000])
+ cl1 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=1,
+ vmax=100)
+ cl2 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=100)
+ cl3 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=1,
+ vmax=None)
+ cl4 = Colormap(name='gray',
+ normalization=Colormap.LOGARITHM,
+ vmin=None,
+ vmax=None)
+
+ self.assertTrue(cl1.getColormapRange(data) == (1, 100))
+ self.assertTrue(cl2.getColormapRange(data) == (1, 100))
+ self.assertTrue(cl3.getColormapRange(data) == (1, 1000))
+ self.assertTrue(cl4.getColormapRange(data) == (1, 1000))
+
+ # test log with annoying cases
+ self.assertEqual(cl3.getColormapRange((0.1, 0.2)), (1, 1))
+ self.assertEqual(cl4.getColormapRange((-2., -1.)), (1., 1.))
+ self.assertEqual(cl4.getColormapRange(()), (1., 10.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'), 1., -float('inf'), 2)), (1., 2.))
+ self.assertEqual(cl4.getColormapRange(
+ (float('nan'), float('inf'))), (1., 10.))
+
+ def testApplyToData(self):
+ """Test applyToData on different datasets"""
+ datasets = [
+ numpy.zeros((0, 0)), # Empty array
+ numpy.array((numpy.nan, numpy.inf)), # All non-finite
+ numpy.array((-numpy.inf, numpy.inf, 1.0, 2.0)), # Some infinite
+ ]
+
+ for normalization in ('linear', 'log'):
+ colormap = Colormap(name='gray',
+ normalization=normalization,
+ vmin=None,
+ vmax=None)
+
+ for data in datasets:
+ with self.subTest(data=data):
+ image = colormap.applyToData(data)
+ self.assertEqual(image.dtype, numpy.uint8)
+ self.assertEqual(image.shape[-1], 4)
+ self.assertEqual(image.shape[:-1], data.shape)
+
+ def testGetNColors(self):
+ """Test getNColors method"""
+ # specific LUT
+ colormap = Colormap(name=None,
+ colors=((0., 0., 0.), (1., 1., 1.)),
+ vmin=1000,
+ vmax=2000)
+ colors = colormap.getNColors()
+ self.assertTrue(numpy.all(numpy.equal(
+ colors,
+ ((0, 0, 0, 255), (255, 255, 255, 255)))))
+
+ def testEditableMode(self):
+ """Make sure the colormap will raise NotEditableError when try to
+ change a colormap not editable"""
+ colormap = Colormap()
+ colormap.setEditable(False)
+ with self.assertRaises(NotEditableError):
+ colormap.setVRange(0., 1.)
+ with self.assertRaises(NotEditableError):
+ colormap.setVMin(1.)
+ with self.assertRaises(NotEditableError):
+ colormap.setVMax(1.)
+ with self.assertRaises(NotEditableError):
+ colormap.setNormalization(Colormap.LOGARITHM)
+ with self.assertRaises(NotEditableError):
+ colormap.setName('magma')
+ with self.assertRaises(NotEditableError):
+ colormap.setColormapLUT([[0., 0., 0.], [1., 1., 1.]])
+ with self.assertRaises(NotEditableError):
+ colormap._setFromDict(colormap._toDict())
+ state = colormap.saveState()
+ with self.assertRaises(NotEditableError):
+ colormap.restoreState(state)
+
+ def testBadColorsType(self):
+ """Make sure colors can't be something else than an array"""
+ with self.assertRaises(TypeError):
+ Colormap(colors=256)
+
+ def testEqual(self):
+ colormap1 = Colormap()
+ colormap2 = Colormap()
+ self.assertEqual(colormap1, colormap2)
+
+ def testCompareString(self):
+ colormap = Colormap()
+ self.assertNotEqual(colormap, "a")
+
+ def testCompareNone(self):
+ colormap = Colormap()
+ self.assertNotEqual(colormap, None)
+
+ def testSet(self):
+ colormap = Colormap()
+ other = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ self.assertNotEqual(colormap, other)
+ colormap.setFromColormap(other)
+ self.assertIsNot(colormap, other)
+ self.assertEqual(colormap, other)
+
+ def testAutoscaleMode(self):
+ colormap = Colormap(autoscaleMode=Colormap.STDDEV3)
+ self.assertEqual(colormap.getAutoscaleMode(), Colormap.STDDEV3)
+ colormap.setAutoscaleMode(Colormap.MINMAX)
+ self.assertEqual(colormap.getAutoscaleMode(), Colormap.MINMAX)
+
+ def testStoreRestore(self):
+ colormaps = [
+ Colormap(name="viridis"),
+ Colormap(normalization=Colormap.SQRT)
+ ]
+ cmap = Colormap(normalization=Colormap.GAMMA)
+ cmap.setGammaNormalizationParameter(1.2)
+ cmap.setNaNColor('red')
+ colormaps.append(cmap)
+ for expected in colormaps:
+ with self.subTest(colormap=expected):
+ state = expected.saveState()
+ result = Colormap()
+ result.restoreState(state)
+ self.assertEqual(expected, result)
+
+ def testStorageV1(self):
+ state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00\x00'\
+ b'\x00\x01\x00\x00\x00\x0E\x00v\x00i\x00r\x00i\x00d\x00i\x00s\x00'\
+ b'\x00\x00\x00\x06\x00?\xF0\x00\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00'\
+ b'l\x00o\x00g'
+ state = qt.QByteArray(state)
+ colormap = Colormap()
+ colormap.restoreState(state)
+
+ expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ self.assertEqual(colormap, expected)
+
+ def testStorageV2(self):
+ state = b'\x00\x00\x00\x10\x00C\x00o\x00l\x00o\x00r\x00m\x00a\x00p\x00'\
+ b'\x00\x00\x02\x00\x00\x00\x0e\x00v\x00i\x00r\x00i\x00d\x00i\x00'\
+ b's\x00\x00\x00\x00\x06\x00?\xf0\x00\x00\x00\x00\x00\x00\x00\x00'\
+ b'\x00\x00\x06\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06'\
+ b'\x00l\x00o\x00g\x00\x00\x00\x0c\x00m\x00i\x00n\x00m\x00a\x00x'
+ state = qt.QByteArray(state)
+ colormap = Colormap()
+ colormap.restoreState(state)
+
+ expected = Colormap(name="viridis", vmin=1, vmax=2, normalization=Colormap.LOGARITHM)
+ expected.setGammaNormalizationParameter(1.5)
+ self.assertEqual(colormap, expected)
+
+
+class TestPreferredColormaps(unittest.TestCase):
+ """Test get|setPreferredColormaps functions"""
+
+ def setUp(self):
+ # Save preferred colormaps
+ self._colormaps = colors.preferredColormaps()
+
+ def tearDown(self):
+ # Restore saved preferred colormaps
+ colors.setPreferredColormaps(self._colormaps)
+
+ def test(self):
+ colormaps = 'viridis', 'magma'
+
+ colors.setPreferredColormaps(colormaps)
+ self.assertEqual(colors.preferredColormaps(), colormaps)
+
+ with self.assertRaises(ValueError):
+ colors.setPreferredColormaps(())
+
+ with self.assertRaises(ValueError):
+ colors.setPreferredColormaps(('This is not a colormap',))
+
+ colormaps = 'red', 'green'
+ colors.setPreferredColormaps(('This is not a colormap',) + colormaps)
+ self.assertEqual(colors.preferredColormaps(), colormaps)
+
+
+class TestRegisteredLut(unittest.TestCase):
+ """Test get|setPreferredColormaps functions"""
+
+ def setUp(self):
+ # Save preferred colormaps
+ lut = numpy.arange(8 * 3)
+ lut.shape = -1, 3
+ lut = lut / (8.0 * 3)
+ colors.registerLUT("test_8", colors=lut, cursor_color='blue')
+
+ def testColormap(self):
+ colormap = Colormap("test_8")
+ self.assertIsNotNone(colormap)
+
+ def testCursor(self):
+ color = colors.cursorColorForColormap("test_8")
+ self.assertEqual(color, 'blue')
+
+ def testLut(self):
+ colormap = Colormap("test_8")
+ colors = colormap.getNColors(8)
+ self.assertEqual(len(colors), 8)
+
+ def testUint8(self):
+ lut = numpy.array([[255, 0, 0], [200, 0, 0], [150, 0, 0]], dtype="uint")
+ colors.registerLUT("test_type", lut)
+ colormap = colors.Colormap(name="test_type")
+ lut = colormap.getNColors(3)
+ self.assertEqual(lut.shape, (3, 4))
+ self.assertEqual(lut[0, 0], 255)
+
+ def testFloatRGB(self):
+ lut = numpy.array([[1.0, 0, 0], [0.5, 0, 0], [0, 0, 0]], dtype="float")
+ colors.registerLUT("test_type", lut)
+ colormap = colors.Colormap(name="test_type")
+ lut = colormap.getNColors(3)
+ self.assertEqual(lut.shape, (3, 4))
+ self.assertEqual(lut[0, 0], 255)
+
+ def testFloatRGBA(self):
+ lut = numpy.array([[1.0, 0, 0, 128 / 256.0], [0.5, 0, 0, 1.0], [0.0, 0, 0, 1.0]], dtype="float")
+ colors.registerLUT("test_type", lut)
+ colormap = colors.Colormap(name="test_type")
+ lut = colormap.getNColors(3)
+ self.assertEqual(lut.shape, (3, 4))
+ self.assertEqual(lut[0, 0], 255)
+ self.assertEqual(lut[0, 3], 128)
+
+
+class TestAutoscaleRange(ParametricTestCase):
+
+ def testAutoscaleRange(self):
+ nan = numpy.nan
+ data_std_inside = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2])
+ data_std_inside_nan = numpy.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, numpy.nan])
+ data = [
+ # Positive values
+ (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50]), (10, 50)),
+ (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100]), (10, 100)),
+ (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside, (0.026671473215424735, 1.9733285267845753)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside, (1, 1.6733506885453602)),
+ (Colormap.LINEAR, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100]), (10, 100)),
+
+ # With nan
+ (Colormap.LINEAR, Colormap.MINMAX, numpy.array([10, 20, 50, nan]), (10, 50)),
+ (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, nan]), (10, 100)),
+ (Colormap.LINEAR, Colormap.STDDEV3, data_std_inside_nan, (0.026671473215424735, 1.9733285267845753)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, data_std_inside_nan, (1, 1.6733506885453602)),
+ # With negative
+ (Colormap.LOGARITHM, Colormap.MINMAX, numpy.array([10, 50, 100, -50]), (10, 100)),
+ (Colormap.LOGARITHM, Colormap.STDDEV3, numpy.array([10, 100, -10]), (10, 100)),
+ ]
+ for norm, mode, array, expectedRange in data:
+ with self.subTest(norm=norm, mode=mode, array=array):
+ colormap = Colormap()
+ colormap.setNormalization(norm)
+ colormap.setAutoscaleMode(mode)
+ vRange = colormap._computeAutoscaleRange(array)
+ if vRange is None:
+ self.assertIsNone(expectedRange)
+ else:
+ self.assertAlmostEqual(vRange[0], expectedRange[0])
+ self.assertAlmostEqual(vRange[1], expectedRange[1])
diff --git a/src/silx/gui/test/test_console.py b/src/silx/gui/test/test_console.py
new file mode 100644
index 0000000..21f3564
--- /dev/null
+++ b/src/silx/gui/test/test_console.py
@@ -0,0 +1,75 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for IPython console widget"""
+
+from __future__ import print_function
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import pytest
+from silx.gui import qt
+
+
+# dummy objects to test pushing variables to the interactive namespace
+_a = 1
+
+
+def _f():
+ print("Hello World!")
+
+
+@pytest.fixture
+def console(qapp_utils):
+ """Create a console widget"""
+ # Console tests disabled due to corruption of python environment
+ pytest.skip("Disabled (see issue #538)")
+ try:
+ from silx.gui.console import IPythonDockWidget
+ except ImportError:
+ pytest.skip("IPythonDockWidget is not available")
+
+ console = IPythonDockWidget(
+ available_vars={"a": _a, "f": _f},
+ custom_banner="Welcome!\n")
+ console.show()
+ qapp_utils.qWaitForWindowExposed(console)
+ yield console
+ console.setAttribute(qt.Qt.WA_DeleteOnClose)
+ console.close()
+ console = None
+
+
+def testShow(console):
+ pass
+
+
+def testInteract(console, qapp_utils):
+ qapp_utils.mouseClick(console, qt.Qt.LeftButton)
+ qapp_utils.keyClicks(console, 'import silx')
+ qapp_utils.keyClick(console, qt.Qt.Key_Enter)
+ qapp_utils.qapp.processEvents()
diff --git a/src/silx/gui/test/test_icons.py b/src/silx/gui/test/test_icons.py
new file mode 100644
index 0000000..154adf6
--- /dev/null
+++ b/src/silx/gui/test/test_icons.py
@@ -0,0 +1,144 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of Qt icons module."""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "06/09/2017"
+
+
+import unittest
+import weakref
+import tempfile
+import shutil
+import os
+
+import silx.resources
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import icons
+
+
+class TestIcons(TestCaseQt):
+ """Test to check that icons module."""
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestIcons, cls).setUpClass()
+
+ cls.tmpDirectory = tempfile.mkdtemp(prefix="resource_")
+ os.mkdir(os.path.join(cls.tmpDirectory, "gui"))
+ destination = os.path.join(cls.tmpDirectory, "gui", "icons")
+ os.mkdir(destination)
+ shutil.copy(silx.resources.resource_filename("gui/icons/zoom-in.png"), destination)
+ shutil.copy(silx.resources.resource_filename("gui/icons/zoom-out.svg"), destination)
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestIcons, cls).tearDownClass()
+ shutil.rmtree(cls.tmpDirectory)
+
+ def setUp(self):
+ # Store the original configuration
+ self._oldResources = dict(silx.resources._RESOURCE_DIRECTORIES)
+ silx.resources.register_resource_directory("test", "foo.bar", forced_path=self.tmpDirectory)
+ unittest.TestCase.setUp(self)
+
+ def tearDown(self):
+ unittest.TestCase.tearDown(self)
+ # Restiture the original configuration
+ silx.resources._RESOURCE_DIRECTORIES = self._oldResources
+
+ def testIcon(self):
+ icon = icons.getQIcon("silx:gui/icons/zoom-out")
+ self.assertIsNotNone(icon)
+
+ def testPrefix(self):
+ icon = icons.getQIcon("silx:gui/icons/zoom-out")
+ self.assertIsNotNone(icon)
+
+ def testSvgIcon(self):
+ if "svg" not in qt.supportedImageFormats():
+ self.skipTest("SVG not supported")
+ icon = icons.getQIcon("test:gui/icons/zoom-out")
+ self.assertIsNotNone(icon)
+
+ def testPngIcon(self):
+ icon = icons.getQIcon("test:gui/icons/zoom-in")
+ self.assertIsNotNone(icon)
+
+ def testUnexistingIcon(self):
+ self.assertRaises(ValueError, icons.getQIcon, "not-exists")
+
+ def testExistingQPixmap(self):
+ icon = icons.getQPixmap("crop")
+ self.assertIsNotNone(icon)
+
+ def testUnexistingQPixmap(self):
+ self.assertRaises(ValueError, icons.getQPixmap, "not-exists")
+
+ def testCache(self):
+ icon1 = icons.getQIcon("crop")
+ icon2 = icons.getQIcon("crop")
+ self.assertIs(icon1, icon2)
+
+ def testCacheReleased(self):
+ icon = icons.getQIcon("crop")
+ icon_ref = weakref.ref(icon)
+ del icon
+ self.assertIsNone(icon_ref())
+
+
+class TestAnimatedIcons(TestCaseQt):
+ """Test to check that icons module."""
+
+ def testProcessWorking(self):
+ icon = icons.getWaitIcon()
+ self.assertIsNotNone(icon)
+
+ def testProcessWorkingCache(self):
+ icon1 = icons.getWaitIcon()
+ icon2 = icons.getWaitIcon()
+ self.assertIs(icon1, icon2)
+
+ def testMovieIconExists(self):
+ if "mng" not in qt.supportedImageFormats():
+ self.skipTest("MNG not supported")
+ icon = icons.MovieAnimatedIcon("process-working")
+ self.assertIsNotNone(icon)
+
+ def testMovieIconNotExists(self):
+ self.assertRaises(ValueError, icons.MovieAnimatedIcon, "not-exists")
+
+ def testMultiImageIconExists(self):
+ icon = icons.MultiImageAnimatedIcon("process-working")
+ self.assertIsNotNone(icon)
+
+ def testPrefixedResourceExists(self):
+ icon = icons.MultiImageAnimatedIcon("silx:gui/icons/process-working")
+ self.assertIsNotNone(icon)
+
+ def testMultiImageIconNotExists(self):
+ self.assertRaises(ValueError, icons.MultiImageAnimatedIcon, "not-exists")
diff --git a/src/silx/gui/test/test_qt.py b/src/silx/gui/test/test_qt.py
new file mode 100644
index 0000000..8554744
--- /dev/null
+++ b/src/silx/gui/test/test_qt.py
@@ -0,0 +1,212 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of Qt bindings wrapper."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import os.path
+import unittest
+import pytest
+
+from silx.test.utils import temp_dir
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui import qt
+try:
+ from silx.gui.qt import inspect as qt_inspect
+except ImportError:
+ qt_inspect = None
+
+
+class TestQtWrapper(unittest.TestCase):
+ """Minimalistic test to check that Qt has been loaded."""
+
+ def testQObject(self):
+ """Test that QObject is there."""
+ obj = qt.QObject()
+ self.assertTrue(obj is not None)
+
+
+class TestLoadUi(TestCaseQt):
+ """Test loadUi function"""
+
+ TEST_UI = """<?xml version="1.0" encoding="UTF-8"?>
+ <ui version="4.0">
+ <class>MainWindow</class>
+ <widget class="QMainWindow" name="MainWindow">
+ <property name="geometry">
+ <rect>
+ <x>0</x>
+ <y>0</y>
+ <width>293</width>
+ <height>296</height>
+ </rect>
+ </property>
+ <property name="windowTitle">
+ <string>Test loadUi</string>
+ </property>
+ <widget class="QWidget" name="centralwidget">
+ <widget class="QPushButton" name="pushButton">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>10</y>
+ <width>89</width>
+ <height>27</height>
+ </rect>
+ </property>
+ <property name="text">
+ <string>Button 1</string>
+ </property>
+ </widget>
+ <widget class="QPushButton" name="pushButton_2">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>50</y>
+ <width>89</width>
+ <height>27</height>
+ </rect>
+ </property>
+ <property name="text">
+ <string>Button 2</string>
+ </property>
+ </widget>
+ <widget class="Line" name="line">
+ <property name="geometry">
+ <rect>
+ <x>10</x>
+ <y>90</y>
+ <width>118</width>
+ <height>3</height>
+ </rect>
+ </property>
+ <property name="orientation">
+ <enum>Qt::Horizontal</enum>
+ </property>
+ </widget>
+ <widget class="Line" name="line_2">
+ <property name="geometry">
+ <rect>
+ <x>150</x>
+ <y>20</y>
+ <width>3</width>
+ <height>61</height>
+ </rect>
+ </property>
+ <property name="orientation">
+ <enum>Qt::Vertical</enum>
+ </property>
+ </widget>
+ </widget>
+ <widget class="QMenuBar" name="menubar">
+ <property name="geometry">
+ <rect>
+ <x>0</x>
+ <y>0</y>
+ <width>293</width>
+ <height>25</height>
+ </rect>
+ </property>
+ </widget>
+ <widget class="QStatusBar" name="statusbar"/>
+ </widget>
+ <resources/>
+ <connections/>
+ </ui>
+ """
+
+ def testLoadUi(self):
+ """Create a QMainWindow from an ui file"""
+ with temp_dir() as tmp:
+ uifile = os.path.join(tmp, "test.ui")
+
+ # write file
+ with open(uifile, mode='w') as f:
+ f.write(self.TEST_UI)
+
+ class TestMainWindow(qt.QMainWindow):
+ def __init__(self, parent=None):
+ super(TestMainWindow, self).__init__(parent)
+ qt.loadUi(uifile, self)
+
+ testMainWindow = TestMainWindow()
+ testMainWindow.show()
+ self.qWaitForWindowExposed(testMainWindow)
+
+ testMainWindow.setAttribute(qt.Qt.WA_DeleteOnClose)
+ testMainWindow.close()
+
+
+class TestQtInspect(unittest.TestCase):
+ """Test functions of silx.gui.qt.inspect module"""
+
+ def test(self):
+ """Test functions of silx.gui.qt.inspect module"""
+ self.assertIsNotNone(qt_inspect)
+
+ parent = qt.QObject()
+
+ self.assertTrue(qt_inspect.isValid(parent))
+ self.assertTrue(qt_inspect.createdByPython(parent))
+ self.assertTrue(qt_inspect.ownedByPython(parent))
+
+ obj = qt.QObject(parent)
+
+ self.assertTrue(qt_inspect.isValid(obj))
+ self.assertTrue(qt_inspect.createdByPython(obj))
+ self.assertFalse(qt_inspect.ownedByPython(obj))
+
+ del parent
+ self.assertFalse(qt_inspect.isValid(obj))
+
+
+@pytest.mark.skipif(qt.BINDING not in ("PyQt5", "PySide2"),
+ reason="PyQt5/PySide2 only test")
+def test_exec_():
+ """Test the exec_ is still useable with Qt5 bindings"""
+ klasses = [
+ #QtWidgets
+ qt.QApplication,
+ qt.QColorDialog,
+ qt.QDialog,
+ qt.QErrorMessage,
+ qt.QFileDialog,
+ qt.QFontDialog,
+ qt.QInputDialog,
+ qt.QMenu,
+ qt.QMessageBox,
+ qt.QProgressDialog,
+ #QtCore
+ qt.QCoreApplication,
+ qt.QEventLoop,
+ qt.QThread,
+ ]
+ for klass in klasses:
+ assert hasattr(klass, "exec") and callable(klass.exec), "%s.exec missing" % klass.__name__
+ assert hasattr(klass, "exec_") and callable(klass.exec_), "%s.exec_ missing" % klass.__name__
diff --git a/silx/gui/test/utils.py b/src/silx/gui/test/utils.py
index db4c0ee..db4c0ee 100644
--- a/silx/gui/test/utils.py
+++ b/src/silx/gui/test/utils.py
diff --git a/silx/gui/utils/__init__.py b/src/silx/gui/utils/__init__.py
index 726ad74..726ad74 100755
--- a/silx/gui/utils/__init__.py
+++ b/src/silx/gui/utils/__init__.py
diff --git a/silx/gui/utils/concurrent.py b/src/silx/gui/utils/concurrent.py
index c27374f..c27374f 100644
--- a/silx/gui/utils/concurrent.py
+++ b/src/silx/gui/utils/concurrent.py
diff --git a/src/silx/gui/utils/glutils/__init__.py b/src/silx/gui/utils/glutils/__init__.py
new file mode 100644
index 0000000..20e611e
--- /dev/null
+++ b/src/silx/gui/utils/glutils/__init__.py
@@ -0,0 +1,199 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides the :func:`isOpenGLAvailable` utility function.
+"""
+
+import os
+import sys
+import subprocess
+from silx.gui import qt
+
+
+class _isOpenGLAvailableResult:
+ """Store result of checking OpenGL availability.
+
+ It provides a `status` boolean attribute storing the result of the check and
+ an `error` string attribute storting the possible error message.
+ """
+
+ def __init__(self, status=True, error=''):
+ self.__status = bool(status)
+ self.__error = str(error)
+
+ status = property(lambda self: self.__status, doc="True if OpenGL is working")
+ error = property(lambda self: self.__error, doc="Error message")
+
+ def __bool__(self):
+ return self.status
+
+ def __repr__(self):
+ return '<_isOpenGLAvailableResult: %s, "%s">' % (self.status, self.error)
+
+
+def _runtimeOpenGLCheck(version):
+ """Run OpenGL check in a subprocess.
+
+ This is done by starting a subprocess that displays a Qt OpenGL widget.
+
+ :param List[int] version:
+ The minimal required OpenGL version as a 2-tuple (major, minor).
+ Default: (2, 1)
+ :return: An error string that is empty if no error occured
+ :rtype: str
+ """
+ major, minor = str(version[0]), str(version[1])
+ env = os.environ.copy()
+ env['PYTHONPATH'] = os.pathsep.join(
+ [os.path.abspath(p) for p in sys.path])
+
+ try:
+ error = subprocess.check_output(
+ [sys.executable, '-s', '-S', __file__, major, minor],
+ env=env,
+ timeout=2)
+ except subprocess.TimeoutExpired:
+ status = False
+ error = "Qt OpenGL widget hang"
+ if sys.platform.startswith('linux'):
+ error += ':\nIf connected remotely, GLX forwarding might be disabled.'
+ except subprocess.CalledProcessError as e:
+ status = False
+ error = "Qt OpenGL widget error: retcode=%d, error=%s" % (e.returncode, e.output)
+ else:
+ status = True
+ error = error.decode()
+ return _isOpenGLAvailableResult(status, error)
+
+
+_runtimeCheckCache = {} # Cache runtime check results: {version: result}
+
+
+def isOpenGLAvailable(version=(2, 1), runtimeCheck=True):
+ """Check if OpenGL is available through Qt and actually working.
+
+ After some basic tests, this is done by starting a subprocess that
+ displays a Qt OpenGL widget.
+
+ :param List[int] version:
+ The minimal required OpenGL version as a 2-tuple (major, minor).
+ Default: (2, 1)
+ :param bool runtimeCheck:
+ True (default) to run the test creating a Qt OpenGL widgt in a subprocess,
+ False to avoid this check.
+ :return: A result object that evaluates to True if successful and
+ which has a `status` boolean attribute (True if successful) and
+ an `error` string attribute that is not empty if `status` is False.
+ """
+ error = ''
+
+ if sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ # On Linux and no DISPLAY available (e.g., ssh without -X)
+ error = 'DISPLAY environment variable not set'
+
+ else:
+ # Check pyopengl availability
+ try:
+ import silx.gui._glutils.gl # noqa
+ except ImportError:
+ error = "Cannot import OpenGL wrapper: pyopengl is not installed"
+ else:
+ # Pre checks for Qt < 5.4
+ if not hasattr(qt, 'QOpenGLWidget'):
+ if not qt.HAS_OPENGL:
+ error = '%s.QtOpenGL not available' % qt.BINDING
+
+ elif qt.BINDING in ('PySide2', 'PyQt5') and qt.QApplication.instance() and not qt.QGLFormat.hasOpenGL():
+ # qt.QGLFormat.hasOpenGL MUST be called with a QApplication created
+ # so this is only checked if the QApplication is already created
+ error = 'Qt reports OpenGL not available'
+
+ result = _isOpenGLAvailableResult(error == '', error)
+
+ if result: # No error so far, runtime check
+ if version in _runtimeCheckCache: # Use cache
+ result = _runtimeCheckCache[version]
+ elif runtimeCheck: # Run test in subprocess
+ result = _runtimeOpenGLCheck(version)
+ _runtimeCheckCache[version] = result
+
+ return result
+
+
+if __name__ == "__main__":
+ from silx.gui._glutils import OpenGLWidget
+ from silx.gui._glutils import gl
+ import argparse
+
+ class _TestOpenGLWidget(OpenGLWidget):
+ """Widget checking that OpenGL is indeed available
+
+ :param List[int] version: (major, minor) minimum OpenGL version
+ """
+
+ def __init__(self, version):
+ super(_TestOpenGLWidget, self).__init__(
+ alphaBufferSize=0,
+ depthBufferSize=0,
+ stencilBufferSize=0,
+ version=version)
+
+ def paintEvent(self, event):
+ super(_TestOpenGLWidget, self).paintEvent(event)
+
+ # Check once paint has been done
+ app = qt.QApplication.instance()
+ if not self.isValid():
+ print("OpenGL widget is not valid")
+ app.exit(1)
+ else:
+ qt.QTimer.singleShot(100, app.quit)
+
+ def paintGL(self):
+ gl.glClearColor(1., 0., 0., 0.)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
+
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('major')
+ parser.add_argument('minor')
+
+ args = parser.parse_args(args=sys.argv[1:])
+
+ app = qt.QApplication([])
+ window = qt.QMainWindow(flags=
+ qt.Qt.Popup |
+ qt.Qt.FramelessWindowHint |
+ qt.Qt.NoDropShadowWindowHint |
+ qt.Qt.WindowStaysOnTopHint)
+ window.setAttribute(qt.Qt.WA_ShowWithoutActivating)
+ window.move(0, 0)
+ window.resize(3, 3)
+ widget = _TestOpenGLWidget(version=(args.major, args.minor))
+ window.setCentralWidget(widget)
+ window.setWindowOpacity(0.04)
+ window.show()
+
+ qt.QTimer.singleShot(1000, app.quit)
+ sys.exit(app.exec())
diff --git a/src/silx/gui/utils/image.py b/src/silx/gui/utils/image.py
new file mode 100644
index 0000000..96f50ab
--- /dev/null
+++ b/src/silx/gui/utils/image.py
@@ -0,0 +1,143 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides conversions between numpy.ndarray and QImage
+
+- :func:`convertArrayToQImage`
+- :func:`convertQImageToArray`
+"""
+
+from __future__ import division
+
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "04/09/2018"
+
+
+import sys
+import numpy
+from numpy.lib.stride_tricks import as_strided as _as_strided
+
+from .. import qt
+
+
+def convertArrayToQImage(array):
+ """Convert an array-like image to a QImage.
+
+ The created QImage is using a copy of the array data.
+
+ Limitation: Only RGB or RGBA images with 8 bits per channel are supported.
+
+ :param array: Array-like image data of shape (height, width, channels)
+ Channels are expected to be either RGB or RGBA.
+ :type array: numpy.ndarray of uint8
+ :return: Corresponding Qt image with RGB888 or ARGB32 format.
+ :rtype: QImage
+ """
+ array = numpy.array(array, copy=False, order='C', dtype=numpy.uint8)
+
+ if array.ndim != 3 or array.shape[2] not in (3, 4):
+ raise ValueError(
+ 'Image must be a 3D array with 3 or 4 channels per pixel')
+
+ if array.shape[2] == 4:
+ format_ = qt.QImage.Format_ARGB32
+ # RGBA -> ARGB + take care of endianness
+ if sys.byteorder == 'little': # RGBA -> BGRA
+ array = array[:, :, (2, 1, 0, 3)]
+ else: # big endian: RGBA -> ARGB
+ array = array[:, :, (3, 0, 1, 2)]
+
+ array = numpy.array(array, order='C') # Make a contiguous array
+
+ else: # array.shape[2] == 3
+ format_ = qt.QImage.Format_RGB888
+
+ height, width, depth = array.shape
+ qimage = qt.QImage(
+ array.data,
+ width,
+ height,
+ array.strides[0], # bytesPerLine
+ format_)
+
+ return qimage.copy() # Making a copy of the image and its data
+
+
+def convertQImageToArray(image):
+ """Convert a QImage to a numpy array.
+
+ If QImage format is not Format_RGB888, Format_RGBA8888 or Format_ARGB32,
+ it is first converted to one of this format depending on
+ the presence of an alpha channel.
+
+ The created numpy array is using a copy of the QImage data.
+
+ :param QImage image: The QImage to convert.
+ :return: The image array of RGB or RGBA channels of shape
+ (height, width, channels (3 or 4))
+ :rtype: numpy.ndarray of uint8
+ """
+ rgba8888 = getattr(qt.QImage, 'Format_RGBA8888', None) # Only in Qt5
+
+ # Convert to supported format if needed
+ if image.format() not in (qt.QImage.Format_ARGB32,
+ qt.QImage.Format_RGB888,
+ rgba8888):
+ if image.hasAlphaChannel():
+ image = image.convertToFormat(
+ rgba8888 if rgba8888 is not None else qt.QImage.Format_ARGB32)
+ else:
+ image = image.convertToFormat(qt.QImage.Format_RGB888)
+
+ format_ = image.format()
+ channels = 3 if format_ == qt.QImage.Format_RGB888 else 4
+
+ ptr = image.bits()
+ if qt.BINDING == 'PyQt5':
+ ptr.setsize(image.byteCount())
+ elif qt.BINDING in ('PySide2', 'PySide6'):
+ ptr = ptr.tobytes()
+ else:
+ raise RuntimeError("Unsupported Qt binding: %s" % qt.BINDING)
+
+ # Create an array view on QImage internal data
+ view = _as_strided(
+ numpy.frombuffer(ptr, dtype=numpy.uint8),
+ shape=(image.height(), image.width(), channels),
+ strides=(image.bytesPerLine(), channels, 1))
+
+ if format_ == qt.QImage.Format_ARGB32:
+ # Convert from ARGB to RGBA
+ # Not a byte-ordered format: do care about endianness
+ if sys.byteorder == 'little': # BGRA -> RGBA
+ view = view[:, :, (2, 1, 0, 3)]
+ else: # big endian: ARGB -> RGBA
+ view = view[:, :, (1, 2, 3, 0)]
+
+ # Format_RGB888 and Format_RGBA8888 do not need reshuffling channels:
+ # They are byte-ordered and already in the right order
+
+ return numpy.array(view, copy=True, order='C')
diff --git a/src/silx/gui/utils/matplotlib.py b/src/silx/gui/utils/matplotlib.py
new file mode 100644
index 0000000..90257f8
--- /dev/null
+++ b/src/silx/gui/utils/matplotlib.py
@@ -0,0 +1,65 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+from __future__ import absolute_import
+
+"""This module initializes matplotlib and sets-up the backend to use.
+
+It MUST be imported prior to any other import of matplotlib.
+
+It provides the matplotlib :class:`FigureCanvasQTAgg` class corresponding
+to the used backend.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/05/2018"
+
+
+from pkg_resources import parse_version
+import matplotlib
+
+from .. import qt
+
+
+def _matplotlib_use(backend, force):
+ """Wrapper of `matplotlib.use` to set-up backend.
+
+ It adds extra initialization for PySide2 with matplotlib < 2.2.
+ """
+ # This is kept for compatibility with matplotlib < 2.2
+ if (parse_version(matplotlib.__version__) < parse_version('2.2') and
+ qt.BINDING == 'PySide2'):
+ matplotlib.rcParams['backend.qt5'] = 'PySide2'
+
+ matplotlib.use(backend, force=force)
+
+
+if qt.BINDING in ('PySide6', 'PyQt5', 'PySide2'):
+ _matplotlib_use('Qt5Agg', force=False)
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg # noqa
+
+else:
+ raise ImportError("Unsupported Qt binding: %s" % qt.BINDING)
diff --git a/silx/gui/utils/projecturl.py b/src/silx/gui/utils/projecturl.py
index 0832c2e..0832c2e 100644
--- a/silx/gui/utils/projecturl.py
+++ b/src/silx/gui/utils/projecturl.py
diff --git a/silx/gui/utils/qtutils.py b/src/silx/gui/utils/qtutils.py
index 9682913..9682913 100755
--- a/silx/gui/utils/qtutils.py
+++ b/src/silx/gui/utils/qtutils.py
diff --git a/silx/gui/utils/signal.py b/src/silx/gui/utils/signal.py
index 359f5cc..359f5cc 100644
--- a/silx/gui/utils/signal.py
+++ b/src/silx/gui/utils/signal.py
diff --git a/src/silx/gui/utils/test/__init__.py b/src/silx/gui/utils/test/__init__.py
new file mode 100755
index 0000000..15cd186
--- /dev/null
+++ b/src/silx/gui/utils/test/__init__.py
@@ -0,0 +1,25 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""silx.gui.utils tests"""
diff --git a/src/silx/gui/utils/test/test.py b/src/silx/gui/utils/test/test.py
new file mode 100644
index 0000000..0208d64
--- /dev/null
+++ b/src/silx/gui/utils/test/test.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of functions available in silx.gui.utils module."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/08/2019"
+
+
+import unittest
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt, SignalListener
+
+from silx.gui.utils import blockSignals
+
+
+class TestBlockSignals(TestCaseQt):
+ """Test blockSignals context manager"""
+
+ def _test(self, *objs):
+ """Test for provided objects"""
+ listener = SignalListener()
+ for obj in objs:
+ obj.objectNameChanged.connect(listener)
+ obj.setObjectName("received")
+
+ with blockSignals(*objs):
+ for obj in objs:
+ obj.setObjectName("silent")
+
+ self.assertEqual(listener.arguments(), [("received",)] * len(objs))
+
+ def testManyObjects(self):
+ """Test blockSignals with 2 QObjects"""
+ self._test(qt.QObject(), qt.QObject())
+
+ def testOneObject(self):
+ """Test blockSignals context manager with a single QObject"""
+ self._test(qt.QObject())
diff --git a/src/silx/gui/utils/test/test_async.py b/src/silx/gui/utils/test/test_async.py
new file mode 100644
index 0000000..7304ca9
--- /dev/null
+++ b/src/silx/gui/utils/test/test_async.py
@@ -0,0 +1,127 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of async module."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "09/03/2018"
+
+
+import threading
+import unittest
+
+
+from concurrent.futures import wait
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+
+from silx.gui.utils import concurrent
+
+
+class TestSubmitToQtThread(TestCaseQt):
+ """Test submission of tasks to Qt main thread"""
+
+ def setUp(self):
+ # Reset executor to test lazy-loading in different conditions
+ concurrent._executor = None
+ super(TestSubmitToQtThread, self).setUp()
+
+ def _task(self, value1, value2):
+ return value1, value2
+
+ def _taskWithException(self, *args, **kwargs):
+ raise RuntimeError('task exception')
+
+ def testFromMainThread(self):
+ """Call submitToQtMainThread from the main thread"""
+ value1, value2 = 0, 1
+ future = concurrent.submitToQtMainThread(self._task, value1, value2=value2)
+ self.assertTrue(future.done())
+ self.assertEqual(future.result(1), (value1, value2))
+ self.assertIsNone(future.exception(1))
+
+ future = concurrent.submitToQtMainThread(self._taskWithException)
+ self.assertTrue(future.done())
+ with self.assertRaises(RuntimeError):
+ future.result(1)
+ self.assertIsInstance(future.exception(1), RuntimeError)
+
+ def _threadedTest(self):
+ """Function run in a thread for the tests"""
+ value1, value2 = 0, 1
+ future = concurrent.submitToQtMainThread(self._task, value1, value2=value2)
+
+ wait([future], 3)
+
+ self.assertTrue(future.done())
+ self.assertEqual(future.result(1), (value1, value2))
+ self.assertIsNone(future.exception(1))
+
+ future = concurrent.submitToQtMainThread(self._taskWithException)
+
+ wait([future], 3)
+
+ self.assertTrue(future.done())
+ with self.assertRaises(RuntimeError):
+ future.result(1)
+ self.assertIsInstance(future.exception(1), RuntimeError)
+
+ def testFromPythonThread(self):
+ """Call submitToQtMainThread from a Python thread"""
+ thread = threading.Thread(target=self._threadedTest)
+ thread.start()
+ for i in range(100): # Loop over for 10 seconds
+ self.qapp.processEvents()
+ thread.join(0.1)
+ if not thread.is_alive():
+ break
+ else:
+ self.fail(('Thread task still running'))
+
+ def testFromQtThread(self):
+ """Call submitToQtMainThread from a Qt thread pool"""
+ class Runner(qt.QRunnable):
+ def __init__(self, fn):
+ super(Runner, self).__init__()
+ self._fn = fn
+
+ def run(self):
+ self._fn()
+
+ def autoDelete(self):
+ return True
+
+ threadPool = qt.silxGlobalThreadPool()
+ runner = Runner(self._threadedTest)
+ threadPool.start(runner)
+ for i in range(100): # Loop over for 10 seconds
+ self.qapp.processEvents()
+ done = threadPool.waitForDone(100)
+ if done:
+ break
+ else:
+ self.fail('Thread pool task still running')
diff --git a/src/silx/gui/utils/test/test_glutils.py b/src/silx/gui/utils/test/test_glutils.py
new file mode 100644
index 0000000..7c9831b
--- /dev/null
+++ b/src/silx/gui/utils/test/test_glutils.py
@@ -0,0 +1,55 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for the silx.gui.utils.glutils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/01/2020"
+
+
+import logging
+import unittest
+from silx.gui.utils.glutils import isOpenGLAvailable
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestIsOpenGLAvailable(unittest.TestCase):
+ """Test isOpenGLAvailable"""
+
+ def test(self):
+ for version in ((2, 1), (2, 1), (1000, 1)):
+ with self.subTest(version=version):
+ result = isOpenGLAvailable(version=version)
+ _logger.info("isOpenGLAvailable returned: %s", str(result))
+ if version[0] == 1000:
+ self.assertFalse(result)
+ if not result:
+ self.assertFalse(result.status)
+ self.assertTrue(len(result.error) > 0)
+ else:
+ self.assertTrue(result.status)
+ self.assertTrue(len(result.error) == 0)
diff --git a/src/silx/gui/utils/test/test_image.py b/src/silx/gui/utils/test/test_image.py
new file mode 100644
index 0000000..62316b0
--- /dev/null
+++ b/src/silx/gui/utils/test/test_image.py
@@ -0,0 +1,79 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of utils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+import numpy
+import unittest
+
+from silx.gui import qt
+from silx.utils.testutils import ParametricTestCase
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.image import convertArrayToQImage, convertQImageToArray
+
+
+class TestQImageConversion(TestCaseQt, ParametricTestCase):
+ """Tests conversion of QImage to/from numpy array."""
+
+ def testConvertArrayToQImage(self):
+ """Test conversion of numpy array to QImage"""
+ for format_, channels in [('Format_RGB888', 3),
+ ('Format_ARGB32', 4)]:
+ with self.subTest(format_):
+ image = numpy.arange(
+ 3*3*channels, dtype=numpy.uint8).reshape(3, 3, channels)
+ qimage = convertArrayToQImage(image)
+
+ self.assertEqual(qimage.height(), image.shape[0])
+ self.assertEqual(qimage.width(), image.shape[1])
+ self.assertEqual(qimage.format(), getattr(qt.QImage, format_))
+
+ for row in range(3):
+ for col in range(3):
+ # Qrgb has no alpha channel, not compared
+ # Qt uses x,y while array is row,col...
+ self.assertEqual(qt.QColor(qimage.pixel(col, row)),
+ qt.QColor(*image[row, col, :3]))
+
+
+ def testConvertQImageToArray(self):
+ """Test conversion of QImage to numpy array"""
+ for format_, channels in [
+ ('Format_RGB888', 3), # Native support
+ ('Format_ARGB32', 4), # Native support
+ ('Format_RGB32', 3)]: # Conversion to RGB
+ with self.subTest(format_):
+ color = numpy.arange(channels) # RGB(A) values
+ qimage = qt.QImage(3, 3, getattr(qt.QImage, format_))
+ qimage.fill(qt.QColor(*color))
+ image = convertQImageToArray(qimage)
+
+ self.assertEqual(qimage.height(), image.shape[0])
+ self.assertEqual(qimage.width(), image.shape[1])
+ self.assertEqual(image.shape[2], len(color))
+ self.assertTrue(numpy.all(numpy.equal(image, color)))
diff --git a/src/silx/gui/utils/test/test_qtutils.py b/src/silx/gui/utils/test/test_qtutils.py
new file mode 100755
index 0000000..c00280b
--- /dev/null
+++ b/src/silx/gui/utils/test/test_qtutils.py
@@ -0,0 +1,65 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of functions available in silx.gui.utils module."""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/08/2019"
+
+
+import unittest
+from silx.gui import qt
+from silx.gui import utils
+from silx.gui.utils.testutils import TestCaseQt
+
+
+class TestQEventName(TestCaseQt):
+ """Test QEvent names"""
+
+ def testNoneType(self):
+ result = utils.getQEventName(0)
+ self.assertEqual(result, "None")
+
+ def testNoneEvent(self):
+ event = qt.QEvent(qt.QEvent.Type(0))
+ result = utils.getQEventName(event)
+ self.assertEqual(result, "None")
+
+ def testUserType(self):
+ result = utils.getQEventName(1050)
+ self.assertIn("User", result)
+ self.assertIn("1050", result)
+
+ def testQtUndefinedType(self):
+ result = utils.getQEventName(900)
+ self.assertIn("Unknown", result)
+ self.assertIn("900", result)
+
+ def testUndefinedType(self):
+ result = utils.getQEventName(70000)
+ self.assertIn("Unknown", result)
+ self.assertIn("70000", result)
diff --git a/src/silx/gui/utils/test/test_testutils.py b/src/silx/gui/utils/test/test_testutils.py
new file mode 100644
index 0000000..07294a7
--- /dev/null
+++ b/src/silx/gui/utils/test/test_testutils.py
@@ -0,0 +1,44 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of testutils module."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+import unittest
+import sys
+
+from silx.gui import qt
+from ..testutils import TestCaseQt
+
+
+class TestOutcome(unittest.TestCase):
+ """Tests conversion of QImage to/from numpy array."""
+
+ @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
+ def testNoneOutcome(self):
+ test = TestCaseQt()
+ test._currentTestSucceeded()
diff --git a/src/silx/gui/utils/testutils.py b/src/silx/gui/utils/testutils.py
new file mode 100644
index 0000000..40c8237
--- /dev/null
+++ b/src/silx/gui/utils/testutils.py
@@ -0,0 +1,508 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Helper class to write Qt widget unittests."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "05/10/2018"
+
+
+import gc
+import logging
+import unittest
+import time
+import functools
+import sys
+import os
+
+_logger = logging.getLogger(__name__)
+
+from silx.gui import qt
+from silx.gui.qt import inspect as _inspect
+
+
+if qt.BINDING == 'PySide2':
+ from PySide2.QtTest import QTest
+elif qt.BINDING == 'PyQt5':
+ from PyQt5.QtTest import QTest
+elif qt.BINDING == 'PySide6':
+ from PySide6.QtTest import QTest
+else:
+ raise ImportError('Unsupported Qt bindings')
+
+
+def qWaitForWindowExposedAndActivate(window, timeout=None):
+ """Waits until the window is shown in the screen.
+
+ It also activates the window and raises it.
+
+ See QTest.qWaitForWindowExposed for details.
+ """
+ if timeout is None:
+ result = QTest.qWaitForWindowExposed(window)
+ else:
+ result = QTest.qWaitForWindowExposed(window, timeout)
+
+ if result:
+ # Makes sure window is active and on top
+ window.activateWindow()
+ window.raise_()
+
+ return result
+
+
+class TestCaseQt(unittest.TestCase):
+ """Base class to write test for Qt stuff.
+
+ It creates a QApplication before running the tests.
+ WARNING: The QApplication is shared by all tests, which might have side
+ effects.
+
+ After each test, this class is checking for widgets remaining alive.
+ To allow some widgets to remain alive at the end of a test, set the
+ allowedLeakingWidgets attribute to the number of widgets that can remain
+ alive at the end of the test.
+ With PySide2, this test is not run for now as it seems PySide2
+ is leaking widgets internally.
+
+ All keyboard and mouse event simulation methods call qWait(20) after
+ simulating the event (as QTest does on Mac OSX).
+ This was introduced to fix issues with continuous integration tests
+ running with Xvfb on Linux.
+ """
+
+ DEFAULT_TIMEOUT_WAIT = 100
+ """Default timeout for qWait"""
+
+ TIMEOUT_WAIT = 0
+ """Extra timeout in millisecond to add to qSleep, qWait and
+ qWaitForWindowExposed.
+
+ Intended purpose is for debugging, to add extra time to waits in order to
+ allow to view the tested widgets.
+ """
+
+ _qapp = None
+ """Placeholder for QApplication"""
+
+ @classmethod
+ def exceptionHandler(cls, exceptionClass, exception, stack):
+ import traceback
+ message = (''.join(traceback.format_tb(stack)))
+ template = 'Traceback (most recent call last):\n{2}{0}: {1}'
+ message = template.format(exceptionClass.__name__, exception, message)
+ cls._exceptions.append(message)
+
+ @classmethod
+ def setUpClass(cls):
+ """Makes sure Qt is inited"""
+ cls._oldExceptionHook = sys.excepthook
+ sys.excepthook = cls.exceptionHandler
+
+ # Makes sure a QApplication exists and do it once for all
+ if not qt.QApplication.instance():
+ cls._qapp = qt.QApplication([])
+
+ @classmethod
+ def tearDownClass(cls):
+ sys.excepthook = cls._oldExceptionHook
+
+ def setUp(self):
+ """Get the list of existing widgets."""
+ self.allowedLeakingWidgets = 0
+ if qt.BINDING in ('PySide2', 'PySide6'):
+ self.__previousWidgets = None
+ else:
+ self.__previousWidgets = self.qapp.allWidgets()
+ self.__class__._exceptions = []
+
+ def _currentTestSucceeded(self):
+ if hasattr(self, '_outcome'):
+ # For Python >= 3.4
+ result = self.defaultTestResult() # these 2 methods have no side effects
+ if hasattr(self._outcome, 'errors'):
+ self._feedErrorsToResult(result, self._outcome.errors)
+ else:
+ # For Python < 3.4
+ result = getattr(self, '_outcomeForDoCleanups', self._resultForDoCleanups)
+
+ skipped = self.id() in [case.id() for case, _ in result.skipped]
+ error = self.id() in [case.id() for case, _ in result.errors]
+ failure = self.id() in [case.id() for case, _ in result.failures]
+ return not error and not failure and not skipped
+
+ def _checkForUnreleasedWidgets(self):
+ """Test fixture checking that no more widgets exists."""
+ gc.collect()
+
+ if self.__previousWidgets is None:
+ return # Do not test for leaking widgets with PySide2
+
+ widgets = [widget for widget in self.qapp.allWidgets()
+ if (widget not in self.__previousWidgets and
+ _inspect.createdByPython(widget))]
+ self.__previousWidgets = None
+
+ allowedLeakingWidgets = self.allowedLeakingWidgets
+ self.allowedLeakingWidgets = 0
+
+ if widgets and len(widgets) <= allowedLeakingWidgets:
+ _logger.info(
+ '%s: %d remaining widgets after test' % (self.id(),
+ len(widgets)))
+
+ if len(widgets) > allowedLeakingWidgets:
+ raise RuntimeError(
+ "Test ended with widgets alive: %s" % str(widgets))
+
+ def tearDown(self):
+ self.qapp.processEvents()
+
+ if len(self.__class__._exceptions) > 0:
+ messages = "\n".join(self.__class__._exceptions)
+ raise AssertionError("Exception occured in Qt thread:\n" + messages)
+
+ if self._currentTestSucceeded():
+ self._checkForUnreleasedWidgets()
+
+ @property
+ def qapp(self):
+ """The QApplication currently running."""
+ return qt.QApplication.instance()
+
+ # Proxy to QTest
+
+ Press = QTest.Press
+ """Key press action code"""
+
+ Release = QTest.Release
+ """Key release action code"""
+
+ Click = QTest.Click
+ """Key click action code"""
+
+ QTest = property(lambda self: QTest,
+ doc="""The Qt QTest class from the used Qt binding.""")
+
+ def keyClick(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Simulate clicking a key.
+
+ See QTest.keyClick for details.
+ """
+ QTest.keyClick(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyClicks(self, widget, sequence, modifier=qt.Qt.NoModifier, delay=-1):
+ """Simulate clicking a sequence of keys.
+
+ See QTest.keyClick for details.
+ """
+ QTest.keyClicks(widget, sequence, modifier, delay)
+ self.qWait(20)
+
+ def keyEvent(self, action, widget, key,
+ modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key event.
+
+ See QTest.keyEvent for details.
+ """
+ QTest.keyEvent(action, widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyPress(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key press event.
+
+ See QTest.keyPress for details.
+ """
+ QTest.keyPress(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def keyRelease(self, widget, key, modifier=qt.Qt.NoModifier, delay=-1):
+ """Sends a Qt key release event.
+
+ See QTest.keyRelease for details.
+ """
+ QTest.keyRelease(widget, key, modifier, delay)
+ self.qWait(20)
+
+ def mouseClick(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate clicking a mouse button.
+
+ See QTest.mouseClick for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseClick(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseDClick(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate double clicking a mouse button.
+
+ See QTest.mouseDClick for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseDClick(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseMove(self, widget, pos=None, delay=-1):
+ """Simulate moving the mouse.
+
+ See QTest.mouseMove for details.
+ """
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseMove(widget, pos, delay)
+ self.qWait(20)
+
+ def mousePress(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate pressing a mouse button.
+
+ See QTest.mousePress for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mousePress(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def mouseRelease(self, widget, button, modifier=None, pos=None, delay=-1):
+ """Simulate releasing a mouse button.
+
+ See QTest.mouseRelease for details.
+ """
+ if modifier is None:
+ modifier = qt.Qt.KeyboardModifiers()
+ pos = qt.QPoint(int(pos[0]), int(pos[1])) if pos is not None else qt.QPoint()
+ QTest.mouseRelease(widget, button, modifier, pos, delay)
+ self.qWait(20)
+
+ def qSleep(self, ms):
+ """Sleep for ms milliseconds, blocking the execution of the test.
+
+ See QTest.qSleep for details.
+ """
+ QTest.qSleep(int(ms) + self.TIMEOUT_WAIT)
+
+ @classmethod
+ def qWait(cls, ms=None):
+ """Waits for ms milliseconds, events will be processed.
+
+ See QTest.qWait for details.
+ """
+ if ms is None:
+ ms = cls.DEFAULT_TIMEOUT_WAIT
+
+ if qt.BINDING in ('PySide2', 'PySide6'):
+ # PySide2 has no qWait, provide a replacement
+ timeout = int(ms)
+ endTimeMS = int(time.time() * 1000) + timeout
+ qapp = qt.QApplication.instance()
+ while timeout > 0:
+ qapp.processEvents(qt.QEventLoop.AllEvents,
+ timeout)
+ timeout = endTimeMS - int(time.time() * 1000)
+ else:
+ QTest.qWait(int(ms) + cls.TIMEOUT_WAIT)
+
+ def qWaitForWindowExposed(self, window, timeout=None):
+ """Waits until the window is shown in the screen.
+
+ See QTest.qWaitForWindowExposed for details.
+ """
+ result = qWaitForWindowExposedAndActivate(window, timeout)
+
+ if self.TIMEOUT_WAIT:
+ QTest.qWait(self.TIMEOUT_WAIT)
+
+ return result
+
+ def exposeAndClose(self, widget):
+ """Wait for expose a widget, flag it delete on close, and close it."""
+ self.qWaitForWindowExposed(widget)
+ self.qapp.processEvents()
+ widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ widget.close()
+
+ _qobject_destroyed = False
+
+ @classmethod
+ def _aboutToDestroy(cls):
+ cls._qobject_destroyed = True
+
+ @classmethod
+ def qWaitForDestroy(cls, ref):
+ """
+ Wait for Qt object destruction.
+
+ Use a weakref as parameter to avoid any strong references to the
+ object.
+
+ It have to be used as following. Removing the reference to the object
+ before calling the function looks to be expected, else
+ :meth:`deleteLater` will not work.
+
+ .. code-block:: python
+
+ ref = weakref.ref(self.obj)
+ self.obj = None
+ self.qWaitForDestroy(ref)
+
+ :param weakref ref: A weakref to an object to avoid any reference
+ :return: True if the object was destroyed
+ :rtype: bool
+ """
+ cls._qobject_destroyed = False
+ qobject = ref()
+ if qobject is None:
+ return True
+ qobject.destroyed.connect(cls._aboutToDestroy)
+ qobject.deleteLater()
+ qobject = None
+ for _ in range(10):
+ if cls._qobject_destroyed:
+ break
+ cls.qWait(10)
+ else:
+ _logger.debug("Object was not destroyed")
+
+ return ref() is None
+
+ def logScreenShot(self, level=logging.ERROR):
+ """Take a screenshot and log it into the logging system if the
+ logger is enabled for the expected level.
+
+ The screenshot is stored in the directory "./build/test-debug", and
+ the logging system only log the path to this file.
+
+ :param level: Logging level
+ """
+ if not _logger.isEnabledFor(level):
+ return
+ basedir = os.path.abspath(os.path.join("build", "test-debug"))
+ if not os.path.exists(basedir):
+ os.makedirs(basedir)
+ filename = "Screenshot_%s.png" % self.id()
+ filename = os.path.join(basedir, filename)
+
+ screen = self.qapp.primaryScreen()
+ pixmap = screen.grabWindow(0)
+ pixmap.save(filename)
+ _logger.log(level, "Screenshot saved at %s", filename)
+
+
+class SignalListener(object):
+ """Util to listen a Qt event and store parameters
+ """
+
+ def __init__(self):
+ self.__calls = []
+
+ def __call__(self, *args, **kargs):
+ self.__calls.append((args, kargs))
+
+ def clear(self):
+ """Clear stored data"""
+ self.__calls = []
+
+ def callCount(self):
+ """
+ Returns how many times the listener was called.
+
+ :rtype: int
+ """
+ return len(self.__calls)
+
+ def arguments(self, callIndex=None, argumentIndex=None):
+ """Returns positional arguments optionally filtered by call count id
+ or argument index.
+
+ :param int callIndex: Index of the called data
+ :param int argumentIndex: Index of the positional argument.
+ """
+ if callIndex is not None:
+ result = self.__calls[callIndex][0]
+ if argumentIndex is not None:
+ result = result[argumentIndex]
+ else:
+ result = [x[0] for x in self.__calls]
+ if argumentIndex is not None:
+ result = [x[argumentIndex] for x in result]
+ return result
+
+ def karguments(self, callIndex=None, argumentName=None):
+ """Returns positional arguments optionally filtered by call count id
+ or name of the keyword argument.
+
+ :param int callIndex: Index of the called data
+ :param int argumentName: Name of the keyword argument.
+ """
+ if callIndex is not None:
+ result = self.__calls[callIndex][1]
+ if argumentName is not None:
+ result = result[argumentName]
+ else:
+ result = [x[1] for x in self.__calls]
+ if argumentName is not None:
+ result = [x[argumentName] for x in result]
+ return result
+
+ def partial(self, *args, **kargs):
+ """Returns a new partial object which when called will behave like this
+ listener called with the positional arguments args and keyword
+ arguments keywords. If more arguments are supplied to the call, they
+ are appended to args. If additional keyword arguments are supplied,
+ they extend and override keywords.
+ """
+ return functools.partial(self, *args, **kargs)
+
+
+def getQToolButtonFromAction(action):
+ """Return a QToolButton corresponding to a QAction.
+
+ :param QAction action: The QAction from which to get QToolButton.
+ :return: A QToolButton associated to action or None.
+ """
+ if qt.BINDING == "PySide6":
+ widgets = action.associatedObjects()
+ else:
+ widgets = action.associatedWidgets()
+
+ for widget in widgets:
+ if isinstance(widget, qt.QToolButton):
+ return widget
+ return None
+
+
+def findChildren(parent, kind, name=None):
+ if qt.BINDING in ("PySide2", "PySide6") and name is not None:
+ result = []
+ for obj in parent.findChildren(kind):
+ if obj.objectName() == name:
+ result.append(obj)
+ return result
+ else:
+ return parent.findChildren(kind, name=name)
diff --git a/silx/gui/widgets/BoxLayoutDockWidget.py b/src/silx/gui/widgets/BoxLayoutDockWidget.py
index 3d2b853..3d2b853 100644
--- a/silx/gui/widgets/BoxLayoutDockWidget.py
+++ b/src/silx/gui/widgets/BoxLayoutDockWidget.py
diff --git a/silx/gui/widgets/ColormapNameComboBox.py b/src/silx/gui/widgets/ColormapNameComboBox.py
index fa8faf1..fa8faf1 100644
--- a/silx/gui/widgets/ColormapNameComboBox.py
+++ b/src/silx/gui/widgets/ColormapNameComboBox.py
diff --git a/src/silx/gui/widgets/ElidedLabel.py b/src/silx/gui/widgets/ElidedLabel.py
new file mode 100644
index 0000000..7c6dfb5
--- /dev/null
+++ b/src/silx/gui/widgets/ElidedLabel.py
@@ -0,0 +1,140 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module contains an elidable label
+"""
+
+__license__ = "MIT"
+__date__ = "07/12/2018"
+
+from silx.gui import qt
+
+
+class ElidedLabel(qt.QLabel):
+ """QLabel with an edile property.
+
+ By default if the text is too big, it is elided on the right.
+
+ This mode can be changed with :func:`setElideMode`.
+
+ In case the text is elided, the full content is displayed as part of the
+ tool tip. This behavior can be disabled with :func:`setTextAsToolTip`.
+ """
+
+ def __init__(self, parent=None):
+ super(ElidedLabel, self).__init__(parent)
+ self.__text = ""
+ self.__toolTip = ""
+ self.__textAsToolTip = True
+ self.__textIsElided = False
+ self.__elideMode = qt.Qt.ElideRight
+ self.__updateMinimumSize()
+
+ def resizeEvent(self, event):
+ self.__updateText()
+ return qt.QLabel.resizeEvent(self, event)
+
+ def setFont(self, font):
+ qt.QLabel.setFont(self, font)
+ self.__updateMinimumSize()
+ self.__updateText()
+
+ def __updateMinimumSize(self):
+ metrics = self.fontMetrics()
+ if qt.BINDING in ('PySide2', 'PyQt5'):
+ width = metrics.width("...")
+ else: # Qt6
+ width = metrics.horizontalAdvance("...")
+ self.setMinimumWidth(width)
+
+ def __updateText(self):
+ metrics = self.fontMetrics()
+ elidedText = metrics.elidedText(self.__text, self.__elideMode, self.width())
+ qt.QLabel.setText(self, elidedText)
+ wasElided = self.__textIsElided
+ self.__textIsElided = elidedText != self.__text
+ if self.__textIsElided or wasElided != self.__textIsElided:
+ self.__updateToolTip()
+
+ def __updateToolTip(self):
+ if self.__textIsElided and self.__textAsToolTip:
+ qt.QLabel.setToolTip(self, self.__text + "<br/>" + self.__toolTip)
+ else:
+ qt.QLabel.setToolTip(self, self.__toolTip)
+
+ # Properties
+
+ def setText(self, text):
+ self.__text = text
+ self.__updateText()
+
+ def getText(self):
+ return self.__text
+
+ text = qt.Property(str, getText, setText)
+
+ def setToolTip(self, toolTip):
+ self.__toolTip = toolTip
+ self.__updateToolTip()
+
+ def getToolTip(self):
+ return self.__toolTip
+
+ toolTip = qt.Property(str, getToolTip, setToolTip)
+
+ def setElideMode(self, elideMode):
+ """Set the elide mode.
+
+ :param qt.Qt.TextElideMode elidMode: Elide mode to use
+ """
+ self.__elideMode = elideMode
+ self.__updateText()
+
+ def getElideMode(self):
+ """Returns the used elide mode.
+
+ :rtype: qt.Qt.TextElideMode
+ """
+ return self.__elideMode
+
+ elideMode = qt.Property(qt.Qt.TextElideMode, getToolTip, setToolTip)
+
+ def setTextAsToolTip(self, enabled):
+ """Enable displaying text as part of the tooltip if it is elided.
+
+ :param bool enabled: Enable the behavior
+ """
+ if self.__textAsToolTip == enabled:
+ return
+ self.__textAsToolTip = enabled
+ self.__updateToolTip()
+
+ def getTextAsToolTip(self):
+ """True if an elided text is displayed as part of the tooltip.
+
+ :rtype: bool
+ """
+ return self.__textAsToolTip
+
+ textAsToolTip = qt.Property(bool, getTextAsToolTip, setTextAsToolTip)
diff --git a/src/silx/gui/widgets/FloatEdit.py b/src/silx/gui/widgets/FloatEdit.py
new file mode 100644
index 0000000..08ed67d
--- /dev/null
+++ b/src/silx/gui/widgets/FloatEdit.py
@@ -0,0 +1,71 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module contains a float editor
+"""
+
+from __future__ import division
+
+__authors__ = ["V.A. Sole", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/10/2017"
+
+from .. import qt
+
+
+class FloatEdit(qt.QLineEdit):
+ """Field to edit a float value.
+
+ :param parent: See :class:`QLineEdit`
+ :param float value: The value to set the QLineEdit to.
+ """
+ def __init__(self, parent=None, value=None):
+ qt.QLineEdit.__init__(self, parent)
+ validator = qt.QDoubleValidator(self)
+ self.setValidator(validator)
+ self.setAlignment(qt.Qt.AlignRight)
+ if value is not None:
+ self.setValue(value)
+
+ def value(self):
+ """Return the QLineEdit current value as a float."""
+ text = self.text()
+ value, validated = self.validator().locale().toDouble(text)
+ if not validated:
+ self.setValue(value)
+ return value
+
+ def setValue(self, value):
+ """Set the current value of the LineEdit
+
+ :param float value: The value to set the QLineEdit to.
+ """
+ locale = self.validator().locale()
+ if qt.BINDING == "PySide6":
+ # Fix for PySide6 not selecting the right method
+ text = locale.toString(float(value), 'g')
+ else:
+ text = locale.toString(float(value))
+
+ self.setText(text)
diff --git a/silx/gui/widgets/FlowLayout.py b/src/silx/gui/widgets/FlowLayout.py
index 3c4c9dd..3c4c9dd 100644
--- a/silx/gui/widgets/FlowLayout.py
+++ b/src/silx/gui/widgets/FlowLayout.py
diff --git a/silx/gui/widgets/FrameBrowser.py b/src/silx/gui/widgets/FrameBrowser.py
index 671991f..671991f 100644
--- a/silx/gui/widgets/FrameBrowser.py
+++ b/src/silx/gui/widgets/FrameBrowser.py
diff --git a/silx/gui/widgets/HierarchicalTableView.py b/src/silx/gui/widgets/HierarchicalTableView.py
index 3ccf4c7..3ccf4c7 100644
--- a/silx/gui/widgets/HierarchicalTableView.py
+++ b/src/silx/gui/widgets/HierarchicalTableView.py
diff --git a/silx/gui/widgets/LegendIconWidget.py b/src/silx/gui/widgets/LegendIconWidget.py
index 1c95e41..1c95e41 100755
--- a/silx/gui/widgets/LegendIconWidget.py
+++ b/src/silx/gui/widgets/LegendIconWidget.py
diff --git a/silx/gui/widgets/MedianFilterDialog.py b/src/silx/gui/widgets/MedianFilterDialog.py
index dd4a00d..dd4a00d 100644
--- a/silx/gui/widgets/MedianFilterDialog.py
+++ b/src/silx/gui/widgets/MedianFilterDialog.py
diff --git a/silx/gui/widgets/MultiModeAction.py b/src/silx/gui/widgets/MultiModeAction.py
index 502275d..502275d 100644
--- a/silx/gui/widgets/MultiModeAction.py
+++ b/src/silx/gui/widgets/MultiModeAction.py
diff --git a/src/silx/gui/widgets/PeriodicTable.py b/src/silx/gui/widgets/PeriodicTable.py
new file mode 100644
index 0000000..6fed109
--- /dev/null
+++ b/src/silx/gui/widgets/PeriodicTable.py
@@ -0,0 +1,831 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Periodic table widgets
+
+Classes
+-------
+
+Widgets:
+
+ - :class:`PeriodicTable`
+ - :class:`PeriodicList`
+ - :class:`PeriodicCombo`
+
+Data model:
+
+ - :class:`PeriodicTableItem`
+ - :class:`ColoredPeriodicTableItem`
+
+
+Example of usage
+----------------
+
+This example uses the widgets with the standard builtin elements list.
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable, \
+ PeriodicCombo, PeriodicList
+
+ a = qt.QApplication([])
+
+ w = qt.QTabWidget()
+
+ ptable = PeriodicTable(w, selectable=True)
+ pcombo = PeriodicCombo(w)
+ plist = PeriodicList(w)
+
+ w.addTab(ptable, "PeriodicTable")
+ w.addTab(plist, "PeriodicList")
+ w.addTab(pcombo, "PeriodicCombo")
+
+ ptable.setSelection(['H', 'Fe', 'Si'])
+ plist.setSelectedElements(['H', 'Be', 'F'])
+ pcombo.setSelection("Li")
+
+ def change_list(items):
+ print("New list selection:", [item.symbol for item in items])
+
+ def change_combo(item):
+ print("New combo selection:", item.symbol)
+
+ def click_table(item):
+ print("New table click:", item.symbol)
+
+ def change_table(items):
+ print("New table selection:", [item.symbol for item in items])
+
+ ptable.sigElementClicked.connect(click_table)
+ ptable.sigSelectionChanged.connect(change_table)
+ plist.sigSelectionChanged.connect(change_list)
+ pcombo.sigSelectionChanged.connect(change_combo)
+
+ w.show()
+ a.exec()
+
+
+The second example explains how to define custom elements.
+
+.. code-block:: python
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable, \
+ PeriodicCombo, PeriodicList
+ from silx.gui.widgets.PeriodicTable import PeriodicTableItem
+
+ # subclass PeriodicTableItem
+ class MyPeriodicTableItem(PeriodicTableItem):
+ "New item with added mass number and number of protons"
+ def __init__(self, symbol, Z, A, col, row, name, mass,
+ subcategory=""):
+ PeriodicTableItem.__init__(
+ self, symbol, Z, col, row, name, mass,
+ subcategory)
+
+ self.A = A
+ "Mass number (neutrons + protons)"
+
+ self.num_neutrons = A - Z
+ "Number of neutrons"
+
+ # build your list of elements
+ my_elements = [MyPeriodicTableItem("H", 1, 1, 1, 1, "hydrogen",
+ 1.00800, "diatomic nonmetal"),
+ MyPeriodicTableItem("He", 2, 4, 18, 1, "helium",
+ 4.0030, "noble gas"),
+ # etc ...
+ ]
+
+ app = qt.QApplication([])
+
+ ptable = PeriodicTable(elements=my_elements, selectable=True)
+ ptable.show()
+
+ def click_table(item):
+ "Callback function printing the mass number of clicked element"
+ print("New table click, mass number:", item.A)
+
+ ptable.sigElementClicked.connect(click_table)
+ app.exec()
+
+"""
+
+__authors__ = ["E. Papillon", "V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "26/01/2017"
+
+from collections import OrderedDict
+import logging
+from silx.gui import qt
+
+_logger = logging.getLogger(__name__)
+
+# Symbol Atomic Number col row name mass subcategory
+_elements = [("H", 1, 1, 1, "hydrogen", 1.00800, "diatomic nonmetal"),
+ ("He", 2, 18, 1, "helium", 4.0030, "noble gas"),
+ ("Li", 3, 1, 2, "lithium", 6.94000, "alkali metal"),
+ ("Be", 4, 2, 2, "beryllium", 9.01200, "alkaline earth metal"),
+ ("B", 5, 13, 2, "boron", 10.8110, "metalloid"),
+ ("C", 6, 14, 2, "carbon", 12.0100, "polyatomic nonmetal"),
+ ("N", 7, 15, 2, "nitrogen", 14.0080, "diatomic nonmetal"),
+ ("O", 8, 16, 2, "oxygen", 16.0000, "diatomic nonmetal"),
+ ("F", 9, 17, 2, "fluorine", 19.0000, "diatomic nonmetal"),
+ ("Ne", 10, 18, 2, "neon", 20.1830, "noble gas"),
+ ("Na", 11, 1, 3, "sodium", 22.9970, "alkali metal"),
+ ("Mg", 12, 2, 3, "magnesium", 24.3200, "alkaline earth metal"),
+ ("Al", 13, 13, 3, "aluminium", 26.9700, "post transition metal"),
+ ("Si", 14, 14, 3, "silicon", 28.0860, "metalloid"),
+ ("P", 15, 15, 3, "phosphorus", 30.9750, "polyatomic nonmetal"),
+ ("S", 16, 16, 3, "sulphur", 32.0660, "polyatomic nonmetal"),
+ ("Cl", 17, 17, 3, "chlorine", 35.4570, "diatomic nonmetal"),
+ ("Ar", 18, 18, 3, "argon", 39.9440, "noble gas"),
+ ("K", 19, 1, 4, "potassium", 39.1020, "alkali metal"),
+ ("Ca", 20, 2, 4, "calcium", 40.0800, "alkaline earth metal"),
+ ("Sc", 21, 3, 4, "scandium", 44.9600, "transition metal"),
+ ("Ti", 22, 4, 4, "titanium", 47.9000, "transition metal"),
+ ("V", 23, 5, 4, "vanadium", 50.9420, "transition metal"),
+ ("Cr", 24, 6, 4, "chromium", 51.9960, "transition metal"),
+ ("Mn", 25, 7, 4, "manganese", 54.9400, "transition metal"),
+ ("Fe", 26, 8, 4, "iron", 55.8500, "transition metal"),
+ ("Co", 27, 9, 4, "cobalt", 58.9330, "transition metal"),
+ ("Ni", 28, 10, 4, "nickel", 58.6900, "transition metal"),
+ ("Cu", 29, 11, 4, "copper", 63.5400, "transition metal"),
+ ("Zn", 30, 12, 4, "zinc", 65.3800, "transition metal"),
+ ("Ga", 31, 13, 4, "gallium", 69.7200, "post transition metal"),
+ ("Ge", 32, 14, 4, "germanium", 72.5900, "metalloid"),
+ ("As", 33, 15, 4, "arsenic", 74.9200, "metalloid"),
+ ("Se", 34, 16, 4, "selenium", 78.9600, "polyatomic nonmetal"),
+ ("Br", 35, 17, 4, "bromine", 79.9200, "diatomic nonmetal"),
+ ("Kr", 36, 18, 4, "krypton", 83.8000, "noble gas"),
+ ("Rb", 37, 1, 5, "rubidium", 85.4800, "alkali metal"),
+ ("Sr", 38, 2, 5, "strontium", 87.6200, "alkaline earth metal"),
+ ("Y", 39, 3, 5, "yttrium", 88.9050, "transition metal"),
+ ("Zr", 40, 4, 5, "zirconium", 91.2200, "transition metal"),
+ ("Nb", 41, 5, 5, "niobium", 92.9060, "transition metal"),
+ ("Mo", 42, 6, 5, "molybdenum", 95.9500, "transition metal"),
+ ("Tc", 43, 7, 5, "technetium", 99.0000, "transition metal"),
+ ("Ru", 44, 8, 5, "ruthenium", 101.0700, "transition metal"),
+ ("Rh", 45, 9, 5, "rhodium", 102.9100, "transition metal"),
+ ("Pd", 46, 10, 5, "palladium", 106.400, "transition metal"),
+ ("Ag", 47, 11, 5, "silver", 107.880, "transition metal"),
+ ("Cd", 48, 12, 5, "cadmium", 112.410, "transition metal"),
+ ("In", 49, 13, 5, "indium", 114.820, "post transition metal"),
+ ("Sn", 50, 14, 5, "tin", 118.690, "post transition metal"),
+ ("Sb", 51, 15, 5, "antimony", 121.760, "metalloid"),
+ ("Te", 52, 16, 5, "tellurium", 127.600, "metalloid"),
+ ("I", 53, 17, 5, "iodine", 126.910, "diatomic nonmetal"),
+ ("Xe", 54, 18, 5, "xenon", 131.300, "noble gas"),
+ ("Cs", 55, 1, 6, "caesium", 132.910, "alkali metal"),
+ ("Ba", 56, 2, 6, "barium", 137.360, "alkaline earth metal"),
+ ("La", 57, 3, 6, "lanthanum", 138.920, "lanthanide"),
+ ("Ce", 58, 4, 9, "cerium", 140.130, "lanthanide"),
+ ("Pr", 59, 5, 9, "praseodymium", 140.920, "lanthanide"),
+ ("Nd", 60, 6, 9, "neodymium", 144.270, "lanthanide"),
+ ("Pm", 61, 7, 9, "promethium", 147.000, "lanthanide"),
+ ("Sm", 62, 8, 9, "samarium", 150.350, "lanthanide"),
+ ("Eu", 63, 9, 9, "europium", 152.000, "lanthanide"),
+ ("Gd", 64, 10, 9, "gadolinium", 157.260, "lanthanide"),
+ ("Tb", 65, 11, 9, "terbium", 158.930, "lanthanide"),
+ ("Dy", 66, 12, 9, "dysprosium", 162.510, "lanthanide"),
+ ("Ho", 67, 13, 9, "holmium", 164.940, "lanthanide"),
+ ("Er", 68, 14, 9, "erbium", 167.270, "lanthanide"),
+ ("Tm", 69, 15, 9, "thulium", 168.940, "lanthanide"),
+ ("Yb", 70, 16, 9, "ytterbium", 173.040, "lanthanide"),
+ ("Lu", 71, 17, 9, "lutetium", 174.990, "lanthanide"),
+ ("Hf", 72, 4, 6, "hafnium", 178.500, "transition metal"),
+ ("Ta", 73, 5, 6, "tantalum", 180.950, "transition metal"),
+ ("W", 74, 6, 6, "tungsten", 183.920, "transition metal"),
+ ("Re", 75, 7, 6, "rhenium", 186.200, "transition metal"),
+ ("Os", 76, 8, 6, "osmium", 190.200, "transition metal"),
+ ("Ir", 77, 9, 6, "iridium", 192.200, "transition metal"),
+ ("Pt", 78, 10, 6, "platinum", 195.090, "transition metal"),
+ ("Au", 79, 11, 6, "gold", 197.200, "transition metal"),
+ ("Hg", 80, 12, 6, "mercury", 200.610, "transition metal"),
+ ("Tl", 81, 13, 6, "thallium", 204.390, "post transition metal"),
+ ("Pb", 82, 14, 6, "lead", 207.210, "post transition metal"),
+ ("Bi", 83, 15, 6, "bismuth", 209.000, "post transition metal"),
+ ("Po", 84, 16, 6, "polonium", 209.000, "post transition metal"),
+ ("At", 85, 17, 6, "astatine", 210.000, "metalloid"),
+ ("Rn", 86, 18, 6, "radon", 222.000, "noble gas"),
+ ("Fr", 87, 1, 7, "francium", 223.000, "alkali metal"),
+ ("Ra", 88, 2, 7, "radium", 226.000, "alkaline earth metal"),
+ ("Ac", 89, 3, 7, "actinium", 227.000, "actinide"),
+ ("Th", 90, 4, 10, "thorium", 232.000, "actinide"),
+ ("Pa", 91, 5, 10, "proactinium", 231.03588, "actinide"),
+ ("U", 92, 6, 10, "uranium", 238.070, "actinide"),
+ ("Np", 93, 7, 10, "neptunium", 237.000, "actinide"),
+ ("Pu", 94, 8, 10, "plutonium", 239.100, "actinide"),
+ ("Am", 95, 9, 10, "americium", 243, "actinide"),
+ ("Cm", 96, 10, 10, "curium", 247, "actinide"),
+ ("Bk", 97, 11, 10, "berkelium", 247, "actinide"),
+ ("Cf", 98, 12, 10, "californium", 251, "actinide"),
+ ("Es", 99, 13, 10, "einsteinium", 252, "actinide"),
+ ("Fm", 100, 14, 10, "fermium", 257, "actinide"),
+ ("Md", 101, 15, 10, "mendelevium", 258, "actinide"),
+ ("No", 102, 16, 10, "nobelium", 259, "actinide"),
+ ("Lr", 103, 17, 10, "lawrencium", 262, "actinide"),
+ ("Rf", 104, 4, 7, "rutherfordium", 261, "transition metal"),
+ ("Db", 105, 5, 7, "dubnium", 262, "transition metal"),
+ ("Sg", 106, 6, 7, "seaborgium", 266, "transition metal"),
+ ("Bh", 107, 7, 7, "bohrium", 264, "transition metal"),
+ ("Hs", 108, 8, 7, "hassium", 269, "transition metal"),
+ ("Mt", 109, 9, 7, "meitnerium", 268)]
+
+
+class PeriodicTableItem(object):
+ """Periodic table item, used as generic item in :class:`PeriodicTable`,
+ :class:`PeriodicCombo` and :class:`PeriodicList`.
+
+ This implementation stores the minimal amount of information needed by the
+ widgets:
+
+ - atomic symbol
+ - atomic number
+ - element name
+ - atomic mass
+ - column of element in periodic table
+ - row of element in periodic table
+
+ You can subclass this class to add additional information.
+
+ :param str symbol: Atomic symbol (e.g. H, He, Li...)
+ :param int Z: Proton number
+ :param int col: 1-based column index of element in periodic table
+ :param int row: 1-based row index of element in periodic table
+ :param str name: PeriodicTableItem name ("hydrogen", ...)
+ :param float mass: Atomic mass (gram per mol)
+ :param str subcategory: Subcategory, based on physical properties
+ (e.g. "alkali metal", "noble gas"...)
+ """
+ def __init__(self, symbol, Z, col, row, name, mass,
+ subcategory=""):
+ self.symbol = symbol
+ """Atomic symbol (e.g. H, He, Li...)"""
+ self.Z = Z
+ """Atomic number (Proton number)"""
+ self.col = col
+ """1-based column index of element in periodic table"""
+ self.row = row
+ """1-based row index of element in periodic table"""
+ self.name = name
+ """PeriodicTableItem name ("hydrogen", ...)"""
+ self.mass = mass
+ """Atomic mass (gram per mol)"""
+ self.subcategory = subcategory
+ """Subcategory, based on physical properties
+ (e.g. "alkali metal", "noble gas"...)"""
+
+ # pymca compatibility (elements used to be stored as a list of lists)
+ def __getitem__(self, idx):
+ if idx == 6:
+ _logger.warning("density not implemented in silx, returning 0.")
+
+ ret = [self.symbol, self.Z,
+ self.col, self.row,
+ self.name, self.mass,
+ 0.]
+ return ret[idx]
+
+ def __len__(self):
+ return 6
+
+
+class ColoredPeriodicTableItem(PeriodicTableItem):
+ """:class:`PeriodicTableItem` with an added :attr:`bgcolor`.
+ The background color can be passed as a parameter to the constructor.
+ If it is not specified, it will be defined based on
+ :attr:`subcategory`.
+
+ :param str bgcolor: Custom background color for element in
+ periodic table, as a RGB string *#RRGGBB*"""
+ COLORS = {
+ "diatomic nonmetal": "#7FFF00", # chartreuse
+ "noble gas": "#00FFFF", # cyan
+ "alkali metal": "#FFE4B5", # Moccasin
+ "alkaline earth metal": "#FFA500", # orange
+ "polyatomic nonmetal": "#7FFFD4", # aquamarine
+ "transition metal": "#FFA07A", # light salmon
+ "metalloid": "#8FBC8F", # Dark Sea Green
+ "post transition metal": "#D3D3D3", # light gray
+ "lanthanide": "#FFB6C1", # light pink
+ "actinide": "#F08080", # Light Coral
+ "": "#FFFFFF" # white
+ }
+ """Dictionary defining RGB colors for each subcategory."""
+
+ def __init__(self, symbol, Z, col, row, name, mass,
+ subcategory="", bgcolor=None):
+ PeriodicTableItem.__init__(self, symbol, Z, col, row, name, mass,
+ subcategory)
+
+ self.bgcolor = self.COLORS.get(subcategory, "#FFFFFF")
+ """Background color of element in the periodic table,
+ based on its subcategory. This should be a string of a hexadecimal
+ RGB code, with the format *#RRGGBB*.
+ If the subcategory is unknown, use white (*#FFFFFF*)
+ """
+
+ # possible custom color
+ if bgcolor is not None:
+ self.bgcolor = bgcolor
+
+
+_defaultTableItems = [ColoredPeriodicTableItem(*info) for info in _elements]
+
+
+class _ElementButton(qt.QPushButton):
+ """Atomic element button, used as a cell in the periodic table
+ """
+ sigElementEnter = qt.pyqtSignal(object)
+ """Signal emitted as the cursor enters the widget"""
+ sigElementLeave = qt.pyqtSignal(object)
+ """Signal emitted as the cursor leaves the widget"""
+ sigElementClicked = qt.pyqtSignal(object)
+ """Signal emitted when the widget is clicked"""
+
+ def __init__(self, item, parent=None):
+ """
+
+ :param parent: Parent widget
+ :param PeriodicTableItem item: :class:`PeriodicTableItem` object
+ """
+ qt.QPushButton.__init__(self, parent)
+
+ self.item = item
+ """:class:`PeriodicTableItem` object represented by this button"""
+
+ self.setText(item.symbol)
+ self.setFlat(1)
+ self.setCheckable(0)
+
+ self.setSizePolicy(qt.QSizePolicy(qt.QSizePolicy.Expanding,
+ qt.QSizePolicy.Expanding))
+
+ self.selected = False
+ self.current = False
+
+ # selection colors
+ self.selected_color = qt.QColor(qt.Qt.yellow)
+ self.current_color = qt.QColor(qt.Qt.gray)
+ self.selected_current_color = qt.QColor(qt.Qt.darkYellow)
+
+ # element colors
+
+ if hasattr(item, "bgcolor"):
+ self.bgcolor = qt.QColor(item.bgcolor)
+ else:
+ self.bgcolor = qt.QColor("#FFFFFF")
+
+ self.brush = qt.QBrush()
+ self.__setBrush()
+
+ self.clicked.connect(self.clickedSlot)
+
+ def sizeHint(self):
+ return qt.QSize(40, 40)
+
+ def setCurrent(self, b):
+ """Set this element button as current.
+ Multiple buttons can be selected.
+
+ :param b: boolean
+ """
+ self.current = b
+ self.__setBrush()
+
+ def isCurrent(self):
+ """
+ :return: True if element button is current
+ """
+ return self.current
+
+ def isSelected(self):
+ """
+ :return: True if element button is selected
+ """
+ return self.selected
+
+ def setSelected(self, b):
+ """Set this element button as selected.
+ Only a single button can be selected.
+
+ :param b: boolean
+ """
+ self.selected = b
+ self.__setBrush()
+
+ def __setBrush(self):
+ """Selected cells are yellow when not current.
+ The current cell is dark yellow when selected or grey when not
+ selected.
+ Other cells have no bg color by default, unless specified at
+ instantiation (:attr:`bgcolor`)"""
+ palette = self.palette()
+ # if self.current and self.selected:
+ # self.brush = qt.QBrush(self.selected_current_color)
+ # el
+ if self.selected:
+ self.brush = qt.QBrush(self.selected_color)
+ # elif self.current:
+ # self.brush = qt.QBrush(self.current_color)
+ elif self.bgcolor is not None:
+ self.brush = qt.QBrush(self.bgcolor)
+ else:
+ self.brush = qt.QBrush()
+ palette.setBrush(self.backgroundRole(),
+ self.brush)
+ self.setPalette(palette)
+ self.update()
+
+ def paintEvent(self, pEvent):
+ # get button geometry
+ widgGeom = self.rect()
+ paintGeom = qt.QRect(widgGeom.left() + 1,
+ widgGeom.top() + 1,
+ widgGeom.width() - 2,
+ widgGeom.height() - 2)
+
+ # paint background color
+ painter = qt.QPainter(self)
+ if self.brush is not None:
+ painter.fillRect(paintGeom, self.brush)
+ # paint frame
+ pen = qt.QPen(qt.Qt.black)
+ pen.setWidth(1 if not self.isCurrent() else 5)
+ painter.setPen(pen)
+ painter.drawRect(paintGeom)
+ painter.end()
+ qt.QPushButton.paintEvent(self, pEvent)
+
+ def enterEvent(self, e):
+ """Emit a :attr:`sigElementEnter` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementEnter.emit(self.item)
+
+ def leaveEvent(self, e):
+ """Emit a :attr:`sigElementLeave` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementLeave.emit(self.item)
+
+ def clickedSlot(self):
+ """Emit a :attr:`sigElementClicked` signal and send a
+ :class:`PeriodicTableItem` object"""
+ self.sigElementClicked.emit(self.item)
+
+
+class PeriodicTable(qt.QWidget):
+ """Periodic Table widget
+
+ .. image:: img/PeriodicTable.png
+
+ The following example shows how to connect clicking to selection::
+
+ from silx.gui import qt
+ from silx.gui.widgets.PeriodicTable import PeriodicTable
+ app = qt.QApplication([])
+ pt = PeriodicTable()
+ pt.sigElementClicked.connect(pt.elementToggle)
+ pt.show()
+ app.exec()
+
+ To print all selected elements each time a new element is selected::
+
+ def my_slot(item):
+ pt.elementToggle(item)
+ selected_elements = pt.getSelection()
+ for e in selected_elements:
+ print(e.symbol)
+
+ pt.sigElementClicked.connect(my_slot)
+
+ """
+ sigElementClicked = qt.pyqtSignal(object)
+ """When any element is clicked in the table, the widget emits
+ this signal and sends a :class:`PeriodicTableItem` object.
+ """
+
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """When any element is selected/unselected in the table, the widget emits
+ this signal and sends a list of :class:`PeriodicTableItem` objects.
+
+ .. note::
+
+ To enable selection of elements, you must set *selectable=True*
+ when you instantiate the widget. Alternatively, you can also connect
+ :attr:`sigElementClicked` to :meth:`elementToggle` manually::
+
+ pt = PeriodicTable()
+ pt.sigElementClicked.connect(pt.elementToggle)
+
+
+ :param parent: parent QWidget
+ :param str name: Widget window title
+ :param elements: List of items (:class:`PeriodicTableItem` objects) to
+ be represented in the table. By default, take elements from
+ a predefined list with minimal information (symbol, atomic number,
+ name, mass).
+ :param bool selectable: If *True*, multiple elements can be
+ selected by clicking with the mouse. If *False* (default),
+ selection is only possible with method :meth:`setSelection`.
+ """
+
+ def __init__(self, parent=None, name="PeriodicTable", elements=None,
+ selectable=False):
+ self.selectable = selectable
+ qt.QWidget.__init__(self, parent)
+ self.setWindowTitle(name)
+ self.gridLayout = qt.QGridLayout(self)
+ self.gridLayout.setContentsMargins(0, 0, 0, 0)
+ self.gridLayout.addItem(qt.QSpacerItem(0, 5), 7, 0)
+
+ for idx in range(10):
+ self.gridLayout.setRowStretch(idx, 3)
+ # row 8 (above lanthanoids is empty)
+ self.gridLayout.setRowStretch(7, 2)
+
+ # Element information displayed when cursor enters a cell
+ self.eltLabel = qt.QLabel(self)
+ f = self.eltLabel.font()
+ f.setBold(1)
+ self.eltLabel.setFont(f)
+ self.eltLabel.setAlignment(qt.Qt.AlignHCenter)
+ self.gridLayout.addWidget(self.eltLabel, 1, 1, 3, 10)
+
+ self._eltCurrent = None
+ """Current :class:`_ElementButton` (last clicked)"""
+
+ self._eltButtons = OrderedDict()
+ """Dictionary of all :class:`_ElementButton`. Keys are the symbols
+ ("H", "He", "Li"...)"""
+
+ if elements is None:
+ elements = _defaultTableItems
+ # fill cells with elements
+ for elmt in elements:
+ self.__addElement(elmt)
+
+ def __addElement(self, elmt):
+ """Add one :class:`_ElementButton` widget into the grid,
+ connect its signals to interact with the cursor"""
+ b = _ElementButton(elmt, self)
+ b.setAutoDefault(False)
+
+ self._eltButtons[elmt.symbol] = b
+ self.gridLayout.addWidget(b, elmt.row, elmt.col)
+
+ b.sigElementEnter.connect(self.elementEnter)
+ b.sigElementLeave.connect(self._elementLeave)
+ b.sigElementClicked.connect(self._elementClicked)
+
+ def elementEnter(self, item):
+ """Update label with element info (e.g. "Nb(41) - niobium")
+ when mouse cursor hovers an element.
+
+ :param PeriodicTableItem item: Element entered by cursor
+ """
+ self.eltLabel.setText("%s(%d) - %s" % (item.symbol, item.Z, item.name))
+
+ def _elementLeave(self, item):
+ """Clear label when the cursor leaves the cell
+
+ :param PeriodicTableItem item: Element left
+ """
+ self.eltLabel.setText("")
+
+ def _elementClicked(self, item):
+ """Emit :attr:`sigElementClicked`,
+ toggle selected state of element
+
+ :param PeriodicTableItem item: Element clicked
+ """
+ if self._eltCurrent is not None:
+ self._eltCurrent.setCurrent(False)
+ self._eltButtons[item.symbol].setCurrent(True)
+ self._eltCurrent = self._eltButtons[item.symbol]
+ if self.selectable:
+ self.elementToggle(item)
+ self.sigElementClicked.emit(item)
+
+ def getSelection(self):
+ """Return a list of selected elements, as a list of :class:`PeriodicTableItem`
+ objects.
+
+ :return: Selected items
+ :rtype: List[PeriodicTableItem]
+ """
+ return [b.item for b in self._eltButtons.values() if b.isSelected()]
+
+ def setSelection(self, symbols):
+ """Set selected elements.
+
+ This causes the sigSelectionChanged signal
+ to be emitted, even if the selection didn't actually change.
+
+ :param List[str] symbols: List of symbols of elements to be selected
+ (e.g. *["Fe", "Hg", "Li"]*)
+ """
+ # accept list of PeriodicTableItems as input, because getSelection
+ # returns these objects and it makes sense to have getter and setter
+ # use same type of data
+ if isinstance(symbols[0], PeriodicTableItem):
+ symbols = [elmt.symbol for elmt in symbols]
+
+ for (e, b) in self._eltButtons.items():
+ b.setSelected(e in symbols)
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def setElementSelected(self, symbol, state):
+ """Modify *selected* status of a single element (select or unselect)
+
+ :param str symbol: PeriodicTableItem symbol to be selected
+ :param bool state: *True* to select, *False* to unselect
+ """
+ self._eltButtons[symbol].setSelected(state)
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def isElementSelected(self, symbol):
+ """Return *True* if element is selected, else *False*
+
+ :param str symbol: PeriodicTableItem symbol
+ :return: *True* if element is selected, else *False*
+ """
+ return self._eltButtons[symbol].isSelected()
+
+ def elementToggle(self, item):
+ """Toggle selected/unselected state for element
+
+ :param item: PeriodicTableItem object
+ """
+ b = self._eltButtons[item.symbol]
+ b.setSelected(not b.isSelected())
+ self.sigSelectionChanged.emit(self.getSelection())
+
+
+class PeriodicCombo(qt.QComboBox):
+ """
+ Combo list with all atomic elements of the periodic table
+
+ .. image:: img/PeriodicCombo.png
+
+ :param bool detailed: True (default) display element symbol, Z and name.
+ False display only element symbol and Z.
+ :param elements: List of items (:class:`PeriodicTableItem` objects) to
+ be represented in the table. By default, take elements from
+ a predefined list with minimal information (symbol, atomic number,
+ name, mass).
+ """
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """Signal emitted when the selection changes. Send
+ :class:`PeriodicTableItem` object representing selected
+ element
+ """
+
+ def __init__(self, parent=None, detailed=True, elements=None):
+ qt.QComboBox.__init__(self, parent)
+
+ # add all elements from global list
+ if elements is None:
+ elements = _defaultTableItems
+ for i, elmt in enumerate(elements):
+ if detailed:
+ txt = "%2s (%d) - %s" % (elmt.symbol, elmt.Z, elmt.name)
+ else:
+ txt = "%2s (%d)" % (elmt.symbol, elmt.Z)
+ self.insertItem(i, txt)
+
+ self.currentIndexChanged[int].connect(self.__selectionChanged)
+
+ def __selectionChanged(self, idx):
+ """Emit :attr:`sigSelectionChanged`"""
+ self.sigSelectionChanged.emit(_defaultTableItems[idx])
+
+ def getSelection(self):
+ """Get selected element
+
+ :return: Selected element
+ :rtype: PeriodicTableItem
+ """
+ return _defaultTableItems[self.currentIndex()]
+
+ def setSelection(self, symbol):
+ """Set selected item in combobox by giving the atomic symbol
+
+ :param symbol: Symbol of element to be selected
+ """
+ # accept PeriodicTableItem for getter/setter consistency
+ if isinstance(symbol, PeriodicTableItem):
+ symbol = symbol.symbol
+ symblist = [elmt.symbol for elmt in _defaultTableItems]
+ self.setCurrentIndex(symblist.index(symbol))
+
+
+class PeriodicList(qt.QTreeWidget):
+ """List of atomic elements in a :class:`QTreeView`
+
+ .. image:: img/PeriodicList.png
+
+ :param QWidget parent: Parent widget
+ :param bool detailed: True (default) display element symbol, Z and name.
+ False display only element symbol and Z.
+ :param single: *True* for single element selection with mouse click,
+ *False* for multiple element selection mode.
+ """
+ sigSelectionChanged = qt.pyqtSignal(object)
+ """When any element is selected/unselected in the widget, it emits
+ this signal and sends a list of currently selected
+ :class:`PeriodicTableItem` objects.
+ """
+
+ def __init__(self, parent=None, detailed=True, single=False, elements=None):
+ qt.QTreeWidget.__init__(self, parent)
+
+ self.detailed = detailed
+
+ headers = ["Z", "Symbol"]
+ if detailed:
+ headers.append("Name")
+ self.setColumnCount(3)
+ else:
+ self.setColumnCount(2)
+ self.setHeaderLabels(headers)
+ self.header().setStretchLastSection(False)
+
+ self.setRootIsDecorated(0)
+ self.itemClicked.connect(self.__selectionChanged)
+ self.setSelectionMode(qt.QAbstractItemView.SingleSelection if single
+ else qt.QAbstractItemView.ExtendedSelection)
+ self.__fill_widget(elements)
+ self.resizeColumnToContents(0)
+ self.resizeColumnToContents(1)
+ if detailed:
+ self.resizeColumnToContents(2)
+
+ def __fill_widget(self, elements):
+ """Fill tree widget with elements """
+ if elements is None:
+ elements = _defaultTableItems
+
+ self.tree_items = []
+
+ previous_item = None
+ for elmt in elements:
+ if previous_item is None:
+ item = qt.QTreeWidgetItem(self)
+ else:
+ item = qt.QTreeWidgetItem(self, previous_item)
+ item.setText(0, str(elmt.Z))
+ item.setText(1, elmt.symbol)
+ if self.detailed:
+ item.setText(2, elmt.name)
+ self.tree_items.append(item)
+ previous_item = item
+
+ def __selectionChanged(self, treeItem, column):
+ """Emit a :attr:`sigSelectionChanged` and send a list of
+ :class:`PeriodicTableItem` objects."""
+ self.sigSelectionChanged.emit(self.getSelection())
+
+ def getSelection(self):
+ """Get a list of selected elements, as a list of :class:`PeriodicTableItem`
+ objects.
+
+ :return: Selected elements
+ :rtype: List[PeriodicTableItem]"""
+ return [_defaultTableItems[idx] for idx in range(len(self.tree_items))
+ if self.tree_items[idx].isSelected()]
+
+ # setSelection is a bad name (name of a QTreeWidget method)
+ def setSelectedElements(self, symbolList):
+ """
+
+ :param symbolList: List of atomic symbols ["H", "He", "Li"...]
+ to be selected in the widget
+ """
+ # accept PeriodicTableItem for getter/setter consistency
+ if isinstance(symbolList[0], PeriodicTableItem):
+ symbolList = [elmt.symbol for elmt in symbolList]
+ for idx in range(len(self.tree_items)):
+ self.tree_items[idx].setSelected(_defaultTableItems[idx].symbol in symbolList)
diff --git a/src/silx/gui/widgets/PrintGeometryDialog.py b/src/silx/gui/widgets/PrintGeometryDialog.py
new file mode 100644
index 0000000..98ff8d1
--- /dev/null
+++ b/src/silx/gui/widgets/PrintGeometryDialog.py
@@ -0,0 +1,222 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+
+from silx.gui import qt
+from silx.gui.widgets.FloatEdit import FloatEdit
+
+
+class PrintGeometryWidget(qt.QWidget):
+ """Widget to specify the size and aspect ratio of an item
+ before sending it to the print preview dialog.
+
+ Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
+ to interact with the widget.
+ """
+ def __init__(self, parent=None):
+ super(PrintGeometryWidget, self).__init__(parent)
+ self.mainLayout = qt.QGridLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(2)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ label = qt.QLabel(self)
+ label.setText("Units")
+ label.setAlignment(qt.Qt.AlignCenter)
+ self._pageButton = qt.QRadioButton()
+ self._pageButton.setText("Page")
+ self._inchButton = qt.QRadioButton()
+ self._inchButton.setText("Inches")
+ self._cmButton = qt.QRadioButton()
+ self._cmButton.setText("Centimeters")
+ self._buttonGroup = qt.QButtonGroup(self)
+ self._buttonGroup.addButton(self._pageButton)
+ self._buttonGroup.addButton(self._inchButton)
+ self._buttonGroup.addButton(self._cmButton)
+ self._buttonGroup.setExclusive(True)
+
+ # units
+ self.mainLayout.addWidget(label, 0, 0, 1, 4)
+ hboxLayout.addWidget(self._pageButton)
+ hboxLayout.addWidget(self._inchButton)
+ hboxLayout.addWidget(self._cmButton)
+ self.mainLayout.addWidget(hbox, 1, 0, 1, 4)
+ self._pageButton.setChecked(True)
+
+ # xOffset
+ label = qt.QLabel(self)
+ label.setText("X Offset:")
+ self.mainLayout.addWidget(label, 2, 0)
+ self._xOffset = FloatEdit(self, 0.1)
+ self.mainLayout.addWidget(self._xOffset, 2, 1)
+
+ # yOffset
+ label = qt.QLabel(self)
+ label.setText("Y Offset:")
+ self.mainLayout.addWidget(label, 2, 2)
+ self._yOffset = FloatEdit(self, 0.1)
+ self.mainLayout.addWidget(self._yOffset, 2, 3)
+
+ # width
+ label = qt.QLabel(self)
+ label.setText("Width:")
+ self.mainLayout.addWidget(label, 3, 0)
+ self._width = FloatEdit(self, 0.9)
+ self.mainLayout.addWidget(self._width, 3, 1)
+
+ # height
+ label = qt.QLabel(self)
+ label.setText("Height:")
+ self.mainLayout.addWidget(label, 3, 2)
+ self._height = FloatEdit(self, 0.9)
+ self.mainLayout.addWidget(self._height, 3, 3)
+
+ # aspect ratio
+ self._aspect = qt.QCheckBox(self)
+ self._aspect.setText("Keep screen aspect ratio")
+ self._aspect.setChecked(True)
+ self.mainLayout.addWidget(self._aspect, 4, 1, 1, 2)
+
+ def getPrintGeometry(self):
+ """Return the print geometry dictionary.
+
+ See :meth:`setPrintGeometry` for documentation about the
+ print geometry dictionary."""
+ ddict = {}
+ if self._inchButton.isChecked():
+ ddict['units'] = "inches"
+ elif self._cmButton.isChecked():
+ ddict['units'] = "centimeters"
+ else:
+ ddict['units'] = "page"
+
+ ddict['xOffset'] = self._xOffset.value()
+ ddict['yOffset'] = self._yOffset.value()
+ ddict['width'] = self._width.value()
+ ddict['height'] = self._height.value()
+
+ if self._aspect.isChecked():
+ ddict['keepAspectRatio'] = True
+ else:
+ ddict['keepAspectRatio'] = False
+ return ddict
+
+ def setPrintGeometry(self, geometry=None):
+ """Set the print geometry.
+
+ The geometry parameters must be provided as a dictionary with
+ the following keys:
+
+ - *"xOffset"* (float)
+ - *"yOffset"* (float)
+ - *"width"* (float)
+ - *"height"* (float)
+ - *"units"*: possible values *"page", "inch", "cm"*
+ - *"keepAspectRatio"*: *True* or *False*
+
+ If *units* is *"page"*, the values should be floats in [0, 1.]
+ and are interpreted as a fraction of the page width or height.
+
+ :param dict geometry: Geometry parameters, as a dictionary."""
+ if geometry is None:
+ geometry = {}
+ oldDict = self.getPrintGeometry()
+ for key in ["units", "xOffset", "yOffset",
+ "width", "height", "keepAspectRatio"]:
+ geometry[key] = geometry.get(key, oldDict[key])
+
+ if geometry['units'].lower().startswith("inc"):
+ self._inchButton.setChecked(True)
+ elif geometry['units'].lower().startswith("c"):
+ self._cmButton.setChecked(True)
+ else:
+ self._pageButton.setChecked(True)
+
+ self._xOffset.setText("%s" % float(geometry['xOffset']))
+ self._yOffset.setText("%s" % float(geometry['yOffset']))
+ self._width.setText("%s" % float(geometry['width']))
+ self._height.setText("%s" % float(geometry['height']))
+ if geometry['keepAspectRatio']:
+ self._aspect.setChecked(True)
+ else:
+ self._aspect.setChecked(False)
+
+
+class PrintGeometryDialog(qt.QDialog):
+ """Dialog embedding a :class:`PrintGeometryWidget`.
+
+ Use methods :meth:`setPrintGeometry` and :meth:`getPrintGeometry`
+ to interact with the widget.
+
+ Execute method :meth:`exec` to run the dialog.
+ The return value of that method is *True* if the geometry was set
+ (*Ok* button clicked) or *False* if the user clicked the *Cancel*
+ button.
+ """
+
+ def __init__(self, parent=None):
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Set print size preferences")
+ layout = qt.QVBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setSpacing(0)
+ self.configurationWidget = PrintGeometryWidget(self)
+ hbox = qt.QWidget(self)
+ hboxLayout = qt.QHBoxLayout(hbox)
+ self.okButton = qt.QPushButton(hbox)
+ self.okButton.setText("Accept")
+ self.okButton.setAutoDefault(False)
+ self.rejectButton = qt.QPushButton(hbox)
+ self.rejectButton.setText("Dismiss")
+ self.rejectButton.setAutoDefault(False)
+ self.okButton.clicked.connect(self.accept)
+ self.rejectButton.clicked.connect(self.reject)
+ hboxLayout.setContentsMargins(0, 0, 0, 0)
+ hboxLayout.setSpacing(2)
+ # hboxLayout.addWidget(qt.HorizontalSpacer(hbox))
+ hboxLayout.addWidget(self.okButton)
+ hboxLayout.addWidget(self.rejectButton)
+ # hboxLayout.addWidget(qt.HorizontalSpacer(hbox))
+ layout.addWidget(self.configurationWidget)
+ layout.addWidget(hbox)
+
+ def setPrintGeometry(self, geometry):
+ """Return the print geometry dictionary.
+
+ See :meth:`PrintGeometryWidget.setPrintGeometry` for documentation on
+ print geometry dictionary.
+
+ :param dict geometry: Print geometry parameters dictionary.
+ """
+ self.configurationWidget.setPrintGeometry(geometry)
+
+ def getPrintGeometry(self):
+ """Return the print geometry dictionary.
+
+ See :meth:`PrintGeometryWidget.setPrintGeometry` for documentation on
+ print geometry dictionary."""
+ return self.configurationWidget.getPrintGeometry()
diff --git a/src/silx/gui/widgets/PrintPreview.py b/src/silx/gui/widgets/PrintPreview.py
new file mode 100644
index 0000000..53e0a1f
--- /dev/null
+++ b/src/silx/gui/widgets/PrintPreview.py
@@ -0,0 +1,697 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module implements a print preview dialog.
+
+The dialog provides methods to send images, pixmaps and SVG
+items to the page to be printed.
+
+The user can interactively move and resize the items.
+"""
+import sys
+import logging
+from silx.gui import qt, printer
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "11/07/2017"
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PrintPreviewDialog(qt.QDialog):
+ """Print preview dialog widget.
+ """
+ def __init__(self, parent=None, printer=None):
+
+ qt.QDialog.__init__(self, parent)
+ self.setWindowTitle("Print Preview")
+ self.setModal(False)
+ self.resize(400, 500)
+
+ self.mainLayout = qt.QVBoxLayout(self)
+ self.mainLayout.setContentsMargins(0, 0, 0, 0)
+ self.mainLayout.setSpacing(0)
+
+ self._buildToolbar()
+
+ self.printer = printer
+ # :class:`QPrinter` (paint device that paints on a printer).
+ # :meth:`showEvent` has been reimplemented to enforce printer
+ # setup.
+
+ self.printDialog = None
+ # :class:`QPrintDialog` (dialog for specifying the printer's
+ # configuration)
+
+ self.scene = None
+ # :class:`QGraphicsScene` (surface for managing
+ # 2D graphical items)
+
+ self.page = None
+ # :class:`QGraphicsRectItem` used as white background page on which
+ # to display the print preview.
+
+ self.view = None
+ # :class:`QGraphicsView` widget for displaying :attr:`scene`
+
+ self._svgItems = []
+ # List storing :class:`QSvgRenderer` items to be printed, added in
+ # :meth:`addSvgItem`, cleared in :meth:`_clearAll`.
+ # This ensures that there is a reference pointing to the items,
+ # which ensures they are not destroyed before being printed.
+
+ self._viewScale = 1.0
+ # Zoom level (1.0 is 100%)
+
+ self._toBeCleared = False
+ # Flag indicating that all items must be removed from :attr:`scene`
+ # and from :attr:`_svgItems`.
+ # Set to True after a successful printing. The widget is then hidden,
+ # and it will be cleared the next time it is shown.
+ # Reset to False after :meth:`_clearAll` has done its job.
+
+ def _buildToolbar(self):
+ toolBar = qt.QWidget(self)
+ # a layout for the toolbar
+ toolsLayout = qt.QHBoxLayout(toolBar)
+ toolsLayout.setContentsMargins(0, 0, 0, 0)
+ toolsLayout.setSpacing(0)
+
+ hideBut = qt.QPushButton("Hide", toolBar)
+ hideBut.setToolTip("Hide print preview dialog")
+ hideBut.clicked.connect(self.hide)
+
+ cancelBut = qt.QPushButton("Clear All", toolBar)
+ cancelBut.setToolTip("Remove all items")
+ cancelBut.clicked.connect(self._clearAll)
+
+ removeBut = qt.QPushButton("Remove",
+ toolBar)
+ removeBut.setToolTip("Remove selected item (use left click to select)")
+ removeBut.clicked.connect(self._remove)
+
+ setupBut = qt.QPushButton("Setup", toolBar)
+ setupBut.setToolTip("Select and configure a printer")
+ setupBut.clicked.connect(self.setup)
+
+ printBut = qt.QPushButton("Print", toolBar)
+ printBut.setToolTip("Print page and close print preview")
+ printBut.clicked.connect(self._print)
+
+ zoomPlusBut = qt.QPushButton("Zoom +", toolBar)
+ zoomPlusBut.clicked.connect(self._zoomPlus)
+
+ zoomMinusBut = qt.QPushButton("Zoom -", toolBar)
+ zoomMinusBut.clicked.connect(self._zoomMinus)
+
+ toolsLayout.addWidget(hideBut)
+ toolsLayout.addWidget(printBut)
+ toolsLayout.addWidget(cancelBut)
+ toolsLayout.addWidget(removeBut)
+ toolsLayout.addWidget(setupBut)
+ # toolsLayout.addStretch()
+ # toolsLayout.addWidget(marginLabel)
+ # toolsLayout.addWidget(self.marginSpin)
+ toolsLayout.addStretch()
+ # toolsLayout.addWidget(scaleLabel)
+ # toolsLayout.addWidget(self.scaleCombo)
+ toolsLayout.addWidget(zoomPlusBut)
+ toolsLayout.addWidget(zoomMinusBut)
+ # toolsLayout.addStretch()
+ self.toolBar = toolBar
+ self.mainLayout.addWidget(self.toolBar)
+
+ def _buildStatusBar(self):
+ """Create the status bar used to display the printer name
+ or output file name."""
+ # status bar
+ statusBar = qt.QStatusBar(self)
+ self.targetLabel = qt.QLabel(statusBar)
+ self._updateTargetLabel()
+ statusBar.addWidget(self.targetLabel)
+ self.mainLayout.addWidget(statusBar)
+
+ def _updateTargetLabel(self):
+ """Update printer name or file name shown in the status bar."""
+ if self.printer is None:
+ self.targetLabel.setText("Undefined printer")
+ return
+ if self.printer.outputFileName():
+ self.targetLabel.setText("File:" +
+ self.printer.outputFileName())
+ else:
+ self.targetLabel.setText("Printer:" +
+ self.printer.printerName())
+
+ def _updatePrinter(self):
+ """Resize :attr:`page`, :attr:`scene` and :attr:`view` to :attr:`printer`
+ width and height."""
+ printer = self.printer
+ assert printer is not None, \
+ "_updatePrinter should not be called unless a printer is defined"
+ if self.scene is None:
+ self.scene = qt.QGraphicsScene()
+ self.scene.setBackgroundBrush(qt.QColor(qt.Qt.lightGray))
+ self.scene.setSceneRect(qt.QRectF(0, 0, printer.width(), printer.height()))
+
+ if self.page is None:
+ self.page = qt.QGraphicsRectItem(0, 0, printer.width(), printer.height())
+ self.page.setBrush(qt.QColor(qt.Qt.white))
+ self.scene.addItem(self.page)
+
+ self.scene.setSceneRect(qt.QRectF(0, 0, printer.width(), printer.height()))
+ self.page.setPos(qt.QPointF(0.0, 0.0))
+ self.page.setRect(qt.QRectF(0, 0, printer.width(), printer.height()))
+
+ if self.view is None:
+ self.view = qt.QGraphicsView(self.scene)
+ self.mainLayout.addWidget(self.view)
+ self._buildStatusBar()
+ # self.view.scale(1./self._viewScale, 1./self._viewScale)
+ self.view.fitInView(self.page.rect(), qt.Qt.KeepAspectRatio)
+ self._viewScale = 1.00
+ self._updateTargetLabel()
+
+ # Public methods
+ def addImage(self, image, title=None, comment=None, commentPosition=None):
+ """Add an image to the print preview scene.
+
+ :param QImage image: Image to be added to the scene
+ :param str title: Title shown above (centered) the image
+ :param str comment: Comment displayed below the image
+ :param commentPosition: "CENTER" or "LEFT"
+ """
+ self.addPixmap(qt.QPixmap.fromImage(image),
+ title=title, comment=comment,
+ commentPosition=commentPosition)
+
+ def addPixmap(self, pixmap, title=None, comment=None, commentPosition=None):
+ """Add a pixmap to the print preview scene
+
+ :param QPixmap pixmap: Pixmap to be added to the scene
+ :param str title: Title shown above (centered) the pixmap
+ :param str comment: Comment displayed below the pixmap
+ :param commentPosition: "CENTER" or "LEFT"
+ """
+ if self._toBeCleared:
+ self._clearAll()
+ self.ensurePrinterIsSet()
+ if self.printer is None:
+ _logger.error("printer is not set, cannot add pixmap to page")
+ return
+ if title is None:
+ title = ' ' * 88
+ if comment is None:
+ comment = ' ' * 88
+ if commentPosition is None:
+ commentPosition = "CENTER"
+ rectItem = qt.QGraphicsRectItem(self.page)
+ rectItem.setRect(qt.QRectF(1, 1,
+ pixmap.width(), pixmap.height()))
+
+ pen = rectItem.pen()
+ color = qt.QColor(qt.Qt.red)
+ color.setAlpha(1)
+ pen.setColor(color)
+ rectItem.setPen(pen)
+ rectItem.setZValue(1)
+ rectItem.setFlag(qt.QGraphicsItem.ItemIsSelectable, True)
+ rectItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+ rectItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
+
+ rectItemResizeRect = _GraphicsResizeRectItem(rectItem, self.scene)
+ rectItemResizeRect.setZValue(2)
+
+ pixmapItem = qt.QGraphicsPixmapItem(rectItem)
+ pixmapItem.setPixmap(pixmap)
+ pixmapItem.setZValue(0)
+
+ # I add the title
+ textItem = qt.QGraphicsTextItem(title, rectItem)
+ textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ offset = 0.5 * textItem.boundingRect().width()
+ textItem.moveBy(0.5 * pixmap.width() - offset, -20)
+ textItem.setZValue(2)
+
+ # I add the comment
+ commentItem = qt.QGraphicsTextItem(comment, rectItem)
+ commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ offset = 0.5 * commentItem.boundingRect().width()
+ if commentPosition.upper() == "LEFT":
+ x = 1
+ else:
+ x = 0.5 * pixmap.width() - offset
+ commentItem.moveBy(x, pixmap.height() + 20)
+ commentItem.setZValue(2)
+
+ rectItem.moveBy(20, 40)
+
+ def addSvgItem(self, item, title=None,
+ comment=None, commentPosition=None,
+ viewBox=None, keepRatio=True):
+ """Add a SVG item to the scene.
+
+ :param QSvgRenderer item: SVG item to be added to the scene.
+ :param str title: Title shown above (centered) the SVG item.
+ :param str comment: Comment displayed below the SVG item.
+ :param str commentPosition: "CENTER" or "LEFT"
+ :param QRectF viewBox: Bounding box for the item on the print page
+ (xOffset, yOffset, width, height). If None, use original
+ item size.
+ :param bool keepRatio: If True, resizing the item will preserve its
+ original aspect ratio.
+ """
+ if not qt.HAS_SVG:
+ raise RuntimeError("Missing QtSvg library.")
+ if not isinstance(item, qt.QSvgRenderer):
+ raise TypeError("addSvgItem: QSvgRenderer expected")
+ if self._toBeCleared:
+ self._clearAll()
+ self.ensurePrinterIsSet()
+ if self.printer is None:
+ _logger.error("printer is not set, cannot add SvgItem to page")
+ return
+
+ if title is None:
+ title = 50 * ' '
+ if comment is None:
+ comment = 80 * ' '
+ if commentPosition is None:
+ commentPosition = "CENTER"
+
+ if viewBox is None:
+ if hasattr(item, "_viewBox"):
+ # PyMca compatibility: viewbox attached to item
+ viewBox = item._viewBox
+ else:
+ # try the original item viewbox
+ viewBox = item.viewBoxF()
+
+ svgItem = _GraphicsSvgRectItem(viewBox, self.page)
+ svgItem.setSvgRenderer(item)
+
+ svgItem.setCacheMode(qt.QGraphicsItem.NoCache)
+ svgItem.setZValue(0)
+ svgItem.setFlag(qt.QGraphicsItem.ItemIsSelectable, True)
+ svgItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+ svgItem.setFlag(qt.QGraphicsItem.ItemIsFocusable, False)
+
+ rectItemResizeRect = _GraphicsResizeRectItem(svgItem, self.scene,
+ keepratio=keepRatio)
+ rectItemResizeRect.setZValue(2)
+
+ self._svgItems.append(item)
+
+ # Comment / legend
+ dummyComment = 80 * "1"
+ commentItem = qt.QGraphicsTextItem(dummyComment, svgItem)
+ commentItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ # we scale the text to have the legend box have the same width as the graph
+ scaleCalculationRect = qt.QRectF(commentItem.boundingRect())
+ scale = svgItem.boundingRect().width() / scaleCalculationRect.width()
+
+ commentItem.setPlainText(comment)
+ commentItem.setZValue(1)
+
+ commentItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+ commentItem.setScale(scale)
+
+ # align
+ if commentPosition.upper() == "CENTER":
+ alignment = qt.Qt.AlignCenter
+ elif commentPosition.upper() == "RIGHT":
+ alignment = qt.Qt.AlignRight
+ else:
+ alignment = qt.Qt.AlignLeft
+ commentItem.setTextWidth(commentItem.boundingRect().width())
+ center_format = qt.QTextBlockFormat()
+ center_format.setAlignment(alignment)
+ cursor = commentItem.textCursor()
+ cursor.select(qt.QTextCursor.Document)
+ cursor.mergeBlockFormat(center_format)
+ cursor.clearSelection()
+ commentItem.setTextCursor(cursor)
+ if alignment == qt.Qt.AlignLeft:
+ deltax = 0
+ else:
+ deltax = (svgItem.boundingRect().width() - commentItem.boundingRect().width()) / 2.
+ commentItem.moveBy(svgItem.boundingRect().x() + deltax,
+ svgItem.boundingRect().y() + svgItem.boundingRect().height())
+
+ # Title
+ textItem = qt.QGraphicsTextItem(title, svgItem)
+ textItem.setTextInteractionFlags(qt.Qt.TextEditorInteraction)
+ textItem.setZValue(1)
+ textItem.setFlag(qt.QGraphicsItem.ItemIsMovable, True)
+
+ title_offset = 0.5 * textItem.boundingRect().width()
+ textItem.moveBy(svgItem.boundingRect().x() +
+ 0.5 * svgItem.boundingRect().width() - title_offset * scale,
+ svgItem.boundingRect().y())
+ textItem.setScale(scale)
+
+ def setup(self):
+ """Open a print dialog to ensure the :attr:`printer` is set.
+
+ If the setting fails or is cancelled, :attr:`printer` is reset to
+ *None*.
+ """
+ if self.printer is None:
+ self.printer = printer.getDefaultPrinter()
+ if self.printDialog is None:
+ self.printDialog = qt.QPrintDialog(self.printer, self)
+ if self.printDialog.exec():
+ if self.printer.width() <= 0 or self.printer.height() <= 0:
+ self.message = qt.QMessageBox(self)
+ self.message.setIcon(qt.QMessageBox.Critical)
+ self.message.setText("Unknown library error \non printer initialization")
+ self.message.setWindowTitle("Library Error")
+ self.message.setModal(0)
+ self.printer = None
+ return
+ self.printer.setFullPage(True)
+ self._updatePrinter()
+ else:
+ # printer setup cancelled, check for a possible previous configuration
+ if self.page is None:
+ # not initialized
+ self.printer = None
+
+ def ensurePrinterIsSet(self):
+ """If the printer is not already set, try to interactively
+ setup the printer using a QPrintDialog.
+ In case of failure, hide widget and log a warning.
+
+ :return: True if printer was set. False if it failed or if the
+ selection dialog was canceled.
+ """
+ if self.printer is None:
+ self.setup()
+ if self.printer is None:
+ self.hide()
+ _logger.warning("Printer setup failed or was cancelled, " +
+ "but printer is required.")
+ return self.printer is not None
+
+ def setOutputFileName(self, name):
+ """Set output filename.
+
+ Setting a non-empty name enables printing to file.
+
+ :param str name: File name (path)"""
+ self.printer.setOutputFileName(name)
+
+ # overloaded methods
+ def exec(self):
+ if self._toBeCleared:
+ self._clearAll()
+ return qt.QDialog.exec(self)
+
+ def exec_(self): # Qt5 compatibility
+ return self.exec()
+
+ def raise_(self):
+ if self._toBeCleared:
+ self._clearAll()
+ return qt.QDialog.raise_(self)
+
+ def showEvent(self, event):
+ """Reimplemented to force printer setup.
+ In case of failure, hide the widget."""
+ if self._toBeCleared:
+ self._clearAll()
+ self.ensurePrinterIsSet()
+
+ return super(PrintPreviewDialog, self).showEvent(event)
+
+ # button callbacks
+ def _print(self):
+ """Do the printing, hide the print preview dialog,
+ set :attr:`_toBeCleared` flag to True to trigger clearing the
+ next time the dialog is shown.
+
+ If the printer is not setup, do it first."""
+ printer = self.printer
+
+ painter = qt.QPainter()
+ if not painter.begin(printer) or printer is None:
+ _logger.error("Cannot initialize printer")
+ return
+ try:
+ self.scene.render(painter, qt.QRectF(0, 0, printer.width(), printer.height()),
+ qt.QRectF(self.page.rect().x(), self.page.rect().y(),
+ self.page.rect().width(), self.page.rect().height()),
+ qt.Qt.KeepAspectRatio)
+ painter.end()
+ self.hide()
+ self.accept()
+ self._toBeCleared = True
+ except: # FIXME
+ painter.end()
+ qt.QMessageBox.critical(self, "ERROR",
+ 'Printing problem:\n %s' % sys.exc_info()[1])
+ _logger.error('printing problem:\n %s' % sys.exc_info()[1])
+ return
+
+ def _zoomPlus(self):
+ self._viewScale *= 1.20
+ self.view.scale(1.20, 1.20)
+
+ def _zoomMinus(self):
+ self._viewScale *= 0.80
+ self.view.scale(0.80, 0.80)
+
+ def _clearAll(self):
+ """
+ Clear the print preview window, remove all items
+ but keep the page.
+ """
+ itemlist = self.scene.items()
+ keep = self.page
+ while len(itemlist) != 1:
+ if itemlist.index(keep) == 0:
+ self.scene.removeItem(itemlist[1])
+ else:
+ self.scene.removeItem(itemlist[0])
+ itemlist = self.scene.items()
+ self._svgItems = []
+ self._toBeCleared = False
+
+ def _remove(self):
+ """Remove selected item in :attr:`scene`.
+ """
+ itemlist = self.scene.items()
+
+ # this loop is not efficient if there are many items ...
+ for item in itemlist:
+ if item.isSelected():
+ self.scene.removeItem(item)
+
+
+class SingletonPrintPreviewDialog(PrintPreviewDialog):
+ """Singleton print preview dialog.
+
+ All widgets in a program that instantiate this class will share
+ a single print preview dialog. This enables sending
+ multiple images to a single page to be printed.
+ """
+ _instance = None
+
+ def __new__(self, *var, **kw):
+ if self._instance is None:
+ self._instance = PrintPreviewDialog(*var, **kw)
+ return self._instance
+
+
+class _GraphicsSvgRectItem(qt.QGraphicsRectItem):
+ """:class:`qt.QGraphicsRectItem` with an attached
+ :class:`qt.QSvgRenderer`, and with a painter redefined to render
+ the SVG item."""
+ def setSvgRenderer(self, renderer):
+ """
+
+ :param QSvgRenderer renderer: svg renderer
+ """
+ self._renderer = renderer
+
+ def paint(self, painter, *var, **kw):
+ self._renderer.render(painter, self.boundingRect())
+
+
+class _GraphicsResizeRectItem(qt.QGraphicsRectItem):
+ """Resizable QGraphicsRectItem."""
+ def __init__(self, parent=None, scene=None, keepratio=True):
+ qt.QGraphicsRectItem.__init__(self, parent)
+ rect = parent.boundingRect()
+ x = rect.x()
+ y = rect.y()
+ w = rect.width()
+ h = rect.height()
+ self._newRect = None
+ self.keepRatio = keepratio
+ self.setRect(qt.QRectF(x + w - 40, y + h - 40, 40, 40))
+ self.setAcceptHoverEvents(True)
+ pen = qt.QPen()
+ color = qt.QColor(qt.Qt.white)
+ color.setAlpha(0)
+ pen.setColor(color)
+ pen.setStyle(qt.Qt.NoPen)
+ self.setPen(pen)
+ self.setBrush(color)
+ self.setFlag(self.ItemIsMovable, True)
+ self.show()
+
+ def hoverEnterEvent(self, event):
+ if self.parentItem().isSelected():
+ self.parentItem().setSelected(False)
+ if self.keepRatio:
+ self.setCursor(qt.QCursor(qt.Qt.SizeFDiagCursor))
+ else:
+ self.setCursor(qt.QCursor(qt.Qt.SizeAllCursor))
+ self.setBrush(qt.QBrush(qt.Qt.yellow, qt.Qt.SolidPattern))
+ return qt.QGraphicsRectItem.hoverEnterEvent(self, event)
+
+ def hoverLeaveEvent(self, event):
+ self.setCursor(qt.QCursor(qt.Qt.ArrowCursor))
+ pen = qt.QPen()
+ color = qt.QColor(qt.Qt.white)
+ color.setAlpha(0)
+ pen.setColor(color)
+ pen.setStyle(qt.Qt.NoPen)
+ self.setPen(pen)
+ self.setBrush(color)
+ return qt.QGraphicsRectItem.hoverLeaveEvent(self, event)
+
+ def mousePressEvent(self, event):
+ if self._newRect is not None:
+ self._newRect = None
+ self._point0 = self.pos()
+ parent = self.parentItem()
+ scene = self.scene()
+ # following line prevents dragging along the previously selected
+ # item when resizing another one
+ scene.clearSelection()
+
+ rect = parent.boundingRect()
+ self._x = rect.x()
+ self._y = rect.y()
+ self._w = rect.width()
+ self._h = rect.height()
+ self._ratio = self._w / self._h
+ self._newRect = qt.QGraphicsRectItem(parent)
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ self._w,
+ self._h))
+ qt.QGraphicsRectItem.mousePressEvent(self, event)
+
+ def mouseMoveEvent(self, event):
+ point1 = self.pos()
+ deltax = point1.x() - self._point0.x()
+ deltay = point1.y() - self._point0.y()
+ if self.keepRatio:
+ r1 = (self._w + deltax) / self._w
+ r2 = (self._h + deltay) / self._h
+ if r1 < r2:
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ self._w + deltax,
+ (self._w + deltax) / self._ratio))
+ else:
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ (self._h + deltay) * self._ratio,
+ self._h + deltay))
+ else:
+ self._newRect.setRect(qt.QRectF(self._x,
+ self._y,
+ self._w + deltax,
+ self._h + deltay))
+ qt.QGraphicsRectItem.mouseMoveEvent(self, event)
+
+ def mouseReleaseEvent(self, event):
+ point1 = self.pos()
+ deltax = point1.x() - self._point0.x()
+ deltay = point1.y() - self._point0.y()
+ self.moveBy(-deltax, -deltay)
+ parent = self.parentItem()
+
+ # deduce scale from rectangle
+ if self.keepRatio:
+ scalex = self._newRect.rect().width() / self._w
+ scaley = scalex
+ else:
+ scalex = self._newRect.rect().width() / self._w
+ scaley = self._newRect.rect().height() / self._h
+
+ # apply the scale to the previous transformation matrix
+ previousTransform = parent.transform()
+ parent.setTransform(
+ previousTransform.scale(scalex, scaley))
+
+ self.scene().removeItem(self._newRect)
+ self._newRect = None
+ qt.QGraphicsRectItem.mouseReleaseEvent(self, event)
+
+
+def main():
+ """
+ """
+ if len(sys.argv) < 2:
+ print("give an image file as parameter please.")
+ sys.exit(1)
+
+ if len(sys.argv) > 2:
+ print("only one parameter please.")
+ sys.exit(1)
+
+ filename = sys.argv[1]
+ w = PrintPreviewDialog()
+ w.resize(400, 500)
+
+ comment = ""
+ for i in range(20):
+ comment += "Line number %d: En un lugar de La Mancha de cuyo nombre ...\n" % i
+
+ if filename[-3:] == "svg":
+ item = qt.QSvgRenderer(filename, w.page)
+ w.addSvgItem(item, title=filename,
+ comment=comment, commentPosition="CENTER")
+ else:
+ w.addPixmap(qt.QPixmap.fromImage(qt.QImage(filename)),
+ title=filename,
+ comment=comment,
+ commentPosition="CENTER")
+ w.addImage(qt.QImage(filename), comment=comment, commentPosition="LEFT")
+
+ sys.exit(w.exec())
+
+
+if __name__ == '__main__':
+ a = qt.QApplication(sys.argv)
+ main()
+ a.exec()
diff --git a/src/silx/gui/widgets/RangeSlider.py b/src/silx/gui/widgets/RangeSlider.py
new file mode 100644
index 0000000..61b73fc
--- /dev/null
+++ b/src/silx/gui/widgets/RangeSlider.py
@@ -0,0 +1,776 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides a :class:`RangeSlider` widget.
+
+.. image:: img/RangeSlider.png
+ :align: center
+"""
+from __future__ import absolute_import, division
+
+__authors__ = ["D. Naudet", "T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/11/2018"
+
+
+import numpy as numpy
+
+from silx.gui import qt, icons, colors
+from silx.gui.utils.image import convertArrayToQImage
+
+
+class StyleOptionRangeSlider(qt.QStyleOption):
+ def __init__(self):
+ super(StyleOptionRangeSlider, self).__init__()
+ self.minimum = None
+ self.maximum = None
+ self.sliderPosition1 = None
+ self.sliderPosition2 = None
+ self.handlerRect1 = None
+ self.handlerRect2 = None
+
+
+class RangeSlider(qt.QWidget):
+ """Range slider with 2 thumbs and an optional colored groove.
+
+ The position of the slider thumbs can be retrieved either as values
+ in the slider range or as a number of steps or pixels.
+
+ :param QWidget parent: See QWidget
+ """
+
+ _SLIDER_WIDTH = 10
+ """Width of the slider rectangle"""
+
+ _PIXMAP_VOFFSET = 7
+ """Vertical groove pixmap offset"""
+
+ sigRangeChanged = qt.Signal(float, float)
+ """Signal emitted when the value range has changed.
+
+ It provides the new range (min, max).
+ """
+
+ sigValueChanged = qt.Signal(float, float)
+ """Signal emitted when the value of the sliders has changed.
+
+ It provides the slider values (first, second).
+ """
+
+ sigPositionCountChanged = qt.Signal(object)
+ """This signal is emitted when the number of steps has changed.
+
+ It provides the new position count.
+ """
+
+ sigPositionChanged = qt.Signal(int, int)
+ """Signal emitted when the position of the sliders has changed.
+
+ It provides the slider positions in steps or pixels (first, second).
+ """
+
+ def __init__(self, parent=None):
+ self.__pixmap = None
+ self.__positionCount = None
+ self.__firstValue = 0.
+ self.__secondValue = 1.
+ self.__minValue = 0.
+ self.__maxValue = 1.
+ self.__hoverRect = qt.QRect()
+ self.__hoverControl = None
+
+ self.__focus = None
+ self.__moving = None
+
+ self.__icons = {
+ 'first': icons.getQIcon('previous'),
+ 'second': icons.getQIcon('next')
+ }
+
+ # call the super constructor AFTER defining all members that
+ # are used in the "paint" method
+ super(RangeSlider, self).__init__(parent)
+
+ self.setFocusPolicy(qt.Qt.ClickFocus)
+ self.setAttribute(qt.Qt.WA_Hover)
+
+ self.setMinimumSize(qt.QSize(50, 20))
+ self.setMaximumHeight(20)
+
+ # Broadcast value changed signal
+ self.sigValueChanged.connect(self.__emitPositionChanged)
+
+ def event(self, event):
+ t = event.type()
+ if t == qt.QEvent.HoverEnter or t == qt.QEvent.HoverLeave or t == qt.QEvent.HoverMove:
+ return self.__updateHoverControl(event.pos())
+ else:
+ return super(RangeSlider, self).event(event)
+
+ def __updateHoverControl(self, pos):
+ hoverControl, hoverRect = self.__findHoverControl(pos)
+ if hoverControl != self.__hoverControl:
+ self.update(self.__hoverRect)
+ self.update(hoverRect)
+ self.__hoverControl = hoverControl
+ self.__hoverRect = hoverRect
+ return True
+ return hoverControl is not None
+
+ def __findHoverControl(self, pos):
+ """Returns the control at the position and it's rect location"""
+ for name in ["first", "second"]:
+ rect = self.__sliderRect(name)
+ if rect.contains(pos):
+ return name, rect
+ rect = self.__drawArea()
+ if rect.contains(pos):
+ return "groove", rect
+ return None, qt.QRect()
+
+ # Position <-> Value conversion
+
+ def __positionToValue(self, position):
+ """Returns value corresponding to position
+
+ :param int position:
+ :rtype: float
+ """
+ min_, max_ = self.getMinimum(), self.getMaximum()
+ maxPos = self.__getCurrentPositionCount() - 1
+ return min_ + (max_ - min_) * int(position) / maxPos
+
+ def __valueToPosition(self, value):
+ """Returns closest position corresponding to value
+
+ :param float value:
+ :rtype: int
+ """
+ min_, max_ = self.getMinimum(), self.getMaximum()
+ maxPos = self.__getCurrentPositionCount() - 1
+ return int(0.5 + maxPos * (float(value) - min_) / (max_ - min_))
+
+ # Position (int) API
+
+ def __getCurrentPositionCount(self):
+ """Return current count (either position count or widget width
+
+ :rtype: int
+ """
+ count = self.getPositionCount()
+ if count is not None:
+ return count
+ else:
+ return max(2, self.width() - self._SLIDER_WIDTH)
+
+ def getPositionCount(self):
+ """Returns the number of positions.
+
+ :rtype: Union[int,None]"""
+ return self.__positionCount
+
+ def setPositionCount(self, count):
+ """Set the number of positions.
+
+ Slider values are eventually adjusted.
+
+ :param Union[int,None] count:
+ Either the number of possible positions or
+ None to allow any values.
+ :raise ValueError: If count <= 1
+ """
+ count = None if count is None else int(count)
+ if count != self.getPositionCount():
+ if count is not None and count <= 1:
+ raise ValueError("Position count must be higher than 1")
+ self.__positionCount = count
+ emit = self.__setValues(*self.getValues())
+ self.sigPositionCountChanged.emit(count)
+ if emit:
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getFirstPosition(self):
+ """Returns first slider position
+
+ :rtype: int
+ """
+ return self.__valueToPosition(self.getFirstValue())
+
+ def setFirstPosition(self, position):
+ """Set the position of the first slider
+
+ The position is adjusted to valid values
+
+ :param int position:
+ """
+ self.setFirstValue(self.__positionToValue(position))
+
+ def getSecondPosition(self):
+ """Returns second slider position
+
+ :rtype: int
+ """
+ return self.__valueToPosition(self.getSecondValue())
+
+ def setSecondPosition(self, position):
+ """Set the position of the second slider
+
+ The position is adjusted to valid values
+
+ :param int position:
+ """
+ self.setSecondValue(self.__positionToValue(position))
+
+ def getPositions(self):
+ """Returns slider positions (first, second)
+
+ :rtype: List[int]
+ """
+ return self.getFirstPosition(), self.getSecondPosition()
+
+ def setPositions(self, first, second):
+ """Set the position of both sliders at once
+
+ First is clipped to the slider range: [0, max].
+ Second is clipped to valid values: [first, max]
+
+ :param int first:
+ :param int second:
+ """
+ self.setValues(self.__positionToValue(first),
+ self.__positionToValue(second))
+
+ # Value (float) API
+
+ def __emitPositionChanged(self, *args, **kwargs):
+ self.sigPositionChanged.emit(*self.getPositions())
+
+ def __rangeChanged(self):
+ """Handle change of value range"""
+ emit = self.__setValues(*self.getValues())
+ self.sigRangeChanged.emit(*self.getRange())
+ if emit:
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getMinimum(self):
+ """Returns the minimum value of the slider range
+
+ :rtype: float
+ """
+ return self.__minValue
+
+ def setMinimum(self, minimum):
+ """Set the minimum value of the slider range.
+
+ It eventually adjusts maximum.
+ Slider positions remains unchanged and slider values are modified.
+
+ :param float minimum:
+ :raises ValueError:
+ """
+ minimum = float(minimum)
+ if minimum == self.getMaximum():
+ raise ValueError("min and max must be different")
+
+ if minimum != self.getMinimum():
+ if minimum > self.getMaximum():
+ self.__maxValue = minimum
+ self.__minValue = minimum
+ self.__rangeChanged()
+
+ def getMaximum(self):
+ """Returns the maximum value of the slider range
+
+ :rtype: float
+ """
+ return self.__maxValue
+
+ def setMaximum(self, maximum):
+ """Set the maximum value of the slider range
+
+ It eventually adjusts minimum.
+ Slider positions remains unchanged and slider values are modified.
+
+ :param float maximum:
+ :raises ValueError:
+ """
+ maximum = float(maximum)
+ if maximum == self.getMinimum():
+ raise ValueError("min and max must be different")
+
+ if maximum != self.getMaximum():
+ if maximum < self.getMinimum():
+ self.__minValue = maximum
+ self.__maxValue = maximum
+ self.__rangeChanged()
+
+ def getRange(self):
+ """Returns the range of values (min, max)
+
+ :rtype: List[float]
+ """
+ return self.getMinimum(), self.getMaximum()
+
+ def setRange(self, minimum, maximum):
+ """Set the range of values.
+
+ If maximum is lower than minimum, minimum is the only valid value.
+ Slider positions remains unchanged and slider values are modified.
+
+ :param float minimum:
+ :param float maximum:
+ :raises ValueError:
+ """
+ minimum, maximum = float(minimum), float(maximum)
+ if minimum == maximum:
+ raise ValueError("min and max must be different")
+ if minimum != self.getMinimum() or maximum != self.getMaximum():
+ self.__minValue = minimum
+ self.__maxValue = max(maximum, minimum)
+ self.__rangeChanged()
+
+ def getFirstValue(self):
+ """Returns the value of the first slider
+
+ :rtype: float
+ """
+ return self.__firstValue
+
+ def __clipFirstValue(self, value, max_=None):
+ """Clip first value to range and steps
+
+ :param float value:
+ :param float max_: Alternative maximum to use
+ """
+ if max_ is None:
+ max_ = self.getSecondValue()
+ value = min(max(self.getMinimum(), float(value)), max_)
+ if self.getPositionCount() is not None: # Clip to steps
+ value = self.__positionToValue(self.__valueToPosition(value))
+ return value
+
+ def setFirstValue(self, value):
+ """Set the value of the first slider
+
+ Value is clipped to valid values.
+
+ :param float value:
+ """
+ value = self.__clipFirstValue(value)
+ if value != self.getFirstValue():
+ self.__firstValue = value
+ self.update()
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getSecondValue(self):
+ """Returns the value of the second slider
+
+ :rtype: float
+ """
+ return self.__secondValue
+
+ def __clipSecondValue(self, value):
+ """Clip second value to range and steps
+
+ :param float value:
+ """
+ value = min(max(self.getFirstValue(), float(value)), self.getMaximum())
+ if self.getPositionCount() is not None: # Clip to steps
+ value = self.__positionToValue(self.__valueToPosition(value))
+ return value
+
+ def setSecondValue(self, value):
+ """Set the value of the second slider
+
+ Value is clipped to valid values.
+
+ :param float value:
+ """
+ value = self.__clipSecondValue(value)
+ if value != self.getSecondValue():
+ self.__secondValue = value
+ self.update()
+ self.sigValueChanged.emit(*self.getValues())
+
+ def getValues(self):
+ """Returns value of both sliders at once
+
+ :return: (first value, second value)
+ :rtype: List[float]
+ """
+ return self.getFirstValue(), self.getSecondValue()
+
+ def setValues(self, first, second):
+ """Set values for both sliders at once
+
+ First is clipped to the slider range: [minimum, maximum].
+ Second is clipped to valid values: [first, maximum]
+
+ :param float first:
+ :param float second:
+ """
+ if self.__setValues(first, second):
+ self.sigValueChanged.emit(*self.getValues())
+
+ def __setValues(self, first, second):
+ """Set values for both sliders at once
+
+ First is clipped to the slider range: [minimum, maximum].
+ Second is clipped to valid values: [first, maximum]
+
+ :param float first:
+ :param float second:
+ :return: True if values has changed, False otherwise
+ :rtype: bool
+ """
+ first = self.__clipFirstValue(first, self.getMaximum())
+ second = self.__clipSecondValue(second)
+ values = first, second
+
+ if self.getValues() != values:
+ self.__firstValue, self.__secondValue = values
+ self.update()
+ return True
+ return False
+
+ # Groove API
+
+ def getGroovePixmap(self):
+ """Returns the pixmap displayed in the slider groove if any.
+
+ :rtype: Union[QPixmap,None]
+ """
+ return self.__pixmap
+
+ def setGroovePixmap(self, pixmap):
+ """Set the pixmap displayed in the slider groove.
+
+ :param Union[QPixmap,None] pixmap: The QPixmap to use or None to unset.
+ """
+ assert pixmap is None or isinstance(pixmap, qt.QPixmap)
+ self.__pixmap = pixmap
+ self.update()
+
+ def setGroovePixmapFromProfile(self, profile, colormap=None):
+ """Set the pixmap displayed in the slider groove from histogram values.
+
+ :param Union[numpy.ndarray,None] profile:
+ 1D array of values to display
+ :param Union[~silx.gui.colors.Colormap,str] colormap:
+ The colormap name or object to convert profile values to colors
+ """
+ if profile is None:
+ self.setSliderPixmap(None)
+ return
+
+ profile = numpy.array(profile, copy=False)
+
+ if profile.size == 0:
+ self.setSliderPixmap(None)
+ return
+
+ if colormap is None:
+ colormap = colors.Colormap()
+ elif isinstance(colormap, str):
+ colormap = colors.Colormap(name=colormap)
+ assert isinstance(colormap, colors.Colormap)
+
+ rgbImage = colormap.applyToData(profile.reshape(1, -1))[:, :, :3]
+ qimage = convertArrayToQImage(rgbImage)
+ qpixmap = qt.QPixmap.fromImage(qimage)
+ self.setGroovePixmap(qpixmap)
+
+ # Handle interaction
+
+ def mousePressEvent(self, event):
+ super(RangeSlider, self).mousePressEvent(event)
+
+ if event.buttons() == qt.Qt.LeftButton:
+ picked = None
+ for name in ('first', 'second'):
+ area = self.__sliderRect(name)
+ if area.contains(event.pos()):
+ picked = name
+ break
+
+ self.__moving = picked
+ self.__focus = picked
+ self.update()
+
+ def mouseMoveEvent(self, event):
+ super(RangeSlider, self).mouseMoveEvent(event)
+
+ if self.__moving is not None:
+ delta = self._SLIDER_WIDTH // 2
+ if self.__moving == 'first':
+ position = self.__xPixelToPosition(event.pos().x() + delta)
+ self.setFirstPosition(position)
+ else:
+ position = self.__xPixelToPosition(event.pos().x() - delta)
+ self.setSecondPosition(position)
+
+ def mouseReleaseEvent(self, event):
+ super(RangeSlider, self).mouseReleaseEvent(event)
+
+ if event.button() == qt.Qt.LeftButton and self.__moving is not None:
+ self.__moving = None
+ self.update()
+
+ def focusOutEvent(self, event):
+ if self.__focus is not None:
+ self.__focus = None
+ self.update()
+ super(RangeSlider, self).focusOutEvent(event)
+
+ def keyPressEvent(self, event):
+ key = event.key()
+ if event.modifiers() == qt.Qt.NoModifier and self.__focus is not None:
+ if key in (qt.Qt.Key_Left, qt.Qt.Key_Down):
+ if self.__focus == 'first':
+ self.setFirstPosition(self.getFirstPosition() - 1)
+ else:
+ self.setSecondPosition(self.getSecondPosition() - 1)
+ return # accept event
+ elif key in (qt.Qt.Key_Right, qt.Qt.Key_Up):
+ if self.__focus == 'first':
+ self.setFirstPosition(self.getFirstPosition() + 1)
+ else:
+ self.setSecondPosition(self.getSecondPosition() + 1)
+ return # accept event
+
+ super(RangeSlider, self).keyPressEvent(event)
+
+ # Handle resize
+
+ def resizeEvent(self, event):
+ super(RangeSlider, self).resizeEvent(event)
+
+ # If no step, signal position update when width change
+ if (self.getPositionCount() is None and
+ event.size().width() != event.oldSize().width()):
+ self.sigPositionChanged.emit(*self.getPositions())
+
+ # Handle repaint
+
+ def __xPixelToPosition(self, x):
+ """Convert position in pixel to slider position
+
+ :param int x: X in pixel coordinates
+ :rtype: int
+ """
+ sliderArea = self.__sliderAreaRect()
+ maxPos = self.__getCurrentPositionCount() - 1
+ position = maxPos * (x - sliderArea.left()) / (sliderArea.width() - 1)
+ return int(position + 0.5)
+
+ def __sliderRect(self, name):
+ """Returns rectangle corresponding to slider in pixels
+
+ :param str name: 'first' or 'second'
+ :rtype: QRect
+ :raise ValueError: If wrong name
+ """
+ assert name in ('first', 'second')
+ if name == 'first':
+ offset = - self._SLIDER_WIDTH
+ position = self.getFirstPosition()
+ elif name == 'second':
+ offset = 0
+ position = self.getSecondPosition()
+ else:
+ raise ValueError('Unknown name')
+
+ sliderArea = self.__sliderAreaRect()
+
+ maxPos = self.__getCurrentPositionCount() - 1
+ xOffset = int((sliderArea.width() - 1) * position / maxPos)
+ xPos = sliderArea.left() + xOffset + offset
+
+ return qt.QRect(xPos,
+ sliderArea.top(),
+ self._SLIDER_WIDTH,
+ sliderArea.height())
+
+ def __drawArea(self):
+ return self.rect().adjusted(self._SLIDER_WIDTH, 0,
+ -self._SLIDER_WIDTH, 0)
+
+ def __sliderAreaRect(self):
+ return self.__drawArea().adjusted(self._SLIDER_WIDTH // 2,
+ 0,
+ -self._SLIDER_WIDTH // 2 + 1,
+ 0)
+
+ def __pixMapRect(self):
+ return self.__sliderAreaRect().adjusted(0,
+ self._PIXMAP_VOFFSET,
+ -1,
+ -self._PIXMAP_VOFFSET)
+
+ def paintEvent(self, event):
+ painter = qt.QPainter(self)
+
+ style = qt.QApplication.style()
+
+ area = self.__drawArea()
+ if self.__pixmap is not None:
+ pixmapRect = self.__pixMapRect()
+
+ option = qt.QStyleOptionProgressBar()
+ option.initFrom(self)
+ option.rect = area
+ option.state = (qt.QStyle.State_Enabled if self.isEnabled()
+ else qt.QStyle.State_None)
+ style.drawControl(qt.QStyle.CE_ProgressBarGroove,
+ option,
+ painter,
+ self)
+
+ painter.save()
+ pen = painter.pen()
+ pen.setWidth(1)
+ pen.setColor(qt.Qt.black if self.isEnabled() else qt.Qt.gray)
+ painter.setPen(pen)
+ painter.drawRect(pixmapRect.adjusted(-1, -1, 0, 1))
+ painter.restore()
+
+ if self.isEnabled():
+ rect = area.adjusted(self._SLIDER_WIDTH // 2,
+ self._PIXMAP_VOFFSET,
+ -self._SLIDER_WIDTH // 2,
+ -self._PIXMAP_VOFFSET + 1)
+ painter.drawPixmap(rect,
+ self.__pixmap,
+ self.__pixmap.rect())
+ else:
+ option = StyleOptionRangeSlider()
+ option.initFrom(self)
+ option.rect = area
+ option.sliderPosition1 = self.__firstValue
+ option.sliderPosition2 = self.__secondValue
+ option.handlerRect1 = self.__sliderRect("first")
+ option.handlerRect2 = self.__sliderRect("second")
+ option.minimum = self.__minValue
+ option.maximum = self.__maxValue
+ option.state = (qt.QStyle.State_Enabled if self.isEnabled()
+ else qt.QStyle.State_None)
+ if self.__hoverControl == "groove":
+ option.state |= qt.QStyle.State_MouseOver
+ elif option.state & qt.QStyle.State_MouseOver:
+ option.state ^= qt.QStyle.State_MouseOver
+ self.drawRangeSliderBackground(painter, option, self)
+
+ # Avoid glitch when moving handles
+ hoverControl = self.__moving or self.__hoverControl
+
+ for name in ('first', 'second'):
+ rect = self.__sliderRect(name)
+ option = qt.QStyleOptionButton()
+ option.initFrom(self)
+ option.icon = self.__icons[name]
+ option.iconSize = rect.size() * 0.7
+ if hoverControl == name:
+ option.state |= qt.QStyle.State_MouseOver
+ elif option.state & qt.QStyle.State_MouseOver:
+ option.state ^= qt.QStyle.State_MouseOver
+ if self.__focus == name:
+ option.state |= qt.QStyle.State_HasFocus
+ elif option.state & qt.QStyle.State_HasFocus:
+ option.state ^= qt.QStyle.State_HasFocus
+ option.rect = rect
+ style.drawControl(
+ qt.QStyle.CE_PushButton, option, painter, self)
+
+ def sizeHint(self):
+ return qt.QSize(200, self.minimumHeight())
+
+ @classmethod
+ def drawRangeSliderBackground(cls, painter, option, widget):
+ """Draw the background of the RangeSlider widget into the painter.
+
+ :param qt.QPainter painter: A painter
+ :param StyleOptionRangeSlider option: Options to draw the widget
+ :param qt.QWidget: The widget which have to be drawn
+ """
+ painter.save()
+ painter.translate(0.5, 0.5)
+
+ backgroundRect = qt.QRect(option.rect)
+ if backgroundRect.height() > 8:
+ center = backgroundRect.center()
+ backgroundRect.setHeight(8)
+ backgroundRect.moveCenter(center)
+
+ selectedRangeRect = qt.QRect(backgroundRect)
+ selectedRangeRect.setLeft(option.handlerRect1.center().x())
+ selectedRangeRect.setRight(option.handlerRect2.center().x())
+
+ highlight = option.palette.color(qt.QPalette.Highlight)
+ activeHighlight = highlight
+ selectedOutline = option.palette.color(qt.QPalette.Highlight)
+
+ buttonColor = option.palette.button().color()
+ val = qt.qGray(buttonColor.rgb())
+ buttonColor = buttonColor.lighter(100 + max(1, (180 - val) // 6))
+ buttonColor.setHsv(buttonColor.hue(), (buttonColor.saturation() * 3) // 4, buttonColor.value())
+
+ grooveColor = qt.QColor()
+ grooveColor.setHsv(buttonColor.hue(),
+ min(255, (int)(buttonColor.saturation())),
+ min(255, (int)(buttonColor.value() * 0.9)))
+
+ selectedInnerContrastLine = qt.QColor(255, 255, 255, 30)
+
+ outline = option.palette.color(qt.QPalette.Window).darker(140)
+ if (option.state & qt.QStyle.State_HasFocus and option.state & qt.QStyle.State_KeyboardFocusChange):
+ outline = highlight.darker(125)
+ if outline.value() > 160:
+ outline.setHsl(highlight.hue(), highlight.saturation(), 160)
+
+ # Draw background groove
+ painter.setRenderHint(qt.QPainter.Antialiasing, True)
+ gradient = qt.QLinearGradient()
+ gradient.setStart(backgroundRect.center().x(), backgroundRect.top())
+ gradient.setFinalStop(backgroundRect.center().x(), backgroundRect.bottom())
+ painter.setPen(qt.QPen(outline))
+ gradient.setColorAt(0, grooveColor.darker(110))
+ gradient.setColorAt(1, grooveColor.lighter(110))
+ painter.setBrush(gradient)
+ painter.drawRoundedRect(backgroundRect.adjusted(1, 1, -2, -2), 1, 1)
+
+ # Draw slider background for the value
+ gradient = qt.QLinearGradient()
+ gradient.setStart(selectedRangeRect.center().x(), selectedRangeRect.top())
+ gradient.setFinalStop(selectedRangeRect.center().x(), selectedRangeRect.bottom())
+ painter.setRenderHint(qt.QPainter.Antialiasing, True)
+ painter.setPen(qt.QPen(selectedOutline))
+ gradient.setColorAt(0, activeHighlight)
+ gradient.setColorAt(1, activeHighlight.lighter(130))
+ painter.setBrush(gradient)
+ painter.drawRoundedRect(selectedRangeRect.adjusted(1, 1, -2, -2), 1, 1)
+ painter.setPen(selectedInnerContrastLine)
+ painter.setBrush(qt.Qt.NoBrush)
+ painter.drawRoundedRect(selectedRangeRect.adjusted(2, 2, -3, -3), 1, 1)
+
+ painter.restore()
diff --git a/src/silx/gui/widgets/TableWidget.py b/src/silx/gui/widgets/TableWidget.py
new file mode 100644
index 0000000..50eb9e2
--- /dev/null
+++ b/src/silx/gui/widgets/TableWidget.py
@@ -0,0 +1,626 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module provides table widgets handling cut, copy and paste for
+multiple cell selections. These actions can be triggered using keyboard
+shortcuts or through a context menu (right-click).
+
+:class:`TableView` is a subclass of :class:`QTableView`. The added features
+are made available to users after a model is added to the widget, using
+:meth:`TableView.setModel`.
+
+:class:`TableWidget` is a subclass of :class:`qt.QTableWidget`, a table view
+with a built-in standard data model. The added features are available as soon as
+the widget is initialized.
+
+The cut, copy and paste actions are implemented as QActions:
+
+ - :class:`CopySelectedCellsAction` (*Ctrl+C*)
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction` (*Ctrl+X*)
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction` (*Ctrl+V*)
+
+The copy actions are enabled by default. The cut and paste actions must be
+explicitly enabled, by passing parameters ``cut=True, paste=True`` when
+creating the widgets, or later by calling their :meth:`enableCut` and
+:meth:`enablePaste` methods.
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "03/07/2017"
+
+
+import sys
+from .. import qt
+
+
+if sys.platform.startswith("win"):
+ row_separator = "\r\n"
+else:
+ row_separator = "\n"
+
+col_separator = "\t"
+
+
+class CopySelectedCellsAction(qt.QAction):
+ """QAction to copy text from selected cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ If multiple cells are selected, the copied text will be a concatenation
+ of the texts in all selected cells, tabulated with tabulation and
+ newline characters.
+
+ If the cells are sparsely selected, the structure is preserved by
+ representing the unselected cells as empty strings in between two
+ tabulation characters.
+ Beware of pasting this data in another table widget, because depending
+ on how the paste is implemented, the empty cells may cause data in the
+ target table to be deleted, even though you didn't necessarily select the
+ corresponding cell in the origin table.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopySelectedCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopySelectedCellsAction, self).__init__(table)
+ self.setText("Copy selection")
+ self.setToolTip("Copy selected cells into the clipboard.")
+ self.setShortcut(qt.QKeySequence.Copy)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.triggered.connect(self.copyCellsToClipboard)
+ self.table = table
+ self.cut = False
+ """:attr:`cut` can be set to True by classes inheriting this action,
+ to do a cut action."""
+
+ def copyCellsToClipboard(self):
+ """Concatenate the text content of all selected cells into a string
+ using tabulations and newlines to keep the table structure.
+ Put this text into the clipboard.
+ """
+ selected_idx = self.table.selectedIndexes()
+ if not selected_idx:
+ return
+ selected_idx_tuples = [(idx.row(), idx.column()) for idx in selected_idx]
+
+ selected_rows = [idx[0] for idx in selected_idx_tuples]
+ selected_columns = [idx[1] for idx in selected_idx_tuples]
+
+ data_model = self.table.model()
+
+ copied_text = ""
+ for row in range(min(selected_rows), max(selected_rows) + 1):
+ for col in range(min(selected_columns), max(selected_columns) + 1):
+ index = data_model.index(row, col)
+ cell_text = data_model.data(index)
+ flags = data_model.flags(index)
+
+ if (row, col) in selected_idx_tuples and cell_text is not None:
+ copied_text += cell_text
+ if self.cut and (flags & qt.Qt.ItemIsEditable):
+ data_model.setData(index, "")
+ copied_text += col_separator
+ # remove the right-most tabulation
+ copied_text = copied_text[:-len(col_separator)]
+ # add a newline
+ copied_text += row_separator
+ # remove final newline
+ copied_text = copied_text[:-len(row_separator)]
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(copied_text)
+
+
+class CopyAllCellsAction(qt.QAction):
+ """QAction to copy text from all cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The copied text will be a concatenation
+ of the texts in all cells, tabulated with tabulation and
+ newline characters.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopyAllCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopyAllCellsAction, self).__init__(table)
+ self.setText("Copy all")
+ self.setToolTip("Copy all cells into the clipboard.")
+ self.triggered.connect(self.copyCellsToClipboard)
+ self.table = table
+ self.cut = False
+
+ def copyCellsToClipboard(self):
+ """Concatenate the text content of all cells into a string
+ using tabulations and newlines to keep the table structure.
+ Put this text into the clipboard.
+ """
+ data_model = self.table.model()
+ copied_text = ""
+ for row in range(data_model.rowCount()):
+ for col in range(data_model.columnCount()):
+ index = data_model.index(row, col)
+ cell_text = data_model.data(index)
+ flags = data_model.flags(index)
+ if cell_text is not None:
+ copied_text += cell_text
+ if self.cut and (flags & qt.Qt.ItemIsEditable):
+ data_model.setData(index, "")
+ copied_text += col_separator
+ # remove the right-most tabulation
+ copied_text = copied_text[:-len(col_separator)]
+ # add a newline
+ copied_text += row_separator
+ # remove final newline
+ copied_text = copied_text[:-len(row_separator)]
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(copied_text)
+
+
+class CutSelectedCellsAction(CopySelectedCellsAction):
+ """QAction to cut text from selected cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The text is deleted from the original table widget
+ (use :class:`CopySelectedCellsAction` to preserve the original data).
+
+ If multiple cells are selected, the cut text will be a concatenation
+ of the texts in all selected cells, tabulated with tabulation and
+ newline characters.
+
+ If the cells are sparsely selected, the structure is preserved by
+ representing the unselected cells as empty strings in between two
+ tabulation characters.
+ Beware of pasting this data in another table widget, because depending
+ on how the paste is implemented, the empty cells may cause data in the
+ target table to be deleted, even though you didn't necessarily select the
+ corresponding cell in the origin table.
+
+ :param table: :class:`QTableView` to which this action belongs."""
+ def __init__(self, table):
+ super(CutSelectedCellsAction, self).__init__(table)
+ self.setText("Cut selection")
+ self.setShortcut(qt.QKeySequence.Cut)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ # cutting is already implemented in CopySelectedCellsAction (but
+ # it is disabled), we just need to enable it
+ self.cut = True
+
+
+class CutAllCellsAction(CopyAllCellsAction):
+ """QAction to cut text from all cells in a :class:`QTableWidget`
+ into the clipboard.
+
+ The text is deleted from the original table widget
+ (use :class:`CopyAllCellsAction` to preserve the original data).
+
+ The cut text will be a concatenation
+ of the texts in all cells, tabulated with tabulation and
+ newline characters.
+
+ :param table: :class:`QTableView` to which this action belongs."""
+ def __init__(self, table):
+ super(CutAllCellsAction, self).__init__(table)
+ self.setText("Cut all")
+ self.setToolTip("Cut all cells into the clipboard.")
+ self.cut = True
+
+
+def _parseTextAsTable(text, row_separator=row_separator, col_separator=col_separator):
+ """Parse text into list of lists (2D sequence).
+
+ The input text must be tabulated using tabulation characters and
+ newlines to separate columns and rows.
+
+ :param text: text to be parsed
+ :param record_separator: String, or single character, to be interpreted
+ as a record/row separator.
+ :param field_separator: String, or single character, to be interpreted
+ as a field/column separator.
+ :return: 2D sequence of strings
+ """
+ rows = text.split(row_separator)
+ table_data = [row.split(col_separator) for row in rows]
+ return table_data
+
+
+class PasteCellsAction(qt.QAction):
+ """QAction to paste text from the clipboard into the table.
+
+ If the text contains tabulations and
+ newlines, they are interpreted as column and row separators.
+ In such a case, the text is split into multiple texts to be pasted
+ into multiple cells.
+
+ If a cell content is an empty string in the original text, it is
+ ignored: the destination cell's text will not be deleted.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('PasteCellsAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(PasteCellsAction, self).__init__(table)
+ self.table = table
+ self.setText("Paste")
+ self.setShortcut(qt.QKeySequence.Paste)
+ self.setShortcutContext(qt.Qt.WidgetShortcut)
+ self.setToolTip("Paste data. The selected cell is the top-left" +
+ "corner of the paste area.")
+ self.triggered.connect(self.pasteCellFromClipboard)
+
+ def pasteCellFromClipboard(self):
+ """Paste text from clipboard into the table.
+
+ :return: *True* in case of success, *False* if pasting data failed.
+ """
+ selected_idx = self.table.selectedIndexes()
+ if len(selected_idx) != 1:
+ msgBox = qt.QMessageBox(parent=self.table)
+ msgBox.setText("A single cell must be selected to paste data")
+ msgBox.exec()
+ return False
+
+ data_model = self.table.model()
+
+ selected_row = selected_idx[0].row()
+ selected_col = selected_idx[0].column()
+
+ qapp = qt.QApplication.instance()
+ clipboard_text = qapp.clipboard().text()
+ table_data = _parseTextAsTable(clipboard_text)
+
+ protected_cells = 0
+ out_of_range_cells = 0
+
+ # paste table data into cells, using selected cell as origin
+ for row_offset in range(len(table_data)):
+ for col_offset in range(len(table_data[row_offset])):
+ target_row = selected_row + row_offset
+ target_col = selected_col + col_offset
+
+ if target_row >= data_model.rowCount() or\
+ target_col >= data_model.columnCount():
+ out_of_range_cells += 1
+ continue
+
+ index = data_model.index(target_row, target_col)
+ flags = data_model.flags(index)
+
+ # ignore empty strings
+ if table_data[row_offset][col_offset] != "":
+ if not flags & qt.Qt.ItemIsEditable:
+ protected_cells += 1
+ continue
+ data_model.setData(index, table_data[row_offset][col_offset])
+ # item.setText(table_data[row_offset][col_offset])
+
+ if protected_cells or out_of_range_cells:
+ msgBox = qt.QMessageBox(parent=self.table)
+ msg = "Some data could not be inserted, "
+ msg += "due to out-of-range or write-protected cells."
+ msgBox.setText(msg)
+ msgBox.exec()
+ return False
+ return True
+
+
+class CopySingleCellAction(qt.QAction):
+ """QAction to copy text from a single cell in a modified
+ :class:`QTableWidget`.
+
+ This action relies on the fact that the text in the last clicked cell
+ are stored in :attr:`_last_cell_clicked` of the modified widget.
+
+ In most cases, :class:`CopySelectedCellsAction` handles single cells,
+ but if the selection mode of the widget has been set to NoSelection
+ it is necessary to use this class instead.
+
+ :param table: :class:`QTableView` to which this action belongs.
+ """
+ def __init__(self, table):
+ if not isinstance(table, qt.QTableView):
+ raise ValueError('CopySingleCellAction must be initialised ' +
+ 'with a QTableWidget.')
+ super(CopySingleCellAction, self).__init__(table)
+ self.setText("Copy cell")
+ self.setToolTip("Copy cell content into the clipboard.")
+ self.triggered.connect(self.copyCellToClipboard)
+ self.table = table
+
+ def copyCellToClipboard(self):
+ """
+ """
+ cell_text = self.table._text_last_cell_clicked
+ if cell_text is None:
+ return
+
+ # put this text into clipboard
+ qapp = qt.QApplication.instance()
+ qapp.clipboard().setText(cell_text)
+
+
+class TableWidget(qt.QTableWidget):
+ """:class:`QTableWidget` with a context menu displaying up to 5 actions:
+
+ - :class:`CopySelectedCellsAction`
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction`
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction`
+
+ These actions interact with the clipboard and can be used to copy data
+ to or from an external application, or another widget.
+
+ The cut and paste actions are disabled by default, due to the risk of
+ overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
+ and :meth:`enableCut` to activate them.
+
+ .. image:: img/TableWidget.png
+
+ :param parent: Parent QWidget
+ :param bool cut: Enable cut action
+ :param bool paste: Enable paste action
+ """
+ def __init__(self, parent=None, cut=False, paste=False):
+ super(TableWidget, self).__init__(parent)
+ self._text_last_cell_clicked = None
+
+ self.copySelectedCellsAction = CopySelectedCellsAction(self)
+ self.copyAllCellsAction = CopyAllCellsAction(self)
+ self.copySingleCellAction = None
+ self.pasteCellsAction = None
+ self.cutSelectedCellsAction = None
+ self.cutAllCellsAction = None
+
+ self.addAction(self.copySelectedCellsAction)
+ self.addAction(self.copyAllCellsAction)
+ if cut:
+ self.enableCut()
+ if paste:
+ self.enablePaste()
+
+ self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
+
+ def mousePressEvent(self, event):
+ item = self.itemAt(event.pos())
+ if item is not None:
+ self._text_last_cell_clicked = item.text()
+ super(TableWidget, self).mousePressEvent(event)
+
+ def enablePaste(self):
+ """Enable paste action, to paste data from the clipboard into the
+ table.
+
+ .. warning::
+
+ This action can cause data to be overwritten.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.pasteCellsAction = PasteCellsAction(self)
+ self.addAction(self.pasteCellsAction)
+
+ def enableCut(self):
+ """Enable cut action.
+
+ .. warning::
+
+ This action can cause data to be deleted.
+ There is currently no *Undo* action to retrieve lost data."""
+ self.cutSelectedCellsAction = CutSelectedCellsAction(self)
+ self.cutAllCellsAction = CutAllCellsAction(self)
+ self.addAction(self.cutSelectedCellsAction)
+ self.addAction(self.cutAllCellsAction)
+
+ def setSelectionMode(self, mode):
+ """Overloaded from QTableWidget to disable cut/copy selection
+ actions in case mode is NoSelection
+
+ :param mode:
+ :return:
+ """
+ if mode == qt.QTableView.NoSelection:
+ self.copySelectedCellsAction.setVisible(False)
+ self.copySelectedCellsAction.setEnabled(False)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(False)
+ self.cutSelectedCellsAction.setEnabled(False)
+ if self.copySingleCellAction is None:
+ self.copySingleCellAction = CopySingleCellAction(self)
+ self.insertAction(self.copySelectedCellsAction, # before first action
+ self.copySingleCellAction)
+ self.copySingleCellAction.setVisible(True)
+ self.copySingleCellAction.setEnabled(True)
+ else:
+ self.copySelectedCellsAction.setVisible(True)
+ self.copySelectedCellsAction.setEnabled(True)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(True)
+ self.cutSelectedCellsAction.setEnabled(True)
+ if self.copySingleCellAction is not None:
+ self.copySingleCellAction.setVisible(False)
+ self.copySingleCellAction.setEnabled(False)
+ super(TableWidget, self).setSelectionMode(mode)
+
+
+class TableView(qt.QTableView):
+ """:class:`QTableView` with a context menu displaying up to 5 actions:
+
+ - :class:`CopySelectedCellsAction`
+ - :class:`CopyAllCellsAction`
+ - :class:`CutSelectedCellsAction`
+ - :class:`CutAllCellsAction`
+ - :class:`PasteCellsAction`
+
+ These actions interact with the clipboard and can be used to copy data
+ to or from an external application, or another widget.
+
+ The cut and paste actions are disabled by default, due to the risk of
+ overwriting data (no *Undo* action is available). Use :meth:`enablePaste`
+ and :meth:`enableCut` to activate them.
+
+ .. note::
+
+ These actions will be available only after a model is associated
+ with this view, using :meth:`setModel`.
+
+ :param parent: Parent QWidget
+ :param bool cut: Enable cut action
+ :param bool paste: Enable paste action
+ """
+ def __init__(self, parent=None, cut=False, paste=False):
+ super(TableView, self).__init__(parent)
+ self._text_last_cell_clicked = None
+
+ self.cut = cut
+ self.paste = paste
+
+ self.copySelectedCellsAction = None
+ self.copyAllCellsAction = None
+ self.copySingleCellAction = None
+ self.pasteCellsAction = None
+ self.cutSelectedCellsAction = None
+ self.cutAllCellsAction = None
+
+ def mousePressEvent(self, event):
+ qindex = self.indexAt(event.pos())
+ if self.copyAllCellsAction is not None: # model was set
+ self._text_last_cell_clicked = self.model().data(qindex)
+ super(TableView, self).mousePressEvent(event)
+
+ def setModel(self, model):
+ """Set the data model for the table view, activate the actions
+ and the context menu.
+
+ :param model: :class:`qt.QAbstractItemModel` object
+ """
+ super(TableView, self).setModel(model)
+
+ self.copySelectedCellsAction = CopySelectedCellsAction(self)
+ self.copyAllCellsAction = CopyAllCellsAction(self)
+ self.addAction(self.copySelectedCellsAction)
+ self.addAction(self.copyAllCellsAction)
+ if self.cut:
+ self.enableCut()
+ if self.paste:
+ self.enablePaste()
+
+ self.setContextMenuPolicy(qt.Qt.ActionsContextMenu)
+
+ def enablePaste(self):
+ """Enable paste action, to paste data from the clipboard into the
+ table.
+
+ .. warning::
+
+ This action can cause data to be overwritten.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.pasteCellsAction = PasteCellsAction(self)
+ self.addAction(self.pasteCellsAction)
+
+ def enableCut(self):
+ """Enable cut action.
+
+ .. warning::
+
+ This action can cause data to be deleted.
+ There is currently no *Undo* action to retrieve lost data.
+ """
+ self.cutSelectedCellsAction = CutSelectedCellsAction(self)
+ self.cutAllCellsAction = CutAllCellsAction(self)
+ self.addAction(self.cutSelectedCellsAction)
+ self.addAction(self.cutAllCellsAction)
+
+ def addAction(self, action):
+ # ensure the actions are not added multiple times:
+ # compare action type and parent widget with those of existing actions
+ for existing_action in self.actions():
+ if type(action) == type(existing_action):
+ if hasattr(action, "table") and\
+ action.table is existing_action.table:
+ return None
+ super(TableView, self).addAction(action)
+
+ def setSelectionMode(self, mode):
+ """Overloaded from QTableView to disable cut/copy selection
+ actions in case mode is NoSelection
+
+ :param mode:
+ :return:
+ """
+ if mode == qt.QTableView.NoSelection:
+ self.copySelectedCellsAction.setVisible(False)
+ self.copySelectedCellsAction.setEnabled(False)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(False)
+ self.cutSelectedCellsAction.setEnabled(False)
+ if self.copySingleCellAction is None:
+ self.copySingleCellAction = CopySingleCellAction(self)
+ self.insertAction(self.copySelectedCellsAction, # before first action
+ self.copySingleCellAction)
+ self.copySingleCellAction.setVisible(True)
+ self.copySingleCellAction.setEnabled(True)
+ else:
+ self.copySelectedCellsAction.setVisible(True)
+ self.copySelectedCellsAction.setEnabled(True)
+ if self.cutSelectedCellsAction is not None:
+ self.cutSelectedCellsAction.setVisible(True)
+ self.cutSelectedCellsAction.setEnabled(True)
+ if self.copySingleCellAction is not None:
+ self.copySingleCellAction.setVisible(False)
+ self.copySingleCellAction.setEnabled(False)
+ super(TableView, self).setSelectionMode(mode)
+
+
+if __name__ == "__main__":
+ app = qt.QApplication([])
+
+ tablewidget = TableWidget()
+ tablewidget.setWindowTitle("TableWidget")
+ tablewidget.setColumnCount(10)
+ tablewidget.setRowCount(7)
+ tablewidget.enableCut()
+ tablewidget.enablePaste()
+ tablewidget.show()
+
+ tableview = TableView(cut=True, paste=True)
+ tableview.setWindowTitle("TableView")
+ model = qt.QStandardItemModel()
+ model.setColumnCount(10)
+ model.setRowCount(7)
+ tableview.setModel(model)
+ tableview.show()
+
+ app.exec()
diff --git a/silx/gui/widgets/ThreadPoolPushButton.py b/src/silx/gui/widgets/ThreadPoolPushButton.py
index 949b6ef..949b6ef 100644
--- a/silx/gui/widgets/ThreadPoolPushButton.py
+++ b/src/silx/gui/widgets/ThreadPoolPushButton.py
diff --git a/src/silx/gui/widgets/UrlSelectionTable.py b/src/silx/gui/widgets/UrlSelectionTable.py
new file mode 100644
index 0000000..bc75d32
--- /dev/null
+++ b/src/silx/gui/widgets/UrlSelectionTable.py
@@ -0,0 +1,169 @@
+# /*##########################################################################
+# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
+#
+# This file is part of the PyMca X-ray Fluorescence Toolkit developed at
+# the ESRF by the Software group.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+#############################################################################*/
+"""Some widget construction to check if a sample moved"""
+
+__author__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "19/03/2018"
+
+from silx.gui import qt
+from collections import OrderedDict
+from silx.gui.widgets.TableWidget import TableWidget
+from silx.io.url import DataUrl
+import functools
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+class UrlSelectionTable(TableWidget):
+ """Table used to select the color channel to be displayed for each"""
+
+ COLUMS_INDEX = OrderedDict([
+ ('url', 0),
+ ('img A', 1),
+ ('img B', 2),
+ ])
+
+ sigImageAChanged = qt.Signal(str)
+ """Signal emitted when the image A change. Param is the image url path"""
+
+ sigImageBChanged = qt.Signal(str)
+ """Signal emitted when the image B change. Param is the image url path"""
+
+ def __init__(self, parent=None):
+ TableWidget.__init__(self, parent)
+ self.clear()
+
+ def clear(self):
+ qt.QTableWidget.clear(self)
+ self.setRowCount(0)
+ self.setColumnCount(len(self.COLUMS_INDEX))
+ self.setHorizontalHeaderLabels(list(self.COLUMS_INDEX.keys()))
+ self.verticalHeader().hide()
+ self.horizontalHeader().setSectionResizeMode(0,
+ qt.QHeaderView.Stretch)
+
+ self.setSortingEnabled(True)
+ self._checkBoxes = {}
+
+ def setUrls(self, urls: list) -> None:
+ """
+
+ :param urls: urls to be displayed
+ """
+ for url in urls:
+ self.addUrl(url=url)
+
+ def addUrl(self, url, **kwargs):
+ """
+
+ :param url:
+ :param args:
+ :return: index of the created items row
+ :rtype int
+ """
+ assert isinstance(url, DataUrl)
+ row = self.rowCount()
+ self.setRowCount(row + 1)
+
+ _item = qt.QTableWidgetItem()
+ _item.setText(os.path.basename(url.path()))
+ _item.setFlags(qt.Qt.ItemIsEnabled | qt.Qt.ItemIsSelectable)
+ self.setItem(row, self.COLUMS_INDEX['url'], _item)
+
+ widgetImgA = qt.QRadioButton(parent=self)
+ widgetImgA.setAutoExclusive(False)
+ self.setCellWidget(row, self.COLUMS_INDEX['img A'], widgetImgA)
+ callbackImgA = functools.partial(self._activeImgAChanged, url.path())
+ widgetImgA.toggled.connect(callbackImgA)
+
+ widgetImgB = qt.QRadioButton(parent=self)
+ widgetImgA.setAutoExclusive(False)
+ self.setCellWidget(row, self.COLUMS_INDEX['img B'], widgetImgB)
+ callbackImgB = functools.partial(self._activeImgBChanged, url.path())
+ widgetImgB.toggled.connect(callbackImgB)
+
+ self._checkBoxes[url.path()] = {'img A': widgetImgA,
+ 'img B': widgetImgB}
+ self.resizeColumnsToContents()
+ return row
+
+ def _activeImgAChanged(self, name):
+ self._updatecheckBoxes('img A', name)
+ self.sigImageAChanged.emit(name)
+
+ def _activeImgBChanged(self, name):
+ self._updatecheckBoxes('img B', name)
+ self.sigImageBChanged.emit(name)
+
+ def _updatecheckBoxes(self, whichImg, name):
+ assert name in self._checkBoxes
+ assert whichImg in self._checkBoxes[name]
+ if self._checkBoxes[name][whichImg].isChecked():
+ for radioUrl in self._checkBoxes:
+ if radioUrl != name:
+ self._checkBoxes[radioUrl][whichImg].blockSignals(True)
+ self._checkBoxes[radioUrl][whichImg].setChecked(False)
+ self._checkBoxes[radioUrl][whichImg].blockSignals(False)
+
+ def getSelection(self):
+ """
+
+ :return: url selected for img A and img B.
+ """
+ imgA = imgB = None
+ for radioUrl in self._checkBoxes:
+ if self._checkBoxes[radioUrl]['img A'].isChecked():
+ imgA = radioUrl
+ if self._checkBoxes[radioUrl]['img B'].isChecked():
+ imgB = radioUrl
+ return imgA, imgB
+
+ def setSelection(self, url_img_a, url_img_b):
+ """
+
+ :param ddict: key: image url, values: list of active channels
+ """
+ for radioUrl in self._checkBoxes:
+ for img in ('img A', 'img B'):
+ self._checkBoxes[radioUrl][img].blockSignals(True)
+ self._checkBoxes[radioUrl][img].setChecked(False)
+ self._checkBoxes[radioUrl][img].blockSignals(False)
+
+ self._checkBoxes[radioUrl][img].blockSignals(True)
+ self._checkBoxes[url_img_a]['img A'].setChecked(True)
+ self._checkBoxes[radioUrl][img].blockSignals(False)
+
+ self._checkBoxes[radioUrl][img].blockSignals(True)
+ self._checkBoxes[url_img_b]['img B'].setChecked(True)
+ self._checkBoxes[radioUrl][img].blockSignals(False)
+ self.sigImageAChanged.emit(url_img_a)
+ self.sigImageBChanged.emit(url_img_b)
+
+ def removeUrl(self, url):
+ raise NotImplementedError("")
diff --git a/src/silx/gui/widgets/WaitingPushButton.py b/src/silx/gui/widgets/WaitingPushButton.py
new file mode 100644
index 0000000..443dc9a
--- /dev/null
+++ b/src/silx/gui/widgets/WaitingPushButton.py
@@ -0,0 +1,241 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""WaitingPushButton module
+"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+from .. import qt
+from .. import icons
+
+
+class WaitingPushButton(qt.QPushButton):
+ """Button which allows to display a waiting status when, for example,
+ something is still computing.
+
+ The component is graphically disabled when it is in waiting. Then we
+ overwrite the enabled method to dissociate the 2 concepts:
+ graphically enabled/disabled, and enabled/disabled
+
+ .. image:: img/WaitingPushButton.png
+ """
+
+ def __init__(self, parent=None, text=None, icon=None):
+ """Constructor
+
+ :param str text: Text displayed on the button
+ :param qt.QIcon icon: Icon displayed on the button
+ :param qt.QWidget parent: Parent of the widget
+ """
+ if icon is not None:
+ qt.QPushButton.__init__(self, icon, text, parent)
+ elif text is not None:
+ qt.QPushButton.__init__(self, text, parent)
+ else:
+ qt.QPushButton.__init__(self, parent)
+
+ self.__waiting = False
+ self.__enabled = True
+ self.__icon = icon
+ self.__disabled_when_waiting = True
+ self.__waitingIcon = icons.getWaitIcon()
+
+ def sizeHint(self):
+ """Returns the recommended size for the widget.
+
+ This implementation of the recommended size always consider there is an
+ icon. In this way it avoid to update the layout when the waiting icon
+ is displayed.
+ """
+ self.ensurePolished()
+
+ w = 0
+ h = 0
+
+ opt = qt.QStyleOptionButton()
+ self.initStyleOption(opt)
+
+ # Content with icon
+ # no condition, assume that there is an icon to avoid blinking
+ # when the widget switch to waiting state
+ ih = opt.iconSize.height()
+ iw = opt.iconSize.width() + 4
+ w += iw
+ h = max(h, ih)
+
+ # Content with text
+ text = self.text()
+ isEmpty = text == ""
+ if isEmpty:
+ text = "XXXX"
+ fm = self.fontMetrics()
+ textSize = fm.size(qt.Qt.TextShowMnemonic, text)
+ if not isEmpty or w == 0:
+ w += textSize.width()
+ if not isEmpty or h == 0:
+ h = max(h, textSize.height())
+
+ # Content with menu indicator
+ opt.rect.setSize(qt.QSize(w, h)) # PM_MenuButtonIndicator depends on the height
+ if self.menu() is not None:
+ w += self.style().pixelMetric(qt.QStyle.PM_MenuButtonIndicator, opt, self)
+
+ contentSize = qt.QSize(w, h)
+ sizeHint = self.style().sizeFromContents(qt.QStyle.CT_PushButton, opt, contentSize, self)
+ if qt.BINDING in ('PySide2', 'PyQt5'): # Qt6: globalStrut not available
+ sizeHint = sizeHint.expandedTo(qt.QApplication.globalStrut())
+ return sizeHint
+
+ def setDisabledWhenWaiting(self, isDisabled):
+ """Enable or disable the auto disable behaviour when the button is waiting.
+
+ :param bool isDisabled: Enable the auto-disable behaviour
+ """
+ if self.__disabled_when_waiting == isDisabled:
+ return
+ self.__disabled_when_waiting = isDisabled
+ self.__updateVisibleEnabled()
+
+ def isDisabledWhenWaiting(self):
+ """Returns true if the button is auto disabled when it is waiting.
+
+ :rtype: bool
+ """
+ return self.__disabled_when_waiting
+
+ disabledWhenWaiting = qt.Property(bool, isDisabledWhenWaiting, setDisabledWhenWaiting)
+ """Property to enable/disable the auto disabled state when the button is waiting."""
+
+ def __setWaitingIcon(self, icon):
+ """Called when the waiting icon is updated. It is called every frames
+ of the animation.
+
+ :param qt.QIcon icon: The new waiting icon
+ """
+ qt.QPushButton.setIcon(self, icon)
+
+ def setIcon(self, icon):
+ """Set the button icon. If the button is waiting, the icon is not
+ visible directly, but will be visible when the waiting state will be
+ removed.
+
+ :param qt.QIcon icon: An icon
+ """
+ self.__icon = icon
+ self.__updateVisibleIcon()
+
+ def getIcon(self):
+ """Returns the icon set to the button. If the widget is waiting
+ it is not returning the visible icon, but the one requested by
+ the application (the one displayed when the widget is not in
+ waiting state).
+
+ :rtype: qt.QIcon
+ """
+ return self.__icon
+
+ icon = qt.Property(qt.QIcon, getIcon, setIcon)
+ """Property providing access to the icon."""
+
+ def __updateVisibleIcon(self):
+ """Update the visible icon according to the state of the widget."""
+ if not self.isWaiting():
+ icon = self.__icon
+ else:
+ icon = self.__waitingIcon.currentIcon()
+ if icon is None:
+ icon = qt.QIcon()
+ qt.QPushButton.setIcon(self, icon)
+
+ def setEnabled(self, enabled):
+ """Set the enabled state of the widget.
+
+ :param bool enabled: The enabled state
+ """
+ if self.__enabled == enabled:
+ return
+ self.__enabled = enabled
+ self.__updateVisibleEnabled()
+
+ def isEnabled(self):
+ """Returns the enabled state of the widget.
+
+ :rtype: bool
+ """
+ return self.__enabled
+
+ enabled = qt.Property(bool, isEnabled, setEnabled)
+ """Property providing access to the enabled state of the widget"""
+
+ def __updateVisibleEnabled(self):
+ """Update the visible enabled state according to the state of the
+ widget."""
+ if self.__disabled_when_waiting:
+ enabled = not self.isWaiting() and self.__enabled
+ else:
+ enabled = self.__enabled
+ qt.QPushButton.setEnabled(self, enabled)
+
+ def setWaiting(self, waiting):
+ """Set the waiting state of the widget.
+
+ :param bool waiting: Requested state"""
+ if self.__waiting == waiting:
+ return
+ self.__waiting = waiting
+
+ if self.__waiting:
+ self.__waitingIcon.register(self)
+ self.__waitingIcon.iconChanged.connect(self.__setWaitingIcon)
+ else:
+ # unregister only if the object is registred
+ self.__waitingIcon.unregister(self)
+ self.__waitingIcon.iconChanged.disconnect(self.__setWaitingIcon)
+
+ self.__updateVisibleEnabled()
+ self.__updateVisibleIcon()
+
+ def isWaiting(self):
+ """Returns true if the widget is in waiting state.
+
+ :rtype: bool"""
+ return self.__waiting
+
+ @qt.Slot()
+ def wait(self):
+ """Enable the waiting state."""
+ self.setWaiting(True)
+
+ @qt.Slot()
+ def stopWaiting(self):
+ """Disable the waiting state."""
+ self.setWaiting(False)
+
+ @qt.Slot()
+ def swapWaiting(self):
+ """Swap the waiting state."""
+ self.setWaiting(not self.isWaiting())
diff --git a/silx/gui/widgets/__init__.py b/src/silx/gui/widgets/__init__.py
index 9d0299d..9d0299d 100644
--- a/silx/gui/widgets/__init__.py
+++ b/src/silx/gui/widgets/__init__.py
diff --git a/silx/gui/widgets/setup.py b/src/silx/gui/widgets/setup.py
index e96ac8d..e96ac8d 100644
--- a/silx/gui/widgets/setup.py
+++ b/src/silx/gui/widgets/setup.py
diff --git a/src/silx/gui/widgets/test/__init__.py b/src/silx/gui/widgets/test/__init__.py
new file mode 100644
index 0000000..243dbc7
--- /dev/null
+++ b/src/silx/gui/widgets/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py b/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
new file mode 100644
index 0000000..5df8df9
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_boxlayoutdockwidget.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for BoxLayoutDockWidget"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/03/2018"
+
+import unittest
+
+from silx.gui.widgets.BoxLayoutDockWidget import BoxLayoutDockWidget
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+
+
+class TestBoxLayoutDockWidget(TestCaseQt):
+ """Tests for BoxLayoutDockWidget"""
+
+ def setUp(self):
+ """Create and show a main window"""
+ self.window = qt.QMainWindow()
+ self.qWaitForWindowExposed(self.window)
+
+ def tearDown(self):
+ """Delete main window"""
+ self.window.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.window.close()
+ del self.window
+ self.qapp.processEvents()
+
+ def test(self):
+ """Test update of layout direction according to dock area"""
+ # Create a widget with a QBoxLayout
+ layout = qt.QBoxLayout(qt.QBoxLayout.LeftToRight)
+ layout.addWidget(qt.QLabel('First'))
+ layout.addWidget(qt.QLabel('Second'))
+ widget = qt.QWidget()
+ widget.setLayout(layout)
+
+ # Add it to a BoxLayoutDockWidget
+ dock = BoxLayoutDockWidget()
+ dock.setWidget(widget)
+
+ self.window.addDockWidget(qt.Qt.BottomDockWidgetArea, dock)
+ self.qapp.processEvents()
+ self.assertEqual(layout.direction(), qt.QBoxLayout.LeftToRight)
+
+ self.window.addDockWidget(qt.Qt.LeftDockWidgetArea, dock)
+ self.qapp.processEvents()
+ self.assertEqual(layout.direction(), qt.QBoxLayout.TopToBottom)
diff --git a/src/silx/gui/widgets/test/test_elidedlabel.py b/src/silx/gui/widgets/test/test_elidedlabel.py
new file mode 100644
index 0000000..693e43c
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_elidedlabel.py
@@ -0,0 +1,100 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for ElidedLabel"""
+
+__license__ = "MIT"
+__date__ = "08/06/2020"
+
+import unittest
+
+from silx.gui import qt
+from silx.gui.widgets.ElidedLabel import ElidedLabel
+from silx.gui.utils import testutils
+
+
+class TestElidedLabel(testutils.TestCaseQt):
+
+ def setUp(self):
+ self.label = ElidedLabel()
+ self.label.show()
+ self.qWaitForWindowExposed(self.label)
+
+ def tearDown(self):
+ self.label.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.label.close()
+ del self.label
+ self.qapp.processEvents()
+
+ def testElidedValue(self):
+ """Test elided text"""
+ raw = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ self.label.setText(raw)
+ self.label.setFixedWidth(30)
+ displayedText = qt.QLabel.text(self.label)
+ self.assertNotEqual(raw, displayedText)
+ self.assertIn("…", displayedText)
+ self.assertIn("m", displayedText)
+
+ def testNotElidedValue(self):
+ """Test elided text"""
+ raw = "mmmmmmm"
+ self.label.setText(raw)
+ self.label.setFixedWidth(200)
+ displayedText = qt.QLabel.text(self.label)
+ self.assertNotIn("…", displayedText)
+ self.assertEqual(raw, displayedText)
+
+ def testUpdateFromElidedToNotElided(self):
+ """Test tooltip when not elided"""
+ raw1 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ raw2 = "nn"
+ self.label.setText(raw1)
+ self.label.setFixedWidth(30)
+ self.label.setText(raw2)
+ displayedTooltip = qt.QLabel.toolTip(self.label)
+ self.assertNotIn(raw1, displayedTooltip)
+ self.assertNotIn(raw2, displayedTooltip)
+
+ def testUpdateFromNotElidedToElided(self):
+ """Test tooltip when elided"""
+ raw1 = "nn"
+ raw2 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ self.label.setText(raw1)
+ self.label.setFixedWidth(30)
+ self.label.setText(raw2)
+ displayedTooltip = qt.QLabel.toolTip(self.label)
+ self.assertNotIn(raw1, displayedTooltip)
+ self.assertIn(raw2, displayedTooltip)
+
+ def testUpdateFromElidedToElided(self):
+ """Test tooltip when elided"""
+ raw1 = "nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn"
+ raw2 = "mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm"
+ self.label.setText(raw1)
+ self.label.setFixedWidth(30)
+ self.label.setText(raw2)
+ displayedTooltip = qt.QLabel.toolTip(self.label)
+ self.assertNotIn(raw1, displayedTooltip)
+ self.assertIn(raw2, displayedTooltip)
diff --git a/src/silx/gui/widgets/test/test_flowlayout.py b/src/silx/gui/widgets/test/test_flowlayout.py
new file mode 100644
index 0000000..85d7cfe
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_flowlayout.py
@@ -0,0 +1,66 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for FlowLayout"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "02/08/2018"
+
+import unittest
+
+from silx.gui.widgets.FlowLayout import FlowLayout
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+
+
+class TestFlowLayout(TestCaseQt):
+ """Tests for FlowLayout"""
+
+ def setUp(self):
+ """Create and show a widget"""
+ self.widget = qt.QWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ """Delete widget"""
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ self.qapp.processEvents()
+
+ def test(self):
+ """Basic tests"""
+ layout = FlowLayout()
+ self.widget.setLayout(layout)
+
+ layout.addWidget(qt.QLabel('first'))
+ layout.addWidget(qt.QLabel('second'))
+ self.assertEqual(layout.count(), 2)
+
+ layout.setHorizontalSpacing(10)
+ self.assertEqual(layout.horizontalSpacing(), 10)
+ layout.setVerticalSpacing(5)
+ self.assertEqual(layout.verticalSpacing(), 5)
diff --git a/src/silx/gui/widgets/test/test_framebrowser.py b/src/silx/gui/widgets/test/test_framebrowser.py
new file mode 100644
index 0000000..8233622
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_framebrowser.py
@@ -0,0 +1,62 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "23/03/2018"
+
+
+import unittest
+
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.widgets.FrameBrowser import FrameBrowser
+
+
+class TestFrameBrowser(TestCaseQt):
+ """Test for FrameBrowser"""
+
+ def test(self):
+ """Test FrameBrowser"""
+ widget = FrameBrowser()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ nFrames = 20
+ widget.setNFrames(nFrames)
+ self.assertEqual(widget.getRange(), (0, nFrames - 1))
+ self.assertEqual(widget.getValue(), 0)
+
+ range_ = -100, 100
+ widget.setRange(*range_)
+ self.assertEqual(widget.getRange(), range_)
+ self.assertEqual(widget.getValue(), range_[0])
+
+ widget.setValue(0)
+ self.assertEqual(widget.getValue(), 0)
+
+ widget.setValue(range_[1] + 100)
+ self.assertEqual(widget.getValue(), range_[1])
+
+ widget.setValue(range_[0] - 100)
+ self.assertEqual(widget.getValue(), range_[0])
diff --git a/src/silx/gui/widgets/test/test_hierarchicaltableview.py b/src/silx/gui/widgets/test/test_hierarchicaltableview.py
new file mode 100644
index 0000000..302086a
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_hierarchicaltableview.py
@@ -0,0 +1,103 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "07/04/2017"
+
+import unittest
+
+from .. import HierarchicalTableView
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+
+class TableModel(HierarchicalTableView.HierarchicalTableModel):
+
+ def __init__(self, parent):
+ HierarchicalTableView.HierarchicalTableModel.__init__(self, parent)
+ self.__content = {}
+
+ def rowCount(self, parent=qt.QModelIndex()):
+ return 3
+
+ def columnCount(self, parent=qt.QModelIndex()):
+ return 3
+
+ def setData1(self):
+ self.beginResetModel()
+
+ content = {}
+ content[0, 0] = ("title", True, (1, 3))
+ content[0, 1] = ("a", True, (2, 1))
+ content[1, 1] = ("b", False, (1, 2))
+ content[1, 2] = ("c", False, (1, 1))
+ content[2, 2] = ("d", False, (1, 1))
+ self.__content = content
+
+ self.endResetModel()
+
+ def data(self, index, role=qt.Qt.DisplayRole):
+ if not index.isValid():
+ return None
+ cell = self.__content.get((index.column(), index.row()), None)
+ if cell is None:
+ return None
+
+ if role == self.SpanRole:
+ return cell[2]
+ elif role == self.IsHeaderRole:
+ return cell[1]
+ elif role == qt.Qt.DisplayRole:
+ return cell[0]
+ return None
+
+
+class TestHierarchicalTableView(TestCaseQt):
+ """Test for HierarchicalTableView"""
+
+ def testEmpty(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ def testModel(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ model = TableModel(widget)
+ # set the data before using the model into the widget
+ model.setData1()
+ widget.setModel(model)
+ span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
+ self.assertEqual(span, (1, 3))
+ widget.show()
+ self.qWaitForWindowExposed(widget)
+
+ def testModelUpdate(self):
+ widget = HierarchicalTableView.HierarchicalTableView()
+ model = TableModel(widget)
+ widget.setModel(model)
+ # set the data after using the model into the widget
+ model.setData1()
+ span = widget.rowSpan(0, 0), widget.columnSpan(0, 0)
+ self.assertEqual(span, (1, 3))
diff --git a/src/silx/gui/widgets/test/test_legendiconwidget.py b/src/silx/gui/widgets/test/test_legendiconwidget.py
new file mode 100644
index 0000000..fe320f6
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_legendiconwidget.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for LegendIconWidget"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/10/2020"
+
+import unittest
+
+from silx.gui import qt
+from silx.gui.widgets.LegendIconWidget import LegendIconWidget
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+
+
+class TestLegendIconWidget(TestCaseQt, ParametricTestCase):
+ """Tests for TestRangeSlider"""
+
+ def setUp(self):
+ self.widget = LegendIconWidget()
+ self.widget.show()
+ self.qWaitForWindowExposed(self.widget)
+
+ def tearDown(self):
+ self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.widget.close()
+ del self.widget
+ self.qapp.processEvents()
+
+ def testCreate(self):
+ self.qapp.processEvents()
+
+ def testColormap(self):
+ self.widget.setColormap("viridis")
+ self.qapp.processEvents()
+
+ def testSymbol(self):
+ self.widget.setSymbol("o")
+ self.widget.setSymbolColormap("viridis")
+ self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_periodictable.py b/src/silx/gui/widgets/test/test_periodictable.py
new file mode 100644
index 0000000..de9e1af
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_periodictable.py
@@ -0,0 +1,148 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+import unittest
+
+from .. import PeriodicTable
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui import qt
+
+
+class TestPeriodicTable(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ pt = PeriodicTable.PeriodicTable()
+ pt.show()
+ self.qWaitForWindowExposed(pt)
+
+ def testSelectable(self):
+ """basic test (instantiation done in setUp)"""
+ pt = PeriodicTable.PeriodicTable(selectable=True)
+ self.assertTrue(pt.selectable)
+
+ def testCustomElements(self):
+ PTI = PeriodicTable.ColoredPeriodicTableItem
+ my_items = [
+ PTI("Xx", 42, 43, 44, "xaxatorium", 1002.2,
+ bgcolor="#FF0000"),
+ PTI("Yy", 25, 22, 44, "yoyotrium", 8.8)
+ ]
+
+ pt = PeriodicTable.PeriodicTable(elements=my_items)
+
+ pt.setSelection(["He", "Xx"])
+ selection = pt.getSelection()
+ self.assertEqual(len(selection), 1) # "He" not found
+ self.assertEqual(selection[0].symbol, "Xx")
+ self.assertEqual(selection[0].Z, 42)
+ self.assertEqual(selection[0].col, 43)
+ self.assertAlmostEqual(selection[0].mass, 1002.2)
+ self.assertEqual(qt.QColor(selection[0].bgcolor),
+ qt.QColor(qt.Qt.red))
+
+ self.assertTrue(pt.isElementSelected("Xx"))
+ self.assertFalse(pt.isElementSelected("Yy"))
+ self.assertRaises(KeyError, pt.isElementSelected, "Yx")
+
+ def testVeryCustomElements(self):
+ class MyPTI(PeriodicTable.PeriodicTableItem):
+ def __init__(self, *args):
+ PeriodicTable.PeriodicTableItem.__init__(self, *args[:6])
+ self.my_feature = args[6]
+
+ my_items = [
+ MyPTI("Xx", 42, 43, 44, "xaxatorium", 1002.2, "spam"),
+ MyPTI("Yy", 25, 22, 44, "yoyotrium", 8.8, "eggs")
+ ]
+
+ pt = PeriodicTable.PeriodicTable(elements=my_items)
+
+ pt.setSelection(["Xx", "Yy"])
+ selection = pt.getSelection()
+ self.assertEqual(len(selection), 2)
+ self.assertEqual(selection[1].symbol, "Yy")
+ self.assertEqual(selection[1].Z, 25)
+ self.assertEqual(selection[1].col, 22)
+ self.assertEqual(selection[1].row, 44)
+ self.assertAlmostEqual(selection[0].mass, 1002.2)
+ self.assertAlmostEqual(selection[0].my_feature, "spam")
+
+
+class TestPeriodicCombo(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestPeriodicCombo, self).setUp()
+ self.pc = PeriodicTable.PeriodicCombo()
+
+ def tearDown(self):
+ del self.pc
+ super(TestPeriodicCombo, self).tearDown()
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ self.pc.show()
+ self.qWaitForWindowExposed(self.pc)
+
+ def testSelect(self):
+ self.pc.setSelection("Sb")
+ selection = self.pc.getSelection()
+ self.assertIsInstance(selection,
+ PeriodicTable.PeriodicTableItem)
+ self.assertEqual(selection.symbol, "Sb")
+ self.assertEqual(selection.Z, 51)
+ self.assertEqual(selection.name, "antimony")
+
+
+class TestPeriodicList(TestCaseQt):
+ """Basic test for ArrayTableWidget with a numpy array"""
+ def setUp(self):
+ super(TestPeriodicList, self).setUp()
+ self.pl = PeriodicTable.PeriodicList()
+
+ def tearDown(self):
+ del self.pl
+ super(TestPeriodicList, self).tearDown()
+
+ def testShow(self):
+ """basic test (instantiation done in setUp)"""
+ self.pl.show()
+ self.qWaitForWindowExposed(self.pl)
+
+ def testSelect(self):
+ self.pl.setSelectedElements(["Li", "He", "Au"])
+ sel_elmts = self.pl.getSelection()
+
+ self.assertEqual(len(sel_elmts), 3,
+ "Wrong number of elements selected")
+ for e in sel_elmts:
+ self.assertIsInstance(e, PeriodicTable.PeriodicTableItem)
+ self.assertIn(e.symbol, ["Li", "He", "Au"])
+ self.assertIn(e.Z, [2, 3, 79])
+ self.assertIn(e.name, ["lithium", "helium", "gold"])
diff --git a/src/silx/gui/widgets/test/test_printpreview.py b/src/silx/gui/widgets/test/test_printpreview.py
new file mode 100644
index 0000000..8602666
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_printpreview.py
@@ -0,0 +1,63 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test PrintPreview"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "19/07/2017"
+
+
+import unittest
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.widgets.PrintPreview import PrintPreviewDialog
+from silx.gui import qt
+
+from silx.resources import resource_filename
+
+
+class TestPrintPreview(TestCaseQt):
+ def testShow(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.show()
+ self.qapp.processEvents()
+
+ def testAddImage(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.addImage(qt.QImage(resource_filename("gui/icons/clipboard.png")))
+ self.qapp.processEvents()
+
+ def testAddSvg(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.addSvgItem(qt.QSvgRenderer(resource_filename("gui/icons/clipboard.svg"), d.page))
+ self.qapp.processEvents()
+
+ def testAddPixmap(self):
+ p = qt.QPrinter()
+ d = PrintPreviewDialog(printer=p)
+ d.addPixmap(qt.QPixmap.fromImage(qt.QImage(resource_filename("gui/icons/clipboard.png"))))
+ self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_rangeslider.py b/src/silx/gui/widgets/test/test_rangeslider.py
new file mode 100644
index 0000000..f829857
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_rangeslider.py
@@ -0,0 +1,103 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for RangeSlider"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "01/08/2018"
+
+import unittest
+
+from silx.gui import qt, colors
+from silx.gui.widgets.RangeSlider import RangeSlider
+from silx.gui.utils.testutils import TestCaseQt
+from silx.utils.testutils import ParametricTestCase
+
+
+class TestRangeSlider(TestCaseQt, ParametricTestCase):
+ """Tests for TestRangeSlider"""
+
+ def setUp(self):
+ self.slider = RangeSlider()
+ self.slider.show()
+ self.qWaitForWindowExposed(self.slider)
+
+ def tearDown(self):
+ self.slider.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.slider.close()
+ del self.slider
+ self.qapp.processEvents()
+
+ def testRangeValue(self):
+ """Test slider range and values"""
+
+ # Play with range
+ self.slider.setRange(1, 2)
+ self.assertEqual(self.slider.getRange(), (1., 2.))
+ self.assertEqual(self.slider.getValues(), (1., 1.))
+
+ self.slider.setMinimum(-1)
+ self.assertEqual(self.slider.getRange(), (-1., 2.))
+ self.assertEqual(self.slider.getValues(), (1., 1.))
+
+ self.slider.setMaximum(0)
+ self.assertEqual(self.slider.getRange(), (-1., 0.))
+ self.assertEqual(self.slider.getValues(), (0., 0.))
+
+ # Play with values
+ self.slider.setFirstValue(-2.)
+ self.assertEqual(self.slider.getValues(), (-1., 0.))
+
+ self.slider.setFirstValue(-0.5)
+ self.assertEqual(self.slider.getValues(), (-0.5, 0.))
+
+ self.slider.setSecondValue(2.)
+ self.assertEqual(self.slider.getValues(), (-0.5, 0.))
+
+ self.slider.setSecondValue(-0.1)
+ self.assertEqual(self.slider.getValues(), (-0.5, -0.1))
+
+ def testStepCount(self):
+ """Test related to step count"""
+ self.slider.setPositionCount(11)
+ self.assertEqual(self.slider.getPositionCount(), 11)
+ self.slider.setFirstValue(0.32)
+ self.assertEqual(self.slider.getFirstValue(), 0.3)
+ self.assertEqual(self.slider.getFirstPosition(), 3)
+
+ self.slider.setPositionCount(3) # Value is adjusted
+ self.assertEqual(self.slider.getValues(), (0.5, 1.))
+ self.assertEqual(self.slider.getPositions(), (1, 2))
+
+ def testGroove(self):
+ """Test Groove pixmap"""
+ profile = list(range(100))
+
+ for cmap in ('jet', colors.Colormap('viridis')):
+ with self.subTest(str(cmap)):
+ self.slider.setGroovePixmapFromProfile(profile, cmap)
+ pixmap = self.slider.getGroovePixmap()
+ self.assertIsInstance(pixmap, qt.QPixmap)
+ self.assertEqual(pixmap.width(), len(profile))
diff --git a/src/silx/gui/widgets/test/test_tablewidget.py b/src/silx/gui/widgets/test/test_tablewidget.py
new file mode 100644
index 0000000..09122ca
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_tablewidget.py
@@ -0,0 +1,50 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test TableWidget"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "05/12/2016"
+
+
+import unittest
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.widgets.TableWidget import TableWidget
+
+
+class TestTableWidget(TestCaseQt):
+ def setUp(self):
+ super(TestTableWidget, self).setUp()
+ self._result = []
+
+ def testShow(self):
+ table = TableWidget()
+ table.setColumnCount(10)
+ table.setRowCount(7)
+ table.enableCut()
+ table.enablePaste()
+ table.show()
+ table.hide()
+ self.qapp.processEvents()
diff --git a/src/silx/gui/widgets/test/test_threadpoolpushbutton.py b/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
new file mode 100644
index 0000000..3808be0
--- /dev/null
+++ b/src/silx/gui/widgets/test/test_threadpoolpushbutton.py
@@ -0,0 +1,124 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+import time
+from silx.gui import qt
+from silx.gui.utils.testutils import TestCaseQt
+from silx.gui.utils.testutils import SignalListener
+from silx.gui.widgets.ThreadPoolPushButton import ThreadPoolPushButton
+from silx.utils.testutils import LoggingValidator
+
+
+class TestThreadPoolPushButton(TestCaseQt):
+
+ def setUp(self):
+ super(TestThreadPoolPushButton, self).setUp()
+ self._result = []
+
+ def waitForPendingOperations(self, object):
+ for i in range(50):
+ if not object.hasPendingOperations():
+ break
+ self.qWait(10)
+ else:
+ raise RuntimeError("Still waiting for a pending operation")
+
+ def _trace(self, name, delay=0):
+ self._result.append(name)
+ if delay != 0:
+ time.sleep(delay / 1000.0)
+
+ def _compute(self):
+ return "result"
+
+ def _computeFail(self):
+ raise Exception("exception")
+
+ def testExecute(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 0)
+ button.executeCallable()
+ time.sleep(0.1)
+ self.assertListEqual(self._result, ["a"])
+ self.waitForPendingOperations(button)
+
+ def testMultiExecution(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 0)
+ number = qt.silxGlobalThreadPool().maxThreadCount()
+ for _ in range(number):
+ button.executeCallable()
+ self.waitForPendingOperations(button)
+ self.assertListEqual(self._result, ["a"] * number)
+
+ def testSaturateThreadPool(self):
+ button = ThreadPoolPushButton()
+ button.setCallable(self._trace, "a", 100)
+ number = qt.silxGlobalThreadPool().maxThreadCount() * 2
+ for _ in range(number):
+ button.executeCallable()
+ self.waitForPendingOperations(button)
+ self.assertListEqual(self._result, ["a"] * number)
+
+ def testSuccess(self):
+ listener = SignalListener()
+ button = ThreadPoolPushButton()
+ button.setCallable(self._compute)
+ button.beforeExecuting.connect(listener.partial(test="be"))
+ button.started.connect(listener.partial(test="s"))
+ button.succeeded.connect(listener.partial(test="result"))
+ button.failed.connect(listener.partial(test="Unexpected exception"))
+ button.finished.connect(listener.partial(test="f"))
+ button.executeCallable()
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ result = listener.karguments(argumentName="test")
+ self.assertListEqual(result, ["be", "s", "result", "f"])
+
+ def testFail(self):
+ listener = SignalListener()
+ button = ThreadPoolPushButton()
+ button.setCallable(self._computeFail)
+ button.beforeExecuting.connect(listener.partial(test="be"))
+ button.started.connect(listener.partial(test="s"))
+ button.succeeded.connect(listener.partial(test="Unexpected success"))
+ button.failed.connect(listener.partial(test="exception"))
+ button.finished.connect(listener.partial(test="f"))
+ with LoggingValidator('silx.gui.widgets.ThreadPoolPushButton', error=1):
+ button.executeCallable()
+ self.qapp.processEvents()
+ time.sleep(0.1)
+ self.qapp.processEvents()
+ result = listener.karguments(argumentName="test")
+ self.assertListEqual(result, ["be", "s", "exception", "f"])
+ listener.clear()
diff --git a/silx/image/__init__.py b/src/silx/image/__init__.py
index 12bf320..12bf320 100644
--- a/silx/image/__init__.py
+++ b/src/silx/image/__init__.py
diff --git a/silx/image/_boundingbox.py b/src/silx/image/_boundingbox.py
index 1c086b1..1c086b1 100644
--- a/silx/image/_boundingbox.py
+++ b/src/silx/image/_boundingbox.py
diff --git a/silx/image/backprojection.py b/src/silx/image/backprojection.py
index 63f99ca..63f99ca 100644
--- a/silx/image/backprojection.py
+++ b/src/silx/image/backprojection.py
diff --git a/silx/image/bilinear.pyx b/src/silx/image/bilinear.pyx
index 14547f8..14547f8 100644
--- a/silx/image/bilinear.pyx
+++ b/src/silx/image/bilinear.pyx
diff --git a/silx/image/marchingsquares/__init__.py b/src/silx/image/marchingsquares/__init__.py
index a47a7f6..a47a7f6 100644
--- a/silx/image/marchingsquares/__init__.py
+++ b/src/silx/image/marchingsquares/__init__.py
diff --git a/silx/image/marchingsquares/_mergeimpl.pyx b/src/silx/image/marchingsquares/_mergeimpl.pyx
index 5a7a3b5..5a7a3b5 100644
--- a/silx/image/marchingsquares/_mergeimpl.pyx
+++ b/src/silx/image/marchingsquares/_mergeimpl.pyx
diff --git a/silx/image/marchingsquares/_skimage.py b/src/silx/image/marchingsquares/_skimage.py
index d49eeb0..d49eeb0 100644
--- a/silx/image/marchingsquares/_skimage.py
+++ b/src/silx/image/marchingsquares/_skimage.py
diff --git a/silx/image/marchingsquares/include/patterns.h b/src/silx/image/marchingsquares/include/patterns.h
index ff86cc1..ff86cc1 100644
--- a/silx/image/marchingsquares/include/patterns.h
+++ b/src/silx/image/marchingsquares/include/patterns.h
diff --git a/src/silx/image/marchingsquares/setup.py b/src/silx/image/marchingsquares/setup.py
new file mode 100644
index 0000000..95998ab
--- /dev/null
+++ b/src/silx/image/marchingsquares/setup.py
@@ -0,0 +1,51 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "23/04/2018"
+
+import os
+import numpy
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('marchingsquares', parent_package, top_path)
+ config.add_subpackage('test')
+
+ silx_include = os.path.join(top_path, "src", "silx", "utils", "include")
+ config.add_extension('_mergeimpl',
+ sources=['_mergeimpl.pyx'],
+ include_dirs=[numpy.get_include(), silx_include],
+ language='c++',
+ extra_link_args=['-fopenmp'],
+ extra_compile_args=['-fopenmp'])
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/src/silx/image/marchingsquares/test/__init__.py b/src/silx/image/marchingsquares/test/__init__.py
new file mode 100644
index 0000000..776bb73
--- /dev/null
+++ b/src/silx/image/marchingsquares/test/__init__.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
diff --git a/src/silx/image/marchingsquares/test/test_funcapi.py b/src/silx/image/marchingsquares/test/test_funcapi.py
new file mode 100644
index 0000000..d1be584
--- /dev/null
+++ b/src/silx/image/marchingsquares/test/test_funcapi.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/04/2018"
+
+import unittest
+import numpy
+import silx.image.marchingsquares
+
+
+class MockMarchingSquares(object):
+
+ last = None
+
+ def __init__(self, image, mask=None):
+ MockMarchingSquares.last = self
+ self.events = []
+ self.events.append(("image", image))
+ self.events.append(("mask", mask))
+
+ def find_pixels(self, level):
+ self.events.append(("find_pixels", level))
+ return None
+
+ def find_contours(self, level):
+ self.events.append(("find_contours", level))
+ return None
+
+
+class TestFunctionalApi(unittest.TestCase):
+ """Test that the default functional API is called using the right
+ parameters to the right location."""
+
+ def setUp(self):
+ self.old_impl = silx.image.marchingsquares.MarchingSquaresMergeImpl
+ silx.image.marchingsquares.MarchingSquaresMergeImpl = MockMarchingSquares
+
+ def tearDown(self):
+ silx.image.marchingsquares.MarchingSquaresMergeImpl = self.old_impl
+ del self.old_impl
+
+ def test_default_find_contours(self):
+ image = numpy.ones((2, 2), dtype=numpy.float32)
+ mask = numpy.zeros((2, 2), dtype=numpy.int32)
+ level = 2.5
+ silx.image.marchingsquares.find_contours(image=image, level=level, mask=mask)
+ events = MockMarchingSquares.last.events
+ self.assertEqual(len(events), 3)
+ self.assertEqual(events[0][0], "image")
+ self.assertEqual(events[0][1][0, 0], 1)
+ self.assertEqual(events[1][0], "mask")
+ self.assertEqual(events[1][1][0, 0], 0)
+ self.assertEqual(events[2][0], "find_contours")
+ self.assertEqual(events[2][1], level)
+
+ def test_default_find_pixels(self):
+ image = numpy.ones((2, 2), dtype=numpy.float32)
+ mask = numpy.zeros((2, 2), dtype=numpy.int32)
+ level = 3.5
+ silx.image.marchingsquares.find_pixels(image=image, level=level, mask=mask)
+ events = MockMarchingSquares.last.events
+ self.assertEqual(len(events), 3)
+ self.assertEqual(events[0][0], "image")
+ self.assertEqual(events[0][1][0, 0], 1)
+ self.assertEqual(events[1][0], "mask")
+ self.assertEqual(events[1][1][0, 0], 0)
+ self.assertEqual(events[2][0], "find_pixels")
+ self.assertEqual(events[2][1], level)
diff --git a/src/silx/image/marchingsquares/test/test_mergeimpl.py b/src/silx/image/marchingsquares/test/test_mergeimpl.py
new file mode 100644
index 0000000..07b94b5
--- /dev/null
+++ b/src/silx/image/marchingsquares/test/test_mergeimpl.py
@@ -0,0 +1,264 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "18/04/2018"
+
+import unittest
+import numpy
+from .._mergeimpl import MarchingSquaresMergeImpl
+
+
+class TestMergeImplApi(unittest.TestCase):
+
+ def test_image_not_an_array(self):
+ bad_image = 1
+ self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
+
+ def test_image_bad_dim(self):
+ bad_image = numpy.array([[[1.0]]])
+ self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
+
+ def test_image_not_big_enough(self):
+ bad_image = numpy.array([[1.0, 1.0, 1.0, 1.0]])
+ self.assertRaises(ValueError, MarchingSquaresMergeImpl, bad_image)
+
+ def test_mask_not_an_array(self):
+ image = numpy.array([[1.0, 1.0], [1.0, 1.0]])
+ bad_mask = 1
+ self.assertRaises(ValueError, MarchingSquaresMergeImpl, image, bad_mask)
+
+ def test_mask_not_match(self):
+ image = numpy.array([[1.0, 1.0], [1.0, 1.0]])
+ bad_mask = numpy.array([[1.0, 1.0]])
+ self.assertRaises(ValueError, MarchingSquaresMergeImpl, image, bad_mask)
+
+ def test_ok_anyway_bad_type(self):
+ image = numpy.array([[1.0, 1.0], [1.0, 1.0]], dtype=numpy.int32)
+ mask = numpy.array([[1.0, 1.0], [1.0, 1.0]], dtype=numpy.float32)
+ MarchingSquaresMergeImpl(image, mask)
+
+ def test_find_contours_result(self):
+ image = numpy.zeros((2, 2))
+ image[0, 0] = 1
+ ms = MarchingSquaresMergeImpl(image)
+ polygons = ms.find_contours(0.5)
+ self.assertIsInstance(polygons, list)
+ self.assertTrue(len(polygons), 1)
+ self.assertIsInstance(polygons[0], numpy.ndarray)
+ self.assertEqual(polygons[0].shape[1], 2)
+ self.assertEqual(polygons[0].dtype.kind, "f")
+
+ def test_find_pixels_result(self):
+ image = numpy.zeros((2, 2))
+ image[0, 0] = 1
+ ms = MarchingSquaresMergeImpl(image)
+ pixels = ms.find_pixels(0.5)
+ self.assertIsInstance(pixels, numpy.ndarray)
+ self.assertEqual(pixels.shape[1], 2)
+ self.assertEqual(pixels.dtype.kind, "i")
+
+ def test_find_contours_empty_result(self):
+ image = numpy.zeros((2, 2))
+ ms = MarchingSquaresMergeImpl(image)
+ polygons = ms.find_contours(0.5)
+ self.assertIsInstance(polygons, list)
+ self.assertEqual(len(polygons), 0)
+
+ def test_find_pixels_empty_result(self):
+ image = numpy.zeros((2, 2))
+ ms = MarchingSquaresMergeImpl(image)
+ pixels = ms.find_pixels(0.5)
+ self.assertIsInstance(pixels, numpy.ndarray)
+ self.assertEqual(pixels.shape[1], 2)
+ self.assertEqual(pixels.shape[0], 0)
+ self.assertEqual(pixels.dtype.kind, "i")
+
+ def test_find_contours_yx_result(self):
+ image = numpy.zeros((2, 2))
+ image[1, 0] = 1
+ ms = MarchingSquaresMergeImpl(image)
+ polygons = ms.find_contours(0.5)
+ polygon = polygons[0]
+ self.assertTrue((polygon == (0.5, 0)).any())
+ self.assertTrue((polygon == (1, 0.5)).any())
+
+ def test_find_pixels_yx_result(self):
+ image = numpy.zeros((2, 2))
+ image[1, 0] = 1
+ ms = MarchingSquaresMergeImpl(image)
+ pixels = ms.find_pixels(0.5)
+ self.assertTrue((pixels == (1, 0)).any())
+
+
+class TestMergeImplContours(unittest.TestCase):
+
+ def test_merge_segments(self):
+ image = numpy.zeros((4, 4))
+ image[(2, 3), :] = 1
+ ms = MarchingSquaresMergeImpl(image)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 1)
+
+ def test_merge_segments_2(self):
+ image = numpy.zeros((4, 4))
+ image[(2, 3), :] = 1
+ image[2, 2] = 0
+ ms = MarchingSquaresMergeImpl(image)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 1)
+
+ def test_merge_tiles(self):
+ image = numpy.zeros((4, 4))
+ image[(2, 3), :] = 1
+ ms = MarchingSquaresMergeImpl(image, group_size=2)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 1)
+
+ def test_fully_masked(self):
+ image = numpy.zeros((5, 5))
+ image[(2, 3), :] = 1
+ mask = numpy.ones((5, 5))
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 0)
+
+ def test_fully_masked_minmax(self):
+ """This invalidates all the tiles. The route is not the same."""
+ image = numpy.zeros((5, 5))
+ image[(2, 3), :] = 1
+ mask = numpy.ones((5, 5))
+ ms = MarchingSquaresMergeImpl(image, mask, group_size=2, use_minmax_cache=True)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 0)
+
+ def test_masked_segments(self):
+ image = numpy.zeros((5, 5))
+ image[(2, 3, 4), :] = 1
+ mask = numpy.zeros((5, 5))
+ mask[:, 2] = 1
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 2)
+
+ def test_closed_polygon(self):
+ image = numpy.zeros((5, 5))
+ image[2, 2] = 1
+ image[1, 2] = 1
+ image[3, 2] = 1
+ image[2, 1] = 1
+ image[2, 3] = 1
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.9)
+ self.assertEqual(len(polygons), 1)
+ self.assertEqual(list(polygons[0][0]), list(polygons[0][-1]))
+
+ def test_closed_polygon_between_tiles(self):
+ image = numpy.zeros((5, 5))
+ image[2, 2] = 1
+ image[1, 2] = 1
+ image[3, 2] = 1
+ image[2, 1] = 1
+ image[2, 3] = 1
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask, group_size=2)
+ polygons = ms.find_contours(0.9)
+ self.assertEqual(len(polygons), 1)
+ self.assertEqual(list(polygons[0][0]), list(polygons[0][-1]))
+
+ def test_open_polygon(self):
+ image = numpy.zeros((5, 5))
+ image[2, 2] = 1
+ image[1, 2] = 1
+ image[3, 2] = 1
+ image[2, 1] = 1
+ image[2, 3] = 1
+ mask = numpy.zeros((5, 5))
+ mask[1, 1] = 1
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.9)
+ self.assertEqual(len(polygons), 1)
+ self.assertNotEqual(list(polygons[0][0]), list(polygons[0][-1]))
+
+ def test_ambiguous_pattern(self):
+ image = numpy.zeros((6, 8))
+ image[(3, 4), :] = 1
+ image[:, (0, -1)] = 0
+ image[3, 3] = -0.001
+ image[4, 4] = 0.0
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 2)
+
+ def test_ambiguous_pattern_2(self):
+ image = numpy.zeros((6, 8))
+ image[(3, 4), :] = 1
+ image[:, (0, -1)] = 0
+ image[3, 3] = +0.001
+ image[4, 4] = 0.0
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 1)
+
+ def count_closed_polygons(self, polygons):
+ closed = 0
+ for polygon in polygons:
+ if list(polygon[0]) == list(polygon[-1]):
+ closed += 1
+ return closed
+
+ def test_image(self):
+ # example from skimage
+ x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
+ image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 11)
+ self.assertEqual(self.count_closed_polygons(polygons), 3)
+
+ def test_image_tiled(self):
+ # example from skimage
+ x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
+ image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask, group_size=50)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 11)
+ self.assertEqual(self.count_closed_polygons(polygons), 3)
+
+ def test_image_tiled_minmax(self):
+ # example from skimage
+ x, y = numpy.ogrid[-numpy.pi:numpy.pi:100j, -numpy.pi:numpy.pi:100j]
+ image = numpy.sin(numpy.exp((numpy.sin(x)**3 + numpy.cos(y)**2)))
+ mask = None
+ ms = MarchingSquaresMergeImpl(image, mask, group_size=50, use_minmax_cache=True)
+ polygons = ms.find_contours(0.5)
+ self.assertEqual(len(polygons), 11)
+ self.assertEqual(self.count_closed_polygons(polygons), 3)
diff --git a/silx/image/medianfilter.py b/src/silx/image/medianfilter.py
index 857f73d..857f73d 100644
--- a/silx/image/medianfilter.py
+++ b/src/silx/image/medianfilter.py
diff --git a/silx/image/phantomgenerator.py b/src/silx/image/phantomgenerator.py
index 10b249b..10b249b 100644
--- a/silx/image/phantomgenerator.py
+++ b/src/silx/image/phantomgenerator.py
diff --git a/silx/image/projection.py b/src/silx/image/projection.py
index 5c76c35..5c76c35 100644
--- a/silx/image/projection.py
+++ b/src/silx/image/projection.py
diff --git a/silx/image/reconstruction.py b/src/silx/image/reconstruction.py
index 875b66b..875b66b 100644
--- a/silx/image/reconstruction.py
+++ b/src/silx/image/reconstruction.py
diff --git a/silx/image/setup.py b/src/silx/image/setup.py
index 69d5b1b..69d5b1b 100644
--- a/silx/image/setup.py
+++ b/src/silx/image/setup.py
diff --git a/silx/image/shapes.pyx b/src/silx/image/shapes.pyx
index 9284811..9284811 100644
--- a/silx/image/shapes.pyx
+++ b/src/silx/image/shapes.pyx
diff --git a/silx/image/sift.py b/src/silx/image/sift.py
index cb1e6bd..cb1e6bd 100644
--- a/silx/image/sift.py
+++ b/src/silx/image/sift.py
diff --git a/src/silx/image/test/__init__.py b/src/silx/image/test/__init__.py
new file mode 100644
index 0000000..40b11a1
--- /dev/null
+++ b/src/silx/image/test/__init__.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2018 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
diff --git a/src/silx/image/test/test_bb.py b/src/silx/image/test/test_bb.py
new file mode 100644
index 0000000..7427273
--- /dev/null
+++ b/src/silx/image/test/test_bb.py
@@ -0,0 +1,74 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic tests for Bounding box"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "27/09/2019"
+
+
+import unittest
+import numpy
+from silx.image._boundingbox import _BoundingBox
+
+
+class TestBB(unittest.TestCase):
+ """Some simple test on the bounding box class"""
+ def test_creation(self):
+ """test some constructors"""
+ pts = numpy.array([(0, 0), (10, 20), (20, 0)])
+ bb = _BoundingBox.from_points(pts)
+ self.assertTrue(bb.bottom_left == (0, 0))
+ self.assertTrue(bb.top_right == (20, 20))
+ pts = numpy.array([(0, 10), (10, 20), (45, 30), (35, 0)])
+ bb = _BoundingBox.from_points(pts)
+ self.assertTrue(bb.bottom_left == (0, 0))
+ print(bb.top_right)
+ self.assertTrue(bb.top_right == (45, 30))
+
+ def test_isIn_pt(self):
+ """test the isIn function with points"""
+ bb = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
+ self.assertTrue(bb.contains((10, 4)))
+ self.assertTrue(bb.contains((6, 2)))
+ self.assertTrue(bb.contains((12, 2)))
+ self.assertFalse(bb.contains((0, 0)))
+ self.assertFalse(bb.contains((20, 0)))
+ self.assertFalse(bb.contains((10, 0)))
+
+ def test_collide(self):
+ """test the collide function"""
+ bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
+ self.assertTrue(bb1.collide(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6))))
+ bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
+ self.assertFalse(bb1.collide(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2))))
+
+ def test_isIn_bb(self):
+ """test the isIn function with other bounding box"""
+ bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
+ self.assertTrue(bb1.contains(_BoundingBox(bottom_left=(6, 2), top_right=(12, 6))))
+ bb1 = _BoundingBox(bottom_left=(6, 2), top_right=(12, 6))
+ self.assertTrue(bb1.contains(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2))))
+ self.assertFalse(_BoundingBox(bottom_left=(12, 2), top_right=(12, 2)).contains(bb1))
diff --git a/src/silx/image/test/test_bilinear.py b/src/silx/image/test/test_bilinear.py
new file mode 100644
index 0000000..20ceb58
--- /dev/null
+++ b/src/silx/image/test/test_bilinear.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx (originally pyFAI)
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2017 European Synchrotron Radiation Facility, Grenoble, France
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+__authors__ = ["J. Kieffer"]
+__license__ = "MIT"
+__date__ = "25/11/2020"
+
+import unittest
+import numpy
+import logging
+logger = logging.getLogger(__name__)
+from ..bilinear import BilinearImage
+
+
+class TestBilinear(unittest.TestCase):
+ """basic maximum search test"""
+ N = 1000
+
+ def test_max_search_round(self):
+ """test maximum search using random points: maximum is at the pixel center"""
+ a = numpy.arange(100) - 40.
+ b = numpy.arange(100) - 60.
+ ga = numpy.exp(-a * a / 4000)
+ gb = numpy.exp(-b * b / 6000)
+ gg = numpy.outer(ga, gb)
+ b = BilinearImage(gg)
+
+ self.assertAlmostEqual(b.maxi, 1, 2, "maxi is almost 1")
+ self.assertLess(b.mini, 0.3, "mini should be around 0.23")
+
+ ok = 0
+ for s in range(self.N):
+ i, j = numpy.random.randint(100), numpy.random.randint(100)
+ k, l = b.local_maxi((i, j))
+ if abs(k - 40) > 1e-4 or abs(l - 60) > 1e-4:
+ logger.warning("Wrong guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
+ else:
+ logger.debug("Good guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
+ ok += 1
+ logger.debug("Success rate: %.1f", 100. * ok / self.N)
+ self.assertEqual(ok, self.N, "Maximum is always found")
+
+ def test_max_search_half(self):
+ """test maximum search using random points: maximum is at a pixel edge"""
+ a = numpy.arange(100) - 40.5
+ b = numpy.arange(100) - 60.5
+ ga = numpy.exp(-a * a / 4000)
+ gb = numpy.exp(-b * b / 6000)
+ gg = numpy.outer(ga, gb)
+ b = BilinearImage(gg)
+ ok = 0
+ for s in range(self.N):
+ i, j = numpy.random.randint(100), numpy.random.randint(100)
+ k, l = b.local_maxi((i, j))
+ if abs(k - 40.5) > 0.5 or abs(l - 60.5) > 0.5:
+ logger.warning("Wrong guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
+ else:
+ logger.debug("Good guess maximum (%i,%i) -> (%.1f,%.1f)", i, j, k, l)
+ ok += 1
+ logger.debug("Success rate: %.1f", 100. * ok / self.N)
+ self.assertEqual(ok, self.N, "Maximum is always found")
+
+ def test_map(self):
+ N = 6
+ y, x = numpy.ogrid[:N,:N + 10]
+ img = x + y
+ b = BilinearImage(img)
+ x2d = numpy.zeros_like(y) + x
+ y2d = numpy.zeros_like(x) + y
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertEqual(abs(res1 - img).max(), 0, "images are the same (corners)")
+
+ x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + y
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertEqual(abs(res1 - img[:,:-1] - 0.5).max(), 0, "images are the same (middle)")
+
+ x2d = numpy.zeros_like(y[:-1,:]) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1,:] + 0.5)
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertEqual(abs(res1 - img[:-1, 1:]).max(), 0, "images are the same (center)")
+
+ def test_mask_grad(self):
+ N = 100
+ img = numpy.arange(N * N).reshape(N, N)
+ # No mask on the boundaries, makes the test complicated, pixel always separated
+ masked = 2 * numpy.random.randint(0, int((N - 1) / 2), size=(2, N)) + 1
+ mask = numpy.zeros((N, N), dtype=numpy.uint8)
+ mask[(masked[0], masked[1])] = 1
+ self.assertLessEqual(mask.sum(), N, "At most N pixels are masked")
+
+ b = BilinearImage(img, mask=mask)
+ self.assertEqual(b.has_mask, True, "interpolator has mask")
+ self.assertEqual(b.maxi, N * N - 1, "maxi is N²-1")
+ self.assertEqual(b.mini, 0, "mini is 0")
+
+ y, x = numpy.ogrid[:N,:N]
+ x2d = numpy.zeros_like(y) + x
+ y2d = numpy.zeros_like(x) + y
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertEqual(numpy.nanmax(abs(res1 - img)), 0, "images are the same (corners), or Nan ")
+
+ x2d = numpy.zeros_like(y) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + y
+ res1 = b.map_coordinates((y2d, x2d))
+ self.assertLessEqual(numpy.max(abs(res1 - img[:, 1:] + 1 / 2.)), 0.5, "images are the same (middle) +/- 0.5")
+
+ x2d = numpy.zeros_like(y[:-1]) + (x[:,:-1] + 0.5)
+ y2d = numpy.zeros_like(x[:,:-1]) + (y[:-1] + 0.5)
+ res1 = b.map_coordinates((y2d, x2d))
+ exp = 0.25 * (img[:-1,:-1] + img[:-1, 1:] + img[1:,:-1] + img[1:, 1:])
+ self.assertLessEqual(abs(res1 - exp).max(), N / 4, "images are almost the same (center)")
+
+ def test_profile_grad(self):
+ N = 100
+ img = numpy.arange(N * N).reshape(N, N)
+ b = BilinearImage(img)
+ res1 = b.profile_line((0, 0), (N - 1, N - 1))
+ l = numpy.ceil(numpy.sqrt(2) * N)
+ self.assertEqual(len(res1), l, "Profile has correct length")
+ self.assertLess((res1[:-2] - res1[1:-1]).std(), 1e-3, "profile is linear (excluding last point)")
+
+ def test_profile_gaus(self):
+ N = 100
+ x = numpy.arange(N) - N // 2.0
+ g = numpy.exp(-x * x / (N * N))
+ img = numpy.outer(g, g)
+ b = BilinearImage(img)
+ res_hor = b.profile_line((N // 2, 0), (N // 2, N - 1))
+ res_ver = b.profile_line((0, N // 2), (N - 1, N // 2))
+ self.assertEqual(len(res_hor), N, "Profile has correct length")
+ self.assertEqual(len(res_ver), N, "Profile has correct length")
+ self.assertLess(abs(res_hor - g).max(), 1e-5, "correct horizontal profile")
+ self.assertLess(abs(res_ver - g).max(), 1e-5, "correct vertical profile")
+
+ # Profile with linewidth=3
+ expected_profile = img[:, N // 2 - 1:N // 2 + 2].mean(axis=1)
+ res_hor = b.profile_line((N // 2, 0), (N // 2, N - 1), linewidth=3)
+ res_ver = b.profile_line((0, N // 2), (N - 1, N // 2), linewidth=3)
+
+ self.assertEqual(len(res_hor), N, "Profile has correct length")
+ self.assertEqual(len(res_ver), N, "Profile has correct length")
+ self.assertLess(abs(res_hor - expected_profile).max(), 1e-5,
+ "correct horizontal profile")
+ self.assertLess(abs(res_ver - expected_profile).max(), 1e-5,
+ "correct vertical profile")
diff --git a/src/silx/image/test/test_medianfilter.py b/src/silx/image/test/test_medianfilter.py
new file mode 100644
index 0000000..d3386a4
--- /dev/null
+++ b/src/silx/image/test/test_medianfilter.py
@@ -0,0 +1,64 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests that the different implementation of opencl (cpp, opencl) are
+ accessible
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "11/05/2017"
+
+import unittest
+from silx.image import medianfilter
+import numpy
+
+from silx.opencl.common import ocl
+
+
+class TestMedianFilterEngines(unittest.TestCase):
+ """Make sure we have access to all the different implementation of
+ median filter from image medfilt"""
+
+
+ IMG = numpy.arange(10000.).reshape(100, 100)
+
+ KERNEL = (1, 1)
+
+ def testCppMedFilt2d(self):
+ """test cpp engine for medfilt2d"""
+ res = medianfilter.medfilt2d(
+ image=TestMedianFilterEngines.IMG,
+ kernel_size=TestMedianFilterEngines.KERNEL,
+ engine='cpp')
+ self.assertTrue(numpy.array_equal(res, TestMedianFilterEngines.IMG))
+
+ @unittest.skipUnless(ocl, "PyOpenCl is missing")
+ def testOpenCLMedFilt2d(self):
+ """test cpp engine for medfilt2d"""
+ res = medianfilter.medfilt2d(
+ image=TestMedianFilterEngines.IMG,
+ kernel_size=TestMedianFilterEngines.KERNEL,
+ engine='opencl')
+ self.assertTrue(numpy.array_equal(res, TestMedianFilterEngines.IMG))
diff --git a/src/silx/image/test/test_shapes.py b/src/silx/image/test/test_shapes.py
new file mode 100644
index 0000000..63abc00
--- /dev/null
+++ b/src/silx/image/test/test_shapes.py
@@ -0,0 +1,354 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for polygon functions
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "15/02/2019"
+
+
+import logging
+import unittest
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.image import shapes
+
+_logger = logging.getLogger(__name__)
+
+
+class TestPolygonFill(ParametricTestCase):
+ """basic poylgon test"""
+
+ def test_squares(self):
+ """Test polygon fill for a square polygons"""
+ mask_shape = 4, 4
+ tests = {
+ # test name: [(row min, row max), (col min, col max)]
+ 'square in': [(1, 3), (1, 3)],
+ 'square out': [(1, 3), (1, 10)],
+ 'square around': [(-1, 5), (-1, 5)],
+ }
+
+ for test_name, (rows, cols) in tests.items():
+ with self.subTest(msg=test_name, rows=rows, cols=cols,
+ mask_shape=mask_shape):
+ ref_mask = numpy.zeros(mask_shape, dtype=numpy.uint8)
+ ref_mask[max(0, rows[0]):rows[1],
+ max(0, cols[0]):cols[1]] = True
+
+ vertices = [(rows[0], cols[0]), (rows[1], cols[0]),
+ (rows[1], cols[1]), (rows[0], cols[1])]
+ mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
+ is_equal = numpy.all(numpy.equal(ref_mask, mask))
+ if not is_equal:
+ _logger.debug('%s failed with mask != ref_mask:',
+ test_name)
+ _logger.debug('result:\n%s', str(mask))
+ _logger.debug('ref:\n%s', str(ref_mask))
+ self.assertTrue(is_equal)
+
+ def test_eight(self):
+ """Tests with eight shape with different rotation and direction"""
+ ref_mask = numpy.array((
+ (1, 1, 1, 1, 1, 0),
+ (0, 1, 1, 1, 0, 0),
+ (0, 0, 1, 0, 0, 0),
+ (0, 0, 1, 0, 0, 0),
+ (0, 1, 1, 1, 0, 0),
+ (0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)
+ ref_mask_rot = numpy.asarray(numpy.logical_not(ref_mask),
+ dtype=numpy.uint8)
+ ref_mask_rot[:, -1] = 0
+ ref_mask_rot[-1, :] = 0
+
+ tests = {
+ 'dir 1': ([(0, 0), (5, 5), (5, 0), (0, 5)], ref_mask),
+ 'dir 1, rot 90': ([(5, 0), (0, 5), (5, 5), (0, 0)], ref_mask_rot),
+ 'dir 1, rot 180': ([(5, 5), (0, 0), (0, 5), (5, 0)], ref_mask),
+ 'dir 1, rot -90': ([(0, 5), (5, 0), (0, 0), (5, 5)], ref_mask_rot),
+ 'dir 2': ([(0, 0), (0, 5), (5, 0), (5, 5)], ref_mask),
+ 'dir 2, rot 90': ([(5, 0), (0, 0), (5, 5), (0, 5)], ref_mask_rot),
+ 'dir 2, rot 180': ([(5, 5), (5, 0), (0, 5), (0, 0)], ref_mask),
+ 'dir 2, rot -90': ([(0, 5), (5, 5), (0, 0), (5, 0)], ref_mask_rot),
+ }
+
+ for test_name, (vertices, ref_mask) in tests.items():
+ with self.subTest(msg=test_name):
+ mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
+ is_equal = numpy.all(numpy.equal(ref_mask, mask))
+ if not is_equal:
+ _logger.debug('%s failed with mask != ref_mask:',
+ test_name)
+ _logger.debug('result:\n%s', str(mask))
+ _logger.debug('ref:\n%s', str(ref_mask))
+ self.assertTrue(is_equal)
+
+ def test_shapes(self):
+ """Tests with shapes and reference mask"""
+ tests = {
+ # name: (
+ # polygon corners as a list of (row, col),
+ # ref_mask)
+ 'concave polygon': (
+ [(1, 1), (4, 3), (1, 5), (2, 3)],
+ numpy.array((
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 0, 1, 1, 1, 0, 0, 0),
+ (0, 0, 0, 1, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)),
+ 'concave polygon partly outside mask': (
+ [(-1, -1), (4, 3), (1, 5), (2, 3)],
+ numpy.array((
+ (1, 0, 0, 0, 0, 0),
+ (0, 1, 0, 0, 0, 0),
+ (0, 0, 1, 1, 1, 0),
+ (0, 0, 0, 1, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0),
+ (0, 0, 0, 0, 0, 0)), dtype=numpy.uint8)),
+ 'polygon surrounding mask': (
+ [(-1, -1), (-1, 7), (7, 7), (7, -1), (0, -1),
+ (8, -2), (8, 8), (-2, 8)],
+ numpy.zeros((6, 6), dtype=numpy.uint8))
+ }
+
+ for test_name, (vertices, ref_mask) in tests.items():
+ with self.subTest(msg=test_name):
+ mask = shapes.polygon_fill_mask(vertices, ref_mask.shape)
+ is_equal = numpy.all(numpy.equal(ref_mask, mask))
+ if not is_equal:
+ _logger.debug('%s failed with mask != ref_mask:',
+ test_name)
+ _logger.debug('result:\n%s', str(mask))
+ _logger.debug('ref:\n%s', str(ref_mask))
+ self.assertTrue(is_equal)
+
+
+class TestDrawLine(ParametricTestCase):
+ """basic draw line test"""
+
+ def test_aligned_lines(self):
+ """Test drawing horizontal, vertical and diagonal lines"""
+
+ lines = { # test_name: (drow, dcol)
+ 'Horizontal line, col0 < col1': (0, 10),
+ 'Horizontal line, col0 > col1': (0, -10),
+ 'Vertical line, row0 < row1': (10, 0),
+ 'Vertical line, row0 > row1': (-10, 0),
+ 'Diagonal col0 < col1 and row0 < row1': (10, 10),
+ 'Diagonal col0 < col1 and row0 > row1': (-10, 10),
+ 'Diagonal col0 > col1 and row0 < row1': (10, -10),
+ 'Diagonal col0 > col1 and row0 > row1': (-10, -10),
+ }
+ row0, col0 = 1, 2 # Start point
+
+ for test_name, (drow, dcol) in lines.items():
+ row1 = row0 + drow
+ col1 = col0 + dcol
+ with self.subTest(msg=test_name, drow=drow, dcol=dcol):
+ # Build reference coordinates from drow and dcol
+ if drow == 0:
+ rows = row0 * numpy.ones(abs(dcol) + 1)
+ else:
+ step = 1 if drow > 0 else -1
+ rows = row0 + numpy.arange(0, drow + step, step)
+
+ if dcol == 0:
+ cols = col0 * numpy.ones(abs(drow) + 1)
+ else:
+ step = 1 if dcol > 0 else -1
+ cols = col0 + numpy.arange(0, dcol + step, step)
+ ref_coords = rows, cols
+
+ result = shapes.draw_line(row0, col0, row1, col1)
+ self.assertTrue(self.isEqual(test_name, result, ref_coords))
+
+ def test_noline(self):
+ """Test pt0 == pt1"""
+ for width in range(4):
+ with self.subTest(width=width):
+ result = shapes.draw_line(1, 2, 1, 2, width)
+ self.assertTrue(numpy.all(numpy.equal(result, [(1,), (2,)])))
+
+ def test_lines(self):
+ """Test lines not aligned with axes for 8 slopes and directions"""
+ row0, col0 = 1, 1
+
+ dy, dx = 3, 5
+ ref_coords = numpy.array(
+ [(0, 0), (1, 1), (1, 2), (2, 3), (2, 4), (3, 5)])
+
+ # Build lines for the 8 octants from this coordinantes
+ lines = { # name: (drow, dcol, ref_coords)
+ '1st octant': (dy, dx, ref_coords),
+ '2nd octant': (dx, dy, ref_coords[:, (1, 0)]), # invert x and y
+ '3rd octant': (dx, -dy, ref_coords[:, (1, 0)] * (1, -1)),
+ '4th octant': (dy, -dx, ref_coords * (1, -1)),
+ '5th octant': (-dy, -dx, ref_coords * (-1, -1)),
+ '6th octant': (-dx, -dy, ref_coords[:, (1, 0)] * (-1, -1)),
+ '7th octant': (-dx, dy, ref_coords[:, (1, 0)] * (-1, 1)),
+ '8th octant': (-dy, dx, ref_coords * (-1, 1))
+ }
+
+ # Test with different starting points with positive and negative coords
+ for row0, col0 in ((0, 0), (2, 3), (-4, 1), (-5, -6), (8, -7)):
+ for name, (drow, dcol, ref_coords) in lines.items():
+ row1 = row0 + drow
+ col1 = col0 + dcol
+ # Transpose from ((row0, col0), ...) to (rows, cols)
+ ref_coords = numpy.transpose(ref_coords + (row0, col0))
+
+ with self.subTest(msg=name,
+ pt0=(row0, col0), pt1=(row1, col1)):
+ result = shapes.draw_line(row0, col0, row1, col1)
+ self.assertTrue(self.isEqual(name, result, ref_coords))
+
+ def test_width(self):
+ """Test of line width"""
+
+ lines = { # test_name: row0, col0, row1, col1, width, ref
+ 'horizontal w=2':
+ (0, 0, 0, 1, 2, ((0, 1, 0, 1),
+ (0, 0, 1, 1))),
+ 'horizontal w=3':
+ (0, 0, 0, 1, 3, ((-1, 0, 1, -1, 0, 1),
+ (0, 0, 0, 1, 1, 1))),
+ 'vertical w=2':
+ (0, 0, 1, 0, 2, ((0, 0, 1, 1),
+ (0, 1, 0, 1))),
+ 'vertical w=3':
+ (0, 0, 1, 0, 3, ((0, 0, 0, 1, 1, 1),
+ (-1, 0, 1, -1, 0, 1))),
+ 'diagonal w=3':
+ (0, 0, 1, 1, 3, ((-1, 0, 1, 0, 1, 2),
+ (0, 0, 0, 1, 1, 1))),
+ '1st octant w=3':
+ (0, 0, 1, 2, 3,
+ numpy.array(((-1, 0), (0, 0), (1, 0),
+ (0, 1), (1, 1), (2, 1),
+ (0, 2), (1, 2), (2, 2))).T),
+ '2nd octant w=3':
+ (0, 0, 2, 1, 3,
+ numpy.array(((0, -1), (0, 0), (0, 1),
+ (1, 0), (1, 1), (1, 2),
+ (2, 0), (2, 1), (2, 2))).T),
+ }
+
+ for test_name, (row0, col0, row1, col1, width, ref) in lines.items():
+ with self.subTest(msg=test_name,
+ pt0=(row0, col0), pt1=(row1, col1), width=width):
+ result = shapes.draw_line(row0, col0, row1, col1, width)
+ self.assertTrue(self.isEqual(test_name, result, ref))
+
+ def isEqual(self, test_name, result, ref):
+ """Test equality of two numpy arrays and log them if different"""
+ is_equal = numpy.all(numpy.equal(result, ref))
+ if not is_equal:
+ _logger.debug('%s failed with result != ref:',
+ test_name)
+ _logger.debug('result:\n%s', str(result))
+ _logger.debug('ref:\n%s', str(ref))
+ return is_equal
+
+
+class TestCircleFill(ParametricTestCase):
+ """Tests for circle filling"""
+
+ def testCircle(self):
+ """Test circle_fill with different input parameters"""
+
+ square3x3 = numpy.array(((-1, -1, -1, 0, 0, 0, 1, 1, 1),
+ (-1, 0, 1, -1, 0, 1, -1, 0, 1)))
+
+ tests = [
+ # crow, ccol, radius, ref_coords = (ref_rows, ref_cols)
+ (0, 0, 1, ((0,), (0,))),
+ (10, 15, 1, ((10,), (15,))),
+ (0, 0, 1.5, square3x3),
+ (5, 10, 2, (5 + square3x3[0], 10 + square3x3[1])),
+ (10, 20, 3.5, (
+ 10 + numpy.array((-3, -3, -3,
+ -2, -2, -2, -2, -2,
+ -1, -1, -1, -1, -1, -1, -1,
+ 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1,
+ 2, 2, 2, 2, 2,
+ 3, 3, 3)),
+ 20 + numpy.array((-1, 0, 1,
+ -2, -1, 0, 1, 2,
+ -3, -2, -1, 0, 1, 2, 3,
+ -3, -2, -1, 0, 1, 2, 3,
+ -3, -2, -1, 0, 1, 2, 3,
+ -2, -1, 0, 1, 2,
+ -1, 0, 1)))),
+ ]
+
+ for crow, ccol, radius, ref_coords in tests:
+ with self.subTest(crow=crow, ccol=ccol, radius=radius):
+ coords = shapes.circle_fill(crow, ccol, radius)
+ is_equal = numpy.all(numpy.equal(coords, ref_coords))
+ if not is_equal:
+ _logger.debug('result:\n%s', str(coords))
+ _logger.debug('ref:\n%s', str(ref_coords))
+ self.assertTrue(is_equal)
+
+
+class TestEllipseFill(unittest.TestCase):
+ """Tests for ellipse filling"""
+
+ def testPoint(self):
+ args = [1, 1, 1, 1]
+ result = shapes.ellipse_fill(*args)
+ expected = numpy.array(([1], [1]))
+ numpy.testing.assert_array_equal(result, expected)
+
+ def testTranslatedPoint(self):
+ args = [10, 10, 1, 1]
+ result = shapes.ellipse_fill(*args)
+ expected = numpy.array(([10], [10]))
+ numpy.testing.assert_array_equal(result, expected)
+
+ def testEllipse(self):
+ args = [0, 0, 20, 10]
+ rows, cols = shapes.ellipse_fill(*args)
+ self.assertEqual(len(rows), 617)
+ self.assertEqual(rows.mean(), 0)
+ self.assertAlmostEqual(rows.std(), 10.025575, places=3)
+ self.assertEqual(len(cols), 617)
+ self.assertEqual(cols.mean(), 0)
+ self.assertAlmostEqual(cols.std(), 4.897325, places=3)
+
+ def testTranslatedEllipse(self):
+ args = [0, 0, 20, 10]
+ expected_rows, expected_cols = shapes.ellipse_fill(*args)
+ args = [10, 50, 20, 10]
+ rows, cols = shapes.ellipse_fill(*args)
+ numpy.testing.assert_allclose(rows, expected_rows + 10)
+ numpy.testing.assert_allclose(cols, expected_cols + 50)
diff --git a/src/silx/image/test/test_tomography.py b/src/silx/image/test/test_tomography.py
new file mode 100644
index 0000000..f391a72
--- /dev/null
+++ b/src/silx/image/test/test_tomography.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Tests that the functions of tomography are valid
+"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "12/09/2017"
+
+import unittest
+import numpy
+from silx.test.utils import utilstest
+from silx.image import tomography
+
+class TestTomography(unittest.TestCase):
+ """
+
+ """
+
+ def setUp(self):
+ self.sinoTrueData = numpy.load(utilstest.getfile("sino500.npz"))["data"]
+
+ def testCalcCenterCentroid(self):
+ centerTD = tomography.calc_center_centroid(self.sinoTrueData)
+ self.assertTrue(numpy.isclose(centerTD, 256, rtol=0.01))
+
+ def testCalcCenterCorr(self):
+ centerTrueData = tomography.calc_center_corr(self.sinoTrueData,
+ fullrot=False,
+ props=1)
+ self.assertTrue(numpy.isclose(centerTrueData, 256, rtol=0.01))
diff --git a/silx/image/tomography.py b/src/silx/image/tomography.py
index 53855c1..53855c1 100644
--- a/silx/image/tomography.py
+++ b/src/silx/image/tomography.py
diff --git a/silx/image/utils.py b/src/silx/image/utils.py
index 996d010..996d010 100644
--- a/silx/image/utils.py
+++ b/src/silx/image/utils.py
diff --git a/silx/io/__init__.py b/src/silx/io/__init__.py
index b43d290..b43d290 100644
--- a/silx/io/__init__.py
+++ b/src/silx/io/__init__.py
diff --git a/src/silx/io/commonh5.py b/src/silx/io/commonh5.py
new file mode 100644
index 0000000..af4274f
--- /dev/null
+++ b/src/silx/io/commonh5.py
@@ -0,0 +1,1061 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module contains generic objects, emulating *h5py* groups, datasets and
+files. They are used in :mod:`spech5` and :mod:`fabioh5`.
+"""
+import collections
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import weakref
+
+import h5py
+import numpy
+
+from . import utils
+
+__authors__ = ["V. Valls", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "02/07/2018"
+
+
+class _MappingProxyType(abc.MutableMapping):
+ """Read-only dictionary
+
+ This class is available since Python 3.3, but not on earlyer Python
+ versions.
+ """
+
+ def __init__(self, data):
+ self._data = data
+
+ def __getitem__(self, key):
+ return self._data[key]
+
+ def __len__(self):
+ return len(self._data)
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def get(self, key, default=None):
+ return self._data.get(key, default)
+
+ def __setitem__(self, key, value):
+ raise RuntimeError("Cannot modify read-only dictionary")
+
+ def __delitem__(self, key):
+ raise RuntimeError("Cannot modify read-only dictionary")
+
+ def pop(self, key):
+ raise RuntimeError("Cannot modify read-only dictionary")
+
+ def clear(self):
+ raise RuntimeError("Cannot modify read-only dictionary")
+
+ def update(self, key, value):
+ raise RuntimeError("Cannot modify read-only dictionary")
+
+ def setdefault(self, key):
+ raise RuntimeError("Cannot modify read-only dictionary")
+
+
+class Node(object):
+ """This is the base class for all :mod:`spech5` and :mod:`fabioh5`
+ classes. It represents a tree node, and knows its parent node
+ (:attr:`parent`).
+ The API mimics a *h5py* node, with following attributes: :attr:`file`,
+ :attr:`attrs`, :attr:`name`, and :attr:`basename`.
+ """
+
+ def __init__(self, name, parent=None, attrs=None):
+ self._set_parent(parent)
+ self.__basename = name
+ self.__attrs = {}
+ if attrs is not None:
+ self.__attrs.update(attrs)
+
+ def _set_basename(self, name):
+ self.__basename = name
+
+ @property
+ def h5_class(self):
+ """Returns the HDF5 class which is mimicked by this class.
+
+ :rtype: H5Type
+ """
+ raise NotImplementedError()
+
+ @property
+ def h5py_class(self):
+ """Returns the h5py classes which is mimicked by this class. It can be
+ one of `h5py.File, h5py.Group` or `h5py.Dataset`
+
+ This should not be used anymore. Prefer using `h5_class`
+
+ :rtype: Class
+ """
+ h5_class = self.h5_class
+ if h5_class == utils.H5Type.FILE:
+ return h5py.File
+ elif h5_class == utils.H5Type.GROUP:
+ return h5py.Group
+ elif h5_class == utils.H5Type.DATASET:
+ return h5py.Dataset
+ elif h5_class == utils.H5Type.SOFT_LINK:
+ return h5py.SoftLink
+ raise NotImplementedError()
+
+ @property
+ def parent(self):
+ """Returns the parent of the node.
+
+ :rtype: Node
+ """
+ if self.__parent is None:
+ parent = None
+ else:
+ parent = self.__parent()
+ if parent is None:
+ self.__parent = None
+ return parent
+
+ def _set_parent(self, parent):
+ """Set the parent of this node.
+
+ It do not update the parent object.
+
+ :param Node parent: New parent for this node
+ """
+ if parent is not None:
+ self.__parent = weakref.ref(parent)
+ else:
+ self.__parent = None
+
+ @property
+ def file(self):
+ """Returns the file node of this node.
+
+ :rtype: Node
+ """
+ node = self
+ while node.parent is not None:
+ node = node.parent
+ if isinstance(node, File):
+ return node
+ else:
+ return None
+
+ @property
+ def attrs(self):
+ """Returns HDF5 attributes of this node.
+
+ :rtype: dict
+ """
+ if self._is_editable():
+ return self.__attrs
+ else:
+ return _MappingProxyType(self.__attrs)
+
+ @property
+ def name(self):
+ """Returns the HDF5 name of this node.
+ """
+ parent = self.parent
+ if parent is None:
+ return "/"
+ if parent.name == "/":
+ return "/" + self.basename
+ return parent.name + "/" + self.basename
+
+ @property
+ def basename(self):
+ """Returns the HDF5 basename of this node.
+ """
+ return self.__basename
+
+ def _is_editable(self):
+ """Returns true if the file is editable or if the node is not linked
+ to a tree.
+
+ :rtype: bool
+ """
+ f = self.file
+ return f is None or f.mode == "w"
+
+
+class Dataset(Node):
+ """This class handles a numpy data object, as a mimicry of a
+ *h5py.Dataset*.
+ """
+
+ def __init__(self, name, data, parent=None, attrs=None):
+ Node.__init__(self, name, parent, attrs=attrs)
+ if data is not None:
+ self._check_data(data)
+ self.__data = data
+
+ def _check_data(self, data):
+ """Check that the data provided by the dataset is valid.
+
+ It is valid when it can be stored in a HDF5 using h5py.
+
+ :param numpy.ndarray data: Data associated to the dataset
+ :raises TypeError: In the case the data is not valid.
+ """
+ if isinstance(data, (str, bytes)):
+ return
+
+ chartype = data.dtype.char
+ if chartype == "U":
+ pass
+ elif chartype == "O":
+ d = h5py.special_dtype(vlen=data.dtype)
+ if d is not None:
+ return
+ d = h5py.special_dtype(ref=data.dtype)
+ if d is not None:
+ return
+ else:
+ return
+
+ msg = "Type of the dataset '%s' is not supported. Found '%s'."
+ raise TypeError(msg % (self.name, data.dtype))
+
+ def _set_data(self, data):
+ """Set the data exposed by the dataset.
+
+ It have to be called only one time before the data is used. It should
+ not be edited after use.
+
+ :param numpy.ndarray data: Data associated to the dataset
+ """
+ self._check_data(data)
+ self.__data = data
+
+ def _get_data(self):
+ """Returns the exposed data
+
+ :rtype: numpy.ndarray
+ """
+ return self.__data
+
+ @property
+ def h5_class(self):
+ """Returns the HDF5 class which is mimicked by this class.
+
+ :rtype: H5Type
+ """
+ return utils.H5Type.DATASET
+
+ @property
+ def dtype(self):
+ """Returns the numpy datatype exposed by this dataset.
+
+ :rtype: numpy.dtype
+ """
+ return self._get_data().dtype
+
+ @property
+ def shape(self):
+ """Returns the shape of the data exposed by this dataset.
+
+ :rtype: tuple
+ """
+ if isinstance(self._get_data(), numpy.ndarray):
+ return self._get_data().shape
+ else:
+ return tuple()
+
+ @property
+ def size(self):
+ """Returns the size of the data exposed by this dataset.
+
+ :rtype: int
+ """
+ if isinstance(self._get_data(), numpy.ndarray):
+ return self._get_data().size
+ else:
+ # It is returned as float64 1.0 by h5py
+ return numpy.float64(1.0)
+
+ def __len__(self):
+ """Returns the size of the data exposed by this dataset.
+
+ :rtype: int
+ """
+ if isinstance(self._get_data(), numpy.ndarray):
+ return len(self._get_data())
+ else:
+ # It is returned as float64 1.0 by h5py
+ raise TypeError("Attempt to take len() of scalar dataset")
+
+ def __getitem__(self, item):
+ """Returns the slice of the data exposed by this dataset.
+
+ :rtype: numpy.ndarray
+ """
+ if not isinstance(self._get_data(), numpy.ndarray):
+ if item == Ellipsis:
+ return numpy.array(self._get_data())
+ elif item == tuple():
+ return self._get_data()
+ else:
+ raise ValueError("Scalar can only be reached with an ellipsis or an empty tuple")
+ return self._get_data().__getitem__(item)
+
+ def __str__(self):
+ basename = self.name.split("/")[-1]
+ return '<HDF5-like dataset "%s": shape %s, type "%s">' % \
+ (basename, self.shape, self.dtype.str)
+
+ def __getslice__(self, i, j):
+ """Returns the slice of the data exposed by this dataset.
+
+ Deprecated but still in use for python 2.7
+
+ :rtype: numpy.ndarray
+ """
+ return self.__getitem__(slice(i, j, None))
+
+ @property
+ def value(self):
+ """Returns the data exposed by this dataset.
+
+ Deprecated by h5py. It is prefered to use indexing `[()]`.
+
+ :rtype: numpy.ndarray
+ """
+ return self._get_data()
+
+ @property
+ def compression(self):
+ """Returns compression as provided by `h5py.Dataset`.
+
+ There is no compression."""
+ return None
+
+ @property
+ def compression_opts(self):
+ """Returns compression options as provided by `h5py.Dataset`.
+
+ There is no compression."""
+ return None
+
+ @property
+ def chunks(self):
+ """Returns chunks as provided by `h5py.Dataset`.
+
+ There is no chunks."""
+ return None
+
+ @property
+ def is_virtual(self):
+ """Checks virtual data as provided by `h5py.Dataset`"""
+ return False
+
+ def virtual_sources(self):
+ """Returns virtual dataset sources as provided by `h5py.Dataset`.
+
+ :rtype: list"""
+ raise RuntimeError("Not a virtual dataset")
+
+ @property
+ def external(self):
+ """Returns external sources as provided by `h5py.Dataset`.
+
+ :rtype: list or None"""
+ return None
+
+ def __array__(self, dtype=None):
+ # Special case for (0,)*-shape datasets
+ if numpy.product(self.shape) == 0:
+ return self[()]
+ else:
+ return numpy.array(self[...], dtype=self.dtype if dtype is None else dtype)
+
+ def __iter__(self):
+ """Iterate over the first axis. TypeError if scalar."""
+ if len(self.shape) == 0:
+ raise TypeError("Can't iterate over a scalar dataset")
+ return self._get_data().__iter__()
+
+ # make comparisons and operations on the data
+ def __eq__(self, other):
+ """When comparing datasets, compare the actual data."""
+ if utils.is_dataset(other):
+ return self[()] == other[()]
+ return self[()] == other
+
+ def __add__(self, other):
+ return self[()] + other
+
+ def __radd__(self, other):
+ return other + self[()]
+
+ def __sub__(self, other):
+ return self[()] - other
+
+ def __rsub__(self, other):
+ return other - self[()]
+
+ def __mul__(self, other):
+ return self[()] * other
+
+ def __rmul__(self, other):
+ return other * self[()]
+
+ def __truediv__(self, other):
+ return self[()] / other
+
+ def __rtruediv__(self, other):
+ return other / self[()]
+
+ def __floordiv__(self, other):
+ return self[()] // other
+
+ def __rfloordiv__(self, other):
+ return other // self[()]
+
+ def __neg__(self):
+ return -self[()]
+
+ def __abs__(self):
+ return abs(self[()])
+
+ def __float__(self):
+ return float(self[()])
+
+ def __int__(self):
+ return int(self[()])
+
+ def __bool__(self):
+ if self[()]:
+ return True
+ return False
+
+ def __nonzero__(self):
+ # python 2
+ return self.__bool__()
+
+ def __ne__(self, other):
+ if utils.is_dataset(other):
+ return self[()] != other[()]
+ else:
+ return self[()] != other
+
+ def __lt__(self, other):
+ if utils.is_dataset(other):
+ return self[()] < other[()]
+ else:
+ return self[()] < other
+
+ def __le__(self, other):
+ if utils.is_dataset(other):
+ return self[()] <= other[()]
+ else:
+ return self[()] <= other
+
+ def __gt__(self, other):
+ if utils.is_dataset(other):
+ return self[()] > other[()]
+ else:
+ return self[()] > other
+
+ def __ge__(self, other):
+ if utils.is_dataset(other):
+ return self[()] >= other[()]
+ else:
+ return self[()] >= other
+
+ def __getattr__(self, item):
+ """Proxy to underlying numpy array methods.
+ """
+ data = self._get_data()
+ if hasattr(data, item):
+ return getattr(data, item)
+
+ raise AttributeError("Dataset has no attribute %s" % item)
+
+
+class DatasetProxy(Dataset):
+ """Virtual dataset providing content of another dataset"""
+
+ def __init__(self, name, target, parent=None):
+ Dataset.__init__(self, name, data=None, parent=parent)
+ if not utils.is_dataset(target):
+ raise TypeError("A Dataset is expected but %s found", target.__class__)
+ self.__target = target
+
+ @property
+ def shape(self):
+ return self.__target.shape
+
+ @property
+ def size(self):
+ return self.__target.size
+
+ @property
+ def dtype(self):
+ return self.__target.dtype
+
+ def _get_data(self):
+ return self.__target[...]
+
+ @property
+ def attrs(self):
+ return self.__target.attrs
+
+
+class _LinkToDataset(Dataset):
+ """Virtual dataset providing link to another dataset"""
+
+ def __init__(self, name, target, parent=None):
+ Dataset.__init__(self, name, data=None, parent=parent)
+ self.__target = target
+
+ def _get_data(self):
+ return self.__target._get_data()
+
+ @property
+ def attrs(self):
+ return self.__target.attrs
+
+
+class LazyLoadableDataset(Dataset):
+ """Abstract dataset which provides a lazy loading of the data.
+
+ The class has to be inherited and the :meth:`_create_data` method has to be
+ implemented to return the numpy data exposed by the dataset. This factory
+ method is only called once, when the data is needed.
+ """
+
+ def __init__(self, name, parent=None, attrs=None):
+ super(LazyLoadableDataset, self).__init__(name, None, parent, attrs=attrs)
+ self._is_initialized = False
+
+ def _create_data(self):
+ """
+ Factory to create the data exposed by the dataset when it is needed.
+
+ It has to be implemented for the class to work.
+
+ :rtype: numpy.ndarray
+ """
+ raise NotImplementedError()
+
+ def _get_data(self):
+ """Returns the data exposed by the dataset.
+
+ Overwrite Dataset method :meth:`_get_data` to implement the lazy
+ loading feature.
+
+ :rtype: numpy.ndarray
+ """
+ if not self._is_initialized:
+ data = self._create_data()
+ # is_initialized before set_data to avoid infinit initialization
+ # is case of wrong check of the data
+ self._is_initialized = True
+ self._set_data(data)
+ return super(LazyLoadableDataset, self)._get_data()
+
+
+class SoftLink(Node):
+ """This class is a tree node that mimics a *h5py.Softlink*.
+
+ In this implementation, the path to the target must be absolute.
+ """
+ def __init__(self, name, path, parent=None):
+ assert str(path).startswith("/") # TODO: h5py also allows a relative path
+
+ Node.__init__(self, name, parent)
+
+ # attr target defined for spech5 backward compatibility
+ self.target = str(path)
+
+ @property
+ def h5_class(self):
+ """Returns the HDF5 class which is mimicked by this class.
+
+ :rtype: H5Type
+ """
+ return utils.H5Type.SOFT_LINK
+
+ @property
+ def path(self):
+ """Soft link value. Not guaranteed to be a valid path."""
+ return self.target
+
+
+class Group(Node):
+ """This class mimics a `h5py.Group`."""
+
+ def __init__(self, name, parent=None, attrs=None):
+ Node.__init__(self, name, parent, attrs=attrs)
+ self.__items = collections.OrderedDict()
+
+ def _get_items(self):
+ """Returns the child items as a name-node dictionary.
+
+ :rtype: dict
+ """
+ return self.__items
+
+ def add_node(self, node):
+ """Add a child to this group.
+
+ :param Node node: Child to add to this group
+ """
+ self._get_items()[node.basename] = node
+ node._set_parent(self)
+
+ @property
+ def h5_class(self):
+ """Returns the HDF5 class which is mimicked by this class.
+
+ :rtype: H5Type
+ """
+ return utils.H5Type.GROUP
+
+ def _get(self, name, getlink):
+ """If getlink is True and name points to an existing SoftLink, this
+ SoftLink is returned. In all other situations, we try to return a
+ Group or Dataset, or we raise a KeyError if we fail."""
+ if "/" not in name:
+ result = self._get_items()[name]
+ elif name.startswith("/"):
+ root = self.file
+ if name == "/":
+ return root
+ result = root._get(name[1:], getlink)
+ else:
+ path = name.split("/")
+ result = self
+ for item_name in path:
+ if isinstance(result, SoftLink):
+ # traverse links
+ l_name, l_target = result.name, result.path
+ result = result.file.get(l_target)
+ if result is None:
+ raise KeyError(
+ "Unable to open object (broken SoftLink %s -> %s)" %
+ (l_name, l_target))
+ if not item_name:
+ # trailing "/" in name (legal for accessing Groups only)
+ if isinstance(result, Group):
+ continue
+ if not isinstance(result, Group):
+ raise KeyError("Unable to open object (Component not found)")
+ result = result._get_items()[item_name]
+
+ if isinstance(result, SoftLink) and not getlink:
+ link = result
+ target = result.file.get(link.path)
+ if result is None:
+ msg = "Unable to open object (broken SoftLink %s -> %s)"
+ raise KeyError(msg % (link.name, link.path))
+ # Convert SoftLink into typed group/dataset
+ if isinstance(target, Group):
+ result = _LinkToGroup(name=link.basename, target=target, parent=link.parent)
+ elif isinstance(target, Dataset):
+ result = _LinkToDataset(name=link.basename, target=target, parent=link.parent)
+ else:
+ raise TypeError("Unexpected target type %s" % type(target))
+
+ return result
+
+ def get(self, name, default=None, getclass=False, getlink=False):
+ """Retrieve an item or other information.
+
+ If getlink only is true, the returned value is always `h5py.HardLink`,
+ because this implementation do not use links. Like the original
+ implementation.
+
+ :param str name: name of the item
+ :param object default: default value returned if the name is not found
+ :param bool getclass: if true, the returned object is the class of the object found
+ :param bool getlink: if true, links object are returned instead of the target
+ :return: An object, else None
+ :rtype: object
+ """
+ if name not in self:
+ return default
+
+ node = self._get(name, getlink=True)
+ if isinstance(node, SoftLink) and not getlink:
+ # get target
+ try:
+ node = self._get(name, getlink=False)
+ except KeyError:
+ return default
+ elif not isinstance(node, SoftLink) and getlink:
+ # ExternalLink objects don't exist in silx, so it must be a HardLink
+ node = h5py.HardLink()
+
+ if getclass:
+ obj = utils.get_h5py_class(node)
+ if obj is None:
+ obj = node.__class__
+ else:
+ obj = node
+ return obj
+
+ def __setitem__(self, name, obj):
+ """Add an object to the group.
+
+ :param str name: Location on the group to store the object.
+ This path name must not exists.
+ :param object obj: Object to store on the file. According to the type,
+ the behaviour will not be the same.
+
+ - `commonh5.SoftLink`: Create the corresponding link.
+ - `numpy.ndarray`: The array is converted to a dataset object.
+ - `commonh5.Node`: A hard link should be created pointing to the
+ given object. This implementation uses a soft link.
+ If the node do not have parent it is connected to the tree
+ without using a link (that's a hard link behaviour).
+ - other object: Convert first the object with ndarray and then
+ store it. ValueError if the resulting array dtype is not
+ supported.
+ """
+ if name in self:
+ # From the h5py API
+ raise RuntimeError("Unable to create link (name already exists)")
+
+ elements = name.rsplit("/", 1)
+ if len(elements) == 1:
+ parent = self
+ basename = elements[0]
+ else:
+ group_path, basename = elements
+ if group_path in self:
+ parent = self[group_path]
+ else:
+ parent = self.create_group(group_path)
+
+ if isinstance(obj, SoftLink):
+ obj._set_basename(basename)
+ node = obj
+ elif isinstance(obj, Node):
+ if obj.parent is None:
+ obj._set_basename(basename)
+ node = obj
+ else:
+ node = SoftLink(basename, obj.name)
+ elif isinstance(obj, numpy.dtype):
+ node = Dataset(basename, data=obj)
+ elif isinstance(obj, numpy.ndarray):
+ node = Dataset(basename, data=obj)
+ else:
+ data = numpy.array(obj)
+ try:
+ node = Dataset(basename, data=data)
+ except TypeError as e:
+ raise ValueError(e.args[0])
+
+ parent.add_node(node)
+
+ def __getitem__(self, name):
+ """Return a child from his name.
+
+ :param str name: name of a member or a path throug members using '/'
+ separator. A '/' as a prefix access to the root item of the tree.
+ :rtype: Node
+ """
+ if name is None or name == "":
+ raise ValueError("No name")
+ return self._get(name, getlink=False)
+
+ def __contains__(self, name):
+ """Returns true if name is an existing child of this group.
+
+ :rtype: bool
+ """
+ if "/" not in name:
+ return name in self._get_items()
+
+ if name.startswith("/"):
+ # h5py allows to access any valid full path from any group
+ node = self.file
+ else:
+ node = self
+
+ name = name.lstrip("/")
+ basenames = name.split("/")
+ for basename in basenames:
+ if basename.strip() == "":
+ # presence of a trailing "/" in name
+ # (OK for groups, not for datasets)
+ if isinstance(node, SoftLink):
+ # traverse links
+ node = node.file.get(node.path, getlink=False)
+ if node is None:
+ # broken link
+ return False
+ if utils.is_dataset(node):
+ return False
+ continue
+ if basename not in node._get_items():
+ return False
+ node = node[basename]
+
+ return True
+
+ def __len__(self):
+ """Returns the number of children contained in this group.
+
+ :rtype: int
+ """
+ return len(self._get_items())
+
+ def __iter__(self):
+ """Iterate over member names"""
+ for x in self._get_items().__iter__():
+ yield x
+
+ def keys(self):
+ """Returns an iterator over the children's names in a group."""
+ return self._get_items().keys()
+
+ def values(self):
+ """Returns an iterator over the children nodes (groups and datasets)
+ in a group.
+
+ .. versionadded:: 0.6
+ """
+ return self._get_items().values()
+
+ def items(self):
+ """Returns items iterator containing name-node mapping.
+
+ :rtype: iterator
+ """
+ return self._get_items().items()
+
+ def visit(self, func, visit_links=False):
+ """Recursively visit all names in this group and subgroups.
+ See the documentation for `h5py.Group.visit` for more help.
+
+ :param func: Callable (function, method or callable object)
+ :type func: callable
+ """
+ origin_name = self.name
+ return self._visit(func, origin_name, visit_links)
+
+ def visititems(self, func, visit_links=False):
+ """Recursively visit names and objects in this group.
+ See the documentation for `h5py.Group.visititems` for more help.
+
+ :param func: Callable (function, method or callable object)
+ :type func: callable
+ :param bool visit_links: If *False*, ignore links. If *True*,
+ call `func(name)` for links and recurse into target groups.
+ """
+ origin_name = self.name
+ return self._visit(func, origin_name, visit_links,
+ visititems=True)
+
+ def _visit(self, func, origin_name,
+ visit_links=False, visititems=False):
+ """
+
+ :param origin_name: name of first group that initiated the recursion
+ This is used to compute the relative path from each item's
+ absolute path.
+ """
+ for member in self.values():
+ ret = None
+ if not isinstance(member, SoftLink) or visit_links:
+ relative_name = member.name[len(origin_name):]
+ # remove leading slash and unnecessary trailing slash
+ relative_name = relative_name.strip("/")
+ if visititems:
+ ret = func(relative_name, member)
+ else:
+ ret = func(relative_name)
+ if ret is not None:
+ return ret
+ if isinstance(member, Group):
+ member._visit(func, origin_name, visit_links, visititems)
+
+ def create_group(self, name):
+ """Create and return a new subgroup.
+
+ Name may be absolute or relative. Fails if the target name already
+ exists.
+
+ :param str name: Name of the new group
+ """
+ if not self._is_editable():
+ raise RuntimeError("File is not editable")
+ if name in self:
+ raise ValueError("Unable to create group (name already exists)")
+
+ if name.startswith("/"):
+ name = name[1:]
+ return self.file.create_group(name)
+
+ elements = name.split('/')
+ group = self
+ for basename in elements:
+ if basename in group:
+ group = group[basename]
+ if not isinstance(group, Group):
+ raise RuntimeError("Unable to create group (group parent is missing")
+ else:
+ node = Group(basename)
+ group.add_node(node)
+ group = node
+ return group
+
+ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
+ """Create and return a sub dataset.
+
+ :param str name: Name of the dataset.
+ :param shape: Dataset shape. Use "()" for scalar datasets.
+ Required if "data" isn't provided.
+ :param dtype: Numpy dtype or string.
+ If omitted, dtype('f') will be used.
+ Required if "data" isn't provided; otherwise, overrides data
+ array's dtype.
+ :param numpy.ndarray data: Provide data to initialize the dataset.
+ If used, you can omit shape and dtype arguments.
+ :param kwds: Extra arguments. Nothing yet supported.
+ """
+ if not self._is_editable():
+ raise RuntimeError("File is not editable")
+ if len(kwds) > 0:
+ raise TypeError("Extra args provided, but nothing supported")
+ if "/" in name:
+ raise TypeError("Path are not supported")
+ if data is None:
+ if dtype is None:
+ dtype = numpy.float64
+ data = numpy.empty(shape=shape, dtype=dtype)
+ elif dtype is not None:
+ data = data.astype(dtype)
+ dataset = Dataset(name, data)
+ self.add_node(dataset)
+ return dataset
+
+
+class _LinkToGroup(Group):
+ """Virtual group providing link to another group"""
+
+ def __init__(self, name, target, parent=None):
+ Group.__init__(self, name, parent=parent)
+ self.__target = target
+
+ def _get_items(self):
+ return self.__target._get_items()
+
+ @property
+ def attrs(self):
+ return self.__target.attrs
+
+
+class LazyLoadableGroup(Group):
+ """Abstract group which provides a lazy loading of the child.
+
+ The class has to be inherited and the :meth:`_create_child` method has
+ to be implemented to add (:meth:`_add_node`) all children. This factory
+ is only called once, when children are needed.
+ """
+
+ def __init__(self, name, parent=None, attrs=None):
+ Group.__init__(self, name, parent, attrs)
+ self.__is_initialized = False
+
+ def _get_items(self):
+ """Returns the internal structure which contains the children.
+
+ It overwrite method :meth:`_get_items` to implement the lazy
+ loading feature.
+
+ :rtype: dict
+ """
+ if not self.__is_initialized:
+ self.__is_initialized = True
+ self._create_child()
+ return Group._get_items(self)
+
+ def _create_child(self):
+ """
+ Factory to create the child contained by the group when it is needed.
+
+ It has to be implemented to work.
+ """
+ raise NotImplementedError()
+
+
+class File(Group):
+ """This class is the special :class:`Group` that is the root node
+ of the tree structure. It mimics `h5py.File`."""
+
+ def __init__(self, name=None, mode=None, attrs=None):
+ """
+ Constructor
+
+ :param str name: File name if it exists
+ :param str mode: Access mode
+ - "r": Read-only. Methods :meth:`create_dataset` and
+ :meth:`create_group` are locked.
+ - "w": File is editable. Methods :meth:`create_dataset` and
+ :meth:`create_group` are available.
+ :param dict attrs: Default attributes
+ """
+ Group.__init__(self, name="", parent=None, attrs=attrs)
+ self._file_name = name
+ if mode is None:
+ mode = "r"
+ assert(mode in ["r", "w"])
+ self._mode = mode
+
+ @property
+ def filename(self):
+ return self._file_name
+
+ @property
+ def mode(self):
+ return self._mode
+
+ @property
+ def h5_class(self):
+ """Returns the :class:`h5py.File` class"""
+ return utils.H5Type.FILE
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self):
+ """Close the object, and free up associated resources.
+ """
+ # should be implemented in subclass
+ pass
diff --git a/silx/io/configdict.py b/src/silx/io/configdict.py
index c028211..c028211 100644
--- a/silx/io/configdict.py
+++ b/src/silx/io/configdict.py
diff --git a/src/silx/io/convert.py b/src/silx/io/convert.py
new file mode 100644
index 0000000..ba9a254
--- /dev/null
+++ b/src/silx/io/convert.py
@@ -0,0 +1,335 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides classes and function to convert file formats supported
+by *silx* into HDF5 file. Currently, SPEC file and fabio images are the
+supported formats.
+
+Read the documentation of :mod:`silx.io.spech5`, :mod:`silx.io.fioh5` and :mod:`silx.io.fabioh5` for
+information on the structure of the output HDF5 files.
+
+Text strings are written to the HDF5 datasets as variable-length utf-8.
+
+.. warning::
+
+ The output format for text strings changed in silx version 0.7.0.
+ Prior to that, text was output as fixed-length ASCII.
+
+ To be on the safe side, when reading back a HDF5 file written with an
+ older version of silx, you can test for the presence of a *decode*
+ attribute. To ensure that you always work with unicode text::
+
+ >>> import h5py
+ >>> h5f = h5py.File("my_scans.h5", "r")
+ >>> title = h5f["/68.1/title"]
+ >>> if hasattr(title, "decode"):
+ ... title = title.decode()
+
+
+.. note:: This module has a dependency on the `h5py <http://www.h5py.org/>`_
+ library, which is not a mandatory dependency for `silx`. You might need
+ to install it if you don't already have it.
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+
+import logging
+
+import h5py
+import numpy
+
+import silx.io
+from .utils import is_dataset, is_group, is_softlink, visitall
+from . import fabioh5
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _create_link(h5f, link_name, target_name,
+ link_type="soft", overwrite_data=False):
+ """Create a link in a HDF5 file
+
+ If member with name ``link_name`` already exists, delete it first or
+ ignore link depending on global param ``overwrite_data``.
+
+ :param h5f: :class:`h5py.File` object
+ :param link_name: Link path
+ :param target_name: Handle for target group or dataset
+ :param str link_type: "soft" or "hard"
+ :param bool overwrite_data: If True, delete existing member (group,
+ dataset or link) with the same name. Default is False.
+ """
+ if link_name not in h5f:
+ _logger.debug("Creating link " + link_name + " -> " + target_name)
+ elif overwrite_data:
+ _logger.warning("Overwriting " + link_name + " with link to " +
+ target_name)
+ del h5f[link_name]
+ else:
+ _logger.warning(link_name + " already exist. Cannot create link to " +
+ target_name)
+ return None
+
+ if link_type == "hard":
+ h5f[link_name] = h5f[target_name]
+ elif link_type == "soft":
+ h5f[link_name] = h5py.SoftLink(target_name)
+ else:
+ raise ValueError("link_type must be 'hard' or 'soft'")
+
+
+def _attr_utf8(attr_value):
+ """If attr_value is bytes, make sure we output utf-8
+
+ :param attr_value: String (possibly bytes if PY2)
+ :return: Attr ready to be written by h5py as utf8
+ """
+ if isinstance(attr_value, (bytes, str)):
+ out_attr_value = numpy.array(
+ attr_value,
+ dtype=h5py.special_dtype(vlen=str))
+ else:
+ out_attr_value = attr_value
+
+ return out_attr_value
+
+
+class Hdf5Writer(object):
+ """Converter class to write the content of a data file to a HDF5 file.
+ """
+ def __init__(self,
+ h5path='/',
+ overwrite_data=False,
+ link_type="soft",
+ create_dataset_args=None,
+ min_size=500):
+ """
+
+ :param h5path: Target path where the scan groups will be written
+ in the output HDF5 file.
+ :param bool overwrite_data:
+ See documentation of :func:`write_to_h5`
+ :param str link_type: ``"hard"`` or ``"soft"`` (default)
+ :param dict create_dataset_args: Dictionary of args you want to pass to
+ ``h5py.File.create_dataset``.
+ See documentation of :func:`write_to_h5`
+ :param int min_size:
+ See documentation of :func:`write_to_h5`
+ """
+ self.h5path = h5path
+ if not h5path.startswith("/"):
+ # target path must be absolute
+ self.h5path = "/" + h5path
+ if not self.h5path.endswith("/"):
+ self.h5path += "/"
+
+ self._h5f = None
+ """h5py.File object, assigned in :meth:`write`"""
+
+ if create_dataset_args is None:
+ create_dataset_args = {}
+ self.create_dataset_args = create_dataset_args
+
+ self.min_size = min_size
+
+ self.overwrite_data = overwrite_data # boolean
+
+ self.link_type = link_type
+ """'soft' or 'hard' """
+
+ self._links = []
+ """List of *(link_path, target_path)* tuples."""
+
+ def write(self, infile, h5f):
+ """Copy `infile` content to `h5f` file under `h5path`.
+
+ All the parameters needed for the conversion have been initialized
+ in the constructor.
+
+ External links in `infile` are ignored.
+
+ :param Union[commonh5.Group,h5py.Group] infile:
+ File/Class from which to read the content to copy from.
+ :param h5py.File h5f: File where to write the copied content to
+ """
+ # Recurse through all groups and datasets to add them to the HDF5
+ self._h5f = h5f
+ for name, item in visitall(infile):
+ self.append_member_to_h5(name, item)
+
+ # Handle the attributes of the root group
+ root_grp = h5f[self.h5path]
+ for key in infile.attrs:
+ if self.overwrite_data or key not in root_grp.attrs:
+ root_grp.attrs.create(key,
+ _attr_utf8(infile.attrs[key]))
+
+ # Handle links at the end, when their targets are created
+ for link_name, target_name in self._links:
+ _create_link(self._h5f, link_name, target_name,
+ link_type=self.link_type,
+ overwrite_data=self.overwrite_data)
+ self._links = []
+
+ def append_member_to_h5(self, h5like_name, obj):
+ """Add one group or one dataset to :attr:`h5f`"""
+ h5_name = self.h5path + h5like_name.lstrip("/")
+ if is_softlink(obj):
+ # links to be created after all groups and datasets
+ h5_target = self.h5path + obj.path.lstrip("/")
+ self._links.append((h5_name, h5_target))
+
+ elif is_dataset(obj):
+ _logger.debug("Saving dataset: " + h5_name)
+
+ member_initially_exists = h5_name in self._h5f
+
+ if self.overwrite_data and member_initially_exists:
+ _logger.warning("Overwriting dataset: " + h5_name)
+ del self._h5f[h5_name]
+
+ if self.overwrite_data or not member_initially_exists:
+ if isinstance(obj, fabioh5.FrameData) and len(obj.shape) > 2:
+ # special case of multiframe data
+ # write frame by frame to save memory usage low
+ ds = self._h5f.create_dataset(h5_name,
+ shape=obj.shape,
+ dtype=obj.dtype,
+ **self.create_dataset_args)
+ for i, frame in enumerate(obj):
+ ds[i] = frame
+ else:
+ # fancy arguments don't apply to small dataset
+ if obj.size < self.min_size:
+ ds = self._h5f.create_dataset(h5_name, data=obj[()])
+ else:
+ ds = self._h5f.create_dataset(h5_name, data=obj[()],
+ **self.create_dataset_args)
+ else:
+ ds = self._h5f[h5_name]
+
+ # add HDF5 attributes
+ for key in obj.attrs:
+ if self.overwrite_data or key not in ds.attrs:
+ ds.attrs.create(key,
+ _attr_utf8(obj.attrs[key]))
+
+ if not self.overwrite_data and member_initially_exists:
+ _logger.warning("Not overwriting existing dataset: " + h5_name)
+
+ elif is_group(obj):
+ if h5_name not in self._h5f:
+ _logger.debug("Creating group: " + h5_name)
+ grp = self._h5f.create_group(h5_name)
+ else:
+ grp = self._h5f[h5_name]
+
+ # add HDF5 attributes
+ for key in obj.attrs:
+ if self.overwrite_data or key not in grp.attrs:
+ grp.attrs.create(key,
+ _attr_utf8(obj.attrs[key]))
+ else:
+ _logger.warning("Unsuppored entity, ignoring: %s", h5_name)
+
+
+def write_to_h5(infile, h5file, h5path='/', mode="a",
+ overwrite_data=False, link_type="soft",
+ create_dataset_args=None, min_size=500):
+ """Write content of a h5py-like object into a HDF5 file.
+
+ Warning: External links in `infile` are ignored.
+
+ :param infile: Path of input file, :class:`commonh5.File`,
+ :class:`commonh5.Group`, :class:`h5py.File` or :class:`h5py.Group`
+ :param h5file: Path of output HDF5 file or HDF5 file handle
+ (`h5py.File` object)
+ :param str h5path: Target path in HDF5 file in which scan groups are created.
+ Default is root (``"/"``)
+ :param str mode: Can be ``"r+"`` (read/write, file must exist),
+ ``"w"`` (write, existing file is lost), ``"w-"`` (write, fail
+ if exists) or ``"a"`` (read/write if exists, create otherwise).
+ This parameter is ignored if ``h5file`` is a file handle.
+ :param bool overwrite_data: If ``True``, existing groups and datasets can be
+ overwritten, if ``False`` they are skipped. This parameter is only
+ relevant if ``file_mode`` is ``"r+"`` or ``"a"``.
+ :param str link_type: *"soft"* (default) or *"hard"*
+ :param dict create_dataset_args: Dictionary of args you want to pass to
+ ``h5py.File.create_dataset``. This allows you to specify filters and
+ compression parameters. Don't specify ``name`` and ``data``.
+ These arguments are only applied to datasets larger than 1MB.
+ :param int min_size: Minimum number of elements in a dataset to apply
+ chunking and compression. Default is 500.
+
+ The structure of the spec data in an HDF5 file is described in the
+ documentation of :mod:`silx.io.spech5`.
+ """
+ writer = Hdf5Writer(h5path=h5path,
+ overwrite_data=overwrite_data,
+ link_type=link_type,
+ create_dataset_args=create_dataset_args,
+ min_size=min_size)
+
+ # both infile and h5file can be either file handle or a file name: 4 cases
+ if not isinstance(h5file, h5py.File) and not is_group(infile):
+ with silx.io.open(infile) as h5pylike:
+ with h5py.File(h5file, mode) as h5f:
+ writer.write(h5pylike, h5f)
+ elif isinstance(h5file, h5py.File) and not is_group(infile):
+ with silx.io.open(infile) as h5pylike:
+ writer.write(h5pylike, h5file)
+ elif is_group(infile) and not isinstance(h5file, h5py.File):
+ with h5py.File(h5file, mode) as h5f:
+ writer.write(infile, h5f)
+ else:
+ writer.write(infile, h5file)
+
+
+def convert(infile, h5file, mode="w-", create_dataset_args=None):
+ """Convert a supported file into an HDF5 file, write scans into the
+ root group (``/``).
+
+ This is a convenience shortcut to call::
+
+ write_to_h5(h5like, h5file, h5path='/',
+ mode="w-", link_type="soft")
+
+ :param infile: Path of input file or :class:`commonh5.File` object
+ or :class:`commonh5.Group` object
+ :param h5file: Path of output HDF5 file, or h5py.File object
+ :param mode: Can be ``"w"`` (write, existing file is
+ lost), ``"w-"`` (write, fail if exists). This is ignored
+ if ``h5file`` is a file handle.
+ :param create_dataset_args: Dictionary of args you want to pass to
+ ``h5py.File.create_dataset``. This allows you to specify filters and
+ compression parameters. Don't specify ``name`` and ``data``.
+ """
+ if mode not in ["w", "w-"]:
+ raise IOError("File mode must be 'w' or 'w-'. Use write_to_h5" +
+ " to append data to an existing HDF5 file.")
+ write_to_h5(infile, h5file, h5path='/', mode=mode,
+ create_dataset_args=create_dataset_args)
diff --git a/src/silx/io/dictdump.py b/src/silx/io/dictdump.py
new file mode 100644
index 0000000..a24de42
--- /dev/null
+++ b/src/silx/io/dictdump.py
@@ -0,0 +1,843 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module offers a set of functions to dump a python dictionary indexed
+by text strings to following file formats: `HDF5, INI, JSON`
+"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+import json
+import logging
+import numpy
+import os.path
+import sys
+import h5py
+
+from .configdict import ConfigDict
+from .utils import is_group
+from .utils import is_dataset
+from .utils import is_link
+from .utils import is_softlink
+from .utils import is_externallink
+from .utils import is_file as is_h5_file_like
+from .utils import open as h5open
+from .utils import h5py_read_dataset
+from .utils import H5pyAttributesReadWrapper
+from silx.utils.deprecation import deprecated_warning
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+logger = logging.getLogger(__name__)
+
+vlen_utf8 = h5py.special_dtype(vlen=str)
+vlen_bytes = h5py.special_dtype(vlen=bytes)
+
+UPDATE_MODE_VALID_EXISTING_VALUES = ("add", "replace", "modify")
+
+
+def _prepare_hdf5_write_value(array_like):
+ """Cast a python object into a numpy array in a HDF5 friendly format.
+
+ :param array_like: Input dataset in a type that can be digested by
+ ``numpy.array()`` (`str`, `list`, `numpy.ndarray`…)
+ :return: ``numpy.ndarray`` ready to be written as an HDF5 dataset
+ """
+ array = numpy.asarray(array_like)
+ if numpy.issubdtype(array.dtype, numpy.bytes_):
+ return numpy.array(array_like, dtype=vlen_bytes)
+ elif numpy.issubdtype(array.dtype, numpy.str_):
+ return numpy.array(array_like, dtype=vlen_utf8)
+ else:
+ return array
+
+
+class _SafeH5FileWrite:
+ """Context manager returning a :class:`h5py.File` object.
+
+ If this object is initialized with a file path, we open the file
+ and then we close it on exiting.
+
+ If a :class:`h5py.File` instance is provided to :meth:`__init__` rather
+ than a path, we assume that the user is responsible for closing the
+ file.
+
+ This behavior is well suited for handling h5py file in a recursive
+ function. The object is created in the initial call if a path is provided,
+ and it is closed only at the end when all the processing is finished.
+ """
+ def __init__(self, h5file, mode="w"):
+ """
+ :param h5file: HDF5 file path or :class:`h5py.File` instance
+ :param str mode: Can be ``"r+"`` (read/write, file must exist),
+ ``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
+ exists) or ``"a"`` (read/write if exists, create otherwise).
+ This parameter is ignored if ``h5file`` is a file handle.
+ """
+ self.raw_h5file = h5file
+ self.mode = mode
+
+ def __enter__(self):
+ if not isinstance(self.raw_h5file, h5py.File):
+ self.h5file = h5py.File(self.raw_h5file, self.mode)
+ self.close_when_finished = True
+ else:
+ self.h5file = self.raw_h5file
+ self.close_when_finished = False
+ return self.h5file
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.close_when_finished:
+ self.h5file.close()
+
+
+class _SafeH5FileRead:
+ """Context manager returning a :class:`h5py.File` or a
+ :class:`silx.io.spech5.SpecH5` or a :class:`silx.io.fabioh5.File` object.
+
+ The general behavior is the same as :class:`_SafeH5FileWrite` except
+ that SPEC files and all formats supported by fabio can also be opened,
+ but in read-only mode.
+ """
+ def __init__(self, h5file):
+ """
+
+ :param h5file: HDF5 file path or h5py.File-like object
+ """
+ self.raw_h5file = h5file
+
+ def __enter__(self):
+ if not is_h5_file_like(self.raw_h5file):
+ self.h5file = h5open(self.raw_h5file)
+ self.close_when_finished = True
+ else:
+ self.h5file = self.raw_h5file
+ self.close_when_finished = False
+
+ return self.h5file
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.close_when_finished:
+ self.h5file.close()
+
+
+def _normalize_h5_path(h5root, h5path):
+ """
+ :param h5root: File name or h5py-like File, Group or Dataset
+ :param str h5path: relative to ``h5root``
+ :returns 2-tuple: (File or file object, h5path)
+ """
+ if is_group(h5root):
+ group_name = h5root.name
+ if group_name == "/":
+ pass
+ elif h5path:
+ h5path = group_name + "/" + h5path
+ else:
+ h5path = group_name
+ h5file = h5root.file
+ elif is_dataset(h5root):
+ h5path = h5root.name
+ h5file = h5root.file
+ else:
+ h5file = h5root
+ if not h5path:
+ h5path = "/"
+ elif not h5path.endswith("/"):
+ h5path += "/"
+ return h5file, h5path
+
+
+def dicttoh5(treedict, h5file, h5path='/',
+ mode="w", overwrite_data=None,
+ create_dataset_args=None, update_mode=None):
+ """Write a nested dictionary to a HDF5 file, using keys as member names.
+
+ If a dictionary value is a sub-dictionary, a group is created. If it is
+ any other data type, it is cast into a numpy array and written as a
+ :mod:`h5py` dataset. Dictionary keys must be strings and cannot contain
+ the ``/`` character.
+
+ If dictionary keys are tuples they are interpreted to set h5 attributes.
+ The tuples should have the format (dataset_name, attr_name).
+
+ Existing HDF5 items can be deleted by providing the dictionary value
+ ``None``, provided that ``update_mode in ["modify", "replace"]``.
+
+ .. note::
+
+ This function requires `h5py <http://www.h5py.org/>`_ to be installed.
+
+ :param treedict: Nested dictionary/tree structure with strings or tuples as
+ keys and array-like objects as leafs. The ``"/"`` character can be used
+ to define sub trees. If tuples are used as keys they should have the
+ format (dataset_name,attr_name) and will add a 5h attribute with the
+ corresponding value.
+ :param h5file: File name or h5py-like File, Group or Dataset
+ :param h5path: Target path in the HDF5 file relative to ``h5file``.
+ Default is root (``"/"``)
+ :param mode: Can be ``"r+"`` (read/write, file must exist),
+ ``"w"`` (write, existing file is lost), ``"w-"`` (write, fail if
+ exists) or ``"a"`` (read/write if exists, create otherwise).
+ This parameter is ignored if ``h5file`` is a file handle.
+ :param overwrite_data: Deprecated. ``True`` is approximately equivalent
+ to ``update_mode="modify"`` and ``False`` is equivalent to
+ ``update_mode="add"``.
+ :param create_dataset_args: Dictionary of args you want to pass to
+ ``h5f.create_dataset``. This allows you to specify filters and
+ compression parameters. Don't specify ``name`` and ``data``.
+ :param update_mode: Can be ``add`` (default), ``modify`` or ``replace``.
+
+ * ``add``: Extend the existing HDF5 tree when possible. Existing HDF5
+ items (groups, datasets and attributes) remain untouched.
+ * ``modify``: Extend the existing HDF5 tree when possible, modify
+ existing attributes, modify same-sized dataset values and delete
+ HDF5 items with a ``None`` value in the dict tree.
+ * ``replace``: Replace the existing HDF5 tree. Items from the root of
+ the HDF5 tree that are not present in the root of the dict tree
+ will remain untouched.
+
+ Example::
+
+ from silx.io.dictdump import dicttoh5
+
+ city_area = {
+ "Europe": {
+ "France": {
+ "Isère": {
+ "Grenoble": 18.44,
+ ("Grenoble","unit"): "km2"
+ },
+ "Nord": {
+ "Tourcoing": 15.19,
+ ("Tourcoing","unit"): "km2"
+ },
+ },
+ },
+ }
+
+ create_ds_args = {'compression': "gzip",
+ 'shuffle': True,
+ 'fletcher32': True}
+
+ dicttoh5(city_area, "cities.h5", h5path="/area",
+ create_dataset_args=create_ds_args)
+ """
+
+ if overwrite_data is not None:
+ reason = (
+ "`overwrite_data=True` becomes `update_mode='modify'` and "
+ "`overwrite_data=False` becomes `update_mode='add'`"
+ )
+ deprecated_warning(
+ type_="argument",
+ name="overwrite_data",
+ reason=reason,
+ replacement="update_mode",
+ since_version="0.15",
+ )
+
+ if update_mode is None:
+ if overwrite_data:
+ update_mode = "modify"
+ else:
+ update_mode = "add"
+ else:
+ if update_mode not in UPDATE_MODE_VALID_EXISTING_VALUES:
+ raise ValueError((
+ "Argument 'update_mode' can only have values: {}"
+ "".format(UPDATE_MODE_VALID_EXISTING_VALUES)
+ ))
+ if overwrite_data is not None:
+ logger.warning("The argument `overwrite_data` is ignored")
+
+ if not isinstance(treedict, Mapping):
+ raise TypeError("'treedict' must be a dictionary")
+
+ h5file, h5path = _normalize_h5_path(h5file, h5path)
+
+ def _iter_treedict(attributes=False):
+ nonlocal treedict
+ for key, value in treedict.items():
+ if isinstance(key, tuple) == attributes:
+ yield key, value
+
+ change_allowed = update_mode in ("replace", "modify")
+
+ with _SafeH5FileWrite(h5file, mode=mode) as h5f:
+ # Create the root of the tree
+ if h5path in h5f:
+ if not is_group(h5f[h5path]):
+ if update_mode == "replace":
+ del h5f[h5path]
+ h5f.create_group(h5path)
+ else:
+ return
+ else:
+ h5f.create_group(h5path)
+
+ # Loop over all groups, links and datasets
+ for key, value in _iter_treedict(attributes=False):
+ h5name = h5path + str(key)
+ exists = h5name in h5f
+
+ if value is None:
+ # Delete HDF5 item
+ if exists and change_allowed:
+ del h5f[h5name]
+ exists = False
+ elif isinstance(value, Mapping):
+ # HDF5 group
+ if exists and update_mode == "replace":
+ del h5f[h5name]
+ exists = False
+ if value:
+ dicttoh5(value, h5f, h5name,
+ update_mode=update_mode,
+ create_dataset_args=create_dataset_args)
+ elif not exists:
+ h5f.create_group(h5name)
+ elif is_link(value):
+ # HDF5 link
+ if exists and update_mode == "replace":
+ del h5f[h5name]
+ exists = False
+ if not exists:
+ # Create link from h5py link object
+ h5f[h5name] = value
+ else:
+ # HDF5 dataset
+ if exists and not change_allowed:
+ continue
+ data = _prepare_hdf5_write_value(value)
+
+ # Edit the existing dataset
+ attrs_backup = None
+ if exists:
+ try:
+ h5f[h5name][()] = data
+ continue
+ except Exception:
+ # Delete the existing dataset
+ if update_mode != "replace":
+ if not is_dataset(h5f[h5name]):
+ continue
+ attrs_backup = dict(h5f[h5name].attrs)
+ del h5f[h5name]
+
+ # Create dataset
+ # can't apply filters on scalars (datasets with shape == ())
+ if data.shape == () or create_dataset_args is None:
+ h5f.create_dataset(h5name,
+ data=data)
+ else:
+ h5f.create_dataset(h5name,
+ data=data,
+ **create_dataset_args)
+ if attrs_backup:
+ h5f[h5name].attrs.update(attrs_backup)
+
+ # Loop over all attributes
+ for key, value in _iter_treedict(attributes=True):
+ if len(key) != 2:
+ raise ValueError("HDF5 attribute must be described by 2 values")
+ h5name = h5path + key[0]
+ attr_name = key[1]
+
+ if h5name not in h5f:
+ # Create an empty group to store the attribute
+ h5f.create_group(h5name)
+
+ h5a = h5f[h5name].attrs
+ exists = attr_name in h5a
+
+ if value is None:
+ # Delete HDF5 attribute
+ if exists and change_allowed:
+ del h5a[attr_name]
+ exists = False
+ else:
+ # Add/modify HDF5 attribute
+ if exists and not change_allowed:
+ continue
+ data = _prepare_hdf5_write_value(value)
+ h5a[attr_name] = data
+
+
+def _has_nx_class(treedict, key=""):
+ return key + "@NX_class" in treedict or \
+ (key, "NX_class") in treedict
+
+
+def _ensure_nx_class(treedict, parents=tuple()):
+ """Each group needs an "NX_class" attribute.
+ """
+ if _has_nx_class(treedict):
+ return
+ nparents = len(parents)
+ if nparents == 0:
+ treedict[("", "NX_class")] = "NXroot"
+ elif nparents == 1:
+ treedict[("", "NX_class")] = "NXentry"
+ else:
+ treedict[("", "NX_class")] = "NXcollection"
+
+
+def nexus_to_h5_dict(
+ treedict, parents=tuple(), add_nx_class=True, has_nx_class=False
+):
+ """The following conversions are applied:
+ * key with "{name}@{attr_name}" notation: key converted to 2-tuple
+ * key with ">{url}" notation: strip ">" and convert value to
+ h5py.SoftLink or h5py.ExternalLink
+
+ :param treedict: Nested dictionary/tree structure with strings as keys
+ and array-like objects as leafs. The ``"/"`` character can be used
+ to define sub tree. The ``"@"`` character is used to write attributes.
+ The ``">"`` prefix is used to define links.
+ :param parents: Needed to resolve up-links (tuple of HDF5 group names)
+ :param add_nx_class: Add "NX_class" attribute when missing
+ :param has_nx_class: The "NX_class" attribute is defined in the parent
+
+ :rtype dict:
+ """
+ if not isinstance(treedict, Mapping):
+ raise TypeError("'treedict' must be a dictionary")
+ copy = dict()
+ for key, value in treedict.items():
+ if "@" in key:
+ # HDF5 attribute
+ key = tuple(key.rsplit("@", 1))
+ elif key.startswith(">"):
+ # HDF5 link
+ if isinstance(value, str):
+ key = key[1:]
+ first, sep, second = value.partition("::")
+ if sep:
+ value = h5py.ExternalLink(first, second)
+ else:
+ if ".." in first:
+ # Up-links not supported: make absolute
+ parts = []
+ for p in list(parents) + first.split("/"):
+ if not p or p == ".":
+ continue
+ elif p == "..":
+ parts.pop(-1)
+ else:
+ parts.append(p)
+ first = "/" + "/".join(parts)
+ value = h5py.SoftLink(first)
+ elif is_link(value):
+ key = key[1:]
+ if isinstance(value, Mapping):
+ # HDF5 group
+ key_has_nx_class = add_nx_class and _has_nx_class(treedict, key)
+ copy[key] = nexus_to_h5_dict(
+ value,
+ parents=parents+(key,),
+ add_nx_class=add_nx_class,
+ has_nx_class=key_has_nx_class)
+ else:
+ # HDF5 dataset or link
+ copy[key] = value
+ if add_nx_class and not has_nx_class:
+ _ensure_nx_class(copy, parents)
+ return copy
+
+
+def h5_to_nexus_dict(treedict):
+ """The following conversions are applied:
+ * 2-tuple key: converted to string ("@" notation)
+ * h5py.Softlink value: converted to string (">" key prefix)
+ * h5py.ExternalLink value: converted to string (">" key prefix)
+
+ :param treedict: Nested dictionary/tree structure with strings as keys
+ and array-like objects as leafs. The ``"/"`` character can be used
+ to define sub tree.
+
+ :rtype dict:
+ """
+ copy = dict()
+ for key, value in treedict.items():
+ if isinstance(key, tuple):
+ if len(key) != 2:
+ raise ValueError("HDF5 attribute must be described by 2 values")
+ key = "%s@%s" % (key[0], key[1])
+ elif is_softlink(value):
+ key = ">" + key
+ value = value.path
+ elif is_externallink(value):
+ key = ">" + key
+ value = value.filename + "::" + value.path
+ if isinstance(value, Mapping):
+ copy[key] = h5_to_nexus_dict(value)
+ else:
+ copy[key] = value
+ return copy
+
+
+def _name_contains_string_in_list(name, strlist):
+ if strlist is None:
+ return False
+ for filter_str in strlist:
+ if filter_str in name:
+ return True
+ return False
+
+
+def _handle_error(mode: str, exception, msg: str, *args) -> None:
+ """Handle errors.
+
+ :param str mode: 'raise', 'log', 'ignore'
+ :param type exception: Exception class to use in 'raise' mode
+ :param str msg: Error message template
+ :param List[str] args: Arguments for error message template
+ """
+ if mode == 'ignore':
+ return # no-op
+ elif mode == 'log':
+ logger.error(msg, *args)
+ elif mode == 'raise':
+ raise exception(msg % args)
+ else:
+ raise ValueError("Unsupported error handling: %s" % mode)
+
+
+def h5todict(h5file,
+ path="/",
+ exclude_names=None,
+ asarray=True,
+ dereference_links=True,
+ include_attributes=False,
+ errors='raise'):
+ """Read a HDF5 file and return a nested dictionary with the complete file
+ structure and all data.
+
+ Example of usage::
+
+ from silx.io.dictdump import h5todict
+
+ # initialize dict with file header and scan header
+ header94 = h5todict("oleg.dat",
+ "/94.1/instrument/specfile")
+ # add positioners subdict
+ header94["positioners"] = h5todict("oleg.dat",
+ "/94.1/instrument/positioners")
+ # add scan data without mca data
+ header94["detector data"] = h5todict("oleg.dat",
+ "/94.1/measurement",
+ exclude_names="mca_")
+
+
+ .. note:: This function requires `h5py <http://www.h5py.org/>`_ to be
+ installed.
+
+ .. note:: If you write a dictionary to a HDF5 file with
+ :func:`dicttoh5` and then read it back with :func:`h5todict`, data
+ types are not preserved. All values are cast to numpy arrays before
+ being written to file, and they are read back as numpy arrays (or
+ scalars). In some cases, you may find that a list of heterogeneous
+ data types is converted to a numpy array of strings.
+
+ :param h5file: File name or h5py-like File, Group or Dataset
+ :param str path: Target path in the HDF5 file relative to ``h5file``
+ :param List[str] exclude_names: Groups and datasets whose name contains
+ a string in this list will be ignored. Default is None (ignore nothing)
+ :param bool asarray: True (default) to read scalar as arrays, False to
+ read them as scalar
+ :param bool dereference_links: True (default) to dereference links, False
+ to preserve the link itself
+ :param bool include_attributes: False (default)
+ :param str errors: Handling of errors (HDF5 access issue, broken link,...):
+ - 'raise' (default): Raise an exception
+ - 'log': Log as errors
+ - 'ignore': Ignore errors
+ :return: Nested dictionary
+ """
+ h5file, path = _normalize_h5_path(h5file, path)
+ with _SafeH5FileRead(h5file) as h5f:
+ ddict = {}
+ if path not in h5f:
+ _handle_error(
+ errors, KeyError, 'Path "%s" does not exist in file.', path)
+ return ddict
+
+ try:
+ root = h5f[path]
+ except KeyError as e:
+ if not isinstance(h5f.get(path, getlink=True), h5py.HardLink):
+ _handle_error(errors,
+ KeyError,
+ 'Cannot retrieve path "%s" (broken link)',
+ path)
+ else:
+ _handle_error(errors, KeyError, ', '.join(e.args))
+ return ddict
+
+ # Read the attributes of the group
+ if include_attributes:
+ attrs = H5pyAttributesReadWrapper(root.attrs)
+ for aname, avalue in attrs.items():
+ ddict[("", aname)] = avalue
+ # Read the children of the group
+ for key in root:
+ if _name_contains_string_in_list(key, exclude_names):
+ continue
+ h5name = path + "/" + key
+ # Preserve HDF5 link when requested
+ if not dereference_links:
+ lnk = h5f.get(h5name, getlink=True)
+ if is_link(lnk):
+ ddict[key] = lnk
+ continue
+
+ try:
+ h5obj = h5f[h5name]
+ except KeyError as e:
+ if not isinstance(h5f.get(h5name, getlink=True), h5py.HardLink):
+ _handle_error(errors,
+ KeyError,
+ 'Cannot retrieve path "%s" (broken link)',
+ h5name)
+ else:
+ _handle_error(errors, KeyError, ', '.join(e.args))
+ continue
+
+ if is_group(h5obj):
+ # Child is an HDF5 group
+ ddict[key] = h5todict(h5f,
+ h5name,
+ exclude_names=exclude_names,
+ asarray=asarray,
+ dereference_links=dereference_links,
+ include_attributes=include_attributes)
+ else:
+ # Child is an HDF5 dataset
+ try:
+ data = h5py_read_dataset(h5obj)
+ except OSError:
+ _handle_error(errors,
+ OSError,
+ 'Cannot retrieve dataset "%s"',
+ h5name)
+ else:
+ if asarray: # Convert HDF5 dataset to numpy array
+ data = numpy.array(data, copy=False)
+ ddict[key] = data
+ # Read the attributes of the child
+ if include_attributes:
+ attrs = H5pyAttributesReadWrapper(h5obj.attrs)
+ for aname, avalue in attrs.items():
+ ddict[(key, aname)] = avalue
+ return ddict
+
+
+def dicttonx(treedict, h5file, h5path="/", add_nx_class=None, **kw):
+ """
+ Write a nested dictionary to a HDF5 file, using string keys as member names.
+ The NeXus convention is used to identify attributes with ``"@"`` character,
+ therefore the dataset_names should not contain ``"@"``.
+
+ Similarly, links are identified by keys starting with the ``">"`` character.
+ The corresponding value can be a soft or external link.
+
+ :param treedict: Nested dictionary/tree structure with strings as keys
+ and array-like objects as leafs. The ``"/"`` character can be used
+ to define sub tree. The ``"@"`` character is used to write attributes.
+ The ``">"`` prefix is used to define links.
+ :param add_nx_class: Add "NX_class" attribute when missing. By default it
+ is ``True`` when ``update_mode`` is ``"add"`` or ``None``.
+
+ The named parameters are passed to dicttoh5.
+
+ Example::
+
+ import numpy
+ from silx.io.dictdump import dicttonx
+
+ gauss = {
+ "entry":{
+ "title":u"A plot of a gaussian",
+ "instrument": {
+ "@NX_class": u"NXinstrument",
+ "positioners": {
+ "@NX_class": u"NXCollection",
+ "x": numpy.arange(0,1.1,.1)
+ }
+ }
+ "plot": {
+ "y": numpy.array([0.08, 0.19, 0.39, 0.66, 0.9, 1.,
+ 0.9, 0.66, 0.39, 0.19, 0.08]),
+ ">x": "../instrument/positioners/x",
+ "@signal": "y",
+ "@axes": "x",
+ "@NX_class":u"NXdata",
+ "title:u"Gauss Plot",
+ },
+ "@NX_class": u"NXentry",
+ "default":"plot",
+ }
+ "@NX_class": u"NXroot",
+ "@default": "entry",
+ }
+
+ dicttonx(gauss,"test.h5")
+ """
+ h5file, h5path = _normalize_h5_path(h5file, h5path)
+ parents = tuple(p for p in h5path.split("/") if p)
+ if add_nx_class is None:
+ add_nx_class = kw.get("update_mode", None) in (None, "add")
+ nxtreedict = nexus_to_h5_dict(
+ treedict, parents=parents, add_nx_class=add_nx_class
+ )
+ dicttoh5(nxtreedict, h5file, h5path=h5path, **kw)
+
+
+def nxtodict(h5file, include_attributes=True, **kw):
+ """Read a HDF5 file and return a nested dictionary with the complete file
+ structure and all data.
+
+ As opposed to h5todict, all keys will be strings and no h5py objects are
+ present in the tree.
+
+ The named parameters are passed to h5todict.
+ """
+ nxtreedict = h5todict(h5file, include_attributes=include_attributes, **kw)
+ return h5_to_nexus_dict(nxtreedict)
+
+
+def dicttojson(ddict, jsonfile, indent=None, mode="w"):
+ """Serialize ``ddict`` as a JSON formatted stream to ``jsonfile``.
+
+ :param ddict: Dictionary (or any object compatible with ``json.dump``).
+ :param jsonfile: JSON file name or file-like object.
+ If a file name is provided, the function opens the file in the
+ specified mode and closes it again.
+ :param indent: If indent is a non-negative integer, then JSON array
+ elements and object members will be pretty-printed with that indent
+ level. An indent level of ``0`` will only insert newlines.
+ ``None`` (the default) selects the most compact representation.
+ :param mode: File opening mode (``w``, ``a``, ``w+``…)
+ """
+ if not hasattr(jsonfile, "write"):
+ jsonf = open(jsonfile, mode)
+ else:
+ jsonf = jsonfile
+
+ json.dump(ddict, jsonf, indent=indent)
+
+ if not hasattr(jsonfile, "write"):
+ jsonf.close()
+
+
+def dicttoini(ddict, inifile, mode="w"):
+ """Output dict as configuration file (similar to Microsoft Windows INI).
+
+ :param dict: Dictionary of configuration parameters
+ :param inifile: INI file name or file-like object.
+ If a file name is provided, the function opens the file in the
+ specified mode and closes it again.
+ :param mode: File opening mode (``w``, ``a``, ``w+``…)
+ """
+ if not hasattr(inifile, "write"):
+ inif = open(inifile, mode)
+ else:
+ inif = inifile
+
+ ConfigDict(initdict=ddict).write(inif)
+
+ if not hasattr(inifile, "write"):
+ inif.close()
+
+
+def dump(ddict, ffile, mode="w", fmat=None):
+ """Dump dictionary to a file
+
+ :param ddict: Dictionary with string keys
+ :param ffile: File name or file-like object with a ``write`` method
+ :param str fmat: Output format: ``"json"``, ``"hdf5"`` or ``"ini"``.
+ When None (the default), it uses the filename extension as the format.
+ Dumping to a HDF5 file requires `h5py <http://www.h5py.org/>`_ to be
+ installed.
+ :param str mode: File opening mode (``w``, ``a``, ``w+``…)
+ Default is *"w"*, write mode, overwrite if exists.
+ :raises IOError: if file format is not supported
+ """
+ if fmat is None:
+ # If file-like object get its name, else use ffile as filename
+ filename = getattr(ffile, 'name', ffile)
+ fmat = os.path.splitext(filename)[1][1:] # Strip extension leading '.'
+ fmat = fmat.lower()
+
+ if fmat == "json":
+ dicttojson(ddict, ffile, indent=2, mode=mode)
+ elif fmat in ["hdf5", "h5"]:
+ dicttoh5(ddict, ffile, mode=mode)
+ elif fmat in ["ini", "cfg"]:
+ dicttoini(ddict, ffile, mode=mode)
+ else:
+ raise IOError("Unknown format " + fmat)
+
+
+def load(ffile, fmat=None):
+ """Load dictionary from a file
+
+ When loading from a JSON or INI file, an OrderedDict is returned to
+ preserve the values' insertion order.
+
+ :param ffile: File name or file-like object with a ``read`` method
+ :param fmat: Input format: ``json``, ``hdf5`` or ``ini``.
+ When None (the default), it uses the filename extension as the format.
+ Loading from a HDF5 file requires `h5py <http://www.h5py.org/>`_ to be
+ installed.
+ :return: Dictionary (ordered dictionary for JSON and INI)
+ :raises IOError: if file format is not supported
+ """
+ must_be_closed = False
+ if not hasattr(ffile, "read"):
+ f = open(ffile, "r")
+ fname = ffile
+ must_be_closed = True
+ else:
+ f = ffile
+ fname = ffile.name
+
+ try:
+ if fmat is None: # Use file extension as format
+ fmat = os.path.splitext(fname)[1][1:] # Strip extension leading '.'
+ fmat = fmat.lower()
+
+ if fmat == "json":
+ return json.load(f, object_pairs_hook=OrderedDict)
+ if fmat in ["hdf5", "h5"]:
+ return h5todict(fname)
+ elif fmat in ["ini", "cfg"]:
+ return ConfigDict(filelist=[fname])
+ else:
+ raise IOError("Unknown format " + fmat)
+ finally:
+ if must_be_closed:
+ f.close()
diff --git a/src/silx/io/fabioh5.py b/src/silx/io/fabioh5.py
new file mode 100755
index 0000000..af9b29a
--- /dev/null
+++ b/src/silx/io/fabioh5.py
@@ -0,0 +1,1050 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides functions to read fabio images as an HDF5 file.
+
+ >>> import silx.io.fabioh5
+ >>> f = silx.io.fabioh5.File("foobar.edf")
+
+.. note:: This module has a dependency on the `h5py <http://www.h5py.org/>`_
+ and `fabio <https://github.com/silx-kit/fabio>`_ libraries,
+ which are not mandatory dependencies for `silx`.
+
+"""
+
+import collections
+import datetime
+import logging
+import numbers
+import os
+
+import fabio.file_series
+import numpy
+
+from . import commonh5
+from silx import version as silx_version
+import silx.utils.number
+import h5py
+
+
+_logger = logging.getLogger(__name__)
+
+
+_fabio_extensions = set([])
+
+
+def supported_extensions():
+ """Returns all extensions supported by fabio.
+
+ :returns: A set containing extensions like "*.edf".
+ :rtype: Set[str]
+ """
+ global _fabio_extensions
+ if len(_fabio_extensions) > 0:
+ return _fabio_extensions
+
+ formats = fabio.fabioformats.get_classes(reader=True)
+ all_extensions = set([])
+
+ for reader in formats:
+ if not hasattr(reader, "DEFAULT_EXTENSIONS"):
+ continue
+
+ ext = reader.DEFAULT_EXTENSIONS
+ ext = ["*.%s" % e for e in ext]
+ all_extensions.update(ext)
+
+ _fabio_extensions = set(all_extensions)
+ return _fabio_extensions
+
+
+class _FileSeries(fabio.file_series.file_series):
+ """
+ .. note:: Overwrite a function to fix an issue in fabio.
+ """
+ def jump(self, num):
+ """
+ Goto a position in sequence
+ """
+ assert num < len(self) and num >= 0, "num out of range"
+ self._current = num
+ return self[self._current]
+
+
+class FrameData(commonh5.LazyLoadableDataset):
+ """Expose a cube of image from a Fabio file using `FabioReader` as
+ cache."""
+
+ def __init__(self, name, fabio_reader, parent=None):
+ if fabio_reader.is_spectrum():
+ attrs = {"interpretation": "spectrum"}
+ else:
+ attrs = {"interpretation": "image"}
+ commonh5.LazyLoadableDataset.__init__(self, name, parent, attrs=attrs)
+ self.__fabio_reader = fabio_reader
+ self._shape = None
+ self._dtype = None
+
+ def _create_data(self):
+ return self.__fabio_reader.get_data()
+
+ def _update_cache(self):
+ if isinstance(self.__fabio_reader.fabio_file(),
+ fabio.file_series.file_series):
+ # Reading all the files is taking too much time
+ # Reach the information from the only first frame
+ first_image = self.__fabio_reader.fabio_file().first_image()
+ self._dtype = first_image.data.dtype
+ shape0 = self.__fabio_reader.frame_count()
+ shape1, shape2 = first_image.data.shape
+ self._shape = shape0, shape1, shape2
+ else:
+ self._dtype = super(commonh5.LazyLoadableDataset, self).dtype
+ self._shape = super(commonh5.LazyLoadableDataset, self).shape
+
+ @property
+ def dtype(self):
+ if self._dtype is None:
+ self._update_cache()
+ return self._dtype
+
+ @property
+ def shape(self):
+ if self._shape is None:
+ self._update_cache()
+ return self._shape
+
+ def __iter__(self):
+ for frame in self.__fabio_reader.iter_frames():
+ yield frame.data
+
+ def __getitem__(self, item):
+ # optimization for fetching a single frame if data not already loaded
+ if not self._is_initialized:
+ if isinstance(item, int) and \
+ isinstance(self.__fabio_reader.fabio_file(),
+ fabio.file_series.file_series):
+ if item < 0:
+ # negative indexing
+ item += len(self)
+ return self.__fabio_reader.fabio_file().jump_image(item).data
+ return super(FrameData, self).__getitem__(item)
+
+
+class RawHeaderData(commonh5.LazyLoadableDataset):
+ """Lazy loadable raw header"""
+
+ def __init__(self, name, fabio_reader, parent=None):
+ commonh5.LazyLoadableDataset.__init__(self, name, parent)
+ self.__fabio_reader = fabio_reader
+
+ def _create_data(self):
+ """Initialize hold data by merging all headers of each frames.
+ """
+ headers = []
+ types = set([])
+ for fabio_frame in self.__fabio_reader.iter_frames():
+ header = fabio_frame.header
+
+ data = []
+ for key, value in header.items():
+ data.append("%s: %s" % (str(key), str(value)))
+
+ data = "\n".join(data)
+ try:
+ line = data.encode("ascii")
+ types.add(numpy.string_)
+ except UnicodeEncodeError:
+ try:
+ line = data.encode("utf-8")
+ types.add(numpy.unicode_)
+ except UnicodeEncodeError:
+ # Fallback in void
+ line = numpy.void(data)
+ types.add(numpy.void)
+
+ headers.append(line)
+
+ if numpy.void in types:
+ dtype = numpy.void
+ elif numpy.unicode_ in types:
+ dtype = numpy.unicode_
+ else:
+ dtype = numpy.string_
+
+ if dtype == numpy.unicode_:
+ # h5py only support vlen unicode
+ dtype = h5py.special_dtype(vlen=str)
+
+ return numpy.array(headers, dtype=dtype)
+
+
+class MetadataGroup(commonh5.LazyLoadableGroup):
+ """Abstract class for groups containing a reference to a fabio image.
+ """
+
+ def __init__(self, name, metadata_reader, kind, parent=None, attrs=None):
+ commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
+ self.__metadata_reader = metadata_reader
+ self.__kind = kind
+
+ def _create_child(self):
+ keys = self.__metadata_reader.get_keys(self.__kind)
+ for name in keys:
+ data = self.__metadata_reader.get_value(self.__kind, name)
+ dataset = commonh5.Dataset(name, data)
+ self.add_node(dataset)
+
+ @property
+ def _metadata_reader(self):
+ return self.__metadata_reader
+
+
+class DetectorGroup(commonh5.LazyLoadableGroup):
+ """Define the detector group (sub group of instrument) using Fabio data.
+ """
+
+ def __init__(self, name, fabio_reader, parent=None, attrs=None):
+ if attrs is None:
+ attrs = {"NX_class": "NXdetector"}
+ commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
+ self.__fabio_reader = fabio_reader
+
+ def _create_child(self):
+ data = FrameData("data", self.__fabio_reader)
+ self.add_node(data)
+
+ # TODO we should add here Nexus informations we can extract from the
+ # metadata
+
+ others = MetadataGroup("others", self.__fabio_reader, kind=FabioReader.DEFAULT)
+ self.add_node(others)
+
+
+class ImageGroup(commonh5.LazyLoadableGroup):
+ """Define the image group (sub group of measurement) using Fabio data.
+ """
+
+ def __init__(self, name, fabio_reader, parent=None, attrs=None):
+ commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
+ self.__fabio_reader = fabio_reader
+
+ def _create_child(self):
+ basepath = self.parent.parent.name
+ data = commonh5.SoftLink("data", path=basepath + "/instrument/detector_0/data")
+ self.add_node(data)
+ detector = commonh5.SoftLink("info", path=basepath + "/instrument/detector_0")
+ self.add_node(detector)
+
+
+class NxDataPreviewGroup(commonh5.LazyLoadableGroup):
+ """Define the NxData group which is used as the default NXdata to show the
+ content of the file.
+ """
+
+ def __init__(self, name, fabio_reader, parent=None):
+ if fabio_reader.is_spectrum():
+ interpretation = "spectrum"
+ else:
+ interpretation = "image"
+ attrs = {
+ "NX_class": "NXdata",
+ "interpretation": interpretation,
+ "signal": "data",
+ }
+ commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
+ self.__fabio_reader = fabio_reader
+
+ def _create_child(self):
+ basepath = self.parent.name
+ data = commonh5.SoftLink("data", path=basepath + "/instrument/detector_0/data")
+ self.add_node(data)
+
+
+class SampleGroup(commonh5.LazyLoadableGroup):
+ """Define the image group (sub group of measurement) using Fabio data.
+ """
+
+ def __init__(self, name, fabio_reader, parent=None):
+ attrs = {"NXclass": "NXsample"}
+ commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
+ self.__fabio_reader = fabio_reader
+
+ def _create_child(self):
+ if self.__fabio_reader.has_ub_matrix():
+ scalar = {"interpretation": "scalar"}
+ data = self.__fabio_reader.get_unit_cell_abc()
+ data = commonh5.Dataset("unit_cell_abc", data, attrs=scalar)
+ self.add_node(data)
+ unit_cell_data = numpy.zeros((1, 6), numpy.float32)
+ unit_cell_data[0, :3] = data
+ data = self.__fabio_reader.get_unit_cell_alphabetagamma()
+ data = commonh5.Dataset("unit_cell_alphabetagamma", data, attrs=scalar)
+ self.add_node(data)
+ unit_cell_data[0, 3:] = data
+ data = commonh5.Dataset("unit_cell", unit_cell_data, attrs=scalar)
+ self.add_node(data)
+ data = self.__fabio_reader.get_ub_matrix()
+ data = commonh5.Dataset("ub_matrix", data, attrs=scalar)
+ self.add_node(data)
+
+
+class MeasurementGroup(commonh5.LazyLoadableGroup):
+ """Define the measurement group for fabio file.
+ """
+
+ def __init__(self, name, fabio_reader, parent=None, attrs=None):
+ commonh5.LazyLoadableGroup.__init__(self, name, parent, attrs)
+ self.__fabio_reader = fabio_reader
+
+ def _create_child(self):
+ keys = self.__fabio_reader.get_keys(FabioReader.COUNTER)
+
+ # create image measurement but take care that no other metadata use
+ # this name
+ for i in range(1000):
+ name = "image_%i" % i
+ if name not in keys:
+ data = ImageGroup(name, self.__fabio_reader)
+ self.add_node(data)
+ break
+ else:
+ raise Exception("image_i for 0..1000 already used")
+
+ # add all counters
+ for name in keys:
+ data = self.__fabio_reader.get_value(FabioReader.COUNTER, name)
+ dataset = commonh5.Dataset(name, data)
+ self.add_node(dataset)
+
+
+class FabioReader(object):
+ """Class which read and cache data and metadata from a fabio image."""
+
+ DEFAULT = 0
+ COUNTER = 1
+ POSITIONER = 2
+
+ def __init__(self, file_name=None, fabio_image=None, file_series=None):
+ """
+ Constructor
+
+ :param str file_name: File name of the image file to read
+ :param fabio.fabioimage.FabioImage fabio_image: An already openned
+ :class:`fabio.fabioimage.FabioImage` instance.
+ :param Union[list[str],fabio.file_series.file_series] file_series: An
+ list of file name or a :class:`fabio.file_series.file_series`
+ instance
+ """
+ self.__at_least_32bits = False
+ self.__signed_type = False
+
+ self.__load(file_name, fabio_image, file_series)
+ self.__counters = {}
+ self.__positioners = {}
+ self.__measurements = {}
+ self.__key_filters = set([])
+ self.__data = None
+ self.__frame_count = self.frame_count()
+ self._read()
+
+ def __load(self, file_name=None, fabio_image=None, file_series=None):
+ if file_name is not None and fabio_image:
+ raise TypeError("Parameters file_name and fabio_image are mutually exclusive.")
+ if file_name is not None and fabio_image:
+ raise TypeError("Parameters fabio_image and file_series are mutually exclusive.")
+
+ self.__must_be_closed = False
+
+ if file_name is not None:
+ self.__fabio_file = fabio.open(file_name)
+ self.__must_be_closed = True
+ elif fabio_image is not None:
+ if isinstance(fabio_image, fabio.fabioimage.FabioImage):
+ self.__fabio_file = fabio_image
+ else:
+ raise TypeError("FabioImage expected but %s found.", fabio_image.__class__)
+ elif file_series is not None:
+ if isinstance(file_series, list):
+ self.__fabio_file = _FileSeries(file_series)
+ elif isinstance(file_series, fabio.file_series.file_series):
+ self.__fabio_file = file_series
+ else:
+ raise TypeError("file_series or list expected but %s found.", file_series.__class__)
+
+ def close(self):
+ """Close the object, and free up associated resources.
+
+ The associated FabioImage is closed only if the object was created from
+ a filename by this class itself.
+
+ After calling this method, attempts to use the object (and children)
+ may fail.
+ """
+ if self.__must_be_closed:
+ # Make sure the API of fabio provide it a 'close' method
+ # TODO the test can be removed if fabio version >= 0.8
+ if hasattr(self.__fabio_file, "close"):
+ self.__fabio_file.close()
+ self.__fabio_file = None
+
+ def fabio_file(self):
+ return self.__fabio_file
+
+ def frame_count(self):
+ """Returns the number of frames available."""
+ if isinstance(self.__fabio_file, fabio.file_series.file_series):
+ return len(self.__fabio_file)
+ elif isinstance(self.__fabio_file, fabio.fabioimage.FabioImage):
+ return self.__fabio_file.nframes
+ else:
+ raise TypeError("Unsupported type %s", self.__fabio_file.__class__)
+
+ def iter_frames(self):
+ """Iter all the available frames.
+
+ A frame provides at least `data` and `header` attributes.
+ """
+ if isinstance(self.__fabio_file, fabio.file_series.file_series):
+ for file_number in range(len(self.__fabio_file)):
+ with self.__fabio_file.jump_image(file_number) as fabio_image:
+ # return the first frame only
+ assert(fabio_image.nframes == 1)
+ yield fabio_image
+ elif isinstance(self.__fabio_file, fabio.fabioimage.FabioImage):
+ for frame_count in range(self.__fabio_file.nframes):
+ if self.__fabio_file.nframes == 1:
+ yield self.__fabio_file
+ else:
+ yield self.__fabio_file.getframe(frame_count)
+ else:
+ raise TypeError("Unsupported type %s", self.__fabio_file.__class__)
+
+ def _create_data(self):
+ """Initialize hold data by merging all frames into a single cube.
+
+ Choose the cube size which fit the best the data. If some images are
+ smaller than expected, the empty space is set to 0.
+
+ The computation is cached into the class, and only done ones.
+ """
+ images = []
+ for fabio_frame in self.iter_frames():
+ images.append(fabio_frame.data)
+
+ # returns the data without extra dim in case of single frame
+ if len(images) == 1:
+ return images[0]
+
+ # get the max size
+ max_dim = max([i.ndim for i in images])
+ max_shape = [0] * max_dim
+ for image in images:
+ for dim in range(image.ndim):
+ if image.shape[dim] > max_shape[dim]:
+ max_shape[dim] = image.shape[dim]
+ max_shape = tuple(max_shape)
+
+ # fix smallest images
+ for index, image in enumerate(images):
+ if image.shape == max_shape:
+ continue
+ location = [slice(0, i) for i in image.shape]
+ while len(location) < max_dim:
+ location.append(0)
+ normalized_image = numpy.zeros(max_shape, dtype=image.dtype)
+ normalized_image[tuple(location)] = image
+ images[index] = normalized_image
+
+ # create a cube
+ return numpy.array(images)
+
+ def __get_dict(self, kind):
+ """Returns a dictionary from according to an expected kind"""
+ if kind == self.DEFAULT:
+ return self.__measurements
+ elif kind == self.COUNTER:
+ return self.__counters
+ elif kind == self.POSITIONER:
+ return self.__positioners
+ else:
+ raise Exception("Unexpected kind %s", kind)
+
+ def get_data(self):
+ """Returns a cube from all available data from frames
+
+ :rtype: numpy.ndarray
+ """
+ if self.__data is None:
+ self.__data = self._create_data()
+ return self.__data
+
+ def get_keys(self, kind):
+ """Get all available keys according to a kind of metadata.
+
+ :rtype: list
+ """
+ return self.__get_dict(kind).keys()
+
+ def get_value(self, kind, name):
+ """Get a metadata value according to the kind and the name.
+
+ :rtype: numpy.ndarray
+ """
+ value = self.__get_dict(kind)[name]
+ if not isinstance(value, numpy.ndarray):
+ if kind in [self.COUNTER, self.POSITIONER]:
+ # Force normalization for counters and positioners
+ old = self._set_vector_normalization(at_least_32bits=True, signed_type=True)
+ else:
+ old = None
+ value = self._convert_metadata_vector(value)
+ self.__get_dict(kind)[name] = value
+ if old is not None:
+ self._set_vector_normalization(*old)
+ return value
+
+ def _set_counter_value(self, frame_id, name, value):
+ """Set a counter metadata according to the frame id"""
+ if name not in self.__counters:
+ self.__counters[name] = [None] * self.__frame_count
+ self.__counters[name][frame_id] = value
+
+ def _set_positioner_value(self, frame_id, name, value):
+ """Set a positioner metadata according to the frame id"""
+ if name not in self.__positioners:
+ self.__positioners[name] = [None] * self.__frame_count
+ self.__positioners[name][frame_id] = value
+
+ def _set_measurement_value(self, frame_id, name, value):
+ """Set a measurement metadata according to the frame id"""
+ if name not in self.__measurements:
+ self.__measurements[name] = [None] * self.__frame_count
+ self.__measurements[name][frame_id] = value
+
+ def _enable_key_filters(self, fabio_file):
+ self.__key_filters.clear()
+ if hasattr(fabio_file, "RESERVED_HEADER_KEYS"):
+ # Provided in fabio 0.5
+ for key in fabio_file.RESERVED_HEADER_KEYS:
+ self.__key_filters.add(key.lower())
+
+ def _read(self):
+ """Read all metadata from the fabio file and store it into this
+ object."""
+
+ file_series = isinstance(self.__fabio_file, fabio.file_series.file_series)
+ if not file_series:
+ self._enable_key_filters(self.__fabio_file)
+
+ for frame_id, fabio_frame in enumerate(self.iter_frames()):
+ if file_series:
+ self._enable_key_filters(fabio_frame)
+ self._read_frame(frame_id, fabio_frame.header)
+
+ def _is_filtered_key(self, key):
+ """
+ If this function returns True, the :meth:`_read_key` while not be
+ called with this `key`while reading the metatdata frame.
+
+ :param str key: A key of the metadata
+ :rtype: bool
+ """
+ return key.lower() in self.__key_filters
+
+ def _read_frame(self, frame_id, header):
+ """Read all metadata from a frame and store it into this
+ object."""
+ for key, value in header.items():
+ if self._is_filtered_key(key):
+ continue
+ self._read_key(frame_id, key, value)
+
+ def _read_key(self, frame_id, name, value):
+ """Read a key from the metadata and cache it into this object."""
+ self._set_measurement_value(frame_id, name, value)
+
+ def _set_vector_normalization(self, at_least_32bits, signed_type):
+ previous = self.__at_least_32bits, self.__signed_type
+ self.__at_least_32bits = at_least_32bits
+ self.__signed_type = signed_type
+ return previous
+
+ def _normalize_vector_type(self, dtype):
+ """Normalize the """
+ if self.__at_least_32bits:
+ if numpy.issubdtype(dtype, numpy.signedinteger):
+ dtype = numpy.result_type(dtype, numpy.uint32)
+ if numpy.issubdtype(dtype, numpy.unsignedinteger):
+ dtype = numpy.result_type(dtype, numpy.uint32)
+ elif numpy.issubdtype(dtype, numpy.floating):
+ dtype = numpy.result_type(dtype, numpy.float32)
+ elif numpy.issubdtype(dtype, numpy.complexfloating):
+ dtype = numpy.result_type(dtype, numpy.complex64)
+ if self.__signed_type:
+ if numpy.issubdtype(dtype, numpy.unsignedinteger):
+ signed = numpy.dtype("%s%i" % ('i', dtype.itemsize))
+ dtype = numpy.result_type(dtype, signed)
+ return dtype
+
+ def _convert_metadata_vector(self, values):
+ """Convert a list of numpy data into a numpy array with the better
+ fitting type."""
+ converted = []
+ types = set([])
+ has_none = False
+ is_array = False
+ array = []
+
+ for v in values:
+ if v is None:
+ converted.append(None)
+ has_none = True
+ array.append(None)
+ else:
+ c = self._convert_value(v)
+ if c.shape != tuple():
+ array.append(v.split(" "))
+ is_array = True
+ else:
+ array.append(v)
+ converted.append(c)
+ types.add(c.dtype)
+
+ if has_none and len(types) == 0:
+ # That's a list of none values
+ return numpy.array([0] * len(values), numpy.int8)
+
+ result_type = numpy.result_type(*types)
+
+ if issubclass(result_type.type, numpy.string_):
+ # use the raw data to create the array
+ result = values
+ elif issubclass(result_type.type, numpy.unicode_):
+ # use the raw data to create the array
+ result = values
+ else:
+ result = converted
+
+ result_type = self._normalize_vector_type(result_type)
+
+ if has_none:
+ # Fix missing data according to the array type
+ if result_type.kind == "S":
+ none_value = b""
+ elif result_type.kind == "U":
+ none_value = u""
+ elif result_type.kind == "f":
+ none_value = numpy.float64("NaN")
+ elif result_type.kind == "i":
+ none_value = numpy.int64(0)
+ elif result_type.kind == "u":
+ none_value = numpy.int64(0)
+ elif result_type.kind == "b":
+ none_value = numpy.bool_(False)
+ else:
+ none_value = None
+
+ for index, r in enumerate(result):
+ if r is not None:
+ continue
+ result[index] = none_value
+ values[index] = none_value
+ array[index] = none_value
+
+ if result_type.kind in "uifd" and len(types) > 1 and len(values) > 1:
+ # Catch numerical precision
+ if is_array and len(array) > 1:
+ return numpy.array(array, dtype=result_type)
+ else:
+ return numpy.array(values, dtype=result_type)
+ return numpy.array(result, dtype=result_type)
+
+ def _convert_value(self, value):
+ """Convert a string into a numpy object (scalar or array).
+
+ The value is most of the time a string, but it can be python object
+ in case if TIFF decoder for example.
+ """
+ if isinstance(value, list):
+ # convert to a numpy array
+ return numpy.array(value)
+ if isinstance(value, dict):
+ # convert to a numpy associative array
+ key_dtype = numpy.min_scalar_type(list(value.keys()))
+ value_dtype = numpy.min_scalar_type(list(value.values()))
+ associative_type = [('key', key_dtype), ('value', value_dtype)]
+ assert key_dtype.kind != "O" and value_dtype.kind != "O"
+ return numpy.array(list(value.items()), dtype=associative_type)
+ if isinstance(value, numbers.Number):
+ dtype = numpy.min_scalar_type(value)
+ assert dtype.kind != "O"
+ return dtype.type(value)
+
+ if isinstance(value, bytes):
+ try:
+ value = value.decode('utf-8')
+ except UnicodeDecodeError:
+ return numpy.void(value)
+
+ if " " in value:
+ result = self._convert_list(value)
+ else:
+ result = self._convert_scalar_value(value)
+ return result
+
+ def _convert_scalar_value(self, value):
+ """Convert a string into a numpy int or float.
+
+ If it is not possible it returns a numpy string.
+ """
+ try:
+ numpy_type = silx.utils.number.min_numerical_convertible_type(value)
+ converted = numpy_type(value)
+ except ValueError:
+ converted = numpy.string_(value)
+ return converted
+
+ def _convert_list(self, value):
+ """Convert a string into a typed numpy array.
+
+ If it is not possible it returns a numpy string.
+ """
+ try:
+ numpy_values = []
+ values = value.split(" ")
+ types = set([])
+ for string_value in values:
+ v = self._convert_scalar_value(string_value)
+ numpy_values.append(v)
+ types.add(v.dtype.type)
+
+ result_type = numpy.result_type(*types)
+
+ if issubclass(result_type.type, (numpy.string_, bytes)):
+ # use the raw data to create the result
+ return numpy.string_(value)
+ elif issubclass(result_type.type, (numpy.unicode_, str)):
+ # use the raw data to create the result
+ return numpy.unicode_(value)
+ else:
+ if len(types) == 1:
+ return numpy.array(numpy_values, dtype=result_type)
+ else:
+ return numpy.array(values, dtype=result_type)
+ except ValueError:
+ return numpy.string_(value)
+
+ def has_sample_information(self):
+ """Returns true if there is information about the sample in the
+ file
+
+ :rtype: bool
+ """
+ return self.has_ub_matrix()
+
+ def has_ub_matrix(self):
+ """Returns true if a UB matrix is available.
+
+ :rtype: bool
+ """
+ return False
+
+ def is_spectrum(self):
+ """Returns true if the data should be interpreted as
+ MCA data.
+
+ :rtype: bool
+ """
+ return False
+
+
+class EdfFabioReader(FabioReader):
+ """Class which read and cache data and metadata from a fabio image.
+
+ It is mostly the same as FabioReader, but counter_mne and
+ motor_mne are parsed using a special way.
+ """
+
+ def __init__(self, file_name=None, fabio_image=None, file_series=None):
+ FabioReader.__init__(self, file_name, fabio_image, file_series)
+ self.__unit_cell_abc = None
+ self.__unit_cell_alphabetagamma = None
+ self.__ub_matrix = None
+
+ def _read_frame(self, frame_id, header):
+ """Overwrite the method to check and parse special keys: counter and
+ motors keys."""
+ self.__catch_keys = set([])
+ if "motor_pos" in header and "motor_mne" in header:
+ self.__catch_keys.add("motor_pos")
+ self.__catch_keys.add("motor_mne")
+ self._read_mnemonic_key(frame_id, "motor", header)
+ if "counter_pos" in header and "counter_mne" in header:
+ self.__catch_keys.add("counter_pos")
+ self.__catch_keys.add("counter_mne")
+ self._read_mnemonic_key(frame_id, "counter", header)
+ FabioReader._read_frame(self, frame_id, header)
+
+ def _is_filtered_key(self, key):
+ if key in self.__catch_keys:
+ return True
+ return FabioReader._is_filtered_key(self, key)
+
+ def _get_mnemonic_key(self, base_key, header):
+ mnemonic_values_key = base_key + "_mne"
+ mnemonic_values = header.get(mnemonic_values_key, "")
+ mnemonic_values = mnemonic_values.split()
+ pos_values_key = base_key + "_pos"
+ pos_values = header.get(pos_values_key, "")
+ pos_values = pos_values.split()
+
+ result = collections.OrderedDict()
+ nbitems = max(len(mnemonic_values), len(pos_values))
+ for i in range(nbitems):
+ if i < len(mnemonic_values):
+ mnemonic = mnemonic_values[i]
+ else:
+ # skip the element
+ continue
+
+ if i < len(pos_values):
+ pos = pos_values[i]
+ else:
+ pos = None
+
+ result[mnemonic] = pos
+ return result
+
+ def _read_mnemonic_key(self, frame_id, base_key, header):
+ """Parse a mnemonic key"""
+ is_counter = base_key == "counter"
+ is_positioner = base_key == "motor"
+ data = self._get_mnemonic_key(base_key, header)
+
+ for mnemonic, pos in data.items():
+ if is_counter:
+ self._set_counter_value(frame_id, mnemonic, pos)
+ elif is_positioner:
+ self._set_positioner_value(frame_id, mnemonic, pos)
+ else:
+ raise Exception("State unexpected (base_key: %s)" % base_key)
+
+ def _get_first_header(self):
+ """
+ ..note:: This function can be cached
+ """
+ fabio_file = self.fabio_file()
+ if isinstance(fabio_file, fabio.file_series.file_series):
+ return fabio_file.jump_image(0).header
+ return fabio_file.header
+
+ def has_ub_matrix(self):
+ """Returns true if a UB matrix is available.
+
+ :rtype: bool
+ """
+ header = self._get_first_header()
+ expected_keys = set(["UB_mne", "UB_pos", "sample_mne", "sample_pos"])
+ return expected_keys.issubset(header)
+
+ def parse_ub_matrix(self):
+ header = self._get_first_header()
+ ub_data = self._get_mnemonic_key("UB", header)
+ s_data = self._get_mnemonic_key("sample", header)
+ if len(ub_data) > 9:
+ _logger.warning("UB_mne and UB_pos contains more than expected keys.")
+ if len(s_data) > 6:
+ _logger.warning("sample_mne and sample_pos contains more than expected keys.")
+
+ data = numpy.array([s_data["U0"], s_data["U1"], s_data["U2"]], dtype=float)
+ unit_cell_abc = data
+
+ data = numpy.array([s_data["U3"], s_data["U4"], s_data["U5"]], dtype=float)
+ unit_cell_alphabetagamma = data
+
+ ub_matrix = numpy.array([[
+ [ub_data["UB0"], ub_data["UB1"], ub_data["UB2"]],
+ [ub_data["UB3"], ub_data["UB4"], ub_data["UB5"]],
+ [ub_data["UB6"], ub_data["UB7"], ub_data["UB8"]]]], dtype=float)
+
+ self.__unit_cell_abc = unit_cell_abc
+ self.__unit_cell_alphabetagamma = unit_cell_alphabetagamma
+ self.__ub_matrix = ub_matrix
+
+ def get_unit_cell_abc(self):
+ """Get a numpy array data as defined for the dataset unit_cell_abc
+ from the NXsample dataset.
+
+ :rtype: numpy.ndarray
+ """
+ if self.__unit_cell_abc is None:
+ self.parse_ub_matrix()
+ return self.__unit_cell_abc
+
+ def get_unit_cell_alphabetagamma(self):
+ """Get a numpy array data as defined for the dataset
+ unit_cell_alphabetagamma from the NXsample dataset.
+
+ :rtype: numpy.ndarray
+ """
+ if self.__unit_cell_alphabetagamma is None:
+ self.parse_ub_matrix()
+ return self.__unit_cell_alphabetagamma
+
+ def get_ub_matrix(self):
+ """Get a numpy array data as defined for the dataset ub_matrix
+ from the NXsample dataset.
+
+ :rtype: numpy.ndarray
+ """
+ if self.__ub_matrix is None:
+ self.parse_ub_matrix()
+ return self.__ub_matrix
+
+ def is_spectrum(self):
+ """Returns true if the data should be interpreted as
+ MCA data.
+ EDF files or file series, with two or more header names starting with
+ "MCA", should be interpreted as MCA data.
+
+ :rtype: bool
+ """
+ count = 0
+ for key in self._get_first_header():
+ if key.lower().startswith("mca"):
+ count += 1
+ if count >= 2:
+ return True
+ return False
+
+
+class File(commonh5.File):
+ """Class which handle a fabio image as a mimick of a h5py.File.
+ """
+
+ def __init__(self, file_name=None, fabio_image=None, file_series=None):
+ """
+ Constructor
+
+ :param str file_name: File name of the image file to read
+ :param fabio.fabioimage.FabioImage fabio_image: An already openned
+ :class:`fabio.fabioimage.FabioImage` instance.
+ :param Union[list[str],fabio.file_series.file_series] file_series: An
+ list of file name or a :class:`fabio.file_series.file_series`
+ instance
+ """
+ self.__fabio_reader = self.create_fabio_reader(file_name, fabio_image, file_series)
+ if fabio_image is not None:
+ file_name = fabio_image.filename
+ scan = self.create_scan_group(self.__fabio_reader)
+
+ attrs = {"NX_class": "NXroot",
+ "file_time": datetime.datetime.now().isoformat(),
+ "creator": "silx %s" % silx_version,
+ "default": scan.basename}
+ if file_name is not None:
+ attrs["file_name"] = file_name
+ commonh5.File.__init__(self, name=file_name, attrs=attrs)
+ self.add_node(scan)
+
+ def create_scan_group(self, fabio_reader):
+ """Factory to create the scan group.
+
+ :param FabioImage fabio_image: A Fabio image
+ :param FabioReader fabio_reader: A reader for the Fabio image
+ :rtype: commonh5.Group
+ """
+ nxdata = NxDataPreviewGroup("image", fabio_reader)
+ scan_attrs = {
+ "NX_class": "NXentry",
+ "default": nxdata.basename,
+ }
+ scan = commonh5.Group("scan_0", attrs=scan_attrs)
+ instrument = commonh5.Group("instrument", attrs={"NX_class": "NXinstrument"})
+ measurement = MeasurementGroup("measurement", fabio_reader, attrs={"NX_class": "NXcollection"})
+ file_ = commonh5.Group("file", attrs={"NX_class": "NXcollection"})
+ positioners = MetadataGroup("positioners", fabio_reader, FabioReader.POSITIONER, attrs={"NX_class": "NXpositioner"})
+ raw_header = RawHeaderData("scan_header", fabio_reader, self)
+ detector = DetectorGroup("detector_0", fabio_reader)
+
+ scan.add_node(instrument)
+ instrument.add_node(positioners)
+ instrument.add_node(file_)
+ instrument.add_node(detector)
+ file_.add_node(raw_header)
+ scan.add_node(measurement)
+ scan.add_node(nxdata)
+
+ if fabio_reader.has_sample_information():
+ sample = SampleGroup("sample", fabio_reader)
+ scan.add_node(sample)
+
+ return scan
+
+ def create_fabio_reader(self, file_name, fabio_image, file_series):
+ """Factory to create fabio reader.
+
+ :rtype: FabioReader"""
+ use_edf_reader = False
+ first_file_name = None
+ first_image = None
+
+ if isinstance(file_series, list):
+ first_file_name = file_series[0]
+ elif isinstance(file_series, fabio.file_series.file_series):
+ first_image = file_series.first_image()
+ elif fabio_image is not None:
+ first_image = fabio_image
+ else:
+ first_file_name = file_name
+
+ if first_file_name is not None:
+ _, ext = os.path.splitext(first_file_name)
+ ext = ext[1:]
+ edfimage = fabio.edfimage.EdfImage
+ if hasattr(edfimage, "DEFAULT_EXTENTIONS"):
+ # Typo on fabio 0.5
+ edf_extensions = edfimage.DEFAULT_EXTENTIONS
+ else:
+ edf_extensions = edfimage.DEFAULT_EXTENSIONS
+ use_edf_reader = ext in edf_extensions
+ elif first_image is not None:
+ use_edf_reader = isinstance(first_image, fabio.edfimage.EdfImage)
+ else:
+ assert(False)
+
+ if use_edf_reader:
+ reader = EdfFabioReader(file_name, fabio_image, file_series)
+ else:
+ reader = FabioReader(file_name, fabio_image, file_series)
+ return reader
+
+ def close(self):
+ """Close the object, and free up associated resources.
+
+ After calling this method, attempts to use the object (and children)
+ may fail.
+ """
+ self.__fabio_reader.close()
+ self.__fabio_reader = None
diff --git a/src/silx/io/fioh5.py b/src/silx/io/fioh5.py
new file mode 100644
index 0000000..75fe587
--- /dev/null
+++ b/src/silx/io/fioh5.py
@@ -0,0 +1,490 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2021 Timo Fuchs
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides a h5py-like API to access FioFile data.
+
+API description
++++++++++++++++
+
+Fiofile data structure exposed by this API:
+
+::
+
+ /
+ n.1/
+ title = "…"
+ start_time = "…"
+ instrument/
+ fiofile/
+ comments = "…"
+ parameter = "…"
+ comment = "…"
+ parameter/
+ parameter_name = value
+
+ measurement/
+ colname0 = …
+ colname1 = …
+ …
+
+
+The top level scan number ``n.1`` is determined from the filename as in
+``prefix_n.fio``. (e.g. ``eh1_sixc_00045.fio`` would give ``45.1``)
+If no number is available, will use the filename instead.
+
+``comments`` and ``parameter`` in group ``fiofile`` are the raw headers as they
+appear in the original file, as a string of lines separated by newline
+(``\\n``) characters. ``comment`` are the remaining comments,
+which were not parsed.
+
+
+
+The title is the content of the first comment header line
+(e.g ``"ascan ss1vo -4.55687 -0.556875 40 0.2"``).
+The start_time is parsed from the second comment line.
+
+Datasets are stored in the data format specified in the fio file header.
+
+Scan data (e.g. ``/1.1/measurement/colname0``) is accessed by column,
+the dataset name ``colname0`` being the column label as defined in the
+``Col …`` header line.
+
+If a ``/`` character is present in a column label or in a motor name in the
+original FIO file, it will be substituted with a ``%`` character in the
+corresponding dataset name.
+
+MCA data is not yet supported.
+
+This reader requires a fio file as defined in
+src/sardana/macroserver/recorders/storage.py of the Sardana project
+(https://github.com/sardana-org/sardana).
+
+
+Accessing data
+++++++++++++++
+
+Data and groups are accessed in :mod:`h5py` fashion::
+
+ from silx.io.fioh5 import FioH5
+
+ # Open a FioFile
+ fiofh5 = FioH5("test_00056.fio")
+
+ # using FioH5 as a regular group to access scans
+ scan1group = fiofh5["56.1"]
+ instrument_group = scan1group["instrument"]
+
+ # alternative: full path access
+ measurement_group = fiofh5["/56.1/measurement"]
+
+ # accessing a scan data column by name as a 1D numpy array
+ data_array = measurement_group["Pslit HGap"]
+
+
+:class:`FioH5` files and groups provide a :meth:`keys` method::
+
+ >>> fiofh5.keys()
+ ['96.1', '97.1', '98.1']
+ >>> fiofh5['96.1'].keys()
+ ['title', 'start_time', 'instrument', 'measurement']
+
+They can also be treated as iterators:
+
+.. code-block:: python
+
+ from silx.io import is_dataset
+
+ for scan_group in FioH5("test_00056.fio"):
+ dataset_names = [item.name in scan_group["measurement"] if
+ is_dataset(item)]
+ print("Found data columns in scan " + scan_group.name)
+ print(", ".join(dataset_names))
+
+You can test for existence of data or groups::
+
+ >>> "/1.1/measurement/Pslit HGap" in fiofh5
+ True
+ >>> "positioners" in fiofh5["/2.1/instrument"]
+ True
+ >>> "spam" in fiofh5["1.1"]
+ False
+
+"""
+
+__authors__ = ["T. Fuchs"]
+__license__ = "MIT"
+__date__ = "09/04/2021"
+
+
+import os
+
+import datetime
+import logging
+import io
+
+import h5py
+import numpy
+
+from silx import version as silx_version
+from . import commonh5
+
+from .spech5 import to_h5py_utf8
+
+logger1 = logging.getLogger(__name__)
+
+if h5py.version.version_tuple[0] < 3:
+ text_dtype = h5py.special_dtype(vlen=str) # old API
+else:
+ text_dtype = 'O' # variable-length string (supported as of h5py > 3.0)
+
+ABORTLINENO = 5
+
+dtypeConverter = {'STRING': text_dtype,
+ 'DOUBLE': 'f8',
+ 'FLOAT': 'f4',
+ 'INTEGER': 'i8',
+ 'BOOLEAN': '?'}
+
+
+def is_fiofile(filename):
+ """Test if a file is a FIO file, by checking if three consecutive lines
+ start with *!*. Tests up to ABORTLINENO lines at the start of the file.
+
+ :param str filename: File path
+ :return: *True* if file is a FIO file, *False* if it is not a FIO file
+ :rtype: bool
+ """
+ if not os.path.isfile(filename):
+ return False
+ # test for presence of three ! in first lines
+ with open(filename, "rb") as f:
+ chunk = f.read(2500)
+ count = 0
+ for i, line in enumerate(chunk.split(b"\n")):
+ if line.startswith(b"!"):
+ count += 1
+ if count >= 3:
+ return True
+ else:
+ count = 0
+ if i >= ABORTLINENO:
+ break
+ return False
+
+
+class FioFile(object):
+ """This class opens a FIO file and reads the data.
+
+ """
+
+ def __init__(self, filepath):
+ # parse filename
+ filename = os.path.basename(filepath)
+ fnowithsuffix = filename.split('_')[-1]
+ try:
+ self.scanno = int(fnowithsuffix.split('.')[0])
+ except Exception:
+ self.scanno = None
+ logger1.warning("Cannot parse scan number of file %s", filename)
+
+ with open(filepath, 'r') as fiof:
+
+ prev = 0
+ line_counter = 0
+
+ while(True):
+ line = fiof.readline()
+ if line.startswith('!'): # skip comments
+ prev = fiof.tell()
+ line_counter = 0
+ continue
+ if line.startswith('%c'): # comment section
+ line_counter = 0
+ self.commentsection = ''
+ line = fiof.readline()
+ while(not line.startswith('%')
+ and not line.startswith('!')):
+ self.commentsection += line
+ prev = fiof.tell()
+ line = fiof.readline()
+ if line.startswith('%p'): # parameter section
+ line_counter = 0
+ self.parameterssection = ''
+ line = fiof.readline()
+ while(not line.startswith('%')
+ and not line.startswith('!')):
+ self.parameterssection += line
+ prev = fiof.tell()
+ line = fiof.readline()
+ if line.startswith('%d'): # data type definitions
+ line_counter = 0
+ self.datacols = []
+ self.names = []
+ self.dtypes = []
+ line = fiof.readline()
+ while(line.startswith(' Col')):
+ splitline = line.split()
+ name = splitline[-2]
+ self.names.append(name)
+ dtype = dtypeConverter[splitline[-1]]
+ self.dtypes.append(dtype)
+ self.datacols.append((name, dtype))
+ prev = fiof.tell()
+ line = fiof.readline()
+ fiof.seek(prev)
+ break
+
+ line_counter += 1
+ if line_counter > ABORTLINENO:
+ raise IOError("Invalid fio file: Found no data "
+ "after %s lines" % ABORTLINENO)
+
+ self.data = numpy.loadtxt(fiof,
+ dtype={'names': tuple(self.names),
+ 'formats': tuple(self.dtypes)},
+ comments="!")
+
+ # ToDo: read only last line of file,
+ # which sometimes contains the end of acquisition timestamp.
+
+ self.parameter = {}
+
+ # parse parameter section:
+ try:
+ for line in self.parameterssection.splitlines():
+ param, value = line.split(' = ')
+ self.parameter[param] = value
+ except Exception:
+ logger1.warning("Cannot parse parameter section")
+
+ # parse default sardana comments: username and start time
+ try:
+ acquiMarker = "acquisition started at" # indicates timestamp
+ commentlines = self.commentsection.splitlines()
+ if len(commentlines) >= 2:
+ self.title = commentlines[0]
+ l2 = commentlines[1]
+ acqpos = l2.lower().find(acquiMarker)
+ if acqpos < 0:
+ raise Exception("acquisition str not found")
+
+ self.user = l2[:acqpos][4:].strip()
+ self.start_time = l2[acqpos+len(acquiMarker):].strip()
+ commentlines = commentlines[2:]
+ self.comments = "\n".join(commentlines[2:])
+
+ except Exception:
+ logger1.warning("Cannot parse default comment section")
+ self.comments = self.commentsection
+ self.user = ""
+ self.start_time = ""
+ self.title = ""
+
+
+class FioH5NodeDataset(commonh5.Dataset):
+ """This class inherits :class:`commonh5.Dataset`, to which it adds
+ little extra functionality. The main additional functionality is the
+ proxy behavior that allows to mimic the numpy array stored in this
+ class.
+ """
+
+ def __init__(self, name, data, parent=None, attrs=None):
+ # get proper value types, to inherit from numpy
+ # attributes (dtype, shape, size)
+ if isinstance(data, str):
+ # use unicode (utf-8 when saved to HDF5 output)
+ value = to_h5py_utf8(data)
+ elif isinstance(data, float):
+ # use 32 bits for float scalars
+ value = numpy.float32(data)
+ elif isinstance(data, int):
+ value = numpy.int_(data)
+ else:
+ # Enforce numpy array
+ array = numpy.array(data)
+ data_kind = array.dtype.kind
+
+ if data_kind in ["S", "U"]:
+ value = numpy.asarray(array,
+ dtype=text_dtype)
+ else:
+ value = array # numerical data is already the correct datatype
+ commonh5.Dataset.__init__(self, name, value, parent, attrs)
+
+ def __getattr__(self, item):
+ """Proxy to underlying numpy array methods.
+ """
+ if hasattr(self[()], item):
+ return getattr(self[()], item)
+
+ raise AttributeError("FioH5NodeDataset has no attribute %s" % item)
+
+
+class FioH5(commonh5.File):
+ """This class reads a FIO file and exposes it as a *h5py.File*.
+
+ It inherits :class:`silx.io.commonh5.Group` (via :class:`commonh5.File`),
+ which implements most of its API.
+ """
+
+ def __init__(self, filename, order=1):
+ """
+ :param filename: Path to FioFile in filesystem
+ :type filename: str
+ """
+ if isinstance(filename, io.IOBase):
+ # see https://github.com/silx-kit/silx/issues/858
+ filename = filename.name
+
+ if not is_fiofile(filename):
+ raise IOError("File %s is not a FIO file." % filename)
+
+ try:
+ fiof = FioFile(filename) # reads complete file
+ except Exception as e:
+ raise IOError("FIO file %s cannot be read.") from e
+
+ attrs = {"NX_class": to_h5py_utf8("NXroot"),
+ "file_time": to_h5py_utf8(
+ datetime.datetime.now().isoformat()),
+ "file_name": to_h5py_utf8(filename),
+ "creator": to_h5py_utf8("silx fioh5 %s" % silx_version)}
+ commonh5.File.__init__(self, filename, attrs=attrs)
+
+ if fiof.scanno is not None:
+ scan_key = "%s.%s" % (fiof.scanno, int(order))
+ else:
+ scan_key = os.path.splitext(os.path.basename(filename))[0]
+
+ scan_group = FioScanGroup(scan_key, parent=self, scan=fiof)
+ self.add_node(scan_group)
+
+
+class FioScanGroup(commonh5.Group):
+ def __init__(self, scan_key, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param str scan_key: Scan key (e.g. "1.1")
+ :param scan: FioFile object
+ """
+ if hasattr(scan, 'user'):
+ userattr = to_h5py_utf8(scan.user)
+ else:
+ userattr = to_h5py_utf8('')
+ commonh5.Group.__init__(self, scan_key, parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXentry"),
+ "user": userattr})
+
+ # 'title', 'start_time' and 'user' are defaults
+ # in Sardana created files:
+ if hasattr(scan, 'title'):
+ title = scan.title
+ else:
+ title = scan_key # use scan number as default title
+ self.add_node(FioH5NodeDataset(name="title",
+ data=to_h5py_utf8(title),
+ parent=self))
+
+ if hasattr(scan, 'start_time'):
+ start_time = scan.start_time
+ self.add_node(FioH5NodeDataset(name="start_time",
+ data=to_h5py_utf8(start_time),
+ parent=self))
+
+ self.add_node(FioH5NodeDataset(name="comments",
+ data=to_h5py_utf8(scan.comments),
+ parent=self))
+
+ self.add_node(FioInstrumentGroup(parent=self, scan=scan))
+ self.add_node(FioMeasurementGroup(parent=self, scan=scan))
+
+
+class FioMeasurementGroup(commonh5.Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: FioFile object
+ """
+ commonh5.Group.__init__(self, name="measurement", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")})
+
+ for label in scan.names:
+ safe_label = label.replace("/", "%")
+ self.add_node(FioH5NodeDataset(name=safe_label,
+ data=scan.data[label],
+ parent=self))
+
+
+class FioInstrumentGroup(commonh5.Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: FioFile object
+ """
+ commonh5.Group.__init__(self, name="instrument", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXinstrument")})
+
+ self.add_node(FioParameterGroup(parent=self, scan=scan))
+ self.add_node(FioFileGroup(parent=self, scan=scan))
+ self.add_node(FioH5NodeDataset(name="comment",
+ data=to_h5py_utf8(scan.comments),
+ parent=self))
+
+
+class FioFileGroup(commonh5.Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: FioFile object
+ """
+ commonh5.Group.__init__(self, name="fiofile", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")})
+
+ self.add_node(FioH5NodeDataset(name="comments",
+ data=to_h5py_utf8(scan.commentsection),
+ parent=self))
+
+ self.add_node(FioH5NodeDataset(name="parameter",
+ data=to_h5py_utf8(scan.parameterssection),
+ parent=self))
+
+
+class FioParameterGroup(commonh5.Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: FioFile object
+ """
+ commonh5.Group.__init__(self, name="parameter", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")})
+
+ for label in scan.parameter:
+ safe_label = label.replace("/", "%")
+ self.add_node(FioH5NodeDataset(name=safe_label,
+ data=to_h5py_utf8(scan.parameter[label]),
+ parent=self))
diff --git a/src/silx/io/h5py_utils.py b/src/silx/io/h5py_utils.py
new file mode 100644
index 0000000..fb04152
--- /dev/null
+++ b/src/silx/io/h5py_utils.py
@@ -0,0 +1,440 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module provides utility methods on top of h5py, mainly to handle
+parallel writing and reading.
+"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "27/01/2020"
+
+
+import os
+import sys
+import traceback
+import logging
+import h5py
+
+from .._version import calc_hexversion
+from ..utils import retry as retry_mod
+from silx.utils.deprecation import deprecated_warning
+
+_logger = logging.getLogger(__name__)
+
+IS_WINDOWS = sys.platform == "win32"
+
+H5PY_HEX_VERSION = calc_hexversion(*h5py.version.version_tuple[:3])
+HDF5_HEX_VERSION = calc_hexversion(*h5py.version.hdf5_version_tuple[:3])
+
+HDF5_SWMR_VERSION = calc_hexversion(*h5py.get_config().swmr_min_hdf5_version[:3])
+HAS_SWMR = HDF5_HEX_VERSION >= HDF5_SWMR_VERSION
+
+HAS_TRACK_ORDER = H5PY_HEX_VERSION >= calc_hexversion(2, 9, 0)
+
+if h5py.version.hdf5_version_tuple[:2] == (1, 10):
+ HDF5_HAS_LOCKING_ARGUMENT = HDF5_HEX_VERSION >= calc_hexversion(1, 10, 7)
+else:
+ HDF5_HAS_LOCKING_ARGUMENT = HDF5_HEX_VERSION >= calc_hexversion(1, 12, 1)
+H5PY_HAS_LOCKING_ARGUMENT = H5PY_HEX_VERSION >= calc_hexversion(3, 5, 0)
+HAS_LOCKING_ARGUMENT = HDF5_HAS_LOCKING_ARGUMENT & H5PY_HAS_LOCKING_ARGUMENT
+
+LATEST_LIBVER_IS_V108 = HDF5_HEX_VERSION < calc_hexversion(1, 10, 0)
+
+
+def _libver_low_bound_is_v108(libver) -> bool:
+ if libver is None:
+ return True
+ if LATEST_LIBVER_IS_V108:
+ return True
+ if isinstance(libver, str):
+ low = libver
+ else:
+ low = libver[0]
+ if low == "latest":
+ return False
+ return low == "v108"
+
+
+def _hdf5_file_locking(mode="r", locking=None, swmr=None, libver=None, **_):
+ """Concurrent access by disabling file locking is not supported
+ in these cases:
+
+ * mode != "r": causes file corruption
+ * SWMR: does not work
+ * libver > v108 and file already locked: does not work
+ * windows and HDF5_HAS_LOCKING_ARGUMENT and file already locked: does not work
+
+ :param str or None mode: read-only by default
+ :param bool or None locking: by default it is disabled for `mode='r'`
+ and `swmr=False` and enabled for all
+ other modes.
+ :param bool or None swmr: try both modes when `mode='r'` and `swmr=None`
+ :param None or str or tuple libver:
+ :returns bool:
+ """
+ if locking is None:
+ locking = bool(mode != "r" or swmr)
+ if not locking:
+ if mode != "r":
+ raise ValueError("Locking is mandatory for HDF5 writing")
+ if swmr:
+ raise ValueError("Locking is mandatory for HDF5 SWMR mode")
+ if IS_WINDOWS and HDF5_HAS_LOCKING_ARGUMENT:
+ _logger.debug(
+ "Non-locking readers will fail when a writer has already locked the HDF5 file (this restriction applies to libhdf5 >= 1.12.1 or libhdf5 >= 1.10.7 on Windows)"
+ )
+ if not _libver_low_bound_is_v108(libver):
+ _logger.debug(
+ "Non-locking readers will fail when a writer has already locked the HDF5 file (this restriction applies to libver >= v110)"
+ )
+ return locking
+
+
+def _is_h5py_exception(e):
+ """
+ :param BaseException e:
+ :returns bool:
+ """
+ for frame in traceback.walk_tb(e.__traceback__):
+ if frame[0].f_locals.get("__package__", None) == "h5py":
+ return True
+ return False
+
+
+def _retry_h5py_error(e):
+ """
+ :param BaseException e:
+ :returns bool:
+ """
+ if _is_h5py_exception(e):
+ if isinstance(e, (OSError, RuntimeError)):
+ return True
+ elif isinstance(e, KeyError):
+ # For example this needs to be retried:
+ # KeyError: 'Unable to open object (bad object header version number)'
+ return "Unable to open object" in str(e)
+ elif isinstance(e, retry_mod.RetryError):
+ return True
+ return False
+
+
+def retry(**kw):
+ r"""Decorator for a method that needs to be executed until it not longer
+ fails on HDF5 IO. Mainly used for reading an HDF5 file that is being
+ written.
+
+ :param \**kw: see `silx.utils.retry`
+ """
+ kw.setdefault("retry_on_error", _retry_h5py_error)
+ return retry_mod.retry(**kw)
+
+
+def retry_contextmanager(**kw):
+ r"""Decorator to make a context manager from a method that needs to be
+ entered until it not longer fails on HDF5 IO. Mainly used for reading
+ an HDF5 file that is being written.
+
+ :param \**kw: see `silx.utils.retry_contextmanager`
+ """
+ kw.setdefault("retry_on_error", _retry_h5py_error)
+ return retry_mod.retry_contextmanager(**kw)
+
+
+def retry_in_subprocess(**kw):
+ r"""Same as `retry` but it also retries segmentation faults.
+
+ On Window you cannot use this decorator with the "@" syntax:
+
+ .. code-block:: python
+
+ def _method(*args, **kw):
+ ...
+
+ method = retry_in_subprocess()(_method)
+
+ :param \**kw: see `silx.utils.retry_in_subprocess`
+ """
+ kw.setdefault("retry_on_error", _retry_h5py_error)
+ return retry_mod.retry_in_subprocess(**kw)
+
+
+def group_has_end_time(h5item):
+ """Returns True when the HDF5 item is a Group with an "end_time"
+ dataset. A reader can use this as an indication that the Group
+ has been fully written (at least if the writer supports this).
+
+ :param Union[h5py.Group,h5py.Dataset] h5item:
+ :returns bool:
+ """
+ if isinstance(h5item, h5py.Group):
+ return "end_time" in h5item
+ else:
+ return False
+
+
+@retry_contextmanager()
+def open_item(filename, name, retry_invalid=False, validate=None, **open_options):
+ r"""Yield an HDF5 dataset or group (retry until it can be instantiated).
+
+ :param str filename:
+ :param bool retry_invalid: retry when item is missing or not valid
+ :param callable or None validate:
+ :param \**open_options: see `File.__init__`
+ :yields Dataset, Group or None:
+ """
+ with File(filename, **open_options) as h5file:
+ try:
+ item = h5file[name]
+ except KeyError as e:
+ if "doesn't exist" in str(e):
+ if retry_invalid:
+ raise retry_mod.RetryError
+ else:
+ item = None
+ else:
+ raise
+ if callable(validate) and item is not None:
+ if not validate(item):
+ if retry_invalid:
+ raise retry_mod.RetryError
+ else:
+ item = None
+ yield item
+
+
+def _top_level_names(filename, include_only=group_has_end_time, **open_options):
+ r"""Return all valid top-level HDF5 names.
+
+ :param str filename:
+ :param callable or None include_only:
+ :param \**open_options: see `File.__init__`
+ :returns list(str):
+ """
+ with File(filename, **open_options) as h5file:
+ try:
+ if callable(include_only):
+ return [name for name in h5file["/"] if include_only(h5file[name])]
+ else:
+ return list(h5file["/"])
+ except KeyError:
+ raise retry_mod.RetryError
+
+
+top_level_names = retry()(_top_level_names)
+safe_top_level_names = retry_in_subprocess()(_top_level_names)
+
+
+class Hdf5FileLockingManager:
+ """Manage HDF5 file locking in the current process through the HDF5_USE_FILE_LOCKING
+ environment variable.
+ """
+
+ def __init__(self) -> None:
+ self._hdf5_file_locking = None
+ self._nfiles_open = 0
+
+ def opened(self):
+ self._add_nopen(1)
+
+ def closed(self):
+ self._add_nopen(-1)
+ if not self._nfiles_open:
+ self._restore_locking_env()
+
+ def set_locking(self, locking):
+ if self._nfiles_open:
+ self._check_locking_env(locking)
+ else:
+ self._set_locking_env(locking)
+
+ def _add_nopen(self, v):
+ self._nfiles_open = max(self._nfiles_open + v, 0)
+
+ def _set_locking_env(self, enable):
+ self._backup_locking_env()
+ if enable:
+ os.environ["HDF5_USE_FILE_LOCKING"] = "TRUE"
+ elif enable is None:
+ try:
+ del os.environ["HDF5_USE_FILE_LOCKING"]
+ except KeyError:
+ pass
+ else:
+ os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
+
+ def _get_locking_env(self):
+ v = os.environ.get("HDF5_USE_FILE_LOCKING")
+ if v == "TRUE":
+ return True
+ elif v is None:
+ return None
+ else:
+ return False
+
+ def _check_locking_env(self, enable):
+ if enable != self._get_locking_env():
+ if enable:
+ raise RuntimeError(
+ "Close all HDF5 files before enabling HDF5 file locking"
+ )
+ else:
+ raise RuntimeError(
+ "Close all HDF5 files before disabling HDF5 file locking"
+ )
+
+ def _backup_locking_env(self):
+ v = os.environ.get("HDF5_USE_FILE_LOCKING")
+ if v is None:
+ self._hdf5_file_locking = None
+ else:
+ self._hdf5_file_locking = v == "TRUE"
+
+ def _restore_locking_env(self):
+ self._set_locking_env(self._hdf5_file_locking)
+ self._hdf5_file_locking = None
+
+
+class File(h5py.File):
+ """Takes care of HDF5 file locking and SWMR mode without the need
+ to handle those explicitely.
+
+ When file locking is managed through the HDF5_USE_FILE_LOCKING environment
+ variable, you cannot open different files simultaneously with different modes.
+ """
+
+ _SWMR_LIBVER = "latest"
+
+ if HAS_LOCKING_ARGUMENT:
+ _LOCKING_MGR = None
+ else:
+ _LOCKING_MGR = Hdf5FileLockingManager()
+
+ def __init__(
+ self,
+ filename,
+ mode=None,
+ locking=None,
+ enable_file_locking=None,
+ swmr=None,
+ libver=None,
+ **kwargs,
+ ):
+ r"""The arguments `locking` and `swmr` should not be
+ specified explicitly for normal use cases.
+
+ :param str filename:
+ :param str or None mode: read-only by default
+ :param bool or None locking: by default it is disabled for `mode='r'`
+ and `swmr=False` and enabled for all
+ other modes.
+ :param bool or None enable_file_locking: deprecated
+ :param bool or None swmr: try both modes when `mode='r'` and `swmr=None`
+ :param None or str or tuple libver:
+ :param \**kwargs: see `h5py.File.__init__`
+ """
+ # File locking behavior has changed in recent versions of libhdf5
+ if HDF5_HAS_LOCKING_ARGUMENT != H5PY_HAS_LOCKING_ARGUMENT:
+ _logger.critical(
+ "The version of libhdf5 ({}) used by h5py ({}) is not supported: "
+ "Do not expect file locking to work.".format(
+ h5py.version.hdf5_version, h5py.version.version
+ )
+ )
+
+ if mode is None:
+ mode = "r"
+ elif mode not in ("r", "w", "w-", "x", "a", "r+"):
+ raise ValueError("invalid mode {}".format(mode))
+ if not HAS_SWMR:
+ swmr = False
+ if swmr and libver is None:
+ libver = self._SWMR_LIBVER
+
+ if enable_file_locking is not None:
+ deprecated_warning(
+ type_="argument",
+ name="enable_file_locking",
+ replacement="locking",
+ since_version="1.0",
+ )
+ if locking is None:
+ locking = enable_file_locking
+ locking = _hdf5_file_locking(
+ mode=mode, locking=locking, swmr=swmr, libver=libver
+ )
+ if self._LOCKING_MGR is None:
+ kwargs.setdefault("locking", locking)
+ else:
+ self._LOCKING_MGR.set_locking(locking)
+
+ if HAS_TRACK_ORDER:
+ kwargs.setdefault("track_order", True)
+ try:
+ super().__init__(filename, mode=mode, swmr=swmr, libver=libver, **kwargs)
+ except OSError as e:
+ # wlock wSWMR rlock rSWMR OSError: Unable to open file (...)
+ # 1 TRUE FALSE FALSE FALSE -
+ # 2 TRUE FALSE FALSE TRUE -
+ # 3 TRUE FALSE TRUE FALSE unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'
+ # 4 TRUE FALSE TRUE TRUE unable to lock file, errno = 11, error message = 'Resource temporarily unavailable'
+ # 5 TRUE TRUE FALSE FALSE file is already open for write (may use <h5clear file> to clear file consistency flags)
+ # 6 TRUE TRUE FALSE TRUE -
+ # 7 TRUE TRUE TRUE FALSE file is already open for write (may use <h5clear file> to clear file consistency flags)
+ # 8 TRUE TRUE TRUE TRUE -
+ if (
+ mode == "r"
+ and swmr is None
+ and "file is already open for write" in str(e)
+ ):
+ # Try reading in SWMR mode (situation 5 and 7)
+ swmr = True
+ if libver is None:
+ libver = self._SWMR_LIBVER
+ super().__init__(
+ filename, mode=mode, swmr=swmr, libver=libver, **kwargs
+ )
+ else:
+ raise
+ else:
+ self._file_open_callback()
+ try:
+ if mode != "r" and swmr:
+ # Try setting writer in SWMR mode
+ self.swmr_mode = True
+ except Exception:
+ self.close()
+ raise
+
+ def close(self):
+ super().close()
+ self._file_close_callback()
+
+ def _file_open_callback(self):
+ if self._LOCKING_MGR is not None:
+ self._LOCKING_MGR.opened()
+
+ def _file_close_callback(self):
+ if self._LOCKING_MGR is not None:
+ self._LOCKING_MGR.closed()
diff --git a/silx/io/nxdata/__init__.py b/src/silx/io/nxdata/__init__.py
index 5bfa442..5bfa442 100644
--- a/silx/io/nxdata/__init__.py
+++ b/src/silx/io/nxdata/__init__.py
diff --git a/src/silx/io/nxdata/_utils.py b/src/silx/io/nxdata/_utils.py
new file mode 100644
index 0000000..12318f1
--- /dev/null
+++ b/src/silx/io/nxdata/_utils.py
@@ -0,0 +1,183 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utility functions used by NXdata validation and parsing."""
+
+import copy
+import logging
+
+import numpy
+
+from silx.io import is_dataset
+from silx.utils.deprecation import deprecated
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/04/2018"
+
+
+nxdata_logger = logging.getLogger("silx.io.nxdata")
+
+
+INTERPDIM = {"scalar": 0,
+ "spectrum": 1,
+ "image": 2,
+ "rgba-image": 3, # "hsla-image": 3, "cmyk-image": 3, # TODO
+ "vertex": 1} # 3D scatter: 1D signal + 3 axes (x, y, z) of same legth
+"""Number of signal dimensions associated to each possible @interpretation
+attribute.
+"""
+
+
+@deprecated(since_version="0.8.0", replacement="get_attr_as_unicode")
+def get_attr_as_string(*args, **kwargs):
+ return get_attr_as_unicode(*args, **kwargs)
+
+
+def get_attr_as_unicode(item, attr_name, default=None):
+ """Return item.attrs[attr_name] as unicode or as a
+ list of unicode.
+
+ Numpy arrays of strings or bytes returned by h5py are converted to
+ lists of unicode.
+
+ :param item: Group or dataset
+ :param attr_name: Attribute name
+ :param default: Value to be returned if attribute is not found.
+ :return: item.attrs[attr_name]
+ """
+ attr = item.attrs.get(attr_name, default)
+
+ if isinstance(attr, bytes):
+ # byte-string
+ return attr.decode("utf-8")
+ elif isinstance(attr, numpy.ndarray) and not attr.shape:
+ if isinstance(attr[()], bytes):
+ # byte string as ndarray scalar
+ return attr[()].decode("utf-8")
+ else:
+ # other scalar, possibly unicode
+ return attr[()]
+ elif isinstance(attr, numpy.ndarray) and len(attr.shape):
+ if hasattr(attr[0], "decode"):
+ # array of byte-strings
+ return [element.decode("utf-8") for element in attr]
+ else:
+ # other array, most likely unicode objects
+ return [element for element in attr]
+ else:
+ return copy.deepcopy(attr)
+
+
+def get_uncertainties_names(group, signal_name):
+ # Test consistency of @uncertainties
+ uncertainties_names = get_attr_as_unicode(group, "uncertainties")
+ if uncertainties_names is None:
+ uncertainties_names = get_attr_as_unicode(group[signal_name], "uncertainties")
+ if isinstance(uncertainties_names, str):
+ uncertainties_names = [uncertainties_names]
+ return uncertainties_names
+
+
+def get_signal_name(group):
+ """Return the name of the (main) signal in a NXdata group.
+ Return None if this info is missing (invalid NXdata).
+
+ """
+ signal_name = get_attr_as_unicode(group, "signal", default=None)
+ if signal_name is None:
+ nxdata_logger.info("NXdata group %s does not define a signal attr. "
+ "Testing legacy specification.", group.name)
+ for key in group:
+ if "signal" in group[key].attrs:
+ signal_name = key
+ signal_attr = group[key].attrs["signal"]
+ if signal_attr in [1, b"1", u"1"]:
+ # This is the main (default) signal
+ break
+ return signal_name
+
+
+def get_auxiliary_signals_names(group):
+ """Return list of auxiliary signals names"""
+ auxiliary_signals_names = get_attr_as_unicode(group, "auxiliary_signals",
+ default=[])
+ if isinstance(auxiliary_signals_names, (str, bytes)):
+ auxiliary_signals_names = [auxiliary_signals_names]
+ return auxiliary_signals_names
+
+
+def validate_auxiliary_signals(group, signal_name, auxiliary_signals_names):
+ """Check data dimensionality and size. Return False if invalid."""
+ issues = []
+ for asn in auxiliary_signals_names:
+ if asn not in group or not is_dataset(group[asn]):
+ issues.append(
+ "Cannot find auxiliary signal dataset '%s'" % asn)
+ elif group[signal_name].shape != group[asn].shape:
+ issues.append("Auxiliary signal dataset '%s' does not" % asn +
+ " have the same shape as the main signal.")
+ return issues
+
+
+def validate_number_of_axes(group, signal_name, num_axes):
+ issues = []
+ ndims = len(group[signal_name].shape)
+ if 1 < ndims < num_axes:
+ # ndim = 1 with several axes could be a scatter
+ issues.append(
+ "More @axes defined than there are " +
+ "signal dimensions: " +
+ "%d axes, %d dimensions." % (num_axes, ndims))
+
+ # case of less axes than dimensions: number of axes must match
+ # dimensionality defined by @interpretation
+ elif ndims > num_axes:
+ interpretation = get_attr_as_unicode(group[signal_name], "interpretation")
+ if interpretation is None:
+ interpretation = get_attr_as_unicode(group, "interpretation")
+ if interpretation is None:
+ issues.append("No @interpretation and not enough" +
+ " @axes defined.")
+
+ elif interpretation not in INTERPDIM:
+ issues.append("Unrecognized @interpretation=" + interpretation +
+ " for data with wrong number of defined @axes.")
+ elif interpretation == "rgba-image":
+ if ndims != 3 or group[signal_name].shape[-1] not in [3, 4]:
+ issues.append(
+ "Inconsistent RGBA Image. Expected 3 dimensions with " +
+ "last one of length 3 or 4. Got ndim=%d " % ndims +
+ "with last dimension of length %d." % group[signal_name].shape[-1])
+ if num_axes != 2:
+ issues.append(
+ "Inconsistent number of axes for RGBA Image. Expected "
+ "3, but got %d." % ndims)
+
+ elif num_axes != INTERPDIM[interpretation]:
+ issues.append(
+ "%d-D signal with @interpretation=%s " % (ndims, interpretation) +
+ "must define %d or %d axes." % (ndims, INTERPDIM[interpretation]))
+ return issues
diff --git a/src/silx/io/nxdata/parse.py b/src/silx/io/nxdata/parse.py
new file mode 100644
index 0000000..d00f65b
--- /dev/null
+++ b/src/silx/io/nxdata/parse.py
@@ -0,0 +1,1004 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides a collection of functions to work with h5py-like
+groups following the NeXus *NXdata* specification.
+
+See http://download.nexusformat.org/sphinx/classes/base_classes/NXdata.html
+
+The main class is :class:`NXdata`.
+You can also fetch the default NXdata in a NXroot or a NXentry with function
+:func:`get_default`.
+
+
+Other public functions:
+
+ - :func:`is_valid_nxdata`
+ - :func:`is_NXroot_with_default_NXdata`
+ - :func:`is_NXentry_with_default_NXdata`
+ - :func:`is_group_with_default_NXdata`
+
+"""
+
+import json
+import numpy
+
+from silx.io.utils import is_group, is_file, is_dataset, h5py_read_dataset
+
+from ._utils import get_attr_as_unicode, INTERPDIM, nxdata_logger, \
+ get_uncertainties_names, get_signal_name, \
+ get_auxiliary_signals_names, validate_auxiliary_signals, validate_number_of_axes
+
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/03/2020"
+
+
+class InvalidNXdataError(Exception):
+ pass
+
+
+class _SilxStyle(object):
+ """NXdata@SILX_style parser.
+
+ :param NXdata nxdata:
+ NXdata description for which to extract silx_style information.
+ """
+
+ def __init__(self, nxdata):
+ naxes = len(nxdata.axes)
+ self._axes_scale_types = [None] * naxes
+ self._signal_scale_type = None
+
+ stylestr = get_attr_as_unicode(nxdata.group, "SILX_style")
+ if stylestr is None:
+ return
+
+ try:
+ style = json.loads(stylestr)
+ except json.JSONDecodeError:
+ nxdata_logger.error(
+ "Ignoring SILX_style, cannot parse: %s", stylestr)
+ return
+
+ if not isinstance(style, dict):
+ nxdata_logger.error(
+ "Ignoring SILX_style, cannot parse: %s", stylestr)
+
+ if 'axes_scale_types' in style:
+ axes_scale_types = style['axes_scale_types']
+
+ if isinstance(axes_scale_types, str):
+ # Convert single argument to list
+ axes_scale_types = [axes_scale_types]
+
+ if not isinstance(axes_scale_types, list):
+ nxdata_logger.error(
+ "Ignoring SILX_style:axes_scale_types, not a list")
+ else:
+ for scale_type in axes_scale_types:
+ if scale_type not in ('linear', 'log'):
+ nxdata_logger.error(
+ "Ignoring SILX_style:axes_scale_types, invalid value: %s", str(scale_type))
+ break
+ else: # All values are valid
+ if len(axes_scale_types) > naxes:
+ nxdata_logger.error(
+ "Clipping SILX_style:axes_scale_types, too many values")
+ axes_scale_types = axes_scale_types[:naxes]
+ elif len(axes_scale_types) < naxes:
+ # Extend axes_scale_types with None to match number of axes
+ axes_scale_types = [None] * (naxes - len(axes_scale_types)) + axes_scale_types
+ self._axes_scale_types = tuple(axes_scale_types)
+
+ if 'signal_scale_type' in style:
+ scale_type = style['signal_scale_type']
+ if scale_type not in ('linear', 'log'):
+ nxdata_logger.error(
+ "Ignoring SILX_style:signal_scale_type, invalid value: %s", str(scale_type))
+ else:
+ self._signal_scale_type = scale_type
+
+ axes_scale_types = property(
+ lambda self: self._axes_scale_types,
+ doc="Tuple of NXdata axes scale types (None, 'linear' or 'log'). List[str]")
+
+ signal_scale_type = property(
+ lambda self: self._signal_scale_type,
+ doc="NXdata signal scale type (None, 'linear' or 'log'). str")
+
+
+class NXdata(object):
+ """NXdata parser.
+
+ .. note::
+
+ Before attempting to access any attribute or property,
+ you should check that :attr:`is_valid` is *True*.
+
+ :param group: h5py-like group following the NeXus *NXdata* specification.
+ :param boolean validate: Set this parameter to *False* to skip the initial
+ validation. This option is provided for optimisation purposes, for cases
+ where :meth:`silx.io.nxdata.is_valid_nxdata` has already been called
+ prior to instantiating this :class:`NXdata`.
+ """
+ def __init__(self, group, validate=True):
+ super(NXdata, self).__init__()
+ self._plot_style = None
+
+ self.group = group
+ """h5py-like group object with @NX_class=NXdata.
+ """
+
+ self.issues = []
+ """List of error messages for malformed NXdata."""
+
+ if validate:
+ self._validate()
+ self.is_valid = not self.issues
+ """Validity status for this NXdata.
+ If False, all properties and attributes will be None.
+ """
+
+ self._is_scatter = None
+ self._axes = None
+
+ self.signal = None
+ """Main signal dataset in this NXdata group.
+ In case more than one signal is present in this group,
+ the other ones can be found in :attr:`auxiliary_signals`.
+ """
+
+ self.signal_name = None
+ """Signal long name, as specified in the @long_name attribute of the
+ signal dataset. If not specified, the dataset name is used."""
+
+ self.signal_ndim = None
+ self.signal_is_0d = None
+ self.signal_is_1d = None
+ self.signal_is_2d = None
+ self.signal_is_3d = None
+
+ self.axes_names = None
+ """List of axes names in a NXdata group.
+
+ This attribute is similar to :attr:`axes_dataset_names` except that
+ if an axis dataset has a "@long_name" attribute, it will be used
+ instead of the dataset name.
+ """
+
+ if not self.is_valid:
+ nxdata_logger.debug("%s", self.issues)
+ else:
+ self.signal = self.group[self.signal_dataset_name]
+ self.signal_name = get_attr_as_unicode(self.signal, "long_name")
+
+ if self.signal_name is None:
+ self.signal_name = self.signal_dataset_name
+
+ # ndim will be available in very recent h5py versions only
+ self.signal_ndim = getattr(self.signal, "ndim",
+ len(self.signal.shape))
+
+ self.signal_is_0d = self.signal_ndim == 0
+ self.signal_is_1d = self.signal_ndim == 1
+ self.signal_is_2d = self.signal_ndim == 2
+ self.signal_is_3d = self.signal_ndim == 3
+
+ self.axes_names = []
+ # check if axis dataset defines @long_name
+ for _, dsname in enumerate(self.axes_dataset_names):
+ if dsname is not None and "long_name" in self.group[dsname].attrs:
+ self.axes_names.append(get_attr_as_unicode(self.group[dsname], "long_name"))
+ else:
+ self.axes_names.append(dsname)
+
+ # excludes scatters
+ self.signal_is_1d = self.signal_is_1d and len(self.axes) <= 1 # excludes n-D scatters
+
+ self._plot_style = _SilxStyle(self)
+
+ def _validate(self):
+ """Fill :attr:`issues` with error messages for each error found."""
+ if not is_group(self.group):
+ raise TypeError("group must be a h5py-like group")
+ if get_attr_as_unicode(self.group, "NX_class") != "NXdata":
+ self.issues.append("Group has no attribute @NX_class='NXdata'")
+ return
+
+ signal_name = get_signal_name(self.group)
+ if signal_name is None:
+ self.issues.append("No @signal attribute on the NXdata group, "
+ "and no dataset with a @signal=1 attr found")
+ # very difficult to do more consistency tests without signal
+ return
+
+ elif signal_name not in self.group or not is_dataset(self.group[signal_name]):
+ self.issues.append("Cannot find signal dataset '%s'" % signal_name)
+ return
+
+ auxiliary_signals_names = get_auxiliary_signals_names(self.group)
+ self.issues += validate_auxiliary_signals(self.group,
+ signal_name,
+ auxiliary_signals_names)
+
+ axes_names = get_attr_as_unicode(self.group, "axes")
+ if axes_names is None:
+ # try @axes on signal dataset (older NXdata specification)
+ axes_names = get_attr_as_unicode(self.group[signal_name], "axes")
+ if axes_names is not None:
+ # we expect a comma separated string
+ if hasattr(axes_names, "split"):
+ axes_names = axes_names.split(":")
+
+ if isinstance(axes_names, (str, bytes)):
+ axes_names = [axes_names]
+
+ if axes_names:
+ self.issues += validate_number_of_axes(self.group, signal_name,
+ num_axes=len(axes_names))
+
+ # Test consistency of @uncertainties
+ uncertainties_names = get_uncertainties_names(self.group, signal_name)
+ if uncertainties_names is not None:
+ if len(uncertainties_names) != len(axes_names):
+ if len(uncertainties_names) < len(axes_names):
+ # ignore the field to avoid index error in the axes loop
+ uncertainties_names = None
+ self.issues.append("@uncertainties does not define the same " +
+ "number of fields than @axes. Field ignored")
+ else:
+ self.issues.append("@uncertainties does not define the same " +
+ "number of fields than @axes")
+
+ # Test individual axes
+ is_scatter = True # true if all axes have the same size as the signal
+ signal_size = 1
+ for dim in self.group[signal_name].shape:
+ signal_size *= dim
+ polynomial_axes_names = []
+ for i, axis_name in enumerate(axes_names):
+
+ if axis_name == ".":
+ continue
+ if axis_name not in self.group or not is_dataset(self.group[axis_name]):
+ self.issues.append("Could not find axis dataset '%s'" % axis_name)
+ continue
+
+ axis_size = 1
+ for dim in self.group[axis_name].shape:
+ axis_size *= dim
+
+ if len(self.group[axis_name].shape) != 1:
+ # I don't know how to interpret n-D axes
+ self.issues.append("Axis %s is not 1D" % axis_name)
+ continue
+ else:
+ # for a 1-d axis,
+ fg_idx = self.group[axis_name].attrs.get("first_good", 0)
+ lg_idx = self.group[axis_name].attrs.get("last_good", len(self.group[axis_name]) - 1)
+ axis_len = lg_idx + 1 - fg_idx
+
+ if axis_len != signal_size:
+ if axis_len not in self.group[signal_name].shape + (1, 2):
+ self.issues.append(
+ "Axis %s number of elements does not " % axis_name +
+ "correspond to the length of any signal dimension,"
+ " it does not appear to be a constant or a linear calibration," +
+ " and this does not seem to be a scatter plot.")
+ continue
+ elif axis_len in (1, 2):
+ polynomial_axes_names.append(axis_name)
+ is_scatter = False
+ else:
+ if not is_scatter:
+ self.issues.append(
+ "Axis %s number of elements is equal " % axis_name +
+ "to the length of the signal, but this does not seem" +
+ " to be a scatter (other axes have different sizes)")
+ continue
+
+ # Test individual uncertainties
+ errors_name = axis_name + "_errors"
+ if errors_name not in self.group and uncertainties_names is not None:
+ errors_name = uncertainties_names[i]
+ if errors_name in self.group and axis_name not in polynomial_axes_names:
+ if self.group[errors_name].shape != self.group[axis_name].shape:
+ self.issues.append(
+ "Errors '%s' does not have the same " % errors_name +
+ "dimensions as axis '%s'." % axis_name)
+
+ # test dimensions of errors associated with signal
+
+ signal_errors = signal_name + "_errors"
+ if "errors" in self.group and is_dataset(self.group["errors"]):
+ errors = "errors"
+ elif signal_errors in self.group and is_dataset(self.group[signal_errors]):
+ errors = signal_errors
+ else:
+ errors = None
+ if errors:
+ if self.group[errors].shape != self.group[signal_name].shape:
+ # In principle just the same size should be enough but
+ # NeXus documentation imposes to have the same shape
+ self.issues.append(
+ "Dataset containing standard deviations must " +
+ "have the same dimensions as the signal.")
+
+ @property
+ def signal_dataset_name(self):
+ """Name of the main signal dataset."""
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+ signal_dataset_name = get_attr_as_unicode(self.group, "signal")
+ if signal_dataset_name is None:
+ # find a dataset with @signal == 1
+ for dsname in self.group:
+ signal_attr = self.group[dsname].attrs.get("signal")
+ if signal_attr in [1, b"1", u"1"]:
+ # This is the main (default) signal
+ signal_dataset_name = dsname
+ break
+ assert signal_dataset_name is not None
+ return signal_dataset_name
+
+ @property
+ def auxiliary_signals_dataset_names(self):
+ """Sorted list of names of the auxiliary signals datasets.
+
+ These are the names provided by the *@auxiliary_signals* attribute
+ on the NXdata group.
+
+ In case the NXdata group does not specify a *@signal* attribute
+ but has a dataset with an attribute *@signal=1*,
+ we look for datasets with attributes *@signal=2, @signal=3...*
+ (deprecated NXdata specification)."""
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+ signal_dataset_name = get_attr_as_unicode(self.group, "signal")
+ if signal_dataset_name is not None:
+ auxiliary_signals_names = get_attr_as_unicode(self.group, "auxiliary_signals")
+ if auxiliary_signals_names is not None:
+ if not isinstance(auxiliary_signals_names,
+ (tuple, list, numpy.ndarray)):
+ # tolerate a single string, but coerce into a list
+ return [auxiliary_signals_names]
+ return list(auxiliary_signals_names)
+ return []
+
+ # try old spec, @signal=1 (2, 3...) on dataset
+ numbered_names = []
+ for dsname in self.group:
+ if dsname == self.signal_dataset_name:
+ # main signal, not auxiliary
+ continue
+ ds = self.group[dsname]
+ signal_attr = ds.attrs.get("signal")
+ if signal_attr is not None and not is_dataset(ds):
+ nxdata_logger.warning("Item %s with @signal=%s is not a dataset (%s)",
+ dsname, signal_attr, type(ds))
+ continue
+ if signal_attr is not None:
+ try:
+ signal_number = int(signal_attr)
+ except (ValueError, TypeError):
+ nxdata_logger.warning("Could not parse attr @signal=%s on "
+ "dataset %s as an int",
+ signal_attr, dsname)
+ continue
+ numbered_names.append((signal_number, dsname))
+ return [a[1] for a in sorted(numbered_names)]
+
+ @property
+ def auxiliary_signals_names(self):
+ """List of names of the auxiliary signals.
+
+ Similar to :attr:`auxiliary_signals_dataset_names`, but the @long_name
+ is used when this attribute is present, instead of the dataset name.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ signal_names = []
+ for asdn in self.auxiliary_signals_dataset_names:
+ if "long_name" in self.group[asdn].attrs:
+ signal_names.append(self.group[asdn].attrs["long_name"])
+ else:
+ signal_names.append(asdn)
+ return signal_names
+
+ @property
+ def auxiliary_signals(self):
+ """List of all auxiliary signal datasets."""
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ return [self.group[dsname] for dsname in self.auxiliary_signals_dataset_names]
+
+ @property
+ def interpretation(self):
+ """*@interpretation* attribute associated with the *signal*
+ dataset of the NXdata group. ``None`` if no interpretation
+ attribute is present.
+
+ The *interpretation* attribute provides information about the last
+ dimensions of the signal. The allowed values are:
+
+ - *"scalar"*: 0-D data to be plotted
+ - *"spectrum"*: 1-D data to be plotted
+ - *"image"*: 2-D data to be plotted
+ - *"vertex"*: 3-D data to be plotted
+
+ For example, a 3-D signal with interpretation *"spectrum"* should be
+ considered to be a 2-D array of 1-D data. A 3-D signal with
+ interpretation *"image"* should be interpreted as a 1-D array (a list)
+ of 2-D images. An n-D array with interpretation *"image"* should be
+ interpreted as an (n-2)-D array of images.
+
+ A warning message is logged if the returned interpretation is not one
+ of the allowed values, but no error is raised and the unknown
+ interpretation is returned anyway.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ allowed_interpretations = [None, "scaler", "scalar", "spectrum", "image",
+ "rgba-image", # "hsla-image", "cmyk-image"
+ "vertex"]
+
+ interpretation = get_attr_as_unicode(self.signal, "interpretation")
+ if interpretation is None:
+ interpretation = get_attr_as_unicode(self.group, "interpretation")
+
+ if interpretation not in allowed_interpretations:
+ nxdata_logger.warning("Interpretation %s is not valid." % interpretation +
+ " Valid values: " + ", ".join(str(s) for s in allowed_interpretations))
+ return interpretation
+
+ @property
+ def axes(self):
+ """List of the axes datasets.
+
+ The list typically has as many elements as there are dimensions in the
+ signal dataset, the exception being scatter plots which use a 1D
+ signal and multiple 1D axes of the same size.
+
+ If an axis dataset applies to several dimensions of the signal, it
+ will be repeated in the list.
+
+ If a dimension of the signal has no dimension scale, `None` is
+ inserted in its position in the list.
+
+ .. note::
+
+ The *@axes* attribute should define as many entries as there
+ are dimensions in the signal, to avoid any ambiguity.
+ If this is not the case, this implementation relies on the existence
+ of an *@interpretation* (*spectrum* or *image*) attribute in the
+ *signal* dataset.
+
+ .. note::
+
+ If an axis dataset defines attributes @first_good or @last_good,
+ the output will be a numpy array resulting from slicing that
+ axis (*axis[first_good:last_good + 1]*).
+
+ :rtype: List[Dataset or 1D array or None]
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ if self._axes is not None:
+ # use cache
+ return self._axes
+ axes = []
+ for axis_name in self.axes_dataset_names:
+ if axis_name is None:
+ axes.append(None)
+ else:
+ axes.append(self.group[axis_name])
+
+ # keep only good range of axis data
+ for i, axis in enumerate(axes):
+ if axis is None:
+ continue
+ if "first_good" not in axis.attrs and "last_good" not in axis.attrs:
+ continue
+ fg_idx = axis.attrs.get("first_good", 0)
+ lg_idx = axis.attrs.get("last_good", len(axis) - 1)
+ axes[i] = axis[fg_idx:lg_idx + 1]
+
+ self._axes = axes
+ return self._axes
+
+ @property
+ def axes_dataset_names(self):
+ """List of axes dataset names.
+
+ If an axis dataset applies to several dimensions of the signal, its
+ name will be repeated in the list.
+
+ If a dimension of the signal has no dimension scale (i.e. there is a
+ "." in that position in the *@axes* array), `None` is inserted in the
+ output list in its position.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ numbered_names = [] # used in case of @axis=0 (old spec)
+ axes_dataset_names = get_attr_as_unicode(self.group, "axes")
+ if axes_dataset_names is None:
+ # try @axes on signal dataset (older NXdata specification)
+ axes_dataset_names = get_attr_as_unicode(self.signal, "axes")
+ if axes_dataset_names is not None:
+ # we expect a comma separated string
+ if hasattr(axes_dataset_names, "split"):
+ axes_dataset_names = axes_dataset_names.split(":")
+ else:
+ # try @axis on the individual datasets (oldest NXdata specification)
+ for dsname in self.group:
+ if not is_dataset(self.group[dsname]):
+ continue
+ axis_attr = self.group[dsname].attrs.get("axis")
+ if axis_attr is not None:
+ try:
+ axis_num = int(axis_attr)
+ except (ValueError, TypeError):
+ nxdata_logger.warning("Could not interpret attr @axis as"
+ "int on dataset %s", dsname)
+ continue
+ numbered_names.append((axis_num, dsname))
+
+ ndims = len(self.signal.shape)
+ if axes_dataset_names is None:
+ if numbered_names:
+ axes_dataset_names = []
+ numbers = [a[0] for a in numbered_names]
+ names = [a[1] for a in numbered_names]
+ for i in range(ndims):
+ if i in numbers:
+ axes_dataset_names.append(names[numbers.index(i)])
+ else:
+ axes_dataset_names.append(None)
+ return axes_dataset_names
+ else:
+ return [None] * ndims
+
+ if isinstance(axes_dataset_names, (str, bytes)):
+ axes_dataset_names = [axes_dataset_names]
+
+ for i, axis_name in enumerate(axes_dataset_names):
+ if hasattr(axis_name, "decode"):
+ axis_name = axis_name.decode()
+ if axis_name == ".":
+ axes_dataset_names[i] = None
+
+ if len(axes_dataset_names) != ndims:
+ if self.is_scatter and ndims == 1:
+ # case of a 1D signal with arbitrary number of axes
+ return list(axes_dataset_names)
+ if self.interpretation != "rgba-image":
+ # @axes may only define 1 or 2 axes if @interpretation=spectrum/image.
+ # Use the existing names for the last few dims, and prepend with Nones.
+ assert len(axes_dataset_names) == INTERPDIM[self.interpretation]
+ all_dimensions_names = [None] * (ndims - INTERPDIM[self.interpretation])
+ for axis_name in axes_dataset_names:
+ all_dimensions_names.append(axis_name)
+ else:
+ # 2 axes applying to the first two dimensions.
+ # The 3rd signal dimension is expected to contain 3(4) RGB(A) values.
+ assert len(axes_dataset_names) == 2
+ all_dimensions_names = [axn for axn in axes_dataset_names]
+ all_dimensions_names.append(None)
+ return all_dimensions_names
+
+ return list(axes_dataset_names)
+
+ @property
+ def title(self):
+ """Plot title. If not found, returns an empty string.
+
+ This attribute does not appear in the NXdata specification, but it is
+ implemented in *nexpy* as a dataset named "title" inside the NXdata
+ group. This dataset is expected to contain text.
+
+ Because the *nexpy* approach could cause a conflict if the signal
+ dataset or an axis dataset happened to be called "title", we also
+ support providing the title as an attribute of the NXdata group.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ title = self.group.get("title")
+ data_dataset_names = [self.signal_name] + self.axes_dataset_names
+ if (title is not None and is_dataset(title) and
+ "title" not in data_dataset_names):
+ return str(h5py_read_dataset(title))
+
+ title = self.group.attrs.get("title")
+ if title is None:
+ return ""
+ return str(title)
+
+ def get_axis_errors(self, axis_name):
+ """Return errors (uncertainties) associated with an axis.
+
+ If the axis has attributes @first_good or @last_good, the output
+ is trimmed accordingly (a numpy array will be returned rather than a
+ dataset).
+
+ :param str axis_name: Name of axis dataset. This dataset **must exist**.
+ :return: Dataset with axis errors, or None
+ :raise KeyError: if this group does not contain a dataset named axis_name
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ # ensure axis_name is decoded, before comparing it with decoded attributes
+ if hasattr(axis_name, "decode"):
+ axis_name = axis_name.decode("utf-8")
+ if axis_name not in self.group:
+ # tolerate axis_name given as @long_name
+ for item in self.group:
+ long_name = get_attr_as_unicode(self.group[item], "long_name")
+ if long_name is not None and long_name == axis_name:
+ axis_name = item
+ break
+
+ if axis_name not in self.group:
+ raise KeyError("group does not contain a dataset named '%s'" % axis_name)
+
+ len_axis = len(self.group[axis_name])
+
+ fg_idx = self.group[axis_name].attrs.get("first_good", 0)
+ lg_idx = self.group[axis_name].attrs.get("last_good", len_axis - 1)
+
+ # case of axisname_errors dataset present
+ errors_name = axis_name + "_errors"
+ if errors_name in self.group and is_dataset(self.group[errors_name]):
+ if fg_idx != 0 or lg_idx != (len_axis - 1):
+ return self.group[errors_name][fg_idx:lg_idx + 1]
+ else:
+ return self.group[errors_name]
+ # case of uncertainties dataset name provided in @uncertainties
+ uncertainties_names = get_attr_as_unicode(self.group, "uncertainties")
+ if uncertainties_names is None:
+ uncertainties_names = get_attr_as_unicode(self.signal, "uncertainties")
+ if isinstance(uncertainties_names, str):
+ uncertainties_names = [uncertainties_names]
+ if uncertainties_names is not None:
+ # take the uncertainty with the same index as the axis in @axes
+ axes_ds_names = get_attr_as_unicode(self.group, "axes")
+ if axes_ds_names is None:
+ axes_ds_names = get_attr_as_unicode(self.signal, "axes")
+ if isinstance(axes_ds_names, str):
+ axes_ds_names = [axes_ds_names]
+ elif isinstance(axes_ds_names, numpy.ndarray):
+ # transform numpy.ndarray into list
+ axes_ds_names = list(axes_ds_names)
+ assert isinstance(axes_ds_names, list)
+ if hasattr(axes_ds_names[0], "decode"):
+ axes_ds_names = [ax_name.decode("utf-8") for ax_name in axes_ds_names]
+ if axis_name not in axes_ds_names:
+ raise KeyError("group attr @axes does not mention a dataset " +
+ "named '%s'" % axis_name)
+ errors = self.group[uncertainties_names[list(axes_ds_names).index(axis_name)]]
+ if fg_idx == 0 and lg_idx == (len_axis - 1):
+ return errors # dataset
+ else:
+ return errors[fg_idx:lg_idx + 1] # numpy array
+ return None
+
+ @property
+ def errors(self):
+ """Return errors (uncertainties) associated with the signal values.
+
+ :return: Dataset with errors, or None
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ # case of signal
+ signal_errors = self.signal_dataset_name + "_errors"
+ if "errors" in self.group and is_dataset(self.group["errors"]):
+ errors = "errors"
+ elif signal_errors in self.group and is_dataset(self.group[signal_errors]):
+ errors = signal_errors
+ else:
+ return None
+ return self.group[errors]
+
+ @property
+ def plot_style(self):
+ """Information extracted from the optional SILX_style attribute
+
+ :raises: InvalidNXdataError
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ return self._plot_style
+
+ @property
+ def is_scatter(self):
+ """True if the signal is 1D and all the axes have the
+ same size as the signal."""
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ if self._is_scatter is not None:
+ return self._is_scatter
+ if not self.signal_is_1d:
+ self._is_scatter = False
+ else:
+ self._is_scatter = True
+ sigsize = 1
+ for dim in self.signal.shape:
+ sigsize *= dim
+ for axis in self.axes:
+ if axis is None:
+ continue
+ axis_size = 1
+ for dim in axis.shape:
+ axis_size *= dim
+ self._is_scatter = self._is_scatter and (axis_size == sigsize)
+ return self._is_scatter
+
+ @property
+ def is_x_y_value_scatter(self):
+ """True if this is a scatter with a signal and two axes."""
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ return self.is_scatter and len(self.axes) == 2
+
+ # we currently have no widget capable of plotting 4D data
+ @property
+ def is_unsupported_scatter(self):
+ """True if this is a scatter with a signal and more than 2 axes."""
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ return self.is_scatter and len(self.axes) > 2
+
+ @property
+ def is_curve(self):
+ """This property is True if the signal is 1D or :attr:`interpretation` is
+ *"spectrum"*, and there is at most one axis with a consistent length.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ if self.signal_is_0d or self.interpretation not in [None, "spectrum"]:
+ return False
+ # the axis, if any, must be of the same length as the last dimension
+ # of the signal, or of length 2 (a + b *x scale)
+ if self.axes[-1] is not None and len(self.axes[-1]) not in [
+ self.signal.shape[-1], 2]:
+ return False
+ if self.interpretation is None:
+ # We no longer test whether x values are monotonic
+ # (in the past, in that case, we used to consider it a scatter)
+ return self.signal_is_1d
+ # everything looks good
+ return True
+
+ @property
+ def is_image(self):
+ """True if the signal is 2D, or 3D with last dimension of length 3 or 4
+ and interpretation *rgba-image*, or >2D with interpretation *image*.
+ The axes (if any) length must also be consistent with the signal shape.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ if self.interpretation in ["scalar", "spectrum", "scaler"]:
+ return False
+ if self.signal_is_0d or self.signal_is_1d:
+ return False
+ if not self.signal_is_2d and \
+ self.interpretation not in ["image", "rgba-image"]:
+ return False
+ if self.signal_is_3d and self.interpretation == "rgba-image":
+ if self.signal.shape[-1] not in [3, 4]:
+ return False
+ img_axes = self.axes[0:2]
+ img_shape = self.signal.shape[0:2]
+ else:
+ img_axes = self.axes[-2:]
+ img_shape = self.signal.shape[-2:]
+ for i, axis in enumerate(img_axes):
+ if axis is not None and len(axis) not in [img_shape[i], 2]:
+ return False
+
+ return True
+
+ @property
+ def is_stack(self):
+ """True in the signal is at least 3D and interpretation is not
+ "scalar", "spectrum", "image" or "rgba-image".
+ The axes length must also be consistent with the last 3 dimensions
+ of the signal.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ if self.signal_ndim < 3 or self.interpretation in [
+ "scalar", "scaler", "spectrum", "image", "rgba-image"]:
+ return False
+ stack_shape = self.signal.shape[-3:]
+ for i, axis in enumerate(self.axes[-3:]):
+ if axis is not None and len(axis) not in [stack_shape[i], 2]:
+ return False
+ return True
+
+ @property
+ def is_volume(self):
+ """True in the signal is exactly 3D and interpretation
+ "scalar", or nothing.
+
+ The axes length must also be consistent with the 3 dimensions
+ of the signal.
+ """
+ if not self.is_valid:
+ raise InvalidNXdataError("Unable to parse invalid NXdata")
+
+ if self.signal_ndim != 3:
+ return False
+ if self.interpretation not in [None, "scalar", "scaler"]:
+ # 'scaler' and 'scalar' for a three dimensional array indicate a scalar field in 3D
+ return False
+ volume_shape = self.signal.shape[-3:]
+ for i, axis in enumerate(self.axes[-3:]):
+ if axis is not None and len(axis) not in [volume_shape[i], 2]:
+ return False
+ return True
+
+
+def is_valid_nxdata(group): # noqa
+ """Check if a h5py group is a **valid** NX_data group.
+
+ :param group: h5py-like group
+ :return: True if this NXdata group is valid.
+ :raise TypeError: if group is not a h5py group, a spech5 group,
+ or a fabioh5 group
+ """
+ nxd = NXdata(group)
+ return nxd.is_valid
+
+
+def is_group_with_default_NXdata(group, validate=True):
+ """Return True if group defines a valid default
+ NXdata.
+
+ .. note::
+
+ See https://github.com/silx-kit/silx/issues/2215
+
+ :param group: h5py-like object.
+ :param bool validate: Set this to skip the NXdata validation, and only
+ check the existence of the group.
+ Parameter provided for optimisation purposes, to avoid double
+ validation if the validation is already performed separately."""
+ default_nxdata_name = group.attrs.get("default")
+ if default_nxdata_name is None or default_nxdata_name not in group:
+ return False
+
+ default_nxdata_group = group.get(default_nxdata_name)
+
+ if not is_group(default_nxdata_group):
+ return False
+
+ if not validate:
+ return True
+ else:
+ return is_valid_nxdata(default_nxdata_group)
+
+
+def is_NXentry_with_default_NXdata(group, validate=True):
+ """Return True if group is a valid NXentry defining a valid default
+ NXdata.
+
+ :param group: h5py-like object.
+ :param bool validate: Set this to skip the NXdata validation, and only
+ check the existence of the group.
+ Parameter provided for optimisation purposes, to avoid double
+ validation if the validation is already performed separately."""
+ if not is_group(group):
+ return False
+
+ if get_attr_as_unicode(group, "NX_class") != "NXentry":
+ return False
+
+ return is_group_with_default_NXdata(group, validate)
+
+
+def is_NXroot_with_default_NXdata(group, validate=True):
+ """Return True if group is a valid NXroot defining a default NXentry
+ defining a valid default NXdata.
+
+ .. note::
+
+ A NXroot group cannot directly define a default NXdata. If a
+ *@default* argument is present, it must point to a NXentry group.
+ This NXentry must define a valid NXdata for this function to return
+ True.
+
+ :param group: h5py-like object.
+ :param bool validate: Set this to False if you are sure that the target group
+ is valid NXdata (i.e. :func:`silx.io.nxdata.is_valid_nxdata(target_group)`
+ returns True). Parameter provided for optimisation purposes.
+ """
+ if not is_group(group):
+ return False
+
+ # A NXroot is supposed to be at the root of a data file, and @NX_class
+ # is therefore optional. We accept groups that are not located at the root
+ # if they have @NX_class=NXroot (use case: several nexus files archived
+ # in a single HDF5 file)
+ if get_attr_as_unicode(group, "NX_class") != "NXroot" and not is_file(group):
+ return False
+
+ default_nxentry_name = group.attrs.get("default")
+ if default_nxentry_name is None or default_nxentry_name not in group:
+ return False
+
+ default_nxentry_group = group.get(default_nxentry_name)
+ return is_NXentry_with_default_NXdata(default_nxentry_group,
+ validate=validate)
+
+
+def get_default(group, validate=True):
+ """Return a :class:`NXdata` object corresponding to the default NXdata group
+ in the group specified as parameter.
+
+ This function can find the NXdata if the group is already a NXdata, or
+ if it is a NXentry defining a default NXdata, or if it is a NXroot
+ defining such a default valid NXentry.
+
+ Return None if no valid NXdata could be found.
+
+ :param group: h5py-like group following the Nexus specification
+ (NXdata, NXentry or NXroot).
+ :param bool validate: Set this to False if you are sure that group
+ is valid NXdata (i.e. :func:`silx.io.nxdata.is_valid_nxdata(group)`
+ returns True). Parameter provided for optimisation purposes.
+ :return: :class:`NXdata` object or None
+ :raise TypeError: if group is not a h5py-like group
+ """
+ if not is_group(group):
+ raise TypeError("Provided parameter is not a h5py-like group")
+
+ if is_NXroot_with_default_NXdata(group, validate=validate):
+ default_entry = group[group.attrs["default"]]
+ default_data = default_entry[default_entry.attrs["default"]]
+ elif is_group_with_default_NXdata(group, validate=validate):
+ default_data = group[group.attrs["default"]]
+ elif not validate or is_valid_nxdata(group):
+ default_data = group
+ else:
+ return None
+
+ return NXdata(default_data, validate=False)
diff --git a/src/silx/io/nxdata/write.py b/src/silx/io/nxdata/write.py
new file mode 100644
index 0000000..9e84240
--- /dev/null
+++ b/src/silx/io/nxdata/write.py
@@ -0,0 +1,202 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+import os
+import logging
+
+import h5py
+import numpy
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/04/2018"
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _str_to_utf8(text):
+ return numpy.array(text, dtype=h5py.special_dtype(vlen=str))
+
+
+def save_NXdata(filename, signal, axes=None,
+ signal_name="data", axes_names=None,
+ signal_long_name=None, axes_long_names=None,
+ signal_errors=None, axes_errors=None,
+ title=None, interpretation=None,
+ nxentry_name="entry", nxdata_name=None):
+ """Write data to an NXdata group.
+
+ .. note::
+
+ No consistency checks are made regarding the dimensionality of the
+ signal and number of axes. The user is responsible for providing
+ meaningful data, that can be interpreted by visualization software.
+
+ :param str filename: Path to output file. If the file does not
+ exists, it is created.
+ :param numpy.ndarray signal: Signal array.
+ :param List[numpy.ndarray] axes: List of axes arrays.
+ :param str signal_name: Name of signal dataset, in output file
+ :param List[str] axes_names: List of dataset names for axes, in
+ output file
+ :param str signal_long_name: *@long_name* attribute for signal, or None.
+ :param axes_long_names: None, or list of long names
+ for axes
+ :type axes_long_names: List[str, None]
+ :param numpy.ndarray signal_errors: Array of errors associated with the
+ signal
+ :param axes_errors: List of arrays of errors
+ associated with each axis
+ :type axes_errors: List[numpy.ndarray, None]
+ :param str title: Graph title (saved as a "title" dataset) or None.
+ :param str interpretation: *@interpretation* attribute ("spectrum",
+ "image", "rgba-image" or None). This is only needed in cases of
+ ambiguous dimensionality, e.g. a 3D array which represents a RGBA
+ image rather than a stack.
+ :param str nxentry_name: Name of group in which the NXdata group
+ is created. By default, "/entry" is used.
+
+ .. note::
+
+ The Nexus format specification requires for NXdata groups
+ be part of a NXentry group.
+ The specified group should have attribute *@NX_class=NXentry*, in
+ order for the created file to be nexus compliant.
+ :param str nxdata_name: Name of NXdata group. If omitted (None), the
+ function creates a new group using the first available name ("data0",
+ or "data1"...).
+ Overwriting an existing group (or dataset) is not supported, you must
+ delete it yourself prior to calling this function if this is what you
+ want.
+ :return: True if save was successful, else False.
+ """
+ if h5py is None:
+ raise ImportError("h5py could not be imported, but is required by "
+ "save_NXdata function")
+
+ if axes_names is not None:
+ assert axes is not None, "Axes names defined, but missing axes arrays"
+ assert len(axes) == len(axes_names), \
+ "Mismatch between number of axes and axes_names"
+
+ if axes is not None and axes_names is None:
+ axes_names = []
+ for i, axis in enumerate(axes):
+ axes_names.append("dim%d" % i if axis is not None else ".")
+ if axes is None:
+ axes = []
+
+ # Open file in
+ if os.path.exists(filename):
+ errmsg = "Cannot write/append to existing path %s"
+ if not os.path.isfile(filename):
+ errmsg += " (not a file)"
+ _logger.error(errmsg, filename)
+ return False
+ if not os.access(filename, os.W_OK):
+ errmsg += " (no permission to write)"
+ _logger.error(errmsg, filename)
+ return False
+ mode = "r+"
+ else:
+ mode = "w-"
+
+ with h5py.File(filename, mode=mode) as h5f:
+ # get or create entry
+ if nxentry_name is not None:
+ entry = h5f.require_group(nxentry_name)
+ if "default" not in h5f.attrs:
+ # set this entry as default
+ h5f.attrs["default"] = _str_to_utf8(nxentry_name)
+ if "NX_class" not in entry.attrs:
+ entry.attrs["NX_class"] = u"NXentry"
+ else:
+ # write NXdata into the root of the file (invalid nexus!)
+ entry = h5f
+
+ # Create NXdata group
+ if nxdata_name is not None:
+ if nxdata_name in entry:
+ _logger.error("Cannot assign an NXdata group to an existing"
+ " group or dataset")
+ return False
+ else:
+ # no name specified, take one that is available
+ nxdata_name = "data0"
+ i = 1
+ while nxdata_name in entry:
+ _logger.info("%s item already exists in NXentry group," +
+ " trying %s", nxdata_name, "data%d" % i)
+ nxdata_name = "data%d" % i
+ i += 1
+
+ data_group = entry.create_group(nxdata_name)
+ data_group.attrs["NX_class"] = u"NXdata"
+ data_group.attrs["signal"] = _str_to_utf8(signal_name)
+ if axes:
+ data_group.attrs["axes"] = _str_to_utf8(axes_names)
+ if title:
+ # not in NXdata spec, but implemented by nexpy
+ data_group["title"] = title
+ # better way imho
+ data_group.attrs["title"] = _str_to_utf8(title)
+
+ signal_dataset = data_group.create_dataset(signal_name,
+ data=signal)
+ if signal_long_name:
+ signal_dataset.attrs["long_name"] = _str_to_utf8(signal_long_name)
+ if interpretation:
+ signal_dataset.attrs["interpretation"] = _str_to_utf8(interpretation)
+
+ for i, axis_array in enumerate(axes):
+ if axis_array is None:
+ assert axes_names[i] in [".", None], \
+ "Axis name defined for dim %d but no axis array" % i
+ continue
+ axis_dataset = data_group.create_dataset(axes_names[i],
+ data=axis_array)
+ if axes_long_names is not None:
+ axis_dataset.attrs["long_name"] = _str_to_utf8(axes_long_names[i])
+
+ if signal_errors is not None:
+ data_group.create_dataset("errors",
+ data=signal_errors)
+
+ if axes_errors is not None:
+ assert isinstance(axes_errors, (list, tuple)), \
+ "axes_errors must be a list or a tuple of ndarray or None"
+ assert len(axes_errors) == len(axes_names), \
+ "Mismatch between number of axes_errors and axes_names"
+ for i, axis_errors in enumerate(axes_errors):
+ if axis_errors is not None:
+ dsname = axes_names[i] + "_errors"
+ data_group.create_dataset(dsname,
+ data=axis_errors)
+ if "default" not in entry.attrs:
+ # set this NXdata as default
+ entry.attrs["default"] = nxdata_name
+
+ return True
diff --git a/silx/io/octaveh5.py b/src/silx/io/octaveh5.py
index 84fa726..84fa726 100644
--- a/silx/io/octaveh5.py
+++ b/src/silx/io/octaveh5.py
diff --git a/silx/io/rawh5.py b/src/silx/io/rawh5.py
index ceabbdb..ceabbdb 100644
--- a/silx/io/rawh5.py
+++ b/src/silx/io/rawh5.py
diff --git a/silx/io/setup.py b/src/silx/io/setup.py
index 9cafa17..9cafa17 100644
--- a/silx/io/setup.py
+++ b/src/silx/io/setup.py
diff --git a/src/silx/io/specfile.pyx b/src/silx/io/specfile.pyx
new file mode 100644
index 0000000..cb9e1a5
--- /dev/null
+++ b/src/silx/io/specfile.pyx
@@ -0,0 +1,1268 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+This module is a cython binding to wrap the C SpecFile library, to access
+SpecFile data within a python program.
+
+Documentation for the original C library SpecFile can be found on the ESRF
+website:
+`The manual for the SpecFile Library <http://ftp.esrf.fr/pub/scisoft/silx/doc/SpecFileManual.pdf>`_
+
+Examples
+========
+
+Start by importing :class:`SpecFile` and instantiate it:
+
+.. code-block:: python
+
+ from silx.io.specfile import SpecFile
+
+ sf = SpecFile("test.dat")
+
+A :class:`SpecFile` instance can be accessed like a dictionary to obtain a
+:class:`Scan` instance.
+
+If the key is a string representing two values
+separated by a dot (e.g. ``"1.2"``), they will be treated as the scan number
+(``#S`` header line) and the scan order::
+
+ # get second occurrence of scan "#S 1"
+ myscan = sf["1.2"]
+
+ # access scan data as a numpy array
+ nlines, ncolumns = myscan.data.shape
+
+If the key is an integer, it will be treated as a 0-based index::
+
+ first_scan = sf[0]
+ second_scan = sf[1]
+
+It is also possible to browse through all scans using :class:`SpecFile` as
+an iterator::
+
+ for scan in sf:
+ print(scan.scan_header_dict['S'])
+
+MCA spectra can be selectively loaded using an instance of :class:`MCA`
+provided by :class:`Scan`::
+
+ # Only one MCA spectrum is loaded in memory
+ second_mca = first_scan.mca[1]
+
+ # Iterating trough all MCA spectra in a scan:
+ for mca_data in first_scan.mca:
+ print(sum(mca_data))
+
+Classes
+=======
+
+- :class:`SpecFile`
+- :class:`Scan`
+- :class:`MCA`
+
+Exceptions
+==========
+
+- :class:`SfError`
+- :class:`SfErrMemoryAlloc`
+- :class:`SfErrFileOpen`
+- :class:`SfErrFileClose`
+- :class:`SfErrFileRead`
+- :class:`SfErrFileWrite`
+- :class:`SfErrLineNotFound`
+- :class:`SfErrScanNotFound`
+- :class:`SfErrHeaderNotFound`
+- :class:`SfErrLabelNotFound`
+- :class:`SfErrMotorNotFound`
+- :class:`SfErrPositionNotFound`
+- :class:`SfErrLineEmpty`
+- :class:`SfErrUserNotFound`
+- :class:`SfErrColNotFound`
+- :class:`SfErrMcaNotFound`
+
+"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "11/08/2017"
+
+import os.path
+import logging
+import numpy
+import re
+import sys
+
+_logger = logging.getLogger(__name__)
+
+cimport cython
+from libc.stdlib cimport free
+
+cimport silx.io.specfile_wrapper as specfile_wrapper
+
+
+SF_ERR_NO_ERRORS = 0
+SF_ERR_FILE_OPEN = 2
+SF_ERR_SCAN_NOT_FOUND = 7
+
+
+# custom errors
+class SfError(Exception):
+ """Base exception inherited by all exceptions raised when a
+ C function from the legacy SpecFile library returns an error
+ code.
+ """
+ pass
+
+class SfErrMemoryAlloc(SfError, MemoryError): pass
+class SfErrFileOpen(SfError, IOError): pass
+class SfErrFileClose(SfError, IOError): pass
+class SfErrFileRead(SfError, IOError): pass
+class SfErrFileWrite(SfError, IOError): pass
+class SfErrLineNotFound(SfError, KeyError): pass
+class SfErrScanNotFound(SfError, IndexError): pass
+class SfErrHeaderNotFound(SfError, KeyError): pass
+class SfErrLabelNotFound(SfError, KeyError): pass
+class SfErrMotorNotFound(SfError, KeyError): pass
+class SfErrPositionNotFound(SfError, KeyError): pass
+class SfErrLineEmpty(SfError, IOError): pass
+class SfErrUserNotFound(SfError, KeyError): pass
+class SfErrColNotFound(SfError, KeyError): pass
+class SfErrMcaNotFound(SfError, IndexError): pass
+
+
+ERRORS = {
+ 1: SfErrMemoryAlloc,
+ 2: SfErrFileOpen,
+ 3: SfErrFileClose,
+ 4: SfErrFileRead,
+ 5: SfErrFileWrite,
+ 6: SfErrLineNotFound,
+ 7: SfErrScanNotFound,
+ 8: SfErrHeaderNotFound,
+ 9: SfErrLabelNotFound,
+ 10: SfErrMotorNotFound,
+ 11: SfErrPositionNotFound,
+ 12: SfErrLineEmpty,
+ 13: SfErrUserNotFound,
+ 14: SfErrColNotFound,
+ 15: SfErrMcaNotFound,
+}
+
+
+class SfNoMcaError(SfError):
+ """Custom exception raised when ``SfNoMca()`` returns ``-1``
+ """
+ pass
+
+
+class MCA(object):
+ """
+
+ :param scan: Parent Scan instance
+ :type scan: :class:`Scan`
+
+ :var calibration: MCA calibration :math:`(a, b, c)` (as in
+ :math:`a + b x + c x²`) from ``#@CALIB`` scan header.
+ :type calibration: list of 3 floats, default ``[0., 1., 0.]``
+ :var channels: MCA channels list from ``#@CHANN`` scan header.
+ In the absence of a ``#@CHANN`` header, this attribute is a list
+ ``[0, …, N-1]`` where ``N`` is the length of the first spectrum.
+ In the absence of MCA spectra, this attribute defaults to ``None``.
+ :type channels: list of int
+
+ This class provides access to Multi-Channel Analysis data. A :class:`MCA`
+ instance can be indexed to access 1D numpy arrays representing single
+ MCA spectra.
+
+ To create a :class:`MCA` instance, you must provide a parent :class:`Scan`
+ instance, which in turn will provide a reference to the original
+ :class:`SpecFile` instance::
+
+ sf = SpecFile("/path/to/specfile.dat")
+ scan2 = Scan(sf, scan_index=2)
+ mcas_in_scan2 = MCA(scan2)
+ for i in len(mcas_in_scan2):
+ mca_data = mcas_in_scan2[i]
+ ... # do some something with mca_data (1D numpy array)
+
+ A more pythonic way to do the same work, without having to explicitly
+ instantiate ``scan`` and ``mcas_in_scan``, would be::
+
+ sf = SpecFile("specfilename.dat")
+ # scan2 from previous example can be referred to as sf[2]
+ # mcas_in_scan2 from previous example can be referred to as scan2.mca
+ for mca_data in sf[2].mca:
+ ... # do some something with mca_data (1D numpy array)
+
+ """
+ def __init__(self, scan):
+ self._scan = scan
+
+ # Header dict
+ self._header = scan.mca_header_dict
+
+ self.calibration = []
+ """List of lists of calibration values,
+ one list of 3 floats per MCA device or a single list applying to
+ all devices """
+ self._parse_calibration()
+
+ self.channels = []
+ """List of lists of channels,
+ one list of integers per MCA device or a single list applying to
+ all devices"""
+ self._parse_channels()
+
+ def _parse_channels(self):
+ """Fill :attr:`channels`"""
+ # Channels list
+ if "CHANN" in self._header:
+ chann_lines = self._header["CHANN"].split("\n")
+ all_chann_values = [chann_line.split() for chann_line in chann_lines]
+ for one_line_chann_values in all_chann_values:
+ length, start, stop, increment = map(int, one_line_chann_values)
+ self.channels.append(list(range(start, stop + 1, increment)))
+ elif len(self):
+ # in the absence of #@CHANN, use shape of first MCA
+ length = self[0].shape[0]
+ start, stop, increment = (0, length - 1, 1)
+ self.channels.append(list(range(start, stop + 1, increment)))
+
+ def _parse_calibration(self):
+ """Fill :attr:`calibration`"""
+ # Channels list
+ if "CALIB" in self._header:
+ calib_lines = self._header["CALIB"].split("\n")
+ all_calib_values = [calib_line.split() for calib_line in calib_lines]
+ for one_line_calib_values in all_calib_values:
+ self.calibration.append(list(map(float, one_line_calib_values)))
+ else:
+ # in the absence of #@calib, use default
+ self.calibration.append([0., 1., 0.])
+
+ def __len__(self):
+ """
+
+ :return: Number of mca in Scan
+ :rtype: int
+ """
+ return self._scan._specfile.number_of_mca(self._scan.index)
+
+ def __getitem__(self, key):
+ """Return a single MCA data line
+
+ :param key: 0-based index of MCA within Scan
+ :type key: int
+
+ :return: Single MCA
+ :rtype: 1D numpy array
+ """
+ if not len(self):
+ raise IndexError("No MCA spectrum found in this scan")
+
+ if isinstance(key, (int, long)):
+ mca_index = key
+ # allow negative index, like lists
+ if mca_index < 0:
+ mca_index = len(self) + mca_index
+ else:
+ raise TypeError("MCA index should be an integer (%s provided)" %
+ (type(key)))
+
+ if not 0 <= mca_index < len(self):
+ msg = "MCA index must be in range 0-%d" % (len(self) - 1)
+ raise IndexError(msg)
+
+ return self._scan._specfile.get_mca(self._scan.index,
+ mca_index)
+
+ def __iter__(self):
+ """Return the next MCA data line each time this method is called.
+
+ :return: Single MCA
+ :rtype: 1D numpy array
+ """
+ for mca_index in range(len(self)):
+ yield self._scan._specfile.get_mca(self._scan.index, mca_index)
+
+
+def _add_or_concatenate(dictionary, key, value):
+ """If key doesn't exist in dictionary, create a new ``key: value`` pair.
+ Else append/concatenate the new value to the existing one
+ """
+ try:
+ if key not in dictionary:
+ dictionary[key] = value
+ else:
+ dictionary[key] += "\n" + value
+ except TypeError:
+ raise TypeError("Parameter value must be a string.")
+
+
+class Scan(object):
+ """
+
+ :param specfile: Parent SpecFile from which this scan is extracted.
+ :type specfile: :class:`SpecFile`
+ :param scan_index: Unique index defining the scan in the SpecFile
+ :type scan_index: int
+
+ Interface to access a SpecFile scan
+
+ A scan is a block of descriptive header lines followed by a 2D data array.
+
+ Following three ways of accessing a scan are equivalent::
+
+ sf = SpecFile("/path/to/specfile.dat")
+
+ # Explicit class instantiation
+ scan2 = Scan(sf, scan_index=2)
+
+ # 0-based index on a SpecFile object
+ scan2 = sf[2]
+
+ # Using a "n.m" key (scan number starting with 1, scan order)
+ scan2 = sf["3.1"]
+ """
+ def __init__(self, specfile, scan_index):
+ self._specfile = specfile
+
+ self._index = scan_index
+ self._number = specfile.number(scan_index)
+ self._order = specfile.order(scan_index)
+
+ self._scan_header_lines = self._specfile.scan_header(self._index)
+ self._file_header_lines = self._specfile.file_header(self._index)
+
+ if self._file_header_lines == self._scan_header_lines:
+ self._file_header_lines = []
+ self._header = self._file_header_lines + self._scan_header_lines
+
+ self._scan_header_dict = {}
+ self._mca_header_dict = {}
+ for line in self._scan_header_lines:
+ match = re.search(r"#(\w+) *(.*)", line)
+ match_mca = re.search(r"#@(\w+) *(.*)", line)
+ if match:
+ hkey = match.group(1).lstrip("#").strip()
+ hvalue = match.group(2).strip()
+ _add_or_concatenate(self._scan_header_dict, hkey, hvalue)
+ elif match_mca:
+ hkey = match_mca.group(1).lstrip("#").strip()
+ hvalue = match_mca.group(2).strip()
+ _add_or_concatenate(self._mca_header_dict, hkey, hvalue)
+ else:
+ # this shouldn't happen
+ _logger.warning("Unable to parse scan header line " + line)
+
+ self._labels = []
+ if self.record_exists_in_hdr('L'):
+ try:
+ self._labels = self._specfile.labels(self._index)
+ except SfErrLineNotFound:
+ # SpecFile.labels raises an IndexError when encountering
+ # a Scan with no data, even if the header exists.
+ L_header = re.sub(r" {2,}", " ", # max. 2 spaces
+ self._scan_header_dict["L"])
+ self._labels = L_header.split(" ")
+
+ self._file_header_dict = {}
+ for line in self._file_header_lines:
+ match = re.search(r"#(\w+) *(.*)", line)
+ if match:
+ # header type
+ hkey = match.group(1).lstrip("#").strip()
+ hvalue = match.group(2).strip()
+ _add_or_concatenate(self._file_header_dict, hkey, hvalue)
+ else:
+ _logger.warning("Unable to parse file header line " + line)
+
+ self._motor_names = self._specfile.motor_names(self._index)
+ self._motor_positions = self._specfile.motor_positions(self._index)
+
+ self._data = None
+ self._mca = None
+
+ @cython.embedsignature(False)
+ @property
+ def index(self):
+ """Unique scan index 0 - len(specfile)-1
+
+ This attribute is implemented as a read-only property as changing
+ its value may cause nasty side-effects (such as loading data from a
+ different scan without updating the header accordingly."""
+ return self._index
+
+ @cython.embedsignature(False)
+ @property
+ def number(self):
+ """First value on #S line (as int)"""
+ return self._number
+
+ @cython.embedsignature(False)
+ @property
+ def order(self):
+ """Order can be > 1 if the same number is repeated in a specfile"""
+ return self._order
+
+ @cython.embedsignature(False)
+ @property
+ def header(self):
+ """List of raw header lines (as a list of strings).
+
+ This includes the file header, the scan header and possibly a MCA
+ header.
+ """
+ return self._header
+
+ @cython.embedsignature(False)
+ @property
+ def scan_header(self):
+ """List of raw scan header lines (as a list of strings).
+ """
+ return self._scan_header_lines
+
+ @cython.embedsignature(False)
+ @property
+ def file_header(self):
+ """List of raw file header lines (as a list of strings).
+ """
+ return self._file_header_lines
+
+ @cython.embedsignature(False)
+ @property
+ def scan_header_dict(self):
+ """
+ Dictionary of scan header strings, keys without the leading``#``
+ (e.g. ``scan_header_dict["S"]``).
+ Note: this does not include MCA header lines starting with ``#@``.
+ """
+ return self._scan_header_dict
+
+ @cython.embedsignature(False)
+ @property
+ def mca_header_dict(self):
+ """
+ Dictionary of MCA header strings, keys without the leading ``#@``
+ (e.g. ``mca_header_dict["CALIB"]``).
+ """
+ return self._mca_header_dict
+
+ @cython.embedsignature(False)
+ @property
+ def file_header_dict(self):
+ """
+ Dictionary of file header strings, keys without the leading ``#``
+ (e.g. ``file_header_dict["F"]``).
+ """
+ return self._file_header_dict
+
+ @cython.embedsignature(False)
+ @property
+ def labels(self):
+ """
+ List of data column headers from ``#L`` scan header
+ """
+ return self._labels
+
+ @cython.embedsignature(False)
+ @property
+ def data(self):
+ """Scan data as a 2D numpy.ndarray with the usual attributes
+ (e.g. data.shape).
+
+ The first index is the detector, the second index is the sample index.
+ """
+ if self._data is None:
+ self._data = numpy.transpose(self._specfile.data(self._index))
+
+ return self._data
+
+ @cython.embedsignature(False)
+ @property
+ def mca(self):
+ """MCA data in this scan.
+
+ Each multichannel analysis is a 1D numpy array. Metadata about
+ MCA data is to be found in :py:attr:`mca_header`.
+
+ :rtype: :class:`MCA`
+ """
+ if self._mca is None:
+ self._mca = MCA(self)
+ return self._mca
+
+ @cython.embedsignature(False)
+ @property
+ def motor_names(self):
+ """List of motor names from the ``#O`` file header line.
+ """
+ return self._motor_names
+
+ @cython.embedsignature(False)
+ @property
+ def motor_positions(self):
+ """List of motor positions as floats from the ``#P`` scan header line.
+ """
+ return self._motor_positions
+
+ def record_exists_in_hdr(self, record):
+ """Check whether a scan header line exists.
+
+ This should be used before attempting to retrieve header information
+ using a C function that may crash with a *segmentation fault* if the
+ header isn't defined in the SpecFile.
+
+ :param record: single upper case letter corresponding to the
+ header you want to test (e.g. ``L`` for labels)
+ :type record: str
+
+ :return: True or False
+ :rtype: boolean
+ """
+ for line in self._header:
+ if line.startswith("#" + record):
+ return True
+ return False
+
+ def data_line(self, line_index):
+ """Returns data for a given line of this scan.
+
+ .. note::
+
+ A data line returned by this method, corresponds to a data line
+ in the original specfile (a series of data points, one per
+ detector). In the :attr:`data` array, this line index corresponds
+ to the index in the second dimension (~ column) of the array.
+
+ :param line_index: Index of data line to retrieve (starting with 0)
+ :type line_index: int
+
+ :return: Line data as a 1D array of doubles
+ :rtype: numpy.ndarray
+ """
+ # attribute data corresponds to a transposed version of the original
+ # specfile data (where detectors correspond to columns)
+ return self.data[:, line_index]
+
+ def data_column_by_name(self, label):
+ """Returns a data column
+
+ :param label: Label of data column to retrieve, as defined on the
+ ``#L`` line of the scan header.
+ :type label: str
+
+ :return: Line data as a 1D array of doubles
+ :rtype: numpy.ndarray
+ """
+ try:
+ ret = self._specfile.data_column_by_name(self._index, label)
+ except SfErrLineNotFound:
+ # Could be a "#C Scan aborted after 0 points"
+ _logger.warning("Cannot get data column %s in scan %d.%d",
+ label, self.number, self.order)
+ ret = numpy.empty((0, ), numpy.double)
+ return ret
+
+ def motor_position_by_name(self, name):
+ """Returns the position for a given motor
+
+ :param name: Name of motor, as defined on the ``#O`` line of the
+ file header.
+ :type name: str
+
+ :return: Motor position
+ :rtype: float
+ """
+ return self._specfile.motor_position_by_name(self._index, name)
+
+
+def _string_to_char_star(string_):
+ """Convert a string to ASCII encoded bytes when using python3"""
+ if sys.version_info[0] >= 3 and not isinstance(string_, bytes):
+ return bytes(string_, "ascii")
+ return string_
+
+
+def is_specfile(filename):
+ """Test if a file is a SPEC file, by checking if one of the first two
+ lines starts with *#F* (SPEC file header) or *#S* (scan header).
+
+ :param str filename: File path
+ :return: *True* if file is a SPEC file, *False* if it is not a SPEC file
+ :rtype: bool
+ """
+ if not os.path.isfile(filename):
+ return False
+ # test for presence of #S or #F in first 10 lines
+ with open(filename, "rb") as f:
+ chunk = f.read(2500)
+ for i, line in enumerate(chunk.split(b"\n")):
+ if line.startswith(b"#S ") or line.startswith(b"#F "):
+ return True
+ if i >= 10:
+ break
+ return False
+
+
+cdef class SpecFile(object):
+ """
+
+ :param filename: Path of the SpecFile to read
+
+ This class wraps the main data and header access functions of the C
+ SpecFile library.
+ """
+
+ cdef:
+ specfile_wrapper.SpecFileHandle *handle
+ str filename
+
+ def __cinit__(self, filename):
+ cdef int error = 0
+ self.handle = NULL
+
+ if is_specfile(filename):
+ filename = _string_to_char_star(filename)
+ self.handle = specfile_wrapper.SfOpen(filename, &error)
+ if error:
+ self._handle_error(error)
+ else:
+ # handle_error takes care of raising the correct error,
+ # this causes the destructor to be called
+ self._handle_error(SF_ERR_FILE_OPEN)
+
+ def __init__(self, filename):
+ if not isinstance(filename, str):
+ # decode bytes to str in python 3, str to unicode in python 2
+ self.filename = filename.decode()
+ else:
+ self.filename = filename
+
+ def __dealloc__(self):
+ """Destructor: Calls SfClose(self.handle)"""
+ self.close()
+
+ def close(self):
+ """Close the file descriptor"""
+ # handle is NULL if SfOpen failed
+ if self.handle:
+ if specfile_wrapper.SfClose(self.handle):
+ _logger.warning("Error while closing SpecFile")
+ self.handle = NULL
+
+ def __len__(self):
+ """Return the number of scans in the SpecFile
+ """
+ return specfile_wrapper.SfScanNo(self.handle)
+
+ def __iter__(self):
+ """Return the next :class:`Scan` in a SpecFile each time this method
+ is called.
+
+ This usually happens when the python built-in function ``next()`` is
+ called with a :class:`SpecFile` instance as a parameter, or when a
+ :class:`SpecFile` instance is used as an iterator (e.g. in a ``for``
+ loop).
+ """
+ for scan_index in range(len(self)):
+ yield Scan(self, scan_index)
+
+ def __getitem__(self, key):
+ """Return a :class:`Scan` object.
+
+ This special method is called when a :class:`SpecFile` instance is
+ accessed as a dictionary (e.g. ``sf[key]``).
+
+ :param key: 0-based scan index or ``"n.m"`` key, where ``n`` is the scan
+ number defined on the ``#S`` header line and ``m`` is the order
+ :type key: int or str
+
+ :return: Scan defined by its 0-based index or its ``"n.m"`` key
+ :rtype: :class:`Scan`
+ """
+ msg = "The scan identification key can be an integer representing "
+ msg += "the unique scan index or a string 'N.M' with N being the scan"
+ msg += " number and M the order (eg '2.3')."
+
+ if isinstance(key, int):
+ scan_index = key
+ # allow negative index, like lists
+ if scan_index < 0:
+ scan_index = len(self) + scan_index
+ else:
+ try:
+ (number, order) = map(int, key.split("."))
+ scan_index = self.index(number, order)
+ except (ValueError, SfErrScanNotFound, KeyError):
+ # int() can raise a value error
+ raise KeyError(msg + "\nValid keys: '" +
+ "', '".join(self.keys()) + "'")
+ except AttributeError:
+ # e.g. "AttrErr: 'float' object has no attribute 'split'"
+ raise TypeError(msg)
+
+ if not 0 <= scan_index < len(self):
+ msg = "Scan index must be in range 0-%d" % (len(self) - 1)
+ raise IndexError(msg)
+
+ return Scan(self, scan_index)
+
+ def keys(self):
+ """Returns list of scan keys (eg ``['1.1', '2.1',...]``).
+
+ :return: list of scan keys
+ :rtype: list of strings
+ """
+ ret_list = []
+ list_of_numbers = self._list()
+ count = {}
+
+ for number in list_of_numbers:
+ if number not in count:
+ count[number] = 1
+ else:
+ count[number] += 1
+ ret_list.append(u'%d.%d' % (number, count[number]))
+
+ return ret_list
+
+ def __contains__(self, key):
+ """Return ``True`` if ``key`` is a valid scan key.
+ Valid keys can be a string such as ``"1.1"`` or a 0-based scan index.
+ """
+ return key in (self.keys() + list(range(len(self))))
+
+ def _get_error_string(self, error_code):
+ """Returns the error message corresponding to the error code.
+
+ :param code: Error code
+ :type code: int
+ :return: Human readable error message
+ :rtype: str
+ """
+ return (<bytes> specfile_wrapper.SfError(error_code)).decode()
+
+ def _handle_error(self, error_code):
+ """Inspect error code, raise adequate error type if necessary.
+
+ :param code: Error code
+ :type code: int
+ """
+ error_message = self._get_error_string(error_code)
+ if error_code in ERRORS:
+ raise ERRORS[error_code](error_message)
+
+ def index(self, scan_number, scan_order=1):
+ """Returns scan index from scan number and order.
+
+ :param scan_number: Scan number (possibly non-unique).
+ :type scan_number: int
+ :param scan_order: Scan order.
+ :type scan_order: int default 1
+
+ :return: Unique scan index
+ :rtype: int
+
+
+ Scan indices are increasing from ``0`` to ``len(self)-1`` in the
+ order in which they appear in the file.
+ Scan numbers are defined by users and are not necessarily unique.
+ The scan order for a given scan number increments each time the scan
+ number appears in a given file.
+ """
+ idx = specfile_wrapper.SfIndex(self.handle, scan_number, scan_order)
+ if idx == -1:
+ self._handle_error(SF_ERR_SCAN_NOT_FOUND)
+ return idx - 1
+
+ def number(self, scan_index):
+ """Returns scan number from scan index.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: User defined scan number.
+ :rtype: int
+ """
+ idx = specfile_wrapper.SfNumber(self.handle, scan_index + 1)
+ if idx == -1:
+ self._handle_error(SF_ERR_SCAN_NOT_FOUND)
+ return idx
+
+ def order(self, scan_index):
+ """Returns scan order from scan index.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: Scan order (sequential number incrementing each time a
+ non-unique occurrence of a scan number is encountered).
+ :rtype: int
+ """
+ ordr = specfile_wrapper.SfOrder(self.handle, scan_index + 1)
+ if ordr == -1:
+ self._handle_error(SF_ERR_SCAN_NOT_FOUND)
+ return ordr
+
+ def _list(self):
+ """see documentation of :meth:`list`
+ """
+ cdef:
+ long *scan_numbers
+ int error = SF_ERR_NO_ERRORS
+
+ scan_numbers = specfile_wrapper.SfList(self.handle, &error)
+ self._handle_error(error)
+
+ ret_list = []
+ for i in range(len(self)):
+ ret_list.append(scan_numbers[i])
+
+ free(scan_numbers)
+ return ret_list
+
+ def list(self):
+ """Returns list (1D numpy array) of scan numbers in SpecFile.
+
+ :return: list of scan numbers (from `` #S`` lines) in the same order
+ as in the original SpecFile (e.g ``[1, 1, 2, 3, …]``).
+ :rtype: numpy array
+ """
+ # this method is overloaded in specfilewrapper to output a string
+ # representation of the list
+ return self._list()
+
+ def data(self, scan_index):
+ """Returns data for the specified scan index.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: Complete scan data as a 2D array of doubles
+ :rtype: numpy.ndarray
+ """
+ cdef:
+ double** mydata
+ long* data_info
+ int i, j
+ int error = SF_ERR_NO_ERRORS
+ long nlines, ncolumns, regular
+ double[:, :] ret_array
+
+ sfdata_error = specfile_wrapper.SfData(self.handle,
+ scan_index + 1,
+ &mydata,
+ &data_info,
+ &error)
+ if sfdata_error == -1 and not error:
+ # this has happened in some situations with empty scans (#1759)
+ _logger.warning("SfData returned -1 without an error."
+ " Assuming aborted scan.")
+
+ self._handle_error(error)
+
+ if <long>data_info != 0:
+ nlines = data_info[0]
+ ncolumns = data_info[1]
+ regular = data_info[2]
+ else:
+ nlines = 0
+ ncolumns = 0
+ regular = 0
+
+ ret_array = numpy.empty((nlines, ncolumns), dtype=numpy.double)
+
+ for i in range(nlines):
+ for j in range(ncolumns):
+ ret_array[i, j] = mydata[i][j]
+
+ specfile_wrapper.freeArrNZ(<void ***>&mydata, nlines)
+ free(data_info)
+ return numpy.asarray(ret_array)
+
+ def data_column_by_name(self, scan_index, label):
+ """Returns data column for the specified scan index and column label.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+ :param label: Label of data column, as defined in the ``#L`` line
+ of the scan header.
+ :type label: str
+
+ :return: Data column as a 1D array of doubles
+ :rtype: numpy.ndarray
+ """
+ cdef:
+ double* data_column
+ long i, nlines
+ int error = SF_ERR_NO_ERRORS
+ double[:] ret_array
+
+ label = _string_to_char_star(label)
+
+ nlines = specfile_wrapper.SfDataColByName(self.handle,
+ scan_index + 1,
+ label,
+ &data_column,
+ &error)
+ self._handle_error(error)
+
+ if nlines == -1:
+ # this can happen on empty scans in some situations (see #1759)
+ _logger.warning("SfDataColByName returned -1 without an error."
+ " Assuming aborted scan.")
+ nlines = 0
+
+ ret_array = numpy.empty((nlines,), dtype=numpy.double)
+
+ for i in range(nlines):
+ ret_array[i] = data_column[i]
+
+ free(data_column)
+ return numpy.asarray(ret_array)
+
+ def scan_header(self, scan_index):
+ """Return list of scan header lines.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: List of raw scan header lines
+ :rtype: list of str
+ """
+ cdef:
+ char** lines
+ int error = SF_ERR_NO_ERRORS
+
+ nlines = specfile_wrapper.SfHeader(self.handle,
+ scan_index + 1,
+ "", # no pattern matching
+ &lines,
+ &error)
+
+ self._handle_error(error)
+
+ lines_list = []
+ for i in range(nlines):
+ line = <bytes>lines[i].decode()
+ lines_list.append(line)
+
+ specfile_wrapper.freeArrNZ(<void***>&lines, nlines)
+ return lines_list
+
+ def file_header(self, scan_index=0):
+ """Return list of file header lines.
+
+ A file header contains all lines between a ``#F`` header line and
+ a ``#S`` header line (start of scan). We need to specify a scan
+ number because there can be more than one file header in a given file.
+ A file header applies to all subsequent scans, until a new file
+ header is defined.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: List of raw file header lines
+ :rtype: list of str
+ """
+ cdef:
+ char** lines
+ int error = SF_ERR_NO_ERRORS
+
+ nlines = specfile_wrapper.SfFileHeader(self.handle,
+ scan_index + 1,
+ "", # no pattern matching
+ &lines,
+ &error)
+ self._handle_error(error)
+
+ lines_list = []
+ for i in range(nlines):
+ line = <bytes>lines[i].decode()
+ lines_list.append(line)
+
+ specfile_wrapper.freeArrNZ(<void***>&lines, nlines)
+ return lines_list
+
+ def columns(self, scan_index):
+ """Return number of columns in a scan from the ``#N`` header line
+ (without ``#N`` and scan number)
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: Number of columns in scan from ``#N`` line
+ :rtype: int
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+
+ ncolumns = specfile_wrapper.SfNoColumns(self.handle,
+ scan_index + 1,
+ &error)
+ self._handle_error(error)
+
+ return ncolumns
+
+ def command(self, scan_index):
+ """Return ``#S`` line (without ``#S`` and scan number)
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: S line
+ :rtype: str
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+
+ s_record = <bytes> specfile_wrapper.SfCommand(self.handle,
+ scan_index + 1,
+ &error)
+ self._handle_error(error)
+
+ return s_record.decode()
+
+ def date(self, scan_index=0):
+ """Return date from ``#D`` line
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: Date from ``#D`` line
+ :rtype: str
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+
+ d_line = <bytes> specfile_wrapper.SfDate(self.handle,
+ scan_index + 1,
+ &error)
+ self._handle_error(error)
+
+ return d_line.decode()
+
+ def labels(self, scan_index):
+ """Return all labels from ``#L`` line
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: All labels from ``#L`` line
+ :rtype: list of strings
+ """
+ cdef:
+ char** all_labels
+ int error = SF_ERR_NO_ERRORS
+
+ nlabels = specfile_wrapper.SfAllLabels(self.handle,
+ scan_index + 1,
+ &all_labels,
+ &error)
+ self._handle_error(error)
+
+ labels_list = []
+ for i in range(nlabels):
+ labels_list.append(<bytes>all_labels[i].decode())
+
+ specfile_wrapper.freeArrNZ(<void***>&all_labels, nlabels)
+ return labels_list
+
+ def motor_names(self, scan_index=0):
+ """Return all motor names from ``#O`` lines
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.If not specified, defaults to 0 (meaning the
+ function returns motors names associated with the first scan).
+ This parameter makes a difference only if there are more than
+ on file header in the file, in which case the file header applies
+ to all following scans until a new file header appears.
+ :type scan_index: int
+
+ :return: All motor names
+ :rtype: list of strings
+ """
+ cdef:
+ char** all_motors
+ int error = SF_ERR_NO_ERRORS
+
+ nmotors = specfile_wrapper.SfAllMotors(self.handle,
+ scan_index + 1,
+ &all_motors,
+ &error)
+ self._handle_error(error)
+
+ motors_list = []
+ for i in range(nmotors):
+ motors_list.append(<bytes>all_motors[i].decode())
+
+ specfile_wrapper.freeArrNZ(<void***>&all_motors, nmotors)
+ return motors_list
+
+ def motor_positions(self, scan_index):
+ """Return all motor positions
+
+ :param scan_index: Unique scan index between ``0``
+ and ``len(self)-1``.
+ :type scan_index: int
+
+ :return: All motor positions
+ :rtype: list of double
+ """
+ cdef:
+ double* motor_positions
+ int error = SF_ERR_NO_ERRORS
+
+ nmotors = specfile_wrapper.SfAllMotorPos(self.handle,
+ scan_index + 1,
+ &motor_positions,
+ &error)
+ self._handle_error(error)
+
+ motor_positions_list = []
+ for i in range(nmotors):
+ motor_positions_list.append(motor_positions[i])
+
+ free(motor_positions)
+ return motor_positions_list
+
+ def motor_position_by_name(self, scan_index, name):
+ """Return motor position
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: Specified motor position
+ :rtype: double
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+
+ name = _string_to_char_star(name)
+
+ motor_position = specfile_wrapper.SfMotorPosByName(self.handle,
+ scan_index + 1,
+ name,
+ &error)
+ self._handle_error(error)
+
+ return motor_position
+
+ def number_of_mca(self, scan_index):
+ """Return number of mca spectra in a scan.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: Number of mca spectra.
+ :rtype: int
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+
+ num_mca = specfile_wrapper.SfNoMca(self.handle,
+ scan_index + 1,
+ &error)
+ # error code updating isn't implemented in SfNoMCA
+ if num_mca == -1:
+ raise SfNoMcaError("Failed to retrieve number of MCA " +
+ "(SfNoMca returned -1)")
+ return num_mca
+
+ def mca_calibration(self, scan_index):
+ """Return MCA calibration in the form :math:`a + b x + c x²`
+
+ Raise a KeyError if there is no ``@CALIB`` line in the scan header.
+
+ :param scan_index: Unique scan index between ``0`` and
+ ``len(self)-1``.
+ :type scan_index: int
+
+ :return: MCA calibration as a list of 3 values :math:`(a, b, c)`
+ :rtype: list of floats
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+ double* mca_calib
+
+ mca_calib_error = specfile_wrapper.SfMcaCalib(self.handle,
+ scan_index + 1,
+ &mca_calib,
+ &error)
+
+ # error code updating isn't implemented in SfMcaCalib
+ if mca_calib_error:
+ raise KeyError("MCA calibration line (@CALIB) not found")
+
+ mca_calib_list = []
+ for i in range(3):
+ mca_calib_list.append(mca_calib[i])
+
+ free(mca_calib)
+ return mca_calib_list
+
+ def get_mca(self, scan_index, mca_index):
+ """Return one MCA spectrum
+
+ :param scan_index: Unique scan index between ``0`` and ``len(self)-1``.
+ :type scan_index: int
+ :param mca_index: Index of MCA in the scan
+ :type mca_index: int
+
+ :return: MCA spectrum
+ :rtype: 1D numpy array
+ """
+ cdef:
+ int error = SF_ERR_NO_ERRORS
+ double* mca_data
+ long len_mca
+ double[:] ret_array
+
+ len_mca = specfile_wrapper.SfGetMca(self.handle,
+ scan_index + 1,
+ mca_index + 1,
+ &mca_data,
+ &error)
+ self._handle_error(error)
+
+ ret_array = numpy.empty((len_mca,), dtype=numpy.double)
+
+ for i in range(len_mca):
+ ret_array[i] = mca_data[i]
+
+ free(mca_data)
+ return numpy.asarray(ret_array)
diff --git a/silx/io/specfile/include/Lists.h b/src/silx/io/specfile/include/Lists.h
index 01164fb..01164fb 100644
--- a/silx/io/specfile/include/Lists.h
+++ b/src/silx/io/specfile/include/Lists.h
diff --git a/silx/io/specfile/include/SpecFile.h b/src/silx/io/specfile/include/SpecFile.h
index 9456e3f..9456e3f 100644
--- a/silx/io/specfile/include/SpecFile.h
+++ b/src/silx/io/specfile/include/SpecFile.h
diff --git a/silx/io/specfile/include/SpecFileCython.h b/src/silx/io/specfile/include/SpecFileCython.h
index 3225e13..3225e13 100644
--- a/silx/io/specfile/include/SpecFileCython.h
+++ b/src/silx/io/specfile/include/SpecFileCython.h
diff --git a/silx/io/specfile/include/SpecFileP.h b/src/silx/io/specfile/include/SpecFileP.h
index 97c3db6..97c3db6 100644
--- a/silx/io/specfile/include/SpecFileP.h
+++ b/src/silx/io/specfile/include/SpecFileP.h
diff --git a/silx/io/specfile/include/locale_management.h b/src/silx/io/specfile/include/locale_management.h
index 64562c5..64562c5 100644
--- a/silx/io/specfile/include/locale_management.h
+++ b/src/silx/io/specfile/include/locale_management.h
diff --git a/silx/io/specfile/src/locale_management.c b/src/silx/io/specfile/src/locale_management.c
index 0c5f7ca..0c5f7ca 100644
--- a/silx/io/specfile/src/locale_management.c
+++ b/src/silx/io/specfile/src/locale_management.c
diff --git a/silx/io/specfile/src/sfdata.c b/src/silx/io/specfile/src/sfdata.c
index 689f56d..689f56d 100644
--- a/silx/io/specfile/src/sfdata.c
+++ b/src/silx/io/specfile/src/sfdata.c
diff --git a/silx/io/specfile/src/sfheader.c b/src/silx/io/specfile/src/sfheader.c
index b669e33..b669e33 100644
--- a/silx/io/specfile/src/sfheader.c
+++ b/src/silx/io/specfile/src/sfheader.c
diff --git a/silx/io/specfile/src/sfindex.c b/src/silx/io/specfile/src/sfindex.c
index 320b086..320b086 100644
--- a/silx/io/specfile/src/sfindex.c
+++ b/src/silx/io/specfile/src/sfindex.c
diff --git a/silx/io/specfile/src/sfinit.c b/src/silx/io/specfile/src/sfinit.c
index ca2fa7f..ca2fa7f 100644
--- a/silx/io/specfile/src/sfinit.c
+++ b/src/silx/io/specfile/src/sfinit.c
diff --git a/silx/io/specfile/src/sflabel.c b/src/silx/io/specfile/src/sflabel.c
index 61cbb3f..61cbb3f 100644
--- a/silx/io/specfile/src/sflabel.c
+++ b/src/silx/io/specfile/src/sflabel.c
diff --git a/silx/io/specfile/src/sflists.c b/src/silx/io/specfile/src/sflists.c
index aca267f..aca267f 100644
--- a/silx/io/specfile/src/sflists.c
+++ b/src/silx/io/specfile/src/sflists.c
diff --git a/silx/io/specfile/src/sfmca.c b/src/silx/io/specfile/src/sfmca.c
index ad13bae..ad13bae 100644
--- a/silx/io/specfile/src/sfmca.c
+++ b/src/silx/io/specfile/src/sfmca.c
diff --git a/silx/io/specfile/src/sftools.c b/src/silx/io/specfile/src/sftools.c
index 9b78b67..9b78b67 100644
--- a/silx/io/specfile/src/sftools.c
+++ b/src/silx/io/specfile/src/sftools.c
diff --git a/silx/io/specfile/src/sfwrite.c b/src/silx/io/specfile/src/sfwrite.c
index c77f400..c77f400 100644
--- a/silx/io/specfile/src/sfwrite.c
+++ b/src/silx/io/specfile/src/sfwrite.c
diff --git a/silx/io/specfile_wrapper.pxd b/src/silx/io/specfile_wrapper.pxd
index 6770f7e..6770f7e 100644
--- a/silx/io/specfile_wrapper.pxd
+++ b/src/silx/io/specfile_wrapper.pxd
diff --git a/silx/io/specfilewrapper.py b/src/silx/io/specfilewrapper.py
index 01e185c..01e185c 100644
--- a/silx/io/specfilewrapper.py
+++ b/src/silx/io/specfilewrapper.py
diff --git a/src/silx/io/spech5.py b/src/silx/io/spech5.py
new file mode 100644
index 0000000..df2021c
--- /dev/null
+++ b/src/silx/io/spech5.py
@@ -0,0 +1,907 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides a h5py-like API to access SpecFile data.
+
+API description
++++++++++++++++
+
+Specfile data structure exposed by this API:
+
+::
+
+ /
+ 1.1/
+ title = "…"
+ start_time = "…"
+ instrument/
+ specfile/
+ file_header = "…"
+ scan_header = "…"
+ positioners/
+ motor_name = value
+ …
+ mca_0/
+ data = …
+ calibration = …
+ channels = …
+ preset_time = …
+ elapsed_time = …
+ live_time = …
+
+ mca_1/
+ …
+ …
+ measurement/
+ colname0 = …
+ colname1 = …
+ …
+ mca_0/
+ data -> /1.1/instrument/mca_0/data
+ info -> /1.1/instrument/mca_0/
+ …
+ sample/
+ ub_matrix = …
+ unit_cell = …
+ unit_cell_abc = …
+ unit_cell_alphabetagamma = …
+ 2.1/
+ …
+
+``file_header`` and ``scan_header`` are the raw headers as they
+appear in the original file, as a string of lines separated by newline (``\\n``) characters.
+
+The title is the content of the ``#S`` scan header line without the leading
+``#S`` and without the scan number (e.g ``"ascan ss1vo -4.55687 -0.556875 40 0.2"``).
+
+The start time is converted to ISO8601 format (``"2016-02-23T22:49:05Z"``),
+if the original date format is standard.
+
+Numeric datasets are stored in *float32* format, except for scalar integers
+which are stored as *int64*.
+
+Motor positions (e.g. ``/1.1/instrument/positioners/motor_name``) can be
+1D numpy arrays if they are measured as scan data, or else scalars as defined
+on ``#P`` scan header lines. A simple test is done to check if the motor name
+is also a data column header defined in the ``#L`` scan header line.
+
+Scan data (e.g. ``/1.1/measurement/colname0``) is accessed by column,
+the dataset name ``colname0`` being the column label as defined in the ``#L``
+scan header line.
+
+If a ``/`` character is present in a column label or in a motor name in the
+original SPEC file, it will be substituted with a ``%`` character in the
+corresponding dataset name.
+
+MCA data is exposed as a 2D numpy array containing all spectra for a given
+analyser. The number of analysers is calculated as the number of MCA spectra
+per scan data line. Demultiplexing is then performed to assign the correct
+spectra to a given analyser.
+
+MCA calibration is an array of 3 scalars, from the ``#@CALIB`` header line.
+It is identical for all MCA analysers, as there can be only one
+``#@CALIB`` line per scan.
+
+MCA channels is an array containing all channel numbers. This information is
+computed from the ``#@CHANN`` scan header line (if present), or computed from
+the shape of the first spectrum in a scan (``[0, … len(first_spectrum] - 1]``).
+
+Accessing data
+++++++++++++++
+
+Data and groups are accessed in :mod:`h5py` fashion::
+
+ from silx.io.spech5 import SpecH5
+
+ # Open a SpecFile
+ sfh5 = SpecH5("test.dat")
+
+ # using SpecH5 as a regular group to access scans
+ scan1group = sfh5["1.1"]
+ instrument_group = scan1group["instrument"]
+
+ # alternative: full path access
+ measurement_group = sfh5["/1.1/measurement"]
+
+ # accessing a scan data column by name as a 1D numpy array
+ data_array = measurement_group["Pslit HGap"]
+
+ # accessing all mca-spectra for one MCA device
+ mca_0_spectra = measurement_group["mca_0/data"]
+
+:class:`SpecH5` files and groups provide a :meth:`keys` method::
+
+ >>> sfh5.keys()
+ ['96.1', '97.1', '98.1']
+ >>> sfh5['96.1'].keys()
+ ['title', 'start_time', 'instrument', 'measurement']
+
+They can also be treated as iterators:
+
+.. code-block:: python
+
+ from silx.io import is_dataset
+
+ for scan_group in SpecH5("test.dat"):
+ dataset_names = [item.name in scan_group["measurement"] if
+ is_dataset(item)]
+ print("Found data columns in scan " + scan_group.name)
+ print(", ".join(dataset_names))
+
+You can test for existence of data or groups::
+
+ >>> "/1.1/measurement/Pslit HGap" in sfh5
+ True
+ >>> "positioners" in sfh5["/2.1/instrument"]
+ True
+ >>> "spam" in sfh5["1.1"]
+ False
+
+.. note::
+
+ Text used to be stored with a dtype ``numpy.string_`` in silx versions
+ prior to *0.7.0*. The type ``numpy.string_`` is a byte-string format.
+ The consequence of this is that you had to decode strings before using
+ them in **Python 3**::
+
+ >>> from silx.io.spech5 import SpecH5
+ >>> sfh5 = SpecH5("31oct98.dat")
+ >>> sfh5["/68.1/title"]
+ b'68 ascan tx3 -28.5 -24.5 20 0.5'
+ >>> sfh5["/68.1/title"].decode()
+ '68 ascan tx3 -28.5 -24.5 20 0.5'
+
+ From silx version *0.7.0* onwards, text is now stored as unicode. This
+ corresponds to the default text type in python 3, and to the *unicode*
+ type in Python 2.
+
+ To be on the safe side, you can test for the presence of a *decode*
+ attribute, to ensure that you always work with unicode text::
+
+ >>> title = sfh5["/68.1/title"]
+ >>> if hasattr(title, "decode"):
+ ... title = title.decode()
+
+"""
+
+import datetime
+import logging
+import re
+import io
+
+import h5py
+import numpy
+
+from silx import version as silx_version
+from .specfile import SpecFile, SfErrColNotFound
+from . import commonh5
+
+__authors__ = ["P. Knobel", "D. Naudet"]
+__license__ = "MIT"
+__date__ = "17/07/2018"
+
+logger1 = logging.getLogger(__name__)
+
+
+text_dtype = h5py.special_dtype(vlen=str)
+
+
+def to_h5py_utf8(str_list):
+ """Convert a string or a list of strings to a numpy array of
+ unicode strings that can be written to HDF5 as utf-8.
+
+ This ensures that the type will be consistent between python 2 and
+ python 3, if attributes or datasets are saved to an HDF5 file.
+ """
+ return numpy.array(str_list, dtype=text_dtype)
+
+
+def _get_number_of_mca_analysers(scan):
+ """
+ :param SpecFile sf: :class:`SpecFile` instance
+ """
+ number_of_mca_spectra = len(scan.mca)
+ # Scan.data is transposed
+ number_of_data_lines = scan.data.shape[1]
+
+ if not number_of_data_lines == 0:
+ # Number of MCA spectra must be a multiple of number of data lines
+ assert number_of_mca_spectra % number_of_data_lines == 0
+ return number_of_mca_spectra // number_of_data_lines
+ elif number_of_mca_spectra:
+ # Case of a scan without data lines, only MCA.
+ # Our only option is to assume that the number of analysers
+ # is the number of #@CHANN lines
+ return len(scan.mca.channels)
+ else:
+ return 0
+
+
+def _motor_in_scan(sf, scan_key, motor_name):
+ """
+ :param sf: :class:`SpecFile` instance
+ :param scan_key: Scan identification key (e.g. ``1.1``)
+ :param motor_name: Name of motor as defined in file header lines
+ :return: ``True`` if motor exists in scan, else ``False``
+ :raise: ``KeyError`` if scan_key not found in SpecFile
+ """
+ if scan_key not in sf:
+ raise KeyError("Scan key %s " % scan_key +
+ "does not exist in SpecFile %s" % sf.filename)
+ ret = motor_name in sf[scan_key].motor_names
+ if not ret and "%" in motor_name:
+ motor_name = motor_name.replace("%", "/")
+ ret = motor_name in sf[scan_key].motor_names
+ return ret
+
+
+def _column_label_in_scan(sf, scan_key, column_label):
+ """
+ :param sf: :class:`SpecFile` instance
+ :param scan_key: Scan identification key (e.g. ``1.1``)
+ :param column_label: Column label as defined in scan header
+ :return: ``True`` if data column label exists in scan, else ``False``
+ :raise: ``KeyError`` if scan_key not found in SpecFile
+ """
+ if scan_key not in sf:
+ raise KeyError("Scan key %s " % scan_key +
+ "does not exist in SpecFile %s" % sf.filename)
+ ret = column_label in sf[scan_key].labels
+ if not ret and "%" in column_label:
+ column_label = column_label.replace("%", "/")
+ ret = column_label in sf[scan_key].labels
+ return ret
+
+
+def _parse_UB_matrix(header_line):
+ """Parse G3 header line and return UB matrix
+
+ :param str header_line: G3 header line
+ :return: UB matrix
+ :raises ValueError: For malformed UB matrix header line
+ """
+ values = list(map(float, header_line.split())) # Can raise ValueError
+ if len(values) < 9:
+ raise ValueError("Not enough values in UB matrix")
+ return numpy.array(values).reshape((1, 3, 3))
+
+
+def _ub_matrix_in_scan(scan):
+ """Return True if scan header has a G3 line and all values are not 0.
+
+ :param scan: specfile.Scan instance
+ :return: True or False
+ """
+ header_line = scan.scan_header_dict.get("G3", None)
+ if header_line is None:
+ return False
+ try:
+ ub_matrix = _parse_UB_matrix(header_line)
+ except ValueError:
+ logger1.warning("Malformed G3 header line")
+ return False
+ return numpy.any(ub_matrix)
+
+
+def _parse_unit_cell(header_line):
+ """Parse G1 header line and return unit cell
+
+ :param str header_line: G1 header line
+ :return: unit cell
+ :raises ValueError: For malformed unit cell header line
+ """
+ values = list(map(float, header_line.split()[0:6])) # can raise ValueError
+ if len(values) < 6:
+ raise ValueError("Not enough values in unit cell")
+ return numpy.array(values).reshape((1, 6))
+
+
+def _unit_cell_in_scan(scan):
+ """Return True if scan header has a G1 line and all values are not 0.
+
+ :param scan: specfile.Scan instance
+ :return: True or False
+ """
+ header_line = scan.scan_header_dict.get("G1", None)
+ if header_line is None:
+ return False
+ try:
+ unit_cell = _parse_unit_cell(header_line)
+ except ValueError:
+ logger1.warning("Malformed G1 header line")
+ return False
+ return numpy.any(unit_cell)
+
+
+def _parse_ctime(ctime_lines, analyser_index=0):
+ """
+ :param ctime_lines: e.g ``@CTIME %f %f %f``, first word ``@CTIME`` optional
+ When multiple CTIME lines are present in a scan header, this argument
+ is a concatenation of them separated by a ``\\n`` character.
+ :param analyser_index: MCA device/analyser index, when multiple devices
+ are in a scan.
+ :return: (preset_time, live_time, elapsed_time)
+ """
+ ctime_lines = ctime_lines.lstrip("@CTIME ")
+ ctimes_lines_list = ctime_lines.split("\n")
+ if len(ctimes_lines_list) == 1:
+ # single @CTIME line for all devices
+ ctime_line = ctimes_lines_list[0]
+ else:
+ ctime_line = ctimes_lines_list[analyser_index]
+ if not len(ctime_line.split()) == 3:
+ raise ValueError("Incorrect format for @CTIME header line " +
+ '(expected "@CTIME %f %f %f").')
+ return list(map(float, ctime_line.split()))
+
+
+def spec_date_to_iso8601(date, zone=None):
+ """Convert SpecFile date to Iso8601.
+
+ :param date: Date (see supported formats below)
+ :type date: str
+ :param zone: Time zone as it appears in a ISO8601 date
+
+ Supported formats:
+
+ * ``DDD MMM dd hh:mm:ss YYYY``
+ * ``DDD YYYY/MM/dd hh:mm:ss YYYY``
+
+ where `DDD` is the abbreviated weekday, `MMM` is the month abbreviated
+ name, `MM` is the month number (zero padded), `dd` is the weekday number
+ (zero padded) `YYYY` is the year, `hh` the hour (zero padded), `mm` the
+ minute (zero padded) and `ss` the second (zero padded).
+ All names are expected to be in english.
+
+ Examples::
+
+ >>> spec_date_to_iso8601("Thu Feb 11 09:54:35 2016")
+ '2016-02-11T09:54:35'
+
+ >>> spec_date_to_iso8601("Sat 2015/03/14 03:53:50")
+ '2015-03-14T03:53:50'
+ """
+ months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul',
+ 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
+ days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
+
+ days_rx = '(?P<day>' + '|'.join(days) + ')'
+ months_rx = '(?P<month>' + '|'.join(months) + ')'
+ year_rx = r'(?P<year>\d{4})'
+ day_nb_rx = r'(?P<day_nb>[0-3 ]\d)'
+ month_nb_rx = r'(?P<month_nb>[0-1]\d)'
+ hh_rx = r'(?P<hh>[0-2]\d)'
+ mm_rx = r'(?P<mm>[0-5]\d)'
+ ss_rx = r'(?P<ss>[0-5]\d)'
+ tz_rx = r'(?P<tz>[+-]\d\d:\d\d){0,1}'
+
+ # date formats must have either month_nb (1..12) or month (Jan, Feb, ...)
+ re_tpls = ['{days} {months} {day_nb} {hh}:{mm}:{ss}{tz} {year}',
+ '{days} {year}/{month_nb}/{day_nb} {hh}:{mm}:{ss}{tz}']
+
+ grp_d = None
+
+ for rx in re_tpls:
+ full_rx = rx.format(days=days_rx,
+ months=months_rx,
+ year=year_rx,
+ day_nb=day_nb_rx,
+ month_nb=month_nb_rx,
+ hh=hh_rx,
+ mm=mm_rx,
+ ss=ss_rx,
+ tz=tz_rx)
+ m = re.match(full_rx, date)
+
+ if m:
+ grp_d = m.groupdict()
+ break
+
+ if not grp_d:
+ raise ValueError('Date format not recognized : {0}'.format(date))
+
+ year = grp_d['year']
+
+ month = grp_d.get('month_nb')
+
+ if not month:
+ month = '{0:02d}'.format(months.index(grp_d.get('month')) + 1)
+
+ day = grp_d['day_nb']
+
+ tz = grp_d['tz']
+ if not tz:
+ tz = zone
+
+ time = '{0}:{1}:{2}'.format(grp_d['hh'],
+ grp_d['mm'],
+ grp_d['ss'])
+
+ full_date = '{0}-{1}-{2}T{3}{4}'.format(year,
+ month,
+ day,
+ time,
+ tz if tz else '')
+ return full_date
+
+
+def _demultiplex_mca(scan, analyser_index):
+ """Return MCA data for a single analyser.
+
+ Each MCA spectrum is a 1D array. For each analyser, there is one
+ spectrum recorded per scan data line. When there are more than a single
+ MCA analyser in a scan, the data will be multiplexed. For instance if
+ there are 3 analysers, the consecutive spectra for the first analyser must
+ be accessed as ``mca[0], mca[3], mca[6]…``.
+
+ :param scan: :class:`Scan` instance containing the MCA data
+ :param analyser_index: 0-based index referencing the analyser
+ :type analyser_index: int
+ :return: 2D numpy array containing all spectra for one analyser
+ """
+ number_of_analysers = _get_number_of_mca_analysers(scan)
+ number_of_spectra = len(scan.mca)
+ number_of_spectra_per_analyser = number_of_spectra // number_of_analysers
+ len_spectrum = len(scan.mca[analyser_index])
+
+ mca_array = numpy.empty((number_of_spectra_per_analyser, len_spectrum))
+
+ for i in range(number_of_spectra_per_analyser):
+ mca_array[i, :] = scan.mca[analyser_index + i * number_of_analysers]
+
+ return mca_array
+
+
+# Node classes
+class SpecH5Dataset(object):
+ """This convenience class is to be inherited by all datasets, for
+ compatibility purpose with code that tests for
+ ``isinstance(obj, SpecH5Dataset)``.
+
+ This legacy behavior is deprecated. The correct way to test
+ if an object is a dataset is to use :meth:`silx.io.utils.is_dataset`.
+
+ Datasets must also inherit :class:`SpecH5NodeDataset` or
+ :class:`SpecH5LazyNodeDataset` which actually implement all the
+ API."""
+ pass
+
+
+class SpecH5NodeDataset(commonh5.Dataset, SpecH5Dataset):
+ """This class inherits :class:`commonh5.Dataset`, to which it adds
+ little extra functionality. The main additional functionality is the
+ proxy behavior that allows to mimic the numpy array stored in this
+ class.
+ """
+ def __init__(self, name, data, parent=None, attrs=None):
+ # get proper value types, to inherit from numpy
+ # attributes (dtype, shape, size)
+ if isinstance(data, str):
+ # use unicode (utf-8 when saved to HDF5 output)
+ value = to_h5py_utf8(data)
+ elif isinstance(data, float):
+ # use 32 bits for float scalars
+ value = numpy.float32(data)
+ elif isinstance(data, int):
+ value = numpy.int_(data)
+ else:
+ # Enforce numpy array
+ array = numpy.array(data)
+ data_kind = array.dtype.kind
+
+ if data_kind in ["S", "U"]:
+ value = numpy.asarray(array,
+ dtype=text_dtype)
+ elif data_kind in ["f"]:
+ value = numpy.asarray(array, dtype=numpy.float32)
+ else:
+ value = array
+ commonh5.Dataset.__init__(self, name, value, parent, attrs)
+
+ def __getattr__(self, item):
+ """Proxy to underlying numpy array methods.
+ """
+ if hasattr(self[()], item):
+ return getattr(self[()], item)
+
+ raise AttributeError("SpecH5Dataset has no attribute %s" % item)
+
+
+class SpecH5LazyNodeDataset(commonh5.LazyLoadableDataset, SpecH5Dataset):
+ """This class inherits :class:`commonh5.LazyLoadableDataset`,
+ to which it adds a proxy behavior that allows to mimic the numpy
+ array stored in this class.
+
+ The class has to be inherited and the :meth:`_create_data` method has to be
+ implemented to return the numpy data exposed by the dataset. This factory
+ method is only called once, when the data is needed.
+ """
+ def __getattr__(self, item):
+ """Proxy to underlying numpy array methods.
+ """
+ if hasattr(self[()], item):
+ return getattr(self[()], item)
+
+ raise AttributeError("SpecH5Dataset has no attribute %s" % item)
+
+ def _create_data(self):
+ """
+ Factory to create the data exposed by the dataset when it is needed.
+
+ It has to be implemented for the class to work.
+
+ :rtype: numpy.ndarray
+ """
+ raise NotImplementedError()
+
+
+class SpecH5Group(object):
+ """This convenience class is to be inherited by all groups, for
+ compatibility purposes with code that tests for
+ ``isinstance(obj, SpecH5Group)``.
+
+ This legacy behavior is deprecated. The correct way to test
+ if an object is a group is to use :meth:`silx.io.utils.is_group`.
+
+ Groups must also inherit :class:`silx.io.commonh5.Group`, which
+ actually implements all the methods and attributes."""
+ pass
+
+
+class SpecH5(commonh5.File, SpecH5Group):
+ """This class opens a SPEC file and exposes it as a *h5py.File*.
+
+ It inherits :class:`silx.io.commonh5.Group` (via :class:`commonh5.File`),
+ which implements most of its API.
+ """
+
+ def __init__(self, filename):
+ """
+ :param filename: Path to SpecFile in filesystem
+ :type filename: str
+ """
+ if isinstance(filename, io.IOBase):
+ # see https://github.com/silx-kit/silx/issues/858
+ filename = filename.name
+
+ self._sf = SpecFile(filename)
+
+ attrs = {"NX_class": to_h5py_utf8("NXroot"),
+ "file_time": to_h5py_utf8(
+ datetime.datetime.now().isoformat()),
+ "file_name": to_h5py_utf8(filename),
+ "creator": to_h5py_utf8("silx spech5 %s" % silx_version)}
+ commonh5.File.__init__(self, filename, attrs=attrs)
+
+ for scan_key in self._sf.keys():
+ scan = self._sf[scan_key]
+ scan_group = ScanGroup(scan_key, parent=self, scan=scan)
+ self.add_node(scan_group)
+
+ def close(self):
+ self._sf.close()
+ self._sf = None
+
+
+class ScanGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, scan_key, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param str scan_key: Scan key (e.g. "1.1")
+ :param scan: specfile.Scan object
+ """
+ commonh5.Group.__init__(self, scan_key, parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXentry")})
+
+ # take title in #S after stripping away scan number and spaces
+ s_hdr_line = scan.scan_header_dict["S"]
+ title = s_hdr_line.lstrip("0123456789").lstrip()
+ self.add_node(SpecH5NodeDataset(name="title",
+ data=to_h5py_utf8(title),
+ parent=self))
+
+ if "D" in scan.scan_header_dict:
+ try:
+ start_time_str = spec_date_to_iso8601(scan.scan_header_dict["D"])
+ except (IndexError, ValueError):
+ logger1.warning("Could not parse date format in scan %s header." +
+ " Using original date not converted to ISO-8601",
+ scan_key)
+ start_time_str = scan.scan_header_dict["D"]
+ elif "D" in scan.file_header_dict:
+ logger1.warning("No #D line in scan %s header. " +
+ "Using file header for start_time.",
+ scan_key)
+ try:
+ start_time_str = spec_date_to_iso8601(scan.file_header_dict["D"])
+ except (IndexError, ValueError):
+ logger1.warning("Could not parse date format in scan %s header. " +
+ "Using original date not converted to ISO-8601",
+ scan_key)
+ start_time_str = scan.file_header_dict["D"]
+ else:
+ logger1.warning("No #D line in %s header. Setting date to empty string.",
+ scan_key)
+ start_time_str = ""
+ self.add_node(SpecH5NodeDataset(name="start_time",
+ data=to_h5py_utf8(start_time_str),
+ parent=self))
+
+ self.add_node(InstrumentGroup(parent=self, scan=scan))
+ self.add_node(MeasurementGroup(parent=self, scan=scan))
+ if _unit_cell_in_scan(scan) or _ub_matrix_in_scan(scan):
+ self.add_node(SampleGroup(parent=self, scan=scan))
+
+
+class InstrumentGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: specfile.Scan object
+ """
+ commonh5.Group.__init__(self, name="instrument", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXinstrument")})
+
+ self.add_node(InstrumentSpecfileGroup(parent=self, scan=scan))
+ self.add_node(PositionersGroup(parent=self, scan=scan))
+
+ num_analysers = _get_number_of_mca_analysers(scan)
+ for anal_idx in range(num_analysers):
+ self.add_node(InstrumentMcaGroup(parent=self,
+ analyser_index=anal_idx,
+ scan=scan))
+
+
+class InstrumentSpecfileGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, scan):
+ commonh5.Group.__init__(self, name="specfile", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")})
+ self.add_node(SpecH5NodeDataset(
+ name="file_header",
+ data=to_h5py_utf8(scan.file_header),
+ parent=self,
+ attrs={}))
+ self.add_node(SpecH5NodeDataset(
+ name="scan_header",
+ data=to_h5py_utf8(scan.scan_header),
+ parent=self,
+ attrs={}))
+
+
+class PositionersGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, scan):
+ commonh5.Group.__init__(self, name="positioners", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection")})
+
+ dataset_info = [] # Store list of positioner's (name, value)
+ is_error = False # True if error encountered
+
+ for motor_name in scan.motor_names:
+ safe_motor_name = motor_name.replace("/", "%")
+ if motor_name in scan.labels and scan.data.shape[0] > 0:
+ # return a data column if one has the same label as the motor
+ motor_value = scan.data_column_by_name(motor_name)
+ else:
+ # Take value from #P scan header.
+ # (may return float("inf") if #P line is missing from scan hdr)
+ try:
+ motor_value = scan.motor_position_by_name(motor_name)
+ except SfErrColNotFound:
+ is_error = True
+ motor_value = float('inf')
+ dataset_info.append((safe_motor_name, motor_value))
+
+ if is_error: # Filter-out scalar values
+ logger1.warning("Mismatching number of elements in #P and #O: Ignoring")
+ dataset_info = [
+ (name, value) for name, value in dataset_info
+ if not isinstance(value, float)]
+
+ for name, value in dataset_info:
+ self.add_node(SpecH5NodeDataset(
+ name=name,
+ data=value,
+ parent=self))
+
+
+class InstrumentMcaGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, analyser_index, scan):
+ name = "mca_%d" % analyser_index
+ commonh5.Group.__init__(self, name=name, parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXdetector")})
+
+ mcaDataDataset = McaDataDataset(parent=self,
+ analyser_index=analyser_index,
+ scan=scan)
+ self.add_node(mcaDataDataset)
+ spectrum_length = mcaDataDataset.shape[-1]
+ mcaDataDataset = None
+
+ if len(scan.mca.channels) == 1:
+ # single @CALIB line applying to multiple devices
+ calibration_dataset = scan.mca.calibration[0]
+ channels_dataset = scan.mca.channels[0]
+ else:
+ calibration_dataset = scan.mca.calibration[analyser_index]
+ channels_dataset = scan.mca.channels[analyser_index]
+
+ channels_length = len(channels_dataset)
+ if (channels_length > 1) and (spectrum_length > 0):
+ logger1.info("Spectrum and channels length mismatch")
+ # this should always be the case
+ if channels_length > spectrum_length:
+ channels_dataset = channels_dataset[:spectrum_length]
+ elif channels_length < spectrum_length:
+ # only trust first channel and increment
+ channel0 = channels_dataset[0]
+ increment = channels_dataset[1] - channels_dataset[0]
+ channels_dataset = numpy.linspace(channel0,
+ channel0 + increment * spectrum_length,
+ spectrum_length, endpoint=False)
+
+ self.add_node(SpecH5NodeDataset(name="calibration",
+ data=calibration_dataset,
+ parent=self))
+ self.add_node(SpecH5NodeDataset(name="channels",
+ data=channels_dataset,
+ parent=self))
+
+ if "CTIME" in scan.mca_header_dict:
+ ctime_line = scan.mca_header_dict['CTIME']
+ preset_time, live_time, elapsed_time = _parse_ctime(ctime_line, analyser_index)
+ self.add_node(SpecH5NodeDataset(name="preset_time",
+ data=preset_time,
+ parent=self))
+ self.add_node(SpecH5NodeDataset(name="live_time",
+ data=live_time,
+ parent=self))
+ self.add_node(SpecH5NodeDataset(name="elapsed_time",
+ data=elapsed_time,
+ parent=self))
+
+
+class McaDataDataset(SpecH5LazyNodeDataset):
+ """Lazy loadable dataset for MCA data"""
+ def __init__(self, parent, analyser_index, scan):
+ commonh5.LazyLoadableDataset.__init__(
+ self, name="data", parent=parent,
+ attrs={"interpretation": to_h5py_utf8("spectrum"),})
+ self._scan = scan
+ self._analyser_index = analyser_index
+ self._shape = None
+ self._num_analysers = _get_number_of_mca_analysers(self._scan)
+
+ def _create_data(self):
+ return _demultiplex_mca(self._scan, self._analyser_index)
+
+ @property
+ def shape(self):
+ if self._shape is None:
+ num_spectra_in_file = len(self._scan.mca)
+ num_spectra_per_analyser = num_spectra_in_file // self._num_analysers
+ len_spectrum = len(self._scan.mca[self._analyser_index])
+ self._shape = num_spectra_per_analyser, len_spectrum
+ return self._shape
+
+ @property
+ def size(self):
+ return numpy.prod(self.shape, dtype=numpy.intp)
+
+ @property
+ def dtype(self):
+ # we initialize the data with numpy.empty() without specifying a dtype
+ # in _demultiplex_mca()
+ return numpy.empty((1, )).dtype
+
+ def __len__(self):
+ return self.shape[0]
+
+ def __getitem__(self, item):
+ # optimization for fetching a single spectrum if data not already loaded
+ if not self._is_initialized:
+ if isinstance(item, int):
+ if item < 0:
+ # negative indexing
+ item += len(self)
+ return self._scan.mca[self._analyser_index +
+ item * self._num_analysers]
+ # accessing a slice or element of a single spectrum [i, j:k]
+ try:
+ spectrum_idx, channel_idx_or_slice = item
+ assert isinstance(spectrum_idx, int)
+ except (ValueError, TypeError, AssertionError):
+ pass
+ else:
+ if spectrum_idx < 0:
+ item += len(self)
+ idx = self._analyser_index + spectrum_idx * self._num_analysers
+ return self._scan.mca[idx][channel_idx_or_slice]
+
+ return super(McaDataDataset, self).__getitem__(item)
+
+
+class MeasurementGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: specfile.Scan object
+ """
+ commonh5.Group.__init__(self, name="measurement", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXcollection"),})
+ for label in scan.labels:
+ safe_label = label.replace("/", "%")
+ self.add_node(SpecH5NodeDataset(name=safe_label,
+ data=scan.data_column_by_name(label),
+ parent=self))
+
+ num_analysers = _get_number_of_mca_analysers(scan)
+ for anal_idx in range(num_analysers):
+ self.add_node(MeasurementMcaGroup(parent=self, analyser_index=anal_idx))
+
+
+class MeasurementMcaGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, analyser_index):
+ basename = "mca_%d" % analyser_index
+ commonh5.Group.__init__(self, name=basename, parent=parent,
+ attrs={})
+
+ target_name = self.name.replace("measurement", "instrument")
+ self.add_node(commonh5.SoftLink(name="data",
+ path=target_name + "/data",
+ parent=self))
+ self.add_node(commonh5.SoftLink(name="info",
+ path=target_name,
+ parent=self))
+
+
+class SampleGroup(commonh5.Group, SpecH5Group):
+ def __init__(self, parent, scan):
+ """
+
+ :param parent: parent Group
+ :param scan: specfile.Scan object
+ """
+ commonh5.Group.__init__(self, name="sample", parent=parent,
+ attrs={"NX_class": to_h5py_utf8("NXsample"),})
+
+ if _unit_cell_in_scan(scan):
+ self.add_node(SpecH5NodeDataset(name="unit_cell",
+ data=_parse_unit_cell(scan.scan_header_dict["G1"]),
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")}))
+ self.add_node(SpecH5NodeDataset(name="unit_cell_abc",
+ data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 0:3],
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")}))
+ self.add_node(SpecH5NodeDataset(name="unit_cell_alphabetagamma",
+ data=_parse_unit_cell(scan.scan_header_dict["G1"])[0, 3:6],
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")}))
+ if _ub_matrix_in_scan(scan):
+ self.add_node(SpecH5NodeDataset(name="ub_matrix",
+ data=_parse_UB_matrix(scan.scan_header_dict["G3"]),
+ parent=self,
+ attrs={"interpretation": to_h5py_utf8("scalar")}))
diff --git a/silx/io/spectoh5.py b/src/silx/io/spectoh5.py
index fb3b739..fb3b739 100644
--- a/silx/io/spectoh5.py
+++ b/src/silx/io/spectoh5.py
diff --git a/src/silx/io/test/__init__.py b/src/silx/io/test/__init__.py
new file mode 100644
index 0000000..244d090
--- /dev/null
+++ b/src/silx/io/test/__init__.py
@@ -0,0 +1,23 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
diff --git a/src/silx/io/test/test_commonh5.py b/src/silx/io/test/test_commonh5.py
new file mode 100644
index 0000000..27f6e8c
--- /dev/null
+++ b/src/silx/io/test/test_commonh5.py
@@ -0,0 +1,285 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for commonh5 wrapper"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "21/09/2017"
+
+import logging
+import numpy
+import unittest
+import tempfile
+import shutil
+
+_logger = logging.getLogger(__name__)
+
+import silx.io
+import silx.io.utils
+import h5py
+
+try:
+ from .. import commonh5
+except ImportError:
+ commonh5 = None
+
+
+class _TestCommonFeatures(unittest.TestCase):
+ """Test common features supported by h5py and our implementation."""
+ __test__ = False # ignore abstract class tests
+
+ @classmethod
+ def createFile(cls):
+ return None
+
+ @classmethod
+ def setUpClass(cls):
+ # Set to None cause create_resource can raise an excpetion
+ cls.h5 = None
+ cls.h5 = cls.create_resource()
+ if cls.h5 is None:
+ raise unittest.SkipTest("File not created")
+
+ @classmethod
+ def create_resource(cls):
+ """Must be implemented"""
+ return None
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.h5 = None
+
+ def test_file(self):
+ node = self.h5
+ self.assertTrue(silx.io.is_file(node))
+ self.assertTrue(silx.io.is_group(node))
+ self.assertFalse(silx.io.is_dataset(node))
+ self.assertEqual(len(node.attrs), 0)
+
+ def test_group(self):
+ node = self.h5["group"]
+ self.assertFalse(silx.io.is_file(node))
+ self.assertTrue(silx.io.is_group(node))
+ self.assertFalse(silx.io.is_dataset(node))
+ self.assertEqual(len(node.attrs), 0)
+ class_ = self.h5.get("group", getclass=True)
+ classlink = self.h5.get("group", getlink=True, getclass=True)
+ self.assertEqual(class_, h5py.Group)
+ self.assertEqual(classlink, h5py.HardLink)
+
+ def test_dataset(self):
+ node = self.h5["group/dataset"]
+ self.assertFalse(silx.io.is_file(node))
+ self.assertFalse(silx.io.is_group(node))
+ self.assertTrue(silx.io.is_dataset(node))
+ self.assertEqual(len(node.attrs), 0)
+ class_ = self.h5.get("group/dataset", getclass=True)
+ classlink = self.h5.get("group/dataset", getlink=True, getclass=True)
+ self.assertEqual(class_, h5py.Dataset)
+ self.assertEqual(classlink, h5py.HardLink)
+
+ def test_soft_link(self):
+ node = self.h5["link/soft_link"]
+ self.assertEqual(node.name, "/link/soft_link")
+ class_ = self.h5.get("link/soft_link", getclass=True)
+ link = self.h5.get("link/soft_link", getlink=True)
+ classlink = self.h5.get("link/soft_link", getlink=True, getclass=True)
+ self.assertEqual(class_, h5py.Dataset)
+ self.assertTrue(isinstance(link, (h5py.SoftLink, commonh5.SoftLink)))
+ self.assertTrue(silx.io.utils.is_softlink(link))
+ self.assertEqual(classlink, h5py.SoftLink)
+
+ def test_external_link(self):
+ node = self.h5["link/external_link"]
+ self.assertEqual(node.name, "/target/dataset")
+ class_ = self.h5.get("link/external_link", getclass=True)
+ classlink = self.h5.get("link/external_link", getlink=True, getclass=True)
+ self.assertEqual(class_, h5py.Dataset)
+ self.assertEqual(classlink, h5py.ExternalLink)
+
+ def test_external_link_to_link(self):
+ node = self.h5["link/external_link_to_link"]
+ self.assertEqual(node.name, "/target/link")
+ class_ = self.h5.get("link/external_link_to_link", getclass=True)
+ classlink = self.h5.get("link/external_link_to_link", getlink=True, getclass=True)
+ self.assertEqual(class_, h5py.Dataset)
+ self.assertEqual(classlink, h5py.ExternalLink)
+
+ def test_create_groups(self):
+ c = self.h5.create_group(self.id() + "/a/b/c")
+ d = c.create_group("/" + self.id() + "/a/b/d")
+
+ self.assertRaises(ValueError, self.h5.create_group, self.id() + "/a/b/d")
+ self.assertEqual(c.name, "/" + self.id() + "/a/b/c")
+ self.assertEqual(d.name, "/" + self.id() + "/a/b/d")
+
+ def test_setitem_python_object_dataset(self):
+ group = self.h5.create_group(self.id())
+ group["a"] = 10
+ self.assertEqual(group["a"].dtype.kind, "i")
+
+ def test_setitem_numpy_dataset(self):
+ group = self.h5.create_group(self.id())
+ group["a"] = numpy.array([10, 20, 30])
+ self.assertEqual(group["a"].dtype.kind, "i")
+ self.assertEqual(group["a"].shape, (3,))
+
+ def test_setitem_link(self):
+ group = self.h5.create_group(self.id())
+ group["a"] = 10
+ group["b"] = group["a"]
+ self.assertEqual(group["b"].dtype.kind, "i")
+
+ def test_setitem_dataset_is_sub_group(self):
+ self.h5[self.id() + "/a"] = 10
+
+
+class TestCommonFeatures_h5py(_TestCommonFeatures):
+ """Check if h5py is compliant with what we expect."""
+ __test__ = True # because _TestCommonFeatures is ignored
+
+ @classmethod
+ def create_resource(cls):
+ cls.tmp_dir = tempfile.mkdtemp()
+
+ externalh5 = h5py.File(cls.tmp_dir + "/external.h5", mode="w")
+ externalh5["target/dataset"] = 50
+ externalh5["target/link"] = h5py.SoftLink("/target/dataset")
+ externalh5.close()
+
+ h5 = h5py.File(cls.tmp_dir + "/base.h5", mode="w")
+ h5["group/dataset"] = 50
+ h5["link/soft_link"] = h5py.SoftLink("/group/dataset")
+ h5["link/external_link"] = h5py.ExternalLink("external.h5", "/target/dataset")
+ h5["link/external_link_to_link"] = h5py.ExternalLink("external.h5", "/target/link")
+
+ return h5
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestCommonFeatures_h5py, cls).tearDownClass()
+ if hasattr(cls, "tmp_dir") and cls.tmp_dir is not None:
+ shutil.rmtree(cls.tmp_dir)
+
+
+class TestCommonFeatures_commonH5(_TestCommonFeatures):
+ """Check if commonh5 is compliant with h5py."""
+ __test__ = True # because _TestCommonFeatures is ignored
+
+ @classmethod
+ def create_resource(cls):
+ h5 = commonh5.File("base.h5", "w")
+ h5.create_group("group").create_dataset("dataset", data=numpy.int32(50))
+
+ link = h5.create_group("link")
+ link.add_node(commonh5.SoftLink("soft_link", "/group/dataset"))
+
+ return h5
+
+ def test_external_link(self):
+ # not applicable
+ pass
+
+ def test_external_link_to_link(self):
+ # not applicable
+ pass
+
+
+class TestSpecificCommonH5(unittest.TestCase):
+ """Test specific features from commonh5.
+
+ Test of shared features should be done by TestCommonFeatures."""
+
+ def setUp(self):
+ if commonh5 is None:
+ self.skipTest("silx.io.commonh5 is needed")
+
+ def test_node_attrs(self):
+ node = commonh5.Node("Foo", attrs={"a": 1})
+ self.assertEqual(node.attrs["a"], 1)
+ node.attrs["b"] = 8
+ self.assertEqual(node.attrs["b"], 8)
+ node.attrs["b"] = 2
+ self.assertEqual(node.attrs["b"], 2)
+
+ def test_node_readonly_attrs(self):
+ f = commonh5.File(name="Foo", mode="r")
+ node = commonh5.Node("Foo", attrs={"a": 1})
+ node.attrs["b"] = 8
+ f.add_node(node)
+ self.assertEqual(node.attrs["b"], 8)
+ try:
+ node.attrs["b"] = 1
+ self.fail()
+ except RuntimeError:
+ pass
+
+ def test_create_dataset(self):
+ f = commonh5.File(name="Foo", mode="w")
+ node = f.create_dataset("foo", data=numpy.array([1]))
+ self.assertIs(node.parent, f)
+ self.assertIs(f["foo"], node)
+
+ def test_create_group(self):
+ f = commonh5.File(name="Foo", mode="w")
+ node = f.create_group("foo")
+ self.assertIs(node.parent, f)
+ self.assertIs(f["foo"], node)
+
+ def test_readonly_create_dataset(self):
+ f = commonh5.File(name="Foo", mode="r")
+ try:
+ f.create_dataset("foo", data=numpy.array([1]))
+ self.fail()
+ except RuntimeError:
+ pass
+
+ def test_readonly_create_group(self):
+ f = commonh5.File(name="Foo", mode="r")
+ try:
+ f.create_group("foo")
+ self.fail()
+ except RuntimeError:
+ pass
+
+ def test_create_unicode_dataset(self):
+ f = commonh5.File(name="Foo", mode="w")
+ try:
+ f.create_dataset("foo", data=numpy.array(u"aaaa"))
+ self.fail()
+ except TypeError:
+ pass
+
+ def test_setitem_dataset(self):
+ self.h5 = commonh5.File(name="Foo", mode="w")
+ group = self.h5.create_group(self.id())
+ group["a"] = commonh5.Dataset(None, data=numpy.array(10))
+ self.assertEqual(group["a"].dtype.kind, "i")
+
+ def test_setitem_explicit_link(self):
+ self.h5 = commonh5.File(name="Foo", mode="w")
+ group = self.h5.create_group(self.id())
+ group["a"] = 10
+ group["b"] = commonh5.SoftLink(None, path="/" + self.id() + "/a")
+ self.assertEqual(group["b"].dtype.kind, "i")
diff --git a/src/silx/io/test/test_dictdump.py b/src/silx/io/test/test_dictdump.py
new file mode 100644
index 0000000..4cafa9b
--- /dev/null
+++ b/src/silx/io/test/test_dictdump.py
@@ -0,0 +1,1009 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for dicttoh5 module"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+from collections import OrderedDict
+import numpy
+import os
+import tempfile
+import unittest
+import h5py
+from copy import deepcopy
+
+from collections import defaultdict
+
+from silx.utils.testutils import LoggingValidator
+
+from ..configdict import ConfigDict
+from .. import dictdump
+from ..dictdump import dicttoh5, dicttojson, dump
+from ..dictdump import h5todict, load
+from ..dictdump import logger as dictdump_logger
+from ..utils import is_link
+from ..utils import h5py_read_dataset
+
+
+def tree():
+ """Tree data structure as a recursive nested dictionary"""
+ return defaultdict(tree)
+
+
+inhabitants = 160215
+
+city_attrs = tree()
+city_attrs["Europe"]["France"]["Grenoble"]["area"] = "18.44 km2"
+city_attrs["Europe"]["France"]["Grenoble"]["inhabitants"] = inhabitants
+city_attrs["Europe"]["France"]["Grenoble"]["coordinates"] = [45.1830, 5.7196]
+city_attrs["Europe"]["France"]["Tourcoing"]["area"]
+
+ext_attrs = tree()
+ext_attrs["ext_group"]["dataset"] = 10
+ext_filename = "ext.h5"
+
+link_attrs = tree()
+link_attrs["links"]["group"]["dataset"] = 10
+link_attrs["links"]["group"]["relative_softlink"] = h5py.SoftLink("dataset")
+link_attrs["links"]["relative_softlink"] = h5py.SoftLink("group/dataset")
+link_attrs["links"]["absolute_softlink"] = h5py.SoftLink("/links/group/dataset")
+link_attrs["links"]["external_link"] = h5py.ExternalLink(ext_filename, "/ext_group/dataset")
+
+
+class DictTestCase(unittest.TestCase):
+
+ def assertRecursiveEqual(self, expected, actual, nodes=tuple()):
+ err_msg = "\n\n Tree nodes: {}".format(nodes)
+ if isinstance(expected, dict):
+ self.assertTrue(isinstance(actual, dict), msg=err_msg)
+ self.assertEqual(
+ set(expected.keys()),
+ set(actual.keys()),
+ msg=err_msg
+ )
+ for k in actual:
+ self.assertRecursiveEqual(
+ expected[k],
+ actual[k],
+ nodes=nodes + (k,),
+ )
+ return
+ if isinstance(actual, numpy.ndarray):
+ actual = actual.tolist()
+ if isinstance(expected, numpy.ndarray):
+ expected = expected.tolist()
+
+ self.assertEqual(expected, actual, msg=err_msg)
+
+
+class H5DictTestCase(DictTestCase):
+
+ def _dictRoundTripNormalize(self, treedict):
+ """Convert the dictionary as expected from a round-trip
+ treedict -> dicttoh5 -> h5todict -> newtreedict
+ """
+ for key, value in list(treedict.items()):
+ if isinstance(value, dict):
+ self._dictRoundTripNormalize(value)
+
+ # Expand treedict[("group", "attr_name")]
+ # to treedict["group"]["attr_name"]
+ for key, value in list(treedict.items()):
+ if not isinstance(key, tuple):
+ continue
+ # Put the attribute inside the group
+ grpname, attr = key
+ if not grpname:
+ continue
+ group = treedict.setdefault(grpname, dict())
+ if isinstance(group, dict):
+ del treedict[key]
+ group[("", attr)] = value
+
+ def dictRoundTripNormalize(self, treedict):
+ treedict2 = deepcopy(treedict)
+ self._dictRoundTripNormalize(treedict2)
+ return treedict2
+
+
+class TestDictToH5(H5DictTestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, ext_filename)
+
+ def tearDown(self):
+ if os.path.exists(self.h5_fname):
+ os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
+ os.rmdir(self.tempdir)
+
+ def testH5CityAttrs(self):
+ filters = {'shuffle': True,
+ 'fletcher32': True}
+ dicttoh5(city_attrs, self.h5_fname, h5path='/city attributes',
+ mode="w", create_dataset_args=filters)
+
+ h5f = h5py.File(self.h5_fname, mode='r')
+
+ self.assertIn("Tourcoing/area", h5f["/city attributes/Europe/France"])
+ ds = h5f["/city attributes/Europe/France/Grenoble/inhabitants"]
+ self.assertEqual(ds[...], 160215)
+
+ # filters only apply to datasets that are not scalars (shape != () )
+ ds = h5f["/city attributes/Europe/France/Grenoble/coordinates"]
+ #self.assertEqual(ds.compression, "gzip")
+ self.assertTrue(ds.fletcher32)
+ self.assertTrue(ds.shuffle)
+
+ h5f.close()
+
+ ddict = load(self.h5_fname, fmat="hdf5")
+ self.assertAlmostEqual(
+ min(ddict["city attributes"]["Europe"]["France"]["Grenoble"]["coordinates"]),
+ 5.7196)
+
+ def testH5OverwriteDeprecatedApi(self):
+ dd = ConfigDict({'t': True})
+
+ dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a')
+ dd = ConfigDict({'t': False})
+ dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
+ overwrite_data=False)
+
+ res = h5todict(self.h5_fname)
+ assert(res['t'] == True)
+
+ dicttoh5(h5file=self.h5_fname, treedict=dd, mode='a',
+ overwrite_data=True)
+
+ res = h5todict(self.h5_fname)
+ assert(res['t'] == False)
+
+ def testAttributes(self):
+ """Any kind of attribute can be described"""
+ ddict = {
+ "group": {"datatset": "hmmm", ("", "group_attr"): 10},
+ "dataset": "aaaaaaaaaaaaaaa",
+ ("", "root_attr"): 11,
+ ("dataset", "dataset_attr"): 12,
+ ("group", "group_attr2"): 13,
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict, h5file)
+ self.assertEqual(h5file["group"].attrs['group_attr'], 10)
+ self.assertEqual(h5file.attrs['root_attr'], 11)
+ self.assertEqual(h5file["dataset"].attrs['dataset_attr'], 12)
+ self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
+
+ def testPathAttributes(self):
+ """A group is requested at a path"""
+ ddict = {
+ ("", "NX_class"): 'NXcollection',
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ # This should not warn
+ with LoggingValidator(dictdump_logger, warning=0):
+ dictdump.dicttoh5(ddict, h5file, h5path="foo/bar")
+
+ def testKeyOrder(self):
+ ddict1 = {
+ "d": "plow",
+ ("d", "a"): "ox",
+ }
+ ddict2 = {
+ ("d", "a"): "ox",
+ "d": "plow",
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict1, h5file, h5path="g1")
+ dictdump.dicttoh5(ddict2, h5file, h5path="g2")
+ self.assertEqual(h5file["g1/d"].attrs['a'], "ox")
+ self.assertEqual(h5file["g2/d"].attrs['a'], "ox")
+
+ def testAttributeValues(self):
+ """Any NX data types can be used"""
+ ddict = {
+ ("", "bool"): True,
+ ("", "int"): 11,
+ ("", "float"): 1.1,
+ ("", "str"): "a",
+ ("", "boollist"): [True, False, True],
+ ("", "intlist"): [11, 22, 33],
+ ("", "floatlist"): [1.1, 2.2, 3.3],
+ ("", "strlist"): ["a", "bb", "ccc"],
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict, h5file)
+ for k, expected in ddict.items():
+ result = h5file.attrs[k[1]]
+ if isinstance(expected, list):
+ if isinstance(expected[0], str):
+ numpy.testing.assert_array_equal(result, expected)
+ else:
+ numpy.testing.assert_array_almost_equal(result, expected)
+ else:
+ self.assertEqual(result, expected)
+
+ def testAttributeAlreadyExists(self):
+ """A duplicated attribute warns if overwriting is not enabled"""
+ ddict = {
+ "group": {"dataset": "hmmm", ("", "attr"): 10},
+ ("group", "attr"): 10,
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict, h5file)
+ self.assertEqual(h5file["group"].attrs['attr'], 10)
+
+ def testFlatDict(self):
+ """Description of a tree with a single level of keys"""
+ ddict = {
+ "group/group/dataset": 10,
+ ("group/group/dataset", "attr"): 11,
+ ("group/group", "attr"): 12,
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict, h5file)
+ self.assertEqual(h5file["group/group/dataset"][()], 10)
+ self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
+ self.assertEqual(h5file["group/group"].attrs['attr'], 12)
+
+ def testLinks(self):
+ with h5py.File(self.h5_ext_fname, "w") as h5file:
+ dictdump.dicttoh5(ext_attrs, h5file)
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(link_attrs, h5file)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ self.assertEqual(h5file["links/group/dataset"][()], 10)
+ self.assertEqual(h5file["links/group/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/absolute_softlink"][()], 10)
+ self.assertEqual(h5file["links/external_link"][()], 10)
+
+ def testDumpNumpyArray(self):
+ ddict = {
+ 'darks': {
+ '0': numpy.array([[0, 0, 0], [0, 0, 0]], dtype=numpy.uint16)
+ }
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttoh5(ddict, h5file)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ numpy.testing.assert_array_equal(h5py_read_dataset(h5file["darks"]["0"]),
+ ddict['darks']['0'])
+
+ def testOverwrite(self):
+ # Tree structure that will be tested
+ group1 = {
+ ("", "attr2"): "original2",
+ "dset1": 0,
+ "dset2": [0, 1],
+ ("dset1", "attr1"): "original1",
+ ("dset1", "attr2"): "original2",
+ ("dset2", "attr1"): "original1",
+ ("dset2", "attr2"): "original2",
+ }
+ group2 = {
+ "subgroup1": group1.copy(),
+ "subgroup2": group1.copy(),
+ ("subgroup1", "attr1"): "original1",
+ ("subgroup2", "attr1"): "original1"
+ }
+ group2.update(group1)
+ # initial HDF5 tree
+ otreedict = {
+ ('', 'attr1'): "original1",
+ ('', 'attr2'): "original2",
+ 'group1': group1,
+ 'group2': group2,
+ ('group1', 'attr1'): "original1",
+ ('group2', 'attr1'): "original1"
+ }
+ wtreedict = None # dumped dictionary
+ etreedict = None # expected HDF5 tree after dump
+
+ def reset_file():
+ dicttoh5(
+ otreedict,
+ h5file=self.h5_fname,
+ mode="w",
+ )
+
+ def append_file(update_mode):
+ dicttoh5(
+ wtreedict,
+ h5file=self.h5_fname,
+ mode="a",
+ update_mode=update_mode
+ )
+
+ def assert_file():
+ rtreedict = h5todict(
+ self.h5_fname,
+ include_attributes=True,
+ asarray=False
+ )
+ netreedict = self.dictRoundTripNormalize(etreedict)
+ try:
+ self.assertRecursiveEqual(netreedict, rtreedict)
+ except AssertionError:
+ from pprint import pprint
+ print("\nDUMP:")
+ pprint(wtreedict)
+ print("\nEXPECTED:")
+ pprint(netreedict)
+ print("\nHDF5:")
+ pprint(rtreedict)
+ raise
+
+ def assert_append(update_mode):
+ append_file(update_mode)
+ assert_file()
+
+ # Test wrong arguments
+ with self.assertRaises(ValueError):
+ dicttoh5(
+ otreedict,
+ h5file=self.h5_fname,
+ mode="w",
+ update_mode="wrong-value"
+ )
+
+ # No writing
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ assert_file()
+
+ # Write identical dictionary
+ wtreedict = deepcopy(otreedict)
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add", "modify", "replace"]:
+ assert_append(update_mode)
+
+ # Write empty dictionary
+ wtreedict = dict()
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add", "modify", "replace"]:
+ assert_append(update_mode)
+
+ # Modified dataset
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"]["dset1"] = {"dset3": [10, 20]}
+ wtreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ etreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ etreedict["group2"]["subgroup2"]["dset1"] = {"dset3": [10, 20]}
+ etreedict["group2"]["subgroup2"]["dset2"] = [10, 20]
+ assert_append("replace")
+
+ # Modified group
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = [0, 1]
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add", "modify"]:
+ assert_append(update_mode)
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = [0, 1]
+ assert_append("replace")
+
+ # Modified attribute
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"][("dset1", "attr1")] = "modified"
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ etreedict["group2"]["subgroup2"][("dset1", "attr1")] = "modified"
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ etreedict["group2"]["subgroup2"]["dset1"] = dict()
+ etreedict["group2"]["subgroup2"]["dset1"][("", "attr1")] = "modified"
+ assert_append("replace")
+
+ # Delete group
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = None
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ del etreedict["group2"]["subgroup2"]
+ del etreedict["group2"][("subgroup2", "attr1")]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ assert_append("replace")
+
+ # Delete dataset
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"]["dset2"] = None
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ del etreedict["group2"]["subgroup2"]["dset2"]
+ del etreedict["group2"]["subgroup2"][("dset2", "attr1")]
+ del etreedict["group2"]["subgroup2"][("dset2", "attr2")]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ assert_append("replace")
+
+ # Delete attribute
+ wtreedict = dict()
+ wtreedict["group2"] = dict()
+ wtreedict["group2"]["subgroup2"] = dict()
+ wtreedict["group2"]["subgroup2"][("dset2", "attr1")] = None
+
+ reset_file()
+ etreedict = deepcopy(otreedict)
+ for update_mode in [None, "add"]:
+ assert_append(update_mode)
+
+ del etreedict["group2"]["subgroup2"][("dset2", "attr1")]
+ assert_append("modify")
+
+ etreedict["group2"] = dict()
+ del etreedict[("group2", "attr1")]
+ etreedict["group2"]["subgroup2"] = dict()
+ etreedict["group2"]["subgroup2"]["dset2"] = dict()
+ assert_append("replace")
+
+
+class TestH5ToDict(H5DictTestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "cityattrs.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, ext_filename)
+ dicttoh5(city_attrs, self.h5_fname)
+ dicttoh5(link_attrs, self.h5_fname, mode="a")
+ dicttoh5(ext_attrs, self.h5_ext_fname)
+
+ def tearDown(self):
+ if os.path.exists(self.h5_fname):
+ os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
+ os.rmdir(self.tempdir)
+
+ def testExcludeNames(self):
+ ddict = h5todict(self.h5_fname, path="/Europe/France",
+ exclude_names=["ourcoing", "inhab", "toto"])
+ self.assertNotIn("Tourcoing", ddict)
+ self.assertIn("Grenoble", ddict)
+
+ self.assertNotIn("inhabitants", ddict["Grenoble"])
+ self.assertIn("coordinates", ddict["Grenoble"])
+ self.assertIn("area", ddict["Grenoble"])
+
+ def testAsArrayTrue(self):
+ """Test with asarray=True, the default"""
+ ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble")
+ self.assertTrue(numpy.array_equal(ddict["inhabitants"], numpy.array(inhabitants)))
+
+ def testAsArrayFalse(self):
+ """Test with asarray=False"""
+ ddict = h5todict(self.h5_fname, path="/Europe/France/Grenoble", asarray=False)
+ self.assertEqual(ddict["inhabitants"], inhabitants)
+
+ def testDereferenceLinks(self):
+ ddict = h5todict(self.h5_fname, path="links", dereference_links=True)
+ self.assertTrue(ddict["absolute_softlink"], 10)
+ self.assertTrue(ddict["relative_softlink"], 10)
+ self.assertTrue(ddict["external_link"], 10)
+ self.assertTrue(ddict["group"]["relative_softlink"], 10)
+
+ def testPreserveLinks(self):
+ ddict = h5todict(self.h5_fname, path="links", dereference_links=False)
+ self.assertTrue(is_link(ddict["absolute_softlink"]))
+ self.assertTrue(is_link(ddict["relative_softlink"]))
+ self.assertTrue(is_link(ddict["external_link"]))
+ self.assertTrue(is_link(ddict["group"]["relative_softlink"]))
+
+ def testStrings(self):
+ ddict = {"dset_bytes": b"bytes",
+ "dset_utf8": "utf8",
+ "dset_2bytes": [b"bytes", b"bytes"],
+ "dset_2utf8": ["utf8", "utf8"],
+ ("", "attr_bytes"): b"bytes",
+ ("", "attr_utf8"): "utf8",
+ ("", "attr_2bytes"): [b"bytes", b"bytes"],
+ ("", "attr_2utf8"): ["utf8", "utf8"]}
+ dicttoh5(ddict, self.h5_fname, mode="w")
+ adict = h5todict(self.h5_fname, include_attributes=True, asarray=False)
+ self.assertEqual(ddict["dset_bytes"], adict["dset_bytes"])
+ self.assertEqual(ddict["dset_utf8"], adict["dset_utf8"])
+ self.assertEqual(ddict[("", "attr_bytes")], adict[("", "attr_bytes")])
+ self.assertEqual(ddict[("", "attr_utf8")], adict[("", "attr_utf8")])
+ numpy.testing.assert_array_equal(ddict["dset_2bytes"], adict["dset_2bytes"])
+ numpy.testing.assert_array_equal(ddict["dset_2utf8"], adict["dset_2utf8"])
+ numpy.testing.assert_array_equal(ddict[("", "attr_2bytes")], adict[("", "attr_2bytes")])
+ numpy.testing.assert_array_equal(ddict[("", "attr_2utf8")], adict[("", "attr_2utf8")])
+
+
+class TestDictToNx(H5DictTestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "nx.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5")
+
+ def tearDown(self):
+ if os.path.exists(self.h5_fname):
+ os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
+ os.rmdir(self.tempdir)
+
+ def testAttributes(self):
+ """Any kind of attribute can be described"""
+ ddict = {
+ "group": {"dataset": 100, "@group_attr1": 10},
+ "dataset": 200,
+ "@root_attr": 11,
+ "dataset@dataset_attr": "12",
+ "group@group_attr2": 13,
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttonx(ddict, h5file)
+ self.assertEqual(h5file["group"].attrs['group_attr1'], 10)
+ self.assertEqual(h5file.attrs['root_attr'], 11)
+ self.assertEqual(h5file["dataset"].attrs['dataset_attr'], "12")
+ self.assertEqual(h5file["group"].attrs['group_attr2'], 13)
+
+ def testKeyOrder(self):
+ ddict1 = {
+ "d": "plow",
+ "d@a": "ox",
+ }
+ ddict2 = {
+ "d@a": "ox",
+ "d": "plow",
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttonx(ddict1, h5file, h5path="g1")
+ dictdump.dicttonx(ddict2, h5file, h5path="g2")
+ self.assertEqual(h5file["g1/d"].attrs['a'], "ox")
+ self.assertEqual(h5file["g2/d"].attrs['a'], "ox")
+
+ def testAttributeValues(self):
+ """Any NX data types can be used"""
+ ddict = {
+ "@bool": True,
+ "@int": 11,
+ "@float": 1.1,
+ "@str": "a",
+ "@boollist": [True, False, True],
+ "@intlist": [11, 22, 33],
+ "@floatlist": [1.1, 2.2, 3.3],
+ "@strlist": ["a", "bb", "ccc"],
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttonx(ddict, h5file)
+ for k, expected in ddict.items():
+ result = h5file.attrs[k[1:]]
+ if isinstance(expected, list):
+ if isinstance(expected[0], str):
+ numpy.testing.assert_array_equal(result, expected)
+ else:
+ numpy.testing.assert_array_almost_equal(result, expected)
+ else:
+ self.assertEqual(result, expected)
+
+ def testFlatDict(self):
+ """Description of a tree with a single level of keys"""
+ ddict = {
+ "group/group/dataset": 10,
+ "group/group/dataset@attr": 11,
+ "group/group@attr": 12,
+ }
+ with h5py.File(self.h5_fname, "w") as h5file:
+ dictdump.dicttonx(ddict, h5file)
+ self.assertEqual(h5file["group/group/dataset"][()], 10)
+ self.assertEqual(h5file["group/group/dataset"].attrs['attr'], 11)
+ self.assertEqual(h5file["group/group"].attrs['attr'], 12)
+
+ def testLinks(self):
+ ddict = {"ext_group": {"dataset": 10}}
+ dictdump.dicttonx(ddict, self.h5_ext_fname)
+ ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ self.assertEqual(h5file["links/group/dataset"][()], 10)
+ self.assertEqual(h5file["links/group/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/relative_softlink"][()], 10)
+ self.assertEqual(h5file["links/absolute_softlink"][()], 10)
+ self.assertEqual(h5file["links/external_link"][()], 10)
+
+ def testUpLinks(self):
+ ddict = {"data": {"group": {"dataset": 10, ">relative_softlink": "dataset"}},
+ "links": {"group": {"subgroup": {">relative_softlink": "../../../data/group/dataset"}}}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+ with h5py.File(self.h5_fname, "r") as h5file:
+ self.assertEqual(h5file["/links/group/subgroup/relative_softlink"][()], 10)
+
+ def testOverwrite(self):
+ entry_name = "entry"
+ wtreedict = {
+ "group1": {"a": 1, "b": 2},
+ "group2@attr3": "attr3",
+ "group2@attr4": "attr4",
+ "group2": {
+ "@attr1": "attr1",
+ "@attr2": "attr2",
+ "c": 3,
+ "d": 4,
+ "dataset4": 8,
+ "dataset4@units": "keV",
+ },
+ "group3": {"subgroup": {"e": 9, "f": 10}},
+ "dataset1": 5,
+ "dataset2": 6,
+ "dataset3": 7,
+ "dataset3@units": "mm",
+ }
+ esubtree = {
+ "@NX_class": "NXentry",
+ "group1": {"@NX_class": "NXcollection", "a": 1, "b": 2},
+ "group2": {
+ "@NX_class": "NXcollection",
+ "@attr1": "attr1",
+ "@attr2": "attr2",
+ "@attr3": "attr3",
+ "@attr4": "attr4",
+ "c": 3,
+ "d": 4,
+ "dataset4": 8,
+ "dataset4@units": "keV",
+ },
+ "group3": {
+ "@NX_class": "NXcollection",
+ "subgroup": {"@NX_class": "NXcollection", "e": 9, "f": 10},
+ },
+ "dataset1": 5,
+ "dataset2": 6,
+ "dataset3": 7,
+ "dataset3@units": "mm",
+ }
+ etreedict = {entry_name: esubtree}
+
+ def append_file(update_mode, add_nx_class):
+ dictdump.dicttonx(
+ wtreedict,
+ h5file=self.h5_fname,
+ mode="a",
+ h5path=entry_name,
+ update_mode=update_mode,
+ add_nx_class=add_nx_class
+ )
+
+ def assert_file():
+ rtreedict = dictdump.nxtodict(
+ self.h5_fname,
+ include_attributes=True,
+ asarray=False,
+ )
+ netreedict = self.dictRoundTripNormalize(etreedict)
+ try:
+ self.assertRecursiveEqual(netreedict, rtreedict)
+ except AssertionError:
+ from pprint import pprint
+ print("\nDUMP:")
+ pprint(wtreedict)
+ print("\nEXPECTED:")
+ pprint(netreedict)
+ print("\nHDF5:")
+ pprint(rtreedict)
+ raise
+
+ def assert_append(update_mode, add_nx_class=None):
+ append_file(update_mode, add_nx_class=add_nx_class)
+ assert_file()
+
+ # First to an empty file
+ assert_append(None)
+
+ # Add non-existing attributes/datasets/groups
+ wtreedict["group1"].pop("a")
+ wtreedict["group2"].pop("@attr1")
+ wtreedict["group2"]["@attr2"] = "attr3" # only for update
+ wtreedict["group2"]["@type"] = "test"
+ wtreedict["group2"]["dataset4"] = 9 # only for update
+ del wtreedict["group2"]["dataset4@units"]
+ wtreedict["group3"] = {}
+ esubtree["group2"]["@type"] = "test"
+ assert_append("add")
+
+ # Add update existing attributes and datasets
+ esubtree["group2"]["@attr2"] = "attr3"
+ esubtree["group2"]["dataset4"] = 9
+ assert_append("modify")
+
+ # Do not add missing NX_class by default when updating
+ wtreedict["group2"]["@NX_class"] = "NXprocess"
+ esubtree["group2"]["@NX_class"] = "NXprocess"
+ assert_append("modify")
+ del wtreedict["group2"]["@NX_class"]
+ assert_append("modify")
+
+ # Overwrite existing groups/datasets/attributes
+ esubtree["group1"].pop("a")
+ esubtree["group2"].pop("@attr1")
+ esubtree["group2"]["@NX_class"] = "NXcollection"
+ esubtree["group2"]["dataset4"] = 9
+ del esubtree["group2"]["dataset4@units"]
+ esubtree["group3"] = {"@NX_class": "NXcollection"}
+ assert_append("replace", add_nx_class=True)
+
+
+class TestNxToDict(H5DictTestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "nx.h5")
+ self.h5_ext_fname = os.path.join(self.tempdir, "nx_ext.h5")
+
+ def tearDown(self):
+ if os.path.exists(self.h5_fname):
+ os.unlink(self.h5_fname)
+ if os.path.exists(self.h5_ext_fname):
+ os.unlink(self.h5_ext_fname)
+ os.rmdir(self.tempdir)
+
+ def testAttributes(self):
+ """Any kind of attribute can be described"""
+ ddict = {
+ "group": {"dataset": 100, "@group_attr1": 10},
+ "dataset": 200,
+ "@root_attr": 11,
+ "dataset@dataset_attr": "12",
+ "group@group_attr2": 13,
+ }
+ dictdump.dicttonx(ddict, self.h5_fname)
+ ddict = dictdump.nxtodict(self.h5_fname, include_attributes=True)
+ self.assertEqual(ddict["group"]["@group_attr1"], 10)
+ self.assertEqual(ddict["@root_attr"], 11)
+ self.assertEqual(ddict["dataset@dataset_attr"], "12")
+ self.assertEqual(ddict["group"]["@group_attr2"], 13)
+
+ def testDereferenceLinks(self):
+ """Write links and dereference on read"""
+ ddict = {"ext_group": {"dataset": 10}}
+ dictdump.dicttonx(ddict, self.h5_ext_fname)
+ ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+
+ ddict = dictdump.h5todict(self.h5_fname, dereference_links=True)
+ self.assertTrue(ddict["links"]["absolute_softlink"], 10)
+ self.assertTrue(ddict["links"]["relative_softlink"], 10)
+ self.assertTrue(ddict["links"]["external_link"], 10)
+ self.assertTrue(ddict["links"]["group"]["relative_softlink"], 10)
+
+ def testPreserveLinks(self):
+ """Write/read links"""
+ ddict = {"ext_group": {"dataset": 10}}
+ dictdump.dicttonx(ddict, self.h5_ext_fname)
+ ddict = {"links": {"group": {"dataset": 10, ">relative_softlink": "dataset"},
+ ">relative_softlink": "group/dataset",
+ ">absolute_softlink": "/links/group/dataset",
+ ">external_link": "nx_ext.h5::/ext_group/dataset"}}
+ dictdump.dicttonx(ddict, self.h5_fname)
+
+ ddict = dictdump.nxtodict(self.h5_fname, dereference_links=False)
+ self.assertTrue(ddict["links"][">absolute_softlink"], "dataset")
+ self.assertTrue(ddict["links"][">relative_softlink"], "group/dataset")
+ self.assertTrue(ddict["links"][">external_link"], "/links/group/dataset")
+ self.assertTrue(ddict["links"]["group"][">relative_softlink"], "nx_ext.h5::/ext_group/datase")
+
+ def testNotExistingPath(self):
+ """Test converting not existing path"""
+ with h5py.File(self.h5_fname, 'a') as f:
+ f['data'] = 1
+
+ ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='ignore')
+ self.assertFalse(ddict)
+
+ with LoggingValidator(dictdump_logger, error=1):
+ ddict = h5todict(self.h5_fname, path="/I/am/not/a/path", errors='log')
+ self.assertFalse(ddict)
+
+ with self.assertRaises(KeyError):
+ h5todict(self.h5_fname, path="/I/am/not/a/path", errors='raise')
+
+ def testBrokenLinks(self):
+ """Test with broken links"""
+ with h5py.File(self.h5_fname, 'a') as f:
+ f["/Mars/BrokenSoftLink"] = h5py.SoftLink("/Idontexists")
+ f["/Mars/BrokenExternalLink"] = h5py.ExternalLink("notexistingfile.h5", "/Idontexists")
+
+ ddict = h5todict(self.h5_fname, path="/Mars", errors='ignore')
+ self.assertFalse(ddict)
+
+ with LoggingValidator(dictdump_logger, error=2):
+ ddict = h5todict(self.h5_fname, path="/Mars", errors='log')
+ self.assertFalse(ddict)
+
+ with self.assertRaises(KeyError):
+ h5todict(self.h5_fname, path="/Mars", errors='raise')
+
+
+class TestDictToJson(DictTestCase):
+ def setUp(self):
+ self.dir_path = tempfile.mkdtemp()
+ self.json_fname = os.path.join(self.dir_path, "cityattrs.json")
+
+ def tearDown(self):
+ os.unlink(self.json_fname)
+ os.rmdir(self.dir_path)
+
+ def testJsonCityAttrs(self):
+ self.json_fname = os.path.join(self.dir_path, "cityattrs.json")
+ dicttojson(city_attrs, self.json_fname, indent=3)
+
+ with open(self.json_fname, "r") as f:
+ json_content = f.read()
+ self.assertIn('"inhabitants": 160215', json_content)
+
+
+class TestDictToIni(DictTestCase):
+ def setUp(self):
+ self.dir_path = tempfile.mkdtemp()
+ self.ini_fname = os.path.join(self.dir_path, "test.ini")
+
+ def tearDown(self):
+ os.unlink(self.ini_fname)
+ os.rmdir(self.dir_path)
+
+ def testConfigDictIO(self):
+ """Ensure values and types of data is preserved when dictionary is
+ written to file and read back."""
+ testdict = {
+ 'simple_types': {
+ 'float': 1.0,
+ 'int': 1,
+ 'percent string': '5 % is too much',
+ 'backslash string': 'i can use \\',
+ 'empty_string': '',
+ 'nonestring': 'None',
+ 'nonetype': None,
+ 'interpstring': 'interpolation: %(percent string)s',
+ },
+ 'containers': {
+ 'list': [-1, 'string', 3.0, False, None],
+ 'array': numpy.array([1.0, 2.0, 3.0]),
+ 'dict': {
+ 'key1': 'Hello World',
+ 'key2': 2.0,
+ }
+ }
+ }
+
+ dump(testdict, self.ini_fname)
+
+ #read the data back
+ readdict = load(self.ini_fname)
+
+ testdictkeys = list(testdict.keys())
+ readkeys = list(readdict.keys())
+
+ self.assertTrue(len(readkeys) == len(testdictkeys),
+ "Number of read keys not equal")
+
+ self.assertEqual(readdict['simple_types']["interpstring"],
+ "interpolation: 5 % is too much")
+
+ testdict['simple_types']["interpstring"] = "interpolation: 5 % is too much"
+
+ for key in testdict["simple_types"]:
+ original = testdict['simple_types'][key]
+ read = readdict['simple_types'][key]
+ self.assertEqual(read, original,
+ "Read <%s> instead of <%s>" % (read, original))
+
+ for key in testdict["containers"]:
+ original = testdict["containers"][key]
+ read = readdict["containers"][key]
+ if key == 'array':
+ self.assertEqual(read.all(), original.all(),
+ "Read <%s> instead of <%s>" % (read, original))
+ else:
+ self.assertEqual(read, original,
+ "Read <%s> instead of <%s>" % (read, original))
+
+ def testConfigDictOrder(self):
+ """Ensure order is preserved when dictionary is
+ written to file and read back."""
+ test_dict = {'banana': 3, 'apple': 4, 'pear': 1, 'orange': 2}
+ # sort by key
+ test_ordered_dict1 = OrderedDict(sorted(test_dict.items(),
+ key=lambda t: t[0]))
+ # sort by value
+ test_ordered_dict2 = OrderedDict(sorted(test_dict.items(),
+ key=lambda t: t[1]))
+ # add the two ordered dict as sections of a third ordered dict
+ test_ordered_dict3 = OrderedDict()
+ test_ordered_dict3["section1"] = test_ordered_dict1
+ test_ordered_dict3["section2"] = test_ordered_dict2
+
+ # write to ini and read back as a ConfigDict (inherits OrderedDict)
+ dump(test_ordered_dict3,
+ self.ini_fname, fmat="ini")
+ read_instance = ConfigDict()
+ read_instance.read(self.ini_fname)
+
+ # loop through original and read-back dictionaries,
+ # test identical order for key/value pairs
+ for orig_key, section in zip(test_ordered_dict3.keys(),
+ read_instance.keys()):
+ self.assertEqual(orig_key, section)
+ for orig_key2, read_key in zip(test_ordered_dict3[section].keys(),
+ read_instance[section].keys()):
+ self.assertEqual(orig_key2, read_key)
+ self.assertEqual(test_ordered_dict3[section][orig_key2],
+ read_instance[section][read_key])
diff --git a/src/silx/io/test/test_fabioh5.py b/src/silx/io/test/test_fabioh5.py
new file mode 100755
index 0000000..c410024
--- /dev/null
+++ b/src/silx/io/test/test_fabioh5.py
@@ -0,0 +1,615 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for fabioh5 wrapper"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "02/07/2018"
+
+import os
+import logging
+import numpy
+import unittest
+import tempfile
+import shutil
+
+_logger = logging.getLogger(__name__)
+
+import fabio
+import h5py
+
+from .. import commonh5
+from .. import fabioh5
+
+
+class TestFabioH5(unittest.TestCase):
+
+ def setUp(self):
+
+ header = {
+ "integer": "-100",
+ "float": "1.0",
+ "string": "hi!",
+ "list_integer": "100 50 0",
+ "list_float": "1.0 2.0 3.5",
+ "string_looks_like_list": "2000 hi!",
+ }
+ data = numpy.array([[10, 11], [12, 13], [14, 15]], dtype=numpy.int64)
+ self.fabio_image = fabio.numpyimage.NumpyImage(data, header)
+ self.h5_image = fabioh5.File(fabio_image=self.fabio_image)
+
+ def test_main_groups(self):
+ self.assertEqual(self.h5_image.h5py_class, h5py.File)
+ self.assertEqual(self.h5_image["/"].h5py_class, h5py.File)
+ self.assertEqual(self.h5_image["/scan_0"].h5py_class, h5py.Group)
+ self.assertEqual(self.h5_image["/scan_0/instrument"].h5py_class, h5py.Group)
+ self.assertEqual(self.h5_image["/scan_0/measurement"].h5py_class, h5py.Group)
+
+ def test_wrong_path_syntax(self):
+ # result tested with a default h5py file
+ self.assertRaises(ValueError, lambda: self.h5_image[""])
+
+ def test_wrong_root_name(self):
+ # result tested with a default h5py file
+ self.assertRaises(KeyError, lambda: self.h5_image["/foo"])
+
+ def test_wrong_root_path(self):
+ # result tested with a default h5py file
+ self.assertRaises(KeyError, lambda: self.h5_image["/foo/foo"])
+
+ def test_wrong_name(self):
+ # result tested with a default h5py file
+ self.assertRaises(KeyError, lambda: self.h5_image["foo"])
+
+ def test_wrong_path(self):
+ # result tested with a default h5py file
+ self.assertRaises(KeyError, lambda: self.h5_image["foo/foo"])
+
+ def test_single_frame(self):
+ data = numpy.arange(2 * 3)
+ data.shape = 2, 3
+ fabio_image = fabio.edfimage.edfimage(data=data)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+
+ dataset = h5_image["/scan_0/instrument/detector_0/data"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertTrue(isinstance(dataset[()], numpy.ndarray))
+ self.assertEqual(dataset.dtype.kind, "i")
+ self.assertEqual(dataset.shape, (2, 3))
+ self.assertEqual(dataset[...][0, 0], 0)
+ self.assertEqual(dataset.attrs["interpretation"], "image")
+
+ def test_multi_frames(self):
+ data = numpy.arange(2 * 3)
+ data.shape = 2, 3
+ fabio_image = fabio.edfimage.edfimage(data=data)
+ fabio_image.append_frame(data=data)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+
+ dataset = h5_image["/scan_0/instrument/detector_0/data"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertTrue(isinstance(dataset[()], numpy.ndarray))
+ self.assertEqual(dataset.dtype.kind, "i")
+ self.assertEqual(dataset.shape, (2, 2, 3))
+ self.assertEqual(dataset[...][0, 0, 0], 0)
+ self.assertEqual(dataset.attrs["interpretation"], "image")
+
+ def test_heterogeneous_frames(self):
+ """Frames containing 2 images with different sizes and a cube"""
+ data1 = numpy.arange(2 * 3)
+ data1.shape = 2, 3
+ data2 = numpy.arange(2 * 5)
+ data2.shape = 2, 5
+ data3 = numpy.arange(2 * 5 * 1)
+ data3.shape = 2, 5, 1
+ fabio_image = fabio.edfimage.edfimage(data=data1)
+ fabio_image.append_frame(data=data2)
+ fabio_image.append_frame(data=data3)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+
+ dataset = h5_image["/scan_0/instrument/detector_0/data"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertTrue(isinstance(dataset[()], numpy.ndarray))
+ self.assertEqual(dataset.dtype.kind, "i")
+ self.assertEqual(dataset.shape, (3, 2, 5, 1))
+ self.assertEqual(dataset[...][0, 0, 0], 0)
+ self.assertEqual(dataset.attrs["interpretation"], "image")
+
+ def test_single_3d_frame(self):
+ """Image source contains a cube"""
+ data = numpy.arange(2 * 3 * 4)
+ data.shape = 2, 3, 4
+ # Do not provide the data to the constructor to avoid slicing of the
+ # data. In this way the result stay a cube, and not a multi-frame
+ fabio_image = fabio.edfimage.edfimage()
+ fabio_image.data = data
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+
+ dataset = h5_image["/scan_0/instrument/detector_0/data"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertTrue(isinstance(dataset[()], numpy.ndarray))
+ self.assertEqual(dataset.dtype.kind, "i")
+ self.assertEqual(dataset.shape, (2, 3, 4))
+ self.assertEqual(dataset[...][0, 0, 0], 0)
+ self.assertEqual(dataset.attrs["interpretation"], "image")
+
+ def test_metadata_int(self):
+ dataset = self.h5_image["/scan_0/instrument/detector_0/others/integer"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertEqual(dataset[()], -100)
+ self.assertEqual(dataset.dtype.kind, "i")
+ self.assertEqual(dataset.shape, (1,))
+
+ def test_metadata_float(self):
+ dataset = self.h5_image["/scan_0/instrument/detector_0/others/float"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertEqual(dataset[()], 1.0)
+ self.assertEqual(dataset.dtype.kind, "f")
+ self.assertEqual(dataset.shape, (1,))
+
+ def test_metadata_string(self):
+ dataset = self.h5_image["/scan_0/instrument/detector_0/others/string"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertEqual(dataset[()], numpy.string_("hi!"))
+ self.assertEqual(dataset.dtype.type, numpy.string_)
+ self.assertEqual(dataset.shape, (1,))
+
+ def test_metadata_list_integer(self):
+ dataset = self.h5_image["/scan_0/instrument/detector_0/others/list_integer"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertEqual(dataset.dtype.kind, "u")
+ self.assertEqual(dataset.shape, (1, 3))
+ self.assertEqual(dataset[0, 0], 100)
+ self.assertEqual(dataset[0, 1], 50)
+
+ def test_metadata_list_float(self):
+ dataset = self.h5_image["/scan_0/instrument/detector_0/others/list_float"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertEqual(dataset.dtype.kind, "f")
+ self.assertEqual(dataset.shape, (1, 3))
+ self.assertEqual(dataset[0, 0], 1.0)
+ self.assertEqual(dataset[0, 1], 2.0)
+
+ def test_metadata_list_looks_like_list(self):
+ dataset = self.h5_image["/scan_0/instrument/detector_0/others/string_looks_like_list"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertEqual(dataset[()], numpy.string_("2000 hi!"))
+ self.assertEqual(dataset.dtype.type, numpy.string_)
+ self.assertEqual(dataset.shape, (1,))
+
+ def test_float_32(self):
+ float_list = [u'1.2', u'1.3', u'1.4']
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = None
+ for float_item in float_list:
+ header = {"float_item": float_item}
+ if fabio_image is None:
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ else:
+ fabio_image.append_frame(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
+ # There is no equality between items
+ self.assertEqual(len(data), len(set(data)))
+ # At worst a float32
+ self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertLessEqual(data.dtype.itemsize, 32 / 8)
+
+ def test_float_64(self):
+ float_list = [
+ u'1469117129.082226',
+ u'1469117136.684986', u'1469117144.312749', u'1469117151.892507',
+ u'1469117159.474265', u'1469117167.100027', u'1469117174.815799',
+ u'1469117182.437561', u'1469117190.094326', u'1469117197.721089']
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = None
+ for float_item in float_list:
+ header = {"time_of_day": float_item}
+ if fabio_image is None:
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ else:
+ fabio_image.append_frame(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ data = h5_image["/scan_0/instrument/detector_0/others/time_of_day"]
+ # There is no equality between items
+ self.assertEqual(len(data), len(set(data)))
+ # At least a float64
+ self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertGreaterEqual(data.dtype.itemsize, 64 / 8)
+
+ def test_mixed_float_size__scalar(self):
+ # We expect to have a precision of 32 bits
+ float_list = [u'1.2', u'1.3001']
+ expected_float_result = [1.2, 1.3001]
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = None
+ for float_item in float_list:
+ header = {"float_item": float_item}
+ if fabio_image is None:
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ else:
+ fabio_image.append_frame(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
+ # At worst a float32
+ self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertLessEqual(data.dtype.itemsize, 32 / 8)
+ for computed, expected in zip(data, expected_float_result):
+ numpy.testing.assert_almost_equal(computed, expected, 5)
+
+ def test_mixed_float_size__list(self):
+ # We expect to have a precision of 32 bits
+ float_list = [u'1.2 1.3001']
+ expected_float_result = numpy.array([[1.2, 1.3001]])
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = None
+ for float_item in float_list:
+ header = {"float_item": float_item}
+ if fabio_image is None:
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ else:
+ fabio_image.append_frame(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
+ # At worst a float32
+ self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertLessEqual(data.dtype.itemsize, 32 / 8)
+ for computed, expected in zip(data, expected_float_result):
+ numpy.testing.assert_almost_equal(computed, expected, 5)
+
+ def test_mixed_float_size__list_of_list(self):
+ # We expect to have a precision of 32 bits
+ float_list = [u'1.2 1.3001', u'1.3001 1.3001']
+ expected_float_result = numpy.array([[1.2, 1.3001], [1.3001, 1.3001]])
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = None
+ for float_item in float_list:
+ header = {"float_item": float_item}
+ if fabio_image is None:
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ else:
+ fabio_image.append_frame(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ data = h5_image["/scan_0/instrument/detector_0/others/float_item"]
+ # At worst a float32
+ self.assertIn(data.dtype.kind, ['d', 'f'])
+ self.assertLessEqual(data.dtype.itemsize, 32 / 8)
+ for computed, expected in zip(data, expected_float_result):
+ numpy.testing.assert_almost_equal(computed, expected, 5)
+
+ def test_ub_matrix(self):
+ """Data from mediapix.edf"""
+ header = {}
+ header["UB_mne"] = 'UB0 UB1 UB2 UB3 UB4 UB5 UB6 UB7 UB8'
+ header["UB_pos"] = '1.99593e-16 2.73682e-16 -1.54 -1.08894 1.08894 1.6083e-16 1.08894 1.08894 9.28619e-17'
+ header["sample_mne"] = 'U0 U1 U2 U3 U4 U5'
+ header["sample_pos"] = '4.08 4.08 4.08 90 90 90'
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ sample = h5_image["/scan_0/sample"]
+ self.assertIsNotNone(sample)
+ self.assertEqual(sample.attrs["NXclass"], "NXsample")
+
+ d = sample['unit_cell_abc']
+ expected = numpy.array([4.08, 4.08, 4.08])
+ self.assertIsNotNone(d)
+ self.assertEqual(d.shape, (3, ))
+ self.assertIn(d.dtype.kind, ['d', 'f'])
+ numpy.testing.assert_array_almost_equal(d[...], expected)
+
+ d = sample['unit_cell_alphabetagamma']
+ expected = numpy.array([90.0, 90.0, 90.0])
+ self.assertIsNotNone(d)
+ self.assertEqual(d.shape, (3, ))
+ self.assertIn(d.dtype.kind, ['d', 'f'])
+ numpy.testing.assert_array_almost_equal(d[...], expected)
+
+ d = sample['ub_matrix']
+ expected = numpy.array([[[1.99593e-16, 2.73682e-16, -1.54],
+ [-1.08894, 1.08894, 1.6083e-16],
+ [1.08894, 1.08894, 9.28619e-17]]])
+ self.assertIsNotNone(d)
+ self.assertEqual(d.shape, (1, 3, 3))
+ self.assertIn(d.dtype.kind, ['d', 'f'])
+ numpy.testing.assert_array_almost_equal(d[...], expected)
+
+ def test_interpretation_mca_edf(self):
+ """EDF files with two or more headers starting with "MCA"
+ must have @interpretation = "spectrum" an the data."""
+ header = {
+ "Title": "zapimage samy -4.975 -5.095 80 500 samz -4.091 -4.171 70 0",
+ "MCA a": -23.812,
+ "MCA b": 2.7107,
+ "MCA c": 8.1164e-06}
+
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = fabio.edfimage.EdfImage(data=data, header=header)
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+
+ data_dataset = h5_image["/scan_0/measurement/image_0/data"]
+ self.assertEqual(data_dataset.attrs["interpretation"], "spectrum")
+
+ data_dataset = h5_image["/scan_0/instrument/detector_0/data"]
+ self.assertEqual(data_dataset.attrs["interpretation"], "spectrum")
+
+ data_dataset = h5_image["/scan_0/measurement/image_0/info/data"]
+ self.assertEqual(data_dataset.attrs["interpretation"], "spectrum")
+
+ def test_get_api(self):
+ result = self.h5_image.get("scan_0", getclass=True, getlink=True)
+ self.assertIs(result, h5py.HardLink)
+ result = self.h5_image.get("scan_0", getclass=False, getlink=True)
+ self.assertIsInstance(result, h5py.HardLink)
+ result = self.h5_image.get("scan_0", getclass=True, getlink=False)
+ self.assertIs(result, h5py.Group)
+ result = self.h5_image.get("scan_0", getclass=False, getlink=False)
+ self.assertIsInstance(result, commonh5.Group)
+
+ def test_detector_link(self):
+ detector1 = self.h5_image["/scan_0/instrument/detector_0"]
+ detector2 = self.h5_image["/scan_0/measurement/image_0/info"]
+ self.assertIsNot(detector1, detector2)
+ self.assertEqual(list(detector1.items()), list(detector2.items()))
+ self.assertEqual(self.h5_image.get(detector2.name, getlink=True).path, detector1.name)
+
+ def test_detector_data_link(self):
+ data1 = self.h5_image["/scan_0/instrument/detector_0/data"]
+ data2 = self.h5_image["/scan_0/measurement/image_0/data"]
+ self.assertIsNot(data1, data2)
+ self.assertIs(data1._get_data(), data2._get_data())
+ self.assertEqual(self.h5_image.get(data2.name, getlink=True).path, data1.name)
+
+ def test_dirty_header(self):
+ """Test that it does not fail"""
+ try:
+ header = {}
+ header["foo"] = b'abc'
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = fabio.edfimage.edfimage(data=data, header=header)
+ header = {}
+ header["foo"] = b'a\x90bc\xFE'
+ fabio_image.append_frame(data=data, header=header)
+ except Exception as e:
+ _logger.error(e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ self.skipTest("fabio do not allow to create the resource")
+
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ scan_header_path = "/scan_0/instrument/file/scan_header"
+ self.assertIn(scan_header_path, h5_image)
+ data = h5_image[scan_header_path]
+ self.assertIsInstance(data[...], numpy.ndarray)
+
+ def test_unicode_header(self):
+ """Test that it does not fail"""
+ try:
+ header = {}
+ header["foo"] = b'abc'
+ data = numpy.array([[0, 0], [0, 0]], dtype=numpy.int8)
+ fabio_image = fabio.edfimage.edfimage(data=data, header=header)
+ header = {}
+ header["foo"] = u'abc\u2764'
+ fabio_image.append_frame(data=data, header=header)
+ except Exception as e:
+ _logger.error(e.args[0])
+ _logger.debug("Backtrace", exc_info=True)
+ self.skipTest("fabio do not allow to create the resource")
+
+ h5_image = fabioh5.File(fabio_image=fabio_image)
+ scan_header_path = "/scan_0/instrument/file/scan_header"
+ self.assertIn(scan_header_path, h5_image)
+ data = h5_image[scan_header_path]
+ self.assertIsInstance(data[...], numpy.ndarray)
+
+
+class TestFabioH5MultiFrames(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+
+ names = ["A", "B", "C", "D"]
+ values = [["32000", "-10", "5.0", "1"],
+ ["-32000", "-10", "5.0", "1"]]
+
+ fabio_file = None
+
+ for i in range(10):
+ header = {
+ "image_id": "%d" % i,
+ "integer": "-100",
+ "float": "1.0",
+ "string": "hi!",
+ "list_integer": "100 50 0",
+ "list_float": "1.0 2.0 3.5",
+ "string_looks_like_list": "2000 hi!",
+ "motor_mne": " ".join(names),
+ "motor_pos": " ".join(values[i % len(values)]),
+ "counter_mne": " ".join(names),
+ "counter_pos": " ".join(values[i % len(values)])
+ }
+ for iname, name in enumerate(names):
+ header[name] = values[i % len(values)][iname]
+
+ data = numpy.array([[i, 11], [12, 13], [14, 15]], dtype=numpy.int64)
+ if fabio_file is None:
+ fabio_file = fabio.edfimage.EdfImage(data=data, header=header)
+ else:
+ fabio_file.append_frame(data=data, header=header)
+
+ cls.fabio_file = fabio_file
+ cls.fabioh5 = fabioh5.File(fabio_image=fabio_file)
+
+ def test_others(self):
+ others = self.fabioh5["/scan_0/instrument/detector_0/others"]
+ dataset = others["A"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 1)
+ self.assertEqual(dataset.dtype.kind, "i")
+ dataset = others["B"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 1)
+ self.assertEqual(dataset.dtype.kind, "i")
+ dataset = others["C"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 1)
+ self.assertEqual(dataset.dtype.kind, "f")
+ dataset = others["D"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 1)
+ self.assertEqual(dataset.dtype.kind, "u")
+
+ def test_positioners(self):
+ counters = self.fabioh5["/scan_0/instrument/positioners"]
+ # At least 32 bits, no unsigned values
+ dataset = counters["A"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "i")
+ dataset = counters["B"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "i")
+ dataset = counters["C"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "f")
+ dataset = counters["D"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "i")
+
+ def test_counters(self):
+ counters = self.fabioh5["/scan_0/measurement"]
+ # At least 32 bits, no unsigned values
+ dataset = counters["A"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "i")
+ dataset = counters["B"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "i")
+ dataset = counters["C"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "f")
+ dataset = counters["D"]
+ self.assertGreaterEqual(dataset.dtype.itemsize, 4)
+ self.assertEqual(dataset.dtype.kind, "i")
+
+
+class TestFabioH5WithEdf(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+
+ cls.tmp_directory = tempfile.mkdtemp()
+
+ cls.edf_filename = os.path.join(cls.tmp_directory, "test.edf")
+
+ header = {
+ "integer": "-100",
+ "float": "1.0",
+ "string": "hi!",
+ "list_integer": "100 50 0",
+ "list_float": "1.0 2.0 3.5",
+ "string_looks_like_list": "2000 hi!",
+ }
+ data = numpy.array([[10, 11], [12, 13], [14, 15]], dtype=numpy.int64)
+ fabio_image = fabio.edfimage.edfimage(data, header)
+ fabio_image.write(cls.edf_filename)
+
+ cls.fabio_image = fabio.open(cls.edf_filename)
+ cls.h5_image = fabioh5.File(fabio_image=cls.fabio_image)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.fabio_image = None
+ cls.h5_image = None
+ shutil.rmtree(cls.tmp_directory)
+
+ def test_reserved_format_metadata(self):
+ if fabio.hexversion < 327920: # 0.5.0 final
+ self.skipTest("fabio >= 0.5.0 final is needed")
+
+ # The EDF contains reserved keys in the header
+ self.assertIn("HeaderID", self.fabio_image.header)
+ # We do not expose them in FabioH5
+ self.assertNotIn("/scan_0/instrument/detector_0/others/HeaderID", self.h5_image)
+
+
+class _TestableFrameData(fabioh5.FrameData):
+ """Allow to test if the full data is reached."""
+ def _create_data(self):
+ raise RuntimeError("Not supposed to be called")
+
+
+class TestFabioH5WithFileSeries(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+
+ cls.tmp_directory = tempfile.mkdtemp()
+
+ cls.edf_filenames = []
+
+ for i in range(10):
+ filename = os.path.join(cls.tmp_directory, "test_%04d.edf" % i)
+ cls.edf_filenames.append(filename)
+
+ header = {
+ "image_id": "%d" % i,
+ "integer": "-100",
+ "float": "1.0",
+ "string": "hi!",
+ "list_integer": "100 50 0",
+ "list_float": "1.0 2.0 3.5",
+ "string_looks_like_list": "2000 hi!",
+ }
+ data = numpy.array([[i, 11], [12, 13], [14, 15]], dtype=numpy.int64)
+ fabio_image = fabio.edfimage.edfimage(data, header)
+ fabio_image.write(filename)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmp_directory)
+
+ def _testH5Image(self, h5_image):
+ # test data
+ dataset = h5_image["/scan_0/instrument/detector_0/data"]
+ self.assertEqual(dataset.h5py_class, h5py.Dataset)
+ self.assertTrue(isinstance(dataset[()], numpy.ndarray))
+ self.assertEqual(dataset.dtype.kind, "i")
+ self.assertEqual(dataset.shape, (10, 3, 2))
+ self.assertEqual(list(dataset[:, 0, 0]), list(range(10)))
+ self.assertEqual(dataset.attrs["interpretation"], "image")
+ # test metatdata
+ dataset = h5_image["/scan_0/instrument/detector_0/others/image_id"]
+ self.assertEqual(list(dataset[...]), list(range(10)))
+
+ def testFileList(self):
+ h5_image = fabioh5.File(file_series=self.edf_filenames)
+ self._testH5Image(h5_image)
+
+ def testFileSeries(self):
+ file_series = fabioh5._FileSeries(self.edf_filenames)
+ h5_image = fabioh5.File(file_series=file_series)
+ self._testH5Image(h5_image)
+
+ def testFrameDataCache(self):
+ file_series = fabioh5._FileSeries(self.edf_filenames)
+ reader = fabioh5.FabioReader(file_series=file_series)
+ frameData = _TestableFrameData("foo", reader)
+ self.assertEqual(frameData.dtype.kind, "i")
+ self.assertEqual(frameData.shape, (10, 3, 2))
diff --git a/src/silx/io/test/test_fioh5.py b/src/silx/io/test/test_fioh5.py
new file mode 100644
index 0000000..8ffb4ad
--- /dev/null
+++ b/src/silx/io/test/test_fioh5.py
@@ -0,0 +1,299 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2021 Timo Fuchs
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for fioh5"""
+import numpy
+import os
+import io
+import sys
+import tempfile
+import unittest
+import datetime
+import logging
+
+from silx.utils import testutils
+
+from .. import fioh5
+from ..fioh5 import (FioH5, FioH5NodeDataset, is_fiofile, logger1, dtypeConverter)
+
+import h5py
+
+__authors__ = ["T. Fuchs"]
+__license__ = "MIT"
+__date__ = "15/10/2021"
+
+fioftext = """
+!
+! Comments
+!
+%c
+ascan omega 180.0 180.5 3:10/1 4
+user username, acquisition started at Thu Dec 12 18:00:00 2021
+sweep motor lag: 1.0e-03
+channel 3: Detector
+!
+! Parameter
+!
+%p
+channel3_exposure = 1.000000e+00
+ScanName = ascan
+!
+! Data
+!
+%d
+ Col 1 omega(encoder) DOUBLE
+ Col 2 channel INTEGER
+ Col 3 filename STRING
+ Col 4 type STRING
+ Col 5 unix time DOUBLE
+ Col 6 enable BOOLEAN
+ Col 7 time_s FLOAT
+ 179.998418821 3 00001 exposure 1576165741.20308 1 1.243
+ 180.048418821 3 00002 exposure 1576165742.20308 1 1.243
+ 180.098418821 3 00003 exposure 1576165743.20308 1 1.243
+ 180.148418821 3 00004 exposure 1576165744.20308 1 1.243
+ 180.198418821 3 00005 exposure 1576165745.20308 1 1.243
+ 180.248418821 3 00006 exposure 1576165746.20308 1 1.243
+ 180.298418821 3 00007 exposure 1576165747.20308 1 1.243
+ 180.348418821 3 00008 exposure 1576165748.20308 1 1.243
+ 180.398418821 3 00009 exposure 1576165749.20308 1 1.243
+ 180.448418821 3 00010 exposure 1576165750.20308 1 1.243
+"""
+
+
+
+class TestFioH5(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = tempfile.TemporaryDirectory()
+ #fd, cls.fname = tempfile.mkstemp()
+ cls.fname_numbered = os.path.join(cls.temp_dir.name, "eh1scan_00005.fio")
+
+ with open(cls.fname_numbered, 'w') as fiof:
+ fiof.write(fioftext)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+ del cls.temp_dir
+
+ def setUp(self):
+ self.fioh5 = FioH5(self.fname_numbered)
+
+ def tearDown(self):
+ self.fioh5.close()
+
+ def testScanNumber(self):
+ # scan number is parsed from the file name.
+ self.assertIn("/5.1", self.fioh5)
+ self.assertIn("5.1", self.fioh5)
+
+ def testContainsFile(self):
+ self.assertIn("/5.1/measurement", self.fioh5)
+ self.assertNotIn("25.2", self.fioh5)
+ # measurement is a child of a scan, full path would be required to
+ # access from root level
+ self.assertNotIn("measurement", self.fioh5)
+ # Groups may or may not have a trailing /
+ self.assertIn("/5.1/measurement/", self.fioh5)
+ self.assertIn("/5.1/measurement", self.fioh5)
+ # Datasets can't have a trailing /
+ self.assertIn("/5.1/measurement/omega(encoder)", self.fioh5)
+ self.assertNotIn("/5.1/measurement/omega(encoder)/", self.fioh5)
+ # No gamma
+ self.assertNotIn("/5.1/measurement/gamma", self.fioh5)
+
+ def testContainsGroup(self):
+ self.assertIn("measurement", self.fioh5["/5.1/"])
+ self.assertIn("measurement", self.fioh5["/5.1"])
+ self.assertIn("5.1", self.fioh5["/"])
+ self.assertNotIn("5.2", self.fioh5["/"])
+ self.assertIn("measurement/filename", self.fioh5["/5.1"])
+ # illegal trailing "/" after dataset name
+ self.assertNotIn("measurement/filename/",
+ self.fioh5["/5.1"])
+ # full path to element in group (OK)
+ self.assertIn("/5.1/measurement/filename",
+ self.fioh5["/5.1/measurement"])
+
+ def testDataType(self):
+ meas = self.fioh5["/5.1/measurement/"]
+ self.assertEqual(meas["omega(encoder)"].dtype, dtypeConverter['DOUBLE'])
+ self.assertEqual(meas["channel"].dtype, dtypeConverter['INTEGER'])
+ self.assertEqual(meas["filename"].dtype, dtypeConverter['STRING'])
+ self.assertEqual(meas["time_s"].dtype, dtypeConverter['FLOAT'])
+ self.assertEqual(meas["enable"].dtype, dtypeConverter['BOOLEAN'])
+
+ def testDataColumn(self):
+ self.assertAlmostEqual(sum(self.fioh5["/5.1/measurement/omega(encoder)"]),
+ 1802.23418821)
+ self.assertTrue(numpy.all(self.fioh5["/5.1/measurement/enable"]))
+
+ # --- comment section tests ---
+
+ def testComment(self):
+ # should hold the complete comment section
+ self.assertEqual(self.fioh5["/5.1/instrument/fiofile/comments"],
+"""ascan omega 180.0 180.5 3:10/1 4
+user username, acquisition started at Thu Dec 12 18:00:00 2021
+sweep motor lag: 1.0e-03
+channel 3: Detector
+""")
+
+ def testDate(self):
+ # there is no convention on how to format the time. So just check its existence.
+ self.assertEqual(self.fioh5["/5.1/start_time"],
+ u"Thu Dec 12 18:00:00 2021")
+
+ def testTitle(self):
+ self.assertEqual(self.fioh5["/5.1/title"],
+ u"ascan omega 180.0 180.5 3:10/1 4")
+
+
+ # --- parameter section tests ---
+
+ def testParameter(self):
+ # should hold the complete parameter section
+ self.assertEqual(self.fioh5["/5.1/instrument/fiofile/parameter"],
+"""channel3_exposure = 1.000000e+00
+ScanName = ascan
+""")
+
+ def testParsedParameter(self):
+ # no dtype is given, so everything is str.
+ self.assertEqual(self.fioh5["/5.1/instrument/parameter/channel3_exposure"],
+ u"1.000000e+00")
+ self.assertEqual(self.fioh5["/5.1/instrument/parameter/ScanName"], u"ascan")
+
+ def testNotFioH5(self):
+ testfilename = os.path.join(self.temp_dir.name, "eh1scan_00010.fio")
+ with open(testfilename, 'w') as fiof:
+ fiof.write("!Not a fio file!")
+
+ self.assertRaises(IOError, FioH5, testfilename)
+
+ self.assertTrue(is_fiofile(self.fname_numbered))
+ self.assertFalse(is_fiofile(testfilename))
+
+ os.unlink(testfilename)
+
+
+class TestUnnumberedFioH5(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = tempfile.TemporaryDirectory()
+ cls.fname_nosuffix = os.path.join(cls.temp_dir.name, "eh1scan_nosuffix.fio")
+
+ with open(cls.fname_nosuffix, 'w') as fiof:
+ fiof.write(fioftext)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+ del cls.temp_dir
+
+ def setUp(self):
+ self.fioh5 = FioH5(self.fname_nosuffix)
+
+ def testLogMissingScanno(self):
+ with self.assertLogs(logger1,level='WARNING') as cm:
+ fioh5 = FioH5(self.fname_nosuffix)
+ self.assertIn("Cannot parse scan number of file", cm.output[0])
+
+ def testFallbackName(self):
+ self.assertIn("/eh1scan_nosuffix", self.fioh5)
+
+brokenHeaderText = """
+!
+! Comments
+!
+%c
+ascan omega 180.0 180.5 3:10/1 4
+user username, acquisited at Thu Dec 12 100 2021
+sweep motor lavgvf.0e-03
+channel 3: Detector
+!
+! Parameter
+!
+%p
+channel3_exposu65 1.000000e+00
+ScanName = ascan
+!
+! Data
+!
+%d
+ Col 1 omega(encoder) DOUBLE
+ Col 2 channel INTEGER
+ Col 3 filename STRING
+ Col 4 type STRING
+ Col 5 unix time DOUBLE
+ 179.998418821 3 00001 exposure 1576165741.20308
+ 180.048418821 3 00002 exposure 1576165742.20308
+ 180.098418821 3 00003 exposure 1576165743.20308
+ 180.148418821 3 00004 exposure 1576165744.20308
+ 180.198418821 3 00005 exposure 1576165745.20308
+ 180.248418821 3 00006 exposure 1576165746.20308
+ 180.298418821 3 00007 exposure 1576165747.20308
+ 180.348418821 3 00008 exposure 1576165748.20308
+ 180.398418821 3 00009 exposure 1576165749.20308
+ 180.448418821 3 00010 exposure 1576165750.20308
+"""
+
+class TestBrokenHeaderFioH5(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.temp_dir = tempfile.TemporaryDirectory()
+ cls.fname_numbered = os.path.join(cls.temp_dir.name, "eh1scan_00005.fio")
+
+ with open(cls.fname_numbered, 'w') as fiof:
+ fiof.write(brokenHeaderText)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.temp_dir.cleanup()
+ del cls.temp_dir
+
+ def setUp(self):
+ self.fioh5 = FioH5(self.fname_numbered)
+
+ def testLogBrokenHeader(self):
+ with self.assertLogs(logger1,level='WARNING') as cm:
+ fioh5 = FioH5(self.fname_numbered)
+ self.assertIn("Cannot parse parameter section", cm.output[0])
+ self.assertIn("Cannot parse default comment section", cm.output[1])
+
+ def testComment(self):
+ # should hold the complete comment section
+ self.assertEqual(self.fioh5["/5.1/instrument/fiofile/comments"],
+"""ascan omega 180.0 180.5 3:10/1 4
+user username, acquisited at Thu Dec 12 100 2021
+sweep motor lavgvf.0e-03
+channel 3: Detector
+""")
+
+ def testParameter(self):
+ # should hold the complete parameter section
+ self.assertEqual(self.fioh5["/5.1/instrument/fiofile/parameter"],
+"""channel3_exposu65 1.000000e+00
+ScanName = ascan
+""")
diff --git a/src/silx/io/test/test_h5py_utils.py b/src/silx/io/test/test_h5py_utils.py
new file mode 100644
index 0000000..ea46eca
--- /dev/null
+++ b/src/silx/io/test/test_h5py_utils.py
@@ -0,0 +1,451 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for h5py utilities"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "27/01/2020"
+
+
+import unittest
+import os
+import sys
+import time
+import shutil
+import logging
+import tempfile
+import multiprocessing
+from contextlib import contextmanager
+
+from .. import h5py_utils
+from ...utils.retry import RetryError, RetryTimeoutError
+
+IS_WINDOWS = sys.platform == "win32"
+logger = logging.getLogger()
+
+
+def _subprocess_context_main(queue, contextmgr, *args, **kw):
+ try:
+ with contextmgr(*args, **kw):
+ queue.put(None)
+ queue.get()
+ except Exception:
+ queue.put(None)
+ raise
+
+
+@contextmanager
+def _subprocess_context(contextmgr, *args, **kw):
+ print("\nSTART", os.getpid())
+ timeout = kw.pop("timeout", 10)
+ queue = multiprocessing.Queue(maxsize=1)
+ p = multiprocessing.Process(
+ target=_subprocess_context_main, args=(queue, contextmgr) + args, kwargs=kw
+ )
+ p.start()
+ try:
+ queue.get(timeout=timeout)
+ yield
+ finally:
+ queue.put(None)
+ p.join(timeout)
+ print(" EXIT", os.getpid())
+
+
+@contextmanager
+def _open_context(filename, **kw):
+ try:
+ print(os.getpid(), "OPEN", filename, kw)
+ with h5py_utils.File(filename, **kw) as f:
+ if kw.get("mode") == "w":
+ f["check"] = True
+ f.flush()
+ yield f
+ except Exception:
+ print(" ", os.getpid(), "FAILED", filename, kw)
+ raise
+ else:
+ print(" ", os.getpid(), "CLOSED", filename, kw)
+
+
+def _cause_segfault():
+ import ctypes
+
+ i = ctypes.c_char(b"a")
+ j = ctypes.pointer(i)
+ c = 0
+ while True:
+ j[c] = b"a"
+ c += 1
+
+
+def _top_level_names_test(txtfilename, *args, **kw):
+ sys.stderr = open(os.devnull, "w")
+
+ with open(txtfilename, mode="r") as f:
+ failcounter = int(f.readline().strip())
+
+ ncausefailure = kw.pop("ncausefailure")
+ faildelay = kw.pop("faildelay")
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ with open(txtfilename, mode="w") as f:
+ f.write(str(failcounter))
+ if failcounter % 2:
+ raise RetryError
+ else:
+ _cause_segfault()
+ return h5py_utils._top_level_names(*args, **kw)
+
+
+top_level_names_test = h5py_utils.retry_in_subprocess()(_top_level_names_test)
+
+
+def subtests(test):
+ def wrapper(self):
+ for subtest_options in self._subtests():
+ print("\n====SUB TEST===\n")
+ print(f"sub test options: {subtest_options}")
+ with self.subTest(str(subtest_options)):
+ test(self)
+
+ return wrapper
+
+
+class TestH5pyUtils(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.test_dir)
+
+ def _subtests(self):
+ self._subtest_options = {"mode": "w"}
+ self.filename_generator = self._filenames()
+ yield self._subtest_options
+ self._subtest_options = {"mode": "w", "libver": "latest"}
+ self.filename_generator = self._filenames()
+ yield
+
+ def _filenames(self):
+ i = 1
+ while True:
+ filename = os.path.join(self.test_dir, "file{}.h5".format(i))
+ with self._open_context(filename):
+ pass
+ yield filename
+ i += 1
+
+ def _new_filename(self):
+ return next(self.filename_generator)
+
+ @contextmanager
+ def _open_context(self, filename, **kwargs):
+ kw = dict(self._subtest_options)
+ kw.update(kwargs)
+ with _open_context(filename, **kw) as f:
+ yield f
+
+ @contextmanager
+ def _open_context_subprocess(self, filename, **kwargs):
+ kw = dict(self._subtest_options)
+ kw.update(kwargs)
+ with _subprocess_context(_open_context, filename, **kw):
+ yield
+
+ def _assert_hdf5_data(self, f):
+ self.assertTrue(f["check"][()])
+
+ def _validate_hdf5_data(self, filename, swmr=False):
+ with self._open_context(filename, mode="r") as f:
+ self.assertEqual(f.swmr_mode, swmr)
+ self._assert_hdf5_data(f)
+
+ @subtests
+ def test_modes_single_process(self):
+ """Test concurrent access to the different files from the same process"""
+ # When using HDF5_USE_FILE_LOCKING, open files with and without
+ # locking should raise an exception. HDF5_USE_FILE_LOCKING should
+ # be reset when all files are closed.
+
+ orig = os.environ.get("HDF5_USE_FILE_LOCKING")
+ filename1 = self._new_filename()
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+ filename2 = self._new_filename()
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+
+ with self._open_context(filename1, mode="r"):
+ locking1 = False
+ for mode in ["r", "w", "a"]:
+ locking2 = mode != "r"
+ raise_condition = not h5py_utils.HAS_LOCKING_ARGUMENT
+ raise_condition &= locking1 != locking2
+ with self.assertRaisesIf(raise_condition, RuntimeError):
+ with self._open_context(filename2, mode=mode):
+ pass
+ self._validate_hdf5_data(filename1)
+ self._validate_hdf5_data(filename2)
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+
+ with self._open_context(filename1, mode="a"):
+ locking1 = True
+ for mode in ["r", "w", "a"]:
+ locking2 = mode != "r"
+ raise_condition = not h5py_utils.HAS_LOCKING_ARGUMENT
+ raise_condition &= locking1 != locking2
+ with self.assertRaisesIf(raise_condition, RuntimeError):
+ with self._open_context(filename2, mode=mode):
+ pass
+ self._validate_hdf5_data(filename1)
+ self._validate_hdf5_data(filename2)
+ self.assertEqual(orig, os.environ.get("HDF5_USE_FILE_LOCKING"))
+
+ @property
+ def _libver_low_bound_is_v108(self):
+ libver = self._subtest_options.get("libver")
+ return h5py_utils._libver_low_bound_is_v108(libver)
+
+ @property
+ def _nonlocking_reader_before_writer(self):
+ """A non-locking reader must open the file before it is locked by a writer"""
+ if IS_WINDOWS and h5py_utils.HDF5_HAS_LOCKING_ARGUMENT:
+ return True
+ if not self._libver_low_bound_is_v108:
+ return True
+ return False
+
+ @contextmanager
+ def assertRaisesIf(self, condition, *args, **kw):
+ if condition:
+ with self.assertRaises(*args, **kw):
+ yield
+ else:
+ yield
+
+ @unittest.skipIf(
+ h5py_utils.HDF5_HAS_LOCKING_ARGUMENT != h5py_utils.H5PY_HAS_LOCKING_ARGUMENT,
+ "Versions of libhdf5 and h5py use incompatible file locking behaviour",
+ )
+ @subtests
+ def test_modes_multi_process(self):
+ """Test concurrent access to the same file from different processes"""
+ filename = self._new_filename()
+
+ nonlocking_reader_before_writer = self._nonlocking_reader_before_writer
+ writer_before_nonlocking_reader_exception = OSError
+ old_hdf5_on_windows = IS_WINDOWS and not h5py_utils.HDF5_HAS_LOCKING_ARGUMENT
+ locked_exception = OSError
+
+ # File locked by a writer
+ unexpected_access = old_hdf5_on_windows and self._libver_low_bound_is_v108
+ for wmode in ["w", "a"]:
+ with self._open_context_subprocess(filename, mode=wmode):
+ # Access by a second non-locking reader
+ with self.assertRaisesIf(
+ nonlocking_reader_before_writer,
+ writer_before_nonlocking_reader_exception,
+ ):
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ # No access by a second locking reader
+ if unexpected_access:
+ logger.warning("unexpected concurrent access by a locking reader")
+ with self.assertRaisesIf(not unexpected_access, locked_exception):
+ with self._open_context(filename, mode="r", locking=True) as f:
+ self._assert_hdf5_data(f)
+ # No access by a second writer
+ if unexpected_access:
+ logger.warning("unexpected concurrent access by a writer")
+ with self.assertRaisesIf(not unexpected_access, locked_exception):
+ with self._open_context(filename, mode="a") as f:
+ self._assert_hdf5_data(f)
+ # Check for file corruption
+ if not nonlocking_reader_before_writer:
+ self._validate_hdf5_data(filename)
+ self._validate_hdf5_data(filename)
+
+ # File locked by a reader
+ unexpected_access = old_hdf5_on_windows
+ with _subprocess_context(_open_context, filename, mode="r", locking=True):
+ # Access by a non-locking reader
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ # Access by a locking reader
+ with self._open_context(filename, mode="r", locking=True) as f:
+ self._assert_hdf5_data(f)
+ # No access by a second writer
+ if unexpected_access:
+ logger.warning("unexpected concurrent access by a writer")
+ raise_condition = not unexpected_access
+ with self.assertRaisesIf(raise_condition, locked_exception):
+ with self._open_context(filename, mode="a") as f:
+ self._assert_hdf5_data(f)
+ # Check for file corruption
+ self._validate_hdf5_data(filename)
+ self._validate_hdf5_data(filename)
+
+ # File open by a non-locking reader
+ with self._open_context_subprocess(filename, mode="r"):
+ # Access by a second non-locking reader
+ with self._open_context(filename, mode="r") as f:
+ self._assert_hdf5_data(f)
+ # Access by a second locking reader
+ with self._open_context(filename, mode="r", locking=True) as f:
+ self._assert_hdf5_data(f)
+ # Access by a second writer
+ with self._open_context(filename, mode="a") as f:
+ self._assert_hdf5_data(f)
+ # Check for file corruption
+ self._validate_hdf5_data(filename)
+ self._validate_hdf5_data(filename)
+
+ @subtests
+ @unittest.skipIf(not h5py_utils.HAS_SWMR, "SWMR not supported")
+ def test_modes_multi_process_swmr(self):
+ filename = self._new_filename()
+
+ with self._open_context(filename, mode="w", libver="latest") as f:
+ pass
+
+ # File open by SWMR writer
+ with self._open_context_subprocess(filename, mode="a", swmr=True):
+ with self._open_context(filename, mode="r") as f:
+ assert f.swmr_mode
+ self._assert_hdf5_data(f)
+ with self.assertRaises(OSError):
+ with self._open_context(filename, mode="a") as f:
+ pass
+ self._validate_hdf5_data(filename, swmr=True)
+
+ @subtests
+ def test_retry_defaults(self):
+ filename = self._new_filename()
+
+ names = h5py_utils.top_level_names(filename)
+ self.assertEqual(names, [])
+
+ names = h5py_utils.safe_top_level_names(filename)
+ self.assertEqual(names, [])
+
+ names = h5py_utils.top_level_names(filename, include_only=None)
+ self.assertEqual(names, ["check"])
+
+ names = h5py_utils.safe_top_level_names(filename, include_only=None)
+ self.assertEqual(names, ["check"])
+
+ with h5py_utils.open_item(filename, "/check", validate=lambda x: False) as item:
+ self.assertEqual(item, None)
+
+ with h5py_utils.open_item(filename, "/check", validate=None) as item:
+ self.assertTrue(item[()])
+
+ with self.assertRaises(RetryTimeoutError):
+ with h5py_utils.open_item(
+ filename,
+ "/check",
+ retry_timeout=0.1,
+ retry_invalid=True,
+ validate=lambda x: False,
+ ) as item:
+ pass
+
+ ncall = 0
+
+ def validate(item):
+ nonlocal ncall
+ if ncall >= 1:
+ return True
+ else:
+ ncall += 1
+ raise RetryError
+
+ with h5py_utils.open_item(
+ filename,
+ "/check",
+ validate=validate,
+ retry_timeout=1,
+ retry_invalid=True,
+ ) as item:
+ self.assertTrue(item[()])
+
+ @subtests
+ def test_retry_custom(self):
+ filename = self._new_filename()
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ @h5py_utils.retry_contextmanager()
+ def open_item(filename, name):
+ nonlocal failcounter
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ raise RetryError
+ with h5py_utils.File(filename) as h5file:
+ yield h5file[name]
+
+ failcounter = 0
+ kw = {"retry_timeout": sufficient_timeout}
+ with open_item(filename, "/check", **kw) as item:
+ self.assertTrue(item[()])
+
+ failcounter = 0
+ kw = {"retry_timeout": insufficient_timeout}
+ with self.assertRaises(RetryTimeoutError):
+ with open_item(filename, "/check", **kw) as item:
+ pass
+
+ @subtests
+ def test_retry_in_subprocess(self):
+ filename = self._new_filename()
+ txtfilename = os.path.join(self.test_dir, "failcounter.txt")
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ kw = {
+ "retry_timeout": sufficient_timeout,
+ "include_only": None,
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ }
+ with open(txtfilename, mode="w") as f:
+ f.write("0")
+ names = top_level_names_test(txtfilename, filename, **kw)
+ self.assertEqual(names, ["check"])
+
+ kw = {
+ "retry_timeout": insufficient_timeout,
+ "include_only": None,
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ }
+ with open(txtfilename, mode="w") as f:
+ f.write("0")
+ with self.assertRaises(RetryTimeoutError):
+ top_level_names_test(txtfilename, filename, **kw)
diff --git a/src/silx/io/test/test_nxdata.py b/src/silx/io/test/test_nxdata.py
new file mode 100644
index 0000000..9025d6d
--- /dev/null
+++ b/src/silx/io/test/test_nxdata.py
@@ -0,0 +1,563 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for NXdata parsing"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "24/03/2020"
+
+
+import tempfile
+import unittest
+import h5py
+import numpy
+
+from .. import nxdata
+
+
+text_dtype = h5py.special_dtype(vlen=str)
+
+
+class TestNXdata(unittest.TestCase):
+ def setUp(self):
+ tmp = tempfile.NamedTemporaryFile(prefix="nxdata_examples_", suffix=".h5", delete=True)
+ tmp.file.close()
+ self.h5fname = tmp.name
+ self.h5f = h5py.File(tmp.name, "w")
+
+ # SCALARS
+ g0d = self.h5f.create_group("scalars")
+
+ g0d0 = g0d.create_group("0D_scalar")
+ g0d0.attrs["NX_class"] = "NXdata"
+ g0d0.attrs["signal"] = "scalar"
+ g0d0.create_dataset("scalar", data=10)
+ g0d0.create_dataset("scalar_errors", data=0.1)
+
+ g0d1 = g0d.create_group("2D_scalars")
+ g0d1.attrs["NX_class"] = "NXdata"
+ g0d1.attrs["signal"] = "scalars"
+ ds = g0d1.create_dataset("scalars", data=numpy.arange(3 * 10).reshape((3, 10)))
+ ds.attrs["interpretation"] = "scalar"
+
+ g0d1 = g0d.create_group("4D_scalars")
+ g0d1.attrs["NX_class"] = "NXdata"
+ g0d1.attrs["signal"] = "scalars"
+ ds = g0d1.create_dataset("scalars", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10)))
+ ds.attrs["interpretation"] = "scalar"
+
+ # SPECTRA
+ g1d = self.h5f.create_group("spectra")
+
+ g1d0 = g1d.create_group("1D_spectrum")
+ g1d0.attrs["NX_class"] = "NXdata"
+ g1d0.attrs["signal"] = "count"
+ g1d0.attrs["auxiliary_signals"] = numpy.array(["count2", "count3"],
+ dtype=text_dtype)
+ g1d0.attrs["axes"] = "energy_calib"
+ g1d0.attrs["uncertainties"] = numpy.array(["energy_errors", ],
+ dtype=text_dtype)
+ g1d0.create_dataset("count", data=numpy.arange(10))
+ g1d0.create_dataset("count2", data=0.5 * numpy.arange(10))
+ d = g1d0.create_dataset("count3", data=0.4 * numpy.arange(10))
+ d.attrs["long_name"] = "3rd counter"
+ g1d0.create_dataset("title", data="Title as dataset (like nexpy)")
+ g1d0.create_dataset("energy_calib", data=(10, 5)) # 10 * idx + 5
+ g1d0.create_dataset("energy_errors", data=3.14 * numpy.random.rand(10))
+
+ g1d1 = g1d.create_group("2D_spectra")
+ g1d1.attrs["NX_class"] = "NXdata"
+ g1d1.attrs["signal"] = "counts"
+ ds = g1d1.create_dataset("counts", data=numpy.arange(3 * 10).reshape((3, 10)))
+ ds.attrs["interpretation"] = "spectrum"
+
+ g1d2 = g1d.create_group("4D_spectra")
+ g1d2.attrs["NX_class"] = "NXdata"
+ g1d2.attrs["signal"] = "counts"
+ g1d2.attrs["axes"] = numpy.array(["energy", ], dtype=text_dtype)
+ ds = g1d2.create_dataset("counts", data=numpy.arange(2 * 2 * 3 * 10).reshape((2, 2, 3, 10)))
+ ds.attrs["interpretation"] = "spectrum"
+ ds = g1d2.create_dataset("errors", data=4.5 * numpy.random.rand(2, 2, 3, 10))
+ ds = g1d2.create_dataset("energy", data=5 + 10 * numpy.arange(15),
+ shuffle=True, compression="gzip")
+ ds.attrs["long_name"] = "Calibrated energy"
+ ds.attrs["first_good"] = 3
+ ds.attrs["last_good"] = 12
+ g1d2.create_dataset("energy_errors", data=10 * numpy.random.rand(15))
+
+ # IMAGES
+ g2d = self.h5f.create_group("images")
+
+ g2d0 = g2d.create_group("2D_regular_image")
+ g2d0.attrs["NX_class"] = "NXdata"
+ g2d0.attrs["signal"] = "image"
+ g2d0.attrs["auxiliary_signals"] = "image2"
+ g2d0.attrs["axes"] = numpy.array(["rows_calib", "columns_coordinates"],
+ dtype=text_dtype)
+ g2d0.create_dataset("image", data=numpy.arange(4 * 6).reshape((4, 6)))
+ g2d0.create_dataset("image2", data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds = g2d0.create_dataset("rows_calib", data=(10, 5))
+ ds.attrs["long_name"] = "Calibrated Y"
+ g2d0.create_dataset("columns_coordinates", data=0.5 + 0.02 * numpy.arange(6))
+
+ g2d1 = g2d.create_group("2D_irregular_data")
+ g2d1.attrs["NX_class"] = "NXdata"
+ g2d1.attrs["signal"] = "data"
+ g2d1.attrs["title"] = "Title as group attr"
+ g2d1.attrs["axes"] = numpy.array(["rows_coordinates", "columns_coordinates"],
+ dtype=text_dtype)
+ g2d1.create_dataset("data", data=numpy.arange(64 * 128).reshape((64, 128)))
+ g2d1.create_dataset("rows_coordinates", data=numpy.arange(64) + numpy.random.rand(64))
+ g2d1.create_dataset("columns_coordinates", data=numpy.arange(128) + 2.5 * numpy.random.rand(128))
+
+ g2d2 = g2d.create_group("3D_images")
+ g2d2.attrs["NX_class"] = "NXdata"
+ g2d2.attrs["signal"] = "images"
+ ds = g2d2.create_dataset("images", data=numpy.arange(2 * 4 * 6).reshape((2, 4, 6)))
+ ds.attrs["interpretation"] = "image"
+
+ g2d3 = g2d.create_group("5D_images")
+ g2d3.attrs["NX_class"] = "NXdata"
+ g2d3.attrs["signal"] = "images"
+ g2d3.attrs["axes"] = numpy.array(["rows_coordinates", "columns_coordinates"],
+ dtype=text_dtype)
+ ds = g2d3.create_dataset("images", data=numpy.arange(2 * 2 * 2 * 4 * 6).reshape((2, 2, 2, 4, 6)))
+ ds.attrs["interpretation"] = "image"
+ g2d3.create_dataset("rows_coordinates", data=5 + 10 * numpy.arange(4))
+ g2d3.create_dataset("columns_coordinates", data=0.5 + 0.02 * numpy.arange(6))
+
+ g2d4 = g2d.create_group("RGBA_image")
+ g2d4.attrs["NX_class"] = "NXdata"
+ g2d4.attrs["signal"] = "image"
+ g2d4.attrs["axes"] = numpy.array(["rows_calib", "columns_coordinates"],
+ dtype=text_dtype)
+ rgba_image = numpy.linspace(0, 1, num=7*8*3).reshape((7, 8, 3))
+ rgba_image[:, :, 1] = 1 - rgba_image[:, :, 1] # invert G channel to add some color
+ ds = g2d4.create_dataset("image", data=rgba_image)
+ ds.attrs["interpretation"] = "rgba-image"
+ ds = g2d4.create_dataset("rows_calib", data=(10, 5))
+ ds.attrs["long_name"] = "Calibrated Y"
+ g2d4.create_dataset("columns_coordinates", data=0.5+0.02*numpy.arange(8))
+
+ # SCATTER
+ g = self.h5f.create_group("scatters")
+
+ gd0 = g.create_group("x_y_scatter")
+ gd0.attrs["NX_class"] = "NXdata"
+ gd0.attrs["signal"] = "y"
+ gd0.attrs["axes"] = numpy.array(["x", ], dtype=text_dtype)
+ gd0.create_dataset("y", data=numpy.random.rand(128) - 0.5)
+ gd0.create_dataset("x", data=2 * numpy.random.rand(128))
+ gd0.create_dataset("x_errors", data=0.05 * numpy.random.rand(128))
+ gd0.create_dataset("errors", data=0.05 * numpy.random.rand(128))
+
+ gd1 = g.create_group("x_y_value_scatter")
+ gd1.attrs["NX_class"] = "NXdata"
+ gd1.attrs["signal"] = "values"
+ gd1.attrs["axes"] = numpy.array(["x", "y"], dtype=text_dtype)
+ gd1.create_dataset("values", data=3.14 * numpy.random.rand(128))
+ gd1.create_dataset("y", data=numpy.random.rand(128))
+ gd1.create_dataset("y_errors", data=0.02 * numpy.random.rand(128))
+ gd1.create_dataset("x", data=numpy.random.rand(128))
+ gd1.create_dataset("x_errors", data=0.02 * numpy.random.rand(128))
+
+ def tearDown(self):
+ self.h5f.close()
+
+ def testValidity(self):
+ for group in self.h5f:
+ for subgroup in self.h5f[group]:
+ self.assertTrue(
+ nxdata.is_valid_nxdata(self.h5f[group][subgroup]),
+ "%s/%s not found to be a valid NXdata group" % (group, subgroup))
+
+ def testScalars(self):
+ nxd = nxdata.NXdata(self.h5f["scalars/0D_scalar"])
+ self.assertTrue(nxd.signal_is_0d)
+ self.assertEqual(nxd.signal[()], 10)
+ self.assertEqual(nxd.axes_names, [])
+ self.assertEqual(nxd.axes_dataset_names, [])
+ self.assertEqual(nxd.axes, [])
+ self.assertIsNotNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertIsNone(nxd.interpretation)
+
+ nxd = nxdata.NXdata(self.h5f["scalars/2D_scalars"])
+ self.assertTrue(nxd.signal_is_2d)
+ self.assertEqual(nxd.signal[1, 2], 12)
+ self.assertEqual(nxd.axes_names, [None, None])
+ self.assertEqual(nxd.axes_dataset_names, [None, None])
+ self.assertEqual(nxd.axes, [None, None])
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertEqual(nxd.interpretation, "scalar")
+
+ nxd = nxdata.NXdata(self.h5f["scalars/4D_scalars"])
+ self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
+ nxd.signal_is_2d or nxd.signal_is_3d)
+ self.assertEqual(nxd.signal[1, 0, 1, 4], 74)
+ self.assertEqual(nxd.axes_names, [None, None, None, None])
+ self.assertEqual(nxd.axes_dataset_names, [None, None, None, None])
+ self.assertEqual(nxd.axes, [None, None, None, None])
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertEqual(nxd.interpretation, "scalar")
+
+ def testSpectra(self):
+ nxd = nxdata.NXdata(self.h5f["spectra/1D_spectrum"])
+ self.assertTrue(nxd.signal_is_1d)
+ self.assertTrue(nxd.is_curve)
+ self.assertTrue(numpy.array_equal(numpy.array(nxd.signal),
+ numpy.arange(10)))
+ self.assertEqual(nxd.axes_names, ["energy_calib"])
+ self.assertEqual(nxd.axes_dataset_names, ["energy_calib"])
+ self.assertEqual(nxd.axes[0][0], 10)
+ self.assertEqual(nxd.axes[0][1], 5)
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertIsNone(nxd.interpretation)
+ self.assertEqual(nxd.title, "Title as dataset (like nexpy)")
+
+ self.assertEqual(nxd.auxiliary_signals_dataset_names,
+ ["count2", "count3"])
+ self.assertEqual(nxd.auxiliary_signals_names,
+ ["count2", "3rd counter"])
+ self.assertAlmostEqual(nxd.auxiliary_signals[1][2],
+ 0.8) # numpy.arange(10) * 0.4
+
+ nxd = nxdata.NXdata(self.h5f["spectra/2D_spectra"])
+ self.assertTrue(nxd.signal_is_2d)
+ self.assertTrue(nxd.is_curve)
+ self.assertEqual(nxd.axes_names, [None, None])
+ self.assertEqual(nxd.axes_dataset_names, [None, None])
+ self.assertEqual(nxd.axes, [None, None])
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertEqual(nxd.interpretation, "spectrum")
+
+ nxd = nxdata.NXdata(self.h5f["spectra/4D_spectra"])
+ self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
+ nxd.signal_is_2d or nxd.signal_is_3d)
+ self.assertTrue(nxd.is_curve)
+ self.assertEqual(nxd.axes_names,
+ [None, None, None, "Calibrated energy"])
+ self.assertEqual(nxd.axes_dataset_names,
+ [None, None, None, "energy"])
+ self.assertEqual(nxd.axes[:3], [None, None, None])
+ self.assertEqual(nxd.axes[3].shape, (10, )) # dataset shape (15, ) sliced [3:12]
+ self.assertIsNotNone(nxd.errors)
+ self.assertEqual(nxd.errors.shape, (2, 2, 3, 10))
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertEqual(nxd.interpretation, "spectrum")
+ self.assertEqual(nxd.get_axis_errors("energy").shape,
+ (10,))
+ # test getting axis errors by long_name
+ self.assertTrue(numpy.array_equal(nxd.get_axis_errors("Calibrated energy"),
+ nxd.get_axis_errors("energy")))
+ self.assertTrue(numpy.array_equal(nxd.get_axis_errors(b"Calibrated energy"),
+ nxd.get_axis_errors("energy")))
+
+ def testImages(self):
+ nxd = nxdata.NXdata(self.h5f["images/2D_regular_image"])
+ self.assertTrue(nxd.signal_is_2d)
+ self.assertTrue(nxd.is_image)
+ self.assertEqual(nxd.axes_names, ["Calibrated Y", "columns_coordinates"])
+ self.assertEqual(list(nxd.axes_dataset_names),
+ ["rows_calib", "columns_coordinates"])
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertIsNone(nxd.interpretation)
+ self.assertEqual(len(nxd.auxiliary_signals), 1)
+ self.assertEqual(nxd.auxiliary_signals_names, ["image2"])
+
+ nxd = nxdata.NXdata(self.h5f["images/2D_irregular_data"])
+ self.assertTrue(nxd.signal_is_2d)
+ self.assertTrue(nxd.is_image)
+
+ self.assertEqual(nxd.axes_dataset_names, nxd.axes_names)
+ self.assertEqual(list(nxd.axes_dataset_names),
+ ["rows_coordinates", "columns_coordinates"])
+ self.assertEqual(len(nxd.axes), 2)
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertIsNone(nxd.interpretation)
+ self.assertEqual(nxd.title, "Title as group attr")
+
+ nxd = nxdata.NXdata(self.h5f["images/5D_images"])
+ self.assertTrue(nxd.is_image)
+ self.assertFalse(nxd.signal_is_0d or nxd.signal_is_1d or
+ nxd.signal_is_2d or nxd.signal_is_3d)
+ self.assertEqual(nxd.axes_names,
+ [None, None, None, 'rows_coordinates', 'columns_coordinates'])
+ self.assertEqual(nxd.axes_dataset_names,
+ [None, None, None, 'rows_coordinates', 'columns_coordinates'])
+ self.assertIsNone(nxd.errors)
+ self.assertFalse(nxd.is_scatter or nxd.is_x_y_value_scatter)
+ self.assertEqual(nxd.interpretation, "image")
+
+ nxd = nxdata.NXdata(self.h5f["images/RGBA_image"])
+ self.assertTrue(nxd.is_image)
+ self.assertEqual(nxd.interpretation, "rgba-image")
+ self.assertTrue(nxd.signal_is_3d)
+ self.assertEqual(nxd.axes_names, ["Calibrated Y",
+ "columns_coordinates",
+ None])
+ self.assertEqual(list(nxd.axes_dataset_names),
+ ["rows_calib", "columns_coordinates", None])
+
+ def testScatters(self):
+ nxd = nxdata.NXdata(self.h5f["scatters/x_y_scatter"])
+ self.assertTrue(nxd.signal_is_1d)
+ self.assertEqual(nxd.axes_names, ["x"])
+ self.assertEqual(nxd.axes_dataset_names,
+ ["x"])
+ self.assertIsNotNone(nxd.errors)
+ self.assertEqual(nxd.get_axis_errors("x").shape,
+ (128, ))
+ self.assertTrue(nxd.is_scatter)
+ self.assertFalse(nxd.is_x_y_value_scatter)
+ self.assertIsNone(nxd.interpretation)
+
+ nxd = nxdata.NXdata(self.h5f["scatters/x_y_value_scatter"])
+ self.assertFalse(nxd.signal_is_1d)
+ self.assertTrue(nxd.axes_dataset_names,
+ nxd.axes_names)
+ self.assertEqual(nxd.axes_dataset_names,
+ ["x", "y"])
+ self.assertEqual(nxd.get_axis_errors("x").shape,
+ (128, ))
+ self.assertEqual(nxd.get_axis_errors("y").shape,
+ (128, ))
+ self.assertEqual(len(nxd.axes), 2)
+ self.assertIsNone(nxd.errors)
+ self.assertTrue(nxd.is_scatter)
+ self.assertTrue(nxd.is_x_y_value_scatter)
+ self.assertIsNone(nxd.interpretation)
+
+
+class TestLegacyNXdata(unittest.TestCase):
+ def setUp(self):
+ tmp = tempfile.NamedTemporaryFile(prefix="nxdata_legacy_examples_",
+ suffix=".h5", delete=True)
+ tmp.file.close()
+ self.h5fname = tmp.name
+ self.h5f = h5py.File(tmp.name, "w")
+
+ def tearDown(self):
+ self.h5f.close()
+
+ def testSignalAttrOnDataset(self):
+ g = self.h5f.create_group("2D")
+ g.attrs["NX_class"] = "NXdata"
+
+ ds0 = g.create_dataset("image0",
+ data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds0.attrs["signal"] = 1
+ ds0.attrs["long_name"] = "My first image"
+
+ ds1 = g.create_dataset("image1",
+ data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds1.attrs["signal"] = "2"
+ ds1.attrs["long_name"] = "My 2nd image"
+
+ ds2 = g.create_dataset("image2",
+ data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds2.attrs["signal"] = 3
+
+ nxd = nxdata.NXdata(self.h5f["2D"])
+
+ self.assertEqual(nxd.signal_dataset_name, "image0")
+ self.assertEqual(nxd.signal_name, "My first image")
+ self.assertEqual(nxd.signal.shape,
+ (4, 6))
+
+ self.assertEqual(len(nxd.auxiliary_signals), 2)
+ self.assertEqual(nxd.auxiliary_signals[1].shape,
+ (4, 6))
+
+ self.assertEqual(nxd.auxiliary_signals_dataset_names,
+ ["image1", "image2"])
+ self.assertEqual(nxd.auxiliary_signals_names,
+ ["My 2nd image", "image2"])
+
+ def testAxesOnSignalDataset(self):
+ g = self.h5f.create_group("2D")
+ g.attrs["NX_class"] = "NXdata"
+
+ ds0 = g.create_dataset("image0",
+ data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds0.attrs["signal"] = 1
+ ds0.attrs["axes"] = "yaxis:xaxis"
+
+ ds1 = g.create_dataset("yaxis",
+ data=numpy.arange(4))
+ ds2 = g.create_dataset("xaxis",
+ data=numpy.arange(6))
+
+ nxd = nxdata.NXdata(self.h5f["2D"])
+
+ self.assertEqual(nxd.axes_dataset_names,
+ ["yaxis", "xaxis"])
+ self.assertTrue(numpy.array_equal(nxd.axes[0],
+ numpy.arange(4)))
+ self.assertTrue(numpy.array_equal(nxd.axes[1],
+ numpy.arange(6)))
+
+ def testAxesOnAxesDatasets(self):
+ g = self.h5f.create_group("2D")
+ g.attrs["NX_class"] = "NXdata"
+
+ ds0 = g.create_dataset("image0",
+ data=numpy.arange(4 * 6).reshape((4, 6)))
+ ds0.attrs["signal"] = 1
+ ds1 = g.create_dataset("yaxis",
+ data=numpy.arange(4))
+ ds1.attrs["axis"] = 0
+ ds2 = g.create_dataset("xaxis",
+ data=numpy.arange(6))
+ ds2.attrs["axis"] = "1"
+
+ nxd = nxdata.NXdata(self.h5f["2D"])
+ self.assertEqual(nxd.axes_dataset_names,
+ ["yaxis", "xaxis"])
+ self.assertTrue(numpy.array_equal(nxd.axes[0],
+ numpy.arange(4)))
+ self.assertTrue(numpy.array_equal(nxd.axes[1],
+ numpy.arange(6)))
+
+ def testAsciiUndefinedAxesAttrs(self):
+ """Some files may not be using utf8 for str attrs"""
+ g = self.h5f.create_group("bytes_attrs")
+ g.attrs["NX_class"] = b"NXdata"
+ g.attrs["signal"] = b"image0"
+ g.attrs["axes"] = b"yaxis", b"."
+
+ g.create_dataset("image0",
+ data=numpy.arange(4 * 6).reshape((4, 6)))
+ g.create_dataset("yaxis",
+ data=numpy.arange(4))
+
+ nxd = nxdata.NXdata(self.h5f["bytes_attrs"])
+ self.assertEqual(nxd.axes_dataset_names,
+ ["yaxis", None])
+
+
+class TestSaveNXdata(unittest.TestCase):
+ def setUp(self):
+ tmp = tempfile.NamedTemporaryFile(prefix="nxdata",
+ suffix=".h5", delete=True)
+ tmp.file.close()
+ self.h5fname = tmp.name
+
+ def testSimpleSave(self):
+ sig = numpy.array([0, 1, 2])
+ a0 = numpy.array([2, 3, 4])
+ a1 = numpy.array([3, 4, 5])
+ nxdata.save_NXdata(filename=self.h5fname,
+ signal=sig,
+ axes=[a0, a1],
+ signal_name="sig",
+ axes_names=["a0", "a1"],
+ nxentry_name="a",
+ nxdata_name="mydata")
+
+ h5f = h5py.File(self.h5fname, "r")
+ self.assertTrue(nxdata.is_valid_nxdata(h5f["a/mydata"]))
+
+ nxd = nxdata.NXdata(h5f["/a/mydata"])
+ self.assertTrue(numpy.array_equal(nxd.signal,
+ sig))
+ self.assertTrue(numpy.array_equal(nxd.axes[0],
+ a0))
+
+ h5f.close()
+
+ def testSimplestSave(self):
+ sig = numpy.array([0, 1, 2])
+ nxdata.save_NXdata(filename=self.h5fname,
+ signal=sig)
+
+ h5f = h5py.File(self.h5fname, "r")
+
+ self.assertTrue(nxdata.is_valid_nxdata(h5f["/entry/data0"]))
+
+ nxd = nxdata.NXdata(h5f["/entry/data0"])
+ self.assertTrue(numpy.array_equal(nxd.signal,
+ sig))
+ h5f.close()
+
+ def testSaveDefaultAxesNames(self):
+ sig = numpy.array([0, 1, 2])
+ a0 = numpy.array([2, 3, 4])
+ a1 = numpy.array([3, 4, 5])
+ nxdata.save_NXdata(filename=self.h5fname,
+ signal=sig,
+ axes=[a0, a1],
+ signal_name="sig",
+ axes_names=None,
+ axes_long_names=["a", "b"],
+ nxentry_name="a",
+ nxdata_name="mydata")
+
+ h5f = h5py.File(self.h5fname, "r")
+ self.assertTrue(nxdata.is_valid_nxdata(h5f["a/mydata"]))
+
+ nxd = nxdata.NXdata(h5f["/a/mydata"])
+ self.assertTrue(numpy.array_equal(nxd.signal,
+ sig))
+ self.assertTrue(numpy.array_equal(nxd.axes[0],
+ a0))
+ self.assertEqual(nxd.axes_dataset_names,
+ [u"dim0", u"dim1"])
+ self.assertEqual(nxd.axes_names,
+ [u"a", u"b"])
+
+ h5f.close()
+
+ def testSaveToExistingEntry(self):
+ h5f = h5py.File(self.h5fname, "w")
+ g = h5f.create_group("myentry")
+ g.attrs["NX_class"] = "NXentry"
+ h5f.close()
+
+ sig = numpy.array([0, 1, 2])
+ a0 = numpy.array([2, 3, 4])
+ a1 = numpy.array([3, 4, 5])
+ nxdata.save_NXdata(filename=self.h5fname,
+ signal=sig,
+ axes=[a0, a1],
+ signal_name="sig",
+ axes_names=["a0", "a1"],
+ nxentry_name="myentry",
+ nxdata_name="toto")
+
+ h5f = h5py.File(self.h5fname, "r")
+ self.assertTrue(nxdata.is_valid_nxdata(h5f["myentry/toto"]))
+
+ nxd = nxdata.NXdata(h5f["myentry/toto"])
+ self.assertTrue(numpy.array_equal(nxd.signal,
+ sig))
+ self.assertTrue(numpy.array_equal(nxd.axes[0],
+ a0))
+ h5f.close()
diff --git a/src/silx/io/test/test_octaveh5.py b/src/silx/io/test/test_octaveh5.py
new file mode 100644
index 0000000..1c3b3e0
--- /dev/null
+++ b/src/silx/io/test/test_octaveh5.py
@@ -0,0 +1,156 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Tests for the octaveh5 module
+"""
+
+__authors__ = ["C. Nemoz", "H. Payno"]
+__license__ = "MIT"
+__date__ = "12/07/2016"
+
+import unittest
+import os
+import tempfile
+
+try:
+ from ..octaveh5 import Octaveh5
+except ImportError:
+ Octaveh5 = None
+
+
+@unittest.skipIf(Octaveh5 is None, "Could not import h5py")
+class TestOctaveH5(unittest.TestCase):
+ @staticmethod
+ def _get_struct_FT():
+ return {
+ 'NO_CHECK': 0.0, 'SHOWSLICE': 1.0, 'DOTOMO': 1.0, 'DATABASE': 0.0, 'ANGLE_OFFSET': 0.0,
+ 'VOLSELECTION_REMEMBER': 0.0, 'NUM_PART': 4.0, 'VOLOUTFILE': 0.0, 'RINGSCORRECTION': 0.0,
+ 'DO_TEST_SLICE': 1.0, 'ZEROOFFMASK': 1.0, 'VERSION': 'fastomo3 version 2.0',
+ 'CORRECT_SPIKES_THRESHOLD': 0.040000000000000001, 'SHOWPROJ': 0.0, 'HALF_ACQ': 0.0,
+ 'ANGLE_OFFSET_VALUE': 0.0, 'FIXEDSLICE': 'middle', 'VOLSELECT': 'total' }
+ @staticmethod
+ def _get_struct_PYHSTEXE():
+ return {
+ 'EXE': 'PyHST2_2015d', 'VERBOSE': 0.0, 'OFFV': 'PyHST2_2015d', 'TOMO': 0.0,
+ 'VERBOSE_FILE': 'pyhst_out.txt', 'DIR': '/usr/bin/', 'OFFN': 'pyhst2'}
+
+ @staticmethod
+ def _get_struct_FTAXIS():
+ return {
+ 'POSITION_VALUE': 12345.0, 'COR_ERROR': 0.0, 'FILESDURINGSCAN': 0.0, 'PLOTFIGURE': 1.0,
+ 'DIM1': 0.0, 'OVERSAMPLING': 5.0, 'TO_THE_CENTER': 1.0, 'POSITION': 'fixed',
+ 'COR_POSITION': 0.0, 'HA': 0.0 }
+
+ @staticmethod
+ def _get_struct_PAGANIN():
+ return {
+ 'MKEEP_MASK': 0.0, 'UNSHARP_SIGMA': 0.80000000000000004, 'DILATE': 2.0, 'UNSHARP_COEFF': 3.0,
+ 'MEDIANR': 4.0, 'DB': 500.0, 'MKEEP_ABS': 0.0, 'MODE': 0.0, 'THRESHOLD': 0.5,
+ 'MKEEP_BONE': 0.0, 'DB2': 100.0, 'MKEEP_CORR': 0.0, 'MKEEP_SOFT': 0.0 }
+
+ @staticmethod
+ def _get_struct_BEAMGEO():
+ return {'DIST': 55.0, 'SY': 0.0, 'SX': 0.0, 'TYPE': 'p'}
+
+
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.test_3_6_fname = os.path.join(self.tempdir, "silx_tmp_t00_octaveTest_3_6.h5")
+ self.test_3_8_fname = os.path.join(self.tempdir, "silx_tmp_t00_octaveTest_3_8.h5")
+
+ def tearDown(self):
+ if os.path.isfile(self.test_3_6_fname):
+ os.unlink(self.test_3_6_fname)
+ if os.path.isfile(self.test_3_8_fname):
+ os.unlink(self.test_3_8_fname)
+
+ def testWritedIsReaded(self):
+ """
+ Simple test to write and reaf the structure compatible with the octave h5 using structure.
+ This test is for # test for octave version > 3.8
+ """
+ writer = Octaveh5()
+
+ writer.open(self.test_3_8_fname, 'a')
+ # step 1 writing the file
+ writer.write('FT', self._get_struct_FT())
+ writer.write('PYHSTEXE', self._get_struct_PYHSTEXE())
+ writer.write('FTAXIS', self._get_struct_FTAXIS())
+ writer.write('PAGANIN', self._get_struct_PAGANIN())
+ writer.write('BEAMGEO', self._get_struct_BEAMGEO())
+ writer.close()
+
+ # step 2 reading the file
+ reader = Octaveh5().open(self.test_3_8_fname)
+ # 2.1 check FT
+ data_readed = reader.get('FT')
+ self.assertEqual(data_readed, self._get_struct_FT() )
+ # 2.2 check PYHSTEXE
+ data_readed = reader.get('PYHSTEXE')
+ self.assertEqual(data_readed, self._get_struct_PYHSTEXE() )
+ # 2.3 check FTAXIS
+ data_readed = reader.get('FTAXIS')
+ self.assertEqual(data_readed, self._get_struct_FTAXIS() )
+ # 2.4 check PAGANIN
+ data_readed = reader.get('PAGANIN')
+ self.assertEqual(data_readed, self._get_struct_PAGANIN() )
+ # 2.5 check BEAMGEO
+ data_readed = reader.get('BEAMGEO')
+ self.assertEqual(data_readed, self._get_struct_BEAMGEO() )
+ reader.close()
+
+ def testWritedIsReadedOldOctaveVersion(self):
+ """The same test as testWritedIsReaded but for octave version < 3.8
+ """
+ # test for octave version < 3.8
+ writer = Octaveh5(3.6)
+
+ writer.open(self.test_3_6_fname, 'a')
+
+ # step 1 writing the file
+ writer.write('FT', self._get_struct_FT())
+ writer.write('PYHSTEXE', self._get_struct_PYHSTEXE())
+ writer.write('FTAXIS', self._get_struct_FTAXIS())
+ writer.write('PAGANIN', self._get_struct_PAGANIN())
+ writer.write('BEAMGEO', self._get_struct_BEAMGEO())
+ writer.close()
+
+ # step 2 reading the file
+ reader = Octaveh5(3.6).open(self.test_3_6_fname)
+ # 2.1 check FT
+ data_readed = reader.get('FT')
+ self.assertEqual(data_readed, self._get_struct_FT() )
+ # 2.2 check PYHSTEXE
+ data_readed = reader.get('PYHSTEXE')
+ self.assertEqual(data_readed, self._get_struct_PYHSTEXE() )
+ # 2.3 check FTAXIS
+ data_readed = reader.get('FTAXIS')
+ self.assertEqual(data_readed, self._get_struct_FTAXIS() )
+ # 2.4 check PAGANIN
+ data_readed = reader.get('PAGANIN')
+ self.assertEqual(data_readed, self._get_struct_PAGANIN() )
+ # 2.5 check BEAMGEO
+ data_readed = reader.get('BEAMGEO')
+ self.assertEqual(data_readed, self._get_struct_BEAMGEO() )
+ reader.close()
diff --git a/src/silx/io/test/test_rawh5.py b/src/silx/io/test/test_rawh5.py
new file mode 100644
index 0000000..236484d
--- /dev/null
+++ b/src/silx/io/test/test_rawh5.py
@@ -0,0 +1,85 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for silx.gui.hdf5 module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "21/09/2017"
+
+
+import unittest
+import tempfile
+import numpy
+import shutil
+from ..import rawh5
+
+
+class TestNumpyFile(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tmpDirectory = tempfile.mkdtemp()
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmpDirectory)
+
+ def testNumpyFile(self):
+ filename = "%s/%s.npy" % (self.tmpDirectory, self.id())
+ c = numpy.random.rand(5, 5)
+ numpy.save(filename, c)
+ h5 = rawh5.NumpyFile(filename)
+ self.assertIn("data", h5)
+ self.assertEqual(h5["data"].dtype.kind, "f")
+
+ def testNumpyZFile(self):
+ filename = "%s/%s.npz" % (self.tmpDirectory, self.id())
+ a = numpy.array(u"aaaaa")
+ b = numpy.array([1, 2, 3, 4])
+ c = numpy.random.rand(5, 5)
+ d = numpy.array(b"aaaaa")
+ e = numpy.array(u"i \u2661 my mother")
+ numpy.savez(filename, a, b=b, c=c, d=d, e=e)
+ h5 = rawh5.NumpyFile(filename)
+ self.assertIn("arr_0", h5)
+ self.assertIn("b", h5)
+ self.assertIn("c", h5)
+ self.assertIn("d", h5)
+ self.assertIn("e", h5)
+ self.assertEqual(h5["arr_0"].dtype.kind, "U")
+ self.assertEqual(h5["b"].dtype.kind, "i")
+ self.assertEqual(h5["c"].dtype.kind, "f")
+ self.assertEqual(h5["d"].dtype.kind, "S")
+ self.assertEqual(h5["e"].dtype.kind, "U")
+
+ def testNumpyZFileContainingDirectories(self):
+ filename = "%s/%s.npz" % (self.tmpDirectory, self.id())
+ data = {}
+ data['a/b/c'] = numpy.arange(10)
+ data['a/b/e'] = numpy.arange(10)
+ numpy.savez(filename, **data)
+ h5 = rawh5.NumpyFile(filename)
+ self.assertIn("a/b/c", h5)
+ self.assertIn("a/b/e", h5)
diff --git a/src/silx/io/test/test_specfile.py b/src/silx/io/test/test_specfile.py
new file mode 100644
index 0000000..44cb08c
--- /dev/null
+++ b/src/silx/io/test/test_specfile.py
@@ -0,0 +1,420 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for specfile wrapper"""
+
+__authors__ = ["P. Knobel", "V.A. Sole"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import locale
+import logging
+import numpy
+import os
+import sys
+import tempfile
+import unittest
+
+from silx.utils import testutils
+
+from ..specfile import SpecFile, Scan
+from .. import specfile
+
+
+logger1 = logging.getLogger(__name__)
+
+sftext = """#F /tmp/sf.dat
+#E 1455180875
+#D Thu Feb 11 09:54:35 2016
+#C imaging User = opid17
+#U00 user comment first line
+#U01 This is a dummy file to test SpecFile parsing
+#U02
+#U03 last line
+
+#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
+#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
+#o0 pshg mrtu mrtd
+#o2 ss1vo ss1ho ss1vg
+
+#J0 Seconds IA ion.mono Current
+#J1 xbpmc2 idgap1 Inorm
+
+#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
+#D Thu Feb 11 09:55:20 2016
+#T 0.2 (Seconds)
+#G0 0
+#G1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+#G3 0 0 0 0 0 0 0 0 0
+#G4 0
+#Q
+#P0 180.005 -0.66875 0.87125
+#P1 14.74255 16.197579 12.238283
+#UMI0 Current AutoM Shutter
+#UMI1 192.51 OFF FE open
+#UMI2 Refill in 39883 sec, Fill Mode: uniform multibunch / Message: Feb 11 08:00 Delivery:Next Refill at 21:00;
+#N 4
+#L first column second column 3rd_col
+-1.23 5.89 8
+8.478100E+01 5 1.56
+3.14 2.73 -3.14
+1.2 2.3 3.4
+
+#S 25 ascan c3th 1.33245 1.52245 40 0.15
+#D Thu Feb 11 10:00:31 2016
+#P0 80.005 -1.66875 1.87125
+#P1 4.74255 6.197579 2.238283
+#N 5
+#L column0 column1 col2 col3
+0.0 0.1 0.2 0.3
+1.0 1.1 1.2 1.3
+2.0 2.1 2.2 2.3
+3.0 3.1 3.2 3.3
+
+#S 26 yyyyyy
+#D Thu Feb 11 09:55:20 2016
+#P0 80.005 -1.66875 1.87125
+#P1 4.74255 6.197579 2.238283
+#N 4
+#L first column second column 3rd_col
+#C Sat Oct 31 15:51:47 1998. Scan aborted after 0 points.
+
+#F /tmp/sf.dat
+#E 1455180876
+#D Thu Feb 11 09:54:36 2016
+
+#S 1 aaaaaa
+#U first duplicate line
+#U second duplicate line
+#@MCADEV 1
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#N 3
+#L uno duo
+1 2
+@A 0 1 2
+3 4
+@A 3.1 4 5
+5 6
+@A 6 7.7 8
+"""
+
+
+loc = locale.getlocale(locale.LC_NUMERIC)
+try:
+ locale.setlocale(locale.LC_NUMERIC, 'de_DE.utf8')
+except locale.Error:
+ try_DE = False
+else:
+ try_DE = True
+ locale.setlocale(locale.LC_NUMERIC, loc)
+
+
+class TestSpecFile(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname1 = tempfile.mkstemp(text=False)
+ if sys.version_info < (3, ):
+ os.write(fd, sftext)
+ else:
+ os.write(fd, bytes(sftext, 'ascii'))
+ os.close(fd)
+
+ fd2, cls.fname2 = tempfile.mkstemp(text=False)
+ if sys.version_info < (3, ):
+ os.write(fd2, sftext[370:923])
+ else:
+ os.write(fd2, bytes(sftext[370:923], 'ascii'))
+ os.close(fd2)
+
+ fd3, cls.fname3 = tempfile.mkstemp(text=False)
+ txt = sftext[371:923]
+ if sys.version_info < (3, ):
+ os.write(fd3, txt)
+ else:
+ os.write(fd3, bytes(txt, 'ascii'))
+ os.close(fd3)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname1)
+ os.unlink(cls.fname2)
+ os.unlink(cls.fname3)
+
+ def setUp(self):
+ self.sf = SpecFile(self.fname1)
+ self.scan1 = self.sf[0]
+ self.scan1_2 = self.sf["1.2"]
+ self.scan25 = self.sf["25.1"]
+ self.empty_scan = self.sf["26.1"]
+
+ self.sf_no_fhdr = SpecFile(self.fname2)
+ self.scan1_no_fhdr = self.sf_no_fhdr[0]
+
+ self.sf_no_fhdr_crash = SpecFile(self.fname3)
+ self.scan1_no_fhdr_crash = self.sf_no_fhdr_crash[0]
+
+ def tearDown(self):
+ self.sf.close()
+ self.sf_no_fhdr.close()
+ self.sf_no_fhdr_crash.close()
+
+ def test_open(self):
+ self.assertIsInstance(self.sf, SpecFile)
+ with self.assertRaises(specfile.SfErrFileOpen):
+ SpecFile("doesnt_exist.dat")
+
+ # test filename types unicode and bytes
+ if sys.version_info[0] < 3:
+ try:
+ SpecFile(self.fname1)
+ except TypeError:
+ self.fail("failed to handle filename as python2 str")
+ try:
+ SpecFile(unicode(self.fname1))
+ except TypeError:
+ self.fail("failed to handle filename as python2 unicode")
+ else:
+ try:
+ SpecFile(self.fname1)
+ except TypeError:
+ self.fail("failed to handle filename as python3 str")
+ try:
+ SpecFile(bytes(self.fname1, 'utf-8'))
+ except TypeError:
+ self.fail("failed to handle filename as python3 bytes")
+
+ def test_number_of_scans(self):
+ self.assertEqual(4, len(self.sf))
+
+ def test_list_of_scan_indices(self):
+ self.assertEqual(self.sf.list(),
+ [1, 25, 26, 1])
+ self.assertEqual(self.sf.keys(),
+ ["1.1", "25.1", "26.1", "1.2"])
+
+ def test_index_number_order(self):
+ self.assertEqual(self.sf.index(1, 2), 3) # sf["1.2"]==sf[3]
+ self.assertEqual(self.sf.number(1), 25) # sf[1]==sf["25"]
+ self.assertEqual(self.sf.order(3), 2) # sf[3]==sf["1.2"]
+ with self.assertRaises(specfile.SfErrScanNotFound):
+ self.sf.index(3, 2)
+ with self.assertRaises(specfile.SfErrScanNotFound):
+ self.sf.index(99)
+
+ def assertRaisesRegex(self, *args, **kwargs):
+ # Python 2 compatibility
+ if sys.version_info.major >= 3:
+ return super(TestSpecFile, self).assertRaisesRegex(*args, **kwargs)
+ else:
+ return self.assertRaisesRegexp(*args, **kwargs)
+
+ def test_getitem(self):
+ self.assertIsInstance(self.sf[2], Scan)
+ self.assertIsInstance(self.sf["1.2"], Scan)
+ # int out of range
+ with self.assertRaisesRegex(IndexError, 'Scan index must be in ran'):
+ self.sf[107]
+ # float indexing not allowed
+ with self.assertRaisesRegex(TypeError, 'The scan identification k'):
+ self.sf[1.2]
+ # non existant scan with "N.M" indexing
+ with self.assertRaises(KeyError):
+ self.sf["3.2"]
+
+ def test_specfile_iterator(self):
+ i = 0
+ for scan in self.sf:
+ if i == 1:
+ self.assertEqual(scan.motor_positions,
+ self.sf[1].motor_positions)
+ i += 1
+ # number of returned scans
+ self.assertEqual(i, len(self.sf))
+
+ def test_scan_index(self):
+ self.assertEqual(self.scan1.index, 0)
+ self.assertEqual(self.scan1_2.index, 3)
+ self.assertEqual(self.scan25.index, 1)
+
+ def test_scan_headers(self):
+ self.assertEqual(self.scan25.scan_header_dict['S'],
+ "25 ascan c3th 1.33245 1.52245 40 0.15")
+ self.assertEqual(self.scan1.header[17], '#G0 0')
+ self.assertEqual(len(self.scan1.header), 29)
+ # parsing headers with long keys
+ self.assertEqual(self.scan1.scan_header_dict['UMI0'],
+ 'Current AutoM Shutter')
+ # parsing empty headers
+ self.assertEqual(self.scan1.scan_header_dict['Q'], '')
+ # duplicate headers: concatenated (with newline)
+ self.assertEqual(self.scan1_2.scan_header_dict["U"],
+ "first duplicate line\nsecond duplicate line")
+
+ def test_file_headers(self):
+ self.assertEqual(self.scan1.header[1],
+ '#E 1455180875')
+ self.assertEqual(self.scan1.file_header_dict['F'],
+ '/tmp/sf.dat')
+
+ def test_multiple_file_headers(self):
+ """Scan 1.2 is after the second file header, with a different
+ Epoch"""
+ self.assertEqual(self.scan1_2.header[1],
+ '#E 1455180876')
+
+ def test_scan_labels(self):
+ self.assertEqual(self.scan1.labels,
+ ['first column', 'second column', '3rd_col'])
+
+ def test_data(self):
+ # data_line() and data_col() take 1-based indices as arg
+ self.assertAlmostEqual(self.scan1.data_line(1)[2],
+ 1.56)
+ # tests for data transposition between original file and .data attr
+ self.assertAlmostEqual(self.scan1.data[2, 0],
+ 8)
+ self.assertEqual(self.scan1.data.shape, (3, 4))
+ self.assertAlmostEqual(numpy.sum(self.scan1.data), 113.631)
+
+ def test_data_column_by_name(self):
+ self.assertAlmostEqual(self.scan25.data_column_by_name("col2")[1],
+ 1.2)
+ # Scan.data is transposed after readinq, so column is the first index
+ self.assertAlmostEqual(numpy.sum(self.scan25.data_column_by_name("col2")),
+ numpy.sum(self.scan25.data[2, :]))
+ with self.assertRaises(specfile.SfErrColNotFound):
+ self.scan25.data_column_by_name("ygfxgfyxg")
+
+ def test_motors(self):
+ self.assertEqual(len(self.scan1.motor_names), 6)
+ self.assertEqual(len(self.scan1.motor_positions), 6)
+ self.assertAlmostEqual(sum(self.scan1.motor_positions),
+ 223.385912)
+ self.assertEqual(self.scan1.motor_names[1], 'MRTSlit UP')
+ self.assertAlmostEqual(
+ self.scan25.motor_position_by_name('MRTSlit UP'),
+ -1.66875)
+
+ def test_absence_of_file_header(self):
+ """We expect Scan.file_header to be an empty list in the absence
+ of a file header.
+ """
+ self.assertEqual(len(self.scan1_no_fhdr.motor_names), 0)
+ # motor positions can still be read in the scan header
+ # even in the absence of motor names
+ self.assertAlmostEqual(sum(self.scan1_no_fhdr.motor_positions),
+ 223.385912)
+ self.assertEqual(len(self.scan1_no_fhdr.header), 15)
+ self.assertEqual(len(self.scan1_no_fhdr.scan_header), 15)
+ self.assertEqual(len(self.scan1_no_fhdr.file_header), 0)
+
+ def test_crash_absence_of_file_header(self):
+ """Test no crash in absence of file header and no leading newline
+ character
+ """
+ self.assertEqual(len(self.scan1_no_fhdr_crash.motor_names), 0)
+ # motor positions can still be read in the scan header
+ # even in the absence of motor names
+ self.assertAlmostEqual(sum(self.scan1_no_fhdr_crash.motor_positions),
+ 223.385912)
+ self.assertEqual(len(self.scan1_no_fhdr_crash.scan_header), 15)
+ self.assertEqual(len(self.scan1_no_fhdr_crash.file_header), 0)
+
+ def test_mca(self):
+ self.assertEqual(len(self.scan1.mca), 0)
+ self.assertEqual(len(self.scan1_2.mca), 3)
+ self.assertEqual(self.scan1_2.mca[1][2], 5)
+ self.assertEqual(sum(self.scan1_2.mca[2]), 21.7)
+
+ # Negative indexing
+ self.assertEqual(sum(self.scan1_2.mca[len(self.scan1_2.mca) - 1]),
+ sum(self.scan1_2.mca[-1]))
+
+ # Test iterator
+ line_count, total_sum = (0, 0)
+ for mca_line in self.scan1_2.mca:
+ line_count += 1
+ total_sum += sum(mca_line)
+ self.assertEqual(line_count, 3)
+ self.assertAlmostEqual(total_sum, 36.8)
+
+ def test_mca_header(self):
+ self.assertEqual(self.scan1.mca_header_dict, {})
+ self.assertEqual(len(self.scan1_2.mca_header_dict), 4)
+ self.assertEqual(self.scan1_2.mca_header_dict["CALIB"], "1 2 3")
+ self.assertEqual(self.scan1_2.mca.calibration,
+ [[1., 2., 3.]])
+ # default calib in the absence of #@CALIB
+ self.assertEqual(self.scan25.mca.calibration,
+ [[0., 1., 0.]])
+ self.assertEqual(self.scan1_2.mca.channels,
+ [[0, 1, 2]])
+ # absence of #@CHANN and spectra
+ self.assertEqual(self.scan25.mca.channels,
+ [])
+
+ @testutils.validate_logging(specfile._logger.name, warning=1)
+ def test_empty_scan(self):
+ """Test reading a scan with no data points"""
+ self.assertEqual(len(self.empty_scan.labels),
+ 3)
+ col1 = self.empty_scan.data_column_by_name("second column")
+ self.assertEqual(col1.shape, (0, ))
+
+
+class TestSFLocale(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname = tempfile.mkstemp(text=False)
+ if sys.version_info < (3, ):
+ os.write(fd, sftext)
+ else:
+ os.write(fd, bytes(sftext, 'ascii'))
+ os.close(fd)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname)
+ locale.setlocale(locale.LC_NUMERIC, loc) # restore saved locale
+
+ def crunch_data(self):
+ self.sf3 = SpecFile(self.fname)
+ self.assertAlmostEqual(self.sf3[0].data_line(1)[2],
+ 1.56)
+ self.sf3.close()
+
+ @unittest.skipIf(not try_DE, "de_DE.utf8 locale not installed")
+ def test_locale_de_DE(self):
+ locale.setlocale(locale.LC_NUMERIC, 'de_DE.utf8')
+ self.crunch_data()
+
+ def test_locale_user(self):
+ locale.setlocale(locale.LC_NUMERIC, '') # use user's preferred locale
+ self.crunch_data()
+
+ def test_locale_C(self):
+ locale.setlocale(locale.LC_NUMERIC, 'C') # use default (C) locale
+ self.crunch_data()
diff --git a/src/silx/io/test/test_specfilewrapper.py b/src/silx/io/test/test_specfilewrapper.py
new file mode 100644
index 0000000..a1ba5f4
--- /dev/null
+++ b/src/silx/io/test/test_specfilewrapper.py
@@ -0,0 +1,195 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for old specfile wrapper"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "15/05/2017"
+
+import locale
+import logging
+import numpy
+import os
+import sys
+import tempfile
+import unittest
+
+logger1 = logging.getLogger(__name__)
+
+from ..specfilewrapper import Specfile
+
+sftext = """#F /tmp/sf.dat
+#E 1455180875
+#D Thu Feb 11 09:54:35 2016
+#C imaging User = opid17
+#U00 user comment first line
+#U01 This is a dummy file to test SpecFile parsing
+#U02
+#U03 last line
+
+#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
+#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
+#o0 pshg mrtu mrtd
+#o2 ss1vo ss1ho ss1vg
+
+#J0 Seconds IA ion.mono Current
+#J1 xbpmc2 idgap1 Inorm
+
+#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
+#D Thu Feb 11 09:55:20 2016
+#T 0.2 (Seconds)
+#G0 0
+#G1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
+#G3 0 0 0 0 0 0 0 0 0
+#G4 0
+#Q
+#P0 180.005 -0.66875 0.87125
+#P1 14.74255 16.197579 12.238283
+#UMI0 Current AutoM Shutter
+#UMI1 192.51 OFF FE open
+#UMI2 Refill in 39883 sec, Fill Mode: uniform multibunch / Message: Feb 11 08:00 Delivery:Next Refill at 21:00;
+#N 4
+#L first column second column 3rd_col
+-1.23 5.89 8
+8.478100E+01 5 1.56
+3.14 2.73 -3.14
+1.2 2.3 3.4
+
+#S 25 ascan c3th 1.33245 1.52245 40 0.15
+#D Thu Feb 11 10:00:31 2016
+#P0 80.005 -1.66875 1.87125
+#P1 4.74255 6.197579 2.238283
+#N 5
+#L column0 column1 col2 col3
+0.0 0.1 0.2 0.3
+1.0 1.1 1.2 1.3
+2.0 2.1 2.2 2.3
+3.0 3.1 3.2 3.3
+
+#F /tmp/sf.dat
+#E 1455180876
+#D Thu Feb 11 09:54:36 2016
+
+#S 1 aaaaaa
+#U first duplicate line
+#U second duplicate line
+#@MCADEV 1
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#N 3
+#L uno duo
+1 2
+@A 0 1 2
+3 4
+@A 3.1 4 5
+5 6
+@A 6 7.7 8
+"""
+
+
+class TestSpecfilewrapper(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname1 = tempfile.mkstemp(text=False)
+ if sys.version_info < (3, ):
+ os.write(fd, sftext)
+ else:
+ os.write(fd, bytes(sftext, 'ascii'))
+ os.close(fd)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname1)
+
+ def setUp(self):
+ self.sf = Specfile(self.fname1)
+ self.scan1 = self.sf[0]
+ self.scan1_2 = self.sf.select("1.2")
+ self.scan25 = self.sf.select("25.1")
+
+ def tearDown(self):
+ self.sf.close()
+
+ def test_number_of_scans(self):
+ self.assertEqual(3, len(self.sf))
+
+ def test_list_of_scan_indices(self):
+ self.assertEqual(self.sf.list(),
+ '1,25,1')
+ self.assertEqual(self.sf.keys(),
+ ["1.1", "25.1", "1.2"])
+
+ def test_scan_headers(self):
+ self.assertEqual(self.scan25.header('S'),
+ ["#S 25 ascan c3th 1.33245 1.52245 40 0.15"])
+ self.assertEqual(self.scan1.header("G0"), ['#G0 0'])
+ # parsing headers with long keys
+ # parsing empty headers
+ self.assertEqual(self.scan1.header('Q'), ['#Q '])
+
+ def test_file_headers(self):
+ self.assertEqual(self.scan1.header("E"),
+ ['#E 1455180875'])
+ self.assertEqual(self.sf.title(),
+ "imaging")
+ self.assertEqual(self.sf.epoch(),
+ 1455180875)
+ self.assertEqual(self.sf.allmotors(),
+ ["Pslit HGap", "MRTSlit UP", "MRTSlit DOWN",
+ "Sslit1 VOff", "Sslit1 HOff", "Sslit1 VGap"])
+
+ def test_scan_labels(self):
+ self.assertEqual(self.scan1.alllabels(),
+ ['first column', 'second column', '3rd_col'])
+
+ def test_data(self):
+ self.assertAlmostEqual(self.scan1.dataline(3)[2],
+ -3.14)
+ self.assertAlmostEqual(self.scan1.datacol(1)[2],
+ 3.14)
+ # tests for data transposition between original file and .data attr
+ self.assertAlmostEqual(self.scan1.data()[2, 0],
+ 8)
+ self.assertEqual(self.scan1.data().shape, (3, 4))
+ self.assertAlmostEqual(numpy.sum(self.scan1.data()), 113.631)
+
+ def test_date(self):
+ self.assertEqual(self.scan1.date(),
+ "Thu Feb 11 09:55:20 2016")
+
+ def test_motors(self):
+ self.assertEqual(len(self.sf.allmotors()), 6)
+ self.assertEqual(len(self.scan1.allmotorpos()), 6)
+ self.assertAlmostEqual(sum(self.scan1.allmotorpos()),
+ 223.385912)
+ self.assertEqual(self.sf.allmotors()[1], 'MRTSlit UP')
+
+ def test_mca(self):
+ self.assertEqual(self.scan1_2.mca(2)[2], 5)
+ self.assertEqual(sum(self.scan1_2.mca(3)), 21.7)
+
+ def test_mca_header(self):
+ self.assertEqual(self.scan1_2.header("CALIB"),
+ ["#@CALIB 1 2 3"])
diff --git a/src/silx/io/test/test_spech5.py b/src/silx/io/test/test_spech5.py
new file mode 100644
index 0000000..1e67961
--- /dev/null
+++ b/src/silx/io/test/test_spech5.py
@@ -0,0 +1,929 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for spech5"""
+import numpy
+import os
+import io
+import sys
+import tempfile
+import unittest
+import datetime
+from functools import partial
+
+from silx.utils import testutils
+
+from .. import spech5
+from ..spech5 import (SpecH5, SpecH5Dataset, spec_date_to_iso8601)
+from .. import specfile
+
+import h5py
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "12/02/2018"
+
+sftext = """#F /tmp/sf.dat
+#E 1455180875
+#D Thu Feb 11 09:54:35 2016
+#C imaging User = opid17
+#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
+#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
+#o0 pshg mrtu mrtd
+#o2 ss1vo ss1ho ss1vg
+
+#J0 Seconds IA ion.mono Current
+#J1 xbpmc2 idgap1 Inorm
+
+#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
+#D Thu Feb 11 09:55:20 2016
+#T 0.2 (Seconds)
+#P0 180.005 -0.66875 0.87125
+#P1 14.74255 16.197579 12.238283
+#N 4
+#L MRTSlit UP second column 3rd_col
+-1.23 5.89 8
+8.478100E+01 5 1.56
+3.14 2.73 -3.14
+1.2 2.3 3.4
+
+#S 25 ascan c3th 1.33245 1.52245 40 0.15
+#D Sat 2015/03/14 03:53:50
+#P0 80.005 -1.66875 1.87125
+#P1 4.74255 6.197579 2.238283
+#N 5
+#L column0 column1 col2 col3
+0.0 0.1 0.2 0.3
+1.0 1.1 1.2 1.3
+2.0 2.1 2.2 2.3
+3.0 3.1 3.2 3.3
+
+#S 1 aaaaaa
+#D Thu Feb 11 10:00:32 2016
+#@MCADEV 1
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#@CTIME 123.4 234.5 345.6
+#N 3
+#L uno duo
+1 2
+@A 0 1 2
+@A 10 9 8
+@A 1 1 1.1
+3 4
+@A 3.1 4 5
+@A 7 6 5
+@A 1 1 1
+5 6
+@A 6 7.7 8
+@A 4 3 2
+@A 1 1 1
+
+#S 1000 bbbbb
+#G1 3.25 3.25 5.207 90 90 120 2.232368448 2.232368448 1.206680489 90 90 60 1 1 2 -1 2 2 26.132 7.41 -88.96 1.11 1.000012861 15.19 26.06 67.355 -88.96 1.11 1.000012861 15.11 0.723353 0.723353
+#G3 0.0106337923671 0.027529133 1.206191273 -1.43467075 0.7633438883 0.02401568018 -1.709143587 -2.097621783 0.02456954971
+#L a b
+1 2
+
+#S 1001 ccccc
+#G1 0. 0. 0. 0 0 0 2.232368448 2.232368448 1.206680489 90 90 60 1 1 2 -1 2 2 26.132 7.41 -88.96 1.11 1.000012861 15.19 26.06 67.355 -88.96 1.11 1.000012861 15.11 0.723353 0.723353
+#G3 0. 0. 0. 0. 0.0 0. 0. 0. 0.
+#L a b
+1 2
+
+"""
+
+
+class TestSpecDate(unittest.TestCase):
+ """
+ Test of the spec_date_to_iso8601 function.
+ """
+ # TODO : time zone tests
+ # TODO : error cases
+
+ @classmethod
+ def setUpClass(cls):
+ import locale
+ # FYI : not threadsafe
+ cls.locale_saved = locale.setlocale(locale.LC_TIME)
+ locale.setlocale(locale.LC_TIME, 'C')
+
+ @classmethod
+ def tearDownClass(cls):
+ import locale
+ # FYI : not threadsafe
+ locale.setlocale(locale.LC_TIME, cls.locale_saved)
+
+ def setUp(self):
+ # covering all week days
+ self.n_days = range(1, 10)
+ # covering all months
+ self.n_months = range(1, 13)
+
+ self.n_years = [1999, 2016, 2020]
+ self.n_seconds = [0, 5, 26, 59]
+ self.n_minutes = [0, 9, 42, 59]
+ self.n_hours = [0, 2, 17, 23]
+
+ self.formats = ['%a %b %d %H:%M:%S %Y', '%a %Y/%m/%d %H:%M:%S']
+
+ self.check_date_formats = partial(self.__check_date_formats,
+ year=self.n_years[0],
+ month=self.n_months[0],
+ day=self.n_days[0],
+ hour=self.n_hours[0],
+ minute=self.n_minutes[0],
+ second=self.n_seconds[0],
+ msg=None)
+
+ def __check_date_formats(self,
+ year,
+ month,
+ day,
+ hour,
+ minute,
+ second,
+ msg=None):
+ dt = datetime.datetime(year, month, day, hour, minute, second)
+ expected_date = dt.isoformat()
+
+ for i_fmt, fmt in enumerate(self.formats):
+ spec_date = dt.strftime(fmt)
+ iso_date = spec_date_to_iso8601(spec_date)
+ self.assertEqual(iso_date,
+ expected_date,
+ msg='Testing {0}. format={1}. '
+ 'Expected "{2}", got "{3} ({4})" (dt={5}).'
+ ''.format(msg,
+ i_fmt,
+ expected_date,
+ iso_date,
+ spec_date,
+ dt))
+
+ def testYearsNominal(self):
+ for year in self.n_years:
+ self.check_date_formats(year=year, msg='year')
+
+ def testMonthsNominal(self):
+ for month in self.n_months:
+ self.check_date_formats(month=month, msg='month')
+
+ def testDaysNominal(self):
+ for day in self.n_days:
+ self.check_date_formats(day=day, msg='day')
+
+ def testHoursNominal(self):
+ for hour in self.n_hours:
+ self.check_date_formats(hour=hour, msg='hour')
+
+ def testMinutesNominal(self):
+ for minute in self.n_minutes:
+ self.check_date_formats(minute=minute, msg='minute')
+
+ def testSecondsNominal(self):
+ for second in self.n_seconds:
+ self.check_date_formats(second=second, msg='second')
+
+
+class TestSpecH5(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname = tempfile.mkstemp()
+ if sys.version_info < (3, ):
+ os.write(fd, sftext)
+ else:
+ os.write(fd, bytes(sftext, 'ascii'))
+ os.close(fd)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname)
+
+ def setUp(self):
+ self.sfh5 = SpecH5(self.fname)
+
+ def tearDown(self):
+ self.sfh5.close()
+
+ def testContainsFile(self):
+ self.assertIn("/1.2/measurement", self.sfh5)
+ self.assertIn("/25.1", self.sfh5)
+ self.assertIn("25.1", self.sfh5)
+ self.assertNotIn("25.2", self.sfh5)
+ # measurement is a child of a scan, full path would be required to
+ # access from root level
+ self.assertNotIn("measurement", self.sfh5)
+ # Groups may or may not have a trailing /
+ self.assertIn("/1.2/measurement/mca_1/", self.sfh5)
+ self.assertIn("/1.2/measurement/mca_1", self.sfh5)
+ # Datasets can't have a trailing /
+ self.assertNotIn("/1.2/measurement/mca_0/info/calibration/ ", self.sfh5)
+ # No mca_8
+ self.assertNotIn("/1.2/measurement/mca_8/info/calibration", self.sfh5)
+ # Link
+ self.assertIn("/1.2/measurement/mca_0/info/calibration", self.sfh5)
+
+ def testContainsGroup(self):
+ self.assertIn("measurement", self.sfh5["/1.2/"])
+ self.assertIn("measurement", self.sfh5["/1.2"])
+ self.assertIn("25.1", self.sfh5["/"])
+ self.assertNotIn("25.2", self.sfh5["/"])
+ self.assertIn("instrument/positioners/Sslit1 HOff", self.sfh5["/1.1"])
+ # illegal trailing "/" after dataset name
+ self.assertNotIn("instrument/positioners/Sslit1 HOff/",
+ self.sfh5["/1.1"])
+ # full path to element in group (OK)
+ self.assertIn("/1.1/instrument/positioners/Sslit1 HOff",
+ self.sfh5["/1.1/instrument"])
+
+ def testDataColumn(self):
+ self.assertAlmostEqual(sum(self.sfh5["/1.2/measurement/duo"]),
+ 12.0)
+ self.assertAlmostEqual(
+ sum(self.sfh5["1.1"]["measurement"]["MRTSlit UP"]),
+ 87.891, places=4)
+
+ def testDate(self):
+ # start time is in Iso8601 format
+ self.assertEqual(self.sfh5["/1.1/start_time"],
+ u"2016-02-11T09:55:20")
+ self.assertEqual(self.sfh5["25.1/start_time"],
+ u"2015-03-14T03:53:50")
+
+ def assertRaisesRegex(self, *args, **kwargs):
+ # Python 2 compatibility
+ if sys.version_info.major >= 3:
+ return super(TestSpecH5, self).assertRaisesRegex(*args, **kwargs)
+ else:
+ return self.assertRaisesRegexp(*args, **kwargs)
+
+ def testDatasetInstanceAttr(self):
+ """The SpecH5Dataset objects must implement some dummy attributes
+ to improve compatibility with widgets dealing with h5py datasets."""
+ self.assertIsNone(self.sfh5["/1.1/start_time"].compression)
+ self.assertIsNone(self.sfh5["1.1"]["measurement"]["MRTSlit UP"].chunks)
+
+ # error message must be explicit
+ with self.assertRaisesRegex(
+ AttributeError,
+ "SpecH5Dataset has no attribute tOTo"):
+ dummy = self.sfh5["/1.1/start_time"].tOTo
+
+ def testGet(self):
+ """Test :meth:`SpecH5Group.get`"""
+ # default value of param *default* is None
+ self.assertIsNone(self.sfh5.get("toto"))
+ self.assertEqual(self.sfh5["25.1"].get("toto", default=-3),
+ -3)
+
+ self.assertEqual(self.sfh5.get("/1.1/start_time", default=-3),
+ u"2016-02-11T09:55:20")
+
+ def testGetClass(self):
+ """Test :meth:`SpecH5Group.get`"""
+ self.assertIs(self.sfh5["1.1"].get("start_time", getclass=True),
+ h5py.Dataset)
+ self.assertIs(self.sfh5["1.1"].get("instrument", getclass=True),
+ h5py.Group)
+
+ # spech5 does not define external link, so there is no way
+ # a group can *get* a SpecH5 class
+
+ def testGetApi(self):
+ result = self.sfh5.get("1.1", getclass=True, getlink=True)
+ self.assertIs(result, h5py.HardLink)
+ result = self.sfh5.get("1.1", getclass=False, getlink=True)
+ self.assertIsInstance(result, h5py.HardLink)
+ result = self.sfh5.get("1.1", getclass=True, getlink=False)
+ self.assertIs(result, h5py.Group)
+ result = self.sfh5.get("1.1", getclass=False, getlink=False)
+ self.assertIsInstance(result, spech5.SpecH5Group)
+
+ def testGetItemGroup(self):
+ group = self.sfh5["25.1"]["instrument"]
+ self.assertEqual(list(group["positioners"].keys()),
+ ["Pslit HGap", "MRTSlit UP", "MRTSlit DOWN",
+ "Sslit1 VOff", "Sslit1 HOff", "Sslit1 VGap"])
+ with self.assertRaises(KeyError):
+ group["Holy Grail"]
+
+ def testGetitemSpecH5(self):
+ self.assertEqual(self.sfh5["/1.2/instrument/positioners"],
+ self.sfh5["1.2"]["instrument"]["positioners"])
+
+ def testH5pyClass(self):
+ """Test :attr:`h5py_class` returns the corresponding h5py class
+ (h5py.File, h5py.Group, h5py.Dataset)"""
+ a_file = self.sfh5
+ self.assertIs(a_file.h5py_class,
+ h5py.File)
+
+ a_group = self.sfh5["/1.2/measurement"]
+ self.assertIs(a_group.h5py_class,
+ h5py.Group)
+
+ a_dataset = self.sfh5["/1.1/instrument/positioners/Sslit1 HOff"]
+ self.assertIs(a_dataset.h5py_class,
+ h5py.Dataset)
+
+ def testHeader(self):
+ file_header = self.sfh5["/1.2/instrument/specfile/file_header"]
+ scan_header = self.sfh5["/1.2/instrument/specfile/scan_header"]
+
+ # File header has 10 lines
+ self.assertEqual(len(file_header), 10)
+ # 1.2 has 9 scan & mca header lines
+ self.assertEqual(len(scan_header), 9)
+
+ # line 4 of file header
+ self.assertEqual(
+ file_header[3],
+ u"#C imaging User = opid17")
+ # line 4 of scan header
+ scan_header = self.sfh5["25.1/instrument/specfile/scan_header"]
+
+ self.assertEqual(
+ scan_header[3],
+ u"#P1 4.74255 6.197579 2.238283")
+
+ def testLinks(self):
+ self.assertTrue(numpy.array_equal(
+ self.sfh5["/1.2/measurement/mca_0/data"],
+ self.sfh5["/1.2/instrument/mca_0/data"])
+ )
+ self.assertTrue(numpy.array_equal(
+ self.sfh5["/1.2/measurement/mca_0/info/data"],
+ self.sfh5["/1.2/instrument/mca_0/data"])
+ )
+ self.assertTrue(numpy.array_equal(
+ self.sfh5["/1.2/measurement/mca_0/info/channels"],
+ self.sfh5["/1.2/instrument/mca_0/channels"])
+ )
+ self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/"].keys(),
+ self.sfh5["/1.2/instrument/mca_0/"].keys())
+
+ self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/preset_time"],
+ self.sfh5["/1.2/instrument/mca_0/preset_time"])
+ self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/live_time"],
+ self.sfh5["/1.2/instrument/mca_0/live_time"])
+ self.assertEqual(self.sfh5["/1.2/measurement/mca_0/info/elapsed_time"],
+ self.sfh5["/1.2/instrument/mca_0/elapsed_time"])
+
+ def testListScanIndices(self):
+ self.assertEqual(list(self.sfh5.keys()),
+ ["1.1", "25.1", "1.2", "1000.1", "1001.1"])
+ self.assertEqual(self.sfh5["1.2"].attrs,
+ {"NX_class": "NXentry", })
+
+ def testMcaAbsent(self):
+ def access_absent_mca():
+ """This must raise a KeyError, because scan 1.1 has no MCA"""
+ return self.sfh5["/1.1/measurement/mca_0/"]
+ self.assertRaises(KeyError, access_absent_mca)
+
+ def testMcaCalib(self):
+ mca0_calib = self.sfh5["/1.2/measurement/mca_0/info/calibration"]
+ mca1_calib = self.sfh5["/1.2/measurement/mca_1/info/calibration"]
+ self.assertEqual(mca0_calib.tolist(),
+ [1, 2, 3])
+ # calibration is unique in this scan and applies to all analysers
+ self.assertEqual(mca0_calib.tolist(),
+ mca1_calib.tolist())
+
+ def testMcaChannels(self):
+ mca0_chann = self.sfh5["/1.2/measurement/mca_0/info/channels"]
+ mca1_chann = self.sfh5["/1.2/measurement/mca_1/info/channels"]
+ self.assertEqual(mca0_chann.tolist(),
+ [0, 1, 2])
+ self.assertEqual(mca0_chann.tolist(),
+ mca1_chann.tolist())
+
+ def testMcaCtime(self):
+ """Tests for #@CTIME mca header"""
+ datasets = ["preset_time", "live_time", "elapsed_time"]
+ for ds in datasets:
+ self.assertNotIn("/1.1/instrument/mca_0/" + ds, self.sfh5)
+ self.assertIn("/1.2/instrument/mca_0/" + ds, self.sfh5)
+
+ mca0_preset_time = self.sfh5["/1.2/instrument/mca_0/preset_time"]
+ mca1_preset_time = self.sfh5["/1.2/instrument/mca_1/preset_time"]
+ self.assertLess(mca0_preset_time - 123.4,
+ 10**-5)
+ # ctime is unique in a this scan and applies to all analysers
+ self.assertEqual(mca0_preset_time,
+ mca1_preset_time)
+
+ mca0_live_time = self.sfh5["/1.2/instrument/mca_0/live_time"]
+ mca1_live_time = self.sfh5["/1.2/instrument/mca_1/live_time"]
+ self.assertLess(mca0_live_time - 234.5,
+ 10**-5)
+ self.assertEqual(mca0_live_time,
+ mca1_live_time)
+
+ mca0_elapsed_time = self.sfh5["/1.2/instrument/mca_0/elapsed_time"]
+ mca1_elapsed_time = self.sfh5["/1.2/instrument/mca_1/elapsed_time"]
+ self.assertLess(mca0_elapsed_time - 345.6,
+ 10**-5)
+ self.assertEqual(mca0_elapsed_time,
+ mca1_elapsed_time)
+
+ def testMcaData(self):
+ # sum 1st MCA in scan 1.2 over rows
+ mca_0_data = self.sfh5["/1.2/measurement/mca_0/data"]
+ for summed_row, expected in zip(mca_0_data.sum(axis=1).tolist(),
+ [3.0, 12.1, 21.7]):
+ self.assertAlmostEqual(summed_row, expected, places=4)
+
+ # sum 3rd MCA in scan 1.2 along both axis
+ mca_2_data = self.sfh5["1.2"]["measurement"]["mca_2"]["data"]
+ self.assertAlmostEqual(sum(sum(mca_2_data)), 9.1, places=5)
+ # attrs
+ self.assertEqual(mca_0_data.attrs, {"interpretation": "spectrum"})
+
+ def testMotorPosition(self):
+ positioners_group = self.sfh5["/1.1/instrument/positioners"]
+ # MRTSlit DOWN position is defined in #P0 san header line
+ self.assertAlmostEqual(float(positioners_group["MRTSlit DOWN"]),
+ 0.87125)
+ # MRTSlit UP position is defined in first data column
+ for a, b in zip(positioners_group["MRTSlit UP"].tolist(),
+ [-1.23, 8.478100E+01, 3.14, 1.2]):
+ self.assertAlmostEqual(float(a), b, places=4)
+
+ def testNumberMcaAnalysers(self):
+ """Scan 1.2 has 2 data columns + 3 mca spectra per data line."""
+ self.assertEqual(len(self.sfh5["1.2"]["measurement"]), 5)
+
+ def testTitle(self):
+ self.assertEqual(self.sfh5["/25.1/title"],
+ u"ascan c3th 1.33245 1.52245 40 0.15")
+
+ def testValues(self):
+ group = self.sfh5["/25.1"]
+ self.assertTrue(hasattr(group, "values"))
+ self.assertTrue(callable(group.values))
+ self.assertIn(self.sfh5["/25.1/title"],
+ self.sfh5["/25.1"].values())
+
+ # visit and visititems ignore links
+ def testVisit(self):
+ name_list = []
+ self.sfh5.visit(name_list.append)
+ self.assertIn('1.2/instrument/positioners/Pslit HGap', name_list)
+ self.assertIn("1.2/instrument/specfile/scan_header", name_list)
+ self.assertEqual(len(name_list), 117)
+
+ # test also visit of a subgroup, with various group name formats
+ name_list_leading_and_trailing_slash = []
+ self.sfh5['/1.2/instrument/'].visit(name_list_leading_and_trailing_slash.append)
+ name_list_leading_slash = []
+ self.sfh5['/1.2/instrument'].visit(name_list_leading_slash.append)
+ name_list_trailing_slash = []
+ self.sfh5['1.2/instrument/'].visit(name_list_trailing_slash.append)
+ name_list_no_slash = []
+ self.sfh5['1.2/instrument'].visit(name_list_no_slash.append)
+
+ # no differences expected in the output names
+ self.assertEqual(name_list_leading_and_trailing_slash,
+ name_list_leading_slash)
+ self.assertEqual(name_list_leading_slash,
+ name_list_trailing_slash)
+ self.assertEqual(name_list_leading_slash,
+ name_list_no_slash)
+ self.assertIn("positioners/Pslit HGap", name_list_no_slash)
+ self.assertIn("positioners", name_list_no_slash)
+
+ def testVisitItems(self):
+ dataset_name_list = []
+
+ def func_generator(l):
+ """return a function appending names to list l"""
+ def func(name, obj):
+ if isinstance(obj, SpecH5Dataset):
+ l.append(name)
+ return func
+
+ self.sfh5.visititems(func_generator(dataset_name_list))
+ self.assertIn('1.2/instrument/positioners/Pslit HGap', dataset_name_list)
+ self.assertEqual(len(dataset_name_list), 85)
+
+ # test also visit of a subgroup, with various group name formats
+ name_list_leading_and_trailing_slash = []
+ self.sfh5['/1.2/instrument/'].visititems(func_generator(name_list_leading_and_trailing_slash))
+ name_list_leading_slash = []
+ self.sfh5['/1.2/instrument'].visititems(func_generator(name_list_leading_slash))
+ name_list_trailing_slash = []
+ self.sfh5['1.2/instrument/'].visititems(func_generator(name_list_trailing_slash))
+ name_list_no_slash = []
+ self.sfh5['1.2/instrument'].visititems(func_generator(name_list_no_slash))
+
+ # no differences expected in the output names
+ self.assertEqual(name_list_leading_and_trailing_slash,
+ name_list_leading_slash)
+ self.assertEqual(name_list_leading_slash,
+ name_list_trailing_slash)
+ self.assertEqual(name_list_leading_slash,
+ name_list_no_slash)
+ self.assertIn("positioners/Pslit HGap", name_list_no_slash)
+
+ def testNotSpecH5(self):
+ fd, fname = tempfile.mkstemp()
+ os.write(fd, b"Not a spec file!")
+ os.close(fd)
+ self.assertRaises(specfile.SfErrFileOpen, SpecH5, fname)
+ self.assertRaises(IOError, SpecH5, fname)
+ os.unlink(fname)
+
+ def testSample(self):
+ self.assertNotIn("sample", self.sfh5["/1.1"])
+ self.assertIn("sample", self.sfh5["/1000.1"])
+ self.assertIn("ub_matrix", self.sfh5["/1000.1/sample"])
+ self.assertIn("unit_cell", self.sfh5["/1000.1/sample"])
+ self.assertIn("unit_cell_abc", self.sfh5["/1000.1/sample"])
+ self.assertIn("unit_cell_alphabetagamma", self.sfh5["/1000.1/sample"])
+
+ # All 0 values
+ self.assertNotIn("sample", self.sfh5["/1001.1"])
+ with self.assertRaises(KeyError):
+ self.sfh5["/1001.1/sample/unit_cell"]
+
+ @testutils.validate_logging(spech5.logger1.name, warning=2)
+ def testOpenFileDescriptor(self):
+ """Open a SpecH5 file from a file descriptor"""
+ with io.open(self.sfh5.filename) as f:
+ sfh5 = SpecH5(f)
+ self.assertIsNotNone(sfh5)
+ name_list = []
+ # check if the object is working
+ self.sfh5.visit(name_list.append)
+ sfh5.close()
+
+
+sftext_multi_mca_headers = """
+#S 1 aaaaaa
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#@CTIME 123.4 234.5 345.6
+#@MCA %16C
+#@CHANN 3 1 3 1
+#@CALIB 5.5 6.6 7.7
+#@CTIME 10 11 12
+#N 3
+#L uno duo
+1 2
+@A 0 1 2
+@A 10 9 8
+3 4
+@A 3.1 4 5
+@A 7 6 5
+5 6
+@A 6 7.7 8
+@A 4 3 2
+
+"""
+
+
+class TestSpecH5MultiMca(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname = tempfile.mkstemp(text=False)
+ if sys.version_info < (3, ):
+ os.write(fd, sftext_multi_mca_headers)
+ else:
+ os.write(fd, bytes(sftext_multi_mca_headers, 'ascii'))
+ os.close(fd)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname)
+
+ def setUp(self):
+ self.sfh5 = SpecH5(self.fname)
+
+ def tearDown(self):
+ self.sfh5.close()
+
+ def testMcaCalib(self):
+ mca0_calib = self.sfh5["/1.1/measurement/mca_0/info/calibration"]
+ mca1_calib = self.sfh5["/1.1/measurement/mca_1/info/calibration"]
+ self.assertEqual(mca0_calib.tolist(),
+ [1, 2, 3])
+ self.assertAlmostEqual(sum(mca1_calib.tolist()),
+ sum([5.5, 6.6, 7.7]),
+ places=5)
+
+ def testMcaChannels(self):
+ mca0_chann = self.sfh5["/1.1/measurement/mca_0/info/channels"]
+ mca1_chann = self.sfh5["/1.1/measurement/mca_1/info/channels"]
+ self.assertEqual(mca0_chann.tolist(),
+ [0., 1., 2.])
+ # @CHANN is unique in this scan and applies to all analysers
+ self.assertEqual(mca1_chann.tolist(),
+ [1., 2., 3.])
+
+ def testMcaCtime(self):
+ """Tests for #@CTIME mca header"""
+ mca0_preset_time = self.sfh5["/1.1/instrument/mca_0/preset_time"]
+ mca1_preset_time = self.sfh5["/1.1/instrument/mca_1/preset_time"]
+ self.assertLess(mca0_preset_time - 123.4,
+ 10**-5)
+ self.assertLess(mca1_preset_time - 10,
+ 10**-5)
+
+ mca0_live_time = self.sfh5["/1.1/instrument/mca_0/live_time"]
+ mca1_live_time = self.sfh5["/1.1/instrument/mca_1/live_time"]
+ self.assertLess(mca0_live_time - 234.5,
+ 10**-5)
+ self.assertLess(mca1_live_time - 11,
+ 10**-5)
+
+ mca0_elapsed_time = self.sfh5["/1.1/instrument/mca_0/elapsed_time"]
+ mca1_elapsed_time = self.sfh5["/1.1/instrument/mca_1/elapsed_time"]
+ self.assertLess(mca0_elapsed_time - 345.6,
+ 10**-5)
+ self.assertLess(mca1_elapsed_time - 12,
+ 10**-5)
+
+
+sftext_no_cols = r"""#F C:/DATA\test.mca
+#D Thu Jul 7 08:40:19 2016
+
+#S 1 31oct98.dat 22.1 If4
+#D Thu Jul 7 08:40:19 2016
+#C no data cols, one mca analyser, single spectrum
+#@MCA %16C
+#@CHANN 151 0 150 1
+#@CALIB 0 2 0
+@A 789 784 788 814 847 862 880 904 925 955 987 1015 1031 1070 1111 1139 \
+1203 1236 1290 1392 1492 1558 1688 1813 1977 2119 2346 2699 3121 3542 4102 4970 \
+6071 7611 10426 16188 28266 40348 50539 55555 56162 54162 47102 35718 24588 17034 12994 11444 \
+11808 13461 15687 18885 23827 31578 41999 49556 58084 59415 59456 55698 44525 28219 17680 12881 \
+9518 7415 6155 5246 4646 3978 3612 3299 3020 2761 2670 2472 2500 2310 2286 2106 \
+1989 1890 1782 1655 1421 1293 1135 990 879 757 672 618 532 488 445 424 \
+414 373 351 325 307 284 270 247 228 213 199 187 183 176 164 156 \
+153 140 142 130 118 118 103 101 97 86 90 86 87 81 75 82 \
+80 76 77 75 76 77 62 69 74 60 65 68 65 58 63 64 \
+63 59 60 56 57 60 55
+
+#S 2 31oct98.dat 22.1 If4
+#D Thu Jul 7 08:40:19 2016
+#C no data cols, one mca analyser, multiple spectra
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#@CTIME 123.4 234.5 345.6
+@A 0 1 2
+@A 10 9 8
+@A 1 1 1.1
+@A 3.1 4 5
+@A 7 6 5
+@A 1 1 1
+@A 6 7.7 8
+@A 4 3 2
+@A 1 1 1
+
+#S 3 31oct98.dat 22.1 If4
+#D Thu Jul 7 08:40:19 2016
+#C no data cols, 3 mca analysers, multiple spectra
+#@MCADEV 1
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#@CTIME 123.4 234.5 345.6
+#@MCADEV 2
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#@CTIME 123.4 234.5 345.6
+#@MCADEV 3
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#@CTIME 123.4 234.5 345.6
+@A 0 1 2
+@A 10 9 8
+@A 1 1 1.1
+@A 3.1 4 5
+@A 7 6 5
+@A 1 1 1
+@A 6 7.7 8
+@A 4 3 2
+@A 1 1 1
+"""
+
+
+class TestSpecH5NoDataCols(unittest.TestCase):
+ """Test reading SPEC files with only MCA data"""
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname = tempfile.mkstemp()
+ if sys.version_info < (3, ):
+ os.write(fd, sftext_no_cols)
+ else:
+ os.write(fd, bytes(sftext_no_cols, 'ascii'))
+ os.close(fd)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname)
+
+ def setUp(self):
+ self.sfh5 = SpecH5(self.fname)
+
+ def tearDown(self):
+ self.sfh5.close()
+
+ def testScan1(self):
+ # 1.1: single analyser, single spectrum, 151 channels
+ self.assertIn("mca_0",
+ self.sfh5["1.1/instrument/"])
+ self.assertEqual(self.sfh5["1.1/instrument/mca_0/data"].shape,
+ (1, 151))
+ self.assertNotIn("mca_1",
+ self.sfh5["1.1/instrument/"])
+
+ def testScan2(self):
+ # 2.1: single analyser, 9 spectra, 3 channels
+ self.assertIn("mca_0",
+ self.sfh5["2.1/instrument/"])
+ self.assertEqual(self.sfh5["2.1/instrument/mca_0/data"].shape,
+ (9, 3))
+ self.assertNotIn("mca_1",
+ self.sfh5["2.1/instrument/"])
+
+ def testScan3(self):
+ # 3.1: 3 analysers, 3 spectra/analyser, 3 channels
+ for i in range(3):
+ self.assertIn("mca_%d" % i,
+ self.sfh5["3.1/instrument/"])
+ self.assertEqual(
+ self.sfh5["3.1/instrument/mca_%d/data" % i].shape,
+ (3, 3))
+
+ self.assertNotIn("mca_3",
+ self.sfh5["3.1/instrument/"])
+
+
+sf_text_slash = r"""#F /data/id09/archive/logspecfiles/laue/2016/scan_231_laue_16-11-29.dat
+#D Sat Dec 10 22:20:59 2016
+#O0 Pslit/HGap MRTSlit%UP
+
+#S 1 laue_16-11-29.log 231.1 PD3/A
+#D Sat Dec 10 22:20:59 2016
+#P0 180.005 -0.66875
+#N 2
+#L GONY/mm PD3%A
+-2.015 5.250424e-05
+-2.01 5.30798e-05
+-2.005 5.281903e-05
+-2 5.220436e-05
+"""
+
+
+class TestSpecH5SlashInLabels(unittest.TestCase):
+ """Test reading SPEC files with labels containing a / character
+
+ The / character must be substituted with a %
+ """
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.fname = tempfile.mkstemp()
+ if sys.version_info < (3, ):
+ os.write(fd, sf_text_slash)
+ else:
+ os.write(fd, bytes(sf_text_slash, 'ascii'))
+ os.close(fd)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.fname)
+
+ def setUp(self):
+ self.sfh5 = SpecH5(self.fname)
+
+ def tearDown(self):
+ self.sfh5.close()
+
+ def testLabels(self):
+ """Ensure `/` is substituted with `%` and
+ ensure legitimate `%` in names are still working"""
+ self.assertEqual(list(self.sfh5["1.1/measurement/"].keys()),
+ ["GONY%mm", "PD3%A"])
+
+ # substituted "%"
+ self.assertIn("GONY%mm",
+ self.sfh5["1.1/measurement/"])
+ self.assertNotIn("GONY/mm",
+ self.sfh5["1.1/measurement/"])
+ self.assertAlmostEqual(self.sfh5["1.1/measurement/GONY%mm"][0],
+ -2.015, places=4)
+ # legitimate "%"
+ self.assertIn("PD3%A",
+ self.sfh5["1.1/measurement/"])
+
+ def testMotors(self):
+ """Ensure `/` is substituted with `%` and
+ ensure legitimate `%` in names are still working"""
+ self.assertEqual(list(self.sfh5["1.1/instrument/positioners"].keys()),
+ ["Pslit%HGap", "MRTSlit%UP"])
+ # substituted "%"
+ self.assertIn("Pslit%HGap",
+ self.sfh5["1.1/instrument/positioners"])
+ self.assertNotIn("Pslit/HGap",
+ self.sfh5["1.1/instrument/positioners"])
+ self.assertAlmostEqual(
+ self.sfh5["1.1/instrument/positioners/Pslit%HGap"],
+ 180.005, places=4)
+ # legitimate "%"
+ self.assertIn("MRTSlit%UP",
+ self.sfh5["1.1/instrument/positioners"])
+
+
+def testUnitCellUBMatrix(tmp_path):
+ """Test unit cell (#G1) and UB matrix (#G3)"""
+ file_path = tmp_path / "spec.dat"
+ file_path.write_bytes(bytes("""
+#S 1 OK
+#G1 0 1 2 3 4 5
+#G3 0 1 2 3 4 5 6 7 8
+""", encoding="ascii"))
+ with SpecH5(str(file_path)) as spech5:
+ assert numpy.array_equal(
+ spech5["/1.1/sample/ub_matrix"],
+ numpy.arange(9).reshape(1, 3, 3))
+ assert numpy.array_equal(
+ spech5["/1.1/sample/unit_cell"], [[0, 1, 2, 3, 4, 5]])
+ assert numpy.array_equal(
+ spech5["/1.1/sample/unit_cell_abc"], [0, 1, 2])
+ assert numpy.array_equal(
+ spech5["/1.1/sample/unit_cell_alphabetagamma"], [3, 4, 5])
+
+
+def testMalformedUnitCellUBMatrix(tmp_path):
+ """Test malformed unit cell (#G1) and UB matrix (#G3): 1 value"""
+ file_path = tmp_path / "spec.dat"
+ file_path.write_bytes(bytes("""
+#S 1 all malformed=0
+#G1 0
+#G3 0
+""", encoding="ascii"))
+ with SpecH5(str(file_path)) as spech5:
+ assert "sample" not in spech5["1.1"]
+
+
+def testMalformedUBMatrix(tmp_path):
+ """Test malformed UB matrix (#G3): all zeros"""
+ file_path = tmp_path / "spec.dat"
+ file_path.write_bytes(bytes("""
+#S 1 G3 all 0
+#G1 0 1 2 3 4 5
+#G3 0 0 0 0 0 0 0 0 0
+""", encoding="ascii"))
+ with SpecH5(str(file_path)) as spech5:
+ assert "ub_matrix" not in spech5["/1.1/sample"]
+ assert numpy.array_equal(
+ spech5["/1.1/sample/unit_cell"], [[0, 1, 2, 3, 4, 5]])
+ assert numpy.array_equal(
+ spech5["/1.1/sample/unit_cell_abc"], [0, 1, 2])
+ assert numpy.array_equal(
+ spech5["/1.1/sample/unit_cell_alphabetagamma"], [3, 4, 5])
+
+
+def testMalformedUnitCell(tmp_path):
+ """Test malformed unit cell (#G1): missing values"""
+ file_path = tmp_path / "spec.dat"
+ file_path.write_bytes(bytes("""
+#S 1 G1 malformed missing values
+#G1 0 1 2
+#G3 0 1 2 3 4 5 6 7 8
+""", encoding="ascii"))
+ with SpecH5(str(file_path)) as spech5:
+ assert "unit_cell" not in spech5["/1.1/sample"]
+ assert "unit_cell_abc" not in spech5["/1.1/sample"]
+ assert "unit_cell_alphabetagamma" not in spech5["/1.1/sample"]
+ assert numpy.array_equal(
+ spech5["/1.1/sample/ub_matrix"],
+ numpy.arange(9).reshape(1, 3, 3))
diff --git a/src/silx/io/test/test_spectoh5.py b/src/silx/io/test/test_spectoh5.py
new file mode 100644
index 0000000..66bf8d6
--- /dev/null
+++ b/src/silx/io/test/test_spectoh5.py
@@ -0,0 +1,183 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for SpecFile to HDF5 converter"""
+
+from numpy import array_equal
+import os
+import sys
+import tempfile
+import unittest
+
+import h5py
+
+from ..spech5 import SpecH5, SpecH5Group
+from ..convert import convert, write_to_h5
+from ..utils import h5py_read_dataset
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "12/02/2018"
+
+
+sfdata = b"""#F /tmp/sf.dat
+#E 1455180875
+#D Thu Feb 11 09:54:35 2016
+#C imaging User = opid17
+#O0 Pslit HGap MRTSlit UP MRTSlit DOWN
+#O1 Sslit1 VOff Sslit1 HOff Sslit1 VGap
+#o0 pshg mrtu mrtd
+#o2 ss1vo ss1ho ss1vg
+
+#J0 Seconds IA ion.mono Current
+#J1 xbpmc2 idgap1 Inorm
+
+#S 1 ascan ss1vo -4.55687 -0.556875 40 0.2
+#D Thu Feb 11 09:55:20 2016
+#T 0.2 (Seconds)
+#P0 180.005 -0.66875 0.87125
+#P1 14.74255 16.197579 12.238283
+#N 4
+#L MRTSlit UP second column 3rd_col
+-1.23 5.89 8
+8.478100E+01 5 1.56
+3.14 2.73 -3.14
+1.2 2.3 3.4
+
+#S 1 aaaaaa
+#D Thu Feb 11 10:00:32 2016
+#@MCADEV 1
+#@MCA %16C
+#@CHANN 3 0 2 1
+#@CALIB 1 2 3
+#N 3
+#L uno duo
+1 2
+@A 0 1 2
+@A 10 9 8
+3 4
+@A 3.1 4 5
+@A 7 6 5
+5 6
+@A 6 7.7 8
+@A 4 3 2
+"""
+
+
+class TestConvertSpecHDF5(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ fd, cls.spec_fname = tempfile.mkstemp(prefix="TestConvertSpecHDF5")
+ os.write(fd, sfdata)
+ os.close(fd)
+
+ fd, cls.h5_fname = tempfile.mkstemp(prefix="TestConvertSpecHDF5")
+ # Close and delete (we just need the name)
+ os.close(fd)
+ os.unlink(cls.h5_fname)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.unlink(cls.spec_fname)
+
+ def setUp(self):
+ convert(self.spec_fname, self.h5_fname)
+
+ self.sfh5 = SpecH5(self.spec_fname)
+ self.h5f = h5py.File(self.h5_fname, "a")
+
+ def tearDown(self):
+ self.h5f.close()
+ self.sfh5.close()
+ os.unlink(self.h5_fname)
+
+ def testAppendToHDF5(self):
+ write_to_h5(self.sfh5, self.h5f, h5path="/foo/bar/spam")
+ self.assertTrue(
+ array_equal(self.h5f["/1.2/measurement/mca_1/data"],
+ self.h5f["/foo/bar/spam/1.2/measurement/mca_1/data"])
+ )
+
+ def testWriteSpecH5Group(self):
+ """Test passing a SpecH5Group as parameter, instead of a Spec filename
+ or a SpecH5."""
+ g = self.sfh5["1.1/instrument"]
+ self.assertIsInstance(g, SpecH5Group) # let's be paranoid
+ write_to_h5(g, self.h5f, h5path="my instruments")
+
+ self.assertAlmostEqual(self.h5f["my instruments/positioners/Sslit1 HOff"][tuple()],
+ 16.197579, places=4)
+
+ def testTitle(self):
+ """Test the value of a dataset"""
+ title12 = h5py_read_dataset(self.h5f["/1.2/title"])
+ self.assertEqual(title12,
+ u"aaaaaa")
+
+ def testAttrs(self):
+ # Test root group (file) attributes
+ self.assertEqual(self.h5f.attrs["NX_class"],
+ u"NXroot")
+ # Test dataset attributes
+ ds = self.h5f["/1.2/instrument/mca_1/data"]
+ self.assertTrue("interpretation" in ds.attrs)
+ self.assertEqual(list(ds.attrs.values()),
+ [u"spectrum"])
+ # Test group attributes
+ grp = self.h5f["1.1"]
+ self.assertEqual(grp.attrs["NX_class"],
+ u"NXentry")
+ self.assertEqual(len(list(grp.attrs.keys())),
+ 1)
+
+ def testHdf5HasSameMembers(self):
+ spec_member_list = []
+
+ def append_spec_members(name):
+ spec_member_list.append(name)
+ self.sfh5.visit(append_spec_members)
+
+ hdf5_member_list = []
+
+ def append_hdf5_members(name):
+ hdf5_member_list.append(name)
+ self.h5f.visit(append_hdf5_members)
+
+ # 1. For some reason, h5py visit method doesn't include the leading
+ # "/" character when it passes the member name to the function,
+ # even though an explicit the .name attribute of a member will
+ # have a leading "/"
+ spec_member_list = [m.lstrip("/") for m in spec_member_list]
+
+ self.assertEqual(set(hdf5_member_list),
+ set(spec_member_list))
+
+ def testLinks(self):
+ self.assertTrue(
+ array_equal(self.sfh5["/1.2/measurement/mca_0/data"],
+ self.h5f["/1.2/measurement/mca_0/data"])
+ )
+ self.assertTrue(
+ array_equal(self.h5f["/1.2/instrument/mca_1/channels"],
+ self.h5f["/1.2/measurement/mca_1/info/channels"])
+ )
diff --git a/src/silx/io/test/test_url.py b/src/silx/io/test/test_url.py
new file mode 100644
index 0000000..7346391
--- /dev/null
+++ b/src/silx/io/test/test_url.py
@@ -0,0 +1,217 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for url module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/01/2018"
+
+
+import unittest
+from ..url import DataUrl
+
+
+class TestDataUrl(unittest.TestCase):
+
+ def assertUrl(self, url, expected):
+ self.assertEqual(url.is_valid(), expected[0])
+ self.assertEqual(url.is_absolute(), expected[1])
+ self.assertEqual(url.scheme(), expected[2])
+ self.assertEqual(url.file_path(), expected[3])
+ self.assertEqual(url.data_path(), expected[4])
+ self.assertEqual(url.data_slice(), expected[5])
+
+ def test_fabio_absolute(self):
+ url = DataUrl("fabio:///data/image.edf?slice=2")
+ expected = [True, True, "fabio", "/data/image.edf", None, (2, )]
+ self.assertUrl(url, expected)
+
+ def test_fabio_absolute_windows(self):
+ url = DataUrl("fabio:///C:/data/image.edf?slice=2")
+ expected = [True, True, "fabio", "C:/data/image.edf", None, (2, )]
+ self.assertUrl(url, expected)
+
+ def test_silx_absolute(self):
+ url = DataUrl("silx:///data/image.h5?path=/data/dataset&slice=1,5")
+ expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
+ self.assertUrl(url, expected)
+
+ def test_commandline_shell_separator(self):
+ url = DataUrl("silx:///data/image.h5::path=/data/dataset&slice=1,5")
+ expected = [True, True, "silx", "/data/image.h5", "/data/dataset", (1, 5)]
+ self.assertUrl(url, expected)
+
+ def test_silx_absolute2(self):
+ url = DataUrl("silx:///data/image.edf?/scan_0/detector/data")
+ expected = [True, True, "silx", "/data/image.edf", "/scan_0/detector/data", None]
+ self.assertUrl(url, expected)
+
+ def test_silx_absolute_windows(self):
+ url = DataUrl("silx:///C:/data/image.h5?/scan_0/detector/data")
+ expected = [True, True, "silx", "C:/data/image.h5", "/scan_0/detector/data", None]
+ self.assertUrl(url, expected)
+
+ def test_silx_relative(self):
+ url = DataUrl("silx:./image.h5")
+ expected = [True, False, "silx", "./image.h5", None, None]
+ self.assertUrl(url, expected)
+
+ def test_fabio_relative(self):
+ url = DataUrl("fabio:./image.edf")
+ expected = [True, False, "fabio", "./image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_silx_relative2(self):
+ url = DataUrl("silx:image.h5")
+ expected = [True, False, "silx", "image.h5", None, None]
+ self.assertUrl(url, expected)
+
+ def test_fabio_relative2(self):
+ url = DataUrl("fabio:image.edf")
+ expected = [True, False, "fabio", "image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_file_relative(self):
+ url = DataUrl("image.edf")
+ expected = [True, False, None, "image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_file_relative2(self):
+ url = DataUrl("./foo/bar/image.edf")
+ expected = [True, False, None, "./foo/bar/image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_file_relative3(self):
+ url = DataUrl("foo/bar/image.edf")
+ expected = [True, False, None, "foo/bar/image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_file_absolute(self):
+ url = DataUrl("/data/image.edf")
+ expected = [True, True, None, "/data/image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_file_absolute_windows(self):
+ url = DataUrl("C:/data/image.edf")
+ expected = [True, True, None, "C:/data/image.edf", None, None]
+ self.assertUrl(url, expected)
+
+ def test_absolute_with_path(self):
+ url = DataUrl("/foo/foobar.h5?/foo/bar")
+ expected = [True, True, None, "/foo/foobar.h5", "/foo/bar", None]
+ self.assertUrl(url, expected)
+
+ def test_windows_file_data_slice(self):
+ url = DataUrl("C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [True, True, None, "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
+ self.assertUrl(url, expected)
+
+ def test_scheme_file_data_slice(self):
+ url = DataUrl("silx:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [True, True, "silx", "/foo/foobar.h5", "/foo/bar", (5, 1)]
+ self.assertUrl(url, expected)
+
+ def test_scheme_windows_file_data_slice(self):
+ url = DataUrl("silx:C:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [True, True, "silx", "C:/foo/foobar.h5", "/foo/bar", (5, 1)]
+ self.assertUrl(url, expected)
+
+ def test_empty(self):
+ url = DataUrl("")
+ expected = [False, False, None, "", None, None]
+ self.assertUrl(url, expected)
+
+ def test_unknown_scheme(self):
+ url = DataUrl("foo:/foo/foobar.h5?path=/foo/bar&slice=5,1")
+ expected = [False, True, "foo", "/foo/foobar.h5", "/foo/bar", (5, 1)]
+ self.assertUrl(url, expected)
+
+ def test_slice(self):
+ url = DataUrl("/a.h5?path=/b&slice=5,1")
+ expected = [True, True, None, "/a.h5", "/b", (5, 1)]
+ self.assertUrl(url, expected)
+
+ def test_slice2(self):
+ url = DataUrl("/a.h5?path=/b&slice=2:5")
+ expected = [True, True, None, "/a.h5", "/b", (slice(2, 5),)]
+ self.assertUrl(url, expected)
+
+ def test_slice3(self):
+ url = DataUrl("/a.h5?path=/b&slice=::2")
+ expected = [True, True, None, "/a.h5", "/b", (slice(None, None, 2),)]
+ self.assertUrl(url, expected)
+
+ def test_slice_ellipsis(self):
+ url = DataUrl("/a.h5?path=/b&slice=...")
+ expected = [True, True, None, "/a.h5", "/b", (Ellipsis, )]
+ self.assertUrl(url, expected)
+
+ def test_slice_slicing(self):
+ url = DataUrl("/a.h5?path=/b&slice=:")
+ expected = [True, True, None, "/a.h5", "/b", (slice(None), )]
+ self.assertUrl(url, expected)
+
+ def test_slice_missing_element(self):
+ url = DataUrl("/a.h5?path=/b&slice=5,,1")
+ expected = [False, True, None, "/a.h5", "/b", None]
+ self.assertUrl(url, expected)
+
+ def test_slice_no_elements(self):
+ url = DataUrl("/a.h5?path=/b&slice=")
+ expected = [False, True, None, "/a.h5", "/b", None]
+ self.assertUrl(url, expected)
+
+ def test_create_relative_url(self):
+ url = DataUrl(scheme="silx", file_path="./foo.h5", data_path="/", data_slice=(5, 1))
+ self.assertFalse(url.is_absolute())
+ url2 = DataUrl(url.path())
+ self.assertEqual(url, url2)
+
+ def test_create_absolute_url(self):
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1))
+ url2 = DataUrl(url.path())
+ self.assertEqual(url, url2)
+
+ def test_create_absolute_windows_url(self):
+ url = DataUrl(scheme="silx", file_path="C:/foo.h5", data_path="/", data_slice=(5, 1))
+ url2 = DataUrl(url.path())
+ self.assertEqual(url, url2)
+
+ def test_create_slice_url(self):
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_path="/", data_slice=(5, 1, Ellipsis, slice(None)))
+ url2 = DataUrl(url.path())
+ self.assertEqual(url, url2)
+
+ def test_wrong_url(self):
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=(5, 1))
+ self.assertFalse(url.is_valid())
+
+ def test_path_creation(self):
+ """make sure the construction of path succeed and that we can
+ recreate a DataUrl from a path"""
+ for data_slice in (1, (1,)):
+ with self.subTest(data_slice=data_slice):
+ url = DataUrl(scheme="silx", file_path="/foo.h5", data_slice=data_slice)
+ path = url.path()
+ DataUrl(path=path)
diff --git a/src/silx/io/test/test_utils.py b/src/silx/io/test/test_utils.py
new file mode 100644
index 0000000..cc34100
--- /dev/null
+++ b/src/silx/io/test/test_utils.py
@@ -0,0 +1,923 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for utils module"""
+
+import io
+import numpy
+import os
+import re
+import shutil
+import tempfile
+import unittest
+import sys
+
+from .. import utils
+from ..._version import calc_hexversion
+import silx.io.url
+
+import h5py
+from ..utils import h5ls
+from silx.io import commonh5
+
+import fabio
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "03/12/2020"
+
+expected_spec1 = r"""#F .*
+#D .*
+
+#S 1 Ordinate1
+#D .*
+#N 2
+#L Abscissa Ordinate1
+1 4\.00
+2 5\.00
+3 6\.00
+"""
+
+expected_spec2 = expected_spec1 + r"""
+#S 2 Ordinate2
+#D .*
+#N 2
+#L Abscissa Ordinate2
+1 7\.00
+2 8\.00
+3 9\.00
+"""
+
+expected_spec2reg = r"""#F .*
+#D .*
+
+#S 1 Ordinate1
+#D .*
+#N 3
+#L Abscissa Ordinate1 Ordinate2
+1 4\.00 7\.00
+2 5\.00 8\.00
+3 6\.00 9\.00
+"""
+
+expected_spec2irr = expected_spec1 + r"""
+#S 2 Ordinate2
+#D .*
+#N 2
+#L Abscissa Ordinate2
+1 7\.00
+2 8\.00
+"""
+
+expected_csv = r"""Abscissa;Ordinate1;Ordinate2
+1;4\.00;7\.00e\+00
+2;5\.00;8\.00e\+00
+3;6\.00;9\.00e\+00
+"""
+
+expected_csv2 = r"""x;y0;y1
+1;4\.00;7\.00e\+00
+2;5\.00;8\.00e\+00
+3;6\.00;9\.00e\+00
+"""
+
+
+class TestSave(unittest.TestCase):
+ """Test saving curves as SpecFile:
+ """
+
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.spec_fname = os.path.join(self.tempdir, "savespec.dat")
+ self.csv_fname = os.path.join(self.tempdir, "savecsv.csv")
+ self.npy_fname = os.path.join(self.tempdir, "savenpy.npy")
+
+ self.x = [1, 2, 3]
+ self.xlab = "Abscissa"
+ self.y = [[4, 5, 6], [7, 8, 9]]
+ self.y_irr = [[4, 5, 6], [7, 8]]
+ self.ylabs = ["Ordinate1", "Ordinate2"]
+
+ def tearDown(self):
+ if os.path.isfile(self.spec_fname):
+ os.unlink(self.spec_fname)
+ if os.path.isfile(self.csv_fname):
+ os.unlink(self.csv_fname)
+ if os.path.isfile(self.npy_fname):
+ os.unlink(self.npy_fname)
+ shutil.rmtree(self.tempdir)
+
+ def test_save_csv(self):
+ utils.save1D(self.csv_fname, self.x, self.y,
+ xlabel=self.xlab, ylabels=self.ylabs,
+ filetype="csv", fmt=["%d", "%.2f", "%.2e"],
+ csvdelim=";", autoheader=True)
+
+ csvf = open(self.csv_fname)
+ actual_csv = csvf.read()
+ csvf.close()
+
+ self.assertRegex(actual_csv, expected_csv)
+
+ def test_save_npy(self):
+ """npy file is saved with numpy.save after building a numpy array
+ and converting it to a named record array"""
+ npyf = open(self.npy_fname, "wb")
+ utils.save1D(npyf, self.x, self.y,
+ xlabel=self.xlab, ylabels=self.ylabs)
+ npyf.close()
+
+ npy_recarray = numpy.load(self.npy_fname)
+
+ self.assertEqual(npy_recarray.shape, (3,))
+ self.assertTrue(numpy.array_equal(npy_recarray['Ordinate1'],
+ numpy.array((4, 5, 6))))
+
+ def test_savespec_filename(self):
+ """Save SpecFile using savespec()"""
+ utils.savespec(self.spec_fname, self.x, self.y[0], xlabel=self.xlab,
+ ylabel=self.ylabs[0], fmt=["%d", "%.2f"],
+ close_file=True, scan_number=1)
+
+ specf = open(self.spec_fname)
+ actual_spec = specf.read()
+ specf.close()
+ self.assertRegex(actual_spec, expected_spec1)
+
+ def test_savespec_file_handle(self):
+ """Save SpecFile using savespec(), passing a file handle"""
+ # first savespec: open, write file header, save y[0] as scan 1,
+ # return file handle
+ specf = utils.savespec(self.spec_fname, self.x, self.y[0],
+ xlabel=self.xlab, ylabel=self.ylabs[0],
+ fmt=["%d", "%.2f"], close_file=False)
+
+ # second savespec: save y[1] as scan 2, close file
+ utils.savespec(specf, self.x, self.y[1], xlabel=self.xlab,
+ ylabel=self.ylabs[1], fmt=["%d", "%.2f"],
+ write_file_header=False, close_file=True,
+ scan_number=2)
+
+ specf = open(self.spec_fname)
+ actual_spec = specf.read()
+ specf.close()
+ self.assertRegex(actual_spec, expected_spec2)
+
+ def test_save_spec_reg(self):
+ """Save SpecFile using save() on a regular pattern"""
+ utils.save1D(self.spec_fname, self.x, self.y, xlabel=self.xlab,
+ ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
+
+ specf = open(self.spec_fname)
+ actual_spec = specf.read()
+ specf.close()
+
+ self.assertRegex(actual_spec, expected_spec2reg)
+
+ def test_save_spec_irr(self):
+ """Save SpecFile using save() on an irregular pattern"""
+ # invalid test case ?!
+ return
+ utils.save1D(self.spec_fname, self.x, self.y_irr, xlabel=self.xlab,
+ ylabels=self.ylabs, filetype="spec", fmt=["%d", "%.2f"])
+
+ specf = open(self.spec_fname)
+ actual_spec = specf.read()
+ specf.close()
+ self.assertRegex(actual_spec, expected_spec2irr)
+
+ def test_save_csv_no_labels(self):
+ """Save csv using save(), with autoheader=True but
+ xlabel=None and ylabels=None
+ This is a non-regression test for bug #223"""
+ self.tempdir = tempfile.mkdtemp()
+ self.spec_fname = os.path.join(self.tempdir, "savespec.dat")
+ self.csv_fname = os.path.join(self.tempdir, "savecsv.csv")
+ self.npy_fname = os.path.join(self.tempdir, "savenpy.npy")
+
+ self.x = [1, 2, 3]
+ self.xlab = "Abscissa"
+ self.y = [[4, 5, 6], [7, 8, 9]]
+ self.ylabs = ["Ordinate1", "Ordinate2"]
+ utils.save1D(self.csv_fname, self.x, self.y,
+ autoheader=True, fmt=["%d", "%.2f", "%.2e"])
+
+ csvf = open(self.csv_fname)
+ actual_csv = csvf.read()
+ csvf.close()
+ self.assertRegex(actual_csv, expected_csv2)
+
+
+def assert_match_any_string_in_list(test, pattern, list_of_strings):
+ for string_ in list_of_strings:
+ if re.match(pattern, string_):
+ return True
+ return False
+
+
+class TestH5Ls(unittest.TestCase):
+ """Test displaying the following HDF5 file structure:
+
+ +foo
+ +bar
+ <HDF5 dataset "spam": shape (2, 2), type "<i8">
+ <HDF5 dataset "tmp": shape (3,), type "<i8">
+ <HDF5 dataset "data": shape (1,), type "<f8">
+
+ """
+
+ def assertMatchAnyStringInList(self, pattern, list_of_strings):
+ for string_ in list_of_strings:
+ if re.match(pattern, string_):
+ return None
+ raise AssertionError("regex pattern %s does not match any" % pattern +
+ " string in list " + str(list_of_strings))
+
+ def testHdf5(self):
+ fd, self.h5_fname = tempfile.mkstemp(text=False)
+ # Close and delete (we just want the name)
+ os.close(fd)
+ os.unlink(self.h5_fname)
+ self.h5f = h5py.File(self.h5_fname, "w")
+ self.h5f["/foo/bar/tmp"] = [1, 2, 3]
+ self.h5f["/foo/bar/spam"] = [[1, 2], [3, 4]]
+ self.h5f["/foo/data"] = [3.14]
+ self.h5f.close()
+
+ rep = h5ls(self.h5_fname)
+ lines = rep.split("\n")
+
+ self.assertIn("+foo", lines)
+ self.assertIn("\t+bar", lines)
+
+ match = r'\t\t<HDF5 dataset "tmp": shape \(3,\), type "<i[48]">'
+ self.assertMatchAnyStringInList(match, lines)
+ match = r'\t\t<HDF5 dataset "spam": shape \(2, 2\), type "<i[48]">'
+ self.assertMatchAnyStringInList(match, lines)
+ match = r'\t<HDF5 dataset "data": shape \(1,\), type "<f[48]">'
+ self.assertMatchAnyStringInList(match, lines)
+
+ os.unlink(self.h5_fname)
+
+ # Following test case disabled d/t errors on AppVeyor:
+ # os.unlink(spec_fname)
+ # PermissionError: [WinError 32] The process cannot access the file because
+ # it is being used by another process: 'C:\\...\\savespec.dat'
+
+ # def testSpec(self):
+ # tempdir = tempfile.mkdtemp()
+ # spec_fname = os.path.join(tempdir, "savespec.dat")
+ #
+ # x = [1, 2, 3]
+ # xlab = "Abscissa"
+ # y = [[4, 5, 6], [7, 8, 9]]
+ # ylabs = ["Ordinate1", "Ordinate2"]
+ # utils.save1D(spec_fname, x, y, xlabel=xlab,
+ # ylabels=ylabs, filetype="spec",
+ # fmt=["%d", "%.2f"])
+ #
+ # rep = h5ls(spec_fname)
+ # lines = rep.split("\n")
+ # self.assertIn("+1.1", lines)
+ # self.assertIn("\t+instrument", lines)
+ #
+ # self.assertMatchAnyStringInList(
+ # r'\t\t\t<SPEC dataset "file_header": shape \(\), type "|S60">',
+ # lines)
+ # self.assertMatchAnyStringInList(
+ # r'\t\t<SPEC dataset "Ordinate1": shape \(3L?,\), type "<f4">',
+ # lines)
+ #
+ # os.unlink(spec_fname)
+ # shutil.rmtree(tempdir)
+
+
+class TestOpen(unittest.TestCase):
+ """Test `silx.io.utils.open` function."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tmp_directory = tempfile.mkdtemp()
+ cls.createResources(cls.tmp_directory)
+
+ @classmethod
+ def createResources(cls, directory):
+
+ cls.h5_filename = os.path.join(directory, "test.h5")
+ h5 = h5py.File(cls.h5_filename, mode="w")
+ h5["group/group/dataset"] = 50
+ h5.close()
+
+ cls.spec_filename = os.path.join(directory, "test.dat")
+ utils.savespec(cls.spec_filename, [1], [1.1], xlabel="x", ylabel="y",
+ fmt=["%d", "%.2f"], close_file=True, scan_number=1)
+
+ cls.edf_filename = os.path.join(directory, "test.edf")
+ header = fabio.fabioimage.OrderedDict()
+ header["integer"] = "10"
+ data = numpy.array([[10, 50], [50, 10]])
+ fabiofile = fabio.edfimage.EdfImage(data, header)
+ fabiofile.write(cls.edf_filename)
+
+ cls.txt_filename = os.path.join(directory, "test.txt")
+ f = io.open(cls.txt_filename, "w+t")
+ f.write(u"Kikoo")
+ f.close()
+
+ cls.missing_filename = os.path.join(directory, "test.missing")
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmp_directory)
+
+ def testH5(self):
+ f = utils.open(self.h5_filename)
+ self.assertIsNotNone(f)
+ self.assertIsInstance(f, h5py.File)
+ f.close()
+
+ def testH5With(self):
+ with utils.open(self.h5_filename) as f:
+ self.assertIsNotNone(f)
+ self.assertIsInstance(f, h5py.File)
+
+ def testH5_withPath(self):
+ f = utils.open(self.h5_filename + "::/group/group/dataset")
+ self.assertIsNotNone(f)
+ self.assertEqual(f.h5py_class, h5py.Dataset)
+ self.assertEqual(f[()], 50)
+ f.close()
+
+ def testH5With_withPath(self):
+ with utils.open(self.h5_filename + "::/group/group") as f:
+ self.assertIsNotNone(f)
+ self.assertEqual(f.h5py_class, h5py.Group)
+ self.assertIn("dataset", f)
+
+ def testSpec(self):
+ f = utils.open(self.spec_filename)
+ self.assertIsNotNone(f)
+ self.assertEqual(f.h5py_class, h5py.File)
+ f.close()
+
+ def testSpecWith(self):
+ with utils.open(self.spec_filename) as f:
+ self.assertIsNotNone(f)
+ self.assertEqual(f.h5py_class, h5py.File)
+
+ def testEdf(self):
+ f = utils.open(self.edf_filename)
+ self.assertIsNotNone(f)
+ self.assertEqual(f.h5py_class, h5py.File)
+ f.close()
+
+ def testEdfWith(self):
+ with utils.open(self.edf_filename) as f:
+ self.assertIsNotNone(f)
+ self.assertEqual(f.h5py_class, h5py.File)
+
+ def testUnsupported(self):
+ self.assertRaises(IOError, utils.open, self.txt_filename)
+
+ def testNotExists(self):
+ # load it
+ self.assertRaises(IOError, utils.open, self.missing_filename)
+
+ def test_silx_scheme(self):
+ url = silx.io.url.DataUrl(scheme="silx", file_path=self.h5_filename, data_path="/")
+ with utils.open(url.path()) as f:
+ self.assertIsNotNone(f)
+ self.assertTrue(silx.io.utils.is_file(f))
+
+ def test_fabio_scheme(self):
+ url = silx.io.url.DataUrl(scheme="fabio", file_path=self.edf_filename)
+ self.assertRaises(IOError, utils.open, url.path())
+
+ def test_bad_url(self):
+ url = silx.io.url.DataUrl(scheme="sil", file_path=self.h5_filename)
+ self.assertRaises(IOError, utils.open, url.path())
+
+ def test_sliced_url(self):
+ url = silx.io.url.DataUrl(file_path=self.h5_filename, data_slice=(5,))
+ self.assertRaises(IOError, utils.open, url.path())
+
+
+class TestNodes(unittest.TestCase):
+ """Test `silx.io.utils.is_` functions."""
+
+ def test_real_h5py_objects(self):
+ name = tempfile.mktemp(suffix=".h5")
+ try:
+ with h5py.File(name, "w") as h5file:
+ h5group = h5file.create_group("arrays")
+ h5dataset = h5group.create_dataset("scalar", data=10)
+
+ self.assertTrue(utils.is_file(h5file))
+ self.assertTrue(utils.is_group(h5file))
+ self.assertFalse(utils.is_dataset(h5file))
+
+ self.assertFalse(utils.is_file(h5group))
+ self.assertTrue(utils.is_group(h5group))
+ self.assertFalse(utils.is_dataset(h5group))
+
+ self.assertFalse(utils.is_file(h5dataset))
+ self.assertFalse(utils.is_group(h5dataset))
+ self.assertTrue(utils.is_dataset(h5dataset))
+ finally:
+ os.unlink(name)
+
+ def test_h5py_like_file(self):
+
+ class Foo(object):
+
+ def __init__(self):
+ self.h5_class = utils.H5Type.FILE
+
+ obj = Foo()
+ self.assertTrue(utils.is_file(obj))
+ self.assertTrue(utils.is_group(obj))
+ self.assertFalse(utils.is_dataset(obj))
+
+ def test_h5py_like_group(self):
+
+ class Foo(object):
+
+ def __init__(self):
+ self.h5_class = utils.H5Type.GROUP
+
+ obj = Foo()
+ self.assertFalse(utils.is_file(obj))
+ self.assertTrue(utils.is_group(obj))
+ self.assertFalse(utils.is_dataset(obj))
+
+ def test_h5py_like_dataset(self):
+
+ class Foo(object):
+
+ def __init__(self):
+ self.h5_class = utils.H5Type.DATASET
+
+ obj = Foo()
+ self.assertFalse(utils.is_file(obj))
+ self.assertFalse(utils.is_group(obj))
+ self.assertTrue(utils.is_dataset(obj))
+
+ def test_bad(self):
+
+ class Foo(object):
+
+ def __init__(self):
+ pass
+
+ obj = Foo()
+ self.assertFalse(utils.is_file(obj))
+ self.assertFalse(utils.is_group(obj))
+ self.assertFalse(utils.is_dataset(obj))
+
+ def test_bad_api(self):
+
+ class Foo(object):
+
+ def __init__(self):
+ self.h5_class = int
+
+ obj = Foo()
+ self.assertFalse(utils.is_file(obj))
+ self.assertFalse(utils.is_group(obj))
+ self.assertFalse(utils.is_dataset(obj))
+
+
+class TestGetData(unittest.TestCase):
+ """Test `silx.io.utils.get_data` function."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tmp_directory = tempfile.mkdtemp()
+ cls.createResources(cls.tmp_directory)
+
+ @classmethod
+ def createResources(cls, directory):
+
+ cls.h5_filename = os.path.join(directory, "test.h5")
+ h5 = h5py.File(cls.h5_filename, mode="w")
+ h5["group/group/scalar"] = 50
+ h5["group/group/array"] = [1, 2, 3, 4, 5]
+ h5["group/group/array2d"] = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
+ h5.close()
+
+ cls.spec_filename = os.path.join(directory, "test.dat")
+ utils.savespec(cls.spec_filename, [1], [1.1], xlabel="x", ylabel="y",
+ fmt=["%d", "%.2f"], close_file=True, scan_number=1)
+
+ cls.edf_filename = os.path.join(directory, "test.edf")
+ cls.edf_multiframe_filename = os.path.join(directory, "test_multi.edf")
+ header = fabio.fabioimage.OrderedDict()
+ header["integer"] = "10"
+ data = numpy.array([[10, 50], [50, 10]])
+ fabiofile = fabio.edfimage.EdfImage(data, header)
+ fabiofile.write(cls.edf_filename)
+ fabiofile.append_frame(data=data, header=header)
+ fabiofile.write(cls.edf_multiframe_filename)
+
+ cls.txt_filename = os.path.join(directory, "test.txt")
+ f = io.open(cls.txt_filename, "w+t")
+ f.write(u"Kikoo")
+ f.close()
+
+ cls.missing_filename = os.path.join(directory, "test.missing")
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmp_directory)
+
+ def test_hdf5_scalar(self):
+ url = "silx:%s?/group/group/scalar" % self.h5_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data, 50)
+
+ def test_hdf5_array(self):
+ url = "silx:%s?/group/group/array" % self.h5_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data.shape, (5,))
+ self.assertEqual(data[0], 1)
+
+ def test_hdf5_array_slice(self):
+ url = "silx:%s?path=/group/group/array2d&slice=1" % self.h5_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data.shape, (5,))
+ self.assertEqual(data[0], 6)
+
+ def test_hdf5_array_slice_out_of_range(self):
+ url = "silx:%s?path=/group/group/array2d&slice=5" % self.h5_filename
+ # ValueError: h5py 2.x
+ # IndexError: h5py 3.x
+ self.assertRaises((ValueError, IndexError), utils.get_data, url)
+
+ def test_edf_using_silx(self):
+ url = "silx:%s?/scan_0/instrument/detector_0/data" % self.edf_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data.shape, (2, 2))
+ self.assertEqual(data[0, 0], 10)
+
+ def test_fabio_frame(self):
+ url = "fabio:%s?slice=1" % self.edf_multiframe_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data.shape, (2, 2))
+ self.assertEqual(data[0, 0], 10)
+
+ def test_fabio_singleframe(self):
+ url = "fabio:%s?slice=0" % self.edf_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data.shape, (2, 2))
+ self.assertEqual(data[0, 0], 10)
+
+ def test_fabio_too_much_frames(self):
+ url = "fabio:%s?slice=..." % self.edf_multiframe_filename
+ self.assertRaises(ValueError, utils.get_data, url)
+
+ def test_fabio_no_frame(self):
+ url = "fabio:%s" % self.edf_filename
+ data = utils.get_data(url=url)
+ self.assertEqual(data.shape, (2, 2))
+ self.assertEqual(data[0, 0], 10)
+
+ def test_unsupported_scheme(self):
+ url = "foo:/foo/bar"
+ self.assertRaises(ValueError, utils.get_data, url)
+
+ def test_no_scheme(self):
+ url = "%s?path=/group/group/array2d&slice=5" % self.h5_filename
+ self.assertRaises((ValueError, IOError), utils.get_data, url)
+
+ def test_file_not_exists(self):
+ url = "silx:/foo/bar"
+ self.assertRaises(IOError, utils.get_data, url)
+
+
+def _h5_py_version_older_than(version):
+ v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
+ r_majeur, r_mineur, r_micro = [int(i) for i in version.split('.')]
+ return calc_hexversion(v_majeur, v_mineur, v_micro) >= calc_hexversion(r_majeur, r_mineur, r_micro)
+
+
+@unittest.skipUnless(_h5_py_version_older_than('2.9.0'), 'h5py version < 2.9.0')
+class TestRawFileToH5(unittest.TestCase):
+ """Test conversion of .vol file to .h5 external dataset"""
+
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self._vol_file = os.path.join(self.tempdir, 'test_vol.vol')
+ self._file_info = os.path.join(self.tempdir, 'test_vol.info.vol')
+ self._dataset_shape = 100, 20, 5
+ data = numpy.random.random(self._dataset_shape[0] *
+ self._dataset_shape[1] *
+ self._dataset_shape[2]).astype(dtype=numpy.float32).reshape(self._dataset_shape)
+ numpy.save(file=self._vol_file, arr=data)
+ # those are storing into .noz file
+ assert os.path.exists(self._vol_file + '.npy')
+ os.rename(self._vol_file + '.npy', self._vol_file)
+ self.h5_file = os.path.join(self.tempdir, 'test_h5.h5')
+ self.external_dataset_path = '/root/my_external_dataset'
+ self._data_url = silx.io.url.DataUrl(file_path=self.h5_file,
+ data_path=self.external_dataset_path)
+ with open(self._file_info, 'w') as _fi:
+ _fi.write('NUM_X = %s\n' % self._dataset_shape[2])
+ _fi.write('NUM_Y = %s\n' % self._dataset_shape[1])
+ _fi.write('NUM_Z = %s\n' % self._dataset_shape[0])
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def check_dataset(self, h5_file, data_path, shape):
+ """Make sure the external dataset is valid"""
+ with h5py.File(h5_file, 'r') as _file:
+ return data_path in _file and _file[data_path].shape == shape
+
+ def test_h5_file_not_existing(self):
+ """Test that can create a file with external dataset from scratch"""
+ utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32)
+ self.assertTrue(self.check_dataset(h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape))
+ os.remove(self.h5_file)
+ utils.vol_to_h5_external_dataset(vol_file=self._vol_file,
+ output_url=self._data_url,
+ info_file=self._file_info)
+ self.assertTrue(self.check_dataset(h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape))
+
+ def test_h5_file_existing(self):
+ """Test that can add the external dataset from an existing file"""
+ with h5py.File(self.h5_file, 'w') as _file:
+ _file['/root/dataset1'] = numpy.zeros((100, 100))
+ _file['/root/group/dataset2'] = numpy.ones((100, 100))
+ utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32)
+ self.assertTrue(self.check_dataset(h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape))
+
+ def test_vol_file_not_existing(self):
+ """Make sure error is raised if .vol file does not exists"""
+ os.remove(self._vol_file)
+ utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32)
+
+ self.assertTrue(self.check_dataset(h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape))
+
+ def test_conflicts(self):
+ """Test several conflict cases"""
+ # test if path already exists
+ utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ dtype=numpy.float32)
+ with self.assertRaises(ValueError):
+ utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ overwrite=False,
+ dtype=numpy.float32)
+
+ utils.rawfile_to_h5_external_dataset(bin_file=self._vol_file,
+ output_url=self._data_url,
+ shape=(100, 20, 5),
+ overwrite=True,
+ dtype=numpy.float32)
+
+ self.assertTrue(self.check_dataset(h5_file=self.h5_file,
+ data_path=self.external_dataset_path,
+ shape=self._dataset_shape))
+
+
+class TestH5Strings(unittest.TestCase):
+ """Test HDF5 str and bytes writing and reading"""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.vlenstr = h5py.special_dtype(vlen=str)
+ cls.vlenbytes = h5py.special_dtype(vlen=bytes)
+ try:
+ cls.unicode = unicode
+ except NameError:
+ cls.unicode = str
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir)
+
+ def setUp(self):
+ self.file = h5py.File(os.path.join(self.tempdir, 'file.h5'), mode="w")
+
+ def tearDown(self):
+ self.file.close()
+
+ @classmethod
+ def _make_array(cls, value, n):
+ if isinstance(value, bytes):
+ dtype = cls.vlenbytes
+ elif isinstance(value, cls.unicode):
+ dtype = cls.vlenstr
+ else:
+ return numpy.array([value] * n)
+ return numpy.array([value] * n, dtype=dtype)
+
+ @classmethod
+ def _get_charset(cls, value):
+ if isinstance(value, bytes):
+ return h5py.h5t.CSET_ASCII
+ elif isinstance(value, cls.unicode):
+ return h5py.h5t.CSET_UTF8
+ else:
+ return None
+
+ def _check_dataset(self, value, result=None):
+ # Write+read scalar
+ if result:
+ decode_ascii = True
+ else:
+ decode_ascii = False
+ result = value
+ charset = self._get_charset(value)
+ self.file["data"] = value
+ data = utils.h5py_read_dataset(self.file["data"], decode_ascii=decode_ascii)
+ assert type(data) == type(result), data
+ assert data == result, data
+ if charset:
+ assert self.file["data"].id.get_type().get_cset() == charset
+
+ # Write+read variable length
+ self.file["vlen_data"] = self._make_array(value, 2)
+ data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii, index=0)
+ assert type(data) == type(result), data
+ assert data == result, data
+ data = utils.h5py_read_dataset(self.file["vlen_data"], decode_ascii=decode_ascii)
+ numpy.testing.assert_array_equal(data, [result] * 2)
+ if charset:
+ assert self.file["vlen_data"].id.get_type().get_cset() == charset
+
+ def _check_attribute(self, value, result=None):
+ if result:
+ decode_ascii = True
+ else:
+ decode_ascii = False
+ result = value
+ self.file.attrs["data"] = value
+ data = utils.h5py_read_attribute(self.file.attrs, "data", decode_ascii=decode_ascii)
+ assert type(data) == type(result), data
+ assert data == result, data
+
+ self.file.attrs["vlen_data"] = self._make_array(value, 2)
+ data = utils.h5py_read_attribute(self.file.attrs, "vlen_data", decode_ascii=decode_ascii)
+ assert type(data[0]) == type(result), data[0]
+ assert data[0] == result, data[0]
+ numpy.testing.assert_array_equal(data, [result] * 2)
+
+ data = utils.h5py_read_attributes(self.file.attrs, decode_ascii=decode_ascii)["vlen_data"]
+ assert type(data[0]) == type(result), data[0]
+ assert data[0] == result, data[0]
+ numpy.testing.assert_array_equal(data, [result] * 2)
+
+ def test_dataset_ascii_bytes(self):
+ self._check_dataset(b"abc")
+
+ def test_attribute_ascii_bytes(self):
+ self._check_attribute(b"abc")
+
+ def test_dataset_ascii_bytes_decode(self):
+ self._check_dataset(b"abc", result="abc")
+
+ def test_attribute_ascii_bytes_decode(self):
+ self._check_attribute(b"abc", result="abc")
+
+ def test_dataset_ascii_str(self):
+ self._check_dataset("abc")
+
+ def test_attribute_ascii_str(self):
+ self._check_attribute("abc")
+
+ def test_dataset_utf8_str(self):
+ self._check_dataset("\u0101bc")
+
+ def test_attribute_utf8_str(self):
+ self._check_attribute("\u0101bc")
+
+ def test_dataset_utf8_bytes(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_dataset(b"\xc4\x81bc")
+
+ def test_attribute_utf8_bytes(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_attribute(b"\xc4\x81bc")
+
+ def test_dataset_utf8_bytes_decode(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_dataset(b"\xc4\x81bc", result="\u0101bc")
+
+ def test_attribute_utf8_bytes_decode(self):
+ # 0xC481 is the byte representation of U+0101
+ self._check_attribute(b"\xc4\x81bc", result="\u0101bc")
+
+ def test_dataset_latin1_bytes(self):
+ # extended ascii character 0xE4
+ self._check_dataset(b"\xe423")
+
+ def test_attribute_latin1_bytes(self):
+ # extended ascii character 0xE4
+ self._check_attribute(b"\xe423")
+
+ def test_dataset_latin1_bytes_decode(self):
+ # U+DCE4: surrogate for extended ascii character 0xE4
+ self._check_dataset(b"\xe423", result="\udce423")
+
+ def test_attribute_latin1_bytes_decode(self):
+ # U+DCE4: surrogate for extended ascii character 0xE4
+ self._check_attribute(b"\xe423", result="\udce423")
+
+ def test_dataset_no_string(self):
+ self._check_dataset(numpy.int64(10))
+
+ def test_attribute_no_string(self):
+ self._check_attribute(numpy.int64(10))
+
+
+def test_visitall_hdf5(tmp_path):
+ """visit HDF5 file content not following links"""
+ external_filepath = tmp_path / "external.h5"
+ with h5py.File(external_filepath, mode="w") as h5file:
+ h5file["target/dataset"] = 50
+
+ filepath = tmp_path / "base.h5"
+ with h5py.File(filepath, mode="w") as h5file:
+ h5file["group/dataset"] = 50
+ h5file["link/soft_link"] = h5py.SoftLink("/group/dataset")
+ h5file["link/external_link"] = h5py.ExternalLink("external.h5", "/target/dataset")
+
+ with h5py.File(filepath, mode="r") as h5file:
+ visited_items = {}
+ for path, item in utils.visitall(h5file):
+ if isinstance(item, h5py.Dataset):
+ content = item[()]
+ elif isinstance(item, h5py.Group):
+ content = None
+ elif isinstance(item, h5py.SoftLink):
+ content = item.path
+ elif isinstance(item, h5py.ExternalLink):
+ content = item.filename, item.path
+ else:
+ raise AssertionError("Item should not be present: %s" % path)
+ visited_items[path] = (item.__class__, content)
+
+ assert visited_items == {
+ "/group": (h5py.Group, None),
+ "/group/dataset": (h5py.Dataset, 50),
+ "/link": (h5py.Group, None),
+ "/link/soft_link": (h5py.SoftLink, "/group/dataset"),
+ "/link/external_link": (h5py.ExternalLink, ("external.h5", "/target/dataset")),
+ }
+
+def test_visitall_commonh5():
+ """Visit commonh5 File object"""
+ fobj = commonh5.File("filename.file", mode="w")
+ group = fobj.create_group("group")
+ dataset = group.create_dataset("dataset", data=numpy.array(50))
+ group["soft_link"] = dataset # Create softlink
+
+ visited_items = dict(utils.visitall(fobj))
+ assert len(visited_items) == 3
+ assert visited_items["/group"] is group
+ assert visited_items["/group/dataset"] is dataset
+ soft_link = visited_items["/group/soft_link"]
+ assert isinstance(soft_link, commonh5.SoftLink)
+ assert soft_link.path == "/group/dataset"
diff --git a/src/silx/io/test/test_write_to_h5.py b/src/silx/io/test/test_write_to_h5.py
new file mode 100644
index 0000000..06149c9
--- /dev/null
+++ b/src/silx/io/test/test_write_to_h5.py
@@ -0,0 +1,118 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Test silx.io.convert.write_to_h5"""
+
+
+import h5py
+import numpy
+from silx.io import spech5
+
+from silx.io.convert import write_to_h5
+from silx.io.dictdump import h5todict
+from silx.io import commonh5
+from silx.io.spech5 import SpecH5
+
+
+def test_with_commonh5(tmp_path):
+ """Test write_to_h5 with commonh5 input"""
+ fobj = commonh5.File("filename.txt", mode="w")
+ group = fobj.create_group("group")
+ dataset = group.create_dataset("dataset", data=numpy.array(50))
+ group["soft_link"] = dataset # Create softlink
+
+ output_filepath = tmp_path / "output.h5"
+ write_to_h5(fobj, str(output_filepath))
+
+ assert h5todict(str(output_filepath)) == {
+ 'group': {'dataset': numpy.array(50), 'soft_link': numpy.array(50)},
+ }
+ with h5py.File(output_filepath, mode="r") as h5file:
+ soft_link = h5file.get("/group/soft_link", getlink=True)
+ assert isinstance(soft_link, h5py.SoftLink)
+ assert soft_link.path == "/group/dataset"
+
+
+def test_with_hdf5(tmp_path):
+ """Test write_to_h5 with HDF5 file input"""
+ filepath = tmp_path / "base.h5"
+ with h5py.File(filepath, mode="w") as h5file:
+ h5file["group/dataset"] = 50
+ h5file["group/soft_link"] = h5py.SoftLink("/group/dataset")
+ h5file["group/external_link"] = h5py.ExternalLink("base.h5", "/group/dataset")
+
+ output_filepath = tmp_path / "output.h5"
+ write_to_h5(str(filepath), str(output_filepath))
+ assert h5todict(str(output_filepath)) == {
+ 'group': {'dataset': 50, 'soft_link': 50},
+ }
+ with h5py.File(output_filepath, mode="r") as h5file:
+ soft_link = h5file.get("group/soft_link", getlink=True)
+ assert isinstance(soft_link, h5py.SoftLink)
+ assert soft_link.path == "/group/dataset"
+
+
+def test_with_spech5(tmp_path):
+ """Test write_to_h5 with SpecH5 input"""
+ filepath = tmp_path / "file.spec"
+ filepath.write_bytes(
+ bytes(
+"""#F /tmp/sf.dat
+
+#S 1 cmd
+#L a b
+1 2
+""",
+ encoding='ascii')
+ )
+
+ output_filepath = tmp_path / "output.h5"
+ with spech5.SpecH5(str(filepath)) as spech5file:
+ write_to_h5(spech5file, str(output_filepath))
+ print(h5todict(str(output_filepath)))
+
+ def assert_equal(item1, item2):
+ if isinstance(item1, dict):
+ assert tuple(item1.keys()) == tuple(item2.keys())
+ for key in item1.keys():
+ assert_equal(item1[key], item2[key])
+ else:
+ numpy.array_equal(item1, item2)
+
+ assert_equal(h5todict(str(output_filepath)), {
+ '1.1': {
+ 'instrument': {
+ 'positioners': {},
+ 'specfile': {
+ 'file_header': ['#F /tmp/sf.dat'],
+ 'scan_header': ['#S 1 cmd', '#L a b'],
+ },
+ },
+ 'measurement': {
+ 'a': [1.],
+ 'b': [2.],
+ },
+ 'start_time': '',
+ 'title': 'cmd',
+ },
+ })
diff --git a/src/silx/io/url.py b/src/silx/io/url.py
new file mode 100644
index 0000000..a3c49e6
--- /dev/null
+++ b/src/silx/io/url.py
@@ -0,0 +1,388 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""URL module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "29/01/2018"
+
+import logging
+from collections.abc import Iterable
+import urllib.parse
+
+
+_logger = logging.getLogger(__name__)
+
+
+class DataUrl(object):
+ """Non-mutable object to parse a string representing a resource data
+ locator.
+
+ It supports:
+
+ - path to file and path inside file to the data
+ - data slicing
+ - fabio or silx access to the data
+ - absolute and relative file access
+
+ >>> # fabio access using absolute path
+ >>> DataUrl("fabio:///data/image.edf?slice=2")
+ >>> DataUrl("fabio:///C:/data/image.edf?slice=2")
+
+ >>> # silx access using absolute path
+ >>> DataUrl("silx:///data/image.h5?path=/data/dataset&slice=1,5")
+ >>> DataUrl("silx:///data/image.edf?path=/scan_0/detector/data")
+ >>> DataUrl("silx:///C:/data/image.edf?path=/scan_0/detector/data")
+
+ >>> # `path=` can be omited if there is no other query keys
+ >>> DataUrl("silx:///data/image.h5?/data/dataset")
+ >>> # is the same as
+ >>> DataUrl("silx:///data/image.h5?path=/data/dataset")
+
+ >>> # `::` can be used instead of `?` which can be useful with shell in
+ >>> # command lines
+ >>> DataUrl("silx:///data/image.h5::/data/dataset")
+ >>> # is the same as
+ >>> DataUrl("silx:///data/image.h5?/data/dataset")
+
+ >>> # Relative path access
+ >>> DataUrl("silx:./image.h5")
+ >>> DataUrl("fabio:./image.edf")
+ >>> DataUrl("silx:image.h5")
+ >>> DataUrl("fabio:image.edf")
+
+ >>> # Is also support parsing of file access for convenience
+ >>> DataUrl("./foo/bar/image.edf")
+ >>> DataUrl("C:/data/")
+
+ :param str path: Path representing a link to a data. If specified, other
+ arguments are not used.
+ :param str file_path: Link to the file containing the the data.
+ None if there is no data selection.
+ :param str data_path: Data selection applyed to the data file selected.
+ None if there is no data selection.
+ :param Tuple[int,slice,Ellipse] data_slice: Slicing applyed of the selected
+ data. None if no slicing applyed.
+ :param Union[str,None] scheme: Scheme of the URL. "silx", "fabio"
+ is supported. Other strings can be provided, but :meth:`is_valid` will
+ be false.
+ """
+ def __init__(self, path=None, file_path=None, data_path=None, data_slice=None, scheme=None):
+ self.__is_valid = False
+ if path is not None:
+ assert(file_path is None)
+ assert(data_path is None)
+ assert(data_slice is None)
+ assert(scheme is None)
+ self.__parse_from_path(path)
+ else:
+ self.__file_path = file_path
+ self.__data_path = data_path
+ self.__data_slice = data_slice
+ self.__scheme = scheme
+ self.__path = None
+ self.__check_validity()
+
+ def __eq__(self, other):
+ if not isinstance(other, DataUrl):
+ return False
+ if self.is_valid() != other.is_valid():
+ return False
+ if self.is_valid():
+ if self.__scheme != other.scheme():
+ return False
+ if self.__file_path != other.file_path():
+ return False
+ if self.__data_path != other.data_path():
+ return False
+ if self.__data_slice != other.data_slice():
+ return False
+ return True
+ else:
+ return self.__path == other.path()
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def __repr__(self):
+ return str(self)
+
+ def __str__(self):
+ if self.is_valid() or self.__path is None:
+ def quote_string(string):
+ if isinstance(string, str):
+ return "'%s'" % string
+ else:
+ return string
+
+ template = "DataUrl(valid=%s, scheme=%s, file_path=%s, data_path=%s, data_slice=%s)"
+ return template % (self.__is_valid,
+ quote_string(self.__scheme),
+ quote_string(self.__file_path),
+ quote_string(self.__data_path),
+ self.__data_slice)
+ else:
+ template = "DataUrl(valid=%s, string=%s)"
+ return template % (self.__is_valid, self.__path)
+
+ def __check_validity(self):
+ """Check the validity of the attributes."""
+ if self.__file_path in [None, ""]:
+ self.__is_valid = False
+ return
+
+ if self.__scheme is None:
+ self.__is_valid = True
+ elif self.__scheme == "fabio":
+ self.__is_valid = self.__data_path is None
+ elif self.__scheme == "silx":
+ # If there is a slice you must have a data path
+ # But you can have a data path without slice
+ slice_implies_data = (self.__data_path is None and self.__data_slice is None) or self.__data_path is not None
+ self.__is_valid = slice_implies_data
+ else:
+ self.__is_valid = False
+
+ @staticmethod
+ def _parse_slice(slice_string):
+ """Parse a slicing sequence and return an associated tuple.
+
+ It supports a sequence of `...`, `:`, and integers separated by a coma.
+
+ :rtype: tuple
+ """
+ def str_to_slice(string):
+ if string == "...":
+ return Ellipsis
+ elif ':' in string:
+ if string == ":":
+ return slice(None)
+ else:
+ def get_value(my_str):
+ if my_str in ('', None):
+ return None
+ else:
+ return int(my_str)
+ sss = string.split(':')
+ start = get_value(sss[0])
+ stop = get_value(sss[1] if len(sss) > 1 else None)
+ step = get_value(sss[2] if len(sss) > 2 else None)
+ return slice(start, stop, step)
+ else:
+ return int(string)
+
+ if slice_string == "":
+ raise ValueError("An empty slice is not valid")
+
+ tokens = slice_string.split(",")
+ data_slice = []
+ for t in tokens:
+ try:
+ data_slice.append(str_to_slice(t))
+ except ValueError:
+ raise ValueError("'%s' is not a valid slicing" % t)
+ return tuple(data_slice)
+
+ def __parse_from_path(self, path):
+ """Parse the path and initialize attributes.
+
+ :param str path: Path representing the URL.
+ """
+ self.__path = path
+ # only replace if ? not here already. Otherwise can mess sith
+ # data_slice if == ::2 for example
+ if '?' not in path:
+ path = path.replace("::", "?", 1)
+ url = urllib.parse.urlparse(path)
+
+ is_valid = True
+
+ if len(url.scheme) <= 2:
+ # Windows driver
+ scheme = None
+ pos = self.__path.index(url.path)
+ file_path = self.__path[0:pos] + url.path
+ else:
+ scheme = url.scheme if url.scheme != "" else None
+ file_path = url.path
+
+ # Check absolute windows path
+ if len(file_path) > 2 and file_path[0] == '/':
+ if file_path[1] == ":" or file_path[2] == ":":
+ file_path = file_path[1:]
+
+ self.__scheme = scheme
+ self.__file_path = file_path
+
+ query = urllib.parse.parse_qsl(url.query, keep_blank_values=True)
+ if len(query) == 1 and query[0][1] == "":
+ # there is no query keys
+ data_path = query[0][0]
+ data_slice = None
+ else:
+ merged_query = {}
+ for name, value in query:
+ if name in query:
+ merged_query[name].append(value)
+ else:
+ merged_query[name] = [value]
+
+ def pop_single_value(merged_query, name):
+ if name in merged_query:
+ values = merged_query.pop(name)
+ if len(values) > 1:
+ _logger.warning("More than one query key named '%s'. The last one is used.", name)
+ value = values[-1]
+ else:
+ value = None
+ return value
+
+ data_path = pop_single_value(merged_query, "path")
+ data_slice = pop_single_value(merged_query, "slice")
+ if data_slice is not None:
+ try:
+ data_slice = self._parse_slice(data_slice)
+ except ValueError:
+ is_valid = False
+ data_slice = None
+
+ for key in merged_query.keys():
+ _logger.warning("Query key %s unsupported. Key skipped.", key)
+
+ self.__data_path = data_path
+ self.__data_slice = data_slice
+
+ if is_valid:
+ self.__check_validity()
+ else:
+ self.__is_valid = False
+
+ def is_valid(self):
+ """Returns true if the URL is valid. Else attributes can be None.
+
+ :rtype: bool
+ """
+ return self.__is_valid
+
+ def path(self):
+ """Returns the string representing the URL.
+
+ :rtype: str
+ """
+ if self.__path is not None:
+ return self.__path
+
+ def slice_to_string(data_slice):
+ if data_slice == Ellipsis:
+ return "..."
+ elif data_slice == slice(None):
+ return ":"
+ elif isinstance(data_slice, int):
+ return str(data_slice)
+ else:
+ raise TypeError("Unexpected slicing type. Found %s" % type(data_slice))
+
+ if self.__data_path is not None and self.__data_slice is None:
+ query = self.__data_path
+ else:
+ queries = []
+ if self.__data_path is not None:
+ queries.append("path=" + self.__data_path)
+ if self.__data_slice is not None:
+ if isinstance(self.__data_slice, Iterable):
+ data_slice = ",".join([slice_to_string(s) for s in self.__data_slice])
+ else:
+ data_slice = slice_to_string(self.__data_slice)
+ queries.append("slice=" + data_slice)
+ query = "&".join(queries)
+
+ path = ""
+ if self.__file_path is not None:
+ path += self.__file_path
+
+ if query != "":
+ path = path + "?" + query
+
+ if self.__scheme is not None:
+ if self.is_absolute():
+ if path.startswith("/"):
+ path = self.__scheme + "://" + path
+ else:
+ path = self.__scheme + ":///" + path
+ else:
+ path = self.__scheme + ":" + path
+
+ return path
+
+ def is_absolute(self):
+ """Returns true if the file path is an absolute path.
+
+ :rtype: bool
+ """
+ file_path = self.file_path()
+ if file_path is None:
+ return False
+ if len(file_path) > 0:
+ if file_path[0] == "/":
+ return True
+ if len(file_path) > 2:
+ # Windows
+ if file_path[1] == ":" or file_path[2] == ":":
+ return True
+ elif len(file_path) > 1:
+ # Windows
+ if file_path[1] == ":":
+ return True
+ return False
+
+ def file_path(self):
+ """Returns the path to the file containing the data.
+
+ :rtype: str
+ """
+ return self.__file_path
+
+ def data_path(self):
+ """Returns the path inside the file to the data.
+
+ :rtype: str
+ """
+ return self.__data_path
+
+ def data_slice(self):
+ """Returns the slicing applied to the data.
+
+ It is a tuple containing numbers, slice or ellipses.
+
+ :rtype: Tuple[int, slice, Ellipse]
+ """
+ return self.__data_slice
+
+ def scheme(self):
+ """Returns the scheme. It can be None if no scheme is specified.
+
+ :rtype: Union[str, None]
+ """
+ return self.__scheme
diff --git a/src/silx/io/utils.py b/src/silx/io/utils.py
new file mode 100644
index 0000000..642c6fb
--- /dev/null
+++ b/src/silx/io/utils.py
@@ -0,0 +1,1185 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+""" I/O utility functions"""
+
+__authors__ = ["P. Knobel", "V. Valls"]
+__license__ = "MIT"
+__date__ = "03/12/2020"
+
+import enum
+import os.path
+import sys
+import time
+import logging
+import collections
+import urllib.parse
+
+import numpy
+
+from silx.utils.proxy import Proxy
+import silx.io.url
+from .._version import calc_hexversion
+
+import h5py
+import h5py.h5t
+import h5py.h5a
+
+try:
+ import h5pyd
+except ImportError as e:
+ h5pyd = None
+
+logger = logging.getLogger(__name__)
+
+NEXUS_HDF5_EXT = [".h5", ".nx5", ".nxs", ".hdf", ".hdf5", ".cxi"]
+"""List of possible extensions for HDF5 file formats."""
+
+
+class H5Type(enum.Enum):
+ """Identify a set of HDF5 concepts"""
+ DATASET = 1
+ GROUP = 2
+ FILE = 3
+ SOFT_LINK = 4
+ EXTERNAL_LINK = 5
+ HARD_LINK = 6
+
+
+_CLASSES_TYPE = None
+"""Store mapping between classes and types"""
+
+string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
+
+builtin_open = open
+
+
+def supported_extensions(flat_formats=True):
+ """Returns the list file extensions supported by `silx.open`.
+
+ The result filter out formats when the expected module is not available.
+
+ :param bool flat_formats: If true, also include flat formats like npy or
+ edf (while the expected module is available)
+ :returns: A dictionary indexed by file description and containing a set of
+ extensions (an extension is a string like "\\*.ext").
+ :rtype: Dict[str, Set[str]]
+ """
+ formats = collections.OrderedDict()
+ formats["HDF5 files"] = set(["*.h5", "*.hdf", "*.hdf5"])
+ formats["NeXus files"] = set(["*.nx", "*.nxs", "*.h5", "*.hdf", "*.hdf5"])
+ formats["NeXus layout from spec files"] = set(["*.dat", "*.spec", "*.mca"])
+ if flat_formats:
+ try:
+ from silx.io import fabioh5
+ except ImportError:
+ fabioh5 = None
+ if fabioh5 is not None:
+ formats["NeXus layout from fabio files"] = set(fabioh5.supported_extensions())
+
+ extensions = ["*.npz"]
+ if flat_formats:
+ extensions.append("*.npy")
+
+ formats["Numpy binary files"] = set(extensions)
+ formats["Coherent X-Ray Imaging files"] = set(["*.cxi"])
+ formats["FIO files"] = set(["*.fio"])
+ return formats
+
+
+def save1D(fname, x, y, xlabel=None, ylabels=None, filetype=None,
+ fmt="%.7g", csvdelim=";", newline="\n", header="",
+ footer="", comments="#", autoheader=False):
+ """Saves any number of curves to various formats: `Specfile`, `CSV`,
+ `txt` or `npy`. All curves must have the same number of points and share
+ the same ``x`` values.
+
+ :param fname: Output file path, or file handle open in write mode.
+ If ``fname`` is a path, file is opened in ``w`` mode. Existing file
+ with a same name will be overwritten.
+ :param x: 1D-Array (or list) of abscissa values.
+ :param y: 2D-array (or list of lists) of ordinates values. First index
+ is the curve index, second index is the sample index. The length
+ of the second dimension (number of samples) must be equal to
+ ``len(x)``. ``y`` can be a 1D-array in case there is only one curve
+ to be saved.
+ :param filetype: Filetype: ``"spec", "csv", "txt", "ndarray"``.
+ If ``None``, filetype is detected from file name extension
+ (``.dat, .csv, .txt, .npy``).
+ :param xlabel: Abscissa label
+ :param ylabels: List of `y` labels
+ :param fmt: Format string for data. You can specify a short format
+ string that defines a single format for both ``x`` and ``y`` values,
+ or a list of two different format strings (e.g. ``["%d", "%.7g"]``).
+ Default is ``"%.7g"``.
+ This parameter does not apply to the `npy` format.
+ :param csvdelim: String or character separating columns in `txt` and
+ `CSV` formats. The user is responsible for ensuring that this
+ delimiter is not used in data labels when writing a `CSV` file.
+ :param newline: String or character separating lines/records in `txt`
+ format (default is line break character ``\\n``).
+ :param header: String that will be written at the beginning of the file in
+ `txt` format.
+ :param footer: String that will be written at the end of the file in `txt`
+ format.
+ :param comments: String that will be prepended to the ``header`` and
+ ``footer`` strings, to mark them as comments. Default: ``#``.
+ :param autoheader: In `CSV` or `txt`, ``True`` causes the first header
+ line to be written as a standard CSV header line with column labels
+ separated by the specified CSV delimiter.
+
+ When saving to Specfile format, each curve is saved as a separate scan
+ with two data columns (``x`` and ``y``).
+
+ `CSV` and `txt` formats are similar, except that the `txt` format allows
+ user defined header and footer text blocks, whereas the `CSV` format has
+ only a single header line with columns labels separated by field
+ delimiters and no footer. The `txt` format also allows defining a record
+ separator different from a line break.
+
+ The `npy` format is written with ``numpy.save`` and can be read back with
+ ``numpy.load``. If ``xlabel`` and ``ylabels`` are undefined, data is saved
+ as a regular 2D ``numpy.ndarray`` (contatenation of ``x`` and ``y``). If
+ both ``xlabel`` and ``ylabels`` are defined, the data is saved as a
+ ``numpy.recarray`` after being transposed and having labels assigned to
+ columns.
+ """
+
+ available_formats = ["spec", "csv", "txt", "ndarray"]
+
+ if filetype is None:
+ exttypes = {".dat": "spec",
+ ".csv": "csv",
+ ".txt": "txt",
+ ".npy": "ndarray"}
+ outfname = (fname if not hasattr(fname, "name") else
+ fname.name)
+ fileext = os.path.splitext(outfname)[1]
+ if fileext in exttypes:
+ filetype = exttypes[fileext]
+ else:
+ raise IOError("File type unspecified and could not be " +
+ "inferred from file extension (not in " +
+ "txt, dat, csv, npy)")
+ else:
+ filetype = filetype.lower()
+
+ if filetype not in available_formats:
+ raise IOError("File type %s is not supported" % (filetype))
+
+ # default column headers
+ if xlabel is None:
+ xlabel = "x"
+ if ylabels is None:
+ if numpy.array(y).ndim > 1:
+ ylabels = ["y%d" % i for i in range(len(y))]
+ else:
+ ylabels = ["y"]
+ elif isinstance(ylabels, (list, tuple)):
+ # if ylabels is provided as a list, every element must
+ # be a string
+ ylabels = [ylabel if isinstance(ylabel, string_types) else "y%d" % i
+ for ylabel in ylabels]
+
+ if filetype.lower() == "spec":
+ # Check if we have regular data:
+ ref = len(x)
+ regular = True
+ for one_y in y:
+ regular &= len(one_y) == ref
+ if regular:
+ if isinstance(fmt, (list, tuple)) and len(fmt) < (len(ylabels) + 1):
+ fmt = fmt + [fmt[-1] * (1 + len(ylabels) - len(fmt))]
+ specf = savespec(fname, x, y, xlabel, ylabels, fmt=fmt,
+ scan_number=1, mode="w", write_file_header=True,
+ close_file=False)
+ else:
+ y_array = numpy.asarray(y)
+ # make sure y_array is a 2D array even for a single curve
+ if y_array.ndim == 1:
+ y_array.shape = 1, -1
+ elif y_array.ndim not in [1, 2]:
+ raise IndexError("y must be a 1D or 2D array")
+
+ # First curve
+ specf = savespec(fname, x, y_array[0], xlabel, ylabels[0], fmt=fmt,
+ scan_number=1, mode="w", write_file_header=True,
+ close_file=False)
+ # Other curves
+ for i in range(1, y_array.shape[0]):
+ specf = savespec(specf, x, y_array[i], xlabel, ylabels[i],
+ fmt=fmt, scan_number=i + 1, mode="w",
+ write_file_header=False, close_file=False)
+
+ # close file if we created it
+ if not hasattr(fname, "write"):
+ specf.close()
+
+ else:
+ autoheader_line = xlabel + csvdelim + csvdelim.join(ylabels)
+ if xlabel is not None and ylabels is not None and filetype == "csv":
+ # csv format: optional single header line with labels, no footer
+ if autoheader:
+ header = autoheader_line + newline
+ else:
+ header = ""
+ comments = ""
+ footer = ""
+ newline = "\n"
+ elif filetype == "txt" and autoheader:
+ # Comments string is added at the beginning of header string in
+ # savetxt(). We add another one after the first header line and
+ # before the rest of the header.
+ if header:
+ header = autoheader_line + newline + comments + header
+ else:
+ header = autoheader_line + newline
+
+ # Concatenate x and y in a single 2D array
+ X = numpy.vstack((x, y))
+
+ if filetype.lower() in ["csv", "txt"]:
+ X = X.transpose()
+ savetxt(fname, X, fmt=fmt, delimiter=csvdelim,
+ newline=newline, header=header, footer=footer,
+ comments=comments)
+
+ elif filetype.lower() == "ndarray":
+ if xlabel is not None and ylabels is not None:
+ labels = [xlabel] + ylabels
+
+ # .transpose is needed here because recarray labels
+ # apply to columns
+ X = numpy.core.records.fromrecords(X.transpose(),
+ names=labels)
+ numpy.save(fname, X)
+
+
+# Replace with numpy.savetxt when dropping support of numpy < 1.7.0
+def savetxt(fname, X, fmt="%.7g", delimiter=";", newline="\n",
+ header="", footer="", comments="#"):
+ """``numpy.savetxt`` backport of header and footer arguments from
+ numpy=1.7.0.
+
+ See ``numpy.savetxt`` help:
+ http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.savetxt.html
+ """
+ if not hasattr(fname, "name"):
+ ffile = builtin_open(fname, 'wb')
+ else:
+ ffile = fname
+
+ if header:
+ if sys.version_info[0] >= 3:
+ header = header.encode("utf-8")
+ ffile.write(header)
+
+ numpy.savetxt(ffile, X, fmt, delimiter, newline)
+
+ if footer:
+ footer = (comments + footer.replace(newline, newline + comments) +
+ newline)
+ if sys.version_info[0] >= 3:
+ footer = footer.encode("utf-8")
+ ffile.write(footer)
+
+ if not hasattr(fname, "name"):
+ ffile.close()
+
+
+def savespec(specfile, x, y, xlabel="X", ylabel="Y", fmt="%.7g",
+ scan_number=1, mode="w", write_file_header=True,
+ close_file=False):
+ """Saves one curve to a SpecFile.
+
+ The curve is saved as a scan with two data columns. To save multiple
+ curves to a single SpecFile, call this function for each curve by
+ providing the same file handle each time.
+
+ :param specfile: Output SpecFile name, or file handle open in write
+ or append mode. If a file name is provided, a new file is open in
+ write mode (existing file with the same name will be lost)
+ :param x: 1D-Array (or list) of abscissa values
+ :param y: 1D-array (or list), or list of them of ordinates values.
+ All dataset must have the same length as x
+ :param xlabel: Abscissa label (default ``"X"``)
+ :param ylabel: Ordinate label, may be a list of labels when multiple curves
+ are to be saved together.
+ :param fmt: Format string for data. You can specify a short format
+ string that defines a single format for both ``x`` and ``y`` values,
+ or a list of two different format strings (e.g. ``["%d", "%.7g"]``).
+ Default is ``"%.7g"``.
+ :param scan_number: Scan number (default 1).
+ :param mode: Mode for opening file: ``w`` (default), ``a``, ``r+``,
+ ``w+``, ``a+``. This parameter is only relevant if ``specfile`` is a
+ path.
+ :param write_file_header: If ``True``, write a file header before writing
+ the scan (``#F`` and ``#D`` line).
+ :param close_file: If ``True``, close the file after saving curve.
+ :return: ``None`` if ``close_file`` is ``True``, else return the file
+ handle.
+ """
+ # Make sure we use binary mode for write
+ # (issue with windows: write() replaces \n with os.linesep in text mode)
+ if "b" not in mode:
+ first_letter = mode[0]
+ assert first_letter in "rwa"
+ mode = mode.replace(first_letter, first_letter + "b")
+
+ x_array = numpy.asarray(x)
+ y_array = numpy.asarray(y)
+ if y_array.ndim > 2:
+ raise IndexError("Y columns must have be packed as 1D")
+
+ if y_array.shape[-1] != x_array.shape[0]:
+ raise IndexError("X and Y columns must have the same length")
+
+ if y_array.ndim == 2:
+ assert isinstance(ylabel, (list, tuple))
+ assert y_array.shape[0] == len(ylabel)
+ labels = (xlabel, *ylabel)
+ else:
+ labels = (xlabel, ylabel)
+ data = numpy.vstack((x_array, y_array))
+ ncol = data.shape[0]
+ assert len(labels) == ncol
+
+ print(xlabel, ylabel, fmt, ncol, x_array, y_array)
+ if isinstance(fmt, string_types) and fmt.count("%") == 1:
+ full_fmt_string = " ".join([fmt] * ncol)
+ elif isinstance(fmt, (list, tuple)) and len(fmt) == ncol:
+ full_fmt_string = " ".join(fmt)
+ else:
+ raise ValueError("`fmt` must be a single format string or a list of " +
+ "format strings with as many format as ncolumns")
+
+ if not hasattr(specfile, "write"):
+ f = builtin_open(specfile, mode)
+ else:
+ f = specfile
+
+ current_date = "#D %s" % (time.ctime(time.time()))
+ if write_file_header:
+ lines = [ "#F %s" % f.name, current_date, ""]
+ else:
+ lines = [""]
+
+ lines += [ "#S %d %s" % (scan_number, labels[1]),
+ current_date,
+ "#N %d" % ncol,
+ "#L " + " ".join(labels)]
+
+ for i in data.T:
+ lines.append(full_fmt_string % tuple(i))
+ lines.append("")
+ output = "\n".join(lines)
+ f.write(output.encode())
+
+ if close_file:
+ f.close()
+ return None
+ return f
+
+
+def h5ls(h5group, lvl=0):
+ """Return a simple string representation of a HDF5 tree structure.
+
+ :param h5group: Any :class:`h5py.Group` or :class:`h5py.File` instance,
+ or a HDF5 file name
+ :param lvl: Number of tabulations added to the group. ``lvl`` is
+ incremented as we recursively process sub-groups.
+ :return: String representation of an HDF5 tree structure
+
+
+ Group names and dataset representation are printed preceded by a number of
+ tabulations corresponding to their depth in the tree structure.
+ Datasets are represented as :class:`h5py.Dataset` objects.
+
+ Example::
+
+ >>> print(h5ls("Downloads/sample.h5"))
+ +fields
+ +fieldB
+ <HDF5 dataset "z": shape (256, 256), type "<f4">
+ +fieldE
+ <HDF5 dataset "x": shape (256, 256), type "<f4">
+ <HDF5 dataset "y": shape (256, 256), type "<f4">
+
+ .. note:: This function requires `h5py <http://www.h5py.org/>`_ to be
+ installed.
+ """
+ h5repr = ''
+ if is_group(h5group):
+ h5f = h5group
+ elif isinstance(h5group, string_types):
+ h5f = open(h5group) # silx.io.open
+ else:
+ raise TypeError("h5group must be a hdf5-like group object or a file name.")
+
+ for key in h5f.keys():
+ # group
+ if hasattr(h5f[key], 'keys'):
+ h5repr += '\t' * lvl + '+' + key
+ h5repr += '\n'
+ h5repr += h5ls(h5f[key], lvl + 1)
+ # dataset
+ else:
+ h5repr += '\t' * lvl
+ h5repr += str(h5f[key])
+ h5repr += '\n'
+
+ if isinstance(h5group, string_types):
+ h5f.close()
+
+ return h5repr
+
+
+def _open_local_file(filename):
+ """
+ Load a file as an `h5py.File`-like object.
+
+ Format supported:
+ - h5 files, if `h5py` module is installed
+ - SPEC files exposed as a NeXus layout
+ - raster files exposed as a NeXus layout (if `fabio` is installed)
+ - fio files exposed as a NeXus layout
+ - Numpy files ('npy' and 'npz' files)
+
+ The file is opened in read-only mode.
+
+ :param str filename: A filename
+ :raises: IOError if the file can't be loaded as an h5py.File like object
+ :rtype: h5py.File
+ """
+ if not os.path.isfile(filename):
+ raise IOError("Filename '%s' must be a file path" % filename)
+
+ debugging_info = []
+ try:
+ _, extension = os.path.splitext(filename)
+
+ if extension in [".npz", ".npy"]:
+ try:
+ from . import rawh5
+ return rawh5.NumpyFile(filename)
+ except (IOError, ValueError) as e:
+ debugging_info.append((sys.exc_info(),
+ "File '%s' can't be read as a numpy file." % filename))
+
+ if h5py.is_hdf5(filename):
+ try:
+ return h5py.File(filename, "r")
+ except OSError:
+ return h5py.File(filename, "r", libver='latest', swmr=True)
+
+ try:
+ from . import fabioh5
+ return fabioh5.File(filename)
+ except ImportError:
+ debugging_info.append((sys.exc_info(), "fabioh5 can't be loaded."))
+ except Exception:
+ debugging_info.append((sys.exc_info(),
+ "File '%s' can't be read as fabio file." % filename))
+
+ try:
+ from . import spech5
+ return spech5.SpecH5(filename)
+ except ImportError:
+ debugging_info.append((sys.exc_info(),
+ "spech5 can't be loaded."))
+ except IOError:
+ debugging_info.append((sys.exc_info(),
+ "File '%s' can't be read as spec file." % filename))
+
+ try:
+ from . import fioh5
+ return fioh5.FioH5(filename)
+ except IOError:
+ debugging_info.append((sys.exc_info(),
+ "File '%s' can't be read as fio file." % filename))
+
+ finally:
+ for exc_info, message in debugging_info:
+ logger.debug(message, exc_info=exc_info)
+
+ raise IOError("File '%s' can't be read as HDF5" % filename)
+
+
+class _MainNode(Proxy):
+ """A main node is a sub node of the HDF5 tree which is responsible of the
+ closure of the file.
+
+ It is a proxy to the sub node, plus support context manager and `close`
+ method usually provided by `h5py.File`.
+
+ :param h5_node: Target to the proxy.
+ :param h5_file: Main file. This object became the owner of this file.
+ """
+
+ def __init__(self, h5_node, h5_file):
+ super(_MainNode, self).__init__(h5_node)
+ self.__file = h5_file
+ self.__class = get_h5_class(h5_node)
+
+ @property
+ def h5_class(self):
+ """Returns the HDF5 class which is mimicked by this class.
+
+ :rtype: H5Type
+ """
+ return self.__class
+
+ @property
+ def h5py_class(self):
+ """Returns the h5py classes which is mimicked by this class. It can be
+ one of `h5py.File, h5py.Group` or `h5py.Dataset`.
+
+ :rtype: h5py class
+ """
+ return h5type_to_h5py_class(self.__class)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self):
+ """Close the file"""
+ self.__file.close()
+ self.__file = None
+
+
+def open(filename): # pylint:disable=redefined-builtin
+ """
+ Open a file as an `h5py`-like object.
+
+ Format supported:
+ - h5 files, if `h5py` module is installed
+ - SPEC files exposed as a NeXus layout
+ - raster files exposed as a NeXus layout (if `fabio` is installed)
+ - fio files exposed as a NeXus layout
+ - Numpy files ('npy' and 'npz' files)
+
+ The filename can be trailled an HDF5 path using the separator `::`. In this
+ case the object returned is a proxy to the target node, implementing the
+ `close` function and supporting `with` context.
+
+ The file is opened in read-only mode.
+
+ :param str filename: A filename which can containt an HDF5 path by using
+ `::` separator.
+ :raises: IOError if the file can't be loaded or path can't be found
+ :rtype: h5py-like node
+ """
+ url = silx.io.url.DataUrl(filename)
+
+ if url.scheme() in [None, "file", "silx"]:
+ # That's a local file
+ if not url.is_valid():
+ raise IOError("URL '%s' is not valid" % filename)
+ h5_file = _open_local_file(url.file_path())
+ elif url.scheme() in ["fabio"]:
+ raise IOError("URL '%s' containing fabio scheme is not supported" % filename)
+ else:
+ # That's maybe an URL supported by h5pyd
+ uri = urllib.parse.urlparse(filename)
+ if h5pyd is None:
+ raise IOError("URL '%s' unsupported. Try to install h5pyd." % filename)
+ path = uri.path
+ endpoint = "%s://%s" % (uri.scheme, uri.netloc)
+ if path.startswith("/"):
+ path = path[1:]
+ return h5pyd.File(path, 'r', endpoint=endpoint)
+
+ if url.data_slice():
+ raise IOError("URL '%s' containing slicing is not supported" % filename)
+
+ if url.data_path() in [None, "/", ""]:
+ # The full file is requested
+ return h5_file
+ else:
+ # Only a children is requested
+ if url.data_path() not in h5_file:
+ msg = "File '%s' does not contain path '%s'." % (filename, url.data_path())
+ raise IOError(msg)
+ node = h5_file[url.data_path()]
+ proxy = _MainNode(node, h5_file)
+ return proxy
+
+
+def _get_classes_type():
+ """Returns a mapping between Python classes and HDF5 concepts.
+
+ This function allow an lazy initialization to avoid recurssive import
+ of modules.
+ """
+ global _CLASSES_TYPE
+ from . import commonh5
+
+ if _CLASSES_TYPE is not None:
+ return _CLASSES_TYPE
+
+ _CLASSES_TYPE = collections.OrderedDict()
+
+ _CLASSES_TYPE[commonh5.Dataset] = H5Type.DATASET
+ _CLASSES_TYPE[commonh5.File] = H5Type.FILE
+ _CLASSES_TYPE[commonh5.Group] = H5Type.GROUP
+ _CLASSES_TYPE[commonh5.SoftLink] = H5Type.SOFT_LINK
+
+ _CLASSES_TYPE[h5py.Dataset] = H5Type.DATASET
+ _CLASSES_TYPE[h5py.File] = H5Type.FILE
+ _CLASSES_TYPE[h5py.Group] = H5Type.GROUP
+ _CLASSES_TYPE[h5py.SoftLink] = H5Type.SOFT_LINK
+ _CLASSES_TYPE[h5py.HardLink] = H5Type.HARD_LINK
+ _CLASSES_TYPE[h5py.ExternalLink] = H5Type.EXTERNAL_LINK
+
+ if h5pyd is not None:
+ _CLASSES_TYPE[h5pyd.Dataset] = H5Type.DATASET
+ _CLASSES_TYPE[h5pyd.File] = H5Type.FILE
+ _CLASSES_TYPE[h5pyd.Group] = H5Type.GROUP
+ _CLASSES_TYPE[h5pyd.SoftLink] = H5Type.SOFT_LINK
+ _CLASSES_TYPE[h5pyd.HardLink] = H5Type.HARD_LINK
+ _CLASSES_TYPE[h5pyd.ExternalLink] = H5Type.EXTERNAL_LINK
+
+ return _CLASSES_TYPE
+
+
+def get_h5_class(obj=None, class_=None):
+ """
+ Returns the HDF5 type relative to the object or to the class.
+
+ :param obj: Instance of an object
+ :param class_: A class
+ :rtype: H5Type
+ """
+ if class_ is None:
+ class_ = obj.__class__
+
+ classes = _get_classes_type()
+ t = classes.get(class_, None)
+ if t is not None:
+ return t
+
+ if obj is not None:
+ if hasattr(obj, "h5_class"):
+ return obj.h5_class
+
+ for referencedClass_, type_ in classes.items():
+ if issubclass(class_, referencedClass_):
+ classes[class_] = type_
+ return type_
+
+ classes[class_] = None
+ return None
+
+
+def h5type_to_h5py_class(type_):
+ """
+ Returns an h5py class from an H5Type. None if nothing found.
+
+ :param H5Type type_:
+ :rtype: H5py class
+ """
+ if type_ == H5Type.FILE:
+ return h5py.File
+ if type_ == H5Type.GROUP:
+ return h5py.Group
+ if type_ == H5Type.DATASET:
+ return h5py.Dataset
+ if type_ == H5Type.SOFT_LINK:
+ return h5py.SoftLink
+ if type_ == H5Type.HARD_LINK:
+ return h5py.HardLink
+ if type_ == H5Type.EXTERNAL_LINK:
+ return h5py.ExternalLink
+ return None
+
+
+def get_h5py_class(obj):
+ """Returns the h5py class from an object.
+
+ If it is an h5py object or an h5py-like object, an h5py class is returned.
+ If the object is not an h5py-like object, None is returned.
+
+ :param obj: An object
+ :return: An h5py object
+ """
+ if hasattr(obj, "h5py_class"):
+ return obj.h5py_class
+ type_ = get_h5_class(obj)
+ return h5type_to_h5py_class(type_)
+
+
+def is_file(obj):
+ """
+ True is the object is an h5py.File-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t == H5Type.FILE
+
+
+def is_group(obj):
+ """
+ True if the object is a h5py.Group-like object. A file is a group.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t in [H5Type.GROUP, H5Type.FILE]
+
+
+def is_dataset(obj):
+ """
+ True if the object is a h5py.Dataset-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t == H5Type.DATASET
+
+
+def is_softlink(obj):
+ """
+ True if the object is a h5py.SoftLink-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t == H5Type.SOFT_LINK
+
+
+def is_externallink(obj):
+ """
+ True if the object is a h5py.ExternalLink-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t == H5Type.EXTERNAL_LINK
+
+
+def is_link(obj):
+ """
+ True if the object is a h5py link-like object.
+
+ :param obj: An object
+ """
+ t = get_h5_class(obj)
+ return t in {H5Type.SOFT_LINK, H5Type.EXTERNAL_LINK}
+
+
+def _visitall(item, path=''):
+ """Helper function for func:`visitall`.
+
+ :param item: Item to visit
+ :param str path: Relative path of the item
+ """
+ if not is_group(item):
+ return
+
+ for name, child_item in item.items():
+ if isinstance(child_item, (h5py.Group, h5py.Dataset)):
+ link = item.get(name, getlink=True)
+ else:
+ link = child_item
+ child_path = '/'.join((path, name))
+
+ ret = link if link is not None and is_link(link) else child_item
+ yield child_path, ret
+ yield from _visitall(child_item, child_path)
+
+
+def visitall(item):
+ """Visit entity recursively including links.
+
+ It does not follow links.
+ This is a generator yielding (relative path, object) for visited items.
+
+ :param item: The item to visit.
+ """
+ yield from _visitall(item, '')
+
+
+def get_data(url):
+ """Returns a numpy data from an URL.
+
+ Examples:
+
+ >>> # 1st frame from an EDF using silx.io.open
+ >>> data = silx.io.get_data("silx:/users/foo/image.edf::/scan_0/instrument/detector_0/data[0]")
+
+ >>> # 1st frame from an EDF using fabio
+ >>> data = silx.io.get_data("fabio:/users/foo/image.edf::[0]")
+
+ Yet 2 schemes are supported by the function.
+
+ - If `silx` scheme is used, the file is opened using
+ :meth:`silx.io.open`
+ and the data is reach using usually NeXus paths.
+ - If `fabio` scheme is used, the file is opened using :meth:`fabio.open`
+ from the FabIO library.
+ No data path have to be specified, but each frames can be accessed
+ using the data slicing.
+ This shortcut of :meth:`silx.io.open` allow to have a faster access to
+ the data.
+
+ .. seealso:: :class:`silx.io.url.DataUrl`
+
+ :param Union[str,silx.io.url.DataUrl]: A data URL
+ :rtype: Union[numpy.ndarray, numpy.generic]
+ :raises ImportError: If the mandatory library to read the file is not
+ available.
+ :raises ValueError: If the URL is not valid or do not match the data
+ :raises IOError: If the file is not found or in case of internal error of
+ :meth:`fabio.open` or :meth:`silx.io.open`. In this last case more
+ informations are displayed in debug mode.
+ """
+ if not isinstance(url, silx.io.url.DataUrl):
+ url = silx.io.url.DataUrl(url)
+
+ if not url.is_valid():
+ raise ValueError("URL '%s' is not valid" % url.path())
+
+ if not os.path.exists(url.file_path()):
+ raise IOError("File '%s' not found" % url.file_path())
+
+ if url.scheme() == "silx":
+ data_path = url.data_path()
+ data_slice = url.data_slice()
+
+ with open(url.file_path()) as h5:
+ if data_path not in h5:
+ raise ValueError("Data path from URL '%s' not found" % url.path())
+ data = h5[data_path]
+
+ if not silx.io.is_dataset(data):
+ raise ValueError("Data path from URL '%s' is not a dataset" % url.path())
+
+ if data_slice is not None:
+ data = h5py_read_dataset(data, index=data_slice)
+ else:
+ # works for scalar and array
+ data = h5py_read_dataset(data)
+
+ elif url.scheme() == "fabio":
+ import fabio
+ data_slice = url.data_slice()
+ if data_slice is None:
+ data_slice = (0,)
+ if data_slice is None or len(data_slice) != 1:
+ raise ValueError("Fabio slice expect a single frame, but %s found" % data_slice)
+ index = data_slice[0]
+ if not isinstance(index, int):
+ raise ValueError("Fabio slice expect a single integer, but %s found" % data_slice)
+
+ try:
+ fabio_file = fabio.open(url.file_path())
+ except Exception:
+ logger.debug("Error while opening %s with fabio", url.file_path(), exc_info=True)
+ raise IOError("Error while opening %s with fabio (use debug for more information)" % url.path())
+
+ if fabio_file.nframes == 1:
+ if index != 0:
+ raise ValueError("Only a single frame available. Slice %s out of range" % index)
+ data = fabio_file.data
+ else:
+ data = fabio_file.getframe(index).data
+
+ # There is no explicit close
+ fabio_file = None
+
+ else:
+ raise ValueError("Scheme '%s' not supported" % url.scheme())
+
+ return data
+
+
+def rawfile_to_h5_external_dataset(bin_file, output_url, shape, dtype,
+ overwrite=False):
+ """
+ Create a HDF5 dataset at `output_url` pointing to the given vol_file.
+
+ Either `shape` or `info_file` must be provided.
+
+ :param str bin_file: Path to the .vol file
+ :param DataUrl output_url: HDF5 URL where to save the external dataset
+ :param tuple shape: Shape of the volume
+ :param numpy.dtype dtype: Data type of the volume elements (default: float32)
+ :param bool overwrite: True to allow overwriting (default: False).
+ """
+ assert isinstance(output_url, silx.io.url.DataUrl)
+ assert isinstance(shape, (tuple, list))
+ v_majeur, v_mineur, v_micro = [int(i) for i in h5py.version.version.split('.')[:3]]
+ if calc_hexversion(v_majeur, v_mineur, v_micro)< calc_hexversion(2,9,0):
+ raise Exception('h5py >= 2.9 should be installed to access the '
+ 'external feature.')
+
+ with h5py.File(output_url.file_path(), mode="a") as _h5_file:
+ if output_url.data_path() in _h5_file:
+ if overwrite is False:
+ raise ValueError('data_path already exists')
+ else:
+ logger.warning('will overwrite path %s' % output_url.data_path())
+ del _h5_file[output_url.data_path()]
+ external = [(bin_file, 0, h5py.h5f.UNLIMITED)]
+ _h5_file.create_dataset(output_url.data_path(),
+ shape,
+ dtype=dtype,
+ external=external)
+
+
+def vol_to_h5_external_dataset(vol_file, output_url, info_file=None,
+ vol_dtype=numpy.float32, overwrite=False):
+ """
+ Create a HDF5 dataset at `output_url` pointing to the given vol_file.
+
+ If the vol_file.info containing the shape is not on the same folder as the
+ vol-file then you should specify her location.
+
+ :param str vol_file: Path to the .vol file
+ :param DataUrl output_url: HDF5 URL where to save the external dataset
+ :param Union[str,None] info_file:
+ .vol.info file name written by pyhst and containing the shape information
+ :param numpy.dtype vol_dtype: Data type of the volume elements (default: float32)
+ :param bool overwrite: True to allow overwriting (default: False).
+ :raises ValueError: If fails to read shape from the .vol.info file
+ """
+ _info_file = info_file
+ if _info_file is None:
+ _info_file = vol_file + '.info'
+ if not os.path.exists(_info_file):
+ logger.error('info_file not given and %s does not exists, please'
+ 'specify .vol.info file' % _info_file)
+ return
+
+ def info_file_to_dict():
+ ddict = {}
+ with builtin_open(info_file, "r") as _file:
+ lines = _file.readlines()
+ for line in lines:
+ if not '=' in line:
+ continue
+ l = line.rstrip().replace(' ', '')
+ l = l.split('#')[0]
+ key, value = l.split('=')
+ ddict[key.lower()] = value
+ return ddict
+
+ ddict = info_file_to_dict()
+ if 'num_x' not in ddict or 'num_y' not in ddict or 'num_z' not in ddict:
+ raise ValueError(
+ 'Unable to retrieve volume shape from %s' % info_file)
+
+ dimX = int(ddict['num_x'])
+ dimY = int(ddict['num_y'])
+ dimZ = int(ddict['num_z'])
+ shape = (dimZ, dimY, dimX)
+
+ return rawfile_to_h5_external_dataset(bin_file=vol_file,
+ output_url=output_url,
+ shape=shape,
+ dtype=vol_dtype,
+ overwrite=overwrite)
+
+
+def h5py_decode_value(value, encoding="utf-8", errors="surrogateescape"):
+ """Keep bytes when value cannot be decoded
+
+ :param value: bytes or array of bytes
+ :param encoding str:
+ :param errors str:
+ """
+ try:
+ if numpy.isscalar(value):
+ return value.decode(encoding, errors=errors)
+ str_item = [b.decode(encoding, errors=errors) for b in value.flat]
+ return numpy.array(str_item, dtype=object).reshape(value.shape)
+ except UnicodeDecodeError:
+ return value
+
+
+def h5py_encode_value(value, encoding="utf-8", errors="surrogateescape"):
+ """Keep string when value cannot be encoding
+
+ :param value: string or array of strings
+ :param encoding str:
+ :param errors str:
+ """
+ try:
+ if numpy.isscalar(value):
+ return value.encode(encoding, errors=errors)
+ bytes_item = [s.encode(encoding, errors=errors) for s in value.flat]
+ return numpy.array(bytes_item, dtype=object).reshape(value.shape)
+ except UnicodeEncodeError:
+ return value
+
+
+class H5pyDatasetReadWrapper:
+ """Wrapper to handle H5T_STRING decoding on-the-fly when reading
+ a dataset. Uniform behaviour for h5py 2.x and h5py 3.x
+
+ h5py abuses H5T_STRING with ASCII character set
+ to store `bytes`: dset[()] = b"..."
+ Therefore an H5T_STRING with ASCII encoding is not decoded by default.
+ """
+
+ H5PY_AUTODECODE_NONASCII = int(h5py.version.version.split(".")[0]) < 3
+
+ def __init__(self, dset, decode_ascii=False):
+ """
+ :param h5py.Dataset dset:
+ :param bool decode_ascii:
+ """
+ try:
+ string_info = h5py.h5t.check_string_dtype(dset.dtype)
+ except AttributeError:
+ # h5py < 2.10
+ try:
+ idx = dset.id.get_type().get_cset()
+ except AttributeError:
+ # Not an H5T_STRING
+ encoding = None
+ else:
+ encoding = ["ascii", "utf-8"][idx]
+ else:
+ # h5py >= 2.10
+ try:
+ encoding = string_info.encoding
+ except AttributeError:
+ # Not an H5T_STRING
+ encoding = None
+ if encoding == "ascii" and not decode_ascii:
+ encoding = None
+ if encoding != "ascii" and self.H5PY_AUTODECODE_NONASCII:
+ # Decoding is already done by the h5py library
+ encoding = None
+ if encoding == "ascii":
+ # ASCII can be decoded as UTF-8
+ encoding = "utf-8"
+ self._encoding = encoding
+ self._dset = dset
+
+ def __getitem__(self, args):
+ value = self._dset[args]
+ if self._encoding:
+ return h5py_decode_value(value, encoding=self._encoding)
+ else:
+ return value
+
+
+class H5pyAttributesReadWrapper:
+ """Wrapper to handle H5T_STRING decoding on-the-fly when reading
+ an attribute. Uniform behaviour for h5py 2.x and h5py 3.x
+
+ h5py abuses H5T_STRING with ASCII character set
+ to store `bytes`: dset[()] = b"..."
+ Therefore an H5T_STRING with ASCII encoding is not decoded by default.
+ """
+
+ H5PY_AUTODECODE = int(h5py.version.version.split(".")[0]) >= 3
+
+ def __init__(self, attrs, decode_ascii=False):
+ """
+ :param h5py.Dataset dset:
+ :param bool decode_ascii:
+ """
+ self._attrs = attrs
+ self._decode_ascii = decode_ascii
+
+ def __getitem__(self, args):
+ value = self._attrs[args]
+
+ # Get the string encoding (if a string)
+ try:
+ dtype = self._attrs.get_id(args).dtype
+ except AttributeError:
+ # h5py < 2.10
+ attr_id = h5py.h5a.open(self._attrs._id, self._attrs._e(args))
+ try:
+ idx = attr_id.get_type().get_cset()
+ except AttributeError:
+ # Not an H5T_STRING
+ return value
+ else:
+ encoding = ["ascii", "utf-8"][idx]
+ else:
+ # h5py >= 2.10
+ try:
+ encoding = h5py.h5t.check_string_dtype(dtype).encoding
+ except AttributeError:
+ # Not an H5T_STRING
+ return value
+
+ if self.H5PY_AUTODECODE:
+ if encoding == "ascii" and not self._decode_ascii:
+ # Undo decoding by the h5py library
+ return h5py_encode_value(value, encoding="utf-8")
+ else:
+ if encoding == "ascii" and self._decode_ascii:
+ # Decode ASCII as UTF-8 for consistency
+ return h5py_decode_value(value, encoding="utf-8")
+
+ # Decoding is already done by the h5py library
+ return value
+
+ def items(self):
+ for k in self._attrs.keys():
+ yield k, self[k]
+
+
+def h5py_read_dataset(dset, index=tuple(), decode_ascii=False):
+ """Read data from dataset object. UTF-8 strings will be
+ decoded while ASCII strings will only be decoded when
+ `decode_ascii=True`.
+
+ :param h5py.Dataset dset:
+ :param index: slicing (all by default)
+ :param bool decode_ascii:
+ """
+ return H5pyDatasetReadWrapper(dset, decode_ascii=decode_ascii)[index]
+
+
+def h5py_read_attribute(attrs, name, decode_ascii=False):
+ """Read data from attributes. UTF-8 strings will be
+ decoded while ASCII strings will only be decoded when
+ `decode_ascii=True`.
+
+ :param h5py.AttributeManager attrs:
+ :param str name: attribute name
+ :param bool decode_ascii:
+ """
+ return H5pyAttributesReadWrapper(attrs, decode_ascii=decode_ascii)[name]
+
+
+def h5py_read_attributes(attrs, decode_ascii=False):
+ """Read data from attributes. UTF-8 strings will be
+ decoded while ASCII strings will only be decoded when
+ `decode_ascii=True`.
+
+ :param h5py.AttributeManager attrs:
+ :param bool decode_ascii:
+ """
+ return dict(H5pyAttributesReadWrapper(attrs, decode_ascii=decode_ascii).items())
diff --git a/silx/math/__init__.py b/src/silx/math/__init__.py
index d8b7d81..d8b7d81 100644
--- a/silx/math/__init__.py
+++ b/src/silx/math/__init__.py
diff --git a/src/silx/math/_colormap.pyx b/src/silx/math/_colormap.pyx
new file mode 100644
index 0000000..70857f0
--- /dev/null
+++ b/src/silx/math/_colormap.pyx
@@ -0,0 +1,571 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides :func:`cmap` which applies a colormap to a dataset.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/05/2018"
+
+
+import os
+cimport cython
+from cython.parallel import prange
+cimport numpy as cnumpy
+from libc.math cimport frexp, sinh, sqrt
+from .math_compatibility cimport asinh, isnan, isfinite, lrint, INFINITY, NAN
+
+import logging
+import numbers
+
+import numpy
+
+__all__ = ['cmap']
+
+_logger = logging.getLogger(__name__)
+
+
+cdef int DEFAULT_NUM_THREADS
+if hasattr(os, 'sched_getaffinity'):
+ DEFAULT_NUM_THREADS = min(4, len(os.sched_getaffinity(0)))
+elif os.cpu_count() is not None:
+ DEFAULT_NUM_THREADS = min(4, os.cpu_count())
+else: # Fallback
+ DEFAULT_NUM_THREADS = 1
+# Number of threads to use for the computation (initialized to up to 4)
+
+cdef int USE_OPENMP_THRESHOLD = 1000
+"""OpenMP is not used for arrays with less elements than this threshold"""
+
+# Supported data types
+ctypedef fused data_types:
+ cnumpy.uint8_t
+ cnumpy.int8_t
+ cnumpy.uint16_t
+ cnumpy.int16_t
+ cnumpy.uint32_t
+ cnumpy.int32_t
+ cnumpy.uint64_t
+ cnumpy.int64_t
+ float
+ double
+ long double
+
+
+# Data types using a LUT to apply the colormap
+ctypedef fused lut_types:
+ cnumpy.uint8_t
+ cnumpy.int8_t
+ cnumpy.uint16_t
+ cnumpy.int16_t
+
+
+# Data types using default colormap implementation
+ctypedef fused default_types:
+ cnumpy.uint32_t
+ cnumpy.int32_t
+ cnumpy.uint64_t
+ cnumpy.int64_t
+ float
+ double
+ long double
+
+
+# Supported colors/output types
+ctypedef fused image_types:
+ cnumpy.uint8_t
+ float
+
+
+# Normalization
+
+ctypedef double (*NormalizationFunction)(double) nogil
+
+
+cdef class Normalization:
+ """Base class for colormap normalization"""
+
+ def apply(self, data, double vmin, double vmax):
+ """Apply normalization.
+
+ :param Union[float,numpy.ndarray] data:
+ :param float vmin: Lower bound of the range
+ :param float vmax: Upper bound of the range
+ :rtype: Union[float,numpy.ndarray]
+ """
+ cdef int length
+ cdef double[:] result
+
+ if isinstance(data, numbers.Real):
+ return self.apply_double(<double> data, vmin, vmax)
+ else:
+ data = numpy.array(data, copy=False)
+ length = <int> data.size
+ result = numpy.empty(length, dtype=numpy.float64)
+ data1d = numpy.ravel(data)
+ for index in range(length):
+ result[index] = self.apply_double(
+ <double> data1d[index], vmin, vmax)
+ return numpy.array(result).reshape(data.shape)
+
+ def revert(self, data, double vmin, double vmax):
+ """Revert normalization.
+
+ :param Union[float,numpy.ndarray] data:
+ :param float vmin: Lower bound of the range
+ :param float vmax: Upper bound of the range
+ :rtype: Union[float,numpy.ndarray]
+ """
+ cdef int length
+ cdef double[:] result
+
+ if isinstance(data, numbers.Real):
+ return self.revert_double(<double> data, vmin, vmax)
+ else:
+ data = numpy.array(data, copy=False)
+ length = <int> data.size
+ result = numpy.empty(length, dtype=numpy.float64)
+ data1d = numpy.ravel(data)
+ for index in range(length):
+ result[index] = self.revert_double(
+ <double> data1d[index], vmin, vmax)
+ return numpy.array(result).reshape(data.shape)
+
+ cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ """Apply normalization to a floating point value
+
+ Override in subclass
+
+ :param float value:
+ :param float vmin: Lower bound of the range
+ :param float vmax: Upper bound of the range
+ """
+ return value
+
+ cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ """Apply inverse of normalization to a floating point value
+
+ Override in subclass
+
+ :param float value:
+ :param float vmin: Lower bound of the range
+ :param float vmax: Upper bound of the range
+ """
+ return value
+
+
+cdef class LinearNormalization(Normalization):
+ """Linear normalization"""
+
+ cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ return value
+
+ cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ return value
+
+
+cdef class LogarithmicNormalization(Normalization):
+ """Logarithmic normalization using a fast log approximation"""
+ cdef:
+ readonly int lutsize
+ readonly double[::1] lut # LUT used for fast log approximation
+
+ def __cinit__(self, int lutsize=4096):
+ # Initialize log approximation LUT
+ self.lutsize = lutsize
+ self.lut = numpy.log2(
+ numpy.linspace(0.5, 1., lutsize + 1,
+ endpoint=True).astype(numpy.float64))
+ # index_lut can overflow of 1
+ self.lut[lutsize] = self.lut[lutsize - 1]
+
+ def __dealloc__(self):
+ self.lut = None
+
+ @cython.wraparound(False)
+ @cython.boundscheck(False)
+ @cython.nonecheck(False)
+ @cython.cdivision(True)
+ cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ """Return log10(value) fast approximation based on LUT"""
+ cdef double result = NAN # if value < 0.0 or value == NAN
+ cdef int exponent, index_lut
+ cdef double mantissa # in [0.5, 1) unless value == 0 NaN or +/-inf
+
+ if value <= 0.0 or not isfinite(value):
+ if value == 0.0:
+ result = - INFINITY
+ elif value > 0.0: # i.e., value = +INFINITY
+ result = value # i.e. +INFINITY
+ else:
+ mantissa = frexp(value, &exponent)
+ index_lut = lrint(self.lutsize * 2 * (mantissa - 0.5))
+ # 1/log2(10) = 0.30102999566398114
+ result = 0.30102999566398114 * (<double> exponent +
+ self.lut[index_lut])
+ return result
+
+ cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ return 10**value
+
+
+cdef class ArcsinhNormalization(Normalization):
+ """Inverse hyperbolic sine normalization"""
+
+ cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ return asinh(value)
+
+ cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ return sinh(value)
+
+
+cdef class SqrtNormalization(Normalization):
+ """Square root normalization"""
+
+ cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ return sqrt(value)
+
+ cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ return value**2
+
+
+cdef class PowerNormalization(Normalization):
+ """Gamma correction:
+
+ Linear normalization to [0, 1] followed by power normalization.
+
+ :param gamma: Gamma correction factor
+ """
+
+ cdef:
+ readonly double gamma
+
+ def __cinit__(self, double gamma):
+ self.gamma = gamma
+
+ def __init__(self, gamma):
+ # Needed for multiple inheritance to work
+ pass
+
+ cdef double apply_double(self, double value, double vmin, double vmax) nogil:
+ if vmin == vmax:
+ return 0.
+ elif value <= vmin:
+ return 0.
+ elif value >= vmax:
+ return 1.
+ else:
+ return ((value - vmin) / (vmax - vmin))**self.gamma
+
+ cdef double revert_double(self, double value, double vmin, double vmax) nogil:
+ if value <= 0.:
+ return vmin
+ elif value >= 1.:
+ return vmax
+ else:
+ return vmin + (vmax - vmin) * value**(1.0/self.gamma)
+
+
+# Colormap
+
+@cython.wraparound(False)
+@cython.boundscheck(False)
+@cython.nonecheck(False)
+@cython.cdivision(True)
+cdef image_types[:, ::1] compute_cmap(
+ default_types[:] data,
+ image_types[:, ::1] colors,
+ Normalization normalization,
+ double vmin,
+ double vmax,
+ image_types[::1] nan_color):
+ """Apply colormap to data.
+
+ :param data: Input data
+ :param colors: Colors look-up-table
+ :param vmin: Lower bound of the colormap range
+ :param vmax: Upper bound of the colormap range
+ :param nan_color: Color to use for NaN value
+ :param normalization: Normalization to apply
+ :return: Data converted to colors
+ """
+ cdef image_types[:, ::1] output
+ cdef double scale, value, normalized_vmin, normalized_vmax
+ cdef int length, nb_channels, nb_colors
+ cdef int channel, index, lut_index, num_threads
+
+ nb_colors = <int> colors.shape[0]
+ nb_channels = <int> colors.shape[1]
+ length = <int> data.size
+
+ output = numpy.empty((length, nb_channels),
+ dtype=numpy.array(colors, copy=False).dtype)
+
+ normalized_vmin = normalization.apply_double(vmin, vmin, vmax)
+ normalized_vmax = normalization.apply_double(vmax, vmin, vmax)
+
+ if not isfinite(normalized_vmin) or not isfinite(normalized_vmax):
+ raise ValueError('Colormap range is not valid')
+
+ if normalized_vmin == normalized_vmax:
+ scale = 0.
+ else:
+ scale = nb_colors / (normalized_vmax - normalized_vmin)
+
+ if length < USE_OPENMP_THRESHOLD:
+ num_threads = 1
+ else:
+ num_threads = min(
+ DEFAULT_NUM_THREADS,
+ int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)))
+
+ with nogil:
+ for index in prange(length, num_threads=num_threads):
+ value = normalization.apply_double(
+ <double> data[index], vmin, vmax)
+
+ # Handle NaN
+ if isnan(value):
+ for channel in range(nb_channels):
+ output[index, channel] = nan_color[channel]
+ continue
+
+ if value <= normalized_vmin:
+ lut_index = 0
+ elif value >= normalized_vmax:
+ lut_index = nb_colors - 1
+ else:
+ lut_index = <int>((value - normalized_vmin) * scale)
+ # Index can overflow of 1
+ if lut_index >= nb_colors:
+ lut_index = nb_colors - 1
+
+ for channel in range(nb_channels):
+ output[index, channel] = colors[lut_index, channel]
+
+ return output
+
+@cython.wraparound(False)
+@cython.boundscheck(False)
+@cython.nonecheck(False)
+@cython.cdivision(True)
+cdef image_types[:, ::1] compute_cmap_with_lut(
+ lut_types[:] data,
+ image_types[:, ::1] colors,
+ Normalization normalization,
+ double vmin,
+ double vmax,
+ image_types[::1] nan_color):
+ """Convert data to colors using look-up table to speed the process.
+
+ Only supports data of types: uint8, uint16, int8, int16.
+
+ :param data: Input data
+ :param colors: Colors look-up-table
+ :param vmin: Lower bound of the colormap range
+ :param vmax: Upper bound of the colormap range
+ :param nan_color: Color to use for NaN values
+ :param normalization: Normalization to apply
+ :return: The generated image
+ """
+ cdef image_types[:, ::1] output
+ cdef double[:] values
+ cdef image_types[:, ::1] lut
+ cdef int type_min, type_max
+ cdef int nb_channels, length
+ cdef int channel, index, lut_index, num_threads
+
+ length = <int> data.size
+ nb_channels = <int> colors.shape[1]
+
+ if lut_types is cnumpy.int8_t:
+ type_min = -128
+ type_max = 127
+ elif lut_types is cnumpy.uint8_t:
+ type_min = 0
+ type_max = 255
+ elif lut_types is cnumpy.int16_t:
+ type_min = -32768
+ type_max = 32767
+ else: # uint16_t
+ type_min = 0
+ type_max = 65535
+
+ colors_dtype = numpy.array(colors).dtype
+
+ values = numpy.arange(type_min, type_max + 1, dtype=numpy.float64)
+ lut = compute_cmap(
+ values, colors, normalization, vmin, vmax, nan_color)
+
+ output = numpy.empty((length, nb_channels), dtype=colors_dtype)
+
+ if length < USE_OPENMP_THRESHOLD:
+ num_threads = 1
+ else:
+ num_threads = min(
+ DEFAULT_NUM_THREADS,
+ int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)))
+
+ with nogil:
+ # Apply LUT
+ for index in prange(length, num_threads=num_threads):
+ lut_index = data[index] - type_min
+ for channel in range(nb_channels):
+ output[index, channel] = lut[lut_index, channel]
+
+ return output
+
+
+# Normalizations without parameters
+_BASIC_NORMALIZATIONS = {
+ 'linear': LinearNormalization(),
+ 'log': LogarithmicNormalization(),
+ 'arcsinh': ArcsinhNormalization(),
+ 'sqrt': SqrtNormalization(),
+ }
+
+
+@cython.wraparound(False)
+@cython.boundscheck(False)
+@cython.nonecheck(False)
+@cython.cdivision(True)
+def _cmap(data_types[:] data,
+ image_types[:, ::1] colors,
+ Normalization normalization,
+ double vmin,
+ double vmax,
+ image_types[::1] nan_color):
+ """Implementation of colormap.
+
+ Use :func:`cmap`.
+
+ :param data: Input data
+ :param colors: Colors look-up-table
+ :param normalization: Normalization object to apply
+ :param vmin: Lower bound of the colormap range
+ :param vmax: Upper bound of the colormap range
+ :param nan_color: Color to use for NaN value.
+ :return: The generated image
+ """
+ cdef image_types[:, ::1] output
+
+ # Proxy for calling the right implementation depending on data type
+ if data_types in lut_types: # Use LUT implementation
+ output = compute_cmap_with_lut(
+ data, colors, normalization, vmin, vmax, nan_color)
+
+ elif data_types in default_types: # Use default implementation
+ output = compute_cmap(
+ data, colors, normalization, vmin, vmax, nan_color)
+
+ else:
+ raise ValueError('Unsupported data type')
+
+ return numpy.array(output, copy=False)
+
+
+def cmap(data not None,
+ colors not None,
+ double vmin,
+ double vmax,
+ normalization='linear',
+ nan_color=None):
+ """Convert data to colors with provided colors look-up table.
+
+ :param numpy.ndarray data: The input data
+ :param numpy.ndarray colors: Color look-up table as a 2D array.
+ It MUST be of type uint8 or float32
+ :param vmin: Data value to map to the beginning of colormap.
+ :param vmax: Data value to map to the end of the colormap.
+ :param Union[str,Normalization] normalization:
+ Either a :class:`Normalization` instance or a str in:
+
+ - 'linear' (default)
+ - 'log'
+ - 'arcsinh'
+ - 'sqrt'
+ - 'gamma'
+
+ :param nan_color: Color to use for NaN value.
+ Default: A color with all channels set to 0
+ :return: Array of colors. The shape of the
+ returned array is that of data array + the last dimension of colors.
+ The dtype of the returned array is that of the colors array.
+ :rtype: numpy.ndarray
+ :raises ValueError: If data of colors dtype is not supported
+ """
+ cdef int nb_channels
+ cdef Normalization norm
+
+ # Make data a numpy array of native endian type (no need for contiguity)
+ data = numpy.array(data, copy=False)
+ if data.dtype.kind not in ('b', 'i', 'u', 'f'):
+ raise ValueError("Unsupported data dtype: %s" % data.dtype)
+ native_endian_dtype = data.dtype.newbyteorder('N')
+ if native_endian_dtype.kind == 'f' and native_endian_dtype.itemsize == 2:
+ native_endian_dtype = "=f4" # Use native float32 instead of float16
+ data = numpy.array(data, copy=False, dtype=native_endian_dtype)
+
+ # Make colors a contiguous array of native endian type
+ colors = numpy.array(colors, copy=False)
+ if colors.dtype.kind == 'f':
+ colors_dtype = numpy.dtype('float32')
+ elif colors.dtype.kind in ('b', 'i', 'u'):
+ colors_dtype = numpy.dtype('uint8')
+ else:
+ raise ValueError("Unsupported colors dtype: %s" % colors.dtype)
+ if (colors_dtype.kind != colors.dtype.kind or
+ colors_dtype.itemsize != colors.dtype.itemsize):
+ # Do not warn if only endianness has changed
+ _logger.warning("Casting colors from %s to %s", colors.dtype, colors_dtype)
+ nb_channels = colors.shape[colors.ndim - 1]
+ colors = numpy.ascontiguousarray(colors, dtype=colors_dtype)
+
+ # Make normalization a Normalization object
+ if isinstance(normalization, str):
+ norm = _BASIC_NORMALIZATIONS.get(normalization, None)
+ if norm is None:
+ raise ValueError('Unsupported normalization %s' % normalization)
+ else:
+ norm = normalization
+
+ # Check nan_color
+ if nan_color is None:
+ nan_color = numpy.zeros((nb_channels,), dtype=colors.dtype)
+ else:
+ nan_color = numpy.ascontiguousarray(
+ nan_color, dtype=colors.dtype).reshape(-1)
+ assert nan_color.shape == (nb_channels,)
+
+ image = _cmap(
+ data.reshape(-1),
+ colors.reshape(-1, nb_channels),
+ norm,
+ vmin,
+ vmax,
+ nan_color)
+ image.shape = data.shape + (nb_channels,)
+
+ return image
diff --git a/silx/math/calibration.py b/src/silx/math/calibration.py
index 658e2dc..658e2dc 100644
--- a/silx/math/calibration.py
+++ b/src/silx/math/calibration.py
diff --git a/silx/math/chistogramnd.pyx b/src/silx/math/chistogramnd.pyx
index 8484f35..8484f35 100644
--- a/silx/math/chistogramnd.pyx
+++ b/src/silx/math/chistogramnd.pyx
diff --git a/silx/math/chistogramnd_lut.pyx b/src/silx/math/chistogramnd_lut.pyx
index 3a3f05e..3a3f05e 100644
--- a/silx/math/chistogramnd_lut.pyx
+++ b/src/silx/math/chistogramnd_lut.pyx
diff --git a/src/silx/math/colormap.py b/src/silx/math/colormap.py
new file mode 100644
index 0000000..43b8949
--- /dev/null
+++ b/src/silx/math/colormap.py
@@ -0,0 +1,450 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""This module provides helper functions for applying colormaps to datasets"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "25/08/2021"
+
+
+import collections
+import warnings
+import numpy
+
+from ..resources import resource_filename as _resource_filename
+from .combo import min_max as _min_max
+from . import _colormap
+from ._colormap import cmap # noqa
+
+
+__all__ = ["apply_colormap", "cmap"]
+
+
+_LUT_DESCRIPTION = collections.namedtuple("_LUT_DESCRIPTION", ["source", "cursor_color"])
+"""Description of a LUT for internal purpose."""
+
+
+_AVAILABLE_LUTS = collections.OrderedDict([
+ ('gray', _LUT_DESCRIPTION('builtin', '#ff66ff')),
+ ('reversed gray', _LUT_DESCRIPTION('builtin', '#ff66ff')),
+ ('red', _LUT_DESCRIPTION('builtin', '#00ff00')),
+ ('green', _LUT_DESCRIPTION('builtin', '#ff66ff')),
+ ('blue', _LUT_DESCRIPTION('builtin', '#ffff00')),
+ ('viridis', _LUT_DESCRIPTION('resource', '#ff66ff')),
+ ('cividis', _LUT_DESCRIPTION('resource', '#ff66ff')),
+ ('magma', _LUT_DESCRIPTION('resource', '#00ff00')),
+ ('inferno', _LUT_DESCRIPTION('resource', '#00ff00')),
+ ('plasma', _LUT_DESCRIPTION('resource', '#00ff00')),
+ ('temperature', _LUT_DESCRIPTION('builtin', '#ff66ff')),
+])
+"""Description for internal porpose of all the default LUT provided by the library."""
+
+
+# Colormap loader
+
+_COLORMAP_CACHE = {}
+"""Cache already used colormaps as name: color LUT"""
+
+
+def array_to_rgba8888(colors):
+ """Convert colors from a numpy array using float (0..1) int or uint
+ (0..255) to uint8 RGBA.
+
+ :param numpy.ndarray colors: Array of float int or uint colors to convert
+ :return: colors as uint8
+ :rtype: numpy.ndarray
+ """
+ assert len(colors.shape) == 2
+ assert colors.shape[1] in (3, 4)
+
+ if colors.dtype == numpy.uint8:
+ pass
+ elif colors.dtype.kind == 'f':
+ # Each bin is [N, N+1[ except the last one: [255, 256]
+ colors = numpy.clip(colors.astype(numpy.float64) * 256, 0., 255.)
+ colors = colors.astype(numpy.uint8)
+ elif colors.dtype.kind in 'iu':
+ colors = numpy.clip(colors, 0, 255)
+ colors = colors.astype(numpy.uint8)
+
+ if colors.shape[1] == 3:
+ tmp = numpy.empty((len(colors), 4), dtype=numpy.uint8)
+ tmp[:, 0:3] = colors
+ tmp[:, 3] = 255
+ colors = tmp
+
+ return colors
+
+
+def _create_colormap_lut(name):
+ """Returns the color LUT corresponding to a colormap name
+
+ :param str name: Name of the colormap to load
+ :returns: Corresponding table of colors
+ :rtype: numpy.ndarray
+ :raise ValueError: If no colormap corresponds to name
+ """
+ description = _AVAILABLE_LUTS.get(name)
+ if description is not None:
+ if description.source == "builtin":
+ # Build colormap LUT
+ lut = numpy.zeros((256, 4), dtype=numpy.uint8)
+ lut[:, 3] = 255
+
+ if name == 'gray':
+ lut[:, :3] = numpy.arange(256, dtype=numpy.uint8).reshape(-1, 1)
+ elif name == 'reversed gray':
+ lut[:, :3] = numpy.arange(255, -1, -1, dtype=numpy.uint8).reshape(-1, 1)
+ elif name == 'red':
+ lut[:, 0] = numpy.arange(256, dtype=numpy.uint8)
+ elif name == 'green':
+ lut[:, 1] = numpy.arange(256, dtype=numpy.uint8)
+ elif name == 'blue':
+ lut[:, 2] = numpy.arange(256, dtype=numpy.uint8)
+ elif name == 'temperature':
+ # Red
+ lut[128:192, 0] = numpy.arange(2, 255, 4, dtype=numpy.uint8)
+ lut[192:, 0] = 255
+ # Green
+ lut[:64, 1] = numpy.arange(0, 255, 4, dtype=numpy.uint8)
+ lut[64:192, 1] = 255
+ lut[192:, 1] = numpy.arange(252, -1, -4, dtype=numpy.uint8)
+ # Blue
+ lut[:64, 2] = 255
+ lut[64:128, 2] = numpy.arange(254, 0, -4, dtype=numpy.uint8)
+ else:
+ raise RuntimeError("Built-in colormap not implemented")
+ return lut
+
+ elif description.source == "resource":
+ # Load colormap LUT
+ colors = numpy.load(_resource_filename("gui/colormaps/%s.npy" % name))
+ # Convert to uint8 and add alpha channel
+ lut = array_to_rgba8888(colors)
+ return lut
+
+ else:
+ raise RuntimeError("Internal LUT source '%s' unsupported" % description.source)
+
+ raise ValueError("Unknown colormap '%s'" % name)
+
+
+def register_colormap(name, lut, cursor_color='#000000'):
+ """Register a custom colormap LUT
+
+ It can override existing LUT names.
+
+ :param str name: Name of the LUT as defined to configure colormaps
+ :param numpy.ndarray lut: The custom LUT to register.
+ Nx3 or Nx4 numpy array of RGB(A) colors,
+ either uint8 or float in [0, 1].
+ :param str cursor_color: Color used to display overlay over images using
+ colormap with this LUT.
+ """
+ description = _LUT_DESCRIPTION('user', cursor_color)
+ colors = array_to_rgba8888(lut)
+ _AVAILABLE_LUTS[name] = description
+
+ # Register the cache as the LUT was already loaded
+ _COLORMAP_CACHE[name] = colors
+
+
+def get_registered_colormaps():
+ """Returns currently registered colormap names"""
+ return tuple(_AVAILABLE_LUTS.keys())
+
+
+def get_colormap_cursor_color(name):
+ """Get a color suitable for overlay over a colormap.
+
+ :param str name: The name of the colormap.
+ :return: Name of the color.
+ :rtype: str
+ """
+ description = _AVAILABLE_LUTS.get(name, None)
+ if description is not None:
+ color = description.cursor_color
+ if color is not None:
+ return color
+ return 'black'
+
+
+def get_colormap_lut(name):
+ """Returns the color LUT corresponding to a colormap name
+
+ :param str name: Name of the colormap to load
+ :returns: Corresponding table of colors
+ :rtype: numpy.ndarray
+ :raise ValueError: If no colormap corresponds to name
+ """
+ name = str(name)
+ if name not in _COLORMAP_CACHE:
+ lut = _create_colormap_lut(name)
+ _COLORMAP_CACHE[name] = lut
+ return _COLORMAP_CACHE[name]
+
+
+# Normalizations
+
+class _NormalizationMixIn:
+ """Colormap normalization mix-in class"""
+
+ DEFAULT_RANGE = 0, 1
+ """Fallback for (vmin, vmax)"""
+
+ def is_valid(self, value):
+ """Check if a value is in the valid range for this normalization.
+
+ Override in subclass.
+
+ :param Union[float,numpy.ndarray] value:
+ :rtype: Union[bool,numpy.ndarray]
+ """
+ if isinstance(value, collections.abc.Iterable):
+ return numpy.ones_like(value, dtype=numpy.bool_)
+ else:
+ return True
+
+ def autoscale(self, data, mode):
+ """Returns range for given data and autoscale mode.
+
+ :param Union[None,numpy.ndarray] data:
+ :param str mode: Autoscale mode: 'minmax' or 'stddev3'
+ :returns: Range as (min, max)
+ :rtype: Tuple[float,float]
+ """
+ data = None if data is None else numpy.array(data, copy=False)
+ if data is None or data.size == 0:
+ return self.DEFAULT_RANGE
+
+ if mode == "minmax":
+ vmin, vmax = self.autoscale_minmax(data)
+ elif mode == "stddev3":
+ dmin, dmax = self.autoscale_minmax(data)
+ stdmin, stdmax = self.autoscale_mean3std(data)
+ if dmin is None:
+ vmin = stdmin
+ elif stdmin is None:
+ vmin = dmin
+ else:
+ vmin = max(dmin, stdmin)
+
+ if dmax is None:
+ vmax = stdmax
+ elif stdmax is None:
+ vmax = dmax
+ else:
+ vmax = min(dmax, stdmax)
+
+ else:
+ raise ValueError('Unsupported mode: %s' % mode)
+
+ # Check returned range and handle fallbacks
+ if vmin is None or not numpy.isfinite(vmin):
+ vmin = self.DEFAULT_RANGE[0]
+ if vmax is None or not numpy.isfinite(vmax):
+ vmax = self.DEFAULT_RANGE[1]
+ if vmax < vmin:
+ vmax = vmin
+ return float(vmin), float(vmax)
+
+ def autoscale_minmax(self, data):
+ """Autoscale using min/max
+
+ :param numpy.ndarray data:
+ :returns: (vmin, vmax)
+ :rtype: Tuple[float,float]
+ """
+ data = data[self.is_valid(data)]
+ if data.size == 0:
+ return None, None
+ result = _min_max(data, min_positive=False, finite=True)
+ return result.minimum, result.maximum
+
+ def autoscale_mean3std(self, data):
+ """Autoscale using mean+/-3std
+
+ This implementation only works for normalization that do NOT
+ use the data range.
+ Override this method for normalization using the range.
+
+ :param numpy.ndarray data:
+ :returns: (vmin, vmax)
+ :rtype: Tuple[float,float]
+ """
+ # Use [0, 1] as data range for normalization not using range
+ normdata = self.apply(data, 0., 1.)
+ if normdata.dtype.kind == 'f': # Replaces inf by NaN
+ normdata[numpy.isfinite(normdata) == False] = numpy.nan
+ if normdata.size == 0: # Fallback
+ return None, None
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore nanmean "Mean of empty slice" warning and
+ # nanstd "Degrees of freedom <= 0 for slice" warning
+ mean, std = numpy.nanmean(normdata), numpy.nanstd(normdata)
+
+ return self.revert(mean - 3 * std, 0., 1.), self.revert(mean + 3 * std, 0., 1.)
+
+
+class _LinearNormalizationMixIn(_NormalizationMixIn):
+ """Colormap normalization mix-in class specific to autoscale taken from initial range"""
+
+ def autoscale_mean3std(self, data):
+ """Autoscale using mean+/-3std
+
+ Do the autoscale on the data itself, not the normalized data.
+
+ :param numpy.ndarray data:
+ :returns: (vmin, vmax)
+ :rtype: Tuple[float,float]
+ """
+ if data.dtype.kind == 'f': # Replaces inf by NaN
+ data = numpy.array(data, copy=True) # Work on a copy
+ data[numpy.isfinite(data) == False] = numpy.nan
+ if data.size == 0: # Fallback
+ return None, None
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', category=RuntimeWarning)
+ # Ignore nanmean "Mean of empty slice" warning and
+ # nanstd "Degrees of freedom <= 0 for slice" warning
+ mean, std = numpy.nanmean(data), numpy.nanstd(data)
+ return mean - 3 * std, mean + 3 * std
+
+
+class LinearNormalization(_colormap.LinearNormalization, _LinearNormalizationMixIn):
+ """Linear normalization"""
+ def __init__(self):
+ _colormap.LinearNormalization.__init__(self)
+ _LinearNormalizationMixIn.__init__(self)
+
+
+class LogarithmicNormalization(_colormap.LogarithmicNormalization, _NormalizationMixIn):
+ """Logarithm normalization"""
+
+ DEFAULT_RANGE = 1, 10
+
+ def __init__(self):
+ _colormap.LogarithmicNormalization.__init__(self)
+ _NormalizationMixIn.__init__(self)
+
+ def is_valid(self, value):
+ return value > 0.
+
+ def autoscale_minmax(self, data):
+ result = _min_max(data, min_positive=True, finite=True)
+ return result.min_positive, result.maximum
+
+
+class SqrtNormalization(_colormap.SqrtNormalization, _NormalizationMixIn):
+ """Square root normalization"""
+
+ DEFAULT_RANGE = 0, 1
+
+ def __init__(self):
+ _colormap.SqrtNormalization.__init__(self)
+ _NormalizationMixIn.__init__(self)
+
+ def is_valid(self, value):
+ return value >= 0.
+
+
+class GammaNormalization(_colormap.PowerNormalization, _LinearNormalizationMixIn):
+ """Gamma correction normalization:
+
+ Linear normalization to [0, 1] followed by power normalization.
+
+ :param gamma: Gamma correction factor
+ """
+ def __init__(self, gamma):
+ _colormap.PowerNormalization.__init__(self, gamma)
+ _LinearNormalizationMixIn.__init__(self)
+
+
+# Backward compatibility
+PowerNormalization = GammaNormalization
+
+
+class ArcsinhNormalization(_colormap.ArcsinhNormalization, _NormalizationMixIn):
+ """Inverse hyperbolic sine normalization"""
+
+ def __init__(self):
+ _colormap.ArcsinhNormalization.__init__(self)
+ _NormalizationMixIn.__init__(self)
+
+
+# Colormap function
+
+_BASIC_NORMALIZATIONS = {
+ "linear": LinearNormalization(),
+ "log": LogarithmicNormalization(),
+ "sqrt": SqrtNormalization(),
+ "arcsinh": ArcsinhNormalization(),
+}
+
+_DEFAULT_NAN_COLOR = 255, 255, 255, 0
+
+def apply_colormap(data,
+ colormap: str,
+ norm: str="linear",
+ autoscale: str="minmax",
+ vmin=None,
+ vmax=None,
+ gamma=1.0):
+ """Apply colormap to data with given normalization and autoscale.
+
+ :param numpy.ndarray data: Data on which to apply the colormap
+ :param str colormap: Name of the colormap to use
+ :param str norm: Normalization to use
+ :param str autoscale: Autoscale mode: "minmax" (default) or "stddev3"
+ :param vmin: Lower bound, None (default) to autoscale
+ :param vmax: Upper bound, None (default) to autoscale
+ :param float gamma:
+ Gamma correction parameter (used only for "gamma" normalization)
+ :returns: Array of colors
+ """
+ colors = get_colormap_lut(colormap)
+
+ if norm == "gamma":
+ normalizer = GammaNormalization(gamma)
+ else:
+ normalizer = _BASIC_NORMALIZATIONS[norm]
+
+ if vmin is None or vmax is None:
+ auto_vmin, auto_vmax = normalizer.autoscale(data, autoscale)
+ if vmin is None: # Set vmin respecting provided vmax
+ vmin = auto_vmin if vmax is None else min(auto_vmin, vmax)
+ if vmax is None:
+ vmax = max(auto_vmax, vmin) # Handle max_ <= 0 for log scale
+
+ return _colormap.cmap(
+ data,
+ colors,
+ vmin,
+ vmax,
+ normalizer,
+ _DEFAULT_NAN_COLOR,
+ )
diff --git a/silx/math/combo.pyx b/src/silx/math/combo.pyx
index e24edda..e24edda 100644
--- a/silx/math/combo.pyx
+++ b/src/silx/math/combo.pyx
diff --git a/silx/math/fft/__init__.py b/src/silx/math/fft/__init__.py
index ea12cd6..ea12cd6 100644
--- a/silx/math/fft/__init__.py
+++ b/src/silx/math/fft/__init__.py
diff --git a/silx/math/fft/basefft.py b/src/silx/math/fft/basefft.py
index 854ca37..854ca37 100644
--- a/silx/math/fft/basefft.py
+++ b/src/silx/math/fft/basefft.py
diff --git a/silx/math/fft/clfft.py b/src/silx/math/fft/clfft.py
index dad8ec1..dad8ec1 100644
--- a/silx/math/fft/clfft.py
+++ b/src/silx/math/fft/clfft.py
diff --git a/silx/math/fft/cufft.py b/src/silx/math/fft/cufft.py
index 848f3e6..848f3e6 100644
--- a/silx/math/fft/cufft.py
+++ b/src/silx/math/fft/cufft.py
diff --git a/silx/math/fft/fft.py b/src/silx/math/fft/fft.py
index eb0d73b..eb0d73b 100644
--- a/silx/math/fft/fft.py
+++ b/src/silx/math/fft/fft.py
diff --git a/silx/math/fft/fftw.py b/src/silx/math/fft/fftw.py
index ff6966c..ff6966c 100644
--- a/silx/math/fft/fftw.py
+++ b/src/silx/math/fft/fftw.py
diff --git a/silx/math/fft/npfft.py b/src/silx/math/fft/npfft.py
index 20351de..20351de 100644
--- a/silx/math/fft/npfft.py
+++ b/src/silx/math/fft/npfft.py
diff --git a/silx/math/fft/setup.py b/src/silx/math/fft/setup.py
index 76bb864..76bb864 100644
--- a/silx/math/fft/setup.py
+++ b/src/silx/math/fft/setup.py
diff --git a/src/silx/math/fft/test/__init__.py b/src/silx/math/fft/test/__init__.py
new file mode 100644
index 0000000..ad9836c
--- /dev/null
+++ b/src/silx/math/fft/test/__init__.py
@@ -0,0 +1,23 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
diff --git a/src/silx/math/fft/test/test_fft.py b/src/silx/math/fft/test/test_fft.py
new file mode 100644
index 0000000..19becb8
--- /dev/null
+++ b/src/silx/math/fft/test/test_fft.py
@@ -0,0 +1,257 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the FFT module"""
+
+import numpy as np
+import unittest
+import logging
+import pytest
+try:
+ from scipy.misc import ascent
+ __have_scipy = True
+except ImportError:
+ __have_scipy = False
+from silx.utils.testutils import ParametricTestCase
+from silx.math.fft.fft import FFT
+from silx.math.fft.clfft import __have_clfft__
+from silx.math.fft.cufft import __have_cufft__
+from silx.math.fft.fftw import __have_fftw__
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransformInfos(object):
+ def __init__(self):
+ self.dimensions = [
+ "1D",
+ "batched_1D",
+ "2D",
+ "batched_2D",
+ "3D",
+ ]
+ self.modes = {
+ "R2C": np.float32,
+ "R2C_double": np.float64,
+ "C2C": np.complex64,
+ "C2C_double": np.complex128,
+ }
+ self.sizes = {
+ "1D": [(128,), (127,)],
+ "2D": [(128, 128), (128, 127), (127, 128), (127, 127)],
+ "3D": [(64, 64, 64), (64, 64, 63), (64, 63, 64), (63, 64, 64),
+ (64, 63, 63), (63, 64, 63), (63, 63, 64), (63, 63, 63)]
+ }
+ self.axes = {
+ "1D": None,
+ "batched_1D": (-1,),
+ "2D": None,
+ "batched_2D": (-2, -1),
+ "3D": None,
+ }
+ self.sizes["batched_1D"] = self.sizes["2D"]
+ self.sizes["batched_2D"] = self.sizes["3D"]
+
+
+class Data(object):
+ def __init__(self):
+ self.data = ascent().astype("float32")
+ self.data1d = self.data[:, 0] # non-contiguous data
+ self.data3d = np.tile(self.data[:64, :64], (64, 1, 1))
+ self.data_refs = {
+ 1: self.data1d,
+ 2: self.data,
+ 3: self.data3d,
+ }
+
+
+@unittest.skipUnless(__have_scipy, "scipy is missing")
+@pytest.mark.usefixtures("test_options_class_attr")
+class TestFFT(ParametricTestCase):
+ """Test cuda/opencl/fftw backends of FFT"""
+
+ def setUp(self):
+ self.tol = {
+ np.dtype("float32"): 1e-3,
+ np.dtype("float64"): 1e-9,
+ np.dtype("complex64"): 1e-3,
+ np.dtype("complex128"): 1e-9,
+ }
+ self.transform_infos = TransformInfos()
+ self.test_data = Data()
+
+ @staticmethod
+ def calc_mae(arr1, arr2):
+ """
+ Compute the Max Absolute Error between two arrays
+ """
+ return np.max(np.abs(arr1 - arr2))
+
+ @unittest.skipIf(not __have_cufft__,
+ "cuda back-end requires pycuda and scikit-cuda")
+ def test_cuda(self):
+ import pycuda.autoinit
+
+ # Error is higher when using cuda. fast_math mode ?
+ self.tol[np.dtype("float32")] *= 2
+
+ self.__run_tests(backend="cuda")
+
+ @unittest.skipIf(not __have_clfft__,
+ "opencl back-end requires pyopencl and gpyfft")
+ def test_opencl(self):
+ from silx.opencl.common import ocl
+ if ocl is not None:
+ self.__run_tests(backend="opencl", ctx=ocl.create_context())
+
+ @unittest.skipIf(not __have_fftw__,
+ "fftw back-end requires pyfftw")
+ def test_fftw(self):
+ self.__run_tests(backend="fftw")
+
+ def __run_tests(self, backend, **extra_args):
+ """Run all tests with the given backend
+
+ :param str backend:
+ :param dict extra_args: Additional arguments to provide to FFT
+ """
+ for trdim in self.transform_infos.dimensions:
+ for mode in self.transform_infos.modes:
+ for size in self.transform_infos.sizes[trdim]:
+ with self.subTest(trdim=trdim, mode=mode, size=size):
+ self.__test(backend, trdim, mode, size, **extra_args)
+
+ def __test(self, backend, trdim, mode, size, **extra_args):
+ """Compare given backend with numpy for given conditions"""
+ logger.debug("backend: %s, trdim: %s, mode: %s, size: %s",
+ backend, trdim, mode, str(size))
+ if size == "3D" and self.test_options.TEST_LOW_MEM:
+ self.skipTest("low mem")
+
+ ndim = len(size)
+ input_data = self.test_data.data_refs[ndim].astype(
+ self.transform_infos.modes[mode])
+ tol = self.tol[np.dtype(input_data.dtype)]
+ if trdim == "3D":
+ tol *= 10 # Error is relatively high in high dimensions
+ # It seems that cuda has problems with C2D batched 1D
+ if trdim == "batched_1D" and backend == "cuda" and mode == "C2C":
+ tol *= 10
+
+ # Python < 3.5 does not want to mix **extra_args with existing kwargs
+ fft_args = {
+ "template": input_data,
+ "axes": self.transform_infos.axes[trdim],
+ "backend": backend,
+ }
+ fft_args.update(extra_args)
+ F = FFT(
+ **fft_args
+ )
+ F_np = FFT(
+ template=input_data,
+ axes=self.transform_infos.axes[trdim],
+ backend="numpy"
+ )
+
+ # Forward FFT
+ res = F.fft(input_data)
+ res_np = F_np.fft(input_data)
+ mae = self.calc_mae(res, res_np)
+ all_close = np.allclose(res, res_np, atol=tol, rtol=tol),
+ self.assertTrue(
+ all_close,
+ "FFT %s:%s, MAE(%s, numpy) = %f (tol = %.2e)" % (mode, trdim, backend, mae, tol)
+ )
+
+ # Inverse FFT
+ res2 = F.ifft(res)
+ mae = self.calc_mae(res2, input_data)
+ self.assertTrue(
+ mae < tol,
+ "IFFT %s:%s, MAE(%s, numpy) = %f" % (mode, trdim, backend, mae)
+ )
+
+
+@unittest.skipUnless(__have_scipy, "scipy is missing")
+class TestNumpyFFT(ParametricTestCase):
+ """
+ Test the Numpy backend individually.
+ """
+
+ def setUp(self):
+ transforms = {
+ "1D": {
+ True: (np.fft.rfft, np.fft.irfft),
+ False: (np.fft.fft, np.fft.ifft),
+ },
+ "2D": {
+ True: (np.fft.rfft2, np.fft.irfft2),
+ False: (np.fft.fft2, np.fft.ifft2),
+ },
+ "3D": {
+ True: (np.fft.rfftn, np.fft.irfftn),
+ False: (np.fft.fftn, np.fft.ifftn),
+ },
+ }
+ transforms["batched_1D"] = transforms["1D"]
+ transforms["batched_2D"] = transforms["2D"]
+ self.transforms = transforms
+ self.transform_infos = TransformInfos()
+ self.test_data = Data()
+
+ def test(self):
+ """Test the numpy backend against native fft.
+
+ Results should be exactly the same.
+ """
+ for trdim in self.transform_infos.dimensions:
+ for mode in self.transform_infos.modes:
+ for size in self.transform_infos.sizes[trdim]:
+ with self.subTest(trdim=trdim, mode=mode, size=size):
+ self.__test(trdim, mode, size)
+
+ def __test(self, trdim, mode, size):
+ logger.debug("trdim: %s, mode: %s, size: %s", trdim, mode, str(size))
+ ndim = len(size)
+ input_data = self.test_data.data_refs[ndim].astype(
+ self.transform_infos.modes[mode])
+ np_fft, np_ifft = self.transforms[trdim][np.isrealobj(input_data)]
+
+ F = FFT(
+ template=input_data,
+ axes=self.transform_infos.axes[trdim],
+ backend="numpy"
+ )
+ # Test FFT
+ res = F.fft(input_data)
+ ref = np_fft(input_data)
+ self.assertTrue(np.allclose(res, ref))
+
+ # Test IFFT
+ res2 = F.ifft(res)
+ ref2 = np_ifft(ref)
+ self.assertTrue(np.allclose(res2, ref2))
diff --git a/silx/math/fit/__init__.py b/src/silx/math/fit/__init__.py
index 29e6a9e..29e6a9e 100644
--- a/silx/math/fit/__init__.py
+++ b/src/silx/math/fit/__init__.py
diff --git a/silx/math/fit/bgtheories.py b/src/silx/math/fit/bgtheories.py
index 631c43e..631c43e 100644
--- a/silx/math/fit/bgtheories.py
+++ b/src/silx/math/fit/bgtheories.py
diff --git a/silx/math/fit/filters.pyx b/src/silx/math/fit/filters.pyx
index da1f6f5..da1f6f5 100644
--- a/silx/math/fit/filters.pyx
+++ b/src/silx/math/fit/filters.pyx
diff --git a/silx/math/fit/filters/include/filters.h b/src/silx/math/fit/filters/include/filters.h
index 1ee9a95..1ee9a95 100644
--- a/silx/math/fit/filters/include/filters.h
+++ b/src/silx/math/fit/filters/include/filters.h
diff --git a/silx/math/fit/filters/src/smoothnd.c b/src/silx/math/fit/filters/src/smoothnd.c
index cb96961..cb96961 100644
--- a/silx/math/fit/filters/src/smoothnd.c
+++ b/src/silx/math/fit/filters/src/smoothnd.c
diff --git a/silx/math/fit/filters/src/snip1d.c b/src/silx/math/fit/filters/src/snip1d.c
index 994a272..994a272 100644
--- a/silx/math/fit/filters/src/snip1d.c
+++ b/src/silx/math/fit/filters/src/snip1d.c
diff --git a/silx/math/fit/filters/src/snip2d.c b/src/silx/math/fit/filters/src/snip2d.c
index 235759c..235759c 100644
--- a/silx/math/fit/filters/src/snip2d.c
+++ b/src/silx/math/fit/filters/src/snip2d.c
diff --git a/silx/math/fit/filters/src/snip3d.c b/src/silx/math/fit/filters/src/snip3d.c
index cf48ee4..cf48ee4 100644
--- a/silx/math/fit/filters/src/snip3d.c
+++ b/src/silx/math/fit/filters/src/snip3d.c
diff --git a/silx/math/fit/filters/src/strip.c b/src/silx/math/fit/filters/src/strip.c
index dec0742..dec0742 100644
--- a/silx/math/fit/filters/src/strip.c
+++ b/src/silx/math/fit/filters/src/strip.c
diff --git a/silx/math/fit/filters_wrapper.pxd b/src/silx/math/fit/filters_wrapper.pxd
index e4f7c72..e4f7c72 100644
--- a/silx/math/fit/filters_wrapper.pxd
+++ b/src/silx/math/fit/filters_wrapper.pxd
diff --git a/src/silx/math/fit/fitmanager.py b/src/silx/math/fit/fitmanager.py
new file mode 100644
index 0000000..226e047
--- /dev/null
+++ b/src/silx/math/fit/fitmanager.py
@@ -0,0 +1,1087 @@
+# coding: utf-8
+# /*#########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ##########################################################################*/
+"""
+This module provides a tool to perform advanced fitting. The actual fit relies
+on :func:`silx.math.fit.leastsq`.
+
+This module deals with:
+
+ - handling of the model functions (using a set of default functions or
+ loading custom user functions)
+ - handling of estimation function, that are used to determine the number
+ of parameters to be fitted for functions with unknown number of
+ parameters (such as the sum of a variable number of gaussian curves),
+ and find reasonable initial parameters for input to the iterative
+ fitting algorithm
+ - handling of custom derivative functions that can be passed as a
+ parameter to :func:`silx.math.fit.leastsq`
+ - providing different background models
+
+"""
+from collections import OrderedDict
+import logging
+import numpy
+from numpy.linalg.linalg import LinAlgError
+import os
+import sys
+
+from .filters import strip, smooth1d
+from .leastsq import leastsq
+from .fittheory import FitTheory
+from . import bgtheories
+
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "16/01/2017"
+
+_logger = logging.getLogger(__name__)
+
+
+class FitManager(object):
+ """
+ Fit functions manager
+
+ :param x: Abscissa data. If ``None``, :attr:`xdata` is set to
+ ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
+ :type x: Sequence or numpy array or None
+ :param y: The dependant data ``y = f(x)``. ``y`` must have the same
+ shape as ``x`` if ``x`` is not ``None``.
+ :type y: Sequence or numpy array or None
+ :param sigmay: The uncertainties in the ``ydata`` array. These can be
+ used as weights in the least-squares problem, if ``weight_flag``
+ is ``True``.
+ If ``None``, the uncertainties are assumed to be 1, unless
+ ``weight_flag`` is ``True``, in which case the square-root
+ of ``y`` is used.
+ :type sigmay: Sequence or numpy array or None
+ :param weight_flag: If this parameter is ``True`` and ``sigmay``
+ uncertainties are not specified, the square root of ``y`` is used
+ as weights in the least-squares problem. If ``False``, the
+ uncertainties are set to 1.
+ :type weight_flag: boolean
+ """
+ def __init__(self, x=None, y=None, sigmay=None, weight_flag=False):
+ """
+ """
+ self.fitconfig = {
+ 'WeightFlag': weight_flag,
+ 'fitbkg': 'No Background',
+ 'fittheory': None,
+ # Next few parameters are defined for compatibility with legacy theories
+ # which take the background as argument for their estimation function
+ 'StripWidth': 2,
+ 'StripIterations': 5000,
+ 'StripThresholdFactor': 1.0,
+ 'SmoothingFlag': False
+ }
+ """Dictionary of fit configuration parameters.
+ These parameters can be modified using the :meth:`configure` method.
+
+ Keys are:
+
+ - 'fitbkg': name of the function used for fitting a low frequency
+ background signal
+ - 'FwhmPoints': default full width at half maximum value for the
+ peaks'.
+ - 'Sensitivity': Sensitivity parameter for the peak detection
+ algorithm (:func:`silx.math.fit.peak_search`)
+ """
+
+ self.theories = OrderedDict()
+ """Dictionary of fit theories, defining functions to be fitted
+ to individual peaks.
+
+ Keys are descriptive theory names (e.g "Gaussians" or "Step up").
+ Values are :class:`silx.math.fit.fittheory.FitTheory` objects with
+ the following attributes:
+
+ - *"function"* is the fit function for an individual peak
+ - *"parameters"* is a sequence of parameter names
+ - *"estimate"* is the parameter estimation function
+ - *"configure"* is the function returning the configuration dict
+ for the theory in the format described in the :attr:` fitconfig`
+ documentation
+ - *"derivative"* (optional) is a custom derivative function, whose
+ signature is described in the documentation of
+ :func:`silx.math.fit.leastsq.leastsq`
+ (``model_deriv(xdata, parameters, index)``).
+ - *"description"* is a description string
+ """
+
+ self.selectedtheory = None
+ """Name of currently selected theory. This name matches a key in
+ :attr:`theories`."""
+
+ self.bgtheories = OrderedDict()
+ """Dictionary of background theories.
+
+ See :attr:`theories` for documentation on theories.
+ """
+
+ # Load default theories (constant, linear, strip)
+ self.loadbgtheories(bgtheories)
+
+ self.selectedbg = 'No Background'
+ """Name of currently selected background theory. This name must be
+ an existing key in :attr:`bgtheories`."""
+
+ self.fit_results = []
+ """This list stores detailed information about all fit parameters.
+ It is initialized in :meth:`estimate` and completed with final fit
+ values in :meth:`runfit`.
+
+ Each fit parameter is stored as a dictionary with following fields:
+
+ - 'name': Parameter name.
+ - 'estimation': Estimated value.
+ - 'group': Group number. Group 0 corresponds to the background
+ function parameters. Group ``n`` (for ``n>0``) corresponds to
+ the fit function parameters for the n-th peak.
+ - 'code': Constraint code
+
+ - 0 - FREE
+ - 1 - POSITIVE
+ - 2 - QUOTED
+ - 3 - FIXED
+ - 4 - FACTOR
+ - 5 - DELTA
+ - 6 - SUM
+
+ - 'cons1':
+
+ - Ignored if 'code' is FREE, POSITIVE or FIXED.
+ - Min value of the parameter if code is QUOTED
+ - Index of fitted parameter to which 'cons2' is related
+ if code is FACTOR, DELTA or SUM.
+
+ - 'cons2':
+
+ - Ignored if 'code' is FREE, POSITIVE or FIXED.
+ - Max value of the parameter if QUOTED
+ - Factor to apply to related parameter with index 'cons1' if
+ 'code' is FACTOR
+ - Difference with parameter with index 'cons1' if
+ 'code' is DELTA
+ - Sum obtained when adding parameter with index 'cons1' if
+ 'code' is SUM
+
+ - 'fitresult': Fitted value.
+ - 'sigma': Standard deviation for the parameter estimate
+ - 'xmin': Lower limit of the ``x`` data range on which the fit
+ was performed
+ - 'xmax': Upeer limit of the ``x`` data range on which the fit
+ was performed
+ """
+
+ self.parameter_names = []
+ """This list stores all fit parameter names: background function
+ parameters and fit function parameters for every peak. It is filled
+ in :meth:`estimate`.
+
+ It is the responsibility of the estimate function defined in
+ :attr:`theories` to determine how many parameters are needed,
+ based on how many peaks are detected and how many parameters are needed
+ to fit an individual peak.
+ """
+
+ self.setdata(x, y, sigmay)
+
+ ##################
+ # Public methods #
+ ##################
+ def addbackground(self, bgname, bgtheory):
+ """Add a new background theory to dictionary :attr:`bgtheories`.
+
+ :param bgname: String with the name describing the function
+ :param bgtheory: :class:`FitTheory` object
+ :type bgtheory: :class:`silx.math.fit.fittheory.FitTheory`
+ """
+ self.bgtheories[bgname] = bgtheory
+
+ def addtheory(self, name, theory=None,
+ function=None, parameters=None,
+ estimate=None, configure=None, derivative=None,
+ description=None, pymca_legacy=False):
+ """Add a new theory to dictionary :attr:`theories`.
+
+ You can pass a name and a :class:`FitTheory` object as arguments, or
+ alternatively provide all arguments necessary to instantiate a new
+ :class:`FitTheory` object.
+
+ See :meth:`loadtheories` for more information on estimation functions,
+ configuration functions and custom derivative functions.
+
+ :param name: String with the name describing the function
+ :param theory: :class:`FitTheory` object, defining a fit function and
+ associated information (estimation function, description…).
+ If this parameter is provided, all other parameters, except for
+ ``name``, are ignored.
+ :type theory: :class:`silx.math.fit.fittheory.FitTheory`
+ :param callable function: Mandatory argument if ``theory`` is not provided.
+ See documentation for :attr:`silx.math.fit.fittheory.FitTheory.function`.
+ :param List[str] parameters: Mandatory argument if ``theory`` is not provided.
+ See documentation for :attr:`silx.math.fit.fittheory.FitTheory.parameters`.
+ :param callable estimate: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.estimate`
+ :param callable configure: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.configure`
+ :param callable derivative: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.derivative`
+ :param str description: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.description`
+ :param config_widget: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.config_widget`
+ :param bool pymca_legacy: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.pymca_legacy`
+ """
+ if theory is not None:
+ self.theories[name] = theory
+
+ elif function is not None and parameters is not None:
+ self.theories[name] = FitTheory(
+ description=description,
+ function=function,
+ parameters=parameters,
+ estimate=estimate,
+ configure=configure,
+ derivative=derivative,
+ pymca_legacy=pymca_legacy
+ )
+
+ else:
+ raise TypeError("You must supply a FitTheory object or define " +
+ "a fit function and its parameters.")
+
+ def addbgtheory(self, name, theory=None,
+ function=None, parameters=None,
+ estimate=None, configure=None,
+ derivative=None, description=None):
+ """Add a new theory to dictionary :attr:`bgtheories`.
+
+ You can pass a name and a :class:`FitTheory` object as arguments, or
+ alternatively provide all arguments necessary to instantiate a new
+ :class:`FitTheory` object.
+
+ :param name: String with the name describing the function
+ :param theory: :class:`FitTheory` object, defining a fit function and
+ associated information (estimation function, description…).
+ If this parameter is provided, all other parameters, except for
+ ``name``, are ignored.
+ :type theory: :class:`silx.math.fit.fittheory.FitTheory`
+ :param function function: Mandatory argument if ``theory`` is not provided.
+ See documentation for :attr:`silx.math.fit.fittheory.FitTheory.function`.
+ :param list[str] parameters: Mandatory argument if ``theory`` is not provided.
+ See documentation for :attr:`silx.math.fit.fittheory.FitTheory.parameters`.
+ :param function estimate: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.estimate`
+ :param function configure: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.configure`
+ :param function derivative: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.derivative`
+ :param str description: See documentation for
+ :attr:`silx.math.fit.fittheory.FitTheory.description`
+ """
+ if theory is not None:
+ self.bgtheories[name] = theory
+
+ elif function is not None and parameters is not None:
+ self.bgtheories[name] = FitTheory(
+ description=description,
+ function=function,
+ parameters=parameters,
+ estimate=estimate,
+ configure=configure,
+ derivative=derivative,
+ is_background=True
+ )
+
+ else:
+ raise TypeError("You must supply a FitTheory object or define " +
+ "a background function and its parameters.")
+
+ def configure(self, **kw):
+ """Configure the current theory by filling or updating the
+ :attr:`fitconfig` dictionary.
+ Call the custom configuration function, if any. This allows the user
+ to modify the behavior of the custom fit function or the custom
+ estimate function.
+
+ This methods accepts only named parameters. All ``**kw`` parameters
+ are expected to be fields of :attr:`fitconfig` to be updated, unless
+ they have a special meaning for the custom configuration function
+ of the currently selected theory..
+
+ This method returns the modified config dictionary returned by the
+ custom configuration function.
+ """
+ # inspect **kw to find known keys, update them in self.fitconfig
+ for key in self.fitconfig:
+ if key in kw:
+ self.fitconfig[key] = kw[key]
+
+ # initialize dict with existing config dict
+ result = {}
+ result.update(self.fitconfig)
+
+ if "WeightFlag" in kw:
+ if kw["WeightFlag"]:
+ self.enableweight()
+ else:
+ self.disableweight()
+
+ if self.selectedtheory is None:
+ return result
+
+ # Apply custom configuration function
+ custom_config_fun = self.theories[self.selectedtheory].configure
+ if custom_config_fun is not None:
+ result.update(custom_config_fun(**kw))
+
+ custom_bg_config_fun = self.bgtheories[self.selectedbg].configure
+ if custom_bg_config_fun is not None:
+ result.update(custom_bg_config_fun(**kw))
+
+ # Update self.fitconfig with custom config
+ for key in self.fitconfig:
+ if key in result:
+ self.fitconfig[key] = result[key]
+
+ result.update(self.fitconfig)
+ return result
+
+ def estimate(self, callback=None):
+ """
+ Fill :attr:`fit_results` with an estimation of the fit parameters.
+
+ At first, the background parameters are estimated, if a background
+ model has been specified.
+ Then, a custom estimation function related to the model function is
+ called.
+
+ This process determines the number of needed fit parameters and
+ provides an initial estimation for them, to serve as an input for the
+ actual iterative fitting performed in :meth:`runfit`.
+
+ :param callback: Optional callback function, conforming to the
+ signature ``callback(data)`` with ``data`` being a dictionary.
+ This callback function is called before and after the estimation
+ process, and is given a dictionary containing the values of
+ :attr:`state` (``'Estimate in progress'`` or ``'Ready to Fit'``)
+ and :attr:`chisq`.
+ This is used for instance in :mod:`silx.gui.fit.FitWidget` to
+ update a widget displaying a status message.
+ :return: Estimated parameters
+ """
+ self.state = 'Estimate in progress'
+ self.chisq = None
+
+ if callback is not None:
+ callback(data={'chisq': self.chisq,
+ 'status': self.state})
+
+ CONS = {0: 'FREE',
+ 1: 'POSITIVE',
+ 2: 'QUOTED',
+ 3: 'FIXED',
+ 4: 'FACTOR',
+ 5: 'DELTA',
+ 6: 'SUM',
+ 7: 'IGNORE'}
+
+ # Filter-out not finite data
+ xwork = self.xdata[self._finite_mask]
+ ywork = self.ydata[self._finite_mask]
+
+ # estimate the background
+ bg_params, bg_constraints = self.estimate_bkg(xwork, ywork)
+
+ # estimate the function
+ try:
+ fun_params, fun_constraints = self.estimate_fun(xwork, ywork)
+ except LinAlgError:
+ self.state = 'Estimate failed'
+ if callback is not None:
+ callback(data={'status': self.state})
+ raise
+
+ # build the names
+ self.parameter_names = []
+
+ for bg_param_name in self.bgtheories[self.selectedbg].parameters:
+ self.parameter_names.append(bg_param_name)
+
+ fun_param_names = self.theories[self.selectedtheory].parameters
+ param_index, peak_index = 0, 0
+ while param_index < len(fun_params):
+ peak_index += 1
+ for fun_param_name in fun_param_names:
+ self.parameter_names.append(fun_param_name + "%d" % peak_index)
+ param_index += 1
+
+ self.fit_results = []
+ nb_fun_params_per_group = len(fun_param_names)
+ group_number = 0
+ xmin = min(xwork)
+ xmax = max(xwork)
+ nb_bg_params = len(bg_params)
+ for (pindex, pname) in enumerate(self.parameter_names):
+ # First come background parameters
+ if pindex < nb_bg_params:
+ estimation_value = bg_params[pindex]
+ constraint_code = CONS[int(bg_constraints[pindex][0])]
+ cons1 = bg_constraints[pindex][1]
+ cons2 = bg_constraints[pindex][2]
+ # then come peak function parameters
+ else:
+ fun_param_index = pindex - nb_bg_params
+
+ # increment group_number for each new fitted peak
+ if (fun_param_index % nb_fun_params_per_group) == 0:
+ group_number += 1
+
+ estimation_value = fun_params[fun_param_index]
+ constraint_code = CONS[int(fun_constraints[fun_param_index][0])]
+ # cons1 is the index of another fit parameter. In the global
+ # fit_results, we must adjust the index to account for the bg
+ # params added to the start of the list.
+ cons1 = fun_constraints[fun_param_index][1]
+ if constraint_code in ["FACTOR", "DELTA", "SUM"]:
+ cons1 += nb_bg_params
+ cons2 = fun_constraints[fun_param_index][2]
+
+ self.fit_results.append({'name': pname,
+ 'estimation': estimation_value,
+ 'group': group_number,
+ 'code': constraint_code,
+ 'cons1': cons1,
+ 'cons2': cons2,
+ 'fitresult': 0.0,
+ 'sigma': 0.0,
+ 'xmin': xmin,
+ 'xmax': xmax})
+
+ self.state = 'Ready to Fit'
+ self.chisq = None
+ self.niter = 0
+
+ if callback is not None:
+ callback(data={'chisq': self.chisq,
+ 'status': self.state})
+ return numpy.append(bg_params, fun_params)
+
+ def fit(self):
+ """Convenience method to call :meth:`estimate` followed by :meth:`runfit`.
+
+ :return: Output of :meth:`runfit`"""
+ self.estimate()
+ return self.runfit()
+
+ def gendata(self, x=None, paramlist=None, estimated=False):
+ """Return a data array using the currently selected fit function
+ and the fitted parameters.
+
+ :param x: Independent variable where the function is calculated.
+ If ``None``, use :attr:`xdata`.
+ :param paramlist: List of dictionaries, each dictionary item being a
+ fit parameter. The dictionary's format is documented in
+ :attr:`fit_results`.
+ If ``None`` (default), use parameters from :attr:`fit_results`.
+ :param estimated: If *True*, use estimated parameters.
+ :return: :meth:`fitfunction` calculated for parameters whose code is
+ not set to ``"IGNORE"``.
+
+ This calculates :meth:`fitfunction` on `x` data using fit parameters
+ from a list of parameter dictionaries, if field ``code`` is not set
+ to ``"IGNORE"``.
+ """
+ x = self.xdata if x is None else numpy.array(x, copy=False)
+
+ if paramlist is None:
+ paramlist = self.fit_results
+ active_params = []
+ for param in paramlist:
+ if param['code'] not in ['IGNORE', 7]:
+ if not estimated:
+ active_params.append(param['fitresult'])
+ else:
+ active_params.append(param['estimation'])
+
+ # Mask x with not finite (support nD x)
+ finite_mask = numpy.all(numpy.isfinite(x), axis=tuple(range(1, x.ndim)))
+
+ if numpy.all(finite_mask): # All values are finite: fast path
+ return self.fitfunction(numpy.array(x, copy=True), *active_params)
+
+ else: # Only run fitfunction on finite data and complete result with NaNs
+ # Create result with same number as elements as x, filling holes with NaNs
+ result = numpy.full((x.shape[0],), numpy.nan, dtype=numpy.float64)
+ result[finite_mask] = self.fitfunction(
+ numpy.array(x[finite_mask], copy=True), *active_params)
+ return result
+
+ def get_estimation(self):
+ """Return the list of fit parameter names."""
+ if self.state not in ["Ready to fit", "Fit in progress", "Ready"]:
+ _logger.warning("get_estimation() called before estimate() completed")
+ return [param["estimation"] for param in self.fit_results]
+
+ def get_names(self):
+ """Return the list of fit parameter estimations."""
+ if self.state not in ["Ready to fit", "Fit in progress", "Ready"]:
+ msg = "get_names() called before estimate() completed, "
+ msg += "names are not populated at this stage"
+ _logger.warning(msg)
+ return [param["name"] for param in self.fit_results]
+
+ def get_fitted_parameters(self):
+ """Return the list of fitted parameters."""
+ if self.state not in ["Ready"]:
+ msg = "get_fitted_parameters() called before runfit() completed, "
+ msg += "results are not available a this stage"
+ _logger.warning(msg)
+ return [param["fitresult"] for param in self.fit_results]
+
+ def loadtheories(self, theories):
+ """Import user defined fit functions defined in an external Python
+ source file, and save them in :attr:`theories`.
+
+ An example of such a file can be found in the sources of
+ :mod:`silx.math.fit.fittheories`. It must contain a
+ dictionary named ``THEORY`` with the following structure::
+
+ THEORY = {
+ 'theory_name_1':
+ FitTheory(description='Description of theory 1',
+ function=fitfunction1,
+ parameters=('param name 1', 'param name 2', …),
+ estimate=estimation_function1,
+ configure=configuration_function1,
+ derivative=derivative_function1),
+ 'theory_name_2':
+ FitTheory(…),
+ }
+
+ See documentation of :mod:`silx.math.fit.fittheories` and
+ :mod:`silx.math.fit.fittheory` for more
+ information on designing your fit functions file.
+
+ This method can also load user defined functions in the legacy
+ format used in *PyMca*.
+
+ :param theories: Name of python source file, or module containing the
+ definition of fit functions.
+ :raise: ImportError if theories cannot be imported
+ """
+ from types import ModuleType
+ if isinstance(theories, ModuleType):
+ theories_module = theories
+ else:
+ # if theories is not a module, it must be a string
+ string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
+ if not isinstance(theories, string_types):
+ raise ImportError("theory must be a python module, a module" +
+ "name or a python filename")
+ # if theories is a filename
+ if os.path.isfile(theories):
+ sys.path.append(os.path.dirname(theories))
+ f = os.path.basename(os.path.splitext(theories)[0])
+ theories_module = __import__(f)
+ # if theories is a module name
+ else:
+ theories_module = __import__(theories)
+
+ if hasattr(theories_module, "INIT"):
+ theories.INIT()
+
+ if not hasattr(theories_module, "THEORY"):
+ msg = "File %s does not contain a THEORY dictionary" % theories
+ raise ImportError(msg)
+
+ elif isinstance(theories_module.THEORY, dict):
+ # silx format for theory definition
+ for theory_name, fittheory in list(theories_module.THEORY.items()):
+ self.addtheory(theory_name, fittheory)
+ else:
+ self._load_legacy_theories(theories_module)
+
+ def loadbgtheories(self, theories):
+ """Import user defined background functions defined in an external Python
+ module (source file), and save them in :attr:`theories`.
+
+ An example of such a file can be found in the sources of
+ :mod:`silx.math.fit.fittheories`. It must contain a
+ dictionary named ``THEORY`` with the following structure::
+
+ THEORY = {
+ 'theory_name_1':
+ FitTheory(description='Description of theory 1',
+ function=fitfunction1,
+ parameters=('param name 1', 'param name 2', …),
+ estimate=estimation_function1,
+ configure=configuration_function1,
+ 'theory_name_2':
+ FitTheory(…),
+ }
+
+ See documentation of :mod:`silx.math.fit.bgtheories` and
+ :mod:`silx.math.fit.fittheory` for more
+ information on designing your background functions file.
+
+ :param theories: Module or name of python source file containing the
+ definition of background functions.
+ :raise: ImportError if theories cannot be imported
+ """
+ from types import ModuleType
+ if isinstance(theories, ModuleType):
+ theories_module = theories
+ else:
+ # if theories is not a module, it must be a string
+ string_types = (basestring,) if sys.version_info[0] == 2 else (str,) # noqa
+ if not isinstance(theories, string_types):
+ raise ImportError("theory must be a python module, a module" +
+ "name or a python filename")
+ # if theories is a filename
+ if os.path.isfile(theories):
+ sys.path.append(os.path.dirname(theories))
+ f = os.path.basename(os.path.splitext(theories)[0])
+ theories_module = __import__(f)
+ # if theories is a module name
+ else:
+ theories_module = __import__(theories)
+
+ if hasattr(theories_module, "INIT"):
+ theories.INIT()
+
+ if not hasattr(theories_module, "THEORY"):
+ msg = "File %s does not contain a THEORY dictionary" % theories
+ raise ImportError(msg)
+
+ elif isinstance(theories_module.THEORY, dict):
+ # silx format for theory definition
+ for theory_name, fittheory in list(theories_module.THEORY.items()):
+ self.addbgtheory(theory_name, fittheory)
+
+ def setbackground(self, theory):
+ """Choose a background type from within :attr:`bgtheories`.
+
+ This updates :attr:`selectedbg`.
+
+ :param theory: The name of the background to be used.
+ :raise: KeyError if ``theory`` is not a key of :attr:`bgtheories``.
+ """
+ if theory in self.bgtheories:
+ self.selectedbg = theory
+ else:
+ msg = "No theory with name %s in bgtheories.\n" % theory
+ msg += "Available theories: %s\n" % self.bgtheories.keys()
+ raise KeyError(msg)
+
+ # run configure to apply our fitconfig to the selected theory
+ # through its custom config function
+ self.configure(**self.fitconfig)
+
+ def setdata(self, x, y, sigmay=None, xmin=None, xmax=None):
+ """Set data attributes:
+
+ - ``xdata0``, ``ydata0`` and ``sigmay0`` store the initial data
+ and uncertainties. These attributes are not modified after
+ initialization.
+ - ``xdata``, ``ydata`` and ``sigmay`` store the data after
+ removing values where ``xdata < xmin`` or ``xdata > xmax``.
+ These attributes may be modified at a latter stage by filters.
+
+ :param x: Abscissa data. If ``None``, :attr:`xdata`` is set to
+ ``numpy.array([0.0, 1.0, 2.0, ..., len(y)-1])``
+ :type x: Sequence or numpy array or None
+ :param y: The dependant data ``y = f(x)``. ``y`` must have the same
+ shape as ``x`` if ``x`` is not ``None``.
+ :type y: Sequence or numpy array or None
+ :param sigmay: The uncertainties in the ``ydata`` array. These are
+ used as weights in the least-squares problem.
+ If ``None``, the uncertainties are assumed to be 1.
+ :type sigmay: Sequence or numpy array or None
+ :param xmin: Lower value of x values to use for fitting
+ :param xmax: Upper value of x values to use for fitting
+ """
+ if y is None:
+ self.xdata0 = numpy.array([], numpy.float64)
+ self.ydata0 = numpy.array([], numpy.float64)
+ # self.sigmay0 = numpy.array([], numpy.float64)
+ self.xdata = numpy.array([], numpy.float64)
+ self.ydata = numpy.array([], numpy.float64)
+ # self.sigmay = numpy.array([], numpy.float64)
+
+ else:
+ self.ydata0 = numpy.array(y)
+ self.ydata = numpy.array(y)
+ if x is None:
+ self.xdata0 = numpy.arange(len(self.ydata0))
+ self.xdata = numpy.arange(len(self.ydata0))
+ else:
+ self.xdata0 = numpy.array(x)
+ self.xdata = numpy.array(x)
+
+ # default weight
+ if sigmay is None:
+ self.sigmay0 = None
+ self.sigmay = numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
+ else:
+ self.sigmay0 = numpy.array(sigmay)
+ self.sigmay = numpy.array(sigmay) if self.fitconfig["WeightFlag"] else None
+
+ # take the data between limits, using boolean array indexing
+ if (xmin is not None or xmax is not None) and len(self.xdata):
+ xmin = xmin if xmin is not None else min(self.xdata)
+ xmax = xmax if xmax is not None else max(self.xdata)
+ bool_array = (self.xdata >= xmin) & (self.xdata <= xmax)
+ self.xdata = self.xdata[bool_array]
+ self.ydata = self.ydata[bool_array]
+ self.sigmay = self.sigmay[bool_array] if sigmay is not None else None
+
+ self._finite_mask = numpy.logical_and(
+ numpy.all(numpy.isfinite(self.xdata), axis=tuple(range(1, self.xdata.ndim))),
+ numpy.isfinite(self.ydata))
+
+ def enableweight(self):
+ """This method can be called to set :attr:`sigmay`. If :attr:`sigmay0` was filled with
+ actual uncertainties in :meth:`setdata`, use these values.
+ Else, use ``sqrt(self.ydata)``.
+ """
+ if self.sigmay0 is None:
+ self.sigmay = numpy.sqrt(self.ydata) if self.fitconfig["WeightFlag"] else None
+ else:
+ self.sigmay = self.sigmay0
+
+ def disableweight(self):
+ """This method can be called to set :attr:`sigmay` equal to ``None``.
+ As a result, :func:`leastsq` will consider that the weights in the
+ least square problem are 1 for all samples."""
+ self.sigmay = None
+
+ def settheory(self, theory):
+ """Pick a theory from :attr:`theories`.
+
+ :param theory: Name of the theory to be used.
+ :raise: KeyError if ``theory`` is not a key of :attr:`theories`.
+ """
+ if theory is None:
+ self.selectedtheory = None
+ elif theory in self.theories:
+ self.selectedtheory = theory
+ else:
+ msg = "No theory with name %s in theories.\n" % theory
+ msg += "Available theories: %s\n" % self.theories.keys()
+ raise KeyError(msg)
+
+ # run configure to apply our fitconfig to the selected theory
+ # through its custom config function
+ self.configure(**self.fitconfig)
+
+ def runfit(self, callback=None):
+ """Run the actual fitting and fill :attr:`fit_results` with fit results.
+
+ Before running this method, :attr:`fit_results` must already be
+ populated with a list of all parameters and their estimated values.
+ For this, run :meth:`estimate` beforehand.
+
+ :param callback: Optional callback function, conforming to the
+ signature ``callback(data)`` with ``data`` being a dictionary.
+ This callback function is called before and after the estimation
+ process, and is given a dictionary containing the values of
+ :attr:`state` (``'Fit in progress'`` or ``'Ready'``)
+ and :attr:`chisq`.
+ This is used for instance in :mod:`silx.gui.fit.FitWidget` to
+ update a widget displaying a status message.
+ :return: Tuple ``(fitted parameters, uncertainties, infodict)``.
+ *infodict* is the dictionary returned by
+ :func:`silx.math.fit.leastsq` when called with option
+ ``full_output=True``. Uncertainties is a sequence of uncertainty
+ values associated with each fitted parameter.
+ """
+ # self.dataupdate()
+
+ self.state = 'Fit in progress'
+ self.chisq = None
+
+ if callback is not None:
+ callback(data={'chisq': self.chisq,
+ 'status': self.state})
+
+ param_val = []
+ param_constraints = []
+ # Initial values are set to the ones computed in estimate()
+ for param in self.fit_results:
+ param_val.append(param['estimation'])
+ param_constraints.append([param['code'], param['cons1'], param['cons2']])
+
+ # Filter-out not finite data
+ ywork = self.ydata[self._finite_mask]
+ xwork = self.xdata[self._finite_mask]
+
+ try:
+ params, covariance_matrix, infodict = leastsq(
+ self.fitfunction, # bg + actual model function
+ xwork, ywork, param_val,
+ sigma=self.sigmay,
+ constraints=param_constraints,
+ model_deriv=self.theories[self.selectedtheory].derivative,
+ full_output=True, left_derivative=True)
+ except LinAlgError:
+ self.state = 'Fit failed'
+ callback(data={'status': self.state})
+ raise
+
+ sigmas = infodict['uncertainties']
+
+ for i, param in enumerate(self.fit_results):
+ if param['code'] != 'IGNORE':
+ param['fitresult'] = params[i]
+ param['sigma'] = sigmas[i]
+
+ self.chisq = infodict["reduced_chisq"]
+ self.niter = infodict["niter"]
+ self.state = 'Ready'
+
+ if callback is not None:
+ callback(data={'chisq': self.chisq,
+ 'status': self.state})
+
+ return params, sigmas, infodict
+
+ ###################
+ # Private methods #
+ ###################
+ def fitfunction(self, x, *pars):
+ """Function to be fitted.
+
+ This is the sum of the selected background function plus
+ the selected fit model function.
+
+ :param x: Independent variable where the function is calculated.
+ :param pars: Sequence of all fit parameters. The first few parameters
+ are background parameters, then come the peak function parameters.
+ :return: Output of the fit function with ``x`` as input and ``pars``
+ as fit parameters.
+ """
+ result = numpy.zeros(numpy.shape(x), numpy.float64)
+
+ if self.selectedbg is not None:
+ bg_pars_list = self.bgtheories[self.selectedbg].parameters
+ nb_bg_pars = len(bg_pars_list)
+
+ bgfun = self.bgtheories[self.selectedbg].function
+ result += bgfun(x, self.ydata, *pars[0:nb_bg_pars])
+ else:
+ nb_bg_pars = 0
+
+ selectedfun = self.theories[self.selectedtheory].function
+ result += selectedfun(x, *pars[nb_bg_pars:])
+
+ return result
+
+ def estimate_bkg(self, x, y):
+ """Estimate background parameters using the function defined in
+ the current fit configuration.
+
+ To change the selected background model, attribute :attr:`selectdbg`
+ must be changed using method :meth:`setbackground`.
+
+ The actual background function to be used is
+ referenced in :attr:`bgtheories`
+
+ :param x: Sequence of x data
+ :param y: sequence of y data
+ :return: Tuple of two sequences and one data array
+ ``(estimated_param, constraints, bg_data)``:
+
+ - ``estimated_param`` is a list of estimated values for each
+ background parameter.
+ - ``constraints`` is a 2D sequence of dimension ``(n_parameters, 3)``
+
+ - ``constraints[i][0]``: Constraint code.
+ See explanation about codes in :attr:`fit_results`
+
+ - ``constraints[i][1]``
+ See explanation about 'cons1' in :attr:`fit_results`
+ documentation.
+
+ - ``constraints[i][2]``
+ See explanation about 'cons2' in :attr:`fit_results`
+ documentation.
+ """
+ background_estimate_function = self.bgtheories[self.selectedbg].estimate
+ if background_estimate_function is not None:
+ return background_estimate_function(x, y)
+ else:
+ return [], []
+
+ def estimate_fun(self, x, y):
+ """Estimate fit parameters using the function defined in
+ the current fit configuration.
+
+ :param x: Sequence of x data
+ :param y: sequence of y data
+ :param bg: Background signal, to be subtracted from ``y`` before fitting.
+ :return: Tuple of two sequences ``(estimated_param, constraints)``:
+
+ - ``estimated_param`` is a list of estimated values for each
+ background parameter.
+ - ``constraints`` is a 2D sequence of dimension (n_parameters, 3)
+
+ - ``constraints[i][0]``: Constraint code.
+ See explanation about codes in :attr:`fit_results`
+
+ - ``constraints[i][1]``
+ See explanation about 'cons1' in :attr:`fit_results`
+ documentation.
+
+ - ``constraints[i][2]``
+ See explanation about 'cons2' in :attr:`fit_results`
+ documentation.
+ :raise: ``TypeError`` if estimation function is not callable
+
+ """
+ estimatefunction = self.theories[self.selectedtheory].estimate
+ if hasattr(estimatefunction, '__call__'):
+ if not self.theories[self.selectedtheory].pymca_legacy:
+ return estimatefunction(x, y)
+ else:
+ # legacy pymca estimate functions have a different signature
+ if self.fitconfig["fitbkg"] == "No Background":
+ bg = numpy.zeros_like(y)
+ else:
+ if self.fitconfig["SmoothingFlag"]:
+ y = smooth1d(y)
+ bg = strip(y,
+ w=self.fitconfig["StripWidth"],
+ niterations=self.fitconfig["StripIterations"],
+ factor=self.fitconfig["StripThresholdFactor"])
+ # fitconfig can be filled by user defined config function
+ xscaling = self.fitconfig.get('Xscaling', 1.0)
+ yscaling = self.fitconfig.get('Yscaling', 1.0)
+ return estimatefunction(x, y, bg, xscaling, yscaling)
+ else:
+ raise TypeError("Estimation function in attribute " +
+ "theories[%s]" % self.selectedtheory +
+ " must be callable.")
+
+ def _load_legacy_theories(self, theories_module):
+ """Load theories from a custom module in the old PyMca format.
+
+ See PyMca5.PyMcaMath.fitting.SpecfitFunctions for an example.
+ """
+ mandatory_attributes = ["THEORY", "PARAMETERS",
+ "FUNCTION", "ESTIMATE"]
+ err_msg = "Custom fit function file must define: "
+ err_msg += ", ".join(mandatory_attributes)
+ for attr in mandatory_attributes:
+ if not hasattr(theories_module, attr):
+ raise ImportError(err_msg)
+
+ derivative = theories_module.DERIVATIVE if hasattr(theories_module, "DERIVATIVE") else None
+ configure = theories_module.CONFIGURE if hasattr(theories_module, "CONFIGURE") else None
+ estimate = theories_module.ESTIMATE if hasattr(theories_module, "ESTIMATE") else None
+ if isinstance(theories_module.THEORY, (list, tuple)):
+ # multiple fit functions
+ for i in range(len(theories_module.THEORY)):
+ deriv = derivative[i] if derivative is not None else None
+ config = configure[i] if configure is not None else None
+ estim = estimate[i] if estimate is not None else None
+ self.addtheory(theories_module.THEORY[i],
+ FitTheory(
+ theories_module.FUNCTION[i],
+ theories_module.PARAMETERS[i],
+ estim,
+ config,
+ deriv,
+ pymca_legacy=True))
+ else:
+ # single fit function
+ self.addtheory(theories_module.THEORY,
+ FitTheory(
+ theories_module.FUNCTION,
+ theories_module.PARAMETERS,
+ estimate,
+ configure,
+ derivative,
+ pymca_legacy=True))
+
+
+def test():
+ from .functions import sum_gauss
+ from . import fittheories
+ from . import bgtheories
+
+ # Create synthetic data with a sum of gaussian functions
+ x = numpy.arange(1000).astype(numpy.float64)
+
+ p = [1000, 100., 250,
+ 255, 690., 45,
+ 1500, 800.5, 95]
+ y = 0.5 * x + 13 + sum_gauss(x, *p)
+
+ # Fitting
+ fit = FitManager()
+ # more sensitivity necessary to resolve
+ # overlapping peaks at x=690 and x=800.5
+ fit.setdata(x=x, y=y)
+ fit.loadtheories(fittheories)
+ fit.settheory('Gaussians')
+ fit.loadbgtheories(bgtheories)
+ fit.setbackground('Linear')
+ fit.estimate()
+ fit.runfit()
+
+ print("Searched parameters = ", p)
+ print("Obtained parameters : ")
+ dummy_list = []
+ for param in fit.fit_results:
+ print(param['name'], ' = ', param['fitresult'])
+ dummy_list.append(param['fitresult'])
+ print("chisq = ", fit.chisq)
+
+ # Plot
+ constant, slope = dummy_list[:2]
+ p1 = dummy_list[2:]
+ print(p1)
+ y2 = slope * x + constant + sum_gauss(x, *p1)
+
+ try:
+ from silx.gui import qt
+ from silx.gui.plot.PlotWindow import PlotWindow
+ app = qt.QApplication([])
+ pw = PlotWindow(control=True)
+ pw.addCurve(x, y, "Original")
+ pw.addCurve(x, y2, "Fit result")
+ pw.legendsDockWidget.show()
+ pw.show()
+ app.exec()
+ except ImportError:
+ _logger.warning("Could not import qt to display fit result as curve")
+
+
+if __name__ == "__main__":
+ test()
diff --git a/src/silx/math/fit/fittheories.py b/src/silx/math/fit/fittheories.py
new file mode 100644
index 0000000..5461416
--- /dev/null
+++ b/src/silx/math/fit/fittheories.py
@@ -0,0 +1,1374 @@
+# coding: utf-8
+#/*##########################################################################
+#
+# Copyright (c) 2004-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+########################################################################### */
+"""This modules provides a set of fit functions and associated
+estimation functions in a format that can be imported into a
+:class:`silx.math.fit.FitManager` instance.
+
+These functions are well suited for fitting multiple gaussian shaped peaks
+typically found in spectroscopy data. The estimation functions are designed
+to detect how many peaks are present in the data, and provide an initial
+estimate for their height, their center location and their full-width
+at half maximum (fwhm).
+
+The limitation of these estimation algorithms is that only gaussians having a
+similar fwhm can be detected by the peak search algorithm.
+This *search fwhm* can be defined by the user, if
+he knows the characteristics of his data, or can be automatically estimated
+based on the fwhm of the largest peak in the data.
+
+The source code of this module can serve as template for defining your own
+fit functions.
+
+The functions to be imported by :meth:`FitManager.loadtheories` are defined by
+a dictionary :const:`THEORY`: with the following structure::
+
+ from silx.math.fit.fittheory import FitTheory
+
+ THEORY = {
+ 'theory_name_1': FitTheory(
+ description='Description of theory 1',
+ function=fitfunction1,
+ parameters=('param name 1', 'param name 2', …),
+ estimate=estimation_function1,
+ configure=configuration_function1,
+ derivative=derivative_function1),
+
+ 'theory_name_2': FitTheory(…),
+ }
+
+.. note::
+
+ Consider using an OrderedDict instead of a regular dictionary, when
+ defining your own theory dictionary, if the order matters to you.
+ This will likely be the case if you intend to load a selection of
+ functions in a GUI such as :class:`silx.gui.fit.FitManager`.
+
+Theory names can be customized (e.g. ``gauss, lorentz, splitgauss``…).
+
+The mandatory parameters for :class:`FitTheory` are ``function`` and
+``parameters``.
+
+You can also define an ``INIT`` function that will be executed by
+:meth:`FitManager.loadtheories`.
+
+See the documentation of :class:`silx.math.fit.fittheory.FitTheory`
+for more information.
+
+Module members:
+---------------
+"""
+import numpy
+from collections import OrderedDict
+import logging
+
+from silx.math.fit import functions
+from silx.math.fit.peaks import peak_search, guess_fwhm
+from silx.math.fit.filters import strip, savitsky_golay
+from silx.math.fit.leastsq import leastsq
+from silx.math.fit.fittheory import FitTheory
+
+_logger = logging.getLogger(__name__)
+
+__authors__ = ["V.A. Sole", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "15/05/2017"
+
+
+DEFAULT_CONFIG = {
+ 'NoConstraintsFlag': False,
+ 'PositiveFwhmFlag': True,
+ 'PositiveHeightAreaFlag': True,
+ 'SameFwhmFlag': False,
+ 'QuotedPositionFlag': False, # peak not outside data range
+ 'QuotedEtaFlag': False, # force 0 < eta < 1
+ # Peak detection
+ 'AutoScaling': False,
+ 'Yscaling': 1.0,
+ 'FwhmPoints': 8,
+ 'AutoFwhm': True,
+ 'Sensitivity': 2.5,
+ 'ForcePeakPresence': True,
+ # Hypermet
+ 'HypermetTails': 15,
+ 'QuotedFwhmFlag': 0,
+ 'MaxFwhm2InputRatio': 1.5,
+ 'MinFwhm2InputRatio': 0.4,
+ # short tail parameters
+ 'MinGaussArea4ShortTail': 50000.,
+ 'InitialShortTailAreaRatio': 0.050,
+ 'MaxShortTailAreaRatio': 0.100,
+ 'MinShortTailAreaRatio': 0.0010,
+ 'InitialShortTailSlopeRatio': 0.70,
+ 'MaxShortTailSlopeRatio': 2.00,
+ 'MinShortTailSlopeRatio': 0.50,
+ # long tail parameters
+ 'MinGaussArea4LongTail': 1000.0,
+ 'InitialLongTailAreaRatio': 0.050,
+ 'MaxLongTailAreaRatio': 0.300,
+ 'MinLongTailAreaRatio': 0.010,
+ 'InitialLongTailSlopeRatio': 20.0,
+ 'MaxLongTailSlopeRatio': 50.0,
+ 'MinLongTailSlopeRatio': 5.0,
+ # step tail
+ 'MinGaussHeight4StepTail': 5000.,
+ 'InitialStepTailHeightRatio': 0.002,
+ 'MaxStepTailHeightRatio': 0.0100,
+ 'MinStepTailHeightRatio': 0.0001,
+ # Hypermet constraints
+ # position in range [estimated position +- estimated fwhm/2]
+ 'HypermetQuotedPositionFlag': True,
+ 'DeltaPositionFwhmUnits': 0.5,
+ 'SameSlopeRatioFlag': 1,
+ 'SameAreaRatioFlag': 1,
+ # Strip bg removal
+ 'StripBackgroundFlag': True,
+ 'SmoothingFlag': True,
+ 'SmoothingWidth': 5,
+ 'StripWidth': 2,
+ 'StripIterations': 5000,
+ 'StripThresholdFactor': 1.0}
+"""This dictionary defines default configuration parameters that have effects
+on fit functions and estimation functions, mainly on fit constraints.
+This dictionary is accessible as attribute :attr:`FitTheories.config`,
+which can be modified by configuration functions defined in
+:const:`CONFIGURE`.
+"""
+
+CFREE = 0
+CPOSITIVE = 1
+CQUOTED = 2
+CFIXED = 3
+CFACTOR = 4
+CDELTA = 5
+CSUM = 6
+CIGNORED = 7
+
+
+class FitTheories(object):
+ """Class wrapping functions from :class:`silx.math.fit.functions`
+ and providing estimate functions for all of these fit functions."""
+ def __init__(self, config=None):
+ if config is None:
+ self.config = DEFAULT_CONFIG
+ else:
+ self.config = config
+
+ def ahypermet(self, x, *pars):
+ """
+ Wrapping of :func:`silx.math.fit.functions.sum_ahypermet` without
+ the tail flags in the function signature.
+
+ Depending on the value of `self.config['HypermetTails']`, one can
+ activate or deactivate the various terms of the hypermet function.
+
+ `self.config['HypermetTails']` must be an integer between 0 and 15.
+ It is a set of 4 binary flags, one for activating each one of the
+ hypermet terms: *gaussian function, short tail, long tail, step*.
+
+ For example, 15 can be expressed as ``1111`` in base 2, so a flag of
+ 15 means all terms are active.
+ """
+ g_term = self.config['HypermetTails'] & 1
+ st_term = (self.config['HypermetTails'] >> 1) & 1
+ lt_term = (self.config['HypermetTails'] >> 2) & 1
+ step_term = (self.config['HypermetTails'] >> 3) & 1
+ return functions.sum_ahypermet(x, *pars,
+ gaussian_term=g_term, st_term=st_term,
+ lt_term=lt_term, step_term=step_term)
+
+ def poly(self, x, *pars):
+ """Order n polynomial.
+ The order of the polynomial is defined by the number of
+ coefficients (``*pars``).
+
+ """
+ p = numpy.poly1d(pars)
+ return p(x)
+
+ @staticmethod
+ def estimate_poly(x, y, n=2):
+ """Estimate polynomial coefficients for a degree n polynomial.
+
+ """
+ pcoeffs = numpy.polyfit(x, y, n)
+ constraints = numpy.zeros((n + 1, 3), numpy.float64)
+ return pcoeffs, constraints
+
+ def estimate_quadratic(self, x, y):
+ """Estimate quadratic coefficients
+
+ """
+ return self.estimate_poly(x, y, n=2)
+
+ def estimate_cubic(self, x, y):
+ """Estimate coefficients for a degree 3 polynomial
+
+ """
+ return self.estimate_poly(x, y, n=3)
+
+ def estimate_quartic(self, x, y):
+ """Estimate coefficients for a degree 4 polynomial
+
+ """
+ return self.estimate_poly(x, y, n=4)
+
+ def estimate_quintic(self, x, y):
+ """Estimate coefficients for a degree 5 polynomial
+
+ """
+ return self.estimate_poly(x, y, n=5)
+
+ def strip_bg(self, y):
+ """Return the strip background of y, using parameters from
+ :attr:`config` dictionary (*StripBackgroundFlag, StripWidth,
+ StripIterations, StripThresholdFactor*)"""
+ remove_strip_bg = self.config.get('StripBackgroundFlag', False)
+ if remove_strip_bg:
+ if self.config['SmoothingFlag']:
+ y = savitsky_golay(y, self.config['SmoothingWidth'])
+ strip_width = self.config['StripWidth']
+ strip_niterations = self.config['StripIterations']
+ strip_thr_factor = self.config['StripThresholdFactor']
+ return strip(y, w=strip_width,
+ niterations=strip_niterations,
+ factor=strip_thr_factor)
+ else:
+ return numpy.zeros_like(y)
+
+ def guess_yscaling(self, y):
+ """Estimate scaling for y prior to peak search.
+ A smoothing filter is applied to y to estimate the noise level
+ (chi-squared)
+
+ :param y: Data array
+ :return: Scaling factor
+ """
+ # ensure y is an array
+ yy = numpy.array(y, copy=False)
+
+ # smooth
+ convolution_kernel = numpy.ones(shape=(3,)) / 3.
+ ysmooth = numpy.convolve(y, convolution_kernel, mode="same")
+
+ # remove zeros
+ idx_array = numpy.fabs(y) > 0.0
+ yy = yy[idx_array]
+ ysmooth = ysmooth[idx_array]
+
+ # compute scaling factor
+ chisq = numpy.mean((yy - ysmooth)**2 / numpy.fabs(yy))
+ if chisq > 0:
+ return 1. / chisq
+ else:
+ return 1.0
+
+ def peak_search(self, y, fwhm, sensitivity):
+ """Search for peaks in y array, after padding the array and
+ multiplying its value by a scaling factor.
+
+ :param y: 1-D data array
+ :param int fwhm: Typical full width at half maximum for peaks,
+ in number of points. This parameter is used for to discriminate between
+ true peaks and background fluctuations.
+ :param float sensitivity: Sensitivity parameter. This is a threshold factor
+ for peak detection. Only peaks larger than the standard deviation
+ of the noise multiplied by this sensitivity parameter are detected.
+ :return: List of peak indices
+ """
+ # add padding
+ ysearch = numpy.ones((len(y) + 2 * fwhm,), numpy.float64)
+ ysearch[0:fwhm] = y[0]
+ ysearch[-1:-fwhm - 1:-1] = y[len(y)-1]
+ ysearch[fwhm:fwhm + len(y)] = y[:]
+
+ scaling = self.guess_yscaling(y) if self.config["AutoScaling"] else self.config["Yscaling"]
+
+ if len(ysearch) > 1.5 * fwhm:
+ peaks = peak_search(scaling * ysearch,
+ fwhm=fwhm, sensitivity=sensitivity)
+ return [peak_index - fwhm for peak_index in peaks
+ if 0 <= peak_index - fwhm < len(y)]
+ else:
+ return []
+
+ def estimate_height_position_fwhm(self, x, y):
+ """Estimation of *Height, Position, FWHM* of peaks, for gaussian-like
+ curves.
+
+ This functions finds how many parameters are needed, based on the
+ number of peaks detected. Then it estimates the fit parameters
+ with a few iterations of fitting gaussian functions.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Height, Position, FWHM*.
+ Fit constraints depend on :attr:`config`.
+ """
+ fittedpar = []
+
+ bg = self.strip_bg(y)
+
+ if self.config['AutoFwhm']:
+ search_fwhm = guess_fwhm(y)
+ else:
+ search_fwhm = int(float(self.config['FwhmPoints']))
+ search_sens = float(self.config['Sensitivity'])
+
+ if search_fwhm < 3:
+ _logger.warning("Setting peak fwhm to 3 (lower limit)")
+ search_fwhm = 3
+ self.config['FwhmPoints'] = 3
+
+ if search_sens < 1:
+ _logger.warning("Setting peak search sensitivity to 1. " +
+ "(lower limit to filter out noise peaks)")
+ search_sens = 1
+ self.config['Sensitivity'] = 1
+
+ npoints = len(y)
+
+ # Find indices of peaks in data array
+ peaks = self.peak_search(y,
+ fwhm=search_fwhm,
+ sensitivity=search_sens)
+
+ if not len(peaks):
+ forcepeak = int(float(self.config.get('ForcePeakPresence', 0)))
+ if forcepeak:
+ delta = y - bg
+ # get index of global maximum
+ # (first one if several samples are equal to this value)
+ peaks = [numpy.nonzero(delta == delta.max())[0][0]]
+
+ # Find index of largest peak in peaks array
+ index_largest_peak = 0
+ if len(peaks) > 0:
+ # estimate fwhm as 5 * sampling interval
+ sig = 5 * abs(x[npoints - 1] - x[0]) / npoints
+ peakpos = x[int(peaks[0])]
+ if abs(peakpos) < 1.0e-16:
+ peakpos = 0.0
+ param = numpy.array(
+ [y[int(peaks[0])] - bg[int(peaks[0])], peakpos, sig])
+ height_largest_peak = param[0]
+ peak_index = 1
+ for i in peaks[1:]:
+ param2 = numpy.array(
+ [y[int(i)] - bg[int(i)], x[int(i)], sig])
+ param = numpy.concatenate((param, param2))
+ if param2[0] > height_largest_peak:
+ height_largest_peak = param2[0]
+ index_largest_peak = peak_index
+ peak_index += 1
+
+ # Subtract background
+ xw = x
+ yw = y - bg
+
+ cons = numpy.zeros((len(param), 3), numpy.float64)
+
+ # peak height must be positive
+ cons[0:len(param):3, 0] = CPOSITIVE
+ # force peaks to stay around their position
+ cons[1:len(param):3, 0] = CQUOTED
+
+ # set possible peak range to estimated peak +- guessed fwhm
+ if len(xw) > search_fwhm:
+ fwhmx = numpy.fabs(xw[int(search_fwhm)] - xw[0])
+ cons[1:len(param):3, 1] = param[1:len(param):3] - 0.5 * fwhmx
+ cons[1:len(param):3, 2] = param[1:len(param):3] + 0.5 * fwhmx
+ else:
+ shape = [max(1, int(x)) for x in (param[1:len(param):3])]
+ cons[1:len(param):3, 1] = min(xw) * numpy.ones(
+ shape,
+ numpy.float64)
+ cons[1:len(param):3, 2] = max(xw) * numpy.ones(
+ shape,
+ numpy.float64)
+
+ # ensure fwhm is positive
+ cons[2:len(param):3, 0] = CPOSITIVE
+
+ # run a quick iterative fit (4 iterations) to improve
+ # estimations
+ fittedpar, _, _ = leastsq(functions.sum_gauss, xw, yw, param,
+ max_iter=4, constraints=cons.tolist(),
+ full_output=True)
+
+ # set final constraints based on config parameters
+ cons = numpy.zeros((len(fittedpar), 3), numpy.float64)
+ peak_index = 0
+ for i in range(len(peaks)):
+ # Setup height area constrains
+ if not self.config['NoConstraintsFlag']:
+ if self.config['PositiveHeightAreaFlag']:
+ cons[peak_index, 0] = CPOSITIVE
+ cons[peak_index, 1] = 0
+ cons[peak_index, 2] = 0
+ peak_index += 1
+
+ # Setup position constrains
+ if not self.config['NoConstraintsFlag']:
+ if self.config['QuotedPositionFlag']:
+ cons[peak_index, 0] = CQUOTED
+ cons[peak_index, 1] = min(x)
+ cons[peak_index, 2] = max(x)
+ peak_index += 1
+
+ # Setup positive FWHM constrains
+ if not self.config['NoConstraintsFlag']:
+ if self.config['PositiveFwhmFlag']:
+ cons[peak_index, 0] = CPOSITIVE
+ cons[peak_index, 1] = 0
+ cons[peak_index, 2] = 0
+ if self.config['SameFwhmFlag']:
+ if i != index_largest_peak:
+ cons[peak_index, 0] = CFACTOR
+ cons[peak_index, 1] = 3 * index_largest_peak + 2
+ cons[peak_index, 2] = 1.0
+ peak_index += 1
+
+ return fittedpar, cons
+
+ def estimate_agauss(self, x, y):
+ """Estimation of *Area, Position, FWHM* of peaks, for gaussian-like
+ curves.
+
+ This functions uses :meth:`estimate_height_position_fwhm`, then
+ converts the height parameters to area under the curve with the
+ formula ``area = sqrt(2*pi) * height * fwhm / (2 * sqrt(2 * log(2))``
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Area, Position, FWHM*.
+ Fit constraints depend on :attr:`config`.
+ """
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ # get the number of found peaks
+ npeaks = len(fittedpar) // 3
+ for i in range(npeaks):
+ height = fittedpar[3 * i]
+ fwhm = fittedpar[3 * i + 2]
+ # Replace height with area in fittedpar
+ fittedpar[3 * i] = numpy.sqrt(2 * numpy.pi) * height * fwhm / (
+ 2.0 * numpy.sqrt(2 * numpy.log(2)))
+ return fittedpar, cons
+
+ def estimate_alorentz(self, x, y):
+ """Estimation of *Area, Position, FWHM* of peaks, for Lorentzian
+ curves.
+
+ This functions uses :meth:`estimate_height_position_fwhm`, then
+ converts the height parameters to area under the curve with the
+ formula ``area = height * fwhm * 0.5 * pi``
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Area, Position, FWHM*.
+ Fit constraints depend on :attr:`config`.
+ """
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ # get the number of found peaks
+ npeaks = len(fittedpar) // 3
+ for i in range(npeaks):
+ height = fittedpar[3 * i]
+ fwhm = fittedpar[3 * i + 2]
+ # Replace height with area in fittedpar
+ fittedpar[3 * i] = (height * fwhm * 0.5 * numpy.pi)
+ return fittedpar, cons
+
+ def estimate_splitgauss(self, x, y):
+ """Estimation of *Height, Position, FWHM1, FWHM2* of peaks, for
+ asymmetric gaussian-like curves.
+
+ This functions uses :meth:`estimate_height_position_fwhm`, then
+ adds a second (identical) estimation of FWHM to the fit parameters
+ for each peak, and the corresponding constraint.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Height, Position, FWHM1, FWHM2*.
+ Fit constraints depend on :attr:`config`.
+ """
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ # get the number of found peaks
+ npeaks = len(fittedpar) // 3
+ estimated_parameters = []
+ estimated_constraints = numpy.zeros((4 * npeaks, 3), numpy.float64)
+ for i in range(npeaks):
+ for j in range(3):
+ estimated_parameters.append(fittedpar[3 * i + j])
+ # fwhm2 estimate = fwhm1
+ estimated_parameters.append(fittedpar[3 * i + 2])
+ # height
+ estimated_constraints[4 * i, 0] = cons[3 * i, 0]
+ estimated_constraints[4 * i, 1] = cons[3 * i, 1]
+ estimated_constraints[4 * i, 2] = cons[3 * i, 2]
+ # position
+ estimated_constraints[4 * i + 1, 0] = cons[3 * i + 1, 0]
+ estimated_constraints[4 * i + 1, 1] = cons[3 * i + 1, 1]
+ estimated_constraints[4 * i + 1, 2] = cons[3 * i + 1, 2]
+ # fwhm1
+ estimated_constraints[4 * i + 2, 0] = cons[3 * i + 2, 0]
+ estimated_constraints[4 * i + 2, 1] = cons[3 * i + 2, 1]
+ estimated_constraints[4 * i + 2, 2] = cons[3 * i + 2, 2]
+ # fwhm2
+ estimated_constraints[4 * i + 3, 0] = cons[3 * i + 2, 0]
+ estimated_constraints[4 * i + 3, 1] = cons[3 * i + 2, 1]
+ estimated_constraints[4 * i + 3, 2] = cons[3 * i + 2, 2]
+ if cons[3 * i + 2, 0] == CFACTOR:
+ # convert indices of related parameters
+ # (this happens if SameFwhmFlag == True)
+ estimated_constraints[4 * i + 2, 1] = \
+ int(cons[3 * i + 2, 1] / 3) * 4 + 2
+ estimated_constraints[4 * i + 3, 1] = \
+ int(cons[3 * i + 2, 1] / 3) * 4 + 3
+ return estimated_parameters, estimated_constraints
+
+ def estimate_pvoigt(self, x, y):
+ """Estimation of *Height, Position, FWHM, eta* of peaks, for
+ pseudo-Voigt curves.
+
+ Pseudo-Voigt are a sum of a gaussian curve *G(x)* and a lorentzian
+ curve *L(x)* with the same height, center, fwhm parameters:
+ ``y(x) = eta * G(x) + (1-eta) * L(x)``
+
+ This functions uses :meth:`estimate_height_position_fwhm`, then
+ adds a constant estimation of *eta* (0.5) to the fit parameters
+ for each peak, and the corresponding constraint.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Height, Position, FWHM, eta*.
+ Constraint for the eta parameter can be set to QUOTED (0.--1.)
+ by setting :attr:`config`['QuotedEtaFlag'] to ``True``.
+ If this is not the case, the constraint code is set to FREE.
+ """
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ npeaks = len(fittedpar) // 3
+ newpar = []
+ newcons = numpy.zeros((4 * npeaks, 3), numpy.float64)
+ # find out related parameters proper index
+ if not self.config['NoConstraintsFlag']:
+ if self.config['SameFwhmFlag']:
+ j = 0
+ # get the index of the free FWHM
+ for i in range(npeaks):
+ if cons[3 * i + 2, 0] != 4:
+ j = i
+ for i in range(npeaks):
+ if i != j:
+ cons[3 * i + 2, 1] = 4 * j + 2
+ for i in range(npeaks):
+ newpar.append(fittedpar[3 * i])
+ newpar.append(fittedpar[3 * i + 1])
+ newpar.append(fittedpar[3 * i + 2])
+ newpar.append(0.5)
+ # height
+ newcons[4 * i, 0] = cons[3 * i, 0]
+ newcons[4 * i, 1] = cons[3 * i, 1]
+ newcons[4 * i, 2] = cons[3 * i, 2]
+ # position
+ newcons[4 * i + 1, 0] = cons[3 * i + 1, 0]
+ newcons[4 * i + 1, 1] = cons[3 * i + 1, 1]
+ newcons[4 * i + 1, 2] = cons[3 * i + 1, 2]
+ # fwhm
+ newcons[4 * i + 2, 0] = cons[3 * i + 2, 0]
+ newcons[4 * i + 2, 1] = cons[3 * i + 2, 1]
+ newcons[4 * i + 2, 2] = cons[3 * i + 2, 2]
+ # Eta constrains
+ newcons[4 * i + 3, 0] = CFREE
+ newcons[4 * i + 3, 1] = 0
+ newcons[4 * i + 3, 2] = 0
+ if self.config['QuotedEtaFlag']:
+ newcons[4 * i + 3, 0] = CQUOTED
+ newcons[4 * i + 3, 1] = 0.0
+ newcons[4 * i + 3, 2] = 1.0
+ return newpar, newcons
+
+ def estimate_splitpvoigt(self, x, y):
+ """Estimation of *Height, Position, FWHM1, FWHM2, eta* of peaks, for
+ asymmetric pseudo-Voigt curves.
+
+ This functions uses :meth:`estimate_height_position_fwhm`, then
+ adds an identical FWHM2 parameter and a constant estimation of
+ *eta* (0.5) to the fit parameters for each peak, and the corresponding
+ constraints.
+
+ Constraint for the eta parameter can be set to QUOTED (0.--1.)
+ by setting :attr:`config`['QuotedEtaFlag'] to ``True``.
+ If this is not the case, the constraint code is set to FREE.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Height, Position, FWHM1, FWHM2, eta*.
+ """
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ npeaks = len(fittedpar) // 3
+ newpar = []
+ newcons = numpy.zeros((5 * npeaks, 3), numpy.float64)
+ # find out related parameters proper index
+ if not self.config['NoConstraintsFlag']:
+ if self.config['SameFwhmFlag']:
+ j = 0
+ # get the index of the free FWHM
+ for i in range(npeaks):
+ if cons[3 * i + 2, 0] != 4:
+ j = i
+ for i in range(npeaks):
+ if i != j:
+ cons[3 * i + 2, 1] = 4 * j + 2
+ for i in range(npeaks):
+ # height
+ newpar.append(fittedpar[3 * i])
+ # position
+ newpar.append(fittedpar[3 * i + 1])
+ # fwhm1
+ newpar.append(fittedpar[3 * i + 2])
+ # fwhm2 estimate equal to fwhm1
+ newpar.append(fittedpar[3 * i + 2])
+ # eta
+ newpar.append(0.5)
+ # constraint codes
+ # ----------------
+ # height
+ newcons[5 * i, 0] = cons[3 * i, 0]
+ # position
+ newcons[5 * i + 1, 0] = cons[3 * i + 1, 0]
+ # fwhm1
+ newcons[5 * i + 2, 0] = cons[3 * i + 2, 0]
+ # fwhm2
+ newcons[5 * i + 3, 0] = cons[3 * i + 2, 0]
+ # cons 1
+ # ------
+ newcons[5 * i, 1] = cons[3 * i, 1]
+ newcons[5 * i + 1, 1] = cons[3 * i + 1, 1]
+ newcons[5 * i + 2, 1] = cons[3 * i + 2, 1]
+ newcons[5 * i + 3, 1] = cons[3 * i + 2, 1]
+ # cons 2
+ # ------
+ newcons[5 * i, 2] = cons[3 * i, 2]
+ newcons[5 * i + 1, 2] = cons[3 * i + 1, 2]
+ newcons[5 * i + 2, 2] = cons[3 * i + 2, 2]
+ newcons[5 * i + 3, 2] = cons[3 * i + 2, 2]
+
+ if cons[3 * i + 2, 0] == CFACTOR:
+ # fwhm2 connstraint depends on fwhm1
+ newcons[5 * i + 3, 1] = newcons[5 * i + 2, 1] + 1
+ # eta constraints
+ newcons[5 * i + 4, 0] = CFREE
+ newcons[5 * i + 4, 1] = 0
+ newcons[5 * i + 4, 2] = 0
+ if self.config['QuotedEtaFlag']:
+ newcons[5 * i + 4, 0] = CQUOTED
+ newcons[5 * i + 4, 1] = 0.0
+ newcons[5 * i + 4, 2] = 1.0
+ return newpar, newcons
+
+ def estimate_apvoigt(self, x, y):
+ """Estimation of *Area, Position, FWHM1, eta* of peaks, for
+ pseudo-Voigt curves.
+
+ This functions uses :meth:`estimate_pvoigt`, then converts the height
+ parameter to area.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *Area, Position, FWHM, eta*.
+ """
+ fittedpar, cons = self.estimate_pvoigt(x, y)
+ npeaks = len(fittedpar) // 4
+ # Assume 50% of the area is determined by the gaussian and 50% by
+ # the Lorentzian.
+ for i in range(npeaks):
+ height = fittedpar[4 * i]
+ fwhm = fittedpar[4 * i + 2]
+ fittedpar[4 * i] = 0.5 * (height * fwhm * 0.5 * numpy.pi) +\
+ 0.5 * (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
+ ) * numpy.sqrt(2 * numpy.pi)
+ return fittedpar, cons
+
+ def estimate_ahypermet(self, x, y):
+ """Estimation of *area, position, fwhm, st_area_r, st_slope_r,
+ lt_area_r, lt_slope_r, step_height_r* of peaks, for hypermet curves.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each peak are:
+ *area, position, fwhm, st_area_r, st_slope_r,
+ lt_area_r, lt_slope_r, step_height_r* .
+ """
+ yscaling = self.config.get('Yscaling', 1.0)
+ if yscaling == 0:
+ yscaling = 1.0
+ fittedpar, cons = self.estimate_height_position_fwhm(x, y)
+ npeaks = len(fittedpar) // 3
+ newpar = []
+ newcons = numpy.zeros((8 * npeaks, 3), numpy.float64)
+ main_peak = 0
+ # find out related parameters proper index
+ if not self.config['NoConstraintsFlag']:
+ if self.config['SameFwhmFlag']:
+ j = 0
+ # get the index of the free FWHM
+ for i in range(npeaks):
+ if cons[3 * i + 2, 0] != 4:
+ j = i
+ for i in range(npeaks):
+ if i != j:
+ cons[3 * i + 2, 1] = 8 * j + 2
+ main_peak = j
+ for i in range(npeaks):
+ if fittedpar[3 * i] > fittedpar[3 * main_peak]:
+ main_peak = i
+
+ for i in range(npeaks):
+ height = fittedpar[3 * i]
+ position = fittedpar[3 * i + 1]
+ fwhm = fittedpar[3 * i + 2]
+ area = (height * fwhm / (2.0 * numpy.sqrt(2 * numpy.log(2)))
+ ) * numpy.sqrt(2 * numpy.pi)
+ # the gaussian parameters
+ newpar.append(area)
+ newpar.append(position)
+ newpar.append(fwhm)
+ # print "area, pos , fwhm = ",area,position,fwhm
+ # Avoid zero derivatives because of not calculating contribution
+ g_term = 1
+ st_term = 1
+ lt_term = 1
+ step_term = 1
+ if self.config['HypermetTails'] != 0:
+ g_term = self.config['HypermetTails'] & 1
+ st_term = (self.config['HypermetTails'] >> 1) & 1
+ lt_term = (self.config['HypermetTails'] >> 2) & 1
+ step_term = (self.config['HypermetTails'] >> 3) & 1
+ if g_term == 0:
+ # fix the gaussian parameters
+ newcons[8 * i, 0] = CFIXED
+ newcons[8 * i + 1, 0] = CFIXED
+ newcons[8 * i + 2, 0] = CFIXED
+ # the short tail parameters
+ if ((area * yscaling) <
+ self.config['MinGaussArea4ShortTail']) | \
+ (st_term == 0):
+ newpar.append(0.0)
+ newpar.append(0.0)
+ newcons[8 * i + 3, 0] = CFIXED
+ newcons[8 * i + 3, 1] = 0.0
+ newcons[8 * i + 3, 2] = 0.0
+ newcons[8 * i + 4, 0] = CFIXED
+ newcons[8 * i + 4, 1] = 0.0
+ newcons[8 * i + 4, 2] = 0.0
+ else:
+ newpar.append(self.config['InitialShortTailAreaRatio'])
+ newpar.append(self.config['InitialShortTailSlopeRatio'])
+ newcons[8 * i + 3, 0] = CQUOTED
+ newcons[8 * i + 3, 1] = self.config['MinShortTailAreaRatio']
+ newcons[8 * i + 3, 2] = self.config['MaxShortTailAreaRatio']
+ newcons[8 * i + 4, 0] = CQUOTED
+ newcons[8 * i + 4, 1] = self.config['MinShortTailSlopeRatio']
+ newcons[8 * i + 4, 2] = self.config['MaxShortTailSlopeRatio']
+ # the long tail parameters
+ if ((area * yscaling) <
+ self.config['MinGaussArea4LongTail']) | \
+ (lt_term == 0):
+ newpar.append(0.0)
+ newpar.append(0.0)
+ newcons[8 * i + 5, 0] = CFIXED
+ newcons[8 * i + 5, 1] = 0.0
+ newcons[8 * i + 5, 2] = 0.0
+ newcons[8 * i + 6, 0] = CFIXED
+ newcons[8 * i + 6, 1] = 0.0
+ newcons[8 * i + 6, 2] = 0.0
+ else:
+ newpar.append(self.config['InitialLongTailAreaRatio'])
+ newpar.append(self.config['InitialLongTailSlopeRatio'])
+ newcons[8 * i + 5, 0] = CQUOTED
+ newcons[8 * i + 5, 1] = self.config['MinLongTailAreaRatio']
+ newcons[8 * i + 5, 2] = self.config['MaxLongTailAreaRatio']
+ newcons[8 * i + 6, 0] = CQUOTED
+ newcons[8 * i + 6, 1] = self.config['MinLongTailSlopeRatio']
+ newcons[8 * i + 6, 2] = self.config['MaxLongTailSlopeRatio']
+ # the step parameters
+ if ((height * yscaling) <
+ self.config['MinGaussHeight4StepTail']) | \
+ (step_term == 0):
+ newpar.append(0.0)
+ newcons[8 * i + 7, 0] = CFIXED
+ newcons[8 * i + 7, 1] = 0.0
+ newcons[8 * i + 7, 2] = 0.0
+ else:
+ newpar.append(self.config['InitialStepTailHeightRatio'])
+ newcons[8 * i + 7, 0] = CQUOTED
+ newcons[8 * i + 7, 1] = self.config['MinStepTailHeightRatio']
+ newcons[8 * i + 7, 2] = self.config['MaxStepTailHeightRatio']
+ # if self.config['NoConstraintsFlag'] == 1:
+ # newcons=numpy.zeros((8*npeaks, 3),numpy.float64)
+ if npeaks > 0:
+ if g_term:
+ if self.config['PositiveHeightAreaFlag']:
+ for i in range(npeaks):
+ newcons[8 * i, 0] = CPOSITIVE
+ if self.config['PositiveFwhmFlag']:
+ for i in range(npeaks):
+ newcons[8 * i + 2, 0] = CPOSITIVE
+ if self.config['SameFwhmFlag']:
+ for i in range(npeaks):
+ if i != main_peak:
+ newcons[8 * i + 2, 0] = CFACTOR
+ newcons[8 * i + 2, 1] = 8 * main_peak + 2
+ newcons[8 * i + 2, 2] = 1.0
+ if self.config['HypermetQuotedPositionFlag']:
+ for i in range(npeaks):
+ delta = self.config['DeltaPositionFwhmUnits'] * fwhm
+ newcons[8 * i + 1, 0] = CQUOTED
+ newcons[8 * i + 1, 1] = newpar[8 * i + 1] - delta
+ newcons[8 * i + 1, 2] = newpar[8 * i + 1] + delta
+ if self.config['SameSlopeRatioFlag']:
+ for i in range(npeaks):
+ if i != main_peak:
+ newcons[8 * i + 4, 0] = CFACTOR
+ newcons[8 * i + 4, 1] = 8 * main_peak + 4
+ newcons[8 * i + 4, 2] = 1.0
+ newcons[8 * i + 6, 0] = CFACTOR
+ newcons[8 * i + 6, 1] = 8 * main_peak + 6
+ newcons[8 * i + 6, 2] = 1.0
+ if self.config['SameAreaRatioFlag']:
+ for i in range(npeaks):
+ if i != main_peak:
+ newcons[8 * i + 3, 0] = CFACTOR
+ newcons[8 * i + 3, 1] = 8 * main_peak + 3
+ newcons[8 * i + 3, 2] = 1.0
+ newcons[8 * i + 5, 0] = CFACTOR
+ newcons[8 * i + 5, 1] = 8 * main_peak + 5
+ newcons[8 * i + 5, 2] = 1.0
+ return newpar, newcons
+
+ def estimate_stepdown(self, x, y):
+ """Estimation of parameters for stepdown curves.
+
+ The functions estimates gaussian parameters for the derivative of
+ the data, takes the largest gaussian peak and uses its estimated
+ parameters to define the center of the step and its fwhm. The
+ estimated amplitude returned is simply ``max(y) - min(y)``.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit newconstraints.
+ Parameters to be estimated for each stepdown are:
+ *height, centroid, fwhm* .
+ """
+ crappyfilter = [-0.25, -0.75, 0.0, 0.75, 0.25]
+ cutoff = len(crappyfilter) // 2
+ y_deriv = numpy.convolve(y,
+ crappyfilter,
+ mode="valid")
+
+ # make the derivative's peak have the same amplitude as the step
+ if max(y_deriv) > 0:
+ y_deriv = y_deriv * max(y) / max(y_deriv)
+
+ fittedpar, newcons = self.estimate_height_position_fwhm(
+ x[cutoff:-cutoff], y_deriv)
+
+ data_amplitude = max(y) - min(y)
+
+ # use parameters from largest gaussian found
+ if len(fittedpar):
+ npeaks = len(fittedpar) // 3
+ largest_index = 0
+ largest = [data_amplitude,
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2]]
+ for i in range(npeaks):
+ if fittedpar[3 * i] > largest[0]:
+ largest_index = i
+ largest = [data_amplitude,
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2]]
+ else:
+ # no peak was found
+ largest = [data_amplitude, # height
+ x[len(x)//2], # center: middle of x range
+ self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value
+
+ # Setup constrains
+ newcons = numpy.zeros((3, 3), numpy.float64)
+ if not self.config['NoConstraintsFlag']:
+ # Setup height constrains
+ if self.config['PositiveHeightAreaFlag']:
+ newcons[0, 0] = CPOSITIVE
+ newcons[0, 1] = 0
+ newcons[0, 2] = 0
+
+ # Setup position constrains
+ if self.config['QuotedPositionFlag']:
+ newcons[1, 0] = CQUOTED
+ newcons[1, 1] = min(x)
+ newcons[1, 2] = max(x)
+
+ # Setup positive FWHM constrains
+ if self.config['PositiveFwhmFlag']:
+ newcons[2, 0] = CPOSITIVE
+ newcons[2, 1] = 0
+ newcons[2, 2] = 0
+
+ return largest, newcons
+
+ def estimate_slit(self, x, y):
+ """Estimation of parameters for slit curves.
+
+ The functions estimates stepup and stepdown parameters for the largest
+ steps, and uses them for calculating the center (middle between stepup
+ and stepdown), the height (maximum amplitude in data), the fwhm
+ (distance between the up- and down-step centers) and the beamfwhm
+ (average of FWHM for up- and down-step).
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each slit are:
+ *height, position, fwhm, beamfwhm* .
+ """
+ largestup, cons = self.estimate_stepup(x, y)
+ largestdown, cons = self.estimate_stepdown(x, y)
+ fwhm = numpy.fabs(largestdown[1] - largestup[1])
+ beamfwhm = 0.5 * (largestup[2] + largestdown[1])
+ beamfwhm = min(beamfwhm, fwhm / 10.0)
+ beamfwhm = max(beamfwhm, (max(x) - min(x)) * 3.0 / len(x))
+
+ y_minus_bg = y - self.strip_bg(y)
+ height = max(y_minus_bg)
+
+ i1 = numpy.nonzero(y_minus_bg >= 0.5 * height)[0]
+ xx = numpy.take(x, i1)
+ position = (xx[0] + xx[-1]) / 2.0
+ fwhm = xx[-1] - xx[0]
+ largest = [height, position, fwhm, beamfwhm]
+ cons = numpy.zeros((4, 3), numpy.float64)
+ # Setup constrains
+ if not self.config['NoConstraintsFlag']:
+ # Setup height constrains
+ if self.config['PositiveHeightAreaFlag']:
+ cons[0, 0] = CPOSITIVE
+ cons[0, 1] = 0
+ cons[0, 2] = 0
+
+ # Setup position constrains
+ if self.config['QuotedPositionFlag']:
+ cons[1, 0] = CQUOTED
+ cons[1, 1] = min(x)
+ cons[1, 2] = max(x)
+
+ # Setup positive FWHM constrains
+ if self.config['PositiveFwhmFlag']:
+ cons[2, 0] = CPOSITIVE
+ cons[2, 1] = 0
+ cons[2, 2] = 0
+
+ # Setup positive FWHM constrains
+ if self.config['PositiveFwhmFlag']:
+ cons[3, 0] = CPOSITIVE
+ cons[3, 1] = 0
+ cons[3, 2] = 0
+ return largest, cons
+
+ def estimate_stepup(self, x, y):
+ """Estimation of parameters for a single step up curve.
+
+ The functions estimates gaussian parameters for the derivative of
+ the data, takes the largest gaussian peak and uses its estimated
+ parameters to define the center of the step and its fwhm. The
+ estimated amplitude returned is simply ``max(y) - min(y)``.
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ Parameters to be estimated for each stepup are:
+ *height, centroid, fwhm* .
+ """
+ crappyfilter = [0.25, 0.75, 0.0, -0.75, -0.25]
+ cutoff = len(crappyfilter) // 2
+ y_deriv = numpy.convolve(y, crappyfilter, mode="valid")
+ if max(y_deriv) > 0:
+ y_deriv = y_deriv * max(y) / max(y_deriv)
+
+ fittedpar, cons = self.estimate_height_position_fwhm(
+ x[cutoff:-cutoff], y_deriv)
+
+ # for height, use the data amplitude after removing the background
+ data_amplitude = max(y) - min(y)
+
+ # find params of the largest gaussian found
+ if len(fittedpar):
+ npeaks = len(fittedpar) // 3
+ largest_index = 0
+ largest = [data_amplitude,
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2]]
+ for i in range(npeaks):
+ if fittedpar[3 * i] > largest[0]:
+ largest_index = i
+ largest = [fittedpar[3 * largest_index],
+ fittedpar[3 * largest_index + 1],
+ fittedpar[3 * largest_index + 2]]
+ else:
+ # no peak was found
+ largest = [data_amplitude, # height
+ x[len(x)//2], # center: middle of x range
+ self.config["FwhmPoints"] * (x[1] - x[0])] # fwhm: default value
+
+ newcons = numpy.zeros((3, 3), numpy.float64)
+ # Setup constrains
+ if not self.config['NoConstraintsFlag']:
+ # Setup height constraints
+ if self.config['PositiveHeightAreaFlag']:
+ newcons[0, 0] = CPOSITIVE
+ newcons[0, 1] = 0
+ newcons[0, 2] = 0
+
+ # Setup position constraints
+ if self.config['QuotedPositionFlag']:
+ newcons[1, 0] = CQUOTED
+ newcons[1, 1] = min(x)
+ newcons[1, 2] = max(x)
+
+ # Setup positive FWHM constraints
+ if self.config['PositiveFwhmFlag']:
+ newcons[2, 0] = CPOSITIVE
+ newcons[2, 1] = 0
+ newcons[2, 2] = 0
+
+ return largest, newcons
+
+ def estimate_periodic_gauss(self, x, y):
+ """Estimation of parameters for periodic gaussian curves:
+ *number of peaks, distance between peaks, height, position of the
+ first peak, fwhm*
+
+ The functions detects all peaks, then computes the parameters the
+ following way:
+
+ - *distance*: average of distances between detected peaks
+ - *height*: average height of detected peaks
+ - *fwhm*: fwhm of the highest peak (in number of samples) if
+ field ``'AutoFwhm'`` in :attr:`config` is ``True``, else take
+ the default value (field ``'FwhmPoints'`` in :attr:`config`)
+
+ :param x: Array of abscissa values
+ :param y: Array of ordinate values (``y = f(x)``)
+ :return: Tuple of estimated fit parameters and fit constraints.
+ """
+ yscaling = self.config.get('Yscaling', 1.0)
+ if yscaling == 0:
+ yscaling = 1.0
+
+ bg = self.strip_bg(y)
+
+ if self.config['AutoFwhm']:
+ search_fwhm = guess_fwhm(y)
+ else:
+ search_fwhm = int(float(self.config['FwhmPoints']))
+ search_sens = float(self.config['Sensitivity'])
+
+ if search_fwhm < 3:
+ search_fwhm = 3
+
+ if search_sens < 1:
+ search_sens = 1
+
+ if len(y) > 1.5 * search_fwhm:
+ peaks = peak_search(yscaling * y, fwhm=search_fwhm,
+ sensitivity=search_sens)
+ else:
+ peaks = []
+ npeaks = len(peaks)
+ if not npeaks:
+ fittedpar = []
+ cons = numpy.zeros((len(fittedpar), 3), numpy.float64)
+ return fittedpar, cons
+
+ fittedpar = [0.0, 0.0, 0.0, 0.0, 0.0]
+
+ # The number of peaks
+ fittedpar[0] = npeaks
+
+ # The separation between peaks in x units
+ delta = 0.0
+ height = 0.0
+ for i in range(npeaks):
+ height += y[int(peaks[i])] - bg[int(peaks[i])]
+ if i != npeaks - 1:
+ delta += (x[int(peaks[i + 1])] - x[int(peaks[i])])
+
+ # delta between peaks
+ if npeaks > 1:
+ fittedpar[1] = delta / (npeaks - 1)
+
+ # starting height
+ fittedpar[2] = height / npeaks
+
+ # position of the first peak
+ fittedpar[3] = x[int(peaks[0])]
+
+ # Estimate the fwhm
+ fittedpar[4] = search_fwhm
+
+ # setup constraints
+ cons = numpy.zeros((5, 3), numpy.float64)
+ cons[0, 0] = CFIXED # the number of gaussians
+ if npeaks == 1:
+ cons[1, 0] = CFIXED # the delta between peaks
+ else:
+ cons[1, 0] = CFREE
+ j = 2
+ # Setup height area constrains
+ if not self.config['NoConstraintsFlag']:
+ if self.config['PositiveHeightAreaFlag']:
+ # POSITIVE = 1
+ cons[j, 0] = CPOSITIVE
+ cons[j, 1] = 0
+ cons[j, 2] = 0
+ j += 1
+
+ # Setup position constrains
+ if not self.config['NoConstraintsFlag']:
+ if self.config['QuotedPositionFlag']:
+ # QUOTED = 2
+ cons[j, 0] = CQUOTED
+ cons[j, 1] = min(x)
+ cons[j, 2] = max(x)
+ j += 1
+
+ # Setup positive FWHM constrains
+ if not self.config['NoConstraintsFlag']:
+ if self.config['PositiveFwhmFlag']:
+ # POSITIVE=1
+ cons[j, 0] = CPOSITIVE
+ cons[j, 1] = 0
+ cons[j, 2] = 0
+ j += 1
+ return fittedpar, cons
+
+ def configure(self, **kw):
+ """Add new / unknown keyword arguments to :attr:`config`,
+ update entries in :attr:`config` if the parameter name is a existing
+ key.
+
+ :param kw: Dictionary of keyword arguments.
+ :return: Configuration dictionary :attr:`config`
+ """
+ if not kw.keys():
+ return self.config
+ for key in kw.keys():
+ notdone = 1
+ # take care of lower / upper case problems ...
+ for config_key in self.config.keys():
+ if config_key.lower() == key.lower():
+ self.config[config_key] = kw[key]
+ notdone = 0
+ if notdone:
+ self.config[key] = kw[key]
+ return self.config
+
+fitfuns = FitTheories()
+
+THEORY = OrderedDict((
+ ('Gaussians',
+ FitTheory(description='Gaussian functions',
+ function=functions.sum_gauss,
+ parameters=('Height', 'Position', 'FWHM'),
+ estimate=fitfuns.estimate_height_position_fwhm,
+ configure=fitfuns.configure)),
+ ('Lorentz',
+ FitTheory(description='Lorentzian functions',
+ function=functions.sum_lorentz,
+ parameters=('Height', 'Position', 'FWHM'),
+ estimate=fitfuns.estimate_height_position_fwhm,
+ configure=fitfuns.configure)),
+ ('Area Gaussians',
+ FitTheory(description='Gaussian functions (area)',
+ function=functions.sum_agauss,
+ parameters=('Area', 'Position', 'FWHM'),
+ estimate=fitfuns.estimate_agauss,
+ configure=fitfuns.configure)),
+ ('Area Lorentz',
+ FitTheory(description='Lorentzian functions (area)',
+ function=functions.sum_alorentz,
+ parameters=('Area', 'Position', 'FWHM'),
+ estimate=fitfuns.estimate_alorentz,
+ configure=fitfuns.configure)),
+ ('Pseudo-Voigt Line',
+ FitTheory(description='Pseudo-Voigt functions',
+ function=functions.sum_pvoigt,
+ parameters=('Height', 'Position', 'FWHM', 'Eta'),
+ estimate=fitfuns.estimate_pvoigt,
+ configure=fitfuns.configure)),
+ ('Area Pseudo-Voigt',
+ FitTheory(description='Pseudo-Voigt functions (area)',
+ function=functions.sum_apvoigt,
+ parameters=('Area', 'Position', 'FWHM', 'Eta'),
+ estimate=fitfuns.estimate_apvoigt,
+ configure=fitfuns.configure)),
+ ('Split Gaussian',
+ FitTheory(description='Asymmetric gaussian functions',
+ function=functions.sum_splitgauss,
+ parameters=('Height', 'Position', 'LowFWHM',
+ 'HighFWHM'),
+ estimate=fitfuns.estimate_splitgauss,
+ configure=fitfuns.configure)),
+ ('Split Lorentz',
+ FitTheory(description='Asymmetric lorentzian functions',
+ function=functions.sum_splitlorentz,
+ parameters=('Height', 'Position', 'LowFWHM', 'HighFWHM'),
+ estimate=fitfuns.estimate_splitgauss,
+ configure=fitfuns.configure)),
+ ('Split Pseudo-Voigt',
+ FitTheory(description='Asymmetric pseudo-Voigt functions',
+ function=functions.sum_splitpvoigt,
+ parameters=('Height', 'Position', 'LowFWHM',
+ 'HighFWHM', 'Eta'),
+ estimate=fitfuns.estimate_splitpvoigt,
+ configure=fitfuns.configure)),
+ ('Step Down',
+ FitTheory(description='Step down function',
+ function=functions.sum_stepdown,
+ parameters=('Height', 'Position', 'FWHM'),
+ estimate=fitfuns.estimate_stepdown,
+ configure=fitfuns.configure)),
+ ('Step Up',
+ FitTheory(description='Step up function',
+ function=functions.sum_stepup,
+ parameters=('Height', 'Position', 'FWHM'),
+ estimate=fitfuns.estimate_stepup,
+ configure=fitfuns.configure)),
+ ('Slit',
+ FitTheory(description='Slit function',
+ function=functions.sum_slit,
+ parameters=('Height', 'Position', 'FWHM', 'BeamFWHM'),
+ estimate=fitfuns.estimate_slit,
+ configure=fitfuns.configure)),
+ ('Atan',
+ FitTheory(description='Arctan step up function',
+ function=functions.atan_stepup,
+ parameters=('Height', 'Position', 'Width'),
+ estimate=fitfuns.estimate_stepup,
+ configure=fitfuns.configure)),
+ ('Hypermet',
+ FitTheory(description='Hypermet functions',
+ function=fitfuns.ahypermet, # customized version of functions.sum_ahypermet
+ parameters=('G_Area', 'Position', 'FWHM', 'ST_Area',
+ 'ST_Slope', 'LT_Area', 'LT_Slope', 'Step_H'),
+ estimate=fitfuns.estimate_ahypermet,
+ configure=fitfuns.configure)),
+ # ('Periodic Gaussians',
+ # FitTheory(description='Periodic gaussian functions',
+ # function=functions.periodic_gauss,
+ # parameters=('N', 'Delta', 'Height', 'Position', 'FWHM'),
+ # estimate=fitfuns.estimate_periodic_gauss,
+ # configure=fitfuns.configure))
+ ('Degree 2 Polynomial',
+ FitTheory(description='Degree 2 polynomial'
+ '\ny = a*x^2 + b*x +c',
+ function=fitfuns.poly,
+ parameters=['a', 'b', 'c'],
+ estimate=fitfuns.estimate_quadratic)),
+ ('Degree 3 Polynomial',
+ FitTheory(description='Degree 3 polynomial'
+ '\ny = a*x^3 + b*x^2 + c*x + d',
+ function=fitfuns.poly,
+ parameters=['a', 'b', 'c', 'd'],
+ estimate=fitfuns.estimate_cubic)),
+ ('Degree 4 Polynomial',
+ FitTheory(description='Degree 4 polynomial'
+ '\ny = a*x^4 + b*x^3 + c*x^2 + d*x + e',
+ function=fitfuns.poly,
+ parameters=['a', 'b', 'c', 'd', 'e'],
+ estimate=fitfuns.estimate_quartic)),
+ ('Degree 5 Polynomial',
+ FitTheory(description='Degree 5 polynomial'
+ '\ny = a*x^5 + b*x^4 + c*x^3 + d*x^2 + e*x + f',
+ function=fitfuns.poly,
+ parameters=['a', 'b', 'c', 'd', 'e', 'f'],
+ estimate=fitfuns.estimate_quintic)),
+))
+"""Dictionary of fit theories: fit functions and their associated estimation
+function, parameters list, configuration function and description.
+"""
+
+
+def test(a):
+ from silx.math.fit import fitmanager
+ x = numpy.arange(1000).astype(numpy.float64)
+ p = [1500, 100., 50.0,
+ 1500, 700., 50.0]
+ y_synthetic = functions.sum_gauss(x, *p) + 1
+
+ fit = fitmanager.FitManager(x, y_synthetic)
+ fit.addtheory('Gaussians', functions.sum_gauss, ['Height', 'Position', 'FWHM'],
+ a.estimate_height_position_fwhm)
+ fit.settheory('Gaussians')
+ fit.setbackground('Linear')
+
+ fit.estimate()
+ fit.runfit()
+
+ y_fit = fit.gendata()
+
+ print("Fit parameter names: %s" % str(fit.get_names()))
+ print("Theoretical parameters: %s" % str(numpy.append([1, 0], p)))
+ print("Fitted parameters: %s" % str(fit.get_fitted_parameters()))
+
+ try:
+ from silx.gui import qt
+ from silx.gui.plot import plot1D
+ app = qt.QApplication([])
+
+ # Offset of 1 to see the difference in log scale
+ plot1D(x, (y_synthetic + 1, y_fit), "Input data + 1, Fit")
+
+ app.exec()
+ except ImportError:
+ _logger.warning("Unable to load qt binding, can't plot results.")
+
+
+if __name__ == "__main__":
+ test(fitfuns)
diff --git a/silx/math/fit/fittheory.py b/src/silx/math/fit/fittheory.py
index fa42e6b..fa42e6b 100644
--- a/silx/math/fit/fittheory.py
+++ b/src/silx/math/fit/fittheory.py
diff --git a/silx/math/fit/functions.pyx b/src/silx/math/fit/functions.pyx
index 1f78563..1f78563 100644
--- a/silx/math/fit/functions.pyx
+++ b/src/silx/math/fit/functions.pyx
diff --git a/silx/math/fit/functions/include/functions.h b/src/silx/math/fit/functions/include/functions.h
index de4209b..de4209b 100644
--- a/silx/math/fit/functions/include/functions.h
+++ b/src/silx/math/fit/functions/include/functions.h
diff --git a/silx/math/fit/functions/src/funs.c b/src/silx/math/fit/functions/src/funs.c
index aae173f..aae173f 100644
--- a/silx/math/fit/functions/src/funs.c
+++ b/src/silx/math/fit/functions/src/funs.c
diff --git a/silx/math/fit/functions_wrapper.pxd b/src/silx/math/fit/functions_wrapper.pxd
index 780116c..780116c 100644
--- a/silx/math/fit/functions_wrapper.pxd
+++ b/src/silx/math/fit/functions_wrapper.pxd
diff --git a/silx/math/fit/leastsq.py b/src/silx/math/fit/leastsq.py
index 3df1a35..3df1a35 100644
--- a/silx/math/fit/leastsq.py
+++ b/src/silx/math/fit/leastsq.py
diff --git a/silx/math/fit/peaks.pyx b/src/silx/math/fit/peaks.pyx
index a4fce89..a4fce89 100644
--- a/silx/math/fit/peaks.pyx
+++ b/src/silx/math/fit/peaks.pyx
diff --git a/silx/math/fit/peaks/include/peaks.h b/src/silx/math/fit/peaks/include/peaks.h
index bd25d96..bd25d96 100644
--- a/silx/math/fit/peaks/include/peaks.h
+++ b/src/silx/math/fit/peaks/include/peaks.h
diff --git a/silx/math/fit/peaks/src/peaks.c b/src/silx/math/fit/peaks/src/peaks.c
index 65cb4f6..65cb4f6 100644
--- a/silx/math/fit/peaks/src/peaks.c
+++ b/src/silx/math/fit/peaks/src/peaks.c
diff --git a/silx/math/fit/peaks_wrapper.pxd b/src/silx/math/fit/peaks_wrapper.pxd
index 4c77dc6..4c77dc6 100644
--- a/silx/math/fit/peaks_wrapper.pxd
+++ b/src/silx/math/fit/peaks_wrapper.pxd
diff --git a/silx/math/fit/setup.py b/src/silx/math/fit/setup.py
index 649387f..649387f 100644
--- a/silx/math/fit/setup.py
+++ b/src/silx/math/fit/setup.py
diff --git a/src/silx/math/fit/test/__init__.py b/src/silx/math/fit/test/__init__.py
new file mode 100644
index 0000000..745efe3
--- /dev/null
+++ b/src/silx/math/fit/test/__init__.py
@@ -0,0 +1,23 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
diff --git a/src/silx/math/fit/test/test_bgtheories.py b/src/silx/math/fit/test/test_bgtheories.py
new file mode 100644
index 0000000..6620d38
--- /dev/null
+++ b/src/silx/math/fit/test/test_bgtheories.py
@@ -0,0 +1,154 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+import copy
+import unittest
+import numpy
+import random
+
+from silx.math.fit import bgtheories
+from silx.math.fit.functions import sum_gauss
+
+
+class TestBgTheories(unittest.TestCase):
+ """
+ """
+ def setUp(self):
+ self.x = numpy.arange(100)
+ self.y = 10 + 0.05 * self.x + sum_gauss(self.x, 10., 45., 15.)
+ # add a very narrow high amplitude peak to test strip and snip
+ self.y += sum_gauss(self.x, 100., 75., 2.)
+ self.narrow_peak_index = list(self.x).index(75)
+ random.seed()
+
+ def tearDown(self):
+ pass
+
+ def testTheoriesAttrs(self):
+ for theory_name in bgtheories.THEORY:
+ self.assertIsInstance(theory_name, str)
+ self.assertTrue(hasattr(bgtheories.THEORY[theory_name],
+ "function"))
+ self.assertTrue(hasattr(bgtheories.THEORY[theory_name].function,
+ "__call__"))
+ # Ensure legacy functions are not renamed accidentally
+ self.assertTrue(
+ {"No Background", "Constant", "Linear", "Strip", "Snip"}.issubset(
+ set(bgtheories.THEORY)))
+
+ def testNoBg(self):
+ nobgfun = bgtheories.THEORY["No Background"].function
+ self.assertTrue(numpy.array_equal(nobgfun(self.x, self.y),
+ numpy.zeros_like(self.x)))
+ # default estimate
+ self.assertEqual(bgtheories.THEORY["No Background"].estimate(self.x, self.y),
+ ([], []))
+
+ def testConstant(self):
+ consfun = bgtheories.THEORY["Constant"].function
+ c = random.random() * 100
+ self.assertTrue(numpy.array_equal(consfun(self.x, self.y, c),
+ c * numpy.ones_like(self.x)))
+ # default estimate
+ esti_par, cons = bgtheories.THEORY["Constant"].estimate(self.x, self.y)
+ self.assertEqual(cons,
+ [[0, 0, 0]])
+ self.assertAlmostEqual(esti_par,
+ min(self.y))
+
+ def testLinear(self):
+ linfun = bgtheories.THEORY["Linear"].function
+ a = random.random() * 100
+ b = random.random() * 100
+ self.assertTrue(numpy.array_equal(linfun(self.x, self.y, a, b),
+ a + b * self.x))
+ # default estimate
+ esti_par, cons = bgtheories.THEORY["Linear"].estimate(self.x, self.y)
+
+ self.assertEqual(cons,
+ [[0, 0, 0], [0, 0, 0]])
+ self.assertAlmostEqual(esti_par[0], 10, places=3)
+ self.assertAlmostEqual(esti_par[1], 0.05, places=3)
+
+ def testStrip(self):
+ stripfun = bgtheories.THEORY["Strip"].function
+ anchors = sorted(random.sample(list(self.x), 4))
+ anchors_indices = [list(self.x).index(a) for a in anchors]
+
+ # we really want to strip away the narrow peak
+ anchors_indices_copy = copy.deepcopy(anchors_indices)
+ for idx in anchors_indices_copy:
+ if abs(idx - self.narrow_peak_index) < 5:
+ anchors_indices.remove(idx)
+ anchors.remove(self.x[idx])
+
+ width = 2
+ niter = 1000
+ bgtheories.THEORY["Strip"].configure(AnchorsList=anchors, AnchorsFlag=True)
+
+ bg = stripfun(self.x, self.y, width, niter)
+
+ # assert peak amplitude has been decreased
+ self.assertLess(bg[self.narrow_peak_index],
+ self.y[self.narrow_peak_index])
+
+ # default estimate
+ for i in anchors_indices:
+ self.assertEqual(bg[i], self.y[i])
+
+ # estimated parameters are equal to the default ones in the config dict
+ bgtheories.THEORY["Strip"].configure(StripWidth=7, StripIterations=8)
+ esti_par, cons = bgtheories.THEORY["Strip"].estimate(self.x, self.y)
+ self.assertTrue(numpy.array_equal(cons, [[3, 0, 0], [3, 0, 0]]))
+ self.assertEqual(esti_par, [7, 8])
+
+ def testSnip(self):
+ snipfun = bgtheories.THEORY["Snip"].function
+ anchors = sorted(random.sample(list(self.x), 4))
+ anchors_indices = [list(self.x).index(a) for a in anchors]
+
+ # we want to strip away the narrow peak, so remove nearby anchors
+ anchors_indices_copy = copy.deepcopy(anchors_indices)
+ for idx in anchors_indices_copy:
+ if abs(idx - self.narrow_peak_index) < 5:
+ anchors_indices.remove(idx)
+ anchors.remove(self.x[idx])
+
+ width = 16
+ bgtheories.THEORY["Snip"].configure(AnchorsList=anchors, AnchorsFlag=True)
+ bg = snipfun(self.x, self.y, width)
+
+ # assert peak amplitude has been decreased
+ self.assertLess(bg[self.narrow_peak_index],
+ self.y[self.narrow_peak_index],
+ "Snip didn't decrease the peak amplitude.")
+
+ # anchored data must remain fixed
+ for i in anchors_indices:
+ self.assertEqual(bg[i], self.y[i])
+
+ # estimated parameters are equal to the default ones in the config dict
+ bgtheories.THEORY["Snip"].configure(SnipWidth=7)
+ esti_par, cons = bgtheories.THEORY["Snip"].estimate(self.x, self.y)
+ self.assertTrue(numpy.array_equal(cons, [[3, 0, 0]]))
+ self.assertEqual(esti_par, [7])
diff --git a/src/silx/math/fit/test/test_filters.py b/src/silx/math/fit/test/test_filters.py
new file mode 100644
index 0000000..8314bdc
--- /dev/null
+++ b/src/silx/math/fit/test/test_filters.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+import numpy
+import unittest
+from silx.math.fit import filters
+from silx.math.fit import functions
+from silx.test.utils import add_relative_noise
+
+
+class TestSmooth(unittest.TestCase):
+ """
+ Unit tests of smoothing functions.
+
+ Test that the difference between a synthetic curve with 5% added random
+ noise and the result of smoothing that signal is less than 5%. We compare
+ the sum of all samples in each curve.
+ """
+ def setUp(self):
+ x = numpy.arange(5000)
+ # (height1, center1, fwhm1, beamfwhm...)
+ slit_params = (50, 500, 200, 100,
+ 50, 600, 80, 30,
+ 20, 2000, 150, 150,
+ 50, 2250, 110, 100,
+ 40, 3000, 50, 10,
+ 23, 4980, 250, 20)
+
+ self.y1 = functions.sum_slit(x, *slit_params)
+ # 5% noise
+ self.y1 = add_relative_noise(self.y1, 5.)
+
+ # (height1, center1, fwhm1...)
+ step_params = (50, 500, 200,
+ 50, 600, 80,
+ 20, 2000, 150,
+ 50, 2250, 110,
+ 40, 3000, 50,
+ 23, 4980, 250,)
+
+ self.y2 = functions.sum_stepup(x, *step_params)
+ # 5% noise
+ self.y2 = add_relative_noise(self.y2, 5.)
+
+ self.y3 = functions.sum_stepdown(x, *step_params)
+ # 5% noise
+ self.y3 = add_relative_noise(self.y3, 5.)
+
+ def tearDown(self):
+ pass
+
+ def testSavitskyGolay(self):
+ npts = 25
+ for y in [self.y1, self.y2, self.y3]:
+ smoothed_y = filters.savitsky_golay(y, npoints=npts)
+
+ # we added +-5% of random noise. The difference must be much lower
+ # than 5%.
+ diff = abs(sum(smoothed_y) - sum(y)) / sum(y)
+ self.assertLess(diff, 0.05,
+ "Difference between data with 5%% noise and " +
+ "smoothed data is > 5%% (%f %%)" % (diff * 100))
+
+ # Try various smoothing levels
+ npts += 25
+
+ def testSmooth1d(self):
+ """Test the 1D smoothing against the formula
+ ys[i] = (y[i-1] + 2 * y[i] + y[i+1]) / 4 (for 1 < i < n-1)"""
+ smoothed_y = filters.smooth1d(self.y1)
+
+ for i in range(1, len(self.y1) - 1):
+ self.assertAlmostEqual(4 * smoothed_y[i],
+ self.y1[i-1] + 2 * self.y1[i] + self.y1[i+1])
+
+ def testSmooth2d(self):
+ """Test that a 2D smoothing is the same as two successive and
+ orthogonal 1D smoothings"""
+ x = numpy.arange(10000)
+
+ noise = 2 * numpy.random.random(10000) - 1
+ noise *= 0.05
+ y = x * (1 + noise)
+
+ y.shape = (100, 100)
+
+ smoothed_y = filters.smooth2d(y)
+
+ intermediate_smooth = numpy.zeros_like(y)
+ expected_smooth = numpy.zeros_like(y)
+ # smooth along first dimension
+ for i in range(0, y.shape[0]):
+ intermediate_smooth[i, :] = filters.smooth1d(y[i, :])
+
+ # smooth along second dimension
+ for j in range(0, y.shape[1]):
+ expected_smooth[:, j] = filters.smooth1d(intermediate_smooth[:, j])
+
+ for i in range(0, y.shape[0]):
+ for j in range(0, y.shape[1]):
+ self.assertAlmostEqual(smoothed_y[i, j],
+ expected_smooth[i, j])
diff --git a/src/silx/math/fit/test/test_fit.py b/src/silx/math/fit/test/test_fit.py
new file mode 100644
index 0000000..00f04e2
--- /dev/null
+++ b/src/silx/math/fit/test/test_fit.py
@@ -0,0 +1,373 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Nominal tests of the leastsq function.
+"""
+
+import unittest
+
+import numpy
+import sys
+
+from silx.utils import testutils
+from silx.math.fit.leastsq import _logger as fitlogger
+
+
+class Test_leastsq(unittest.TestCase):
+ """
+ Unit tests of the leastsq function.
+ """
+
+ ndims = None
+
+ def setUp(self):
+ try:
+ from silx.math.fit import leastsq
+ self.instance = leastsq
+ except ImportError:
+ self.instance = None
+
+ def myexp(x):
+ # put a (bad) filter to avoid over/underflows
+ # with no python looping
+ with numpy.errstate(invalid='ignore'):
+ return numpy.exp(x*numpy.less(abs(x), 250)) - \
+ 1.0 * numpy.greater_equal(abs(x), 250)
+
+ self.my_exp = myexp
+
+ def gauss(x, *params):
+ params = numpy.array(params, copy=False, dtype=numpy.float64)
+ result = params[0] + params[1] * x
+ for i in range(2, len(params), 3):
+ p = params[i:(i+3)]
+ dummy = 2.3548200450309493*(x - p[1])/p[2]
+ result += p[0] * self.my_exp(-0.5 * dummy * dummy)
+ return result
+
+ self.gauss = gauss
+
+ def gauss_derivative(x, params, idx):
+ if idx == 0:
+ return numpy.ones(len(x), numpy.float64)
+ if idx == 1:
+ return x
+ gaussian_peak = (idx - 2) // 3
+ gaussian_parameter = (idx - 2) % 3
+ actual_idx = 2 + 3 * gaussian_peak
+ p = params[actual_idx:(actual_idx+3)]
+ if gaussian_parameter == 0:
+ return self.gauss(x, *[0, 0, 1.0, p[1], p[2]])
+ if gaussian_parameter == 1:
+ tmp = self.gauss(x, *[0, 0, p[0], p[1], p[2]])
+ tmp *= 2.3548200450309493*(x - p[1])/p[2]
+ return tmp * 2.3548200450309493/p[2]
+ if gaussian_parameter == 2:
+ tmp = self.gauss(x, *[0, 0, p[0], p[1], p[2]])
+ tmp *= 2.3548200450309493*(x - p[1])/p[2]
+ return tmp * 2.3548200450309493*(x - p[1])/(p[2]*p[2])
+
+ self.gauss_derivative = gauss_derivative
+
+ def tearDown(self):
+ self.instance = None
+ self.gauss = None
+ self.gauss_derivative = None
+ self.my_exp = None
+ self.model_function = None
+ self.model_derivative = None
+
+ def testImport(self):
+ self.assertTrue(self.instance is not None,
+ "Cannot import leastsq from silx.math.fit")
+
+ def testUnconstrainedFitNoWeight(self):
+ parameters_actual = [10.5, 2, 1000.0, 20., 15]
+ x = numpy.arange(10000.)
+ y = self.gauss(x, *parameters_actual)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ model_function = self.gauss
+
+ fittedpar, cov = self.instance(model_function, x, y, parameters_estimate)
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ def testUnconstrainedFitWeight(self):
+ parameters_actual = [10.5,2,1000.0,20.,15]
+ x = numpy.arange(10000.)
+ y = self.gauss(x, *parameters_actual)
+ sigma = numpy.sqrt(y)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ model_function = self.gauss
+
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma)
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ def testDerivativeFunction(self):
+ parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300]
+ x = numpy.arange(10000.)
+ y = self.gauss(x, *parameters_actual)
+ delta = numpy.sqrt(numpy.finfo(numpy.float64).eps)
+ for i in range(len(parameters_actual)):
+ p = parameters_actual * 1
+ if p[i] == 0:
+ delta_par = delta
+ else:
+ delta_par = p[i] * delta
+ if i > 2:
+ p[0] = 0.0
+ p[1] = 0.0
+ p[i] += delta_par
+ yPlus = self.gauss(x, *p)
+ p[i] = parameters_actual[i] - delta_par
+ yMinus = self.gauss(x, *p)
+ numerical_derivative = (yPlus - yMinus) / (2 * delta_par)
+ #numerical_derivative = (self.gauss(x, *p) - y) / delta_par
+ p[i] = parameters_actual[i]
+ derivative = self.gauss_derivative(x, p, i)
+ diff = numerical_derivative - derivative
+ test_condition = numpy.allclose(numerical_derivative,
+ derivative, atol=5.0e-6)
+ if not test_condition:
+ msg = "Error calculating derivative of parameter %d." % i
+ msg += "\n diff min = %g diff max = %g" % (diff.min(), diff.max())
+ self.assertTrue(test_condition, msg)
+
+ def testConstrainedFit(self):
+ CFREE = 0
+ CPOSITIVE = 1
+ CQUOTED = 2
+ CFIXED = 3
+ CFACTOR = 4
+ CDELTA = 5
+ CSUM = 6
+ parameters_actual = [10.5, 2, 10000.0, 20., 150, 5000, 900., 300]
+ x = numpy.arange(10000.)
+ y = self.gauss(x, *parameters_actual)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10, 400, 850, 200]
+ model_function = self.gauss
+ model_deriv = self.gauss_derivative
+ constraints_all_free = [[0, 0, 0]] * len(parameters_actual)
+ constraints_all_positive = [[1, 0, 0]] * len(parameters_actual)
+ constraints_delta_position = [[0, 0, 0]] * len(parameters_actual)
+ constraints_delta_position[6] = [CDELTA, 3, 880]
+ constraints_sum_position = constraints_all_positive * 1
+ constraints_sum_position[6] = [CSUM, 3, 920]
+ constraints_factor = constraints_delta_position * 1
+ constraints_factor[2] = [CFACTOR, 5, 2]
+ constraints_list = [None,
+ constraints_all_free,
+ constraints_all_positive,
+ constraints_delta_position,
+ constraints_sum_position]
+
+ # for better code coverage, the warning recommending to set full_output
+ # to True when using constraints should be shown at least once
+ full_output = True
+ for index, constraints in enumerate(constraints_list):
+ if index == 2:
+ full_output = None
+ elif index == 3:
+ full_output = 0
+ for model_deriv in [None, self.gauss_derivative]:
+ for sigma in [None, numpy.sqrt(y)]:
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ constraints=constraints,
+ model_deriv=model_deriv,
+ full_output=full_output)[:2]
+ full_output = True
+
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ def testUnconstrainedFitAnalyticalDerivative(self):
+ parameters_actual = [10.5, 2, 1000.0, 20., 15]
+ x = numpy.arange(10000.)
+ y = self.gauss(x, *parameters_actual)
+ sigma = numpy.sqrt(y)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ model_function = self.gauss
+ model_deriv = self.gauss_derivative
+
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ model_deriv=model_deriv)
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ @testutils.validate_logging(fitlogger.name, warning=2)
+ def testBadlyShapedData(self):
+ parameters_actual = [10.5, 2, 1000.0, 20., 15]
+ x = numpy.arange(10000.).reshape(1000, 10)
+ y = self.gauss(x, *parameters_actual)
+ sigma = numpy.sqrt(y)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ model_function = self.gauss
+
+ for check_finite in [True, False]:
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=check_finite)
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ @testutils.validate_logging(fitlogger.name, warning=3)
+ def testDataWithNaN(self):
+ parameters_actual = [10.5, 2, 1000.0, 20., 15]
+ x = numpy.arange(10000.).reshape(1000, 10)
+ y = self.gauss(x, *parameters_actual)
+ sigma = numpy.sqrt(y)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10]
+ model_function = self.gauss
+ x[500] = numpy.inf
+ # check default behavior
+ try:
+ self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma)
+ except ValueError:
+ info = "%s" % sys.exc_info()[1]
+ self.assertTrue("array must not contain inf" in info)
+
+ # check requested behavior
+ try:
+ self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=True)
+ except ValueError:
+ info = "%s" % sys.exc_info()[1]
+ self.assertTrue("array must not contain inf" in info)
+
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=False)
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ # testing now with ydata containing NaN
+ x = numpy.arange(10000.).reshape(1000, 10)
+ y[500] = numpy.nan
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=False)
+
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ # testing now with sigma containing NaN
+ sigma[300] = numpy.nan
+ fittedpar, cov = self.instance(model_function, x, y,
+ parameters_estimate,
+ sigma=sigma,
+ check_finite=False)
+ test_condition = numpy.allclose(parameters_actual, fittedpar)
+ if not test_condition:
+ msg = "Unsuccessfull fit\n"
+ for i in range(len(fittedpar)):
+ msg += "Expected %g obtained %g\n" % (parameters_actual[i],
+ fittedpar[i])
+ self.assertTrue(test_condition, msg)
+
+ def testUncertainties(self):
+ """Test for validity of uncertainties in returned full-output
+ dictionary. This is a non-regression test for pull request #197"""
+ parameters_actual = [10.5, 2, 1000.0, 20., 15, 2001.0, 30.1, 16]
+ x = numpy.arange(10000.)
+ y = self.gauss(x, *parameters_actual)
+ parameters_estimate = [0.0, 1.0, 900.0, 25., 10., 1500., 20., 2.0]
+
+ # test that uncertainties are not 0.
+ fittedpar, cov, infodict = self.instance(self.gauss, x, y, parameters_estimate,
+ full_output=True)
+ uncertainties = infodict["uncertainties"]
+ self.assertEqual(len(uncertainties), len(parameters_actual))
+ self.assertEqual(len(uncertainties), len(fittedpar))
+ for uncertainty in uncertainties:
+ self.assertNotAlmostEqual(uncertainty, 0.)
+
+ # set constraint FIXED for half the parameters.
+ # This should cause leastsq to return 100% uncertainty.
+ parameters_estimate = [10.6, 2.1, 1000.1, 20.1, 15.1, 2001.1, 30.2, 16.1]
+ CFIXED = 3
+ CFREE = 0
+ constraints = []
+ for i in range(len(parameters_estimate)):
+ if i % 2:
+ constraints.append([CFIXED, 0, 0])
+ else:
+ constraints.append([CFREE, 0, 0])
+ fittedpar, cov, infodict = self.instance(self.gauss, x, y, parameters_estimate,
+ constraints=constraints,
+ full_output=True)
+ uncertainties = infodict["uncertainties"]
+ for i in range(len(parameters_estimate)):
+ if i % 2:
+ # test that all FIXED parameters have 100% uncertainty
+ self.assertAlmostEqual(uncertainties[i],
+ parameters_estimate[i])
diff --git a/src/silx/math/fit/test/test_fitmanager.py b/src/silx/math/fit/test/test_fitmanager.py
new file mode 100644
index 0000000..4ab56a5
--- /dev/null
+++ b/src/silx/math/fit/test/test_fitmanager.py
@@ -0,0 +1,498 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Tests for fitmanager module
+"""
+
+import unittest
+import numpy
+import os.path
+
+from silx.math.fit import fitmanager
+from silx.math.fit import fittheories
+from silx.math.fit import bgtheories
+from silx.math.fit.fittheory import FitTheory
+from silx.math.fit.functions import sum_gauss, sum_stepdown, sum_stepup
+
+from silx.utils.testutils import ParametricTestCase
+from silx.test.utils import temp_dir
+
+custom_function_definition = """
+import copy
+from silx.math.fit.fittheory import FitTheory
+
+CONFIG = {'d': 1.}
+
+def myfun(x, a, b, c):
+ "Model function"
+ return (a * x**2 + b * x + c) / CONFIG['d']
+
+def myesti(x, y):
+ "Initial parameters for iterative fit (a, b, c) = (1, 1, 1)"
+ return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
+
+def myconfig(d=1., **kw):
+ "This function can modify CONFIG"
+ CONFIG["d"] = d
+ return CONFIG
+
+def myderiv(x, parameters, index):
+ "Custom derivative (does not work, causes singular matrix)"
+ pars_plus = copy.copy(parameters)
+ pars_plus[index] *= 1.0001
+
+ pars_minus = parameters
+ pars_minus[index] *= copy.copy(0.9999)
+
+ delta_fun = myfun(x, *pars_plus) - myfun(x, *pars_minus)
+ delta_par = parameters[index] * 0.0001 * 2
+
+ return delta_fun / delta_par
+
+THEORY = {
+ 'my fit theory':
+ FitTheory(function=myfun,
+ parameters=('A', 'B', 'C'),
+ estimate=myesti,
+ configure=myconfig,
+ derivative=myderiv)
+}
+
+"""
+
+old_custom_function_definition = """
+CONFIG = {'d': 1.0}
+
+def myfun(x, a, b, c):
+ "Model function"
+ return (a * x**2 + b * x + c) / CONFIG['d']
+
+def myesti(x, y, bg, xscalinq, yscaling):
+ "Initial parameters for iterative fit (a, b, c) = (1, 1, 1)"
+ return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
+
+def myconfig(**kw):
+ "Update or complete CONFIG dictionary"
+ for key in kw:
+ CONFIG[key] = kw[key]
+ return CONFIG
+
+THEORY = ['my fit theory']
+PARAMETERS = [('A', 'B', 'C')]
+FUNCTION = [myfun]
+ESTIMATE = [myesti]
+CONFIGURE = [myconfig]
+
+"""
+
+
+def _order_of_magnitude(x):
+ return numpy.log10(x).round()
+
+
+class TestFitmanager(ParametricTestCase):
+ """
+ Unit tests of multi-peak functions.
+ """
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def testFitManager(self):
+ """Test fit manager on synthetic data using a gaussian function
+ and a linear background"""
+ # Create synthetic data with a sum of gaussian functions
+ x = numpy.arange(1000).astype(numpy.float64)
+
+ p = [1000, 100., 250,
+ 255, 650., 45,
+ 1500, 800.5, 95]
+ linear_bg = 2.65 * x + 13
+ y = linear_bg + sum_gauss(x, *p)
+
+ y_with_nans = numpy.array(y)
+ y_with_nans[::10] = numpy.nan
+
+ x_with_nans = numpy.array(x)
+ x_with_nans[5::15] = numpy.nan
+
+ tests = {
+ 'all finite': (x, y),
+ 'y with NaNs': (x, y_with_nans),
+ 'x with NaNs': (x_with_nans, y),
+ }
+
+ for name, (xdata, ydata) in tests.items():
+ with self.subTest(name=name):
+ # Fitting
+ fit = fitmanager.FitManager()
+ fit.setdata(x=xdata, y=ydata)
+ fit.loadtheories(fittheories)
+ # Use one of the default fit functions
+ fit.settheory('Gaussians')
+ fit.setbackground('Linear')
+ fit.estimate()
+ fit.runfit()
+
+ # fit.fit_results[]
+
+ # first 2 parameters are related to the linear background
+ self.assertEqual(fit.fit_results[0]["name"], "Constant")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"], 13)
+ self.assertEqual(fit.fit_results[1]["name"], "Slope")
+ self.assertAlmostEqual(fit.fit_results[1]["fitresult"], 2.65)
+
+ for i, param in enumerate(fit.fit_results[2:]):
+ param_number = i // 3 + 1
+ if i % 3 == 0:
+ self.assertEqual(param["name"],
+ "Height%d" % param_number)
+ elif i % 3 == 1:
+ self.assertEqual(param["name"],
+ "Position%d" % param_number)
+ elif i % 3 == 2:
+ self.assertEqual(param["name"],
+ "FWHM%d" % param_number)
+
+ self.assertAlmostEqual(param["fitresult"],
+ p[i])
+ self.assertAlmostEqual(_order_of_magnitude(param["estimation"]),
+ _order_of_magnitude(p[i]))
+
+ def testLoadCustomFitFunction(self):
+ """Test FitManager using a custom fit function defined in an external
+ file and imported with FitManager.loadtheories"""
+ # Create synthetic data with a sum of gaussian functions
+ x = numpy.arange(100).astype(numpy.float64)
+
+ # a, b, c are the fit parameters
+ # d is a known scaling parameter that is set using configure()
+ a, b, c, d = 1.5, 2.5, 3.5, 4.5
+ y = (a * x**2 + b * x + c) / d
+
+ # Fitting
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y)
+
+ # Create a temporary function definition file, and import it
+ with temp_dir() as tmpDir:
+ tmpfile = os.path.join(tmpDir, 'customfun.py')
+ # custom_function_definition
+ fd = open(tmpfile, "w")
+ fd.write(custom_function_definition)
+ fd.close()
+ fit.loadtheories(tmpfile)
+ tmpfile_pyc = os.path.join(tmpDir, 'customfun.pyc')
+ if os.path.exists(tmpfile_pyc):
+ os.unlink(tmpfile_pyc)
+ os.unlink(tmpfile)
+
+ fit.settheory('my fit theory')
+ # Test configure
+ fit.configure(d=4.5)
+ fit.estimate()
+ fit.runfit()
+
+ self.assertEqual(fit.fit_results[0]["name"],
+ "A1")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
+ 1.5)
+ self.assertEqual(fit.fit_results[1]["name"],
+ "B1")
+ self.assertAlmostEqual(fit.fit_results[1]["fitresult"],
+ 2.5)
+ self.assertEqual(fit.fit_results[2]["name"],
+ "C1")
+ self.assertAlmostEqual(fit.fit_results[2]["fitresult"],
+ 3.5)
+
+ def testLoadOldCustomFitFunction(self):
+ """Test FitManager using a custom fit function defined in an external
+ file and imported with FitManager.loadtheories (legacy PyMca format)"""
+ # Create synthetic data with a sum of gaussian functions
+ x = numpy.arange(100).astype(numpy.float64)
+
+ # a, b, c are the fit parameters
+ # d is a known scaling parameter that is set using configure()
+ a, b, c, d = 1.5, 2.5, 3.5, 4.5
+ y = (a * x**2 + b * x + c) / d
+
+ # Fitting
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y)
+
+ # Create a temporary function definition file, and import it
+ with temp_dir() as tmpDir:
+ tmpfile = os.path.join(tmpDir, 'oldcustomfun.py')
+ # custom_function_definition
+ fd = open(tmpfile, "w")
+ fd.write(old_custom_function_definition)
+ fd.close()
+ fit.loadtheories(tmpfile)
+ tmpfile_pyc = os.path.join(tmpDir, 'oldcustomfun.pyc')
+ if os.path.exists(tmpfile_pyc):
+ os.unlink(tmpfile_pyc)
+ os.unlink(tmpfile)
+
+ fit.settheory('my fit theory')
+ fit.configure(d=4.5)
+ fit.estimate()
+ fit.runfit()
+
+ self.assertEqual(fit.fit_results[0]["name"],
+ "A1")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
+ 1.5)
+ self.assertEqual(fit.fit_results[1]["name"],
+ "B1")
+ self.assertAlmostEqual(fit.fit_results[1]["fitresult"],
+ 2.5)
+ self.assertEqual(fit.fit_results[2]["name"],
+ "C1")
+ self.assertAlmostEqual(fit.fit_results[2]["fitresult"],
+ 3.5)
+
+ def testAddTheory(self, estimate=True):
+ """Test FitManager using a custom fit function imported with
+ FitManager.addtheory"""
+ # Create synthetic data with a sum of gaussian functions
+ x = numpy.arange(100).astype(numpy.float64)
+
+ # a, b, c are the fit parameters
+ # d is a known scaling parameter that is set using configure()
+ a, b, c, d = -3.14, 1234.5, 10000, 4.5
+ y = (a * x**2 + b * x + c) / d
+
+ # Fitting
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y)
+
+ # Define and add the fit theory
+ CONFIG = {'d': 1.}
+
+ def myfun(x_, a_, b_, c_):
+ """"Model function"""
+ return (a_ * x_**2 + b_ * x_ + c_) / CONFIG['d']
+
+ def myesti(x_, y_):
+ """"Initial parameters for iterative fit:
+ (a, b, c) = (1, 1, 1)
+ Constraints all set to 0 (FREE)"""
+ return (1., 1., 1.), ((0, 0, 0), (0, 0, 0), (0, 0, 0))
+
+ def myconfig(d_=1., **kw):
+ """This function can modify CONFIG"""
+ CONFIG["d"] = d_
+ return CONFIG
+
+ def myderiv(x_, parameters, index):
+ """Custom derivative"""
+ pars_plus = numpy.array(parameters, copy=True)
+ pars_plus[index] *= 1.001
+
+ pars_minus = numpy.array(parameters, copy=True)
+ pars_minus[index] *= 0.999
+
+ delta_fun = myfun(x_, *pars_plus) - myfun(x_, *pars_minus)
+ delta_par = parameters[index] * 0.001 * 2
+
+ return delta_fun / delta_par
+
+ fit.addtheory("polynomial",
+ FitTheory(function=myfun,
+ parameters=["A", "B", "C"],
+ estimate=myesti if estimate else None,
+ configure=myconfig,
+ derivative=myderiv))
+
+ fit.settheory('polynomial')
+ fit.configure(d_=4.5)
+ fit.estimate()
+ params1, sigmas, infodict = fit.runfit()
+
+ self.assertEqual(fit.fit_results[0]["name"],
+ "A1")
+ self.assertAlmostEqual(fit.fit_results[0]["fitresult"],
+ -3.14)
+ self.assertEqual(fit.fit_results[1]["name"],
+ "B1")
+ # params1[1] is the same as fit.fit_results[1]["fitresult"]
+ self.assertAlmostEqual(params1[1],
+ 1234.5)
+ self.assertEqual(fit.fit_results[2]["name"],
+ "C1")
+ self.assertAlmostEqual(params1[2],
+ 10000)
+
+ # change configuration scaling factor and check that the fit returns
+ # different values
+ fit.configure(d_=5.)
+ fit.estimate()
+ params2, sigmas, infodict = fit.runfit()
+ for p1, p2 in zip(params1, params2):
+ self.assertFalse(numpy.array_equal(p1, p2),
+ "Fit parameters are equal even though the " +
+ "configuration has been changed")
+
+ def testNoEstimate(self):
+ """Ensure that the in the absence of the estimation function,
+ the default estimation function :meth:`FitTheory.default_estimate`
+ is used."""
+ self.testAddTheory(estimate=False)
+
+ def testStep(self):
+ """Test fit manager on a step function with a more complex estimate
+ function than the gaussian (convolution filter)"""
+ for theory_name, theory_fun in (('Step Down', sum_stepdown),
+ ('Step Up', sum_stepup)):
+ # Create synthetic data with a sum of gaussian functions
+ x = numpy.arange(1000).astype(numpy.float64)
+
+ # ('Height', 'Position', 'FWHM')
+ p = [1000, 439, 250]
+
+ constantbg = 13
+ y = theory_fun(x, *p) + constantbg
+
+ # Fitting
+ fit = fitmanager.FitManager()
+ fit.setdata(x=x, y=y)
+ fit.loadtheories(fittheories)
+ fit.settheory(theory_name)
+ fit.setbackground('Constant')
+
+ fit.estimate()
+
+ params, sigmas, infodict = fit.runfit()
+
+ # first parameter is the constant background
+ self.assertAlmostEqual(params[0], 13, places=5)
+ for i, param in enumerate(params[1:]):
+ self.assertAlmostEqual(param, p[i], places=5)
+ self.assertAlmostEqual(_order_of_magnitude(fit.fit_results[i+1]["estimation"]),
+ _order_of_magnitude(p[i]))
+
+
+def quadratic(x, a, b, c):
+ return a * x**2 + b * x + c
+
+
+def cubic(x, a, b, c, d):
+ return a * x**3 + b * x**2 + c * x + d
+
+
+class TestPolynomials(unittest.TestCase):
+ """Test polynomial fit theories and fit background"""
+ def setUp(self):
+ self.x = numpy.arange(100).astype(numpy.float64)
+
+ def testQuadraticBg(self):
+ gaussian_params = [100, 45, 8]
+ poly_params = [0.05, -2, 3]
+ p = numpy.poly1d(poly_params)
+
+ y = p(self.x) + sum_gauss(self.x, *gaussian_params)
+
+ fm = fitmanager.FitManager(self.x, y)
+ fm.loadbgtheories(bgtheories)
+ fm.loadtheories(fittheories)
+ fm.settheory("Gaussians")
+ fm.setbackground("Degree 2 Polynomial")
+ esti_params = fm.estimate()
+ fit_params = fm.runfit()[0]
+
+ for p, pfit in zip(poly_params + gaussian_params, fit_params):
+ self.assertAlmostEqual(p,
+ pfit)
+
+ def testCubicBg(self):
+ gaussian_params = [1000, 45, 8]
+ poly_params = [0.0005, -0.05, 3, -4]
+ p = numpy.poly1d(poly_params)
+
+ y = p(self.x) + sum_gauss(self.x, *gaussian_params)
+
+ fm = fitmanager.FitManager(self.x, y)
+ fm.loadtheories(fittheories)
+ fm.settheory("Gaussians")
+ fm.setbackground("Degree 3 Polynomial")
+ esti_params = fm.estimate()
+ fit_params = fm.runfit()[0]
+
+ for p, pfit in zip(poly_params + gaussian_params, fit_params):
+ self.assertAlmostEqual(p,
+ pfit)
+
+ def testQuarticcBg(self):
+ gaussian_params = [10000, 69, 25]
+ poly_params = [5e-10, 0.0005, 0.005, 2, 4]
+ p = numpy.poly1d(poly_params)
+
+ y = p(self.x) + sum_gauss(self.x, *gaussian_params)
+
+ fm = fitmanager.FitManager(self.x, y)
+ fm.loadtheories(fittheories)
+ fm.settheory("Gaussians")
+ fm.setbackground("Degree 4 Polynomial")
+ esti_params = fm.estimate()
+ fit_params = fm.runfit()[0]
+
+ for p, pfit in zip(poly_params + gaussian_params, fit_params):
+ self.assertAlmostEqual(p,
+ pfit,
+ places=5)
+
+ def _testPoly(self, poly_params, theory, places=5):
+ p = numpy.poly1d(poly_params)
+
+ y = p(self.x)
+
+ fm = fitmanager.FitManager(self.x, y)
+ fm.loadbgtheories(bgtheories)
+ fm.loadtheories(fittheories)
+ fm.settheory(theory)
+ esti_params = fm.estimate()
+ fit_params = fm.runfit()[0]
+
+ for p, pfit in zip(poly_params, fit_params):
+ self.assertAlmostEqual(p, pfit, places=places)
+
+ def testQuadratic(self):
+ self._testPoly([0.05, -2, 3],
+ "Degree 2 Polynomial")
+
+ def testCubic(self):
+ self._testPoly([0.0005, -0.05, 3, -4],
+ "Degree 3 Polynomial")
+
+ def testQuartic(self):
+ self._testPoly([1, -2, 3, -4, -5],
+ "Degree 4 Polynomial")
+
+ def testQuintic(self):
+ self._testPoly([1, -2, 3, -4, -5, 6],
+ "Degree 5 Polynomial",
+ places=4)
diff --git a/src/silx/math/fit/test/test_functions.py b/src/silx/math/fit/test/test_functions.py
new file mode 100644
index 0000000..7e3ff63
--- /dev/null
+++ b/src/silx/math/fit/test/test_functions.py
@@ -0,0 +1,259 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Tests for functions module
+"""
+
+import unittest
+import numpy
+import math
+
+from silx.math.fit import functions
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "21/07/2016"
+
+class Test_functions(unittest.TestCase):
+ """
+ Unit tests of multi-peak functions.
+ """
+ def setUp(self):
+ self.x = numpy.arange(11)
+
+ # height, center, sigma1, sigma2
+ (h, c, s1, s2) = (7., 5., 3., 2.1)
+ self.g_params = {
+ "height": h,
+ "center": c,
+ #"sigma": s,
+ "fwhm1": 2 * math.sqrt(2 * math.log(2)) * s1,
+ "fwhm2": 2 * math.sqrt(2 * math.log(2)) * s2,
+ "area1": h * s1 * math.sqrt(2 * math.pi)
+ }
+ # result of `7 * scipy.signal.gaussian(11, 3)`
+ self.scipy_gaussian = numpy.array(
+ [1.74546546, 2.87778603, 4.24571462, 5.60516182, 6.62171628,
+ 7., 6.62171628, 5.60516182, 4.24571462, 2.87778603,
+ 1.74546546]
+ )
+
+ # result of:
+ # numpy.concatenate((7 * scipy.signal.gaussian(11, 3)[0:5],
+ # 7 * scipy.signal.gaussian(11, 2.1)[5:11]))
+ self.scipy_asym_gaussian = numpy.array(
+ [1.74546546, 2.87778603, 4.24571462, 5.60516182, 6.62171628,
+ 7., 6.24968751, 4.44773692, 2.52313452, 1.14093853, 0.41124877]
+ )
+
+ def tearDown(self):
+ pass
+
+ def testGauss(self):
+ """Compare sum_gauss with scipy.signals.gaussian"""
+ y = functions.sum_gauss(self.x,
+ self.g_params["height"],
+ self.g_params["center"],
+ self.g_params["fwhm1"])
+
+ for i in range(11):
+ self.assertAlmostEqual(y[i], self.scipy_gaussian[i])
+
+ def testAGauss(self):
+ """Compare sum_agauss with scipy.signals.gaussian"""
+ y = functions.sum_agauss(self.x,
+ self.g_params["area1"],
+ self.g_params["center"],
+ self.g_params["fwhm1"])
+ for i in range(11):
+ self.assertAlmostEqual(y[i], self.scipy_gaussian[i])
+
+ def testFastAGauss(self):
+ """Compare sum_fastagauss with scipy.signals.gaussian
+ Limit precision to 3 decimal places."""
+ y = functions.sum_fastagauss(self.x,
+ self.g_params["area1"],
+ self.g_params["center"],
+ self.g_params["fwhm1"])
+ for i in range(11):
+ self.assertAlmostEqual(y[i], self.scipy_gaussian[i], 3)
+
+
+ def testSplitGauss(self):
+ """Compare sum_splitgauss with scipy.signals.gaussian"""
+ y = functions.sum_splitgauss(self.x,
+ self.g_params["height"],
+ self.g_params["center"],
+ self.g_params["fwhm1"],
+ self.g_params["fwhm2"])
+ for i in range(11):
+ self.assertAlmostEqual(y[i], self.scipy_asym_gaussian[i])
+
+ def testErf(self):
+ """Compare erf with math.erf"""
+ # scalars
+ self.assertAlmostEqual(functions.erf(0.14), math.erf(0.14), places=5)
+ self.assertAlmostEqual(functions.erf(0), math.erf(0), places=5)
+ self.assertAlmostEqual(functions.erf(-0.74), math.erf(-0.74), places=5)
+
+ # lists
+ x = [-5, -2, -1.5, -0.6, 0, 0.1, 2, 3]
+ erfx = functions.erf(x)
+ for i in range(len(x)):
+ self.assertAlmostEqual(erfx[i],
+ math.erf(x[i]),
+ places=5)
+
+ # ndarray
+ x = numpy.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
+ erfx = functions.erf(x)
+ for i in range(x.shape[0]):
+ for j in range(x.shape[1]):
+ self.assertAlmostEqual(erfx[i, j],
+ math.erf(x[i, j]),
+ places=5)
+
+ def testErfc(self):
+ """Compare erf with math.erf"""
+ # scalars
+ self.assertAlmostEqual(functions.erfc(0.14), math.erfc(0.14), places=5)
+ self.assertAlmostEqual(functions.erfc(0), math.erfc(0), places=5)
+ self.assertAlmostEqual(functions.erfc(-0.74), math.erfc(-0.74), places=5)
+
+ # lists
+ x = [-5, -2, -1.5, -0.6, 0, 0.1, 2, 3]
+ erfcx = functions.erfc(x)
+ for i in range(len(x)):
+ self.assertAlmostEqual(erfcx[i], math.erfc(x[i]), places=5)
+
+ # ndarray
+ x = numpy.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
+ erfcx = functions.erfc(x)
+ for i in range(x.shape[0]):
+ for j in range(x.shape[1]):
+ self.assertAlmostEqual(erfcx[i, j], math.erfc(x[i, j]), places=5)
+
+ def testAtanStepUp(self):
+ """Compare atan_stepup with math.atan
+
+ atan_stepup(x, a, b, c) = a * (0.5 + (arctan((x - b) / c) / pi))"""
+ x0 = numpy.arange(100) / 6.33
+ y0 = functions.atan_stepup(x0, 11.1, 22.2, 3.33)
+
+ for x, y in zip(x0, y0):
+ self.assertAlmostEqual(
+ 11.1 * (0.5 + math.atan((x - 22.2) / 3.33) / math.pi),
+ y
+ )
+
+ def testStepUp(self):
+ """sanity check for step up:
+
+ - derivative must be largest around the step center
+ - max value must be close to height parameter
+
+ """
+ x0 = numpy.arange(1000)
+ center = 444
+ height = 1234
+ fwhm = 210
+ y0 = functions.sum_stepup(x0, height, center, fwhm)
+
+ self.assertLess(max(y0), height)
+ self.assertAlmostEqual(max(y0), height, places=1)
+ self.assertAlmostEqual(min(y0), 0, places=1)
+
+ deriv0 = _numerical_derivative(functions.sum_stepup, x0, [height, center, fwhm])
+
+ # Test center position within +- 1 sample of max derivative
+ index_max_deriv = numpy.argmax(deriv0)
+ self.assertLess(abs(index_max_deriv - center),
+ 1)
+
+ def testStepDown(self):
+ """sanity check for step down:
+
+ - absolute value of derivative must be largest around the step center
+ - max value must be close to height parameter
+
+ """
+ x0 = numpy.arange(1000)
+ center = 444
+ height = 1234
+ fwhm = 210
+ y0 = functions.sum_stepdown(x0, height, center, fwhm)
+
+ self.assertLess(max(y0), height)
+ self.assertAlmostEqual(max(y0), height, places=1)
+ self.assertAlmostEqual(min(y0), 0, places=1)
+
+ deriv0 = _numerical_derivative(functions.sum_stepdown, x0, [height, center, fwhm])
+
+ # Test center position within +- 1 sample of max derivative
+ index_min_deriv = numpy.argmax(-deriv0)
+ self.assertLess(abs(index_min_deriv - center),
+ 1)
+
+ def testSlit(self):
+ """sanity check for slit:
+
+ - absolute value of derivative must be largest around the step center
+ - max value must be close to height parameter
+
+ """
+ x0 = numpy.arange(1000)
+ center = 444
+ height = 1234
+ fwhm = 210
+ beamfwhm = 30
+ y0 = functions.sum_slit(x0, height, center, fwhm, beamfwhm)
+
+ self.assertAlmostEqual(max(y0), height, places=1)
+ self.assertAlmostEqual(min(y0), 0, places=1)
+
+ deriv0 = _numerical_derivative(functions.sum_slit, x0, [height, center, fwhm, beamfwhm])
+
+ # Test step up center position (center - fwhm/2) within +- 1 sample of max derivative
+ index_max_deriv = numpy.argmax(deriv0)
+ self.assertLess(abs(index_max_deriv - (center - fwhm/2)),
+ 1)
+ # Test step down center position (center + fwhm/2) within +- 1 sample of min derivative
+ index_min_deriv = numpy.argmin(deriv0)
+ self.assertLess(abs(index_min_deriv - (center + fwhm/2)),
+ 1)
+
+
+def _numerical_derivative(f, x, params=[], delta_factor=0.0001):
+ """Compute the numerical derivative of ``f`` for all values of ``x``.
+
+ :param f: function
+ :param x: Array of evenly spaced abscissa values
+ :param params: list of additional parameters
+ :return: Array of derivative values
+ """
+ deltax = (x[1] - x[0]) * delta_factor
+ y_plus = f(x + deltax, *params)
+ y_minus = f(x - deltax, *params)
+
+ return (y_plus - y_minus) / (2 * deltax)
diff --git a/src/silx/math/fit/test/test_peaks.py b/src/silx/math/fit/test/test_peaks.py
new file mode 100644
index 0000000..495c70d
--- /dev/null
+++ b/src/silx/math/fit/test/test_peaks.py
@@ -0,0 +1,132 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Tests for peaks module
+"""
+
+import unittest
+import numpy
+import math
+
+from silx.math.fit import functions
+from silx.math.fit import peaks
+
+class Test_peak_search(unittest.TestCase):
+ """
+ Unit tests of peak_search on various types of multi-peak functions.
+ """
+ def setUp(self):
+ self.x = numpy.arange(5000)
+ # (height1, center1, fwhm1, ...)
+ self.h_c_fwhm = (50, 500, 100,
+ 50, 600, 80,
+ 20, 2000, 100,
+ 50, 2250, 110,
+ 40, 3000, 99,
+ 23, 4980, 80)
+ # (height1, center1, fwhm1, eta1 ...)
+ self.h_c_fwhm_eta = (50, 500, 100, 0.4,
+ 50, 600, 80, 0.5,
+ 20, 2000, 100, 0.6,
+ 50, 2250, 110, 0.7,
+ 40, 3000, 99, 0.8,
+ 23, 4980, 80, 0.3,)
+ # (height1, center1, fwhm11, fwhm21, ...)
+ self.h_c_fwhm_fwhm = (50, 500, 100, 85,
+ 50, 600, 80, 110,
+ 20, 2000, 100, 100,
+ 50, 2250, 110, 99,
+ 40, 3000, 99, 110,
+ 23, 4980, 80, 80,)
+ # (height1, center1, fwhm11, fwhm21, eta1 ...)
+ self.h_c_fwhm_fwhm_eta = (50, 500, 100, 85, 0.4,
+ 50, 600, 80, 110, 0.5,
+ 20, 2000, 100, 100, 0.6,
+ 50, 2250, 110, 99, 0.7,
+ 40, 3000, 99, 110, 0.8,
+ 23, 4980, 80, 80, 0.3,)
+ # (area1, center1, fwhm1, ...)
+ self.a_c_fwhm = (2550, 500, 100,
+ 2000, 600, 80,
+ 500, 2000, 100,
+ 4000, 2250, 110,
+ 2300, 3000, 99,
+ 3333, 4980, 80)
+ # (area1, center1, fwhm1, eta1 ...)
+ self.a_c_fwhm_eta = (500, 500, 100, 0.4,
+ 500, 600, 80, 0.5,
+ 200, 2000, 100, 0.6,
+ 500, 2250, 110, 0.7,
+ 400, 3000, 99, 0.8,
+ 230, 4980, 80, 0.3,)
+ # (area, position, fwhm, st_area_r, st_slope_r, lt_area_r, lt_slope_r, step_height_r)
+ self.hypermet_params = (1000, 500, 200, 0.2, 100, 0.3, 100, 0.05,
+ 1000, 1000, 200, 0.2, 100, 0.3, 100, 0.05,
+ 1000, 2000, 200, 0.2, 100, 0.3, 100, 0.05,
+ 1000, 2350, 200, 0.2, 100, 0.3, 100, 0.05,
+ 1000, 3000, 200, 0.2, 100, 0.3, 100, 0.05,
+ 1000, 4900, 200, 0.2, 100, 0.3, 100, 0.05,)
+
+
+ def tearDown(self):
+ pass
+
+ def get_peaks(self, function, params):
+ """
+
+ :param function: Multi-peak function
+ :param params: Parameter for this function
+ :return: list of (peak, relevance) tuples
+ """
+ y = function(self.x, *params)
+ return peaks.peak_search(y=y, fwhm=100, relevance_info=True)
+
+ def testPeakSearch_various_functions(self):
+ """Run peak search on a variety of synthetic functions, and
+ check that result falls within +-25 samples of the actual peak
+ (reasonable delta considering a fwhm of ~100 samples) and effects
+ of overlapping peaks)."""
+ f_p = ((functions.sum_gauss, self.h_c_fwhm ),
+ (functions.sum_lorentz, self.h_c_fwhm),
+ (functions.sum_pvoigt, self.h_c_fwhm_eta),
+ (functions.sum_splitgauss, self.h_c_fwhm_fwhm),
+ (functions.sum_splitlorentz, self.h_c_fwhm_fwhm),
+ (functions.sum_splitpvoigt, self.h_c_fwhm_fwhm_eta),
+ (functions.sum_agauss, self.a_c_fwhm),
+ (functions.sum_fastagauss, self.a_c_fwhm),
+ (functions.sum_alorentz, self.a_c_fwhm),
+ (functions.sum_apvoigt, self.a_c_fwhm_eta),
+ (functions.sum_ahypermet, self.hypermet_params),
+ (functions.sum_fastahypermet, self.hypermet_params),)
+
+ for function, params in f_p:
+ peaks = self.get_peaks(function, params)
+
+ self.assertEqual(len(peaks), 6,
+ "Wrong number of peaks detected")
+
+ for i in range(6):
+ theoretical_peak_index = params[i*(len(params)//6) + 1]
+ found_peak_index = peaks[i][0]
+ self.assertLess(abs(found_peak_index - theoretical_peak_index), 25)
diff --git a/silx/math/histogram.py b/src/silx/math/histogram.py
index af9ee68..af9ee68 100644
--- a/silx/math/histogram.py
+++ b/src/silx/math/histogram.py
diff --git a/silx/math/histogramnd/include/histogramnd_c.h b/src/silx/math/histogramnd/include/histogramnd_c.h
index abe464f..abe464f 100644
--- a/silx/math/histogramnd/include/histogramnd_c.h
+++ b/src/silx/math/histogramnd/include/histogramnd_c.h
diff --git a/silx/math/histogramnd/include/msvc/stdint.h b/src/silx/math/histogramnd/include/msvc/stdint.h
index e236bb0..e236bb0 100644
--- a/silx/math/histogramnd/include/msvc/stdint.h
+++ b/src/silx/math/histogramnd/include/msvc/stdint.h
diff --git a/silx/math/histogramnd/include/templates.h b/src/silx/math/histogramnd/include/templates.h
index 490eed3..490eed3 100644
--- a/silx/math/histogramnd/include/templates.h
+++ b/src/silx/math/histogramnd/include/templates.h
diff --git a/silx/math/histogramnd/src/histogramnd_c.c b/src/silx/math/histogramnd/src/histogramnd_c.c
index fc9d77e..fc9d77e 100644
--- a/silx/math/histogramnd/src/histogramnd_c.c
+++ b/src/silx/math/histogramnd/src/histogramnd_c.c
diff --git a/silx/math/histogramnd/src/histogramnd_template.c b/src/silx/math/histogramnd/src/histogramnd_template.c
index 0276bb4..0276bb4 100644
--- a/silx/math/histogramnd/src/histogramnd_template.c
+++ b/src/silx/math/histogramnd/src/histogramnd_template.c
diff --git a/silx/math/histogramnd_c.pxd b/src/silx/math/histogramnd_c.pxd
index 35db529..35db529 100644
--- a/silx/math/histogramnd_c.pxd
+++ b/src/silx/math/histogramnd_c.pxd
diff --git a/silx/math/include/math_compatibility.h b/src/silx/math/include/math_compatibility.h
index 3d69c0c..3d69c0c 100644
--- a/silx/math/include/math_compatibility.h
+++ b/src/silx/math/include/math_compatibility.h
diff --git a/silx/math/interpolate.pyx b/src/silx/math/interpolate.pyx
index c79224a..c79224a 100644
--- a/silx/math/interpolate.pyx
+++ b/src/silx/math/interpolate.pyx
diff --git a/silx/math/marchingcubes.pyx b/src/silx/math/marchingcubes.pyx
index 0409691..0409691 100644
--- a/silx/math/marchingcubes.pyx
+++ b/src/silx/math/marchingcubes.pyx
diff --git a/silx/math/marchingcubes/mc.hpp b/src/silx/math/marchingcubes/mc.hpp
index 82eced9..82eced9 100644
--- a/silx/math/marchingcubes/mc.hpp
+++ b/src/silx/math/marchingcubes/mc.hpp
diff --git a/silx/math/marchingcubes/mc_lut.cpp b/src/silx/math/marchingcubes/mc_lut.cpp
index 7998f1b..7998f1b 100644
--- a/silx/math/marchingcubes/mc_lut.cpp
+++ b/src/silx/math/marchingcubes/mc_lut.cpp
diff --git a/silx/math/math_compatibility.pxd b/src/silx/math/math_compatibility.pxd
index ddaa550..ddaa550 100644
--- a/silx/math/math_compatibility.pxd
+++ b/src/silx/math/math_compatibility.pxd
diff --git a/silx/math/mc.pxd b/src/silx/math/mc.pxd
index b1c81e7..b1c81e7 100644
--- a/silx/math/mc.pxd
+++ b/src/silx/math/mc.pxd
diff --git a/silx/math/medianfilter/__init__.py b/src/silx/math/medianfilter/__init__.py
index 2b05f06..2b05f06 100644
--- a/silx/math/medianfilter/__init__.py
+++ b/src/silx/math/medianfilter/__init__.py
diff --git a/silx/math/medianfilter/include/median_filter.hpp b/src/silx/math/medianfilter/include/median_filter.hpp
index 7e42980..7e42980 100644
--- a/silx/math/medianfilter/include/median_filter.hpp
+++ b/src/silx/math/medianfilter/include/median_filter.hpp
diff --git a/silx/math/medianfilter/median_filter.pxd b/src/silx/math/medianfilter/median_filter.pxd
index 2fc0283..2fc0283 100644
--- a/silx/math/medianfilter/median_filter.pxd
+++ b/src/silx/math/medianfilter/median_filter.pxd
diff --git a/silx/math/medianfilter/medianfilter.pyx b/src/silx/math/medianfilter/medianfilter.pyx
index fe05a78..fe05a78 100644
--- a/silx/math/medianfilter/medianfilter.pyx
+++ b/src/silx/math/medianfilter/medianfilter.pyx
diff --git a/silx/math/medianfilter/setup.py b/src/silx/math/medianfilter/setup.py
index d228357..d228357 100644
--- a/silx/math/medianfilter/setup.py
+++ b/src/silx/math/medianfilter/setup.py
diff --git a/src/silx/math/medianfilter/test/__init__.py b/src/silx/math/medianfilter/test/__init__.py
new file mode 100644
index 0000000..71f8e95
--- /dev/null
+++ b/src/silx/math/medianfilter/test/__init__.py
@@ -0,0 +1,23 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
diff --git a/src/silx/math/medianfilter/test/benchmark.py b/src/silx/math/medianfilter/test/benchmark.py
new file mode 100644
index 0000000..81e893e
--- /dev/null
+++ b/src/silx/math/medianfilter/test/benchmark.py
@@ -0,0 +1,122 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2017-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests of the median filter"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "02/05/2017"
+
+from silx.gui import qt
+from silx.math.medianfilter import medfilt2d as medfilt2d_silx
+import numpy
+import numpy.random
+from timeit import Timer
+from silx.gui.plot import Plot1D
+import logging
+
+try:
+ import scipy
+except:
+ scipy = None
+else:
+ import scipy.ndimage
+
+try:
+ import PyMca5.PyMca as pymca
+except:
+ pymca = None
+else:
+ from PyMca5.PyMca.median import medfilt2d as medfilt2d_pymca
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class BenchmarkMedianFilter(object):
+ """Simple benchmark of the median fiter silx vs scipy"""
+
+ NB_ITER = 3
+
+ def __init__(self, imageWidth, kernels):
+ self.img = numpy.random.rand(imageWidth, imageWidth)
+ self.kernels = kernels
+
+ self.run()
+
+ def run(self):
+ self.execTime = {}
+ for kernel in self.kernels:
+ self.execTime[kernel] = self.bench(kernel)
+
+ def bench(self, width):
+ def execSilx():
+ medfilt2d_silx(self.img, width)
+
+ def execScipy():
+ scipy.ndimage.median_filter(input=self.img,
+ size=width,
+ mode='nearest')
+
+ def execPymca():
+ medfilt2d_pymca(self.img, width)
+
+ execTime = {}
+
+ t = Timer(execSilx)
+ execTime["silx"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
+ logger.info(
+ 'exec time silx (kernel size = %s) is %s' % (width, execTime["silx"]))
+
+ if scipy is not None:
+ t = Timer(execScipy)
+ execTime["scipy"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
+ logger.info(
+ 'exec time scipy (kernel size = %s) is %s' % (width, execTime["scipy"]))
+ if pymca is not None:
+ t = Timer(execPymca)
+ execTime["pymca"] = t.timeit(BenchmarkMedianFilter.NB_ITER)
+ logger.info(
+ 'exec time pymca (kernel size = %s) is %s' % (width, execTime["pymca"]))
+
+ return execTime
+
+ def getExecTimeFor(self, id):
+ res = []
+ for k in self.kernels:
+ res.append(self.execTime[k][id])
+ return res
+
+
+app = qt.QApplication([])
+kernels = [3, 5, 7, 11, 15]
+benchmark = BenchmarkMedianFilter(imageWidth=1000, kernels=kernels)
+plot = Plot1D()
+plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("silx"), legend='silx')
+if scipy is not None:
+ plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("scipy"), legend='scipy')
+if pymca is not None:
+ plot.addCurve(x=kernels, y=benchmark.getExecTimeFor("pymca"), legend='pymca')
+plot.show()
+app.exec()
+del app
diff --git a/src/silx/math/medianfilter/test/test_medianfilter.py b/src/silx/math/medianfilter/test/test_medianfilter.py
new file mode 100644
index 0000000..a4e3021
--- /dev/null
+++ b/src/silx/math/medianfilter/test/test_medianfilter.py
@@ -0,0 +1,722 @@
+# coding: utf-8
+# ##########################################################################
+# Copyright (C) 2017-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################
+"""Tests of the median filter"""
+
+__authors__ = ["H. Payno"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+import unittest
+import numpy
+from silx.math.medianfilter import medfilt2d, medfilt1d
+from silx.math.medianfilter.medianfilter import reflect, mirror
+from silx.math.medianfilter.medianfilter import MODES as silx_mf_modes
+from silx.utils.testutils import ParametricTestCase
+try:
+ import scipy
+ import scipy.misc
+except:
+ scipy = None
+else:
+ import scipy.ndimage
+
+import logging
+_logger = logging.getLogger(__name__)
+
+RANDOM_FLOAT_MAT = numpy.array([
+ [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
+ [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
+ [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
+ [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
+
+RANDOM_INT_MAT = numpy.array([
+ [0, 5, 2, 6, 1],
+ [2, 3, 1, 7, 1],
+ [9, 8, 6, 7, 8],
+ [5, 6, 8, 2, 4]])
+
+
+class TestMedianFilterNearest(ParametricTestCase):
+ """Unit tests for the median filter in nearest mode"""
+
+ def testFilter3_100(self):
+ """Test median filter on a 10x10 matrix with a 3x3 kernel."""
+ dataIn = numpy.arange(100, dtype=numpy.int32)
+ dataIn = dataIn.reshape((10, 10))
+
+ dataOut = medfilt2d(image=dataIn,
+ kernel_size=(3, 3),
+ conditional=False,
+ mode='nearest')
+ self.assertTrue(dataOut[0, 0] == 1)
+ self.assertTrue(dataOut[9, 0] == 90)
+ self.assertTrue(dataOut[9, 9] == 98)
+
+ self.assertTrue(dataOut[0, 9] == 9)
+ self.assertTrue(dataOut[0, 4] == 5)
+ self.assertTrue(dataOut[9, 4] == 93)
+ self.assertTrue(dataOut[4, 4] == 44)
+
+ def testFilter3_9(self):
+ "Test median filter on a 3x3 matrix with a 3x3 kernel."
+ dataIn = numpy.array([0, -1, 1,
+ 12, 6, -2,
+ 100, 4, 12],
+ dtype=numpy.int16)
+ dataIn = dataIn.reshape((3, 3))
+ dataOut = medfilt2d(image=dataIn,
+ kernel_size=(3, 3),
+ conditional=False,
+ mode='nearest')
+ self.assertTrue(dataOut.shape == dataIn.shape)
+ self.assertTrue(dataOut[1, 1] == 4)
+ self.assertTrue(dataOut[0, 0] == 0)
+ self.assertTrue(dataOut[0, 1] == 0)
+ self.assertTrue(dataOut[1, 0] == 6)
+
+ def testFilterWidthOne(self):
+ """Make sure a filter of one by one give the same result as the input
+ """
+ dataIn = numpy.arange(100, dtype=numpy.int32)
+ dataIn = dataIn.reshape((10, 10))
+
+ dataOut = medfilt2d(image=dataIn,
+ kernel_size=(1, 1),
+ conditional=False,
+ mode='nearest')
+
+ self.assertTrue(numpy.array_equal(dataIn, dataOut))
+
+ def testFilter3_1d(self):
+ """Test binding and result of the 1d filter"""
+ self.assertTrue(numpy.array_equal(
+ medfilt1d(RANDOM_INT_MAT[0], kernel_size=3, conditional=False,
+ mode='nearest'),
+ [0, 2, 5, 2, 1])
+ )
+
+ def testFilter3Conditionnal(self):
+ """Test that the conditional filter apply correctly in a 10x10 matrix
+ with a 3x3 kernel
+ """
+ dataIn = numpy.arange(100, dtype=numpy.int32)
+ dataIn = dataIn.reshape((10, 10))
+
+ dataOut = medfilt2d(image=dataIn,
+ kernel_size=(3, 3),
+ conditional=True,
+ mode='nearest')
+ self.assertTrue(dataOut[0, 0] == 1)
+ self.assertTrue(dataOut[0, 1] == 1)
+ self.assertTrue(numpy.array_equal(dataOut[1:8, 1:8], dataIn[1:8, 1:8]))
+ self.assertTrue(dataOut[9, 9] == 98)
+
+ def testFilter3_1D(self):
+ """Simple test of a 3x3 median filter on a 1D array"""
+ dataIn = numpy.arange(100, dtype=numpy.int32)
+
+ dataOut = medfilt2d(image=dataIn,
+ kernel_size=(5),
+ conditional=False,
+ mode='nearest')
+
+ self.assertTrue(dataOut[0] == 0)
+ self.assertTrue(dataOut[9] == 9)
+ self.assertTrue(dataOut[99] == 99)
+
+ def testNaNs(self):
+ """Test median filter on image with NaNs in nearest mode"""
+ # Data with a NaN in first corner
+ nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner[0, 0] = numpy.nan
+ output = medfilt2d(
+ nan_corner, kernel_size=3, conditional=False, mode='nearest')
+ self.assertEqual(output[0, 0], 10)
+ self.assertEqual(output[0, 1], 2)
+ self.assertEqual(output[1, 0], 11)
+ self.assertEqual(output[1, 1], 12)
+
+ # Data with some NaNs
+ some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans[0, 1] = numpy.nan
+ some_nans[1, 1] = numpy.nan
+ some_nans[1, 0] = numpy.nan
+ output = medfilt2d(
+ some_nans, kernel_size=3, conditional=False, mode='nearest')
+ self.assertEqual(output[0, 0], 0)
+ self.assertEqual(output[0, 1], 2)
+ self.assertEqual(output[1, 0], 20)
+ self.assertEqual(output[1, 1], 20)
+
+
+class TestMedianFilterReflect(ParametricTestCase):
+ """Unit test for the median filter in reflect mode"""
+
+ def testArange9(self):
+ """Test from a 3x3 window to RANDOM_FLOAT_MAT"""
+ img = numpy.arange(9, dtype=numpy.int32)
+ img = img.reshape(3, 3)
+ kernel = (3, 3)
+ res = medfilt2d(image=img,
+ kernel_size=kernel,
+ conditional=False,
+ mode='reflect')
+ self.assertTrue(
+ numpy.array_equal(res.ravel(), [1, 2, 2, 3, 4, 5, 6, 6, 7]))
+
+ def testRandom10(self):
+ """Test a (5, 3) window to a RANDOM_FLOAT_MAT"""
+ kernel = (5, 3)
+
+ thRes = numpy.array([
+ [0.23067427, 0.56049024, 0.56049024, 0.4440632, 0.42161216],
+ [0.23067427, 0.62717157, 0.56049024, 0.56049024, 0.46372299],
+ [0.62717157, 0.62717157, 0.56049024, 0.56049024, 0.4440632],
+ [0.76532598, 0.68382382, 0.56049024, 0.56049024, 0.42161216],
+ [0.81025249, 0.68382382, 0.56049024, 0.68382382, 0.46372299]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode='reflect')
+
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ def testApplyReflect1D(self):
+ """Test the reflect function used for the median filter in reflect mode
+ """
+ # test for inside values
+ self.assertTrue(reflect(2, 3) == 2)
+ # test for boundaries values
+ self.assertTrue(reflect(3, 3) == 2)
+ self.assertTrue(reflect(4, 3) == 1)
+ self.assertTrue(reflect(5, 3) == 0)
+ self.assertTrue(reflect(6, 3) == 0)
+ self.assertTrue(reflect(7, 3) == 1)
+ self.assertTrue(reflect(-1, 3) == 0)
+ self.assertTrue(reflect(-2, 3) == 1)
+ self.assertTrue(reflect(-3, 3) == 2)
+ self.assertTrue(reflect(-4, 3) == 2)
+ self.assertTrue(reflect(-5, 3) == 1)
+ self.assertTrue(reflect(-6, 3) == 0)
+ self.assertTrue(reflect(-7, 3) == 0)
+
+ def testRandom10Conditionnal(self):
+ """Test the median filter in reflect mode and with the conditionnal
+ option"""
+ kernel = (3, 1)
+
+ thRes = numpy.array([
+ [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
+ [0.23067427, 0.62717157, 0.56049024, 0.44406320, 0.42161216],
+ [0.76532598, 0.20303021, 0.56049024, 0.46372299, 0.42161216],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.33623165],
+ [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=True,
+ mode='reflect')
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ def testNaNs(self):
+ """Test median filter on image with NaNs in reflect mode"""
+ # Data with a NaN in first corner
+ nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner[0, 0] = numpy.nan
+ output = medfilt2d(
+ nan_corner, kernel_size=3, conditional=False, mode='reflect')
+ self.assertEqual(output[0, 0], 10)
+ self.assertEqual(output[0, 1], 2)
+ self.assertEqual(output[1, 0], 11)
+ self.assertEqual(output[1, 1], 12)
+
+ # Data with some NaNs
+ some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans[0, 1] = numpy.nan
+ some_nans[1, 1] = numpy.nan
+ some_nans[1, 0] = numpy.nan
+ output = medfilt2d(
+ some_nans, kernel_size=3, conditional=False, mode='reflect')
+ self.assertEqual(output[0, 0], 0)
+ self.assertEqual(output[0, 1], 2)
+ self.assertEqual(output[1, 0], 20)
+ self.assertEqual(output[1, 1], 20)
+
+ def testFilter3_1d(self):
+ """Test binding and result of the 1d filter"""
+ self.assertTrue(numpy.array_equal(
+ medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
+ mode='reflect'),
+ [2, 2, 2, 2, 2])
+ )
+
+
+class TestMedianFilterMirror(ParametricTestCase):
+ """Unit test for the median filter in mirror mode
+ """
+
+ def testApplyMirror1D(self):
+ """Test the reflect function used for the median filter in mirror mode
+ """
+ # test for inside values
+ self.assertTrue(mirror(2, 3) == 2)
+ # test for boundaries values
+ self.assertTrue(mirror(4, 4) == 2)
+ self.assertTrue(mirror(5, 4) == 1)
+ self.assertTrue(mirror(6, 4) == 0)
+ self.assertTrue(mirror(7, 4) == 1)
+ self.assertTrue(mirror(8, 4) == 2)
+ self.assertTrue(mirror(-1, 4) == 1)
+ self.assertTrue(mirror(-2, 4) == 2)
+ self.assertTrue(mirror(-3, 4) == 3)
+ self.assertTrue(mirror(-4, 4) == 2)
+ self.assertTrue(mirror(-5, 4) == 1)
+ self.assertTrue(mirror(-6, 4) == 0)
+
+ def testRandom10(self):
+ """Test a (5, 3) window to a random array"""
+ kernel = (3, 5)
+
+ thRes = numpy.array([
+ [0.05272484, 0.40555336, 0.42161216, 0.42161216, 0.42161216],
+ [0.56049024, 0.56049024, 0.4440632, 0.4440632, 0.4440632],
+ [0.56049024, 0.46372299, 0.46372299, 0.46372299, 0.46372299],
+ [0.68382382, 0.56049024, 0.56049024, 0.46372299, 0.56049024],
+ [0.68382382, 0.46372299, 0.68382382, 0.46372299, 0.68382382]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode='mirror')
+
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ def testRandom10Conditionnal(self):
+ """Test the median filter in reflect mode and with the conditionnal
+ option"""
+ kernel = (1, 3)
+
+ thRes = numpy.array([
+ [0.62717157, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
+ [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.65166994],
+ [0.74219128, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
+ [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
+ [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.84220106]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=True,
+ mode='mirror')
+
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ def testNaNs(self):
+ """Test median filter on image with NaNs in mirror mode"""
+ # Data with a NaN in first corner
+ nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner[0, 0] = numpy.nan
+ output = medfilt2d(
+ nan_corner, kernel_size=3, conditional=False, mode='mirror')
+ self.assertEqual(output[0, 0], 11)
+ self.assertEqual(output[0, 1], 11)
+ self.assertEqual(output[1, 0], 11)
+ self.assertEqual(output[1, 1], 12)
+
+ # Data with some NaNs
+ some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans[0, 1] = numpy.nan
+ some_nans[1, 1] = numpy.nan
+ some_nans[1, 0] = numpy.nan
+ output = medfilt2d(
+ some_nans, kernel_size=3, conditional=False, mode='mirror')
+ self.assertEqual(output[0, 0], 0)
+ self.assertEqual(output[0, 1], 12)
+ self.assertEqual(output[1, 0], 21)
+ self.assertEqual(output[1, 1], 20)
+
+ def testFilter3_1d(self):
+ """Test binding and result of the 1d filter"""
+ self.assertTrue(numpy.array_equal(
+ medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
+ mode='mirror'),
+ [2, 5, 2, 5, 2])
+ )
+
+class TestMedianFilterShrink(ParametricTestCase):
+ """Unit test for the median filter in mirror mode
+ """
+
+ def testRandom_3x3(self):
+ """Test the median filter in shrink mode and with the conditionnal
+ option"""
+ kernel = (3, 3)
+
+ thRes = numpy.array([
+ [0.62717157, 0.62717157, 0.62717157, 0.65166994, 0.65166994],
+ [0.62717157, 0.56049024, 0.56049024, 0.44406320, 0.44406320],
+ [0.74219128, 0.56049024, 0.46372299, 0.46372299, 0.46372299],
+ [0.74219128, 0.68382382, 0.56049024, 0.56049024, 0.46372299],
+ [0.81025249, 0.81025249, 0.68382382, 0.81281709, 0.81281709]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode='shrink')
+
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ def testBounds(self):
+ """Test the median filter in shrink mode with 3 different kernels
+ which should return the same result due to the large values of kernels
+ used.
+ """
+ kernel1 = (1, 9)
+ kernel2 = (1, 11)
+ kernel3 = (1, 21)
+
+ thRes = numpy.array([[2, 2, 2, 2, 2],
+ [2, 2, 2, 2, 2],
+ [8, 8, 8, 8, 8],
+ [5, 5, 5, 5, 5]])
+
+ resK1 = medfilt2d(image=RANDOM_INT_MAT,
+ kernel_size=kernel1,
+ conditional=False,
+ mode='shrink')
+
+ resK2 = medfilt2d(image=RANDOM_INT_MAT,
+ kernel_size=kernel2,
+ conditional=False,
+ mode='shrink')
+
+ resK3 = medfilt2d(image=RANDOM_INT_MAT,
+ kernel_size=kernel3,
+ conditional=False,
+ mode='shrink')
+
+ self.assertTrue(numpy.array_equal(resK1, thRes))
+ self.assertTrue(numpy.array_equal(resK2, resK1))
+ self.assertTrue(numpy.array_equal(resK3, resK1))
+
+ def testRandom_3x3Conditionnal(self):
+ """Test the median filter in reflect mode and with the conditionnal
+ option"""
+ kernel = (3, 3)
+
+ thRes = numpy.array([
+ [0.05564293, 0.62717157, 0.62717157, 0.40555336, 0.65166994],
+ [0.62717157, 0.56049024, 0.05272484, 0.65166994, 0.42161216],
+ [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.46372299],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
+ [0.81025249, 0.81025249, 0.81651256, 0.81281709, 0.81281709]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=True,
+ mode='shrink')
+
+ self.assertTrue(numpy.array_equal(res, thRes))
+
+ def testRandomInt(self):
+ """Test 3x3 kernel on RANDOM_INT_MAT
+ """
+ kernel = (3, 3)
+
+ thRes = numpy.array([[3, 2, 5, 2, 6],
+ [5, 3, 6, 6, 7],
+ [6, 6, 6, 6, 7],
+ [8, 8, 7, 7, 7]])
+
+ resK1 = medfilt2d(image=RANDOM_INT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode='shrink')
+
+ self.assertTrue(numpy.array_equal(resK1, thRes))
+
+ def testNaNs(self):
+ """Test median filter on image with NaNs in shrink mode"""
+ # Data with a NaN in first corner
+ nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner[0, 0] = numpy.nan
+ output = medfilt2d(
+ nan_corner, kernel_size=3, conditional=False, mode='shrink')
+ self.assertEqual(output[0, 0], 10)
+ self.assertEqual(output[0, 1], 10)
+ self.assertEqual(output[1, 0], 11)
+ self.assertEqual(output[1, 1], 12)
+
+ # Data with some NaNs
+ some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans[0, 1] = numpy.nan
+ some_nans[1, 1] = numpy.nan
+ some_nans[1, 0] = numpy.nan
+ output = medfilt2d(
+ some_nans, kernel_size=3, conditional=False, mode='shrink')
+ self.assertEqual(output[0, 0], 0)
+ self.assertEqual(output[0, 1], 2)
+ self.assertEqual(output[1, 0], 20)
+ self.assertEqual(output[1, 1], 20)
+
+ def testFilter3_1d(self):
+ """Test binding and result of the 1d filter"""
+ self.assertTrue(numpy.array_equal(
+ medfilt1d(RANDOM_INT_MAT[0], kernel_size=3, conditional=False,
+ mode='shrink'),
+ [5, 2, 5, 2, 6])
+ )
+
+class TestMedianFilterConstant(ParametricTestCase):
+ """Unit test for the median filter in constant mode
+ """
+
+ def testRandom10(self):
+ """Test a (5, 3) window to a random array"""
+ kernel = (3, 5)
+
+ thRes = numpy.array([
+ [0., 0.02839148, 0.05564293, 0.02839148, 0.],
+ [0.05272484, 0.40555336, 0.4440632, 0.42161216, 0.28773158],
+ [0.05272484, 0.44406320, 0.46372299, 0.42161216, 0.28773158],
+ [0.20303021, 0.46372299, 0.56049024, 0.44406320, 0.33623165],
+ [0., 0.07813661, 0.33623165, 0.07813661, 0.]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode='constant')
+
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ RANDOM_FLOAT_MAT = numpy.array([
+ [0.05564293, 0.62717157, 0.75002406, 0.40555336, 0.70278975],
+ [0.76532598, 0.02839148, 0.05272484, 0.65166994, 0.42161216],
+ [0.23067427, 0.74219128, 0.56049024, 0.44406320, 0.28773158],
+ [0.81025249, 0.20303021, 0.68382382, 0.46372299, 0.81281709],
+ [0.94691602, 0.07813661, 0.81651256, 0.84220106, 0.33623165]])
+
+ def testRandom10Conditionnal(self):
+ """Test the median filter in reflect mode and with the conditionnal
+ option"""
+ kernel = (1, 3)
+
+ print(RANDOM_FLOAT_MAT)
+
+ thRes = numpy.array([
+ [0.05564293, 0.62717157, 0.62717157, 0.70278975, 0.40555336],
+ [0.02839148, 0.05272484, 0.05272484, 0.42161216, 0.42161216],
+ [0.23067427, 0.56049024, 0.56049024, 0.44406320, 0.28773158],
+ [0.20303021, 0.68382382, 0.46372299, 0.68382382, 0.46372299],
+ [0.07813661, 0.81651256, 0.81651256, 0.81651256, 0.33623165]])
+
+ res = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=True,
+ mode='constant')
+
+ self.assertTrue(numpy.array_equal(thRes, res))
+
+ def testNaNs(self):
+ """Test median filter on image with NaNs in constant mode"""
+ # Data with a NaN in first corner
+ nan_corner = numpy.arange(100.).reshape(10, 10)
+ nan_corner[0, 0] = numpy.nan
+ output = medfilt2d(nan_corner,
+ kernel_size=3,
+ conditional=False,
+ mode='constant',
+ cval=0)
+ self.assertEqual(output[0, 0], 0)
+ self.assertEqual(output[0, 1], 2)
+ self.assertEqual(output[1, 0], 10)
+ self.assertEqual(output[1, 1], 12)
+
+ # Data with some NaNs
+ some_nans = numpy.arange(100.).reshape(10, 10)
+ some_nans[0, 1] = numpy.nan
+ some_nans[1, 1] = numpy.nan
+ some_nans[1, 0] = numpy.nan
+ output = medfilt2d(some_nans,
+ kernel_size=3,
+ conditional=False,
+ mode='constant',
+ cval=0)
+ self.assertEqual(output[0, 0], 0)
+ self.assertEqual(output[0, 1], 0)
+ self.assertEqual(output[1, 0], 0)
+ self.assertEqual(output[1, 1], 20)
+
+ def testFilter3_1d(self):
+ """Test binding and result of the 1d filter"""
+ self.assertTrue(numpy.array_equal(
+ medfilt1d(RANDOM_INT_MAT[0], kernel_size=5, conditional=False,
+ mode='constant'),
+ [0, 2, 2, 2, 1])
+ )
+
+class TestGeneralExecution(ParametricTestCase):
+ """Some general test on median filter application"""
+
+ def testTypes(self):
+ """Test that all needed types have their implementation of the median
+ filter
+ """
+ for mode in silx_mf_modes:
+ for testType in [numpy.float32, numpy.float64, numpy.int16,
+ numpy.uint16, numpy.int32, numpy.int64,
+ numpy.uint64]:
+ with self.subTest(mode=mode, type=testType):
+ data = (numpy.random.rand(10, 10) * 65000).astype(testType)
+ out = medfilt2d(image=data,
+ kernel_size=(3, 3),
+ conditional=False,
+ mode=mode)
+ self.assertTrue(out.dtype.type is testType)
+
+ def testInputDataIsNotModify(self):
+ """Make sure input data is not modify by the median filter"""
+ dataIn = numpy.arange(100, dtype=numpy.int32)
+ dataIn = dataIn.reshape((10, 10))
+ dataInCopy = dataIn.copy()
+
+ for mode in silx_mf_modes:
+ with self.subTest(mode=mode):
+ medfilt2d(image=dataIn,
+ kernel_size=(3, 3),
+ conditional=False,
+ mode=mode)
+ self.assertTrue(numpy.array_equal(dataIn, dataInCopy))
+
+ def testAllNaNs(self):
+ """Test median filter on image all NaNs"""
+ all_nans = numpy.empty((10, 10), dtype=numpy.float32)
+ all_nans[:] = numpy.nan
+
+ for mode in silx_mf_modes:
+ for conditional in (True, False):
+ with self.subTest(mode=mode, conditional=conditional):
+ output = medfilt2d(
+ all_nans,
+ kernel_size=3,
+ conditional=conditional,
+ mode=mode,
+ cval=numpy.nan)
+ self.assertTrue(numpy.all(numpy.isnan(output)))
+
+ def testConditionalWithNaNs(self):
+ """Test that NaNs are propagated through conditional median filter"""
+ for mode in silx_mf_modes:
+ with self.subTest(mode=mode):
+ image = numpy.ones((10, 10), dtype=numpy.float32)
+ nan_mask = numpy.zeros_like(image, dtype=bool)
+ nan_mask[0, 0] = True
+ nan_mask[4, :] = True
+ nan_mask[6, 4] = True
+ image[nan_mask] = numpy.nan
+ output = medfilt2d(
+ image,
+ kernel_size=3,
+ conditional=True,
+ mode=mode)
+ out_isnan = numpy.isnan(output)
+ self.assertTrue(numpy.all(out_isnan[nan_mask]))
+ self.assertFalse(
+ numpy.any(out_isnan[numpy.logical_not(nan_mask)]))
+
+
+def _getScipyAndSilxCommonModes():
+ """return the mode which are comparable between silx and scipy"""
+ modes = silx_mf_modes.copy()
+ del modes['shrink']
+ return modes
+
+
+@unittest.skipUnless(scipy is not None, "scipy not available")
+class TestVsScipy(ParametricTestCase):
+ """Compare scipy.ndimage.median_filter vs silx.math.medianfilter
+ on comparable
+ """
+ def testWithArange(self):
+ """Test vs scipy with different kernels on arange matrix"""
+ data = numpy.arange(10000, dtype=numpy.int32)
+ data = data.reshape(100, 100)
+
+ kernels = [(3, 7), (7, 5), (1, 1), (3, 3)]
+ modesToTest = _getScipyAndSilxCommonModes()
+ for kernel in kernels:
+ for mode in modesToTest:
+ with self.subTest(kernel=kernel, mode=mode):
+ resScipy = scipy.ndimage.median_filter(input=data,
+ size=kernel,
+ mode=mode)
+ resSilx = medfilt2d(image=data,
+ kernel_size=kernel,
+ conditional=False,
+ mode=mode)
+
+ self.assertTrue(numpy.array_equal(resScipy, resSilx))
+
+ def testRandomMatrice(self):
+ """Test vs scipy with different kernels on RANDOM_FLOAT_MAT"""
+ kernels = [(3, 7), (7, 5), (1, 1), (3, 3)]
+ modesToTest = _getScipyAndSilxCommonModes()
+ for kernel in kernels:
+ for mode in modesToTest:
+ with self.subTest(kernel=kernel, mode=mode):
+ resScipy = scipy.ndimage.median_filter(input=RANDOM_FLOAT_MAT,
+ size=kernel,
+ mode=mode)
+
+ resSilx = medfilt2d(image=RANDOM_FLOAT_MAT,
+ kernel_size=kernel,
+ conditional=False,
+ mode=mode)
+
+ self.assertTrue(numpy.array_equal(resScipy, resSilx))
+
+ def testAscentOrLena(self):
+ """Test vs scipy with """
+ if hasattr(scipy.misc, 'ascent'):
+ img = scipy.misc.ascent()
+ else:
+ img = scipy.misc.lena()
+
+ kernels = [(3, 1), (3, 5), (5, 9), (9, 3)]
+ modesToTest = _getScipyAndSilxCommonModes()
+
+ for kernel in kernels:
+ for mode in modesToTest:
+ with self.subTest(kernel=kernel, mode=mode):
+ resScipy = scipy.ndimage.median_filter(input=img,
+ size=kernel,
+ mode=mode)
+
+ resSilx = medfilt2d(image=img,
+ kernel_size=kernel,
+ conditional=False,
+ mode=mode)
+
+ self.assertTrue(numpy.array_equal(resScipy, resSilx))
diff --git a/src/silx/math/setup.py b/src/silx/math/setup.py
new file mode 100644
index 0000000..1c30e6e
--- /dev/null
+++ b/src/silx/math/setup.py
@@ -0,0 +1,99 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+__authors__ = ["D. Naudet"]
+__license__ = "MIT"
+__date__ = "27/03/2017"
+
+import os.path
+
+import numpy
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('math', parent_package, top_path)
+ config.add_subpackage('test')
+ config.add_subpackage('fit')
+ config.add_subpackage('medianfilter')
+ config.add_subpackage('fft')
+
+ # =====================================
+ # histogramnd
+ # =====================================
+ histo_src = [os.path.join('histogramnd', 'src', 'histogramnd_c.c'),
+ 'chistogramnd.pyx']
+ histo_inc = [os.path.join('histogramnd', 'include'),
+ numpy.get_include()]
+
+ config.add_extension('chistogramnd',
+ sources=histo_src,
+ include_dirs=histo_inc,
+ language='c')
+
+ # =====================================
+ # histogramnd_lut
+ # =====================================
+ config.add_extension('chistogramnd_lut',
+ sources=['chistogramnd_lut.pyx'],
+ include_dirs=histo_inc,
+ language='c')
+ # =====================================
+ # marching cubes
+ # =====================================
+ mc_src = [os.path.join('marchingcubes', 'mc_lut.cpp'),
+ 'marchingcubes.pyx']
+ config.add_extension('marchingcubes',
+ sources=mc_src,
+ include_dirs=['marchingcubes', numpy.get_include()],
+ language='c++')
+
+ # min/max
+ config.add_extension('combo',
+ sources=['combo.pyx'],
+ include_dirs=['include'],
+ language='c')
+
+ config.add_extension('_colormap',
+ sources=["_colormap.pyx"],
+ language='c',
+ include_dirs=['include', numpy.get_include()],
+ extra_link_args=['-fopenmp'],
+ extra_compile_args=['-fopenmp'])
+
+ config.add_extension('interpolate',
+ sources=["interpolate.pyx"],
+ language='c',
+ include_dirs=['include', numpy.get_include()],
+ extra_link_args=['-fopenmp'],
+ extra_compile_args=['-fopenmp'])
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/src/silx/math/test/__init__.py b/src/silx/math/test/__init__.py
new file mode 100644
index 0000000..ad9836c
--- /dev/null
+++ b/src/silx/math/test/__init__.py
@@ -0,0 +1,23 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
diff --git a/src/silx/math/test/benchmark_combo.py b/src/silx/math/test/benchmark_combo.py
new file mode 100644
index 0000000..c12f590
--- /dev/null
+++ b/src/silx/math/test/benchmark_combo.py
@@ -0,0 +1,192 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Benchmarks of the combo module"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import logging
+import os.path
+import time
+import unittest
+
+import numpy
+
+from silx.test.utils import temp_dir
+from silx.utils.testutils import ParametricTestCase
+
+from silx.math import combo
+
+_logger = logging.getLogger(__name__)
+_logger.setLevel(logging.DEBUG)
+
+
+class TestBenchmarkMinMax(ParametricTestCase):
+ """Benchmark of min max combo"""
+
+ DTYPES = ('float32', 'float64',
+ 'int8', 'int16', 'int32', 'int64',
+ 'uint8', 'uint16', 'uint32', 'uint64')
+
+ ARANGE = 'ascent', 'descent', 'random'
+
+ EXPONENT = 3, 4, 5, 6, 7
+
+ def test_benchmark_min_max(self):
+ """Benchmark min_max without min positive.
+
+ Compares with:
+
+ - numpy.nanmin, numpy.nanmax and
+ - numpy.argmin, numpy.argmax
+
+ It runs bench for different types, different data size and 3
+ data sets: increasing , decreasing and random data.
+ """
+ durations = {'min/max': [], 'argmin/max': [], 'combo': []}
+
+ _logger.info('Benchmark against argmin/argmax and nanmin/nanmax')
+
+ for dtype in self.DTYPES:
+ for arange in self.ARANGE:
+ for exponent in self.EXPONENT:
+ size = 10**exponent
+ with self.subTest(dtype=dtype, size=size, arange=arange):
+ if arange == 'ascent':
+ data = numpy.arange(0, size, 1, dtype=dtype)
+ elif arange == 'descent':
+ data = numpy.arange(size, 0, -1, dtype=dtype)
+ else:
+ if dtype in ('float32', 'float64'):
+ data = numpy.random.random(size)
+ else:
+ data = numpy.random.randint(10**6, size=size)
+ data = numpy.array(data, dtype=dtype)
+
+ start = time.time()
+ ref_min = numpy.nanmin(data)
+ ref_max = numpy.nanmax(data)
+ durations['min/max'].append(time.time() - start)
+
+ start = time.time()
+ ref_argmin = numpy.argmin(data)
+ ref_argmax = numpy.argmax(data)
+ durations['argmin/max'].append(time.time() - start)
+
+ start = time.time()
+ result = combo.min_max(data, min_positive=False)
+ durations['combo'].append(time.time() - start)
+
+ _logger.info(
+ '%s-%s-10**%d\tx%.2f argmin/max x%.2f min/max',
+ dtype, arange, exponent,
+ durations['argmin/max'][-1] / durations['combo'][-1],
+ durations['min/max'][-1] / durations['combo'][-1])
+
+ self.assertEqual(result.minimum, ref_min)
+ self.assertEqual(result.maximum, ref_max)
+ self.assertEqual(result.argmin, ref_argmin)
+ self.assertEqual(result.argmax, ref_argmax)
+
+ self.show_results('min/max', durations, 'combo')
+
+ def test_benchmark_min_pos(self):
+ """Benchmark min_max wit min positive.
+
+ Compares with:
+
+ - numpy.nanmin(data[data > 0]); numpy.nanmin(pos); numpy.nanmax(pos)
+
+ It runs bench for different types, different data size and 3
+ data sets: increasing , decreasing and random data.
+ """
+ durations = {'min/max': [], 'combo': []}
+
+ _logger.info('Benchmark against min, max, positive min')
+
+ for dtype in self.DTYPES:
+ for arange in self.ARANGE:
+ for exponent in self.EXPONENT:
+ size = 10**exponent
+ with self.subTest(dtype=dtype, size=size, arange=arange):
+ if arange == 'ascent':
+ data = numpy.arange(0, size, 1, dtype=dtype)
+ elif arange == 'descent':
+ data = numpy.arange(size, 0, -1, dtype=dtype)
+ else:
+ if dtype in ('float32', 'float64'):
+ data = numpy.random.random(size)
+ else:
+ data = numpy.random.randint(10**6, size=size)
+ data = numpy.array(data, dtype=dtype)
+
+ start = time.time()
+ ref_min_positive = numpy.nanmin(data[data > 0])
+ ref_min = numpy.nanmin(data)
+ ref_max = numpy.nanmax(data)
+ durations['min/max'].append(time.time() - start)
+
+ start = time.time()
+ result = combo.min_max(data, min_positive=True)
+ durations['combo'].append(time.time() - start)
+
+ _logger.info(
+ '%s-%s-10**%d\tx%.2f min/minpos/max',
+ dtype, arange, exponent,
+ durations['min/max'][-1] / durations['combo'][-1])
+
+ self.assertEqual(result.min_positive, ref_min_positive)
+ self.assertEqual(result.minimum, ref_min)
+ self.assertEqual(result.maximum, ref_max)
+
+ self.show_results('min/max/min positive', durations, 'combo')
+
+ def show_results(self, title, durations, ref_key):
+ try:
+ from matplotlib import pyplot
+ except ImportError:
+ _logger.warning('matplotlib not available')
+ return
+
+ pyplot.title(title)
+ pyplot.xlabel('-'.join(self.DTYPES))
+ pyplot.ylabel('duration (sec)')
+ for label, values in durations.items():
+ pyplot.semilogy(values, label=label)
+ pyplot.legend()
+ pyplot.show()
+
+ pyplot.title(title)
+ pyplot.xlabel('-'.join(self.DTYPES))
+ pyplot.ylabel('Duration ratio')
+ ref = numpy.array(durations[ref_key])
+ for label, values in durations.items():
+ values = numpy.array(values)
+ pyplot.plot(values/ref, label=label + ' / ' + ref_key)
+ pyplot.legend()
+ pyplot.show()
diff --git a/silx/math/test/histo_benchmarks.py b/src/silx/math/test/histo_benchmarks.py
index 7d3216d..7d3216d 100644
--- a/silx/math/test/histo_benchmarks.py
+++ b/src/silx/math/test/histo_benchmarks.py
diff --git a/src/silx/math/test/test_HistogramndLut_nominal.py b/src/silx/math/test/test_HistogramndLut_nominal.py
new file mode 100644
index 0000000..52e003c
--- /dev/null
+++ b/src/silx/math/test/test_HistogramndLut_nominal.py
@@ -0,0 +1,571 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Nominal tests of the HistogramndLut function.
+"""
+
+import unittest
+
+import numpy as np
+
+from silx.math import HistogramndLut
+
+
+def _get_bin_edges(histo_range, n_bins, n_dims):
+ edges = []
+ for i_dim in range(n_dims):
+ edges.append(histo_range[i_dim, 0] +
+ np.arange(n_bins[i_dim] + 1) *
+ (histo_range[i_dim, 1] - histo_range[i_dim, 0]) /
+ n_bins[i_dim])
+ return tuple(edges)
+
+
+# ==============================================================
+# ==============================================================
+# ==============================================================
+
+
+class _TestHistogramndLut_nominal(unittest.TestCase):
+ """
+ Unit tests of the HistogramndLut class.
+ """
+ __test__ = False # ignore abstract class
+
+ ndims = None
+
+ def setUp(self):
+ ndims = self.ndims
+ if ndims is None:
+ self.skipTest("Abstract class")
+ self.tested_dim = ndims-1
+
+ if ndims is None:
+ raise ValueError('ndims class member not set.')
+
+ sample = np.array([5.5, -3.3,
+ 0., -0.5,
+ 3.3, 8.8,
+ -7.7, 6.0,
+ -4.0])
+
+ weights = np.array([500.5, -300.3,
+ 0.01, -0.5,
+ 300.3, 800.8,
+ -700.7, 600.6,
+ -400.4])
+
+ n_elems = len(sample)
+
+ if ndims == 1:
+ shape = (n_elems,)
+ else:
+ shape = (n_elems, ndims)
+
+ self.sample = np.zeros(shape=shape, dtype=sample.dtype)
+ if ndims == 1:
+ self.sample = sample
+ else:
+ self.sample[..., ndims-1] = sample
+
+ self.weights = weights
+
+ # the tests are performed along one dimension,
+ # all the other bins indices along the other dimensions
+ # are expected to be 2
+ # (e.g : when testing a 2D sample : [0, x] will go into
+ # bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
+ # for the first dimension)
+ self.other_axes_index = 2
+ self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
+ self.histo_range[ndims-1] = [-4., 6.]
+
+ self.n_bins = np.array([4]*ndims)
+ self.n_bins[ndims-1] = 5
+
+ if ndims == 1:
+ def fill_histo(h, v, dim, op=None):
+ if op:
+ h[:] = op(h[:], v)
+ else:
+ h[:] = v
+ self.fill_histo = fill_histo
+ else:
+ def fill_histo(h, v, dim, op=None):
+ idx = [self.other_axes_index]*len(h.shape)
+ idx[dim] = slice(0, None)
+ idx = tuple(idx)
+ if op:
+ h[idx] = op(h[idx], v)
+ else:
+ h[idx] = v
+ self.fill_histo = fill_histo
+
+ def test_nominal_bin_edges(self):
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ bin_edges = instance.bins_edges
+
+ expected_edges = _get_bin_edges(self.histo_range,
+ self.n_bins,
+ self.ndims)
+
+ for i_edges, edges in enumerate(expected_edges):
+ self.assertTrue(np.array_equal(bin_edges[i_edges],
+ expected_edges[i_edges]),
+ msg='Testing bin_edges for dim {0}'
+ ''.format(i_edges+1))
+
+ def test_nominal_histo_range(self):
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ histo_range = instance.histo_range
+
+ self.assertTrue(np.array_equal(histo_range, self.histo_range))
+
+ def test_nominal_last_bin_closed(self):
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ last_bin_closed = instance.last_bin_closed
+
+ self.assertEqual(last_bin_closed, False)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins,
+ last_bin_closed=True)
+
+ last_bin_closed = instance.last_bin_closed
+
+ self.assertEqual(last_bin_closed, True)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins,
+ last_bin_closed=False)
+
+ last_bin_closed = instance.last_bin_closed
+
+ self.assertEqual(last_bin_closed, False)
+
+ def test_nominal_n_bins_array(self):
+
+ test_n_bins = np.arange(self.ndims) + 10
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ test_n_bins)
+
+ n_bins = instance.n_bins
+
+ self.assertTrue(np.array_equal(test_n_bins, n_bins))
+
+ def test_nominal_n_bins_scalar(self):
+
+ test_n_bins = 10
+ expected_n_bins = np.array([test_n_bins] * self.ndims)
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ test_n_bins)
+
+ n_bins = instance.n_bins
+
+ self.assertTrue(np.array_equal(expected_n_bins, n_bins))
+
+ def test_nominal_histo_ref(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ instance.accumulate(self.weights)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+ histo_ref = instance.histo(copy=False)
+ w_histo_ref = instance.weighted_histo(copy=False)
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+ self.assertTrue(np.array_equal(histo_ref, expected_h))
+ self.assertTrue(np.array_equal(w_histo_ref, expected_c))
+
+ histo_ref[0, ...] = histo_ref[0, ...] + 10
+ w_histo_ref[0, ...] = w_histo_ref[0, ...] + 20
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+ self.assertFalse(np.array_equal(histo_ref, expected_h))
+ self.assertFalse(np.array_equal(w_histo_ref, expected_c))
+
+ histo_2 = instance.histo()
+ w_histo_2 = instance.weighted_histo()
+
+ self.assertFalse(np.array_equal(histo_2, expected_h))
+ self.assertFalse(np.array_equal(w_histo_2, expected_c))
+ self.assertTrue(np.array_equal(histo_2, histo_ref))
+ self.assertTrue(np.array_equal(w_histo_2, w_histo_ref))
+
+ def test_nominal_accumulate_once(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ instance.accumulate(self.weights)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+ self.assertTrue(np.array_equal(instance.histo(), expected_h))
+ self.assertTrue(np.array_equal(instance.weighted_histo(),
+ expected_c))
+
+ def test_nominal_accumulate_twice(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ # calling accumulate twice
+ expected_h *= 2
+ expected_c *= 2
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ instance.accumulate(self.weights)
+
+ instance.accumulate(self.weights)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+ self.assertTrue(np.array_equal(instance.histo(), expected_h))
+ self.assertTrue(np.array_equal(instance.weighted_histo(),
+ expected_c))
+
+ def test_nominal_apply_lut_once(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ histo, w_histo = instance.apply_lut(self.weights)
+
+ self.assertEqual(w_histo.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+ self.assertEqual(instance.histo(), None)
+ self.assertEqual(instance.weighted_histo(), None)
+
+ def test_nominal_apply_lut_twice(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ # calling apply_lut twice
+ expected_h *= 2
+ expected_c *= 2
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ histo, w_histo = instance.apply_lut(self.weights)
+ histo_2, w_histo_2 = instance.apply_lut(self.weights,
+ histo=histo,
+ weighted_histo=w_histo)
+
+ self.assertEqual(id(histo), id(histo_2))
+ self.assertEqual(id(w_histo), id(w_histo_2))
+ self.assertEqual(w_histo.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+ self.assertEqual(instance.histo(), None)
+ self.assertEqual(instance.weighted_histo(), None)
+
+ def test_nominal_accumulate_last_bin_closed(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 2])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins,
+ last_bin_closed=True)
+
+ instance.accumulate(self.weights)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+
+ def test_nominal_accumulate_weight_min_max(self):
+ """
+ """
+ weight_min = -299.9
+ weight_max = 499.9
+
+ expected_h_tpl = np.array([0, 1, 1, 1, 0])
+ expected_c_tpl = np.array([0., -0.5, 0.01, 300.3, 0.])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ instance.accumulate(self.weights,
+ weight_min=weight_min,
+ weight_max=weight_max)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+
+ def test_nominal_accumulate_forced_int32(self):
+ """
+ double weights, int32 weighted_histogram
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700, 0, 0, 300, 500])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins,
+ dtype=np.int32)
+
+ instance.accumulate(self.weights)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.int32)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+
+ def test_nominal_accumulate_forced_float32(self):
+ """
+ int32 weights, float32 weighted_histogram
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700., 0., 0., 300., 500.])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins,
+ dtype=np.float32)
+
+ instance.accumulate(self.weights.astype(np.int32))
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.float32)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+
+ def test_nominal_accumulate_int32(self):
+ """
+ int32 weights
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700, 0, 0, 300, 500])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.int32)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ instance.accumulate(self.weights.astype(np.int32))
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ self.assertEqual(w_histo.dtype, np.int32)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+
+ def test_nominal_accumulate_int32_double(self):
+ """
+ int32 weights
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700, 0, 0, 300, 500])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.int32)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ instance = HistogramndLut(self.sample,
+ self.histo_range,
+ self.n_bins)
+
+ instance.accumulate(self.weights.astype(np.int32))
+ instance.accumulate(self.weights)
+
+ histo = instance.histo()
+ w_histo = instance.weighted_histo()
+
+ expected_h *= 2
+ expected_c *= 2
+
+ self.assertEqual(w_histo.dtype, np.int32)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(w_histo, expected_c))
+
+ def testNoneNativeTypes(self):
+ type = self.sample.dtype.newbyteorder("B")
+ sampleB = self.sample.astype(type)
+
+ type = self.sample.dtype.newbyteorder("L")
+ sampleL = self.sample.astype(type)
+
+ histo_inst = HistogramndLut(sampleB,
+ self.histo_range,
+ self.n_bins)
+
+ histo_inst = HistogramndLut(sampleL,
+ self.histo_range,
+ self.n_bins)
+
+
+class TestHistogramndLut_nominal_1d(_TestHistogramndLut_nominal):
+ __test__ = True # because _TestHistogramndLut_nominal is ignored
+ ndims = 1
+
+
+class TestHistogramndLut_nominal_2d(_TestHistogramndLut_nominal):
+ __test__ = True # because _TestHistogramndLut_nominal is ignored
+ ndims = 2
+
+
+class TestHistogramndLut_nominal_3d(_TestHistogramndLut_nominal):
+ __test__ = True # because _TestHistogramndLut_nominal is ignored
+ ndims = 3
diff --git a/src/silx/math/test/test_calibration.py b/src/silx/math/test/test_calibration.py
new file mode 100644
index 0000000..7158293
--- /dev/null
+++ b/src/silx/math/test/test_calibration.py
@@ -0,0 +1,145 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests of the calibration module"""
+
+from __future__ import division
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "14/05/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.math.calibration import NoCalibration, LinearCalibration, \
+ ArrayCalibration, FunctionCalibration
+
+
+X = numpy.array([3.14, 2.73, 1337])
+
+
+class TestNoCalibration(unittest.TestCase):
+ def setUp(self):
+ self.calib = NoCalibration()
+
+ def testIsAffine(self):
+ self.assertTrue(self.calib.is_affine())
+
+ def testSlope(self):
+ self.assertEqual(self.calib.get_slope(), 1.)
+
+ def testYIntercept(self):
+ self.assertEqual(self.calib(0.),
+ 0.)
+
+ def testCall(self):
+ self.assertTrue(numpy.array_equal(self.calib(X), X))
+
+
+class TestLinearCalibration(unittest.TestCase):
+ def setUp(self):
+ self.y_intercept = 1.5
+ self.slope = 2.5
+ self.calib = LinearCalibration(y_intercept=self.y_intercept,
+ slope=self.slope)
+
+ def testIsAffine(self):
+ self.assertTrue(self.calib.is_affine())
+
+ def testSlope(self):
+ self.assertEqual(self.calib.get_slope(), self.slope)
+
+ def testYIntercept(self):
+ self.assertEqual(self.calib(0.),
+ self.y_intercept)
+
+ def testCall(self):
+ self.assertTrue(numpy.array_equal(self.calib(X),
+ self.y_intercept + self.slope * X))
+
+
+class TestArrayCalibration(unittest.TestCase):
+ def setUp(self):
+ self.arr = numpy.array([45.2, 25.3, 666., -8.])
+ self.calib = ArrayCalibration(self.arr)
+ self.affine_calib = ArrayCalibration([0.1, 0.2, 0.3])
+
+ def testIsAffine(self):
+ self.assertFalse(self.calib.is_affine())
+ self.assertTrue(self.affine_calib.is_affine())
+
+ def testSlope(self):
+ with self.assertRaises(AttributeError):
+ self.calib.get_slope()
+ self.assertEqual(self.affine_calib.get_slope(),
+ 0.1)
+
+ def testYIntercept(self):
+ self.assertEqual(self.calib(0),
+ self.arr[0])
+
+ def testCall(self):
+ with self.assertRaises(ValueError):
+ # X is an array with a different shape
+ self.calib(X)
+
+ with self.assertRaises(ValueError):
+ # floats are not valid indices
+ self.calib(3.14)
+
+ self.assertTrue(
+ numpy.array_equal(self.calib([1, 2, 3, 4]),
+ self.arr))
+
+ for idx, value in enumerate(self.arr):
+ self.assertEqual(self.calib(idx), value)
+
+
+class TestFunctionCalibration(unittest.TestCase):
+ def setUp(self):
+ self.non_affine_fun = numpy.sin
+ self.non_affine_calib = FunctionCalibration(self.non_affine_fun)
+
+ self.affine_fun = lambda x: 52. * x + 0.01
+ self.affine_calib = FunctionCalibration(self.affine_fun,
+ is_affine=True)
+
+ def testIsAffine(self):
+ self.assertFalse(self.non_affine_calib.is_affine())
+ self.assertTrue(self.affine_calib.is_affine())
+
+ def testSlope(self):
+ with self.assertRaises(AttributeError):
+ self.non_affine_calib.get_slope()
+ self.assertAlmostEqual(self.affine_calib.get_slope(),
+ 52.)
+
+ def testCall(self):
+ for x in X:
+ self.assertAlmostEqual(self.non_affine_calib(x),
+ self.non_affine_fun(x))
+ self.assertAlmostEqual(self.affine_calib(x),
+ self.affine_fun(x))
diff --git a/src/silx/math/test/test_colormap.py b/src/silx/math/test/test_colormap.py
new file mode 100644
index 0000000..0b0ec59
--- /dev/null
+++ b/src/silx/math/test/test_colormap.py
@@ -0,0 +1,269 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Test for colormap mapping implementation"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "16/05/2018"
+
+
+import logging
+import sys
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+from silx.math import colormap
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestNormalization(ParametricTestCase):
+ """Test silx.math.colormap.Normalization sub classes"""
+
+ def _testCodec(self, normalization, rtol=1e-5):
+ """Test apply/revert for normalizations"""
+ test_data = (numpy.arange(1, 10, dtype=numpy.int32),
+ numpy.linspace(1., 100., 1000, dtype=numpy.float32),
+ numpy.linspace(-1., 1., 100, dtype=numpy.float32),
+ 1.,
+ 1)
+
+ for index in range(len(test_data)):
+ with self.subTest(normalization=normalization, data_index=index):
+ data = test_data[index]
+ normalized = normalization.apply(data, 1., 100.)
+ result = normalization.revert(normalized, 1., 100.)
+
+ self.assertTrue(numpy.array_equal(
+ numpy.isnan(normalized), numpy.isnan(result)))
+
+ if isinstance(data, numpy.ndarray):
+ notNaN = numpy.logical_not(numpy.isnan(result))
+ data = data[notNaN]
+ result = result[notNaN]
+ self.assertTrue(numpy.allclose(data, result, rtol=rtol))
+
+ def testLinearNormalization(self):
+ """Test for LinearNormalization"""
+ normalization = colormap.LinearNormalization()
+ self._testCodec(normalization)
+
+ def testLogarithmicNormalization(self):
+ """Test for LogarithmicNormalization"""
+ normalization = colormap.LogarithmicNormalization()
+ # relative tolerance is higher because of the log approximation
+ self._testCodec(normalization, rtol=1e-3)
+
+ # Specific extra tests
+ self.assertTrue(numpy.isnan(normalization.apply(-1., 1., 100.)))
+ self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 1., 100.)))
+ self.assertEqual(normalization.apply(numpy.inf, 1., 100.), numpy.inf)
+ self.assertEqual(normalization.apply(0, 1., 100.), - numpy.inf)
+
+ def testArcsinhNormalization(self):
+ """Test for ArcsinhNormalization"""
+ self._testCodec(colormap.ArcsinhNormalization())
+
+ def testSqrtNormalization(self):
+ """Test for SqrtNormalization"""
+ normalization = colormap.SqrtNormalization()
+ self._testCodec(normalization)
+
+ # Specific extra tests
+ self.assertTrue(numpy.isnan(normalization.apply(-1., 0., 100.)))
+ self.assertTrue(numpy.isnan(normalization.apply(numpy.nan, 0., 100.)))
+ self.assertEqual(normalization.apply(numpy.inf, 0., 100.), numpy.inf)
+ self.assertEqual(normalization.apply(0, 0., 100.), 0.)
+
+
+class TestColormap(ParametricTestCase):
+ """Test silx.math.colormap.cmap"""
+
+ NORMALIZATIONS = (
+ 'linear',
+ 'log',
+ 'arcsinh',
+ 'sqrt',
+ colormap.LinearNormalization(),
+ colormap.LogarithmicNormalization(),
+ colormap.GammaNormalization(2.),
+ colormap.GammaNormalization(0.5))
+
+ @staticmethod
+ def ref_colormap(data, colors, vmin, vmax, normalization, nan_color):
+ """Reference implementation of colormap
+
+ :param numpy.ndarray data: Data to convert
+ :param numpy.ndarray colors: Color look-up-table
+ :param float vmin: Lower bound of the colormap range
+ :param float vmax: Upper bound of the colormap range
+ :param str normalization: Normalization to use
+ :param Union[numpy.ndarray, None] nan_color: Color to use for NaN
+ """
+ norm_functions = {'linear': lambda v: v,
+ 'log': numpy.log10,
+ 'arcsinh': numpy.arcsinh,
+ 'sqrt': numpy.sqrt}
+
+ if isinstance(normalization, str):
+ norm_function = norm_functions[normalization]
+ else:
+ def norm_function(value):
+ return normalization.apply(value, vmin, vmax)
+
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ # Ignore divide by zero and invalid value encountered in log10, sqrt
+ norm_data, vmin, vmax = map(norm_function, (data, vmin, vmax))
+
+ if normalization == 'arcsinh' and sys.platform == 'win32':
+ # There is a difference of behavior of numpy.arcsinh
+ # between Windows and other OS for results of infinite values
+ # This makes Windows behaves as Linux and MacOS
+ norm_data[data == numpy.inf] = numpy.inf
+ norm_data[data == -numpy.inf] = -numpy.inf
+
+ nb_colors = len(colors)
+ scale = nb_colors / (vmax - vmin)
+
+ # Substraction must be done in float to avoid overflow with uint
+ indices = numpy.clip(scale * (norm_data - float(vmin)),
+ 0, nb_colors - 1)
+ indices[numpy.isnan(indices)] = nb_colors # Use an extra index for NaN
+ indices = indices.astype('uint')
+
+ # Add NaN color to array
+ if nan_color is None:
+ nan_color = (0,) * colors.shape[-1]
+ colors = numpy.append(colors, numpy.atleast_2d(nan_color), axis=0)
+
+ return colors[indices]
+
+ def _test(self, data, colors, vmin, vmax, normalization, nan_color):
+ """Run test of colormap against alternative implementation
+
+ :param numpy.ndarray data: Data to convert
+ :param numpy.ndarray colors: Color look-up-table
+ :param float vmin: Lower bound of the colormap range
+ :param float vmax: Upper bound of the colormap range
+ :param str normalization: Normalization to use
+ :param Union[numpy.ndarray, None] nan_color: Color to use for NaN
+ """
+ image = colormap.cmap(
+ data, colors, vmin, vmax, normalization, nan_color)
+
+ ref_image = self.ref_colormap(
+ data, colors, vmin, vmax, normalization, nan_color)
+
+ self.assertTrue(numpy.allclose(ref_image, image))
+ self.assertEqual(image.dtype, colors.dtype)
+ self.assertEqual(image.shape, data.shape + (colors.shape[-1],))
+
+ def test(self):
+ """Test all dtypes with finite data
+
+ Test all supported types and endianness
+ """
+ colors = numpy.zeros((256, 4), dtype=numpy.uint8)
+ colors[:, 0] = numpy.arange(len(colors))
+ colors[:, 3] = 255
+
+ # Generates (u)int and floats types
+ dtypes = [e + k + i for e in '<>' for k in 'uif' for i in '1248'
+ if k != 'f' or i != '1']
+ dtypes.append(numpy.dtype(numpy.longdouble).name) # Add long double
+
+ for normalization in self.NORMALIZATIONS:
+ for dtype in dtypes:
+ with self.subTest(dtype=dtype, normalization=normalization):
+ _logger.info('normalization: %s, dtype: %s',
+ normalization, dtype)
+ data = numpy.arange(-5, 15, dtype=dtype).reshape(4, 5)
+
+ self._test(data, colors, 1, 10, normalization, None)
+
+ def test_not_finite(self):
+ """Test float data with not finite values"""
+ colors = numpy.zeros((256, 4), dtype=numpy.uint8)
+ colors[:, 0] = numpy.arange(len(colors))
+ colors[:, 3] = 255
+
+ test_data = { # message: data
+ 'no finite values': (float('inf'), float('-inf'), float('nan')),
+ 'only NaN': (float('nan'), float('nan'), float('nan')),
+ 'mix finite/not finite': (float('inf'), float('-inf'), 1., float('nan')),
+ }
+
+ for normalization in self.NORMALIZATIONS:
+ for msg, data in test_data.items():
+ with self.subTest(msg, normalization=normalization):
+ _logger.info('normalization: %s, %s', normalization, msg)
+ data = numpy.array(data, dtype=numpy.float64)
+ self._test(data, colors, 1, 10, normalization, (0, 0, 0, 0))
+
+ def test_errors(self):
+ """Test raising exception for bad vmin, vmax, normalization parameters
+ """
+ colors = numpy.zeros((256, 4), dtype=numpy.uint8)
+ colors[:, 0] = numpy.arange(len(colors))
+ colors[:, 3] = 255
+
+ data = numpy.arange(10, dtype=numpy.float64)
+
+ test_params = [ # (vmin, vmax, normalization)
+ (-1., 2., 'log'),
+ (0., 1., 'log'),
+ (1., 0., 'log'),
+ (-1., 1., 'sqrt'),
+ (1., -1., 'sqrt'),
+ ]
+
+ for vmin, vmax, normalization in test_params:
+ with self.subTest(
+ vmin=vmin, vmax=vmax, normalization=normalization):
+ _logger.info('normalization: %s, range: [%f, %f]',
+ normalization, vmin, vmax)
+ with self.assertRaises(ValueError):
+ self._test(data, colors, vmin, vmax, normalization, None)
+
+
+def test_apply_colormap():
+ """Basic test of silx.math.colormap.apply_colormap"""
+ data = numpy.arange(256)
+ expected_colors = numpy.empty((256, 4), dtype=numpy.uint8)
+ expected_colors[:, :3] = numpy.arange(256, dtype=numpy.uint8).reshape(256, 1)
+ expected_colors[:, 3] = 255
+ colors = colormap.apply_colormap(
+ data,
+ colormap="gray",
+ norm="linear",
+ autoscale="minmax",
+ vmin=None,
+ vmax=None,
+ gamma=1.0)
+ assert numpy.array_equal(colors, expected_colors)
diff --git a/src/silx/math/test/test_combo.py b/src/silx/math/test/test_combo.py
new file mode 100644
index 0000000..9a96923
--- /dev/null
+++ b/src/silx/math/test/test_combo.py
@@ -0,0 +1,207 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2020 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests of the combo module"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+
+from silx.math.combo import min_max
+
+
+class TestMinMax(ParametricTestCase):
+ """Tests of min max combo"""
+
+ FLOATING_DTYPES = 'float32', 'float64'
+ if hasattr(numpy, "float128"):
+ FLOATING_DTYPES += ('float128',)
+ SIGNED_INT_DTYPES = 'int8', 'int16', 'int32', 'int64'
+ UNSIGNED_INT_DTYPES = 'uint8', 'uint16', 'uint32', 'uint64'
+ DTYPES = FLOATING_DTYPES + SIGNED_INT_DTYPES + UNSIGNED_INT_DTYPES
+
+ def _numpy_min_max(self, data, min_positive=False, finite=False):
+ """Reference numpy implementation of min_max
+
+ :param numpy.ndarray data: Data set to use for test
+ :param bool min_positive: True to test with positive min
+ :param bool finite: True to only test finite values
+ """
+ data = numpy.array(data, copy=False)
+ if data.size == 0:
+ raise ValueError('Zero-sized array')
+
+ minimum = None
+ argmin = None
+ maximum = None
+ argmax = None
+ min_pos = None
+ argmin_pos = None
+
+ if finite:
+ filtered_data = data[numpy.isfinite(data)]
+ else:
+ filtered_data = data
+
+ if filtered_data.size > 0:
+ if numpy.all(numpy.isnan(filtered_data)):
+ minimum = numpy.nan
+ argmin = 0
+ maximum = numpy.nan
+ argmax = 0
+ else:
+ minimum = numpy.nanmin(filtered_data)
+ # nanargmin equivalent
+ argmin = numpy.where(data == minimum)[0][0]
+ maximum = numpy.nanmax(filtered_data)
+ # nanargmax equivalent
+ argmax = numpy.where(data == maximum)[0][0]
+
+ if min_positive:
+ with numpy.errstate(invalid='ignore'):
+ # Ignore invalid value encountered in greater
+ pos_data = filtered_data[filtered_data > 0]
+ if pos_data.size > 0:
+ min_pos = numpy.min(pos_data)
+ argmin_pos = numpy.where(data == min_pos)[0][0]
+
+ return minimum, min_pos, maximum, argmin, argmin_pos, argmax
+
+ def _test_min_max(self, data, min_positive, finite=False):
+ """Compare min_max with numpy for the given dataset
+
+ :param numpy.ndarray data: Data set to use for test
+ :param bool min_positive: True to test with positive min
+ :param bool finite: True to only test finite values
+ """
+ minimum, min_pos, maximum, argmin, argmin_pos, argmax = \
+ self._numpy_min_max(data, min_positive, finite)
+
+ result = min_max(data, min_positive, finite)
+
+ self.assertSimilar(minimum, result.minimum)
+ self.assertSimilar(min_pos, result.min_positive)
+ self.assertSimilar(maximum, result.maximum)
+ self.assertSimilar(argmin, result.argmin)
+ self.assertSimilar(argmin_pos, result.argmin_positive)
+ self.assertSimilar(argmax, result.argmax)
+
+ def assertSimilar(self, a, b):
+ """Assert that a and b are both None or NaN or that a == b."""
+ self.assertTrue((a is None and b is None) or
+ (numpy.isnan(a) and numpy.isnan(b)) or
+ a == b)
+
+ def test_different_datasets(self):
+ """Test min_max with different numpy.arange datasets."""
+ size = 1000
+
+ for dtype in self.DTYPES:
+
+ tests = {
+ '0 to N': (0, 1),
+ 'N-1 to 0': (size - 1, -1)}
+ if dtype not in self.UNSIGNED_INT_DTYPES:
+ tests['N/2 to -N/2'] = size // 2, -1
+ tests['0 to -N'] = 0, -1
+
+ for name, (start, step) in tests.items():
+ for min_positive in (True, False):
+ with self.subTest(dtype=dtype,
+ min_positive=min_positive,
+ data=name):
+ data = numpy.arange(
+ start, start + step * size, step, dtype=dtype)
+
+ self._test_min_max(data, min_positive)
+
+ def test_nodata(self):
+ """Test min_max with None and empty array"""
+ for dtype in self.DTYPES:
+ with self.subTest(dtype=dtype):
+ with self.assertRaises(TypeError):
+ min_max(None)
+
+ data = numpy.array((), dtype=dtype)
+ with self.assertRaises(ValueError):
+ min_max(data)
+
+ NAN_TEST_DATA = [
+ (float('nan'), float('nan')), # All NaNs
+ (float('nan'), 1.0), # NaN first and positive
+ (float('nan'), -1.0), # NaN first and negative
+ (1.0, 2.0, float('nan')), # NaN last and positive
+ (-1.0, -2.0, float('nan')), # NaN last and negative
+ (1.0, float('nan'), -1.0), # Some NaN
+ ]
+
+ def test_nandata(self):
+ """Test min_max with NaN in data"""
+ for dtype in self.FLOATING_DTYPES:
+ for data in self.NAN_TEST_DATA:
+ with self.subTest(dtype=dtype, data=data):
+ data = numpy.array(data, dtype=dtype)
+ self._test_min_max(data, min_positive=True)
+
+ INF_TEST_DATA = [
+ [float('inf')] * 3, # All +inf
+ [float('-inf')] * 3, # All -inf
+ (float('inf'), float('-inf')), # + and - inf
+ (float('inf'), float('-inf'), float('nan')), # +/-inf, nan last
+ (float('nan'), float('-inf'), float('inf')), # +/-inf, nan first
+ (float('inf'), float('nan'), float('-inf')), # +/-inf, nan center
+ ]
+
+ def test_infdata(self):
+ """Test min_max with inf."""
+ for dtype in self.FLOATING_DTYPES:
+ for data in self.INF_TEST_DATA:
+ with self.subTest(dtype=dtype, data=data):
+ data = numpy.array(data, dtype=dtype)
+ self._test_min_max(data, min_positive=True)
+
+ def test_finite(self):
+ """Test min_max with finite=True"""
+ tests = [
+ (-1., 2., 0.), # Basic test
+ (float('nan'), float('inf'), float('-inf')), # NaN + Inf
+ (float('nan'), float('inf'), -2, float('-inf')), # NaN + Inf + 1 value
+ (float('inf'), -3, -2), # values + inf
+ ]
+ tests += self.INF_TEST_DATA
+ tests += self.NAN_TEST_DATA
+
+ for dtype in self.FLOATING_DTYPES:
+ for data in tests:
+ with self.subTest(dtype=dtype, data=data):
+ data = numpy.array(data, dtype=dtype)
+ self._test_min_max(data, min_positive=True, finite=True)
diff --git a/src/silx/math/test/test_histogramnd_error.py b/src/silx/math/test/test_histogramnd_error.py
new file mode 100644
index 0000000..22304cb
--- /dev/null
+++ b/src/silx/math/test/test_histogramnd_error.py
@@ -0,0 +1,519 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+
+__authors__ = ["D. Naudet"]
+__license__ = "MIT"
+__date__ = "01/02/2016"
+
+"""
+Tests of the histogramnd function, error cases.
+"""
+import sys
+import platform
+import unittest
+
+import numpy as np
+
+from silx.math.chistogramnd import chistogramnd as histogramnd
+from silx.math import Histogramnd
+
+
+# ==============================================================
+# ==============================================================
+# ==============================================================
+
+
+class _Test_chistogramnd_errors(unittest.TestCase):
+ """
+ Unit tests of the chistogramnd error cases.
+ """
+ __test__ = False # ignore abstract class
+
+ def setUp(self):
+ self.skipTest("Abstract class")
+
+ def test_weights_shape(self):
+ """
+ """
+
+ for err_w_shape in self.err_weights_shapes:
+ test_msg = ('Testing invalid weights shape : {0}'
+ ''.format(err_w_shape))
+
+ err_weights = np.random.randint(0,
+ high=10,
+ size=err_w_shape)
+ err_weights = err_weights.astype(np.double)
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=err_weights)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str,
+ '<weights> must be an array whose length '
+ 'is equal to the number of samples.')
+
+ def test_histo_range_shape(self):
+ """
+ """
+ n_dims = 1 if len(self.s_shape) == 1 else self.s_shape[1]
+ expected_txt_tpl = ('<histo_range> error : expected {n_dims} sets '
+ 'of lower and upper bin edges, '
+ 'got the following instead : {histo_range}. '
+ '(provided <sample> contains '
+ '{n_dims}D values)')
+
+ for err_histo_range in self.err_histo_range_shapes:
+ test_msg = ('Testing invalid histo_range shape : {0}'
+ ''.format(err_histo_range))
+
+ expected_txt = expected_txt_tpl.format(histo_range=err_histo_range,
+ n_dims=n_dims)
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ err_histo_range,
+ self.n_bins,
+ weights=self.weights)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_nbins_shape(self):
+ """
+ """
+
+ expected_txt = ('n_bins must be either a scalar (same number '
+ 'of bins for all dimensions) or '
+ 'an array (number of bins for each '
+ 'dimension).')
+
+ for err_n_bins in self.err_n_bins_shapes:
+ test_msg = ('Testing invalid n_bins shape : {0}'
+ ''.format(err_n_bins))
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ err_n_bins,
+ weights=self.weights)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_nbins_values(self):
+ """
+ """
+ expected_txt = ('<n_bins> : only positive values allowed.')
+
+ for err_n_bins in self.err_n_bins_values:
+ test_msg = ('Testing invalid n_bins value : {0}'
+ ''.format(err_n_bins))
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ err_n_bins,
+ weights=self.weights)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_histo_shape(self):
+ """
+ """
+ for err_h_shape in self.err_histo_shapes:
+
+ # windows & python 2.7 : numpy shapes are long values
+ if platform.system() == 'Windows':
+ version = (sys.version_info.major, sys.version_info.minor)
+ if version <= (2, 7):
+ err_h_shape = tuple([long(val) for val in err_h_shape])
+
+ test_msg = ('Testing invalid histo shape : {0}'
+ ''.format(err_h_shape))
+
+ expected_txt = ('Provided <histo> array doesn\'t have '
+ 'a shape compatible with <n_bins> '
+ ': should be {0} instead of {1}.'
+ ''.format(self.h_shape, err_h_shape))
+
+ histo = np.zeros(shape=err_h_shape, dtype=np.uint32)
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ histo=histo)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_histo_dtype(self):
+ """
+ """
+ for err_h_dtype in self.err_histo_dtypes:
+ test_msg = ('Testing invalid histo dtype : {0}'
+ ''.format(err_h_dtype))
+
+ histo = np.zeros(shape=self.h_shape, dtype=err_h_dtype)
+
+ expected_txt = ('Provided <histo> array doesn\'t have '
+ 'the expected type '
+ ': should be {0} instead of {1}.'
+ ''.format(np.uint32, histo.dtype))
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ histo=histo)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_weighted_histo_shape(self):
+ """
+ """
+ # using the same values as histo
+ for err_h_shape in self.err_histo_shapes:
+
+ # windows & python 2.7 : numpy shapes are long values
+ if platform.system() == 'Windows':
+ version = (sys.version_info.major, sys.version_info.minor)
+ if version <= (2, 7):
+ err_h_shape = tuple([long(val) for val in err_h_shape])
+
+ test_msg = ('Testing invalid weighted_histo shape : {0}'
+ ''.format(err_h_shape))
+
+ expected_txt = ('Provided <weighted_histo> array doesn\'t have '
+ 'a shape compatible with <n_bins> '
+ ': should be {0} instead of {1}.'
+ ''.format(self.h_shape, err_h_shape))
+
+ cumul = np.zeros(shape=err_h_shape, dtype=np.double)
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ weighted_histo=cumul)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_cumul_dtype(self):
+ """
+ """
+ # using the same values as histo
+ for err_h_dtype in self.err_histo_dtypes:
+ test_msg = ('Testing invalid weighted_histo dtype : {0}'
+ ''.format(err_h_dtype))
+
+ cumul = np.zeros(shape=self.h_shape, dtype=err_h_dtype)
+
+ expected_txt = ('Provided <weighted_histo> array doesn\'t have '
+ 'the expected type '
+ ': should be {0} or {1} instead of {2}.'
+ ''.format(np.float64, np.float32, cumul.dtype))
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ weighted_histo=cumul)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_wh_histo_dtype(self):
+ """
+ """
+ # using the same values as histo
+ for err_h_dtype in self.err_histo_dtypes:
+ test_msg = ('Testing invalid wh_dtype dtype : {0}'
+ ''.format(err_h_dtype))
+
+ expected_txt = ('<wh_dtype> type not supported : {0}.'
+ ''.format(err_h_dtype))
+
+ ex_str = None
+ try:
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ wh_dtype=err_h_dtype)[0:2]
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_unmanaged_dtypes(self):
+ """
+ """
+ for err_unmanaged_dtype in self.err_unmanaged_dtypes:
+ test_msg = ('Testing unmanaged dtypes : {0}'
+ ''.format(err_unmanaged_dtype))
+
+ sample = self.sample.astype(err_unmanaged_dtype[0])
+ weights = self.weights.astype(err_unmanaged_dtype[1])
+
+ expected_txt = ('Case not supported - sample:{0} '
+ 'and weights:{1}.'
+ ''.format(sample.dtype,
+ weights.dtype))
+
+ ex_str = None
+ try:
+ histogramnd(sample,
+ self.histo_range,
+ self.n_bins,
+ weights=weights)
+ except TypeError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str, msg=test_msg)
+ self.assertEqual(ex_str, expected_txt, msg=test_msg)
+
+ def test_uncontiguous_histo(self):
+ """
+ """
+ # non contiguous array
+ shape = np.array(self.n_bins, ndmin=1)
+ shape[0] *= 2
+ histo_tmp = np.zeros(shape)
+ histo = histo_tmp[::2, ...]
+
+ expected_txt = ('<histo> must be a C_CONTIGUOUS numpy array.')
+
+ ex_str = None
+ try:
+ histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ histo=histo)
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str)
+ self.assertEqual(ex_str, expected_txt)
+
+ def test_uncontiguous_weighted_histo(self):
+ """
+ """
+ # non contiguous array
+ shape = np.array(self.n_bins, ndmin=1)
+ shape[0] *= 2
+ cumul_tmp = np.zeros(shape)
+ cumul = cumul_tmp[::2, ...]
+
+ expected_txt = ('<weighted_histo> must be a C_CONTIGUOUS numpy array.')
+
+ ex_str = None
+ try:
+ histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ weighted_histo=cumul)
+ except ValueError as ex:
+ ex_str = str(ex)
+
+ self.assertIsNotNone(ex_str)
+ self.assertEqual(ex_str, expected_txt)
+
+
+class Test_chistogramnd_1D_errors(_Test_chistogramnd_errors):
+ """
+ Unit tests of the 1D histogramnd error cases.
+ """
+ __test__ = True # because _Test_chistogramnd_errors is ignored
+
+ def setUp(self):
+ # nominal values
+ self.n_elements = 1000
+ self.s_shape = (self.n_elements,)
+ self.w_shape = (self.n_elements,)
+
+ self.histo_range = [0., 100.]
+ self.n_bins = 10
+
+ self.h_shape = (self.n_bins,)
+
+ self.sample = np.random.randint(0,
+ high=10,
+ size=self.s_shape)
+ self.sample = self.sample.astype(np.double)
+
+ self.weights = np.random.randint(0,
+ high=10,
+ size=self.w_shape)
+ self.weights = self.weights.astype(np.double)
+
+ self.err_weights_shapes = ((self.n_elements+1,),
+ (self.n_elements-1,),
+ (self.n_elements-1, 3))
+ self.err_histo_range_shapes = ([0.],
+ [0., 1., 2.],
+ [[0.], [1.]])
+ self.err_n_bins_shapes = ([10, 2],
+ [[10], [2]])
+ self.err_n_bins_values = (0,
+ [-10],
+ None)
+ self.err_histo_shapes = ((self.n_bins+1,),
+ (self.n_bins-1,),
+ (self.n_bins, self.n_bins))
+ # these are used for testing the histo parameter as well
+ # as the weighted_histo parameter.
+ self.err_histo_dtypes = (np.uint16,
+ np.float16)
+
+ self.err_unmanaged_dtypes = ((np.double, np.uint16),
+ (np.uint16, np.double),
+ (np.uint16, np.uint16))
+
+class Test_chistogramnd_ND_range(unittest.TestCase):
+ """
+
+ """
+
+ def test_invalid_histo_range(self):
+ data = np.random.random((60, 60))
+ nbins = 10
+
+ with self.assertRaises(ValueError):
+ histo_range = data.min(), np.inf
+
+ Histogramnd(sample=data.ravel(),
+ histo_range=histo_range,
+ n_bins=nbins)
+
+ histo_range = data.min(), np.nan
+
+ Histogramnd(sample=data.ravel(),
+ histo_range=histo_range,
+ n_bins=nbins)
+
+
+class Test_chistogramnd_ND_errors(_Test_chistogramnd_errors):
+ """
+ Unit tests of the 3D histogramnd error cases.
+ """
+ __test__ = True # because _Test_chistogramnd_errors is ignored
+
+ def setUp(self):
+ # nominal values
+ self.n_elements = 1000
+ self.s_shape = (self.n_elements, 3)
+ self.w_shape = (self.n_elements,)
+
+ self.histo_range = [[0., 100.], [0., 100.], [0., 100.]]
+ self.n_bins = (10, 20, 30)
+
+ self.h_shape = self.n_bins
+
+ self.sample = np.random.randint(0,
+ high=10,
+ size=self.s_shape)
+ self.sample = self.sample.astype(np.double)
+
+ self.weights = np.random.randint(0,
+ high=10,
+ size=self.w_shape)
+ self.weights = self.weights.astype(np.double)
+
+ self.err_weights_shapes = ((self.n_elements+1,),
+ (self.n_elements-1,),
+ (self.n_elements-1, 3))
+ self.err_histo_range_shapes = ([0.],
+ [0., 1.],
+ [[0., 10.], [0., 10.]],
+ [0., 10., 0, 10., 0, 10.])
+ self.err_n_bins_shapes = ([10, 2],
+ [[10], [20], [30]])
+ self.err_n_bins_values = (0,
+ [-10],
+ [10, 20, -4],
+ None,
+ [10, None, 30])
+ self.err_histo_shapes = ((self.n_bins[0]+1,
+ self.n_bins[1],
+ self.n_bins[2]),
+ (self.n_bins[0],
+ self.n_bins[1],
+ self.n_bins[2]-1),
+ (self.n_bins[0],
+ self.n_bins[1]),
+ (self.n_bins[1],
+ self.n_bins[0],
+ self.n_bins[2]),
+ (self.n_bins[0],
+ self.n_bins[1],
+ self.n_bins[2],
+ 10)
+ )
+ # these are used for testing the histo parameter as well
+ # as the weighted_histo parameter.
+ self.err_histo_dtypes = (np.uint16,
+ np.float16)
+
+ self.err_unmanaged_dtypes = ((np.double, np.uint16),
+ (np.uint16, np.double),
+ (np.uint16, np.uint16))
diff --git a/src/silx/math/test/test_histogramnd_nominal.py b/src/silx/math/test/test_histogramnd_nominal.py
new file mode 100644
index 0000000..031a772
--- /dev/null
+++ b/src/silx/math/test/test_histogramnd_nominal.py
@@ -0,0 +1,937 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Nominal tests of the histogramnd function.
+"""
+
+import unittest
+import pytest
+
+import numpy as np
+
+from silx.math.chistogramnd import chistogramnd as histogramnd
+from silx.math import Histogramnd
+
+
+def _get_bin_edges(histo_range, n_bins, n_dims):
+ edges = []
+ for i_dim in range(n_dims):
+ edges.append(histo_range[i_dim, 0] +
+ np.arange(n_bins[i_dim] + 1) *
+ (histo_range[i_dim, 1] - histo_range[i_dim, 0]) /
+ n_bins[i_dim])
+ return tuple(edges)
+
+
+# ==============================================================
+# ==============================================================
+# ==============================================================
+
+
+class _Test_chistogramnd_nominal(unittest.TestCase):
+ """
+ Unit tests of the histogramnd function.
+ """
+ __test__ = False # ignore abstract classe
+
+ ndims = None
+
+ def setUp(self):
+ if type(self).__name__.startswith("_"):
+ self.skipTest("Abstract class")
+ ndims = self.ndims
+ self.tested_dim = ndims-1
+
+ if ndims is None:
+ raise ValueError('ndims class member not set.')
+
+ sample = np.array([5.5, -3.3,
+ 0., -0.5,
+ 3.3, 8.8,
+ -7.7, 6.0,
+ -4.0])
+
+ weights = np.array([500.5, -300.3,
+ 0.01, -0.5,
+ 300.3, 800.8,
+ -700.7, 600.6,
+ -400.4])
+
+ n_elems = len(sample)
+
+ if ndims == 1:
+ shape = (n_elems,)
+ else:
+ shape = (n_elems, ndims)
+
+ self.sample = np.zeros(shape=shape, dtype=sample.dtype)
+ if ndims == 1:
+ self.sample = sample
+ else:
+ self.sample[..., ndims-1] = sample
+
+ self.weights = weights
+
+ # the tests are performed along one dimension,
+ # all the other bins indices along the other dimensions
+ # are expected to be 2
+ # (e.g : when testing a 2D sample : [0, x] will go into
+ # bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
+ # for the first dimension)
+ self.other_axes_index = 2
+ self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
+ self.histo_range[ndims-1] = [-4., 6.]
+
+ self.n_bins = np.array([4]*ndims)
+ self.n_bins[ndims-1] = 5
+
+ if ndims == 1:
+ def fill_histo(h, v, dim, op=None):
+ if op:
+ h[:] = op(h[:], v)
+ else:
+ h[:] = v
+ self.fill_histo = fill_histo
+ else:
+ def fill_histo(h, v, dim, op=None):
+ idx = [self.other_axes_index]*len(h.shape)
+ idx[dim] = slice(0, None)
+ idx = tuple(idx)
+ if op:
+ h[idx] = op(h[idx], v)
+ else:
+ h[idx] = v
+ self.fill_histo = fill_histo
+
+ def test_nominal(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul, bin_edges = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ expected_edges = _get_bin_edges(self.histo_range,
+ self.n_bins,
+ self.ndims)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ for i_edges, edges in enumerate(expected_edges):
+ self.assertTrue(np.array_equal(bin_edges[i_edges],
+ expected_edges[i_edges]),
+ msg='Testing bin_edges for dim {0}'
+ ''.format(i_edges+1))
+
+ def test_nominal_wh_dtype(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul, bin_edges = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ wh_dtype=np.float32)
+
+ self.assertEqual(cumul.dtype, np.float32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.allclose(cumul, expected_c))
+
+ def test_nominal_uncontiguous_sample(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ shape = list(self.sample.shape)
+ shape[0] *= 2
+ sample = np.zeros(shape, dtype=self.sample.dtype)
+ uncontig_sample = sample[::2, ...]
+ uncontig_sample[:] = self.sample
+
+ self.assertFalse(uncontig_sample.flags['C_CONTIGUOUS'],
+ msg='Making sure the array is not contiguous.')
+
+ histo, cumul, bin_edges = histogramnd(uncontig_sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_nominal_uncontiguous_weights(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ shape = list(self.weights.shape)
+ shape[0] *= 2
+ weights = np.zeros(shape, dtype=self.weights.dtype)
+ uncontig_weights = weights[::2, ...]
+ uncontig_weights[:] = self.weights
+
+ self.assertFalse(uncontig_weights.flags['C_CONTIGUOUS'],
+ msg='Making sure the array is not contiguous.')
+
+ histo, cumul, bin_edges = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=uncontig_weights)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_nominal_wo_weights(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=None)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(cumul is None)
+
+ def test_nominal_wo_weights_w_cumul(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ # creating an array of ones just to make sure that
+ # it is not cleared by histogramnd
+ cumul_in = np.ones(self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=None,
+ weighted_histo=cumul_in)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(cumul is None)
+ self.assertTrue(np.array_equal(cumul_in,
+ np.ones(shape=self.n_bins,
+ dtype=np.double)))
+
+ def test_nominal_wo_weights_w_histo(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ # creating an array of ones just to make sure that
+ # it is not cleared by histogramnd
+ histo_in = np.ones(self.n_bins, dtype=np.uint32)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=None,
+ histo=histo_in)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h + 1))
+ self.assertTrue(cumul is None)
+ self.assertEqual(id(histo), id(histo_in))
+
+ def test_nominal_last_bin_closed(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 2])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_int32_weights_double_weights_range(self):
+ """
+ """
+ weight_min = -299.9 # ===> will be cast to -299
+ weight_max = 499.9 # ===> will be cast to 499
+
+ expected_h_tpl = np.array([0, 1, 1, 1, 0])
+ expected_c_tpl = np.array([0., 0., 0., 300., 0.])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights.astype(np.int32),
+ weight_min=weight_min,
+ weight_max=weight_max)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_reuse_histo(self):
+ """
+ """
+
+ expected_h_tpl = np.array([2, 3, 2, 2, 2])
+ expected_c_tpl = np.array([0.0, -7007, -5.0, 0.1, 3003.0])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)[0:2]
+
+ sample_2 = self.sample[:]
+ if len(sample_2.shape) == 1:
+ idx = (slice(0, None),)
+ else:
+ idx = slice(0, None), self.tested_dim
+
+ sample_2[idx] += 2
+
+ histo_2, cumul = histogramnd(sample_2, # <==== !!
+ self.histo_range,
+ self.n_bins,
+ weights=10 * self.weights, # <==== !!
+ histo=histo)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+ self.assertEqual(id(histo), id(histo_2))
+
+ def test_reuse_cumul(self):
+ """
+ """
+
+ expected_h_tpl = np.array([0, 2, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)[0:2]
+
+ sample_2 = self.sample[:]
+ if len(sample_2.shape) == 1:
+ idx = (slice(0, None),)
+ else:
+ idx = slice(0, None), self.tested_dim
+
+ sample_2[idx] += 2
+
+ histo, cumul_2 = histogramnd(sample_2, # <==== !!
+ self.histo_range,
+ self.n_bins,
+ weights=10 * self.weights, # <==== !!
+ weighted_histo=cumul)[0:2]
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
+ self.assertEqual(id(cumul), id(cumul_2))
+
+ def test_reuse_cumul_float(self):
+ """
+ """
+
+ expected_h_tpl = np.array([0, 2, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5],
+ dtype=np.float32)
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)[0:2]
+
+ # converting the cumul array to float
+ cumul = cumul.astype(np.float32)
+
+ sample_2 = self.sample[:]
+ if len(sample_2.shape) == 1:
+ idx = (slice(0, None),)
+ else:
+ idx = slice(0, None), self.tested_dim
+
+ sample_2[idx] += 2
+
+ histo, cumul_2 = histogramnd(sample_2, # <==== !!
+ self.histo_range,
+ self.n_bins,
+ weights=10 * self.weights, # <==== !!
+ weighted_histo=cumul)[0:2]
+
+ self.assertEqual(cumul.dtype, np.float32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertEqual(id(cumul), id(cumul_2))
+ self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
+
+class _Test_Histogramnd_nominal(unittest.TestCase):
+ """
+ Unit tests of the Histogramnd class.
+ """
+ __test__ = False # ignore abstract class
+
+ ndims = None
+
+ def setUp(self):
+ ndims = self.ndims
+ if ndims is None:
+ self.skipTest("Abstract class")
+ self.tested_dim = ndims-1
+
+ if ndims is None:
+ raise ValueError('ndims class member not set.')
+
+ sample = np.array([5.5, -3.3,
+ 0., -0.5,
+ 3.3, 8.8,
+ -7.7, 6.0,
+ -4.0])
+
+ weights = np.array([500.5, -300.3,
+ 0.01, -0.5,
+ 300.3, 800.8,
+ -700.7, 600.6,
+ -400.4])
+
+ n_elems = len(sample)
+
+ if ndims == 1:
+ shape = (n_elems,)
+ else:
+ shape = (n_elems, ndims)
+
+ self.sample = np.zeros(shape=shape, dtype=sample.dtype)
+ if ndims == 1:
+ self.sample = sample
+ else:
+ self.sample[..., ndims-1] = sample
+
+ self.weights = weights
+
+ # the tests are performed along one dimension,
+ # all the other bins indices along the other dimensions
+ # are expected to be 2
+ # (e.g : when testing a 2D sample : [0, x] will go into
+ # bin [2, y] because of the bin ranges [-2, 2] and n_bins = 4
+ # for the first dimension)
+ self.other_axes_index = 2
+ self.histo_range = np.repeat([[-2., 2.]], ndims, axis=0)
+ self.histo_range[ndims-1] = [-4., 6.]
+
+ self.n_bins = np.array([4]*ndims)
+ self.n_bins[ndims-1] = 5
+
+ if ndims == 1:
+ def fill_histo(h, v, dim, op=None):
+ if op:
+ h[:] = op(h[:], v)
+ else:
+ h[:] = v
+ self.fill_histo = fill_histo
+ else:
+ def fill_histo(h, v, dim, op=None):
+ idx = [self.other_axes_index]*len(h.shape)
+ idx[dim] = slice(0, None)
+ idx = tuple(idx)
+ if op:
+ h[idx] = op(h[idx], v)
+ else:
+ h[idx] = v
+ self.fill_histo = fill_histo
+
+ def test_nominal(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ histo, cumul, bin_edges = histo
+
+ expected_edges = _get_bin_edges(self.histo_range,
+ self.n_bins,
+ self.ndims)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ for i_edges, edges in enumerate(expected_edges):
+ self.assertTrue(np.array_equal(bin_edges[i_edges],
+ expected_edges[i_edges]),
+ msg='Testing bin_edges for dim {0}'
+ ''.format(i_edges+1))
+
+ def test_nominal_wh_dtype(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.float32)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul, bin_edges = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ wh_dtype=np.float32)
+
+ self.assertEqual(cumul.dtype, np.float32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.allclose(cumul, expected_c))
+
+ def test_nominal_uncontiguous_sample(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ shape = list(self.sample.shape)
+ shape[0] *= 2
+ sample = np.zeros(shape, dtype=self.sample.dtype)
+ uncontig_sample = sample[::2, ...]
+ uncontig_sample[:] = self.sample
+
+ self.assertFalse(uncontig_sample.flags['C_CONTIGUOUS'],
+ msg='Making sure the array is not contiguous.')
+
+ histo, cumul, bin_edges = Histogramnd(uncontig_sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_nominal_uncontiguous_weights(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ shape = list(self.weights.shape)
+ shape[0] *= 2
+ weights = np.zeros(shape, dtype=self.weights.dtype)
+ uncontig_weights = weights[::2, ...]
+ uncontig_weights[:] = self.weights
+
+ self.assertFalse(uncontig_weights.flags['C_CONTIGUOUS'],
+ msg='Making sure the array is not contiguous.')
+
+ histo, cumul, bin_edges = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=uncontig_weights)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_nominal_wo_weights(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+
+ histo, cumul = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=None)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(cumul is None)
+
+ def test_nominal_last_bin_closed(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 2])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 1101.1])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_int32_weights_double_weights_range(self):
+ """
+ """
+ weight_min = -299.9 # ===> will be cast to -299
+ weight_max = 499.9 # ===> will be cast to 499
+
+ expected_h_tpl = np.array([0, 1, 1, 1, 0])
+ expected_c_tpl = np.array([0., 0., 0., 300., 0.])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo, cumul = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights.astype(np.int32),
+ weight_min=weight_min,
+ weight_max=weight_max)[0:2]
+
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def test_nominal_no_sample(self):
+ """
+ """
+
+ histo_inst = Histogramnd(None,
+ self.histo_range,
+ self.n_bins)
+
+ histo, weighted_histo, edges = histo_inst
+
+ self.assertIsNone(histo)
+ self.assertIsNone(weighted_histo)
+ self.assertIsNone(edges)
+ self.assertIsNone(histo_inst.histo)
+ self.assertIsNone(histo_inst.weighted_histo)
+ self.assertIsNone(histo_inst.edges)
+
+ def test_empty_init_accumulate(self):
+ """
+ """
+ expected_h_tpl = np.array([2, 1, 1, 1, 1])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo_inst = Histogramnd(None,
+ self.histo_range,
+ self.n_bins)
+
+ histo_inst.accumulate(self.sample,
+ weights=self.weights)
+
+ histo = histo_inst.histo
+ cumul = histo_inst.weighted_histo
+ bin_edges = histo_inst.edges
+
+ expected_edges = _get_bin_edges(self.histo_range,
+ self.n_bins,
+ self.ndims)
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertEqual(histo.dtype, np.uint32)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ for i_edges, edges in enumerate(expected_edges):
+ self.assertTrue(np.array_equal(bin_edges[i_edges],
+ expected_edges[i_edges]),
+ msg='Testing bin_edges for dim {0}'
+ ''.format(i_edges+1))
+
+ def test_accumulate(self):
+ """
+ """
+
+ expected_h_tpl = np.array([2, 3, 2, 2, 2])
+ expected_c_tpl = np.array([-700.7, -7007.5, -4.99, 300.4, 3503.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo_inst = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ sample_2 = self.sample[:]
+ if len(sample_2.shape) == 1:
+ idx = (slice(0, None),)
+ else:
+ idx = slice(0, None), self.tested_dim
+
+ sample_2[idx] += 2
+
+ histo_inst.accumulate(sample_2, # <==== !!
+ weights=10 * self.weights) # <==== !!
+
+ histo = histo_inst.histo
+ cumul = histo_inst.weighted_histo
+ bin_edges = histo_inst.edges
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
+
+ def test_accumulate_no_weights(self):
+ """
+ """
+
+ expected_h_tpl = np.array([2, 3, 2, 2, 2])
+ expected_c_tpl = np.array([-700.7, -0.5, 0.01, 300.3, 500.5])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo_inst = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ sample_2 = self.sample[:]
+ if len(sample_2.shape) == 1:
+ idx = (slice(0, None),)
+ else:
+ idx = slice(0, None), self.tested_dim
+
+ sample_2[idx] += 2
+
+ histo_inst.accumulate(sample_2) # <==== !!
+
+ histo = histo_inst.histo
+ cumul = histo_inst.weighted_histo
+ bin_edges = histo_inst.edges
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.allclose(cumul, expected_c, rtol=10e-15))
+
+ def test_accumulate_no_weights_at_init(self):
+ """
+ """
+
+ expected_h_tpl = np.array([2, 3, 2, 2, 2])
+ expected_c_tpl = np.array([0.0, -700.7, -0.5, 0.01, 300.3])
+
+ expected_h = np.zeros(shape=self.n_bins, dtype=np.double)
+ expected_c = np.zeros(shape=self.n_bins, dtype=np.double)
+
+ self.fill_histo(expected_h, expected_h_tpl, self.ndims-1)
+ self.fill_histo(expected_c, expected_c_tpl, self.ndims-1)
+
+ histo_inst = Histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=None) # <==== !!
+
+ cumul = histo_inst.weighted_histo
+ self.assertIsNone(cumul)
+
+ sample_2 = self.sample[:]
+ if len(sample_2.shape) == 1:
+ idx = (slice(0, None),)
+ else:
+ idx = slice(0, None), self.tested_dim
+
+ sample_2[idx] += 2
+
+ histo_inst.accumulate(sample_2,
+ weights=self.weights) # <==== !!
+
+ histo = histo_inst.histo
+ cumul = histo_inst.weighted_histo
+ bin_edges = histo_inst.edges
+
+ self.assertEqual(cumul.dtype, np.float64)
+ self.assertTrue(np.array_equal(histo, expected_h))
+ self.assertTrue(np.array_equal(cumul, expected_c))
+
+ def testNoneNativeTypes(self):
+ type = self.sample.dtype.newbyteorder("B")
+ sampleB = self.sample.astype(type)
+
+ type = self.sample.dtype.newbyteorder("L")
+ sampleL = self.sample.astype(type)
+
+ histo_inst = Histogramnd(sampleB,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+ histo_inst = Histogramnd(sampleL,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights)
+
+
+class Test_chistogram_nominal_1d(_Test_chistogramnd_nominal):
+ __test__ = True # because _Test_chistogramnd_nominal is ignored
+ ndims = 1
+
+
+class Test_chistogram_nominal_2d(_Test_chistogramnd_nominal):
+ __test__ = True # because _Test_chistogramnd_nominal is ignored
+ ndims = 2
+
+
+class Test_chistogram_nominal_3d(_Test_chistogramnd_nominal):
+ __test__ = True # because _Test_chistogramnd_nominal is ignored
+ ndims = 3
+
+
+class Test_Histogramnd_nominal_1d(_Test_Histogramnd_nominal):
+ __test__ = True # because _Test_chistogramnd_nominal is ignored
+ ndims = 1
+
+
+class Test_Histogramnd_nominal_2d(_Test_Histogramnd_nominal):
+ __test__ = True # because _Test_chistogramnd_nominal is ignored
+ ndims = 2
+
+
+class Test_Histogramnd_nominal_3d(_Test_Histogramnd_nominal):
+ __test__ = True # because _Test_chistogramnd_nominal is ignored
+ ndims = 3
diff --git a/src/silx/math/test/test_histogramnd_vs_np.py b/src/silx/math/test/test_histogramnd_vs_np.py
new file mode 100644
index 0000000..d6a8d19
--- /dev/null
+++ b/src/silx/math/test/test_histogramnd_vs_np.py
@@ -0,0 +1,826 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""
+Tests for the histogramnd function.
+Results are compared to numpy's histogramdd.
+"""
+
+import unittest
+import operator
+
+import numpy as np
+
+from silx.math.chistogramnd import chistogramnd as histogramnd
+
+# ==============================================================
+# ==============================================================
+# ==============================================================
+
+_RTOL_DICT = {np.float64: 10**-13,
+ np.float32: 10**-5}
+
+# ==============================================================
+# ==============================================================
+# ==============================================================
+
+
+def _add_values_to_array_if_missing(array, values, n_values):
+ max_in_col = np.any(array[:, ...] == values, axis=0)
+
+ if len(array.shape) == 1:
+ if not max_in_col:
+ rnd_idx = np.random.randint(0,
+ high=len(array)-1,
+ size=(n_values,))
+ array[rnd_idx] = values
+ else:
+ for i in range(len(max_in_col)):
+ if not max_in_col[i]:
+ rnd_idx = np.random.randint(0,
+ high=len(array)-1,
+ size=(n_values,))
+ array[rnd_idx, i] = values[i]
+
+
+def _get_values_index(array, values, op=operator.lt):
+ idx = op(array[:, ...], values)
+ if array.ndim > 1:
+ idx = np.all(idx, axis=1)
+ return np.where(idx)[0]
+
+
+def _get_in_range_indices(array,
+ minvalues,
+ maxvalues,
+ minop=operator.ge,
+ maxop=operator.lt):
+ idx = np.logical_and(minop(array, minvalues),
+ maxop(array, maxvalues))
+ if array.ndim > 1:
+ idx = np.all(idx, axis=1)
+ return np.where(idx)[0]
+
+
+class _TestHistogramnd(unittest.TestCase):
+ """
+ Unit tests of the histogramnd function.
+ """
+ __test__ = False # ignore abstract class
+
+ sample_rng = None
+ weights_rng = None
+ n_dims = None
+
+ filter_min = None
+ filter_max = None
+
+ histo_range = None
+ n_bins = None
+
+ dtype_sample = None
+ dtype_weights = None
+
+ def generate_data(self):
+
+ self.longMessage = True
+
+ int_min = 0
+ int_max = 100000
+ n_elements = 10**5
+
+ if self.n_dims == 1:
+ shape = (n_elements,)
+ else:
+ shape = (n_elements, self.n_dims,)
+
+ self.rng_state = np.random.get_state()
+
+ self.state_msg = ('Current RNG state :\n'
+ '{0}'.format(self.rng_state))
+
+ sample = np.random.randint(int_min,
+ high=int_max,
+ size=shape)
+
+ sample = sample.astype(self.dtype_sample)
+ sample = (self.sample_rng[0] +
+ (sample-int_min) *
+ (self.sample_rng[1]-self.sample_rng[0]) /
+ (int_max-int_min)).astype(self.dtype_sample)
+
+ weights = np.random.randint(int_min,
+ high=int_max,
+ size=(n_elements,))
+ weights = weights.astype(self.dtype_weights)
+ weights = (self.weights_rng[0] +
+ (weights-int_min) *
+ (self.weights_rng[1]-self.weights_rng[0]) /
+ (int_max-int_min)).astype(self.dtype_weights)
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # the bins range are cast to the same type as the sample
+ # in order to get the same results as numpy
+ # (which doesnt cast the range)
+ self.histo_range = np.array(self.histo_range).astype(self.dtype_sample)
+
+ # adding some values that are equal to the max
+ # in order to test the opened/closed last bin
+ bins_max = [b[1] for b in self.histo_range]
+ _add_values_to_array_if_missing(sample,
+ bins_max,
+ 100)
+
+ # adding some values that are equal to the min weight value
+ # in order to test the filters
+ _add_values_to_array_if_missing(weights,
+ self.weights_rng[0],
+ 100)
+
+ # adding some values that are equal to the max weight value
+ # in order to test the filters
+ _add_values_to_array_if_missing(weights,
+ self.weights_rng[1],
+ 100)
+
+ return sample, weights
+
+ def setUp(self):
+ if type(self).__name__.startswith("_"):
+ self.skipTest("Abstract class")
+ self.sample, self.weights = self.generate_data()
+ self.rtol = _RTOL_DICT.get(self.dtype_weights, None)
+
+ def array_compare(self, ar_a, ar_b):
+ if self.rtol is None:
+ return np.array_equal(ar_a, ar_b)
+ return np.allclose(ar_a, ar_b, self.rtol)
+
+ def test_bin_ranges(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True)
+
+ result_np = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ for i_edges, edges in enumerate(result_c[2]):
+ # allclose for now until I can try with the latest version (TBD)
+ # of numpy
+ self.assertTrue(np.allclose(edges,
+ result_np[1][i_edges]),
+ msg='{0}. Testing bin_edges for dim {1}.'
+ ''.format(self.state_msg, i_edges+1))
+
+ def test_last_bin_closed(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True)
+
+ result_np = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights)
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c[0],
+ result_np[0])
+ # comparing weights
+ weights_cmp = np.array_equal(result_c[1],
+ result_np_w[0])
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertTrue(weights_cmp, msg=self.state_msg)
+
+ bins_min = [rng[0] for rng in self.histo_range]
+ bins_max = [rng[1] for rng in self.histo_range]
+ inrange_idx = _get_in_range_indices(self.sample,
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le)
+
+ self.assertEqual(result_c[0].sum(), inrange_idx.shape[0],
+ msg=self.state_msg)
+
+ # we have to sum the weights using the same precision as the
+ # histogramnd function
+ weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
+ self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
+ msg=self.state_msg)
+
+ def test_last_bin_open(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=False)
+
+ bins_max = [rng[1] for rng in self.histo_range]
+ filtered_idx = _get_values_index(self.sample, bins_max)
+
+ result_np = np.histogramdd(self.sample[filtered_idx],
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w = np.histogramdd(self.sample[filtered_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[filtered_idx])
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c[0], result_np[0])
+ # comparing weights
+ weights_cmp = np.array_equal(result_c[1],
+ result_np_w[0])
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertTrue(weights_cmp, msg=self.state_msg)
+
+ bins_min = [rng[0] for rng in self.histo_range]
+ bins_max = [rng[1] for rng in self.histo_range]
+ inrange_idx = _get_in_range_indices(self.sample,
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.lt)
+
+ self.assertEqual(result_c[0].sum(), len(inrange_idx),
+ msg=self.state_msg)
+ # we have to sum the weights using the same precision as the
+ # histogramnd function
+ weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
+ self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
+ msg=self.state_msg)
+
+ def test_filter_min(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weight_min=self.filter_min)
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ filter_min = self.dtype_weights(self.filter_min)
+
+ weight_idx = _get_values_index(self.weights,
+ filter_min, # <------ !!!
+ operator.ge)
+
+ result_np = np.histogramdd(self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w = np.histogramdd(self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[weight_idx])
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c[0],
+ result_np[0])
+ # comparing weights
+ weights_cmp = np.array_equal(result_c[1], result_np_w[0])
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertTrue(weights_cmp, msg=self.state_msg)
+
+ bins_min = [rng[0] for rng in self.histo_range]
+ bins_max = [rng[1] for rng in self.histo_range]
+ inrange_idx = _get_in_range_indices(self.sample[weight_idx],
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le)
+
+ inrange_idx = weight_idx[inrange_idx]
+
+ self.assertEqual(result_c[0].sum(), len(inrange_idx),
+ msg=self.state_msg)
+
+ # we have to sum the weights using the same precision as the
+ # histogramnd function
+ weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
+ self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
+ msg=self.state_msg)
+
+ def test_filter_max(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weight_max=self.filter_max)
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ filter_max = self.dtype_weights(self.filter_max)
+
+ weight_idx = _get_values_index(self.weights,
+ filter_max, # <------ !!!
+ operator.le)
+
+ result_np = np.histogramdd(self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w = np.histogramdd(self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[weight_idx])
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c[0],
+ result_np[0])
+ # comparing weights
+ weights_cmp = np.array_equal(result_c[1], result_np_w[0])
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertTrue(weights_cmp, msg=self.state_msg)
+
+ bins_min = [rng[0] for rng in self.histo_range]
+ bins_max = [rng[1] for rng in self.histo_range]
+ inrange_idx = _get_in_range_indices(self.sample[weight_idx],
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le)
+
+ inrange_idx = weight_idx[inrange_idx]
+
+ self.assertEqual(result_c[0].sum(), len(inrange_idx),
+ msg=self.state_msg)
+
+ # we have to sum the weights using the same precision as the
+ # histogramnd function
+ weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
+ self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
+ msg=self.state_msg)
+
+ def test_filter_minmax(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weight_min=self.filter_min,
+ weight_max=self.filter_max)
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ filter_min = self.dtype_weights(self.filter_min)
+ filter_max = self.dtype_weights(self.filter_max)
+
+ weight_idx = _get_in_range_indices(self.weights,
+ filter_min, # <------ !!!
+ filter_max, # <------ !!!
+ minop=operator.ge,
+ maxop=operator.le)
+
+ result_np = np.histogramdd(self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w = np.histogramdd(self.sample[weight_idx],
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights[weight_idx])
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c[0],
+ result_np[0])
+ # comparing weights
+ weights_cmp = np.array_equal(result_c[1], result_np_w[0])
+
+ self.assertTrue(hits_cmp)
+ self.assertTrue(weights_cmp)
+
+ bins_min = [rng[0] for rng in self.histo_range]
+ bins_max = [rng[1] for rng in self.histo_range]
+ inrange_idx = _get_in_range_indices(self.sample[weight_idx],
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le)
+
+ inrange_idx = weight_idx[inrange_idx]
+
+ self.assertEqual(result_c[0].sum(), len(inrange_idx),
+ msg=self.state_msg)
+
+ # we have to sum the weights using the same precision as the
+ # histogramnd function
+ weights_sum = self.weights[inrange_idx].astype(result_c[1].dtype).sum()
+ self.assertTrue(self.array_compare(result_c[1].sum(), weights_sum),
+ msg=self.state_msg)
+
+ def test_reuse_histo(self):
+ """
+
+ """
+ result_c_1 = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True)
+
+ result_np_1 = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights)
+
+ sample_2, weights_2 = self.generate_data()
+
+ result_c_2 = histogramnd(sample_2,
+ self.histo_range,
+ self.n_bins,
+ weights=weights_2,
+ last_bin_closed=True,
+ histo=result_c_1[0])
+
+ result_np_2 = np.histogramdd(sample_2,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w_2 = np.histogramdd(sample_2,
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=weights_2)
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c_2[0],
+ result_np_1[0] +
+ result_np_2[0])
+ # comparing weights
+ weights_cmp = np.array_equal(result_c_2[1],
+ result_np_w_2[0])
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertTrue(weights_cmp, msg=self.state_msg)
+
+ def test_reuse_cumul(self):
+ """
+
+ """
+ result_c = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True)
+
+ np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights)
+
+ sample_2, weights_2 = self.generate_data()
+
+ result_c_2 = histogramnd(sample_2,
+ self.histo_range,
+ self.n_bins,
+ weights=weights_2,
+ last_bin_closed=True,
+ weighted_histo=result_c[1])
+
+ result_np_2 = np.histogramdd(sample_2,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w_2 = np.histogramdd(sample_2,
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=weights_2)
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c_2[0],
+ result_np_2[0])
+ # comparing weights
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertTrue(self.array_compare(result_c_2[1],
+ result_np_w[0] + result_np_w_2[0]),
+ msg=self.state_msg)
+
+ def test_reuse_cumul_float(self):
+ """
+
+ """
+ n_bins = np.array(self.n_bins, ndmin=1)
+ if len(self.sample.shape) == 2:
+ if len(n_bins) == self.sample.shape[1]:
+ shp = tuple([x for x in n_bins])
+ else:
+ shp = (self.n_bins,) * self.sample.shape[1]
+ cumul = np.zeros(shp, dtype=np.float32)
+ else:
+ shp = (self.n_bins,)
+ cumul = np.zeros(shp, dtype=np.float32)
+
+ result_c_1 = histogramnd(self.sample,
+ self.histo_range,
+ self.n_bins,
+ weights=self.weights,
+ last_bin_closed=True,
+ weighted_histo=cumul)
+
+ result_np_1 = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range)
+
+ result_np_w_1 = np.histogramdd(self.sample,
+ bins=self.n_bins,
+ range=self.histo_range,
+ weights=self.weights)
+
+ # comparing "hits"
+ hits_cmp = np.array_equal(result_c_1[0],
+ result_np_1[0])
+
+ self.assertTrue(hits_cmp, msg=self.state_msg)
+ self.assertEqual(result_c_1[1].dtype, np.float32, msg=self.state_msg)
+
+ bins_min = [rng[0] for rng in self.histo_range]
+ bins_max = [rng[1] for rng in self.histo_range]
+ inrange_idx = _get_in_range_indices(self.sample,
+ bins_min,
+ bins_max,
+ minop=operator.ge,
+ maxop=operator.le)
+ weights_sum = \
+ self.weights[inrange_idx].astype(np.float32).sum(dtype=np.float64)
+ self.assertTrue(np.allclose(result_c_1[1].sum(dtype=np.float64),
+ weights_sum), msg=self.state_msg)
+ self.assertTrue(np.allclose(result_c_1[1].sum(dtype=np.float64),
+ result_np_w_1[0].sum(dtype=np.float64)),
+ msg=self.state_msg)
+
+
+class _TestHistogramnd_1d(_TestHistogramnd):
+ """
+ Unit tests of the 1D histogramnd function.
+ """
+ sample_rng = [-55., 100.]
+ weights_rng = [-70., 150.]
+ n_dims = 1
+ filter_min = -15.6
+ filter_max = 85.7
+
+ histo_range = [[-30.2, 90.3]]
+ n_bins = 30
+
+ dtype = None
+
+
+class _TestHistogramnd_2d(_TestHistogramnd):
+ """
+ Unit tests of the 1D histogramnd function.
+ """
+ sample_rng = [-50.2, 100.99]
+ weights_rng = [70., 150.]
+ n_dims = 2
+ filter_min = 81.7
+ filter_max = 135.3
+
+ histo_range = [[10., 90.], [20., 70.]]
+ n_bins = 30
+
+ dtype = None
+
+
+class _TestHistogramnd_3d(_TestHistogramnd):
+ """
+ Unit tests of the 1D histogramnd function.
+ """
+ sample_rng = [10.2, 200.9]
+ weights_rng = [0., 100.]
+ n_dims = 3
+ filter_min = 31.5
+ filter_max = 83.7
+
+ histo_range = [[30.8, 150.2], [20.1, 90.9], [10.1, 195.]]
+ n_bins = 30
+
+ dtype = None
+
+
+# ################################################################
+# ################################################################
+# ################################################################
+# ################################################################
+
+
+class TestHistogramnd_1d_double_double(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.double
+
+
+class TestHistogramnd_1d_double_float(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_1d_double_int32(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_1d_float_double(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.double
+
+
+class TestHistogramnd_1d_float_float(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_1d_float_int32(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_1d_int32_double(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.double
+
+
+class TestHistogramnd_1d_int32_float(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_1d_int32_int32(_TestHistogramnd_1d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_2d_double_double(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.double
+
+
+class TestHistogramnd_2d_double_float(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_2d_double_int32(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_2d_float_double(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.double
+
+
+class TestHistogramnd_2d_float_float(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_2d_float_int32(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_2d_int32_double(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.double
+
+
+class TestHistogramnd_2d_int32_float(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_2d_int32_int32(_TestHistogramnd_2d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_3d_double_double(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.double
+
+
+class TestHistogramnd_3d_double_float(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_3d_double_int32(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.double
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_3d_float_double(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.double
+
+
+class TestHistogramnd_3d_float_float(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_3d_float_int32(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.float32
+ dtype_weights = np.int32
+
+
+class TestHistogramnd_3d_int32_double(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.double
+
+
+class TestHistogramnd_3d_int32_float(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.float32
+
+
+class TestHistogramnd_3d_int32_int32(_TestHistogramnd_3d):
+ __test__ = True # because _TestHistogramnd is ignored
+ dtype_sample = np.int32
+ dtype_weights = np.int32
diff --git a/src/silx/math/test/test_interpolate.py b/src/silx/math/test/test_interpolate.py
new file mode 100644
index 0000000..146449d
--- /dev/null
+++ b/src/silx/math/test/test_interpolate.py
@@ -0,0 +1,125 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Test for interpolate module"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "11/07/2019"
+
+
+import unittest
+
+import numpy
+try:
+ from scipy.interpolate import interpn
+except ImportError:
+ interpn = None
+
+from silx.utils.testutils import ParametricTestCase
+from silx.math import interpolate
+
+
+@unittest.skipUnless(interpn is not None, "scipy missing")
+class TestInterp3d(ParametricTestCase):
+ """Test silx.math.interpolate.interp3d"""
+
+ @staticmethod
+ def ref_interp3d(data, points):
+ """Reference implementation of interp3d based on scipy
+
+ :param numpy.ndarray data: 3D floating dataset
+ :param numpy.ndarray points: Array of points of shape (N, 3)
+ """
+ return interpn(
+ [numpy.arange(dim, dtype=data.dtype) for dim in data.shape],
+ data,
+ points,
+ method='linear')
+
+ def test_random_data(self):
+ """Test interp3d with random data"""
+ size = 32
+ npoints = 10
+
+ ref_data = numpy.random.random((size, size, size))
+ ref_points = numpy.random.random(npoints*3).reshape(npoints, 3) * (size -1)
+
+ for dtype in (numpy.float32, numpy.float64):
+ data = ref_data.astype(dtype)
+ points = ref_points.astype(dtype)
+ ref_result = self.ref_interp3d(data, points)
+
+ for method in (u'linear', u'linear_omp'):
+ with self.subTest(method=method):
+ result = interpolate.interp3d(data, points, method=method)
+ self.assertTrue(numpy.allclose(ref_result, result))
+
+ def test_notfinite_data(self):
+ """Test interp3d with NaN and inf"""
+ data = numpy.ones((3, 3, 3), dtype=numpy.float64)
+ data[0, 0, 0] = numpy.nan
+ data[2, 2, 2] = numpy.inf
+ points = numpy.array([(0.5, 0.5, 0.5),
+ (1.5, 1.5, 1.5)])
+
+ for method in (u'linear', u'linear_omp'):
+ with self.subTest(method=method):
+ result = interpolate.interp3d(
+ data, points, method=method)
+ self.assertTrue(numpy.isnan(result[0]))
+ self.assertTrue(result[1] == numpy.inf)
+
+ def test_points_outside(self):
+ """Test interp3d with points outside the volume"""
+ data = numpy.ones((4, 4, 4), dtype=numpy.float64)
+ points = numpy.array([(-0.1, -0.1, -0.1),
+ (3.1, 3.1, 3.1),
+ (-0.1, 1., 1.),
+ (1., 1., 3.1)])
+
+ for method in (u'linear', u'linear_omp'):
+ for fill_value in (numpy.nan, 0., -1.):
+ with self.subTest(method=method):
+ result = interpolate.interp3d(
+ data, points, method=method, fill_value=fill_value)
+ if numpy.isnan(fill_value):
+ self.assertTrue(numpy.all(numpy.isnan(result)))
+ else:
+ self.assertTrue(numpy.all(numpy.equal(result, fill_value)))
+
+ def test_integer_points(self):
+ """Test interp3d with integer points coord"""
+ data = numpy.arange(4**3, dtype=numpy.float64).reshape(4, 4, 4)
+ points = numpy.array([(0., 0., 0.),
+ (0., 0., 1.),
+ (2., 3., 0.),
+ (3., 3., 3.)])
+
+ ref_result = data[tuple(points.T.astype(numpy.int32))]
+
+ for method in (u'linear', u'linear_omp'):
+ with self.subTest(method=method):
+ result = interpolate.interp3d(data, points, method=method)
+ self.assertTrue(numpy.allclose(ref_result, result))
diff --git a/src/silx/math/test/test_marchingcubes.py b/src/silx/math/test/test_marchingcubes.py
new file mode 100644
index 0000000..5e2b193
--- /dev/null
+++ b/src/silx/math/test/test_marchingcubes.py
@@ -0,0 +1,174 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests of the marchingcubes module"""
+
+from __future__ import division
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+import unittest
+
+import numpy
+
+from silx.utils.testutils import ParametricTestCase
+
+from silx.math import marchingcubes
+
+
+class TestMarchingCubes(ParametricTestCase):
+ """Tests of marching cubes"""
+
+ def assertAllClose(self, array1, array2, msg=None,
+ rtol=1e-05, atol=1e-08):
+ """Assert that the 2 numpy.ndarrays are almost equal.
+
+ :param str msg: Message to provide when assert fails
+ :param float rtol: Relative tolerance, see :func:`numpy.allclose`
+ :param float atol: Absolute tolerance, see :func:`numpy.allclose`
+ """
+ if not numpy.allclose(array1, array2, rtol, atol):
+ raise self.failureException(msg)
+
+ def test_cube(self):
+ """Unit tests with a single cube"""
+
+ # No isosurface
+ cube_zero = numpy.zeros((2, 2, 2), dtype=numpy.float32)
+
+ result = marchingcubes.MarchingCubes(cube_zero, 1.)
+ self.assertEqual(result.shape, cube_zero.shape)
+ self.assertEqual(result.isolevel, 1.)
+ self.assertEqual(result.invert_normals, True)
+
+ vertices, normals, indices = result
+ self.assertEqual(len(vertices), 0)
+ self.assertEqual(len(normals), 0)
+ self.assertEqual(len(indices), 0)
+
+ # Cube array dimensions: shape = (dim 0, dim 1, dim2)
+ #
+ # dim 0 (Z)
+ # ^
+ # |
+ # 4 +------+ 5
+ # /| /|
+ # / | / |
+ # 6 +------+ 7|
+ # | | | |
+ # |0 +---|--+ 1 -> dim 2 (X)
+ # | / | /
+ # |/ |/
+ # 2 +------+ 3
+ # /
+ # dim 1 (Y)
+
+ # isosurface perpendicular to dim 0 (Z)
+ cube = numpy.array(
+ (((0., 0.), (0., 0.)),
+ ((1., 1.), (1., 1.))), dtype=numpy.float32)
+ level = 0.5
+ vertices, normals, indices = marchingcubes.MarchingCubes(
+ cube, level, invert_normals=False)
+ self.assertAllClose(vertices[:, 0], level)
+ self.assertAllClose(normals, (1., 0., 0.))
+ self.assertEqual(len(indices), 2)
+
+ # isosurface perpendicular to dim 1 (Y)
+ cube = numpy.array(
+ (((0., 0.), (1., 1.)),
+ ((0., 0.), (1., 1.))), dtype=numpy.float32)
+ level = 0.2
+ vertices, normals, indices = marchingcubes.MarchingCubes(cube, level)
+ self.assertAllClose(vertices[:, 1], level)
+ self.assertAllClose(normals, (0., -1., 0.))
+ self.assertEqual(len(indices), 2)
+
+ # isosurface perpendicular to dim 2 (X)
+ cube = numpy.array(
+ (((0., 1.), (0., 1.)),
+ ((0., 1.), (0., 1.))), dtype=numpy.float32)
+ level = 0.9
+ vertices, normals, indices = marchingcubes.MarchingCubes(
+ cube, level, invert_normals=False)
+ self.assertAllClose(vertices[:, 2], level)
+ self.assertAllClose(normals, (0., 0., 1.))
+ self.assertEqual(len(indices), 2)
+
+ # isosurface normal in dim1, dim 0 (Y, Z) plane
+ cube = numpy.array(
+ (((0., 0.), (0., 0.)),
+ ((0., 0.), (1., 1.))), dtype=numpy.float32)
+ level = 0.5
+ vertices, normals, indices = marchingcubes.MarchingCubes(cube, level)
+ self.assertAllClose(normals[:, 2], 0.)
+ self.assertEqual(len(indices), 2)
+
+ def test_sampling(self):
+ """Test different sampling, comparing to reference without sampling"""
+ isolevel = 0.5
+ size = 9
+ chessboard = numpy.zeros((size, size, size), dtype=numpy.float32)
+ chessboard.reshape(-1)[::2] = 1 # OK as long as dimensions are odd
+
+ ref_result = marchingcubes.MarchingCubes(chessboard, isolevel)
+
+ samplings = [
+ (2, 1, 1),
+ (1, 2, 1),
+ (1, 1, 2),
+ (2, 2, 2),
+ (3, 3, 3),
+ (1, 3, 1),
+ (1, 1, 3),
+ ]
+
+ for sampling in samplings:
+ with self.subTest(sampling=sampling):
+ sampling = numpy.array(sampling)
+
+ data = 1e6 * numpy.ones(
+ sampling * size, dtype=numpy.float32)
+ # Copy ref chessboard in data according to sampling
+ data[::sampling[0], ::sampling[1], ::sampling[2]] = chessboard
+
+ result = marchingcubes.MarchingCubes(data, isolevel,
+ sampling=sampling)
+ # Compare vertices normalized with shape
+ self.assertAllClose(
+ ref_result.get_vertices() / ref_result.shape,
+ result.get_vertices() / result.shape,
+ atol=0., rtol=0.)
+
+ # Compare normals
+ # This comparison only works for normals aligned with axes
+ # otherwise non uniform sampling would make different normals
+ self.assertAllClose(ref_result.get_normals(),
+ result.get_normals(),
+ atol=0., rtol=0.)
+
+ self.assertAllClose(ref_result.get_indices(),
+ result.get_indices(),
+ atol=0., rtol=0.)
diff --git a/silx/opencl/__init__.py b/src/silx/opencl/__init__.py
index fbd1f88..fbd1f88 100644
--- a/silx/opencl/__init__.py
+++ b/src/silx/opencl/__init__.py
diff --git a/silx/opencl/backprojection.py b/src/silx/opencl/backprojection.py
index 65a9836..65a9836 100644
--- a/silx/opencl/backprojection.py
+++ b/src/silx/opencl/backprojection.py
diff --git a/silx/opencl/codec/__init__.py b/src/silx/opencl/codec/__init__.py
index e69de29..e69de29 100644
--- a/silx/opencl/codec/__init__.py
+++ b/src/silx/opencl/codec/__init__.py
diff --git a/silx/opencl/codec/byte_offset.py b/src/silx/opencl/codec/byte_offset.py
index 9a52427..9a52427 100644
--- a/silx/opencl/codec/byte_offset.py
+++ b/src/silx/opencl/codec/byte_offset.py
diff --git a/silx/opencl/codec/setup.py b/src/silx/opencl/codec/setup.py
index 4a5c1e5..4a5c1e5 100644
--- a/silx/opencl/codec/setup.py
+++ b/src/silx/opencl/codec/setup.py
diff --git a/src/silx/opencl/codec/test/__init__.py b/src/silx/opencl/codec/test/__init__.py
new file mode 100644
index 0000000..325c2c7
--- /dev/null
+++ b/src/silx/opencl/codec/test/__init__.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2013-2017 European Synchrotron Radiation Facility, Grenoble, France
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
diff --git a/src/silx/opencl/codec/test/test_byte_offset.py b/src/silx/opencl/codec/test/test_byte_offset.py
new file mode 100644
index 0000000..4b2d5a3
--- /dev/null
+++ b/src/silx/opencl/codec/test/test_byte_offset.py
@@ -0,0 +1,303 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Project: Byte-offset decompression in OpenCL
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2013-2020 European Synchrotron Radiation Facility,
+# Grenoble, France
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+Test suite for byte-offset decompression
+"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "02/03/2021"
+
+import sys
+import time
+import logging
+import numpy
+from silx.opencl.common import ocl, pyopencl
+from silx.opencl.codec import byte_offset
+import fabio
+import unittest
+logger = logging.getLogger(__name__)
+
+
+@unittest.skipUnless(ocl and pyopencl,
+ "PyOpenCl is missing")
+class TestByteOffset(unittest.TestCase):
+
+ @staticmethod
+ def _create_test_data(shape, nexcept, lam=200):
+ """Create test (image, compressed stream) pair.
+
+ :param shape: Shape of test image
+ :param int nexcept: Number of exceptions in the image
+ :param lam: Expectation of interval argument for numpy.random.poisson
+ :return: (reference image array, compressed stream)
+ """
+ size = numpy.prod(shape)
+ ref = numpy.random.poisson(lam, numpy.prod(shape))
+ exception_loc = numpy.random.randint(0, size, size=nexcept)
+ exception_value = numpy.random.randint(0, 1000000, size=nexcept)
+ ref[exception_loc] = exception_value
+ ref.shape = shape
+
+ raw = fabio.compression.compByteOffset(ref)
+ return ref, raw
+
+ def test_decompress(self):
+ """
+ tests the byte offset decompression on GPU
+ """
+ ref, raw = self._create_test_data(shape=(91, 97), nexcept=229)
+ # ref, raw = self._create_test_data(shape=(7, 9), nexcept=0)
+
+ size = numpy.prod(ref.shape)
+
+ try:
+ bo = byte_offset.ByteOffset(raw_size=len(raw), dec_size=size, profile=True)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ if sys.platform == "darwin":
+ raise unittest.SkipTest("Byte-offset decompression is known to be buggy on MacOS-CPU")
+ else:
+ raise err
+ print(bo.block_size)
+
+ t0 = time.time()
+ res_cy = fabio.compression.decByteOffset(raw)
+ t1 = time.time()
+ res_cl = bo.decode(raw)
+ t2 = time.time()
+ delta_cy = abs(ref.ravel() - res_cy).max()
+ delta_cl = abs(ref.ravel() - res_cl.get()).max()
+
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+ bo.log_profile()
+ # print(ref)
+ # print(res_cl.get())
+ self.assertEqual(delta_cy, 0, "Checks fabio works")
+ self.assertEqual(delta_cl, 0, "Checks opencl works")
+
+ def test_many_decompress(self, ntest=10):
+ """
+ tests the byte offset decompression on GPU, many images to ensure there
+ is not leaking in memory
+ """
+ shape = (991, 997)
+ size = numpy.prod(shape)
+ ref, raw = self._create_test_data(shape=shape, nexcept=0, lam=100)
+
+ try:
+ bo = byte_offset.ByteOffset(len(raw), size, profile=True)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ if sys.platform == "darwin":
+ raise unittest.SkipTest("Byte-offset decompression is known to be buggy on MacOS-CPU")
+ else:
+ raise err
+ t0 = time.time()
+ res_cy = fabio.compression.decByteOffset(raw)
+ t1 = time.time()
+ res_cl = bo(raw)
+ t2 = time.time()
+ delta_cy = abs(ref.ravel() - res_cy).max()
+ delta_cl = abs(ref.ravel() - res_cl.get()).max()
+ self.assertEqual(delta_cy, 0, "Checks fabio works")
+ self.assertEqual(delta_cl, 0, "Checks opencl works")
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+
+ for i in range(ntest):
+ ref, raw = self._create_test_data(shape=shape, nexcept=2729, lam=200)
+
+ t0 = time.time()
+ res_cy = fabio.compression.decByteOffset(raw)
+ t1 = time.time()
+ res_cl = bo(raw)
+ t2 = time.time()
+ delta_cy = abs(ref.ravel() - res_cy).max()
+ delta_cl = abs(ref.ravel() - res_cl.get()).max()
+ self.assertEqual(delta_cy, 0, "Checks fabio works #%i" % i)
+ self.assertEqual(delta_cl, 0, "Checks opencl works #%i" % i)
+
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+ bo.log_profile(stats=True)
+
+ def test_encode(self):
+ """Test byte offset compression"""
+ ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
+
+ try:
+ bo = byte_offset.ByteOffset(len(raw), ref.size, profile=True)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ raise err
+
+ t0 = time.time()
+ compressed_array = bo.encode(ref)
+ t1 = time.time()
+
+ compressed_stream = compressed_array.get().tobytes()
+ self.assertEqual(raw, compressed_stream)
+
+ logger.debug("Global execution time: OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0))
+ bo.log_profile()
+
+ def test_encode_to_array(self):
+ """Test byte offset compression while providing an out array"""
+
+ ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
+
+ try:
+ bo = byte_offset.ByteOffset(profile=True)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ raise err
+ # Test with out buffer too small
+ out = pyopencl.array.empty(bo.queue, (10,), numpy.int8)
+ with self.assertRaises(ValueError):
+ bo.encode(ref, out)
+
+ # Test with out buffer too big
+ out = pyopencl.array.empty(bo.queue, (len(raw) + 10,), numpy.int8)
+
+ compressed_array = bo.encode(ref, out)
+
+ # Get size from returned array
+ compressed_size = compressed_array.size
+ self.assertEqual(compressed_size, len(raw))
+
+ # Get data from out array, read it from bo object queue
+ out_bo_queue = out.with_queue(bo.queue)
+ compressed_stream = out_bo_queue.get().tobytes()[:compressed_size]
+ self.assertEqual(raw, compressed_stream)
+
+ def test_encode_to_bytes(self):
+ """Test byte offset compression to bytes"""
+ ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
+
+ try:
+ bo = byte_offset.ByteOffset(profile=True)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ raise err
+
+ t0 = time.time()
+ res_fabio = fabio.compression.compByteOffset(ref)
+ t1 = time.time()
+ compressed_stream = bo.encode_to_bytes(ref)
+ t2 = time.time()
+
+ self.assertEqual(raw, compressed_stream)
+
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+ bo.log_profile()
+
+ def test_encode_to_bytes_from_array(self):
+ """Test byte offset compression to bytes from a pyopencl array.
+ """
+ ref, raw = self._create_test_data(shape=(2713, 2719), nexcept=2729)
+
+ try:
+ bo = byte_offset.ByteOffset(profile=True)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ raise err
+
+ d_ref = pyopencl.array.to_device(
+ bo.queue, ref.astype(numpy.int32).ravel())
+
+ t0 = time.time()
+ res_fabio = fabio.compression.compByteOffset(ref)
+ t1 = time.time()
+ compressed_stream = bo.encode_to_bytes(d_ref)
+ t2 = time.time()
+
+ self.assertEqual(raw, compressed_stream)
+
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+ bo.log_profile()
+
+ def test_many_encode(self, ntest=10):
+ """Test byte offset compression with many image"""
+ shape = (991, 997)
+ ref, raw = self._create_test_data(shape=shape, nexcept=0, lam=100)
+
+ try:
+ bo = byte_offset.ByteOffset(profile=False)
+ except (RuntimeError, pyopencl.RuntimeError) as err:
+ logger.warning(err)
+ raise err
+
+ bo_durations = []
+
+ t0 = time.time()
+ res_fabio = fabio.compression.compByteOffset(ref)
+ t1 = time.time()
+ compressed_stream = bo.encode_to_bytes(ref)
+ t2 = time.time()
+ bo_durations.append(1000.0 * (t2 - t1))
+
+ self.assertEqual(raw, compressed_stream)
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+
+ for i in range(ntest):
+ ref, raw = self._create_test_data(shape=shape, nexcept=2729, lam=200)
+
+ t0 = time.time()
+ res_fabio = fabio.compression.compByteOffset(ref)
+ t1 = time.time()
+ compressed_stream = bo.encode_to_bytes(ref)
+ t2 = time.time()
+ bo_durations.append(1000.0 * (t2 - t1))
+
+ self.assertEqual(raw, compressed_stream)
+ logger.debug("Global execution time: fabio %.3fms, OpenCL: %.3fms.",
+ 1000.0 * (t1 - t0),
+ 1000.0 * (t2 - t1))
+
+ logger.debug("OpenCL execution time: Mean: %fms, Min: %fms, Max: %fms",
+ numpy.mean(bo_durations),
+ numpy.min(bo_durations),
+ numpy.max(bo_durations))
diff --git a/src/silx/opencl/common.py b/src/silx/opencl/common.py
new file mode 100644
index 0000000..60849d6
--- /dev/null
+++ b/src/silx/opencl/common.py
@@ -0,0 +1,692 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Project: S I L X project
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2021 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+#
+
+__author__ = "Jerome Kieffer"
+__contact__ = "Jerome.Kieffer@ESRF.eu"
+__license__ = "MIT"
+__copyright__ = "2012-2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "29/09/2021"
+__status__ = "stable"
+__all__ = ["ocl", "pyopencl", "mf", "release_cl_buffers", "allocate_cl_buffers",
+ "measure_workgroup_size", "kernel_workgroup_size"]
+
+import os
+import logging
+
+import numpy
+
+from .utils import get_opencl_code
+
+logger = logging.getLogger(__name__)
+
+if os.environ.get("SILX_OPENCL") in ["0", "False"]:
+ logger.info("Use of OpenCL has been disabled from environment variable: SILX_OPENCL=0")
+ pyopencl = None
+else:
+ try:
+ import pyopencl
+ except ImportError:
+ logger.warning("Unable to import pyOpenCl. Please install it from: https://pypi.org/project/pyopencl")
+ pyopencl = None
+ else:
+ try:
+ pyopencl.get_platforms()
+ except pyopencl.LogicError:
+ logger.warning("The module pyOpenCL has been imported but can't be used here")
+ pyopencl = None
+ else:
+ import pyopencl.array as array
+ mf = pyopencl.mem_flags
+
+if pyopencl is None:
+
+ # Define default mem flags
+ class mf(object):
+ WRITE_ONLY = 1
+ READ_ONLY = 1
+ READ_WRITE = 1
+
+FLOP_PER_CORE = {"GPU": 64, # GPU, Fermi at least perform 64 flops per cycle/multicore, G80 were at 24 or 48 ...
+ "CPU": 4, # CPU, at least intel's have 4 operation per cycle
+ "ACC": 8} # ACC: the Xeon-phi (MIC) appears to be able to process 8 Flops per hyperthreaded-core
+
+# Sources : https://en.wikipedia.org/wiki/CUDA
+NVIDIA_FLOP_PER_CORE = {(1, 0): 24, # Guessed !
+ (1, 1): 24, # Measured on G98 [Quadro NVS 295]
+ (1, 2): 24, # Guessed !
+ (1, 3): 24, # measured on a GT285 (GT200)
+ (2, 0): 64, # Measured on a 580 (GF110)
+ (2, 1): 96, # Measured on Quadro2000 GF106GL
+ (3, 0): 384, # Guessed!
+ (3, 5): 384, # Measured on K20
+ (3, 7): 384, # K80: Guessed!
+ (5, 0): 256, # Maxwell 4 warps/SM 2 flops/ CU
+ (5, 2): 256, # Titan-X
+ (5, 3): 256, # TX1
+ (6, 0): 128, # GP100
+ (6, 1): 128, # GP104
+ (6, 2): 128, # ?
+ (7, 0): 128, # Volta # measured on Telsa V100
+ (7, 2): 128, # Volta ?
+ (7, 5): 128, # Turing # measured on RTX 6000
+ (8, 0): 128, # Ampere # measured on Tesla A100
+ (8, 6): 256, # Ampere # measured on RTX A5000
+ }
+
+AMD_FLOP_PER_CORE = 160 # Measured on a M7820 10 core, 700MHz 1120GFlops
+
+
+class Device(object):
+ """
+ Simple class that contains the structure of an OpenCL device
+ """
+
+ def __init__(self, name="None", dtype=None, version=None, driver_version=None,
+ extensions="", memory=None, available=None,
+ cores=None, frequency=None, flop_core=None, idx=0, workgroup=1):
+ """
+ Simple container with some important data for the OpenCL device description.
+
+ :param name: name of the device
+ :param dtype: device type: CPU/GPU/ACC...
+ :param version: driver version
+ :param driver_version:
+ :param extensions: List of opencl extensions
+ :param memory: maximum memory available on the device
+ :param available: is the device deactivated or not
+ :param cores: number of SM/cores
+ :param frequency: frequency of the device
+ :param flop_core: Flopating Point operation per core per cycle
+ :param idx: index of the device within the platform
+ :param workgroup: max workgroup size
+ """
+ self.name = name.strip()
+ self.type = dtype
+ self.version = version
+ self.driver_version = driver_version
+ self.extensions = extensions.split()
+ self.memory = memory
+ self.available = available
+ self.cores = cores
+ self.frequency = frequency
+ self.id = idx
+ self.max_work_group_size = workgroup
+ if not flop_core:
+ flop_core = FLOP_PER_CORE.get(dtype, 1)
+ if cores and frequency:
+ self.flops = cores * frequency * flop_core
+ else:
+ self.flops = flop_core
+
+ def __repr__(self):
+ return "%s" % self.name
+
+ def pretty_print(self):
+ """
+ Complete device description
+
+ :return: string
+ """
+ lst = ["Name\t\t:\t%s" % self.name,
+ "Type\t\t:\t%s" % self.type,
+ "Memory\t\t:\t%.3f MB" % (self.memory / 2.0 ** 20),
+ "Cores\t\t:\t%s CU" % self.cores,
+ "Frequency\t:\t%s MHz" % self.frequency,
+ "Speed\t\t:\t%.3f GFLOPS" % (self.flops / 1000.),
+ "Version\t\t:\t%s" % self.version,
+ "Available\t:\t%s" % self.available]
+ return os.linesep.join(lst)
+
+ def set_unavailable(self):
+ """Use this method to flag a faulty device
+ """
+ self.available = False
+
+
+class Platform(object):
+ """
+ Simple class that contains the structure of an OpenCL platform
+ """
+
+ def __init__(self, name="None", vendor="None", version=None, extensions=None, idx=0):
+ """
+ Class containing all descriptions of a platform and all devices description within that platform.
+
+ :param name: platform name
+ :param vendor: name of the brand/vendor
+ :param version:
+ :param extensions: list of the extension provided by the platform to all of its devices
+ :param idx: index of the platform
+ """
+ self.name = name.strip()
+ self.vendor = vendor.strip()
+ self.version = version
+ self.extensions = extensions.split()
+ self.devices = []
+ self.id = idx
+
+ def __repr__(self):
+ return "%s" % self.name
+
+ def add_device(self, device):
+ """
+ Add new device to the platform
+
+ :param device: Device instance
+ """
+ self.devices.append(device)
+
+ def get_device(self, key):
+ """
+ Return a device according to key
+
+ :param key: identifier for a device, either it's id (int) or it's name
+ :type key: int or str
+ """
+ out = None
+ try:
+ devid = int(key)
+ except ValueError:
+ for a_dev in self.devices:
+ if a_dev.name == key:
+ out = a_dev
+ else:
+ if len(self.devices) > devid > 0:
+ out = self.devices[devid]
+ return out
+
+
+def _measure_workgroup_size(device_or_context, fast=False):
+ """Mesure the maximal work group size of the given device
+
+ DEPRECATED since not perfectly correct !
+
+ :param device_or_context: instance of pyopencl.Device or pyopencl.Context
+ or 2-tuple (platformid,deviceid)
+ :param fast: ask the kernel the valid value, don't probe it
+ :return: maximum size for the workgroup
+ """
+ if isinstance(device_or_context, pyopencl.Device):
+ try:
+ ctx = pyopencl.Context(devices=[device_or_context])
+ except pyopencl._cl.LogicError as error:
+ platform = device_or_context.platform
+ platformid = pyopencl.get_platforms().index(platform)
+ deviceid = platform.get_devices().index(device_or_context)
+ ocl.platforms[platformid].devices[deviceid].set_unavailable()
+ raise RuntimeError("Unable to create context on %s/%s: %s" % (platform, device_or_context, error))
+ else:
+ device = device_or_context
+ elif isinstance(device_or_context, pyopencl.Context):
+ ctx = device_or_context
+ device = device_or_context.devices[0]
+ elif isinstance(device_or_context, (tuple, list)) and len(device_or_context) == 2:
+ ctx = ocl.create_context(platformid=device_or_context[0],
+ deviceid=device_or_context[1])
+ device = ctx.devices[0]
+ else:
+ raise RuntimeError("""given parameter device_or_context is not an
+ instanciation of a device or a context""")
+ shape = device.max_work_group_size
+ # get the context
+
+ assert ctx is not None
+ queue = pyopencl.CommandQueue(ctx)
+
+ max_valid_wg = 1
+ data = numpy.random.random(shape).astype(numpy.float32)
+ d_data = pyopencl.array.to_device(queue, data)
+ d_data_1 = pyopencl.array.empty_like(d_data)
+ d_data_1.fill(numpy.float32(1.0))
+
+ program = pyopencl.Program(ctx, get_opencl_code("addition")).build()
+ if fast:
+ max_valid_wg = program.addition.get_work_group_info(pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, device)
+ else:
+ maxi = int(round(numpy.log2(shape)))
+ for i in range(maxi + 1):
+ d_res = pyopencl.array.empty_like(d_data)
+ wg = 1 << i
+ try:
+ evt = program.addition(
+ queue, (shape,), (wg,),
+ d_data.data, d_data_1.data, d_res.data, numpy.int32(shape))
+ evt.wait()
+ except Exception as error:
+ logger.info("%s on device %s for WG=%s/%s", error, device.name, wg, shape)
+ program = queue = d_res = d_data_1 = d_data = None
+ break
+ else:
+ res = d_res.get()
+ good = numpy.allclose(res, data + 1)
+ if good:
+ if wg > max_valid_wg:
+ max_valid_wg = wg
+ else:
+ logger.warning("ArithmeticError on %s for WG=%s/%s", wg, device.name, shape)
+
+ return max_valid_wg
+
+
+def _is_nvidia_gpu(vendor, devtype):
+ return (vendor == "NVIDIA Corporation") and (devtype == "GPU")
+
+
+class OpenCL(object):
+ """
+ Simple class that wraps the structure ocl_tools_extended.h
+
+ This is a static class.
+ ocl should be the only instance and shared among all python modules.
+ """
+
+ platforms = []
+ nb_devices = 0
+ context_cache = {} # key: 2-tuple of int, value: context
+ if pyopencl:
+ platform = device = pypl = devtype = extensions = pydev = None
+ for idx, platform in enumerate(pyopencl.get_platforms()):
+ pypl = Platform(platform.name, platform.vendor, platform.version, platform.extensions, idx)
+ for idd, device in enumerate(platform.get_devices()):
+ ####################################################
+ # Nvidia does not report int64 atomics (we are using) ...
+ # this is a hack around as any nvidia GPU with double-precision supports int64 atomics
+ ####################################################
+ extensions = device.extensions
+ if (pypl.vendor == "NVIDIA Corporation") and ('cl_khr_fp64' in extensions):
+ extensions += ' cl_khr_int64_base_atomics cl_khr_int64_extended_atomics'
+ try:
+ devtype = pyopencl.device_type.to_string(device.type).upper()
+ except ValueError:
+ # pocl does not describe itself as a CPU !
+ devtype = "CPU"
+ if len(devtype) > 3:
+ if "GPU" in devtype:
+ devtype = "GPU"
+ elif "ACC" in devtype:
+ devtype = "ACC"
+ elif "CPU" in devtype:
+ devtype = "CPU"
+ else:
+ devtype = devtype[:3]
+ if _is_nvidia_gpu(device.vendor, devtype) and ("compute_capability_major_nv" in dir(device)):
+ try:
+ comput_cap = device.compute_capability_major_nv, device.compute_capability_minor_nv
+ except pyopencl.LogicError:
+ flop_core = FLOP_PER_CORE["GPU"]
+ else:
+ flop_core = NVIDIA_FLOP_PER_CORE.get(comput_cap, FLOP_PER_CORE["GPU"])
+ elif (pypl.vendor == "Advanced Micro Devices, Inc.") and (devtype == "GPU"):
+ flop_core = AMD_FLOP_PER_CORE
+ elif devtype == "CPU":
+ flop_core = FLOP_PER_CORE.get(devtype, 1)
+ else:
+ flop_core = 1
+ workgroup = device.max_work_group_size
+ if (devtype == "CPU") and (pypl.vendor == "Apple"):
+ logger.info("For Apple's OpenCL on CPU: Measuring actual valid max_work_goup_size.")
+ workgroup = _measure_workgroup_size(device, fast=True)
+ if (devtype == "GPU") and os.environ.get("GPU") == "False":
+ # Environment variable to disable GPU devices
+ continue
+ pydev = Device(device.name, devtype, device.version, device.driver_version, extensions,
+ device.global_mem_size, bool(device.available), device.max_compute_units,
+ device.max_clock_frequency, flop_core, idd, workgroup)
+ pypl.add_device(pydev)
+ nb_devices += 1
+ platforms.append(pypl)
+ del platform, device, pypl, devtype, extensions, pydev
+
+ def __repr__(self):
+ out = ["OpenCL devices:"]
+ for platformid, platform in enumerate(self.platforms):
+ deviceids = ["(%s,%s) %s" % (platformid, deviceid, dev.name)
+ for deviceid, dev in enumerate(platform.devices)]
+ out.append("[%s] %s: " % (platformid, platform.name) + ", ".join(deviceids))
+ return os.linesep.join(out)
+
+ def get_platform(self, key):
+ """
+ Return a platform according
+
+ :param key: identifier for a platform, either an Id (int) or it's name
+ :type key: int or str
+ """
+ out = None
+ try:
+ platid = int(key)
+ except ValueError:
+ for a_plat in self.platforms:
+ if a_plat.name == key:
+ out = a_plat
+ else:
+ if len(self.platforms) > platid > 0:
+ out = self.platforms[platid]
+ return out
+
+ def select_device(self, dtype="ALL", memory=None, extensions=None, best=True, **kwargs):
+ """
+ Select a device based on few parameters (at the end, keep the one with most memory)
+
+ :param dtype: "gpu" or "cpu" or "all" ....
+ :param memory: minimum amount of memory (int)
+ :param extensions: list of extensions to be present
+ :param best: shall we look for the
+ :returns: A tuple of plateform ID and device ID, else None if nothing
+ found
+ """
+ if extensions is None:
+ extensions = []
+ if "type" in kwargs:
+ dtype = kwargs["type"].upper()
+ else:
+ dtype = dtype.upper()
+ if len(dtype) > 3:
+ dtype = dtype[:3]
+ best_found = None
+ for platformid, platform in enumerate(self.platforms):
+ for deviceid, device in enumerate(platform.devices):
+ if not device.available:
+ continue
+ if (dtype in ["ALL", "DEF"]) or (device.type == dtype):
+ if (memory is None) or (memory <= device.memory):
+ found = True
+ for ext in extensions:
+ if ext not in device.extensions:
+ found = False
+ if found:
+ if not best:
+ return platformid, deviceid
+ else:
+ if not best_found:
+ best_found = platformid, deviceid, device.flops
+ elif best_found[2] < device.flops:
+ best_found = platformid, deviceid, device.flops
+ if best_found:
+ return best_found[0], best_found[1]
+
+ # Nothing found
+ return None
+
+ def create_context(self, devicetype="ALL", useFp64=False, platformid=None,
+ deviceid=None, cached=True, memory=None, extensions=None):
+ """
+ Choose a device and initiate a context.
+
+ Devicetypes can be GPU,gpu,CPU,cpu,DEF,ACC,ALL.
+ Suggested are GPU,CPU.
+ For each setting to work there must be such an OpenCL device and properly installed.
+ E.g.: If Nvidia driver is installed, GPU will succeed but CPU will fail.
+ The AMD SDK kit is required for CPU via OpenCL.
+ :param devicetype: string in ["cpu","gpu", "all", "acc"]
+ :param useFp64: boolean specifying if double precision will be used: deprecated use extensions=["cl_khr_fp64"]
+ :param platformid: integer
+ :param deviceid: integer
+ :param cached: True if we want to cache the context
+ :param memory: minimum amount of memory of the device
+ :param extensions: list of extensions to be present
+ :return: OpenCL context on the selected device
+ """
+ if extensions is None:
+ extensions = []
+ if useFp64:
+ logger.warning("Deprecation: please select your device using the extension name!, i.e. extensions=['cl_khr_fp64']")
+ extensions.append('cl_khr_fp64')
+
+ if (platformid is not None) and (deviceid is not None):
+ platformid = int(platformid)
+ deviceid = int(deviceid)
+ elif "PYOPENCL_CTX" in os.environ:
+ pyopencl_ctx = [int(i) if i.isdigit() else 0 for i in os.environ["PYOPENCL_CTX"].split(":")]
+ pyopencl_ctx += [0] * (2 - len(pyopencl_ctx)) # pad with 0
+ platformid, deviceid = pyopencl_ctx
+ else:
+ ids = ocl.select_device(type=devicetype, extensions=extensions)
+ if ids:
+ platformid, deviceid = ids
+ ctx = None
+ if (platformid is not None) and (deviceid is not None):
+ if (platformid, deviceid) in self.context_cache:
+ ctx = self.context_cache[(platformid, deviceid)]
+ else:
+ try:
+ ctx = pyopencl.Context(devices=[pyopencl.get_platforms()[platformid].get_devices()[deviceid]])
+ except pyopencl._cl.LogicError as error:
+ self.platforms[platformid].devices[deviceid].set_unavailable()
+ logger.warning("Unable to create context on %s/%s: %s", platformid, deviceid, error)
+ ctx = None
+ else:
+ if cached:
+ self.context_cache[(platformid, deviceid)] = ctx
+ if ctx is None:
+ logger.warning("Last chance to get an OpenCL device ... probably not the one requested")
+ ctx = pyopencl.create_some_context(interactive=False)
+ return ctx
+
+ def device_from_context(self, context):
+ """
+ Retrieves the Device from the context
+
+ :param context: OpenCL context
+ :return: instance of Device
+ """
+ odevice = context.devices[0]
+ oplat = odevice.platform
+ device_id = oplat.get_devices().index(odevice)
+ platform_id = pyopencl.get_platforms().index(oplat)
+ return self.platforms[platform_id].devices[device_id]
+
+
+if pyopencl:
+ ocl = OpenCL()
+ if ocl.nb_devices == 0:
+ ocl = None
+else:
+ ocl = None
+
+
+def release_cl_buffers(cl_buffers):
+ """
+ :param cl_buffers: the buffer you want to release
+ :type cl_buffers: dict(str, pyopencl.Buffer)
+
+ This method release the memory of the buffers store in the dict
+ """
+ for key, buffer_ in cl_buffers.items():
+ if buffer_ is not None:
+ if isinstance(buffer_, pyopencl.array.Array):
+ try:
+ buffer_.data.release()
+ except pyopencl.LogicError:
+ logger.error("Error while freeing buffer %s", key)
+ else:
+ try:
+ buffer_.release()
+ except pyopencl.LogicError:
+ logger.error("Error while freeing buffer %s", key)
+ cl_buffers[key] = None
+ return cl_buffers
+
+
+def allocate_cl_buffers(buffers, device=None, context=None):
+ """
+ :param buffers: the buffers info use to create the pyopencl.Buffer
+ :type buffers: list(std, flag, numpy.dtype, int)
+ :param device: one of the context device
+ :param context: opencl contextdevice
+ :return: a dict containing the instanciated pyopencl.Buffer
+ :rtype: dict(str, pyopencl.Buffer)
+
+ This method instanciate the pyopencl.Buffer from the buffers
+ description.
+ """
+ mem = {}
+ if device is None:
+ device = ocl.device_from_context(context)
+
+ # check if enough memory is available on the device
+ ualloc = 0
+ for _, _, dtype, size in buffers:
+ ualloc += numpy.dtype(dtype).itemsize * size
+ memory = device.memory
+ logger.info("%.3fMB are needed on device which has %.3fMB",
+ ualloc / 1.0e6, memory / 1.0e6)
+ if ualloc >= memory:
+ memError = "Fatal error in allocate_buffers."
+ memError += "Not enough device memory for buffers"
+ memError += "(%lu requested, %lu available)" % (ualloc, memory)
+ raise MemoryError(memError) # noqa
+
+ # do the allocation
+ try:
+ for name, flag, dtype, size in buffers:
+ mem[name] = pyopencl.Buffer(context, flag,
+ numpy.dtype(dtype).itemsize * size)
+ except pyopencl.MemoryError as error:
+ release_cl_buffers(mem)
+ raise MemoryError(error)
+
+ return mem
+
+
+def allocate_texture(ctx, shape, hostbuf=None, support_1D=False):
+ """
+ Allocate an OpenCL image ("texture").
+
+ :param ctx: OpenCL context
+ :param shape: Shape of the image. Note that pyopencl and OpenCL < 1.2
+ do not support 1D images, so 1D images are handled as 2D with one row
+ :param support_1D: force the image to be 1D if the shape has only one dim
+ """
+ if len(shape) == 1 and not(support_1D):
+ shape = (1,) + shape
+ return pyopencl.Image(
+ ctx,
+ pyopencl.mem_flags.READ_ONLY | pyopencl.mem_flags.USE_HOST_PTR,
+ pyopencl.ImageFormat(
+ pyopencl.channel_order.INTENSITY,
+ pyopencl.channel_type.FLOAT
+ ),
+ hostbuf=numpy.zeros(shape[::-1], dtype=numpy.float32)
+ )
+
+
+def check_textures_availability(ctx):
+ """
+ Check whether textures are supported on the current OpenCL context.
+
+ :param ctx: OpenCL context
+ """
+ try:
+ dummy_texture = allocate_texture(ctx, (16, 16))
+ # Need to further access some attributes (pocl)
+ dummy_height = dummy_texture.height
+ textures_available = True
+ del dummy_texture, dummy_height
+ except (pyopencl.RuntimeError, pyopencl.LogicError):
+ textures_available = False
+ # Nvidia Fermi GPUs (compute capability 2.X) do not support opencl read_imagef
+ # There is no way to detect this until a kernel is compiled
+ try:
+ cc = ctx.devices[0].compute_capability_major_nv
+ textures_available &= (cc >= 3)
+ except (pyopencl.LogicError, AttributeError): # probably not a Nvidia GPU
+ pass
+ #
+ return textures_available
+
+
+def measure_workgroup_size(device):
+ """Measure the actual size of the workgroup
+
+ :param device: device or context or 2-tuple with indexes
+ :return: the actual measured workgroup size
+
+ if device is "all", returns a dict with all devices with their ids as keys.
+ """
+ if (ocl is None) or (device is None):
+ return None
+
+ if isinstance(device, tuple) and (len(device) == 2):
+ # this is probably a tuple (platformid, deviceid)
+ device = ocl.create_context(platformid=device[0], deviceid=device[1])
+
+ if device == "all":
+ res = {}
+ for pid, platform in enumerate(ocl.platforms):
+ for did, _devices in enumerate(platform.devices):
+ tup = (pid, did)
+ res[tup] = measure_workgroup_size(tup)
+ else:
+ res = _measure_workgroup_size(device)
+ return res
+
+
+def query_kernel_info(program, kernel, what="WORK_GROUP_SIZE"):
+ """Extract the compile time information from a kernel
+
+ :param program: OpenCL program
+ :param kernel: kernel or name of the kernel
+ :param what: what is the query about ?
+ :return: int or 3-int for the workgroup size.
+
+ Possible information available are:
+ * 'COMPILE_WORK_GROUP_SIZE': Returns the work-group size specified inside the kernel (__attribute__((reqd_work_gr oup_size(X, Y, Z))))
+ * 'GLOBAL_WORK_SIZE': maximum global size that can be used to execute a kernel #OCL2.1!
+ * 'LOCAL_MEM_SIZE': amount of local memory in bytes being used by the kernel
+ * 'PREFERRED_WORK_GROUP_SIZE_MULTIPLE': preferred multiple of workgroup size for launch. This is a performance hint.
+ * 'PRIVATE_MEM_SIZE' Returns the minimum amount of private memory, in bytes, used by each workitem in the kernel
+ * 'WORK_GROUP_SIZE': maximum work-group size that can be used to execute a kernel on a specific device given by device
+
+ Further information on:
+ https://www.khronos.org/registry/OpenCL/sdk/1.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html
+
+ """
+ assert isinstance(program, pyopencl.Program)
+ if not isinstance(kernel, pyopencl.Kernel):
+ kernel_name = kernel
+ assert kernel in (k.function_name for k in program.all_kernels()), "the kernel exists"
+ kernel = program.__getattr__(kernel_name)
+
+ device = program.devices[0]
+ query_wg = getattr(pyopencl.kernel_work_group_info, what)
+ return kernel.get_work_group_info(query_wg, device)
+
+
+def kernel_workgroup_size(program, kernel):
+ """Extract the compile time maximum workgroup size
+
+ :param program: OpenCL program
+ :param kernel: kernel or name of the kernel
+ :return: the maximum acceptable workgroup size for the given kernel
+ """
+ return query_kernel_info(program, kernel, what="WORK_GROUP_SIZE")
diff --git a/src/silx/opencl/conftest.py b/src/silx/opencl/conftest.py
new file mode 100644
index 0000000..1fdc516
--- /dev/null
+++ b/src/silx/opencl/conftest.py
@@ -0,0 +1,5 @@
+import pytest
+
+@pytest.mark.usefixtures("use_opencl")
+def setup_module(module):
+ pass
diff --git a/silx/opencl/convolution.py b/src/silx/opencl/convolution.py
index 15ef931..15ef931 100644
--- a/silx/opencl/convolution.py
+++ b/src/silx/opencl/convolution.py
diff --git a/silx/opencl/image.py b/src/silx/opencl/image.py
index 65e2d5e..65e2d5e 100644
--- a/silx/opencl/image.py
+++ b/src/silx/opencl/image.py
diff --git a/silx/opencl/linalg.py b/src/silx/opencl/linalg.py
index a64122a..a64122a 100644
--- a/silx/opencl/linalg.py
+++ b/src/silx/opencl/linalg.py
diff --git a/silx/opencl/medfilt.py b/src/silx/opencl/medfilt.py
index d4e425b..d4e425b 100644
--- a/silx/opencl/medfilt.py
+++ b/src/silx/opencl/medfilt.py
diff --git a/silx/opencl/processing.py b/src/silx/opencl/processing.py
index 8b81f7f..8b81f7f 100644
--- a/silx/opencl/processing.py
+++ b/src/silx/opencl/processing.py
diff --git a/silx/opencl/projection.py b/src/silx/opencl/projection.py
index c02faf6..c02faf6 100644
--- a/silx/opencl/projection.py
+++ b/src/silx/opencl/projection.py
diff --git a/silx/opencl/reconstruction.py b/src/silx/opencl/reconstruction.py
index 2c84aee..2c84aee 100644
--- a/silx/opencl/reconstruction.py
+++ b/src/silx/opencl/reconstruction.py
diff --git a/silx/opencl/setup.py b/src/silx/opencl/setup.py
index 10fb1be..10fb1be 100644
--- a/silx/opencl/setup.py
+++ b/src/silx/opencl/setup.py
diff --git a/silx/opencl/sinofilter.py b/src/silx/opencl/sinofilter.py
index d608744..d608744 100644
--- a/silx/opencl/sinofilter.py
+++ b/src/silx/opencl/sinofilter.py
diff --git a/silx/opencl/sparse.py b/src/silx/opencl/sparse.py
index 514589a..514589a 100644
--- a/silx/opencl/sparse.py
+++ b/src/silx/opencl/sparse.py
diff --git a/silx/opencl/statistics.py b/src/silx/opencl/statistics.py
index a96ee33..a96ee33 100644
--- a/silx/opencl/statistics.py
+++ b/src/silx/opencl/statistics.py
diff --git a/src/silx/opencl/test/__init__.py b/src/silx/opencl/test/__init__.py
new file mode 100644
index 0000000..92cda4a
--- /dev/null
+++ b/src/silx/opencl/test/__init__.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+#
+# Project: silx
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2012-2016 European Synchrotron Radiation Facility, Grenoble, France
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
diff --git a/src/silx/opencl/test/test_addition.py b/src/silx/opencl/test/test_addition.py
new file mode 100644
index 0000000..3b668bf
--- /dev/null
+++ b/src/silx/opencl/test/test_addition.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Project: Sift implementation in Python + OpenCL
+# https://github.com/silx-kit/silx
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+Simple test of an addition
+"""
+
+__authors__ = ["Henri Payno, Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "30/11/2020"
+
+import logging
+import numpy
+import pytest
+
+import unittest
+from ..common import ocl, _measure_workgroup_size, query_kernel_info
+if ocl:
+ import pyopencl
+ import pyopencl.array
+from ..utils import get_opencl_code
+logger = logging.getLogger(__name__)
+
+
+@unittest.skipUnless(ocl, "PyOpenCl is missing")
+class TestAddition(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestAddition, cls).setUpClass()
+ if ocl:
+ cls.ctx = ocl.create_context()
+ if logger.getEffectiveLevel() <= logging.INFO:
+ cls.PROFILE = True
+ cls.queue = pyopencl.CommandQueue(
+ cls.ctx,
+ properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+ else:
+ cls.PROFILE = False
+ cls.queue = pyopencl.CommandQueue(cls.ctx)
+ cls.max_valid_wg = 0
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestAddition, cls).tearDownClass()
+ print("Maximum valid workgroup size %s on device %s" % (cls.max_valid_wg, cls.ctx.devices[0]))
+ cls.ctx = None
+ cls.queue = None
+
+ def setUp(self):
+ if ocl is None:
+ return
+ self.shape = 4096
+ self.data = numpy.random.random(self.shape).astype(numpy.float32)
+ self.d_array_img = pyopencl.array.to_device(self.queue, self.data)
+ self.d_array_5 = pyopencl.array.empty_like(self.d_array_img)
+ self.d_array_5.fill(-5)
+ self.program = pyopencl.Program(self.ctx, get_opencl_code("addition")).build()
+
+ def tearDown(self):
+ self.img = self.data = None
+ self.d_array_img = self.d_array_5 = self.program = None
+
+ def test_add(self):
+ """
+ tests the addition kernel
+ """
+ maxi = int(round(numpy.log2(self.shape)))
+ for i in range(maxi):
+ d_array_result = pyopencl.array.empty_like(self.d_array_img)
+ wg = 1 << i
+ try:
+ evt = self.program.addition(self.queue, (self.shape,), (wg,),
+ self.d_array_img.data, self.d_array_5.data, d_array_result.data, numpy.int32(self.shape))
+ evt.wait()
+ except Exception as error:
+ max_valid_wg = self.program.addition.get_work_group_info(pyopencl.kernel_work_group_info.WORK_GROUP_SIZE, self.ctx.devices[0])
+ msg = "Error %s on WG=%s: %s" % (error, wg, max_valid_wg)
+ self.assertLess(max_valid_wg, wg, msg)
+ break
+ else:
+ res = d_array_result.get()
+ good = numpy.allclose(res, self.data - 5)
+ if good and wg > self.max_valid_wg:
+ self.__class__.max_valid_wg = wg
+ self.assertTrue(good, "calculation is correct for WG=%s" % wg)
+
+ def test_measurement(self):
+ """
+ tests that all devices are working properly ... lengthy and error prone
+ """
+ for platform in ocl.platforms:
+ for did, device in enumerate(platform.devices):
+ meas = _measure_workgroup_size((platform.id, device.id))
+ self.assertEqual(meas, device.max_work_group_size,
+ "Workgroup size for %s/%s: %s == %s" % (platform, device, meas, device.max_work_group_size))
+
+ def test_query(self):
+ """
+ tests that all devices are working properly ... lengthy and error prone
+ """
+ for what in ("COMPILE_WORK_GROUP_SIZE",
+ "LOCAL_MEM_SIZE",
+ "PREFERRED_WORK_GROUP_SIZE_MULTIPLE",
+ "PRIVATE_MEM_SIZE",
+ "WORK_GROUP_SIZE"):
+ logger.info("%s: %s", what, query_kernel_info(program=self.program, kernel="addition", what=what))
+
+ # Not all ICD work properly ....
+ #self.assertEqual(3, len(query_kernel_info(program=self.program, kernel="addition", what="COMPILE_WORK_GROUP_SIZE")), "3D kernel")
+
+ min_wg = query_kernel_info(program=self.program, kernel="addition", what="PREFERRED_WORK_GROUP_SIZE_MULTIPLE")
+ max_wg = query_kernel_info(program=self.program, kernel="addition", what="WORK_GROUP_SIZE")
+ self.assertEqual(max_wg % min_wg, 0, msg="max_wg is a multiple of min_wg")
diff --git a/src/silx/opencl/test/test_array_utils.py b/src/silx/opencl/test/test_array_utils.py
new file mode 100644
index 0000000..325a6c3
--- /dev/null
+++ b/src/silx/opencl/test/test_array_utils.py
@@ -0,0 +1,152 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the OpenCL array_utils"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Pierre paleo"]
+__license__ = "MIT"
+__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "14/06/2017"
+
+
+import time
+import logging
+import numpy as np
+import unittest
+try:
+ import mako
+except ImportError:
+ mako = None
+from ..common import ocl
+if ocl:
+ import pyopencl as cl
+ import pyopencl.array as parray
+ from .. import linalg
+from ..utils import get_opencl_code
+from silx.test.utils import utilstest
+
+logger = logging.getLogger(__name__)
+try:
+ from scipy.ndimage.filters import laplace
+ _has_scipy = True
+except ImportError:
+ _has_scipy = False
+
+
+
+@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
+class TestCpy2d(unittest.TestCase):
+
+ def setUp(self):
+ if ocl is None:
+ return
+ self.ctx = ocl.create_context()
+ if logger.getEffectiveLevel() <= logging.INFO:
+ self.PROFILE = True
+ self.queue = cl.CommandQueue(
+ self.ctx,
+ properties=cl.command_queue_properties.PROFILING_ENABLE)
+ else:
+ self.PROFILE = False
+ self.queue = cl.CommandQueue(self.ctx)
+ self.allocate_arrays()
+ self.program = cl.Program(self.ctx, get_opencl_code("array_utils")).build()
+
+ def allocate_arrays(self):
+ """
+ Allocate various types of arrays for the tests
+ """
+ self.prng_state = np.random.get_state()
+ # Generate arrays of random shape
+ self.shape1 = np.random.randint(20, high=512, size=(2,))
+ self.shape2 = np.random.randint(20, high=512, size=(2,))
+ self.array1 = np.random.rand(*self.shape1).astype(np.float32)
+ self.array2 = np.random.rand(*self.shape2).astype(np.float32)
+ self.d_array1 = parray.to_device(self.queue, self.array1)
+ self.d_array2 = parray.to_device(self.queue, self.array2)
+ # Generate random offsets
+ offset1_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - 10)
+ offset1_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - 10)
+ offset2_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - 10)
+ offset2_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - 10)
+ self.offset1 = (offset1_y, offset1_x)
+ self.offset2 = (offset2_y, offset2_x)
+ # Compute the size of the rectangle to transfer
+ size_y = np.random.randint(2, high=min(self.shape1[0], self.shape2[0]) - max(offset1_y, offset2_y) + 1)
+ size_x = np.random.randint(2, high=min(self.shape1[1], self.shape2[1]) - max(offset1_x, offset2_x) + 1)
+ self.transfer_shape = (size_y, size_x)
+
+ def tearDown(self):
+ self.array1 = None
+ self.array2 = None
+ self.d_array1.data.release()
+ self.d_array2.data.release()
+ self.d_array1 = None
+ self.d_array2 = None
+ self.ctx = None
+ self.queue = None
+
+ def compare(self, result, reference):
+ errmax = np.max(np.abs(result - reference))
+ logger.info("Max error = %e" % (errmax))
+ self.assertTrue(errmax == 0, str("Max error is too high"))#. PRNG state was %s" % str(self.prng_state)))
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_cpy2d(self):
+ """
+ Test rectangular transfer of self.d_array1 to self.d_array2
+ """
+ # Reference
+ o1 = self.offset1
+ o2 = self.offset2
+ T = self.transfer_shape
+ logger.info("""Testing D->D rectangular copy with (N1_y, N1_x) = %s,
+ (N2_y, N2_x) = %s:
+ array2[%d:%d, %d:%d] = array1[%d:%d, %d:%d]""" %
+ (
+ str(self.shape1), str(self.shape2),
+ o2[0], o2[0] + T[0],
+ o2[1], o2[1] + T[1],
+ o1[0], o1[0] + T[0],
+ o1[1], o1[1] + T[1]
+ )
+ )
+ self.array2[o2[0]:o2[0] + T[0], o2[1]:o2[1] + T[1]] = self.array1[o1[0]:o1[0] + T[0], o1[1]:o1[1] + T[1]]
+ kernel_args = (
+ self.d_array2.data,
+ self.d_array1.data,
+ np.int32(self.shape2[1]),
+ np.int32(self.shape1[1]),
+ np.int32(self.offset2[::-1]),
+ np.int32(self.offset1[::-1]),
+ np.int32(self.transfer_shape[::-1])
+ )
+ wg = None
+ ndrange = self.transfer_shape[::-1]
+ self.program.cpy2d(self.queue, ndrange, wg, *kernel_args)
+ res = self.d_array2.get()
+ self.compare(res, self.array2)
diff --git a/src/silx/opencl/test/test_backprojection.py b/src/silx/opencl/test/test_backprojection.py
new file mode 100644
index 0000000..96d56fa
--- /dev/null
+++ b/src/silx/opencl/test/test_backprojection.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the filtered backprojection module"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Pierre paleo"]
+__license__ = "MIT"
+__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "19/01/2018"
+
+
+import time
+import logging
+import numpy as np
+import unittest
+from math import pi
+try:
+ import mako
+except ImportError:
+ mako = None
+from ..common import ocl
+if ocl:
+ from .. import backprojection
+ from ...image.tomography import compute_fourier_filter
+from silx.test.utils import utilstest
+
+logger = logging.getLogger(__name__)
+
+
+def generate_coords(img_shp, center=None):
+ """
+ Return two 2D arrays containing the indexes of an image.
+ The zero is at the center of the image.
+ """
+ l_r, l_c = float(img_shp[0]), float(img_shp[1])
+ R, C = np.mgrid[:l_r, :l_c]
+ if center is None:
+ center0, center1 = l_r / 2., l_c / 2.
+ else:
+ center0, center1 = center
+ R = R + 0.5 - center0
+ C = C + 0.5 - center1
+ return R, C
+
+
+def clip_circle(img, center=None, radius=None):
+ """
+ Puts zeros outside the inscribed circle of the image support.
+ """
+ R, C = generate_coords(img.shape, center)
+ M = R * R + C * C
+ res = np.zeros_like(img)
+ if radius is None:
+ radius = img.shape[0] / 2. - 1
+ mask = M < radius * radius
+ res[mask] = img[mask]
+ return res
+
+
+@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
+class TestFBP(unittest.TestCase):
+
+ def setUp(self):
+ if ocl is None:
+ return
+ self.getfiles()
+ self.fbp = backprojection.Backprojection(self.sino.shape, profile=True)
+ if self.fbp.compiletime_workgroup_size < 16 * 16:
+ self.skipTest("Current implementation of OpenCL backprojection is "
+ "not supported on this platform yet")
+ # Astra does not use the same backprojector implementation.
+ # Therefore, we cannot expect results to be the "same" (up to float32
+ # numerical error)
+ self.tol = 5e-2
+ if not(self.fbp._use_textures) or self.fbp.device.type == "CPU":
+ # Precision is less when using CPU
+ # (either CPU textures or "manual" linear interpolation)
+ self.tol *= 2
+
+ def tearDown(self):
+ self.sino = None
+ # self.fbp.log_profile()
+ self.fbp = None
+
+ def getfiles(self):
+ # load sinogram of 512x512 MRI phantom
+ self.sino = np.load(utilstest.getfile("sino500.npz"))["data"]
+ # load reconstruction made with ASTRA FBP (with filter designed in spatial domain)
+ self.reference_rec = np.load(utilstest.getfile("rec_astra_500.npz"))["data"]
+
+ def measure(self):
+ "Common measurement of timings"
+ t1 = time.time()
+ try:
+ result = self.fbp.filtered_backprojection(self.sino)
+ except RuntimeError as msg:
+ logger.error(msg)
+ return
+ t2 = time.time()
+ return t2 - t1, result
+
+ def compare(self, res):
+ """
+ Compare a result with the reference reconstruction.
+ Only the valid reconstruction zone (inscribed circle) is taken into
+ account
+ """
+ res_clipped = clip_circle(res)
+ ref_clipped = clip_circle(self.reference_rec)
+ delta = abs(res_clipped - ref_clipped)
+ bad = delta > 1
+ logger.debug("Absolute difference: %s with %s outlier pixels out of %s"
+ "", delta.max(), bad.sum(), np.prod(bad.shape))
+ return delta.max()
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_fbp(self):
+ """
+ tests FBP
+ """
+ # Test single reconstruction
+ # --------------------------
+ t, res = self.measure()
+ if t is None:
+ logger.info("test_fp: skipped")
+ else:
+ logger.info("test_backproj: time = %.3fs" % t)
+ err = self.compare(res)
+ msg = str("Max error = %e" % err)
+ logger.info(msg)
+ self.assertTrue(err < self.tol, "Max error is too high")
+
+ # Test multiple reconstructions
+ # -----------------------------
+ res0 = np.copy(res)
+ for i in range(10):
+ res = self.fbp.filtered_backprojection(self.sino)
+ errmax = np.max(np.abs(res - res0))
+ self.assertTrue(errmax < 1.e-6, "Max error is too high")
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_fbp_filters(self):
+ """
+ Test the different available filters of silx FBP.
+ """
+ avail_filters = [
+ "ramlak", "shepp-logan", "cosine", "hamming",
+ "hann"
+ ]
+ # Create a Dirac delta function at a single angle view.
+ # As the filters are radially invarant:
+ # - backprojection yields an image where each line is a Dirac.
+ # - FBP yields an image where each line is the spatial filter
+ # One can simply filter "dirac" without backprojecting it, but this
+ # test will also ensure that backprojection behaves well.
+ dirac = np.zeros_like(self.sino)
+ na, dw = dirac.shape
+ dirac[0, dw//2] = na / pi * 2
+
+ for filter_name in avail_filters:
+ B = backprojection.Backprojection(dirac.shape, filter_name=filter_name)
+ r = B(dirac)
+ # Check that radial invariance is kept
+ std0 = np.max(np.abs(np.std(r, axis=0)))
+ self.assertTrue(
+ std0 < 5.e-6,
+ "Something wrong with FBP(filter=%s)" % filter_name
+ )
+ # Check that the filter is retrieved
+ r_f = np.fft.fft(np.fft.fftshift(r[0])).real / 2. # filter factor
+ ref_filter_f = compute_fourier_filter(dw, filter_name)
+ errmax = np.max(np.abs(r_f - ref_filter_f))
+ logger.info("FBP filter %s: max error=%e" % (filter_name, errmax))
+ self.assertTrue(
+ errmax < 1.e-3,
+ "Something wrong with FBP(filter=%s)" % filter_name
+ )
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_fbp_oddsize(self):
+ # Generate a 513-sinogram.
+ # The padded width will be nextpow(513*2).
+ # silx [0.10, 0.10.1] will give 1029, which makes R2C transform fail.
+ sino = np.pad(self.sino, ((0, 0), (1, 0)), mode='edge')
+ B = backprojection.Backprojection(sino.shape, axis_position=self.fbp.axis_pos+1)
+ res = B(sino)
+ # Compare with self.reference_rec. Tolerance is high as backprojector
+ # is not fully shift-invariant.
+ errmax = np.max(np.abs(clip_circle(res[1:, 1:] - self.reference_rec)))
+ self.assertLess(
+ errmax, 1.e-1,
+ "Something wrong with FBP on odd-sized sinogram"
+ )
diff --git a/src/silx/opencl/test/test_convolution.py b/src/silx/opencl/test/test_convolution.py
new file mode 100644
index 0000000..6a2759d
--- /dev/null
+++ b/src/silx/opencl/test/test_convolution.py
@@ -0,0 +1,280 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+"""
+Test of the Convolution class.
+"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Pierre Paleo"]
+__contact__ = "pierre.paleo@esrf.fr"
+__license__ = "MIT"
+__copyright__ = "2019 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "01/08/2019"
+
+import pytest
+import logging
+from itertools import product
+import numpy as np
+from silx.image.utils import gaussian_kernel
+
+try:
+ from scipy.ndimage import convolve, convolve1d
+ from scipy.misc import ascent
+
+ scipy_convolve = convolve
+ scipy_convolve1d = convolve1d
+except ImportError:
+ scipy_convolve = None
+import unittest
+from ..common import ocl, check_textures_availability
+
+if ocl:
+ import pyopencl as cl
+ import pyopencl.array as parray
+ from silx.opencl.convolution import Convolution
+logger = logging.getLogger(__name__)
+
+
+class ConvolutionData:
+
+ def __init__(self, param):
+ self.param = param
+ self.mode = param["boundary_handling"]
+ logger.debug(
+ """
+ Testing convolution with boundary_handling=%s,
+ use_textures=%s, input_device=%s, output_device=%s
+ """
+ % (
+ self.mode,
+ param["use_textures"],
+ param["input_on_device"],
+ param["output_on_device"],
+ )
+ )
+
+ @classmethod
+ def setUpClass(cls):
+ cls.image = np.ascontiguousarray(ascent()[:, :511], dtype="f")
+ cls.data1d = cls.image[0]
+ cls.data2d = cls.image
+ cls.data3d = np.tile(cls.image[224:-224, 224:-224], (62, 1, 1))
+ cls.kernel1d = gaussian_kernel(1.0)
+ cls.kernel2d = np.outer(cls.kernel1d, cls.kernel1d)
+ cls.kernel3d = np.multiply.outer(cls.kernel2d, cls.kernel1d)
+ cls.ctx = ocl.create_context()
+ cls.tol = {
+ "1D": 1e-4,
+ "2D": 1e-3,
+ "3D": 1e-3,
+ }
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.data1d = cls.data2d = cls.data3d = cls.image = None
+ cls.kernel1d = cls.kernel2d = cls.kernel3d = None
+
+ @staticmethod
+ def compare(arr1, arr2):
+ return np.max(np.abs(arr1 - arr2))
+
+ @staticmethod
+ def print_err(conv):
+ errmsg = str(
+ """
+ Something wrong with %s
+ mode=%s, texture=%s
+ """
+ % (conv.use_case_desc, conv.mode, conv.use_textures)
+ )
+ return errmsg
+
+ def instantiate_convol(self, shape, kernel, axes=None):
+ if self.mode == "constant":
+ if not (self.param["use_textures"]) or (
+ self.param["use_textures"]
+ and not (check_textures_availability(self.ctx))
+ ):
+ pytest.skip("mode=constant not implemented without textures")
+ C = Convolution(
+ shape,
+ kernel,
+ mode=self.mode,
+ ctx=self.ctx,
+ axes=axes,
+ extra_options={"dont_use_textures": not (self.param["use_textures"])},
+ )
+ return C
+
+ def get_data_and_kernel(self, test_name):
+ dims = {
+ "test_1D": (1, 1),
+ "test_separable_2D": (2, 1),
+ "test_separable_3D": (3, 1),
+ "test_nonseparable_2D": (2, 2),
+ "test_nonseparable_3D": (3, 3),
+ }
+ dim_data = {1: self.data1d, 2: self.data2d, 3: self.data3d}
+ dim_kernel = {
+ 1: self.kernel1d,
+ 2: self.kernel2d,
+ 3: self.kernel3d,
+ }
+ dd, kd = dims[test_name]
+ return dim_data[dd], dim_kernel[kd]
+
+ def get_reference_function(self, test_name):
+ ref_func = {
+ "test_1D": lambda x, y: scipy_convolve1d(x, y, mode=self.mode),
+ "test_separable_2D": lambda x, y: scipy_convolve1d(
+ scipy_convolve1d(x, y, mode=self.mode, axis=1),
+ y,
+ mode=self.mode,
+ axis=0,
+ ),
+ "test_separable_3D": lambda x, y: scipy_convolve1d(
+ scipy_convolve1d(
+ scipy_convolve1d(x, y, mode=self.mode, axis=2),
+ y,
+ mode=self.mode,
+ axis=1,
+ ),
+ y,
+ mode=self.mode,
+ axis=0,
+ ),
+ "test_nonseparable_2D": lambda x, y: scipy_convolve(x, y, mode=self.mode),
+ "test_nonseparable_3D": lambda x, y: scipy_convolve(x, y, mode=self.mode),
+ }
+ return ref_func[test_name]
+
+ def template_test(self, test_name):
+ data, kernel = self.get_data_and_kernel(test_name)
+ conv = self.instantiate_convol(data.shape, kernel)
+ if self.param["input_on_device"]:
+ data_ref = parray.to_device(conv.queue, data)
+ else:
+ data_ref = data
+ if self.param["output_on_device"]:
+ d_res = parray.empty_like(conv.data_out)
+ d_res.fill(0)
+ res = d_res
+ else:
+ res = None
+ res = conv(data_ref, output=res)
+ if self.param["output_on_device"]:
+ res = res.get()
+ ref_func = self.get_reference_function(test_name)
+ ref = ref_func(data, kernel)
+ metric = self.compare(res, ref)
+ logger.info("%s: max error = %.2e" % (test_name, metric))
+ tol = self.tol[str("%dD" % kernel.ndim)]
+ assert metric < tol, self.print_err(conv)
+
+
+def convolution_data_params():
+ boundary_handlings = ["reflect", "nearest", "wrap", "constant"]
+ use_textures = [True, False]
+ input_on_devices = [True, False]
+ output_on_devices = [True, False]
+ param_vals = list(
+ product(boundary_handlings, use_textures, input_on_devices, output_on_devices)
+ )
+ params = []
+ for boundary_handling, use_texture, input_dev, output_dev in param_vals:
+ param={
+ "boundary_handling": boundary_handling,
+ "input_on_device": input_dev,
+ "output_on_device": output_dev,
+ "use_textures": use_texture,
+ }
+ params.append(param)
+
+ return params
+
+
+@pytest.fixture(scope="module", params=convolution_data_params())
+def convolution_data(request):
+ """Provide a set of convolution data
+
+ The module scope allows to test each function during a single setup of each
+ convolution data
+ """
+ cdata = None
+ try:
+ cdata = ConvolutionData(request.param)
+ cdata.setUpClass()
+ yield cdata
+ finally:
+ cdata.tearDownClass()
+
+
+@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
+@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
+def test_1D(convolution_data):
+ convolution_data.template_test("test_1D")
+
+@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
+@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
+def test_separable_2D(convolution_data):
+ convolution_data.template_test("test_separable_2D")
+
+@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
+@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
+def test_separable_3D(convolution_data):
+ convolution_data.template_test("test_separable_3D")
+
+@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
+@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
+def test_nonseparable_2D(convolution_data):
+ convolution_data.template_test("test_nonseparable_2D")
+
+@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
+@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
+def test_nonseparable_3D(convolution_data):
+ convolution_data.template_test("test_nonseparable_3D")
+
+@pytest.mark.skipif(ocl is None, reason="OpenCL is missing")
+@pytest.mark.skipif(scipy_convolve is None, reason="scipy is missing")
+def test_batched_2D(convolution_data):
+ """
+ Test batched (nonseparable) 2D convolution on 3D data.
+ In this test: batch along "z" (axis 0)
+ """
+ data = convolution_data.data3d
+ kernel = convolution_data.kernel2d
+ conv = convolution_data.instantiate_convol(data.shape, kernel, axes=(0,))
+ res = conv(data) # 3D
+ ref = scipy_convolve(data[0], kernel, mode=convolution_data.mode) # 2D
+
+ std = np.std(res, axis=0)
+ std_max = np.max(np.abs(std))
+ assert std_max < convolution_data.tol["2D"], convolution_data.print_err(conv)
+ metric = convolution_data.compare(res[0], ref)
+ logger.info("test_nonseparable_3D: max error = %.2e" % metric)
+ assert metric < convolution_data.tol["2D"], convolution_data.print_err(conv)
diff --git a/src/silx/opencl/test/test_doubleword.py b/src/silx/opencl/test/test_doubleword.py
new file mode 100644
index 0000000..a33cf5a
--- /dev/null
+++ b/src/silx/opencl/test/test_doubleword.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+# coding: utf-8
+#
+# Project: The silx project
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2021-2021 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+"test suite for OpenCL code"
+
+__author__ = "Jérôme Kieffer"
+__contact__ = "Jerome.Kieffer@ESRF.eu"
+__license__ = "MIT"
+__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "31/05/2021"
+
+import unittest
+import numpy
+import logging
+import platform
+
+logger = logging.getLogger(__name__)
+try:
+ import pyopencl
+except ImportError as error:
+ logger.warning("OpenCL module (pyopencl) is not present, skip tests. %s.", error)
+ pyopencl = None
+
+from .. import ocl
+if ocl is not None:
+ from ..utils import read_cl_file
+ from .. import pyopencl
+ import pyopencl.array
+ from pyopencl.elementwise import ElementwiseKernel
+
+EPS32 = numpy.finfo("float32").eps
+EPS64 = numpy.finfo("float64").eps
+
+
+@unittest.skipUnless(ocl, "PyOpenCl is missing")
+class TestDoubleWord(unittest.TestCase):
+ """
+ Test the kernels for compensated math in OpenCL
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ if pyopencl is None or ocl is None:
+ raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
+
+ cls.ctx = ocl.create_context(devicetype="GPU")
+ cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+
+ # this is running 32 bits OpenCL woth POCL
+ if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
+ cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
+ cls.args = "-DX87_VOLATILE=volatile"
+ else:
+ cls.args = ""
+ size = 1024
+ cls.a = 1.0 + numpy.random.random(size)
+ cls.b = 1.0 + numpy.random.random(size)
+ cls.ah = cls.a.astype(numpy.float32)
+ cls.bh = cls.b.astype(numpy.float32)
+ cls.al = (cls.a - cls.ah).astype(numpy.float32)
+ cls.bl = (cls.b - cls.bh).astype(numpy.float32)
+ cls.doubleword = read_cl_file("doubleword.cl")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.queue = None
+ cls.ctx = None
+ cls.a = cls.al = cls.ah = None
+ cls.b = cls.bl = cls.bh = None
+ cls.doubleword = None
+
+ def test_fast_sum2(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fast_fp_plus_fp(a[i], b[i]); res_h[i] = tmp.s0; res_l[i] = tmp.s1",
+ preamble=self.doubleword)
+ a_g = pyopencl.array.to_device(self.queue, self.ah)
+ b_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(a_g)
+ res_h = pyopencl.array.empty_like(a_g)
+ test_kernel(a_g, b_g, res_h, res_l)
+ self.assertEqual(abs(self.ah + self.bl - res_h.get()).max(), 0, "Major matches")
+ self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bl - res_h.get()).max(), 0, "Exact mismatches")
+ self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bl - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
+
+ def test_sum2(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fp_plus_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ a_g = pyopencl.array.to_device(self.queue, self.ah)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(a_g)
+ res_h = pyopencl.array.empty_like(a_g)
+ test_kernel(a_g, b_g, res_h, res_l)
+ self.assertEqual(abs(self.ah + self.bh - res_h.get()).max(), 0, "Major matches")
+ self.assertGreater(abs(self.ah.astype(numpy.float64) + self.bh - res_h.get()).max(), 0, "Exact mismatches")
+ self.assertEqual(abs(self.ah.astype(numpy.float64) + self.bh - (res_h.get().astype(numpy.float64) + res_l.get())).max(), 0, "Exact matches")
+
+ def test_prod2(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *a, float *b, float *res_h, float *res_l",
+ "float2 tmp = fp_times_fp(a[i],b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ a_g = pyopencl.array.to_device(self.queue, self.ah)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(a_g)
+ res_h = pyopencl.array.empty_like(a_g)
+ test_kernel(a_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertEqual(abs(self.ah * self.bh - res_m).max(), 0, "Major matches")
+ self.assertGreater(abs(self.ah.astype(numpy.float64) * self.bh - res_m).max(), 0, "Exact mismatches")
+ self.assertEqual(abs(self.ah.astype(numpy.float64) * self.bh - res).max(), 0, "Exact matches")
+
+ def test_dw_plus_fp(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_plus_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(b_g)
+ res_h = pyopencl.array.empty_like(b_g)
+ test_kernel(ah_g, al_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a + self.bh - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a + self.bh - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.ah.astype(numpy.float64) + self.al + self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_plus_dw(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_plus_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ bh_g = pyopencl.array.to_device(self.queue, self.bh)
+ bl_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(bh_g)
+ res_h = pyopencl.array.empty_like(bh_g)
+ test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a + self.b - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a + self.b - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a + self.b - res).max(), 3 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_times_fp(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_times_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(b_g)
+ res_h = pyopencl.array.empty_like(b_g)
+ test_kernel(ah_g, al_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a * self.bh - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a * self.bh - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a * self.bh - res).max(), 2 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_times_dw(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_times_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ bh_g = pyopencl.array.to_device(self.queue, self.bh)
+ bl_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(bh_g)
+ res_h = pyopencl.array.empty_like(bh_g)
+ test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a * self.b - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a * self.b - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a * self.b - res).max(), 5 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_div_fp(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *b, float *res_h, float *res_l",
+ "float2 tmp = dw_div_fp((float2)(ah[i], al[i]),b[i]); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ b_g = pyopencl.array.to_device(self.queue, self.bh)
+ res_l = pyopencl.array.empty_like(b_g)
+ res_h = pyopencl.array.empty_like(b_g)
+ test_kernel(ah_g, al_g, b_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a / self.bh - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a / self.bh - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a / self.bh - res).max(), 3 * EPS32 ** 2, "Exact matches")
+
+ def test_dw_div_dw(self):
+ test_kernel = ElementwiseKernel(self.ctx,
+ "float *ah, float *al, float *bh, float *bl, float *res_h, float *res_l",
+ "float2 tmp = dw_div_dw((float2)(ah[i], al[i]),(float2)(bh[i], bl[i])); res_h[i]=tmp.s0; res_l[i]=tmp.s1;",
+ preamble=self.doubleword)
+ ah_g = pyopencl.array.to_device(self.queue, self.ah)
+ al_g = pyopencl.array.to_device(self.queue, self.al)
+ bh_g = pyopencl.array.to_device(self.queue, self.bh)
+ bl_g = pyopencl.array.to_device(self.queue, self.bl)
+ res_l = pyopencl.array.empty_like(bh_g)
+ res_h = pyopencl.array.empty_like(bh_g)
+ test_kernel(ah_g, al_g, bh_g, bl_g, res_h, res_l)
+ res_m = res_h.get()
+ res = res_h.get().astype(numpy.float64) + res_l.get()
+ self.assertLess(abs(self.a / self.b - res_m).max(), EPS32, "Major matches")
+ self.assertGreater(abs(self.a / self.b - res_m).max(), EPS64, "Exact mismatches")
+ self.assertLess(abs(self.a / self.b - res).max(), 6 * EPS32 ** 2, "Exact matches")
diff --git a/src/silx/opencl/test/test_image.py b/src/silx/opencl/test/test_image.py
new file mode 100644
index 0000000..73c771b
--- /dev/null
+++ b/src/silx/opencl/test/test_image.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Project: image manipulation in OpenCL
+# https://github.com/silx-kit/silx
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+Simple test of image manipulation
+"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "13/02/2018"
+
+import logging
+import numpy
+
+import unittest
+from ..common import ocl, _measure_workgroup_size
+if ocl:
+ import pyopencl
+ import pyopencl.array
+from ...test.utils import utilstest
+from ..image import ImageProcessing
+logger = logging.getLogger(__name__)
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+@unittest.skipUnless(ocl and Image, "PyOpenCl/Image is missing")
+class TestImage(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestImage, cls).setUpClass()
+ if ocl:
+ cls.ctx = ocl.create_context()
+ cls.lena = utilstest.getfile("lena.png")
+ cls.data = numpy.asarray(Image.open(cls.lena))
+ cls.ip = ImageProcessing(ctx=cls.ctx, template=cls.data, profile=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestImage, cls).tearDownClass()
+ cls.ctx = None
+ cls.lena = None
+ cls.data = None
+ if logger.level <= logging.INFO:
+ logger.warning("\n".join(cls.ip.log_profile()))
+ cls.ip = None
+
+ def setUp(self):
+ if ocl is None:
+ return
+ self.data = numpy.asarray(Image.open(self.lena))
+
+ def tearDown(self):
+ self.img = self.data = None
+
+ @unittest.skipUnless(ocl, "pyopencl is missing")
+ def test_cast(self):
+ """
+ tests the cast kernel
+ """
+ res = self.ip.to_float(self.data)
+ self.assertEqual(res.shape, self.data.shape, "shape")
+ self.assertEqual(res.dtype, numpy.float32, "dtype")
+ self.assertEqual(abs(res - self.data).max(), 0, "content")
+
+ @unittest.skipUnless(ocl, "pyopencl is missing")
+ def test_normalize(self):
+ """
+ tests that all devices are working properly ...
+ """
+ tmp = pyopencl.array.empty(self.ip.ctx, self.data.shape, "float32")
+ res = self.ip.to_float(self.data, out=tmp)
+ res2 = self.ip.normalize(tmp, -100, 100, copy=False)
+ norm = (self.data.astype(numpy.float32) - self.data.min()) / (self.data.max() - self.data.min())
+ ref2 = 200 * norm - 100
+ self.assertLess(abs(res2 - ref2).max(), 3e-5, "content")
+
+ @unittest.skipUnless(ocl, "pyopencl is missing")
+ def test_histogram(self):
+ """
+ Test on a greyscaled image ... of Lena :)
+ """
+ lena_bw = (0.2126 * self.data[:, :, 0] +
+ 0.7152 * self.data[:, :, 1] +
+ 0.0722 * self.data[:, :, 2]).astype("int32")
+ ref = numpy.histogram(lena_bw, 255)
+ ip = ImageProcessing(ctx=self.ctx, template=lena_bw, profile=True)
+ res = ip.histogram(lena_bw, 255)
+ ip.log_profile()
+ delta = (ref[0] - res[0])
+ deltap = (ref[1] - res[1])
+ self.assertEqual(delta.sum(), 0, "errors are self-compensated")
+ self.assertLessEqual(abs(delta).max(), 1, "errors are small")
+ self.assertLessEqual(abs(deltap).max(), 3e-5, "errors on position are small: %s" % (abs(deltap).max()))
diff --git a/src/silx/opencl/test/test_kahan.py b/src/silx/opencl/test/test_kahan.py
new file mode 100644
index 0000000..9e4a1e3
--- /dev/null
+++ b/src/silx/opencl/test/test_kahan.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python
+# coding: utf-8
+#
+# Project: OpenCL numerical library
+# https://github.com/silx-kit/silx
+#
+# Copyright (C) 2015-2021 European Synchrotron Radiation Facility, Grenoble, France
+#
+# Principal author: Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+"test suite for OpenCL code"
+
+__author__ = "Jérôme Kieffer"
+__contact__ = "Jerome.Kieffer@ESRF.eu"
+__license__ = "MIT"
+__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "17/05/2021"
+
+
+import unittest
+import numpy
+import logging
+import platform
+
+logger = logging.getLogger(__name__)
+try:
+ import pyopencl
+except ImportError as error:
+ logger.warning("OpenCL module (pyopencl) is not present, skip tests. %s.", error)
+ pyopencl = None
+
+from .. import ocl
+if ocl is not None:
+ from ..utils import read_cl_file
+ from .. import pyopencl
+ import pyopencl.array
+
+
+class TestKahan(unittest.TestCase):
+ """
+ Test the kernels for compensated math in OpenCL
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ if pyopencl is None or ocl is None:
+ raise unittest.SkipTest("OpenCL module (pyopencl) is not present or no device available")
+
+ cls.ctx = ocl.create_context(devicetype="GPU")
+ cls.queue = pyopencl.CommandQueue(cls.ctx, properties=pyopencl.command_queue_properties.PROFILING_ENABLE)
+
+ # this is running 32 bits OpenCL woth POCL
+ if (platform.machine() in ("i386", "i686", "x86_64") and (tuple.__itemsize__ == 4) and
+ cls.ctx.devices[0].platform.name == 'Portable Computing Language'):
+ cls.args = "-DX87_VOLATILE=volatile"
+ else:
+ cls.args = ""
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.queue = None
+ cls.ctx = None
+
+ @staticmethod
+ def dummy_sum(ary, dtype=None):
+ "perform the actual sum in a dummy way "
+ if dtype is None:
+ dtype = ary.dtype.type
+ sum_ = dtype(0)
+ for i in ary:
+ sum_ += i
+ return sum_
+
+ def test_kahan(self):
+ # simple test
+ N = 26
+ data = (1 << (N - 1 - numpy.arange(N))).astype(numpy.float32)
+
+ ref64 = numpy.sum(data, dtype=numpy.float64)
+ ref32 = self.dummy_sum(data)
+ if (ref64 == ref32):
+ logger.warning("Kahan: invalid tests as float32 provides the same result as float64")
+ # Dummy kernel to evaluate
+ src = """
+ kernel void summation(global float* data,
+ int size,
+ global float* result)
+ {
+ float2 acc = (float2)(0.0f, 0.0f);
+ for (int i=0; i<size; i++)
+ {
+ acc = kahan_sum(acc, data[i]);
+ }
+ result[0] = acc.s0;
+ result[1] = acc.s1;
+ }
+ """
+ prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(self.args)
+ ones_d = pyopencl.array.to_device(self.queue, data)
+ res_d = pyopencl.array.empty(self.queue, 2, numpy.float32)
+ res_d.fill(0)
+ evt = prg.summation(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt.wait()
+ res = res_d.get().sum(dtype=numpy.float64)
+ self.assertEqual(ref64, res, "test_kahan")
+
+ def test_dot16(self):
+ # simple test
+ N = 16
+ data = (1 << (N - 1 - numpy.arange(N))).astype(numpy.float32)
+
+ ref64 = numpy.dot(data.astype(numpy.float64), data.astype(numpy.float64))
+ ref32 = numpy.dot(data, data)
+ if (ref64 == ref32):
+ logger.warning("dot16: invalid tests as float32 provides the same result as float64")
+ # Dummy kernel to evaluate
+ src = """
+ kernel void test_dot16(global float* data,
+ int size,
+ global float* result)
+ {
+ float2 acc = (float2)(0.0f, 0.0f);
+ float16 data16 = (float16) (data[0],data[1],data[2],data[3],data[4],
+ data[5],data[6],data[7],data[8],data[9],
+ data[10],data[11],data[12],data[13],data[14],data[15]);
+ acc = comp_dot16(data16, data16);
+ result[0] = acc.s0;
+ result[1] = acc.s1;
+ }
+
+ kernel void test_dot8(global float* data,
+ int size,
+ global float* result)
+ {
+ float2 acc = (float2)(0.0f, 0.0f);
+ float8 data0 = (float8) (data[0],data[2],data[4],data[6],data[8],data[10],data[12],data[14]);
+ float8 data1 = (float8) (data[1],data[3],data[5],data[7],data[9],data[11],data[13],data[15]);
+ acc = comp_dot8(data0, data1);
+ result[0] = acc.s0;
+ result[1] = acc.s1;
+ }
+
+ kernel void test_dot4(global float* data,
+ int size,
+ global float* result)
+ {
+ float2 acc = (float2)(0.0f, 0.0f);
+ float4 data0 = (float4) (data[0],data[4],data[8],data[12]);
+ float4 data1 = (float4) (data[3],data[7],data[11],data[15]);
+ acc = comp_dot4(data0, data1);
+ result[0] = acc.s0;
+ result[1] = acc.s1;
+ }
+
+ kernel void test_dot3(global float* data,
+ int size,
+ global float* result)
+ {
+ float2 acc = (float2)(0.0f, 0.0f);
+ float3 data0 = (float3) (data[0],data[4],data[12]);
+ float3 data1 = (float3) (data[3],data[11],data[15]);
+ acc = comp_dot3(data0, data1);
+ result[0] = acc.s0;
+ result[1] = acc.s1;
+ }
+
+ kernel void test_dot2(global float* data,
+ int size,
+ global float* result)
+ {
+ float2 acc = (float2)(0.0f, 0.0f);
+ float2 data0 = (float2) (data[0],data[14]);
+ float2 data1 = (float2) (data[1],data[15]);
+ acc = comp_dot2(data0, data1);
+ result[0] = acc.s0;
+ result[1] = acc.s1;
+ }
+
+ """
+
+ prg = pyopencl.Program(self.ctx, read_cl_file("kahan.cl") + src).build(self.args)
+ ones_d = pyopencl.array.to_device(self.queue, data)
+ res_d = pyopencl.array.empty(self.queue, 2, numpy.float32)
+ res_d.fill(0)
+ evt = prg.test_dot16(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt.wait()
+ res = res_d.get().sum(dtype="float64")
+ self.assertEqual(ref64, res, "test_dot16")
+
+ res_d.fill(0)
+ data0 = data[0::2]
+ data1 = data[1::2]
+ ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
+ ref32 = numpy.dot(data0, data1)
+ if (ref64 == ref32):
+ logger.warning("dot8: invalid tests as float32 provides the same result as float64")
+ evt = prg.test_dot8(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt.wait()
+ res = res_d.get().sum(dtype="float64")
+ self.assertEqual(ref64, res, "test_dot8")
+
+ res_d.fill(0)
+ data0 = data[0::4]
+ data1 = data[3::4]
+ ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
+ ref32 = numpy.dot(data0, data1)
+ if (ref64 == ref32):
+ logger.warning("dot4: invalid tests as float32 provides the same result as float64")
+ evt = prg.test_dot4(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt.wait()
+ res = res_d.get().sum(dtype="float64")
+ self.assertEqual(ref64, res, "test_dot4")
+
+ res_d.fill(0)
+ data0 = numpy.array([data[0], data[4], data[12]])
+ data1 = numpy.array([data[3], data[11], data[15]])
+ ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
+ ref32 = numpy.dot(data0, data1)
+ if (ref64 == ref32):
+ logger.warning("dot3: invalid tests as float32 provides the same result as float64")
+ evt = prg.test_dot3(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt.wait()
+ res = res_d.get().sum(dtype="float64")
+ self.assertEqual(ref64, res, "test_dot3")
+
+ res_d.fill(0)
+ data0 = numpy.array([data[0], data[14]])
+ data1 = numpy.array([data[1], data[15]])
+ ref64 = numpy.dot(data0.astype(numpy.float64), data1.astype(numpy.float64))
+ ref32 = numpy.dot(data0, data1)
+ if (ref64 == ref32):
+ logger.warning("dot2: invalid tests as float32 provides the same result as float64")
+ evt = prg.test_dot2(self.queue, (1,), (1,), ones_d.data, numpy.int32(N), res_d.data)
+ evt.wait()
+ res = res_d.get().sum(dtype="float64")
+ self.assertEqual(ref64, res, "test_dot2")
diff --git a/src/silx/opencl/test/test_linalg.py b/src/silx/opencl/test/test_linalg.py
new file mode 100644
index 0000000..a997a36
--- /dev/null
+++ b/src/silx/opencl/test/test_linalg.py
@@ -0,0 +1,204 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the linalg module"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Pierre paleo"]
+__license__ = "MIT"
+__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "01/08/2019"
+
+
+import time
+import logging
+import numpy as np
+import unittest
+try:
+ import mako
+except ImportError:
+ mako = None
+from ..common import ocl
+if ocl:
+ import pyopencl as cl
+ import pyopencl.array as parray
+ from .. import linalg
+from silx.test.utils import utilstest
+
+logger = logging.getLogger(__name__)
+try:
+ from scipy.ndimage.filters import laplace
+ _has_scipy = True
+except ImportError:
+ _has_scipy = False
+
+
+# TODO move this function in math or image ?
+def gradient(img):
+ '''
+ Compute the gradient of an image as a numpy array
+ Code from https://github.com/emmanuelle/tomo-tv/
+ '''
+ shape = [img.ndim, ] + list(img.shape)
+ gradient = np.zeros(shape, dtype=img.dtype)
+ slice_all = [0, slice(None, -1),]
+ for d in range(img.ndim):
+ gradient[tuple(slice_all)] = np.diff(img, axis=d)
+ slice_all[0] = d + 1
+ slice_all.insert(1, slice(None))
+ return gradient
+
+
+# TODO move this function in math or image ?
+def divergence(grad):
+ '''
+ Compute the divergence of a gradient
+ Code from https://github.com/emmanuelle/tomo-tv/
+ '''
+ res = np.zeros(grad.shape[1:])
+ for d in range(grad.shape[0]):
+ this_grad = np.rollaxis(grad[d], d)
+ this_res = np.rollaxis(res, d)
+ this_res[:-1] += this_grad[:-1]
+ this_res[1:-1] -= this_grad[:-2]
+ this_res[-1] -= this_grad[-2]
+ return res
+
+
+@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
+class TestLinAlg(unittest.TestCase):
+
+ def setUp(self):
+ if ocl is None:
+ return
+ self.getfiles()
+ self.la = linalg.LinAlg(self.image.shape)
+ self.allocate_arrays()
+
+ def allocate_arrays(self):
+ """
+ Allocate various types of arrays for the tests
+ """
+ # numpy images
+ self.grad = np.zeros(self.image.shape, dtype=np.complex64)
+ self.grad2 = np.zeros((2,) + self.image.shape, dtype=np.float32)
+ self.grad_ref = gradient(self.image)
+ self.div_ref = divergence(self.grad_ref)
+ self.image2 = np.zeros_like(self.image)
+ # Device images
+ self.gradient_parray = parray.empty(self.la.queue, self.image.shape, np.complex64)
+ self.gradient_parray.fill(0)
+ # we should be using cl.Buffer(self.la.ctx, cl.mem_flags.READ_WRITE, size=self.image.nbytes*2),
+ # but platforms not suporting openCL 1.2 have a problem with enqueue_fill_buffer,
+ # so we use the parray "fill" utility
+ self.gradient_buffer = self.gradient_parray.data
+ # Do the same for image
+ self.image_parray = parray.to_device(self.la.queue, self.image)
+ self.image_buffer = self.image_parray.data
+ # Refs
+ tmp = np.zeros(self.image.shape, dtype=np.complex64)
+ tmp.real = np.copy(self.grad_ref[0])
+ tmp.imag = np.copy(self.grad_ref[1])
+ self.grad_ref_parray = parray.to_device(self.la.queue, tmp)
+ self.grad_ref_buffer = self.grad_ref_parray.data
+
+ def tearDown(self):
+ self.image = None
+ self.image2 = None
+ self.grad = None
+ self.grad2 = None
+ self.grad_ref = None
+ self.div_ref = None
+ self.gradient_parray.data.release()
+ self.gradient_parray = None
+ self.gradient_buffer = None
+ self.image_parray.data.release()
+ self.image_parray = None
+ self.image_buffer = None
+ self.grad_ref_parray.data.release()
+ self.grad_ref_parray = None
+ self.grad_ref_buffer = None
+
+ def getfiles(self):
+ # load 512x512 MRI phantom - TODO include Lena or ascent once a .npz is available
+ self.image = np.load(utilstest.getfile("Brain512.npz"))["data"]
+
+ def compare(self, result, reference, abstol, name):
+ errmax = np.max(np.abs(result - reference))
+ logger.info("%s: Max error = %e" % (name, errmax))
+ self.assertTrue(errmax < abstol, str("%s: Max error is too high" % name))
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_gradient(self):
+ arrays = {
+ "numpy.ndarray": self.image,
+ "buffer": self.image_buffer,
+ "parray": self.image_parray
+ }
+ for desc, image in arrays.items():
+ # Test with dst on host (numpy.ndarray)
+ res = self.la.gradient(image, return_to_host=True)
+ self.compare(res, self.grad_ref, 1e-6, str("gradient[src=%s, dst=numpy.ndarray]" % desc))
+ # Test with dst on device (pyopencl.Buffer)
+ self.la.gradient(image, dst=self.gradient_buffer)
+ cl.enqueue_copy(self.la.queue, self.grad, self.gradient_buffer)
+ self.grad2[0] = self.grad.real
+ self.grad2[1] = self.grad.imag
+ self.compare(self.grad2, self.grad_ref, 1e-6, str("gradient[src=%s, dst=buffer]" % desc))
+ # Test with dst on device (pyopencl.Array)
+ self.la.gradient(image, dst=self.gradient_parray)
+ self.grad = self.gradient_parray.get()
+ self.grad2[0] = self.grad.real
+ self.grad2[1] = self.grad.imag
+ self.compare(self.grad2, self.grad_ref, 1e-6, str("gradient[src=%s, dst=parray]" % desc))
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_divergence(self):
+ arrays = {
+ "numpy.ndarray": self.grad_ref,
+ "buffer": self.grad_ref_buffer,
+ "parray": self.grad_ref_parray
+ }
+ for desc, grad in arrays.items():
+ # Test with dst on host (numpy.ndarray)
+ res = self.la.divergence(grad, return_to_host=True)
+ self.compare(res, self.div_ref, 1e-6, str("divergence[src=%s, dst=numpy.ndarray]" % desc))
+ # Test with dst on device (pyopencl.Buffer)
+ self.la.divergence(grad, dst=self.image_buffer)
+ cl.enqueue_copy(self.la.queue, self.image2, self.image_buffer)
+ self.compare(self.image2, self.div_ref, 1e-6, str("divergence[src=%s, dst=buffer]" % desc))
+ # Test with dst on device (pyopencl.Array)
+ self.la.divergence(grad, dst=self.image_parray)
+ self.image2 = self.image_parray.get()
+ self.compare(self.image2, self.div_ref, 1e-6, str("divergence[src=%s, dst=parray]" % desc))
+
+ @unittest.skipUnless(ocl and mako and _has_scipy, "pyopencl and/or scipy is missing")
+ def test_laplacian(self):
+ laplacian_ref = laplace(self.image)
+ # Laplacian = div(grad)
+ self.la.gradient(self.image)
+ laplacian_ocl = self.la.divergence(self.la.d_gradient, return_to_host=True)
+ self.compare(laplacian_ocl, laplacian_ref, 1e-6, "laplacian")
diff --git a/src/silx/opencl/test/test_medfilt.py b/src/silx/opencl/test/test_medfilt.py
new file mode 100644
index 0000000..339e0f2
--- /dev/null
+++ b/src/silx/opencl/test/test_medfilt.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Project: Median filter of images + OpenCL
+# https://github.com/silx-kit/silx
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+Simple test of the median filter
+"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "05/07/2018"
+
+
+import sys
+import time
+import logging
+import numpy
+import unittest
+from collections import namedtuple
+try:
+ import mako
+except ImportError:
+ mako = None
+from ..common import ocl
+if ocl:
+ import pyopencl
+ import pyopencl.array
+ from .. import medfilt
+
+logger = logging.getLogger(__name__)
+
+Result = namedtuple("Result", ["size", "error", "sp_time", "oc_time"])
+
+try:
+ from scipy.misc import ascent
+except:
+ def ascent():
+ """Dummy image from random data"""
+ return numpy.random.random((512, 512))
+try:
+ from scipy.ndimage import filters
+ median_filter = filters.median_filter
+ HAS_SCIPY = True
+except:
+ HAS_SCIPY = False
+ from silx.math import medfilt2d as median_filter
+
+@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
+class TestMedianFilter(unittest.TestCase):
+
+ def setUp(self):
+ if ocl is None:
+ return
+ self.data = ascent().astype(numpy.float32)
+ self.medianfilter = medfilt.MedianFilter2D(self.data.shape, devicetype="gpu")
+
+ def tearDown(self):
+ self.data = None
+ self.medianfilter = None
+
+ def measure(self, size):
+ "Common measurement of accuracy and timings"
+ t0 = time.time()
+ if HAS_SCIPY:
+ ref = median_filter(self.data, size, mode="nearest")
+ else:
+ ref = median_filter(self.data, size)
+ t1 = time.time()
+ try:
+ got = self.medianfilter.medfilt2d(self.data, size)
+ except RuntimeError as msg:
+ logger.error(msg)
+ return
+ t2 = time.time()
+ delta = abs(got - ref).max()
+ return Result(size, delta, t1 - t0, t2 - t1)
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_medfilt(self):
+ """
+ tests the median filter kernel
+ """
+ r = self.measure(size=11)
+ if r is None:
+ logger.info("test_medfilt: size: %s: skipped")
+ else:
+ logger.info("test_medfilt: size: %s error %s, t_ref: %.3fs, t_ocl: %.3fs" % r)
+ self.assertEqual(r.error, 0, 'Results are correct')
+
+ def benchmark(self, limit=36):
+ "Run some benchmarking"
+ try:
+ import PyQt5
+ from ...gui.matplotlib import pylab
+ from ...gui.utils import update_fig
+ except:
+ pylab = None
+
+ def update_fig(*ag, **kwarg):
+ pass
+
+ fig = pylab.figure()
+ fig.suptitle("Median filter of an image 512x512")
+ sp = fig.add_subplot(1, 1, 1)
+ sp.set_title(self.medianfilter.ctx.devices[0].name)
+ sp.set_xlabel("Window width & height")
+ sp.set_ylabel("Execution time (s)")
+ sp.set_xlim(2, limit + 1)
+ sp.set_ylim(0, 4)
+ data_size = []
+ data_scipy = []
+ data_opencl = []
+ plot_sp = sp.plot(data_size, data_scipy, "-or", label="scipy")[0]
+ plot_opencl = sp.plot(data_size, data_opencl, "-ob", label="opencl")[0]
+ sp.legend(loc=2)
+ fig.show()
+ update_fig(fig)
+ for s in range(3, limit, 2):
+ r = self.measure(s)
+ print(r)
+ if r.error == 0:
+ data_size.append(s)
+ data_scipy.append(r.sp_time)
+ data_opencl.append(r.oc_time)
+ plot_sp.set_data(data_size, data_scipy)
+ plot_opencl.set_data(data_size, data_opencl)
+ update_fig(fig)
+ fig.show()
+ input()
+
+
+def benchmark():
+ testSuite = unittest.TestSuite()
+ testSuite.addTest(TestMedianFilter("benchmark"))
+ return testSuite
diff --git a/src/silx/opencl/test/test_projection.py b/src/silx/opencl/test/test_projection.py
new file mode 100644
index 0000000..13db5f4
--- /dev/null
+++ b/src/silx/opencl/test/test_projection.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the forward projection module"""
+
+from __future__ import division, print_function
+
+__authors__ = ["Pierre paleo"]
+__license__ = "MIT"
+__copyright__ = "2013-2017 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "19/01/2018"
+
+
+import time
+import logging
+import numpy as np
+import unittest
+try:
+ import mako
+except ImportError:
+ mako = None
+from ..common import ocl
+if ocl:
+ from .. import projection
+from silx.test.utils import utilstest
+
+logger = logging.getLogger(__name__)
+
+
+@unittest.skipUnless(ocl and mako, "PyOpenCl is missing")
+class TestProj(unittest.TestCase):
+
+ def setUp(self):
+ if ocl is None:
+ return
+ # ~ if sys.platform.startswith('darwin'):
+ # ~ self.skipTest("Projection is not implemented on CPU for OS X yet")
+ self.getfiles()
+ n_angles = self.sino.shape[0]
+ self.proj = projection.Projection(self.phantom.shape, n_angles)
+ if self.proj.compiletime_workgroup_size < 16 * 16:
+ self.skipTest("Current implementation of OpenCL projection is not supported on this platform yet")
+
+ def tearDown(self):
+ self.phantom = None
+ self.sino = None
+ self.proj = None
+
+ def getfiles(self):
+ # load 512x512 MRI phantom
+ self.phantom = np.load(utilstest.getfile("Brain512.npz"))["data"]
+ # load sinogram computed with PyHST
+ self.sino = np.load(utilstest.getfile("sino500_pyhst.npz"))["data"]
+
+ def measure(self):
+ "Common measurement of timings"
+ t1 = time.time()
+ try:
+ result = self.proj.projection(self.phantom)
+ except RuntimeError as msg:
+ logger.error(msg)
+ return
+ t2 = time.time()
+ return t2 - t1, result
+
+ def compare(self, res):
+ """
+ Compare a result with the reference reconstruction.
+ Only the valid reconstruction zone (inscribed circle) is taken into account
+ """
+ # Compare with the original phantom.
+ # TODO: compare a standard projection
+ ref = self.sino
+ return np.max(np.abs(res - ref))
+
+ @unittest.skipUnless(ocl and mako, "pyopencl is missing")
+ def test_proj(self):
+ """
+ tests Projection
+ """
+ # Test single reconstruction
+ # --------------------------
+ t, res = self.measure()
+ if t is None:
+ logger.info("test_proj: skipped")
+ else:
+ logger.info("test_proj: time = %.3fs" % t)
+ err = self.compare(res)
+ msg = str("Max error = %e" % err)
+ logger.info(msg)
+ # Interpolation differs at some lines, giving relative error of 10/50000
+ self.assertTrue(err < 20., "Max error is too high")
+ # Test multiple reconstructions
+ # -----------------------------
+ res0 = np.copy(res)
+ for i in range(10):
+ res = self.proj.projection(self.phantom)
+ errmax = np.max(np.abs(res - res0))
+ self.assertTrue(errmax < 1.e-6, "Max error is too high")
diff --git a/src/silx/opencl/test/test_sparse.py b/src/silx/opencl/test/test_sparse.py
new file mode 100644
index 0000000..1d26b36
--- /dev/null
+++ b/src/silx/opencl/test/test_sparse.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2018-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test of the sparse module"""
+
+import numpy as np
+import unittest
+import logging
+from itertools import product
+from ..common import ocl
+if ocl:
+ import pyopencl.array as parray
+ from silx.opencl.sparse import CSR
+try:
+ import scipy.sparse as sp
+except ImportError:
+ sp = None
+logger = logging.getLogger(__name__)
+
+
+
+def generate_sparse_random_data(
+ shape=(1000,),
+ data_min=0, data_max=100,
+ density=0.1,
+ use_only_integers=True,
+ dtype="f"):
+ """
+ Generate random sparse data where.
+
+ Parameters
+ ------------
+ shape: tuple
+ Output data shape.
+ data_min: int or float
+ Minimum value of data
+ data_max: int or float
+ Maximum value of data
+ density: float
+ Density of non-zero elements in the output data.
+ Low value of density mean low number of non-zero elements.
+ use_only_integers: bool
+ If set to True, the output data items will be primarily integers,
+ possibly casted to float if dtype is a floating-point type.
+ This can be used for ease of debugging.
+ dtype: str or numpy.dtype
+ Output data type
+ """
+ mask = np.random.binomial(1, density, size=shape)
+ if use_only_integers:
+ d = np.random.randint(data_min, high=data_max, size=shape)
+ else:
+ d = data_min + (data_max - data_min) * np.random.rand(*shape)
+ return (d * mask).astype(dtype)
+
+
+
+@unittest.skipUnless(ocl and sp, "PyOpenCl/scipy is missing")
+class TestCSR(unittest.TestCase):
+ """Test CSR format"""
+
+ def setUp(self):
+ # Test possible configurations
+ input_on_device = [False, True]
+ output_on_device = [False, True]
+ dtypes = [np.float32, np.int32, np.uint16]
+ self._test_configs = list(product(input_on_device, output_on_device, dtypes))
+
+
+ def compute_ref_sparsification(self, array):
+ ref_sparse = sp.csr_matrix(array)
+ return ref_sparse
+
+
+ def test_sparsification(self):
+ for input_on_device, output_on_device, dtype in self._test_configs:
+ self._test_sparsification(input_on_device, output_on_device, dtype)
+
+
+ def _test_sparsification(self, input_on_device, output_on_device, dtype):
+ current_config = "input on device: %s, output on device: %s, dtype: %s" % (
+ str(input_on_device), str(output_on_device), str(dtype)
+ )
+ logger.debug("CSR: %s" % current_config)
+ # Generate data and reference CSR
+ array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
+ ref_sparse = self.compute_ref_sparsification(array)
+ # Sparsify on device
+ csr = CSR(array.shape, dtype=dtype)
+ if input_on_device:
+ # The array has to be flattened
+ arr = parray.to_device(csr.queue, array.ravel())
+ else:
+ arr = array
+ if output_on_device:
+ d_data = parray.empty_like(csr.data)
+ d_indices = parray.empty_like(csr.indices)
+ d_indptr = parray.empty_like(csr.indptr)
+ d_data.fill(0)
+ d_indices.fill(0)
+ d_indptr.fill(0)
+ output = (d_data, d_indices, d_indptr)
+ else:
+ output = None
+ data, indices, indptr = csr.sparsify(arr, output=output)
+ if output_on_device:
+ data = data.get()
+ indices = indices.get()
+ indptr = indptr.get()
+ # Compare
+ nnz = ref_sparse.nnz
+ self.assertTrue(
+ np.allclose(data[:nnz], ref_sparse.data),
+ "something wrong with sparsified data (%s)"
+ % current_config
+ )
+ self.assertTrue(
+ np.allclose(indices[:nnz], ref_sparse.indices),
+ "something wrong with sparsified indices (%s)"
+ % current_config
+ )
+ self.assertTrue(
+ np.allclose(indptr, ref_sparse.indptr),
+ "something wrong with sparsified indices pointers (indptr) (%s)"
+ % current_config
+ )
+
+
+ def test_desparsification(self):
+ for input_on_device, output_on_device, dtype in self._test_configs:
+ self._test_desparsification(input_on_device, output_on_device, dtype)
+
+
+ def _test_desparsification(self, input_on_device, output_on_device, dtype):
+ current_config = "input on device: %s, output on device: %s, dtype: %s" % (
+ str(input_on_device), str(output_on_device), str(dtype)
+ )
+ logger.debug("CSR: %s" % current_config)
+ # Generate data and reference CSR
+ array = generate_sparse_random_data(shape=(512, 511), dtype=dtype)
+ ref_sparse = self.compute_ref_sparsification(array)
+ # De-sparsify on device
+ csr = CSR(array.shape, dtype=dtype, max_nnz=ref_sparse.nnz)
+ if input_on_device:
+ data = parray.to_device(csr.queue, ref_sparse.data)
+ indices = parray.to_device(csr.queue, ref_sparse.indices)
+ indptr = parray.to_device(csr.queue, ref_sparse.indptr)
+ else:
+ data = ref_sparse.data
+ indices = ref_sparse.indices
+ indptr = ref_sparse.indptr
+ if output_on_device:
+ d_arr = parray.empty_like(csr.array)
+ d_arr.fill(0)
+ output = d_arr
+ else:
+ output = None
+ arr = csr.densify(data, indices, indptr, output=output)
+ if output_on_device:
+ arr = arr.get()
+ # Compare
+ self.assertTrue(
+ np.allclose(arr.reshape(array.shape), array),
+ "something wrong with densified data (%s)"
+ % current_config
+ )
diff --git a/src/silx/opencl/test/test_stats.py b/src/silx/opencl/test/test_stats.py
new file mode 100644
index 0000000..859271d
--- /dev/null
+++ b/src/silx/opencl/test/test_stats.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Project: Sift implementation in Python + OpenCL
+# https://github.com/silx-kit/silx
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use,
+# copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following
+# conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+"""
+Simple test of an addition
+"""
+__authors__ = ["Henri Payno, Jérôme Kieffer"]
+__contact__ = "jerome.kieffer@esrf.eu"
+__license__ = "MIT"
+__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
+__date__ = "19/05/2021"
+
+import logging
+import time
+import numpy
+
+import unittest
+from ..common import ocl
+if ocl:
+ import pyopencl
+ import pyopencl.array
+ from ..statistics import StatResults, Statistics
+from ..utils import get_opencl_code
+logger = logging.getLogger(__name__)
+
+
+@unittest.skipUnless(ocl, "PyOpenCl is missing")
+class TestStatistics(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.size = 1 << 20 # 1 million elements
+ cls.data = numpy.random.randint(0, 65000, cls.size).astype("uint16")
+ fdata = cls.data.astype("float64")
+ t0 = time.perf_counter()
+ std = fdata.std()
+ cls.ref = StatResults(fdata.min(), fdata.max(), float(fdata.size),
+ fdata.sum(), fdata.mean(), std ** 2,
+ std)
+ t1 = time.perf_counter()
+ cls.ref_time = t1 - t0
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.size = cls.ref = cls.data = cls.ref_time = None
+
+ @classmethod
+ def validate(cls, res):
+ return (
+ (res.min == cls.ref.min) and
+ (res.max == cls.ref.max) and
+ (res.cnt == cls.ref.cnt) and
+ abs(res.mean - cls.ref.mean) < 0.01 and
+ abs(res.std - cls.ref.std) < 0.1)
+
+ def test_measurement(self):
+ """
+ tests that all devices are working properly ...
+ """
+ logger.info("Reference results: %s", self.ref)
+ for pid, platform in enumerate(ocl.platforms):
+ for did, device in enumerate(platform.devices):
+ try:
+ s = Statistics(template=self.data, platformid=pid, deviceid=did)
+ except Exception as err:
+ failed_init = True
+ res = StatResults(0, 0, 0, 0, 0, 0, 0)
+ print(err)
+ else:
+ failed_init = False
+ for comp in ("single", "double", "comp"):
+ t0 = time.perf_counter()
+ res = s(self.data, comp=comp)
+ t1 = time.perf_counter()
+ logger.info("Runtime on %s/%s : %.3fms x%.1f", platform, device, 1000 * (t1 - t0), self.ref_time / (t1 - t0))
+
+ if failed_init or not self.validate(res):
+ logger.error("failed_init %s; Computation modes %s", failed_init, comp)
+ logger.error("Failed on platform %s device %s", platform, device)
+ logger.error("Reference results: %s", self.ref)
+ logger.error("Faulty results: %s", res)
+ self.assertTrue(False, f"Stat calculation failed on {platform},{device} in mode {comp}")
diff --git a/silx/opencl/utils.py b/src/silx/opencl/utils.py
index 575e018..575e018 100644
--- a/silx/opencl/utils.py
+++ b/src/silx/opencl/utils.py
diff --git a/silx/resources/__init__.py b/src/silx/resources/__init__.py
index 5346f48..5346f48 100644
--- a/silx/resources/__init__.py
+++ b/src/silx/resources/__init__.py
diff --git a/silx/resources/gui/colormaps/cividis.npy b/src/silx/resources/gui/colormaps/cividis.npy
index 8e118ef..8e118ef 100644
--- a/silx/resources/gui/colormaps/cividis.npy
+++ b/src/silx/resources/gui/colormaps/cividis.npy
Binary files differ
diff --git a/silx/resources/gui/colormaps/inferno.npy b/src/silx/resources/gui/colormaps/inferno.npy
index 3b00d2e..3b00d2e 100644
--- a/silx/resources/gui/colormaps/inferno.npy
+++ b/src/silx/resources/gui/colormaps/inferno.npy
Binary files differ
diff --git a/silx/resources/gui/colormaps/magma.npy b/src/silx/resources/gui/colormaps/magma.npy
index 3f8f4dc..3f8f4dc 100644
--- a/silx/resources/gui/colormaps/magma.npy
+++ b/src/silx/resources/gui/colormaps/magma.npy
Binary files differ
diff --git a/silx/resources/gui/colormaps/plasma.npy b/src/silx/resources/gui/colormaps/plasma.npy
index 6af2fd0..6af2fd0 100644
--- a/silx/resources/gui/colormaps/plasma.npy
+++ b/src/silx/resources/gui/colormaps/plasma.npy
Binary files differ
diff --git a/silx/resources/gui/colormaps/viridis.npy b/src/silx/resources/gui/colormaps/viridis.npy
index 141877e..141877e 100644
--- a/silx/resources/gui/colormaps/viridis.npy
+++ b/src/silx/resources/gui/colormaps/viridis.npy
Binary files differ
diff --git a/silx/resources/gui/icons/3d-plane-normal-x.png b/src/silx/resources/gui/icons/3d-plane-normal-x.png
index bf8cf45..bf8cf45 100644
--- a/silx/resources/gui/icons/3d-plane-normal-x.png
+++ b/src/silx/resources/gui/icons/3d-plane-normal-x.png
Binary files differ
diff --git a/silx/resources/gui/icons/3d-plane-normal-x.svg b/src/silx/resources/gui/icons/3d-plane-normal-x.svg
index 203bd84..203bd84 100644
--- a/silx/resources/gui/icons/3d-plane-normal-x.svg
+++ b/src/silx/resources/gui/icons/3d-plane-normal-x.svg
diff --git a/silx/resources/gui/icons/3d-plane-normal-y.png b/src/silx/resources/gui/icons/3d-plane-normal-y.png
index 733b92a..733b92a 100644
--- a/silx/resources/gui/icons/3d-plane-normal-y.png
+++ b/src/silx/resources/gui/icons/3d-plane-normal-y.png
Binary files differ
diff --git a/silx/resources/gui/icons/3d-plane-normal-y.svg b/src/silx/resources/gui/icons/3d-plane-normal-y.svg
index 78d8ebd..78d8ebd 100644
--- a/silx/resources/gui/icons/3d-plane-normal-y.svg
+++ b/src/silx/resources/gui/icons/3d-plane-normal-y.svg
diff --git a/silx/resources/gui/icons/3d-plane-normal-z.png b/src/silx/resources/gui/icons/3d-plane-normal-z.png
index 0ab61e6..0ab61e6 100644
--- a/silx/resources/gui/icons/3d-plane-normal-z.png
+++ b/src/silx/resources/gui/icons/3d-plane-normal-z.png
Binary files differ
diff --git a/silx/resources/gui/icons/3d-plane-normal-z.svg b/src/silx/resources/gui/icons/3d-plane-normal-z.svg
index 5ac7d86..5ac7d86 100644
--- a/silx/resources/gui/icons/3d-plane-normal-z.svg
+++ b/src/silx/resources/gui/icons/3d-plane-normal-z.svg
diff --git a/silx/resources/gui/icons/3d-plane-pan.png b/src/silx/resources/gui/icons/3d-plane-pan.png
index 79b8ace..79b8ace 100644
--- a/silx/resources/gui/icons/3d-plane-pan.png
+++ b/src/silx/resources/gui/icons/3d-plane-pan.png
Binary files differ
diff --git a/silx/resources/gui/icons/3d-plane-pan.svg b/src/silx/resources/gui/icons/3d-plane-pan.svg
index 73df5fc..73df5fc 100644
--- a/silx/resources/gui/icons/3d-plane-pan.svg
+++ b/src/silx/resources/gui/icons/3d-plane-pan.svg
diff --git a/silx/resources/gui/icons/3d-plane.png b/src/silx/resources/gui/icons/3d-plane.png
index 6181d42..6181d42 100644
--- a/silx/resources/gui/icons/3d-plane.png
+++ b/src/silx/resources/gui/icons/3d-plane.png
Binary files differ
diff --git a/silx/resources/gui/icons/3d-plane.svg b/src/silx/resources/gui/icons/3d-plane.svg
index 830db78..830db78 100644
--- a/silx/resources/gui/icons/3d-plane.svg
+++ b/src/silx/resources/gui/icons/3d-plane.svg
diff --git a/silx/resources/gui/icons/add-range-horizontal.png b/src/silx/resources/gui/icons/add-range-horizontal.png
index 14bdd18..14bdd18 100644
--- a/silx/resources/gui/icons/add-range-horizontal.png
+++ b/src/silx/resources/gui/icons/add-range-horizontal.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-range-horizontal.svg b/src/silx/resources/gui/icons/add-range-horizontal.svg
index 0470609..0470609 100644
--- a/silx/resources/gui/icons/add-range-horizontal.svg
+++ b/src/silx/resources/gui/icons/add-range-horizontal.svg
diff --git a/silx/resources/gui/icons/add-shape-arc.png b/src/silx/resources/gui/icons/add-shape-arc.png
index 07afaab..07afaab 100644
--- a/silx/resources/gui/icons/add-shape-arc.png
+++ b/src/silx/resources/gui/icons/add-shape-arc.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-arc.svg b/src/silx/resources/gui/icons/add-shape-arc.svg
index a0a8cfc..a0a8cfc 100644
--- a/silx/resources/gui/icons/add-shape-arc.svg
+++ b/src/silx/resources/gui/icons/add-shape-arc.svg
diff --git a/silx/resources/gui/icons/add-shape-circle.png b/src/silx/resources/gui/icons/add-shape-circle.png
index 722c08a..722c08a 100644
--- a/silx/resources/gui/icons/add-shape-circle.png
+++ b/src/silx/resources/gui/icons/add-shape-circle.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-circle.svg b/src/silx/resources/gui/icons/add-shape-circle.svg
index 871d8ee..871d8ee 100644
--- a/silx/resources/gui/icons/add-shape-circle.svg
+++ b/src/silx/resources/gui/icons/add-shape-circle.svg
diff --git a/silx/resources/gui/icons/add-shape-cross.png b/src/silx/resources/gui/icons/add-shape-cross.png
index 2e5eb60..2e5eb60 100644
--- a/silx/resources/gui/icons/add-shape-cross.png
+++ b/src/silx/resources/gui/icons/add-shape-cross.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-cross.svg b/src/silx/resources/gui/icons/add-shape-cross.svg
index c08ef33..c08ef33 100644
--- a/silx/resources/gui/icons/add-shape-cross.svg
+++ b/src/silx/resources/gui/icons/add-shape-cross.svg
diff --git a/silx/resources/gui/icons/add-shape-diagonal.png b/src/silx/resources/gui/icons/add-shape-diagonal.png
index 3696db2..3696db2 100644
--- a/silx/resources/gui/icons/add-shape-diagonal.png
+++ b/src/silx/resources/gui/icons/add-shape-diagonal.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-diagonal.svg b/src/silx/resources/gui/icons/add-shape-diagonal.svg
index 42f9414..42f9414 100644
--- a/silx/resources/gui/icons/add-shape-diagonal.svg
+++ b/src/silx/resources/gui/icons/add-shape-diagonal.svg
diff --git a/silx/resources/gui/icons/add-shape-ellipse.png b/src/silx/resources/gui/icons/add-shape-ellipse.png
index c3f2290..c3f2290 100644
--- a/silx/resources/gui/icons/add-shape-ellipse.png
+++ b/src/silx/resources/gui/icons/add-shape-ellipse.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-ellipse.svg b/src/silx/resources/gui/icons/add-shape-ellipse.svg
index 5c466ae..5c466ae 100644
--- a/silx/resources/gui/icons/add-shape-ellipse.svg
+++ b/src/silx/resources/gui/icons/add-shape-ellipse.svg
diff --git a/silx/resources/gui/icons/add-shape-horizontal.png b/src/silx/resources/gui/icons/add-shape-horizontal.png
index d217af5..d217af5 100644
--- a/silx/resources/gui/icons/add-shape-horizontal.png
+++ b/src/silx/resources/gui/icons/add-shape-horizontal.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-horizontal.svg b/src/silx/resources/gui/icons/add-shape-horizontal.svg
index 72d2b9b..72d2b9b 100644
--- a/silx/resources/gui/icons/add-shape-horizontal.svg
+++ b/src/silx/resources/gui/icons/add-shape-horizontal.svg
diff --git a/silx/resources/gui/icons/add-shape-point.png b/src/silx/resources/gui/icons/add-shape-point.png
index fa2111a..fa2111a 100644
--- a/silx/resources/gui/icons/add-shape-point.png
+++ b/src/silx/resources/gui/icons/add-shape-point.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-point.svg b/src/silx/resources/gui/icons/add-shape-point.svg
index c5ed941..c5ed941 100644
--- a/silx/resources/gui/icons/add-shape-point.svg
+++ b/src/silx/resources/gui/icons/add-shape-point.svg
diff --git a/silx/resources/gui/icons/add-shape-polygon.png b/src/silx/resources/gui/icons/add-shape-polygon.png
index ba7f040..ba7f040 100644
--- a/silx/resources/gui/icons/add-shape-polygon.png
+++ b/src/silx/resources/gui/icons/add-shape-polygon.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-polygon.svg b/src/silx/resources/gui/icons/add-shape-polygon.svg
index 07dcd46..07dcd46 100644
--- a/silx/resources/gui/icons/add-shape-polygon.svg
+++ b/src/silx/resources/gui/icons/add-shape-polygon.svg
diff --git a/silx/resources/gui/icons/add-shape-rectangle.png b/src/silx/resources/gui/icons/add-shape-rectangle.png
index 6246ce6..6246ce6 100644
--- a/silx/resources/gui/icons/add-shape-rectangle.png
+++ b/src/silx/resources/gui/icons/add-shape-rectangle.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-rectangle.svg b/src/silx/resources/gui/icons/add-shape-rectangle.svg
index d35260b..d35260b 100644
--- a/silx/resources/gui/icons/add-shape-rectangle.svg
+++ b/src/silx/resources/gui/icons/add-shape-rectangle.svg
diff --git a/silx/resources/gui/icons/add-shape-unknown.png b/src/silx/resources/gui/icons/add-shape-unknown.png
index 3578e29..3578e29 100644
--- a/silx/resources/gui/icons/add-shape-unknown.png
+++ b/src/silx/resources/gui/icons/add-shape-unknown.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-unknown.svg b/src/silx/resources/gui/icons/add-shape-unknown.svg
index 04758cf..04758cf 100644
--- a/silx/resources/gui/icons/add-shape-unknown.svg
+++ b/src/silx/resources/gui/icons/add-shape-unknown.svg
diff --git a/silx/resources/gui/icons/add-shape-vertical.png b/src/silx/resources/gui/icons/add-shape-vertical.png
index 19050d9..19050d9 100644
--- a/silx/resources/gui/icons/add-shape-vertical.png
+++ b/src/silx/resources/gui/icons/add-shape-vertical.png
Binary files differ
diff --git a/silx/resources/gui/icons/add-shape-vertical.svg b/src/silx/resources/gui/icons/add-shape-vertical.svg
index 9f9070d..9f9070d 100644
--- a/silx/resources/gui/icons/add-shape-vertical.svg
+++ b/src/silx/resources/gui/icons/add-shape-vertical.svg
diff --git a/silx/resources/gui/icons/add.png b/src/silx/resources/gui/icons/add.png
index 80c6400..80c6400 100644
--- a/silx/resources/gui/icons/add.png
+++ b/src/silx/resources/gui/icons/add.png
Binary files differ
diff --git a/silx/resources/gui/icons/add.svg b/src/silx/resources/gui/icons/add.svg
index 19c1a6d..19c1a6d 100644
--- a/silx/resources/gui/icons/add.svg
+++ b/src/silx/resources/gui/icons/add.svg
diff --git a/src/silx/resources/gui/icons/aggregation-mode.png b/src/silx/resources/gui/icons/aggregation-mode.png
new file mode 100644
index 0000000..2b66dda
--- /dev/null
+++ b/src/silx/resources/gui/icons/aggregation-mode.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/aggregation-mode.svg b/src/silx/resources/gui/icons/aggregation-mode.svg
new file mode 100644
index 0000000..bd155a8
--- /dev/null
+++ b/src/silx/resources/gui/icons/aggregation-mode.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg id="svg8295" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata8301"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata>
+
+<ellipse id="path1413" cx="16.039" cy="5.8112" rx="10.713" ry="1.5738" fill="none" stroke="#f7941e" stroke-miterlimit="10" stroke-width="2.2022"/><path id="rect1415" d="m12.37 12.67h7.1053v16.875l-7.1053-3.0197z" fill="#f7941e"/><g id="g1439" transform="matrix(1.1504 0 0 1.1504 -2.127 -2.4569)" fill="none" stroke="#f7941e" stroke-width="2.5509"><path id="path1417" d="m6.5027 6.7799 8.7244 12.585"/><path id="path1417-3" d="m24.955 6.9343-8.7244 12.585"/></g><rect id="rect1435" x="10.505" y="11.515" width="11.191" height="3.1974" fill="#f7941e"/></svg>
diff --git a/silx/resources/gui/icons/arrow-keys.png b/src/silx/resources/gui/icons/arrow-keys.png
index bf83e29..bf83e29 100644
--- a/silx/resources/gui/icons/arrow-keys.png
+++ b/src/silx/resources/gui/icons/arrow-keys.png
Binary files differ
diff --git a/silx/resources/gui/icons/arrow-keys.svg b/src/silx/resources/gui/icons/arrow-keys.svg
index 64aade5..64aade5 100644
--- a/silx/resources/gui/icons/arrow-keys.svg
+++ b/src/silx/resources/gui/icons/arrow-keys.svg
diff --git a/silx/resources/gui/icons/axis.png b/src/silx/resources/gui/icons/axis.png
index aa29525..aa29525 100644
--- a/silx/resources/gui/icons/axis.png
+++ b/src/silx/resources/gui/icons/axis.png
Binary files differ
diff --git a/silx/resources/gui/icons/axis.svg b/src/silx/resources/gui/icons/axis.svg
index 4ea7ddc..4ea7ddc 100644
--- a/silx/resources/gui/icons/axis.svg
+++ b/src/silx/resources/gui/icons/axis.svg
diff --git a/silx/resources/gui/icons/backend-opengl.png b/src/silx/resources/gui/icons/backend-opengl.png
index ff81f64..ff81f64 100644
--- a/silx/resources/gui/icons/backend-opengl.png
+++ b/src/silx/resources/gui/icons/backend-opengl.png
Binary files differ
diff --git a/silx/resources/gui/icons/backend-opengl.svg b/src/silx/resources/gui/icons/backend-opengl.svg
index 41d79b8..41d79b8 100644
--- a/silx/resources/gui/icons/backend-opengl.svg
+++ b/src/silx/resources/gui/icons/backend-opengl.svg
diff --git a/silx/resources/gui/icons/camera.png b/src/silx/resources/gui/icons/camera.png
index ec3e62c..ec3e62c 100644
--- a/silx/resources/gui/icons/camera.png
+++ b/src/silx/resources/gui/icons/camera.png
Binary files differ
diff --git a/silx/resources/gui/icons/camera.svg b/src/silx/resources/gui/icons/camera.svg
index e53858a..e53858a 100644
--- a/silx/resources/gui/icons/camera.svg
+++ b/src/silx/resources/gui/icons/camera.svg
diff --git a/silx/resources/gui/icons/clipboard.png b/src/silx/resources/gui/icons/clipboard.png
index 03b6297..03b6297 100644
--- a/silx/resources/gui/icons/clipboard.png
+++ b/src/silx/resources/gui/icons/clipboard.png
Binary files differ
diff --git a/silx/resources/gui/icons/clipboard.svg b/src/silx/resources/gui/icons/clipboard.svg
index 7754fd1..7754fd1 100644
--- a/silx/resources/gui/icons/clipboard.svg
+++ b/src/silx/resources/gui/icons/clipboard.svg
diff --git a/silx/resources/gui/icons/close.png b/src/silx/resources/gui/icons/close.png
index 181b3fd..181b3fd 100755
--- a/silx/resources/gui/icons/close.png
+++ b/src/silx/resources/gui/icons/close.png
Binary files differ
diff --git a/silx/resources/gui/icons/close.svg b/src/silx/resources/gui/icons/close.svg
index 3b96e8f..3b96e8f 100644
--- a/silx/resources/gui/icons/close.svg
+++ b/src/silx/resources/gui/icons/close.svg
diff --git a/silx/resources/gui/icons/colorbar.png b/src/silx/resources/gui/icons/colorbar.png
index 1b0e416..1b0e416 100644
--- a/silx/resources/gui/icons/colorbar.png
+++ b/src/silx/resources/gui/icons/colorbar.png
Binary files differ
diff --git a/silx/resources/gui/icons/colorbar.svg b/src/silx/resources/gui/icons/colorbar.svg
index 035e619..035e619 100644
--- a/silx/resources/gui/icons/colorbar.svg
+++ b/src/silx/resources/gui/icons/colorbar.svg
diff --git a/silx/resources/gui/icons/colormap-histogram.png b/src/silx/resources/gui/icons/colormap-histogram.png
index a199adb..a199adb 100644
--- a/silx/resources/gui/icons/colormap-histogram.png
+++ b/src/silx/resources/gui/icons/colormap-histogram.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-histogram.svg b/src/silx/resources/gui/icons/colormap-histogram.svg
index d5a0996..d5a0996 100644
--- a/silx/resources/gui/icons/colormap-histogram.svg
+++ b/src/silx/resources/gui/icons/colormap-histogram.svg
diff --git a/silx/resources/gui/icons/colormap-none.png b/src/silx/resources/gui/icons/colormap-none.png
index 5441fa5..5441fa5 100644
--- a/silx/resources/gui/icons/colormap-none.png
+++ b/src/silx/resources/gui/icons/colormap-none.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-none.svg b/src/silx/resources/gui/icons/colormap-none.svg
index 3136d62..3136d62 100644
--- a/silx/resources/gui/icons/colormap-none.svg
+++ b/src/silx/resources/gui/icons/colormap-none.svg
diff --git a/silx/resources/gui/icons/colormap-norm-arcsinh.png b/src/silx/resources/gui/icons/colormap-norm-arcsinh.png
index 653102d..653102d 100644
--- a/silx/resources/gui/icons/colormap-norm-arcsinh.png
+++ b/src/silx/resources/gui/icons/colormap-norm-arcsinh.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-norm-arcsinh.svg b/src/silx/resources/gui/icons/colormap-norm-arcsinh.svg
index 961df04..961df04 100644
--- a/silx/resources/gui/icons/colormap-norm-arcsinh.svg
+++ b/src/silx/resources/gui/icons/colormap-norm-arcsinh.svg
diff --git a/silx/resources/gui/icons/colormap-norm-gamma.png b/src/silx/resources/gui/icons/colormap-norm-gamma.png
index 3fe9c3e..3fe9c3e 100644
--- a/silx/resources/gui/icons/colormap-norm-gamma.png
+++ b/src/silx/resources/gui/icons/colormap-norm-gamma.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-norm-gamma.svg b/src/silx/resources/gui/icons/colormap-norm-gamma.svg
index b43355e..b43355e 100644
--- a/silx/resources/gui/icons/colormap-norm-gamma.svg
+++ b/src/silx/resources/gui/icons/colormap-norm-gamma.svg
diff --git a/silx/resources/gui/icons/colormap-norm-linear.png b/src/silx/resources/gui/icons/colormap-norm-linear.png
index 60d2fe1..60d2fe1 100644
--- a/silx/resources/gui/icons/colormap-norm-linear.png
+++ b/src/silx/resources/gui/icons/colormap-norm-linear.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-norm-linear.svg b/src/silx/resources/gui/icons/colormap-norm-linear.svg
index eabfa23..eabfa23 100644
--- a/silx/resources/gui/icons/colormap-norm-linear.svg
+++ b/src/silx/resources/gui/icons/colormap-norm-linear.svg
diff --git a/silx/resources/gui/icons/colormap-norm-log.png b/src/silx/resources/gui/icons/colormap-norm-log.png
index 2486255..2486255 100644
--- a/silx/resources/gui/icons/colormap-norm-log.png
+++ b/src/silx/resources/gui/icons/colormap-norm-log.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-norm-log.svg b/src/silx/resources/gui/icons/colormap-norm-log.svg
index 69d6b96..69d6b96 100644
--- a/silx/resources/gui/icons/colormap-norm-log.svg
+++ b/src/silx/resources/gui/icons/colormap-norm-log.svg
diff --git a/silx/resources/gui/icons/colormap-norm-sqrt.png b/src/silx/resources/gui/icons/colormap-norm-sqrt.png
index d1b3ef5..d1b3ef5 100644
--- a/silx/resources/gui/icons/colormap-norm-sqrt.png
+++ b/src/silx/resources/gui/icons/colormap-norm-sqrt.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-norm-sqrt.svg b/src/silx/resources/gui/icons/colormap-norm-sqrt.svg
index 4d239e4..4d239e4 100644
--- a/silx/resources/gui/icons/colormap-norm-sqrt.svg
+++ b/src/silx/resources/gui/icons/colormap-norm-sqrt.svg
diff --git a/silx/resources/gui/icons/colormap-range.png b/src/silx/resources/gui/icons/colormap-range.png
index 6225375..6225375 100644
--- a/silx/resources/gui/icons/colormap-range.png
+++ b/src/silx/resources/gui/icons/colormap-range.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap-range.svg b/src/silx/resources/gui/icons/colormap-range.svg
index 0e70311..0e70311 100644
--- a/silx/resources/gui/icons/colormap-range.svg
+++ b/src/silx/resources/gui/icons/colormap-range.svg
diff --git a/silx/resources/gui/icons/colormap.png b/src/silx/resources/gui/icons/colormap.png
index 48a6e52..48a6e52 100755
--- a/silx/resources/gui/icons/colormap.png
+++ b/src/silx/resources/gui/icons/colormap.png
Binary files differ
diff --git a/silx/resources/gui/icons/colormap.svg b/src/silx/resources/gui/icons/colormap.svg
index 03c9672..03c9672 100644
--- a/silx/resources/gui/icons/colormap.svg
+++ b/src/silx/resources/gui/icons/colormap.svg
diff --git a/silx/resources/gui/icons/compare-align-auto.png b/src/silx/resources/gui/icons/compare-align-auto.png
index 0a716e7..0a716e7 100644
--- a/silx/resources/gui/icons/compare-align-auto.png
+++ b/src/silx/resources/gui/icons/compare-align-auto.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/compare-align-auto.svg b/src/silx/resources/gui/icons/compare-align-auto.svg
new file mode 100644
index 0000000..29160a0
--- /dev/null
+++ b/src/silx/resources/gui/icons/compare-align-auto.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/></cc:Work></rdf:RDF></metadata><path id="path36" d="m10.992 6.764s4.839-0.584 5.992 4.366" fill="none" stroke="#FFF" stroke-miterlimit="10" stroke-width="1.2"/><g id="g4597" transform="matrix(.89618 0 0 .89618 33.643 30.672)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#fab058" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/><text id="text4553" x="-33.067287" y="-5.5593224" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551" x="-33.067287" y="-5.5593224" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">A</tspan></text>
+</g><g id="g4602" transform="matrix(.50611 .17057 -.17057 .50611 -5.8136 18.919)"><rect id="rect2-6-4" x="33.767" y="-32.267" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2" stroke-width="2.6213"/><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
+</g></svg>
diff --git a/silx/resources/gui/icons/compare-align-center.png b/src/silx/resources/gui/icons/compare-align-center.png
index bb2e8c1..bb2e8c1 100644
--- a/silx/resources/gui/icons/compare-align-center.png
+++ b/src/silx/resources/gui/icons/compare-align-center.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/compare-align-center.svg b/src/silx/resources/gui/icons/compare-align-center.svg
new file mode 100644
index 0000000..e93957c
--- /dev/null
+++ b/src/silx/resources/gui/icons/compare-align-center.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><g id="g4597" transform="matrix(.89618 0 0 .89618 35.067 29.248)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#fab058" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/>
+</g><g id="g4602" transform="matrix(.70181 0 0 .70181 -16.83 29.513)"><rect id="rect2-6-4" x="33.767" y="-32.267" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.9948"/><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
+</g></svg>
diff --git a/silx/resources/gui/icons/compare-align-origin.png b/src/silx/resources/gui/icons/compare-align-origin.png
index e209ce2..e209ce2 100644
--- a/silx/resources/gui/icons/compare-align-origin.png
+++ b/src/silx/resources/gui/icons/compare-align-origin.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/compare-align-origin.svg b/src/silx/resources/gui/icons/compare-align-origin.svg
new file mode 100644
index 0000000..e5cd921
--- /dev/null
+++ b/src/silx/resources/gui/icons/compare-align-origin.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><g id="g4597" transform="matrix(.89618 0 0 .89618 35.067 29.248)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#fab058" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/>
+</g><g id="g4602" transform="matrix(.70181 0 0 .70181 -19.285 27.058)"><rect id="rect2-6-4" x="33.767" y="-32.267" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.9948"/><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
+</g></svg>
diff --git a/silx/resources/gui/icons/compare-align-stretch.png b/src/silx/resources/gui/icons/compare-align-stretch.png
index 707bcd1..707bcd1 100644
--- a/silx/resources/gui/icons/compare-align-stretch.png
+++ b/src/silx/resources/gui/icons/compare-align-stretch.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/compare-align-stretch.svg b/src/silx/resources/gui/icons/compare-align-stretch.svg
new file mode 100644
index 0000000..6b8db1d
--- /dev/null
+++ b/src/silx/resources/gui/icons/compare-align-stretch.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg id="svg44" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata50"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><g id="g4597" transform="matrix(.89618 0 0 .89618 35.067 29.248)"><rect id="rect2-6" x="-34.289" y="-27.796" width="26.026" height="26.026" ry="0" fill="#d5fa58" stroke="#000" stroke-miterlimit="2.1" stroke-width="1.5622"/>
+</g><g id="g4602" transform="matrix(.70866 0 0 .70866 -17.151 29.645)"><text id="text4553-1" x="36.864368" y="-10.030853" fill="#000000" font-family="Scriptina" font-size="25.304px" letter-spacing="0px" stroke-width=".79075px" word-spacing="0px" style="line-height:125%" xml:space="preserve"><tspan id="tspan4551-7" x="36.864368" y="-10.030853" font-family="sans-serif" font-weight="bold" stroke-width=".79075px">B</tspan></text>
+</g><path id="rect969" d="m6.3051 6.1695h5.4237l-5.4237 5.4237z" color="#000000" fill="#f0f"/><path id="rect969-5" d="m26.034 6.1695h-5.4237l5.4237 5.4237z" color="#000000" fill="#f0f"/><path id="rect969-5-3" d="m26.034 25.763h-5.4237l5.4237-5.4237z" color="#000000" fill="#f0f"/><path id="rect969-5-3-5" d="m6.3051 25.763h5.4237l-5.4237-5.4237z" color="#000000" fill="#f0f"/></svg>
diff --git a/silx/resources/gui/icons/compare-keypoints.png b/src/silx/resources/gui/icons/compare-keypoints.png
index 0f93111..0f93111 100644
--- a/silx/resources/gui/icons/compare-keypoints.png
+++ b/src/silx/resources/gui/icons/compare-keypoints.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-keypoints.svg b/src/silx/resources/gui/icons/compare-keypoints.svg
index 9282526..9282526 100644
--- a/silx/resources/gui/icons/compare-keypoints.svg
+++ b/src/silx/resources/gui/icons/compare-keypoints.svg
diff --git a/silx/resources/gui/icons/compare-mode-a-minus-b.png b/src/silx/resources/gui/icons/compare-mode-a-minus-b.png
index 75f8db3..75f8db3 100644
--- a/silx/resources/gui/icons/compare-mode-a-minus-b.png
+++ b/src/silx/resources/gui/icons/compare-mode-a-minus-b.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-a-minus-b.svg b/src/silx/resources/gui/icons/compare-mode-a-minus-b.svg
index bd4afbf..bd4afbf 100644
--- a/silx/resources/gui/icons/compare-mode-a-minus-b.svg
+++ b/src/silx/resources/gui/icons/compare-mode-a-minus-b.svg
diff --git a/silx/resources/gui/icons/compare-mode-a.png b/src/silx/resources/gui/icons/compare-mode-a.png
index f1158f9..f1158f9 100644
--- a/silx/resources/gui/icons/compare-mode-a.png
+++ b/src/silx/resources/gui/icons/compare-mode-a.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-a.svg b/src/silx/resources/gui/icons/compare-mode-a.svg
index 57be2f3..57be2f3 100644
--- a/silx/resources/gui/icons/compare-mode-a.svg
+++ b/src/silx/resources/gui/icons/compare-mode-a.svg
diff --git a/silx/resources/gui/icons/compare-mode-b.png b/src/silx/resources/gui/icons/compare-mode-b.png
index 58cd231..58cd231 100644
--- a/silx/resources/gui/icons/compare-mode-b.png
+++ b/src/silx/resources/gui/icons/compare-mode-b.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-b.svg b/src/silx/resources/gui/icons/compare-mode-b.svg
index 929c2c0..929c2c0 100644
--- a/silx/resources/gui/icons/compare-mode-b.svg
+++ b/src/silx/resources/gui/icons/compare-mode-b.svg
diff --git a/silx/resources/gui/icons/compare-mode-hline.png b/src/silx/resources/gui/icons/compare-mode-hline.png
index 2a9d403..2a9d403 100644
--- a/silx/resources/gui/icons/compare-mode-hline.png
+++ b/src/silx/resources/gui/icons/compare-mode-hline.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-hline.svg b/src/silx/resources/gui/icons/compare-mode-hline.svg
index 9f9a2d4..9f9a2d4 100644
--- a/silx/resources/gui/icons/compare-mode-hline.svg
+++ b/src/silx/resources/gui/icons/compare-mode-hline.svg
diff --git a/silx/resources/gui/icons/compare-mode-rb-channel.png b/src/silx/resources/gui/icons/compare-mode-rb-channel.png
index 689c8a6..689c8a6 100644
--- a/silx/resources/gui/icons/compare-mode-rb-channel.png
+++ b/src/silx/resources/gui/icons/compare-mode-rb-channel.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-rb-channel.svg b/src/silx/resources/gui/icons/compare-mode-rb-channel.svg
index 9bcadd7..9bcadd7 100644
--- a/silx/resources/gui/icons/compare-mode-rb-channel.svg
+++ b/src/silx/resources/gui/icons/compare-mode-rb-channel.svg
diff --git a/silx/resources/gui/icons/compare-mode-rbneg-channel.png b/src/silx/resources/gui/icons/compare-mode-rbneg-channel.png
index 53b339f..53b339f 100644
--- a/silx/resources/gui/icons/compare-mode-rbneg-channel.png
+++ b/src/silx/resources/gui/icons/compare-mode-rbneg-channel.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-rbneg-channel.svg b/src/silx/resources/gui/icons/compare-mode-rbneg-channel.svg
index 8c23e65..8c23e65 100644
--- a/silx/resources/gui/icons/compare-mode-rbneg-channel.svg
+++ b/src/silx/resources/gui/icons/compare-mode-rbneg-channel.svg
diff --git a/silx/resources/gui/icons/compare-mode-vline.png b/src/silx/resources/gui/icons/compare-mode-vline.png
index fe12d9d..fe12d9d 100644
--- a/silx/resources/gui/icons/compare-mode-vline.png
+++ b/src/silx/resources/gui/icons/compare-mode-vline.png
Binary files differ
diff --git a/silx/resources/gui/icons/compare-mode-vline.svg b/src/silx/resources/gui/icons/compare-mode-vline.svg
index 06d74b3..06d74b3 100644
--- a/silx/resources/gui/icons/compare-mode-vline.svg
+++ b/src/silx/resources/gui/icons/compare-mode-vline.svg
diff --git a/silx/resources/gui/icons/crop.png b/src/silx/resources/gui/icons/crop.png
index 65dd7d1..65dd7d1 100755
--- a/silx/resources/gui/icons/crop.png
+++ b/src/silx/resources/gui/icons/crop.png
Binary files differ
diff --git a/silx/resources/gui/icons/crop.svg b/src/silx/resources/gui/icons/crop.svg
index acb27ec..acb27ec 100644
--- a/silx/resources/gui/icons/crop.svg
+++ b/src/silx/resources/gui/icons/crop.svg
diff --git a/silx/resources/gui/icons/crosshair.png b/src/silx/resources/gui/icons/crosshair.png
index 1d516a3..1d516a3 100644
--- a/silx/resources/gui/icons/crosshair.png
+++ b/src/silx/resources/gui/icons/crosshair.png
Binary files differ
diff --git a/silx/resources/gui/icons/crosshair.svg b/src/silx/resources/gui/icons/crosshair.svg
index e96ef83..e96ef83 100644
--- a/silx/resources/gui/icons/crosshair.svg
+++ b/src/silx/resources/gui/icons/crosshair.svg
diff --git a/silx/resources/gui/icons/cube-back.png b/src/silx/resources/gui/icons/cube-back.png
index 2e326df..2e326df 100644
--- a/silx/resources/gui/icons/cube-back.png
+++ b/src/silx/resources/gui/icons/cube-back.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-back.svg b/src/silx/resources/gui/icons/cube-back.svg
index d1d79a5..d1d79a5 100644
--- a/silx/resources/gui/icons/cube-back.svg
+++ b/src/silx/resources/gui/icons/cube-back.svg
diff --git a/silx/resources/gui/icons/cube-bottom.png b/src/silx/resources/gui/icons/cube-bottom.png
index 0b2aaaf..0b2aaaf 100644
--- a/silx/resources/gui/icons/cube-bottom.png
+++ b/src/silx/resources/gui/icons/cube-bottom.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-bottom.svg b/src/silx/resources/gui/icons/cube-bottom.svg
index f3d9cbc..f3d9cbc 100644
--- a/silx/resources/gui/icons/cube-bottom.svg
+++ b/src/silx/resources/gui/icons/cube-bottom.svg
diff --git a/silx/resources/gui/icons/cube-front.png b/src/silx/resources/gui/icons/cube-front.png
index 9165bd5..9165bd5 100644
--- a/silx/resources/gui/icons/cube-front.png
+++ b/src/silx/resources/gui/icons/cube-front.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-front.svg b/src/silx/resources/gui/icons/cube-front.svg
index 11f4fa2..11f4fa2 100644
--- a/silx/resources/gui/icons/cube-front.svg
+++ b/src/silx/resources/gui/icons/cube-front.svg
diff --git a/silx/resources/gui/icons/cube-left.png b/src/silx/resources/gui/icons/cube-left.png
index c84ad8e..c84ad8e 100644
--- a/silx/resources/gui/icons/cube-left.png
+++ b/src/silx/resources/gui/icons/cube-left.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-left.svg b/src/silx/resources/gui/icons/cube-left.svg
index 7d0ee95..7d0ee95 100644
--- a/silx/resources/gui/icons/cube-left.svg
+++ b/src/silx/resources/gui/icons/cube-left.svg
diff --git a/silx/resources/gui/icons/cube-right.png b/src/silx/resources/gui/icons/cube-right.png
index 6a913bb..6a913bb 100644
--- a/silx/resources/gui/icons/cube-right.png
+++ b/src/silx/resources/gui/icons/cube-right.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-right.svg b/src/silx/resources/gui/icons/cube-right.svg
index c98e3e1..c98e3e1 100644
--- a/silx/resources/gui/icons/cube-right.svg
+++ b/src/silx/resources/gui/icons/cube-right.svg
diff --git a/silx/resources/gui/icons/cube-rotate.png b/src/silx/resources/gui/icons/cube-rotate.png
index e2dc795..e2dc795 100644
--- a/silx/resources/gui/icons/cube-rotate.png
+++ b/src/silx/resources/gui/icons/cube-rotate.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-rotate.svg b/src/silx/resources/gui/icons/cube-rotate.svg
index 44cdfe4..44cdfe4 100644
--- a/silx/resources/gui/icons/cube-rotate.svg
+++ b/src/silx/resources/gui/icons/cube-rotate.svg
diff --git a/silx/resources/gui/icons/cube-top.png b/src/silx/resources/gui/icons/cube-top.png
index d6c1a62..d6c1a62 100644
--- a/silx/resources/gui/icons/cube-top.png
+++ b/src/silx/resources/gui/icons/cube-top.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube-top.svg b/src/silx/resources/gui/icons/cube-top.svg
index 1bc0e2c..1bc0e2c 100644
--- a/silx/resources/gui/icons/cube-top.svg
+++ b/src/silx/resources/gui/icons/cube-top.svg
diff --git a/silx/resources/gui/icons/cube.png b/src/silx/resources/gui/icons/cube.png
index 0dae173..0dae173 100644
--- a/silx/resources/gui/icons/cube.png
+++ b/src/silx/resources/gui/icons/cube.png
Binary files differ
diff --git a/silx/resources/gui/icons/cube.svg b/src/silx/resources/gui/icons/cube.svg
index 19e4f9c..19e4f9c 100644
--- a/silx/resources/gui/icons/cube.svg
+++ b/src/silx/resources/gui/icons/cube.svg
diff --git a/silx/resources/gui/icons/description-description.png b/src/silx/resources/gui/icons/description-description.png
index 36aaf0c..36aaf0c 100644
--- a/silx/resources/gui/icons/description-description.png
+++ b/src/silx/resources/gui/icons/description-description.png
Binary files differ
diff --git a/silx/resources/gui/icons/description-description.svg b/src/silx/resources/gui/icons/description-description.svg
index 8f953ce..8f953ce 100644
--- a/silx/resources/gui/icons/description-description.svg
+++ b/src/silx/resources/gui/icons/description-description.svg
diff --git a/silx/resources/gui/icons/description-error.png b/src/silx/resources/gui/icons/description-error.png
index 053f725..053f725 100644
--- a/silx/resources/gui/icons/description-error.png
+++ b/src/silx/resources/gui/icons/description-error.png
Binary files differ
diff --git a/silx/resources/gui/icons/description-error.svg b/src/silx/resources/gui/icons/description-error.svg
index 50dd7ac..50dd7ac 100644
--- a/silx/resources/gui/icons/description-error.svg
+++ b/src/silx/resources/gui/icons/description-error.svg
diff --git a/silx/resources/gui/icons/description-name.png b/src/silx/resources/gui/icons/description-name.png
index dae65f1..dae65f1 100644
--- a/silx/resources/gui/icons/description-name.png
+++ b/src/silx/resources/gui/icons/description-name.png
Binary files differ
diff --git a/silx/resources/gui/icons/description-name.svg b/src/silx/resources/gui/icons/description-name.svg
index ffbcb58..ffbcb58 100644
--- a/silx/resources/gui/icons/description-name.svg
+++ b/src/silx/resources/gui/icons/description-name.svg
diff --git a/silx/resources/gui/icons/description-program.png b/src/silx/resources/gui/icons/description-program.png
index 72e25d2..72e25d2 100644
--- a/silx/resources/gui/icons/description-program.png
+++ b/src/silx/resources/gui/icons/description-program.png
Binary files differ
diff --git a/silx/resources/gui/icons/description-program.svg b/src/silx/resources/gui/icons/description-program.svg
index 8c04541..8c04541 100644
--- a/silx/resources/gui/icons/description-program.svg
+++ b/src/silx/resources/gui/icons/description-program.svg
diff --git a/silx/resources/gui/icons/description-title.png b/src/silx/resources/gui/icons/description-title.png
index 274b6dd..274b6dd 100644
--- a/silx/resources/gui/icons/description-title.png
+++ b/src/silx/resources/gui/icons/description-title.png
Binary files differ
diff --git a/silx/resources/gui/icons/description-title.svg b/src/silx/resources/gui/icons/description-title.svg
index 9c3eee1..9c3eee1 100644
--- a/silx/resources/gui/icons/description-title.svg
+++ b/src/silx/resources/gui/icons/description-title.svg
diff --git a/silx/resources/gui/icons/description-value.png b/src/silx/resources/gui/icons/description-value.png
index a73ebf1..a73ebf1 100644
--- a/silx/resources/gui/icons/description-value.png
+++ b/src/silx/resources/gui/icons/description-value.png
Binary files differ
diff --git a/silx/resources/gui/icons/description-value.svg b/src/silx/resources/gui/icons/description-value.svg
index 8371771..8371771 100644
--- a/silx/resources/gui/icons/description-value.svg
+++ b/src/silx/resources/gui/icons/description-value.svg
diff --git a/silx/resources/gui/icons/document-open.png b/src/silx/resources/gui/icons/document-open.png
index 15ca326..15ca326 100755
--- a/silx/resources/gui/icons/document-open.png
+++ b/src/silx/resources/gui/icons/document-open.png
Binary files differ
diff --git a/silx/resources/gui/icons/document-open.svg b/src/silx/resources/gui/icons/document-open.svg
index 0046cfd..0046cfd 100644
--- a/silx/resources/gui/icons/document-open.svg
+++ b/src/silx/resources/gui/icons/document-open.svg
diff --git a/silx/resources/gui/icons/document-print.png b/src/silx/resources/gui/icons/document-print.png
index d4a3633..d4a3633 100755
--- a/silx/resources/gui/icons/document-print.png
+++ b/src/silx/resources/gui/icons/document-print.png
Binary files differ
diff --git a/silx/resources/gui/icons/document-print.svg b/src/silx/resources/gui/icons/document-print.svg
index 3ff1099..3ff1099 100644
--- a/silx/resources/gui/icons/document-print.svg
+++ b/src/silx/resources/gui/icons/document-print.svg
diff --git a/silx/resources/gui/icons/document-save.png b/src/silx/resources/gui/icons/document-save.png
index 5229d2b..5229d2b 100755
--- a/silx/resources/gui/icons/document-save.png
+++ b/src/silx/resources/gui/icons/document-save.png
Binary files differ
diff --git a/silx/resources/gui/icons/document-save.svg b/src/silx/resources/gui/icons/document-save.svg
index 5134a42..5134a42 100644
--- a/silx/resources/gui/icons/document-save.svg
+++ b/src/silx/resources/gui/icons/document-save.svg
diff --git a/silx/resources/gui/icons/draw-brush.png b/src/silx/resources/gui/icons/draw-brush.png
index 6184079..6184079 100755
--- a/silx/resources/gui/icons/draw-brush.png
+++ b/src/silx/resources/gui/icons/draw-brush.png
Binary files differ
diff --git a/silx/resources/gui/icons/draw-brush.svg b/src/silx/resources/gui/icons/draw-brush.svg
index b371236..b371236 100644
--- a/silx/resources/gui/icons/draw-brush.svg
+++ b/src/silx/resources/gui/icons/draw-brush.svg
diff --git a/silx/resources/gui/icons/draw-pencil.png b/src/silx/resources/gui/icons/draw-pencil.png
index be47b74..be47b74 100755
--- a/silx/resources/gui/icons/draw-pencil.png
+++ b/src/silx/resources/gui/icons/draw-pencil.png
Binary files differ
diff --git a/silx/resources/gui/icons/draw-pencil.svg b/src/silx/resources/gui/icons/draw-pencil.svg
index 255bdc5..255bdc5 100644
--- a/silx/resources/gui/icons/draw-pencil.svg
+++ b/src/silx/resources/gui/icons/draw-pencil.svg
diff --git a/silx/resources/gui/icons/draw-rubber.png b/src/silx/resources/gui/icons/draw-rubber.png
index b1b24c1..b1b24c1 100755
--- a/silx/resources/gui/icons/draw-rubber.png
+++ b/src/silx/resources/gui/icons/draw-rubber.png
Binary files differ
diff --git a/silx/resources/gui/icons/draw-rubber.svg b/src/silx/resources/gui/icons/draw-rubber.svg
index 29a98b5..29a98b5 100644
--- a/silx/resources/gui/icons/draw-rubber.svg
+++ b/src/silx/resources/gui/icons/draw-rubber.svg
diff --git a/silx/resources/gui/icons/edit-copy.png b/src/silx/resources/gui/icons/edit-copy.png
index 8fe3281..8fe3281 100644
--- a/silx/resources/gui/icons/edit-copy.png
+++ b/src/silx/resources/gui/icons/edit-copy.png
Binary files differ
diff --git a/silx/resources/gui/icons/edit-copy.svg b/src/silx/resources/gui/icons/edit-copy.svg
index 6100075..6100075 100644
--- a/silx/resources/gui/icons/edit-copy.svg
+++ b/src/silx/resources/gui/icons/edit-copy.svg
diff --git a/silx/resources/gui/icons/eye.png b/src/silx/resources/gui/icons/eye.png
index a2d1c23..a2d1c23 100644
--- a/silx/resources/gui/icons/eye.png
+++ b/src/silx/resources/gui/icons/eye.png
Binary files differ
diff --git a/silx/resources/gui/icons/eye.svg b/src/silx/resources/gui/icons/eye.svg
index 7658d86..7658d86 100644
--- a/silx/resources/gui/icons/eye.svg
+++ b/src/silx/resources/gui/icons/eye.svg
diff --git a/silx/resources/gui/icons/first.png b/src/silx/resources/gui/icons/first.png
index fe3b87c..fe3b87c 100644
--- a/silx/resources/gui/icons/first.png
+++ b/src/silx/resources/gui/icons/first.png
Binary files differ
diff --git a/silx/resources/gui/icons/first.svg b/src/silx/resources/gui/icons/first.svg
index bb3b5d8..bb3b5d8 100644
--- a/silx/resources/gui/icons/first.svg
+++ b/src/silx/resources/gui/icons/first.svg
diff --git a/silx/resources/gui/icons/folder.png b/src/silx/resources/gui/icons/folder.png
index 61c8f55..61c8f55 100755
--- a/silx/resources/gui/icons/folder.png
+++ b/src/silx/resources/gui/icons/folder.png
Binary files differ
diff --git a/silx/resources/gui/icons/folder.svg b/src/silx/resources/gui/icons/folder.svg
index 5c3b194..5c3b194 100644
--- a/silx/resources/gui/icons/folder.svg
+++ b/src/silx/resources/gui/icons/folder.svg
diff --git a/silx/resources/gui/icons/image-mask.png b/src/silx/resources/gui/icons/image-mask.png
index 44032e0..44032e0 100644
--- a/silx/resources/gui/icons/image-mask.png
+++ b/src/silx/resources/gui/icons/image-mask.png
Binary files differ
diff --git a/silx/resources/gui/icons/image-mask.svg b/src/silx/resources/gui/icons/image-mask.svg
index 1309376..1309376 100644
--- a/silx/resources/gui/icons/image-mask.svg
+++ b/src/silx/resources/gui/icons/image-mask.svg
diff --git a/silx/resources/gui/icons/image-select-add.png b/src/silx/resources/gui/icons/image-select-add.png
index 8a89cc3..8a89cc3 100755
--- a/silx/resources/gui/icons/image-select-add.png
+++ b/src/silx/resources/gui/icons/image-select-add.png
Binary files differ
diff --git a/silx/resources/gui/icons/image-select-add.svg b/src/silx/resources/gui/icons/image-select-add.svg
index 1856bd8..1856bd8 100644
--- a/silx/resources/gui/icons/image-select-add.svg
+++ b/src/silx/resources/gui/icons/image-select-add.svg
diff --git a/silx/resources/gui/icons/image-select-box.png b/src/silx/resources/gui/icons/image-select-box.png
index ffc9ddc..ffc9ddc 100755
--- a/silx/resources/gui/icons/image-select-box.png
+++ b/src/silx/resources/gui/icons/image-select-box.png
Binary files differ
diff --git a/silx/resources/gui/icons/image-select-box.svg b/src/silx/resources/gui/icons/image-select-box.svg
index 421cee9..421cee9 100644
--- a/silx/resources/gui/icons/image-select-box.svg
+++ b/src/silx/resources/gui/icons/image-select-box.svg
diff --git a/silx/resources/gui/icons/image-select-brush.png b/src/silx/resources/gui/icons/image-select-brush.png
index 33c4d1e..33c4d1e 100755
--- a/silx/resources/gui/icons/image-select-brush.png
+++ b/src/silx/resources/gui/icons/image-select-brush.png
Binary files differ
diff --git a/silx/resources/gui/icons/image-select-brush.svg b/src/silx/resources/gui/icons/image-select-brush.svg
index 8f88b4b..8f88b4b 100644
--- a/silx/resources/gui/icons/image-select-brush.svg
+++ b/src/silx/resources/gui/icons/image-select-brush.svg
diff --git a/silx/resources/gui/icons/image-select-erase-rubber.png b/src/silx/resources/gui/icons/image-select-erase-rubber.png
index 175eb11..175eb11 100755
--- a/silx/resources/gui/icons/image-select-erase-rubber.png
+++ b/src/silx/resources/gui/icons/image-select-erase-rubber.png
Binary files differ
diff --git a/silx/resources/gui/icons/image-select-erase-rubber.svg b/src/silx/resources/gui/icons/image-select-erase-rubber.svg
index b6fb880..b6fb880 100644
--- a/silx/resources/gui/icons/image-select-erase-rubber.svg
+++ b/src/silx/resources/gui/icons/image-select-erase-rubber.svg
diff --git a/silx/resources/gui/icons/image-select-erase.png b/src/silx/resources/gui/icons/image-select-erase.png
index d5d1a5b..d5d1a5b 100755
--- a/silx/resources/gui/icons/image-select-erase.png
+++ b/src/silx/resources/gui/icons/image-select-erase.png
Binary files differ
diff --git a/silx/resources/gui/icons/image-select-erase.svg b/src/silx/resources/gui/icons/image-select-erase.svg
index afb105b..afb105b 100644
--- a/silx/resources/gui/icons/image-select-erase.svg
+++ b/src/silx/resources/gui/icons/image-select-erase.svg
diff --git a/silx/resources/gui/icons/image.png b/src/silx/resources/gui/icons/image.png
index 484caa0..484caa0 100755
--- a/silx/resources/gui/icons/image.png
+++ b/src/silx/resources/gui/icons/image.png
Binary files differ
diff --git a/silx/resources/gui/icons/image.svg b/src/silx/resources/gui/icons/image.svg
index 5789160..5789160 100644
--- a/silx/resources/gui/icons/image.svg
+++ b/src/silx/resources/gui/icons/image.svg
diff --git a/silx/resources/gui/icons/item-0dim.png b/src/silx/resources/gui/icons/item-0dim.png
index e0f75bf..e0f75bf 100644
--- a/silx/resources/gui/icons/item-0dim.png
+++ b/src/silx/resources/gui/icons/item-0dim.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-0dim.svg b/src/silx/resources/gui/icons/item-0dim.svg
index 9a86c3a..9a86c3a 100644
--- a/silx/resources/gui/icons/item-0dim.svg
+++ b/src/silx/resources/gui/icons/item-0dim.svg
diff --git a/silx/resources/gui/icons/item-1dim.png b/src/silx/resources/gui/icons/item-1dim.png
index 49622bc..49622bc 100644
--- a/silx/resources/gui/icons/item-1dim.png
+++ b/src/silx/resources/gui/icons/item-1dim.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-1dim.svg b/src/silx/resources/gui/icons/item-1dim.svg
index a422e31..a422e31 100644
--- a/silx/resources/gui/icons/item-1dim.svg
+++ b/src/silx/resources/gui/icons/item-1dim.svg
diff --git a/silx/resources/gui/icons/item-2dim.png b/src/silx/resources/gui/icons/item-2dim.png
index 6dafb6b..6dafb6b 100644
--- a/silx/resources/gui/icons/item-2dim.png
+++ b/src/silx/resources/gui/icons/item-2dim.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-2dim.svg b/src/silx/resources/gui/icons/item-2dim.svg
index 8e80fd0..8e80fd0 100644
--- a/silx/resources/gui/icons/item-2dim.svg
+++ b/src/silx/resources/gui/icons/item-2dim.svg
diff --git a/silx/resources/gui/icons/item-3dim.png b/src/silx/resources/gui/icons/item-3dim.png
index b9ec4f5..b9ec4f5 100644
--- a/silx/resources/gui/icons/item-3dim.png
+++ b/src/silx/resources/gui/icons/item-3dim.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-3dim.svg b/src/silx/resources/gui/icons/item-3dim.svg
index 2220ee3..2220ee3 100644
--- a/silx/resources/gui/icons/item-3dim.svg
+++ b/src/silx/resources/gui/icons/item-3dim.svg
diff --git a/silx/resources/gui/icons/item-ndim.png b/src/silx/resources/gui/icons/item-ndim.png
index 65dd21c..65dd21c 100644
--- a/silx/resources/gui/icons/item-ndim.png
+++ b/src/silx/resources/gui/icons/item-ndim.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-ndim.svg b/src/silx/resources/gui/icons/item-ndim.svg
index a00e1b3..a00e1b3 100644
--- a/silx/resources/gui/icons/item-ndim.svg
+++ b/src/silx/resources/gui/icons/item-ndim.svg
diff --git a/silx/resources/gui/icons/item-none.png b/src/silx/resources/gui/icons/item-none.png
index 42f7f88..42f7f88 100644
--- a/silx/resources/gui/icons/item-none.png
+++ b/src/silx/resources/gui/icons/item-none.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-none.svg b/src/silx/resources/gui/icons/item-none.svg
index 08a5b51..08a5b51 100644
--- a/silx/resources/gui/icons/item-none.svg
+++ b/src/silx/resources/gui/icons/item-none.svg
diff --git a/silx/resources/gui/icons/item-object.png b/src/silx/resources/gui/icons/item-object.png
index f8e3283..f8e3283 100644
--- a/silx/resources/gui/icons/item-object.png
+++ b/src/silx/resources/gui/icons/item-object.png
Binary files differ
diff --git a/silx/resources/gui/icons/item-object.svg b/src/silx/resources/gui/icons/item-object.svg
index 4f36bbe..4f36bbe 100644
--- a/silx/resources/gui/icons/item-object.svg
+++ b/src/silx/resources/gui/icons/item-object.svg
diff --git a/silx/resources/gui/icons/last.png b/src/silx/resources/gui/icons/last.png
index 4418006..4418006 100644
--- a/silx/resources/gui/icons/last.png
+++ b/src/silx/resources/gui/icons/last.png
Binary files differ
diff --git a/silx/resources/gui/icons/last.svg b/src/silx/resources/gui/icons/last.svg
index df8d7d3..df8d7d3 100644
--- a/silx/resources/gui/icons/last.svg
+++ b/src/silx/resources/gui/icons/last.svg
diff --git a/silx/resources/gui/icons/layer-nx.png b/src/silx/resources/gui/icons/layer-nx.png
index a1587b2..a1587b2 100644
--- a/silx/resources/gui/icons/layer-nx.png
+++ b/src/silx/resources/gui/icons/layer-nx.png
Binary files differ
diff --git a/silx/resources/gui/icons/layer-nx.svg b/src/silx/resources/gui/icons/layer-nx.svg
index c177985..c177985 100644
--- a/silx/resources/gui/icons/layer-nx.svg
+++ b/src/silx/resources/gui/icons/layer-nx.svg
diff --git a/silx/resources/gui/icons/mask-clear-all.png b/src/silx/resources/gui/icons/mask-clear-all.png
index 2d6cf55..2d6cf55 100644
--- a/silx/resources/gui/icons/mask-clear-all.png
+++ b/src/silx/resources/gui/icons/mask-clear-all.png
Binary files differ
diff --git a/silx/resources/gui/icons/mask-clear-all.svg b/src/silx/resources/gui/icons/mask-clear-all.svg
index 7db5055..7db5055 100644
--- a/silx/resources/gui/icons/mask-clear-all.svg
+++ b/src/silx/resources/gui/icons/mask-clear-all.svg
diff --git a/silx/resources/gui/icons/mask-clear.png b/src/silx/resources/gui/icons/mask-clear.png
index 940b607..940b607 100644
--- a/silx/resources/gui/icons/mask-clear.png
+++ b/src/silx/resources/gui/icons/mask-clear.png
Binary files differ
diff --git a/silx/resources/gui/icons/mask-clear.svg b/src/silx/resources/gui/icons/mask-clear.svg
index 77410c2..77410c2 100644
--- a/silx/resources/gui/icons/mask-clear.svg
+++ b/src/silx/resources/gui/icons/mask-clear.svg
diff --git a/silx/resources/gui/icons/mask-invert.png b/src/silx/resources/gui/icons/mask-invert.png
index f1cc339..f1cc339 100644
--- a/silx/resources/gui/icons/mask-invert.png
+++ b/src/silx/resources/gui/icons/mask-invert.png
Binary files differ
diff --git a/silx/resources/gui/icons/mask-invert.svg b/src/silx/resources/gui/icons/mask-invert.svg
index 8fb0c17..8fb0c17 100644
--- a/silx/resources/gui/icons/mask-invert.svg
+++ b/src/silx/resources/gui/icons/mask-invert.svg
diff --git a/silx/resources/gui/icons/math-amplitude.png b/src/silx/resources/gui/icons/math-amplitude.png
index ae31474..ae31474 100644
--- a/silx/resources/gui/icons/math-amplitude.png
+++ b/src/silx/resources/gui/icons/math-amplitude.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-amplitude.svg b/src/silx/resources/gui/icons/math-amplitude.svg
index 497f62e..497f62e 100644
--- a/silx/resources/gui/icons/math-amplitude.svg
+++ b/src/silx/resources/gui/icons/math-amplitude.svg
diff --git a/silx/resources/gui/icons/math-average.png b/src/silx/resources/gui/icons/math-average.png
index 675cd62..675cd62 100755
--- a/silx/resources/gui/icons/math-average.png
+++ b/src/silx/resources/gui/icons/math-average.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-average.svg b/src/silx/resources/gui/icons/math-average.svg
index 418f1eb..418f1eb 100644
--- a/silx/resources/gui/icons/math-average.svg
+++ b/src/silx/resources/gui/icons/math-average.svg
diff --git a/silx/resources/gui/icons/math-derive.png b/src/silx/resources/gui/icons/math-derive.png
index 2a31042..2a31042 100755
--- a/silx/resources/gui/icons/math-derive.png
+++ b/src/silx/resources/gui/icons/math-derive.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-derive.svg b/src/silx/resources/gui/icons/math-derive.svg
index 7c77457..7c77457 100644
--- a/silx/resources/gui/icons/math-derive.svg
+++ b/src/silx/resources/gui/icons/math-derive.svg
diff --git a/silx/resources/gui/icons/math-energy.png b/src/silx/resources/gui/icons/math-energy.png
index 341f483..341f483 100755
--- a/silx/resources/gui/icons/math-energy.png
+++ b/src/silx/resources/gui/icons/math-energy.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-energy.svg b/src/silx/resources/gui/icons/math-energy.svg
index 58b2aec..58b2aec 100644
--- a/silx/resources/gui/icons/math-energy.svg
+++ b/src/silx/resources/gui/icons/math-energy.svg
diff --git a/silx/resources/gui/icons/math-fit.png b/src/silx/resources/gui/icons/math-fit.png
index c4fcd30..c4fcd30 100755
--- a/silx/resources/gui/icons/math-fit.png
+++ b/src/silx/resources/gui/icons/math-fit.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-fit.svg b/src/silx/resources/gui/icons/math-fit.svg
index dbb2716..dbb2716 100644
--- a/silx/resources/gui/icons/math-fit.svg
+++ b/src/silx/resources/gui/icons/math-fit.svg
diff --git a/silx/resources/gui/icons/math-imaginary.png b/src/silx/resources/gui/icons/math-imaginary.png
index 6327beb..6327beb 100644
--- a/silx/resources/gui/icons/math-imaginary.png
+++ b/src/silx/resources/gui/icons/math-imaginary.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-imaginary.svg b/src/silx/resources/gui/icons/math-imaginary.svg
index c60a636..c60a636 100644
--- a/silx/resources/gui/icons/math-imaginary.svg
+++ b/src/silx/resources/gui/icons/math-imaginary.svg
diff --git a/silx/resources/gui/icons/math-mean.png b/src/silx/resources/gui/icons/math-mean.png
index fb4a210..fb4a210 100644
--- a/silx/resources/gui/icons/math-mean.png
+++ b/src/silx/resources/gui/icons/math-mean.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-mean.svg b/src/silx/resources/gui/icons/math-mean.svg
index f6b121d..f6b121d 100644
--- a/silx/resources/gui/icons/math-mean.svg
+++ b/src/silx/resources/gui/icons/math-mean.svg
diff --git a/silx/resources/gui/icons/math-normalize.png b/src/silx/resources/gui/icons/math-normalize.png
index 14db904..14db904 100755
--- a/silx/resources/gui/icons/math-normalize.png
+++ b/src/silx/resources/gui/icons/math-normalize.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-normalize.svg b/src/silx/resources/gui/icons/math-normalize.svg
index f6c0465..f6c0465 100644
--- a/silx/resources/gui/icons/math-normalize.svg
+++ b/src/silx/resources/gui/icons/math-normalize.svg
diff --git a/silx/resources/gui/icons/math-peak-reset.png b/src/silx/resources/gui/icons/math-peak-reset.png
index ec0932b..ec0932b 100755
--- a/silx/resources/gui/icons/math-peak-reset.png
+++ b/src/silx/resources/gui/icons/math-peak-reset.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-peak-reset.svg b/src/silx/resources/gui/icons/math-peak-reset.svg
index 7185046..7185046 100644
--- a/silx/resources/gui/icons/math-peak-reset.svg
+++ b/src/silx/resources/gui/icons/math-peak-reset.svg
diff --git a/silx/resources/gui/icons/math-peak-search.png b/src/silx/resources/gui/icons/math-peak-search.png
index 28db259..28db259 100755
--- a/silx/resources/gui/icons/math-peak-search.png
+++ b/src/silx/resources/gui/icons/math-peak-search.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/math-peak-search.svg b/src/silx/resources/gui/icons/math-peak-search.svg
new file mode 100644
index 0000000..40a71be
--- /dev/null
+++ b/src/silx/resources/gui/icons/math-peak-search.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><defs><filter id="a" color-interpolation-filters="sRGB"><feGaussianBlur stdDeviation="1.2128746"/></filter></defs><path d="m4.356 26.781c0.66-0.935 1.841-0.809 2.729-1.399 0.703-0.467 0.856-1.623 0.992-2.349 0.218-1.165-0.362-4.839 1.218-5.27 1.004-0.274 1.677-0.422 2.422-1.176 1.721-1.742 1.883-4.988 2.669-7.182 0.504-1.407 1.142-1.524 1.711-0.079 0.35 0.886 0.697 1.771 1.017 2.668 0.689 1.934 1.256 3.931 1.737 5.926 0.45 1.865 0.957 3.707 1.576 5.523 0.279 0.821 0.38 1.479 1.177 1.893 1.154 0.598 1.675-0.925 1.896-1.673 0.278-0.937 0.439-1.908 0.69-2.854 0.455-1.711 0.864 0.714 1.019 1.371 0.442 1.884 0.466 3.932 1.071 5.769 0.181 0.549 1.05 0.314 0.867-0.238-0.398-1.209-0.782-9.396-2.967-8.609-1.242 0.448-1.363 3.699-1.672 4.738-0.364 1.226-1.034-0.032-1.215-0.635-0.366-1.225-0.775-2.429-1.108-3.664-0.629-2.33-1.193-4.659-1.927-6.96-0.276-0.867-1.45-6-3.046-5.583-2.015 0.528-2.388 4.501-2.846 6.112-0.615 2.163-1.571 3.309-3.726 3.896-0.864 0.236-1.143 0.979-1.28 1.771-0.3 1.735 0.738 5.357-1.488 6.215-1.107 0.426-1.578 0.317-2.295 1.332-0.334 0.478 0.447 0.927 0.779 0.457z"/><g transform="translate(1.6271 .13559)" filter="url(#a)"><path d="m2.1425 16.187c-0.417 0.236-1.12 0.115-1.557-0.271-0.442-0.39-0.455-0.906-0.039-1.147l7.33-4.184c0.422-0.242 1.121-0.119 1.56 0.27 0.44 0.392 0.457 0.901 0.035 1.146l-7.329 4.186z" stroke="#00a651" stroke-miterlimit="10" stroke-width=".1"/><path d="m14.176 2.8136c-1.8408-0.22181-3.7106 0.0891-5.25 0.96875-1.5391 0.88172-2.4552 2.2584-2.5625 3.75-0.10727 1.4916 0.57148 3.0357 1.9375 4.25 2.7388 2.4255 7.203 2.9807 10.281 1.2188 1.5391-0.87925 2.4546-2.2587 2.5625-3.75 0.10787-1.4913-0.5729-3.0355-1.9375-4.25-1.3708-1.2142-3.1904-1.9657-5.0312-2.1875zm-0.15625 1.5625c1.5617 0.18769 3.0903 0.77817 4.1875 1.75 1.0904 0.97048 1.5071 2.0373 1.4375 3-0.06963 0.96271-0.62261 1.8827-1.8125 2.5625-2.3797 1.3621-6.3401 0.90923-8.5312-1.0312-1.092-0.9707-1.5067-2.0686-1.4375-3.0312 0.06923-0.96267 0.62157-1.849 1.8125-2.5312 1.1906-0.68035 2.7821-0.90644 4.3437-0.71875z" color="#000000" style="block-progression:tb;text-indent:0;text-transform:none"/><path d="m30.572 31.718c0.247 0.361 0.019 0.865-0.506 1.109-0.531 0.246-1.174 0.141-1.42-0.221l-4.346-6.416c-0.255-0.369-0.025-0.869 0.502-1.111 0.533-0.244 1.163-0.146 1.422 0.227l4.348 6.412z" stroke="#00a651" stroke-miterlimit="10" stroke-width=".1"/><path d="m21.551 15.595c-0.87491 0.08975-1.7393 0.30814-2.5625 0.6875-1.6444 0.76154-2.8268 2.0268-3.3438 3.4688-0.51696 1.4419-0.34202 3.0547 0.59375 4.4375v0.03125c1.8808 2.7617 5.9597 3.6148 9.25 2.0938 1.6461-0.76046 2.8267-2.0267 3.3438-3.4688 0.5171-1.442 0.34525-3.0565-0.59375-4.4375-1.4049-2.0747-4.0628-3.0818-6.6875-2.8125zm0.15625 1.5c2.128-0.19847 4.2576 0.6445 5.2812 2.1562 0.684 1.006 0.7729 2.0713 0.40625 3.0938s-1.2054 1.9812-2.5312 2.5938c-2.6497 1.2249-6.0018 0.45384-7.375-1.5625-0.68223-1.0082-0.80429-2.1019-0.4375-3.125s1.2379-1.9803 2.5625-2.5938c0.66327-0.30564 1.3844-0.49634 2.0938-0.5625z" color="#000000" style="block-progression:tb;text-indent:0;text-transform:none"/></g><g stroke="#00a651" stroke-miterlimit="10"><path d="m3.222 15.385c-0.417 0.236-1.12 0.115-1.557-0.271-0.442-0.39-0.455-0.906-0.039-1.147l7.33-4.184c0.422-0.242 1.121-0.119 1.56 0.27 0.44 0.392 0.457 0.901 0.035 1.146l-7.329 4.186z" fill="#00a651" stroke-width=".1"/><path d="m19.291 11.538c-2.729 1.562-6.936 1.054-9.401-1.129-2.458-2.185-2.241-5.219 0.489-6.783 2.73-1.56 6.936-1.054 9.404 1.132 2.455 2.185 2.237 5.221-0.492 6.78z" fill="none" stroke-width="1.5"/><path d="m31.651 30.916c0.247 0.361 0.019 0.865-0.506 1.109-0.531 0.246-1.174 0.141-1.42-0.221l-4.346-6.416c-0.255-0.369-0.025-0.869 0.502-1.111 0.533-0.244 1.163-0.146 1.422 0.227l4.348 6.412z" fill="#00a651" stroke-width=".1"/><path d="m28.693 18.014c1.623 2.387 0.53 5.436-2.442 6.809-2.97 1.373-6.686 0.547-8.313-1.842-1.618-2.391-0.526-5.438 2.443-6.813 2.973-1.37 6.693-0.545 8.312 1.846z" fill="none" stroke-width="1.5"/></g></svg>
diff --git a/silx/resources/gui/icons/math-peak.png b/src/silx/resources/gui/icons/math-peak.png
index 604776d..604776d 100755
--- a/silx/resources/gui/icons/math-peak.png
+++ b/src/silx/resources/gui/icons/math-peak.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-peak.svg b/src/silx/resources/gui/icons/math-peak.svg
index 94a7b1d..94a7b1d 100644
--- a/silx/resources/gui/icons/math-peak.svg
+++ b/src/silx/resources/gui/icons/math-peak.svg
diff --git a/silx/resources/gui/icons/math-phase-color-log.png b/src/silx/resources/gui/icons/math-phase-color-log.png
index 647d634..647d634 100644
--- a/silx/resources/gui/icons/math-phase-color-log.png
+++ b/src/silx/resources/gui/icons/math-phase-color-log.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-phase-color-log.svg b/src/silx/resources/gui/icons/math-phase-color-log.svg
index 803bebf..803bebf 100644
--- a/silx/resources/gui/icons/math-phase-color-log.svg
+++ b/src/silx/resources/gui/icons/math-phase-color-log.svg
diff --git a/silx/resources/gui/icons/math-phase-color.png b/src/silx/resources/gui/icons/math-phase-color.png
index d24d335..d24d335 100644
--- a/silx/resources/gui/icons/math-phase-color.png
+++ b/src/silx/resources/gui/icons/math-phase-color.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-phase-color.svg b/src/silx/resources/gui/icons/math-phase-color.svg
index 65bd287..65bd287 100644
--- a/silx/resources/gui/icons/math-phase-color.svg
+++ b/src/silx/resources/gui/icons/math-phase-color.svg
diff --git a/silx/resources/gui/icons/math-phase.png b/src/silx/resources/gui/icons/math-phase.png
index da3867a..da3867a 100644
--- a/silx/resources/gui/icons/math-phase.png
+++ b/src/silx/resources/gui/icons/math-phase.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-phase.svg b/src/silx/resources/gui/icons/math-phase.svg
index 275eb69..275eb69 100644
--- a/silx/resources/gui/icons/math-phase.svg
+++ b/src/silx/resources/gui/icons/math-phase.svg
diff --git a/silx/resources/gui/icons/math-real.png b/src/silx/resources/gui/icons/math-real.png
index fbe4868..fbe4868 100644
--- a/silx/resources/gui/icons/math-real.png
+++ b/src/silx/resources/gui/icons/math-real.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-real.svg b/src/silx/resources/gui/icons/math-real.svg
index 2f0d6d8..2f0d6d8 100644
--- a/silx/resources/gui/icons/math-real.svg
+++ b/src/silx/resources/gui/icons/math-real.svg
diff --git a/silx/resources/gui/icons/math-sigma.png b/src/silx/resources/gui/icons/math-sigma.png
index ecbd054..ecbd054 100755
--- a/silx/resources/gui/icons/math-sigma.png
+++ b/src/silx/resources/gui/icons/math-sigma.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-sigma.svg b/src/silx/resources/gui/icons/math-sigma.svg
index bbe8347..bbe8347 100644
--- a/silx/resources/gui/icons/math-sigma.svg
+++ b/src/silx/resources/gui/icons/math-sigma.svg
diff --git a/silx/resources/gui/icons/math-smooth.png b/src/silx/resources/gui/icons/math-smooth.png
index 06eda41..06eda41 100755
--- a/silx/resources/gui/icons/math-smooth.png
+++ b/src/silx/resources/gui/icons/math-smooth.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-smooth.svg b/src/silx/resources/gui/icons/math-smooth.svg
index 21a90f2..21a90f2 100644
--- a/silx/resources/gui/icons/math-smooth.svg
+++ b/src/silx/resources/gui/icons/math-smooth.svg
diff --git a/silx/resources/gui/icons/math-square-amplitude.png b/src/silx/resources/gui/icons/math-square-amplitude.png
index 2da16f2..2da16f2 100644
--- a/silx/resources/gui/icons/math-square-amplitude.png
+++ b/src/silx/resources/gui/icons/math-square-amplitude.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-square-amplitude.svg b/src/silx/resources/gui/icons/math-square-amplitude.svg
index 7c18730..7c18730 100644
--- a/silx/resources/gui/icons/math-square-amplitude.svg
+++ b/src/silx/resources/gui/icons/math-square-amplitude.svg
diff --git a/silx/resources/gui/icons/math-substract.png b/src/silx/resources/gui/icons/math-substract.png
index cf7627c..cf7627c 100755
--- a/silx/resources/gui/icons/math-substract.png
+++ b/src/silx/resources/gui/icons/math-substract.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-substract.svg b/src/silx/resources/gui/icons/math-substract.svg
index 620b439..620b439 100644
--- a/silx/resources/gui/icons/math-substract.svg
+++ b/src/silx/resources/gui/icons/math-substract.svg
diff --git a/silx/resources/gui/icons/math-swap-sign.png b/src/silx/resources/gui/icons/math-swap-sign.png
index 8e67e81..8e67e81 100755
--- a/silx/resources/gui/icons/math-swap-sign.png
+++ b/src/silx/resources/gui/icons/math-swap-sign.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-swap-sign.svg b/src/silx/resources/gui/icons/math-swap-sign.svg
index 55fb7aa..55fb7aa 100644
--- a/silx/resources/gui/icons/math-swap-sign.svg
+++ b/src/silx/resources/gui/icons/math-swap-sign.svg
diff --git a/silx/resources/gui/icons/math-ymin-to-zero.png b/src/silx/resources/gui/icons/math-ymin-to-zero.png
index 3366e35..3366e35 100755
--- a/silx/resources/gui/icons/math-ymin-to-zero.png
+++ b/src/silx/resources/gui/icons/math-ymin-to-zero.png
Binary files differ
diff --git a/silx/resources/gui/icons/math-ymin-to-zero.svg b/src/silx/resources/gui/icons/math-ymin-to-zero.svg
index ca218d1..ca218d1 100644
--- a/silx/resources/gui/icons/math-ymin-to-zero.svg
+++ b/src/silx/resources/gui/icons/math-ymin-to-zero.svg
diff --git a/silx/resources/gui/icons/median-filter.png b/src/silx/resources/gui/icons/median-filter.png
index ef47103..ef47103 100644
--- a/silx/resources/gui/icons/median-filter.png
+++ b/src/silx/resources/gui/icons/median-filter.png
Binary files differ
diff --git a/silx/resources/gui/icons/median-filter.svg b/src/silx/resources/gui/icons/median-filter.svg
index e908860..e908860 100644
--- a/silx/resources/gui/icons/median-filter.svg
+++ b/src/silx/resources/gui/icons/median-filter.svg
diff --git a/silx/resources/gui/icons/next.png b/src/silx/resources/gui/icons/next.png
index 1137720..1137720 100644
--- a/silx/resources/gui/icons/next.png
+++ b/src/silx/resources/gui/icons/next.png
Binary files differ
diff --git a/silx/resources/gui/icons/next.svg b/src/silx/resources/gui/icons/next.svg
index a906fc3..a906fc3 100644
--- a/silx/resources/gui/icons/next.svg
+++ b/src/silx/resources/gui/icons/next.svg
diff --git a/silx/resources/gui/icons/normal.png b/src/silx/resources/gui/icons/normal.png
index dd80045..dd80045 100755
--- a/silx/resources/gui/icons/normal.png
+++ b/src/silx/resources/gui/icons/normal.png
Binary files differ
diff --git a/silx/resources/gui/icons/normal.svg b/src/silx/resources/gui/icons/normal.svg
index 306f67d..306f67d 100644
--- a/silx/resources/gui/icons/normal.svg
+++ b/src/silx/resources/gui/icons/normal.svg
diff --git a/silx/resources/gui/icons/nxdata-axis-add.png b/src/silx/resources/gui/icons/nxdata-axis-add.png
index d75dca0..d75dca0 100644
--- a/silx/resources/gui/icons/nxdata-axis-add.png
+++ b/src/silx/resources/gui/icons/nxdata-axis-add.png
Binary files differ
diff --git a/silx/resources/gui/icons/nxdata-axis-add.svg b/src/silx/resources/gui/icons/nxdata-axis-add.svg
index fe20deb..fe20deb 100644
--- a/silx/resources/gui/icons/nxdata-axis-add.svg
+++ b/src/silx/resources/gui/icons/nxdata-axis-add.svg
diff --git a/silx/resources/gui/icons/nxdata-axis-remove.png b/src/silx/resources/gui/icons/nxdata-axis-remove.png
index 20ad063..20ad063 100644
--- a/silx/resources/gui/icons/nxdata-axis-remove.png
+++ b/src/silx/resources/gui/icons/nxdata-axis-remove.png
Binary files differ
diff --git a/silx/resources/gui/icons/nxdata-axis-remove.svg b/src/silx/resources/gui/icons/nxdata-axis-remove.svg
index 9c45d96..9c45d96 100644
--- a/silx/resources/gui/icons/nxdata-axis-remove.svg
+++ b/src/silx/resources/gui/icons/nxdata-axis-remove.svg
diff --git a/silx/resources/gui/icons/nxdata-create.png b/src/silx/resources/gui/icons/nxdata-create.png
index 839fe47..839fe47 100644
--- a/silx/resources/gui/icons/nxdata-create.png
+++ b/src/silx/resources/gui/icons/nxdata-create.png
Binary files differ
diff --git a/silx/resources/gui/icons/nxdata-create.svg b/src/silx/resources/gui/icons/nxdata-create.svg
index f508402..f508402 100644
--- a/silx/resources/gui/icons/nxdata-create.svg
+++ b/src/silx/resources/gui/icons/nxdata-create.svg
diff --git a/silx/resources/gui/icons/nxdata-remove.png b/src/silx/resources/gui/icons/nxdata-remove.png
index 40da64a..40da64a 100644
--- a/silx/resources/gui/icons/nxdata-remove.png
+++ b/src/silx/resources/gui/icons/nxdata-remove.png
Binary files differ
diff --git a/silx/resources/gui/icons/nxdata-remove.svg b/src/silx/resources/gui/icons/nxdata-remove.svg
index 9b6c9d0..9b6c9d0 100644
--- a/silx/resources/gui/icons/nxdata-remove.svg
+++ b/src/silx/resources/gui/icons/nxdata-remove.svg
diff --git a/silx/resources/gui/icons/pan.png b/src/silx/resources/gui/icons/pan.png
index 8fd0a86..8fd0a86 100644
--- a/silx/resources/gui/icons/pan.png
+++ b/src/silx/resources/gui/icons/pan.png
Binary files differ
diff --git a/silx/resources/gui/icons/pan.svg b/src/silx/resources/gui/icons/pan.svg
index 7425124..7425124 100644
--- a/silx/resources/gui/icons/pan.svg
+++ b/src/silx/resources/gui/icons/pan.svg
diff --git a/silx/resources/gui/icons/pixel-intensities.png b/src/silx/resources/gui/icons/pixel-intensities.png
index 63b1bcc..63b1bcc 100644
--- a/silx/resources/gui/icons/pixel-intensities.png
+++ b/src/silx/resources/gui/icons/pixel-intensities.png
Binary files differ
diff --git a/silx/resources/gui/icons/pixel-intensities.svg b/src/silx/resources/gui/icons/pixel-intensities.svg
index bfed7cf..bfed7cf 100644
--- a/silx/resources/gui/icons/pixel-intensities.svg
+++ b/src/silx/resources/gui/icons/pixel-intensities.svg
diff --git a/silx/resources/gui/icons/plot-grid.png b/src/silx/resources/gui/icons/plot-grid.png
index 38884c4..38884c4 100755
--- a/silx/resources/gui/icons/plot-grid.png
+++ b/src/silx/resources/gui/icons/plot-grid.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-grid.svg b/src/silx/resources/gui/icons/plot-grid.svg
index 435c99a..435c99a 100644
--- a/silx/resources/gui/icons/plot-grid.svg
+++ b/src/silx/resources/gui/icons/plot-grid.svg
diff --git a/silx/resources/gui/icons/plot-roi-above.png b/src/silx/resources/gui/icons/plot-roi-above.png
index e994668..e994668 100644
--- a/silx/resources/gui/icons/plot-roi-above.png
+++ b/src/silx/resources/gui/icons/plot-roi-above.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-roi-above.svg b/src/silx/resources/gui/icons/plot-roi-above.svg
index 1a6fcfe..1a6fcfe 100644
--- a/silx/resources/gui/icons/plot-roi-above.svg
+++ b/src/silx/resources/gui/icons/plot-roi-above.svg
diff --git a/silx/resources/gui/icons/plot-roi-below.png b/src/silx/resources/gui/icons/plot-roi-below.png
index 5a92476..5a92476 100644
--- a/silx/resources/gui/icons/plot-roi-below.png
+++ b/src/silx/resources/gui/icons/plot-roi-below.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-roi-below.svg b/src/silx/resources/gui/icons/plot-roi-below.svg
index 6d1edc5..6d1edc5 100644
--- a/silx/resources/gui/icons/plot-roi-below.svg
+++ b/src/silx/resources/gui/icons/plot-roi-below.svg
diff --git a/silx/resources/gui/icons/plot-roi-between.png b/src/silx/resources/gui/icons/plot-roi-between.png
index 5daadbd..5daadbd 100644
--- a/silx/resources/gui/icons/plot-roi-between.png
+++ b/src/silx/resources/gui/icons/plot-roi-between.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-roi-between.svg b/src/silx/resources/gui/icons/plot-roi-between.svg
index bdd835d..bdd835d 100644
--- a/silx/resources/gui/icons/plot-roi-between.svg
+++ b/src/silx/resources/gui/icons/plot-roi-between.svg
diff --git a/silx/resources/gui/icons/plot-roi-reset.png b/src/silx/resources/gui/icons/plot-roi-reset.png
index 4bf6129..4bf6129 100755
--- a/silx/resources/gui/icons/plot-roi-reset.png
+++ b/src/silx/resources/gui/icons/plot-roi-reset.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-roi-reset.svg b/src/silx/resources/gui/icons/plot-roi-reset.svg
index c398dfe..c398dfe 100644
--- a/silx/resources/gui/icons/plot-roi-reset.svg
+++ b/src/silx/resources/gui/icons/plot-roi-reset.svg
diff --git a/silx/resources/gui/icons/plot-roi.png b/src/silx/resources/gui/icons/plot-roi.png
index b34ff7c..b34ff7c 100755
--- a/silx/resources/gui/icons/plot-roi.png
+++ b/src/silx/resources/gui/icons/plot-roi.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-roi.svg b/src/silx/resources/gui/icons/plot-roi.svg
index 6b8a549..6b8a549 100644
--- a/silx/resources/gui/icons/plot-roi.svg
+++ b/src/silx/resources/gui/icons/plot-roi.svg
diff --git a/silx/resources/gui/icons/plot-symbols.png b/src/silx/resources/gui/icons/plot-symbols.png
index 75ec785..75ec785 100644
--- a/silx/resources/gui/icons/plot-symbols.png
+++ b/src/silx/resources/gui/icons/plot-symbols.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-symbols.svg b/src/silx/resources/gui/icons/plot-symbols.svg
index 48a4d22..48a4d22 100644
--- a/silx/resources/gui/icons/plot-symbols.svg
+++ b/src/silx/resources/gui/icons/plot-symbols.svg
diff --git a/silx/resources/gui/icons/plot-toggle-points.png b/src/silx/resources/gui/icons/plot-toggle-points.png
index 33b579a..33b579a 100755
--- a/silx/resources/gui/icons/plot-toggle-points.png
+++ b/src/silx/resources/gui/icons/plot-toggle-points.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-toggle-points.svg b/src/silx/resources/gui/icons/plot-toggle-points.svg
index 23fd107..23fd107 100644
--- a/silx/resources/gui/icons/plot-toggle-points.svg
+++ b/src/silx/resources/gui/icons/plot-toggle-points.svg
diff --git a/silx/resources/gui/icons/plot-widget.png b/src/silx/resources/gui/icons/plot-widget.png
index c0495a5..c0495a5 100755
--- a/silx/resources/gui/icons/plot-widget.png
+++ b/src/silx/resources/gui/icons/plot-widget.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-widget.svg b/src/silx/resources/gui/icons/plot-widget.svg
index 106c274..106c274 100644
--- a/silx/resources/gui/icons/plot-widget.svg
+++ b/src/silx/resources/gui/icons/plot-widget.svg
diff --git a/silx/resources/gui/icons/plot-window-image.png b/src/silx/resources/gui/icons/plot-window-image.png
index a95edb5..a95edb5 100755
--- a/silx/resources/gui/icons/plot-window-image.png
+++ b/src/silx/resources/gui/icons/plot-window-image.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-window-image.svg b/src/silx/resources/gui/icons/plot-window-image.svg
index 5a42355..5a42355 100644
--- a/silx/resources/gui/icons/plot-window-image.svg
+++ b/src/silx/resources/gui/icons/plot-window-image.svg
diff --git a/silx/resources/gui/icons/plot-window.png b/src/silx/resources/gui/icons/plot-window.png
index ea7eb3b..ea7eb3b 100755
--- a/silx/resources/gui/icons/plot-window.png
+++ b/src/silx/resources/gui/icons/plot-window.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-window.svg b/src/silx/resources/gui/icons/plot-window.svg
index f7a3329..f7a3329 100644
--- a/silx/resources/gui/icons/plot-window.svg
+++ b/src/silx/resources/gui/icons/plot-window.svg
diff --git a/silx/resources/gui/icons/plot-xauto.png b/src/silx/resources/gui/icons/plot-xauto.png
index 2c79723..2c79723 100755
--- a/silx/resources/gui/icons/plot-xauto.png
+++ b/src/silx/resources/gui/icons/plot-xauto.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-xauto.svg b/src/silx/resources/gui/icons/plot-xauto.svg
index 0baa9a9..0baa9a9 100644
--- a/silx/resources/gui/icons/plot-xauto.svg
+++ b/src/silx/resources/gui/icons/plot-xauto.svg
diff --git a/silx/resources/gui/icons/plot-xlog.png b/src/silx/resources/gui/icons/plot-xlog.png
index 1e0a843..1e0a843 100755
--- a/silx/resources/gui/icons/plot-xlog.png
+++ b/src/silx/resources/gui/icons/plot-xlog.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-xlog.svg b/src/silx/resources/gui/icons/plot-xlog.svg
index 172ef84..172ef84 100644
--- a/silx/resources/gui/icons/plot-xlog.svg
+++ b/src/silx/resources/gui/icons/plot-xlog.svg
diff --git a/silx/resources/gui/icons/plot-yauto.png b/src/silx/resources/gui/icons/plot-yauto.png
index e5e34f0..e5e34f0 100755
--- a/silx/resources/gui/icons/plot-yauto.png
+++ b/src/silx/resources/gui/icons/plot-yauto.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-yauto.svg b/src/silx/resources/gui/icons/plot-yauto.svg
index 7bc7e29..7bc7e29 100644
--- a/silx/resources/gui/icons/plot-yauto.svg
+++ b/src/silx/resources/gui/icons/plot-yauto.svg
diff --git a/silx/resources/gui/icons/plot-ydown.png b/src/silx/resources/gui/icons/plot-ydown.png
index f857097..f857097 100755
--- a/silx/resources/gui/icons/plot-ydown.png
+++ b/src/silx/resources/gui/icons/plot-ydown.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-ydown.svg b/src/silx/resources/gui/icons/plot-ydown.svg
index ea35cea..ea35cea 100644
--- a/silx/resources/gui/icons/plot-ydown.svg
+++ b/src/silx/resources/gui/icons/plot-ydown.svg
diff --git a/silx/resources/gui/icons/plot-ylog.png b/src/silx/resources/gui/icons/plot-ylog.png
index a705f40..a705f40 100755
--- a/silx/resources/gui/icons/plot-ylog.png
+++ b/src/silx/resources/gui/icons/plot-ylog.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-ylog.svg b/src/silx/resources/gui/icons/plot-ylog.svg
index b0d6e58..b0d6e58 100644
--- a/silx/resources/gui/icons/plot-ylog.svg
+++ b/src/silx/resources/gui/icons/plot-ylog.svg
diff --git a/silx/resources/gui/icons/plot-yup.png b/src/silx/resources/gui/icons/plot-yup.png
index bfef167..bfef167 100755
--- a/silx/resources/gui/icons/plot-yup.png
+++ b/src/silx/resources/gui/icons/plot-yup.png
Binary files differ
diff --git a/silx/resources/gui/icons/plot-yup.svg b/src/silx/resources/gui/icons/plot-yup.svg
index dfdc19f..dfdc19f 100644
--- a/silx/resources/gui/icons/plot-yup.svg
+++ b/src/silx/resources/gui/icons/plot-yup.svg
diff --git a/silx/resources/gui/icons/pointing-hand.png b/src/silx/resources/gui/icons/pointing-hand.png
index abc9cd1..abc9cd1 100644
--- a/silx/resources/gui/icons/pointing-hand.png
+++ b/src/silx/resources/gui/icons/pointing-hand.png
Binary files differ
diff --git a/silx/resources/gui/icons/pointing-hand.svg b/src/silx/resources/gui/icons/pointing-hand.svg
index 2a755da..2a755da 100644
--- a/silx/resources/gui/icons/pointing-hand.svg
+++ b/src/silx/resources/gui/icons/pointing-hand.svg
diff --git a/silx/resources/gui/icons/previous.png b/src/silx/resources/gui/icons/previous.png
index 9f436ce..9f436ce 100644
--- a/silx/resources/gui/icons/previous.png
+++ b/src/silx/resources/gui/icons/previous.png
Binary files differ
diff --git a/silx/resources/gui/icons/previous.svg b/src/silx/resources/gui/icons/previous.svg
index 0f6bcad..0f6bcad 100644
--- a/silx/resources/gui/icons/previous.svg
+++ b/src/silx/resources/gui/icons/previous.svg
diff --git a/silx/resources/gui/icons/process-working.mng b/src/silx/resources/gui/icons/process-working.mng
index 842ea5a..842ea5a 100644
--- a/silx/resources/gui/icons/process-working.mng
+++ b/src/silx/resources/gui/icons/process-working.mng
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/00.png b/src/silx/resources/gui/icons/process-working/00.png
index a787ab7..a787ab7 100644
--- a/silx/resources/gui/icons/process-working/00.png
+++ b/src/silx/resources/gui/icons/process-working/00.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/01.png b/src/silx/resources/gui/icons/process-working/01.png
index 297ed4e..297ed4e 100644
--- a/silx/resources/gui/icons/process-working/01.png
+++ b/src/silx/resources/gui/icons/process-working/01.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/02.png b/src/silx/resources/gui/icons/process-working/02.png
index f2c3a59..f2c3a59 100644
--- a/silx/resources/gui/icons/process-working/02.png
+++ b/src/silx/resources/gui/icons/process-working/02.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/03.png b/src/silx/resources/gui/icons/process-working/03.png
index 75a4b85..75a4b85 100644
--- a/silx/resources/gui/icons/process-working/03.png
+++ b/src/silx/resources/gui/icons/process-working/03.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/04.png b/src/silx/resources/gui/icons/process-working/04.png
index 12fe098..12fe098 100644
--- a/silx/resources/gui/icons/process-working/04.png
+++ b/src/silx/resources/gui/icons/process-working/04.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/05.png b/src/silx/resources/gui/icons/process-working/05.png
index ec0b2bf..ec0b2bf 100644
--- a/silx/resources/gui/icons/process-working/05.png
+++ b/src/silx/resources/gui/icons/process-working/05.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/06.png b/src/silx/resources/gui/icons/process-working/06.png
index 9dca9eb..9dca9eb 100644
--- a/silx/resources/gui/icons/process-working/06.png
+++ b/src/silx/resources/gui/icons/process-working/06.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/07.png b/src/silx/resources/gui/icons/process-working/07.png
index ca8a18c..ca8a18c 100644
--- a/silx/resources/gui/icons/process-working/07.png
+++ b/src/silx/resources/gui/icons/process-working/07.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/08.png b/src/silx/resources/gui/icons/process-working/08.png
index abd1210..abd1210 100644
--- a/silx/resources/gui/icons/process-working/08.png
+++ b/src/silx/resources/gui/icons/process-working/08.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/09.png b/src/silx/resources/gui/icons/process-working/09.png
index a0f362f..a0f362f 100644
--- a/silx/resources/gui/icons/process-working/09.png
+++ b/src/silx/resources/gui/icons/process-working/09.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/10.png b/src/silx/resources/gui/icons/process-working/10.png
index cc8b968..cc8b968 100644
--- a/silx/resources/gui/icons/process-working/10.png
+++ b/src/silx/resources/gui/icons/process-working/10.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/11.png b/src/silx/resources/gui/icons/process-working/11.png
index f5da609..f5da609 100644
--- a/silx/resources/gui/icons/process-working/11.png
+++ b/src/silx/resources/gui/icons/process-working/11.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/12.png b/src/silx/resources/gui/icons/process-working/12.png
index 92e2159..92e2159 100644
--- a/silx/resources/gui/icons/process-working/12.png
+++ b/src/silx/resources/gui/icons/process-working/12.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/13.png b/src/silx/resources/gui/icons/process-working/13.png
index 6e9e8d7..6e9e8d7 100644
--- a/silx/resources/gui/icons/process-working/13.png
+++ b/src/silx/resources/gui/icons/process-working/13.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/14.png b/src/silx/resources/gui/icons/process-working/14.png
index 3f2141b..3f2141b 100644
--- a/silx/resources/gui/icons/process-working/14.png
+++ b/src/silx/resources/gui/icons/process-working/14.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/15.png b/src/silx/resources/gui/icons/process-working/15.png
index 1043659..1043659 100644
--- a/silx/resources/gui/icons/process-working/15.png
+++ b/src/silx/resources/gui/icons/process-working/15.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/16.png b/src/silx/resources/gui/icons/process-working/16.png
index a8d89fc..a8d89fc 100644
--- a/silx/resources/gui/icons/process-working/16.png
+++ b/src/silx/resources/gui/icons/process-working/16.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/17.png b/src/silx/resources/gui/icons/process-working/17.png
index 5b68f03..5b68f03 100644
--- a/silx/resources/gui/icons/process-working/17.png
+++ b/src/silx/resources/gui/icons/process-working/17.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/18.png b/src/silx/resources/gui/icons/process-working/18.png
index cf0ff96..cf0ff96 100644
--- a/silx/resources/gui/icons/process-working/18.png
+++ b/src/silx/resources/gui/icons/process-working/18.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/19.png b/src/silx/resources/gui/icons/process-working/19.png
index 661effd..661effd 100644
--- a/silx/resources/gui/icons/process-working/19.png
+++ b/src/silx/resources/gui/icons/process-working/19.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/20.png b/src/silx/resources/gui/icons/process-working/20.png
index e1c77aa..e1c77aa 100644
--- a/silx/resources/gui/icons/process-working/20.png
+++ b/src/silx/resources/gui/icons/process-working/20.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/21.png b/src/silx/resources/gui/icons/process-working/21.png
index 10861e7..10861e7 100644
--- a/silx/resources/gui/icons/process-working/21.png
+++ b/src/silx/resources/gui/icons/process-working/21.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/22.png b/src/silx/resources/gui/icons/process-working/22.png
index 38907dc..38907dc 100644
--- a/silx/resources/gui/icons/process-working/22.png
+++ b/src/silx/resources/gui/icons/process-working/22.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/23.png b/src/silx/resources/gui/icons/process-working/23.png
index 7ec4915..7ec4915 100644
--- a/silx/resources/gui/icons/process-working/23.png
+++ b/src/silx/resources/gui/icons/process-working/23.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/24.png b/src/silx/resources/gui/icons/process-working/24.png
index 2e90357..2e90357 100644
--- a/silx/resources/gui/icons/process-working/24.png
+++ b/src/silx/resources/gui/icons/process-working/24.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/25.png b/src/silx/resources/gui/icons/process-working/25.png
index 6ffa1a7..6ffa1a7 100644
--- a/silx/resources/gui/icons/process-working/25.png
+++ b/src/silx/resources/gui/icons/process-working/25.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/26.png b/src/silx/resources/gui/icons/process-working/26.png
index b8ae153..b8ae153 100644
--- a/silx/resources/gui/icons/process-working/26.png
+++ b/src/silx/resources/gui/icons/process-working/26.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/27.png b/src/silx/resources/gui/icons/process-working/27.png
index 4d3c716..4d3c716 100644
--- a/silx/resources/gui/icons/process-working/27.png
+++ b/src/silx/resources/gui/icons/process-working/27.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/28.png b/src/silx/resources/gui/icons/process-working/28.png
index dd88b9c..dd88b9c 100644
--- a/silx/resources/gui/icons/process-working/28.png
+++ b/src/silx/resources/gui/icons/process-working/28.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/29.png b/src/silx/resources/gui/icons/process-working/29.png
index 985e115..985e115 100644
--- a/silx/resources/gui/icons/process-working/29.png
+++ b/src/silx/resources/gui/icons/process-working/29.png
Binary files differ
diff --git a/silx/resources/gui/icons/process-working/30.png b/src/silx/resources/gui/icons/process-working/30.png
index f2c6d3d..f2c6d3d 100644
--- a/silx/resources/gui/icons/process-working/30.png
+++ b/src/silx/resources/gui/icons/process-working/30.png
Binary files differ
diff --git a/silx/resources/gui/icons/profile-clear.png b/src/silx/resources/gui/icons/profile-clear.png
index 5f2159d..5f2159d 100644
--- a/silx/resources/gui/icons/profile-clear.png
+++ b/src/silx/resources/gui/icons/profile-clear.png
Binary files differ
diff --git a/silx/resources/gui/icons/profile-clear.svg b/src/silx/resources/gui/icons/profile-clear.svg
index b1593d2..b1593d2 100644
--- a/silx/resources/gui/icons/profile-clear.svg
+++ b/src/silx/resources/gui/icons/profile-clear.svg
diff --git a/silx/resources/gui/icons/profile1D.png b/src/silx/resources/gui/icons/profile1D.png
index 65991fe..65991fe 100644
--- a/silx/resources/gui/icons/profile1D.png
+++ b/src/silx/resources/gui/icons/profile1D.png
Binary files differ
diff --git a/silx/resources/gui/icons/profile1D.svg b/src/silx/resources/gui/icons/profile1D.svg
index c332345..c332345 100644
--- a/silx/resources/gui/icons/profile1D.svg
+++ b/src/silx/resources/gui/icons/profile1D.svg
diff --git a/silx/resources/gui/icons/profile2D.png b/src/silx/resources/gui/icons/profile2D.png
index 8478931..8478931 100644
--- a/silx/resources/gui/icons/profile2D.png
+++ b/src/silx/resources/gui/icons/profile2D.png
Binary files differ
diff --git a/silx/resources/gui/icons/profile2D.svg b/src/silx/resources/gui/icons/profile2D.svg
index e682b3c..e682b3c 100644
--- a/silx/resources/gui/icons/profile2D.svg
+++ b/src/silx/resources/gui/icons/profile2D.svg
diff --git a/silx/resources/gui/icons/remove.png b/src/silx/resources/gui/icons/remove.png
index 235338c..235338c 100755
--- a/silx/resources/gui/icons/remove.png
+++ b/src/silx/resources/gui/icons/remove.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/remove.svg b/src/silx/resources/gui/icons/remove.svg
new file mode 100644
index 0000000..9d3bbce
--- /dev/null
+++ b/src/silx/resources/gui/icons/remove.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="d" x1="22.414" x2="22.473" y1="21.502" y2="21.502" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m22.443 24.002c0.038 0 0.038-5 0-5s-0.038 5 0 5z" fill="url(#d)"/><path d="m8.293 10.899c5.462 5.68 10.925 11.36 16.387 17.04-0.814-0.847 0.851-4.115 0-5-5.462-5.68-10.924-11.36-16.387-17.04 0.814 0.847-0.851 4.116 0 5z" fill="#ed1c24"/><path d="m24.452 5.675c-5.434 5.658-10.869 11.317-16.304 16.975-0.851 0.886 0.814 4.152 0 5 5.435-5.658 10.869-11.317 16.304-16.976 0.851-0.884-0.814-4.152 0-4.999z" fill="#ed1c24"/></svg>
diff --git a/silx/resources/gui/icons/rm.png b/src/silx/resources/gui/icons/rm.png
index ecff08b..ecff08b 100644
--- a/silx/resources/gui/icons/rm.png
+++ b/src/silx/resources/gui/icons/rm.png
Binary files differ
diff --git a/silx/resources/gui/icons/rm.svg b/src/silx/resources/gui/icons/rm.svg
index 7cc515e..7cc515e 100644
--- a/silx/resources/gui/icons/rm.svg
+++ b/src/silx/resources/gui/icons/rm.svg
diff --git a/silx/resources/gui/icons/rotate-3d.png b/src/silx/resources/gui/icons/rotate-3d.png
index 4cf8403..4cf8403 100644
--- a/silx/resources/gui/icons/rotate-3d.png
+++ b/src/silx/resources/gui/icons/rotate-3d.png
Binary files differ
diff --git a/silx/resources/gui/icons/rotate-3d.svg b/src/silx/resources/gui/icons/rotate-3d.svg
index 32a4327..32a4327 100644
--- a/silx/resources/gui/icons/rotate-3d.svg
+++ b/src/silx/resources/gui/icons/rotate-3d.svg
diff --git a/silx/resources/gui/icons/rudder.png b/src/silx/resources/gui/icons/rudder.png
index ad45338..ad45338 100755
--- a/silx/resources/gui/icons/rudder.png
+++ b/src/silx/resources/gui/icons/rudder.png
Binary files differ
diff --git a/silx/resources/gui/icons/rudder.svg b/src/silx/resources/gui/icons/rudder.svg
index 6c8c742..6c8c742 100644
--- a/silx/resources/gui/icons/rudder.svg
+++ b/src/silx/resources/gui/icons/rudder.svg
diff --git a/silx/resources/gui/icons/selected.png b/src/silx/resources/gui/icons/selected.png
index 451d7c7..451d7c7 100755
--- a/silx/resources/gui/icons/selected.png
+++ b/src/silx/resources/gui/icons/selected.png
Binary files differ
diff --git a/silx/resources/gui/icons/selected.svg b/src/silx/resources/gui/icons/selected.svg
index d73d849..d73d849 100644
--- a/silx/resources/gui/icons/selected.svg
+++ b/src/silx/resources/gui/icons/selected.svg
diff --git a/silx/resources/gui/icons/shape-circle-solid.png b/src/silx/resources/gui/icons/shape-circle-solid.png
index f43d736..f43d736 100755
--- a/silx/resources/gui/icons/shape-circle-solid.png
+++ b/src/silx/resources/gui/icons/shape-circle-solid.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-circle-solid.svg b/src/silx/resources/gui/icons/shape-circle-solid.svg
index 600170f..600170f 100644
--- a/silx/resources/gui/icons/shape-circle-solid.svg
+++ b/src/silx/resources/gui/icons/shape-circle-solid.svg
diff --git a/silx/resources/gui/icons/shape-circle.png b/src/silx/resources/gui/icons/shape-circle.png
index 3d21824..3d21824 100755
--- a/silx/resources/gui/icons/shape-circle.png
+++ b/src/silx/resources/gui/icons/shape-circle.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-circle.svg b/src/silx/resources/gui/icons/shape-circle.svg
index 45a2a0d..45a2a0d 100644
--- a/silx/resources/gui/icons/shape-circle.svg
+++ b/src/silx/resources/gui/icons/shape-circle.svg
diff --git a/silx/resources/gui/icons/shape-cross.png b/src/silx/resources/gui/icons/shape-cross.png
index 72106a4..72106a4 100644
--- a/silx/resources/gui/icons/shape-cross.png
+++ b/src/silx/resources/gui/icons/shape-cross.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-cross.svg b/src/silx/resources/gui/icons/shape-cross.svg
index cba6638..cba6638 100644
--- a/silx/resources/gui/icons/shape-cross.svg
+++ b/src/silx/resources/gui/icons/shape-cross.svg
diff --git a/silx/resources/gui/icons/shape-diagonal-directed.png b/src/silx/resources/gui/icons/shape-diagonal-directed.png
index f2405b4..f2405b4 100644
--- a/silx/resources/gui/icons/shape-diagonal-directed.png
+++ b/src/silx/resources/gui/icons/shape-diagonal-directed.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-diagonal-directed.svg b/src/silx/resources/gui/icons/shape-diagonal-directed.svg
index 24e1b12..24e1b12 100644
--- a/silx/resources/gui/icons/shape-diagonal-directed.svg
+++ b/src/silx/resources/gui/icons/shape-diagonal-directed.svg
diff --git a/silx/resources/gui/icons/shape-diagonal.png b/src/silx/resources/gui/icons/shape-diagonal.png
index f71bcb0..f71bcb0 100755
--- a/silx/resources/gui/icons/shape-diagonal.png
+++ b/src/silx/resources/gui/icons/shape-diagonal.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-diagonal.svg b/src/silx/resources/gui/icons/shape-diagonal.svg
index 4580c06..4580c06 100644
--- a/silx/resources/gui/icons/shape-diagonal.svg
+++ b/src/silx/resources/gui/icons/shape-diagonal.svg
diff --git a/silx/resources/gui/icons/shape-ellipse-solid.png b/src/silx/resources/gui/icons/shape-ellipse-solid.png
index 31bcb4c..31bcb4c 100755
--- a/silx/resources/gui/icons/shape-ellipse-solid.png
+++ b/src/silx/resources/gui/icons/shape-ellipse-solid.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-ellipse-solid.svg b/src/silx/resources/gui/icons/shape-ellipse-solid.svg
index b740a23..b740a23 100644
--- a/silx/resources/gui/icons/shape-ellipse-solid.svg
+++ b/src/silx/resources/gui/icons/shape-ellipse-solid.svg
diff --git a/silx/resources/gui/icons/shape-ellipse.png b/src/silx/resources/gui/icons/shape-ellipse.png
index f524f7e..f524f7e 100644
--- a/silx/resources/gui/icons/shape-ellipse.png
+++ b/src/silx/resources/gui/icons/shape-ellipse.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-ellipse.svg b/src/silx/resources/gui/icons/shape-ellipse.svg
index e5aeeaa..e5aeeaa 100644
--- a/silx/resources/gui/icons/shape-ellipse.svg
+++ b/src/silx/resources/gui/icons/shape-ellipse.svg
diff --git a/silx/resources/gui/icons/shape-horizontal.png b/src/silx/resources/gui/icons/shape-horizontal.png
index 0ea55e2..0ea55e2 100755
--- a/silx/resources/gui/icons/shape-horizontal.png
+++ b/src/silx/resources/gui/icons/shape-horizontal.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-horizontal.svg b/src/silx/resources/gui/icons/shape-horizontal.svg
index 053a590..053a590 100644
--- a/silx/resources/gui/icons/shape-horizontal.svg
+++ b/src/silx/resources/gui/icons/shape-horizontal.svg
diff --git a/silx/resources/gui/icons/shape-polygon.png b/src/silx/resources/gui/icons/shape-polygon.png
index efbb449..efbb449 100755
--- a/silx/resources/gui/icons/shape-polygon.png
+++ b/src/silx/resources/gui/icons/shape-polygon.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-polygon.svg b/src/silx/resources/gui/icons/shape-polygon.svg
index b30503c..b30503c 100644
--- a/silx/resources/gui/icons/shape-polygon.svg
+++ b/src/silx/resources/gui/icons/shape-polygon.svg
diff --git a/silx/resources/gui/icons/shape-rectangle.png b/src/silx/resources/gui/icons/shape-rectangle.png
index c523c72..c523c72 100755
--- a/silx/resources/gui/icons/shape-rectangle.png
+++ b/src/silx/resources/gui/icons/shape-rectangle.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-rectangle.svg b/src/silx/resources/gui/icons/shape-rectangle.svg
index caade30..caade30 100644
--- a/silx/resources/gui/icons/shape-rectangle.svg
+++ b/src/silx/resources/gui/icons/shape-rectangle.svg
diff --git a/silx/resources/gui/icons/shape-square.png b/src/silx/resources/gui/icons/shape-square.png
index 667b758..667b758 100755
--- a/silx/resources/gui/icons/shape-square.png
+++ b/src/silx/resources/gui/icons/shape-square.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-square.svg b/src/silx/resources/gui/icons/shape-square.svg
index de6eda8..de6eda8 100644
--- a/silx/resources/gui/icons/shape-square.svg
+++ b/src/silx/resources/gui/icons/shape-square.svg
diff --git a/silx/resources/gui/icons/shape-vertical.png b/src/silx/resources/gui/icons/shape-vertical.png
index 384c4ae..384c4ae 100755
--- a/silx/resources/gui/icons/shape-vertical.png
+++ b/src/silx/resources/gui/icons/shape-vertical.png
Binary files differ
diff --git a/silx/resources/gui/icons/shape-vertical.svg b/src/silx/resources/gui/icons/shape-vertical.svg
index 8cf3b97..8cf3b97 100644
--- a/silx/resources/gui/icons/shape-vertical.svg
+++ b/src/silx/resources/gui/icons/shape-vertical.svg
diff --git a/src/silx/resources/gui/icons/side-histograms.png b/src/silx/resources/gui/icons/side-histograms.png
new file mode 100644
index 0000000..6416ceb
--- /dev/null
+++ b/src/silx/resources/gui/icons/side-histograms.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/side-histograms.svg b/src/silx/resources/gui/icons/side-histograms.svg
new file mode 100644
index 0000000..7c366f0
--- /dev/null
+++ b/src/silx/resources/gui/icons/side-histograms.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg id="svg10460" version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"><metadata id="metadata10466"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata><rect id="rect10458-6" transform="rotate(90)" x="22.448" y="-18.66" width="5.7532" height="14.843" fill="none" stroke="#f7941e" stroke-miterlimit="10" stroke-width="2.5"/><rect id="rect10458-6-5" transform="rotate(90)" x="22.448" y="-28.183" width="5.753" height="5.753" fill="none" stroke="#f7941e" stroke-miterlimit="10" stroke-width="2.5"/><rect id="rect10458-6-3" transform="scale(-1)" x="-28.107" y="-18.643" width="5.7532" height="14.843" fill="none" stroke="#f7941e" stroke-miterlimit="10" stroke-width="2.5"/></svg>
diff --git a/silx/resources/gui/icons/silx.png b/src/silx/resources/gui/icons/silx.png
index 9d7ffc9..9d7ffc9 100755
--- a/silx/resources/gui/icons/silx.png
+++ b/src/silx/resources/gui/icons/silx.png
Binary files differ
diff --git a/silx/resources/gui/icons/silx.svg b/src/silx/resources/gui/icons/silx.svg
index e666b73..e666b73 100644
--- a/silx/resources/gui/icons/silx.svg
+++ b/src/silx/resources/gui/icons/silx.svg
diff --git a/silx/resources/gui/icons/slice-cross.png b/src/silx/resources/gui/icons/slice-cross.png
index 106362e..106362e 100644
--- a/silx/resources/gui/icons/slice-cross.png
+++ b/src/silx/resources/gui/icons/slice-cross.png
Binary files differ
diff --git a/silx/resources/gui/icons/slice-cross.svg b/src/silx/resources/gui/icons/slice-cross.svg
index 271a656..271a656 100644
--- a/silx/resources/gui/icons/slice-cross.svg
+++ b/src/silx/resources/gui/icons/slice-cross.svg
diff --git a/silx/resources/gui/icons/slice-horizontal.png b/src/silx/resources/gui/icons/slice-horizontal.png
index d16b74c..d16b74c 100644
--- a/silx/resources/gui/icons/slice-horizontal.png
+++ b/src/silx/resources/gui/icons/slice-horizontal.png
Binary files differ
diff --git a/silx/resources/gui/icons/slice-horizontal.svg b/src/silx/resources/gui/icons/slice-horizontal.svg
index 9402bc6..9402bc6 100644
--- a/silx/resources/gui/icons/slice-horizontal.svg
+++ b/src/silx/resources/gui/icons/slice-horizontal.svg
diff --git a/silx/resources/gui/icons/slice-vertical.png b/src/silx/resources/gui/icons/slice-vertical.png
index 6fc99b3..6fc99b3 100644
--- a/silx/resources/gui/icons/slice-vertical.png
+++ b/src/silx/resources/gui/icons/slice-vertical.png
Binary files differ
diff --git a/silx/resources/gui/icons/slice-vertical.svg b/src/silx/resources/gui/icons/slice-vertical.svg
index d9d67a4..d9d67a4 100644
--- a/silx/resources/gui/icons/slice-vertical.svg
+++ b/src/silx/resources/gui/icons/slice-vertical.svg
diff --git a/silx/resources/gui/icons/sliders-off.png b/src/silx/resources/gui/icons/sliders-off.png
index 463f4ec..463f4ec 100755
--- a/silx/resources/gui/icons/sliders-off.png
+++ b/src/silx/resources/gui/icons/sliders-off.png
Binary files differ
diff --git a/silx/resources/gui/icons/sliders-off.svg b/src/silx/resources/gui/icons/sliders-off.svg
index 1ed69ce..1ed69ce 100644
--- a/silx/resources/gui/icons/sliders-off.svg
+++ b/src/silx/resources/gui/icons/sliders-off.svg
diff --git a/silx/resources/gui/icons/sliders-on.png b/src/silx/resources/gui/icons/sliders-on.png
index e720d15..e720d15 100755
--- a/silx/resources/gui/icons/sliders-on.png
+++ b/src/silx/resources/gui/icons/sliders-on.png
Binary files differ
diff --git a/silx/resources/gui/icons/sliders-on.svg b/src/silx/resources/gui/icons/sliders-on.svg
index b70d36c..b70d36c 100644
--- a/silx/resources/gui/icons/sliders-on.svg
+++ b/src/silx/resources/gui/icons/sliders-on.svg
diff --git a/silx/resources/gui/icons/spec.png b/src/silx/resources/gui/icons/spec.png
index 876ab1b..876ab1b 100755
--- a/silx/resources/gui/icons/spec.png
+++ b/src/silx/resources/gui/icons/spec.png
Binary files differ
diff --git a/silx/resources/gui/icons/spec.svg b/src/silx/resources/gui/icons/spec.svg
index 26d9d5f..26d9d5f 100644
--- a/silx/resources/gui/icons/spec.svg
+++ b/src/silx/resources/gui/icons/spec.svg
diff --git a/silx/resources/gui/icons/stats-active-items.png b/src/silx/resources/gui/icons/stats-active-items.png
index 9974aa0..9974aa0 100644
--- a/silx/resources/gui/icons/stats-active-items.png
+++ b/src/silx/resources/gui/icons/stats-active-items.png
Binary files differ
diff --git a/silx/resources/gui/icons/stats-active-items.svg b/src/silx/resources/gui/icons/stats-active-items.svg
index 8312178..8312178 100644
--- a/silx/resources/gui/icons/stats-active-items.svg
+++ b/src/silx/resources/gui/icons/stats-active-items.svg
diff --git a/silx/resources/gui/icons/stats-visible-data.png b/src/silx/resources/gui/icons/stats-visible-data.png
index 9353117..9353117 100644
--- a/silx/resources/gui/icons/stats-visible-data.png
+++ b/src/silx/resources/gui/icons/stats-visible-data.png
Binary files differ
diff --git a/silx/resources/gui/icons/stats-visible-data.svg b/src/silx/resources/gui/icons/stats-visible-data.svg
index e56a42c..e56a42c 100644
--- a/silx/resources/gui/icons/stats-visible-data.svg
+++ b/src/silx/resources/gui/icons/stats-visible-data.svg
diff --git a/silx/resources/gui/icons/stats-whole-data.png b/src/silx/resources/gui/icons/stats-whole-data.png
index 3fab615..3fab615 100644
--- a/silx/resources/gui/icons/stats-whole-data.png
+++ b/src/silx/resources/gui/icons/stats-whole-data.png
Binary files differ
diff --git a/silx/resources/gui/icons/stats-whole-data.svg b/src/silx/resources/gui/icons/stats-whole-data.svg
index 5e5b9f9..5e5b9f9 100644
--- a/silx/resources/gui/icons/stats-whole-data.svg
+++ b/src/silx/resources/gui/icons/stats-whole-data.svg
diff --git a/silx/resources/gui/icons/stats-whole-items.png b/src/silx/resources/gui/icons/stats-whole-items.png
index d3c24d0..d3c24d0 100644
--- a/silx/resources/gui/icons/stats-whole-items.png
+++ b/src/silx/resources/gui/icons/stats-whole-items.png
Binary files differ
diff --git a/silx/resources/gui/icons/stats-whole-items.svg b/src/silx/resources/gui/icons/stats-whole-items.svg
index c0e55bc..c0e55bc 100644
--- a/silx/resources/gui/icons/stats-whole-items.svg
+++ b/src/silx/resources/gui/icons/stats-whole-items.svg
diff --git a/silx/resources/gui/icons/tree-collapse-all.png b/src/silx/resources/gui/icons/tree-collapse-all.png
index 6ecf8b5..6ecf8b5 100644
--- a/silx/resources/gui/icons/tree-collapse-all.png
+++ b/src/silx/resources/gui/icons/tree-collapse-all.png
Binary files differ
diff --git a/silx/resources/gui/icons/tree-collapse-all.svg b/src/silx/resources/gui/icons/tree-collapse-all.svg
index ecdd800..ecdd800 100644
--- a/silx/resources/gui/icons/tree-collapse-all.svg
+++ b/src/silx/resources/gui/icons/tree-collapse-all.svg
diff --git a/silx/resources/gui/icons/tree-expand-all.png b/src/silx/resources/gui/icons/tree-expand-all.png
index 97f2aa5..97f2aa5 100644
--- a/silx/resources/gui/icons/tree-expand-all.png
+++ b/src/silx/resources/gui/icons/tree-expand-all.png
Binary files differ
diff --git a/silx/resources/gui/icons/tree-expand-all.svg b/src/silx/resources/gui/icons/tree-expand-all.svg
index 586c269..586c269 100644
--- a/silx/resources/gui/icons/tree-expand-all.svg
+++ b/src/silx/resources/gui/icons/tree-expand-all.svg
diff --git a/silx/resources/gui/icons/tree-sort.png b/src/silx/resources/gui/icons/tree-sort.png
index 2e759b6..2e759b6 100644
--- a/silx/resources/gui/icons/tree-sort.png
+++ b/src/silx/resources/gui/icons/tree-sort.png
Binary files differ
diff --git a/silx/resources/gui/icons/tree-sort.svg b/src/silx/resources/gui/icons/tree-sort.svg
index b813d60..b813d60 100644
--- a/silx/resources/gui/icons/tree-sort.svg
+++ b/src/silx/resources/gui/icons/tree-sort.svg
diff --git a/silx/resources/gui/icons/view-1d.png b/src/silx/resources/gui/icons/view-1d.png
index b420a5c..b420a5c 100644
--- a/silx/resources/gui/icons/view-1d.png
+++ b/src/silx/resources/gui/icons/view-1d.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-1d.svg b/src/silx/resources/gui/icons/view-1d.svg
index a2ad9cc..a2ad9cc 100644
--- a/silx/resources/gui/icons/view-1d.svg
+++ b/src/silx/resources/gui/icons/view-1d.svg
diff --git a/silx/resources/gui/icons/view-2d-stack.png b/src/silx/resources/gui/icons/view-2d-stack.png
index 6571a23..6571a23 100644
--- a/silx/resources/gui/icons/view-2d-stack.png
+++ b/src/silx/resources/gui/icons/view-2d-stack.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-2d-stack.svg b/src/silx/resources/gui/icons/view-2d-stack.svg
index 922d745..922d745 100644
--- a/silx/resources/gui/icons/view-2d-stack.svg
+++ b/src/silx/resources/gui/icons/view-2d-stack.svg
diff --git a/silx/resources/gui/icons/view-2d.png b/src/silx/resources/gui/icons/view-2d.png
index 3704867..3704867 100644
--- a/silx/resources/gui/icons/view-2d.png
+++ b/src/silx/resources/gui/icons/view-2d.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-2d.svg b/src/silx/resources/gui/icons/view-2d.svg
index 10f4cc0..10f4cc0 100644
--- a/silx/resources/gui/icons/view-2d.svg
+++ b/src/silx/resources/gui/icons/view-2d.svg
diff --git a/silx/resources/gui/icons/view-3d.png b/src/silx/resources/gui/icons/view-3d.png
index 4a38b19..4a38b19 100644
--- a/silx/resources/gui/icons/view-3d.png
+++ b/src/silx/resources/gui/icons/view-3d.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-3d.svg b/src/silx/resources/gui/icons/view-3d.svg
index 7e417ae..7e417ae 100644
--- a/silx/resources/gui/icons/view-3d.svg
+++ b/src/silx/resources/gui/icons/view-3d.svg
diff --git a/silx/resources/gui/icons/view-fullscreen.png b/src/silx/resources/gui/icons/view-fullscreen.png
index 7c891c7..7c891c7 100755
--- a/silx/resources/gui/icons/view-fullscreen.png
+++ b/src/silx/resources/gui/icons/view-fullscreen.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-fullscreen.svg b/src/silx/resources/gui/icons/view-fullscreen.svg
index cc389ce..cc389ce 100644
--- a/silx/resources/gui/icons/view-fullscreen.svg
+++ b/src/silx/resources/gui/icons/view-fullscreen.svg
diff --git a/silx/resources/gui/icons/view-hdf5.png b/src/silx/resources/gui/icons/view-hdf5.png
index efdf7c7..efdf7c7 100644
--- a/silx/resources/gui/icons/view-hdf5.png
+++ b/src/silx/resources/gui/icons/view-hdf5.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-hdf5.svg b/src/silx/resources/gui/icons/view-hdf5.svg
index 265db72..265db72 100644
--- a/silx/resources/gui/icons/view-hdf5.svg
+++ b/src/silx/resources/gui/icons/view-hdf5.svg
diff --git a/silx/resources/gui/icons/view-nexus.png b/src/silx/resources/gui/icons/view-nexus.png
index ab36b1a..ab36b1a 100644
--- a/silx/resources/gui/icons/view-nexus.png
+++ b/src/silx/resources/gui/icons/view-nexus.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-nexus.svg b/src/silx/resources/gui/icons/view-nexus.svg
index 4bfff81..4bfff81 100644
--- a/silx/resources/gui/icons/view-nexus.svg
+++ b/src/silx/resources/gui/icons/view-nexus.svg
diff --git a/silx/resources/gui/icons/view-nofullscreen.png b/src/silx/resources/gui/icons/view-nofullscreen.png
index d61625e..d61625e 100755
--- a/silx/resources/gui/icons/view-nofullscreen.png
+++ b/src/silx/resources/gui/icons/view-nofullscreen.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-nofullscreen.svg b/src/silx/resources/gui/icons/view-nofullscreen.svg
index 003ba53..003ba53 100644
--- a/silx/resources/gui/icons/view-nofullscreen.svg
+++ b/src/silx/resources/gui/icons/view-nofullscreen.svg
diff --git a/silx/resources/gui/icons/view-raw.png b/src/silx/resources/gui/icons/view-raw.png
index a0fb23d..a0fb23d 100644
--- a/silx/resources/gui/icons/view-raw.png
+++ b/src/silx/resources/gui/icons/view-raw.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-raw.svg b/src/silx/resources/gui/icons/view-raw.svg
index ff15da3..ff15da3 100644
--- a/silx/resources/gui/icons/view-raw.svg
+++ b/src/silx/resources/gui/icons/view-raw.svg
diff --git a/silx/resources/gui/icons/view-refresh.png b/src/silx/resources/gui/icons/view-refresh.png
index 1a8c064..1a8c064 100755
--- a/silx/resources/gui/icons/view-refresh.png
+++ b/src/silx/resources/gui/icons/view-refresh.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-refresh.svg b/src/silx/resources/gui/icons/view-refresh.svg
index 9593dd8..9593dd8 100644
--- a/silx/resources/gui/icons/view-refresh.svg
+++ b/src/silx/resources/gui/icons/view-refresh.svg
diff --git a/silx/resources/gui/icons/view-text.png b/src/silx/resources/gui/icons/view-text.png
index 5bfde30..5bfde30 100644
--- a/silx/resources/gui/icons/view-text.png
+++ b/src/silx/resources/gui/icons/view-text.png
Binary files differ
diff --git a/silx/resources/gui/icons/view-text.svg b/src/silx/resources/gui/icons/view-text.svg
index fbf0a7c..fbf0a7c 100644
--- a/silx/resources/gui/icons/view-text.svg
+++ b/src/silx/resources/gui/icons/view-text.svg
diff --git a/silx/resources/gui/icons/window-new.png b/src/silx/resources/gui/icons/window-new.png
index d26703b..d26703b 100755
--- a/silx/resources/gui/icons/window-new.png
+++ b/src/silx/resources/gui/icons/window-new.png
Binary files differ
diff --git a/silx/resources/gui/icons/window-new.svg b/src/silx/resources/gui/icons/window-new.svg
index 114f26c..114f26c 100644
--- a/silx/resources/gui/icons/window-new.svg
+++ b/src/silx/resources/gui/icons/window-new.svg
diff --git a/silx/resources/gui/icons/zoom-back.png b/src/silx/resources/gui/icons/zoom-back.png
index 14d7951..14d7951 100644
--- a/silx/resources/gui/icons/zoom-back.png
+++ b/src/silx/resources/gui/icons/zoom-back.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/zoom-back.svg b/src/silx/resources/gui/icons/zoom-back.svg
new file mode 100644
index 0000000..da40620
--- /dev/null
+++ b/src/silx/resources/gui/icons/zoom-back.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="d" x1="20.887" x2="23.374" y1="21.759" y2="18.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><radialGradient id="c" cx="13.206" cy="8.4126" r="9.1344" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="b" x1="4.605" x2="18.267" y1="12.302" y2="12.302" gradientUnits="userSpaceOnUse"><stop stop-color="#FFF" offset="0"/><stop offset="1"/></linearGradient><polygon points="28.606 22.356 26.103 25.571 15.723 17.758 18.174 14.502" fill="url(#d)" stroke="#808285" stroke-miterlimit="10" stroke-width=".2"/><circle cx="11.483" cy="12.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><path d="m17.967 12.302c0 3.594-3.039 6.507-6.518 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.518 2.913 6.518 6.507z" fill="url(#c)" stroke="url(#b)" stroke-miterlimit="10" stroke-width=".6"/><path d="m10.992 6.764s4.839-0.584 5.992 4.366" fill="none" stroke="#FFF" stroke-miterlimit="10" stroke-width="1.2"/><g transform="matrix(-1 0 0 1 23.132 -18.833)" fill="#F00" stroke="#F00" stroke-miterlimit="10"><path d="m4.7543 24.006c-10.964-0.107-10.073 10.653-10.266 10.974 0 0 0.193-7.139 10.267-6.713v-4.261z" stroke-width=".1"/><path d="m4.7543 22.329c0-0.17 0.122-0.243 0.271-0.16l6.169 3.403c0.149 0.083 0.157 0.23 0.018 0.328l-6.204 4.348c-0.14 0.098-0.254 0.038-0.254-0.132v-7.787z"/></g></svg>
diff --git a/silx/resources/gui/icons/zoom-in.png b/src/silx/resources/gui/icons/zoom-in.png
index a133948..a133948 100755
--- a/silx/resources/gui/icons/zoom-in.png
+++ b/src/silx/resources/gui/icons/zoom-in.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/zoom-in.svg b/src/silx/resources/gui/icons/zoom-in.svg
new file mode 100644
index 0000000..3c79364
--- /dev/null
+++ b/src/silx/resources/gui/icons/zoom-in.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="19.887" x2="22.374" y1="23.759" y2="20.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="27.606 24.356 25.103 27.571 14.723 19.758 17.174 16.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".1"/><circle cx="10.483" cy="14.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="12.253" cy="10.413" r="9.1342" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="3.4521" x2="17.514" y1="14.302" y2="14.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m17.014 14.301c0 3.594-3.038 6.507-6.517 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.517 2.914 6.517 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10"/><path d="m9.177 9.151s4.405-1.127 6.307 3.42" fill="none" stroke="#fff" stroke-miterlimit="10"/><g fill="#00a651" stroke="#00a651" stroke-miterlimit="10"><rect x="24.483" y="7.225" width="1.239" height="8.379"/><rect x="20.913" y="10.796" width="8.38" height="1.237"/></g></svg>
diff --git a/silx/resources/gui/icons/zoom-original.png b/src/silx/resources/gui/icons/zoom-original.png
index 5d78149..5d78149 100755
--- a/silx/resources/gui/icons/zoom-original.png
+++ b/src/silx/resources/gui/icons/zoom-original.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/zoom-original.svg b/src/silx/resources/gui/icons/zoom-original.svg
new file mode 100644
index 0000000..250fe97
--- /dev/null
+++ b/src/silx/resources/gui/icons/zoom-original.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="20.888" x2="23.375" y1="23.759" y2="20.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="28.606 24.356 26.103 27.571 15.723 19.758 18.174 16.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".1"/><circle cx="11.483" cy="14.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="13.253" cy="10.413" r="9.1342" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="4.4521" x2="18.514" y1="14.302" y2="14.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m18.014 14.301c0 3.594-3.038 6.507-6.517 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.517 2.914 6.517 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10"/><path d="m10.177 9.151s4.405-1.127 6.307 3.42" fill="none" stroke="#fff" stroke-miterlimit="10"/><g fill="#ed1c24" stroke="#ed1c24" stroke-miterlimit="10" stroke-width="2.5"><line x1="7.257" x2="25.712" y1="24.906" y2="6.518"/><line x1="7.392" x2="25.575" y1="6.371" y2="25.053"/></g></svg>
diff --git a/silx/resources/gui/icons/zoom-out.png b/src/silx/resources/gui/icons/zoom-out.png
index 3110fa8..3110fa8 100755
--- a/silx/resources/gui/icons/zoom-out.png
+++ b/src/silx/resources/gui/icons/zoom-out.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/zoom-out.svg b/src/silx/resources/gui/icons/zoom-out.svg
new file mode 100644
index 0000000..a646b5d
--- /dev/null
+++ b/src/silx/resources/gui/icons/zoom-out.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="19.887" x2="22.374" y1="22.759" y2="19.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="27.606 23.356 25.103 26.571 14.723 18.758 17.174 15.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".1"/><circle cx="10.483" cy="13.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="12.253" cy="9.4126" r="9.1342" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="3.4521" x2="17.514" y1="13.302" y2="13.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m17.014 13.301c0 3.594-3.038 6.507-6.517 6.507s-6.544-2.914-6.544-6.507 3.065-6.507 6.544-6.507 6.517 2.914 6.517 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10"/><path d="m9.177 8.151s4.405-1.127 6.307 3.42" fill="none" stroke="#fff" stroke-miterlimit="10"/><rect x="20.304" y="7.802" width="7.377" height=".988" fill="#ed1c24" stroke="#ed1c24" stroke-miterlimit="10" stroke-width="2"/></svg>
diff --git a/silx/resources/gui/icons/zoom.png b/src/silx/resources/gui/icons/zoom.png
index 7847c0a..7847c0a 100755
--- a/silx/resources/gui/icons/zoom.png
+++ b/src/silx/resources/gui/icons/zoom.png
Binary files differ
diff --git a/src/silx/resources/gui/icons/zoom.svg b/src/silx/resources/gui/icons/zoom.svg
new file mode 100644
index 0000000..deeceb7
--- /dev/null
+++ b/src/silx/resources/gui/icons/zoom.svg
@@ -0,0 +1,2 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg version="1.1" viewBox="0 0 32 32" xml:space="preserve" xmlns="http://www.w3.org/2000/svg"><linearGradient id="e" x1="20.887" x2="23.374" y1="21.759" y2="18.443" gradientUnits="userSpaceOnUse"><stop stop-color="#280e19" offset="0"/><stop stop-color="#382229" offset=".1053"/><stop stop-color="#f9edf5" offset=".9839"/></linearGradient><polygon points="28.606 22.356 26.103 25.571 15.723 17.758 18.174 14.502" fill="url(#e)" stroke="#808285" stroke-miterlimit="10" stroke-width=".2"/><circle cx="11.483" cy="12.302" r="7" fill="none" stroke="#776569" stroke-miterlimit="10"/><radialGradient id="f" cx="13.206" cy="8.4126" r="9.1344" gradientUnits="userSpaceOnUse"><stop stop-color="#00aeef" stop-opacity=".6" offset="0"/><stop stop-color="#9bd1ea" stop-opacity=".5317" offset=".6832"/><stop stop-color="#e6e7e8" stop-opacity=".5" offset="1"/></radialGradient><linearGradient id="g" x1="4.605" x2="18.267" y1="12.302" y2="12.302" gradientUnits="userSpaceOnUse"><stop stop-color="#fff" offset="0"/><stop offset="1"/></linearGradient><path d="m17.967 12.302c0 3.594-3.039 6.507-6.518 6.507s-6.544-2.913-6.544-6.507 3.065-6.507 6.544-6.507 6.518 2.913 6.518 6.507z" fill="url(#f)" stroke="url(#g)" stroke-miterlimit="10" stroke-width=".6"/><path d="m10.992 6.764s4.839-0.584 5.992 4.366" fill="none" stroke="#fff" stroke-miterlimit="10" stroke-width="1.2"/></svg>
diff --git a/silx/resources/gui/logo/silx.png b/src/silx/resources/gui/logo/silx.png
index d6f9733..d6f9733 100644
--- a/silx/resources/gui/logo/silx.png
+++ b/src/silx/resources/gui/logo/silx.png
Binary files differ
diff --git a/silx/resources/gui/logo/silx.svg b/src/silx/resources/gui/logo/silx.svg
index f7eb6cb..f7eb6cb 100644
--- a/silx/resources/gui/logo/silx.svg
+++ b/src/silx/resources/gui/logo/silx.svg
diff --git a/silx/resources/opencl/addition.cl b/src/silx/resources/opencl/addition.cl
index 35d7996..35d7996 100644
--- a/silx/resources/opencl/addition.cl
+++ b/src/silx/resources/opencl/addition.cl
diff --git a/silx/resources/opencl/array_utils.cl b/src/silx/resources/opencl/array_utils.cl
index 6f78921..6f78921 100644
--- a/silx/resources/opencl/array_utils.cl
+++ b/src/silx/resources/opencl/array_utils.cl
diff --git a/silx/resources/opencl/backproj.cl b/src/silx/resources/opencl/backproj.cl
index da15131..da15131 100644
--- a/silx/resources/opencl/backproj.cl
+++ b/src/silx/resources/opencl/backproj.cl
diff --git a/silx/resources/opencl/backproj_helper.cl b/src/silx/resources/opencl/backproj_helper.cl
index b1590f8..b1590f8 100644
--- a/silx/resources/opencl/backproj_helper.cl
+++ b/src/silx/resources/opencl/backproj_helper.cl
diff --git a/silx/resources/opencl/bitonic.cl b/src/silx/resources/opencl/bitonic.cl
index 4096ce8..4096ce8 100644
--- a/silx/resources/opencl/bitonic.cl
+++ b/src/silx/resources/opencl/bitonic.cl
diff --git a/silx/resources/opencl/codec/byte_offset.cl b/src/silx/resources/opencl/codec/byte_offset.cl
index 56a24c4..56a24c4 100644
--- a/silx/resources/opencl/codec/byte_offset.cl
+++ b/src/silx/resources/opencl/codec/byte_offset.cl
diff --git a/silx/resources/opencl/convolution.cl b/src/silx/resources/opencl/convolution.cl
index 629b8fc..629b8fc 100644
--- a/silx/resources/opencl/convolution.cl
+++ b/src/silx/resources/opencl/convolution.cl
diff --git a/silx/resources/opencl/convolution_textures.cl b/src/silx/resources/opencl/convolution_textures.cl
index 517a67c..517a67c 100644
--- a/silx/resources/opencl/convolution_textures.cl
+++ b/src/silx/resources/opencl/convolution_textures.cl
diff --git a/silx/resources/opencl/doubleword.cl b/src/silx/resources/opencl/doubleword.cl
index a0ebfda..a0ebfda 100644
--- a/silx/resources/opencl/doubleword.cl
+++ b/src/silx/resources/opencl/doubleword.cl
diff --git a/silx/resources/opencl/image/cast.cl b/src/silx/resources/opencl/image/cast.cl
index 9e23a82..9e23a82 100644
--- a/silx/resources/opencl/image/cast.cl
+++ b/src/silx/resources/opencl/image/cast.cl
diff --git a/silx/resources/opencl/image/histogram.cl b/src/silx/resources/opencl/image/histogram.cl
index 7fb1485..7fb1485 100644
--- a/silx/resources/opencl/image/histogram.cl
+++ b/src/silx/resources/opencl/image/histogram.cl
diff --git a/silx/resources/opencl/image/map.cl b/src/silx/resources/opencl/image/map.cl
index 804a5a1..804a5a1 100644
--- a/silx/resources/opencl/image/map.cl
+++ b/src/silx/resources/opencl/image/map.cl
diff --git a/silx/resources/opencl/image/max_min.cl b/src/silx/resources/opencl/image/max_min.cl
index 246cd48..246cd48 100644
--- a/silx/resources/opencl/image/max_min.cl
+++ b/src/silx/resources/opencl/image/max_min.cl
diff --git a/silx/resources/opencl/kahan.cl b/src/silx/resources/opencl/kahan.cl
index c23d84d..c23d84d 100644
--- a/silx/resources/opencl/kahan.cl
+++ b/src/silx/resources/opencl/kahan.cl
diff --git a/silx/resources/opencl/linalg.cl b/src/silx/resources/opencl/linalg.cl
index 8710528..8710528 100644
--- a/silx/resources/opencl/linalg.cl
+++ b/src/silx/resources/opencl/linalg.cl
diff --git a/silx/resources/opencl/medfilt.cl b/src/silx/resources/opencl/medfilt.cl
index 0680230..0680230 100644
--- a/silx/resources/opencl/medfilt.cl
+++ b/src/silx/resources/opencl/medfilt.cl
diff --git a/silx/resources/opencl/preprocess.cl b/src/silx/resources/opencl/preprocess.cl
index de35c48..de35c48 100644
--- a/silx/resources/opencl/preprocess.cl
+++ b/src/silx/resources/opencl/preprocess.cl
diff --git a/silx/resources/opencl/proj.cl b/src/silx/resources/opencl/proj.cl
index 2a6d870..2a6d870 100644
--- a/silx/resources/opencl/proj.cl
+++ b/src/silx/resources/opencl/proj.cl
diff --git a/silx/resources/opencl/sparse.cl b/src/silx/resources/opencl/sparse.cl
index c99a0e9..c99a0e9 100644
--- a/silx/resources/opencl/sparse.cl
+++ b/src/silx/resources/opencl/sparse.cl
diff --git a/silx/resources/opencl/statistics.cl b/src/silx/resources/opencl/statistics.cl
index 47d925b..47d925b 100644
--- a/silx/resources/opencl/statistics.cl
+++ b/src/silx/resources/opencl/statistics.cl
diff --git a/src/silx/setup.py b/src/silx/setup.py
new file mode 100644
index 0000000..5e2bd0d
--- /dev/null
+++ b/src/silx/setup.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/07/2018"
+
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('silx', parent_package, top_path)
+ config.add_subpackage('gui')
+ config.add_subpackage('io')
+ config.add_subpackage('math')
+ config.add_subpackage('image')
+ config.add_subpackage('opencl')
+ config.add_subpackage('resources')
+ config.add_subpackage('sx')
+ config.add_subpackage('test')
+ config.add_subpackage('third_party')
+ config.add_subpackage('utils')
+ config.add_subpackage('app')
+ config.add_subpackage("examples", "../../examples")
+
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/silx/sx/__init__.py b/src/silx/sx/__init__.py
index 97a3460..97a3460 100644
--- a/silx/sx/__init__.py
+++ b/src/silx/sx/__init__.py
diff --git a/src/silx/sx/_plot.py b/src/silx/sx/_plot.py
new file mode 100644
index 0000000..b44c042
--- /dev/null
+++ b/src/silx/sx/_plot.py
@@ -0,0 +1,625 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This module adds convenient functions to use plot widgets from the console.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "06/11/2018"
+
+
+import collections
+try:
+ from collections import abc
+except ImportError: # Python2 support
+ import collections as abc
+import logging
+import weakref
+
+import numpy
+
+from ..utils.weakref import WeakList
+from ..gui import qt
+from ..gui.plot import Plot1D, Plot2D, ScatterView
+from ..gui.plot import items
+from ..gui import colors
+from ..gui.plot.tools import roi
+from ..gui.plot.items import roi as roi_items
+from ..gui.plot.tools.toolbars import InteractiveModeToolBar
+
+_logger = logging.getLogger(__name__)
+
+_plots = WeakList()
+"""List of widgets created through plot and imshow"""
+
+
+def plot(*args, **kwargs):
+ """
+ Plot curves in a :class:`~silx.gui.plot.PlotWindow.Plot1D` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ Plot a single curve given some values:
+
+ >>> values = numpy.random.random(100)
+ >>> plot_1curve = sx.plot(values, title='Random data')
+
+ Plot a single curve given the x and y values:
+
+ >>> angles = numpy.linspace(0, numpy.pi, 100)
+ >>> sin_a = numpy.sin(angles)
+ >>> plot_sinus = sx.plot(angles, sin_a, xlabel='angle (radian)', ylabel='sin(a)')
+
+ Plot many curves by giving a 2D array, provided xn, yn arrays:
+
+ >>> plot_curves = sx.plot(x0, y0, x1, y1, x2, y2, ...)
+
+ Plot curve with style giving a style string:
+
+ >>> plot_styled = sx.plot(x0, y0, 'ro-', x1, y1, 'b.')
+
+ Supported symbols:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ Supported types of line:
+
+ - ' ' no line
+ - '-' solid line
+ - '--' dashed line
+ - '-.' dash-dot line
+ - ':' dotted line
+
+ If provided, the names arguments color, linestyle, linewidth and marker
+ override any style provided to a curve.
+
+ This function supports a subset of `matplotlib.pyplot.plot
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot>`_
+ arguments.
+
+ :param str color: Color to use for all curves (default: None)
+ :param str linestyle: Type of line to use for all curves (default: None)
+ :param float linewidth: With of all the curves (default: 1)
+ :param str marker: Symbol to use for all the curves (default: None)
+ :param str title: The title of the Plot widget (default: None)
+ :param str xlabel: The label of the X axis (default: None)
+ :param str ylabel: The label of the Y axis (default: None)
+ :return: The widget plotting the curve(s)
+ :rtype: silx.gui.plot.Plot1D
+ """
+ plt = Plot1D()
+ if 'title' in kwargs:
+ plt.setGraphTitle(kwargs['title'])
+ if 'xlabel' in kwargs:
+ plt.getXAxis().setLabel(kwargs['xlabel'])
+ if 'ylabel' in kwargs:
+ plt.getYAxis().setLabel(kwargs['ylabel'])
+
+ color = kwargs.get('color')
+ linestyle = kwargs.get('linestyle')
+ linewidth = kwargs.get('linewidth')
+ marker = kwargs.get('marker')
+
+ # Parse args and store curves as (x, y, style string)
+ args = list(args)
+ curves = []
+ while args:
+ first_arg = args.pop(0) # Process an arg
+
+ if len(args) == 0:
+ # Last curve defined as (y,)
+ curves.append((numpy.arange(len(first_arg)), first_arg, None))
+ else:
+ second_arg = args.pop(0)
+ if isinstance(second_arg, str):
+ # curve defined as (y, style)
+ y = first_arg
+ style = second_arg
+ curves.append((numpy.arange(len(y)), y, style))
+ else: # second_arg must be an array-like
+ x = first_arg
+ y = second_arg
+ if len(args) >= 1 and isinstance(args[0], str):
+ # Curve defined as (x, y, style)
+ style = args.pop(0)
+ curves.append((x, y, style))
+ else:
+ # Curve defined as (x, y)
+ curves.append((x, y, None))
+
+ for index, curve in enumerate(curves):
+ x, y, style = curve
+
+ # Default style
+ curve_symbol, curve_linestyle, curve_color = None, None, None
+
+ # Parse style
+ if style:
+ # Handle color first
+ possible_colors = [c for c in colors.COLORDICT if style.startswith(c)]
+ if possible_colors: # Take the longest string matching a color name
+ curve_color = possible_colors[0]
+ for c in possible_colors[1:]:
+ if len(c) > len(curve_color):
+ curve_color = c
+ style = style[len(curve_color):]
+
+ if style:
+ # Run twice to handle inversion symbol/linestyle
+ for _i in range(2):
+ # Handle linestyle
+ for line in (' ', '--', '-', '-.', ':'):
+ if style.endswith(line):
+ curve_linestyle = line
+ style = style[:-len(line)]
+ break
+
+ # Handle symbol
+ for curve_marker in ('o', '.', ',', '+', 'x', 'd', 's'):
+ if style.endswith(curve_marker):
+ curve_symbol = style[-1]
+ style = style[:-1]
+ break
+
+ # As in matplotlib, marker, linestyle and color override other style
+ plt.addCurve(x, y,
+ legend=('curve_%d' % index),
+ symbol=marker or curve_symbol,
+ linestyle=linestyle or curve_linestyle,
+ linewidth=linewidth,
+ color=color or curve_color)
+
+ plt.show()
+ _plots.insert(0, plt)
+ return plt
+
+
+def imshow(data=None, cmap=None, norm=colors.Colormap.LINEAR,
+ vmin=None, vmax=None,
+ aspect=False,
+ origin='upper', scale=(1., 1.),
+ title='', xlabel='X', ylabel='Y'):
+ """
+ Plot an image in a :class:`~silx.gui.plot.PlotWindow.Plot2D` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ >>> data = numpy.random.random(1024 * 1024).reshape(1024, 1024)
+ >>> plt = sx.imshow(data, title='Random data')
+
+ By default, the image origin is displayed in the upper left
+ corner of the plot. To invert the Y axis, and place the image origin
+ in the lower left corner of the plot, use the *origin* parameter:
+
+ >>> plt = sx.imshow(data, origin='lower')
+
+ This function supports a subset of `matplotlib.pyplot.imshow
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow>`_
+ arguments.
+
+ :param data: data to plot as an image
+ :type data: numpy.ndarray-like with 2 dimensions
+ :param str cmap: The name of the colormap to use for the plot. It also
+ supports a numpy array containing a RGB LUT, or a `colors.Colormap`
+ instance.
+ :param str norm: The normalization of the colormap:
+ 'linear' (default) or 'log'
+ :param float vmin: The value to use for the min of the colormap
+ :param float vmax: The value to use for the max of the colormap
+ :param bool aspect: True to keep aspect ratio (Default: False)
+ :param origin: Either image origin as the Y axis orientation:
+ 'upper' (default) or 'lower'
+ or the coordinates (ox, oy) of the image origin in the plot.
+ :type origin: str or 2-tuple of floats
+ :param scale: (sx, sy) The scale of the image in the plot
+ (i.e., the size of the image's pixel in plot coordinates)
+ :type scale: 2-tuple of floats
+ :param str title: The title of the Plot widget
+ :param str xlabel: The label of the X axis
+ :param str ylabel: The label of the Y axis
+ :return: The widget plotting the image
+ :rtype: silx.gui.plot.Plot2D
+ """
+ plt = Plot2D()
+ plt.setGraphTitle(title)
+ plt.getXAxis().setLabel(xlabel)
+ plt.getYAxis().setLabel(ylabel)
+
+ # Update default colormap with input parameters
+ colormap = plt.getDefaultColormap()
+ if isinstance(cmap, colors.Colormap):
+ colormap = cmap
+ plt.setDefaultColormap(colormap)
+ elif isinstance(cmap, numpy.ndarray):
+ colormap.setColors(cmap)
+ elif cmap is not None:
+ colormap.setName(cmap)
+ assert norm in colors.Colormap.NORMALIZATIONS
+ colormap.setNormalization(norm)
+ colormap.setVMin(vmin)
+ colormap.setVMax(vmax)
+
+ # Handle aspect
+ if aspect in (None, False, 'auto', 'normal'):
+ plt.setKeepDataAspectRatio(False)
+ elif aspect in (True, 'equal') or aspect == 1:
+ plt.setKeepDataAspectRatio(True)
+ else:
+ _logger.warning(
+ 'imshow: Unhandled aspect argument: %s', str(aspect))
+
+ # Handle matplotlib-like origin
+ if origin in ('upper', 'lower'):
+ plt.setYAxisInverted(origin == 'upper')
+ origin = 0., 0. # Set origin to the definition of silx
+
+ if data is not None:
+ data = numpy.array(data, copy=True)
+
+ assert data.ndim in (2, 3) # data or RGB(A)
+ if data.ndim == 3:
+ assert data.shape[-1] in (3, 4) # RGB(A) image
+
+ plt.addImage(data, origin=origin, scale=scale)
+
+ plt.show()
+ _plots.insert(0, plt)
+ return plt
+
+
+def scatter(x=None, y=None, value=None, size=None,
+ marker=None,
+ cmap=None, norm=colors.Colormap.LINEAR,
+ vmin=None, vmax=None):
+ """
+ Plot scattered data in a :class:`~silx.gui.plot.ScatterView` widget.
+
+ How to use:
+
+ >>> from silx import sx
+ >>> import numpy
+
+ >>> x = numpy.random.random(100)
+ >>> y = numpy.random.random(100)
+ >>> values = numpy.random.random(100)
+ >>> plt = sx.scatter(x, y, values, cmap='viridis')
+
+ Supported symbols:
+
+ - 'o' circle
+ - '.' point
+ - ',' pixel
+ - '+' cross
+ - 'x' x-cross
+ - 'd' diamond
+ - 's' square
+
+ This function supports a subset of `matplotlib.pyplot.scatter
+ <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_
+ arguments.
+
+ :param numpy.ndarray x: 1D array-like of x coordinates
+ :param numpy.ndarray y: 1D array-like of y coordinates
+ :param numpy.ndarray value: 1D array-like of data values
+ :param float size: Size^2 of the markers
+ :param str marker: Symbol used to represent the points
+ :param str cmap: The name of the colormap to use for the plot
+ :param str norm: The normalization of the colormap:
+ 'linear' (default) or 'log'
+ :param float vmin: The value to use for the min of the colormap
+ :param float vmax: The value to use for the max of the colormap
+ :return: The widget plotting the scatter plot
+ :rtype: silx.gui.plot.ScatterView.ScatterView
+ """
+ plt = ScatterView()
+
+ # Update default colormap with input parameters
+ colormap = plt.getPlotWidget().getDefaultColormap()
+ if cmap is not None:
+ colormap.setName(cmap)
+ assert norm in colors.Colormap.NORMALIZATIONS
+ colormap.setNormalization(norm)
+ colormap.setVMin(vmin)
+ colormap.setVMax(vmax)
+ plt.getPlotWidget().setDefaultColormap(colormap)
+
+ if x is not None and y is not None: # Add a scatter plot
+ x = numpy.array(x, copy=True).reshape(-1)
+ y = numpy.array(y, copy=True).reshape(-1)
+ assert len(x) == len(y)
+
+ if value is None:
+ value = numpy.ones(len(x), dtype=numpy.float32)
+
+ elif isinstance(value, abc.Iterable):
+ value = numpy.array(value, copy=True).reshape(-1)
+ assert len(x) == len(value)
+
+ else:
+ value = numpy.ones(len(x), dtype=numpy.float64) * value
+
+ plt.setData(x, y, value)
+ item = plt.getScatterItem()
+ if marker is not None:
+ item.setSymbol(marker)
+ if size is not None:
+ item.setSymbolSize(numpy.sqrt(size))
+
+ plt.resetZoom()
+
+ plt.show()
+ _plots.insert(0, plt.getPlotWidget())
+ return plt
+
+
+class _GInputResult(tuple):
+ """Object storing :func:`ginput` result
+
+ :param position: Selected point coordinates in the plot (x, y)
+ :param Item item: Plot item under the selected position
+ :param indices: Selected indices in the data of the item.
+ For a curve it is a list of indices, for an image it is (row, column)
+ :param data: Value of data at selected indices.
+ For a curve it is an array of values, for an image it is a single value
+ """
+
+ def __new__(cls, position, item, indices, data):
+ return super(_GInputResult, cls).__new__(cls, position)
+
+ def __init__(self, position, item, indices, data):
+ self._itemRef = weakref.ref(item) if item is not None else None
+ self._indices = numpy.array(indices, copy=True)
+ if isinstance(data, abc.Iterable):
+ self._data = numpy.array(data, copy=True)
+ else:
+ self._data = data
+
+ def getItem(self):
+ """Returns the item at the selected position if any.
+
+ :return: plot item under the selected postion.
+ It is None if there was no item at that position or if
+ it is no more in the plot.
+ :rtype: silx.gui.plot.items.Item"""
+ return None if self._itemRef is None else self._itemRef()
+
+ def getIndices(self):
+ """Returns indices in data array at the select position
+
+ :return: 1D array of indices for curve and (row, column) for images
+ :rtype: numpy.ndarray
+ """
+ return numpy.array(self._indices, copy=True)
+
+ def getData(self):
+ """Returns data value at the selected position.
+
+ For curves, an array of (x, y) values close to the point is returned.
+ For images, either a single value or a RGB(A) array is returned.
+
+ :return: 2D array of (x, y) data values for curves (Nx2),
+ a single value for data images and RGB(A) array for images.
+ """
+ if isinstance(self._data, numpy.ndarray):
+ return numpy.array(self._data, copy=True)
+ else:
+ return self._data
+
+
+class _GInputHandler(roi.InteractiveRegionOfInterestManager):
+ """Implements :func:`ginput`
+
+ :param PlotWidget plot:
+ :param int n: Max number of points to request
+ :param float timeout: Timeout in seconds
+ """
+
+ def __init__(self, plot, n, timeout):
+ super(_GInputHandler, self).__init__(plot)
+
+ self._timeout = timeout
+ self.__selections = collections.OrderedDict()
+
+ window = plot.window() # Retrieve window containing PlotWidget
+ statusBar = window.statusBar()
+ self.sigMessageChanged.connect(statusBar.showMessage)
+ self.setMaxRois(n)
+ self.setValidationMode(self.ValidationMode.AUTO_ENTER)
+ self.sigRoiAdded.connect(self.__added)
+ self.sigRoiAboutToBeRemoved.connect(self.__removed)
+
+ def exec(self):
+ """Request user inputs
+
+ :return: List of selection points information
+ """
+ plot = self.parent()
+ if plot is None:
+ return
+
+ window = plot.window() # Retrieve window containing PlotWidget
+
+ # Add ROI point interactive mode action
+ for toolbar in window.findChildren(qt.QToolBar):
+ if isinstance(toolbar, InteractiveModeToolBar):
+ break
+ else: # Add a toolbar
+ toolbar = qt.QToolBar()
+ window.addToolBar(toolbar)
+ toolbar.addAction(self.getInteractionModeAction(roi_items.PointROI))
+
+ super(_GInputHandler, self).exec(roiClass=roi_items.PointROI, timeout=self._timeout)
+
+ if isinstance(toolbar, InteractiveModeToolBar):
+ toolbar.removeAction(self.getInteractionModeAction(roi_items.PointROI))
+ else:
+ toolbar.setParent(None)
+
+ return tuple(self.__selections.values())
+
+ def exec_(self): # Qt5-like compatibility
+ return self.exec()
+
+ def __updateSelection(self, roi):
+ """Perform picking and update selection list
+
+ :param RegionOfInterest roi:
+ """
+ plot = self.parent()
+ if plot is None:
+ return # No plot, abort
+
+ if not isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ raise RuntimeError("Unexpected item")
+
+ x, y = roi.getPosition()
+ xPixel, yPixel = plot.dataToPixel(x, y, axis='left', check=False)
+
+ # Pick item at selected position
+ pickingResult = plot._pickTopMost(
+ xPixel, yPixel,
+ lambda item: isinstance(item, (items.ImageBase, items.Curve)))
+
+ if pickingResult is None:
+ result = _GInputResult((x, y),
+ item=None,
+ indices=numpy.array((), dtype=int),
+ data=None)
+ else:
+ item = pickingResult.getItem()
+ indices = pickingResult.getIndices(copy=True)
+
+ if isinstance(item, items.Curve):
+ xData = item.getXData(copy=False)[indices]
+ yData = item.getYData(copy=False)[indices]
+ result = _GInputResult((x, y),
+ item=item,
+ indices=indices,
+ data=numpy.array((xData, yData)).T)
+
+ elif isinstance(item, items.ImageBase):
+ row, column = indices[0][0], indices[1][0]
+ data = item.getData(copy=False)[row, column]
+ result = _GInputResult((x, y),
+ item=item,
+ indices=(row, column),
+ data=data)
+
+ self.__selections[roi] = result
+
+ def __added(self, roi):
+ """Handle new ROI added
+
+ :param RegionOfInterest roi:
+ """
+ if isinstance(roi, roi_items.PointROI):
+ # Only handle points
+ roi.setName('%d' % len(self.__selections))
+ self.__updateSelection(roi)
+ roi.sigRegionChanged.connect(self.__regionChanged)
+
+ def __removed(self, roi):
+ """Handle ROI removed"""
+ if self.__selections.pop(roi, None) is not None:
+ roi.sigRegionChanged.disconnect(self.__regionChanged)
+
+ def __regionChanged(self):
+ """Handle update of a ROI"""
+ roi = self.sender()
+ self.__updateSelection(roi)
+
+
+def ginput(n=1, timeout=30, plot=None):
+ """Get input points on a plot.
+
+ If no plot is provided, it uses a plot widget created with
+ either :func:`silx.sx.plot` or :func:`silx.sx.imshow`.
+
+ How to use:
+
+ >>> from silx import sx
+
+ >>> sx.imshow(image) # Plot the image
+ >>> sx.ginput(1) # Request selection on the image plot
+ ((0.598, 1.234))
+
+ How to get more information about the selected positions:
+
+ >>> positions = sx.ginput(1)
+
+ >>> positions[0].getData() # Returns value(s) at selected position
+
+ >>> positions[0].getIndices() # Returns data indices at selected position
+
+ >>> positions[0].getItem() # Returns plot item at selected position
+
+ :param int n: Number of points the user need to select
+ :param float timeout: Timeout in seconds before ginput returns
+ event if selection is not completed
+ :param silx.gui.plot.PlotWidget.PlotWidget plot: An optional PlotWidget
+ from which to get input
+ :return: List of clicked points coordinates (x, y) in plot
+ :raise ValueError: If provided plot is not a PlotWidget
+ """
+ if plot is None:
+ # Select most recent visible plot widget
+ for widget in _plots:
+ if widget.isVisible():
+ plot = widget
+ break
+ else: # If no plot widget is visible, take the most recent one
+ try:
+ plot = _plots[0]
+ except IndexError:
+ pass
+ else:
+ plot.show()
+
+ if plot is None:
+ _logger.warning('No plot available to perform ginput, create one')
+ plot = Plot1D()
+ plot.show()
+ _plots.insert(0, plot)
+
+ plot.raise_() # So window becomes the top level one
+
+ _logger.info('Performing ginput with plot widget %s', str(plot))
+ handler = _GInputHandler(plot, n, timeout)
+ points = handler.exec()
+
+ return points
diff --git a/silx/sx/_plot3d.py b/src/silx/sx/_plot3d.py
index 444d9e0..444d9e0 100644
--- a/silx/sx/_plot3d.py
+++ b/src/silx/sx/_plot3d.py
diff --git a/src/silx/test/__init__.py b/src/silx/test/__init__.py
new file mode 100644
index 0000000..d9d3e42
--- /dev/null
+++ b/src/silx/test/__init__.py
@@ -0,0 +1,53 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""This package provides test of the root modules
+"""
+
+import logging
+
+
+try:
+ import pytest
+except ImportError:
+ logging.getLogger(__name__).error(
+ "pytest is required to run the tests, please install it.")
+ raise
+
+def run_tests(module: str='silx', verbosity: int=0, args=()):
+ """Run tests
+
+ :param module: Name of the silx module to test (default: 'silx')
+ :param verbosity: Requested level of verbosity
+ :param args: List of extra arguments to pass to `pytest`
+ """
+ return pytest.main([
+ '--pyargs',
+ module,
+ '--verbosity',
+ str(verbosity),
+ '-o python_files=["test/test*.py","test/Test*.py"]',
+ '-o python_classes=["Test"]',
+ '-o python_functions=["Test"]',
+ ] + list(args))
diff --git a/src/silx/test/test_resources.py b/src/silx/test/test_resources.py
new file mode 100644
index 0000000..4030271
--- /dev/null
+++ b/src/silx/test/test_resources.py
@@ -0,0 +1,187 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for resource files management."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import os
+import unittest
+import shutil
+import tempfile
+
+import silx.resources
+
+
+class TestResources(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestResources, cls).setUpClass()
+
+ cls.tmpDirectory = tempfile.mkdtemp(prefix="resource_")
+ os.mkdir(os.path.join(cls.tmpDirectory, "gui"))
+ destination_dir = os.path.join(cls.tmpDirectory, "gui", "icons")
+ os.mkdir(destination_dir)
+ source = silx.resources.resource_filename("gui/icons/zoom-in.png")
+ destination = os.path.join(destination_dir, "foo.png")
+ shutil.copy(source, destination)
+ source = silx.resources.resource_filename("gui/icons/zoom-out.svg")
+ destination = os.path.join(destination_dir, "close.png")
+ shutil.copy(source, destination)
+
+ @classmethod
+ def tearDownClass(cls):
+ super(TestResources, cls).tearDownClass()
+ shutil.rmtree(cls.tmpDirectory)
+
+ def setUp(self):
+ # Store the original configuration
+ self._oldResources = dict(silx.resources._RESOURCE_DIRECTORIES)
+ unittest.TestCase.setUp(self)
+
+ def tearDown(self):
+ unittest.TestCase.tearDown(self)
+ # Restiture the original configuration
+ silx.resources._RESOURCE_DIRECTORIES = self._oldResources
+
+ def test_resource_dir(self):
+ """Get a resource directory"""
+ icons_dirname = silx.resources.resource_filename('gui/icons/')
+ self.assertTrue(os.path.isdir(icons_dirname))
+
+ def test_resource_file(self):
+ """Get a resource file name"""
+ filename = silx.resources.resource_filename('gui/icons/colormap.png')
+ self.assertTrue(os.path.isfile(filename))
+
+ def test_resource_nonexistent(self):
+ """Get a non existent resource"""
+ filename = silx.resources.resource_filename('non_existent_file.txt')
+ self.assertFalse(os.path.exists(filename))
+
+ def test_isdir(self):
+ self.assertTrue(silx.resources.is_dir('gui/icons'))
+
+ def test_not_isdir(self):
+ self.assertFalse(silx.resources.is_dir('gui/icons/colormap.png'))
+
+ def test_list_dir(self):
+ result = silx.resources.list_dir('gui/icons')
+ self.assertTrue(len(result) > 10)
+
+ # With prefixed resources
+
+ def test_resource_dir_with_prefix(self):
+ """Get a resource directory"""
+ icons_dirname = silx.resources.resource_filename('silx:gui/icons/')
+ self.assertTrue(os.path.isdir(icons_dirname))
+
+ def test_resource_file_with_prefix(self):
+ """Get a resource file name"""
+ filename = silx.resources.resource_filename('silx:gui/icons/colormap.png')
+ self.assertTrue(os.path.isfile(filename))
+
+ def test_resource_nonexistent_with_prefix(self):
+ """Get a non existent resource"""
+ filename = silx.resources.resource_filename('silx:non_existent_file.txt')
+ self.assertFalse(os.path.exists(filename))
+
+ def test_isdir_with_prefix(self):
+ self.assertTrue(silx.resources.is_dir('silx:gui/icons'))
+
+ def test_not_isdir_with_prefix(self):
+ self.assertFalse(silx.resources.is_dir('silx:gui/icons/colormap.png'))
+
+ def test_list_dir_with_prefix(self):
+ result = silx.resources.list_dir('silx:gui/icons')
+ self.assertTrue(len(result) > 10)
+
+ # Test new repository
+
+ def test_repository_not_exists(self):
+ """The resource from 'test' is available"""
+ self.assertRaises(ValueError, silx.resources.resource_filename, 'test:foo.png')
+
+ def test_adding_test_directory(self):
+ """The resource from 'test' is available"""
+ silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
+ path = silx.resources.resource_filename('test:gui/icons/foo.png')
+ self.assertTrue(os.path.exists(path))
+
+ def test_adding_test_directory_no_override(self):
+ """The resource from 'silx' is still available"""
+ silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
+ filename1 = silx.resources.resource_filename('gui/icons/close.png')
+ filename2 = silx.resources.resource_filename('silx:gui/icons/close.png')
+ filename3 = silx.resources.resource_filename('test:gui/icons/close.png')
+ self.assertTrue(os.path.isfile(filename1))
+ self.assertTrue(os.path.isfile(filename2))
+ self.assertTrue(os.path.isfile(filename3))
+ self.assertEqual(filename1, filename2)
+ self.assertNotEqual(filename1, filename3)
+
+ def test_adding_test_directory_non_existing(self):
+ """A resource while not exists in test is not available anyway it exists
+ in silx"""
+ silx.resources.register_resource_directory("test", "silx.test.resources", forced_path=self.tmpDirectory)
+ resource_name = "gui/icons/colormap.png"
+ path = silx.resources.resource_filename('test:' + resource_name)
+ path2 = silx.resources.resource_filename('silx:' + resource_name)
+ self.assertFalse(os.path.exists(path))
+ self.assertTrue(os.path.exists(path2))
+
+
+class TestResourcesWithoutPkgResources(TestResources):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestResourcesWithoutPkgResources, cls).setUpClass()
+ cls._old = silx.resources.pkg_resources
+ silx.resources.pkg_resources = None
+
+ @classmethod
+ def tearDownClass(cls):
+ silx.resources.pkg_resources = cls._old
+ del cls._old
+ super(TestResourcesWithoutPkgResources, cls).tearDownClass()
+
+
+class TestResourcesWithCustomDirectory(TestResources):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestResourcesWithCustomDirectory, cls).setUpClass()
+ cls._old = silx.resources._RESOURCES_DIR
+ base = os.path.dirname(silx.resources.__file__)
+ silx.resources._RESOURCES_DIR = base
+
+ @classmethod
+ def tearDownClass(cls):
+ silx.resources._RESOURCES_DIR = cls._old
+ del cls._old
+ super(TestResourcesWithCustomDirectory, cls).tearDownClass()
diff --git a/src/silx/test/test_sx.py b/src/silx/test/test_sx.py
new file mode 100644
index 0000000..9836285
--- /dev/null
+++ b/src/silx/test/test_sx.py
@@ -0,0 +1,265 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+__authors__ = ["T. Vincent", "P. Knobel"]
+__license__ = "MIT"
+__date__ = "06/11/2018"
+
+
+import numpy
+import pytest
+
+from silx.gui import qt
+from silx.gui.colors import rgba
+from silx.gui.colors import Colormap
+
+
+@pytest.fixture(scope="module")
+def sx(qapp):
+ """Lazy loading to avoid it to create QApplication before qapp fixture"""
+ from silx import sx
+ if sx._IS_NOTEBOOK:
+ pytest.skip("notebook context")
+ if sx._NO_DISPLAY:
+ pytest.skip("no DISPLAY specified")
+ yield sx
+
+
+def test_plot(sx, qapp_utils):
+ """Test plot function"""
+ y = numpy.random.random(100)
+ x = numpy.arange(len(y)) * 0.5
+
+ # Nothing
+ plt = sx.plot()
+ qapp_utils.exposeAndClose(plt)
+
+ # y
+ plt = sx.plot(y, title='y')
+ qapp_utils.exposeAndClose(plt)
+
+ # y, style
+ plt = sx.plot(y, 'blued ', title='y, "blued "')
+ qapp_utils.exposeAndClose(plt)
+
+ # x, y
+ plt = sx.plot(x, y, title='x, y')
+ qapp_utils.exposeAndClose(plt)
+
+ # x, y, style
+ plt = sx.plot(x, y, 'ro-', xlabel='x', title='x, y, "ro-"')
+ qapp_utils.exposeAndClose(plt)
+
+ # x, y, style, y
+ plt = sx.plot(x, y, 'ro-', y ** 2, xlabel='x', ylabel='y',
+ title='x, y, "ro-", y ** 2')
+ qapp_utils.exposeAndClose(plt)
+
+ # x, y, style, y, style
+ plt = sx.plot(x, y, 'ro-', y ** 2, 'b--',
+ title='x, y, "ro-", y ** 2, "b--"')
+ qapp_utils.exposeAndClose(plt)
+
+ # x, y, style, x, y, style
+ plt = sx.plot(x, y, 'ro-', x, y ** 2, 'b--',
+ title='x, y, "ro-", x, y ** 2, "b--"')
+ qapp_utils.exposeAndClose(plt)
+
+ # x, y, x, y
+ plt = sx.plot(x, y, x, y ** 2, title='x, y, x, y ** 2')
+ qapp_utils.exposeAndClose(plt)
+
+
+def test_imshow(sx, qapp_utils):
+ """Test imshow function"""
+ img = numpy.arange(100.).reshape(10, 10) + 1
+
+ # Nothing
+ plt = sx.imshow()
+ qapp_utils.exposeAndClose(plt)
+
+ # image
+ plt = sx.imshow(img)
+ qapp_utils.exposeAndClose(plt)
+
+ # image, named cmap
+ plt = sx.imshow(img, cmap='jet', title='jet cmap')
+ qapp_utils.exposeAndClose(plt)
+
+ # image, custom colormap
+ plt = sx.imshow(img, cmap=Colormap(), title='custom colormap')
+ qapp_utils.exposeAndClose(plt)
+
+ # image, log cmap
+ plt = sx.imshow(img, norm='log', title='log cmap')
+ qapp_utils.exposeAndClose(plt)
+
+ # image, fixed range
+ plt = sx.imshow(img, vmin=10, vmax=20,
+ title='[10,20] cmap')
+ qapp_utils.exposeAndClose(plt)
+
+ # image, keep ratio
+ plt = sx.imshow(img, aspect=True,
+ title='keep ratio')
+ qapp_utils.exposeAndClose(plt)
+
+ # image, change origin and scale
+ plt = sx.imshow(img, origin=(10, 10), scale=(2, 2),
+ title='origin=(10, 10), scale=(2, 2)')
+ qapp_utils.exposeAndClose(plt)
+
+ # image, origin='lower'
+ plt = sx.imshow(img, origin='upper', title='origin="lower"')
+ qapp_utils.exposeAndClose(plt)
+
+
+def test_scatter(sx, qapp_utils):
+ """Test scatter function"""
+ x = numpy.arange(100)
+ y = numpy.arange(100)
+ values = numpy.arange(100)
+
+ # simple scatter
+ plt = sx.scatter(x, y, values)
+ qapp_utils.exposeAndClose(plt)
+
+ # No value
+ plt = sx.scatter(x, y, values)
+ qapp_utils.exposeAndClose(plt)
+
+ # single value
+ plt = sx.scatter(x, y, 10.)
+ qapp_utils.exposeAndClose(plt)
+
+ # set size
+ plt = sx.scatter(x, y, values, size=20)
+ qapp_utils.exposeAndClose(plt)
+
+ # set colormap
+ plt = sx.scatter(x, y, values, cmap='jet')
+ qapp_utils.exposeAndClose(plt)
+
+ # set colormap range
+ plt = sx.scatter(x, y, values, vmin=2, vmax=50)
+ qapp_utils.exposeAndClose(plt)
+
+ # set colormap normalisation
+ plt = sx.scatter(x, y, values, norm='log')
+ qapp_utils.exposeAndClose(plt)
+
+
+@pytest.mark.parametrize("plot_kind", ["plot", "imshow", "scatter"])
+def test_ginput(sx, qapp, qapp_utils, plot_kind):
+ """Test ginput function
+
+ This does NOT perform interactive tests
+ """
+ create_plot = getattr(sx, plot_kind)
+ plt = create_plot()
+ qapp_utils.qWaitForWindowExposed(plt)
+ qapp.processEvents()
+
+ result = sx.ginput(1, timeout=0.1)
+ assert len(result) == 0
+
+ plt.setAttribute(qt.Qt.WA_DeleteOnClose)
+ plt.close()
+
+
+@pytest.mark.usefixtures("use_opengl")
+def test_contour3d(sx, qapp_utils):
+ """Test contour3d function"""
+ coords = numpy.linspace(-10, 10, 64)
+ z = coords.reshape(-1, 1, 1)
+ y = coords.reshape(1, -1, 1)
+ x = coords.reshape(1, 1, -1)
+ data = numpy.sin(x * y * z) / (x * y * z)
+
+ # Just data
+ window = sx.contour3d(data)
+
+ isosurfaces = window.getIsosurfaces()
+ assert len(isosurfaces) == 1
+
+ if not window.getPlot3DWidget().isValid():
+ del window, isosurfaces # Release widget reference
+ pytest.skip("OpenGL context is not valid")
+
+ # N contours + color
+ colors = ['red', 'green', 'blue']
+ window = sx.contour3d(data, copy=False, contours=len(colors),
+ color=colors)
+
+ isosurfaces = window.getIsosurfaces()
+ assert len(isosurfaces) == len(colors)
+ for iso, color in zip(isosurfaces, colors):
+ assert rgba(iso.getColor()) == rgba(color)
+
+ # by isolevel, single color
+ contours = 0.2, 0.5
+ window = sx.contour3d(data, copy=False, contours=contours,
+ color='yellow')
+
+ isosurfaces = window.getIsosurfaces()
+ assert len(isosurfaces) == len(contours)
+ for iso, level in zip(isosurfaces, contours):
+ assert iso.getLevel() == level
+ assert rgba(iso.getColor()) == rgba('yellow')
+
+ # Single isolevel, colormap
+ window = sx.contour3d(data, copy=False, contours=0.5,
+ colormap='gray', vmin=0.6, opacity=0.4)
+
+ isosurfaces = window.getIsosurfaces()
+ assert len(isosurfaces) == 1
+ assert isosurfaces[0].getLevel() == 0.5
+ assert rgba(isosurfaces[0].getColor()) == (0., 0., 0., 0.4)
+
+
+@pytest.mark.usefixtures("use_opengl")
+def test_points3d(sx, qapp_utils):
+ """Test points3d function"""
+ x = numpy.random.random(1024)
+ y = numpy.random.random(1024)
+ z = numpy.random.random(1024)
+ values = numpy.random.random(1024)
+
+ # 3D positions, no value
+ window = sx.points3d(x, y, z)
+
+ if not window.getSceneWidget().isValid():
+ del window # Release widget reference
+ pytest.skip("OpenGL context is not valid")
+
+ # 3D positions, values
+ window = sx.points3d(x, y, z, values, mode='2dsquare',
+ colormap='magma', vmin=0.4, vmax=0.5)
+
+ # 2D positions, no value
+ window = sx.points3d(x, y)
+
+ # 2D positions, values
+ window = sx.points3d(x, y, values=values, mode=',',
+ colormap='magma', vmin=0.4, vmax=0.5)
diff --git a/src/silx/test/test_version.py b/src/silx/test/test_version.py
new file mode 100644
index 0000000..80084f9
--- /dev/null
+++ b/src/silx/test/test_version.py
@@ -0,0 +1,38 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2015-2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Basic test of top-level package import and existence of version info."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/02/2016"
+
+import unittest
+
+import silx
+
+
+class TestVersion(unittest.TestCase):
+ def test_version(self):
+ self.assertTrue(isinstance(silx.version, str))
diff --git a/src/silx/test/utils.py b/src/silx/test/utils.py
new file mode 100644
index 0000000..0c2d5bf
--- /dev/null
+++ b/src/silx/test/utils.py
@@ -0,0 +1,198 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utilities for writing tests.
+
+- :func:`temp_dir` provides a with context to create/delete a temporary
+ directory.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "03/01/2019"
+
+
+import sys
+import contextlib
+import os
+import numpy
+import shutil
+import tempfile
+from ..resources import ExternalResources
+
+
+utilstest = ExternalResources(project="silx",
+ url_base="http://www.silx.org/pub/silx/",
+ env_key="SILX_DATA",
+ timeout=60)
+"This is the instance to be used. Singleton-like feature provided by module"
+
+
+class _TestOptions(object):
+
+ def __init__(self):
+ self.WITH_QT_TEST = True
+ """Qt tests are included"""
+
+ self.WITH_QT_TEST_REASON = ""
+ """Reason for Qt tests are disabled if any"""
+
+ self.WITH_OPENCL_TEST = True
+ """OpenCL tests are included"""
+
+ self.WITH_OPENCL_TEST_REASON = ""
+ """Reason for OpenCL tests are disabled if any"""
+
+ self.WITH_GL_TEST = True
+ """OpenGL tests are included"""
+
+ self.WITH_GL_TEST_REASON = ""
+ """Reason for OpenGL tests are disabled if any"""
+
+ self.TEST_LOW_MEM = False
+ """Skip tests using too much memory"""
+
+ self.TEST_LOW_MEM_REASON = ""
+ """Reason for low_memory tests are disabled if any"""
+
+ def configure(self, parsed_options=None):
+ """Configure the TestOptions class from the command line arguments and the
+ environment variables
+ """
+ if parsed_options is not None and not parsed_options.gui:
+ self.WITH_QT_TEST = False
+ self.WITH_QT_TEST_REASON = "Skipped by command line"
+ elif os.environ.get('WITH_QT_TEST', 'True') == 'False':
+ self.WITH_QT_TEST = False
+ self.WITH_QT_TEST_REASON = "Skipped by WITH_QT_TEST env var"
+ elif sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ self.WITH_QT_TEST = False
+ self.WITH_QT_TEST_REASON = "DISPLAY env variable not set"
+
+ if parsed_options is not None and not parsed_options.opencl:
+ self.WITH_OPENCL_TEST_REASON = "Skipped by command line"
+ self.WITH_OPENCL_TEST = False
+ elif os.environ.get('SILX_OPENCL', 'True') == 'False':
+ self.WITH_OPENCL_TEST_REASON = "Skipped by SILX_OPENCL env var"
+ self.WITH_OPENCL_TEST = False
+
+ if not self.WITH_OPENCL_TEST:
+ # That's an easy way to skip OpenCL tests
+ # It disable the use of OpenCL on the full silx project
+ os.environ['SILX_OPENCL'] = "False"
+
+ if parsed_options is not None and not parsed_options.opengl:
+ self.WITH_GL_TEST = False
+ self.WITH_GL_TEST_REASON = "Skipped by command line"
+ elif os.environ.get('WITH_GL_TEST', 'True') == 'False':
+ self.WITH_GL_TEST = False
+ self.WITH_GL_TEST_REASON = "Skipped by WITH_GL_TEST env var"
+ elif sys.platform.startswith('linux') and not os.environ.get('DISPLAY', ''):
+ self.WITH_GL_TEST = False
+ self.WITH_GL_TEST_REASON = "DISPLAY env variable not set"
+ else:
+ try:
+ import OpenGL
+ except ImportError:
+ self.WITH_GL_TEST = False
+ self.WITH_GL_TEST_REASON = "OpenGL package not available"
+
+ if parsed_options is not None and parsed_options.low_mem:
+ self.TEST_LOW_MEM = True
+ self.TEST_LOW_MEM_REASON = "Skipped by command line"
+ elif os.environ.get('SILX_TEST_LOW_MEM', 'True') == 'False':
+ self.TEST_LOW_MEM = True
+ self.TEST_LOW_MEM_REASON = "Skipped by SILX_TEST_LOW_MEM env var"
+
+ if self.WITH_QT_TEST:
+ try:
+ from silx.gui import qt
+ except ImportError:
+ self.WITH_QT_TEST = False
+ self.WITH_QT_TEST_REASON = "Qt is not installed"
+ else:
+ if sys.platform == "win32" and qt.qVersion() == "5.9.2":
+ self.SKIP_TEST_FOR_ISSUE_936 = True
+
+
+# Temporary directory context #################################################
+
+@contextlib.contextmanager
+def temp_dir():
+ """with context providing a temporary directory.
+
+ >>> import os.path
+ >>> with temp_dir() as tmp:
+ ... print(os.path.isdir(tmp)) # Use tmp directory
+ """
+ tmp_dir = tempfile.mkdtemp()
+ try:
+ yield tmp_dir
+ finally:
+ shutil.rmtree(tmp_dir)
+
+
+# Synthetic data and random noise #############################################
+def add_gaussian_noise(y, stdev=1., mean=0.):
+ """Add random gaussian noise to synthetic data.
+
+ :param ndarray y: Array of synthetic data
+ :param float mean: Mean of the gaussian distribution of noise.
+ :param float stdev: Standard deviation of the gaussian distribution of
+ noise.
+ :return: Array of data with noise added
+ """
+ noise = numpy.random.normal(mean, stdev, size=y.size)
+ noise.shape = y.shape
+ return y + noise
+
+
+def add_poisson_noise(y):
+ """Add random noise from a poisson distribution to synthetic data.
+
+ :param ndarray y: Array of synthetic data
+ :return: Array of data with noise added
+ """
+ yn = numpy.random.poisson(y)
+ yn.shape = y.shape
+ return yn
+
+
+def add_relative_noise(y, max_noise=5.):
+ """Add relative random noise to synthetic data. The maximum noise level
+ is given in percents.
+
+ An array of noise in the interval [-max_noise, max_noise] (continuous
+ uniform distribution) is generated, and applied to the data the
+ following way:
+
+ :math:`yn = y * (1. + noise / 100.)`
+
+ :param ndarray y: Array of synthetic data
+ :param float max_noise: Maximum percentage of noise
+ :return: Array of data with noise added
+ """
+ noise = max_noise * (2 * numpy.random.random(size=y.size) - 1)
+ noise.shape = y.shape
+ return y * (1. + noise / 100.)
diff --git a/silx/third_party/EdfFile.py b/src/silx/third_party/EdfFile.py
index 0606d1c..0606d1c 100644
--- a/silx/third_party/EdfFile.py
+++ b/src/silx/third_party/EdfFile.py
diff --git a/silx/third_party/TiffIO.py b/src/silx/third_party/TiffIO.py
index 7526a75..7526a75 100644
--- a/silx/third_party/TiffIO.py
+++ b/src/silx/third_party/TiffIO.py
diff --git a/silx/third_party/__init__.py b/src/silx/third_party/__init__.py
index 156563c..156563c 100644
--- a/silx/third_party/__init__.py
+++ b/src/silx/third_party/__init__.py
diff --git a/silx/third_party/scipy_spatial.py b/src/silx/third_party/scipy_spatial.py
index 9885154..9885154 100644
--- a/silx/third_party/scipy_spatial.py
+++ b/src/silx/third_party/scipy_spatial.py
diff --git a/src/silx/third_party/setup.py b/src/silx/third_party/setup.py
new file mode 100644
index 0000000..47686ea
--- /dev/null
+++ b/src/silx/third_party/setup.py
@@ -0,0 +1,49 @@
+# coding: ascii
+#
+# JK: Numpy.distutils which imports this does not handle utf-8 in version<1.12
+#
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["Valentin Valls"]
+__license__ = "MIT"
+__date__ = "23/04/2018"
+
+import os
+from numpy.distutils.misc_util import Configuration
+
+
+def configuration(parent_package='', top_path=None):
+ config = Configuration('third_party', parent_package, top_path)
+ # includes _local only if it is available
+ local_path = os.path.join(top_path, "src", "silx", "third_party", "_local")
+ if os.path.exists(local_path):
+ config.add_subpackage('_local')
+ config.add_subpackage('_local.scipy_spatial')
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+ setup(configuration=configuration)
diff --git a/src/silx/utils/ExternalResources.py b/src/silx/utils/ExternalResources.py
new file mode 100644
index 0000000..b79d6ff
--- /dev/null
+++ b/src/silx/utils/ExternalResources.py
@@ -0,0 +1,321 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Helper to access to external resources.
+"""
+
+__authors__ = ["Thomas Vincent", "J. Kieffer"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import os
+import threading
+import json
+import logging
+import tempfile
+import unittest
+import urllib.request
+import urllib.error
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalResources(object):
+ """Utility class which allows to download test-data from www.silx.org
+ and manage the temporary data during the tests.
+
+ """
+
+ def __init__(self, project,
+ url_base,
+ env_key=None,
+ timeout=60):
+ """Constructor of the class
+
+ :param str project: name of the project, like "silx"
+ :param str url_base: base URL for the data, like "http://www.silx.org/pub"
+ :param str env_key: name of the environment variable which contains the
+ test_data directory, like "SILX_DATA".
+ If None (default), then the name of the
+ environment variable is built from the project argument:
+ "<PROJECT>_DATA".
+ The environment variable is optional: in case it is not set,
+ a directory in the temporary folder is used.
+ :param timeout: time in seconds before it breaks
+ """
+ self.project = project
+ self._initialized = False
+ self.sem = threading.Semaphore()
+
+ self.env_key = env_key or (self.project.upper() + "_TESTDATA")
+ self.url_base = url_base
+ self.all_data = set()
+ self.timeout = timeout
+ self._data_home = None
+
+ @property
+ def data_home(self):
+ """Returns the data_home path and make sure it exists in the file
+ system."""
+ if self._data_home is not None:
+ return self._data_home
+
+ data_home = os.environ.get(self.env_key)
+ if data_home is None:
+ try:
+ import getpass
+ name = getpass.getuser()
+ except Exception:
+ if "getlogin" in dir(os):
+ name = os.getlogin()
+ elif "USER" in os.environ:
+ name = os.environ["USER"]
+ elif "USERNAME" in os.environ:
+ name = os.environ["USERNAME"]
+ else:
+ name = "uid" + str(os.getuid())
+
+ basename = "%s_testdata_%s" % (self.project, name)
+ data_home = os.path.join(tempfile.gettempdir(), basename)
+ if not os.path.exists(data_home):
+ os.makedirs(data_home)
+ self._data_home = data_home
+ return data_home
+
+ def _initialize_data(self):
+ """Initialize for downloading test data"""
+ if not self._initialized:
+ with self.sem:
+ if not self._initialized:
+ self.testdata = os.path.join(self.data_home, "all_testdata.json")
+ if os.path.exists(self.testdata):
+ with open(self.testdata) as f:
+ self.all_data = set(json.load(f))
+ self._initialized = True
+
+ def clean_up(self):
+ pass
+
+ def getfile(self, filename):
+ """Downloads the requested file from web-server available
+ at https://www.silx.org/pub/silx/
+
+ :param: relative name of the image.
+ :return: full path of the locally saved file.
+ """
+ logger.debug("ExternalResources.getfile('%s')", filename)
+
+ if not self._initialized:
+ self._initialize_data()
+
+ fullfilename = os.path.abspath(os.path.join(self.data_home, filename))
+
+ if not os.path.isfile(fullfilename):
+ logger.debug("Trying to download image %s, timeout set to %ss",
+ filename, self.timeout)
+ dictProxies = {}
+ if "http_proxy" in os.environ:
+ dictProxies['http'] = os.environ["http_proxy"]
+ dictProxies['https'] = os.environ["http_proxy"]
+ if "https_proxy" in os.environ:
+ dictProxies['https'] = os.environ["https_proxy"]
+ if dictProxies:
+ proxy_handler = urllib.request.ProxyHandler(dictProxies)
+ opener = urllib.request.build_opener(proxy_handler).open
+ else:
+ opener = urllib.request.urlopen
+
+ logger.debug("wget %s/%s", self.url_base, filename)
+ try:
+ data = opener("%s/%s" % (self.url_base, filename),
+ data=None, timeout=self.timeout).read()
+ logger.info("Image %s successfully downloaded.", filename)
+ except urllib.error.URLError:
+ raise unittest.SkipTest("network unreachable.")
+
+ if not os.path.isdir(os.path.dirname(fullfilename)):
+ # Create sub-directory if needed
+ os.makedirs(os.path.dirname(fullfilename))
+
+ try:
+ with open(fullfilename, "wb") as outfile:
+ outfile.write(data)
+ except IOError:
+ raise IOError("unable to write downloaded \
+ data to disk at %s" % self.data_home)
+
+ if not os.path.isfile(fullfilename):
+ raise RuntimeError(
+ """Could not automatically download test images %s!
+ If you are behind a firewall, please set both environment variable http_proxy and https_proxy.
+ This even works under windows !
+ Otherwise please try to download the images manually from
+ %s/%s""" % (filename, self.url_base, filename))
+
+ if filename not in self.all_data:
+ self.all_data.add(filename)
+ image_list = list(self.all_data)
+ image_list.sort()
+ try:
+ with open(self.testdata, "w") as fp:
+ json.dump(image_list, fp, indent=4)
+ except IOError:
+ logger.debug("Unable to save JSON list")
+
+ return fullfilename
+
+ def getdir(self, dirname):
+ """Downloads the requested tarball from the server
+ https://www.silx.org/pub/silx/
+ and unzips it into the data directory
+
+ :param: relative name of the image.
+ :return: list of files with their full path.
+ """
+ lodn = dirname.lower()
+ if (lodn.endswith("tar") or lodn.endswith("tgz") or
+ lodn.endswith("tbz2") or lodn.endswith("tar.gz") or
+ lodn.endswith("tar.bz2")):
+ import tarfile
+ engine = tarfile.TarFile.open
+ elif lodn.endswith("zip"):
+ import zipfile
+ engine = zipfile.ZipFile
+ else:
+ raise RuntimeError("Unsupported archive format. Only tar and zip "
+ "are currently supported")
+ full_path = self.getfile(dirname)
+ with engine(full_path, mode="r") as fd:
+ output = os.path.join(self.data_home, dirname + "__content")
+ fd.extractall(output)
+ if lodn.endswith("zip"):
+ result = [os.path.join(output, i) for i in fd.namelist()]
+ else:
+ result = [os.path.join(output, i) for i in fd.getnames()]
+ return result
+
+ def get_file_and_repack(self, filename):
+ """
+ Download the requested file, decompress and repack it to bz2 and gz.
+
+ :param str filename: name of the image.
+ :rtype: str
+ :return: full path of the locally saved file
+ """
+ if not self._initialized:
+ self._initialize_data()
+ if filename not in self.all_data:
+ self.all_data.add(filename)
+ image_list = list(self.all_data)
+ image_list.sort()
+ try:
+ with open(self.testdata, "w") as fp:
+ json.dump(image_list, fp, indent=4)
+ except IOError:
+ logger.debug("Unable to save JSON list")
+ baseimage = os.path.basename(filename)
+ logger.info("UtilsTest.getimage('%s')" % baseimage)
+
+ if not os.path.exists(self.data_home):
+ os.makedirs(self.data_home)
+ fullimagename = os.path.abspath(os.path.join(self.data_home, baseimage))
+
+ if baseimage.endswith(".bz2"):
+ bzip2name = baseimage
+ basename = baseimage[:-4]
+ gzipname = basename + ".gz"
+ elif baseimage.endswith(".gz"):
+ gzipname = baseimage
+ basename = baseimage[:-3]
+ bzip2name = basename + ".bz2"
+ else:
+ basename = baseimage
+ gzipname = baseimage + "gz2"
+ bzip2name = basename + ".bz2"
+
+ fullimagename_gz = os.path.abspath(os.path.join(self.data_home, gzipname))
+ fullimagename_raw = os.path.abspath(os.path.join(self.data_home, basename))
+ fullimagename_bz2 = os.path.abspath(os.path.join(self.data_home, bzip2name))
+
+ # The files are recreated from the bz2 file
+ if not os.path.isfile(fullimagename_bz2):
+ self.getfile(bzip2name)
+ if not os.path.isfile(fullimagename_bz2):
+ raise RuntimeError(
+ """Could not automatically download test images %s!
+ If you are behind a firewall, please set the environment variable http_proxy.
+ Otherwise please try to download the images manually from
+ %s""" % (self.url_base, filename))
+
+ try:
+ import bz2
+ except ImportError:
+ raise RuntimeError("bz2 library is needed to decompress data")
+ try:
+ import gzip
+ except ImportError:
+ gzip = None
+
+ raw_file_exists = os.path.isfile(fullimagename_raw)
+ gz_file_exists = os.path.isfile(fullimagename_gz)
+ if not raw_file_exists or not gz_file_exists:
+ with open(fullimagename_bz2, "rb") as f:
+ data = f.read()
+ decompressed = bz2.decompress(data)
+
+ if not raw_file_exists:
+ try:
+ with open(fullimagename_raw, "wb") as fullimage:
+ fullimage.write(decompressed)
+ except IOError:
+ raise IOError("unable to write decompressed \
+ data to disk at %s" % self.data_home)
+
+ if not gz_file_exists:
+ if gzip is None:
+ raise RuntimeError("gzip library is expected to recompress data")
+ try:
+ gzip.open(fullimagename_gz, "wb").write(decompressed)
+ except IOError:
+ raise IOError("unable to write gzipped \
+ data to disk at %s" % self.data_home)
+
+ return fullimagename
+
+ def download_all(self, imgs=None):
+ """Download all data needed for the test/benchmarks
+
+ :param imgs: list of files to download, by default all
+ :return: list of path with all files
+ """
+ if not self._initialized:
+ self._initialize_data()
+ if not imgs:
+ imgs = self.all_data
+ res = []
+ for fn in imgs:
+ logger.info("Downloading from silx.org: %s", fn)
+ res.append(self.getfile(fn))
+ return res
diff --git a/silx/utils/__init__.py b/src/silx/utils/__init__.py
index f803a5f..f803a5f 100644
--- a/silx/utils/__init__.py
+++ b/src/silx/utils/__init__.py
diff --git a/src/silx/utils/_have_openmp.pxd b/src/silx/utils/_have_openmp.pxd
new file mode 100644
index 0000000..89a385c
--- /dev/null
+++ b/src/silx/utils/_have_openmp.pxd
@@ -0,0 +1,49 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2018 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+"""
+Store in a Cython module if it was compiled with OpenMP
+
+You have to patch the setup module like that:
+
+.. code-block:: python
+
+ silx_include = os.path.join(top_path, "src", ""silx", "utils", "include")
+ config.add_extension('my_extension',
+ include_dirs=[silx_include],
+ ...)
+
+Then you can include it like that in your Cython module:
+
+.. code-block:: python
+
+ include "../../utils/_have_openmp.pxi"
+
+"""
+
+
+cdef extern from "silx_store_openmp.h":
+ int COMPILED_WITH_OPENMP
+_COMPILED_WITH_OPENMP = COMPILED_WITH_OPENMP
diff --git a/src/silx/utils/array_like.py b/src/silx/utils/array_like.py
new file mode 100644
index 0000000..0cf4857
--- /dev/null
+++ b/src/silx/utils/array_like.py
@@ -0,0 +1,595 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Functions and classes for array-like objects, implementing common numpy
+array features for datasets or nested sequences, while trying to avoid copying
+data.
+
+Classes:
+
+ - :class:`DatasetView`: Similar to a numpy view, to access
+ a h5py dataset as if it was transposed, without casting it into a
+ numpy array (this lets h5py handle reading the data from the
+ file into memory, as needed).
+ - :class:`ListOfImages`: Similar to a numpy view, to access
+ a list of 2D numpy arrays as if it was a 3D array (possibly transposed),
+ without casting it into a numpy array.
+
+Functions:
+
+ - :func:`is_array`
+ - :func:`is_list_of_arrays`
+ - :func:`is_nested_sequence`
+ - :func:`get_shape`
+ - :func:`get_dtype`
+ - :func:`get_concatenated_dtype`
+
+"""
+
+from __future__ import absolute_import, print_function, division
+
+import sys
+
+import numpy
+import numbers
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "26/04/2017"
+
+
+def is_array(obj):
+ """Return True if object implements necessary attributes to be
+ considered similar to a numpy array.
+
+ Attributes needed are "shape", "dtype", "__getitem__"
+ and "__array__".
+
+ :param obj: Array-like object (numpy array, h5py dataset...)
+ :return: boolean
+ """
+ # add more required attribute if necessary
+ for attr in ("shape", "dtype", "__array__", "__getitem__"):
+ if not hasattr(obj, attr):
+ return False
+ return True
+
+
+def is_list_of_arrays(obj):
+ """Return True if object is a sequence of numpy arrays,
+ e.g. a list of images as 2D arrays.
+
+ :param obj: list of arrays
+ :return: boolean"""
+ # object must not be a numpy array
+ if is_array(obj):
+ return False
+
+ # object must have a __len__ method
+ if not hasattr(obj, "__len__"):
+ return False
+
+ # all elements in sequence must be arrays
+ for arr in obj:
+ if not is_array(arr):
+ return False
+
+ return True
+
+
+def is_nested_sequence(obj):
+ """Return True if object is a nested sequence.
+
+ A simple 1D sequence is considered to be a nested sequence.
+
+ Numpy arrays and h5py datasets are not considered to be nested sequences.
+
+ To test if an object is a nested sequence in a more general sense,
+ including arrays and datasets, use::
+
+ is_nested_sequence(obj) or is_array(obj)
+
+ :param obj: nested sequence (numpy array, h5py dataset...)
+ :return: boolean"""
+ # object must not be a numpy array
+ if is_array(obj):
+ return False
+
+ if not hasattr(obj, "__len__"):
+ return False
+
+ # obj must not be a list of (lists of) numpy arrays
+ subsequence = obj
+ while hasattr(subsequence, "__len__"):
+ if is_array(subsequence):
+ return False
+ # strings cause infinite loops
+ if isinstance(subsequence, (str, bytes)):
+ return True
+ subsequence = subsequence[0]
+
+ # object has __len__ and is not an array
+ return True
+
+
+def get_shape(array_like):
+ """Return shape of an array like object.
+
+ In case the object is a nested sequence but not an array or dataset
+ (list of lists, tuples...), the size of each dimension is assumed to be
+ uniform, and is deduced from the length of the first sequence.
+
+ :param array_like: Array like object: numpy array, hdf5 dataset,
+ multi-dimensional sequence
+ :return: Shape of array, as a tuple of integers
+ """
+ if hasattr(array_like, "shape"):
+ return array_like.shape
+
+ shape = []
+ subsequence = array_like
+ while hasattr(subsequence, "__len__"):
+ shape.append(len(subsequence))
+ # strings cause infinite loops
+ if isinstance(subsequence, (str, bytes)):
+ break
+ subsequence = subsequence[0]
+
+ return tuple(shape)
+
+
+def get_dtype(array_like):
+ """Return dtype of an array like object.
+
+ In the case of a nested sequence, the type of the first value
+ is inspected.
+
+ :param array_like: Array like object: numpy array, hdf5 dataset,
+ multi-dimensional nested sequence
+ :return: numpy dtype of object
+ """
+ if hasattr(array_like, "dtype"):
+ return array_like.dtype
+
+ subsequence = array_like
+ while hasattr(subsequence, "__len__"):
+ # strings cause infinite loops
+ if isinstance(subsequence, (str, bytes)):
+ break
+ subsequence = subsequence[0]
+
+ return numpy.dtype(type(subsequence))
+
+
+def get_concatenated_dtype(arrays):
+ """Return dtype of array resulting of concatenation
+ of a list of arrays (without actually concatenating
+ them).
+
+ :param arrays: list of numpy arrays
+ :return: resulting dtype after concatenating arrays
+ """
+ dtypes = {a.dtype for a in arrays}
+ dummy = []
+ for dt in dtypes:
+ dummy.append(numpy.zeros((1, 1), dtype=dt))
+ return numpy.array(dummy).dtype
+
+
+class ListOfImages(object):
+ """This class provides a way to access values and slices in a stack of
+ images stored as a list of 2D numpy arrays, without creating a 3D numpy
+ array first.
+
+ A transposition can be specified, as a 3-tuple of dimensions in the wanted
+ order. For example, to transpose from ``xyz`` ``(0, 1, 2)`` into ``yzx``,
+ the transposition tuple is ``(1, 2, 0)``
+
+ All the 2D arrays in the list must have the same shape.
+
+ The global dtype of the stack of images is the one that would be obtained
+ by casting the list of 2D arrays into a 3D numpy array.
+
+ :param images: list of 2D numpy arrays, or :class:`ListOfImages` object
+ :param transposition: Tuple of dimension numbers in the wanted order
+ """
+ def __init__(self, images, transposition=None):
+ """
+
+ """
+ super(ListOfImages, self).__init__()
+
+ # if images is a ListOfImages instance, get the underlying data
+ # as a list of 2D arrays
+ if isinstance(images, ListOfImages):
+ images = images.images
+
+ # test stack of images is as expected
+ assert is_list_of_arrays(images), \
+ "Image stack must be a list of arrays"
+ image0_shape = images[0].shape
+ for image in images:
+ assert image.ndim == 2, \
+ "Images must be 2D numpy arrays"
+ assert image.shape == image0_shape, \
+ "All images must have the same shape"
+
+ self.images = images
+ """List of images"""
+
+ self.shape = (len(images), ) + image0_shape
+ """Tuple of array dimensions"""
+ self.dtype = get_concatenated_dtype(images)
+ """Data-type of the global array"""
+ self.ndim = 3
+ """Number of array dimensions"""
+
+ self.size = len(images) * image0_shape[0] * image0_shape[1]
+ """Number of elements in the array."""
+
+ self.transposition = list(range(self.ndim))
+ """List of dimension indices, in an order depending on the
+ specified transposition. By default this is simply
+ [0, ..., self.ndim], but it can be changed by specifying a different
+ ``transposition`` parameter at initialization.
+
+ Use :meth:`transpose`, to create a new :class:`ListOfImages`
+ with a different :attr:`transposition`.
+ """
+
+ if transposition is not None:
+ assert len(transposition) == self.ndim
+ assert set(transposition) == set(list(range(self.ndim))), \
+ "Transposition must be a sequence containing all dimensions"
+ self.transposition = transposition
+ self.__sort_shape()
+
+ def __sort_shape(self):
+ """Sort shape in the order defined in :attr:`transposition`
+ """
+ new_shape = tuple(self.shape[dim] for dim in self.transposition)
+ self.shape = new_shape
+
+ def __sort_indices(self, indices):
+ """Return array indices sorted in the order needed
+ to access data in the original non-transposed images.
+
+ :param indices: Tuple of ndim indices, in the order needed
+ to access the transposed view
+ :return: Sorted tuple of indices, to access original data
+ """
+ assert len(indices) == self.ndim
+ sorted_indices = tuple(idx for (_, idx) in
+ sorted(zip(self.transposition, indices)))
+ return sorted_indices
+
+ def __array__(self, dtype=None):
+ """Cast the images into a numpy array, and return it.
+
+ If a transposition has been done on this images, return
+ a transposed view of a numpy array."""
+ return numpy.transpose(numpy.array(self.images, dtype=dtype),
+ self.transposition)
+
+ def __len__(self):
+ return self.shape[0]
+
+ def transpose(self, transposition=None):
+ """Return a re-ordered (dimensions permutated)
+ :class:`ListOfImages`.
+
+ The returned object refers to
+ the same images but with a different :attr:`transposition`.
+
+ :param List[int] transposition: List/tuple of dimension numbers in the
+ wanted order.
+ If ``None`` (default), reverse the dimensions.
+ :return: new :class:`ListOfImages` object
+ """
+ # by default, reverse the dimensions
+ if transposition is None:
+ transposition = list(reversed(self.transposition))
+
+ # If this ListOfImages is already transposed, sort new transposition
+ # relative to old transposition
+ elif list(self.transposition) != list(range(self.ndim)):
+ transposition = [self.transposition[i] for i in transposition]
+
+ return ListOfImages(self.images,
+ transposition)
+
+ @property
+ def T(self):
+ """
+ Same as self.transpose()
+
+ :return: DatasetView with dimensions reversed."""
+ return self.transpose()
+
+ def __getitem__(self, item):
+ """Handle a subset of numpy indexing with regards to the dimension
+ order as specified in :attr:`transposition`
+
+ Following features are **not supported**:
+
+ - fancy indexing using numpy arrays
+ - using ellipsis objects
+
+ :param item: Index
+ :return: value or slice as a numpy array
+ """
+ # 1-D slicing -> n-D slicing (n=1)
+ if not hasattr(item, "__len__"):
+ # first dimension index is given
+ item = [item]
+ # following dimensions are indexed with : (all elements)
+ item += [slice(None) for _i in range(self.ndim - 1)]
+
+ # n-dimensional slicing
+ if len(item) != self.ndim:
+ raise IndexError(
+ "N-dim slicing requires a tuple of N indices/slices. " +
+ "Needed dimensions: %d" % self.ndim)
+
+ # get list of indices sorted in the original images order
+ sorted_indices = self.__sort_indices(item)
+ list_idx, array_idx = sorted_indices[0], sorted_indices[1:]
+
+ images_selection = self.images[list_idx]
+
+ # now we must transpose the output data
+ output_dimensions = []
+ frozen_dimensions = []
+ for i, idx in enumerate(item):
+ # slices and sequences
+ if not isinstance(idx, numbers.Integral):
+ output_dimensions.append(self.transposition[i])
+ # regular integer index
+ else:
+ # whenever a dimension is fixed (indexed by an integer)
+ # the number of output dimension is reduced
+ frozen_dimensions.append(self.transposition[i])
+
+ # decrement output dimensions that are above frozen dimensions
+ for frozen_dim in reversed(sorted(frozen_dimensions)):
+ for i, out_dim in enumerate(output_dimensions):
+ if out_dim > frozen_dim:
+ output_dimensions[i] -= 1
+
+ assert (len(output_dimensions) + len(frozen_dimensions)) == self.ndim
+ assert set(output_dimensions) == set(range(len(output_dimensions)))
+
+ # single list elements selected
+ if isinstance(images_selection, numpy.ndarray):
+ return numpy.transpose(images_selection[array_idx],
+ axes=output_dimensions)
+ # muliple list elements selected
+ else:
+ # apply selection first
+ output_stack = []
+ for img in images_selection:
+ output_stack.append(img[array_idx])
+ # then cast into a numpy array, and transpose
+ return numpy.transpose(numpy.array(output_stack),
+ axes=output_dimensions)
+
+ def min(self):
+ """
+ :return: Global minimum value
+ """
+ min_value = self.images[0].min()
+ if len(self.images) > 1:
+ for img in self.images[1:]:
+ min_value = min(min_value, img.min())
+ return min_value
+
+ def max(self):
+ """
+ :return: Global maximum value
+ """
+ max_value = self.images[0].max()
+ if len(self.images) > 1:
+ for img in self.images[1:]:
+ max_value = max(max_value, img.max())
+ return max_value
+
+
+class DatasetView(object):
+ """This class provides a way to transpose a dataset without
+ casting it into a numpy array. This way, the dataset in a file need not
+ necessarily be integrally read into memory to view it in a different
+ transposition.
+
+ .. note::
+ The performances depend a lot on the way the dataset was written
+ to file. Depending on the chunking strategy, reading a complete 2D slice
+ in an unfavorable direction may still require the entire dataset to
+ be read from disk.
+
+ :param dataset: h5py dataset
+ :param transposition: List of dimensions sorted in the order of
+ transposition (relative to the original h5py dataset)
+ """
+ def __init__(self, dataset, transposition=None):
+ """
+
+ """
+ super(DatasetView, self).__init__()
+ self.dataset = dataset
+ """original dataset"""
+
+ self.shape = dataset.shape
+ """Tuple of array dimensions"""
+ self.dtype = dataset.dtype
+ """Data-type of the array’s element"""
+ self.ndim = len(dataset.shape)
+ """Number of array dimensions"""
+
+ size = 0
+ if self.ndim:
+ size = 1
+ for dimsize in self.shape:
+ size *= dimsize
+ self.size = size
+ """Number of elements in the array."""
+
+ self.transposition = list(range(self.ndim))
+ """List of dimension indices, in an order depending on the
+ specified transposition. By default this is simply
+ [0, ..., self.ndim], but it can be changed by specifying a different
+ `transposition` parameter at initialization.
+
+ Use :meth:`transpose`, to create a new :class:`DatasetView`
+ with a different :attr:`transposition`.
+ """
+
+ if transposition is not None:
+ assert len(transposition) == self.ndim
+ assert set(transposition) == set(list(range(self.ndim))), \
+ "Transposition must be a list containing all dimensions"
+ self.transposition = transposition
+ self.__sort_shape()
+
+ def __sort_shape(self):
+ """Sort shape in the order defined in :attr:`transposition`
+ """
+ new_shape = tuple(self.shape[dim] for dim in self.transposition)
+ self.shape = new_shape
+
+ def __sort_indices(self, indices):
+ """Return array indices sorted in the order needed
+ to access data in the original non-transposed dataset.
+
+ :param indices: Tuple of ndim indices, in the order needed
+ to access the view
+ :return: Sorted tuple of indices, to access original data
+ """
+ assert len(indices) == self.ndim
+ sorted_indices = tuple(idx for (_, idx) in
+ sorted(zip(self.transposition, indices)))
+ return sorted_indices
+
+ def __getitem__(self, item):
+ """Handle fancy indexing with regards to the dimension order as
+ specified in :attr:`transposition`
+
+ The supported fancy-indexing syntax is explained at
+ http://docs.h5py.org/en/latest/high/dataset.html#fancy-indexing.
+
+ Additional restrictions exist if the data has been transposed:
+
+ - numpy boolean array indexing is not supported
+ - ellipsis objects are not supported
+
+ :param item: Index, possibly fancy index (must be supported by h5py)
+ :return: Sliced numpy array or numpy scalar
+ """
+ # no transposition, let the original dataset handle indexing
+ if self.transposition == list(range(self.ndim)):
+ return self.dataset[item]
+
+ # 1-D slicing: create a list of indices to switch to n-D slicing
+ if not hasattr(item, "__len__"):
+ # first dimension index (list index) is given
+ item = [item]
+ # following dimensions are indexed with slices representing all elements
+ item += [slice(None) for _i in range(self.ndim - 1)]
+
+ # n-dimensional slicing
+ if len(item) != self.ndim:
+ raise IndexError(
+ "N-dim slicing requires a tuple of N indices/slices. " +
+ "Needed dimensions: %d" % self.ndim)
+
+ # get list of indices sorted in the original dataset order
+ sorted_indices = self.__sort_indices(item)
+
+ output_data_not_transposed = self.dataset[sorted_indices]
+
+ # now we must transpose the output data
+ output_dimensions = []
+ frozen_dimensions = []
+ for i, idx in enumerate(item):
+ # slices and sequences
+ if not isinstance(idx, int):
+ output_dimensions.append(self.transposition[i])
+ # regular integer index
+ else:
+ # whenever a dimension is fixed (indexed by an integer)
+ # the number of output dimension is reduced
+ frozen_dimensions.append(self.transposition[i])
+
+ # decrement output dimensions that are above frozen dimensions
+ for frozen_dim in reversed(sorted(frozen_dimensions)):
+ for i, out_dim in enumerate(output_dimensions):
+ if out_dim > frozen_dim:
+ output_dimensions[i] -= 1
+
+ assert (len(output_dimensions) + len(frozen_dimensions)) == self.ndim
+ assert set(output_dimensions) == set(range(len(output_dimensions)))
+
+ return numpy.transpose(output_data_not_transposed,
+ axes=output_dimensions)
+
+ def __array__(self, dtype=None):
+ """Cast the dataset into a numpy array, and return it.
+
+ If a transposition has been done on this dataset, return
+ a transposed view of a numpy array."""
+ return numpy.transpose(numpy.array(self.dataset, dtype=dtype),
+ self.transposition)
+
+ def __len__(self):
+ return self.shape[0]
+
+ def transpose(self, transposition=None):
+ """Return a re-ordered (dimensions permutated)
+ :class:`DatasetView`.
+
+ The returned object refers to
+ the same dataset but with a different :attr:`transposition`.
+
+ :param List[int] transposition: List of dimension numbers in the wanted order.
+ If ``None`` (default), reverse the dimensions.
+ :return: Transposed DatasetView
+ """
+ # by default, reverse the dimensions
+ if transposition is None:
+ transposition = list(reversed(self.transposition))
+
+ # If this DatasetView is already transposed, sort new transposition
+ # relative to old transposition
+ elif list(self.transposition) != list(range(self.ndim)):
+ transposition = [self.transposition[i] for i in transposition]
+
+ return DatasetView(self.dataset,
+ transposition)
+
+ @property
+ def T(self):
+ """
+ Same as self.transpose()
+
+ :return: DatasetView with dimensions reversed."""
+ return self.transpose()
diff --git a/src/silx/utils/debug.py b/src/silx/utils/debug.py
new file mode 100644
index 0000000..3d50fc9
--- /dev/null
+++ b/src/silx/utils/debug.py
@@ -0,0 +1,100 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+
+import inspect
+import types
+import logging
+
+
+debug_logger = logging.getLogger("silx.DEBUG")
+
+_indent = 0
+
+
+def log_method(func, class_name=None):
+ """Decorator to inject a warning log before an after any function/method.
+
+ .. code-block:: python
+
+ @log_method
+ def foo():
+ return None
+
+ :param callable func: The function to patch
+ :param str class_name: In case a method, provide the class name
+ """
+ def wrapper(*args, **kwargs):
+ global _indent
+
+ indent = " " * _indent
+ if class_name is not None:
+ name = "%s.%s" % (class_name, func.__name__)
+ else:
+ name = "%s" % func.__name__
+
+ debug_logger.warning("%s%s" % (indent, name))
+ _indent += 1
+ result = func(*args, **kwargs)
+ _indent -= 1
+ debug_logger.warning("%sreturn (%s)" % (indent, name))
+ return result
+ return wrapper
+
+
+def log_all_methods(base_class):
+ """Decorator to inject a warning log before an after any method provided by
+ a class.
+
+ .. code-block:: python
+
+ @log_all_methods
+ class Foo(object):
+
+ def a(self):
+ return None
+
+ def b(self):
+ return self.a()
+
+ Here is the output when calling the `b` method.
+
+ .. code-block::
+
+ WARNING:silx.DEBUG:_Foobar.b
+ WARNING:silx.DEBUG: _Foobar.a
+ WARNING:silx.DEBUG: return (_Foobar.a)
+ WARNING:silx.DEBUG:return (_Foobar.b)
+
+ :param class base_class: The class to patch
+ """
+ methodTypes = (types.MethodType, types.FunctionType, types.BuiltinFunctionType, types.BuiltinMethodType)
+ for name, func in inspect.getmembers(base_class):
+ if isinstance(func, methodTypes):
+ if func.__name__ not in ["__subclasshook__", "__new__"]:
+ # patching __new__ in Python2 break the object, then we skip it
+ setattr(base_class, name, log_method(func, base_class.__name__))
+
+ return base_class
diff --git a/silx/utils/deprecation.py b/src/silx/utils/deprecation.py
index 7b19ee5..7b19ee5 100644
--- a/silx/utils/deprecation.py
+++ b/src/silx/utils/deprecation.py
diff --git a/silx/utils/enum.py b/src/silx/utils/enum.py
index fece575..fece575 100644
--- a/silx/utils/enum.py
+++ b/src/silx/utils/enum.py
diff --git a/silx/utils/exceptions.py b/src/silx/utils/exceptions.py
index addba89..addba89 100644
--- a/silx/utils/exceptions.py
+++ b/src/silx/utils/exceptions.py
diff --git a/silx/utils/files.py b/src/silx/utils/files.py
index 1982c0d..1982c0d 100644
--- a/silx/utils/files.py
+++ b/src/silx/utils/files.py
diff --git a/src/silx/utils/html.py b/src/silx/utils/html.py
new file mode 100644
index 0000000..9b39b95
--- /dev/null
+++ b/src/silx/utils/html.py
@@ -0,0 +1,37 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "19/09/2016"
+
+from .deprecation import deprecated_warning
+
+deprecated_warning(type_='module',
+ name=__name__,
+ replacement='html',
+ since_version='0.15.0')
+
+from html import escape # noqa
diff --git a/silx/utils/include/silx_store_openmp.h b/src/silx/utils/include/silx_store_openmp.h
index f04f630..f04f630 100644
--- a/silx/utils/include/silx_store_openmp.h
+++ b/src/silx/utils/include/silx_store_openmp.h
diff --git a/silx/utils/launcher.py b/src/silx/utils/launcher.py
index c46256a..c46256a 100644
--- a/silx/utils/launcher.py
+++ b/src/silx/utils/launcher.py
diff --git a/silx/utils/number.py b/src/silx/utils/number.py
index f852a39..f852a39 100755
--- a/silx/utils/number.py
+++ b/src/silx/utils/number.py
diff --git a/silx/utils/property.py b/src/silx/utils/property.py
index 10d5d98..10d5d98 100644
--- a/silx/utils/property.py
+++ b/src/silx/utils/property.py
diff --git a/src/silx/utils/proxy.py b/src/silx/utils/proxy.py
new file mode 100644
index 0000000..d8821c2
--- /dev/null
+++ b/src/silx/utils/proxy.py
@@ -0,0 +1,208 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Module containing proxy objects"""
+
+from __future__ import absolute_import, print_function, division
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "02/10/2017"
+
+
+import functools
+
+
+class Proxy(object):
+ """Create a proxy of an object.
+
+ Provides default methods and property using :meth:`__getattr__` and special
+ method by redefining them one by one.
+ Special methods are defined as properties, as a result if the `obj` method
+ is not defined, the property code fail and the special method will not be
+ visible.
+ """
+
+ __slots__ = ["__obj", "__weakref__"]
+
+ def __init__(self, obj):
+ object.__setattr__(self, "_Proxy__obj", obj)
+
+ __class__ = property(lambda self: self.__obj.__class__)
+
+ def __getattr__(self, name):
+ return getattr(self.__obj, name)
+
+ __setattr__ = property(lambda self: self.__obj.__setattr__)
+ __delattr__ = property(lambda self: self.__obj.__delattr__)
+
+ # binary comparator methods
+
+ __lt__ = property(lambda self: self.__obj.__lt__)
+ __le__ = property(lambda self: self.__obj.__le__)
+ __eq__ = property(lambda self: self.__obj.__eq__)
+ __ne__ = property(lambda self: self.__obj.__ne__)
+ __gt__ = property(lambda self: self.__obj.__gt__)
+ __ge__ = property(lambda self: self.__obj.__ge__)
+
+ # binary numeric methods
+
+ __add__ = property(lambda self: self.__obj.__add__)
+ __radd__ = property(lambda self: self.__obj.__radd__)
+ __iadd__ = property(lambda self: self.__obj.__iadd__)
+ __sub__ = property(lambda self: self.__obj.__sub__)
+ __rsub__ = property(lambda self: self.__obj.__rsub__)
+ __isub__ = property(lambda self: self.__obj.__isub__)
+ __mul__ = property(lambda self: self.__obj.__mul__)
+ __rmul__ = property(lambda self: self.__obj.__rmul__)
+ __imul__ = property(lambda self: self.__obj.__imul__)
+
+ __truediv__ = property(lambda self: self.__obj.__truediv__)
+ __rtruediv__ = property(lambda self: self.__obj.__rtruediv__)
+ __itruediv__ = property(lambda self: self.__obj.__itruediv__)
+ __floordiv__ = property(lambda self: self.__obj.__floordiv__)
+ __rfloordiv__ = property(lambda self: self.__obj.__rfloordiv__)
+ __ifloordiv__ = property(lambda self: self.__obj.__ifloordiv__)
+ __mod__ = property(lambda self: self.__obj.__mod__)
+ __rmod__ = property(lambda self: self.__obj.__rmod__)
+ __imod__ = property(lambda self: self.__obj.__imod__)
+ __divmod__ = property(lambda self: self.__obj.__divmod__)
+ __rdivmod__ = property(lambda self: self.__obj.__rdivmod__)
+ __pow__ = property(lambda self: self.__obj.__pow__)
+ __rpow__ = property(lambda self: self.__obj.__rpow__)
+ __ipow__ = property(lambda self: self.__obj.__ipow__)
+ __lshift__ = property(lambda self: self.__obj.__lshift__)
+ __rlshift__ = property(lambda self: self.__obj.__rlshift__)
+ __ilshift__ = property(lambda self: self.__obj.__ilshift__)
+ __rshift__ = property(lambda self: self.__obj.__rshift__)
+ __rrshift__ = property(lambda self: self.__obj.__rrshift__)
+ __irshift__ = property(lambda self: self.__obj.__irshift__)
+
+ # binary logical methods
+
+ __and__ = property(lambda self: self.__obj.__and__)
+ __rand__ = property(lambda self: self.__obj.__rand__)
+ __iand__ = property(lambda self: self.__obj.__iand__)
+ __xor__ = property(lambda self: self.__obj.__xor__)
+ __rxor__ = property(lambda self: self.__obj.__rxor__)
+ __ixor__ = property(lambda self: self.__obj.__ixor__)
+ __or__ = property(lambda self: self.__obj.__or__)
+ __ror__ = property(lambda self: self.__obj.__ror__)
+ __ior__ = property(lambda self: self.__obj.__ior__)
+
+ # unary methods
+
+ __neg__ = property(lambda self: self.__obj.__neg__)
+ __pos__ = property(lambda self: self.__obj.__pos__)
+ __abs__ = property(lambda self: self.__obj.__abs__)
+ __invert__ = property(lambda self: self.__obj.__invert__)
+ __floor__ = property(lambda self: self.__obj.__floor__)
+ __ceil__ = property(lambda self: self.__obj.__ceil__)
+ __round__ = property(lambda self: self.__obj.__round__)
+
+ # cast
+
+ __repr__ = property(lambda self: self.__obj.__repr__)
+ __str__ = property(lambda self: self.__obj.__str__)
+ __complex__ = property(lambda self: self.__obj.__complex__)
+ __int__ = property(lambda self: self.__obj.__int__)
+ __float__ = property(lambda self: self.__obj.__float__)
+ __hash__ = property(lambda self: self.__obj.__hash__)
+ __bytes__ = property(lambda self: self.__obj.__bytes__)
+ __bool__ = property(lambda self: lambda: bool(self.__obj))
+ __format__ = property(lambda self: self.__obj.__format__)
+
+ # container
+
+ __len__ = property(lambda self: self.__obj.__len__)
+ __length_hint__ = property(lambda self: self.__obj.__length_hint__)
+ __getitem__ = property(lambda self: self.__obj.__getitem__)
+ __missing__ = property(lambda self: self.__obj.__missing__)
+ __setitem__ = property(lambda self: self.__obj.__setitem__)
+ __delitem__ = property(lambda self: self.__obj.__delitem__)
+ __iter__ = property(lambda self: self.__obj.__iter__)
+ __reversed__ = property(lambda self: self.__obj.__reversed__)
+ __contains__ = property(lambda self: self.__obj.__contains__)
+
+ # pickle
+
+ __reduce__ = property(lambda self: self.__obj.__reduce__)
+ __reduce_ex__ = property(lambda self: self.__obj.__reduce_ex__)
+
+ # async
+
+ __await__ = property(lambda self: self.__obj.__await__)
+ __aiter__ = property(lambda self: self.__obj.__aiter__)
+ __anext__ = property(lambda self: self.__obj.__anext__)
+ __aenter__ = property(lambda self: self.__obj.__aenter__)
+ __aexit__ = property(lambda self: self.__obj.__aexit__)
+
+ # other
+
+ __index__ = property(lambda self: self.__obj.__index__)
+
+ __next__ = property(lambda self: self.__obj.__next__)
+
+ __enter__ = property(lambda self: self.__obj.__enter__)
+ __exit__ = property(lambda self: self.__obj.__exit__)
+
+ __concat__ = property(lambda self: self.__obj.__concat__)
+ __iconcat__ = property(lambda self: self.__obj.__iconcat__)
+
+ __call__ = property(lambda self: self.__obj.__call__)
+
+
+def _docstring(dest, origin):
+ """Implementation of docstring decorator.
+
+ It patches dest.__doc__.
+ """
+ if not isinstance(dest, type) and isinstance(origin, type):
+ # func is not a class, but origin is, get the method with the same name
+ try:
+ origin = getattr(origin, dest.__name__)
+ except AttributeError:
+ raise ValueError(
+ "origin class has no %s method" % dest.__name__)
+
+ dest.__doc__ = origin.__doc__
+ return dest
+
+
+def docstring(origin):
+ """Decorator to initialize the docstring from another source.
+
+ This is useful to duplicate a docstring for inheritance and composition.
+
+ If origin is a method or a function, it copies its docstring.
+ If origin is a class, the docstring is copied from the method
+ of that class which has the same name as the method/function
+ being decorated.
+
+ :param origin:
+ The method, function or class from which to get the docstring
+ :raises ValueError:
+ If the origin class has not method n case the
+ """
+ return functools.partial(_docstring, origin=origin)
diff --git a/silx/utils/retry.py b/src/silx/utils/retry.py
index adc43bc..adc43bc 100644
--- a/silx/utils/retry.py
+++ b/src/silx/utils/retry.py
diff --git a/silx/utils/setup.py b/src/silx/utils/setup.py
index 1f3e09a..1f3e09a 100644
--- a/silx/utils/setup.py
+++ b/src/silx/utils/setup.py
diff --git a/src/silx/utils/test/__init__.py b/src/silx/utils/test/__init__.py
new file mode 100755
index 0000000..14fd940
--- /dev/null
+++ b/src/silx/utils/test/__init__.py
@@ -0,0 +1,24 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
diff --git a/src/silx/utils/test/test_array_like.py b/src/silx/utils/test/test_array_like.py
new file mode 100644
index 0000000..a0b4b7b
--- /dev/null
+++ b/src/silx/utils/test/test_array_like.py
@@ -0,0 +1,430 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for array_like module"""
+
+__authors__ = ["P. Knobel"]
+__license__ = "MIT"
+__date__ = "09/01/2017"
+
+import h5py
+import numpy
+import os
+import tempfile
+import unittest
+
+from ..array_like import DatasetView, ListOfImages
+from ..array_like import get_dtype, get_concatenated_dtype, get_shape,\
+ is_array, is_nested_sequence, is_list_of_arrays
+
+
+class TestTransposedDatasetView(unittest.TestCase):
+
+ def setUp(self):
+ # dataset attributes
+ self.ndim = 3
+ self.original_shape = (5, 10, 20)
+ self.size = 1
+ for dim in self.original_shape:
+ self.size *= dim
+
+ self.volume = numpy.arange(self.size).reshape(self.original_shape)
+
+ self.tempdir = tempfile.mkdtemp()
+ self.h5_fname = os.path.join(self.tempdir, "tempfile.h5")
+ with h5py.File(self.h5_fname, "w") as f:
+ f["volume"] = self.volume
+
+ self.h5f = h5py.File(self.h5_fname, "r")
+
+ self.all_permutations = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0),
+ (2, 0, 1), (2, 1, 0)]
+
+ def tearDown(self):
+ self.h5f.close()
+ os.unlink(self.h5_fname)
+ os.rmdir(self.tempdir)
+
+ def _testSize(self, obj):
+ """These assertions apply to all following test cases"""
+ self.assertEqual(obj.ndim, self.ndim)
+ self.assertEqual(obj.size, self.size)
+ size_from_shape = 1
+ for dim in obj.shape:
+ size_from_shape *= dim
+ self.assertEqual(size_from_shape, self.size)
+
+ for dim in self.original_shape:
+ self.assertIn(dim, obj.shape)
+
+ def testNoTransposition(self):
+ """no transposition (transposition = (0, 1, 2))"""
+ a = DatasetView(self.h5f["volume"])
+
+ self.assertEqual(a.shape, self.original_shape)
+ self._testSize(a)
+
+ # reversing the dimensions twice results in no change
+ rtrans = list(reversed(range(self.ndim)))
+ self.assertTrue(numpy.array_equal(
+ a,
+ a.transpose(rtrans).transpose(rtrans)))
+
+ for i in range(a.shape[0]):
+ for j in range(a.shape[1]):
+ for k in range(a.shape[2]):
+ self.assertEqual(self.h5f["volume"][i, j, k],
+ a[i, j, k])
+
+ def _testTransposition(self, transposition):
+ """test transposed dataset
+
+ :param tuple transposition: List of dimensions (0... n-1) sorted
+ in the desired order
+ """
+ a = DatasetView(self.h5f["volume"],
+ transposition=transposition)
+ self._testSize(a)
+
+ # sort shape of transposed object, to hopefully find the original shape
+ sorted_shape = tuple(dim_size for (_, dim_size) in
+ sorted(zip(transposition, a.shape)))
+ self.assertEqual(sorted_shape, self.original_shape)
+
+ a_as_array = numpy.array(self.h5f["volume"]).transpose(transposition)
+
+ # test the __array__ method
+ self.assertTrue(numpy.array_equal(
+ numpy.array(a),
+ a_as_array))
+
+ # test slicing
+ for selection in [(2, slice(None), slice(None)),
+ (slice(None), 1, slice(0, 8)),
+ (slice(0, 3), slice(None), 3),
+ (1, 3, slice(None)),
+ (slice(None), 2, 1),
+ (4, slice(1, 9, 2), 2)]:
+ self.assertIsInstance(a[selection], numpy.ndarray)
+ self.assertTrue(numpy.array_equal(
+ a[selection],
+ a_as_array[selection]))
+
+ # test the DatasetView.__getitem__ for single values
+ # (step adjusted to test at least 3 indices in each dimension)
+ for i in range(0, a.shape[0], a.shape[0] // 3):
+ for j in range(0, a.shape[1], a.shape[1] // 3):
+ for k in range(0, a.shape[2], a.shape[2] // 3):
+ sorted_indices = tuple(idx for (_, idx) in
+ sorted(zip(transposition, [i, j, k])))
+ viewed_value = a[i, j, k]
+ corresponding_original_value = self.h5f["volume"][sorted_indices]
+ self.assertEqual(viewed_value,
+ corresponding_original_value)
+
+ # reversing the dimensions twice results in no change
+ rtrans = list(reversed(range(self.ndim)))
+ self.assertTrue(numpy.array_equal(
+ a,
+ a.transpose(rtrans).transpose(rtrans)))
+
+ # test .T property
+ self.assertTrue(numpy.array_equal(
+ a.T,
+ a.transpose(rtrans)))
+
+ def testTransposition012(self):
+ """transposition = (0, 1, 2)
+ (should be the same as testNoTransposition)"""
+ self._testTransposition((0, 1, 2))
+
+ def testTransposition021(self):
+ """transposition = (0, 2, 1)"""
+ self._testTransposition((0, 2, 1))
+
+ def testTransposition102(self):
+ """transposition = (1, 0, 2)"""
+ self._testTransposition((1, 0, 2))
+
+ def testTransposition120(self):
+ """transposition = (1, 2, 0)"""
+ self._testTransposition((1, 2, 0))
+
+ def testTransposition201(self):
+ """transposition = (2, 0, 1)"""
+ self._testTransposition((2, 0, 1))
+
+ def testTransposition210(self):
+ """transposition = (2, 1, 0)"""
+ self._testTransposition((2, 1, 0))
+
+ def testAllDoubleTranspositions(self):
+ for trans1 in self.all_permutations:
+ for trans2 in self.all_permutations:
+ self._testDoubleTransposition(trans1, trans2)
+
+ def _testDoubleTransposition(self, transposition1, transposition2):
+ a = DatasetView(self.h5f["volume"],
+ transposition=transposition1).transpose(transposition2)
+
+ b = self.volume.transpose(transposition1).transpose(transposition2)
+
+ self.assertTrue(numpy.array_equal(a, b),
+ "failed with double transposition %s %s" % (transposition1, transposition2))
+
+ def test1DIndex(self):
+ a = DatasetView(self.h5f["volume"])
+ self.assertTrue(numpy.array_equal(self.volume[1],
+ a[1]))
+
+ b = DatasetView(self.h5f["volume"], transposition=(1, 0, 2))
+ self.assertTrue(numpy.array_equal(self.volume[:, 1, :],
+ b[1]))
+
+
+class TestTransposedListOfImages(unittest.TestCase):
+ def setUp(self):
+ # images attributes
+ self.ndim = 3
+ self.original_shape = (5, 10, 20)
+ self.size = 1
+ for dim in self.original_shape:
+ self.size *= dim
+
+ volume = numpy.arange(self.size).reshape(self.original_shape)
+
+ self.images = []
+ for i in range(self.original_shape[0]):
+ self.images.append(
+ volume[i])
+
+ self.images_as_3D_array = numpy.array(self.images)
+
+ self.all_permutations = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0),
+ (2, 0, 1), (2, 1, 0)]
+
+ def tearDown(self):
+ pass
+
+ def _testSize(self, obj):
+ """These assertions apply to all following test cases"""
+ self.assertEqual(obj.ndim, self.ndim)
+ self.assertEqual(obj.size, self.size)
+ size_from_shape = 1
+ for dim in obj.shape:
+ size_from_shape *= dim
+ self.assertEqual(size_from_shape, self.size)
+
+ for dim in self.original_shape:
+ self.assertIn(dim, obj.shape)
+
+ def testNoTransposition(self):
+ """no transposition (transposition = (0, 1, 2))"""
+ a = ListOfImages(self.images)
+
+ self.assertEqual(a.shape, self.original_shape)
+ self._testSize(a)
+
+ for i in range(a.shape[0]):
+ for j in range(a.shape[1]):
+ for k in range(a.shape[2]):
+ self.assertEqual(self.images[i][j, k],
+ a[i, j, k])
+
+ # reversing the dimensions twice results in no change
+ rtrans = list(reversed(range(self.ndim)))
+ self.assertTrue(numpy.array_equal(
+ a,
+ a.transpose(rtrans).transpose(rtrans)))
+
+ # test .T property
+ self.assertTrue(numpy.array_equal(
+ a.T,
+ a.transpose(rtrans)))
+
+ def _testTransposition(self, transposition):
+ """test transposed dataset
+
+ :param tuple transposition: List of dimensions (0... n-1) sorted
+ in the desired order
+ """
+ a = ListOfImages(self.images,
+ transposition=transposition)
+ self._testSize(a)
+
+ # sort shape of transposed object, to hopefully find the original shape
+ sorted_shape = tuple(dim_size for (_, dim_size) in
+ sorted(zip(transposition, a.shape)))
+ self.assertEqual(sorted_shape, self.original_shape)
+
+ a_as_array = numpy.array(self.images).transpose(transposition)
+
+ # test the DatasetView.__array__ method
+ self.assertTrue(numpy.array_equal(
+ numpy.array(a),
+ a_as_array))
+
+ # test slicing
+ for selection in [(2, slice(None), slice(None)),
+ (slice(None), 1, slice(0, 8)),
+ (slice(0, 3), slice(None), 3),
+ (1, 3, slice(None)),
+ (slice(None), 2, 1),
+ (4, slice(1, 9, 2), 2)]:
+ self.assertIsInstance(a[selection], numpy.ndarray)
+ self.assertTrue(numpy.array_equal(
+ a[selection],
+ a_as_array[selection]))
+
+ # test the DatasetView.__getitem__ for single values
+ # (step adjusted to test at least 3 indices in each dimension)
+ for i in range(0, a.shape[0], a.shape[0] // 3):
+ for j in range(0, a.shape[1], a.shape[1] // 3):
+ for k in range(0, a.shape[2], a.shape[2] // 3):
+ viewed_value = a[i, j, k]
+ sorted_indices = tuple(idx for (_, idx) in
+ sorted(zip(transposition, [i, j, k])))
+ corresponding_original_value = self.images[sorted_indices[0]][sorted_indices[1:]]
+ self.assertEqual(viewed_value,
+ corresponding_original_value)
+
+ # reversing the dimensions twice results in no change
+ rtrans = list(reversed(range(self.ndim)))
+ self.assertTrue(numpy.array_equal(
+ a,
+ a.transpose(rtrans).transpose(rtrans)))
+
+ # test .T property
+ self.assertTrue(numpy.array_equal(
+ a.T,
+ a.transpose(rtrans)))
+
+ def _testDoubleTransposition(self, transposition1, transposition2):
+ a = ListOfImages(self.images,
+ transposition=transposition1).transpose(transposition2)
+
+ b = self.images_as_3D_array.transpose(transposition1).transpose(transposition2)
+
+ self.assertTrue(numpy.array_equal(a, b),
+ "failed with double transposition %s %s" % (transposition1, transposition2))
+
+ def testTransposition012(self):
+ """transposition = (0, 1, 2)
+ (should be the same as testNoTransposition)"""
+ self._testTransposition((0, 1, 2))
+
+ def testTransposition021(self):
+ """transposition = (0, 2, 1)"""
+ self._testTransposition((0, 2, 1))
+
+ def testTransposition102(self):
+ """transposition = (1, 0, 2)"""
+ self._testTransposition((1, 0, 2))
+
+ def testTransposition120(self):
+ """transposition = (1, 2, 0)"""
+ self._testTransposition((1, 2, 0))
+
+ def testTransposition201(self):
+ """transposition = (2, 0, 1)"""
+ self._testTransposition((2, 0, 1))
+
+ def testTransposition210(self):
+ """transposition = (2, 1, 0)"""
+ self._testTransposition((2, 1, 0))
+
+ def testAllDoubleTranspositions(self):
+ for trans1 in self.all_permutations:
+ for trans2 in self.all_permutations:
+ self._testDoubleTransposition(trans1, trans2)
+
+ def test1DIndex(self):
+ a = ListOfImages(self.images)
+ self.assertTrue(numpy.array_equal(self.images[1],
+ a[1]))
+
+ b = ListOfImages(self.images, transposition=(1, 0, 2))
+ self.assertTrue(numpy.array_equal(self.images_as_3D_array[:, 1, :],
+ b[1]))
+
+
+class TestFunctions(unittest.TestCase):
+ """Test functions to guess the dtype and shape of an array_like
+ object"""
+ def testListOfLists(self):
+ l = [[0, 1, 2], [2, 3, 4]]
+ self.assertEqual(get_dtype(l),
+ numpy.dtype(int))
+ self.assertEqual(get_shape(l),
+ (2, 3))
+ self.assertTrue(is_nested_sequence(l))
+ self.assertFalse(is_array(l))
+ self.assertFalse(is_list_of_arrays(l))
+
+ l = [[0., 1.], [2., 3.]]
+ self.assertEqual(get_dtype(l),
+ numpy.dtype(float))
+ self.assertEqual(get_shape(l),
+ (2, 2))
+ self.assertTrue(is_nested_sequence(l))
+ self.assertFalse(is_array(l))
+ self.assertFalse(is_list_of_arrays(l))
+
+ # concatenated dtype of int and float
+ l = [numpy.array([[0, 1, 2], [2, 3, 4]]),
+ numpy.array([[0., 1., 2.], [2., 3., 4.]])]
+
+ self.assertEqual(get_concatenated_dtype(l),
+ numpy.array(l).dtype)
+ self.assertEqual(get_shape(l),
+ (2, 2, 3))
+ self.assertFalse(is_nested_sequence(l))
+ self.assertFalse(is_array(l))
+ self.assertTrue(is_list_of_arrays(l))
+
+ def testNumpyArray(self):
+ a = numpy.array([[0, 1], [2, 3]])
+ self.assertEqual(get_dtype(a),
+ a.dtype)
+ self.assertFalse(is_nested_sequence(a))
+ self.assertTrue(is_array(a))
+ self.assertFalse(is_list_of_arrays(a))
+
+ def testH5pyDataset(self):
+ a = numpy.array([[0, 1], [2, 3]])
+
+ tempdir = tempfile.mkdtemp()
+ h5_fname = os.path.join(tempdir, "tempfile.h5")
+ with h5py.File(h5_fname, "w") as h5f:
+ h5f["dataset"] = a
+ d = h5f["dataset"]
+
+ self.assertEqual(get_dtype(d),
+ numpy.dtype(int))
+ self.assertFalse(is_nested_sequence(d))
+ self.assertTrue(is_array(d))
+ self.assertFalse(is_list_of_arrays(d))
+
+ os.unlink(h5_fname)
+ os.rmdir(tempdir)
diff --git a/src/silx/utils/test/test_debug.py b/src/silx/utils/test/test_debug.py
new file mode 100644
index 0000000..09f4b01
--- /dev/null
+++ b/src/silx/utils/test/test_debug.py
@@ -0,0 +1,88 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for debug module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "27/02/2018"
+
+
+import unittest
+from silx.utils import debug
+from silx.utils import testutils
+
+
+@debug.log_all_methods
+class _Foobar(object):
+
+ def a(self):
+ return None
+
+ def b(self):
+ return self.a()
+
+ def random_args(self, *args, **kwargs):
+ return args, kwargs
+
+ def named_args(self, a, b):
+ return a + 1, b + 1
+
+
+class TestDebug(unittest.TestCase):
+ """Tests for debug module."""
+
+ def logB(self):
+ """
+ Can be used to check the log output using:
+ `./run_tests.py silx.utils.test.test_debug.TestDebug.logB -v`
+ """
+ print()
+ test = _Foobar()
+ test.b()
+
+ @testutils.validate_logging(debug.debug_logger.name, warning=2)
+ def testMethod(self):
+ test = _Foobar()
+ test.a()
+
+ @testutils.validate_logging(debug.debug_logger.name, warning=4)
+ def testInterleavedMethod(self):
+ test = _Foobar()
+ test.b()
+
+ @testutils.validate_logging(debug.debug_logger.name, warning=2)
+ def testNamedArgument(self):
+ # Arguments arre still provided to the patched method
+ test = _Foobar()
+ result = test.named_args(10, 11)
+ self.assertEqual(result, (11, 12))
+
+ @testutils.validate_logging(debug.debug_logger.name, warning=2)
+ def testRandomArguments(self):
+ # Arguments arre still provided to the patched method
+ test = _Foobar()
+ result = test.random_args("foo", 50, a=10, b=100)
+ self.assertEqual(result[0], ("foo", 50))
+ self.assertEqual(result[1], {"a": 10, "b": 100})
diff --git a/src/silx/utils/test/test_deprecation.py b/src/silx/utils/test/test_deprecation.py
new file mode 100644
index 0000000..d52cb26
--- /dev/null
+++ b/src/silx/utils/test/test_deprecation.py
@@ -0,0 +1,96 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for html module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import unittest
+from .. import deprecation
+from silx.utils import testutils
+
+
+class TestDeprecation(unittest.TestCase):
+ """Tests for deprecation module."""
+
+ @deprecation.deprecated
+ def deprecatedWithoutParam(self):
+ pass
+
+ @deprecation.deprecated(reason="r", replacement="r", since_version="v")
+ def deprecatedWithParams(self):
+ pass
+
+ @deprecation.deprecated(reason="r", replacement="r", since_version="v", only_once=True)
+ def deprecatedOnlyOnce(self):
+ pass
+
+ @deprecation.deprecated(reason="r", replacement="r", since_version="v", only_once=False)
+ def deprecatedEveryTime(self):
+ pass
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=1)
+ def testAnnotationWithoutParam(self):
+ self.deprecatedWithoutParam()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=1)
+ def testAnnotationWithParams(self):
+ self.deprecatedWithParams()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=3)
+ def testLoggedEveryTime(self):
+ """Logged everytime cause it is 3 different locations"""
+ self.deprecatedOnlyOnce()
+ self.deprecatedOnlyOnce()
+ self.deprecatedOnlyOnce()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=1)
+ def testLoggedSingleTime(self):
+ def log():
+ self.deprecatedOnlyOnce()
+ log()
+ log()
+ log()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=3)
+ def testLoggedEveryTime2(self):
+ self.deprecatedEveryTime()
+ self.deprecatedEveryTime()
+ self.deprecatedEveryTime()
+
+ @testutils.validate_logging(deprecation.depreclog.name, warning=1)
+ def testWarning(self):
+ deprecation.deprecated_warning(type_="t", name="n")
+
+ def testBacktrace(self):
+ loggingValidator = testutils.LoggingValidator(deprecation.depreclog.name)
+ with loggingValidator:
+ self.deprecatedEveryTime()
+ message = loggingValidator.records[0].getMessage()
+ filename = __file__.replace(".pyc", ".py")
+ self.assertTrue(filename in message)
+ self.assertTrue("testBacktrace" in message)
diff --git a/src/silx/utils/test/test_enum.py b/src/silx/utils/test/test_enum.py
new file mode 100644
index 0000000..808304a
--- /dev/null
+++ b/src/silx/utils/test/test_enum.py
@@ -0,0 +1,85 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests of Enum class with extra class methods"""
+
+from __future__ import absolute_import
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "29/04/2019"
+
+
+import sys
+import unittest
+
+import enum
+from silx.utils.enum import Enum
+
+
+class TestEnum(unittest.TestCase):
+ """Tests for enum module."""
+
+ def test(self):
+ """Test with Enum"""
+ class Success(Enum):
+ A = 1
+ B = 'B'
+ self._check_enum_content(Success)
+
+ @unittest.skipIf(sys.version_info.major <= 2, 'Python3 only')
+ def test(self):
+ """Test Enum with member redefinition"""
+ with self.assertRaises(TypeError):
+ class Failure(Enum):
+ A = 1
+ A = 'B'
+
+ def test_unique(self):
+ """Test with enum.unique"""
+ with self.assertRaises(ValueError):
+ @enum.unique
+ class Failure(Enum):
+ A = 1
+ B = 1
+
+ @enum.unique
+ class Success(Enum):
+ A = 1
+ B = 'B'
+ self._check_enum_content(Success)
+
+ def _check_enum_content(self, enum_):
+ """Check that the content of an enum is: <A: 1, B: 2>.
+
+ :param Enum enum_:
+ """
+ self.assertEqual(enum_.members(), (enum_.A, enum_.B))
+ self.assertEqual(enum_.names(), ('A', 'B'))
+ self.assertEqual(enum_.values(), (1, 'B'))
+
+ self.assertEqual(enum_.from_value(1), enum_.A)
+ self.assertEqual(enum_.from_value('B'), enum_.B)
+ with self.assertRaises(ValueError):
+ enum_.from_value(3)
diff --git a/src/silx/utils/test/test_external_resources.py b/src/silx/utils/test/test_external_resources.py
new file mode 100644
index 0000000..1fedda3
--- /dev/null
+++ b/src/silx/utils/test/test_external_resources.py
@@ -0,0 +1,89 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Test for resource files management."""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "08/03/2019"
+
+
+import os
+import unittest
+import shutil
+import socket
+import urllib.request
+import urllib.error
+
+from silx.utils.ExternalResources import ExternalResources
+
+
+def isSilxWebsiteAvailable():
+ try:
+ urllib.request.urlopen('http://www.silx.org', timeout=1)
+ return True
+ except urllib.error.URLError:
+ return False
+ except socket.timeout:
+ # This exception is still received in Python 2.7
+ return False
+
+
+class TestExternalResources(unittest.TestCase):
+ """This is a test for the ExternalResources"""
+
+ @classmethod
+ def setUpClass(cls):
+ if not isSilxWebsiteAvailable():
+ raise unittest.SkipTest("Network or silx website not available")
+
+ def setUp(self):
+ self.resources = ExternalResources("toto", "http://www.silx.org/pub/silx/")
+
+ def tearDown(self):
+ if self.resources.data_home:
+ shutil.rmtree(self.resources.data_home)
+ self.resources = None
+
+ def test_download(self):
+ "test the download from silx.org"
+ f = self.resources.getfile("lena.png")
+ self.assertTrue(os.path.exists(f))
+ di = self.resources.getdir("source.tar.gz")
+ for fi in di:
+ self.assertTrue(os.path.exists(fi))
+
+ def test_download_all(self):
+ "test the download of all files from silx.org"
+ filename = self.resources.getfile("lena.png")
+ directory = "source.tar.gz"
+ filelist = self.resources.getdir(directory)
+ # download file and remove it to create a json mapping file
+ os.remove(filename)
+ directory_path = os.path.commonprefix(filelist)
+ # Make sure we will rmtree a dangerous path like "/"
+ self.assertIn(self.resources.data_home, directory_path)
+ shutil.rmtree(directory_path)
+ filelist = self.resources.download_all()
+ self.assertGreater(len(filelist), 1, "At least 2 items were downloaded")
diff --git a/src/silx/utils/test/test_launcher.py b/src/silx/utils/test/test_launcher.py
new file mode 100644
index 0000000..9e9024c
--- /dev/null
+++ b/src/silx/utils/test/test_launcher.py
@@ -0,0 +1,191 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for html module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "17/01/2018"
+
+
+import sys
+import unittest
+from silx.utils.testutils import ParametricTestCase
+from .. import launcher
+
+
+class CallbackMock():
+
+ def __init__(self, result=None):
+ self._execute_count = 0
+ self._execute_argv = None
+ self._result = result
+
+ def execute(self, argv):
+ self._execute_count = self._execute_count + 1
+ self._execute_argv = argv
+ return self._result
+
+ def __call__(self, argv):
+ return self.execute(argv)
+
+
+class TestLauncherCommand(unittest.TestCase):
+ """Tests for launcher class."""
+
+ def testEnv(self):
+ command = launcher.LauncherCommand("foo")
+ old = sys.argv
+ params = ["foo", "bar"]
+ with command.get_env(params):
+ self.assertEqual(params, sys.argv)
+ self.assertEqual(sys.argv, old)
+
+ def testEnvWhileException(self):
+ command = launcher.LauncherCommand("foo")
+ old = sys.argv
+ params = ["foo", "bar"]
+ try:
+ with command.get_env(params):
+ raise RuntimeError()
+ except RuntimeError:
+ pass
+ self.assertEqual(sys.argv, old)
+
+ def testExecute(self):
+ params = ["foo", "bar"]
+ callback = CallbackMock(result=42)
+ command = launcher.LauncherCommand("foo", function=callback)
+ status = command.execute(params)
+ self.assertEqual(callback._execute_count, 1)
+ self.assertEqual(callback._execute_argv, params)
+ self.assertEqual(status, 42)
+
+
+class TestModuleCommand(ParametricTestCase):
+
+ def setUp(self):
+ module_name = "silx.utils.test.test_launcher_command"
+ command = launcher.LauncherCommand("foo", module_name=module_name)
+ self.command = command
+
+ def testHelp(self):
+ status = self.command.execute(["--help"])
+ self.assertEqual(status, 0)
+
+ def testException(self):
+ try:
+ self.command.execute(["exception"])
+ self.fail()
+ except RuntimeError:
+ pass
+
+ def testCall(self):
+ status = self.command.execute([])
+ self.assertEqual(status, 0)
+
+ def testError(self):
+ status = self.command.execute(["error"])
+ self.assertEqual(status, -1)
+
+
+class TestLauncher(ParametricTestCase):
+ """Tests for launcher class."""
+
+ def testCallCommand(self):
+ callback = CallbackMock(result=42)
+ runner = launcher.Launcher(prog="prog")
+ command = launcher.LauncherCommand("foo", function=callback)
+ runner.add_command(command=command)
+ status = runner.execute(["prog", "foo", "param1", "param2"])
+ self.assertEqual(status, 42)
+ self.assertEqual(callback._execute_argv, ["prog foo", "param1", "param2"])
+ self.assertEqual(callback._execute_count, 1)
+
+ def testAddCommand(self):
+ runner = launcher.Launcher(prog="prog")
+ module_name = "silx.utils.test.test_launcher_command"
+ runner.add_command("foo", module_name=module_name)
+ status = runner.execute(["prog", "foo"])
+ self.assertEqual(status, 0)
+
+ def testCallHelpOnCommand(self):
+ callback = CallbackMock(result=42)
+ runner = launcher.Launcher(prog="prog")
+ command = launcher.LauncherCommand("foo", function=callback)
+ runner.add_command(command=command)
+ status = runner.execute(["prog", "--help", "foo"])
+ self.assertEqual(status, 42)
+ self.assertEqual(callback._execute_argv, ["prog foo", "--help"])
+ self.assertEqual(callback._execute_count, 1)
+
+ def testCallHelpOnCommand2(self):
+ callback = CallbackMock(result=42)
+ runner = launcher.Launcher(prog="prog")
+ command = launcher.LauncherCommand("foo", function=callback)
+ runner.add_command(command=command)
+ status = runner.execute(["prog", "help", "foo"])
+ self.assertEqual(status, 42)
+ self.assertEqual(callback._execute_argv, ["prog foo", "--help"])
+ self.assertEqual(callback._execute_count, 1)
+
+ def testCallHelpOnUnknownCommand(self):
+ callback = CallbackMock(result=42)
+ runner = launcher.Launcher(prog="prog")
+ command = launcher.LauncherCommand("foo", function=callback)
+ runner.add_command(command=command)
+ status = runner.execute(["prog", "help", "foo2"])
+ self.assertEqual(status, -1)
+
+ def testNotAvailableCommand(self):
+ callback = CallbackMock(result=42)
+ runner = launcher.Launcher(prog="prog")
+ command = launcher.LauncherCommand("foo", function=callback)
+ runner.add_command(command=command)
+ status = runner.execute(["prog", "foo2", "param1", "param2"])
+ self.assertEqual(status, -1)
+ self.assertEqual(callback._execute_count, 0)
+
+ def testCallHelp(self):
+ callback = CallbackMock(result=42)
+ runner = launcher.Launcher(prog="prog")
+ command = launcher.LauncherCommand("foo", function=callback)
+ runner.add_command(command=command)
+ status = runner.execute(["prog", "help"])
+ self.assertEqual(status, 0)
+ self.assertEqual(callback._execute_count, 0)
+
+ def testCommonCommands(self):
+ runner = launcher.Launcher()
+ tests = [
+ ["prog"],
+ ["prog", "--help"],
+ ["prog", "--version"],
+ ["prog", "help", "--help"],
+ ["prog", "help", "help"],
+ ]
+ for arguments in tests:
+ with self.subTest(args=tests):
+ status = runner.execute(arguments)
+ self.assertEqual(status, 0)
diff --git a/silx/utils/test/test_launcher_command.py b/src/silx/utils/test/test_launcher_command.py
index ccf4601..ccf4601 100644
--- a/silx/utils/test/test_launcher_command.py
+++ b/src/silx/utils/test/test_launcher_command.py
diff --git a/src/silx/utils/test/test_number.py b/src/silx/utils/test/test_number.py
new file mode 100644
index 0000000..3eb8e91
--- /dev/null
+++ b/src/silx/utils/test/test_number.py
@@ -0,0 +1,175 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for silx.uitls.number module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "05/06/2018"
+
+import logging
+import numpy
+import unittest
+import pkg_resources
+from silx.utils import number
+from silx.utils import testutils
+
+_logger = logging.getLogger(__name__)
+
+
+class TestConversionTypes(testutils.ParametricTestCase):
+
+ def testEmptyFail(self):
+ self.assertRaises(ValueError, number.min_numerical_convertible_type, "")
+
+ def testStringFail(self):
+ self.assertRaises(ValueError, number.min_numerical_convertible_type, "a")
+
+ def testInteger(self):
+ dtype = number.min_numerical_convertible_type("1456")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.unsignedinteger))
+
+ def testTrailledInteger(self):
+ dtype = number.min_numerical_convertible_type(" \t\n\r1456\t\n\r")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.unsignedinteger))
+
+ def testPositiveInteger(self):
+ dtype = number.min_numerical_convertible_type("+1456")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.unsignedinteger))
+
+ def testNegativeInteger(self):
+ dtype = number.min_numerical_convertible_type("-1456")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.signedinteger))
+
+ def testIntegerExponential(self):
+ dtype = number.min_numerical_convertible_type("14e10")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testIntegerPositiveExponential(self):
+ dtype = number.min_numerical_convertible_type("14e+10")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testIntegerNegativeExponential(self):
+ dtype = number.min_numerical_convertible_type("14e-10")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testNumberDecimal(self):
+ dtype = number.min_numerical_convertible_type("14.5")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testPositiveNumberDecimal(self):
+ dtype = number.min_numerical_convertible_type("+14.5")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testNegativeNumberDecimal(self):
+ dtype = number.min_numerical_convertible_type("-14.5")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testDecimal(self):
+ dtype = number.min_numerical_convertible_type(".50")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testPositiveDecimal(self):
+ dtype = number.min_numerical_convertible_type("+.5")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testNegativeDecimal(self):
+ dtype = number.min_numerical_convertible_type("-.5")
+ self.assertTrue(numpy.issubdtype(dtype, numpy.floating))
+
+ def testMantissa16(self):
+ dtype = number.min_numerical_convertible_type("1.50")
+ self.assertEqual(dtype, numpy.float16)
+
+ def testFloat32(self):
+ dtype = number.min_numerical_convertible_type("-23.172")
+ self.assertEqual(dtype, numpy.float32)
+
+ def testMantissa32(self):
+ dtype = number.min_numerical_convertible_type("1400.50")
+ self.assertEqual(dtype, numpy.float32)
+
+ def testMantissa64(self):
+ dtype = number.min_numerical_convertible_type("10000.000010")
+ self.assertEqual(dtype, numpy.float64)
+
+ def testMantissa80(self):
+ self.skipIfFloat80NotSupported()
+ dtype = number.min_numerical_convertible_type("1000000000.00001013")
+
+ if pkg_resources.parse_version(numpy.version.version) <= pkg_resources.parse_version("1.10.4"):
+ # numpy 1.8.2 -> Debian 8
+ # Checking a float128 precision with numpy 1.8.2 using abs(diff) is not working.
+ # It looks like the difference is done using float64 (diff == 0.0)
+ expected = (numpy.longdouble, numpy.float64)
+ else:
+ expected = (numpy.longdouble, )
+ self.assertIn(dtype, expected)
+
+ def testExponent32(self):
+ dtype = number.min_numerical_convertible_type("14.0e30")
+ self.assertEqual(dtype, numpy.float32)
+
+ def testExponent64(self):
+ dtype = number.min_numerical_convertible_type("14.0e300")
+ self.assertEqual(dtype, numpy.float64)
+
+ def testExponent80(self):
+ self.skipIfFloat80NotSupported()
+ dtype = number.min_numerical_convertible_type("14.0e3000")
+ self.assertEqual(dtype, numpy.longdouble)
+
+ def testFloat32ToString(self):
+ value = str(numpy.float32(numpy.pi))
+ dtype = number.min_numerical_convertible_type(value)
+ self.assertIn(dtype, (numpy.float32, numpy.float64))
+
+ def skipIfFloat80NotSupported(self):
+ if number.is_longdouble_64bits():
+ self.skipTest("float-80bits not supported")
+
+ def testLosePrecisionUsingFloat80(self):
+ self.skipIfFloat80NotSupported()
+ if pkg_resources.parse_version(numpy.version.version) <= pkg_resources.parse_version("1.10.4"):
+ self.skipTest("numpy > 1.10.4 expected")
+ # value does not fit even in a 128 bits mantissa
+ value = "1.0340282366920938463463374607431768211456"
+ func = testutils.validate_logging(number._logger.name, warning=1)
+ func = func(number.min_numerical_convertible_type)
+ dtype = func(value)
+ self.assertIn(dtype, (numpy.longdouble, ))
+
+ def testMillisecondEpochTime(self):
+ datetimes = ['1465803236.495412',
+ '1465803236.999362',
+ '1465803237.504311',
+ '1465803238.009261',
+ '1465803238.512211',
+ '1465803239.016160',
+ '1465803239.520110',
+ '1465803240.026059',
+ '1465803240.529009']
+ for datetime in datetimes:
+ with self.subTest(datetime=datetime):
+ dtype = number.min_numerical_convertible_type(datetime)
+ self.assertEqual(dtype, numpy.float64)
diff --git a/src/silx/utils/test/test_proxy.py b/src/silx/utils/test/test_proxy.py
new file mode 100644
index 0000000..e165267
--- /dev/null
+++ b/src/silx/utils/test/test_proxy.py
@@ -0,0 +1,330 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for weakref module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "02/10/2017"
+
+
+import unittest
+import pickle
+import numpy
+from silx.utils.proxy import Proxy, docstring
+
+
+class Thing(object):
+
+ def __init__(self, value):
+ self.value = value
+
+ def __getitem__(self, selection):
+ return selection + 1
+
+ def method(self, value):
+ return value + 2
+
+
+class InheritedProxy(Proxy):
+ """Inheriting the proxy allow to specialisze methods"""
+
+ def __init__(self, obj, value):
+ Proxy.__init__(self, obj)
+ self.value = value + 2
+
+ def __getitem__(self, selection):
+ return selection + 3
+
+ def method(self, value):
+ return value + 4
+
+
+class TestProxy(unittest.TestCase):
+ """Test that the proxy behave as expected"""
+
+ def text_init(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ self.assertTrue(isinstance(p, Thing))
+ self.assertTrue(isinstance(p, Proxy))
+
+ # methods and properties
+
+ def test_has_special_method(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ self.assertTrue(hasattr(p, "__getitem__"))
+
+ def test_missing_special_method(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ self.assertFalse(hasattr(p, "__and__"))
+
+ def test_method(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ self.assertEqual(p.method(10), obj.method(10))
+
+ def test_property(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ self.assertEqual(p.value, obj.value)
+
+ # special functions
+
+ def test_getitem(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ self.assertEqual(p[10], obj[10])
+
+ def test_setitem(self):
+ obj = numpy.array([10, 20, 30])
+ p = Proxy(obj)
+ p[0] = 20
+ self.assertEqual(obj[0], 20)
+
+ def test_slice(self):
+ obj = numpy.arange(20)
+ p = Proxy(obj)
+ expected = obj[0:10:2]
+ result = p[0:10:2]
+ self.assertEqual(list(result), list(expected))
+
+ # binary comparator methods
+
+ def test_lt(self):
+ obj = numpy.array([20])
+ p = Proxy(obj)
+ expected = obj < obj
+ result = p < p
+ self.assertEqual(result, expected)
+
+ # binary numeric methods
+
+ def test_add(self):
+ obj = numpy.array([20])
+ proxy = Proxy(obj)
+ expected = obj + obj
+ result = proxy + proxy
+ self.assertEqual(result, expected)
+
+ def test_iadd(self):
+ expected = numpy.array([20])
+ expected += 10
+ obj = numpy.array([20])
+ result = Proxy(obj)
+ result += 10
+ self.assertEqual(result, expected)
+
+ def test_radd(self):
+ obj = numpy.array([20])
+ p = Proxy(obj)
+ expected = 10 + obj
+ result = 10 + p
+ self.assertEqual(result, expected)
+
+ # binary logical methods
+
+ def test_and(self):
+ obj = numpy.array([20])
+ p = Proxy(obj)
+ expected = obj & obj
+ result = p & p
+ self.assertEqual(result, expected)
+
+ def test_iand(self):
+ expected = numpy.array([20])
+ expected &= 10
+ obj = numpy.array([20])
+ result = Proxy(obj)
+ result &= 10
+ self.assertEqual(result, expected)
+
+ def test_rand(self):
+ obj = numpy.array([20])
+ p = Proxy(obj)
+ expected = 10 & obj
+ result = 10 & p
+ self.assertEqual(result, expected)
+
+ # unary methods
+
+ def test_neg(self):
+ obj = numpy.array([20])
+ p = Proxy(obj)
+ expected = -obj
+ result = -p
+ self.assertEqual(result, expected)
+
+ def test_round(self):
+ obj = 20.5
+ p = Proxy(obj)
+ expected = round(obj)
+ result = round(p)
+ self.assertEqual(result, expected)
+
+ # cast
+
+ def test_bool(self):
+ obj = True
+ p = Proxy(obj)
+ if p:
+ pass
+ else:
+ self.fail()
+
+ def test_str(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ expected = str(obj)
+ result = str(p)
+ self.assertEqual(result, expected)
+
+ def test_repr(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ expected = repr(obj)
+ result = repr(p)
+ self.assertEqual(result, expected)
+
+ def test_text_bool(self):
+ obj = ""
+ p = Proxy(obj)
+ if p:
+ self.fail()
+ else:
+ pass
+
+ def test_text_str(self):
+ obj = "a"
+ p = Proxy(obj)
+ expected = str(obj)
+ result = str(p)
+ self.assertEqual(result, expected)
+
+ def test_text_repr(self):
+ obj = "a"
+ p = Proxy(obj)
+ expected = repr(obj)
+ result = repr(p)
+ self.assertEqual(result, expected)
+
+ def test_hash(self):
+ obj = [0, 1, 2]
+ p = Proxy(obj)
+ with self.assertRaises(TypeError):
+ hash(p)
+ obj = (0, 1, 2)
+ p = Proxy(obj)
+ hash(p)
+
+
+class TestInheritedProxy(unittest.TestCase):
+ """Test that inheriting the Proxy class behave as expected"""
+
+ # methods and properties
+
+ def test_method(self):
+ obj = Thing(10)
+ p = InheritedProxy(obj, 11)
+ self.assertEqual(p.method(10), 11 + 3)
+
+ def test_property(self):
+ obj = Thing(10)
+ p = InheritedProxy(obj, 11)
+ self.assertEqual(p.value, 11 + 2)
+
+ # special functions
+
+ def test_getitem(self):
+ obj = Thing(10)
+ p = InheritedProxy(obj, 11)
+ self.assertEqual(p[12], 12 + 3)
+
+
+class TestPickle(unittest.TestCase):
+
+ def test_dumps(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ expected = pickle.dumps(obj)
+ result = pickle.dumps(p)
+ self.assertEqual(result, expected)
+
+ def test_loads(self):
+ obj = Thing(10)
+ p = Proxy(obj)
+ obj2 = pickle.loads(pickle.dumps(p))
+ self.assertTrue(isinstance(obj2, Thing))
+ self.assertFalse(isinstance(obj2, Proxy))
+ self.assertEqual(obj.value, obj2.value)
+
+
+class TestDocstring(unittest.TestCase):
+ """Test docstring decorator"""
+
+ class Base(object):
+ def method(self):
+ """Docstring"""
+ pass
+
+ def test_inheritance(self):
+ class Derived(TestDocstring.Base):
+ @docstring(TestDocstring.Base)
+ def method(self):
+ pass
+
+ self.assertEqual(Derived.method.__doc__,
+ TestDocstring.Base.method.__doc__)
+
+ def test_composition(self):
+ class Composed(object):
+ def __init__(self):
+ self._base = TestDocstring.Base()
+
+ @docstring(TestDocstring.Base)
+ def method(self):
+ return self._base.method()
+
+ @docstring(TestDocstring.Base.method)
+ def renamed(self):
+ return self._base.method()
+
+ self.assertEqual(Composed.method.__doc__,
+ TestDocstring.Base.method.__doc__)
+
+ self.assertEqual(Composed.renamed.__doc__,
+ TestDocstring.Base.method.__doc__)
+
+ def test_function(self):
+ def f():
+ """Docstring"""
+ pass
+
+ @docstring(f)
+ def g():
+ pass
+
+ self.assertEqual(f.__doc__, g.__doc__)
diff --git a/src/silx/utils/test/test_retry.py b/src/silx/utils/test/test_retry.py
new file mode 100644
index 0000000..39bfdcf
--- /dev/null
+++ b/src/silx/utils/test/test_retry.py
@@ -0,0 +1,169 @@
+# coding: utf-8
+# /*##########################################################################
+# Copyright (C) 2016-2017 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ############################################################################*/
+"""Tests for retry utilities"""
+
+__authors__ = ["W. de Nolf"]
+__license__ = "MIT"
+__date__ = "05/02/2020"
+
+
+import unittest
+import os
+import sys
+import time
+import tempfile
+
+from .. import retry
+
+
+def _cause_segfault():
+ import ctypes
+
+ i = ctypes.c_char(b"a")
+ j = ctypes.pointer(i)
+ c = 0
+ while True:
+ j[c] = b"a"
+ c += 1
+
+
+def _submain(filename, kwcheck=None, ncausefailure=0, faildelay=0):
+ assert filename
+ assert kwcheck
+ sys.stderr = open(os.devnull, "w")
+
+ with open(filename, mode="r") as f:
+ failcounter = int(f.readline().strip())
+
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ with open(filename, mode="w") as f:
+ f.write(str(failcounter))
+ if failcounter % 2:
+ raise retry.RetryError
+ else:
+ _cause_segfault()
+ return True
+
+
+_wsubmain = retry.retry_in_subprocess()(_submain)
+
+
+class TestRetry(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp()
+ self.ctr_file = os.path.join(self.test_dir, "failcounter.txt")
+
+ def tearDown(self):
+ if os.path.exists(self.ctr_file):
+ os.unlink(self.ctr_file)
+ os.rmdir(self.test_dir)
+
+ def test_retry(self):
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ @retry.retry()
+ def method(check, kwcheck=None):
+ assert check
+ assert kwcheck
+ nonlocal failcounter
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ raise retry.RetryError
+ return True
+
+ failcounter = 0
+ kw = {
+ "kwcheck": True,
+ "retry_timeout": sufficient_timeout,
+ }
+ self.assertTrue(method(True, **kw))
+
+ failcounter = 0
+ kw = {
+ "kwcheck": True,
+ "retry_timeout": insufficient_timeout,
+ }
+ with self.assertRaises(retry.RetryTimeoutError):
+ method(True, **kw)
+
+ def test_retry_contextmanager(self):
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ @retry.retry_contextmanager()
+ def context(check, kwcheck=None):
+ assert check
+ assert kwcheck
+ nonlocal failcounter
+ if failcounter < ncausefailure:
+ time.sleep(faildelay)
+ failcounter += 1
+ raise retry.RetryError
+ yield True
+
+ failcounter = 0
+ kw = {"kwcheck": True, "retry_timeout": sufficient_timeout}
+ with context(True, **kw) as result:
+ self.assertTrue(result)
+
+ failcounter = 0
+ kw = {"kwcheck": True, "retry_timeout": insufficient_timeout}
+ with self.assertRaises(retry.RetryTimeoutError):
+ with context(True, **kw) as result:
+ pass
+
+ def test_retry_in_subprocess(self):
+ ncausefailure = 3
+ faildelay = 0.1
+ sufficient_timeout = ncausefailure * (faildelay + 10)
+ insufficient_timeout = ncausefailure * faildelay * 0.5
+
+ kw = {
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ "kwcheck": True,
+ "retry_timeout": sufficient_timeout,
+ }
+ with open(self.ctr_file, mode="w") as f:
+ f.write("0")
+ self.assertTrue(_wsubmain(self.ctr_file, **kw))
+
+ kw = {
+ "ncausefailure": ncausefailure,
+ "faildelay": faildelay,
+ "kwcheck": True,
+ "retry_timeout": insufficient_timeout,
+ }
+ with open(self.ctr_file, mode="w") as f:
+ f.write("0")
+ with self.assertRaises(retry.RetryTimeoutError):
+ _wsubmain(self.ctr_file, **kw)
diff --git a/src/silx/utils/test/test_testutils.py b/src/silx/utils/test/test_testutils.py
new file mode 100755
index 0000000..4f07c4e
--- /dev/null
+++ b/src/silx/utils/test/test_testutils.py
@@ -0,0 +1,94 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for testutils module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "18/11/2019"
+
+
+import unittest
+import logging
+from .. import testutils
+
+
+class TestLoggingValidator(unittest.TestCase):
+ """Tests for LoggingValidator"""
+
+ def testRight(self):
+ logger = logging.getLogger(__name__ + "testRight")
+ listener = testutils.LoggingValidator(logger, error=1)
+ with listener:
+ logger.error("expected")
+ logger.info("ignored")
+
+ def testCustomLevel(self):
+ logger = logging.getLogger(__name__ + "testCustomLevel")
+ listener = testutils.LoggingValidator(logger, error=1)
+ with listener:
+ logger.error("expected")
+ logger.log(666, "custom level have to be ignored")
+
+ def testWrong(self):
+ logger = logging.getLogger(__name__ + "testWrong")
+ listener = testutils.LoggingValidator(logger, error=1)
+ with self.assertRaises(RuntimeError):
+ with listener:
+ logger.error("expected")
+ logger.error("not expected")
+
+ def testManyErrors(self):
+ logger = logging.getLogger(__name__ + "testManyErrors")
+ listener = testutils.LoggingValidator(logger, error=1, warning=2)
+ with self.assertRaises(RuntimeError):
+ with listener:
+ pass
+
+ def testCanBeChecked(self):
+ logger = logging.getLogger(__name__ + "testCanBreak")
+ listener = testutils.LoggingValidator(logger, error=1, warning=2)
+ with self.assertRaises(RuntimeError):
+ with listener:
+ logger.error("aaa")
+ logger.warning("aaa")
+ self.assertFalse(listener.can_be_checked())
+ logger.error("aaa")
+ # Here we know that it's already wrong without a big cost
+ self.assertTrue(listener.can_be_checked())
+
+ def testWithAs(self):
+ logger = logging.getLogger(__name__ + "testCanBreak")
+ with testutils.LoggingValidator(logger) as listener:
+ logger.error("aaa")
+ self.assertIsNotNone(listener)
+
+ def testErrorMessage(self):
+ logger = logging.getLogger(__name__ + "testCanBreak")
+ listener = testutils.LoggingValidator(logger, error=1, warning=2)
+ with self.assertRaisesRegex(RuntimeError, "aaabbb"):
+ with listener:
+ logger.error("aaa")
+ logger.warning("aaabbb")
+ logger.error("aaa")
diff --git a/src/silx/utils/test/test_weakref.py b/src/silx/utils/test/test_weakref.py
new file mode 100644
index 0000000..06f3adf
--- /dev/null
+++ b/src/silx/utils/test/test_weakref.py
@@ -0,0 +1,315 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2019 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Tests for weakref module"""
+
+__authors__ = ["V. Valls"]
+__license__ = "MIT"
+__date__ = "15/09/2016"
+
+
+import unittest
+from .. import weakref
+
+
+class Dummy(object):
+ """Dummy class to use it as geanie pig"""
+ def inc(self, a):
+ return a + 1
+
+ def __lt__(self, other):
+ return True
+
+
+def dummy_inc(a):
+ """Dummy function to use it as geanie pig"""
+ return a + 1
+
+
+class TestWeakMethod(unittest.TestCase):
+ """Tests for weakref.WeakMethod"""
+
+ def testMethod(self):
+ dummy = Dummy()
+ callable_ = weakref.WeakMethod(dummy.inc)
+ self.assertEqual(callable_()(10), 11)
+
+ def testMethodWithDeadObject(self):
+ dummy = Dummy()
+ callable_ = weakref.WeakMethod(dummy.inc)
+ dummy = None
+ self.assertIsNone(callable_())
+
+ def testMethodWithDeadFunction(self):
+ dummy = Dummy()
+ dummy.inc2 = lambda self, a: a + 1
+ callable_ = weakref.WeakMethod(dummy.inc2)
+ dummy.inc2 = None
+ self.assertIsNone(callable_())
+
+ def testFunction(self):
+ callable_ = weakref.WeakMethod(dummy_inc)
+ self.assertEqual(callable_()(10), 11)
+
+ def testDeadFunction(self):
+ def inc(a):
+ return a + 1
+ callable_ = weakref.WeakMethod(inc)
+ inc = None
+ self.assertIsNone(callable_())
+
+ def testLambda(self):
+ store = lambda a: a + 1 # noqa: E731
+ callable_ = weakref.WeakMethod(store)
+ self.assertEqual(callable_()(10), 11)
+
+ def testDeadLambda(self):
+ callable_ = weakref.WeakMethod(lambda a: a + 1)
+ self.assertIsNone(callable_())
+
+ def testCallbackOnDeadObject(self):
+ self.__count = 0
+
+ def callback(ref):
+ self.__count += 1
+ self.assertIs(callable_, ref)
+ dummy = Dummy()
+ callable_ = weakref.WeakMethod(dummy.inc, callback)
+ dummy = None
+ self.assertEqual(self.__count, 1)
+
+ def testCallbackOnDeadMethod(self):
+ self.__count = 0
+
+ def callback(ref):
+ self.__count += 1
+ self.assertIs(callable_, ref)
+ dummy = Dummy()
+ dummy.inc2 = lambda self, a: a + 1
+ callable_ = weakref.WeakMethod(dummy.inc2, callback)
+ dummy.inc2 = None
+ self.assertEqual(self.__count, 1)
+
+ def testCallbackOnDeadFunction(self):
+ self.__count = 0
+
+ def callback(ref):
+ self.__count += 1
+ self.assertIs(callable_, ref)
+ store = lambda a: a + 1 # noqa: E731
+ callable_ = weakref.WeakMethod(store, callback)
+ store = None
+ self.assertEqual(self.__count, 1)
+
+ def testEquals(self):
+ dummy = Dummy()
+ callable1 = weakref.WeakMethod(dummy.inc)
+ callable2 = weakref.WeakMethod(dummy.inc)
+ self.assertEqual(callable1, callable2)
+
+ def testInSet(self):
+ callable_set = set([])
+ dummy = Dummy()
+ callable_set.add(weakref.WeakMethod(dummy.inc))
+ callable_ = weakref.WeakMethod(dummy.inc)
+ self.assertIn(callable_, callable_set)
+
+ def testInDict(self):
+ callable_dict = {}
+ dummy = Dummy()
+ callable_dict[weakref.WeakMethod(dummy.inc)] = 10
+ callable_ = weakref.WeakMethod(dummy.inc)
+ self.assertEqual(callable_dict.get(callable_), 10)
+
+
+class TestWeakMethodProxy(unittest.TestCase):
+
+ def testMethod(self):
+ dummy = Dummy()
+ callable_ = weakref.WeakMethodProxy(dummy.inc)
+ self.assertEqual(callable_(10), 11)
+
+ def testMethodWithDeadObject(self):
+ dummy = Dummy()
+ method = weakref.WeakMethodProxy(dummy.inc)
+ dummy = None
+ self.assertRaises(ReferenceError, method, 9)
+
+
+class TestWeakList(unittest.TestCase):
+ """Tests for weakref.WeakList"""
+
+ def setUp(self):
+ self.list = weakref.WeakList()
+ self.object1 = Dummy()
+ self.object2 = Dummy()
+ self.list.append(self.object1)
+ self.list.append(self.object2)
+
+ def testAppend(self):
+ obj = Dummy()
+ self.list.append(obj)
+ self.assertEqual(len(self.list), 3)
+ obj = None
+ self.assertEqual(len(self.list), 2)
+
+ def testRemove(self):
+ self.list.remove(self.object1)
+ self.assertEqual(len(self.list), 1)
+
+ def testPop(self):
+ obj = self.list.pop(0)
+ self.assertIs(obj, self.object1)
+ self.assertEqual(len(self.list), 1)
+
+ def testGetItem(self):
+ self.assertIs(self.object1, self.list[0])
+
+ def testGetItemSlice(self):
+ objects = self.list[:]
+ self.assertEqual(len(objects), 2)
+ self.assertIs(self.object1, objects[0])
+ self.assertIs(self.object2, objects[1])
+
+ def testIter(self):
+ obj_list = list(self.list)
+ self.assertEqual(len(obj_list), 2)
+ self.assertIs(self.object1, obj_list[0])
+
+ def testLen(self):
+ self.assertEqual(len(self.list), 2)
+
+ def testSetItem(self):
+ obj = Dummy()
+ self.list[0] = obj
+ self.assertIsNot(self.object1, self.list[0])
+ obj = None
+ self.assertEqual(len(self.list), 1)
+
+ def testSetItemSlice(self):
+ obj = Dummy()
+ self.list[:] = [obj, obj]
+ self.assertEqual(len(self.list), 2)
+ self.assertIs(obj, self.list[0])
+ self.assertIs(obj, self.list[1])
+ obj = None
+ self.assertEqual(len(self.list), 0)
+
+ def testDelItem(self):
+ del self.list[0]
+ self.assertEqual(len(self.list), 1)
+ self.assertIs(self.object2, self.list[0])
+
+ def testDelItemSlice(self):
+ del self.list[:]
+ self.assertEqual(len(self.list), 0)
+
+ def testContains(self):
+ self.assertIn(self.object1, self.list)
+
+ def testAdd(self):
+ others = [Dummy()]
+ l = self.list + others
+ self.assertIs(l[0], self.object1)
+ self.assertEqual(len(l), 3)
+ others = None
+ self.assertEqual(len(l), 2)
+
+ def testExtend(self):
+ others = [Dummy()]
+ self.list.extend(others)
+ self.assertIs(self.list[0], self.object1)
+ self.assertEqual(len(self.list), 3)
+ others = None
+ self.assertEqual(len(self.list), 2)
+
+ def testIadd(self):
+ others = [Dummy()]
+ self.list += others
+ self.assertIs(self.list[0], self.object1)
+ self.assertEqual(len(self.list), 3)
+ others = None
+ self.assertEqual(len(self.list), 2)
+
+ def testMul(self):
+ l = self.list * 2
+ self.assertIs(l[0], self.object1)
+ self.assertEqual(len(l), 4)
+ self.object1 = None
+ self.assertEqual(len(l), 2)
+ self.assertIs(l[0], self.object2)
+ self.assertIs(l[1], self.object2)
+
+ def testImul(self):
+ self.list *= 2
+ self.assertIs(self.list[0], self.object1)
+ self.assertEqual(len(self.list), 4)
+ self.object1 = None
+ self.assertEqual(len(self.list), 2)
+ self.assertIs(self.list[0], self.object2)
+ self.assertIs(self.list[1], self.object2)
+
+ def testCount(self):
+ self.list.append(self.object2)
+ self.assertEqual(self.list.count(self.object1), 1)
+ self.assertEqual(self.list.count(self.object2), 2)
+
+ def testIndex(self):
+ self.assertEqual(self.list.index(self.object1), 0)
+ self.assertEqual(self.list.index(self.object2), 1)
+
+ def testInsert(self):
+ obj = Dummy()
+ self.list.insert(1, obj)
+ self.assertEqual(len(self.list), 3)
+ self.assertIs(self.list[1], obj)
+ obj = None
+ self.assertEqual(len(self.list), 2)
+
+ def testReverse(self):
+ self.list.reverse()
+ self.assertEqual(len(self.list), 2)
+ self.assertIs(self.list[0], self.object2)
+ self.assertIs(self.list[1], self.object1)
+
+ def testReverted(self):
+ new_list = reversed(self.list)
+ self.assertEqual(len(new_list), 2)
+ self.assertIs(self.list[1], self.object2)
+ self.assertIs(self.list[0], self.object1)
+ self.assertIs(new_list[0], self.object2)
+ self.assertIs(new_list[1], self.object1)
+ self.object1 = None
+ self.assertEqual(len(new_list), 1)
+
+ def testStr(self):
+ self.assertNotEqual(self.list.__str__(), "[]")
+
+ def testRepr(self):
+ self.assertNotEqual(self.list.__repr__(), "[]")
+
+ def testSort(self):
+ # only a coverage
+ self.list.sort()
+ self.assertEqual(len(self.list), 2)
diff --git a/src/silx/utils/testutils.py b/src/silx/utils/testutils.py
new file mode 100755
index 0000000..4177e33
--- /dev/null
+++ b/src/silx/utils/testutils.py
@@ -0,0 +1,351 @@
+# coding: utf-8
+# /*##########################################################################
+#
+# Copyright (c) 2016-2021 European Synchrotron Radiation Facility
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+#
+# ###########################################################################*/
+"""Utilities for writing tests.
+
+- :class:`ParametricTestCase` provides a :meth:`TestCase.subTest` replacement
+ for Python < 3.4
+- :class:`LoggingValidator` with context or the :func:`validate_logging`
+ decorator enables testing the number of logging messages of different levels.
+"""
+
+__authors__ = ["T. Vincent"]
+__license__ = "MIT"
+__date__ = "26/01/2018"
+
+
+import contextlib
+import functools
+import logging
+import sys
+import unittest
+from . import deprecation
+
+
+_logger = logging.getLogger(__name__)
+
+
+if sys.hexversion >= 0x030400F0: # Python >= 3.4
+ class ParametricTestCase(unittest.TestCase):
+ pass
+else:
+ class ParametricTestCase(unittest.TestCase):
+ """TestCase with subTest support for Python < 3.4.
+
+ Add subTest method to support parametric tests.
+ API is the same, but behavior differs:
+ If a subTest fails, the following ones are not run.
+ """
+
+ _subtest_msg = None # Class attribute to provide a default value
+
+ @contextlib.contextmanager
+ def subTest(self, msg=None, **params):
+ """Use as unittest.TestCase.subTest method in Python >= 3.4."""
+ # Format arguments as: '[msg] (key=value, ...)'
+ param_str = ', '.join(['%s=%s' % (k, v) for k, v in params.items()])
+ self._subtest_msg = '[%s] (%s)' % (msg or '', param_str)
+ yield
+ self._subtest_msg = None
+
+ def shortDescription(self):
+ short_desc = super(ParametricTestCase, self).shortDescription()
+ if self._subtest_msg is not None:
+ # Append subTest message to shortDescription
+ short_desc = ' '.join(
+ [msg for msg in (short_desc, self._subtest_msg) if msg])
+
+ return short_desc if short_desc else None
+
+
+def parameterize(test_case_class, *args, **kwargs):
+ """Create a suite containing all tests taken from the given
+ subclass, passing them the parameters.
+
+ .. code-block:: python
+
+ class TestParameterizedCase(unittest.TestCase):
+ def __init__(self, methodName='runTest', foo=None):
+ unittest.TestCase.__init__(self, methodName)
+ self.foo = foo
+
+ def suite():
+ testSuite = unittest.TestSuite()
+ testSuite.addTest(parameterize(TestParameterizedCase, foo=10))
+ testSuite.addTest(parameterize(TestParameterizedCase, foo=50))
+ return testSuite
+ """
+ test_loader = unittest.TestLoader()
+ test_names = test_loader.getTestCaseNames(test_case_class)
+ suite = unittest.TestSuite()
+ for name in test_names:
+ suite.addTest(test_case_class(name, *args, **kwargs))
+ return suite
+
+
+class LoggingRuntimeError(RuntimeError):
+ """Raised when the `LoggingValidator` fails"""
+
+ def __init__(self, msg, records):
+ super(LoggingRuntimeError, self).__init__(msg)
+ self.records = records
+
+ def __str__(self):
+ return super(LoggingRuntimeError, self).__str__() + " -> " + str(self.records)
+
+
+class LoggingValidator(logging.Handler):
+ """Context checking the number of logging messages from a specified Logger.
+
+ It disables propagation of logging message while running.
+
+ This is meant to be used as a with statement, for example:
+
+ >>> with LoggingValidator(logger, error=2, warning=0):
+ >>> pass # Run tests here expecting 2 ERROR and no WARNING from logger
+ ...
+
+ :param logger: Name or instance of the logger to test.
+ (Default: root logger)
+ :type logger: str or :class:`logging.Logger`
+ :param int critical: Expected number of CRITICAL messages.
+ Default: Do not check.
+ :param int error: Expected number of ERROR messages.
+ Default: Do not check.
+ :param int warning: Expected number of WARNING messages.
+ Default: Do not check.
+ :param int info: Expected number of INFO messages.
+ Default: Do not check.
+ :param int debug: Expected number of DEBUG messages.
+ Default: Do not check.
+ :param int notset: Expected number of NOTSET messages.
+ Default: Do not check.
+ :raises RuntimeError: If the message counts are the expected ones.
+ """
+
+ def __init__(self, logger=None, critical=None, error=None,
+ warning=None, info=None, debug=None, notset=None):
+ if logger is None:
+ logger = logging.getLogger()
+ elif not isinstance(logger, logging.Logger):
+ logger = logging.getLogger(logger)
+ self.logger = logger
+
+ self.records = []
+
+ self.expected_count_by_level = {
+ logging.CRITICAL: critical,
+ logging.ERROR: error,
+ logging.WARNING: warning,
+ logging.INFO: info,
+ logging.DEBUG: debug,
+ logging.NOTSET: notset
+ }
+
+ self._expected_count = sum([v for k, v in self.expected_count_by_level.items() if v is not None])
+ """Amount of any logging expected"""
+
+ super(LoggingValidator, self).__init__()
+
+ def __enter__(self):
+ """Context (i.e., with) support"""
+ self.records = [] # Reset recorded LogRecords
+ self.logger.addHandler(self)
+ self.logger.propagate = False
+ # ensure no log message is ignored
+ self.entry_level = self.logger.level * 1
+ self.logger.setLevel(logging.DEBUG)
+ self.entry_disabled = self.logger.disabled
+ self.logger.disabled = False
+ return self
+
+ def can_be_checked(self):
+ """Returns True if this listener have received enough messages to
+ be valid, and then checked.
+
+ This can be useful for asynchronous wait of messages. It allows process
+ an early break, instead of waiting much time in an active loop.
+ """
+ return len(self.records) >= self._expected_count
+
+ def get_count_by_level(self):
+ """Returns the current message count by level.
+ """
+ count = {
+ logging.CRITICAL: 0,
+ logging.ERROR: 0,
+ logging.WARNING: 0,
+ logging.INFO: 0,
+ logging.DEBUG: 0,
+ logging.NOTSET: 0
+ }
+ for record in self.records:
+ level = record.levelno
+ if level in count:
+ count[level] = count[level] + 1
+ return count
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ """Context (i.e., with) support"""
+ self.logger.removeHandler(self)
+ self.logger.propagate = True
+ self.logger.setLevel(self.entry_level)
+ self.logger.disabled = self.entry_disabled
+
+ count_by_level = self.get_count_by_level()
+
+ # Remove keys which does not matter
+ ignored = [r for r, v in self.expected_count_by_level.items() if v is None]
+ expected_count_by_level = dict(self.expected_count_by_level)
+ for i in ignored:
+ del count_by_level[i]
+ del expected_count_by_level[i]
+
+ if count_by_level != expected_count_by_level:
+ # Re-send record logs through logger as they where masked
+ # to help debug
+ message = ""
+ for level in count_by_level.keys():
+ if message != "":
+ message += ", "
+ count = count_by_level[level]
+ expected_count = expected_count_by_level[level]
+ message += "%d %s (got %d)" % (expected_count, logging.getLevelName(level), count)
+
+ raise LoggingRuntimeError(
+ 'Expected %s' % message, records=list(self.records))
+
+ def emit(self, record):
+ """Override :meth:`logging.Handler.emit`"""
+ self.records.append(record)
+
+
+def validate_logging(logger=None, critical=None, error=None,
+ warning=None, info=None, debug=None, notset=None):
+ """Decorator checking number of logging messages.
+
+ Propagation of logging messages is disabled by this decorator.
+
+ In case the expected number of logging messages is not found, it raises
+ a RuntimeError.
+
+ >>> class Test(unittest.TestCase):
+ ... @validate_logging('module_logger_name', error=2, warning=0)
+ ... def test(self):
+ ... pass # Test expecting 2 ERROR and 0 WARNING messages
+
+ :param logger: Name or instance of the logger to test.
+ (Default: root logger)
+ :type logger: str or :class:`logging.Logger`
+ :param int critical: Expected number of CRITICAL messages.
+ Default: Do not check.
+ :param int error: Expected number of ERROR messages.
+ Default: Do not check.
+ :param int warning: Expected number of WARNING messages.
+ Default: Do not check.
+ :param int info: Expected number of INFO messages.
+ Default: Do not check.
+ :param int debug: Expected number of DEBUG messages.
+ Default: Do not check.
+ :param int notset: Expected number of NOTSET messages.
+ Default: Do not check.
+ """
+ def decorator(func):
+ test_context = LoggingValidator(
+ logger, critical, error, warning, info, debug, notset)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ with test_context:
+ result = func(*args, **kwargs)
+ return result
+ return wrapper
+ return decorator
+
+
+# Backward compatibility
+class TestLogging(LoggingValidator):
+ def __init__(self, *args, **kwargs):
+ deprecation.deprecated_warning(
+ "Class",
+ "TestLogging",
+ since_version="1.0.0",
+ replacement="LoggingValidator")
+ super().__init__(*args, **kwargs)
+
+
+@deprecation.deprecated(since_version="1.0.0", replacement="validate_logging")
+def test_logging(*args, **kwargs):
+ return validate_logging(*args, **kwargs)
+
+
+# Simulate missing library context
+class EnsureImportError(object):
+ """This context manager allows to simulate the unavailability
+ of a library, even if it is actually available. It ensures that
+ an ImportError is raised if the code inside the context tries to
+ import the module.
+
+ It can be used to test that a correct fallback library is used,
+ or that the expected error code is returned.
+
+ Trivial example::
+
+ from silx.utils.testutils import EnsureImportError
+
+ with EnsureImportError("h5py"):
+ try:
+ import h5py
+ except ImportError:
+ print("Good")
+
+ .. note::
+
+ This context manager does not remove the library from the namespace,
+ if it is already imported. It only ensures that any attempt to import
+ it again will cause an ImportError to be raised.
+ """
+ def __init__(self, name):
+ """
+
+ :param str name: Name of module to be hidden (e.g. "h5py")
+ """
+ self.module_name = name
+
+ def __enter__(self):
+ """Simulate failed import by setting sys.modules[name]=None"""
+ if self.module_name not in sys.modules:
+ self._delete_on_exit = True
+ self._backup = None
+ else:
+ self._delete_on_exit = False
+ self._backup = sys.modules[self.module_name]
+ sys.modules[self.module_name] = None
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Restore previous state"""
+ if self._delete_on_exit:
+ del sys.modules[self.module_name]
+ else:
+ sys.modules[self.module_name] = self._backup
diff --git a/silx/utils/weakref.py b/src/silx/utils/weakref.py
index 06646e8..06646e8 100644
--- a/silx/utils/weakref.py
+++ b/src/silx/utils/weakref.py
diff --git a/version.py b/version.py
deleted file mode 100644
index cc29c6d..0000000
--- a/version.py
+++ /dev/null
@@ -1,120 +0,0 @@
-#!/usr/bin/env python3
-# coding: utf-8
-# /*##########################################################################
-#
-# Copyright (c) 2015-2021 European Synchrotron Radiation Facility
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
-#
-# ###########################################################################*/
-"""Unique place where the version number is defined.
-
-provides:
-* version = "1.2.3" or "1.2.3-beta4"
-* version_info = named tuple (1,2,3,"beta",4)
-* hexversion: 0x010203B4
-* strictversion = "1.2.3b4
-* debianversion = "1.2.3~beta4"
-* calc_hexversion: the function to transform a version_tuple into an integer
-
-This is called hexversion since it only really looks meaningful when viewed as the
-result of passing it to the built-in hex() function.
-The version_info value may be used for a more human-friendly encoding of the same information.
-
-The hexversion is a 32-bit number with the following layout:
-Bits (big endian order) Meaning
-1-8 PY_MAJOR_VERSION (the 2 in 2.1.0a3)
-9-16 PY_MINOR_VERSION (the 1 in 2.1.0a3)
-17-24 PY_MICRO_VERSION (the 0 in 2.1.0a3)
-25-28 PY_RELEASE_LEVEL (0xA for alpha, 0xB for beta, 0xC for release candidate and 0xF for final)
-29-32 PY_RELEASE_SERIAL (the 3 in 2.1.0a3, zero for final releases)
-
-Thus 2.1.0a3 is hexversion 0x020100a3.
-
-"""
-
-from __future__ import absolute_import, print_function, division
-__authors__ = ["Jérôme Kieffer"]
-__license__ = "MIT"
-__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
-__date__ = "30/09/2020"
-__status__ = "production"
-__docformat__ = 'restructuredtext'
-__all__ = ["date", "version_info", "strictversion", "hexversion", "debianversion",
- "calc_hexversion"]
-
-RELEASE_LEVEL_VALUE = {"dev": 0,
- "alpha": 10,
- "beta": 11,
- "candidate": 12,
- "final": 15}
-
-PRERELEASE_NORMALIZED_NAME = {"dev": "a",
- "alpha": "a",
- "beta": "b",
- "candidate": "rc"}
-
-MAJOR = 0
-MINOR = 15
-MICRO = 2
-RELEV = "final" # <16
-SERIAL = 0 # <16
-
-date = __date__
-
-from collections import namedtuple
-_version_info = namedtuple("version_info", ["major", "minor", "micro", "releaselevel", "serial"])
-
-version_info = _version_info(MAJOR, MINOR, MICRO, RELEV, SERIAL)
-
-strictversion = version = debianversion = "%d.%d.%d" % version_info[:3]
-if version_info.releaselevel != "final":
- _prerelease = PRERELEASE_NORMALIZED_NAME[version_info[3]]
- version += "-%s%s" % (_prerelease, version_info[-1])
- debianversion += "~adev%i" % version_info[-1] if RELEV == "dev" else "~%s%i" % (_prerelease, version_info[-1])
- strictversion += _prerelease + str(version_info[-1])
-
-
-def calc_hexversion(major=0, minor=0, micro=0, releaselevel="dev", serial=0):
- """Calculate the hexadecimal version number from the tuple version_info:
-
- :param major: integer
- :param minor: integer
- :param micro: integer
- :param relev: integer or string
- :param serial: integer
- :return: integer always increasing with revision numbers
- """
- try:
- releaselevel = int(releaselevel)
- except ValueError:
- releaselevel = RELEASE_LEVEL_VALUE.get(releaselevel, 0)
-
- hex_version = int(serial)
- hex_version |= releaselevel * 1 << 4
- hex_version |= int(micro) * 1 << 8
- hex_version |= int(minor) * 1 << 16
- hex_version |= int(major) * 1 << 24
- return hex_version
-
-
-hexversion = calc_hexversion(*version_info)
-
-if __name__ == "__main__":
- print(version)